udp_bpf.c 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // SPDX-License-Identifier: GPL-2.0
  2. /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
  3. #include <linux/skmsg.h>
  4. #include <net/sock.h>
  5. #include <net/udp.h>
  6. #include <net/inet_common.h>
  7. #include "udp_impl.h"
  8. static struct proto *udpv6_prot_saved __read_mostly;
  9. static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  10. int flags, int *addr_len)
  11. {
  12. #if IS_ENABLED(CONFIG_IPV6)
  13. if (sk->sk_family == AF_INET6)
  14. return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len);
  15. #endif
  16. return udp_prot.recvmsg(sk, msg, len, flags, addr_len);
  17. }
  18. static bool udp_sk_has_data(struct sock *sk)
  19. {
  20. return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
  21. !skb_queue_empty(&sk->sk_receive_queue);
  22. }
  23. static bool psock_has_data(struct sk_psock *psock)
  24. {
  25. return !skb_queue_empty(&psock->ingress_skb) ||
  26. !sk_psock_queue_empty(psock);
  27. }
  28. #define udp_msg_has_data(__sk, __psock) \
  29. ({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
  30. static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
  31. long timeo)
  32. {
  33. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  34. int ret = 0;
  35. if (sk->sk_shutdown & RCV_SHUTDOWN)
  36. return 1;
  37. if (!timeo)
  38. return ret;
  39. add_wait_queue(sk_sleep(sk), &wait);
  40. sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  41. ret = udp_msg_has_data(sk, psock);
  42. if (!ret) {
  43. wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
  44. ret = udp_msg_has_data(sk, psock);
  45. }
  46. sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  47. remove_wait_queue(sk_sleep(sk), &wait);
  48. return ret;
  49. }
  50. static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  51. int flags, int *addr_len)
  52. {
  53. struct sk_psock *psock;
  54. int copied, ret;
  55. if (unlikely(flags & MSG_ERRQUEUE))
  56. return inet_recv_error(sk, msg, len, addr_len);
  57. if (!len)
  58. return 0;
  59. psock = sk_psock_get(sk);
  60. if (unlikely(!psock))
  61. return sk_udp_recvmsg(sk, msg, len, flags, addr_len);
  62. if (!psock_has_data(psock)) {
  63. ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
  64. goto out;
  65. }
  66. msg_bytes_ready:
  67. copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
  68. if (!copied) {
  69. long timeo;
  70. int data;
  71. timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
  72. data = udp_msg_wait_data(sk, psock, timeo);
  73. if (data) {
  74. if (psock_has_data(psock))
  75. goto msg_bytes_ready;
  76. ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
  77. goto out;
  78. }
  79. copied = -EAGAIN;
  80. }
  81. ret = copied;
  82. out:
  83. sk_psock_put(sk, psock);
  84. return ret;
  85. }
  86. enum {
  87. UDP_BPF_IPV4,
  88. UDP_BPF_IPV6,
  89. UDP_BPF_NUM_PROTS,
  90. };
  91. static DEFINE_SPINLOCK(udpv6_prot_lock);
  92. static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
  93. static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
  94. {
  95. *prot = *base;
  96. prot->close = sock_map_close;
  97. prot->recvmsg = udp_bpf_recvmsg;
  98. prot->sock_is_readable = sk_msg_is_readable;
  99. }
  100. static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
  101. {
  102. if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
  103. spin_lock_bh(&udpv6_prot_lock);
  104. if (likely(ops != udpv6_prot_saved)) {
  105. udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
  106. smp_store_release(&udpv6_prot_saved, ops);
  107. }
  108. spin_unlock_bh(&udpv6_prot_lock);
  109. }
  110. }
  111. static int __init udp_bpf_v4_build_proto(void)
  112. {
  113. udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
  114. return 0;
  115. }
  116. late_initcall(udp_bpf_v4_build_proto);
  117. int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
  118. {
  119. int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
  120. if (restore) {
  121. sk->sk_write_space = psock->saved_write_space;
  122. sock_replace_proto(sk, psock->sk_proto);
  123. return 0;
  124. }
  125. if (sk->sk_family == AF_INET6)
  126. udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
  127. sock_replace_proto(sk, &udp_bpf_prots[family]);
  128. return 0;
  129. }
  130. EXPORT_SYMBOL_GPL(udp_bpf_update_proto);