skmsg.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. /* SPDX-License-Identifier: GPL-2.0 */
  2. /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
  3. #ifndef _LINUX_SKMSG_H
  4. #define _LINUX_SKMSG_H
  5. #include <linux/bpf.h>
  6. #include <linux/filter.h>
  7. #include <linux/scatterlist.h>
  8. #include <linux/skbuff.h>
  9. #include <net/sock.h>
  10. #include <net/tcp.h>
  11. #include <net/strparser.h>
  12. #define MAX_MSG_FRAGS MAX_SKB_FRAGS
  13. #define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
  14. enum __sk_action {
  15. __SK_DROP = 0,
  16. __SK_PASS,
  17. __SK_REDIRECT,
  18. __SK_NONE,
  19. };
  20. struct sk_msg_sg {
  21. u32 start;
  22. u32 curr;
  23. u32 end;
  24. u32 size;
  25. u32 copybreak;
  26. DECLARE_BITMAP(copy, MAX_MSG_FRAGS + 2);
  27. /* The extra two elements:
  28. * 1) used for chaining the front and sections when the list becomes
  29. * partitioned (e.g. end < start). The crypto APIs require the
  30. * chaining;
  31. * 2) to chain tailer SG entries after the message.
  32. */
  33. struct scatterlist data[MAX_MSG_FRAGS + 2];
  34. };
  35. /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
  36. struct sk_msg {
  37. struct sk_msg_sg sg;
  38. void *data;
  39. void *data_end;
  40. u32 apply_bytes;
  41. u32 cork_bytes;
  42. u32 flags;
  43. struct sk_buff *skb;
  44. struct sock *sk_redir;
  45. struct sock *sk;
  46. struct list_head list;
  47. };
  48. struct sk_psock_progs {
  49. struct bpf_prog *msg_parser;
  50. struct bpf_prog *stream_parser;
  51. struct bpf_prog *stream_verdict;
  52. struct bpf_prog *skb_verdict;
  53. };
  54. enum sk_psock_state_bits {
  55. SK_PSOCK_TX_ENABLED,
  56. SK_PSOCK_RX_STRP_ENABLED,
  57. };
  58. struct sk_psock_link {
  59. struct list_head list;
  60. struct bpf_map *map;
  61. void *link_raw;
  62. };
  63. struct sk_psock_work_state {
  64. u32 len;
  65. u32 off;
  66. };
  67. struct sk_psock {
  68. struct sock *sk;
  69. struct sock *sk_redir;
  70. u32 apply_bytes;
  71. u32 cork_bytes;
  72. u32 eval;
  73. bool redir_ingress; /* undefined if sk_redir is null */
  74. struct sk_msg *cork;
  75. struct sk_psock_progs progs;
  76. #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
  77. struct strparser strp;
  78. #endif
  79. struct sk_buff_head ingress_skb;
  80. struct list_head ingress_msg;
  81. spinlock_t ingress_lock;
  82. unsigned long state;
  83. struct list_head link;
  84. spinlock_t link_lock;
  85. refcount_t refcnt;
  86. void (*saved_unhash)(struct sock *sk);
  87. void (*saved_destroy)(struct sock *sk);
  88. void (*saved_close)(struct sock *sk, long timeout);
  89. void (*saved_write_space)(struct sock *sk);
  90. void (*saved_data_ready)(struct sock *sk);
  91. int (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
  92. bool restore);
  93. struct proto *sk_proto;
  94. struct mutex work_mutex;
  95. struct sk_psock_work_state work_state;
  96. struct delayed_work work;
  97. struct rcu_work rwork;
  98. };
  99. int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
  100. int elem_first_coalesce);
  101. int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
  102. u32 off, u32 len);
  103. void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
  104. int sk_msg_free(struct sock *sk, struct sk_msg *msg);
  105. int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
  106. void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
  107. void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
  108. u32 bytes);
  109. void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
  110. void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
  111. int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
  112. struct sk_msg *msg, u32 bytes);
  113. int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
  114. struct sk_msg *msg, u32 bytes);
  115. int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
  116. int len, int flags);
  117. bool sk_msg_is_readable(struct sock *sk);
  118. static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
  119. {
  120. WARN_ON(i == msg->sg.end && bytes);
  121. }
  122. static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
  123. {
  124. if (psock->apply_bytes) {
  125. if (psock->apply_bytes < bytes)
  126. psock->apply_bytes = 0;
  127. else
  128. psock->apply_bytes -= bytes;
  129. }
  130. }
  131. static inline u32 sk_msg_iter_dist(u32 start, u32 end)
  132. {
  133. return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
  134. }
  135. #define sk_msg_iter_var_prev(var) \
  136. do { \
  137. if (var == 0) \
  138. var = NR_MSG_FRAG_IDS - 1; \
  139. else \
  140. var--; \
  141. } while (0)
  142. #define sk_msg_iter_var_next(var) \
  143. do { \
  144. var++; \
  145. if (var == NR_MSG_FRAG_IDS) \
  146. var = 0; \
  147. } while (0)
  148. #define sk_msg_iter_prev(msg, which) \
  149. sk_msg_iter_var_prev(msg->sg.which)
  150. #define sk_msg_iter_next(msg, which) \
  151. sk_msg_iter_var_next(msg->sg.which)
  152. static inline void sk_msg_init(struct sk_msg *msg)
  153. {
  154. BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
  155. memset(msg, 0, sizeof(*msg));
  156. sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
  157. }
  158. static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
  159. int which, u32 size)
  160. {
  161. dst->sg.data[which] = src->sg.data[which];
  162. dst->sg.data[which].length = size;
  163. dst->sg.size += size;
  164. src->sg.size -= size;
  165. src->sg.data[which].length -= size;
  166. src->sg.data[which].offset += size;
  167. }
  168. static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
  169. {
  170. memcpy(dst, src, sizeof(*src));
  171. sk_msg_init(src);
  172. }
  173. static inline bool sk_msg_full(const struct sk_msg *msg)
  174. {
  175. return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
  176. }
  177. static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
  178. {
  179. return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
  180. }
  181. static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
  182. {
  183. return &msg->sg.data[which];
  184. }
  185. static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
  186. {
  187. return msg->sg.data[which];
  188. }
  189. static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
  190. {
  191. return sg_page(sk_msg_elem(msg, which));
  192. }
  193. static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
  194. {
  195. return msg->flags & BPF_F_INGRESS;
  196. }
  197. static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
  198. {
  199. struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
  200. if (test_bit(msg->sg.start, msg->sg.copy)) {
  201. msg->data = NULL;
  202. msg->data_end = NULL;
  203. } else {
  204. msg->data = sg_virt(sge);
  205. msg->data_end = msg->data + sge->length;
  206. }
  207. }
  208. static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
  209. u32 len, u32 offset)
  210. {
  211. struct scatterlist *sge;
  212. get_page(page);
  213. sge = sk_msg_elem(msg, msg->sg.end);
  214. sg_set_page(sge, page, len, offset);
  215. sg_unmark_end(sge);
  216. __set_bit(msg->sg.end, msg->sg.copy);
  217. msg->sg.size += len;
  218. sk_msg_iter_next(msg, end);
  219. }
  220. static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
  221. {
  222. do {
  223. if (copy_state)
  224. __set_bit(i, msg->sg.copy);
  225. else
  226. __clear_bit(i, msg->sg.copy);
  227. sk_msg_iter_var_next(i);
  228. if (i == msg->sg.end)
  229. break;
  230. } while (1);
  231. }
  232. static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
  233. {
  234. sk_msg_sg_copy(msg, start, true);
  235. }
  236. static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
  237. {
  238. sk_msg_sg_copy(msg, start, false);
  239. }
  240. static inline struct sk_psock *sk_psock(const struct sock *sk)
  241. {
  242. return __rcu_dereference_sk_user_data_with_flags(sk,
  243. SK_USER_DATA_PSOCK);
  244. }
  245. static inline void sk_psock_set_state(struct sk_psock *psock,
  246. enum sk_psock_state_bits bit)
  247. {
  248. set_bit(bit, &psock->state);
  249. }
  250. static inline void sk_psock_clear_state(struct sk_psock *psock,
  251. enum sk_psock_state_bits bit)
  252. {
  253. clear_bit(bit, &psock->state);
  254. }
  255. static inline bool sk_psock_test_state(const struct sk_psock *psock,
  256. enum sk_psock_state_bits bit)
  257. {
  258. return test_bit(bit, &psock->state);
  259. }
  260. static inline void sock_drop(struct sock *sk, struct sk_buff *skb)
  261. {
  262. sk_drops_add(sk, skb);
  263. kfree_skb(skb);
  264. }
  265. static inline void sk_psock_queue_msg(struct sk_psock *psock,
  266. struct sk_msg *msg)
  267. {
  268. spin_lock_bh(&psock->ingress_lock);
  269. if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
  270. list_add_tail(&msg->list, &psock->ingress_msg);
  271. else {
  272. sk_msg_free(psock->sk, msg);
  273. kfree(msg);
  274. }
  275. spin_unlock_bh(&psock->ingress_lock);
  276. }
  277. static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
  278. {
  279. struct sk_msg *msg;
  280. spin_lock_bh(&psock->ingress_lock);
  281. msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
  282. if (msg)
  283. list_del(&msg->list);
  284. spin_unlock_bh(&psock->ingress_lock);
  285. return msg;
  286. }
  287. static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
  288. {
  289. struct sk_msg *msg;
  290. spin_lock_bh(&psock->ingress_lock);
  291. msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
  292. spin_unlock_bh(&psock->ingress_lock);
  293. return msg;
  294. }
  295. static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
  296. struct sk_msg *msg)
  297. {
  298. struct sk_msg *ret;
  299. spin_lock_bh(&psock->ingress_lock);
  300. if (list_is_last(&msg->list, &psock->ingress_msg))
  301. ret = NULL;
  302. else
  303. ret = list_next_entry(msg, list);
  304. spin_unlock_bh(&psock->ingress_lock);
  305. return ret;
  306. }
  307. static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
  308. {
  309. return psock ? list_empty(&psock->ingress_msg) : true;
  310. }
  311. static inline void kfree_sk_msg(struct sk_msg *msg)
  312. {
  313. if (msg->skb)
  314. consume_skb(msg->skb);
  315. kfree(msg);
  316. }
  317. static inline void sk_psock_report_error(struct sk_psock *psock, int err)
  318. {
  319. struct sock *sk = psock->sk;
  320. sk->sk_err = err;
  321. sk_error_report(sk);
  322. }
  323. struct sk_psock *sk_psock_init(struct sock *sk, int node);
  324. void sk_psock_stop(struct sk_psock *psock);
  325. #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
  326. int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
  327. void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
  328. void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
  329. #else
  330. static inline int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
  331. {
  332. return -EOPNOTSUPP;
  333. }
  334. static inline void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
  335. {
  336. }
  337. static inline void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
  338. {
  339. }
  340. #endif
  341. void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
  342. void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
  343. int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
  344. struct sk_msg *msg);
  345. static inline struct sk_psock_link *sk_psock_init_link(void)
  346. {
  347. return kzalloc(sizeof(struct sk_psock_link),
  348. GFP_ATOMIC | __GFP_NOWARN);
  349. }
  350. static inline void sk_psock_free_link(struct sk_psock_link *link)
  351. {
  352. kfree(link);
  353. }
  354. struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
  355. static inline void sk_psock_cork_free(struct sk_psock *psock)
  356. {
  357. if (psock->cork) {
  358. sk_msg_free(psock->sk, psock->cork);
  359. kfree(psock->cork);
  360. psock->cork = NULL;
  361. }
  362. }
  363. static inline void sk_psock_restore_proto(struct sock *sk,
  364. struct sk_psock *psock)
  365. {
  366. if (psock->psock_update_sk_prot)
  367. psock->psock_update_sk_prot(sk, psock, true);
  368. }
  369. static inline struct sk_psock *sk_psock_get(struct sock *sk)
  370. {
  371. struct sk_psock *psock;
  372. rcu_read_lock();
  373. psock = sk_psock(sk);
  374. if (psock && !refcount_inc_not_zero(&psock->refcnt))
  375. psock = NULL;
  376. rcu_read_unlock();
  377. return psock;
  378. }
  379. void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
  380. static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
  381. {
  382. if (refcount_dec_and_test(&psock->refcnt))
  383. sk_psock_drop(sk, psock);
  384. }
  385. static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
  386. {
  387. if (psock->saved_data_ready)
  388. psock->saved_data_ready(sk);
  389. else
  390. sk->sk_data_ready(sk);
  391. }
  392. static inline void psock_set_prog(struct bpf_prog **pprog,
  393. struct bpf_prog *prog)
  394. {
  395. prog = xchg(pprog, prog);
  396. if (prog)
  397. bpf_prog_put(prog);
  398. }
  399. static inline int psock_replace_prog(struct bpf_prog **pprog,
  400. struct bpf_prog *prog,
  401. struct bpf_prog *old)
  402. {
  403. if (cmpxchg(pprog, old, prog) != old)
  404. return -ENOENT;
  405. if (old)
  406. bpf_prog_put(old);
  407. return 0;
  408. }
  409. static inline void psock_progs_drop(struct sk_psock_progs *progs)
  410. {
  411. psock_set_prog(&progs->msg_parser, NULL);
  412. psock_set_prog(&progs->stream_parser, NULL);
  413. psock_set_prog(&progs->stream_verdict, NULL);
  414. psock_set_prog(&progs->skb_verdict, NULL);
  415. }
  416. int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
  417. static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
  418. {
  419. if (!psock)
  420. return false;
  421. return !!psock->saved_data_ready;
  422. }
  423. static inline bool sk_is_udp(const struct sock *sk)
  424. {
  425. return sk->sk_type == SOCK_DGRAM &&
  426. sk->sk_protocol == IPPROTO_UDP;
  427. }
  428. #if IS_ENABLED(CONFIG_NET_SOCK_MSG)
  429. #define BPF_F_STRPARSER (1UL << 1)
  430. /* We only have two bits so far. */
  431. #define BPF_F_PTR_MASK ~(BPF_F_INGRESS | BPF_F_STRPARSER)
  432. static inline bool skb_bpf_strparser(const struct sk_buff *skb)
  433. {
  434. unsigned long sk_redir = skb->_sk_redir;
  435. return sk_redir & BPF_F_STRPARSER;
  436. }
  437. static inline void skb_bpf_set_strparser(struct sk_buff *skb)
  438. {
  439. skb->_sk_redir |= BPF_F_STRPARSER;
  440. }
  441. static inline bool skb_bpf_ingress(const struct sk_buff *skb)
  442. {
  443. unsigned long sk_redir = skb->_sk_redir;
  444. return sk_redir & BPF_F_INGRESS;
  445. }
  446. static inline void skb_bpf_set_ingress(struct sk_buff *skb)
  447. {
  448. skb->_sk_redir |= BPF_F_INGRESS;
  449. }
  450. static inline void skb_bpf_set_redir(struct sk_buff *skb, struct sock *sk_redir,
  451. bool ingress)
  452. {
  453. skb->_sk_redir = (unsigned long)sk_redir;
  454. if (ingress)
  455. skb->_sk_redir |= BPF_F_INGRESS;
  456. }
  457. static inline struct sock *skb_bpf_redirect_fetch(const struct sk_buff *skb)
  458. {
  459. unsigned long sk_redir = skb->_sk_redir;
  460. return (struct sock *)(sk_redir & BPF_F_PTR_MASK);
  461. }
  462. static inline void skb_bpf_redirect_clear(struct sk_buff *skb)
  463. {
  464. skb->_sk_redir = 0;
  465. }
  466. #endif /* CONFIG_NET_SOCK_MSG */
  467. #endif /* _LINUX_SKMSG_H */