tcp_bpf.c 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. // SPDX-License-Identifier: GPL-2.0
  2. /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
  3. #include <linux/skmsg.h>
  4. #include <linux/filter.h>
  5. #include <linux/bpf.h>
  6. #include <linux/init.h>
  7. #include <linux/wait.h>
  8. #include <linux/util_macros.h>
  9. #include <net/inet_common.h>
  10. #include <net/tls.h>
  11. void tcp_eat_skb(struct sock *sk, struct sk_buff *skb)
  12. {
  13. struct tcp_sock *tcp;
  14. int copied;
  15. if (!skb || !skb->len || !sk_is_tcp(sk))
  16. return;
  17. if (skb_bpf_strparser(skb))
  18. return;
  19. tcp = tcp_sk(sk);
  20. copied = tcp->copied_seq + skb->len;
  21. WRITE_ONCE(tcp->copied_seq, copied);
  22. tcp_rcv_space_adjust(sk);
  23. __tcp_cleanup_rbuf(sk, skb->len);
  24. }
  25. static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
  26. struct sk_msg *msg, u32 apply_bytes, int flags)
  27. {
  28. bool apply = apply_bytes;
  29. struct scatterlist *sge;
  30. u32 size, copied = 0;
  31. struct sk_msg *tmp;
  32. int i, ret = 0;
  33. tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
  34. if (unlikely(!tmp))
  35. return -ENOMEM;
  36. lock_sock(sk);
  37. tmp->sg.start = msg->sg.start;
  38. i = msg->sg.start;
  39. do {
  40. sge = sk_msg_elem(msg, i);
  41. size = (apply && apply_bytes < sge->length) ?
  42. apply_bytes : sge->length;
  43. if (!sk_wmem_schedule(sk, size)) {
  44. if (!copied)
  45. ret = -ENOMEM;
  46. break;
  47. }
  48. sk_mem_charge(sk, size);
  49. sk_msg_xfer(tmp, msg, i, size);
  50. copied += size;
  51. if (sge->length)
  52. get_page(sk_msg_page(tmp, i));
  53. sk_msg_iter_var_next(i);
  54. tmp->sg.end = i;
  55. if (apply) {
  56. apply_bytes -= size;
  57. if (!apply_bytes) {
  58. if (sge->length)
  59. sk_msg_iter_var_prev(i);
  60. break;
  61. }
  62. }
  63. } while (i != msg->sg.end);
  64. if (!ret) {
  65. msg->sg.start = i;
  66. sk_psock_queue_msg(psock, tmp);
  67. sk_psock_data_ready(sk, psock);
  68. } else {
  69. sk_msg_free(sk, tmp);
  70. kfree(tmp);
  71. }
  72. release_sock(sk);
  73. return ret;
  74. }
  75. static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
  76. int flags, bool uncharge)
  77. {
  78. bool apply = apply_bytes;
  79. struct scatterlist *sge;
  80. struct page *page;
  81. int size, ret = 0;
  82. u32 off;
  83. while (1) {
  84. bool has_tx_ulp;
  85. sge = sk_msg_elem(msg, msg->sg.start);
  86. size = (apply && apply_bytes < sge->length) ?
  87. apply_bytes : sge->length;
  88. off = sge->offset;
  89. page = sg_page(sge);
  90. tcp_rate_check_app_limited(sk);
  91. retry:
  92. has_tx_ulp = tls_sw_has_ctx_tx(sk);
  93. if (has_tx_ulp) {
  94. flags |= MSG_SENDPAGE_NOPOLICY;
  95. ret = kernel_sendpage_locked(sk,
  96. page, off, size, flags);
  97. } else {
  98. ret = do_tcp_sendpages(sk, page, off, size, flags);
  99. }
  100. if (ret <= 0)
  101. return ret;
  102. if (apply)
  103. apply_bytes -= ret;
  104. msg->sg.size -= ret;
  105. sge->offset += ret;
  106. sge->length -= ret;
  107. if (uncharge)
  108. sk_mem_uncharge(sk, ret);
  109. if (ret != size) {
  110. size -= ret;
  111. off += ret;
  112. goto retry;
  113. }
  114. if (!sge->length) {
  115. put_page(page);
  116. sk_msg_iter_next(msg, start);
  117. sg_init_table(sge, 1);
  118. if (msg->sg.start == msg->sg.end)
  119. break;
  120. }
  121. if (apply && !apply_bytes)
  122. break;
  123. }
  124. return 0;
  125. }
  126. static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
  127. u32 apply_bytes, int flags, bool uncharge)
  128. {
  129. int ret;
  130. lock_sock(sk);
  131. ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
  132. release_sock(sk);
  133. return ret;
  134. }
  135. int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress,
  136. struct sk_msg *msg, u32 bytes, int flags)
  137. {
  138. struct sk_psock *psock = sk_psock_get(sk);
  139. int ret;
  140. if (unlikely(!psock))
  141. return -EPIPE;
  142. ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
  143. tcp_bpf_push_locked(sk, msg, bytes, flags, false);
  144. sk_psock_put(sk, psock);
  145. return ret;
  146. }
  147. EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
  148. #ifdef CONFIG_BPF_SYSCALL
  149. static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
  150. long timeo)
  151. {
  152. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  153. int ret = 0;
  154. if (sk->sk_shutdown & RCV_SHUTDOWN)
  155. return 1;
  156. if (!timeo)
  157. return ret;
  158. add_wait_queue(sk_sleep(sk), &wait);
  159. sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  160. ret = sk_wait_event(sk, &timeo,
  161. !list_empty(&psock->ingress_msg) ||
  162. !skb_queue_empty_lockless(&sk->sk_receive_queue), &wait);
  163. sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  164. remove_wait_queue(sk_sleep(sk), &wait);
  165. return ret;
  166. }
  167. static bool is_next_msg_fin(struct sk_psock *psock)
  168. {
  169. struct scatterlist *sge;
  170. struct sk_msg *msg_rx;
  171. int i;
  172. msg_rx = sk_psock_peek_msg(psock);
  173. i = msg_rx->sg.start;
  174. sge = sk_msg_elem(msg_rx, i);
  175. if (!sge->length) {
  176. struct sk_buff *skb = msg_rx->skb;
  177. if (skb && TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN)
  178. return true;
  179. }
  180. return false;
  181. }
  182. static int tcp_bpf_recvmsg_parser(struct sock *sk,
  183. struct msghdr *msg,
  184. size_t len,
  185. int flags,
  186. int *addr_len)
  187. {
  188. struct tcp_sock *tcp = tcp_sk(sk);
  189. int peek = flags & MSG_PEEK;
  190. u32 seq = tcp->copied_seq;
  191. struct sk_psock *psock;
  192. int copied = 0;
  193. if (unlikely(flags & MSG_ERRQUEUE))
  194. return inet_recv_error(sk, msg, len, addr_len);
  195. if (!len)
  196. return 0;
  197. psock = sk_psock_get(sk);
  198. if (unlikely(!psock))
  199. return tcp_recvmsg(sk, msg, len, flags, addr_len);
  200. lock_sock(sk);
  201. /* We may have received data on the sk_receive_queue pre-accept and
  202. * then we can not use read_skb in this context because we haven't
  203. * assigned a sk_socket yet so have no link to the ops. The work-around
  204. * is to check the sk_receive_queue and in these cases read skbs off
  205. * queue again. The read_skb hook is not running at this point because
  206. * of lock_sock so we avoid having multiple runners in read_skb.
  207. */
  208. if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) {
  209. tcp_data_ready(sk);
  210. /* This handles the ENOMEM errors if we both receive data
  211. * pre accept and are already under memory pressure. At least
  212. * let user know to retry.
  213. */
  214. if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) {
  215. copied = -EAGAIN;
  216. goto out;
  217. }
  218. }
  219. msg_bytes_ready:
  220. copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
  221. /* The typical case for EFAULT is the socket was gracefully
  222. * shutdown with a FIN pkt. So check here the other case is
  223. * some error on copy_page_to_iter which would be unexpected.
  224. * On fin return correct return code to zero.
  225. */
  226. if (copied == -EFAULT) {
  227. bool is_fin = is_next_msg_fin(psock);
  228. if (is_fin) {
  229. copied = 0;
  230. seq++;
  231. goto out;
  232. }
  233. }
  234. seq += copied;
  235. if (!copied) {
  236. long timeo;
  237. int data;
  238. if (sock_flag(sk, SOCK_DONE))
  239. goto out;
  240. if (sk->sk_err) {
  241. copied = sock_error(sk);
  242. goto out;
  243. }
  244. if (sk->sk_shutdown & RCV_SHUTDOWN)
  245. goto out;
  246. if (sk->sk_state == TCP_CLOSE) {
  247. copied = -ENOTCONN;
  248. goto out;
  249. }
  250. timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
  251. if (!timeo) {
  252. copied = -EAGAIN;
  253. goto out;
  254. }
  255. if (signal_pending(current)) {
  256. copied = sock_intr_errno(timeo);
  257. goto out;
  258. }
  259. data = tcp_msg_wait_data(sk, psock, timeo);
  260. if (data && !sk_psock_queue_empty(psock))
  261. goto msg_bytes_ready;
  262. copied = -EAGAIN;
  263. }
  264. out:
  265. if (!peek)
  266. WRITE_ONCE(tcp->copied_seq, seq);
  267. tcp_rcv_space_adjust(sk);
  268. if (copied > 0)
  269. __tcp_cleanup_rbuf(sk, copied);
  270. release_sock(sk);
  271. sk_psock_put(sk, psock);
  272. return copied;
  273. }
  274. static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  275. int flags, int *addr_len)
  276. {
  277. struct sk_psock *psock;
  278. int copied, ret;
  279. if (unlikely(flags & MSG_ERRQUEUE))
  280. return inet_recv_error(sk, msg, len, addr_len);
  281. if (!len)
  282. return 0;
  283. psock = sk_psock_get(sk);
  284. if (unlikely(!psock))
  285. return tcp_recvmsg(sk, msg, len, flags, addr_len);
  286. if (!skb_queue_empty(&sk->sk_receive_queue) &&
  287. sk_psock_queue_empty(psock)) {
  288. sk_psock_put(sk, psock);
  289. return tcp_recvmsg(sk, msg, len, flags, addr_len);
  290. }
  291. lock_sock(sk);
  292. msg_bytes_ready:
  293. copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
  294. if (!copied) {
  295. long timeo;
  296. int data;
  297. timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
  298. data = tcp_msg_wait_data(sk, psock, timeo);
  299. if (data) {
  300. if (!sk_psock_queue_empty(psock))
  301. goto msg_bytes_ready;
  302. release_sock(sk);
  303. sk_psock_put(sk, psock);
  304. return tcp_recvmsg(sk, msg, len, flags, addr_len);
  305. }
  306. copied = -EAGAIN;
  307. }
  308. ret = copied;
  309. release_sock(sk);
  310. sk_psock_put(sk, psock);
  311. return ret;
  312. }
  313. static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
  314. struct sk_msg *msg, int *copied, int flags)
  315. {
  316. bool cork = false, enospc = sk_msg_full(msg), redir_ingress;
  317. struct sock *sk_redir;
  318. u32 tosend, origsize, sent, delta = 0;
  319. u32 eval;
  320. int ret;
  321. more_data:
  322. if (psock->eval == __SK_NONE) {
  323. /* Track delta in msg size to add/subtract it on SK_DROP from
  324. * returned to user copied size. This ensures user doesn't
  325. * get a positive return code with msg_cut_data and SK_DROP
  326. * verdict.
  327. */
  328. delta = msg->sg.size;
  329. psock->eval = sk_psock_msg_verdict(sk, psock, msg);
  330. delta -= msg->sg.size;
  331. }
  332. if (msg->cork_bytes &&
  333. msg->cork_bytes > msg->sg.size && !enospc) {
  334. psock->cork_bytes = msg->cork_bytes - msg->sg.size;
  335. if (!psock->cork) {
  336. psock->cork = kzalloc(sizeof(*psock->cork),
  337. GFP_ATOMIC | __GFP_NOWARN);
  338. if (!psock->cork)
  339. return -ENOMEM;
  340. }
  341. memcpy(psock->cork, msg, sizeof(*msg));
  342. return 0;
  343. }
  344. tosend = msg->sg.size;
  345. if (psock->apply_bytes && psock->apply_bytes < tosend)
  346. tosend = psock->apply_bytes;
  347. eval = __SK_NONE;
  348. switch (psock->eval) {
  349. case __SK_PASS:
  350. ret = tcp_bpf_push(sk, msg, tosend, flags, true);
  351. if (unlikely(ret)) {
  352. *copied -= sk_msg_free(sk, msg);
  353. break;
  354. }
  355. sk_msg_apply_bytes(psock, tosend);
  356. break;
  357. case __SK_REDIRECT:
  358. redir_ingress = psock->redir_ingress;
  359. sk_redir = psock->sk_redir;
  360. sk_msg_apply_bytes(psock, tosend);
  361. if (!psock->apply_bytes) {
  362. /* Clean up before releasing the sock lock. */
  363. eval = psock->eval;
  364. psock->eval = __SK_NONE;
  365. psock->sk_redir = NULL;
  366. }
  367. if (psock->cork) {
  368. cork = true;
  369. psock->cork = NULL;
  370. }
  371. sk_msg_return(sk, msg, tosend);
  372. release_sock(sk);
  373. origsize = msg->sg.size;
  374. ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress,
  375. msg, tosend, flags);
  376. sent = origsize - msg->sg.size;
  377. if (eval == __SK_REDIRECT)
  378. sock_put(sk_redir);
  379. lock_sock(sk);
  380. if (unlikely(ret < 0)) {
  381. int free = sk_msg_free_nocharge(sk, msg);
  382. if (!cork)
  383. *copied -= free;
  384. }
  385. if (cork) {
  386. sk_msg_free(sk, msg);
  387. kfree(msg);
  388. msg = NULL;
  389. ret = 0;
  390. }
  391. break;
  392. case __SK_DROP:
  393. default:
  394. sk_msg_free_partial(sk, msg, tosend);
  395. sk_msg_apply_bytes(psock, tosend);
  396. *copied -= (tosend + delta);
  397. return -EACCES;
  398. }
  399. if (likely(!ret)) {
  400. if (!psock->apply_bytes) {
  401. psock->eval = __SK_NONE;
  402. if (psock->sk_redir) {
  403. sock_put(psock->sk_redir);
  404. psock->sk_redir = NULL;
  405. }
  406. }
  407. if (msg &&
  408. msg->sg.data[msg->sg.start].page_link &&
  409. msg->sg.data[msg->sg.start].length) {
  410. if (eval == __SK_REDIRECT)
  411. sk_mem_charge(sk, tosend - sent);
  412. goto more_data;
  413. }
  414. }
  415. return ret;
  416. }
  417. static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
  418. {
  419. struct sk_msg tmp, *msg_tx = NULL;
  420. int copied = 0, err = 0;
  421. struct sk_psock *psock;
  422. long timeo;
  423. int flags;
  424. /* Don't let internal do_tcp_sendpages() flags through */
  425. flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
  426. flags |= MSG_NO_SHARED_FRAGS;
  427. psock = sk_psock_get(sk);
  428. if (unlikely(!psock))
  429. return tcp_sendmsg(sk, msg, size);
  430. lock_sock(sk);
  431. timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
  432. while (msg_data_left(msg)) {
  433. bool enospc = false;
  434. u32 copy, osize;
  435. if (sk->sk_err) {
  436. err = -sk->sk_err;
  437. goto out_err;
  438. }
  439. copy = msg_data_left(msg);
  440. if (!sk_stream_memory_free(sk))
  441. goto wait_for_sndbuf;
  442. if (psock->cork) {
  443. msg_tx = psock->cork;
  444. } else {
  445. msg_tx = &tmp;
  446. sk_msg_init(msg_tx);
  447. }
  448. osize = msg_tx->sg.size;
  449. err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
  450. if (err) {
  451. if (err != -ENOSPC)
  452. goto wait_for_memory;
  453. enospc = true;
  454. copy = msg_tx->sg.size - osize;
  455. }
  456. err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
  457. copy);
  458. if (err < 0) {
  459. sk_msg_trim(sk, msg_tx, osize);
  460. goto out_err;
  461. }
  462. copied += copy;
  463. if (psock->cork_bytes) {
  464. if (size > psock->cork_bytes)
  465. psock->cork_bytes = 0;
  466. else
  467. psock->cork_bytes -= size;
  468. if (psock->cork_bytes && !enospc)
  469. goto out_err;
  470. /* All cork bytes are accounted, rerun the prog. */
  471. psock->eval = __SK_NONE;
  472. psock->cork_bytes = 0;
  473. }
  474. err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
  475. if (unlikely(err < 0))
  476. goto out_err;
  477. continue;
  478. wait_for_sndbuf:
  479. set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
  480. wait_for_memory:
  481. err = sk_stream_wait_memory(sk, &timeo);
  482. if (err) {
  483. if (msg_tx && msg_tx != psock->cork)
  484. sk_msg_free(sk, msg_tx);
  485. goto out_err;
  486. }
  487. }
  488. out_err:
  489. if (err < 0)
  490. err = sk_stream_error(sk, msg->msg_flags, err);
  491. release_sock(sk);
  492. sk_psock_put(sk, psock);
  493. return copied ? copied : err;
  494. }
  495. static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
  496. size_t size, int flags)
  497. {
  498. struct sk_msg tmp, *msg = NULL;
  499. int err = 0, copied = 0;
  500. struct sk_psock *psock;
  501. bool enospc = false;
  502. psock = sk_psock_get(sk);
  503. if (unlikely(!psock))
  504. return tcp_sendpage(sk, page, offset, size, flags);
  505. lock_sock(sk);
  506. if (psock->cork) {
  507. msg = psock->cork;
  508. } else {
  509. msg = &tmp;
  510. sk_msg_init(msg);
  511. }
  512. /* Catch case where ring is full and sendpage is stalled. */
  513. if (unlikely(sk_msg_full(msg)))
  514. goto out_err;
  515. sk_msg_page_add(msg, page, size, offset);
  516. sk_mem_charge(sk, size);
  517. copied = size;
  518. if (sk_msg_full(msg))
  519. enospc = true;
  520. if (psock->cork_bytes) {
  521. if (size > psock->cork_bytes)
  522. psock->cork_bytes = 0;
  523. else
  524. psock->cork_bytes -= size;
  525. if (psock->cork_bytes && !enospc)
  526. goto out_err;
  527. /* All cork bytes are accounted, rerun the prog. */
  528. psock->eval = __SK_NONE;
  529. psock->cork_bytes = 0;
  530. }
  531. err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
  532. out_err:
  533. release_sock(sk);
  534. sk_psock_put(sk, psock);
  535. return copied ? copied : err;
  536. }
  537. enum {
  538. TCP_BPF_IPV4,
  539. TCP_BPF_IPV6,
  540. TCP_BPF_NUM_PROTS,
  541. };
  542. enum {
  543. TCP_BPF_BASE,
  544. TCP_BPF_TX,
  545. TCP_BPF_RX,
  546. TCP_BPF_TXRX,
  547. TCP_BPF_NUM_CFGS,
  548. };
  549. static struct proto *tcpv6_prot_saved __read_mostly;
  550. static DEFINE_SPINLOCK(tcpv6_prot_lock);
  551. static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
  552. static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
  553. struct proto *base)
  554. {
  555. prot[TCP_BPF_BASE] = *base;
  556. prot[TCP_BPF_BASE].destroy = sock_map_destroy;
  557. prot[TCP_BPF_BASE].close = sock_map_close;
  558. prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
  559. prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable;
  560. prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
  561. prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
  562. prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
  563. prot[TCP_BPF_RX] = prot[TCP_BPF_BASE];
  564. prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser;
  565. prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX];
  566. prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser;
  567. }
  568. static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
  569. {
  570. if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
  571. spin_lock_bh(&tcpv6_prot_lock);
  572. if (likely(ops != tcpv6_prot_saved)) {
  573. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
  574. smp_store_release(&tcpv6_prot_saved, ops);
  575. }
  576. spin_unlock_bh(&tcpv6_prot_lock);
  577. }
  578. }
  579. static int __init tcp_bpf_v4_build_proto(void)
  580. {
  581. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
  582. return 0;
  583. }
  584. late_initcall(tcp_bpf_v4_build_proto);
  585. static int tcp_bpf_assert_proto_ops(struct proto *ops)
  586. {
  587. /* In order to avoid retpoline, we make assumptions when we call
  588. * into ops if e.g. a psock is not present. Make sure they are
  589. * indeed valid assumptions.
  590. */
  591. return ops->recvmsg == tcp_recvmsg &&
  592. ops->sendmsg == tcp_sendmsg &&
  593. ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
  594. }
  595. int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
  596. {
  597. int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
  598. int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
  599. if (psock->progs.stream_verdict || psock->progs.skb_verdict) {
  600. config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX;
  601. }
  602. if (restore) {
  603. if (inet_csk_has_ulp(sk)) {
  604. /* TLS does not have an unhash proto in SW cases,
  605. * but we need to ensure we stop using the sock_map
  606. * unhash routine because the associated psock is being
  607. * removed. So use the original unhash handler.
  608. */
  609. WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash);
  610. tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
  611. } else {
  612. sk->sk_write_space = psock->saved_write_space;
  613. /* Pairs with lockless read in sk_clone_lock() */
  614. sock_replace_proto(sk, psock->sk_proto);
  615. }
  616. return 0;
  617. }
  618. if (sk->sk_family == AF_INET6) {
  619. if (tcp_bpf_assert_proto_ops(psock->sk_proto))
  620. return -EINVAL;
  621. tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
  622. }
  623. /* Pairs with lockless read in sk_clone_lock() */
  624. sock_replace_proto(sk, &tcp_bpf_prots[family][config]);
  625. return 0;
  626. }
  627. EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
  628. /* If a child got cloned from a listening socket that had tcp_bpf
  629. * protocol callbacks installed, we need to restore the callbacks to
  630. * the default ones because the child does not inherit the psock state
  631. * that tcp_bpf callbacks expect.
  632. */
  633. void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
  634. {
  635. struct proto *prot = newsk->sk_prot;
  636. if (is_insidevar(prot, tcp_bpf_prots))
  637. newsk->sk_prot = sk->sk_prot_creator;
  638. }
  639. #endif /* CONFIG_BPF_SYSCALL */