psock_snd.c 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. // SPDX-License-Identifier: GPL-2.0
  2. #define _GNU_SOURCE
  3. #include <arpa/inet.h>
  4. #include <errno.h>
  5. #include <error.h>
  6. #include <fcntl.h>
  7. #include <limits.h>
  8. #include <linux/filter.h>
  9. #include <linux/bpf.h>
  10. #include <linux/if_packet.h>
  11. #include <linux/if_vlan.h>
  12. #include <linux/virtio_net.h>
  13. #include <net/if.h>
  14. #include <net/ethernet.h>
  15. #include <netinet/ip.h>
  16. #include <netinet/udp.h>
  17. #include <poll.h>
  18. #include <sched.h>
  19. #include <stdbool.h>
  20. #include <stdint.h>
  21. #include <stdio.h>
  22. #include <stdlib.h>
  23. #include <string.h>
  24. #include <sys/mman.h>
  25. #include <sys/socket.h>
  26. #include <sys/stat.h>
  27. #include <sys/types.h>
  28. #include <unistd.h>
  29. #include "psock_lib.h"
  30. static bool cfg_use_bind;
  31. static bool cfg_use_csum_off;
  32. static bool cfg_use_csum_off_bad;
  33. static bool cfg_use_dgram;
  34. static bool cfg_use_gso;
  35. static bool cfg_use_qdisc_bypass;
  36. static bool cfg_use_vlan;
  37. static bool cfg_use_vnet;
  38. static char *cfg_ifname = "lo";
  39. static int cfg_mtu = 1500;
  40. static int cfg_payload_len = DATA_LEN;
  41. static int cfg_truncate_len = INT_MAX;
  42. static uint16_t cfg_port = 8000;
  43. /* test sending up to max mtu + 1 */
  44. #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
  45. static char tbuf[TEST_SZ], rbuf[TEST_SZ];
  46. static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
  47. {
  48. unsigned long sum = 0;
  49. int i;
  50. for (i = 0; i < num_u16; i++)
  51. sum += start[i];
  52. return sum;
  53. }
  54. static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
  55. unsigned long sum)
  56. {
  57. sum += add_csum_hword(start, num_u16);
  58. while (sum >> 16)
  59. sum = (sum & 0xffff) + (sum >> 16);
  60. return ~sum;
  61. }
  62. static int build_vnet_header(void *header)
  63. {
  64. struct virtio_net_hdr *vh = header;
  65. vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
  66. if (cfg_use_csum_off) {
  67. vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
  68. vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
  69. vh->csum_offset = __builtin_offsetof(struct udphdr, check);
  70. /* position check field exactly one byte beyond end of packet */
  71. if (cfg_use_csum_off_bad)
  72. vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
  73. vh->csum_offset - 1;
  74. }
  75. if (cfg_use_gso) {
  76. vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
  77. vh->gso_size = cfg_mtu - sizeof(struct iphdr);
  78. }
  79. return sizeof(*vh);
  80. }
  81. static int build_eth_header(void *header)
  82. {
  83. struct ethhdr *eth = header;
  84. if (cfg_use_vlan) {
  85. uint16_t *tag = header + ETH_HLEN;
  86. eth->h_proto = htons(ETH_P_8021Q);
  87. tag[1] = htons(ETH_P_IP);
  88. return ETH_HLEN + 4;
  89. }
  90. eth->h_proto = htons(ETH_P_IP);
  91. return ETH_HLEN;
  92. }
  93. static int build_ipv4_header(void *header, int payload_len)
  94. {
  95. struct iphdr *iph = header;
  96. iph->ihl = 5;
  97. iph->version = 4;
  98. iph->ttl = 8;
  99. iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
  100. iph->id = htons(1337);
  101. iph->protocol = IPPROTO_UDP;
  102. iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
  103. iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
  104. iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
  105. return iph->ihl << 2;
  106. }
  107. static int build_udp_header(void *header, int payload_len)
  108. {
  109. const int alen = sizeof(uint32_t);
  110. struct udphdr *udph = header;
  111. int len = sizeof(*udph) + payload_len;
  112. udph->source = htons(9);
  113. udph->dest = htons(cfg_port);
  114. udph->len = htons(len);
  115. if (cfg_use_csum_off)
  116. udph->check = build_ip_csum(header - (2 * alen), alen,
  117. htons(IPPROTO_UDP) + udph->len);
  118. else
  119. udph->check = 0;
  120. return sizeof(*udph);
  121. }
  122. static int build_packet(int payload_len)
  123. {
  124. int off = 0;
  125. off += build_vnet_header(tbuf);
  126. off += build_eth_header(tbuf + off);
  127. off += build_ipv4_header(tbuf + off, payload_len);
  128. off += build_udp_header(tbuf + off, payload_len);
  129. if (off + payload_len > sizeof(tbuf))
  130. error(1, 0, "payload length exceeds max");
  131. memset(tbuf + off, DATA_CHAR, payload_len);
  132. return off + payload_len;
  133. }
  134. static void do_bind(int fd)
  135. {
  136. struct sockaddr_ll laddr = {0};
  137. laddr.sll_family = AF_PACKET;
  138. laddr.sll_protocol = htons(ETH_P_IP);
  139. laddr.sll_ifindex = if_nametoindex(cfg_ifname);
  140. if (!laddr.sll_ifindex)
  141. error(1, errno, "if_nametoindex");
  142. if (bind(fd, (void *)&laddr, sizeof(laddr)))
  143. error(1, errno, "bind");
  144. }
  145. static void do_send(int fd, char *buf, int len)
  146. {
  147. int ret;
  148. if (!cfg_use_vnet) {
  149. buf += sizeof(struct virtio_net_hdr);
  150. len -= sizeof(struct virtio_net_hdr);
  151. }
  152. if (cfg_use_dgram) {
  153. buf += ETH_HLEN;
  154. len -= ETH_HLEN;
  155. }
  156. if (cfg_use_bind) {
  157. ret = write(fd, buf, len);
  158. } else {
  159. struct sockaddr_ll laddr = {0};
  160. laddr.sll_protocol = htons(ETH_P_IP);
  161. laddr.sll_ifindex = if_nametoindex(cfg_ifname);
  162. if (!laddr.sll_ifindex)
  163. error(1, errno, "if_nametoindex");
  164. ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
  165. }
  166. if (ret == -1)
  167. error(1, errno, "write");
  168. if (ret != len)
  169. error(1, 0, "write: %u %u", ret, len);
  170. fprintf(stderr, "tx: %u\n", ret);
  171. }
  172. static int do_tx(void)
  173. {
  174. const int one = 1;
  175. int fd, len;
  176. fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
  177. if (fd == -1)
  178. error(1, errno, "socket t");
  179. if (cfg_use_bind)
  180. do_bind(fd);
  181. if (cfg_use_qdisc_bypass &&
  182. setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
  183. error(1, errno, "setsockopt qdisc bypass");
  184. if (cfg_use_vnet &&
  185. setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
  186. error(1, errno, "setsockopt vnet");
  187. len = build_packet(cfg_payload_len);
  188. if (cfg_truncate_len < len)
  189. len = cfg_truncate_len;
  190. do_send(fd, tbuf, len);
  191. if (close(fd))
  192. error(1, errno, "close t");
  193. return len;
  194. }
  195. static int setup_rx(void)
  196. {
  197. struct timeval tv = { .tv_usec = 100 * 1000 };
  198. struct sockaddr_in raddr = {0};
  199. int fd;
  200. fd = socket(PF_INET, SOCK_DGRAM, 0);
  201. if (fd == -1)
  202. error(1, errno, "socket r");
  203. if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
  204. error(1, errno, "setsockopt rcv timeout");
  205. raddr.sin_family = AF_INET;
  206. raddr.sin_port = htons(cfg_port);
  207. raddr.sin_addr.s_addr = htonl(INADDR_ANY);
  208. if (bind(fd, (void *)&raddr, sizeof(raddr)))
  209. error(1, errno, "bind r");
  210. return fd;
  211. }
  212. static void do_rx(int fd, int expected_len, char *expected)
  213. {
  214. int ret;
  215. ret = recv(fd, rbuf, sizeof(rbuf), 0);
  216. if (ret == -1)
  217. error(1, errno, "recv");
  218. if (ret != expected_len)
  219. error(1, 0, "recv: %u != %u", ret, expected_len);
  220. if (memcmp(rbuf, expected, ret))
  221. error(1, 0, "recv: data mismatch");
  222. fprintf(stderr, "rx: %u\n", ret);
  223. }
  224. static int setup_sniffer(void)
  225. {
  226. struct timeval tv = { .tv_usec = 100 * 1000 };
  227. int fd;
  228. fd = socket(PF_PACKET, SOCK_RAW, 0);
  229. if (fd == -1)
  230. error(1, errno, "socket p");
  231. if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
  232. error(1, errno, "setsockopt rcv timeout");
  233. pair_udp_setfilter(fd);
  234. do_bind(fd);
  235. return fd;
  236. }
  237. static void parse_opts(int argc, char **argv)
  238. {
  239. int c;
  240. while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
  241. switch (c) {
  242. case 'b':
  243. cfg_use_bind = true;
  244. break;
  245. case 'c':
  246. cfg_use_csum_off = true;
  247. break;
  248. case 'C':
  249. cfg_use_csum_off_bad = true;
  250. break;
  251. case 'd':
  252. cfg_use_dgram = true;
  253. break;
  254. case 'g':
  255. cfg_use_gso = true;
  256. break;
  257. case 'l':
  258. cfg_payload_len = strtoul(optarg, NULL, 0);
  259. break;
  260. case 'q':
  261. cfg_use_qdisc_bypass = true;
  262. break;
  263. case 't':
  264. cfg_truncate_len = strtoul(optarg, NULL, 0);
  265. break;
  266. case 'v':
  267. cfg_use_vnet = true;
  268. break;
  269. case 'V':
  270. cfg_use_vlan = true;
  271. break;
  272. default:
  273. error(1, 0, "%s: parse error", argv[0]);
  274. }
  275. }
  276. if (cfg_use_vlan && cfg_use_dgram)
  277. error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
  278. if (cfg_use_csum_off && !cfg_use_vnet)
  279. error(1, 0, "option csum offload (-c) requires vnet (-v)");
  280. if (cfg_use_csum_off_bad && !cfg_use_csum_off)
  281. error(1, 0, "option csum bad (-C) requires csum offload (-c)");
  282. if (cfg_use_gso && !cfg_use_csum_off)
  283. error(1, 0, "option gso (-g) requires csum offload (-c)");
  284. }
  285. static void run_test(void)
  286. {
  287. int fdr, fds, total_len;
  288. fdr = setup_rx();
  289. fds = setup_sniffer();
  290. total_len = do_tx();
  291. /* BPF filter accepts only this length, vlan changes MAC */
  292. if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
  293. do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
  294. tbuf + sizeof(struct virtio_net_hdr));
  295. do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
  296. if (close(fds))
  297. error(1, errno, "close s");
  298. if (close(fdr))
  299. error(1, errno, "close r");
  300. }
  301. int main(int argc, char **argv)
  302. {
  303. parse_opts(argc, argv);
  304. if (system("ip link set dev lo mtu 1500"))
  305. error(1, errno, "ip link set mtu");
  306. if (system("ip addr add dev lo 172.17.0.1/24"))
  307. error(1, errno, "ip addr add");
  308. if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
  309. error(1, errno, "sysctl lo.accept_local");
  310. run_test();
  311. fprintf(stderr, "OK\n\n");
  312. return 0;
  313. }