// SPDX-License-Identifier: GPL-2.0-only /* * Copyright (c) 2023 Qualcomm Innovation Center, Inc. All rights reserved. */ #include #include #include #include #include #include #include #include #include #include #include #include #include "af_qmsgq.h" #ifndef AF_QMSGQ #define AF_QMSGQ 27 #endif #ifndef PF_QMSGQ #define PF_QMSGQ AF_QMSGQ #endif struct qmsgq_cb { u32 src_cid; u32 src_port; u32 dst_cid; u32 dst_port; }; static const struct qmsgq_endpoint *registered_ep; static DEFINE_MUTEX(qmsgq_register_mutex); /* auto-bind range */ #define QMSGQ_MIN_EPH_SOCKET 0x4000 #define QMSGQ_MAX_EPH_SOCKET 0x7fff #define QMSGQ_EPH_PORT_RANGE \ XA_LIMIT(QMSGQ_MIN_EPH_SOCKET, QMSGQ_MAX_EPH_SOCKET) /* local port allocation management */ static DEFINE_XARRAY_ALLOC(qmsgq_ports); u32 qmsgq_ports_next = QMSGQ_MIN_EPH_SOCKET; static DEFINE_SPINLOCK(qmsgq_port_lock); /* The default peer timeout indicates how long we will wait for a peer response * to a control message. */ #define QMSGQ_DEFAULT_CONNECT_TIMEOUT (2 * HZ) #define QMSGQ_DEFAULT_BUFFER_SIZE (1024 * 256) #define QMSGQ_DEFAULT_BUFFER_MAX_SIZE (1024 * 256) #define QMSGQ_DEFAULT_BUFFER_MIN_SIZE 128 static void qmsgq_deassign_ep(struct qmsgq_sock *qsk) { if (!qsk->ep) return; qsk->ep->destruct(qsk); module_put(qsk->ep->module); qsk->ep = NULL; } int qmsgq_assign_ep(struct qmsgq_sock *qsk, struct qmsgq_sock *psk) { const struct qmsgq_endpoint *new_ep; int ret; new_ep = registered_ep; if (qsk->ep) { if (qsk->ep == new_ep) return 0; qsk->ep->release(qsk); qmsgq_deassign_ep(qsk); } /* We increase the module refcnt to prevent the transport unloading * while there are open sockets assigned to it. */ if (!new_ep || !try_module_get(new_ep->module)) return -ENODEV; ret = new_ep->init(qsk, psk); if (ret) { module_put(new_ep->module); return ret; } qsk->ep = new_ep; return 0; } static bool qmsgq_find_cid(unsigned int cid) { if (registered_ep && cid == registered_ep->get_local_cid()) return true; return false; } static bool sock_type_connectible(u16 type) { return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET); } static struct qmsgq_sock *qmsgq_port_lookup(int port) { struct qmsgq_sock *qsk; unsigned long flags; spin_lock_irqsave(&qmsgq_port_lock, flags); qsk = xa_load(&qmsgq_ports, port); if (qsk) sock_hold(qsk_sk(qsk)); spin_unlock_irqrestore(&qmsgq_port_lock, flags); return qsk; } static void qmsgq_port_put(struct qmsgq_sock *qsk) { sock_put(qsk_sk(qsk)); } static void qmsgq_port_remove(struct qmsgq_sock *qsk) { int port = qsk->local_addr.svm_port; unsigned long flags; __sock_put(qsk_sk(qsk)); spin_lock_irqsave(&qmsgq_port_lock, flags); xa_erase(&qmsgq_ports, port); spin_unlock_irqrestore(&qmsgq_port_lock, flags); } static int qmsgq_port_assign(struct qmsgq_sock *qsk, int *port) { int rc; if (!*port || *port < 0) { rc = xa_alloc_cyclic(&qmsgq_ports, port, qsk, QMSGQ_EPH_PORT_RANGE, &qmsgq_ports_next, GFP_ATOMIC); } else if (*port < QMSGQ_MIN_EPH_SOCKET && !capable(CAP_NET_ADMIN)) { rc = -EACCES; } else { rc = xa_insert(&qmsgq_ports, *port, qsk, GFP_ATOMIC); } if (rc == -EBUSY) return -EADDRINUSE; else if (rc < 0) return rc; sock_hold(qsk_sk(qsk)); return 0; } static int qmsgq_send_shutdown(struct sock *sk, int mode) { struct qmsgq_sock *qsk = sk_qsk(sk); if (!qsk->ep) return -ENODEV; return qsk->ep->shutdown(qsk, mode); } static void qmsgq_connect_timeout(struct work_struct *work) { } static void qmsgq_pending_work(struct work_struct *work) { } /* Bind socket to address. * * Socket should be locked upon call. */ static int __qmsgq_bind(struct socket *sock, const struct sockaddr_vm *addr, int zapped) { struct qmsgq_sock *qsk = sk_qsk(sock->sk); struct sock *sk = sock->sk; unsigned long flags; int port; int rc; /* rebinding ok */ if (!zapped && addr->svm_port == qsk->local_addr.svm_port) return 0; if (addr->svm_cid != VMADDR_CID_ANY && !qmsgq_find_cid(addr->svm_cid)) return -EADDRNOTAVAIL; spin_lock_irqsave(&qmsgq_port_lock, flags); port = addr->svm_port; rc = qmsgq_port_assign(qsk, &port); spin_unlock_irqrestore(&qmsgq_port_lock, flags); if (rc) return rc; /* unbind previous, if any */ if (!zapped) qmsgq_port_remove(qsk); vsock_addr_init(&qsk->local_addr, VMADDR_CID_HOST, port); sock_reset_flag(sk, SOCK_ZAPPED); return 0; } /* Auto bind to an ephemeral port. */ static int qmsgq_autobind(struct socket *sock) { struct sock *sk = sock->sk; struct sockaddr_vm addr; if (!sock_flag(sk, SOCK_ZAPPED)) return 0; vsock_addr_init(&addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); return __qmsgq_bind(sock, &addr, 1); } static int qmsgq_bind(struct socket *sock, struct sockaddr *addr, int len) { struct sockaddr_vm *vm_addr; struct sock *sk = sock->sk; int rc; if (vsock_addr_cast(addr, len, &vm_addr) != 0) return -EINVAL; lock_sock(sk); rc = __qmsgq_bind(sock, vm_addr, sock_flag(sk, SOCK_ZAPPED)); release_sock(sk); return rc; } static int qmsgq_dgram_connect(struct socket *sock, struct sockaddr *addr, int addr_len, int flags) { struct sockaddr_vm *remote_addr; struct qmsgq_sock *qsk; struct sock *sk; int rc; sk = sock->sk; qsk = sk_qsk(sk); rc = vsock_addr_cast(addr, addr_len, &remote_addr); if (rc == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) { lock_sock(sk); vsock_addr_init(&qsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); sock->state = SS_UNCONNECTED; release_sock(sk); return 0; } else if (rc != 0) { return -EINVAL; } lock_sock(sk); rc = qmsgq_autobind(sock); if (rc) goto out; if (!qsk->ep->dgram_allow(remote_addr->svm_cid, remote_addr->svm_port)) { rc = -EINVAL; goto out; } memcpy(&qsk->remote_addr, remote_addr, sizeof(qsk->remote_addr)); sock->state = SS_CONNECTED; out: release_sock(sk); return rc; } static int qmsgq_getname(struct socket *sock, struct sockaddr *addr, int peer) { struct sockaddr_vm *vm_addr = NULL; struct sock *sk = sock->sk; struct qmsgq_sock *qsk; int rc = 0; qsk = sk_qsk(sk); lock_sock(sk); if (peer) { if (sock->state != SS_CONNECTED) { rc = -ENOTCONN; goto out; } vm_addr = &qsk->remote_addr; } else { vm_addr = &qsk->local_addr; } if (!vm_addr) { rc = -EINVAL; goto out; } BUILD_BUG_ON(sizeof(*vm_addr) > 128); memcpy(addr, vm_addr, sizeof(*vm_addr)); rc = sizeof(*vm_addr); out: release_sock(sk); return rc; } static int qmsgq_shutdown(struct socket *sock, int mode) { struct sock *sk; int rc; /* User level uses SHUT_RD (0) and SHUT_WR (1), but the kernel uses * RCV_SHUTDOWN (1) and SEND_SHUTDOWN (2), so we must increment mode * here like the other address families do. Note also that the * increment makes SHUT_RDWR (2) into RCV_SHUTDOWN | SEND_SHUTDOWN (3), * which is what we want. */ mode++; if ((mode & ~SHUTDOWN_MASK) || !mode) return -EINVAL; /* If this is a connection oriented socket and it is not connected then * bail out immediately. If it is a DGRAM socket then we must first * kick the socket so that it wakes up from any sleeping calls, for * example recv(), and then afterwards return the error. */ sk = sock->sk; lock_sock(sk); if (sock->state == SS_UNCONNECTED) { rc = -ENOTCONN; if (sock_type_connectible(sk->sk_type)) goto out; } else { sock->state = SS_DISCONNECTING; rc = 0; } /* Receive and send shutdowns are treated alike. */ mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN); if (mode) { sk->sk_shutdown |= mode; sk->sk_state_change(sk); if (sock_type_connectible(sk->sk_type)) { sock_reset_flag(sk, SOCK_DONE); qmsgq_send_shutdown(sk, mode); } } out: release_sock(sk); return rc; } static __poll_t qmsgq_poll(struct file *file, struct socket *sock, poll_table *wait) { struct sock *sk = sock->sk; struct qmsgq_sock *qsk; __poll_t mask; qsk = sk_qsk(sk); poll_wait(file, sk_sleep(sk), wait); mask = 0; if (sk->sk_err) /* Signify that there has been an error on this socket. */ mask |= EPOLLERR; /* INET sockets treat local write shutdown and peer write shutdown as a * case of EPOLLHUP set. */ if (sk->sk_shutdown == SHUTDOWN_MASK || ((sk->sk_shutdown & SEND_SHUTDOWN) && (qsk->peer_shutdown & SEND_SHUTDOWN))) { mask |= EPOLLHUP; } if (sk->sk_shutdown & RCV_SHUTDOWN || qsk->peer_shutdown & SEND_SHUTDOWN) { mask |= EPOLLRDHUP; } if (sock->type == SOCK_DGRAM) { /* For datagram sockets we can read if there is something in * the queue and write as long as the socket isn't shutdown for * sending. */ if (!skb_queue_empty_lockless(&sk->sk_receive_queue) || (sk->sk_shutdown & RCV_SHUTDOWN)) { mask |= EPOLLIN | EPOLLRDNORM; } if (!(sk->sk_shutdown & SEND_SHUTDOWN)) mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; } /* TODO Connected POLL */ return mask; } static int qmsgq_dgram_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) { const struct qmsgq_endpoint *ep; struct sockaddr_vm *remote_addr; struct sock *sk = sock->sk; struct qmsgq_sock *qsk; int rc = 0; if (msg->msg_flags & MSG_OOB) return -EOPNOTSUPP; qsk = sk_qsk(sk); lock_sock(sk); ep = qsk->ep; rc = qmsgq_autobind(sock); if (rc) goto out; if (msg->msg_name) { rc = vsock_addr_cast(msg->msg_name, msg->msg_namelen, &remote_addr); if (rc) goto out; } else if (sock->state == SS_CONNECTED) { remote_addr = &qsk->remote_addr; } else { rc = -EINVAL; goto out; } if (remote_addr->svm_cid == VMADDR_CID_ANY) remote_addr->svm_cid = ep->get_local_cid(); if (!vsock_addr_bound(remote_addr)) { rc = -EINVAL; goto out; } if (!ep->dgram_allow(remote_addr->svm_cid, remote_addr->svm_port)) { rc = -EINVAL; goto out; } rc = ep->dgram_enqueue(qsk, remote_addr, msg, len); if (!rc) rc = len; out: release_sock(sk); return rc; } static int qmsgq_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags) { DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name); struct sock *sk = sock->sk; struct qmsgq_sock *qsk; struct sk_buff *skb; struct qmsgq_cb *cb; int copied; int rc = 0; qsk = sk_qsk(sk); if (sock_flag(sk, SOCK_ZAPPED)) { pr_err("%s: Invalid socket error\n", __func__); return -EADDRNOTAVAIL; } skb = skb_recv_datagram(sk, flags, &rc); if (!skb) return rc; lock_sock(sk); cb = (struct qmsgq_cb *)skb->cb; copied = skb->len; if (copied > len) { copied = len; msg->msg_flags |= MSG_TRUNC; } /* Place the datagram payload in the user's iovec. */ rc = skb_copy_datagram_msg(skb, 0, msg, copied); if (rc < 0) { pr_err("%s: skb_copy_datagram_msg failed: %d\n", __func__, rc); goto out; } rc = copied; if (vm_addr) { vsock_addr_init(vm_addr, VMADDR_CID_HOST, cb->src_port); msg->msg_namelen = sizeof(*vm_addr); } out: skb_free_datagram(sk, skb); release_sock(sk); return rc; } static void __qmsgq_release(struct sock *sk, int level) { if (sk) { struct qmsgq_sock *qsk = sk_qsk(sk); lock_sock_nested(sk, level); if (qsk->ep) qsk->ep->release(qsk); if (!sock_flag(sk, SOCK_DEAD)) sk->sk_state_change(sk); if (!sock_flag(sk, SOCK_ZAPPED)) qmsgq_port_remove(qsk); sock_orphan(sk); sk->sk_shutdown = SHUTDOWN_MASK; skb_queue_purge(&sk->sk_receive_queue); release_sock(sk); sock_put(sk); } } static int qmsgq_release(struct socket *sock) { __qmsgq_release(sock->sk, 0); sock->sk = NULL; sock->state = SS_FREE; return 0; } static const struct proto_ops qmsgq_dgram_ops = { .owner = THIS_MODULE, .family = AF_QMSGQ, .release = qmsgq_release, .bind = qmsgq_bind, .connect = qmsgq_dgram_connect, .socketpair = sock_no_socketpair, .accept = sock_no_accept, .getname = qmsgq_getname, .poll = qmsgq_poll, .ioctl = sock_no_ioctl, .listen = sock_no_listen, .shutdown = qmsgq_shutdown, .sendmsg = qmsgq_dgram_sendmsg, .recvmsg = qmsgq_dgram_recvmsg, .mmap = sock_no_mmap, .sendpage = sock_no_sendpage, }; static struct proto qmsgq_proto = { .name = "QMSGQ", .owner = THIS_MODULE, .obj_size = sizeof(struct qmsgq_sock), }; static void sk_qsk_destruct(struct sock *sk) { struct qmsgq_sock *qsk = sk_qsk(sk); qmsgq_deassign_ep(qsk); /* When clearing these addresses, there's no need to set the family and * possibly register the address family with the kernel. */ vsock_addr_init(&qsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); vsock_addr_init(&qsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); put_cred(qsk->owner); } static int qmsgq_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) { int err; err = sock_queue_rcv_skb(sk, skb); if (err) kfree_skb(skb); return err; } static struct sock *__qmsgq_create(struct net *net, struct socket *sock, struct sock *parent, gfp_t priority, unsigned short type, int kern) { struct qmsgq_sock *psk; struct qmsgq_sock *qsk; struct sock *sk; sk = sk_alloc(net, AF_QMSGQ, priority, &qmsgq_proto, kern); if (!sk) return NULL; sock_init_data(sock, sk); if (!sock) sk->sk_type = type; qsk = sk_qsk(sk); vsock_addr_init(&qsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); vsock_addr_init(&qsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); sk->sk_destruct = sk_qsk_destruct; sk->sk_backlog_rcv = qmsgq_queue_rcv_skb; sock_reset_flag(sk, SOCK_DONE); sock_set_flag(sk, SOCK_ZAPPED); INIT_LIST_HEAD(&qsk->bound_table); INIT_LIST_HEAD(&qsk->connected_table); qsk->listener = NULL; INIT_LIST_HEAD(&qsk->pending_links); INIT_LIST_HEAD(&qsk->accept_queue); qsk->rejected = false; qsk->sent_request = false; qsk->ignore_connecting_rst = false; qsk->peer_shutdown = 0; INIT_DELAYED_WORK(&qsk->connect_work, qmsgq_connect_timeout); INIT_DELAYED_WORK(&qsk->pending_work, qmsgq_pending_work); psk = parent ? sk_qsk(parent) : NULL; if (parent) { qsk->trusted = psk->trusted; qsk->owner = get_cred(psk->owner); qsk->connect_timeout = psk->connect_timeout; qsk->buffer_size = psk->buffer_size; qsk->buffer_min_size = psk->buffer_min_size; qsk->buffer_max_size = psk->buffer_max_size; security_sk_clone(parent, sk); } else { qsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN); qsk->owner = get_current_cred(); qsk->connect_timeout = QMSGQ_DEFAULT_CONNECT_TIMEOUT; qsk->buffer_size = QMSGQ_DEFAULT_BUFFER_SIZE; qsk->buffer_min_size = QMSGQ_DEFAULT_BUFFER_MIN_SIZE; qsk->buffer_max_size = QMSGQ_DEFAULT_BUFFER_MAX_SIZE; } return sk; } static int qmsgq_create(struct net *net, struct socket *sock, int protocol, int kern) { struct qmsgq_sock *qsk; struct sock *sk; int rc; if (!sock) return -EINVAL; if (protocol && protocol != PF_QMSGQ) return -EPROTONOSUPPORT; switch (sock->type) { case SOCK_DGRAM: sock->ops = &qmsgq_dgram_ops; break; default: return -ESOCKTNOSUPPORT; } sock->state = SS_UNCONNECTED; sk = __qmsgq_create(net, sock, NULL, GFP_KERNEL, 0, kern); if (!sk) return -ENOMEM; qsk = sk_qsk(sk); if (sock->type == SOCK_DGRAM) { rc = qmsgq_assign_ep(qsk, NULL); if (rc < 0) { sock_put(sk); return rc; } } return 0; } int qmsgq_post(const struct qmsgq_endpoint *ep, struct sockaddr_vm *src, struct sockaddr_vm *dst, void *data, int len) { struct qmsgq_sock *qsk; struct qmsgq_cb *cb; struct sk_buff *skb; int rc; skb = alloc_skb_with_frags(0, len, 0, &rc, GFP_KERNEL); if (!skb) { pr_err("%s: Unable to get skb with len:%lu\n", __func__, len); return -ENOMEM; } cb = (struct qmsgq_cb *)skb->cb; cb->src_cid = src->svm_cid; cb->src_port = src->svm_port; cb->dst_cid = dst->svm_cid; cb->dst_port = dst->svm_port; skb->data_len = len; skb->len = len; skb_store_bits(skb, 0, data, len); qsk = qmsgq_port_lookup(dst->svm_port); if (!qsk || qsk->ep != ep) { pr_err("%s: invalid dst port:%d\n", __func__, dst->svm_port); kfree_skb(skb); return -EINVAL; } if (sock_queue_rcv_skb(qsk_sk(qsk), skb)) { pr_err("%s: sock_queue_rcv_skb failed\n", __func__); qmsgq_port_put(qsk); kfree_skb(skb); return -EINVAL; } qmsgq_port_put(qsk); return 0; } EXPORT_SYMBOL(qmsgq_post); int qmsgq_endpoint_register(const struct qmsgq_endpoint *ep) { int rc = 0; if (!ep) return -EINVAL; mutex_lock(&qmsgq_register_mutex); if (registered_ep) { rc = -EBUSY; goto error; } registered_ep = ep; error: mutex_unlock(&qmsgq_register_mutex); return rc; } EXPORT_SYMBOL(qmsgq_endpoint_register); void qmsgq_endpoint_unregister(const struct qmsgq_endpoint *ep) { mutex_lock(&qmsgq_register_mutex); if (registered_ep == ep) ep = NULL; mutex_unlock(&qmsgq_register_mutex); } EXPORT_SYMBOL(qmsgq_endpoint_unregister); static const struct net_proto_family qmsgq_family = { .owner = THIS_MODULE, .family = AF_QMSGQ, .create = qmsgq_create, }; static int __init qmsgq_proto_init(void) { int rc; registered_ep = NULL; rc = proto_register(&qmsgq_proto, 1); if (rc) return rc; rc = sock_register(&qmsgq_family); if (rc) goto err_proto; return 0; err_proto: proto_unregister(&qmsgq_proto); return rc; } static void __exit qmsgq_proto_fini(void) { sock_unregister(qmsgq_family.family); proto_unregister(&qmsgq_proto); } module_init(qmsgq_proto_init); module_exit(qmsgq_proto_fini); MODULE_DESCRIPTION("QTI Gunyah MSGQ Socket driver"); MODULE_LICENSE("GPL"); MODULE_ALIAS_NETPROTO(PF_QMSGQ);