Browse Source

ASoC: dsp: add spin lock to protect ac

There is a possible use after free for the ac variable
that causes kernel panic. Add spinlock to fix race
condition to avoid this issue.

Change-Id: I71638c269f572ff1eaad87682fa000cd1dad407d
Signed-off-by: Meng Wang <[email protected]>
Meng Wang 7 years ago
parent
commit
9730cdd47e
1 changed files with 93 additions and 29 deletions
  1. 93 29
      dsp/q6asm.c

+ 93 - 29
dsp/q6asm.c

@@ -91,8 +91,13 @@ struct asm_mmap {
 };
 
 static struct asm_mmap this_mmap;
+
+struct audio_session {
+	struct audio_client *ac;
+	spinlock_t session_lock;
+};
 /* session id: 0 reserved */
-static struct audio_client *session[ASM_ACTIVE_STREAMS_ALLOWED + 1];
+static struct audio_session session[ASM_ACTIVE_STREAMS_ALLOWED + 1];
 
 struct asm_buffer_node {
 	struct list_head list;
@@ -530,8 +535,8 @@ static int q6asm_session_alloc(struct audio_client *ac)
 	int n;
 
 	for (n = 1; n <= ASM_ACTIVE_STREAMS_ALLOWED; n++) {
-		if (!session[n]) {
-			session[n] = ac;
+		if (!(session[n].ac)) {
+			session[n].ac = ac;
 			return n;
 		}
 	}
@@ -539,25 +544,39 @@ static int q6asm_session_alloc(struct audio_client *ac)
 	return -ENOMEM;
 }
 
-static bool q6asm_is_valid_audio_client(struct audio_client *ac)
+static int q6asm_get_session_id_from_audio_client(struct audio_client *ac)
 {
 	int n;
 
 	for (n = 1; n <= ASM_ACTIVE_STREAMS_ALLOWED; n++) {
-		if (session[n] == ac)
-			return 1;
+		if (session[n].ac == ac)
+			return n;
 	}
 	return 0;
 }
 
+static bool q6asm_is_valid_audio_client(struct audio_client *ac)
+{
+	return q6asm_get_session_id_from_audio_client(ac) ? 1 : 0;
+}
+
 static void q6asm_session_free(struct audio_client *ac)
 {
+	int session_id;
+
 	pr_debug("%s: sessionid[%d]\n", __func__, ac->session);
+	session_id = ac->session;
 	rtac_remove_popp_from_adm_devices(ac->session);
-	session[ac->session] = 0;
+	spin_lock_bh(&(session[session_id].session_lock));
+	session[ac->session].ac = NULL;
 	ac->session = 0;
 	ac->perf_mode = LEGACY_PCM_MODE;
 	ac->fptr_cache_ops = NULL;
+	ac->cb = NULL;
+	ac->priv = NULL;
+	kfree(ac);
+	ac = NULL;
+	spin_unlock_bh(&(session[session_id].session_lock));
 }
 
 static uint32_t q6asm_get_next_buf(struct audio_client *ac,
@@ -1088,8 +1107,6 @@ void q6asm_audio_client_free(struct audio_client *ac)
 	pr_debug("%s: APR De-Register\n", __func__);
 
 /*done:*/
-	kfree(ac);
-	ac = NULL;
 	mutex_unlock(&session_lock);
 }
 EXPORT_SYMBOL(q6asm_audio_client_free);
@@ -1252,6 +1269,7 @@ struct audio_client *q6asm_audio_client_alloc(app_cb cb, void *priv)
 	if (n <= 0) {
 		pr_err("%s: ASM Session alloc fail n=%d\n", __func__, n);
 		mutex_unlock(&session_lock);
+		kfree(ac);
 		goto fail_session;
 	}
 	ac->session = n;
@@ -1328,7 +1346,6 @@ fail_apr2:
 fail_apr1:
 	q6asm_session_free(ac);
 fail_session:
-	kfree(ac);
 	return NULL;
 }
 EXPORT_SYMBOL(q6asm_audio_client_alloc);
@@ -1351,11 +1368,11 @@ struct audio_client *q6asm_get_audio_client(int session_id)
 		goto err;
 	}
 
-	if (!session[session_id]) {
+	if (!(session[session_id].ac)) {
 		pr_err("%s: session not active: %d\n", __func__, session_id);
 		goto err;
 	}
-	return session[session_id];
+	return session[session_id].ac;
 err:
 	return NULL;
 }
@@ -1591,6 +1608,8 @@ static int32_t q6asm_srvc_callback(struct apr_client_data *data, void *priv)
 	struct audio_client *ac = NULL;
 	struct audio_port_data *port;
 
+	int session_id;
+
 	if (!data) {
 		pr_err("%s: Invalid CB\n", __func__);
 		return 0;
@@ -1632,12 +1651,20 @@ static int32_t q6asm_srvc_callback(struct apr_client_data *data, void *priv)
 		return 0;
 	}
 	asm_token.token = data->token;
