Explorar o código

Merge e1ee568286de61f58f64fabf92a8db4133397101 on remote branch

Change-Id: Ie4612f83c78f2d96366da4436a6b685858e429e7
Linux Build Service Account %!s(int64=2) %!d(string=hai) anos
pai
achega
e50fe3fc15

+ 21 - 0
BUILD.bazel

@@ -11,6 +11,8 @@ ddk_headers(
     hdrs = glob([
         "include/linux/smcinvoke*.h",
         "include/linux/IClientE*.h",
+        "include/linux/ITrustedCameraDriver.h",
+        "include/linux/CTrustedCameraDriver.h",
         "linux/misc/qseecom_kernel.h",
         "linux/misc/qseecom_priv.h"
     ]),
@@ -25,6 +27,25 @@ ddk_headers(
     ],
     includes = ["linux"]
 )
+ddk_headers(
+    name = "hdcp_qseecom_dlkm",
+    hdrs = glob([
+        "linux/*.h",
+        "linux/misc/qseecom_kernel.h",
+        "hdcp_qseecom/*.h",
+        "config/*.h"
+    ]),
+    includes = [".","linux","config"]
+)
+
+ddk_headers(
+    name = "qcedev_local_headers",
+    hdrs = glob([
+        "linux/*.h",
+        "crypto-qti/*.h"
+    ]),
+    includes = [".", "crypto-qti"]
+)
 
 load("pineapple.bzl", "define_pineapple")
 

+ 4 - 1
crypto-qti/qcedev.c

@@ -22,7 +22,7 @@
 #include <linux/debugfs.h>
 #include <linux/scatterlist.h>
 #include <linux/crypto.h>
-#include "linux/platform_data/qcom_crypto_device.h"
+#include "linux/qcom_crypto_device.h"
 #include "linux/qcedev.h"
 #include <linux/interconnect.h>
 #include <linux/delay.h>
@@ -2685,8 +2685,11 @@ static int qcedev_remove(struct platform_device *pdev)
 	podev = platform_get_drvdata(pdev);
 	if (!podev)
 		return 0;
+
+	qcedev_ce_high_bw_req(podev, true);
 	if (podev->qce)
 		qce_close(podev->qce);
+	qcedev_ce_high_bw_req(podev, false);
 
 	if (podev->icc_path)
 		icc_put(podev->icc_path);

+ 1 - 1
crypto-qti/qcedevi.h

@@ -11,7 +11,7 @@
 #include <linux/interrupt.h>
 #include <linux/cdev.h>
 #include <crypto/hash.h>
-#include "linux/platform_data/qcom_crypto_device.h"
+#include "linux/qcom_crypto_device.h"
 #include "linux/fips_status.h"
 #include "qce.h"
 #include "qcedev_smmu.h"

+ 1 - 1
crypto-qti/qcrypto.c

@@ -25,7 +25,7 @@
 #include <linux/sched.h>
 #include <linux/init.h>
 #include <linux/cache.h>
-#include "linux/platform_data/qcom_crypto_device.h"
+#include "linux/qcom_crypto_device.h"
 #include <linux/interconnect.h>
 #include <linux/hardirq.h>
 #include "linux/qcrypto.h"

+ 1 - 0
include/linux/smcinvoke.h

@@ -11,6 +11,7 @@
 
 #define SMCINVOKE_USERSPACE_OBJ_NULL	-1
 #define DEFAULT_CB_OBJ_THREAD_CNT	4
