• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Institute of Parallel And Distributed Systems (IPADS), Shanghai Jiao Tong University (SJTU)
3  * Licensed under the Mulan PSL v2.
4  * You can use this software according to the terms and conditions of the Mulan PSL v2.
5  * You may obtain a copy of Mulan PSL v2 at:
6  *     http://license.coscl.org.cn/MulanPSL2
7  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
8  * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
9  * PURPOSE.
10  * See the Mulan PSL v2 for more details.
11  */
12 #include <ipc/channel.h>
13 #include <sched/sched.h>
14 #include <mm/uaccess.h>
15 #include <common/util.h>
16 
17 #define GTASK_PID (4)
18 #define GTASK_TID (0xa)
19 
20 /*
21  * 1. if msg queue is empty, server thread will be inserted into thread queue
22  *    and not be sched until client thread enqueue it
23  * 2. if not, server thread will directly consume the msg and return
24  */
__tee_msg_receive(struct channel * channel,void * recv_buf,size_t recv_len,struct msg_hdl * msg_hdl,void * info,int timeout)25 static int __tee_msg_receive(struct channel *channel, void *recv_buf,
26                              size_t recv_len, struct msg_hdl *msg_hdl,
27                              void *info, int timeout)
28 {
29     struct msg_entry *msg_entry;
30     size_t copy_len;
31     int ret;
32 
33     if (channel->creater != current_cap_group) {
34         return -EINVAL;
35     }
36 
37     lock(&channel->lock);
38     lock(&msg_hdl->lock);
39 
40     if (channel->state == CHANNEL_INVALID) {
41         unlock(&msg_hdl->lock);
42         unlock(&channel->lock);
43         return -EINVAL;
44     }
45 
46     if (list_empty(&channel->msg_queue)) {
47         kdebug("%s: list_empty(&channel->msg_queue)\n", __func__);
48         if (timeout == OS_NO_WAIT) {
49             unlock(&msg_hdl->lock);
50             unlock(&channel->lock);
51             return E_EX_TIMER_TIMEOUT;
52         }
53         msg_hdl->server_msg_record.server = current_thread;
54         msg_hdl->server_msg_record.recv_buf = recv_buf;
55         msg_hdl->server_msg_record.recv_len = recv_len;
56         msg_hdl->server_msg_record.info = info;
57 
58         list_append(&msg_hdl->thread_queue_node, &channel->thread_queue);
59 
60         unlock(&msg_hdl->lock);
61         unlock(&channel->lock);
62 
63 
64         /* obj_put due to noreturn */
65         obj_put(msg_hdl);
66         obj_put(channel);
67 
68         current_thread->thread_ctx->state = TS_WAITING;
69         sched();
70         eret_to_thread(switch_context());
71         BUG_ON(1);
72     } else {
73         kdebug("%s: !list_empty(&channel->msg_queue)\n", __func__);
74         msg_entry = list_entry(
75             channel->msg_queue.next, struct msg_entry, msg_queue_node);
76         copy_len = MIN(msg_entry->client_msg_record.send_len, recv_len);
77         ret = copy_to_user(
78             recv_buf, msg_entry->client_msg_record.ksend_buf, copy_len);
79         if (ret < 0) {
80             ret = -EFAULT;
81             goto out_unlock;
82         }
83 
84         ret = copy_to_user(info,
85                            (char *)&msg_entry->client_msg_record.info,
86                            sizeof(struct src_msginfo));
87         if (ret < 0) {
88             ret = -EFAULT;
89             goto out_unlock;
90         }
91         list_del(&msg_entry->msg_queue_node);
92 
93         memcpy(&msg_hdl->client_msg_record,
94                &msg_entry->client_msg_record,
95                sizeof(struct client_msg_record));
96         ret = 0;
97 
98         kfree(msg_entry->client_msg_record.ksend_buf);
99         kfree(msg_entry);
100 
101     out_unlock:
102         unlock(&msg_hdl->lock);
103         unlock(&channel->lock);
104         return ret;
105     }
106 }
107 
108 /*
109  * Client calling __tee_msg_send will cause
110  * 1. if thread queue is empty, client thread's msg will be inserted
111  *    into msg queue
112  * 2. if not, client thread will directly choose and sched_enqueue one
113  *    server thread
114  */
__tee_msg_send(struct channel * channel,struct client_msg_record * client_msg_record)115 static int __tee_msg_send(struct channel *channel,
116                           struct client_msg_record *client_msg_record)
117 {
118     struct msg_hdl *msg_hdl;
119     struct msg_entry *msg_entry;
120     struct thread *client, *server;
121     size_t copy_len;
122     int ret = 0;
123 
124     lock(&channel->lock);
125 
126     if (channel->state == CHANNEL_INVALID) {
127         ret = -EINVAL;
128         goto out;
129     }
130 
131     if (list_empty(&channel->thread_queue)) {
132         kdebug("%s: list_empty(&channel->thread_queue)\n", __func__);
133         msg_entry = kmalloc(sizeof(*msg_entry));
134 
135         memcpy(&msg_entry->client_msg_record,
136                client_msg_record,
137                sizeof(*client_msg_record));
138 
139         list_append(&msg_entry->msg_queue_node, &channel->msg_queue);
140     } else {
141         kdebug("%s: !list_empty(&channel->thread_queue)\n", __func__);
142         msg_hdl = list_entry(
143             channel->thread_queue.next, struct msg_hdl, thread_queue_node);
144 
145         lock(&msg_hdl->lock);
146         list_del(&msg_hdl->thread_queue_node);
147 
148         memcpy(&msg_hdl->client_msg_record,
149                client_msg_record,
150                sizeof(*client_msg_record));
151 
152         server = msg_hdl->server_msg_record.server;
153         client = current_thread;
154 
155         current_thread = server;
156         switch_thread_vmspace_to(server);
157 
158         copy_len = MIN(client_msg_record->send_len,
159                        msg_hdl->server_msg_record.recv_len);
160         ret = copy_to_user(msg_hdl->server_msg_record.recv_buf,
161                            client_msg_record->ksend_buf,
162                            copy_len);
163         if (ret < 0) {
164             goto out_copy;
165         }
166 
167         ret = copy_to_user(msg_hdl->server_msg_record.info,
168                            (char *)&msg_hdl->client_msg_record.info,
169                            sizeof(struct src_msginfo));
170 
171     out_copy:
172         current_thread = client;
173         switch_thread_vmspace_to(client);
174         if (ret < 0) {
175             ret = -EFAULT;
176             goto out_unlock;
177         }
178 
179         kfree(client_msg_record->ksend_buf);
180         arch_set_thread_return(server, 0);
181         server->thread_ctx->state = TS_INTER;
182         BUG_ON(sched_enqueue(server));
183         kdebug("%s: enqueued %s\n", __func__, server->cap_group->cap_group_name);
184     out_unlock:
185         unlock(&msg_hdl->lock);
186     }
187 
188 out:
189     unlock(&channel->lock);
190     return ret;
191 }
192 
193 /*
194  * After client calling __tee_msg_send, client thread will be record in
195  * msg_hdl and not be sched until server thread __tee_msg_reply to enqueue it
196  */
__tee_msg_call(struct channel * channel,void * send_buf,size_t send_len,void * recv_buf,size_t recv_len,struct timespec * timeout)197 static int __tee_msg_call(struct channel *channel, void *send_buf,
198                           size_t send_len, void *recv_buf, size_t recv_len,
199                           struct timespec *timeout)
200 {
201     struct client_msg_record client_msg_record;
202     void *ksend_buf;
203     int ret;
204 
205     kdebug("%s: %s calls %s\n",
206           __func__,
207           current_cap_group->cap_group_name,
208           channel->creater->cap_group_name);
209 
210     if ((ksend_buf = kmalloc(send_len)) == NULL) {
211         ret = -ENOMEM;
212         goto out_fail;
213     }
214     ret = copy_from_user(ksend_buf, send_buf, send_len);
215     if (ret < 0) {
216         ret = -EFAULT;
217         goto out_free_ksend_buf;
218     }
219 
220     client_msg_record.client = current_thread;
221     client_msg_record.ksend_buf = ksend_buf;
222     client_msg_record.send_len = send_len;
223     client_msg_record.recv_buf = recv_buf;
224     client_msg_record.recv_len = recv_len;
225     client_msg_record.info.msg_type = MSG_TYPE_CALL;
226     client_msg_record.info.src_pid = current_cap_group->pid;
227     client_msg_record.info.src_tid = current_thread->cap;
228     if (current_cap_group->pid == GTASK_PID
229         && current_thread->cap == GTASK_TID) {
230         client_msg_record.info.src_pid = 0;
231         client_msg_record.info.src_tid = 0;
232     }
233 
234     ret = __tee_msg_send(channel, &client_msg_record);
235     if (ret != 0) {
236         goto out_free_ksend_buf;
237     }
238 
239     /* obj_put due to noreturn */
240     obj_put(channel);
241 
242     current_thread->thread_ctx->state = TS_WAITING;
243     sched();
244     eret_to_thread(switch_context());
245     BUG_ON(1);
246 
247 out_free_ksend_buf:
248     kfree(ksend_buf);
249 
250 out_fail:
251     return ret;
252 }
253 
254 /* Enqueue blocking client thread */
__tee_msg_reply(struct msg_hdl * msg_hdl,void * reply_buf,size_t reply_len)255 static int __tee_msg_reply(struct msg_hdl *msg_hdl, void *reply_buf,
256                            size_t reply_len)
257 {
258     struct thread *client, *server;
259     void *kreply_buf;
260     size_t copy_len;
261     int ret = 0;
262 
263     kdebug("%s: %s replies to %s\n",
264           __func__,
265           current_cap_group->cap_group_name,
266           msg_hdl->client_msg_record.client->cap_group->cap_group_name);
267 
268     lock(&msg_hdl->lock);
269 
270     if (msg_hdl->client_msg_record.info.msg_type != MSG_TYPE_CALL) {
271         ret = -EINVAL;
272         goto out;
273     }
274 
275     if ((kreply_buf = kmalloc(reply_len)) == NULL) {
276         ret = -ENOMEM;
277         goto out;
278     }
279     ret = copy_from_user(kreply_buf, reply_buf, reply_len);
280     if (ret < 0) {
281         ret = -EFAULT;
282         goto out_free_kreply_buf;
283     }
284 
285     client = msg_hdl->client_msg_record.client;
286     server = current_thread;
287 
288     current_thread = client;
289     switch_thread_vmspace_to(client);
290 
291     copy_len = MIN(reply_len, msg_hdl->client_msg_record.recv_len);
292     ret =
293         copy_to_user(msg_hdl->client_msg_record.recv_buf, kreply_buf, copy_len);
294 
295     current_thread = server;
296     switch_thread_vmspace_to(server);
297     if (ret < 0) {
298         ret = -EFAULT;
299         goto out_free_kreply_buf;
300     }
301 
302     /*
303      * Wait for client's kernel stack to make sure that
304      * sched_enqueue(client) executes after sched() in client's
305      * __tee_msg_call. Note that wait_for_kernel_stack executes with
306      * server's msg_hdl locked. Client should NOT hold the lock of msg_hdl,
307      * which is established under the assumption that msg_hdl cap should not
308      * be distributed to others.
309      */
310     wait_for_kernel_stack(client);
311 
312     arch_set_thread_return(client, 0);
313     client->thread_ctx->state = TS_INTER;
314     BUG_ON(sched_enqueue(client));
315 
316     /* A call cannot be replied successfully twice. */
317     msg_hdl->client_msg_record.info.msg_type = MSG_TYPE_INVALID;
318 
319 out_free_kreply_buf:
320     kfree(kreply_buf);
321 out:
322     unlock(&msg_hdl->lock);
323     return ret;
324 }
325 
326 /* __tee_msg_send and return directly */
__tee_msg_notify(struct channel * channel,void * send_buf,size_t send_len)327 static int __tee_msg_notify(struct channel *channel, void *send_buf,
328                             size_t send_len)
329 {
330     struct client_msg_record client_msg_record;
331     void *ksend_buf;
332     int ret;
333 
334     kdebug("%s: %s notifies %s\n",
335           __func__,
336           current_cap_group->cap_group_name,
337           channel->creater->cap_group_name);
338 
339     if ((ksend_buf = kmalloc(send_len)) == NULL) {
340         ret = -ENOMEM;
341         goto out_fail;
342     }
343     ret = copy_from_user(ksend_buf, send_buf, send_len);
344     if (ret < 0) {
345         ret = -EFAULT;
346         goto out_free_ksend_buf;
347     }
348 
349     client_msg_record.client = current_thread;
350     client_msg_record.ksend_buf = ksend_buf;
351     client_msg_record.send_len = send_len;
352     client_msg_record.info.msg_type = MSG_TYPE_NOTIF;
353     client_msg_record.info.src_pid = current_cap_group->pid;
354     client_msg_record.info.src_tid = current_thread->cap;
355     if (current_cap_group->pid == GTASK_PID
356         && current_thread->cap == GTASK_TID) {
357         client_msg_record.info.src_pid = 0;
358         client_msg_record.info.src_tid = 0;
359     }
360 
361     ret = __tee_msg_send(channel, &client_msg_record);
362     if (ret != 0) {
363         goto out_free_ksend_buf;
364     }
365 
366     return 0;
367 
368 out_free_ksend_buf:
369     kfree(ksend_buf);
370 
371 out_fail:
372     return ret;
373 }
374 
375 /* Wake up all blocking clients in the msg_queue */
__wake_up_all_clients(struct channel * channel)376 static int __wake_up_all_clients(struct channel *channel)
377 {
378     struct msg_entry *entry;
379     struct thread *client;
380 
381     for_each_in_list (
382         entry, struct msg_entry, msg_queue_node, &channel->msg_queue) {
383         if (entry->client_msg_record.info.msg_type == MSG_TYPE_CALL) {
384             client = entry->client_msg_record.client;
385             BUG_ON(client->thread_ctx->state != TS_WAITING
386                    && client->thread_ctx->state != TS_EXIT);
387             if (client->thread_ctx->state == TS_WAITING) {
388                 arch_set_thread_return(client, -EINVAL);
389                 client->thread_ctx->state = TS_INTER;
390                 BUG_ON(sched_enqueue(client));
391             }
392         }
393     }
394 
395     return 0;
396 }
397 
398 /*
399  * Destroy waiting nodes in the msg queue of the given channel
400  * which belong to the given cap_group
401  */
__destory_waiting_node(struct channel * channel,struct cap_group * cap_group)402 static int __destory_waiting_node(struct channel *channel,
403                                   struct cap_group *cap_group)
404 {
405     struct msg_entry *entry;
406 
407     for_each_in_list (
408         entry, struct msg_entry, msg_queue_node, &channel->msg_queue) {
409         if (entry->client_msg_record.client->cap_group == cap_group) {
410             list_del(&entry->msg_queue_node);
411             kfree(entry->client_msg_record.ksend_buf);
412             kfree(entry);
413         }
414     }
415 
416     return 0;
417 }
418 
419 /*
420  * close_channel will be called if
421  * 1. channel's creater calls sys_tee_msg_stop_channel
422  * 2. recycler calls sys_cap_group_recycle
423  */
close_channel(struct channel * channel,struct cap_group * cap_group)424 int close_channel(struct channel *channel, struct cap_group *cap_group)
425 {
426     lock(&channel->lock);
427     if (channel->creater == cap_group) {
428         channel->state = CHANNEL_INVALID;
429         __wake_up_all_clients(channel);
430     } else {
431         __destory_waiting_node(channel, cap_group);
432     }
433     unlock(&channel->lock);
434     return 0;
435 }
436 
sys_tee_msg_create_msg_hdl(void)437 int sys_tee_msg_create_msg_hdl(void)
438 {
439     struct msg_hdl *msg_hdl = NULL;
440     int msg_hdl_cap = 0;
441     int ret = 0;
442 
443     msg_hdl = obj_alloc(TYPE_MSG_HDL, sizeof(*msg_hdl));
444     if (!msg_hdl) {
445         ret = -ENOMEM;
446         goto out_fail;
447     }
448 
449     /* init msg_hdl */
450     lock_init(&msg_hdl->lock);
451 
452     msg_hdl_cap = cap_alloc(current_cap_group, msg_hdl);
453     if (msg_hdl_cap < 0) {
454         ret = msg_hdl_cap;
455         goto out_free_obj;
456     }
457 
458     return msg_hdl_cap;
459 
460 out_free_obj:
461     obj_free(msg_hdl);
462 
463 out_fail:
464     return ret;
465 }
466 
sys_tee_msg_create_channel(void)467 int sys_tee_msg_create_channel(void)
468 {
469     struct channel *channel = NULL;
470     int channel_cap = 0;
471     int ret = 0;
472 
473     channel = obj_alloc(TYPE_CHANNEL, sizeof(*channel));
474     if (!channel) {
475         ret = -ENOMEM;
476         goto out_fail;
477     }
478 
479     /* init channel */
480     lock_init(&channel->lock);
481     init_list_head(&channel->msg_queue);
482     init_list_head(&channel->thread_queue);
483     channel->creater = current_cap_group;
484     channel->state = CHANNEL_VALID;
485 
486     channel_cap = cap_alloc(current_cap_group, channel);
487     if (channel_cap < 0) {
488         ret = channel_cap;
489         goto out_free_obj;
490     }
491 
492     return channel_cap;
493 
494 out_free_obj:
495     obj_free(channel);
496 
497 out_fail:
498     return ret;
499 }
500 
sys_tee_msg_stop_channel(int channel_cap)501 int sys_tee_msg_stop_channel(int channel_cap)
502 {
503     struct channel *channel;
504     int ret;
505 
506     channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
507     if (channel == NULL) {
508         ret = -ECAPBILITY;
509         goto out_fail_get_channel;
510     }
511 
512     if (channel->creater == current_cap_group) {
513         ret = close_channel(channel, current_cap_group);
514     } else {
515         ret = -EINVAL;
516     }
517 
518     obj_put(channel);
519 
520 out_fail_get_channel:
521     return ret;
522 }
523 
sys_tee_msg_receive(int channel_cap,void * recv_buf,size_t recv_len,int msg_hdl_cap,void * info,int timeout)524 int sys_tee_msg_receive(int channel_cap, void *recv_buf, size_t recv_len,
525                         int msg_hdl_cap, void *info, int timeout)
526 {
527     struct channel *channel;
528     struct msg_hdl *msg_hdl;
529     int ret;
530 
531     if (check_user_addr_range((vaddr_t)recv_buf, recv_len) != 0) {
532         return -EINVAL;
533     }
534     if (check_user_addr_range((vaddr_t)info, sizeof(struct src_msginfo)) != 0) {
535         return -EINVAL;
536     }
537 
538     channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
539     if (channel == NULL) {
540         ret = -ECAPBILITY;
541         goto out_fail_get_channel;
542     }
543     msg_hdl = obj_get(current_cap_group, msg_hdl_cap, TYPE_MSG_HDL);
544     if (msg_hdl == NULL) {
545         ret = -ECAPBILITY;
546         goto out_fail_get_msg_hdl;
547     }
548 
549     /* Assume __tee_msg_receive will obj_put channel & msg_hdl if noreturn
550      */
551     ret =
552         __tee_msg_receive(channel, recv_buf, recv_len, msg_hdl, info, timeout);
553 
554     obj_put(msg_hdl);
555 
556 out_fail_get_msg_hdl:
557     obj_put(channel);
558 
559 out_fail_get_channel:
560     return ret;
561 }
562 
sys_tee_msg_call(int channel_cap,void * send_buf,size_t send_len,void * recv_buf,size_t recv_len,struct timespec * timeout)563 int sys_tee_msg_call(int channel_cap, void *send_buf, size_t send_len,
564                      void *recv_buf, size_t recv_len, struct timespec *timeout)
565 {
566     struct channel *channel;
567     int ret;
568 
569     if (check_user_addr_range((vaddr_t)send_buf, send_len) != 0) {
570         return -EINVAL;
571     }
572     if (check_user_addr_range((vaddr_t)recv_buf, recv_len) != 0) {
573         return -EINVAL;
574     }
575 
576     channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
577     if (channel == NULL) {
578         ret = -ECAPBILITY;
579         goto out_fail_get_channel;
580     }
581 
582     /* Assume __tee_msg_call will obj_put channel if noreturn */
583     ret = __tee_msg_call(
584         channel, send_buf, send_len, recv_buf, recv_len, timeout);
585 
586     obj_put(channel);
587 
588 out_fail_get_channel:
589     return ret;
590 }
591 
sys_tee_msg_reply(int msg_hdl_cap,void * reply_buf,size_t reply_len)592 int sys_tee_msg_reply(int msg_hdl_cap, void *reply_buf, size_t reply_len)
593 {
594     struct msg_hdl *msg_hdl;
595     int ret;
596 
597     if (check_user_addr_range((vaddr_t)reply_buf, reply_len) != 0) {
598         return -EINVAL;
599     }
600 
601     msg_hdl = obj_get(current_cap_group, msg_hdl_cap, TYPE_MSG_HDL);
602     if (msg_hdl == NULL) {
603         ret = -ECAPBILITY;
604         goto out_fail_get_msg_hdl;
605     }
606 
607     ret = __tee_msg_reply(msg_hdl, reply_buf, reply_len);
608 
609     obj_put(msg_hdl);
610 
611 out_fail_get_msg_hdl:
612     return ret;
613 }
614 
sys_tee_msg_notify(int channel_cap,void * send_buf,size_t send_len)615 int sys_tee_msg_notify(int channel_cap, void *send_buf, size_t send_len)
616 {
617     struct channel *channel;
618     int ret;
619 
620     if (check_user_addr_range((vaddr_t)send_buf, send_len) != 0) {
621         return -EINVAL;
622     }
623 
624     channel = obj_get(current_cap_group, channel_cap, TYPE_CHANNEL);
625     if (channel == NULL) {
626         ret = -ECAPBILITY;
627         goto out_fail_get_channel;
628     }
629 
630     ret = __tee_msg_notify(channel, send_buf, send_len);
631 
632     obj_put(channel);
633 
634 out_fail_get_channel:
635     return ret;
636 }
637 
channel_deinit(void * ptr)638 void channel_deinit(void *ptr)
639 {
640     struct msg_entry *entry;
641     struct channel *channel;
642 
643     channel = (struct channel *)ptr;
644 
645     for_each_in_list (
646         entry, struct msg_entry, msg_queue_node, &channel->msg_queue) {
647         list_del(&entry->msg_queue_node);
648         kfree(entry->client_msg_record.ksend_buf);
649         kfree(entry);
650     }
651 }
652 
msg_hdl_deinit(void * ptr)653 void msg_hdl_deinit(void *ptr)
654 {
655 }
656