hab_comm.c 6.3 KB


  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Copyright (c) 2016-2019, The Linux Foundation. All rights reserved.
  4. * Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved.
  5. */
  6. #include "hab.h"
  7. struct loopback_msg {
  8. struct list_head node;
  9. int payload_size;
  10. struct hab_header header;
  11. char payload[];
  12. };
  13. struct lb_thread_struct {
  14. int stop; /* set by creator */
  15. int bexited; /* set by thread */
  16. void *data; /* thread private data */
  17. };
  18. struct loopback_dev {
  19. spinlock_t io_lock;
  20. struct list_head msg_list;
  21. int msg_cnt;
  22. struct task_struct *kthread; /* creator's thread handle */
  23. struct lb_thread_struct thread_data; /* thread private data */
  24. wait_queue_head_t thread_queue;
  25. struct loopback_msg *current_msg;
  26. };
  27. static int lb_thread_queue_empty(struct loopback_dev *dev)
  28. {
  29. int ret;
  30. spin_lock_bh(&dev->io_lock);
  31. ret = list_empty(&dev->msg_list);
  32. spin_unlock_bh(&dev->io_lock);
  33. return ret;
  34. }
  35. int lb_kthread(void *d)
  36. {
  37. struct lb_thread_struct *p = (struct lb_thread_struct *)d;
  38. struct physical_channel *pchan = (struct physical_channel *)p->data;
  39. struct loopback_dev *dev = pchan->hyp_data;
  40. int ret = 0;
  41. while (!p->stop) {
  42. schedule();
  43. ret = wait_event_interruptible(dev->thread_queue,
  44. !lb_thread_queue_empty(dev) ||
  45. p->stop);
  46. spin_lock_bh(&dev->io_lock);
  47. while (!list_empty(&dev->msg_list)) {
  48. struct loopback_msg *msg = NULL;
  49. msg = list_first_entry(&dev->msg_list,
  50. struct loopback_msg, node);
  51. dev->current_msg = msg;
  52. list_del(&msg->node);
  53. dev->msg_cnt--;
  54. ret = hab_msg_recv(pchan, &msg->header);
  55. if (ret) {
  56. pr_err("failed %d msg handling sz %d header %d %d %d, %d %X %d, total %d\n",
  57. ret, msg->payload_size,
  58. HAB_HEADER_GET_ID(msg->header),
  59. HAB_HEADER_GET_TYPE(msg->header),
  60. HAB_HEADER_GET_SIZE(msg->header),
  61. msg->header.session_id,
  62. msg->header.signature,
  63. msg->header.sequence, dev->msg_cnt);
  64. }
  65. kfree(msg);
  66. dev->current_msg = NULL;
  67. }
  68. spin_unlock_bh(&dev->io_lock);
  69. }
  70. p->bexited = 1;
  71. pr_debug("exit kthread\n");
  72. return 0;
  73. }
  74. int physical_channel_send(struct physical_channel *pchan,
  75. struct hab_header *header,
  76. void *payload, unsigned int flags)
  77. {
  78. int size = HAB_HEADER_GET_SIZE(*header); /* payload size */
  79. struct timespec64 ts = {0};
  80. struct loopback_msg *msg = NULL;
  81. struct loopback_dev *dev = pchan->hyp_data;
  82. /* Only used in virtio arch */
  83. (void)flags;
  84. msg = kmalloc(size + sizeof(*msg), GFP_KERNEL);
  85. if (!msg)
  86. return -ENOMEM;
  87. memcpy(&msg->header, header, sizeof(*header));
  88. msg->payload_size = size; /* payload size could be zero */
  89. if (size && payload) {
  90. if (HAB_HEADER_GET_TYPE(*header) == HAB_PAYLOAD_TYPE_PROFILE) {
  91. struct habmm_xing_vm_stat *pstat =
  92. (struct habmm_xing_vm_stat *)payload;
  93. ktime_get_ts64(&ts);
  94. pstat->tx_sec = ts.tv_sec;
  95. pstat->tx_usec = ts.tv_nsec/NSEC_PER_USEC;
  96. }
  97. memcpy(msg->payload, payload, size);
  98. }
  99. spin_lock_bh(&dev->io_lock);
  100. list_add_tail(&msg->node, &dev->msg_list);
  101. dev->msg_cnt++;
  102. spin_unlock_bh(&dev->io_lock);
  103. wake_up_interruptible(&dev->thread_queue);
  104. return 0;
  105. }
  106. /* loopback read is only used during open */
  107. int physical_channel_read(struct physical_channel *pchan,
  108. void *payload,
  109. size_t read_size)
  110. {
  111. struct loopback_dev *dev = pchan->hyp_data;
  112. struct loopback_msg *msg = dev->current_msg;
  113. if (read_size) {
  114. if (read_size != msg->payload_size) {
  115. pr_err("read size mismatch requested %zd, received %d\n",
  116. read_size, msg->payload_size);
  117. memcpy(payload, msg->payload, min(((int)read_size),
  118. msg->payload_size));
  119. } else {
  120. memcpy(payload, msg->payload, read_size);
  121. }
  122. } else {
  123. read_size = 0;
  124. }
  125. return read_size;
  126. }
  127. /* pchan is directly added into the hab_device */
  128. int loopback_pchan_create(struct hab_device *dev, char *pchan_name)
  129. {
  130. int result;
  131. struct physical_channel *pchan = NULL;
  132. struct loopback_dev *lb_dev = NULL;
  133. pchan = hab_pchan_alloc(dev, LOOPBACK_DOM);
  134. if (!pchan) {
  135. result = -ENOMEM;
  136. goto err;
  137. }
  138. pchan->closed = 0;
  139. strscpy(pchan->name, pchan_name, sizeof(pchan->name));
  140. lb_dev = kzalloc(sizeof(*lb_dev), GFP_KERNEL);
  141. if (!lb_dev) {
  142. result = -ENOMEM;
  143. goto err;
  144. }
  145. spin_lock_init(&lb_dev->io_lock);
  146. INIT_LIST_HEAD(&lb_dev->msg_list);
  147. init_waitqueue_head(&lb_dev->thread_queue);
  148. lb_dev->thread_data.data = pchan;
  149. lb_dev->kthread = kthread_run(lb_kthread, &lb_dev->thread_data,
  150. pchan->name);
  151. if (IS_ERR(lb_dev->kthread)) {
  152. result = PTR_ERR(lb_dev->kthread);
  153. pr_err("failed to create kthread for %s, ret %d\n",
  154. pchan->name, result);
  155. goto err;
  156. }
  157. pchan->hyp_data = lb_dev;
  158. return 0;
  159. err:
  160. kfree(lb_dev);
  161. kfree(pchan);
  162. return result;
  163. }
  164. void physical_channel_rx_dispatch(unsigned long data)
  165. {
  166. }
  167. int habhyp_commdev_alloc(void **commdev, int is_be, char *name,
  168. int vmid_remote, struct hab_device *mmid_device)
  169. {
  170. struct physical_channel *pchan;
  171. int ret = loopback_pchan_create(mmid_device, name);
  172. if (ret) {
  173. pr_err("failed to create %s pchan in mmid device %s, ret %d, pchan cnt %d\n",
  174. name, mmid_device->name, ret, mmid_device->pchan_cnt);
  175. *commdev = NULL;
  176. } else {
  177. pr_debug("loopback physical channel on %s return %d, loopback mode(%d), total pchan %d\n",
  178. name, ret, hab_driver.b_loopback,
  179. mmid_device->pchan_cnt);
  180. pchan = hab_pchan_find_domid(mmid_device,
  181. HABCFG_VMID_DONT_CARE);
  182. *commdev = pchan;
  183. hab_pchan_put(pchan);
  184. pr_debug("pchan %s vchans %d refcnt %d\n",
  185. pchan->name, pchan->vcnt, get_refcnt(pchan->refcount));
  186. }
  187. return ret;
  188. }
  189. int habhyp_commdev_dealloc(void *commdev)
  190. {
  191. struct physical_channel *pchan = commdev;
  192. struct loopback_dev *dev = pchan->hyp_data;
  193. struct loopback_msg *msg, *tmp;
  194. int ret;
  195. spin_lock_bh(&dev->io_lock);
  196. if (!list_empty(&dev->msg_list) || dev->msg_cnt) {
  197. pr_err("pchan %s msg leak cnt %d\n", pchan->name, dev->msg_cnt);
  198. list_for_each_entry_safe(msg, tmp, &dev->msg_list, node) {
  199. list_del(&msg->node);
  200. dev->msg_cnt--;
  201. kfree(msg);
  202. }
  203. pr_debug("pchan %s msg cnt %d now\n",
  204. pchan->name, dev->msg_cnt);
  205. }
  206. spin_unlock_bh(&dev->io_lock);
  207. dev->thread_data.stop = 1;
  208. ret = kthread_stop(dev->kthread);
  209. while (!dev->thread_data.bexited)
  210. schedule();
  211. dev->kthread = NULL;
  212. /* hyp_data is freed in pchan */
  213. if (get_refcnt(pchan->refcount) > 1) {
  214. pr_warn("potential leak pchan %s vchans %d refcnt %d\n",
  215. pchan->name, pchan->vcnt, get_refcnt(pchan->refcount));
  216. }
  217. hab_pchan_put((struct physical_channel *)commdev);
  218. return 0;
  219. }