-	ac = q6asm_get_audio_client(asm_token._token.session_id);
+	session_id = asm_token._token.session_id;
+
+	if ((session_id > 0 && session_id <= ASM_ACTIVE_STREAMS_ALLOWED))
+		spin_lock(&(session[session_id].session_lock));
+
+	ac = q6asm_get_audio_client(session_id);
 	dir = q6asm_get_flag_from_token(&asm_token, ASM_DIRECTION_OFFSET);
 
 	if (!ac) {
 		pr_debug("%s: session[%d] already freed\n",
-			 __func__, asm_token._token.session_id);
+			 __func__, session_id);
+		if ((session_id > 0 &&
+			session_id <= ASM_ACTIVE_STREAMS_ALLOWED))
+			spin_unlock(&(session[session_id].session_lock));
 		return 0;
 	}
 
@@ -1688,6 +1715,9 @@ static int32_t q6asm_srvc_callback(struct apr_client_data *data, void *priv)
 						__func__, payload[0]);
 			break;
 		}
+		if ((session_id > 0 &&
+			session_id <= ASM_ACTIVE_STREAMS_ALLOWED))
+			spin_unlock(&(session[session_id].session_lock));
 		return 0;
 	}
 
@@ -1722,6 +1752,9 @@ static int32_t q6asm_srvc_callback(struct apr_client_data *data, void *priv)
 	if (ac->cb)
 		ac->cb(data->opcode, data->token,
 			data->payload, ac->priv);
+	if ((session_id > 0 && session_id <= ASM_ACTIVE_STREAMS_ALLOWED))
+		spin_unlock(&(session[session_id].session_lock));
+
 	return 0;
 }
 
@@ -1789,6 +1822,8 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 	struct msm_adsp_event_data *pp_event_package = NULL;
 	uint32_t payload_size = 0;
 
+	int session_id;
+
 	if (ac == NULL) {
 		pr_err("%s: ac NULL\n", __func__);
 		return -EINVAL;
@@ -1797,15 +1832,19 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 		pr_err("%s: data NULL\n", __func__);
 		return -EINVAL;
 	}
-	if (!q6asm_is_valid_audio_client(ac)) {
-		pr_err("%s: audio client pointer is invalid, ac = %pK\n",
-				__func__, ac);
+
+	session_id = q6asm_get_session_id_from_audio_client(ac);
+	if (session_id <= 0 || session_id > ASM_ACTIVE_STREAMS_ALLOWED) {
+		pr_err("%s: Session ID is invalid, session = %d\n", __func__,
+			session_id);
 		return -EINVAL;
 	}
+	spin_lock(&(session[session_id].session_lock));
 
-	if (ac->session <= 0 || ac->session > 8) {
-		pr_err("%s: Session ID is invalid, session = %d\n", __func__,
-			ac->session);
+	if (!q6asm_is_valid_audio_client(ac)) {
+		pr_err("%s: audio client pointer is invalid, ac = %pK\n",
+				__func__, ac);
+		spin_unlock(&(session[session_id].session_lock));
 		return -EINVAL;
 	}
 
@@ -1818,7 +1857,9 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 	}
 
 	if (data->opcode == RESET_EVENTS) {
+		spin_unlock(&(session[session_id].session_lock));
 		mutex_lock(&ac->cmd_lock);
+		spin_lock(&(session[session_id].session_lock));
 		atomic_set(&ac->reset, 1);
 		if (ac->apr == NULL) {
 			ac->apr = ac->apr2;
@@ -1839,6 +1880,7 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 		wake_up(&ac->time_wait);
 		wake_up(&ac->cmd_wait);
 		wake_up(&ac->mem_wait);
+		spin_unlock(&(session[session_id].session_lock));
 		mutex_unlock(&ac->cmd_lock);
 		return 0;
 	}
@@ -1853,6 +1895,7 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 	    (data->opcode != ASM_SESSION_EVENT_RX_UNDERFLOW)) {
 		if (payload == NULL) {
 			pr_err("%s: payload is null\n", __func__);
+			spin_unlock(&(session[session_id].session_lock));
 			return -EINVAL;
 		}
 		dev_vdbg(ac->dev, "%s: Payload = [0x%x] status[0x%x] opcode 0x%x\n",
@@ -1878,6 +1921,7 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 		ret = q6asm_is_valid_session(data, priv);
 		if (ret != 0) {
 			pr_err("%s: session invalid %d\n", __func__, ret);
+			spin_unlock(&(session[session_id].session_lock));
 			return ret;
 		}
 		case ASM_SESSION_CMD_SET_MTMX_STRTR_PARAMS_V2:
@@ -1917,6 +1961,8 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 								payload[1]);
 					wake_up(&ac->cmd_wait);
 				}
+				spin_unlock(
+					&(session[session_id].session_lock));
 				return 0;
 			}
 			if ((is_adsp_reg_event(payload[0]) >= 0) ||
@@ -1947,6 +1993,8 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 					atomic_set(&ac->mem_state, payload[1]);
 					wake_up(&ac->mem_wait);
 				}
+				spin_unlock(
+					&(session[session_id].session_lock));
 				return 0;
 			}
 			if (atomic_read(&ac->mem_state) && wakeup_flag) {
@@ -1994,6 +2042,8 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 							__func__, payload[0]);
 			break;
 		}
