Browse Source

smcinvoke: Make memory objects release inline with callback objects

Release memory object from userspace when a call to release comes in
kernel to prevent the memory leaks in the cases where mem objects
are passed as a response inside a callback call.

Change-Id: I5ce57b6be90e71e255a890895d5f2859312ba1e4
Signed-off-by: Anmolpreet Kaur <[email protected]>
Anmolpreet Kaur 2 years ago
parent
commit
b38291450d
1 changed files with 72 additions and 12 deletions
  1. 72 12
      smcinvoke/smcinvoke.c

+ 72 - 12
smcinvoke/smcinvoke.c

@@ -279,6 +279,8 @@ struct smcinvoke_mem_obj {
 	struct list_head list;
 	bool is_smcinvoke_created_shmbridge;
 	uint64_t shmbridge_handle;
+	struct smcinvoke_server_info *server;
+	int32_t mem_obj_user_fd;
 };
 
 static LIST_HEAD(g_bridge_postprocess);
@@ -404,7 +406,7 @@ static struct smcinvoke_mem_obj *find_mem_obj_locked(uint16_t mem_obj_id,
 				(mem_obj->mem_region_id == mem_obj_id)) ||
 				(!is_mem_rgn_obj &&
 				(mem_obj->mem_map_obj_id == mem_obj_id)))
-			return mem_obj;
+				return mem_obj;
 	}
 	return NULL;
 }
@@ -649,6 +651,7 @@ static inline void free_mem_obj_locked(struct smcinvoke_mem_obj *mem_obj)
 	struct smcinvoke_shmbridge_deregister_pending_list *entry = NULL;
 
 	list_del(&mem_obj->list);
+	kfree(mem_obj->server);
 	kfree(mem_obj);
 	mem_obj = NULL;
 	mutex_unlock(&g_smcinvoke_lock);
@@ -732,7 +735,6 @@ static int release_mem_obj_locked(int32_t tzhandle)
 		kref_put(&mem_obj->mem_regn_ref_cnt, del_mem_regn_obj_locked);
 	else
 		kref_put(&mem_obj->mem_map_obj_ref_cnt, del_mem_map_obj_locked);
-
 	return OBJECT_OK;
 }
 