+#define SMCINVOKE_TZ_MIN_BUF_SIZE	4096
 
 struct smcinvoke_buf {
 	__u64 addr;

+ 0 - 0
linux/platform_data/qcom_crypto_device.h → linux/qcom_crypto_device.h


+ 5 - 0
pineapple.bzl

@@ -6,6 +6,11 @@ def define_pineapple():
         modules = [
             "smcinvoke_dlkm",
             "tz_log_dlkm",
+            "hdcp_qseecom_dlkm",
+            "qce50_dlkm",
+            "qcedev-mod_dlkm",
+            "qrng_dlkm",
+            "qcrypto-msm_dlkm"
         ],
         extra_options = [
             "CONFIG_QCOM_SMCINVOKE"]

+ 45 - 0
securemsm_modules.bzl

@@ -1,6 +1,9 @@
 SMCINVOKE_PATH = "smcinvoke"
 QSEECOM_PATH = "qseecom"
 TZLOG_PATH = "tz_log"
+HDCP_PATH = "hdcp"
+QCEDEV_PATH = "crypto-qti"
+QRNG_PATH = "qrng"
 
 # This dictionary holds all the securemsm-kernel  modules included by calling register_securemsm_module
 securemsm_modules = {}
@@ -76,3 +79,45 @@ register_securemsm_module(
     path = TZLOG_PATH,
     default_srcs = ["tz_log.c"],
 )
+
+register_securemsm_module(
+    name = "hdcp_qseecom_dlkm",
+    path = HDCP_PATH,
+    default_srcs = ["hdcp_qseecom.c"],
+    deps = [":hdcp_qseecom_dlkm","%b_smcinvoke_dlkm"],
+    srcs = ["config/sec-kernel_defconfig.h"],
+    copts = ["-include", "config/sec-kernel_defconfig.h"],
+)
+
+register_securemsm_module(
+    name = "qce50_dlkm",
+    path = QCEDEV_PATH,
+    default_srcs = ["qce50.c"],
+    deps = [":qcedev_local_headers"],
+)
+
+register_securemsm_module(
+    name = "qcedev-mod_dlkm",
+    path = QCEDEV_PATH,
+    default_srcs = [
+                "qcedev.c",
+                "qcedev_smmu.c",
+                "compat_qcedev.c"],
+    deps = [":qcedev_local_headers",
+            "%b_qce50_dlkm"],
+)
+
+register_securemsm_module(
+    name = "qrng_dlkm",
+    path = QRNG_PATH,
+    default_srcs = ["msm_rng.c"],
+    deps = [":qcedev_local_headers"],
+)
+
+register_securemsm_module(
+    name = "qcrypto-msm_dlkm",
+    path = QCEDEV_PATH,
+    default_srcs = ["qcrypto.c"],
+    deps = [":qcedev_local_headers",
+            "%b_qce50_dlkm"],
+)

+ 66 - 28
smcinvoke/smcinvoke.c

@@ -47,7 +47,6 @@
 #define SMCINVOKE_DEV				"smcinvoke"
 #define SMCINVOKE_TZ_ROOT_OBJ			1
 #define SMCINVOKE_TZ_OBJ_NULL			0
-#define SMCINVOKE_TZ_MIN_BUF_SIZE		4096
 #define SMCINVOKE_ARGS_ALIGN_SIZE		(sizeof(uint64_t))
 #define SMCINVOKE_NEXT_AVAILABLE_TXN		0
 #define SMCINVOKE_REQ_PLACED			1
@@ -334,7 +333,6 @@ struct smcinvoke_mem_obj {
 	uint64_t p_addr;
 	size_t p_addr_len;
 	struct list_head list;
-	bool is_smcinvoke_created_shmbridge;
 	uint64_t shmbridge_handle;
 	struct smcinvoke_server_info *server;
 	int32_t mem_obj_user_fd;
@@ -390,6 +388,8 @@ static int prepare_send_scm_msg(const uint8_t *in_buf, phys_addr_t in_paddr,
 		struct qtee_shm *in_shm, struct qtee_shm *out_shm);
 
 static void process_piggyback_data(void *buf, size_t buf_size);
+static void add_mem_obj_info_to_async_side_channel_locked(void *buf, size_t buf_size, struct list_head *l_pending_mem_obj);
+static void delete_pending_async_list_locked(struct list_head *l_pending_mem_obj);
 
 static void destroy_cb_server(struct kref *kref)
 {
@@ -673,7 +673,7 @@ static void __wakeup_postprocess_kthread(struct smcinvoke_worker_thread *smcinvo
 static int smcinvoke_postprocess_kthread_func(void *data)
 {
 	struct smcinvoke_worker_thread *smcinvoke_wrk_trd = data;
-	const char *tag;
+	static const char *const tag[] = {"shmbridge","object","adci","invalid"};
 
 	if (!smcinvoke_wrk_trd) {
 		pr_err("Bad input.\n");
@@ -688,21 +688,18 @@ static int smcinvoke_postprocess_kthread_func(void *data)
 			== POST_KT_WAKEUP));
 		switch (smcinvoke_wrk_trd->type) {
 		case SHMB_WORKER_THREAD:
-			tag = "shmbridge";
 			pr_debug("kthread to %s postprocess is called %d\n",
-			tag, atomic_read(&smcinvoke_wrk_trd->postprocess_kthread_state));
+			tag[SHMB_WORKER_THREAD], atomic_read(&smcinvoke_wrk_trd->postprocess_kthread_state));
 			smcinvoke_shmbridge_post_process();
 			break;
 		case OBJECT_WORKER_THREAD:
-			tag = "object";
 			pr_debug("kthread to %s postprocess is called %d\n",
-			tag, atomic_read(&smcinvoke_wrk_trd->postprocess_kthread_state));
+			tag[OBJECT_WORKER_THREAD], atomic_read(&smcinvoke_wrk_trd->postprocess_kthread_state));
 			smcinvoke_object_post_process();
 			break;
 		case ADCI_WORKER_THREAD:
-			tag = "adci";
 			pr_debug("kthread to %s postprocess is called %d\n",
-			tag, atomic_read(&smcinvoke_wrk_trd->postprocess_kthread_state));
+			tag[ADCI_WORKER_THREAD], atomic_read(&smcinvoke_wrk_trd->postprocess_kthread_state));
 			smcinvoke_start_adci_thread();
 			break;
 		default:
@@ -722,7 +719,7 @@ static int smcinvoke_postprocess_kthread_func(void *data)
 		atomic_set(&smcinvoke_wrk_trd->postprocess_kthread_state,
 			POST_KT_SLEEP);
 	}
-	pr_warn("kthread to %s postprocess stopped\n", tag);
+	pr_warn("kthread(worker_thread) processed, worker_thread type is %d \n", smcinvoke_wrk_trd->type);
 
 	return 0;
 }