+
+		spin_unlock(&(session[session_id].session_lock));
 		return 0;
 	}
 
@@ -2008,6 +2058,8 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 			if (port->buf == NULL) {
 				pr_err("%s: Unexpected Write Done\n",
 								__func__);
+				spin_unlock(
+					&(session[session_id].session_lock));
 				return -EINVAL;
 			}
 			spin_lock_irqsave(&port->dsp_lock, dsp_flags);
@@ -2022,6 +2074,8 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 					__func__, payload[0], payload[1]);
 				spin_unlock_irqrestore(&port->dsp_lock,
 								dsp_flags);
+				spin_unlock(
+					&(session[session_id].session_lock));
 				return -EINVAL;
 			}
 			port->buf[buf_index].used = 1;
@@ -2092,6 +2146,8 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 		if (ac->io_mode & SYNC_IO_MODE) {
 			if (port->buf == NULL) {
 				pr_err("%s: Unexpected Write Done\n", __func__);
+				spin_unlock(
+					&(session[session_id].session_lock));
 				return -EINVAL;
 			}
 			spin_lock_irqsave(&port->dsp_lock, dsp_flags);
@@ -2166,8 +2222,10 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 		pr_debug("%s: ASM_STREAM_EVENT payload[0][0x%x] payload[1][0x%x]",
 				 __func__, payload[0], payload[1]);
 		i = is_adsp_raise_event(data->opcode);
-		if (i < 0)
+		if (i < 0) {
+			spin_unlock(&(session[session_id].session_lock));
 			return 0;
+		}
 
 		/* repack payload for asm_stream_pp_event
 		 * package is composed of event type + size + actual payload
@@ -2176,8 +2234,10 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 		pp_event_package = kzalloc(payload_size
 				+ sizeof(struct msm_adsp_event_data),
 				GFP_ATOMIC);
-		if (!pp_event_package)
+		if (!pp_event_package) {
+			spin_unlock(&(session[session_id].session_lock));
 			return -ENOMEM;
+		}
 
 		pp_event_package->event_type = i;
 		pp_event_package->payload_len = payload_size;
@@ -2186,6 +2246,7 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 		ac->cb(data->opcode, data->token,
 			(void *)pp_event_package, ac->priv);
 		kfree(pp_event_package);
+		spin_unlock(&(session[session_id].session_lock));
 		return 0;
 	case ASM_SESSION_CMDRSP_ADJUST_SESSION_CLOCK_V2:
 		pr_debug("%s: ASM_SESSION_CMDRSP_ADJUST_SESSION_CLOCK_V2 sesion %d status 0x%x msw %u lsw %u\n",
@@ -2211,7 +2272,7 @@ static int32_t q6asm_callback(struct apr_client_data *data, void *priv)
 	if (ac->cb)
 		ac->cb(data->opcode, data->token,
 			data->payload, ac->priv);
-
+	spin_unlock(&(session[session_id].session_lock));
 	return 0;
 }
 
@@ -9864,7 +9925,7 @@ int q6asm_get_apr_service_id(int session_id)
 		return -EINVAL;
 	}
 
-	return ((struct apr_svc *)session[session_id]->apr)->id;
+	return ((struct apr_svc *)(session[session_id].ac)->apr)->id;
 }
 
 int q6asm_get_asm_topology(int session_id)
@@ -9875,12 +9936,12 @@ int q6asm_get_asm_topology(int session_id)
 		pr_err("%s: invalid session_id = %d\n", __func__, session_id);
 		goto done;
 	}
-	if (session[session_id] == NULL) {
+	if (session[session_id].ac == NULL) {
 		pr_err("%s: session not created for session id = %d\n",
 		       __func__, session_id);
 		goto done;
 	}
-	topology = session[session_id]->topology;
+	topology = (session[session_id].ac)->topology;
 done:
 	return topology;
 }
@@ -9893,12 +9954,12 @@ int q6asm_get_asm_app_type(int session_id)
 		pr_err("%s: invalid session_id = %d\n", __func__, session_id);
 		goto done;
 	}
-	if (session[session_id] == NULL) {
+	if (session[session_id].ac == NULL) {
 		pr_err("%s: session not created for session id = %d\n",
 		       __func__, session_id);
 		goto done;
 	}
-	app_type = session[session_id]->app_type;
+	app_type = (session[session_id].ac)->app_type;
 done:
 	return app_type;
 }
@@ -10253,7 +10314,10 @@ int __init q6asm_init(void)
 
 	pr_debug("%s:\n", __func__);
 
-	memset(session, 0, sizeof(session));
+	memset(session, 0, sizeof(struct audio_session) *
+		(ASM_ACTIVE_STREAMS_ALLOWED + 1));
+	for (lcnt = 0; lcnt <= ASM_ACTIVE_STREAMS_ALLOWED; lcnt++)
+		spin_lock_init(&(session[lcnt].session_lock));
 	set_custom_topology = 1;
 
 	/*setup common client used for cal mem map */