xdp_synproxy.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. // SPDX-License-Identifier: LGPL-2.1 OR BSD-2-Clause
  2. /* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. */
  3. #include <stdnoreturn.h>
  4. #include <stdlib.h>
  5. #include <stdio.h>
  6. #include <string.h>
  7. #include <errno.h>
  8. #include <unistd.h>
  9. #include <getopt.h>
  10. #include <signal.h>
  11. #include <sys/types.h>
  12. #include <bpf/bpf.h>
  13. #include <bpf/libbpf.h>
  14. #include <net/if.h>
  15. #include <linux/if_link.h>
  16. #include <linux/limits.h>
  17. static unsigned int ifindex;
  18. static __u32 attached_prog_id;
  19. static bool attached_tc;
  20. static void noreturn cleanup(int sig)
  21. {
  22. LIBBPF_OPTS(bpf_xdp_attach_opts, opts);
  23. int prog_fd;
  24. int err;
  25. if (attached_prog_id == 0)
  26. exit(0);
  27. if (attached_tc) {
  28. LIBBPF_OPTS(bpf_tc_hook, hook,
  29. .ifindex = ifindex,
  30. .attach_point = BPF_TC_INGRESS);
  31. err = bpf_tc_hook_destroy(&hook);
  32. if (err < 0) {
  33. fprintf(stderr, "Error: bpf_tc_hook_destroy: %s\n", strerror(-err));
  34. fprintf(stderr, "Failed to destroy the TC hook\n");
  35. exit(1);
  36. }
  37. exit(0);
  38. }
  39. prog_fd = bpf_prog_get_fd_by_id(attached_prog_id);
  40. if (prog_fd < 0) {
  41. fprintf(stderr, "Error: bpf_prog_get_fd_by_id: %s\n", strerror(-prog_fd));
  42. err = bpf_xdp_attach(ifindex, -1, 0, NULL);
  43. if (err < 0) {
  44. fprintf(stderr, "Error: bpf_set_link_xdp_fd: %s\n", strerror(-err));
  45. fprintf(stderr, "Failed to detach XDP program\n");
  46. exit(1);
  47. }
  48. } else {
  49. opts.old_prog_fd = prog_fd;
  50. err = bpf_xdp_attach(ifindex, -1, XDP_FLAGS_REPLACE, &opts);
  51. close(prog_fd);
  52. if (err < 0) {
  53. fprintf(stderr, "Error: bpf_set_link_xdp_fd_opts: %s\n", strerror(-err));
  54. /* Not an error if already replaced by someone else. */
  55. if (err != -EEXIST) {
  56. fprintf(stderr, "Failed to detach XDP program\n");
  57. exit(1);
  58. }
  59. }
  60. }
  61. exit(0);
  62. }
  63. static noreturn void usage(const char *progname)
  64. {
  65. fprintf(stderr, "Usage: %s [--iface <iface>|--prog <prog_id>] [--mss4 <mss ipv4> --mss6 <mss ipv6> --wscale <wscale> --ttl <ttl>] [--ports <port1>,<port2>,...] [--single] [--tc]\n",
  66. progname);
  67. exit(1);
  68. }
  69. static unsigned long parse_arg_ul(const char *progname, const char *arg, unsigned long limit)
  70. {
  71. unsigned long res;
  72. char *endptr;
  73. errno = 0;
  74. res = strtoul(arg, &endptr, 10);
  75. if (errno != 0 || *endptr != '\0' || arg[0] == '\0' || res > limit)
  76. usage(progname);
  77. return res;
  78. }
  79. static void parse_options(int argc, char *argv[], unsigned int *ifindex, __u32 *prog_id,
  80. __u64 *tcpipopts, char **ports, bool *single, bool *tc)
  81. {
  82. static struct option long_options[] = {
  83. { "help", no_argument, NULL, 'h' },
  84. { "iface", required_argument, NULL, 'i' },
  85. { "prog", required_argument, NULL, 'x' },
  86. { "mss4", required_argument, NULL, 4 },
  87. { "mss6", required_argument, NULL, 6 },
  88. { "wscale", required_argument, NULL, 'w' },
  89. { "ttl", required_argument, NULL, 't' },
  90. { "ports", required_argument, NULL, 'p' },
  91. { "single", no_argument, NULL, 's' },
  92. { "tc", no_argument, NULL, 'c' },
  93. { NULL, 0, NULL, 0 },
  94. };
  95. unsigned long mss4, wscale, ttl;
  96. unsigned long long mss6;
  97. unsigned int tcpipopts_mask = 0;
  98. if (argc < 2)
  99. usage(argv[0]);
  100. *ifindex = 0;
  101. *prog_id = 0;
  102. *tcpipopts = 0;
  103. *ports = NULL;
  104. *single = false;
  105. *tc = false;
  106. while (true) {
  107. int opt;
  108. opt = getopt_long(argc, argv, "", long_options, NULL);
  109. if (opt == -1)
  110. break;
  111. switch (opt) {
  112. case 'h':
  113. usage(argv[0]);
  114. break;
  115. case 'i':
  116. *ifindex = if_nametoindex(optarg);
  117. if (*ifindex == 0)
  118. usage(argv[0]);
  119. break;
  120. case 'x':
  121. *prog_id = parse_arg_ul(argv[0], optarg, UINT32_MAX);
  122. if (*prog_id == 0)
  123. usage(argv[0]);
  124. break;
  125. case 4:
  126. mss4 = parse_arg_ul(argv[0], optarg, UINT16_MAX);
  127. tcpipopts_mask |= 1 << 0;
  128. break;
  129. case 6:
  130. mss6 = parse_arg_ul(argv[0], optarg, UINT16_MAX);
  131. tcpipopts_mask |= 1 << 1;
  132. break;
  133. case 'w':
  134. wscale = parse_arg_ul(argv[0], optarg, 14);
  135. tcpipopts_mask |= 1 << 2;
  136. break;
  137. case 't':
  138. ttl = parse_arg_ul(argv[0], optarg, UINT8_MAX);
  139. tcpipopts_mask |= 1 << 3;
  140. break;
  141. case 'p':
  142. *ports = optarg;
  143. break;
  144. case 's':
  145. *single = true;
  146. break;
  147. case 'c':
  148. *tc = true;
  149. break;
  150. default:
  151. usage(argv[0]);
  152. }
  153. }
  154. if (optind < argc)
  155. usage(argv[0]);
  156. if (tcpipopts_mask == 0xf) {
  157. if (mss4 == 0 || mss6 == 0 || wscale == 0 || ttl == 0)
  158. usage(argv[0]);
  159. *tcpipopts = (mss6 << 32) | (ttl << 24) | (wscale << 16) | mss4;
  160. } else if (tcpipopts_mask != 0) {
  161. usage(argv[0]);
  162. }
  163. if (*ifindex != 0 && *prog_id != 0)
  164. usage(argv[0]);
  165. if (*ifindex == 0 && *prog_id == 0)
  166. usage(argv[0]);
  167. }
  168. static int syncookie_attach(const char *argv0, unsigned int ifindex, bool tc)
  169. {
  170. struct bpf_prog_info info = {};
  171. __u32 info_len = sizeof(info);
  172. char xdp_filename[PATH_MAX];
  173. struct bpf_program *prog;
  174. struct bpf_object *obj;
  175. int prog_fd;
  176. int err;
  177. snprintf(xdp_filename, sizeof(xdp_filename), "%s_kern.bpf.o", argv0);
  178. obj = bpf_object__open_file(xdp_filename, NULL);
  179. err = libbpf_get_error(obj);
  180. if (err < 0) {
  181. fprintf(stderr, "Error: bpf_object__open_file: %s\n", strerror(-err));
  182. return err;
  183. }
  184. err = bpf_object__load(obj);
  185. if (err < 0) {
  186. fprintf(stderr, "Error: bpf_object__open_file: %s\n", strerror(-err));
  187. return err;
  188. }
  189. prog = bpf_object__find_program_by_name(obj, tc ? "syncookie_tc" : "syncookie_xdp");
  190. if (!prog) {
  191. fprintf(stderr, "Error: bpf_object__find_program_by_name: program was not found\n");
  192. return -ENOENT;
  193. }
  194. prog_fd = bpf_program__fd(prog);
  195. err = bpf_obj_get_info_by_fd(prog_fd, &info, &info_len);
  196. if (err < 0) {
  197. fprintf(stderr, "Error: bpf_obj_get_info_by_fd: %s\n", strerror(-err));
  198. goto out;
  199. }
  200. attached_tc = tc;
  201. attached_prog_id = info.id;
  202. signal(SIGINT, cleanup);
  203. signal(SIGTERM, cleanup);
  204. if (tc) {
  205. LIBBPF_OPTS(bpf_tc_hook, hook,
  206. .ifindex = ifindex,
  207. .attach_point = BPF_TC_INGRESS);
  208. LIBBPF_OPTS(bpf_tc_opts, opts,
  209. .handle = 1,
  210. .priority = 1,
  211. .prog_fd = prog_fd);
  212. err = bpf_tc_hook_create(&hook);
  213. if (err < 0) {
  214. fprintf(stderr, "Error: bpf_tc_hook_create: %s\n",
  215. strerror(-err));
  216. goto fail;
  217. }
  218. err = bpf_tc_attach(&hook, &opts);
  219. if (err < 0) {
  220. fprintf(stderr, "Error: bpf_tc_attach: %s\n",
  221. strerror(-err));
  222. goto fail;
  223. }
  224. } else {
  225. err = bpf_xdp_attach(ifindex, prog_fd,
  226. XDP_FLAGS_UPDATE_IF_NOEXIST, NULL);
  227. if (err < 0) {
  228. fprintf(stderr, "Error: bpf_set_link_xdp_fd: %s\n",
  229. strerror(-err));
  230. goto fail;
  231. }
  232. }
  233. err = 0;
  234. out:
  235. bpf_object__close(obj);
  236. return err;
  237. fail:
  238. signal(SIGINT, SIG_DFL);
  239. signal(SIGTERM, SIG_DFL);
  240. attached_prog_id = 0;
  241. goto out;
  242. }
  243. static int syncookie_open_bpf_maps(__u32 prog_id, int *values_map_fd, int *ports_map_fd)
  244. {
  245. struct bpf_prog_info prog_info;
  246. __u32 map_ids[8];
  247. __u32 info_len;
  248. int prog_fd;
  249. int err;
  250. int i;
  251. *values_map_fd = -1;
  252. *ports_map_fd = -1;
  253. prog_fd = bpf_prog_get_fd_by_id(prog_id);
  254. if (prog_fd < 0) {
  255. fprintf(stderr, "Error: bpf_prog_get_fd_by_id: %s\n", strerror(-prog_fd));
  256. return prog_fd;
  257. }
  258. prog_info = (struct bpf_prog_info) {
  259. .nr_map_ids = 8,
  260. .map_ids = (__u64)(unsigned long)map_ids,
  261. };
  262. info_len = sizeof(prog_info);
  263. err = bpf_obj_get_info_by_fd(prog_fd, &prog_info, &info_len);
  264. if (err != 0) {
  265. fprintf(stderr, "Error: bpf_obj_get_info_by_fd: %s\n", strerror(-err));
  266. goto out;
  267. }
  268. if (prog_info.nr_map_ids < 2) {
  269. fprintf(stderr, "Error: Found %u BPF maps, expected at least 2\n",
  270. prog_info.nr_map_ids);
  271. err = -ENOENT;
  272. goto out;
  273. }
  274. for (i = 0; i < prog_info.nr_map_ids; i++) {
  275. struct bpf_map_info map_info = {};
  276. int map_fd;
  277. err = bpf_map_get_fd_by_id(map_ids[i]);
  278. if (err < 0) {
  279. fprintf(stderr, "Error: bpf_map_get_fd_by_id: %s\n", strerror(-err));
  280. goto err_close_map_fds;
  281. }
  282. map_fd = err;
  283. info_len = sizeof(map_info);
  284. err = bpf_obj_get_info_by_fd(map_fd, &map_info, &info_len);
  285. if (err != 0) {
  286. fprintf(stderr, "Error: bpf_obj_get_info_by_fd: %s\n", strerror(-err));
  287. close(map_fd);
  288. goto err_close_map_fds;
  289. }
  290. if (strcmp(map_info.name, "values") == 0) {
  291. *values_map_fd = map_fd;
  292. continue;
  293. }
  294. if (strcmp(map_info.name, "allowed_ports") == 0) {
  295. *ports_map_fd = map_fd;
  296. continue;
  297. }
  298. close(map_fd);
  299. }
  300. if (*values_map_fd != -1 && *ports_map_fd != -1) {
  301. err = 0;
  302. goto out;
  303. }
  304. err = -ENOENT;
  305. err_close_map_fds:
  306. if (*values_map_fd != -1)
  307. close(*values_map_fd);
  308. if (*ports_map_fd != -1)
  309. close(*ports_map_fd);
  310. *values_map_fd = -1;
  311. *ports_map_fd = -1;
  312. out:
  313. close(prog_fd);
  314. return err;
  315. }
  316. int main(int argc, char *argv[])
  317. {
  318. int values_map_fd, ports_map_fd;
  319. __u64 tcpipopts;
  320. bool firstiter;
  321. __u64 prevcnt;
  322. __u32 prog_id;
  323. char *ports;
  324. bool single;
  325. int err = 0;
  326. bool tc;
  327. parse_options(argc, argv, &ifindex, &prog_id, &tcpipopts, &ports,
  328. &single, &tc);
  329. if (prog_id == 0) {
  330. if (!tc) {
  331. err = bpf_xdp_query_id(ifindex, 0, &prog_id);
  332. if (err < 0) {
  333. fprintf(stderr, "Error: bpf_get_link_xdp_id: %s\n",
  334. strerror(-err));
  335. goto out;
  336. }
  337. }
  338. if (prog_id == 0) {
  339. err = syncookie_attach(argv[0], ifindex, tc);
  340. if (err < 0)
  341. goto out;
  342. prog_id = attached_prog_id;
  343. }
  344. }
  345. err = syncookie_open_bpf_maps(prog_id, &values_map_fd, &ports_map_fd);
  346. if (err < 0)
  347. goto out;
  348. if (ports) {
  349. __u16 port_last = 0;
  350. __u32 port_idx = 0;
  351. char *p = ports;
  352. fprintf(stderr, "Replacing allowed ports\n");
  353. while (p && *p != '\0') {
  354. char *token = strsep(&p, ",");
  355. __u16 port;
  356. port = parse_arg_ul(argv[0], token, UINT16_MAX);
  357. err = bpf_map_update_elem(ports_map_fd, &port_idx, &port, BPF_ANY);
  358. if (err != 0) {
  359. fprintf(stderr, "Error: bpf_map_update_elem: %s\n", strerror(-err));
  360. fprintf(stderr, "Failed to add port %u (index %u)\n",
  361. port, port_idx);
  362. goto out_close_maps;
  363. }
  364. fprintf(stderr, "Added port %u\n", port);
  365. port_idx++;
  366. }
  367. err = bpf_map_update_elem(ports_map_fd, &port_idx, &port_last, BPF_ANY);
  368. if (err != 0) {
  369. fprintf(stderr, "Error: bpf_map_update_elem: %s\n", strerror(-err));
  370. fprintf(stderr, "Failed to add the terminator value 0 (index %u)\n",
  371. port_idx);
  372. goto out_close_maps;
  373. }
  374. }
  375. if (tcpipopts) {
  376. __u32 key = 0;
  377. fprintf(stderr, "Replacing TCP/IP options\n");
  378. err = bpf_map_update_elem(values_map_fd, &key, &tcpipopts, BPF_ANY);
  379. if (err != 0) {
  380. fprintf(stderr, "Error: bpf_map_update_elem: %s\n", strerror(-err));
  381. goto out_close_maps;
  382. }
  383. }
  384. if ((ports || tcpipopts) && attached_prog_id == 0 && !single)
  385. goto out_close_maps;
  386. prevcnt = 0;
  387. firstiter = true;
  388. while (true) {
  389. __u32 key = 1;
  390. __u64 value;
  391. err = bpf_map_lookup_elem(values_map_fd, &key, &value);
  392. if (err != 0) {
  393. fprintf(stderr, "Error: bpf_map_lookup_elem: %s\n", strerror(-err));
  394. goto out_close_maps;
  395. }
  396. if (firstiter) {
  397. prevcnt = value;
  398. firstiter = false;
  399. }
  400. if (single) {
  401. printf("Total SYNACKs generated: %llu\n", value);
  402. break;
  403. }
  404. printf("SYNACKs generated: %llu (total %llu)\n", value - prevcnt, value);
  405. prevcnt = value;
  406. sleep(1);
  407. }
  408. out_close_maps:
  409. close(values_map_fd);
  410. close(ports_map_fd);
  411. out:
  412. return err == 0 ? 0 : 1;
  413. }