@@ -855,6 +857,8 @@ static struct smcinvoke_cb_txn *find_cbtxn_locked(
 {
 	int i = 0;
 	struct smcinvoke_cb_txn *cb_txn = NULL;
+	struct smcinvoke_mem_obj *mem_obj = NULL;
+	int32_t tzhandle = 0;
 
 	/*
 	 * Since HASH_BITS() does not work on pointers, we can't select hash
@@ -864,6 +868,12 @@ static struct smcinvoke_cb_txn *find_cbtxn_locked(
 		/* pick up 1st req */
 		hash_for_each(server->reqs_table, i, cb_txn, hash) {
 			kref_get(&cb_txn->ref_cnt);
+			tzhandle = (cb_txn->cb_req)->hdr.tzhandle;
+			if(TZHANDLE_IS_MEM_OBJ(tzhandle)) {
+				mem_obj= find_mem_obj_locked(TZHANDLE_GET_OBJID(tzhandle),
+								SMCINVOKE_MEM_RGN_OBJ);
+				kref_get(&mem_obj->mem_regn_ref_cnt);
+			}
 			hash_del(&cb_txn->hash);
 			return cb_txn;
 		}
@@ -872,6 +882,12 @@ static struct smcinvoke_cb_txn *find_cbtxn_locked(
 				server->responses_table, cb_txn, hash, txn_id) {
 			if (cb_txn->txn_id == txn_id) {
 				kref_get(&cb_txn->ref_cnt);
+				tzhandle = (cb_txn->cb_req)->hdr.tzhandle;
+				if(TZHANDLE_IS_MEM_OBJ(tzhandle)) {
+					mem_obj= find_mem_obj_locked(TZHANDLE_GET_OBJID(tzhandle),
+									SMCINVOKE_MEM_RGN_OBJ);
+					kref_get(&mem_obj->mem_regn_ref_cnt);
+				}
 				hash_del(&cb_txn->hash);
 				return cb_txn;
 			}
@@ -952,20 +968,30 @@ static bool is_remote_obj(int32_t uhandle, struct smcinvoke_file_data **tzobj,
 	return ret;
 }
 
-static int create_mem_obj(struct dma_buf *dma_buf, int32_t *mem_obj)
+static int create_mem_obj(struct dma_buf *dma_buf, int32_t *mem_obj,
+				int32_t server_id, int32_t user_handle)
 {
-	struct smcinvoke_mem_obj *t_mem_obj =
-			kzalloc(sizeof(*t_mem_obj), GFP_KERNEL);
+	struct smcinvoke_mem_obj *t_mem_obj = NULL;
+	struct smcinvoke_server_info *server_i = NULL;
 
+	t_mem_obj = kzalloc(sizeof(struct smcinvoke_mem_obj), GFP_KERNEL);
 	if (!t_mem_obj) {
 		dma_buf_put(dma_buf);
 		return -ENOMEM;
 	}
-
+	server_i = kzalloc(sizeof(struct smcinvoke_server_info),GFP_KERNEL);
+	if (!server_i) {
+		kfree(t_mem_obj);
+		dma_buf_put(dma_buf);
+		return -ENOMEM;
+	}
 	kref_init(&t_mem_obj->mem_regn_ref_cnt);
 	t_mem_obj->dma_buf = dma_buf;
 	mutex_lock(&g_smcinvoke_lock);
 	t_mem_obj->mem_region_id = next_mem_region_obj_id_locked();
+	server_i->server_id = server_id;
+	t_mem_obj->server = server_i;
+	t_mem_obj->mem_obj_user_fd = user_handle;
 	list_add_tail(&t_mem_obj->list, &g_mem_objs);
 	mutex_unlock(&g_smcinvoke_lock);
 	*mem_obj = TZHANDLE_MAKE_LOCAL(MEM_RGN_SRVR_ID,
@@ -1007,7 +1033,8 @@ static int get_tzhandle_from_uhandle(int32_t uhandle, int32_t server_fd,
 		struct smcinvoke_file_data *tzobj = NULL;
 
 		if (is_dma_fd(UHANDLE_GET_FD(uhandle), &dma_buf)) {
-			ret = create_mem_obj(dma_buf, tzhandle);
+			server_id = get_server_id(server_fd);
+			ret = create_mem_obj(dma_buf, tzhandle, server_id, uhandle);
 		} else if (is_remote_obj(UHANDLE_GET_FD(uhandle),
 				&tzobj, filp)) {
 			*tzhandle = tzobj->tzhandle;
@@ -1088,8 +1115,7 @@ static int get_uhandle_from_tzhandle(int32_t tzhandle, int32_t srvr_id,
 		if (mem_obj != NULL) {
 			int fd;
 
-			fd = dma_buf_fd(mem_obj->dma_buf, O_CLOEXEC);
-
+			fd = mem_obj->mem_obj_user_fd;
 			if (fd < 0)
 				goto exit_lock;
 			*uhandle = fd;
@@ -1380,6 +1406,8 @@ static void process_tzcb_req(void *buf, size_t buf_len, struct file **arr_filp)
 	struct smcinvoke_cb_txn *cb_txn = NULL;
 	struct smcinvoke_tzcb_req *cb_req = NULL, *tmp_cb_req = NULL;
 	struct smcinvoke_server_info *srvr_info = NULL;
+	struct smcinvoke_mem_obj *mem_obj = NULL;
+	uint16_t server_id = 0;
 
 	if (buf_len < sizeof(struct smcinvoke_tzcb_req)) {
 		pr_err("smaller buffer length : %u\n", buf_len);
@@ -1391,8 +1419,21 @@ static void process_tzcb_req(void *buf, size_t buf_len, struct file **arr_filp)
 	/* check whether it is to be served by kernel or userspace */
 	if (TZHANDLE_IS_KERNEL_OBJ(cb_req->hdr.tzhandle)) {
 		return process_kernel_obj(buf, buf_len);
-	} else if (TZHANDLE_IS_MEM_OBJ(cb_req->hdr.tzhandle)) {
+	} else if (TZHANDLE_IS_MEM_MAP_OBJ(cb_req->hdr.tzhandle)) {
+		/*
+		 * MEM_MAP memory object is created and owned by kernel,
+		 * hence its processing(handling deletion) is done in
+		 * kernel context.
+		 */
 		return process_mem_obj(buf, buf_len);
+	} else if(TZHANDLE_IS_MEM_RGN_OBJ(cb_req->hdr.tzhandle)) {
+		/*
+		 * MEM_RGN memory objects are created and owned by userspace,
+		 * and hence their deletion/handling requires going back to the
+		 * userspace, similar to that of callback objects. If we enter
+		 * this 'if' condition, its no-op here, and proceed similar to
+		 * case of callback objects.
+		 */
 	} else if (!TZHANDLE_IS_CB_OBJ(cb_req->hdr.tzhandle)) {
 		pr_err("Request object is not a callback object\n");
 		cb_req->result = OBJECT_ERROR_INVALID;
@@ -1434,8 +1475,19 @@ static void process_tzcb_req(void *buf, size_t buf_len, struct file **arr_filp)
 
 	mutex_lock(&g_smcinvoke_lock);
 	++cb_reqs_inflight;
-	srvr_info = get_cb_server_locked(
-			TZHANDLE_GET_SERVER(cb_req->hdr.tzhandle));
+
+	if(TZHANDLE_IS_MEM_RGN_OBJ(cb_req->hdr.tzhandle)) {
+		mem_obj= find_mem_obj_locked(TZHANDLE_GET_OBJID(cb_req->hdr.tzhandle),SMCINVOKE_MEM_RGN_OBJ);
+		if(!mem_obj) {
+			pr_err("mem obj with tzhandle : %d not found",cb_req->hdr.tzhandle);
+			goto out;
+		}
+		server_id = mem_obj->server->server_id;
+	} else {
+		server_id = TZHANDLE_GET_SERVER(cb_req->hdr.tzhandle);
+	}
+
+	srvr_info = get_cb_server_locked(server_id);
 	if (!srvr_info || srvr_info->state == SMCINVOKE_SERVER_STATE_DEFUNCT) {
 		/* ret equals Object_ERROR_DEFUNCT, at this point go to out */
 		if (!srvr_info)
@@ -1532,6 +1584,13 @@ out:
 			cb_req->hdr.counts, cb_reqs_inflight);
 
 	memcpy(buf, cb_req, buf_len);
+
+	if (TZHANDLE_IS_MEM_RGN_OBJ(cb_req->hdr.tzhandle)) {
+		mutex_unlock(&g_smcinvoke_lock);
+		process_mem_obj(buf, buf_len);
+		pr_err("ppid : %x, mem obj deleted\n", current->pid);
+		mutex_lock(&g_smcinvoke_lock);
+	}
 	kref_put(&cb_txn->ref_cnt, delete_cb_txn_locked);
 	if (srvr_info)
 		kref_put(&srvr_info->ref_cnt, destroy_cb_server);
@@ -2153,6 +2212,7 @@ static long process_accept_req(struct file *filp, unsigned int cmd,
 
 		cb_txn->state = SMCINVOKE_REQ_PROCESSED;
 		mutex_lock(&g_smcinvoke_lock);
+
 		kref_put(&cb_txn->ref_cnt, delete_cb_txn_locked);
 		mutex_unlock(&g_smcinvoke_lock);
 		wake_up(&server_info->rsp_wait_q);