@@ -808,7 +805,6 @@ static void queue_mem_obj_pending_async_locked(struct smcinvoke_mem_obj *mem_obj
 static inline void free_mem_obj_locked(struct smcinvoke_mem_obj *mem_obj)
 {
 	int ret = 0;
-	bool is_bridge_created = mem_obj->is_smcinvoke_created_shmbridge;
 	struct dma_buf *dmabuf_to_free = mem_obj->dma_buf;
 	uint64_t shmbridge_handle = mem_obj->shmbridge_handle;
 	struct smcinvoke_shmbridge_deregister_pending_list *entry = NULL;
@@ -819,7 +815,7 @@ static inline void free_mem_obj_locked(struct smcinvoke_mem_obj *mem_obj)
 	mem_obj = NULL;
 	mutex_unlock(&g_smcinvoke_lock);
 
-	if (is_bridge_created)
+	if (shmbridge_handle)
 		ret = qtee_shmbridge_deregister(shmbridge_handle);
 	if (ret) {
 		pr_err("Error:%d delete bridge failed leaking memory 0x%x\n",
@@ -1158,16 +1154,10 @@ static int smcinvoke_create_bridge(struct smcinvoke_mem_obj *mem_obj)
   ret = qtee_shmbridge_register(phys, size, vmid_list, perms_list, nelems,
       tz_perm, &mem_obj->shmbridge_handle);
 
-  if (ret == 0) {
-    /* In case of ret=0/success handle has to be freed in memobj release */
-    mem_obj->is_smcinvoke_created_shmbridge = true;
-  } else if (ret == -EEXIST) {
-    ret = 0;
-    goto exit;
-  } else {
-    pr_err("creation of shm bridge for mem_region_id %d failed ret %d\n",
-        mem_obj->mem_region_id, ret);
-    goto exit;
+  if (ret) {
+	  pr_err("creation of shm bridge for mem_region_id %d failed ret %d\n",
+			  mem_obj->mem_region_id, ret);
+	  goto exit;
   }
 
   trace_smcinvoke_create_bridge(mem_obj->shmbridge_handle, mem_obj->mem_region_id);
@@ -1531,7 +1521,7 @@ static void process_kernel_obj(void *buf, size_t buf_len)
 
 	switch (cb_req->hdr.op) {
 	case OBJECT_OP_MAP_REGION:
-		pr_debug("Received a request to map memory region\n");
+		pr_err("Received a request to map memory region\n");
 		cb_req->result = smcinvoke_process_map_mem_region_req(buf, buf_len);
 		break;
 	case OBJECT_OP_YIELD:
@@ -2198,6 +2188,13 @@ static int marshal_out_tzcb_req(const struct smcinvoke_accept *user_req,
 	int32_t tzhandles_to_release[OBJECT_COUNTS_MAX_OO] = {0};
 	struct smcinvoke_tzcb_req *tzcb_req = cb_txn->cb_req;
 	union smcinvoke_tz_args *tz_args = tzcb_req->args;
+	size_t tz_buf_offset = TZCB_BUF_OFFSET(tzcb_req);
+	LIST_HEAD(l_mem_objs_pending_async);    /* Holds new memory objects, to be later sent to TZ */
+	uint32_t max_offset = 0;
+	uint32_t buffer_size_max_offset = 0;
+	void* async_buf_begin;
+	size_t async_buf_size;
+	uint32_t offset = 0;
 
 	release_tzhandles(&cb_txn->cb_req->hdr.tzhandle, 1);
 	tzcb_req->result = user_req->result;
@@ -2207,6 +2204,16 @@ static int marshal_out_tzcb_req(const struct smcinvoke_accept *user_req,
                 ret = 0;
                 goto out;
         }
+
+	FOR_ARGS(i, tzcb_req->hdr.counts, BI) {
+
+		/* Find the max offset and the size of the buffer in that offset */
+		if (tz_args[i].b.offset > max_offset) {
+			max_offset = tz_args[i].b.offset;
+			buffer_size_max_offset = tz_args[i].b.size;
+		}
+	}
+
 	FOR_ARGS(i, tzcb_req->hdr.counts, BO) {
 		union smcinvoke_arg tmp_arg;
 
@@ -2224,6 +2231,12 @@ static int marshal_out_tzcb_req(const struct smcinvoke_accept *user_req,
 			ret = -EFAULT;
 			goto out;
 		}
+
+		/* Find the max offset and the size of the buffer in that offset */
+		if (tz_args[i].b.offset > max_offset) {
+			max_offset = tz_args[i].b.offset;
+			buffer_size_max_offset = tz_args[i].b.size;
+		}
 	}
 
 	FOR_ARGS(i, tzcb_req->hdr.counts, OO) {
@@ -2237,7 +2250,8 @@ static int marshal_out_tzcb_req(const struct smcinvoke_accept *user_req,
 		}
 		ret = get_tzhandle_from_uhandle(tmp_arg.o.fd,
 				tmp_arg.o.cb_server_fd, &arr_filp[i],
-				&(tz_args[i].handle), NULL);
+				&(tz_args[i].handle), &l_mem_objs_pending_async);
+
 		if (ret)
 			goto out;
 		tzhandles_to_release[i] = tz_args[i].handle;
@@ -2247,12 +2261,36 @@ static int marshal_out_tzcb_req(const struct smcinvoke_accept *user_req,
 	}
 	ret = 0;
 out:
-        FOR_ARGS(i, tzcb_req->hdr.counts, OI) {
-                if (TZHANDLE_IS_CB_OBJ(tz_args[i].handle))
-                        release_tzhandles(&tz_args[i].handle, 1);
-        }
+	FOR_ARGS(i, tzcb_req->hdr.counts, OI) {
+		if (TZHANDLE_IS_CB_OBJ(tz_args[i].handle))
+			release_tzhandles(&tz_args[i].handle, 1);
+	}
+
+	do {
+		if (mem_obj_async_support) {
+		/* We will be able to add the async information to the buffer beyond the data in the max offset, if exists.
+		 * If doesn't exist, we can add the async information after the header and the args. */
+		offset = (max_offset ? (max_offset + buffer_size_max_offset) : tz_buf_offset);
+		offset = size_align(offset, SMCINVOKE_ARGS_ALIGN_SIZE);
+		async_buf_begin = (uint8_t *)tzcb_req + offset;
+
+		if (async_buf_begin - (void *)tzcb_req > g_max_cb_buf_size) {
+			pr_err("Unable to add memory object info to the async channel\n");
+			break;
+		} else {
+			async_buf_size = g_max_cb_buf_size - (async_buf_begin - (void *)tzcb_req);
+		}
+
+		mutex_lock(&g_smcinvoke_lock);
+		add_mem_obj_info_to_async_side_channel_locked(async_buf_begin, async_buf_size, &l_mem_objs_pending_async);
+		delete_pending_async_list_locked(&l_mem_objs_pending_async);
+		mutex_unlock(&g_smcinvoke_lock);
+		}
+	} while (0);
+
 	if (ret)
 		release_tzhandles(tzhandles_to_release, OBJECT_COUNTS_MAX_OO);
+
 	return ret;
 }
 

+ 3 - 2
smcinvoke/smcinvoke_kernel.c

@@ -445,14 +445,15 @@ exit_free_cxt:
 
 static int __qseecom_shutdown_app(struct qseecom_handle **handle)
 {
-	struct qseecom_compat_context *cxt =
-		(struct qseecom_compat_context *)(*handle);
 
+	struct qseecom_compat_context *cxt = NULL;
 	if ((handle == NULL)  || (*handle == NULL)) {
 		pr_err("Handle is NULL\n");
 		return -EINVAL;
 	}
 
+	cxt = (struct qseecom_compat_context *)(*handle);
+
 	qtee_shmbridge_free_shm(&cxt->shm);
 	Object_release(cxt->app_controller);
 	Object_release(cxt->app_loader);

+ 9 - 0
smmu-proxy/qti-smmu-proxy-common.c

@@ -23,6 +23,7 @@ int smmu_proxy_get_csf_version(struct csf_version *csf_version)
 	struct Object client_env = {0};
 	struct Object sc_object;
 
+	/* Assumption is that cached_csf_version.arch_ver !=0 ==> other vals are set */
 	if (cached_csf_version.arch_ver != 0) {
 		csf_version->arch_ver = cached_csf_version.arch_ver;
 		csf_version->max_ver = cached_csf_version.max_ver;
@@ -52,6 +53,14 @@ int smmu_proxy_get_csf_version(struct csf_version *csf_version)
 	Object_release(sc_object);
 	Object_release(client_env);
 
+	/*
+	 * Once we set cached_csf_version.arch_ver, concurrent callers will get
+	 * the cached value.
+	 */
+	cached_csf_version.min_ver = csf_version->min_ver;
+	cached_csf_version.max_ver = csf_version->max_ver;
+	cached_csf_version.arch_ver = csf_version->arch_ver;
+
 	return ret;
 }
 EXPORT_SYMBOL(smmu_proxy_get_csf_version);

+ 1 - 1
smmu-proxy/qti-smmu-proxy-pvm.c

@@ -167,7 +167,7 @@ int smmu_proxy_map(struct device *client_dev, struct sg_table *proxy_iova,
 	resp = buf;
 
 	if (resp->hdr.ret) {
-		proxy_iova = ERR_PTR(resp->hdr.ret);
+		ret = resp->hdr.ret;
 		pr_err_ratelimited("%s: Map call failed on remote VM, rc: %d\n", __func__,
 				   resp->hdr.ret);
 		goto free_buf;

+ 63 - 26
smmu-proxy/qti-smmu-proxy-tvm.c

@@ -34,6 +34,37 @@ struct device *cb_devices[QTI_SMMU_PROXY_CB_IDS_LEN] = { 0 };
 
 struct task_struct *receiver_msgq_handler_thread;
 
+static int zero_dma_buf(struct dma_buf *dmabuf)
+{
+	int ret;
+	struct iosys_map vmap_struct = {0};
+
+	ret = dma_buf_vmap(dmabuf, &vmap_struct);
+	if (ret) {
+		pr_err("%s: dma_buf_vmap() failed with %d\n", __func__, ret);
+		return ret;
+	}
+
+	/* Use DMA_TO_DEVICE since we are not reading anything */
+	ret = dma_buf_begin_cpu_access(dmabuf, DMA_TO_DEVICE);
+	if (ret) {
+		pr_err("%s: dma_buf_begin_cpu_access() failed with %d\n", __func__, ret);
+		goto unmap;
+	}
+
+	memset(vmap_struct.vaddr, 0, dmabuf->size);
+	ret = dma_buf_end_cpu_access(dmabuf, DMA_TO_DEVICE);
+	if (ret)
+		pr_err("%s: dma_buf_end_cpu_access() failed with %d\n", __func__, ret);
+unmap:
+	dma_buf_vunmap(dmabuf, &vmap_struct);
+
+	if (ret)
+		pr_err("%s: Failed to properly zero the DMA-BUF\n", __func__);
+
+	return ret;
+}
+
 static int iommu_unmap_and_relinquish(u32 hdl)
 {
 	int cb_id, ret = 0;
@@ -75,8 +106,11 @@ static int iommu_unmap_and_relinquish(u32 hdl)
 		}
 	}
 
-	dma_buf_put(buf_state->dmabuf);
-	flush_delayed_fput();
+	ret = zero_dma_buf(buf_state->dmabuf);
+	if (!ret) {
+		dma_buf_put(buf_state->dmabuf);
+		flush_delayed_fput();
+	}
 
 	xa_erase(&buffer_state_arr, hdl);
 	kfree(buf_state);
@@ -139,16 +173,16 @@ struct sg_table *retrieve_and_iommu_map(struct mem_buf_retrieve_kernel_arg *retr
 	mutex_lock(&buffer_state_lock);
 	buf_state = xa_load(&buffer_state_arr, retrieve_arg->memparcel_hdl);
 	if (buf_state) {
+		if (buf_state->cb_info[cb_id].mapped) {
+			table = buf_state->cb_info[cb_id].sg_table;
+			goto unlock;
+		}
 		if (buf_state->locked) {
 			pr_err("%s: handle 0x%llx is locked!\n", __func__,
 			       retrieve_arg->memparcel_hdl);
 			ret = -EINVAL;
 			goto unlock_err;
 		}
-		if (buf_state->cb_info[cb_id].mapped) {
-			table = buf_state->cb_info[cb_id].sg_table;
-			goto unlock;
-		}
 
 		dmabuf = buf_state->dmabuf;
 	} else {
@@ -160,6 +194,12 @@ struct sg_table *retrieve_and_iommu_map(struct mem_buf_retrieve_kernel_arg *retr
 			goto unlock_err;
 		}
 
+		ret = zero_dma_buf(dmabuf);
+		if (ret) {
+			pr_err("%s: Failed to zero the DMA-BUF rc: %d\n", __func__, ret);
+			goto free_buf;
+		}
+
 		buf_state = kzalloc(sizeof(*buf_state), GFP_KERNEL);
 		if (!buf_state) {
 			pr_err("%s: Unable to allocate memory for buf_state\n",
@@ -448,7 +488,6 @@ int smmu_proxy_clear_all_buffers(void __user *context_bank_id_array,
 {
 	unsigned long handle;
 	struct smmu_proxy_buffer_state *buf_state;
-	struct iosys_map vmap_struct = {0};
 	__u32 cb_ids[QTI_SMMU_PROXY_CB_IDS_LEN];
 	int i, ret = 0;
 	bool found_mapped_cb;
@@ -485,30 +524,13 @@ int smmu_proxy_clear_all_buffers(void __user *context_bank_id_array,
 		if (!found_mapped_cb)
 			continue;
 
-		ret = dma_buf_vmap(buf_state->dmabuf, &vmap_struct);
+		ret = zero_dma_buf(buf_state->dmabuf);
 		if (ret) {
 			pr_err("%s: dma_buf_vmap() failed with %d\n", __func__, ret);
-			goto unlock;
-		}
-
-		/* Use DMA_TO_DEVICE since we are not reading anything */
-		ret = dma_buf_begin_cpu_access(buf_state->dmabuf, DMA_TO_DEVICE);
-		if (ret) {
-			pr_err("%s: dma_buf_begin_cpu_access() failed with %d\n", __func__, ret);
-			goto unmap;
-		}
-
-		memset(vmap_struct.vaddr, 0, buf_state->dmabuf->size);
-		ret = dma_buf_end_cpu_access(buf_state->dmabuf, DMA_TO_DEVICE);
-		if (ret)
-			pr_err("%s: dma_buf_end_cpu_access() failed with %d\n", __func__, ret);
-unmap:
-		dma_buf_vunmap(buf_state->dmabuf, &vmap_struct);
-		if (ret)
 			break;
+		}
 	}
 
-unlock:
 	mutex_unlock(&buffer_state_lock);
 	return ret;
 }
@@ -649,10 +671,18 @@ free_msgq:
 	return ret;
 }
 
+static int proxy_fault_handler(struct iommu_domain *domain, struct device *dev,
+			       unsigned long iova, int flags, void *token)
+{
+	dev_err(dev, "Context fault with IOVA %llx and fault flags %d\n", iova, flags);
+	return -EINVAL;
+}
+
 static int cb_probe_handler(struct device *dev)
 {
 	int ret;
 	unsigned int context_bank_id;
+	struct iommu_domain *domain;
 
 	ret = of_property_read_u32(dev->of_node, "qti,cb-id", &context_bank_id);
 	if (ret) {
@@ -682,6 +712,13 @@ static int cb_probe_handler(struct device *dev)
 		return ret;
 	}
 
+	domain = iommu_get_domain_for_dev(dev);
+	if (IS_ERR_OR_NULL(domain)) {
+		dev_err(dev, "%s: Failed to get iommu domain\n", __func__);
+		return -EINVAL;
+	}
+
+	iommu_set_fault_handler(domain, proxy_fault_handler, NULL);
 	cb_devices[context_bank_id] = dev;
 
 	return 0;