• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /* interpreters module */
3 /* low-level access to interpreter primitives */
4 
5 #include "Python.h"
6 #include "frameobject.h"
7 #include "interpreteridobject.h"
8 
9 
10 static char *
_copy_raw_string(PyObject * strobj)11 _copy_raw_string(PyObject *strobj)
12 {
13     const char *str = PyUnicode_AsUTF8(strobj);
14     if (str == NULL) {
15         return NULL;
16     }
17     char *copied = PyMem_Malloc(strlen(str)+1);
18     if (copied == NULL) {
19         PyErr_NoMemory();
20         return NULL;
21     }
22     strcpy(copied, str);
23     return copied;
24 }
25 
26 static PyInterpreterState *
_get_current(void)27 _get_current(void)
28 {
29     // PyInterpreterState_Get() aborts if lookup fails, so don't need
30     // to check the result for NULL.
31     return PyInterpreterState_Get();
32 }
33 
34 
35 /* data-sharing-specific code ***********************************************/
36 
37 struct _sharednsitem {
38     char *name;
39     _PyCrossInterpreterData data;
40 };
41 
42 static void _sharednsitem_clear(struct _sharednsitem *);  // forward
43 
44 static int
_sharednsitem_init(struct _sharednsitem * item,PyObject * key,PyObject * value)45 _sharednsitem_init(struct _sharednsitem *item, PyObject *key, PyObject *value)
46 {
47     item->name = _copy_raw_string(key);
48     if (item->name == NULL) {
49         return -1;
50     }
51     if (_PyObject_GetCrossInterpreterData(value, &item->data) != 0) {
52         _sharednsitem_clear(item);
53         return -1;
54     }
55     return 0;
56 }
57 
58 static void
_sharednsitem_clear(struct _sharednsitem * item)59 _sharednsitem_clear(struct _sharednsitem *item)
60 {
61     if (item->name != NULL) {
62         PyMem_Free(item->name);
63         item->name = NULL;
64     }
65     _PyCrossInterpreterData_Release(&item->data);
66 }
67 
68 static int
_sharednsitem_apply(struct _sharednsitem * item,PyObject * ns)69 _sharednsitem_apply(struct _sharednsitem *item, PyObject *ns)
70 {
71     PyObject *name = PyUnicode_FromString(item->name);
72     if (name == NULL) {
73         return -1;
74     }
75     PyObject *value = _PyCrossInterpreterData_NewObject(&item->data);
76     if (value == NULL) {
77         Py_DECREF(name);
78         return -1;
79     }
80     int res = PyDict_SetItem(ns, name, value);
81     Py_DECREF(name);
82     Py_DECREF(value);
83     return res;
84 }
85 
86 typedef struct _sharedns {
87     Py_ssize_t len;
88     struct _sharednsitem* items;
89 } _sharedns;
90 
91 static _sharedns *
_sharedns_new(Py_ssize_t len)92 _sharedns_new(Py_ssize_t len)
93 {
94     _sharedns *shared = PyMem_NEW(_sharedns, 1);
95     if (shared == NULL) {
96         PyErr_NoMemory();
97         return NULL;
98     }
99     shared->len = len;
100     shared->items = PyMem_NEW(struct _sharednsitem, len);
101     if (shared->items == NULL) {
102         PyErr_NoMemory();
103         PyMem_Free(shared);
104         return NULL;
105     }
106     return shared;
107 }
108 
109 static void
_sharedns_free(_sharedns * shared)110 _sharedns_free(_sharedns *shared)
111 {
112     for (Py_ssize_t i=0; i < shared->len; i++) {
113         _sharednsitem_clear(&shared->items[i]);
114     }
115     PyMem_Free(shared->items);
116     PyMem_Free(shared);
117 }
118 
119 static _sharedns *
_get_shared_ns(PyObject * shareable)120 _get_shared_ns(PyObject *shareable)
121 {
122     if (shareable == NULL || shareable == Py_None) {
123         return NULL;
124     }
125     Py_ssize_t len = PyDict_Size(shareable);
126     if (len == 0) {
127         return NULL;
128     }
129 
130     _sharedns *shared = _sharedns_new(len);
131     if (shared == NULL) {
132         return NULL;
133     }
134     Py_ssize_t pos = 0;
135     for (Py_ssize_t i=0; i < len; i++) {
136         PyObject *key, *value;
137         if (PyDict_Next(shareable, &pos, &key, &value) == 0) {
138             break;
139         }
140         if (_sharednsitem_init(&shared->items[i], key, value) != 0) {
141             break;
142         }
143     }
144     if (PyErr_Occurred()) {
145         _sharedns_free(shared);
146         return NULL;
147     }
148     return shared;
149 }
150 
151 static int
_sharedns_apply(_sharedns * shared,PyObject * ns)152 _sharedns_apply(_sharedns *shared, PyObject *ns)
153 {
154     for (Py_ssize_t i=0; i < shared->len; i++) {
155         if (_sharednsitem_apply(&shared->items[i], ns) != 0) {
156             return -1;
157         }
158     }
159     return 0;
160 }
161 
162 // Ultimately we'd like to preserve enough information about the
163 // exception and traceback that we could re-constitute (or at least
164 // simulate, a la traceback.TracebackException), and even chain, a copy
165 // of the exception in the calling interpreter.
166 
167 typedef struct _sharedexception {
168     char *name;
169     char *msg;
170 } _sharedexception;
171 
172 static _sharedexception *
_sharedexception_new(void)173 _sharedexception_new(void)
174 {
175     _sharedexception *err = PyMem_NEW(_sharedexception, 1);
176     if (err == NULL) {
177         PyErr_NoMemory();
178         return NULL;
179     }
180     err->name = NULL;
181     err->msg = NULL;
182     return err;
183 }
184 
185 static void
_sharedexception_clear(_sharedexception * exc)186 _sharedexception_clear(_sharedexception *exc)
187 {
188     if (exc->name != NULL) {
189         PyMem_Free(exc->name);
190     }
191     if (exc->msg != NULL) {
192         PyMem_Free(exc->msg);
193     }
194 }
195 
196 static void
_sharedexception_free(_sharedexception * exc)197 _sharedexception_free(_sharedexception *exc)
198 {
199     _sharedexception_clear(exc);
200     PyMem_Free(exc);
201 }
202 
203 static _sharedexception *
_sharedexception_bind(PyObject * exctype,PyObject * exc,PyObject * tb)204 _sharedexception_bind(PyObject *exctype, PyObject *exc, PyObject *tb)
205 {
206     assert(exctype != NULL);
207     char *failure = NULL;
208 
209     _sharedexception *err = _sharedexception_new();
210     if (err == NULL) {
211         goto finally;
212     }
213 
214     PyObject *name = PyUnicode_FromFormat("%S", exctype);
215     if (name == NULL) {
216         failure = "unable to format exception type name";
217         goto finally;
218     }
219     err->name = _copy_raw_string(name);
220     Py_DECREF(name);
221     if (err->name == NULL) {
222         if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
223             failure = "out of memory copying exception type name";
224         } else {
225             failure = "unable to encode and copy exception type name";
226         }
227         goto finally;
228     }
229 
230     if (exc != NULL) {
231         PyObject *msg = PyUnicode_FromFormat("%S", exc);
232         if (msg == NULL) {
233             failure = "unable to format exception message";
234             goto finally;
235         }
236         err->msg = _copy_raw_string(msg);
237         Py_DECREF(msg);
238         if (err->msg == NULL) {
239             if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
240                 failure = "out of memory copying exception message";
241             } else {
242                 failure = "unable to encode and copy exception message";
243             }
244             goto finally;
245         }
246     }
247 
248 finally:
249     if (failure != NULL) {
250         PyErr_Clear();
251         if (err->name != NULL) {
252             PyMem_Free(err->name);
253             err->name = NULL;
254         }
255         err->msg = failure;
256     }
257     return err;
258 }
259 
260 static void
_sharedexception_apply(_sharedexception * exc,PyObject * wrapperclass)261 _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
262 {
263     if (exc->name != NULL) {
264         if (exc->msg != NULL) {
265             PyErr_Format(wrapperclass, "%s: %s",  exc->name, exc->msg);
266         }
267         else {
268             PyErr_SetString(wrapperclass, exc->name);
269         }
270     }
271     else if (exc->msg != NULL) {
272         PyErr_SetString(wrapperclass, exc->msg);
273     }
274     else {
275         PyErr_SetNone(wrapperclass);
276     }
277 }
278 
279 
280 /* channel-specific code ****************************************************/
281 
282 #define CHANNEL_SEND 1
283 #define CHANNEL_BOTH 0
284 #define CHANNEL_RECV -1
285 
286 static PyObject *ChannelError;
287 static PyObject *ChannelNotFoundError;
288 static PyObject *ChannelClosedError;
289 static PyObject *ChannelEmptyError;
290 static PyObject *ChannelNotEmptyError;
291 
292 static int
channel_exceptions_init(PyObject * ns)293 channel_exceptions_init(PyObject *ns)
294 {
295     // XXX Move the exceptions into per-module memory?
296 
297     // A channel-related operation failed.
298     ChannelError = PyErr_NewException("_xxsubinterpreters.ChannelError",
299                                       PyExc_RuntimeError, NULL);
300     if (ChannelError == NULL) {
301         return -1;
302     }
303     if (PyDict_SetItemString(ns, "ChannelError", ChannelError) != 0) {
304         return -1;
305     }
306 
307     // An operation tried to use a channel that doesn't exist.
308     ChannelNotFoundError = PyErr_NewException(
309             "_xxsubinterpreters.ChannelNotFoundError", ChannelError, NULL);
310     if (ChannelNotFoundError == NULL) {
311         return -1;
312     }
313     if (PyDict_SetItemString(ns, "ChannelNotFoundError", ChannelNotFoundError) != 0) {
314         return -1;
315     }
316 
317     // An operation tried to use a closed channel.
318     ChannelClosedError = PyErr_NewException(
319             "_xxsubinterpreters.ChannelClosedError", ChannelError, NULL);
320     if (ChannelClosedError == NULL) {
321         return -1;
322     }
323     if (PyDict_SetItemString(ns, "ChannelClosedError", ChannelClosedError) != 0) {
324         return -1;
325     }
326 
327     // An operation tried to pop from an empty channel.
328     ChannelEmptyError = PyErr_NewException(
329             "_xxsubinterpreters.ChannelEmptyError", ChannelError, NULL);
330     if (ChannelEmptyError == NULL) {
331         return -1;
332     }
333     if (PyDict_SetItemString(ns, "ChannelEmptyError", ChannelEmptyError) != 0) {
334         return -1;
335     }
336 
337     // An operation tried to close a non-empty channel.
338     ChannelNotEmptyError = PyErr_NewException(
339             "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL);
340     if (ChannelNotEmptyError == NULL) {
341         return -1;
342     }
343     if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) {
344         return -1;
345     }
346 
347     return 0;
348 }
349 
350 /* the channel queue */
351 
352 struct _channelitem;
353 
354 typedef struct _channelitem {
355     _PyCrossInterpreterData *data;
356     struct _channelitem *next;
357 } _channelitem;
358 
359 static _channelitem *
_channelitem_new(void)360 _channelitem_new(void)
361 {
362     _channelitem *item = PyMem_NEW(_channelitem, 1);
363     if (item == NULL) {
364         PyErr_NoMemory();
365         return NULL;
366     }
367     item->data = NULL;
368     item->next = NULL;
369     return item;
370 }
371 
372 static void
_channelitem_clear(_channelitem * item)373 _channelitem_clear(_channelitem *item)
374 {
375     if (item->data != NULL) {
376         _PyCrossInterpreterData_Release(item->data);
377         PyMem_Free(item->data);
378         item->data = NULL;
379     }
380     item->next = NULL;
381 }
382 
383 static void
_channelitem_free(_channelitem * item)384 _channelitem_free(_channelitem *item)
385 {
386     _channelitem_clear(item);
387     PyMem_Free(item);
388 }
389 
390 static void
_channelitem_free_all(_channelitem * item)391 _channelitem_free_all(_channelitem *item)
392 {
393     while (item != NULL) {
394         _channelitem *last = item;
395         item = item->next;
396         _channelitem_free(last);
397     }
398 }
399 
400 static _PyCrossInterpreterData *
_channelitem_popped(_channelitem * item)401 _channelitem_popped(_channelitem *item)
402 {
403     _PyCrossInterpreterData *data = item->data;
404     item->data = NULL;
405     _channelitem_free(item);
406     return data;
407 }
408 
409 typedef struct _channelqueue {
410     int64_t count;
411     _channelitem *first;
412     _channelitem *last;
413 } _channelqueue;
414 
415 static _channelqueue *
_channelqueue_new(void)416 _channelqueue_new(void)
417 {
418     _channelqueue *queue = PyMem_NEW(_channelqueue, 1);
419     if (queue == NULL) {
420         PyErr_NoMemory();
421         return NULL;
422     }
423     queue->count = 0;
424     queue->first = NULL;
425     queue->last = NULL;
426     return queue;
427 }
428 
429 static void
_channelqueue_clear(_channelqueue * queue)430 _channelqueue_clear(_channelqueue *queue)
431 {
432     _channelitem_free_all(queue->first);
433     queue->count = 0;
434     queue->first = NULL;
435     queue->last = NULL;
436 }
437 
438 static void
_channelqueue_free(_channelqueue * queue)439 _channelqueue_free(_channelqueue *queue)
440 {
441     _channelqueue_clear(queue);
442     PyMem_Free(queue);
443 }
444 
445 static int
_channelqueue_put(_channelqueue * queue,_PyCrossInterpreterData * data)446 _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
447 {
448     _channelitem *item = _channelitem_new();
449     if (item == NULL) {
450         return -1;
451     }
452     item->data = data;
453 
454     queue->count += 1;
455     if (queue->first == NULL) {
456         queue->first = item;
457     }
458     else {
459         queue->last->next = item;
460     }
461     queue->last = item;
462     return 0;
463 }
464 
465 static _PyCrossInterpreterData *
_channelqueue_get(_channelqueue * queue)466 _channelqueue_get(_channelqueue *queue)
467 {
468     _channelitem *item = queue->first;
469     if (item == NULL) {
470         return NULL;
471     }
472     queue->first = item->next;
473     if (queue->last == item) {
474         queue->last = NULL;
475     }
476     queue->count -= 1;
477 
478     return _channelitem_popped(item);
479 }
480 
481 /* channel-interpreter associations */
482 
483 struct _channelend;
484 
485 typedef struct _channelend {
486     struct _channelend *next;
487     int64_t interp;
488     int open;
489 } _channelend;
490 
491 static _channelend *
_channelend_new(int64_t interp)492 _channelend_new(int64_t interp)
493 {
494     _channelend *end = PyMem_NEW(_channelend, 1);
495     if (end == NULL) {
496         PyErr_NoMemory();
497         return NULL;
498     }
499     end->next = NULL;
500     end->interp = interp;
501     end->open = 1;
502     return end;
503 }
504 
505 static void
_channelend_free(_channelend * end)506 _channelend_free(_channelend *end)
507 {
508     PyMem_Free(end);
509 }
510 
511 static void
_channelend_free_all(_channelend * end)512 _channelend_free_all(_channelend *end)
513 {
514     while (end != NULL) {
515         _channelend *last = end;
516         end = end->next;
517         _channelend_free(last);
518     }
519 }
520 
521 static _channelend *
_channelend_find(_channelend * first,int64_t interp,_channelend ** pprev)522 _channelend_find(_channelend *first, int64_t interp, _channelend **pprev)
523 {
524     _channelend *prev = NULL;
525     _channelend *end = first;
526     while (end != NULL) {
527         if (end->interp == interp) {
528             break;
529         }
530         prev = end;
531         end = end->next;
532     }
533     if (pprev != NULL) {
534         *pprev = prev;
535     }
536     return end;
537 }
538 
539 typedef struct _channelassociations {
540     // Note that the list entries are never removed for interpreter
541     // for which the channel is closed.  This should not be a problem in
542     // practice.  Also, a channel isn't automatically closed when an
543     // interpreter is destroyed.
544     int64_t numsendopen;
545     int64_t numrecvopen;
546     _channelend *send;
547     _channelend *recv;
548 } _channelends;
549 
550 static _channelends *
_channelends_new(void)551 _channelends_new(void)
552 {
553     _channelends *ends = PyMem_NEW(_channelends, 1);
554     if (ends== NULL) {
555         return NULL;
556     }
557     ends->numsendopen = 0;
558     ends->numrecvopen = 0;
559     ends->send = NULL;
560     ends->recv = NULL;
561     return ends;
562 }
563 
564 static void
_channelends_clear(_channelends * ends)565 _channelends_clear(_channelends *ends)
566 {
567     _channelend_free_all(ends->send);
568     ends->send = NULL;
569     ends->numsendopen = 0;
570 
571     _channelend_free_all(ends->recv);
572     ends->recv = NULL;
573     ends->numrecvopen = 0;
574 }
575 
576 static void
_channelends_free(_channelends * ends)577 _channelends_free(_channelends *ends)
578 {
579     _channelends_clear(ends);
580     PyMem_Free(ends);
581 }
582 
583 static _channelend *
_channelends_add(_channelends * ends,_channelend * prev,int64_t interp,int send)584 _channelends_add(_channelends *ends, _channelend *prev, int64_t interp,
585                  int send)
586 {
587     _channelend *end = _channelend_new(interp);
588     if (end == NULL) {
589         return NULL;
590     }
591 
592     if (prev == NULL) {
593         if (send) {
594             ends->send = end;
595         }
596         else {
597             ends->recv = end;
598         }
599     }
600     else {
601         prev->next = end;
602     }
603     if (send) {
604         ends->numsendopen += 1;
605     }
606     else {
607         ends->numrecvopen += 1;
608     }
609     return end;
610 }
611 
612 static int
_channelends_associate(_channelends * ends,int64_t interp,int send)613 _channelends_associate(_channelends *ends, int64_t interp, int send)
614 {
615     _channelend *prev;
616     _channelend *end = _channelend_find(send ? ends->send : ends->recv,
617                                         interp, &prev);
618     if (end != NULL) {
619         if (!end->open) {
620             PyErr_SetString(ChannelClosedError, "channel already closed");
621             return -1;
622         }
623         // already associated
624         return 0;
625     }
626     if (_channelends_add(ends, prev, interp, send) == NULL) {
627         return -1;
628     }
629     return 0;
630 }
631 
632 static int
_channelends_is_open(_channelends * ends)633 _channelends_is_open(_channelends *ends)
634 {
635     if (ends->numsendopen != 0 || ends->numrecvopen != 0) {
636         return 1;
637     }
638     if (ends->send == NULL && ends->recv == NULL) {
639         return 1;
640     }
641     return 0;
642 }
643 
644 static void
_channelends_close_end(_channelends * ends,_channelend * end,int send)645 _channelends_close_end(_channelends *ends, _channelend *end, int send)
646 {
647     end->open = 0;
648     if (send) {
649         ends->numsendopen -= 1;
650     }
651     else {
652         ends->numrecvopen -= 1;
653     }
654 }
655 
656 static int
_channelends_close_interpreter(_channelends * ends,int64_t interp,int which)657 _channelends_close_interpreter(_channelends *ends, int64_t interp, int which)
658 {
659     _channelend *prev;
660     _channelend *end;
661     if (which >= 0) {  // send/both
662         end = _channelend_find(ends->send, interp, &prev);
663         if (end == NULL) {
664             // never associated so add it
665             end = _channelends_add(ends, prev, interp, 1);
666             if (end == NULL) {
667                 return -1;
668             }
669         }
670         _channelends_close_end(ends, end, 1);
671     }
672     if (which <= 0) {  // recv/both
673         end = _channelend_find(ends->recv, interp, &prev);
674         if (end == NULL) {
675             // never associated so add it
676             end = _channelends_add(ends, prev, interp, 0);
677             if (end == NULL) {
678                 return -1;
679             }
680         }
681         _channelends_close_end(ends, end, 0);
682     }
683     return 0;
684 }
685 
686 static void
_channelends_close_all(_channelends * ends,int which,int force)687 _channelends_close_all(_channelends *ends, int which, int force)
688 {
689     // XXX Handle the ends.
690     // XXX Handle force is True.
691 
692     // Ensure all the "send"-associated interpreters are closed.
693     _channelend *end;
694     for (end = ends->send; end != NULL; end = end->next) {
695         _channelends_close_end(ends, end, 1);
696     }
697 
698     // Ensure all the "recv"-associated interpreters are closed.
699     for (end = ends->recv; end != NULL; end = end->next) {
700         _channelends_close_end(ends, end, 0);
701     }
702 }
703 
704 /* channels */
705 
706 struct _channel;
707 struct _channel_closing;
708 static void _channel_clear_closing(struct _channel *);
709 static void _channel_finish_closing(struct _channel *);
710 
711 typedef struct _channel {
712     PyThread_type_lock mutex;
713     _channelqueue *queue;
714     _channelends *ends;
715     int open;
716     struct _channel_closing *closing;
717 } _PyChannelState;
718 
719 static _PyChannelState *
_channel_new(void)720 _channel_new(void)
721 {
722     _PyChannelState *chan = PyMem_NEW(_PyChannelState, 1);
723     if (chan == NULL) {
724         return NULL;
725     }
726     chan->mutex = PyThread_allocate_lock();
727     if (chan->mutex == NULL) {
728         PyMem_Free(chan);
729         PyErr_SetString(ChannelError,
730                         "can't initialize mutex for new channel");
731         return NULL;
732     }
733     chan->queue = _channelqueue_new();
734     if (chan->queue == NULL) {
735         PyMem_Free(chan);
736         return NULL;
737     }
738     chan->ends = _channelends_new();
739     if (chan->ends == NULL) {
740         _channelqueue_free(chan->queue);
741         PyMem_Free(chan);
742         return NULL;
743     }
744     chan->open = 1;
745     chan->closing = NULL;
746     return chan;
747 }
748 
749 static void
_channel_free(_PyChannelState * chan)750 _channel_free(_PyChannelState *chan)
751 {
752     _channel_clear_closing(chan);
753     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
754     _channelqueue_free(chan->queue);
755     _channelends_free(chan->ends);
756     PyThread_release_lock(chan->mutex);
757 
758     PyThread_free_lock(chan->mutex);
759     PyMem_Free(chan);
760 }
761 
762 static int
_channel_add(_PyChannelState * chan,int64_t interp,_PyCrossInterpreterData * data)763 _channel_add(_PyChannelState *chan, int64_t interp,
764              _PyCrossInterpreterData *data)
765 {
766     int res = -1;
767     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
768 
769     if (!chan->open) {
770         PyErr_SetString(ChannelClosedError, "channel closed");
771         goto done;
772     }
773     if (_channelends_associate(chan->ends, interp, 1) != 0) {
774         goto done;
775     }
776 
777     if (_channelqueue_put(chan->queue, data) != 0) {
778         goto done;
779     }
780 
781     res = 0;
782 done:
783     PyThread_release_lock(chan->mutex);
784     return res;
785 }
786 
787 static _PyCrossInterpreterData *
_channel_next(_PyChannelState * chan,int64_t interp)788 _channel_next(_PyChannelState *chan, int64_t interp)
789 {
790     _PyCrossInterpreterData *data = NULL;
791     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
792 
793     if (!chan->open) {
794         PyErr_SetString(ChannelClosedError, "channel closed");
795         goto done;
796     }
797     if (_channelends_associate(chan->ends, interp, 0) != 0) {
798         goto done;
799     }
800 
801     data = _channelqueue_get(chan->queue);
802     if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
803         chan->open = 0;
804     }
805 
806 done:
807     PyThread_release_lock(chan->mutex);
808     if (chan->queue->count == 0) {
809         _channel_finish_closing(chan);
810     }
811     return data;
812 }
813 
814 static int
_channel_close_interpreter(_PyChannelState * chan,int64_t interp,int end)815 _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
816 {
817     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
818 
819     int res = -1;
820     if (!chan->open) {
821         PyErr_SetString(ChannelClosedError, "channel already closed");
822         goto done;
823     }
824 
825     if (_channelends_close_interpreter(chan->ends, interp, end) != 0) {
826         goto done;
827     }
828     chan->open = _channelends_is_open(chan->ends);
829 
830     res = 0;
831 done:
832     PyThread_release_lock(chan->mutex);
833     return res;
834 }
835 
836 static int
_channel_close_all(_PyChannelState * chan,int end,int force)837 _channel_close_all(_PyChannelState *chan, int end, int force)
838 {
839     int res = -1;
840     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
841 
842     if (!chan->open) {
843         PyErr_SetString(ChannelClosedError, "channel already closed");
844         goto done;
845     }
846 
847     if (!force && chan->queue->count > 0) {
848         PyErr_SetString(ChannelNotEmptyError,
849                         "may not be closed if not empty (try force=True)");
850         goto done;
851     }
852 
853     chan->open = 0;
854 
855     // We *could* also just leave these in place, since we've marked
856     // the channel as closed already.
857     _channelends_close_all(chan->ends, end, force);
858 
859     res = 0;
860 done:
861     PyThread_release_lock(chan->mutex);
862     return res;
863 }
864 
865 /* the set of channels */
866 
867 struct _channelref;
868 
869 typedef struct _channelref {
870     int64_t id;
871     _PyChannelState *chan;
872     struct _channelref *next;
873     Py_ssize_t objcount;
874 } _channelref;
875 
876 static _channelref *
_channelref_new(int64_t id,_PyChannelState * chan)877 _channelref_new(int64_t id, _PyChannelState *chan)
878 {
879     _channelref *ref = PyMem_NEW(_channelref, 1);
880     if (ref == NULL) {
881         return NULL;
882     }
883     ref->id = id;
884     ref->chan = chan;
885     ref->next = NULL;
886     ref->objcount = 0;
887     return ref;
888 }
889 
890 //static void
891 //_channelref_clear(_channelref *ref)
892 //{
893 //    ref->id = -1;
894 //    ref->chan = NULL;
895 //    ref->next = NULL;
896 //    ref->objcount = 0;
897 //}
898 
899 static void
_channelref_free(_channelref * ref)900 _channelref_free(_channelref *ref)
901 {
902     if (ref->chan != NULL) {
903         _channel_clear_closing(ref->chan);
904     }
905     //_channelref_clear(ref);
906     PyMem_Free(ref);
907 }
908 
909 static _channelref *
_channelref_find(_channelref * first,int64_t id,_channelref ** pprev)910 _channelref_find(_channelref *first, int64_t id, _channelref **pprev)
911 {
912     _channelref *prev = NULL;
913     _channelref *ref = first;
914     while (ref != NULL) {
915         if (ref->id == id) {
916             break;
917         }
918         prev = ref;
919         ref = ref->next;
920     }
921     if (pprev != NULL) {
922         *pprev = prev;
923     }
924     return ref;
925 }
926 
927 typedef struct _channels {
928     PyThread_type_lock mutex;
929     _channelref *head;
930     int64_t numopen;
931     int64_t next_id;
932 } _channels;
933 
934 static int
_channels_init(_channels * channels)935 _channels_init(_channels *channels)
936 {
937     if (channels->mutex == NULL) {
938         channels->mutex = PyThread_allocate_lock();
939         if (channels->mutex == NULL) {
940             PyErr_SetString(ChannelError,
941                             "can't initialize mutex for channel management");
942             return -1;
943         }
944     }
945     channels->head = NULL;
946     channels->numopen = 0;
947     channels->next_id = 0;
948     return 0;
949 }
950 
951 static int64_t
_channels_next_id(_channels * channels)952 _channels_next_id(_channels *channels)  // needs lock
953 {
954     int64_t id = channels->next_id;
955     if (id < 0) {
956         /* overflow */
957         PyErr_SetString(ChannelError,
958                         "failed to get a channel ID");
959         return -1;
960     }
961     channels->next_id += 1;
962     return id;
963 }
964 
965 static _PyChannelState *
_channels_lookup(_channels * channels,int64_t id,PyThread_type_lock * pmutex)966 _channels_lookup(_channels *channels, int64_t id, PyThread_type_lock *pmutex)
967 {
968     _PyChannelState *chan = NULL;
969     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
970     if (pmutex != NULL) {
971         *pmutex = NULL;
972     }
973 
974     _channelref *ref = _channelref_find(channels->head, id, NULL);
975     if (ref == NULL) {
976         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id);
977         goto done;
978     }
979     if (ref->chan == NULL || !ref->chan->open) {
980         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", id);
981         goto done;
982     }
983 
984     if (pmutex != NULL) {
985         // The mutex will be closed by the caller.
986         *pmutex = channels->mutex;
987     }
988 
989     chan = ref->chan;
990 done:
991     if (pmutex == NULL || *pmutex == NULL) {
992         PyThread_release_lock(channels->mutex);
993     }
994     return chan;
995 }
996 
997 static int64_t
_channels_add(_channels * channels,_PyChannelState * chan)998 _channels_add(_channels *channels, _PyChannelState *chan)
999 {
1000     int64_t cid = -1;
1001     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1002 
1003     // Create a new ref.
1004     int64_t id = _channels_next_id(channels);
1005     if (id < 0) {
1006         goto done;
1007     }
1008     _channelref *ref = _channelref_new(id, chan);
1009     if (ref == NULL) {
1010         goto done;
1011     }
1012 
1013     // Add it to the list.
1014     // We assume that the channel is a new one (not already in the list).
1015     ref->next = channels->head;
1016     channels->head = ref;
1017     channels->numopen += 1;
1018 
1019     cid = id;
1020 done:
1021     PyThread_release_lock(channels->mutex);
1022     return cid;
1023 }
1024 
1025 /* forward */
1026 static int _channel_set_closing(struct _channelref *, PyThread_type_lock);
1027 
1028 static int
_channels_close(_channels * channels,int64_t cid,_PyChannelState ** pchan,int end,int force)1029 _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan,
1030                 int end, int force)
1031 {
1032     int res = -1;
1033     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1034     if (pchan != NULL) {
1035         *pchan = NULL;
1036     }
1037 
1038     _channelref *ref = _channelref_find(channels->head, cid, NULL);
1039     if (ref == NULL) {
1040         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", cid);
1041         goto done;
1042     }
1043 
1044     if (ref->chan == NULL) {
1045         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid);
1046         goto done;
1047     }
1048     else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
1049         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid);
1050         goto done;
1051     }
1052     else {
1053         if (_channel_close_all(ref->chan, end, force) != 0) {
1054             if (end == CHANNEL_SEND &&
1055                     PyErr_ExceptionMatches(ChannelNotEmptyError)) {
1056                 if (ref->chan->closing != NULL) {
1057                     PyErr_Format(ChannelClosedError,
1058                                  "channel %" PRId64 " closed", cid);
1059                     goto done;
1060                 }
1061                 // Mark the channel as closing and return.  The channel
1062                 // will be cleaned up in _channel_next().
1063                 PyErr_Clear();
1064                 if (_channel_set_closing(ref, channels->mutex) != 0) {
1065                     goto done;
1066                 }
1067                 if (pchan != NULL) {
1068                     *pchan = ref->chan;
1069                 }
1070                 res = 0;
1071             }
1072             goto done;
1073         }
1074         if (pchan != NULL) {
1075             *pchan = ref->chan;
1076         }
1077         else  {
1078             _channel_free(ref->chan);
1079         }
1080         ref->chan = NULL;
1081     }
1082 
1083     res = 0;
1084 done:
1085     PyThread_release_lock(channels->mutex);
1086     return res;
1087 }
1088 
1089 static void
_channels_remove_ref(_channels * channels,_channelref * ref,_channelref * prev,_PyChannelState ** pchan)1090 _channels_remove_ref(_channels *channels, _channelref *ref, _channelref *prev,
1091                      _PyChannelState **pchan)
1092 {
1093     if (ref == channels->head) {
1094         channels->head = ref->next;
1095     }
1096     else {
1097         prev->next = ref->next;
1098     }
1099     channels->numopen -= 1;
1100 
1101     if (pchan != NULL) {
1102         *pchan = ref->chan;
1103     }
1104     _channelref_free(ref);
1105 }
1106 
1107 static int
_channels_remove(_channels * channels,int64_t id,_PyChannelState ** pchan)1108 _channels_remove(_channels *channels, int64_t id, _PyChannelState **pchan)
1109 {
1110     int res = -1;
1111     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1112 
1113     if (pchan != NULL) {
1114         *pchan = NULL;
1115     }
1116 
1117     _channelref *prev = NULL;
1118     _channelref *ref = _channelref_find(channels->head, id, &prev);
1119     if (ref == NULL) {
1120         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id);
1121         goto done;
1122     }
1123 
1124     _channels_remove_ref(channels, ref, prev, pchan);
1125 
1126     res = 0;
1127 done:
1128     PyThread_release_lock(channels->mutex);
1129     return res;
1130 }
1131 
1132 static int
_channels_add_id_object(_channels * channels,int64_t id)1133 _channels_add_id_object(_channels *channels, int64_t id)
1134 {
1135     int res = -1;
1136     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1137 
1138     _channelref *ref = _channelref_find(channels->head, id, NULL);
1139     if (ref == NULL) {
1140         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id);
1141         goto done;
1142     }
1143     ref->objcount += 1;
1144 
1145     res = 0;
1146 done:
1147     PyThread_release_lock(channels->mutex);
1148     return res;
1149 }
1150 
1151 static void
_channels_drop_id_object(_channels * channels,int64_t id)1152 _channels_drop_id_object(_channels *channels, int64_t id)
1153 {
1154     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1155 
1156     _channelref *prev = NULL;
1157     _channelref *ref = _channelref_find(channels->head, id, &prev);
1158     if (ref == NULL) {
1159         // Already destroyed.
1160         goto done;
1161     }
1162     ref->objcount -= 1;
1163 
1164     // Destroy if no longer used.
1165     if (ref->objcount == 0) {
1166         _PyChannelState *chan = NULL;
1167         _channels_remove_ref(channels, ref, prev, &chan);
1168         if (chan != NULL) {
1169             _channel_free(chan);
1170         }
1171     }
1172 
1173 done:
1174     PyThread_release_lock(channels->mutex);
1175 }
1176 
1177 static int64_t *
_channels_list_all(_channels * channels,int64_t * count)1178 _channels_list_all(_channels *channels, int64_t *count)
1179 {
1180     int64_t *cids = NULL;
1181     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1182     int64_t *ids = PyMem_NEW(int64_t, (Py_ssize_t)(channels->numopen));
1183     if (ids == NULL) {
1184         goto done;
1185     }
1186     _channelref *ref = channels->head;
1187     for (int64_t i=0; ref != NULL; ref = ref->next, i++) {
1188         ids[i] = ref->id;
1189     }
1190     *count = channels->numopen;
1191 
1192     cids = ids;
1193 done:
1194     PyThread_release_lock(channels->mutex);
1195     return cids;
1196 }
1197 
1198 /* support for closing non-empty channels */
1199 
1200 struct _channel_closing {
1201     struct _channelref *ref;
1202 };
1203 
1204 static int
_channel_set_closing(struct _channelref * ref,PyThread_type_lock mutex)1205 _channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) {
1206     struct _channel *chan = ref->chan;
1207     if (chan == NULL) {
1208         // already closed
1209         return 0;
1210     }
1211     int res = -1;
1212     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1213     if (chan->closing != NULL) {
1214         PyErr_SetString(ChannelClosedError, "channel closed");
1215         goto done;
1216     }
1217     chan->closing = PyMem_NEW(struct _channel_closing, 1);
1218     if (chan->closing == NULL) {
1219         goto done;
1220     }
1221     chan->closing->ref = ref;
1222 
1223     res = 0;
1224 done:
1225     PyThread_release_lock(chan->mutex);
1226     return res;
1227 }
1228 
1229 static void
_channel_clear_closing(struct _channel * chan)1230 _channel_clear_closing(struct _channel *chan) {
1231     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1232     if (chan->closing != NULL) {
1233         PyMem_Free(chan->closing);
1234         chan->closing = NULL;
1235     }
1236     PyThread_release_lock(chan->mutex);
1237 }
1238 
1239 static void
_channel_finish_closing(struct _channel * chan)1240 _channel_finish_closing(struct _channel *chan) {
1241     struct _channel_closing *closing = chan->closing;
1242     if (closing == NULL) {
1243         return;
1244     }
1245     _channelref *ref = closing->ref;
1246     _channel_clear_closing(chan);
1247     // Do the things that would have been done in _channels_close().
1248     ref->chan = NULL;
1249     _channel_free(chan);
1250 }
1251 
1252 /* "high"-level channel-related functions */
1253 
1254 static int64_t
_channel_create(_channels * channels)1255 _channel_create(_channels *channels)
1256 {
1257     _PyChannelState *chan = _channel_new();
1258     if (chan == NULL) {
1259         return -1;
1260     }
1261     int64_t id = _channels_add(channels, chan);
1262     if (id < 0) {
1263         _channel_free(chan);
1264         return -1;
1265     }
1266     return id;
1267 }
1268 
1269 static int
_channel_destroy(_channels * channels,int64_t id)1270 _channel_destroy(_channels *channels, int64_t id)
1271 {
1272     _PyChannelState *chan = NULL;
1273     if (_channels_remove(channels, id, &chan) != 0) {
1274         return -1;
1275     }
1276     if (chan != NULL) {
1277         _channel_free(chan);
1278     }
1279     return 0;
1280 }
1281 
1282 static int
_channel_send(_channels * channels,int64_t id,PyObject * obj)1283 _channel_send(_channels *channels, int64_t id, PyObject *obj)
1284 {
1285     PyInterpreterState *interp = _get_current();
1286     if (interp == NULL) {
1287         return -1;
1288     }
1289 
1290     // Look up the channel.
1291     PyThread_type_lock mutex = NULL;
1292     _PyChannelState *chan = _channels_lookup(channels, id, &mutex);
1293     if (chan == NULL) {
1294         return -1;
1295     }
1296     // Past this point we are responsible for releasing the mutex.
1297 
1298     if (chan->closing != NULL) {
1299         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", id);
1300         PyThread_release_lock(mutex);
1301         return -1;
1302     }
1303 
1304     // Convert the object to cross-interpreter data.
1305     _PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
1306     if (data == NULL) {
1307         PyThread_release_lock(mutex);
1308         return -1;
1309     }
1310     if (_PyObject_GetCrossInterpreterData(obj, data) != 0) {
1311         PyThread_release_lock(mutex);
1312         PyMem_Free(data);
1313         return -1;
1314     }
1315 
1316     // Add the data to the channel.
1317     int res = _channel_add(chan, PyInterpreterState_GetID(interp), data);
1318     PyThread_release_lock(mutex);
1319     if (res != 0) {
1320         _PyCrossInterpreterData_Release(data);
1321         PyMem_Free(data);
1322         return -1;
1323     }
1324 
1325     return 0;
1326 }
1327 
1328 static PyObject *
_channel_recv(_channels * channels,int64_t id)1329 _channel_recv(_channels *channels, int64_t id)
1330 {
1331     PyInterpreterState *interp = _get_current();
1332     if (interp == NULL) {
1333         return NULL;
1334     }
1335 
1336     // Look up the channel.
1337     PyThread_type_lock mutex = NULL;
1338     _PyChannelState *chan = _channels_lookup(channels, id, &mutex);
1339     if (chan == NULL) {
1340         return NULL;
1341     }
1342     // Past this point we are responsible for releasing the mutex.
1343 
1344     // Pop off the next item from the channel.
1345     _PyCrossInterpreterData *data = _channel_next(chan, PyInterpreterState_GetID(interp));
1346     PyThread_release_lock(mutex);
1347     if (data == NULL) {
1348         return NULL;
1349     }
1350 
1351     // Convert the data back to an object.
1352     PyObject *obj = _PyCrossInterpreterData_NewObject(data);
1353     _PyCrossInterpreterData_Release(data);
1354     PyMem_Free(data);
1355     if (obj == NULL) {
1356         return NULL;
1357     }
1358 
1359     return obj;
1360 }
1361 
1362 static int
_channel_drop(_channels * channels,int64_t id,int send,int recv)1363 _channel_drop(_channels *channels, int64_t id, int send, int recv)
1364 {
1365     PyInterpreterState *interp = _get_current();
1366     if (interp == NULL) {
1367         return -1;
1368     }
1369 
1370     // Look up the channel.
1371     PyThread_type_lock mutex = NULL;
1372     _PyChannelState *chan = _channels_lookup(channels, id, &mutex);
1373     if (chan == NULL) {
1374         return -1;
1375     }
1376     // Past this point we are responsible for releasing the mutex.
1377 
1378     // Close one or both of the two ends.
1379     int res = _channel_close_interpreter(chan, PyInterpreterState_GetID(interp), send-recv);
1380     PyThread_release_lock(mutex);
1381     return res;
1382 }
1383 
1384 static int
_channel_close(_channels * channels,int64_t id,int end,int force)1385 _channel_close(_channels *channels, int64_t id, int end, int force)
1386 {
1387     return _channels_close(channels, id, NULL, end, force);
1388 }
1389 
1390 static int
_channel_is_associated(_channels * channels,int64_t cid,int64_t interp,int send)1391 _channel_is_associated(_channels *channels, int64_t cid, int64_t interp,
1392                        int send)
1393 {
1394     _PyChannelState *chan = _channels_lookup(channels, cid, NULL);
1395     if (chan == NULL) {
1396         return -1;
1397     } else if (send && chan->closing != NULL) {
1398         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid);
1399         return -1;
1400     }
1401 
1402     _channelend *end = _channelend_find(send ? chan->ends->send : chan->ends->recv,
1403                                         interp, NULL);
1404 
1405     return (end != NULL && end->open);
1406 }
1407 
1408 /* ChannelID class */
1409 
1410 static PyTypeObject ChannelIDtype;
1411 
1412 typedef struct channelid {
1413     PyObject_HEAD
1414     int64_t id;
1415     int end;
1416     int resolve;
1417     _channels *channels;
1418 } channelid;
1419 
1420 static int
channel_id_converter(PyObject * arg,void * ptr)1421 channel_id_converter(PyObject *arg, void *ptr)
1422 {
1423     int64_t cid;
1424     if (PyObject_TypeCheck(arg, &ChannelIDtype)) {
1425         cid = ((channelid *)arg)->id;
1426     }
1427     else if (PyIndex_Check(arg)) {
1428         cid = PyLong_AsLongLong(arg);
1429         if (cid == -1 && PyErr_Occurred()) {
1430             return 0;
1431         }
1432         if (cid < 0) {
1433             PyErr_Format(PyExc_ValueError,
1434                         "channel ID must be a non-negative int, got %R", arg);
1435             return 0;
1436         }
1437     }
1438     else {
1439         PyErr_Format(PyExc_TypeError,
1440                      "channel ID must be an int, got %.100s",
1441                      Py_TYPE(arg)->tp_name);
1442         return 0;
1443     }
1444     *(int64_t *)ptr = cid;
1445     return 1;
1446 }
1447 
1448 static channelid *
newchannelid(PyTypeObject * cls,int64_t cid,int end,_channels * channels,int force,int resolve)1449 newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels,
1450              int force, int resolve)
1451 {
1452     channelid *self = PyObject_New(channelid, cls);
1453     if (self == NULL) {
1454         return NULL;
1455     }
1456     self->id = cid;
1457     self->end = end;
1458     self->resolve = resolve;
1459     self->channels = channels;
1460 
1461     if (_channels_add_id_object(channels, cid) != 0) {
1462         if (force && PyErr_ExceptionMatches(ChannelNotFoundError)) {
1463             PyErr_Clear();
1464         }
1465         else {
1466             Py_DECREF((PyObject *)self);
1467             return NULL;
1468         }
1469     }
1470 
1471     return self;
1472 }
1473 
1474 static _channels * _global_channels(void);
1475 
1476 static PyObject *
channelid_new(PyTypeObject * cls,PyObject * args,PyObject * kwds)1477 channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
1478 {
1479     static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
1480     int64_t cid;
1481     int send = -1;
1482     int recv = -1;
1483     int force = 0;
1484     int resolve = 0;
1485     if (!PyArg_ParseTupleAndKeywords(args, kwds,
1486                                      "O&|$pppp:ChannelID.__new__", kwlist,
1487                                      channel_id_converter, &cid, &send, &recv, &force, &resolve))
1488         return NULL;
1489 
1490     // Handle "send" and "recv".
1491     if (send == 0 && recv == 0) {
1492         PyErr_SetString(PyExc_ValueError,
1493                         "'send' and 'recv' cannot both be False");
1494         return NULL;
1495     }
1496 
1497     int end = 0;
1498     if (send == 1) {
1499         if (recv == 0 || recv == -1) {
1500             end = CHANNEL_SEND;
1501         }
1502     }
1503     else if (recv == 1) {
1504         end = CHANNEL_RECV;
1505     }
1506 
1507     return (PyObject *)newchannelid(cls, cid, end, _global_channels(),
1508                                     force, resolve);
1509 }
1510 
1511 static void
channelid_dealloc(PyObject * v)1512 channelid_dealloc(PyObject *v)
1513 {
1514     int64_t cid = ((channelid *)v)->id;
1515     _channels *channels = ((channelid *)v)->channels;
1516     Py_TYPE(v)->tp_free(v);
1517 
1518     _channels_drop_id_object(channels, cid);
1519 }
1520 
1521 static PyObject *
channelid_repr(PyObject * self)1522 channelid_repr(PyObject *self)
1523 {
1524     PyTypeObject *type = Py_TYPE(self);
1525     const char *name = _PyType_Name(type);
1526 
1527     channelid *cid = (channelid *)self;
1528     const char *fmt;
1529     if (cid->end == CHANNEL_SEND) {
1530         fmt = "%s(%" PRId64 ", send=True)";
1531     }
1532     else if (cid->end == CHANNEL_RECV) {
1533         fmt = "%s(%" PRId64 ", recv=True)";
1534     }
1535     else {
1536         fmt = "%s(%" PRId64 ")";
1537     }
1538     return PyUnicode_FromFormat(fmt, name, cid->id);
1539 }
1540 
1541 static PyObject *
channelid_str(PyObject * self)1542 channelid_str(PyObject *self)
1543 {
1544     channelid *cid = (channelid *)self;
1545     return PyUnicode_FromFormat("%" PRId64 "", cid->id);
1546 }
1547 
1548 static PyObject *
channelid_int(PyObject * self)1549 channelid_int(PyObject *self)
1550 {
1551     channelid *cid = (channelid *)self;
1552     return PyLong_FromLongLong(cid->id);
1553 }
1554 
1555 static PyNumberMethods channelid_as_number = {
1556      0,                        /* nb_add */
1557      0,                        /* nb_subtract */
1558      0,                        /* nb_multiply */
1559      0,                        /* nb_remainder */
1560      0,                        /* nb_divmod */
1561      0,                        /* nb_power */
1562      0,                        /* nb_negative */
1563      0,                        /* nb_positive */
1564      0,                        /* nb_absolute */
1565      0,                        /* nb_bool */
1566      0,                        /* nb_invert */
1567      0,                        /* nb_lshift */
1568      0,                        /* nb_rshift */
1569      0,                        /* nb_and */
1570      0,                        /* nb_xor */
1571      0,                        /* nb_or */
1572      (unaryfunc)channelid_int, /* nb_int */
1573      0,                        /* nb_reserved */
1574      0,                        /* nb_float */
1575 
1576      0,                        /* nb_inplace_add */
1577      0,                        /* nb_inplace_subtract */
1578      0,                        /* nb_inplace_multiply */
1579      0,                        /* nb_inplace_remainder */
1580      0,                        /* nb_inplace_power */
1581      0,                        /* nb_inplace_lshift */
1582      0,                        /* nb_inplace_rshift */
1583      0,                        /* nb_inplace_and */
1584      0,                        /* nb_inplace_xor */
1585      0,                        /* nb_inplace_or */
1586 
1587      0,                        /* nb_floor_divide */
1588      0,                        /* nb_true_divide */
1589      0,                        /* nb_inplace_floor_divide */
1590      0,                        /* nb_inplace_true_divide */
1591 
1592      (unaryfunc)channelid_int, /* nb_index */
1593 };
1594 
1595 static Py_hash_t
channelid_hash(PyObject * self)1596 channelid_hash(PyObject *self)
1597 {
1598     channelid *cid = (channelid *)self;
1599     PyObject *id = PyLong_FromLongLong(cid->id);
1600     if (id == NULL) {
1601         return -1;
1602     }
1603     Py_hash_t hash = PyObject_Hash(id);
1604     Py_DECREF(id);
1605     return hash;
1606 }
1607 
1608 static PyObject *
channelid_richcompare(PyObject * self,PyObject * other,int op)1609 channelid_richcompare(PyObject *self, PyObject *other, int op)
1610 {
1611     if (op != Py_EQ && op != Py_NE) {
1612         Py_RETURN_NOTIMPLEMENTED;
1613     }
1614 
1615     if (!PyObject_TypeCheck(self, &ChannelIDtype)) {
1616         Py_RETURN_NOTIMPLEMENTED;
1617     }
1618 
1619     channelid *cid = (channelid *)self;
1620     int equal;
1621     if (PyObject_TypeCheck(other, &ChannelIDtype)) {
1622         channelid *othercid = (channelid *)other;
1623         equal = (cid->end == othercid->end) && (cid->id == othercid->id);
1624     }
1625     else if (PyLong_Check(other)) {
1626         /* Fast path */
1627         int overflow;
1628         long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow);
1629         if (othercid == -1 && PyErr_Occurred()) {
1630             return NULL;
1631         }
1632         equal = !overflow && (othercid >= 0) && (cid->id == othercid);
1633     }
1634     else if (PyNumber_Check(other)) {
1635         PyObject *pyid = PyLong_FromLongLong(cid->id);
1636         if (pyid == NULL) {
1637             return NULL;
1638         }
1639         PyObject *res = PyObject_RichCompare(pyid, other, op);
1640         Py_DECREF(pyid);
1641         return res;
1642     }
1643     else {
1644         Py_RETURN_NOTIMPLEMENTED;
1645     }
1646 
1647     if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
1648         Py_RETURN_TRUE;
1649     }
1650     Py_RETURN_FALSE;
1651 }
1652 
1653 static PyObject *
_channel_from_cid(PyObject * cid,int end)1654 _channel_from_cid(PyObject *cid, int end)
1655 {
1656     PyObject *highlevel = PyImport_ImportModule("interpreters");
1657     if (highlevel == NULL) {
1658         PyErr_Clear();
1659         highlevel = PyImport_ImportModule("test.support.interpreters");
1660         if (highlevel == NULL) {
1661             return NULL;
1662         }
1663     }
1664     const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" :
1665                                                   "SendChannel";
1666     PyObject *cls = PyObject_GetAttrString(highlevel, clsname);
1667     Py_DECREF(highlevel);
1668     if (cls == NULL) {
1669         return NULL;
1670     }
1671     PyObject *chan = PyObject_CallFunctionObjArgs(cls, cid, NULL);
1672     Py_DECREF(cls);
1673     if (chan == NULL) {
1674         return NULL;
1675     }
1676     return chan;
1677 }
1678 
1679 struct _channelid_xid {
1680     int64_t id;
1681     int end;
1682     int resolve;
1683 };
1684 
1685 static PyObject *
_channelid_from_xid(_PyCrossInterpreterData * data)1686 _channelid_from_xid(_PyCrossInterpreterData *data)
1687 {
1688     struct _channelid_xid *xid = (struct _channelid_xid *)data->data;
1689     // Note that we do not preserve the "resolve" flag.
1690     PyObject *cid = (PyObject *)newchannelid(&ChannelIDtype, xid->id, xid->end,
1691                                              _global_channels(), 0, 0);
1692     if (xid->end == 0) {
1693         return cid;
1694     }
1695     if (!xid->resolve) {
1696         return cid;
1697     }
1698 
1699     /* Try returning a high-level channel end but fall back to the ID. */
1700     PyObject *chan = _channel_from_cid(cid, xid->end);
1701     if (chan == NULL) {
1702         PyErr_Clear();
1703         return cid;
1704     }
1705     Py_DECREF(cid);
1706     return chan;
1707 }
1708 
1709 static int
_channelid_shared(PyObject * obj,_PyCrossInterpreterData * data)1710 _channelid_shared(PyObject *obj, _PyCrossInterpreterData *data)
1711 {
1712     struct _channelid_xid *xid = PyMem_NEW(struct _channelid_xid, 1);
1713     if (xid == NULL) {
1714         return -1;
1715     }
1716     xid->id = ((channelid *)obj)->id;
1717     xid->end = ((channelid *)obj)->end;
1718     xid->resolve = ((channelid *)obj)->resolve;
1719 
1720     data->data = xid;
1721     Py_INCREF(obj);
1722     data->obj = obj;
1723     data->new_object = _channelid_from_xid;
1724     data->free = PyMem_Free;
1725     return 0;
1726 }
1727 
1728 static PyObject *
channelid_end(PyObject * self,void * end)1729 channelid_end(PyObject *self, void *end)
1730 {
1731     int force = 1;
1732     channelid *cid = (channelid *)self;
1733     if (end != NULL) {
1734         return (PyObject *)newchannelid(Py_TYPE(self), cid->id, *(int *)end,
1735                                         cid->channels, force, cid->resolve);
1736     }
1737 
1738     if (cid->end == CHANNEL_SEND) {
1739         return PyUnicode_InternFromString("send");
1740     }
1741     if (cid->end == CHANNEL_RECV) {
1742         return PyUnicode_InternFromString("recv");
1743     }
1744     return PyUnicode_InternFromString("both");
1745 }
1746 
1747 static int _channelid_end_send = CHANNEL_SEND;
1748 static int _channelid_end_recv = CHANNEL_RECV;
1749 
1750 static PyGetSetDef channelid_getsets[] = {
1751     {"end", (getter)channelid_end, NULL,
1752      PyDoc_STR("'send', 'recv', or 'both'")},
1753     {"send", (getter)channelid_end, NULL,
1754      PyDoc_STR("the 'send' end of the channel"), &_channelid_end_send},
1755     {"recv", (getter)channelid_end, NULL,
1756      PyDoc_STR("the 'recv' end of the channel"), &_channelid_end_recv},
1757     {NULL}
1758 };
1759 
1760 PyDoc_STRVAR(channelid_doc,
1761 "A channel ID identifies a channel and may be used as an int.");
1762 
1763 static PyTypeObject ChannelIDtype = {
1764     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1765     "_xxsubinterpreters.ChannelID", /* tp_name */
1766     sizeof(channelid),              /* tp_basicsize */
1767     0,                              /* tp_itemsize */
1768     (destructor)channelid_dealloc,  /* tp_dealloc */
1769     0,                              /* tp_vectorcall_offset */
1770     0,                              /* tp_getattr */
1771     0,                              /* tp_setattr */
1772     0,                              /* tp_as_async */
1773     (reprfunc)channelid_repr,       /* tp_repr */
1774     &channelid_as_number,           /* tp_as_number */
1775     0,                              /* tp_as_sequence */
1776     0,                              /* tp_as_mapping */
1777     channelid_hash,                 /* tp_hash */
1778     0,                              /* tp_call */
1779     (reprfunc)channelid_str,        /* tp_str */
1780     0,                              /* tp_getattro */
1781     0,                              /* tp_setattro */
1782     0,                              /* tp_as_buffer */
1783     // Use Py_TPFLAGS_DISALLOW_INSTANTIATION so the type cannot be instantiated
1784     // from Python code.  We do this because there is a strong relationship
1785     // between channel IDs and the channel lifecycle, so this limitation avoids
1786     // related complications. Use the _channel_id() function instead.
1787     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE
1788         | Py_TPFLAGS_DISALLOW_INSTANTIATION, /* tp_flags */
1789     channelid_doc,                  /* tp_doc */
1790     0,                              /* tp_traverse */
1791     0,                              /* tp_clear */
1792     channelid_richcompare,          /* tp_richcompare */
1793     0,                              /* tp_weaklistoffset */
1794     0,                              /* tp_iter */
1795     0,                              /* tp_iternext */
1796     0,                              /* tp_methods */
1797     0,                              /* tp_members */
1798     channelid_getsets,              /* tp_getset */
1799 };
1800 
1801 
1802 /* interpreter-specific code ************************************************/
1803 
1804 static PyObject * RunFailedError = NULL;
1805 
1806 static int
interp_exceptions_init(PyObject * ns)1807 interp_exceptions_init(PyObject *ns)
1808 {
1809     // XXX Move the exceptions into per-module memory?
1810 
1811     if (RunFailedError == NULL) {
1812         // An uncaught exception came out of interp_run_string().
1813         RunFailedError = PyErr_NewException("_xxsubinterpreters.RunFailedError",
1814                                             PyExc_RuntimeError, NULL);
1815         if (RunFailedError == NULL) {
1816             return -1;
1817         }
1818         if (PyDict_SetItemString(ns, "RunFailedError", RunFailedError) != 0) {
1819             return -1;
1820         }
1821     }
1822 
1823     return 0;
1824 }
1825 
1826 static int
_is_running(PyInterpreterState * interp)1827 _is_running(PyInterpreterState *interp)
1828 {
1829     PyThreadState *tstate = PyInterpreterState_ThreadHead(interp);
1830     if (PyThreadState_Next(tstate) != NULL) {
1831         PyErr_SetString(PyExc_RuntimeError,
1832                         "interpreter has more than one thread");
1833         return -1;
1834     }
1835 
1836     assert(!PyErr_Occurred());
1837     PyFrameObject *frame = PyThreadState_GetFrame(tstate);
1838     if (frame == NULL) {
1839         return 0;
1840     }
1841 
1842     int executing = _PyFrame_IsExecuting(frame);
1843     Py_DECREF(frame);
1844 
1845     return executing;
1846 }
1847 
1848 static int
_ensure_not_running(PyInterpreterState * interp)1849 _ensure_not_running(PyInterpreterState *interp)
1850 {
1851     int is_running = _is_running(interp);
1852     if (is_running < 0) {
1853         return -1;
1854     }
1855     if (is_running) {
1856         PyErr_Format(PyExc_RuntimeError, "interpreter already running");
1857         return -1;
1858     }
1859     return 0;
1860 }
1861 
1862 static int
_run_script(PyInterpreterState * interp,const char * codestr,_sharedns * shared,_sharedexception ** exc)1863 _run_script(PyInterpreterState *interp, const char *codestr,
1864             _sharedns *shared, _sharedexception **exc)
1865 {
1866     PyObject *exctype = NULL;
1867     PyObject *excval = NULL;
1868     PyObject *tb = NULL;
1869 
1870     PyObject *main_mod = _PyInterpreterState_GetMainModule(interp);
1871     if (main_mod == NULL) {
1872         goto error;
1873     }
1874     PyObject *ns = PyModule_GetDict(main_mod);  // borrowed
1875     Py_DECREF(main_mod);
1876     if (ns == NULL) {
1877         goto error;
1878     }
1879     Py_INCREF(ns);
1880 
1881     // Apply the cross-interpreter data.
1882     if (shared != NULL) {
1883         if (_sharedns_apply(shared, ns) != 0) {
1884             Py_DECREF(ns);
1885             goto error;
1886         }
1887     }
1888 
1889     // Run the string (see PyRun_SimpleStringFlags).
1890     PyObject *result = PyRun_StringFlags(codestr, Py_file_input, ns, ns, NULL);
1891     Py_DECREF(ns);
1892     if (result == NULL) {
1893         goto error;
1894     }
1895     else {
1896         Py_DECREF(result);  // We throw away the result.
1897     }
1898 
1899     *exc = NULL;
1900     return 0;
1901 
1902 error:
1903     PyErr_Fetch(&exctype, &excval, &tb);
1904 
1905     _sharedexception *sharedexc = _sharedexception_bind(exctype, excval, tb);
1906     Py_XDECREF(exctype);
1907     Py_XDECREF(excval);
1908     Py_XDECREF(tb);
1909     if (sharedexc == NULL) {
1910         fprintf(stderr, "RunFailedError: script raised an uncaught exception");
1911         PyErr_Clear();
1912         sharedexc = NULL;
1913     }
1914     else {
1915         assert(!PyErr_Occurred());
1916     }
1917     *exc = sharedexc;
1918     return -1;
1919 }
1920 
1921 static int
_run_script_in_interpreter(PyInterpreterState * interp,const char * codestr,PyObject * shareables)1922 _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr,
1923                            PyObject *shareables)
1924 {
1925     if (_ensure_not_running(interp) < 0) {
1926         return -1;
1927     }
1928 
1929     _sharedns *shared = _get_shared_ns(shareables);
1930     if (shared == NULL && PyErr_Occurred()) {
1931         return -1;
1932     }
1933 
1934 #ifdef EXPERIMENTAL_ISOLATED_SUBINTERPRETERS
1935     // Switch to interpreter.
1936     PyThreadState *new_tstate = PyInterpreterState_ThreadHead(interp);
1937     PyThreadState *save1 = PyEval_SaveThread();
1938 
1939     (void)PyThreadState_Swap(new_tstate);
1940 
1941     // Run the script.
1942     _sharedexception *exc = NULL;
1943     int result = _run_script(interp, codestr, shared, &exc);
1944 
1945     // Switch back.
1946     PyEval_RestoreThread(save1);
1947 #else
1948     // Switch to interpreter.
1949     PyThreadState *save_tstate = NULL;
1950     if (interp != PyInterpreterState_Get()) {
1951         // XXX Using the "head" thread isn't strictly correct.
1952         PyThreadState *tstate = PyInterpreterState_ThreadHead(interp);
1953         // XXX Possible GILState issues?
1954         save_tstate = PyThreadState_Swap(tstate);
1955     }
1956 
1957     // Run the script.
1958     _sharedexception *exc = NULL;
1959     int result = _run_script(interp, codestr, shared, &exc);
1960 
1961     // Switch back.
1962     if (save_tstate != NULL) {
1963         PyThreadState_Swap(save_tstate);
1964     }
1965 #endif
1966 
1967     // Propagate any exception out to the caller.
1968     if (exc != NULL) {
1969         _sharedexception_apply(exc, RunFailedError);
1970         _sharedexception_free(exc);
1971     }
1972     else if (result != 0) {
1973         // We were unable to allocate a shared exception.
1974         PyErr_NoMemory();
1975     }
1976 
1977     if (shared != NULL) {
1978         _sharedns_free(shared);
1979     }
1980 
1981     return result;
1982 }
1983 
1984 
1985 /* module level code ********************************************************/
1986 
1987 /* globals is the process-global state for the module.  It holds all
1988    the data that we need to share between interpreters, so it cannot
1989    hold PyObject values. */
1990 static struct globals {
1991     _channels channels;
1992 } _globals = {{0}};
1993 
1994 static int
_init_globals(void)1995 _init_globals(void)
1996 {
1997     if (_channels_init(&_globals.channels) != 0) {
1998         return -1;
1999     }
2000     return 0;
2001 }
2002 
2003 static _channels *
_global_channels(void)2004 _global_channels(void) {
2005     return &_globals.channels;
2006 }
2007 
2008 static PyObject *
interp_create(PyObject * self,PyObject * args,PyObject * kwds)2009 interp_create(PyObject *self, PyObject *args, PyObject *kwds)
2010 {
2011 
2012     static char *kwlist[] = {"isolated", NULL};
2013     int isolated = 1;
2014     if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$i:create", kwlist,
2015                                      &isolated)) {
2016         return NULL;
2017     }
2018 
2019     // Create and initialize the new interpreter.
2020     PyThreadState *save_tstate = PyThreadState_Get();
2021     // XXX Possible GILState issues?
2022     PyThreadState *tstate = _Py_NewInterpreter(isolated);
2023     PyThreadState_Swap(save_tstate);
2024     if (tstate == NULL) {
2025         /* Since no new thread state was created, there is no exception to
2026            propagate; raise a fresh one after swapping in the old thread
2027            state. */
2028         PyErr_SetString(PyExc_RuntimeError, "interpreter creation failed");
2029         return NULL;
2030     }
2031     PyInterpreterState *interp = PyThreadState_GetInterpreter(tstate);
2032     PyObject *idobj = _PyInterpreterState_GetIDObject(interp);
2033     if (idobj == NULL) {
2034         // XXX Possible GILState issues?
2035         save_tstate = PyThreadState_Swap(tstate);
2036         Py_EndInterpreter(tstate);
2037         PyThreadState_Swap(save_tstate);
2038         return NULL;
2039     }
2040     _PyInterpreterState_RequireIDRef(interp, 1);
2041     return idobj;
2042 }
2043 
2044 PyDoc_STRVAR(create_doc,
2045 "create() -> ID\n\
2046 \n\
2047 Create a new interpreter and return a unique generated ID.");
2048 
2049 
2050 static PyObject *
interp_destroy(PyObject * self,PyObject * args,PyObject * kwds)2051 interp_destroy(PyObject *self, PyObject *args, PyObject *kwds)
2052 {
2053     static char *kwlist[] = {"id", NULL};
2054     PyObject *id;
2055     // XXX Use "L" for id?
2056     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2057                                      "O:destroy", kwlist, &id)) {
2058         return NULL;
2059     }
2060 
2061     // Look up the interpreter.
2062     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
2063     if (interp == NULL) {
2064         return NULL;
2065     }
2066 
2067     // Ensure we don't try to destroy the current interpreter.
2068     PyInterpreterState *current = _get_current();
2069     if (current == NULL) {
2070         return NULL;
2071     }
2072     if (interp == current) {
2073         PyErr_SetString(PyExc_RuntimeError,
2074                         "cannot destroy the current interpreter");
2075         return NULL;
2076     }
2077 
2078     // Ensure the interpreter isn't running.
2079     /* XXX We *could* support destroying a running interpreter but
2080        aren't going to worry about it for now. */
2081     if (_ensure_not_running(interp) < 0) {
2082         return NULL;
2083     }
2084 
2085     // Destroy the interpreter.
2086     PyThreadState *tstate = PyInterpreterState_ThreadHead(interp);
2087     // XXX Possible GILState issues?
2088     PyThreadState *save_tstate = PyThreadState_Swap(tstate);
2089     Py_EndInterpreter(tstate);
2090     PyThreadState_Swap(save_tstate);
2091 
2092     Py_RETURN_NONE;
2093 }
2094 
2095 PyDoc_STRVAR(destroy_doc,
2096 "destroy(id)\n\
2097 \n\
2098 Destroy the identified interpreter.\n\
2099 \n\
2100 Attempting to destroy the current interpreter results in a RuntimeError.\n\
2101 So does an unrecognized ID.");
2102 
2103 
2104 static PyObject *
interp_list_all(PyObject * self,PyObject * Py_UNUSED (ignored))2105 interp_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
2106 {
2107     PyObject *ids, *id;
2108     PyInterpreterState *interp;
2109 
2110     ids = PyList_New(0);
2111     if (ids == NULL) {
2112         return NULL;
2113     }
2114 
2115     interp = PyInterpreterState_Head();
2116     while (interp != NULL) {
2117         id = _PyInterpreterState_GetIDObject(interp);
2118         if (id == NULL) {
2119             Py_DECREF(ids);
2120             return NULL;
2121         }
2122         // insert at front of list
2123         int res = PyList_Insert(ids, 0, id);
2124         Py_DECREF(id);
2125         if (res < 0) {
2126             Py_DECREF(ids);
2127             return NULL;
2128         }
2129 
2130         interp = PyInterpreterState_Next(interp);
2131     }
2132 
2133     return ids;
2134 }
2135 
2136 PyDoc_STRVAR(list_all_doc,
2137 "list_all() -> [ID]\n\
2138 \n\
2139 Return a list containing the ID of every existing interpreter.");
2140 
2141 
2142 static PyObject *
interp_get_current(PyObject * self,PyObject * Py_UNUSED (ignored))2143 interp_get_current(PyObject *self, PyObject *Py_UNUSED(ignored))
2144 {
2145     PyInterpreterState *interp =_get_current();
2146     if (interp == NULL) {
2147         return NULL;
2148     }
2149     return _PyInterpreterState_GetIDObject(interp);
2150 }
2151 
2152 PyDoc_STRVAR(get_current_doc,
2153 "get_current() -> ID\n\
2154 \n\
2155 Return the ID of current interpreter.");
2156 
2157 
2158 static PyObject *
interp_get_main(PyObject * self,PyObject * Py_UNUSED (ignored))2159 interp_get_main(PyObject *self, PyObject *Py_UNUSED(ignored))
2160 {
2161     // Currently, 0 is always the main interpreter.
2162     int64_t id = 0;
2163     return _PyInterpreterID_New(id);
2164 }
2165 
2166 PyDoc_STRVAR(get_main_doc,
2167 "get_main() -> ID\n\
2168 \n\
2169 Return the ID of main interpreter.");
2170 
2171 
2172 static PyObject *
interp_run_string(PyObject * self,PyObject * args,PyObject * kwds)2173 interp_run_string(PyObject *self, PyObject *args, PyObject *kwds)
2174 {
2175     static char *kwlist[] = {"id", "script", "shared", NULL};
2176     PyObject *id, *code;
2177     PyObject *shared = NULL;
2178     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2179                                      "OU|O:run_string", kwlist,
2180                                      &id, &code, &shared)) {
2181         return NULL;
2182     }
2183 
2184     // Look up the interpreter.
2185     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
2186     if (interp == NULL) {
2187         return NULL;
2188     }
2189 
2190     // Extract code.
2191     Py_ssize_t size;
2192     const char *codestr = PyUnicode_AsUTF8AndSize(code, &size);
2193     if (codestr == NULL) {
2194         return NULL;
2195     }
2196     if (strlen(codestr) != (size_t)size) {
2197         PyErr_SetString(PyExc_ValueError,
2198                         "source code string cannot contain null bytes");
2199         return NULL;
2200     }
2201 
2202     // Run the code in the interpreter.
2203     if (_run_script_in_interpreter(interp, codestr, shared) != 0) {
2204         return NULL;
2205     }
2206     Py_RETURN_NONE;
2207 }
2208 
2209 PyDoc_STRVAR(run_string_doc,
2210 "run_string(id, script, shared)\n\
2211 \n\
2212 Execute the provided string in the identified interpreter.\n\
2213 \n\
2214 See PyRun_SimpleStrings.");
2215 
2216 
2217 static PyObject *
object_is_shareable(PyObject * self,PyObject * args,PyObject * kwds)2218 object_is_shareable(PyObject *self, PyObject *args, PyObject *kwds)
2219 {
2220     static char *kwlist[] = {"obj", NULL};
2221     PyObject *obj;
2222     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2223                                      "O:is_shareable", kwlist, &obj)) {
2224         return NULL;
2225     }
2226 
2227     if (_PyObject_CheckCrossInterpreterData(obj) == 0) {
2228         Py_RETURN_TRUE;
2229     }
2230     PyErr_Clear();
2231     Py_RETURN_FALSE;
2232 }
2233 
2234 PyDoc_STRVAR(is_shareable_doc,
2235 "is_shareable(obj) -> bool\n\
2236 \n\
2237 Return True if the object's data may be shared between interpreters and\n\
2238 False otherwise.");
2239 
2240 
2241 static PyObject *
interp_is_running(PyObject * self,PyObject * args,PyObject * kwds)2242 interp_is_running(PyObject *self, PyObject *args, PyObject *kwds)
2243 {
2244     static char *kwlist[] = {"id", NULL};
2245     PyObject *id;
2246     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2247                                      "O:is_running", kwlist, &id)) {
2248         return NULL;
2249     }
2250 
2251     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
2252     if (interp == NULL) {
2253         return NULL;
2254     }
2255     int is_running = _is_running(interp);
2256     if (is_running < 0) {
2257         return NULL;
2258     }
2259     if (is_running) {
2260         Py_RETURN_TRUE;
2261     }
2262     Py_RETURN_FALSE;
2263 }
2264 
2265 PyDoc_STRVAR(is_running_doc,
2266 "is_running(id) -> bool\n\
2267 \n\
2268 Return whether or not the identified interpreter is running.");
2269 
2270 static PyObject *
channel_create(PyObject * self,PyObject * Py_UNUSED (ignored))2271 channel_create(PyObject *self, PyObject *Py_UNUSED(ignored))
2272 {
2273     int64_t cid = _channel_create(&_globals.channels);
2274     if (cid < 0) {
2275         return NULL;
2276     }
2277     PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, cid, 0,
2278                                             &_globals.channels, 0, 0);
2279     if (id == NULL) {
2280         if (_channel_destroy(&_globals.channels, cid) != 0) {
2281             // XXX issue a warning?
2282         }
2283         return NULL;
2284     }
2285     assert(((channelid *)id)->channels != NULL);
2286     return id;
2287 }
2288 
2289 PyDoc_STRVAR(channel_create_doc,
2290 "channel_create() -> cid\n\
2291 \n\
2292 Create a new cross-interpreter channel and return a unique generated ID.");
2293 
2294 static PyObject *
channel_destroy(PyObject * self,PyObject * args,PyObject * kwds)2295 channel_destroy(PyObject *self, PyObject *args, PyObject *kwds)
2296 {
2297     static char *kwlist[] = {"cid", NULL};
2298     int64_t cid;
2299     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist,
2300                                      channel_id_converter, &cid)) {
2301         return NULL;
2302     }
2303 
2304     if (_channel_destroy(&_globals.channels, cid) != 0) {
2305         return NULL;
2306     }
2307     Py_RETURN_NONE;
2308 }
2309 
2310 PyDoc_STRVAR(channel_destroy_doc,
2311 "channel_destroy(cid)\n\
2312 \n\
2313 Close and finalize the channel.  Afterward attempts to use the channel\n\
2314 will behave as though it never existed.");
2315 
2316 static PyObject *
channel_list_all(PyObject * self,PyObject * Py_UNUSED (ignored))2317 channel_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
2318 {
2319     int64_t count = 0;
2320     int64_t *cids = _channels_list_all(&_globals.channels, &count);
2321     if (cids == NULL) {
2322         if (count == 0) {
2323             return PyList_New(0);
2324         }
2325         return NULL;
2326     }
2327     PyObject *ids = PyList_New((Py_ssize_t)count);
2328     if (ids == NULL) {
2329         goto finally;
2330     }
2331     int64_t *cur = cids;
2332     for (int64_t i=0; i < count; cur++, i++) {
2333         PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, *cur, 0,
2334                                                 &_globals.channels, 0, 0);
2335         if (id == NULL) {
2336             Py_DECREF(ids);
2337             ids = NULL;
2338             break;
2339         }
2340         PyList_SET_ITEM(ids, i, id);
2341     }
2342 
2343 finally:
2344     PyMem_Free(cids);
2345     return ids;
2346 }
2347 
2348 PyDoc_STRVAR(channel_list_all_doc,
2349 "channel_list_all() -> [cid]\n\
2350 \n\
2351 Return the list of all IDs for active channels.");
2352 
2353 static PyObject *
channel_list_interpreters(PyObject * self,PyObject * args,PyObject * kwds)2354 channel_list_interpreters(PyObject *self, PyObject *args, PyObject *kwds)
2355 {
2356     static char *kwlist[] = {"cid", "send", NULL};
2357     int64_t cid;            /* Channel ID */
2358     int send = 0;           /* Send or receive end? */
2359     int64_t id;
2360     PyObject *ids, *id_obj;
2361     PyInterpreterState *interp;
2362 
2363     if (!PyArg_ParseTupleAndKeywords(
2364             args, kwds, "O&$p:channel_list_interpreters",
2365             kwlist, channel_id_converter, &cid, &send)) {
2366         return NULL;
2367     }
2368 
2369     ids = PyList_New(0);
2370     if (ids == NULL) {
2371         goto except;
2372     }
2373 
2374     interp = PyInterpreterState_Head();
2375     while (interp != NULL) {
2376         id = PyInterpreterState_GetID(interp);
2377         assert(id >= 0);
2378         int res = _channel_is_associated(&_globals.channels, cid, id, send);
2379         if (res < 0) {
2380             goto except;
2381         }
2382         if (res) {
2383             id_obj = _PyInterpreterState_GetIDObject(interp);
2384             if (id_obj == NULL) {
2385                 goto except;
2386             }
2387             res = PyList_Insert(ids, 0, id_obj);
2388             Py_DECREF(id_obj);
2389             if (res < 0) {
2390                 goto except;
2391             }
2392         }
2393         interp = PyInterpreterState_Next(interp);
2394     }
2395 
2396     goto finally;
2397 
2398 except:
2399     Py_XDECREF(ids);
2400     ids = NULL;
2401 
2402 finally:
2403     return ids;
2404 }
2405 
2406 PyDoc_STRVAR(channel_list_interpreters_doc,
2407 "channel_list_interpreters(cid, *, send) -> [id]\n\
2408 \n\
2409 Return the list of all interpreter IDs associated with an end of the channel.\n\
2410 \n\
2411 The 'send' argument should be a boolean indicating whether to use the send or\n\
2412 receive end.");
2413 
2414 
2415 static PyObject *
channel_send(PyObject * self,PyObject * args,PyObject * kwds)2416 channel_send(PyObject *self, PyObject *args, PyObject *kwds)
2417 {
2418     static char *kwlist[] = {"cid", "obj", NULL};
2419     int64_t cid;
2420     PyObject *obj;
2421     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
2422                                      channel_id_converter, &cid, &obj)) {
2423         return NULL;
2424     }
2425 
2426     if (_channel_send(&_globals.channels, cid, obj) != 0) {
2427         return NULL;
2428     }
2429     Py_RETURN_NONE;
2430 }
2431 
2432 PyDoc_STRVAR(channel_send_doc,
2433 "channel_send(cid, obj)\n\
2434 \n\
2435 Add the object's data to the channel's queue.");
2436 
2437 static PyObject *
channel_recv(PyObject * self,PyObject * args,PyObject * kwds)2438 channel_recv(PyObject *self, PyObject *args, PyObject *kwds)
2439 {
2440     static char *kwlist[] = {"cid", "default", NULL};
2441     int64_t cid;
2442     PyObject *dflt = NULL;
2443     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|O:channel_recv", kwlist,
2444                                      channel_id_converter, &cid, &dflt)) {
2445         return NULL;
2446     }
2447     Py_XINCREF(dflt);
2448 
2449     PyObject *obj = _channel_recv(&_globals.channels, cid);
2450     if (obj != NULL) {
2451         Py_XDECREF(dflt);
2452         return obj;
2453     } else if (PyErr_Occurred()) {
2454         Py_XDECREF(dflt);
2455         return NULL;
2456     } else if (dflt != NULL) {
2457         return dflt;
2458     } else {
2459         PyErr_Format(ChannelEmptyError, "channel %" PRId64 " is empty", cid);
2460         return NULL;
2461     }
2462 }
2463 
2464 PyDoc_STRVAR(channel_recv_doc,
2465 "channel_recv(cid, [default]) -> obj\n\
2466 \n\
2467 Return a new object from the data at the front of the channel's queue.\n\
2468 \n\
2469 If there is nothing to receive then raise ChannelEmptyError, unless\n\
2470 a default value is provided.  In that case return it.");
2471 
2472 static PyObject *
channel_close(PyObject * self,PyObject * args,PyObject * kwds)2473 channel_close(PyObject *self, PyObject *args, PyObject *kwds)
2474 {
2475     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
2476     int64_t cid;
2477     int send = 0;
2478     int recv = 0;
2479     int force = 0;
2480     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2481                                      "O&|$ppp:channel_close", kwlist,
2482                                      channel_id_converter, &cid, &send, &recv, &force)) {
2483         return NULL;
2484     }
2485 
2486     if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) {
2487         return NULL;
2488     }
2489     Py_RETURN_NONE;
2490 }
2491 
2492 PyDoc_STRVAR(channel_close_doc,
2493 "channel_close(cid, *, send=None, recv=None, force=False)\n\
2494 \n\
2495 Close the channel for all interpreters.\n\
2496 \n\
2497 If the channel is empty then the keyword args are ignored and both\n\
2498 ends are immediately closed.  Otherwise, if 'force' is True then\n\
2499 all queued items are released and both ends are immediately\n\
2500 closed.\n\
2501 \n\
2502 If the channel is not empty *and* 'force' is False then following\n\
2503 happens:\n\
2504 \n\
2505  * recv is True (regardless of send):\n\
2506    - raise ChannelNotEmptyError\n\
2507  * recv is None and send is None:\n\
2508    - raise ChannelNotEmptyError\n\
2509  * send is True and recv is not True:\n\
2510    - fully close the 'send' end\n\
2511    - close the 'recv' end to interpreters not already receiving\n\
2512    - fully close it once empty\n\
2513 \n\
2514 Closing an already closed channel results in a ChannelClosedError.\n\
2515 \n\
2516 Once the channel's ID has no more ref counts in any interpreter\n\
2517 the channel will be destroyed.");
2518 
2519 static PyObject *
channel_release(PyObject * self,PyObject * args,PyObject * kwds)2520 channel_release(PyObject *self, PyObject *args, PyObject *kwds)
2521 {
2522     // Note that only the current interpreter is affected.
2523     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
2524     int64_t cid;
2525     int send = 0;
2526     int recv = 0;
2527     int force = 0;
2528     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2529                                      "O&|$ppp:channel_release", kwlist,
2530                                      channel_id_converter, &cid, &send, &recv, &force)) {
2531         return NULL;
2532     }
2533     if (send == 0 && recv == 0) {
2534         send = 1;
2535         recv = 1;
2536     }
2537 
2538     // XXX Handle force is True.
2539     // XXX Fix implicit release.
2540 
2541     if (_channel_drop(&_globals.channels, cid, send, recv) != 0) {
2542         return NULL;
2543     }
2544     Py_RETURN_NONE;
2545 }
2546 
2547 PyDoc_STRVAR(channel_release_doc,
2548 "channel_release(cid, *, send=None, recv=None, force=True)\n\
2549 \n\
2550 Close the channel for the current interpreter.  'send' and 'recv'\n\
2551 (bool) may be used to indicate the ends to close.  By default both\n\
2552 ends are closed.  Closing an already closed end is a noop.");
2553 
2554 static PyObject *
channel__channel_id(PyObject * self,PyObject * args,PyObject * kwds)2555 channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
2556 {
2557     return channelid_new(&ChannelIDtype, args, kwds);
2558 }
2559 
2560 static PyMethodDef module_functions[] = {
2561     {"create",                    (PyCFunction)(void(*)(void))interp_create,
2562      METH_VARARGS | METH_KEYWORDS, create_doc},
2563     {"destroy",                   (PyCFunction)(void(*)(void))interp_destroy,
2564      METH_VARARGS | METH_KEYWORDS, destroy_doc},
2565     {"list_all",                  interp_list_all,
2566      METH_NOARGS, list_all_doc},
2567     {"get_current",               interp_get_current,
2568      METH_NOARGS, get_current_doc},
2569     {"get_main",                  interp_get_main,
2570      METH_NOARGS, get_main_doc},
2571     {"is_running",                (PyCFunction)(void(*)(void))interp_is_running,
2572      METH_VARARGS | METH_KEYWORDS, is_running_doc},
2573     {"run_string",                (PyCFunction)(void(*)(void))interp_run_string,
2574      METH_VARARGS | METH_KEYWORDS, run_string_doc},
2575 
2576     {"is_shareable",              (PyCFunction)(void(*)(void))object_is_shareable,
2577      METH_VARARGS | METH_KEYWORDS, is_shareable_doc},
2578 
2579     {"channel_create",            channel_create,
2580      METH_NOARGS, channel_create_doc},
2581     {"channel_destroy",           (PyCFunction)(void(*)(void))channel_destroy,
2582      METH_VARARGS | METH_KEYWORDS, channel_destroy_doc},
2583     {"channel_list_all",          channel_list_all,
2584      METH_NOARGS, channel_list_all_doc},
2585     {"channel_list_interpreters", (PyCFunction)(void(*)(void))channel_list_interpreters,
2586      METH_VARARGS | METH_KEYWORDS, channel_list_interpreters_doc},
2587     {"channel_send",              (PyCFunction)(void(*)(void))channel_send,
2588      METH_VARARGS | METH_KEYWORDS, channel_send_doc},
2589     {"channel_recv",              (PyCFunction)(void(*)(void))channel_recv,
2590      METH_VARARGS | METH_KEYWORDS, channel_recv_doc},
2591     {"channel_close",             (PyCFunction)(void(*)(void))channel_close,
2592      METH_VARARGS | METH_KEYWORDS, channel_close_doc},
2593     {"channel_release",           (PyCFunction)(void(*)(void))channel_release,
2594      METH_VARARGS | METH_KEYWORDS, channel_release_doc},
2595     {"_channel_id",               (PyCFunction)(void(*)(void))channel__channel_id,
2596      METH_VARARGS | METH_KEYWORDS, NULL},
2597 
2598     {NULL,                        NULL}           /* sentinel */
2599 };
2600 
2601 
2602 /* initialization function */
2603 
2604 PyDoc_STRVAR(module_doc,
2605 "This module provides primitive operations to manage Python interpreters.\n\
2606 The 'interpreters' module provides a more convenient interface.");
2607 
2608 static struct PyModuleDef interpretersmodule = {
2609     PyModuleDef_HEAD_INIT,
2610     "_xxsubinterpreters",  /* m_name */
2611     module_doc,            /* m_doc */
2612     -1,                    /* m_size */
2613     module_functions,      /* m_methods */
2614     NULL,                  /* m_slots */
2615     NULL,                  /* m_traverse */
2616     NULL,                  /* m_clear */
2617     NULL                   /* m_free */
2618 };
2619 
2620 
2621 PyMODINIT_FUNC
PyInit__xxsubinterpreters(void)2622 PyInit__xxsubinterpreters(void)
2623 {
2624     if (_init_globals() != 0) {
2625         return NULL;
2626     }
2627 
2628     /* Initialize types */
2629     if (PyType_Ready(&ChannelIDtype) != 0) {
2630         return NULL;
2631     }
2632 
2633     /* Create the module */
2634     PyObject *module = PyModule_Create(&interpretersmodule);
2635     if (module == NULL) {
2636         return NULL;
2637     }
2638 
2639     /* Add exception types */
2640     PyObject *ns = PyModule_GetDict(module);  // borrowed
2641     if (interp_exceptions_init(ns) != 0) {
2642         return NULL;
2643     }
2644     if (channel_exceptions_init(ns) != 0) {
2645         return NULL;
2646     }
2647 
2648     /* Add other types */
2649     Py_INCREF(&ChannelIDtype);
2650     if (PyDict_SetItemString(ns, "ChannelID", (PyObject *)&ChannelIDtype) != 0) {
2651         return NULL;
2652     }
2653     Py_INCREF(&_PyInterpreterID_Type);
2654     if (PyDict_SetItemString(ns, "InterpreterID", (PyObject *)&_PyInterpreterID_Type) != 0) {
2655         return NULL;
2656     }
2657 
2658     if (_PyCrossInterpreterData_RegisterClass(&ChannelIDtype, _channelid_shared)) {
2659         return NULL;
2660     }
2661 
2662     return module;
2663 }
2664