util.c 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * vsock test utilities
  4. *
  5. * Copyright (C) 2017 Red Hat, Inc.
  6. *
  7. * Author: Stefan Hajnoczi <[email protected]>
  8. */
  9. #include <errno.h>
  10. #include <stdio.h>
  11. #include <stdint.h>
  12. #include <stdlib.h>
  13. #include <signal.h>
  14. #include <unistd.h>
  15. #include <assert.h>
  16. #include <sys/epoll.h>
  17. #include "timeout.h"
  18. #include "control.h"
  19. #include "util.h"
  20. /* Install signal handlers */
  21. void init_signals(void)
  22. {
  23. struct sigaction act = {
  24. .sa_handler = sigalrm,
  25. };
  26. sigaction(SIGALRM, &act, NULL);
  27. signal(SIGPIPE, SIG_IGN);
  28. }
  29. /* Parse a CID in string representation */
  30. unsigned int parse_cid(const char *str)
  31. {
  32. char *endptr = NULL;
  33. unsigned long n;
  34. errno = 0;
  35. n = strtoul(str, &endptr, 10);
  36. if (errno || *endptr != '\0') {
  37. fprintf(stderr, "malformed CID \"%s\"\n", str);
  38. exit(EXIT_FAILURE);
  39. }
  40. return n;
  41. }
  42. /* Wait for the remote to close the connection */
  43. void vsock_wait_remote_close(int fd)
  44. {
  45. struct epoll_event ev;
  46. int epollfd, nfds;
  47. epollfd = epoll_create1(0);
  48. if (epollfd == -1) {
  49. perror("epoll_create1");
  50. exit(EXIT_FAILURE);
  51. }
  52. ev.events = EPOLLRDHUP | EPOLLHUP;
  53. ev.data.fd = fd;
  54. if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
  55. perror("epoll_ctl");
  56. exit(EXIT_FAILURE);
  57. }
  58. nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
  59. if (nfds == -1) {
  60. perror("epoll_wait");
  61. exit(EXIT_FAILURE);
  62. }
  63. if (nfds == 0) {
  64. fprintf(stderr, "epoll_wait timed out\n");
  65. exit(EXIT_FAILURE);
  66. }
  67. assert(nfds == 1);
  68. assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
  69. assert(ev.data.fd == fd);
  70. close(epollfd);
  71. }
  72. /* Connect to <cid, port> and return the file descriptor. */
  73. static int vsock_connect(unsigned int cid, unsigned int port, int type)
  74. {
  75. union {
  76. struct sockaddr sa;
  77. struct sockaddr_vm svm;
  78. } addr = {
  79. .svm = {
  80. .svm_family = AF_VSOCK,
  81. .svm_port = port,
  82. .svm_cid = cid,
  83. },
  84. };
  85. int ret;
  86. int fd;
  87. control_expectln("LISTENING");
  88. fd = socket(AF_VSOCK, type, 0);
  89. timeout_begin(TIMEOUT);
  90. do {
  91. ret = connect(fd, &addr.sa, sizeof(addr.svm));
  92. timeout_check("connect");
  93. } while (ret < 0 && errno == EINTR);
  94. timeout_end();
  95. if (ret < 0) {
  96. int old_errno = errno;
  97. close(fd);
  98. fd = -1;
  99. errno = old_errno;
  100. }
  101. return fd;
  102. }
  103. int vsock_stream_connect(unsigned int cid, unsigned int port)
  104. {
  105. return vsock_connect(cid, port, SOCK_STREAM);
  106. }
  107. int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
  108. {
  109. return vsock_connect(cid, port, SOCK_SEQPACKET);
  110. }
  111. /* Listen on <cid, port> and return the first incoming connection. The remote
  112. * address is stored to clientaddrp. clientaddrp may be NULL.
  113. */
  114. static int vsock_accept(unsigned int cid, unsigned int port,
  115. struct sockaddr_vm *clientaddrp, int type)
  116. {
  117. union {
  118. struct sockaddr sa;
  119. struct sockaddr_vm svm;
  120. } addr = {
  121. .svm = {
  122. .svm_family = AF_VSOCK,
  123. .svm_port = port,
  124. .svm_cid = cid,
  125. },
  126. };
  127. union {
  128. struct sockaddr sa;
  129. struct sockaddr_vm svm;
  130. } clientaddr;
  131. socklen_t clientaddr_len = sizeof(clientaddr.svm);
  132. int fd;
  133. int client_fd;
  134. int old_errno;
  135. fd = socket(AF_VSOCK, type, 0);
  136. if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
  137. perror("bind");
  138. exit(EXIT_FAILURE);
  139. }
  140. if (listen(fd, 1) < 0) {
  141. perror("listen");
  142. exit(EXIT_FAILURE);
  143. }
  144. control_writeln("LISTENING");
  145. timeout_begin(TIMEOUT);
  146. do {
  147. client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
  148. timeout_check("accept");
  149. } while (client_fd < 0 && errno == EINTR);
  150. timeout_end();
  151. old_errno = errno;
  152. close(fd);
  153. errno = old_errno;
  154. if (client_fd < 0)
  155. return client_fd;
  156. if (clientaddr_len != sizeof(clientaddr.svm)) {
  157. fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
  158. (size_t)clientaddr_len);
  159. exit(EXIT_FAILURE);
  160. }
  161. if (clientaddr.sa.sa_family != AF_VSOCK) {
  162. fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
  163. clientaddr.sa.sa_family);
  164. exit(EXIT_FAILURE);
  165. }
  166. if (clientaddrp)
  167. *clientaddrp = clientaddr.svm;
  168. return client_fd;
  169. }
  170. int vsock_stream_accept(unsigned int cid, unsigned int port,
  171. struct sockaddr_vm *clientaddrp)
  172. {
  173. return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
  174. }
  175. int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
  176. struct sockaddr_vm *clientaddrp)
  177. {
  178. return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET);
  179. }
  180. /* Transmit one byte and check the return value.
  181. *
  182. * expected_ret:
  183. * <0 Negative errno (for testing errors)
  184. * 0 End-of-file
  185. * 1 Success
  186. */
  187. void send_byte(int fd, int expected_ret, int flags)
  188. {
  189. const uint8_t byte = 'A';
  190. ssize_t nwritten;
  191. timeout_begin(TIMEOUT);
  192. do {
  193. nwritten = send(fd, &byte, sizeof(byte), flags);
  194. timeout_check("write");
  195. } while (nwritten < 0 && errno == EINTR);
  196. timeout_end();
  197. if (expected_ret < 0) {
  198. if (nwritten != -1) {
  199. fprintf(stderr, "bogus send(2) return value %zd\n",
  200. nwritten);
  201. exit(EXIT_FAILURE);
  202. }
  203. if (errno != -expected_ret) {
  204. perror("write");
  205. exit(EXIT_FAILURE);
  206. }
  207. return;
  208. }
  209. if (nwritten < 0) {
  210. perror("write");
  211. exit(EXIT_FAILURE);
  212. }
  213. if (nwritten == 0) {
  214. if (expected_ret == 0)
  215. return;
  216. fprintf(stderr, "unexpected EOF while sending byte\n");
  217. exit(EXIT_FAILURE);
  218. }
  219. if (nwritten != sizeof(byte)) {
  220. fprintf(stderr, "bogus send(2) return value %zd\n", nwritten);
  221. exit(EXIT_FAILURE);
  222. }
  223. }
  224. /* Receive one byte and check the return value.
  225. *
  226. * expected_ret:
  227. * <0 Negative errno (for testing errors)
  228. * 0 End-of-file
  229. * 1 Success
  230. */
  231. void recv_byte(int fd, int expected_ret, int flags)
  232. {
  233. uint8_t byte;
  234. ssize_t nread;
  235. timeout_begin(TIMEOUT);
  236. do {
  237. nread = recv(fd, &byte, sizeof(byte), flags);
  238. timeout_check("read");
  239. } while (nread < 0 && errno == EINTR);
  240. timeout_end();
  241. if (expected_ret < 0) {
  242. if (nread != -1) {
  243. fprintf(stderr, "bogus recv(2) return value %zd\n",
  244. nread);
  245. exit(EXIT_FAILURE);
  246. }
  247. if (errno != -expected_ret) {
  248. perror("read");
  249. exit(EXIT_FAILURE);
  250. }
  251. return;
  252. }
  253. if (nread < 0) {
  254. perror("read");
  255. exit(EXIT_FAILURE);
  256. }
  257. if (nread == 0) {
  258. if (expected_ret == 0)
  259. return;
  260. fprintf(stderr, "unexpected EOF while receiving byte\n");
  261. exit(EXIT_FAILURE);
  262. }
  263. if (nread != sizeof(byte)) {
  264. fprintf(stderr, "bogus recv(2) return value %zd\n", nread);
  265. exit(EXIT_FAILURE);
  266. }
  267. if (byte != 'A') {
  268. fprintf(stderr, "unexpected byte read %c\n", byte);
  269. exit(EXIT_FAILURE);
  270. }
  271. }
  272. /* Run test cases. The program terminates if a failure occurs. */
  273. void run_tests(const struct test_case *test_cases,
  274. const struct test_opts *opts)
  275. {
  276. int i;
  277. for (i = 0; test_cases[i].name; i++) {
  278. void (*run)(const struct test_opts *opts);
  279. char *line;
  280. printf("%d - %s...", i, test_cases[i].name);
  281. fflush(stdout);
  282. /* Full barrier before executing the next test. This
  283. * ensures that client and server are executing the
  284. * same test case. In particular, it means whoever is
  285. * faster will not see the peer still executing the
  286. * last test. This is important because port numbers
  287. * can be used by multiple test cases.
  288. */
  289. if (test_cases[i].skip)
  290. control_writeln("SKIP");
  291. else
  292. control_writeln("NEXT");
  293. line = control_readln();
  294. if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) {
  295. printf("skipped\n");
  296. free(line);
  297. continue;
  298. }
  299. control_cmpln(line, "NEXT", true);
  300. free(line);
  301. if (opts->mode == TEST_MODE_CLIENT)
  302. run = test_cases[i].run_client;
  303. else
  304. run = test_cases[i].run_server;
  305. if (run)
  306. run(opts);
  307. printf("ok\n");
  308. }
  309. }
  310. void list_tests(const struct test_case *test_cases)
  311. {
  312. int i;
  313. printf("ID\tTest name\n");
  314. for (i = 0; test_cases[i].name; i++)
  315. printf("%d\t%s\n", i, test_cases[i].name);
  316. exit(EXIT_FAILURE);
  317. }
  318. void skip_test(struct test_case *test_cases, size_t test_cases_len,
  319. const char *test_id_str)
  320. {
  321. unsigned long test_id;
  322. char *endptr = NULL;
  323. errno = 0;
  324. test_id = strtoul(test_id_str, &endptr, 10);
  325. if (errno || *endptr != '\0') {
  326. fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str);
  327. exit(EXIT_FAILURE);
  328. }
  329. if (test_id >= test_cases_len) {
  330. fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n",
  331. test_id, test_cases_len - 1);
  332. exit(EXIT_FAILURE);
  333. }
  334. test_cases[test_id].skip = true;
  335. }