1 /*
2 * Copyright (c) 2023 Institute of Parallel And Distributed Systems (IPADS), Shanghai Jiao Tong University (SJTU)
3 * Licensed under the Mulan PSL v2.
4 * You can use this software according to the terms and conditions of the Mulan PSL v2.
5 * You may obtain a copy of Mulan PSL v2 at:
6 * http://license.coscl.org.cn/MulanPSL2
7 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
8 * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
9 * PURPOSE.
10 * See the Mulan PSL v2 for more details.
11 */
12 #include <ipc/channel.h>
13 #include <sched/sched.h>
14 #include <mm/uaccess.h>
15 #include <common/util.h>
16
17 #define GTASK_PID (4)
18 #define GTASK_TID (0xa)
19
20 /*
21 * 1. if msg queue is empty, server thread will be inserted into thread queue
22 * and not be sched until client thread enqueue it
23 * 2. if not, server thread will directly consume the msg and return
24 */
__tee_msg_receive(struct channel * channel,void * recv_buf,size_t recv_len,struct msg_hdl * msg_hdl,void * info,int timeout)25 static int __tee_msg_receive(struct channel *channel, void *recv_buf,
26 size_t recv_len, struct msg_hdl *msg_hdl,
27 void *info, int timeout)
28 {
29 struct msg_entry *msg_entry;
30 size_t copy_len;
31 int ret;
32
33 if (channel->creater != current_cap_group) {
34 return -EINVAL;
35 }
36
37 lock(&channel->lock);
38 lock(&msg_hdl->lock);
39
40 if (channel->state == CHANNEL_INVALID) {
41 unlock(&msg_hdl->lock);
42 unlock(&channel->lock);
43 return -EINVAL;
44 }
45
46 if (list_empty(&channel->msg_queue)) {
47 kdebug("%s: list_empty(&channel->msg_queue)\n", __func__);
48 if (timeout == OS_NO_WAIT) {
49 unlock(&msg_hdl->lock);
50 unlock(&channel->lock);
51 return E_EX_TIMER_TIMEOUT;
52 }
53 msg_hdl->server_msg_record.server = current_thread;
54 msg_hdl->server_msg_record.recv_buf = recv_buf;
55 msg_hdl->server_msg_record.recv_len = recv_len;
56 msg_hdl->server_msg_record.info = info;
57
58 list_append(&msg_hdl->thread_queue_node, &channel->thread_queue);
59
60 unlock(&msg_hdl->lock);
61 unlock(&channel->lock);
62
63
64 /* obj_put due to noreturn */
65 obj_put(msg_hdl);
66 obj_put(channel);
67
68 current_thread->thread_ctx->state = TS_WAITING;
69 sched();
70 eret_to_thread(switch_context());
71 BUG_ON(1);
72 } else {
73 kdebug("%s: !list_empty(&channel->msg_queue)\n", __func__);
74 msg_entry = list_entry(
75 channel->msg_queue.next, struct msg_entry, msg_queue_node);
76 copy_len = MIN(msg_entry->client_msg_record.send_len, recv_len);
77 ret = copy_to_user(
78 recv_buf, msg_entry->client_msg_record.ksend_buf, copy_len);
79 if (ret < 0) {
80 ret = -EFAULT;
81 goto out_unlock;
82 }
83
84 ret = copy_to_user(info,
85 (char *)&msg_entry->client_msg_record.info,
86 sizeof(struct src_msginfo));
87 if (ret < 0) {
88 ret = -EFAULT;
89 goto out_unlock;
90 }
91 list_del(&msg_entry->msg_queue_node);
92
93 memcpy(&msg_hdl->client_msg_record,
94 &msg_entry->client_msg_record,
95 sizeof(struct client_msg_record));
96 ret = 0;
97
98 kfree(msg_entry->client_msg_record.ksend_buf);
99 kfree(msg_entry);
100
101 out_unlock:
102 unlock(&msg_hdl->lock);
103 unlock(&channel->lock);
104 return ret;
105 }
106 }
107
108 /*
109 * Client calling __tee_msg_send will cause
110 * 1. if thread queue is empty, client thread's msg will be inserted
111 * into msg queue
112 * 2. if not, client thread will directly choose and sched_enqueue one
113 * server thread
114 */
__tee_msg_send(struct channel * channel,struct client_msg_record * client_msg_record)115 static int __tee_msg_send(struct channel *channel,
116 struct client_msg_record *client_msg_record)
117 {
118 struct msg_hdl *msg_hdl;
119 struct msg_entry *msg_entry;
120 struct thread *client, *server;
121 size_t copy_len;
122 int ret = 0;
123
124 lock(&channel->lock);
125
126 if (channel->state == CHANNEL_INVALID) {
127 ret = -EINVAL;
128 goto out;
129 }
130
131 if (list_empty(&channel->thread_queue)) {
132 kdebug("%s: list_empty(&channel->thread_queue)\n", __func__);
133 msg_entry = kmalloc(sizeof(*msg_entry));
134
135 memcpy(&msg_entry->client_msg_record,
136 client_msg_record,
137 sizeof(*client_msg_record));
138
139 list_append(&msg_entry->msg_queue_node, &channel->msg_queue);
140 } else {
141 kdebug("%s: !list_empty(&channel->thread_queue)\n", __func__);
142 msg_hdl = list_entry(
143 channel->thread_queue.next, struct msg_hdl, thread_queue_node);
144
145 lock(&msg_hdl->lock);
146 list_del(&msg_hdl->thread_queue_node);
147
148 memcpy(&msg_hdl->client_msg_record,
149 client_msg_record,
150 sizeof(*client_msg_record));
151
152 server = msg_hdl->server_msg_record.server;
153 client = current_thread;
154
155 current_thread = server;
156 switch_thread_vmspace_to(server);
157
158 copy_len = MIN(client_msg_record->send_len,
159 msg_hdl->server_msg_record.recv_len);
160 ret = copy_to_user(msg_hdl->server_msg_record.recv_buf,
161 client_msg_record->ksend_buf,
162 copy_len);
163 if (ret < 0) {
164 goto out_copy;
165 }
166
167 ret = copy_to_user(msg_hdl->server_msg_record.info,
168 (char *)&msg_hdl->client_msg_record.info,
169 sizeof(struct src_msginfo));
170
171 out_copy:
172 current_thread = client;
173 switch_thread_vmspace_to(client);
174 if (ret < 0) {
175 ret = -EFAULT;
176 goto out_unlock;
177 }
178
179 kfree(client_msg_record->ksend_buf);
180 arch_set_thread_return(server, 0);
181 server->thread_ctx->state = TS_INTER;
182 BUG_ON(sched_enqueue(server));
183 kdebug("%s: enqueued %s\n", __func__, server->cap_group->cap_group_name);
184 out_unlock:
185 unlock(&msg_hdl->lock);
186 }
187
188 out:
189 unlock(&channel->lock);
190 return ret;
191 }
192
193 /*
194 * After client calling __tee_msg_send, client thread will be record in
195 * msg_hdl and not be sched until server thread __tee_msg_reply to enqueue it
196 */
__tee_msg_call(struct channel * channel,void * send_buf,size_t send_len,void * recv_buf,size_t recv_len,struct timespec * timeout)197 static int __tee_msg_call(struct channel *channel, void *send_buf,
198 size_t send_len, void *recv_buf, size_t recv_len,
199 struct timespec *timeout)
200 {
201 struct client_msg_record client_msg_record;
202 void *ksend_buf;
203 int ret;
204
205 kdebug("%s: %s calls %s\n",
206 __func__,
207 current_cap_group->cap_group_name,
208 channel->creater->cap_group_name);
209
210 if ((ksend_buf = kmalloc(send_len)) == NULL) {
211 ret = -ENOMEM;
212 goto out_fail;
213 }
214 ret = copy_from_user(ksend_buf, send_buf, send_len);
215 if (ret < 0) {
216 ret = -EFAULT;
217 goto out_free_ksend_buf;
218 }
219
220 client_msg_record.client = current_thread;
221 client_msg_record.ksend_buf = ksend_buf;
222 client_msg_record.send_len = send_len;
223 client_msg_record.recv_buf = recv_buf;
224 client_msg_record.recv_len = recv_len;
225 client_msg_record.info.msg_type = MSG_TYPE_CALL;
226 client_msg_record.info.src_pid = current_cap_group->pid;
227 client_msg_record.info.src_tid = current_thread->cap;
228 if (current_cap_group->pid == GTASK_PID
229 && current_thread->cap == GTASK_TID) {
230 client_msg_record.info.src_pid = 0;
231 client_msg_record.info.src_tid = 0;
232 }
233
234 ret = __tee_msg_send(channel, &client_msg_record);
235 if (ret != 0) {
236 goto out_free_ksend_buf;
237 }
238
239 /* obj_put due to noreturn */
240 obj_put(channel);
241
242 current_thread->thread_ctx->state = TS_WAITING;
243 sched();
244 eret_to_thread(switch_context());
245 BUG_ON(1);
246
247 out_free_ksend_buf:
248 kfree(ksend_buf);
249
250 out_fail:
251 return ret;
252 }
253
254 /* Enqueue blocking client thread */
__tee_msg_reply(struct msg_hdl * msg_hdl,void * reply_buf,size_t reply_len)255 static int __tee_msg_reply(struct msg_hdl *msg_hdl, void *reply_buf,
256 size_t reply_len)
257 {
258 struct thread *client, *server;
259 void *kreply_buf;
260 size_t copy_len;
261 int ret = 0;
262
263 kdebug("%s: %s replies to %s\n",
264 __func__,
265 current_cap_group->cap_group_name,
266 msg_hdl->client_msg_record.client->cap_group->cap_group_name);
267
268 lock(&msg_hdl->lock);
269
270 if (msg_hdl->client_msg_record.info.msg_type != MSG_TYPE_CALL) {
271 ret = -EINVAL;
272 goto out;
273 }
274
275 if ((kreply_buf = kmalloc(reply_len)) == NULL) {
276 ret = -ENOMEM;
277 goto out;
278 }
279 ret = copy_from_user(kreply_buf, reply_buf, reply_len);
280 if (ret < 0) {
281 ret = -EFAULT;
282 goto out_free_kreply_buf;
283 }
284
285 client = msg_hdl->client_msg_record.client;
286 server = current_thread;
287
288 current_thread = client;
289 switch_thread_vmspace_to(client);
290
291 copy_len = MIN(reply_len, msg_hdl->client_msg_record.recv_len);
292 ret =
293 copy_to_user(msg_hdl->client_msg_record.recv_buf, kreply_buf, copy_len);
294
295 current_thread = server;
296 switch_thread_vmspace_to(server);
297 if (ret < 0) {
298 ret = -EFAULT;
299 goto out_free_kreply_buf;
300 }
301
302 /*
303 * Wait for client's kernel stack to make sure that
304 * sched_enqueue(client) executes after sched() in client's
305 * __tee_msg_call. Note that wait_for_kernel_stack executes with
306 * server's msg_hdl locked. Client should NOT hold the lock of msg_hdl,
307 * which is established under the assumption that msg_hdl cap should not
308 * be distributed to others.
309 */
310 wait_for_kernel_stack(client);
311
312 arch_set_thread_return(client, 0);
313 client->thread_ctx->state = TS_INTER;
314 BUG_ON(sched_enqueue(client));
315
316 /* A call cannot be replied successfully twice. */
317 msg_hdl->client_msg_record.info.msg_type = MSG_TYPE_INVALID;
318
319 out_free_kreply_buf:
320 kfree(kreply_buf);
321 out:
322 unlock(&msg_hdl->lock);
323 return ret;
324 }
325
326 /* __tee_msg_send and return directly */
__tee_msg_notify(struct channel * channel,void * send_buf,size_t send_len)327 static int __tee_msg_notify(struct channel *channel, void *send_buf,
328 size_t send_len)
329 {
330 struct client_msg_record client_msg_record;
331 void *ksend_buf;
332 int ret;
333
334 kdebug("%s: %s notifies %s\n",
335 __func__,
336 current_cap_group->cap_group_name,
337 channel->creater->cap_group_name);
338
339 if ((ksend_buf = kmalloc(send_len)) == NULL) {
340 ret = -ENOMEM;
341 goto out_fail;
342 }
343 ret = copy_from_user(ksend_buf, send_buf, send_len);
344 if (ret < 0) {
345 ret = -EFAULT;
346 goto out_free_ksend_buf;
347 }
348
349 client_msg_record.client = current_thread;
350 client_msg_record.ksend_buf = ksend_buf;
351 client_msg_record.send_len = send_len;
352 client_msg_record.info.msg_type = MSG_TYPE_NOTIF;
353 client_msg_record.info.src_pid = current_cap_group->pid;
354 client_msg_record.info.src_tid = current_thread->cap;
355 if (current_cap_group->pid == GTASK_PID
356 && current_thread->cap == GTASK_TID) {
357 client_msg_record.info.src_pid = 0;
358 client_msg_record.info.src_tid = 0;
359 }
360
361 ret = __tee_msg_send(channel, &client_msg_record);
362 if (ret != 0) {
363 goto out_free_ksend_buf;
364 }
365
366 return 0;
367
368 out_free_ksend_buf:
369 kfree(ksend_buf);
370
371 out_fail:
372 return ret;
373 }
374
375 /* Wake up all blocking clients in the msg_queue */
__wake_up_all_clients(struct channel * channel)376 static int __wake_up_all_clients(struct channel *channel)
377 {
378 struct msg_entry *entry;
379 struct thread *client;
380
381 for_each_in_list (
382 entry, struct msg_entry, msg_queue_node, &channel->msg_queue) {
383 if (entry->client_msg_record.info.msg_type == MSG_TYPE_CALL) {
384 client = entry->client_msg_record.client;
385 BUG_ON(client->thread_ctx->state != TS_WAITING
386 && client->thread_ctx->state != TS_EXIT);
387 if (client->thread_ctx->state == TS_WAITING) {
388 arch_set_thread_return(client, -EINVAL);
389 client->thread_ctx->state = TS_INTER;
390 BUG_ON(sched_enqueue(client));
391 }
392 }
393 }
394
395 return 0;
396 }
397
398 /*
399 * Destroy waiting nodes in the msg queue of the given channel
400 * which belong to the given cap_group
401 */
__destory_waiting_node(struct channel * channel,struct cap_group * cap_group)402 static int __destory_waiting_node(struct channel *channel,
403 struct cap_group *cap_group)
404 {
405 struct msg_entry *entry;
406
407 for_each_in_list (
408 entry, struct msg_entry, msg_queue_node, &channel->msg_queue) {
409 if (entry->client_msg_record.client->cap_group == cap_group) {
410 list_del(&entry->msg_queue_node);
411 kfree(entry->client_msg_record.ksend_buf);
412 kfree(entry);
413 }
414 }
415
416 return 0;
417 }
418
419 /*
420 * close_channel will be called if
421 * 1. channel's creater calls sys_tee_msg_stop_channel
422 * 2. recycler calls sys_cap_group_recycle
423 */
close_channel(struct channel * channel,struct cap_group * cap_group)424 int close_channel(struct channel *channel, struct cap_group *cap_group)
425 {
426 lock(&channel->lock);
427 if (channel->creater == cap_group) {
428 channel->state = CHANNEL_INVALID;
429 __wake_up_all_clients(channel);
430 } else {
431 __destory_waiting_node(channel, cap_group);
432 }
433 unlock(&channel->lock);
434 return 0;
435 }
436
sys_tee_msg_create_msg_hdl(void)437 int sys_tee_msg_create_msg_hdl(void)
438 {
439 struct msg_hdl *msg_hdl = NULL;
440 int msg_hdl_cap = 0;
441 int ret = 0;
442
443 msg_hdl = obj_alloc(TYPE_MSG_HDL, sizeof(*msg_hdl));
444 if (!msg_hdl) {
445 ret = -ENOMEM;
446 goto out_fail;
447 }
448
449 /* init msg_hdl */
450 lock_init(&msg_hdl->lock);
451
452 msg_hdl_cap = cap_alloc(current_cap_group, msg_hdl);
453 if (msg_hdl_cap < 0) {
454 ret = msg_hdl_cap;
455 goto out_free_obj;
456 }
457
458 return msg_hdl_cap;
459
460 out_free_obj:
461 obj_free(msg_hdl);
462
463 out_fail:
464 return ret;
465 }
466
sys_tee_msg_create_channel(void)467 int sys_tee_msg_create_channel(void)
468 {
469 struct channel *channel = NULL;
470 int channel_cap = 0;
471 int ret = 0;
472
473 channel = obj_alloc(TYPE_CHANNEL, sizeof(*channel));
474 if (!channel) {
475 ret = -ENOMEM;
476 goto out_fail;
477 }
478
479 /* init channel */
480 lock_init(&channel->lock);
481 init_list_head(&channel->msg_queue);
482 init_list_head(&channel->thread_queue);
483 channel->creater = current_cap_group;
484 channel->state = CHANNEL_VALID;
485
486 channel_cap = cap_alloc(current_cap_group, channel);
487 if (channel_cap < 0) {
488 ret = channel_cap;
489 goto out_free_obj;
490 }
491
492 return channel_cap;
493
494 out_free_obj:
495 obj_free(channel);
496
497 out_fail:
498 return ret;
499 }
500
sys_tee_msg_stop_channel(int channel_cap)501 int sys_tee_msg_stop_channel(int channel_cap)
502 {
503 struct channel *channel;
504 int ret;
505
506 channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
507 if (channel == NULL) {
508 ret = -ECAPBILITY;
509 goto out_fail_get_channel;
510 }
511
512 if (channel->creater == current_cap_group) {
513 ret = close_channel(channel, current_cap_group);
514 } else {
515 ret = -EINVAL;
516 }
517
518 obj_put(channel);
519
520 out_fail_get_channel:
521 return ret;
522 }
523
sys_tee_msg_receive(int channel_cap,void * recv_buf,size_t recv_len,int msg_hdl_cap,void * info,int timeout)524 int sys_tee_msg_receive(int channel_cap, void *recv_buf, size_t recv_len,
525 int msg_hdl_cap, void *info, int timeout)
526 {
527 struct channel *channel;
528 struct msg_hdl *msg_hdl;
529 int ret;
530
531 if (check_user_addr_range((vaddr_t)recv_buf, recv_len) != 0) {
532 return -EINVAL;
533 }
534 if (check_user_addr_range((vaddr_t)info, sizeof(struct src_msginfo)) != 0) {
535 return -EINVAL;
536 }
537
538 channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
539 if (channel == NULL) {
540 ret = -ECAPBILITY;
541 goto out_fail_get_channel;
542 }
543 msg_hdl = obj_get(current_cap_group, msg_hdl_cap, TYPE_MSG_HDL);
544 if (msg_hdl == NULL) {
545 ret = -ECAPBILITY;
546 goto out_fail_get_msg_hdl;
547 }
548
549 /* Assume __tee_msg_receive will obj_put channel & msg_hdl if noreturn
550 */
551 ret =
552 __tee_msg_receive(channel, recv_buf, recv_len, msg_hdl, info, timeout);
553
554 obj_put(msg_hdl);
555
556 out_fail_get_msg_hdl:
557 obj_put(channel);
558
559 out_fail_get_channel:
560 return ret;
561 }
562
sys_tee_msg_call(int channel_cap,void * send_buf,size_t send_len,void * recv_buf,size_t recv_len,struct timespec * timeout)563 int sys_tee_msg_call(int channel_cap, void *send_buf, size_t send_len,
564 void *recv_buf, size_t recv_len, struct timespec *timeout)
565 {
566 struct channel *channel;
567 int ret;
568
569 if (check_user_addr_range((vaddr_t)send_buf, send_len) != 0) {
570 return -EINVAL;
571 }
572 if (check_user_addr_range((vaddr_t)recv_buf, recv_len) != 0) {
573 return -EINVAL;
574 }
575
576 channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
577 if (channel == NULL) {
578 ret = -ECAPBILITY;
579 goto out_fail_get_channel;
580 }
581
582 /* Assume __tee_msg_call will obj_put channel if noreturn */
583 ret = __tee_msg_call(
584 channel, send_buf, send_len, recv_buf, recv_len, timeout);
585
586 obj_put(channel);
587
588 out_fail_get_channel:
589 return ret;
590 }
591
sys_tee_msg_reply(int msg_hdl_cap,void * reply_buf,size_t reply_len)592 int sys_tee_msg_reply(int msg_hdl_cap, void *reply_buf, size_t reply_len)
593 {
594 struct msg_hdl *msg_hdl;
595 int ret;
596
597 if (check_user_addr_range((vaddr_t)reply_buf, reply_len) != 0) {
598 return -EINVAL;
599 }
600
601 msg_hdl = obj_get(current_cap_group, msg_hdl_cap, TYPE_MSG_HDL);
602 if (msg_hdl == NULL) {
603 ret = -ECAPBILITY;
604 goto out_fail_get_msg_hdl;
605 }
606
607 ret = __tee_msg_reply(msg_hdl, reply_buf, reply_len);
608
609 obj_put(msg_hdl);
610
611 out_fail_get_msg_hdl:
612 return ret;
613 }
614
sys_tee_msg_notify(int channel_cap,void * send_buf,size_t send_len)615 int sys_tee_msg_notify(int channel_cap, void *send_buf, size_t send_len)
616 {
617 struct channel *channel;
618 int ret;
619
620 if (check_user_addr_range((vaddr_t)send_buf, send_len) != 0) {
621 return -EINVAL;
622 }
623
624 channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
625 if (channel == NULL) {
626 ret = -ECAPBILITY;
627 goto out_fail_get_channel;
628 }
629
630 ret = __tee_msg_notify(channel, send_buf, send_len);
631
632 obj_put(channel);
633
634 out_fail_get_channel:
635 return ret;
636 }
637
channel_deinit(void * ptr)638 void channel_deinit(void *ptr)
639 {
640 struct msg_entry *entry;
641 struct channel *channel;
642
643 channel = (struct channel *)ptr;
644
645 for_each_in_list (
646 entry, struct msg_entry, msg_queue_node, &channel->msg_queue) {
647 list_del(&entry->msg_queue_node);
648 kfree(entry->client_msg_record.ksend_buf);
649 kfree(entry);
650 }
651 }
652
msg_hdl_deinit(void * ptr)653 void msg_hdl_deinit(void *ptr)
654 {
655 }
656