• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* interpreters module */
2 /* low-level access to interpreter primitives */
3 
4 #ifndef Py_BUILD_CORE_BUILTIN
5 #  define Py_BUILD_CORE_MODULE 1
6 #endif
7 
8 #include "Python.h"
9 #include "pycore_crossinterp.h"   // struct _xid
10 #include "pycore_interp.h"        // _PyInterpreterState_LookUpID()
11 #include "pycore_pystate.h"       // _PyInterpreterState_GetIDObject()
12 
13 #ifdef MS_WINDOWS
14 #define WIN32_LEAN_AND_MEAN
15 #include <windows.h>        // SwitchToThread()
16 #elif defined(HAVE_SCHED_H)
17 #include <sched.h>          // sched_yield()
18 #endif
19 
20 #define REGISTERS_HEAP_TYPES
21 #define HAS_UNBOUND_ITEMS
22 #include "_interpreters_common.h"
23 #undef HAS_UNBOUND_ITEMS
24 #undef REGISTERS_HEAP_TYPES
25 
26 
27 /*
28 This module has the following process-global state:
29 
30 _globals (static struct globals):
31     mutex (PyMutex)
32     module_count (int)
33     channels (struct _channels):
34         numopen (int64_t)
35         next_id; (int64_t)
36         mutex (PyThread_type_lock)
37         head (linked list of struct _channelref *):
38             cid (int64_t)
39             objcount (Py_ssize_t)
40             next (struct _channelref *):
41                 ...
42             chan (struct _channel *):
43                 open (int)
44                 mutex (PyThread_type_lock)
45                 closing (struct _channel_closing *):
46                     ref (struct _channelref *):
47                         ...
48                 ends (struct _channelends *):
49                     numsendopen (int64_t)
50                     numrecvopen (int64_t)
51                     send (struct _channelend *):
52                         interpid (int64_t)
53                         open (int)
54                         next (struct _channelend *)
55                     recv (struct _channelend *):
56                         ...
57                 queue (struct _channelqueue *):
58                     count (int64_t)
59                     first (struct _channelitem *):
60                         next (struct _channelitem *):
61                             ...
62                         data (_PyCrossInterpreterData *):
63                             data (void *)
64                             obj (PyObject *)
65                             interpid (int64_t)
66                             new_object (xid_newobjectfunc)
67                             free (xid_freefunc)
68                     last (struct _channelitem *):
69                         ...
70 
71 The above state includes the following allocations by the module:
72 
73 * 1 top-level mutex (to protect the rest of the state)
74 * for each channel:
75    * 1 struct _channelref
76    * 1 struct _channel
77    * 0-1 struct _channel_closing
78    * 1 struct _channelends
79    * 2 struct _channelend
80    * 1 struct _channelqueue
81 * for each item in each channel:
82    * 1 struct _channelitem
83    * 1 _PyCrossInterpreterData
84 
85 The only objects in that global state are the references held by each
86 channel's queue, which are safely managed via the _PyCrossInterpreterData_*()
87 API..  The module does not create any objects that are shared globally.
88 */
89 
90 #define MODULE_NAME _interpchannels
91 #define MODULE_NAME_STR Py_STRINGIFY(MODULE_NAME)
92 #define MODINIT_FUNC_NAME RESOLVE_MODINIT_FUNC_NAME(MODULE_NAME)
93 
94 
95 #define GLOBAL_MALLOC(TYPE) \
96     PyMem_RawMalloc(sizeof(TYPE))
97 #define GLOBAL_FREE(VAR) \
98     PyMem_RawFree(VAR)
99 
100 
101 #define XID_IGNORE_EXC 1
102 #define XID_FREE 2
103 
104 static int
_release_xid_data(_PyCrossInterpreterData * data,int flags)105 _release_xid_data(_PyCrossInterpreterData *data, int flags)
106 {
107     int ignoreexc = flags & XID_IGNORE_EXC;
108     PyObject *exc;
109     if (ignoreexc) {
110         exc = PyErr_GetRaisedException();
111     }
112     int res;
113     if (flags & XID_FREE) {
114         res = _PyCrossInterpreterData_ReleaseAndRawFree(data);
115     }
116     else {
117         res = _PyCrossInterpreterData_Release(data);
118     }
119     if (res < 0) {
120         /* The owning interpreter is already destroyed. */
121         if (ignoreexc) {
122             // XXX Emit a warning?
123             PyErr_Clear();
124         }
125     }
126     if (flags & XID_FREE) {
127         /* Either way, we free the data. */
128     }
129     if (ignoreexc) {
130         PyErr_SetRaisedException(exc);
131     }
132     return res;
133 }
134 
135 
136 static PyInterpreterState *
_get_current_interp(void)137 _get_current_interp(void)
138 {
139     // PyInterpreterState_Get() aborts if lookup fails, so don't need
140     // to check the result for NULL.
141     return PyInterpreterState_Get();
142 }
143 
144 static PyObject *
_get_current_module(void)145 _get_current_module(void)
146 {
147     PyObject *name = PyUnicode_FromString(MODULE_NAME_STR);
148     if (name == NULL) {
149         return NULL;
150     }
151     PyObject *mod = PyImport_GetModule(name);
152     Py_DECREF(name);
153     if (mod == NULL) {
154         return NULL;
155     }
156     assert(mod != Py_None);
157     return mod;
158 }
159 
160 static PyObject *
get_module_from_owned_type(PyTypeObject * cls)161 get_module_from_owned_type(PyTypeObject *cls)
162 {
163     assert(cls != NULL);
164     return _get_current_module();
165     // XXX Use the more efficient API now that we use heap types:
166     //return PyType_GetModule(cls);
167 }
168 
169 static struct PyModuleDef moduledef;
170 
171 static PyObject *
get_module_from_type(PyTypeObject * cls)172 get_module_from_type(PyTypeObject *cls)
173 {
174     assert(cls != NULL);
175     return _get_current_module();
176     // XXX Use the more efficient API now that we use heap types:
177     //return PyType_GetModuleByDef(cls, &moduledef);
178 }
179 
180 static PyObject *
add_new_exception(PyObject * mod,const char * name,PyObject * base)181 add_new_exception(PyObject *mod, const char *name, PyObject *base)
182 {
183     assert(!PyObject_HasAttrStringWithError(mod, name));
184     PyObject *exctype = PyErr_NewException(name, base, NULL);
185     if (exctype == NULL) {
186         return NULL;
187     }
188     int res = PyModule_AddType(mod, (PyTypeObject *)exctype);
189     if (res < 0) {
190         Py_DECREF(exctype);
191         return NULL;
192     }
193     return exctype;
194 }
195 
196 #define ADD_NEW_EXCEPTION(MOD, NAME, BASE) \
197     add_new_exception(MOD, MODULE_NAME_STR "." Py_STRINGIFY(NAME), BASE)
198 
199 static int
wait_for_lock(PyThread_type_lock mutex,PY_TIMEOUT_T timeout)200 wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
201 {
202     PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
203     if (res == PY_LOCK_INTR) {
204         /* KeyboardInterrupt, etc. */
205         assert(PyErr_Occurred());
206         return -1;
207     }
208     else if (res == PY_LOCK_FAILURE) {
209         assert(!PyErr_Occurred());
210         assert(timeout > 0);
211         PyErr_SetString(PyExc_TimeoutError, "timed out");
212         return -1;
213     }
214     assert(res == PY_LOCK_ACQUIRED);
215     PyThread_release_lock(mutex);
216     return 0;
217 }
218 
219 
220 /* module state *************************************************************/
221 
222 typedef struct {
223     /* Added at runtime by interpreters module. */
224     PyTypeObject *send_channel_type;
225     PyTypeObject *recv_channel_type;
226 
227     /* heap types */
228     PyTypeObject *ChannelInfoType;
229     PyTypeObject *ChannelIDType;
230 
231     /* exceptions */
232     PyObject *ChannelError;
233     PyObject *ChannelNotFoundError;
234     PyObject *ChannelClosedError;
235     PyObject *ChannelEmptyError;
236     PyObject *ChannelNotEmptyError;
237 } module_state;
238 
239 static inline module_state *
get_module_state(PyObject * mod)240 get_module_state(PyObject *mod)
241 {
242     assert(mod != NULL);
243     module_state *state = PyModule_GetState(mod);
244     assert(state != NULL);
245     return state;
246 }
247 
248 static module_state *
_get_current_module_state(void)249 _get_current_module_state(void)
250 {
251     PyObject *mod = _get_current_module();
252     if (mod == NULL) {
253         // XXX import it?
254         PyErr_SetString(PyExc_RuntimeError,
255                         MODULE_NAME_STR " module not imported yet");
256         return NULL;
257     }
258     module_state *state = get_module_state(mod);
259     Py_DECREF(mod);
260     return state;
261 }
262 
263 static int
traverse_module_state(module_state * state,visitproc visit,void * arg)264 traverse_module_state(module_state *state, visitproc visit, void *arg)
265 {
266     /* external types */
267     Py_VISIT(state->send_channel_type);
268     Py_VISIT(state->recv_channel_type);
269 
270     /* heap types */
271     Py_VISIT(state->ChannelInfoType);
272     Py_VISIT(state->ChannelIDType);
273 
274     /* exceptions */
275     Py_VISIT(state->ChannelError);
276     Py_VISIT(state->ChannelNotFoundError);
277     Py_VISIT(state->ChannelClosedError);
278     Py_VISIT(state->ChannelEmptyError);
279     Py_VISIT(state->ChannelNotEmptyError);
280 
281     return 0;
282 }
283 
284 static void
clear_xid_types(module_state * state)285 clear_xid_types(module_state *state)
286 {
287     /* external types */
288     if (state->send_channel_type != NULL) {
289         (void)clear_xid_class(state->send_channel_type);
290         Py_CLEAR(state->send_channel_type);
291     }
292     if (state->recv_channel_type != NULL) {
293         (void)clear_xid_class(state->recv_channel_type);
294         Py_CLEAR(state->recv_channel_type);
295     }
296 
297     /* heap types */
298     if (state->ChannelIDType != NULL) {
299         (void)clear_xid_class(state->ChannelIDType);
300         Py_CLEAR(state->ChannelIDType);
301     }
302 }
303 
304 static int
clear_module_state(module_state * state)305 clear_module_state(module_state *state)
306 {
307     clear_xid_types(state);
308 
309     /* heap types */
310     Py_CLEAR(state->ChannelInfoType);
311 
312     /* exceptions */
313     Py_CLEAR(state->ChannelError);
314     Py_CLEAR(state->ChannelNotFoundError);
315     Py_CLEAR(state->ChannelClosedError);
316     Py_CLEAR(state->ChannelEmptyError);
317     Py_CLEAR(state->ChannelNotEmptyError);
318 
319     return 0;
320 }
321 
322 
323 /* channel-specific code ****************************************************/
324 
325 #define CHANNEL_SEND 1
326 #define CHANNEL_BOTH 0
327 #define CHANNEL_RECV -1
328 
329 
330 /* channel errors */
331 
332 #define ERR_CHANNEL_NOT_FOUND -2
333 #define ERR_CHANNEL_CLOSED -3
334 #define ERR_CHANNEL_INTERP_CLOSED -4
335 #define ERR_CHANNEL_EMPTY -5
336 #define ERR_CHANNEL_NOT_EMPTY -6
337 #define ERR_CHANNEL_MUTEX_INIT -7
338 #define ERR_CHANNELS_MUTEX_INIT -8
339 #define ERR_NO_NEXT_CHANNEL_ID -9
340 #define ERR_CHANNEL_CLOSED_WAITING -10
341 
342 static int
exceptions_init(PyObject * mod)343 exceptions_init(PyObject *mod)
344 {
345     module_state *state = get_module_state(mod);
346     if (state == NULL) {
347         return -1;
348     }
349 
350 #define ADD(NAME, BASE) \
351     do { \
352         assert(state->NAME == NULL); \
353         state->NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \
354         if (state->NAME == NULL) { \
355             return -1; \
356         } \
357     } while (0)
358 
359     // A channel-related operation failed.
360     ADD(ChannelError, PyExc_RuntimeError);
361     // An operation tried to use a channel that doesn't exist.
362     ADD(ChannelNotFoundError, state->ChannelError);
363     // An operation tried to use a closed channel.
364     ADD(ChannelClosedError, state->ChannelError);
365     // An operation tried to pop from an empty channel.
366     ADD(ChannelEmptyError, state->ChannelError);
367     // An operation tried to close a non-empty channel.
368     ADD(ChannelNotEmptyError, state->ChannelError);
369 #undef ADD
370 
371     return 0;
372 }
373 
374 static int
handle_channel_error(int err,PyObject * mod,int64_t cid)375 handle_channel_error(int err, PyObject *mod, int64_t cid)
376 {
377     if (err == 0) {
378         assert(!PyErr_Occurred());
379         return 0;
380     }
381     assert(err < 0);
382     module_state *state = get_module_state(mod);
383     assert(state != NULL);
384     if (err == ERR_CHANNEL_NOT_FOUND) {
385         PyErr_Format(state->ChannelNotFoundError,
386                      "channel %" PRId64 " not found", cid);
387     }
388     else if (err == ERR_CHANNEL_CLOSED) {
389         PyErr_Format(state->ChannelClosedError,
390                      "channel %" PRId64 " is closed", cid);
391     }
392     else if (err == ERR_CHANNEL_CLOSED_WAITING) {
393         PyErr_Format(state->ChannelClosedError,
394                      "channel %" PRId64 " has closed", cid);
395     }
396     else if (err == ERR_CHANNEL_INTERP_CLOSED) {
397         PyErr_Format(state->ChannelClosedError,
398                      "channel %" PRId64 " is already closed", cid);
399     }
400     else if (err == ERR_CHANNEL_EMPTY) {
401         PyErr_Format(state->ChannelEmptyError,
402                      "channel %" PRId64 " is empty", cid);
403     }
404     else if (err == ERR_CHANNEL_NOT_EMPTY) {
405         PyErr_Format(state->ChannelNotEmptyError,
406                      "channel %" PRId64 " may not be closed "
407                      "if not empty (try force=True)",
408                      cid);
409     }
410     else if (err == ERR_CHANNEL_MUTEX_INIT) {
411         PyErr_SetString(state->ChannelError,
412                         "can't initialize mutex for new channel");
413     }
414     else if (err == ERR_CHANNELS_MUTEX_INIT) {
415         PyErr_SetString(state->ChannelError,
416                         "can't initialize mutex for channel management");
417     }
418     else if (err == ERR_NO_NEXT_CHANNEL_ID) {
419         PyErr_SetString(state->ChannelError,
420                         "failed to get a channel ID");
421     }
422     else {
423         assert(PyErr_Occurred());
424     }
425     return 1;
426 }
427 
428 
429 /* the channel queue */
430 
431 typedef uintptr_t _channelitem_id_t;
432 
433 typedef struct wait_info {
434     PyThread_type_lock mutex;
435     enum {
436         WAITING_NO_STATUS = 0,
437         WAITING_ACQUIRED = 1,
438         WAITING_RELEASING = 2,
439         WAITING_RELEASED = 3,
440     } status;
441     int received;
442     _channelitem_id_t itemid;
443 } _waiting_t;
444 
445 static int
_waiting_init(_waiting_t * waiting)446 _waiting_init(_waiting_t *waiting)
447 {
448     PyThread_type_lock mutex = PyThread_allocate_lock();
449     if (mutex == NULL) {
450         PyErr_NoMemory();
451         return -1;
452     }
453 
454     *waiting = (_waiting_t){
455         .mutex = mutex,
456         .status = WAITING_NO_STATUS,
457     };
458     return 0;
459 }
460 
461 static void
_waiting_clear(_waiting_t * waiting)462 _waiting_clear(_waiting_t *waiting)
463 {
464     assert(waiting->status != WAITING_ACQUIRED
465            && waiting->status != WAITING_RELEASING);
466     if (waiting->mutex != NULL) {
467         PyThread_free_lock(waiting->mutex);
468         waiting->mutex = NULL;
469     }
470 }
471 
472 static _channelitem_id_t
_waiting_get_itemid(_waiting_t * waiting)473 _waiting_get_itemid(_waiting_t *waiting)
474 {
475     return waiting->itemid;
476 }
477 
478 static void
_waiting_acquire(_waiting_t * waiting)479 _waiting_acquire(_waiting_t *waiting)
480 {
481     assert(waiting->status == WAITING_NO_STATUS);
482     PyThread_acquire_lock(waiting->mutex, NOWAIT_LOCK);
483     waiting->status = WAITING_ACQUIRED;
484 }
485 
486 static void
_waiting_release(_waiting_t * waiting,int received)487 _waiting_release(_waiting_t *waiting, int received)
488 {
489     assert(waiting->mutex != NULL);
490     assert(waiting->status == WAITING_ACQUIRED);
491     assert(!waiting->received);
492 
493     waiting->status = WAITING_RELEASING;
494     PyThread_release_lock(waiting->mutex);
495     if (waiting->received != received) {
496         assert(received == 1);
497         waiting->received = received;
498     }
499     waiting->status = WAITING_RELEASED;
500 }
501 
502 static void
_waiting_finish_releasing(_waiting_t * waiting)503 _waiting_finish_releasing(_waiting_t *waiting)
504 {
505     while (waiting->status == WAITING_RELEASING) {
506 #ifdef MS_WINDOWS
507         SwitchToThread();
508 #elif defined(HAVE_SCHED_H)
509         sched_yield();
510 #endif
511     }
512 }
513 
514 struct _channelitem;
515 
516 typedef struct _channelitem {
517     /* The interpreter that added the item to the queue.
518        The actual bound interpid is found in item->data.
519        This is necessary because item->data might be NULL,
520        meaning the interpreter has been destroyed. */
521     int64_t interpid;
522     _PyCrossInterpreterData *data;
523     _waiting_t *waiting;
524     int unboundop;
525     struct _channelitem *next;
526 } _channelitem;
527 
528 static inline _channelitem_id_t
_channelitem_ID(_channelitem * item)529 _channelitem_ID(_channelitem *item)
530 {
531     return (_channelitem_id_t)item;
532 }
533 
534 static void
_channelitem_init(_channelitem * item,int64_t interpid,_PyCrossInterpreterData * data,_waiting_t * waiting,int unboundop)535 _channelitem_init(_channelitem *item,
536                   int64_t interpid, _PyCrossInterpreterData *data,
537                   _waiting_t *waiting, int unboundop)
538 {
539     if (interpid < 0) {
540         interpid = _get_interpid(data);
541     }
542     else {
543         assert(data == NULL
544                || _PyCrossInterpreterData_INTERPID(data) < 0
545                || interpid == _PyCrossInterpreterData_INTERPID(data));
546     }
547     *item = (_channelitem){
548         .interpid = interpid,
549         .data = data,
550         .waiting = waiting,
551         .unboundop = unboundop,
552     };
553     if (waiting != NULL) {
554         waiting->itemid = _channelitem_ID(item);
555     }
556 }
557 
558 static void
_channelitem_clear_data(_channelitem * item,int removed)559 _channelitem_clear_data(_channelitem *item, int removed)
560 {
561     if (item->data != NULL) {
562         // It was allocated in channel_send().
563         (void)_release_xid_data(item->data, XID_IGNORE_EXC & XID_FREE);
564         item->data = NULL;
565     }
566 
567     if (item->waiting != NULL && removed) {
568         if (item->waiting->status == WAITING_ACQUIRED) {
569             _waiting_release(item->waiting, 0);
570         }
571         item->waiting = NULL;
572     }
573 }
574 
575 static void
_channelitem_clear(_channelitem * item)576 _channelitem_clear(_channelitem *item)
577 {
578     item->next = NULL;
579     _channelitem_clear_data(item, 1);
580 }
581 
582 static _channelitem *
_channelitem_new(int64_t interpid,_PyCrossInterpreterData * data,_waiting_t * waiting,int unboundop)583 _channelitem_new(int64_t interpid, _PyCrossInterpreterData *data,
584                  _waiting_t *waiting, int unboundop)
585 {
586     _channelitem *item = GLOBAL_MALLOC(_channelitem);
587     if (item == NULL) {
588         PyErr_NoMemory();
589         return NULL;
590     }
591     _channelitem_init(item, interpid, data, waiting, unboundop);
592     return item;
593 }
594 
595 static void
_channelitem_free(_channelitem * item)596 _channelitem_free(_channelitem *item)
597 {
598     _channelitem_clear(item);
599     GLOBAL_FREE(item);
600 }
601 
602 static void
_channelitem_free_all(_channelitem * item)603 _channelitem_free_all(_channelitem *item)
604 {
605     while (item != NULL) {
606         _channelitem *last = item;
607         item = item->next;
608         _channelitem_free(last);
609     }
610 }
611 
612 static void
_channelitem_popped(_channelitem * item,_PyCrossInterpreterData ** p_data,_waiting_t ** p_waiting,int * p_unboundop)613 _channelitem_popped(_channelitem *item,
614                     _PyCrossInterpreterData **p_data, _waiting_t **p_waiting,
615                     int *p_unboundop)
616 {
617     assert(item->waiting == NULL || item->waiting->status == WAITING_ACQUIRED);
618     *p_data = item->data;
619     *p_waiting = item->waiting;
620     *p_unboundop = item->unboundop;
621     // We clear them here, so they won't be released in _channelitem_clear().
622     item->data = NULL;
623     item->waiting = NULL;
624     _channelitem_free(item);
625 }
626 
627 static int
_channelitem_clear_interpreter(_channelitem * item)628 _channelitem_clear_interpreter(_channelitem *item)
629 {
630     assert(item->interpid >= 0);
631     if (item->data == NULL) {
632         // Its interpreter was already cleared (or it was never bound).
633         // For UNBOUND_REMOVE it should have been freed at that time.
634         assert(item->unboundop != UNBOUND_REMOVE);
635         return 0;
636     }
637     assert(_PyCrossInterpreterData_INTERPID(item->data) == item->interpid);
638 
639     switch (item->unboundop) {
640     case UNBOUND_REMOVE:
641         // The caller must free/clear it.
642         return 1;
643     case UNBOUND_ERROR:
644     case UNBOUND_REPLACE:
645         // We won't need the cross-interpreter data later
646         // so we completely throw it away.
647         _channelitem_clear_data(item, 0);
648         return 0;
649     default:
650         Py_FatalError("not reachable");
651         return -1;
652     }
653 }
654 
655 
656 typedef struct _channelqueue {
657     int64_t count;
658     _channelitem *first;
659     _channelitem *last;
660 } _channelqueue;
661 
662 static _channelqueue *
_channelqueue_new(void)663 _channelqueue_new(void)
664 {
665     _channelqueue *queue = GLOBAL_MALLOC(_channelqueue);
666     if (queue == NULL) {
667         PyErr_NoMemory();
668         return NULL;
669     }
670     queue->count = 0;
671     queue->first = NULL;
672     queue->last = NULL;
673     return queue;
674 }
675 
676 static void
_channelqueue_clear(_channelqueue * queue)677 _channelqueue_clear(_channelqueue *queue)
678 {
679     _channelitem_free_all(queue->first);
680     queue->count = 0;
681     queue->first = NULL;
682     queue->last = NULL;
683 }
684 
685 static void
_channelqueue_free(_channelqueue * queue)686 _channelqueue_free(_channelqueue *queue)
687 {
688     _channelqueue_clear(queue);
689     GLOBAL_FREE(queue);
690 }
691 
692 static int
_channelqueue_put(_channelqueue * queue,int64_t interpid,_PyCrossInterpreterData * data,_waiting_t * waiting,int unboundop)693 _channelqueue_put(_channelqueue *queue,
694                   int64_t interpid, _PyCrossInterpreterData *data,
695                   _waiting_t *waiting, int unboundop)
696 {
697     _channelitem *item = _channelitem_new(interpid, data, waiting, unboundop);
698     if (item == NULL) {
699         return -1;
700     }
701 
702     queue->count += 1;
703     if (queue->first == NULL) {
704         queue->first = item;
705     }
706     else {
707         queue->last->next = item;
708     }
709     queue->last = item;
710 
711     if (waiting != NULL) {
712         _waiting_acquire(waiting);
713     }
714 
715     return 0;
716 }
717 
718 static int
_channelqueue_get(_channelqueue * queue,_PyCrossInterpreterData ** p_data,_waiting_t ** p_waiting,int * p_unboundop)719 _channelqueue_get(_channelqueue *queue,
720                   _PyCrossInterpreterData **p_data, _waiting_t **p_waiting,
721                   int *p_unboundop)
722 {
723     _channelitem *item = queue->first;
724     if (item == NULL) {
725         return ERR_CHANNEL_EMPTY;
726     }
727     queue->first = item->next;
728     if (queue->last == item) {
729         queue->last = NULL;
730     }
731     queue->count -= 1;
732 
733     _channelitem_popped(item, p_data, p_waiting, p_unboundop);
734     return 0;
735 }
736 
737 static int
_channelqueue_find(_channelqueue * queue,_channelitem_id_t itemid,_channelitem ** p_item,_channelitem ** p_prev)738 _channelqueue_find(_channelqueue *queue, _channelitem_id_t itemid,
739                    _channelitem **p_item, _channelitem **p_prev)
740 {
741     _channelitem *prev = NULL;
742     _channelitem *item = NULL;
743     if (queue->first != NULL) {
744         if (_channelitem_ID(queue->first) == itemid) {
745             item = queue->first;
746         }
747         else {
748             prev = queue->first;
749             while (prev->next != NULL) {
750                 if (_channelitem_ID(prev->next) == itemid) {
751                     item = prev->next;
752                     break;
753                 }
754                 prev = prev->next;
755             }
756             if (item == NULL) {
757                 prev = NULL;
758             }
759         }
760     }
761     if (p_item != NULL) {
762         *p_item = item;
763     }
764     if (p_prev != NULL) {
765         *p_prev = prev;
766     }
767     return (item != NULL);
768 }
769 
770 static void
_channelqueue_remove(_channelqueue * queue,_channelitem_id_t itemid,_PyCrossInterpreterData ** p_data,_waiting_t ** p_waiting)771 _channelqueue_remove(_channelqueue *queue, _channelitem_id_t itemid,
772                      _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
773 {
774     _channelitem *prev = NULL;
775     _channelitem *item = NULL;
776     int found = _channelqueue_find(queue, itemid, &item, &prev);
777     if (!found) {
778         return;
779     }
780 
781     assert(item->waiting != NULL);
782     assert(!item->waiting->received);
783     if (prev == NULL) {
784         assert(queue->first == item);
785         queue->first = item->next;
786     }
787     else {
788         assert(queue->first != item);
789         assert(prev->next == item);
790         prev->next = item->next;
791     }
792     item->next = NULL;
793 
794     if (queue->last == item) {
795         queue->last = prev;
796     }
797     queue->count -= 1;
798 
799     int unboundop;
800     _channelitem_popped(item, p_data, p_waiting, &unboundop);
801 }
802 
803 static void
_channelqueue_clear_interpreter(_channelqueue * queue,int64_t interpid)804 _channelqueue_clear_interpreter(_channelqueue *queue, int64_t interpid)
805 {
806     _channelitem *prev = NULL;
807     _channelitem *next = queue->first;
808     while (next != NULL) {
809         _channelitem *item = next;
810         next = item->next;
811         int remove = (item->interpid == interpid)
812             ? _channelitem_clear_interpreter(item)
813             : 0;
814         if (remove) {
815             _channelitem_free(item);
816             if (prev == NULL) {
817                 queue->first = next;
818             }
819             else {
820                 prev->next = next;
821             }
822             queue->count -= 1;
823         }
824         else {
825             prev = item;
826         }
827     }
828 }
829 
830 
831 /* channel-interpreter associations */
832 
833 struct _channelend;
834 
835 typedef struct _channelend {
836     struct _channelend *next;
837     int64_t interpid;
838     int open;
839 } _channelend;
840 
841 static _channelend *
_channelend_new(int64_t interpid)842 _channelend_new(int64_t interpid)
843 {
844     _channelend *end = GLOBAL_MALLOC(_channelend);
845     if (end == NULL) {
846         PyErr_NoMemory();
847         return NULL;
848     }
849     end->next = NULL;
850     end->interpid = interpid;
851     end->open = 1;
852     return end;
853 }
854 
855 static void
_channelend_free(_channelend * end)856 _channelend_free(_channelend *end)
857 {
858     GLOBAL_FREE(end);
859 }
860 
861 static void
_channelend_free_all(_channelend * end)862 _channelend_free_all(_channelend *end)
863 {
864     while (end != NULL) {
865         _channelend *last = end;
866         end = end->next;
867         _channelend_free(last);
868     }
869 }
870 
871 static _channelend *
_channelend_find(_channelend * first,int64_t interpid,_channelend ** pprev)872 _channelend_find(_channelend *first, int64_t interpid, _channelend **pprev)
873 {
874     _channelend *prev = NULL;
875     _channelend *end = first;
876     while (end != NULL) {
877         if (end->interpid == interpid) {
878             break;
879         }
880         prev = end;
881         end = end->next;
882     }
883     if (pprev != NULL) {
884         *pprev = prev;
885     }
886     return end;
887 }
888 
889 typedef struct _channelassociations {
890     // Note that the list entries are never removed for interpreter
891     // for which the channel is closed.  This should not be a problem in
892     // practice.  Also, a channel isn't automatically closed when an
893     // interpreter is destroyed.
894     int64_t numsendopen;
895     int64_t numrecvopen;
896     _channelend *send;
897     _channelend *recv;
898 } _channelends;
899 
900 static _channelends *
_channelends_new(void)901 _channelends_new(void)
902 {
903     _channelends *ends = GLOBAL_MALLOC(_channelends);
904     if (ends== NULL) {
905         return NULL;
906     }
907     ends->numsendopen = 0;
908     ends->numrecvopen = 0;
909     ends->send = NULL;
910     ends->recv = NULL;
911     return ends;
912 }
913 
914 static void
_channelends_clear(_channelends * ends)915 _channelends_clear(_channelends *ends)
916 {
917     _channelend_free_all(ends->send);
918     ends->send = NULL;
919     ends->numsendopen = 0;
920 
921     _channelend_free_all(ends->recv);
922     ends->recv = NULL;
923     ends->numrecvopen = 0;
924 }
925 
926 static void
_channelends_free(_channelends * ends)927 _channelends_free(_channelends *ends)
928 {
929     _channelends_clear(ends);
930     GLOBAL_FREE(ends);
931 }
932 
933 static _channelend *
_channelends_add(_channelends * ends,_channelend * prev,int64_t interpid,int send)934 _channelends_add(_channelends *ends, _channelend *prev, int64_t interpid,
935                  int send)
936 {
937     _channelend *end = _channelend_new(interpid);
938     if (end == NULL) {
939         return NULL;
940     }
941 
942     if (prev == NULL) {
943         if (send) {
944             ends->send = end;
945         }
946         else {
947             ends->recv = end;
948         }
949     }
950     else {
951         prev->next = end;
952     }
953     if (send) {
954         ends->numsendopen += 1;
955     }
956     else {
957         ends->numrecvopen += 1;
958     }
959     return end;
960 }
961 
962 static int
_channelends_associate(_channelends * ends,int64_t interpid,int send)963 _channelends_associate(_channelends *ends, int64_t interpid, int send)
964 {
965     _channelend *prev;
966     _channelend *end = _channelend_find(send ? ends->send : ends->recv,
967                                         interpid, &prev);
968     if (end != NULL) {
969         if (!end->open) {
970             return ERR_CHANNEL_CLOSED;
971         }
972         // already associated
973         return 0;
974     }
975     if (_channelends_add(ends, prev, interpid, send) == NULL) {
976         return -1;
977     }
978     return 0;
979 }
980 
981 static int
_channelends_is_open(_channelends * ends)982 _channelends_is_open(_channelends *ends)
983 {
984     if (ends->numsendopen != 0 || ends->numrecvopen != 0) {
985         // At least one interpreter is still associated with the channel
986         // (and hasn't been released).
987         return 1;
988     }
989     // XXX This is wrong if an end can ever be removed.
990     if (ends->send == NULL && ends->recv == NULL) {
991         // The channel has never had any interpreters associated with it.
992         return 1;
993     }
994     return 0;
995 }
996 
997 static void
_channelends_release_end(_channelends * ends,_channelend * end,int send)998 _channelends_release_end(_channelends *ends, _channelend *end, int send)
999 {
1000     end->open = 0;
1001     if (send) {
1002         ends->numsendopen -= 1;
1003     }
1004     else {
1005         ends->numrecvopen -= 1;
1006     }
1007 }
1008 
1009 static int
_channelends_release_interpreter(_channelends * ends,int64_t interpid,int which)1010 _channelends_release_interpreter(_channelends *ends, int64_t interpid, int which)
1011 {
1012     _channelend *prev;
1013     _channelend *end;
1014     if (which >= 0) {  // send/both
1015         end = _channelend_find(ends->send, interpid, &prev);
1016         if (end == NULL) {
1017             // never associated so add it
1018             end = _channelends_add(ends, prev, interpid, 1);
1019             if (end == NULL) {
1020                 return -1;
1021             }
1022         }
1023         _channelends_release_end(ends, end, 1);
1024     }
1025     if (which <= 0) {  // recv/both
1026         end = _channelend_find(ends->recv, interpid, &prev);
1027         if (end == NULL) {
1028             // never associated so add it
1029             end = _channelends_add(ends, prev, interpid, 0);
1030             if (end == NULL) {
1031                 return -1;
1032             }
1033         }
1034         _channelends_release_end(ends, end, 0);
1035     }
1036     return 0;
1037 }
1038 
1039 static void
_channelends_release_all(_channelends * ends,int which,int force)1040 _channelends_release_all(_channelends *ends, int which, int force)
1041 {
1042     // XXX Handle the ends.
1043     // XXX Handle force is True.
1044 
1045     // Ensure all the "send"-associated interpreters are closed.
1046     _channelend *end;
1047     for (end = ends->send; end != NULL; end = end->next) {
1048         _channelends_release_end(ends, end, 1);
1049     }
1050 
1051     // Ensure all the "recv"-associated interpreters are closed.
1052     for (end = ends->recv; end != NULL; end = end->next) {
1053         _channelends_release_end(ends, end, 0);
1054     }
1055 }
1056 
1057 static void
_channelends_clear_interpreter(_channelends * ends,int64_t interpid)1058 _channelends_clear_interpreter(_channelends *ends, int64_t interpid)
1059 {
1060     // XXX Actually remove the entries?
1061     _channelend *end;
1062     end = _channelend_find(ends->send, interpid, NULL);
1063     if (end != NULL) {
1064         _channelends_release_end(ends, end, 1);
1065     }
1066     end = _channelend_find(ends->recv, interpid, NULL);
1067     if (end != NULL) {
1068         _channelends_release_end(ends, end, 0);
1069     }
1070 }
1071 
1072 
1073 /* each channel's state */
1074 
1075 struct _channel;
1076 struct _channel_closing;
1077 static void _channel_clear_closing(struct _channel *);
1078 static void _channel_finish_closing(struct _channel *);
1079 
1080 typedef struct _channel {
1081     PyThread_type_lock mutex;
1082     _channelqueue *queue;
1083     _channelends *ends;
1084     struct {
1085         int unboundop;
1086     } defaults;
1087     int open;
1088     struct _channel_closing *closing;
1089 } _channel_state;
1090 
1091 static _channel_state *
_channel_new(PyThread_type_lock mutex,int unboundop)1092 _channel_new(PyThread_type_lock mutex, int unboundop)
1093 {
1094     _channel_state *chan = GLOBAL_MALLOC(_channel_state);
1095     if (chan == NULL) {
1096         return NULL;
1097     }
1098     chan->mutex = mutex;
1099     chan->queue = _channelqueue_new();
1100     if (chan->queue == NULL) {
1101         GLOBAL_FREE(chan);
1102         return NULL;
1103     }
1104     chan->ends = _channelends_new();
1105     if (chan->ends == NULL) {
1106         _channelqueue_free(chan->queue);
1107         GLOBAL_FREE(chan);
1108         return NULL;
1109     }
1110     chan->defaults.unboundop = unboundop;
1111     chan->open = 1;
1112     chan->closing = NULL;
1113     return chan;
1114 }
1115 
1116 static void
_channel_free(_channel_state * chan)1117 _channel_free(_channel_state *chan)
1118 {
1119     _channel_clear_closing(chan);
1120     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1121     _channelqueue_free(chan->queue);
1122     _channelends_free(chan->ends);
1123     PyThread_release_lock(chan->mutex);
1124 
1125     PyThread_free_lock(chan->mutex);
1126     GLOBAL_FREE(chan);
1127 }
1128 
1129 static int
_channel_add(_channel_state * chan,int64_t interpid,_PyCrossInterpreterData * data,_waiting_t * waiting,int unboundop)1130 _channel_add(_channel_state *chan, int64_t interpid,
1131              _PyCrossInterpreterData *data, _waiting_t *waiting,
1132              int unboundop)
1133 {
1134     int res = -1;
1135     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1136 
1137     if (!chan->open) {
1138         res = ERR_CHANNEL_CLOSED;
1139         goto done;
1140     }
1141     if (_channelends_associate(chan->ends, interpid, 1) != 0) {
1142         res = ERR_CHANNEL_INTERP_CLOSED;
1143         goto done;
1144     }
1145 
1146     if (_channelqueue_put(chan->queue, interpid, data, waiting, unboundop) != 0) {
1147         goto done;
1148     }
1149     // Any errors past this point must cause a _waiting_release() call.
1150 
1151     res = 0;
1152 done:
1153     PyThread_release_lock(chan->mutex);
1154     return res;
1155 }
1156 
1157 static int
_channel_next(_channel_state * chan,int64_t interpid,_PyCrossInterpreterData ** p_data,_waiting_t ** p_waiting,int * p_unboundop)1158 _channel_next(_channel_state *chan, int64_t interpid,
1159               _PyCrossInterpreterData **p_data, _waiting_t **p_waiting,
1160               int *p_unboundop)
1161 {
1162     int err = 0;
1163     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1164 
1165     if (!chan->open) {
1166         err = ERR_CHANNEL_CLOSED;
1167         goto done;
1168     }
1169     if (_channelends_associate(chan->ends, interpid, 0) != 0) {
1170         err = ERR_CHANNEL_INTERP_CLOSED;
1171         goto done;
1172     }
1173 
1174     int empty = _channelqueue_get(chan->queue, p_data, p_waiting, p_unboundop);
1175     assert(!PyErr_Occurred());
1176     if (empty) {
1177         assert(empty == ERR_CHANNEL_EMPTY);
1178         if (chan->closing != NULL) {
1179             chan->open = 0;
1180         }
1181         err = ERR_CHANNEL_EMPTY;
1182         goto done;
1183     }
1184 
1185 done:
1186     PyThread_release_lock(chan->mutex);
1187     if (chan->queue->count == 0) {
1188         _channel_finish_closing(chan);
1189     }
1190     return err;
1191 }
1192 
1193 static void
_channel_remove(_channel_state * chan,_channelitem_id_t itemid)1194 _channel_remove(_channel_state *chan, _channelitem_id_t itemid)
1195 {
1196     _PyCrossInterpreterData *data = NULL;
1197     _waiting_t *waiting = NULL;
1198 
1199     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1200     _channelqueue_remove(chan->queue, itemid, &data, &waiting);
1201     PyThread_release_lock(chan->mutex);
1202 
1203     (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
1204     if (waiting != NULL) {
1205         _waiting_release(waiting, 0);
1206     }
1207 
1208     if (chan->queue->count == 0) {
1209         _channel_finish_closing(chan);
1210     }
1211 }
1212 
1213 static int
_channel_release_interpreter(_channel_state * chan,int64_t interpid,int end)1214 _channel_release_interpreter(_channel_state *chan, int64_t interpid, int end)
1215 {
1216     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1217 
1218     int res = -1;
1219     if (!chan->open) {
1220         res = ERR_CHANNEL_CLOSED;
1221         goto done;
1222     }
1223 
1224     if (_channelends_release_interpreter(chan->ends, interpid, end) != 0) {
1225         goto done;
1226     }
1227     chan->open = _channelends_is_open(chan->ends);
1228     // XXX Clear the queue if not empty?
1229     // XXX Activate the "closing" mechanism?
1230 
1231     res = 0;
1232 done:
1233     PyThread_release_lock(chan->mutex);
1234     return res;
1235 }
1236 
1237 static int
_channel_release_all(_channel_state * chan,int end,int force)1238 _channel_release_all(_channel_state *chan, int end, int force)
1239 {
1240     int res = -1;
1241     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1242 
1243     if (!chan->open) {
1244         res = ERR_CHANNEL_CLOSED;
1245         goto done;
1246     }
1247 
1248     if (!force && chan->queue->count > 0) {
1249         res = ERR_CHANNEL_NOT_EMPTY;
1250         goto done;
1251     }
1252     // XXX Clear the queue?
1253 
1254     chan->open = 0;
1255 
1256     // We *could* also just leave these in place, since we've marked
1257     // the channel as closed already.
1258     _channelends_release_all(chan->ends, end, force);
1259 
1260     res = 0;
1261 done:
1262     PyThread_release_lock(chan->mutex);
1263     return res;
1264 }
1265 
1266 static void
_channel_clear_interpreter(_channel_state * chan,int64_t interpid)1267 _channel_clear_interpreter(_channel_state *chan, int64_t interpid)
1268 {
1269     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1270 
1271     _channelqueue_clear_interpreter(chan->queue, interpid);
1272     _channelends_clear_interpreter(chan->ends, interpid);
1273     chan->open = _channelends_is_open(chan->ends);
1274 
1275     PyThread_release_lock(chan->mutex);
1276 }
1277 
1278 
1279 /* the set of channels */
1280 
1281 struct _channelref;
1282 
1283 typedef struct _channelref {
1284     int64_t cid;
1285     _channel_state *chan;
1286     struct _channelref *next;
1287     // The number of ChannelID objects referring to this channel.
1288     Py_ssize_t objcount;
1289 } _channelref;
1290 
1291 static _channelref *
_channelref_new(int64_t cid,_channel_state * chan)1292 _channelref_new(int64_t cid, _channel_state *chan)
1293 {
1294     _channelref *ref = GLOBAL_MALLOC(_channelref);
1295     if (ref == NULL) {
1296         return NULL;
1297     }
1298     ref->cid = cid;
1299     ref->chan = chan;
1300     ref->next = NULL;
1301     ref->objcount = 0;
1302     return ref;
1303 }
1304 
1305 //static void
1306 //_channelref_clear(_channelref *ref)
1307 //{
1308 //    ref->cid = -1;
1309 //    ref->chan = NULL;
1310 //    ref->next = NULL;
1311 //    ref->objcount = 0;
1312 //}
1313 
1314 static void
_channelref_free(_channelref * ref)1315 _channelref_free(_channelref *ref)
1316 {
1317     if (ref->chan != NULL) {
1318         _channel_clear_closing(ref->chan);
1319     }
1320     //_channelref_clear(ref);
1321     GLOBAL_FREE(ref);
1322 }
1323 
1324 static _channelref *
_channelref_find(_channelref * first,int64_t cid,_channelref ** pprev)1325 _channelref_find(_channelref *first, int64_t cid, _channelref **pprev)
1326 {
1327     _channelref *prev = NULL;
1328     _channelref *ref = first;
1329     while (ref != NULL) {
1330         if (ref->cid == cid) {
1331             break;
1332         }
1333         prev = ref;
1334         ref = ref->next;
1335     }
1336     if (pprev != NULL) {
1337         *pprev = prev;
1338     }
1339     return ref;
1340 }
1341 
1342 
1343 typedef struct _channels {
1344     PyThread_type_lock mutex;
1345     _channelref *head;
1346     int64_t numopen;
1347     int64_t next_id;
1348 } _channels;
1349 
1350 static void
_channels_init(_channels * channels,PyThread_type_lock mutex)1351 _channels_init(_channels *channels, PyThread_type_lock mutex)
1352 {
1353     assert(mutex != NULL);
1354     assert(channels->mutex == NULL);
1355     *channels = (_channels){
1356         .mutex = mutex,
1357         .head = NULL,
1358         .numopen = 0,
1359         .next_id = 0,
1360     };
1361 }
1362 
1363 static void
_channels_fini(_channels * channels,PyThread_type_lock * p_mutex)1364 _channels_fini(_channels *channels, PyThread_type_lock *p_mutex)
1365 {
1366     PyThread_type_lock mutex = channels->mutex;
1367     assert(mutex != NULL);
1368 
1369     PyThread_acquire_lock(mutex, WAIT_LOCK);
1370     assert(channels->numopen == 0);
1371     assert(channels->head == NULL);
1372     *channels = (_channels){0};
1373     PyThread_release_lock(mutex);
1374 
1375     *p_mutex = mutex;
1376 }
1377 
1378 static int64_t
_channels_next_id(_channels * channels)1379 _channels_next_id(_channels *channels)  // needs lock
1380 {
1381     int64_t cid = channels->next_id;
1382     if (cid < 0) {
1383         /* overflow */
1384         return -1;
1385     }
1386     channels->next_id += 1;
1387     return cid;
1388 }
1389 
1390 static int
_channels_lookup(_channels * channels,int64_t cid,PyThread_type_lock * pmutex,_channel_state ** res)1391 _channels_lookup(_channels *channels, int64_t cid, PyThread_type_lock *pmutex,
1392                  _channel_state **res)
1393 {
1394     int err = -1;
1395     _channel_state *chan = NULL;
1396     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1397     if (pmutex != NULL) {
1398         *pmutex = NULL;
1399     }
1400 
1401     _channelref *ref = _channelref_find(channels->head, cid, NULL);
1402     if (ref == NULL) {
1403         err = ERR_CHANNEL_NOT_FOUND;
1404         goto done;
1405     }
1406     if (ref->chan == NULL || !ref->chan->open) {
1407         err = ERR_CHANNEL_CLOSED;
1408         goto done;
1409     }
1410 
1411     if (pmutex != NULL) {
1412         // The mutex will be closed by the caller.
1413         *pmutex = channels->mutex;
1414     }
1415 
1416     chan = ref->chan;
1417     err = 0;
1418 
1419 done:
1420     if (pmutex == NULL || *pmutex == NULL) {
1421         PyThread_release_lock(channels->mutex);
1422     }
1423     *res = chan;
1424     return err;
1425 }
1426 
1427 static int64_t
_channels_add(_channels * channels,_channel_state * chan)1428 _channels_add(_channels *channels, _channel_state *chan)
1429 {
1430     int64_t cid = -1;
1431     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1432 
1433     // Create a new ref.
1434     int64_t _cid = _channels_next_id(channels);
1435     if (_cid < 0) {
1436         cid = ERR_NO_NEXT_CHANNEL_ID;
1437         goto done;
1438     }
1439     _channelref *ref = _channelref_new(_cid, chan);
1440     if (ref == NULL) {
1441         goto done;
1442     }
1443 
1444     // Add it to the list.
1445     // We assume that the channel is a new one (not already in the list).
1446     ref->next = channels->head;
1447     channels->head = ref;
1448     channels->numopen += 1;
1449 
1450     cid = _cid;
1451 done:
1452     PyThread_release_lock(channels->mutex);
1453     return cid;
1454 }
1455 
1456 /* forward */
1457 static int _channel_set_closing(_channelref *, PyThread_type_lock);
1458 
1459 static int
_channels_close(_channels * channels,int64_t cid,_channel_state ** pchan,int end,int force)1460 _channels_close(_channels *channels, int64_t cid, _channel_state **pchan,
1461                 int end, int force)
1462 {
1463     int res = -1;
1464     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1465     if (pchan != NULL) {
1466         *pchan = NULL;
1467     }
1468 
1469     _channelref *ref = _channelref_find(channels->head, cid, NULL);
1470     if (ref == NULL) {
1471         res = ERR_CHANNEL_NOT_FOUND;
1472         goto done;
1473     }
1474 
1475     if (ref->chan == NULL) {
1476         res = ERR_CHANNEL_CLOSED;
1477         goto done;
1478     }
1479     else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
1480         res = ERR_CHANNEL_CLOSED;
1481         goto done;
1482     }
1483     else {
1484         int err = _channel_release_all(ref->chan, end, force);
1485         if (err != 0) {
1486             if (end == CHANNEL_SEND && err == ERR_CHANNEL_NOT_EMPTY) {
1487                 if (ref->chan->closing != NULL) {
1488                     res = ERR_CHANNEL_CLOSED;
1489                     goto done;
1490                 }
1491                 // Mark the channel as closing and return.  The channel
1492                 // will be cleaned up in _channel_next().
1493                 PyErr_Clear();
1494                 int err = _channel_set_closing(ref, channels->mutex);
1495                 if (err != 0) {
1496                     res = err;
1497                     goto done;
1498                 }
1499                 if (pchan != NULL) {
1500                     *pchan = ref->chan;
1501                 }
1502                 res = 0;
1503             }
1504             else {
1505                 res = err;
1506             }
1507             goto done;
1508         }
1509         if (pchan != NULL) {
1510             *pchan = ref->chan;
1511         }
1512         else  {
1513             _channel_free(ref->chan);
1514         }
1515         ref->chan = NULL;
1516     }
1517 
1518     res = 0;
1519 done:
1520     PyThread_release_lock(channels->mutex);
1521     return res;
1522 }
1523 
1524 static void
_channels_remove_ref(_channels * channels,_channelref * ref,_channelref * prev,_channel_state ** pchan)1525 _channels_remove_ref(_channels *channels, _channelref *ref, _channelref *prev,
1526                      _channel_state **pchan)
1527 {
1528     if (ref == channels->head) {
1529         channels->head = ref->next;
1530     }
1531     else {
1532         prev->next = ref->next;
1533     }
1534     channels->numopen -= 1;
1535 
1536     if (pchan != NULL) {
1537         *pchan = ref->chan;
1538     }
1539     _channelref_free(ref);
1540 }
1541 
1542 static int
_channels_remove(_channels * channels,int64_t cid,_channel_state ** pchan)1543 _channels_remove(_channels *channels, int64_t cid, _channel_state **pchan)
1544 {
1545     int res = -1;
1546     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1547 
1548     if (pchan != NULL) {
1549         *pchan = NULL;
1550     }
1551 
1552     _channelref *prev = NULL;
1553     _channelref *ref = _channelref_find(channels->head, cid, &prev);
1554     if (ref == NULL) {
1555         res = ERR_CHANNEL_NOT_FOUND;
1556         goto done;
1557     }
1558 
1559     _channels_remove_ref(channels, ref, prev, pchan);
1560 
1561     res = 0;
1562 done:
1563     PyThread_release_lock(channels->mutex);
1564     return res;
1565 }
1566 
1567 static int
_channels_add_id_object(_channels * channels,int64_t cid)1568 _channels_add_id_object(_channels *channels, int64_t cid)
1569 {
1570     int res = -1;
1571     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1572 
1573     _channelref *ref = _channelref_find(channels->head, cid, NULL);
1574     if (ref == NULL) {
1575         res = ERR_CHANNEL_NOT_FOUND;
1576         goto done;
1577     }
1578     ref->objcount += 1;
1579 
1580     res = 0;
1581 done:
1582     PyThread_release_lock(channels->mutex);
1583     return res;
1584 }
1585 
1586 static void
_channels_release_cid_object(_channels * channels,int64_t cid)1587 _channels_release_cid_object(_channels *channels, int64_t cid)
1588 {
1589     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1590 
1591     _channelref *prev = NULL;
1592     _channelref *ref = _channelref_find(channels->head, cid, &prev);
1593     if (ref == NULL) {
1594         // Already destroyed.
1595         goto done;
1596     }
1597     ref->objcount -= 1;
1598 
1599     // Destroy if no longer used.
1600     if (ref->objcount == 0) {
1601         _channel_state *chan = NULL;
1602         _channels_remove_ref(channels, ref, prev, &chan);
1603         if (chan != NULL) {
1604             _channel_free(chan);
1605         }
1606     }
1607 
1608 done:
1609     PyThread_release_lock(channels->mutex);
1610 }
1611 
1612 struct channel_id_and_info {
1613     int64_t id;
1614     int unboundop;
1615 };
1616 
1617 static struct channel_id_and_info *
_channels_list_all(_channels * channels,int64_t * count)1618 _channels_list_all(_channels *channels, int64_t *count)
1619 {
1620     struct channel_id_and_info *cids = NULL;
1621     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1622     struct channel_id_and_info *ids =
1623         PyMem_NEW(struct channel_id_and_info, (Py_ssize_t)(channels->numopen));
1624     if (ids == NULL) {
1625         goto done;
1626     }
1627     _channelref *ref = channels->head;
1628     for (int64_t i=0; ref != NULL; ref = ref->next, i++) {
1629         ids[i] = (struct channel_id_and_info){
1630             .id = ref->cid,
1631             .unboundop = ref->chan->defaults.unboundop,
1632         };
1633     }
1634     *count = channels->numopen;
1635 
1636     cids = ids;
1637 done:
1638     PyThread_release_lock(channels->mutex);
1639     return cids;
1640 }
1641 
1642 static void
_channels_clear_interpreter(_channels * channels,int64_t interpid)1643 _channels_clear_interpreter(_channels *channels, int64_t interpid)
1644 {
1645     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1646 
1647     _channelref *ref = channels->head;
1648     for (; ref != NULL; ref = ref->next) {
1649         if (ref->chan != NULL) {
1650             _channel_clear_interpreter(ref->chan, interpid);
1651         }
1652     }
1653 
1654     PyThread_release_lock(channels->mutex);
1655 }
1656 
1657 
1658 /* support for closing non-empty channels */
1659 
1660 struct _channel_closing {
1661     _channelref *ref;
1662 };
1663 
1664 static int
_channel_set_closing(_channelref * ref,PyThread_type_lock mutex)1665 _channel_set_closing(_channelref *ref, PyThread_type_lock mutex) {
1666     _channel_state *chan = ref->chan;
1667     if (chan == NULL) {
1668         // already closed
1669         return 0;
1670     }
1671     int res = -1;
1672     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1673     if (chan->closing != NULL) {
1674         res = ERR_CHANNEL_CLOSED;
1675         goto done;
1676     }
1677     chan->closing = GLOBAL_MALLOC(struct _channel_closing);
1678     if (chan->closing == NULL) {
1679         goto done;
1680     }
1681     chan->closing->ref = ref;
1682 
1683     res = 0;
1684 done:
1685     PyThread_release_lock(chan->mutex);
1686     return res;
1687 }
1688 
1689 static void
_channel_clear_closing(_channel_state * chan)1690 _channel_clear_closing(_channel_state *chan) {
1691     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1692     if (chan->closing != NULL) {
1693         GLOBAL_FREE(chan->closing);
1694         chan->closing = NULL;
1695     }
1696     PyThread_release_lock(chan->mutex);
1697 }
1698 
1699 static void
_channel_finish_closing(_channel_state * chan)1700 _channel_finish_closing(_channel_state *chan) {
1701     struct _channel_closing *closing = chan->closing;
1702     if (closing == NULL) {
1703         return;
1704     }
1705     _channelref *ref = closing->ref;
1706     _channel_clear_closing(chan);
1707     // Do the things that would have been done in _channels_close().
1708     ref->chan = NULL;
1709     _channel_free(chan);
1710 }
1711 
1712 
1713 /* "high"-level channel-related functions */
1714 
1715 // Create a new channel.
1716 static int64_t
channel_create(_channels * channels,int unboundop)1717 channel_create(_channels *channels, int unboundop)
1718 {
1719     PyThread_type_lock mutex = PyThread_allocate_lock();
1720     if (mutex == NULL) {
1721         return ERR_CHANNEL_MUTEX_INIT;
1722     }
1723     _channel_state *chan = _channel_new(mutex, unboundop);
1724     if (chan == NULL) {
1725         PyThread_free_lock(mutex);
1726         return -1;
1727     }
1728     int64_t cid = _channels_add(channels, chan);
1729     if (cid < 0) {
1730         _channel_free(chan);
1731     }
1732     return cid;
1733 }
1734 
1735 // Completely destroy the channel.
1736 static int
channel_destroy(_channels * channels,int64_t cid)1737 channel_destroy(_channels *channels, int64_t cid)
1738 {
1739     _channel_state *chan = NULL;
1740     int err = _channels_remove(channels, cid, &chan);
1741     if (err != 0) {
1742         return err;
1743     }
1744     if (chan != NULL) {
1745         _channel_free(chan);
1746     }
1747     return 0;
1748 }
1749 
1750 // Push an object onto the channel.
1751 // The current interpreter gets associated with the send end of the channel.
1752 // Optionally request to be notified when it is received.
1753 static int
channel_send(_channels * channels,int64_t cid,PyObject * obj,_waiting_t * waiting,int unboundop)1754 channel_send(_channels *channels, int64_t cid, PyObject *obj,
1755              _waiting_t *waiting, int unboundop)
1756 {
1757     PyInterpreterState *interp = _get_current_interp();
1758     if (interp == NULL) {
1759         return -1;
1760     }
1761     int64_t interpid = PyInterpreterState_GetID(interp);
1762 
1763     // Look up the channel.
1764     PyThread_type_lock mutex = NULL;
1765     _channel_state *chan = NULL;
1766     int err = _channels_lookup(channels, cid, &mutex, &chan);
1767     if (err != 0) {
1768         return err;
1769     }
1770     assert(chan != NULL);
1771     // Past this point we are responsible for releasing the mutex.
1772 
1773     if (chan->closing != NULL) {
1774         PyThread_release_lock(mutex);
1775         return ERR_CHANNEL_CLOSED;
1776     }
1777 
1778     // Convert the object to cross-interpreter data.
1779     _PyCrossInterpreterData *data = GLOBAL_MALLOC(_PyCrossInterpreterData);
1780     if (data == NULL) {
1781         PyThread_release_lock(mutex);
1782         return -1;
1783     }
1784     if (_PyObject_GetCrossInterpreterData(obj, data) != 0) {
1785         PyThread_release_lock(mutex);
1786         GLOBAL_FREE(data);
1787         return -1;
1788     }
1789 
1790     // Add the data to the channel.
1791     int res = _channel_add(chan, interpid, data, waiting, unboundop);
1792     PyThread_release_lock(mutex);
1793     if (res != 0) {
1794         // We may chain an exception here:
1795         (void)_release_xid_data(data, 0);
1796         GLOBAL_FREE(data);
1797         return res;
1798     }
1799 
1800     return 0;
1801 }
1802 
1803 // Basically, un-send an object.
1804 static void
channel_clear_sent(_channels * channels,int64_t cid,_waiting_t * waiting)1805 channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
1806 {
1807     // Look up the channel.
1808     PyThread_type_lock mutex = NULL;
1809     _channel_state *chan = NULL;
1810     int err = _channels_lookup(channels, cid, &mutex, &chan);
1811     if (err != 0) {
1812         // The channel was already closed, etc.
1813         assert(waiting->status == WAITING_RELEASED);
1814         return;  // Ignore the error.
1815     }
1816     assert(chan != NULL);
1817     // Past this point we are responsible for releasing the mutex.
1818 
1819     _channelitem_id_t itemid = _waiting_get_itemid(waiting);
1820     _channel_remove(chan, itemid);
1821 
1822     PyThread_release_lock(mutex);
1823 }
1824 
1825 // Like channel_send(), but strictly wait for the object to be received.
1826 static int
channel_send_wait(_channels * channels,int64_t cid,PyObject * obj,int unboundop,PY_TIMEOUT_T timeout)1827 channel_send_wait(_channels *channels, int64_t cid, PyObject *obj,
1828                   int unboundop, PY_TIMEOUT_T timeout)
1829 {
1830     // We use a stack variable here, so we must ensure that &waiting
1831     // is not held by any channel item at the point this function exits.
1832     _waiting_t waiting;
1833     if (_waiting_init(&waiting) < 0) {
1834         assert(PyErr_Occurred());
1835         return -1;
1836     }
1837 
1838     /* Queue up the object. */
1839     int res = channel_send(channels, cid, obj, &waiting, unboundop);
1840     if (res < 0) {
1841         assert(waiting.status == WAITING_NO_STATUS);
1842         goto finally;
1843     }
1844 
1845     /* Wait until the object is received. */
1846     if (wait_for_lock(waiting.mutex, timeout) < 0) {
1847         assert(PyErr_Occurred());
1848         _waiting_finish_releasing(&waiting);
1849         /* The send() call is failing now, so make sure the item
1850            won't be received. */
1851         channel_clear_sent(channels, cid, &waiting);
1852         assert(waiting.status == WAITING_RELEASED);
1853         if (!waiting.received) {
1854             res = -1;
1855             goto finally;
1856         }
1857         // XXX Emit a warning if not a TimeoutError?
1858         PyErr_Clear();
1859     }
1860     else {
1861         _waiting_finish_releasing(&waiting);
1862         assert(waiting.status == WAITING_RELEASED);
1863         if (!waiting.received) {
1864             res = ERR_CHANNEL_CLOSED_WAITING;
1865             goto finally;
1866         }
1867     }
1868 
1869     /* success! */
1870     res = 0;
1871 
1872 finally:
1873     _waiting_clear(&waiting);
1874     return res;
1875 }
1876 
1877 // Pop the next object off the channel.  Fail if empty.
1878 // The current interpreter gets associated with the recv end of the channel.
1879 // XXX Support a "wait" mutex?
1880 static int
channel_recv(_channels * channels,int64_t cid,PyObject ** res,int * p_unboundop)1881 channel_recv(_channels *channels, int64_t cid, PyObject **res, int *p_unboundop)
1882 {
1883     int err;
1884     *res = NULL;
1885 
1886     PyInterpreterState *interp = _get_current_interp();
1887     if (interp == NULL) {
1888         // XXX Is this always an error?
1889         if (PyErr_Occurred()) {
1890             return -1;
1891         }
1892         return 0;
1893     }
1894     int64_t interpid = PyInterpreterState_GetID(interp);
1895 
1896     // Look up the channel.
1897     PyThread_type_lock mutex = NULL;
1898     _channel_state *chan = NULL;
1899     err = _channels_lookup(channels, cid, &mutex, &chan);
1900     if (err != 0) {
1901         return err;
1902     }
1903     assert(chan != NULL);
1904     // Past this point we are responsible for releasing the mutex.
1905 
1906     // Pop off the next item from the channel.
1907     _PyCrossInterpreterData *data = NULL;
1908     _waiting_t *waiting = NULL;
1909     err = _channel_next(chan, interpid, &data, &waiting, p_unboundop);
1910     PyThread_release_lock(mutex);
1911     if (err != 0) {
1912         return err;
1913     }
1914     else if (data == NULL) {
1915         // The item was unbound.
1916         assert(!PyErr_Occurred());
1917         *res = NULL;
1918         return 0;
1919     }
1920 
1921     // Convert the data back to an object.
1922     PyObject *obj = _PyCrossInterpreterData_NewObject(data);
1923     if (obj == NULL) {
1924         assert(PyErr_Occurred());
1925         // It was allocated in channel_send(), so we free it.
1926         (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
1927         if (waiting != NULL) {
1928             _waiting_release(waiting, 0);
1929         }
1930         return -1;
1931     }
1932     // It was allocated in channel_send(), so we free it.
1933     int release_res = _release_xid_data(data, XID_FREE);
1934     if (release_res < 0) {
1935         // The source interpreter has been destroyed already.
1936         assert(PyErr_Occurred());
1937         Py_DECREF(obj);
1938         if (waiting != NULL) {
1939             _waiting_release(waiting, 0);
1940         }
1941         return -1;
1942     }
1943 
1944     // Notify the sender.
1945     if (waiting != NULL) {
1946         _waiting_release(waiting, 1);
1947     }
1948 
1949     *res = obj;
1950     return 0;
1951 }
1952 
1953 // Disallow send/recv for the current interpreter.
1954 // The channel is marked as closed if no other interpreters
1955 // are currently associated.
1956 static int
channel_release(_channels * channels,int64_t cid,int send,int recv)1957 channel_release(_channels *channels, int64_t cid, int send, int recv)
1958 {
1959     PyInterpreterState *interp = _get_current_interp();
1960     if (interp == NULL) {
1961         return -1;
1962     }
1963     int64_t interpid = PyInterpreterState_GetID(interp);
1964 
1965     // Look up the channel.
1966     PyThread_type_lock mutex = NULL;
1967     _channel_state *chan = NULL;
1968     int err = _channels_lookup(channels, cid, &mutex, &chan);
1969     if (err != 0) {
1970         return err;
1971     }
1972     // Past this point we are responsible for releasing the mutex.
1973 
1974     // Close one or both of the two ends.
1975     int res = _channel_release_interpreter(chan, interpid, send-recv);
1976     PyThread_release_lock(mutex);
1977     return res;
1978 }
1979 
1980 // Close the channel (for all interpreters).  Fail if it's already closed.
1981 // Close immediately if it's empty.  Otherwise, disallow sending and
1982 // finally close once empty.  Optionally, immediately clear and close it.
1983 static int
channel_close(_channels * channels,int64_t cid,int end,int force)1984 channel_close(_channels *channels, int64_t cid, int end, int force)
1985 {
1986     return _channels_close(channels, cid, NULL, end, force);
1987 }
1988 
1989 // Return true if the identified interpreter is associated
1990 // with the given end of the channel.
1991 static int
channel_is_associated(_channels * channels,int64_t cid,int64_t interpid,int send)1992 channel_is_associated(_channels *channels, int64_t cid, int64_t interpid,
1993                        int send)
1994 {
1995     _channel_state *chan = NULL;
1996     int err = _channels_lookup(channels, cid, NULL, &chan);
1997     if (err != 0) {
1998         return err;
1999     }
2000     else if (send && chan->closing != NULL) {
2001         return ERR_CHANNEL_CLOSED;
2002     }
2003 
2004     _channelend *end = _channelend_find(send ? chan->ends->send : chan->ends->recv,
2005                                         interpid, NULL);
2006 
2007     return (end != NULL && end->open);
2008 }
2009 
2010 static int
_channel_get_count(_channels * channels,int64_t cid,Py_ssize_t * p_count)2011 _channel_get_count(_channels *channels, int64_t cid, Py_ssize_t *p_count)
2012 {
2013     PyThread_type_lock mutex = NULL;
2014     _channel_state *chan = NULL;
2015     int err = _channels_lookup(channels, cid, &mutex, &chan);
2016     if (err != 0) {
2017         return err;
2018     }
2019     assert(chan != NULL);
2020     int64_t count = chan->queue->count;
2021     PyThread_release_lock(mutex);
2022 
2023     *p_count = (Py_ssize_t)count;
2024     return 0;
2025 }
2026 
2027 
2028 /* channel info */
2029 
2030 struct channel_info {
2031     struct {
2032         // 1: closed; -1: closing
2033         int closed;
2034         struct {
2035             Py_ssize_t nsend_only;  // not released
2036             Py_ssize_t nsend_only_released;
2037             Py_ssize_t nrecv_only;  // not released
2038             Py_ssize_t nrecv_only_released;
2039             Py_ssize_t nboth;  // not released
2040             Py_ssize_t nboth_released;
2041             Py_ssize_t nboth_send_released;
2042             Py_ssize_t nboth_recv_released;
2043         } all;
2044         struct {
2045             // 1: associated; -1: released
2046             int send;
2047             int recv;
2048         } cur;
2049     } status;
2050     int64_t count;
2051 };
2052 
2053 static int
_channel_get_info(_channels * channels,int64_t cid,struct channel_info * info)2054 _channel_get_info(_channels *channels, int64_t cid, struct channel_info *info)
2055 {
2056     int err = 0;
2057     *info = (struct channel_info){0};
2058 
2059     // Get the current interpreter.
2060     PyInterpreterState *interp = _get_current_interp();
2061     if (interp == NULL) {
2062         return -1;
2063     }
2064     int64_t interpid = PyInterpreterState_GetID(interp);
2065 
2066     // Hold the global lock until we're done.
2067     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
2068 
2069     // Find the channel.
2070     _channelref *ref = _channelref_find(channels->head, cid, NULL);
2071     if (ref == NULL) {
2072         err = ERR_CHANNEL_NOT_FOUND;
2073         goto finally;
2074     }
2075     _channel_state *chan = ref->chan;
2076 
2077     // Check if open.
2078     if (chan == NULL) {
2079         info->status.closed = 1;
2080         goto finally;
2081     }
2082     if (!chan->open) {
2083         assert(chan->queue->count == 0);
2084         info->status.closed = 1;
2085         goto finally;
2086     }
2087     if (chan->closing != NULL) {
2088         assert(chan->queue->count > 0);
2089         info->status.closed = -1;
2090     }
2091     else {
2092         info->status.closed = 0;
2093     }
2094 
2095     // Get the number of queued objects.
2096     info->count = chan->queue->count;
2097 
2098     // Get the ends statuses.
2099     assert(info->status.cur.send == 0);
2100     assert(info->status.cur.recv == 0);
2101     _channelend *send = chan->ends->send;
2102     while (send != NULL) {
2103         if (send->interpid == interpid) {
2104             info->status.cur.send = send->open ? 1 : -1;
2105         }
2106 
2107         if (send->open) {
2108             info->status.all.nsend_only += 1;
2109         }
2110         else {
2111             info->status.all.nsend_only_released += 1;
2112         }
2113         send = send->next;
2114     }
2115     _channelend *recv = chan->ends->recv;
2116     while (recv != NULL) {
2117         if (recv->interpid == interpid) {
2118             info->status.cur.recv = recv->open ? 1 : -1;
2119         }
2120 
2121         // XXX This is O(n*n).  Why do we have 2 linked lists?
2122         _channelend *send = chan->ends->send;
2123         while (send != NULL) {
2124             if (send->interpid == recv->interpid) {
2125                 break;
2126             }
2127             send = send->next;
2128         }
2129         if (send == NULL) {
2130             if (recv->open) {
2131                 info->status.all.nrecv_only += 1;
2132             }
2133             else {
2134                 info->status.all.nrecv_only_released += 1;
2135             }
2136         }
2137         else {
2138             if (recv->open) {
2139                 if (send->open) {
2140                     info->status.all.nboth += 1;
2141                     info->status.all.nsend_only -= 1;
2142                 }
2143                 else {
2144                     info->status.all.nboth_recv_released += 1;
2145                     info->status.all.nsend_only_released -= 1;
2146                 }
2147             }
2148             else {
2149                 if (send->open) {
2150                     info->status.all.nboth_send_released += 1;
2151                     info->status.all.nsend_only -= 1;
2152                 }
2153                 else {
2154                     info->status.all.nboth_released += 1;
2155                     info->status.all.nsend_only_released -= 1;
2156                 }
2157             }
2158         }
2159         recv = recv->next;
2160     }
2161 
2162 finally:
2163     PyThread_release_lock(channels->mutex);
2164     return err;
2165 }
2166 
2167 PyDoc_STRVAR(channel_info_doc,
2168 "ChannelInfo\n\
2169 \n\
2170 A named tuple of a channel's state.");
2171 
2172 static PyStructSequence_Field channel_info_fields[] = {
2173     {"open", "both ends are open"},
2174     {"closing", "send is closed, recv is non-empty"},
2175     {"closed", "both ends are closed"},
2176     {"count", "queued objects"},
2177 
2178     {"num_interp_send", "interpreters bound to the send end"},
2179     {"num_interp_send_released",
2180      "interpreters bound to the send end and released"},
2181 
2182     {"num_interp_recv", "interpreters bound to the send end"},
2183     {"num_interp_recv_released",
2184      "interpreters bound to the send end and released"},
2185 
2186     {"num_interp_both", "interpreters bound to both ends"},
2187     {"num_interp_both_released",
2188      "interpreters bound to both ends and released_from_both"},
2189     {"num_interp_both_send_released",
2190      "interpreters bound to both ends and released_from_the send end"},
2191     {"num_interp_both_recv_released",
2192      "interpreters bound to both ends and released_from_the recv end"},
2193 
2194     {"send_associated", "current interpreter is bound to the send end"},
2195     {"send_released", "current interpreter *was* bound to the send end"},
2196     {"recv_associated", "current interpreter is bound to the recv end"},
2197     {"recv_released", "current interpreter *was* bound to the recv end"},
2198     {0}
2199 };
2200 
2201 static PyStructSequence_Desc channel_info_desc = {
2202     .name = MODULE_NAME_STR ".ChannelInfo",
2203     .doc = channel_info_doc,
2204     .fields = channel_info_fields,
2205     .n_in_sequence = 8,
2206 };
2207 
2208 static PyObject *
new_channel_info(PyObject * mod,struct channel_info * info)2209 new_channel_info(PyObject *mod, struct channel_info *info)
2210 {
2211     module_state *state = get_module_state(mod);
2212     if (state == NULL) {
2213         return NULL;
2214     }
2215 
2216     assert(state->ChannelInfoType != NULL);
2217     PyObject *self = PyStructSequence_New(state->ChannelInfoType);
2218     if (self == NULL) {
2219         return NULL;
2220     }
2221 
2222     int pos = 0;
2223 #define SET_BOOL(val) \
2224     PyStructSequence_SET_ITEM(self, pos++, \
2225                               Py_NewRef(val ? Py_True : Py_False))
2226 #define SET_COUNT(val) \
2227     do { \
2228         PyObject *obj = PyLong_FromLongLong(val); \
2229         if (obj == NULL) { \
2230             Py_CLEAR(self); \
2231             return NULL; \
2232         } \
2233         PyStructSequence_SET_ITEM(self, pos++, obj); \
2234     } while(0)
2235     SET_BOOL(info->status.closed == 0);
2236     SET_BOOL(info->status.closed == -1);
2237     SET_BOOL(info->status.closed == 1);
2238     SET_COUNT(info->count);
2239     SET_COUNT(info->status.all.nsend_only);
2240     SET_COUNT(info->status.all.nsend_only_released);
2241     SET_COUNT(info->status.all.nrecv_only);
2242     SET_COUNT(info->status.all.nrecv_only_released);
2243     SET_COUNT(info->status.all.nboth);
2244     SET_COUNT(info->status.all.nboth_released);
2245     SET_COUNT(info->status.all.nboth_send_released);
2246     SET_COUNT(info->status.all.nboth_recv_released);
2247     SET_BOOL(info->status.cur.send == 1);
2248     SET_BOOL(info->status.cur.send == -1);
2249     SET_BOOL(info->status.cur.recv == 1);
2250     SET_BOOL(info->status.cur.recv == -1);
2251 #undef SET_COUNT
2252 #undef SET_BOOL
2253     assert(!PyErr_Occurred());
2254     return self;
2255 }
2256 
2257 
2258 /* ChannelID class */
2259 
2260 typedef struct channelid {
2261     PyObject_HEAD
2262     int64_t cid;
2263     int end;
2264     int resolve;
2265     _channels *channels;
2266 } channelid;
2267 
2268 struct channel_id_converter_data {
2269     PyObject *module;
2270     int64_t cid;
2271     int end;
2272 };
2273 
2274 static int
channel_id_converter(PyObject * arg,void * ptr)2275 channel_id_converter(PyObject *arg, void *ptr)
2276 {
2277     int64_t cid;
2278     int end = 0;
2279     struct channel_id_converter_data *data = ptr;
2280     module_state *state = get_module_state(data->module);
2281     assert(state != NULL);
2282     if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
2283         cid = ((channelid *)arg)->cid;
2284         end = ((channelid *)arg)->end;
2285     }
2286     else if (PyIndex_Check(arg)) {
2287         cid = PyLong_AsLongLong(arg);
2288         if (cid == -1 && PyErr_Occurred()) {
2289             return 0;
2290         }
2291         if (cid < 0) {
2292             PyErr_Format(PyExc_ValueError,
2293                         "channel ID must be a non-negative int, got %R", arg);
2294             return 0;
2295         }
2296     }
2297     else {
2298         PyErr_Format(PyExc_TypeError,
2299                      "channel ID must be an int, got %.100s",
2300                      Py_TYPE(arg)->tp_name);
2301         return 0;
2302     }
2303     data->cid = cid;
2304     data->end = end;
2305     return 1;
2306 }
2307 
2308 static int
newchannelid(PyTypeObject * cls,int64_t cid,int end,_channels * channels,int force,int resolve,channelid ** res)2309 newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels,
2310              int force, int resolve, channelid **res)
2311 {
2312     *res = NULL;
2313 
2314     channelid *self = PyObject_New(channelid, cls);
2315     if (self == NULL) {
2316         return -1;
2317     }
2318     self->cid = cid;
2319     self->end = end;
2320     self->resolve = resolve;
2321     self->channels = channels;
2322 
2323     int err = _channels_add_id_object(channels, cid);
2324     if (err != 0) {
2325         if (force && err == ERR_CHANNEL_NOT_FOUND) {
2326             assert(!PyErr_Occurred());
2327         }
2328         else {
2329             Py_DECREF((PyObject *)self);
2330             return err;
2331         }
2332     }
2333 
2334     *res = self;
2335     return 0;
2336 }
2337 
2338 static _channels * _global_channels(void);
2339 
2340 static PyObject *
_channelid_new(PyObject * mod,PyTypeObject * cls,PyObject * args,PyObject * kwds)2341 _channelid_new(PyObject *mod, PyTypeObject *cls,
2342                PyObject *args, PyObject *kwds)
2343 {
2344     static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
2345     int64_t cid;
2346     int end;
2347     struct channel_id_converter_data cid_data = {
2348         .module = mod,
2349     };
2350     int send = -1;
2351     int recv = -1;
2352     int force = 0;
2353     int resolve = 0;
2354     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2355                                      "O&|$pppp:ChannelID.__new__", kwlist,
2356                                      channel_id_converter, &cid_data,
2357                                      &send, &recv, &force, &resolve)) {
2358         return NULL;
2359     }
2360     cid = cid_data.cid;
2361     end = cid_data.end;
2362 
2363     // Handle "send" and "recv".
2364     if (send == 0 && recv == 0) {
2365         PyErr_SetString(PyExc_ValueError,
2366                         "'send' and 'recv' cannot both be False");
2367         return NULL;
2368     }
2369     else if (send == 1) {
2370         if (recv == 0 || recv == -1) {
2371             end = CHANNEL_SEND;
2372         }
2373         else {
2374             assert(recv == 1);
2375             end = 0;
2376         }
2377     }
2378     else if (recv == 1) {
2379         assert(send == 0 || send == -1);
2380         end = CHANNEL_RECV;
2381     }
2382 
2383     PyObject *cidobj = NULL;
2384     int err = newchannelid(cls, cid, end, _global_channels(),
2385                            force, resolve,
2386                            (channelid **)&cidobj);
2387     if (handle_channel_error(err, mod, cid)) {
2388         assert(cidobj == NULL);
2389         return NULL;
2390     }
2391     assert(cidobj != NULL);
2392     return cidobj;
2393 }
2394 
2395 static void
channelid_dealloc(PyObject * self)2396 channelid_dealloc(PyObject *self)
2397 {
2398     int64_t cid = ((channelid *)self)->cid;
2399     _channels *channels = ((channelid *)self)->channels;
2400 
2401     PyTypeObject *tp = Py_TYPE(self);
2402     tp->tp_free(self);
2403     /* "Instances of heap-allocated types hold a reference to their type."
2404      * See: https://docs.python.org/3.11/howto/isolating-extensions.html#garbage-collection-protocol
2405      * See: https://docs.python.org/3.11/c-api/typeobj.html#c.PyTypeObject.tp_traverse
2406     */
2407     // XXX Why don't we implement Py_TPFLAGS_HAVE_GC, e.g. Py_tp_traverse,
2408     // like we do for _abc._abc_data?
2409     Py_DECREF(tp);
2410 
2411     _channels_release_cid_object(channels, cid);
2412 }
2413 
2414 static PyObject *
channelid_repr(PyObject * self)2415 channelid_repr(PyObject *self)
2416 {
2417     PyTypeObject *type = Py_TYPE(self);
2418     const char *name = _PyType_Name(type);
2419 
2420     channelid *cidobj = (channelid *)self;
2421     const char *fmt;
2422     if (cidobj->end == CHANNEL_SEND) {
2423         fmt = "%s(%" PRId64 ", send=True)";
2424     }
2425     else if (cidobj->end == CHANNEL_RECV) {
2426         fmt = "%s(%" PRId64 ", recv=True)";
2427     }
2428     else {
2429         fmt = "%s(%" PRId64 ")";
2430     }
2431     return PyUnicode_FromFormat(fmt, name, cidobj->cid);
2432 }
2433 
2434 static PyObject *
channelid_str(PyObject * self)2435 channelid_str(PyObject *self)
2436 {
2437     channelid *cidobj = (channelid *)self;
2438     return PyUnicode_FromFormat("%" PRId64 "", cidobj->cid);
2439 }
2440 
2441 static PyObject *
channelid_int(PyObject * self)2442 channelid_int(PyObject *self)
2443 {
2444     channelid *cidobj = (channelid *)self;
2445     return PyLong_FromLongLong(cidobj->cid);
2446 }
2447 
2448 static Py_hash_t
channelid_hash(PyObject * self)2449 channelid_hash(PyObject *self)
2450 {
2451     channelid *cidobj = (channelid *)self;
2452     PyObject *pyid = PyLong_FromLongLong(cidobj->cid);
2453     if (pyid == NULL) {
2454         return -1;
2455     }
2456     Py_hash_t hash = PyObject_Hash(pyid);
2457     Py_DECREF(pyid);
2458     return hash;
2459 }
2460 
2461 static PyObject *
channelid_richcompare(PyObject * self,PyObject * other,int op)2462 channelid_richcompare(PyObject *self, PyObject *other, int op)
2463 {
2464     PyObject *res = NULL;
2465     if (op != Py_EQ && op != Py_NE) {
2466         Py_RETURN_NOTIMPLEMENTED;
2467     }
2468 
2469     PyObject *mod = get_module_from_type(Py_TYPE(self));
2470     if (mod == NULL) {
2471         return NULL;
2472     }
2473     module_state *state = get_module_state(mod);
2474     if (state == NULL) {
2475         goto done;
2476     }
2477 
2478     if (!PyObject_TypeCheck(self, state->ChannelIDType)) {
2479         res = Py_NewRef(Py_NotImplemented);
2480         goto done;
2481     }
2482 
2483     channelid *cidobj = (channelid *)self;
2484     int equal;
2485     if (PyObject_TypeCheck(other, state->ChannelIDType)) {
2486         channelid *othercidobj = (channelid *)other;
2487         equal = (cidobj->end == othercidobj->end) && (cidobj->cid == othercidobj->cid);
2488     }
2489     else if (PyLong_Check(other)) {
2490         /* Fast path */
2491         int overflow;
2492         long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow);
2493         if (othercid == -1 && PyErr_Occurred()) {
2494             goto done;
2495         }
2496         equal = !overflow && (othercid >= 0) && (cidobj->cid == othercid);
2497     }
2498     else if (PyNumber_Check(other)) {
2499         PyObject *pyid = PyLong_FromLongLong(cidobj->cid);
2500         if (pyid == NULL) {
2501             goto done;
2502         }
2503         res = PyObject_RichCompare(pyid, other, op);
2504         Py_DECREF(pyid);
2505         goto done;
2506     }
2507     else {
2508         res = Py_NewRef(Py_NotImplemented);
2509         goto done;
2510     }
2511 
2512     if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
2513         res = Py_NewRef(Py_True);
2514     }
2515     else {
2516         res = Py_NewRef(Py_False);
2517     }
2518 
2519 done:
2520     Py_DECREF(mod);
2521     return res;
2522 }
2523 
2524 static PyTypeObject * _get_current_channelend_type(int end);
2525 
2526 static PyObject *
_channelobj_from_cidobj(PyObject * cidobj,int end)2527 _channelobj_from_cidobj(PyObject *cidobj, int end)
2528 {
2529     PyObject *cls = (PyObject *)_get_current_channelend_type(end);
2530     if (cls == NULL) {
2531         return NULL;
2532     }
2533     PyObject *chan = PyObject_CallFunctionObjArgs(cls, cidobj, NULL);
2534     Py_DECREF(cls);
2535     if (chan == NULL) {
2536         return NULL;
2537     }
2538     return chan;
2539 }
2540 
2541 struct _channelid_xid {
2542     int64_t cid;
2543     int end;
2544     int resolve;
2545 };
2546 
2547 static PyObject *
_channelid_from_xid(_PyCrossInterpreterData * data)2548 _channelid_from_xid(_PyCrossInterpreterData *data)
2549 {
2550     struct _channelid_xid *xid = \
2551                 (struct _channelid_xid *)_PyCrossInterpreterData_DATA(data);
2552 
2553     // It might not be imported yet, so we can't use _get_current_module().
2554     PyObject *mod = PyImport_ImportModule(MODULE_NAME_STR);
2555     if (mod == NULL) {
2556         return NULL;
2557     }
2558     assert(mod != Py_None);
2559     module_state *state = get_module_state(mod);
2560     if (state == NULL) {
2561         return NULL;
2562     }
2563 
2564     // Note that we do not preserve the "resolve" flag.
2565     PyObject *cidobj = NULL;
2566     int err = newchannelid(state->ChannelIDType, xid->cid, xid->end,
2567                            _global_channels(), 0, 0,
2568                            (channelid **)&cidobj);
2569     if (err != 0) {
2570         assert(cidobj == NULL);
2571         (void)handle_channel_error(err, mod, xid->cid);
2572         goto done;
2573     }
2574     assert(cidobj != NULL);
2575     if (xid->end == 0) {
2576         goto done;
2577     }
2578     if (!xid->resolve) {
2579         goto done;
2580     }
2581 
2582     /* Try returning a high-level channel end but fall back to the ID. */
2583     PyObject *chan = _channelobj_from_cidobj(cidobj, xid->end);
2584     if (chan == NULL) {
2585         PyErr_Clear();
2586         goto done;
2587     }
2588     Py_DECREF(cidobj);
2589     cidobj = chan;
2590 
2591 done:
2592     Py_DECREF(mod);
2593     return cidobj;
2594 }
2595 
2596 static int
_channelid_shared(PyThreadState * tstate,PyObject * obj,_PyCrossInterpreterData * data)2597 _channelid_shared(PyThreadState *tstate, PyObject *obj,
2598                   _PyCrossInterpreterData *data)
2599 {
2600     if (_PyCrossInterpreterData_InitWithSize(
2601             data, tstate->interp, sizeof(struct _channelid_xid), obj,
2602             _channelid_from_xid
2603             ) < 0)
2604     {
2605         return -1;
2606     }
2607     struct _channelid_xid *xid = \
2608                 (struct _channelid_xid *)_PyCrossInterpreterData_DATA(data);
2609     xid->cid = ((channelid *)obj)->cid;
2610     xid->end = ((channelid *)obj)->end;
2611     xid->resolve = ((channelid *)obj)->resolve;
2612     return 0;
2613 }
2614 
2615 static PyObject *
channelid_end(PyObject * self,void * end)2616 channelid_end(PyObject *self, void *end)
2617 {
2618     int force = 1;
2619     channelid *cidobj = (channelid *)self;
2620     if (end != NULL) {
2621         PyObject *obj = NULL;
2622         int err = newchannelid(Py_TYPE(self), cidobj->cid, *(int *)end,
2623                                cidobj->channels, force, cidobj->resolve,
2624                                (channelid **)&obj);
2625         if (err != 0) {
2626             assert(obj == NULL);
2627             PyObject *mod = get_module_from_type(Py_TYPE(self));
2628             if (mod == NULL) {
2629                 return NULL;
2630             }
2631             (void)handle_channel_error(err, mod, cidobj->cid);
2632             Py_DECREF(mod);
2633             return NULL;
2634         }
2635         assert(obj != NULL);
2636         return obj;
2637     }
2638 
2639     if (cidobj->end == CHANNEL_SEND) {
2640         return PyUnicode_InternFromString("send");
2641     }
2642     if (cidobj->end == CHANNEL_RECV) {
2643         return PyUnicode_InternFromString("recv");
2644     }
2645     return PyUnicode_InternFromString("both");
2646 }
2647 
2648 static int _channelid_end_send = CHANNEL_SEND;
2649 static int _channelid_end_recv = CHANNEL_RECV;
2650 
2651 static PyGetSetDef channelid_getsets[] = {
2652     {"end", (getter)channelid_end, NULL,
2653      PyDoc_STR("'send', 'recv', or 'both'")},
2654     {"send", (getter)channelid_end, NULL,
2655      PyDoc_STR("the 'send' end of the channel"), &_channelid_end_send},
2656     {"recv", (getter)channelid_end, NULL,
2657      PyDoc_STR("the 'recv' end of the channel"), &_channelid_end_recv},
2658     {NULL}
2659 };
2660 
2661 PyDoc_STRVAR(channelid_doc,
2662 "A channel ID identifies a channel and may be used as an int.");
2663 
2664 static PyType_Slot channelid_typeslots[] = {
2665     {Py_tp_dealloc, (destructor)channelid_dealloc},
2666     {Py_tp_doc, (void *)channelid_doc},
2667     {Py_tp_repr, (reprfunc)channelid_repr},
2668     {Py_tp_str, (reprfunc)channelid_str},
2669     {Py_tp_hash, channelid_hash},
2670     {Py_tp_richcompare, channelid_richcompare},
2671     {Py_tp_getset, channelid_getsets},
2672     // number slots
2673     {Py_nb_int, (unaryfunc)channelid_int},
2674     {Py_nb_index,  (unaryfunc)channelid_int},
2675     {0, NULL},
2676 };
2677 
2678 static PyType_Spec channelid_typespec = {
2679     .name = MODULE_NAME_STR ".ChannelID",
2680     .basicsize = sizeof(channelid),
2681     .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
2682               Py_TPFLAGS_DISALLOW_INSTANTIATION | Py_TPFLAGS_IMMUTABLETYPE),
2683     .slots = channelid_typeslots,
2684 };
2685 
2686 static PyTypeObject *
add_channelid_type(PyObject * mod)2687 add_channelid_type(PyObject *mod)
2688 {
2689     PyTypeObject *cls = (PyTypeObject *)PyType_FromModuleAndSpec(
2690                 mod, &channelid_typespec, NULL);
2691     if (cls == NULL) {
2692         return NULL;
2693     }
2694     if (PyModule_AddType(mod, cls) < 0) {
2695         Py_DECREF(cls);
2696         return NULL;
2697     }
2698     if (ensure_xid_class(cls, _channelid_shared) < 0) {
2699         Py_DECREF(cls);
2700         return NULL;
2701     }
2702     return cls;
2703 }
2704 
2705 
2706 /* SendChannel and RecvChannel classes */
2707 
2708 // XXX Use a new __xid__ protocol instead?
2709 
2710 static PyTypeObject *
_get_current_channelend_type(int end)2711 _get_current_channelend_type(int end)
2712 {
2713     module_state *state = _get_current_module_state();
2714     if (state == NULL) {
2715         return NULL;
2716     }
2717     PyTypeObject *cls;
2718     if (end == CHANNEL_SEND) {
2719         cls = state->send_channel_type;
2720     }
2721     else {
2722         assert(end == CHANNEL_RECV);
2723         cls = state->recv_channel_type;
2724     }
2725     if (cls == NULL) {
2726         // Force the module to be loaded, to register the type.
2727         PyObject *highlevel = PyImport_ImportModule("interpreters.channels");
2728         if (highlevel == NULL) {
2729             PyErr_Clear();
2730             highlevel = PyImport_ImportModule("test.support.interpreters.channels");
2731             if (highlevel == NULL) {
2732                 return NULL;
2733             }
2734         }
2735         Py_DECREF(highlevel);
2736         if (end == CHANNEL_SEND) {
2737             cls = state->send_channel_type;
2738         }
2739         else {
2740             cls = state->recv_channel_type;
2741         }
2742         assert(cls != NULL);
2743     }
2744     return cls;
2745 }
2746 
2747 static PyObject *
_channelend_from_xid(_PyCrossInterpreterData * data)2748 _channelend_from_xid(_PyCrossInterpreterData *data)
2749 {
2750     channelid *cidobj = (channelid *)_channelid_from_xid(data);
2751     if (cidobj == NULL) {
2752         return NULL;
2753     }
2754     PyTypeObject *cls = _get_current_channelend_type(cidobj->end);
2755     if (cls == NULL) {
2756         Py_DECREF(cidobj);
2757         return NULL;
2758     }
2759     PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cidobj);
2760     Py_DECREF(cidobj);
2761     return obj;
2762 }
2763 
2764 static int
_channelend_shared(PyThreadState * tstate,PyObject * obj,_PyCrossInterpreterData * data)2765 _channelend_shared(PyThreadState *tstate, PyObject *obj,
2766                     _PyCrossInterpreterData *data)
2767 {
2768     PyObject *cidobj = PyObject_GetAttrString(obj, "_id");
2769     if (cidobj == NULL) {
2770         return -1;
2771     }
2772     int res = _channelid_shared(tstate, cidobj, data);
2773     Py_DECREF(cidobj);
2774     if (res < 0) {
2775         return -1;
2776     }
2777     _PyCrossInterpreterData_SET_NEW_OBJECT(data, _channelend_from_xid);
2778     return 0;
2779 }
2780 
2781 static int
set_channelend_types(PyObject * mod,PyTypeObject * send,PyTypeObject * recv)2782 set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
2783 {
2784     module_state *state = get_module_state(mod);
2785     if (state == NULL) {
2786         return -1;
2787     }
2788 
2789     // Clear the old values if the .py module was reloaded.
2790     if (state->send_channel_type != NULL) {
2791         (void)clear_xid_class(state->send_channel_type);
2792         Py_CLEAR(state->send_channel_type);
2793     }
2794     if (state->recv_channel_type != NULL) {
2795         (void)clear_xid_class(state->recv_channel_type);
2796         Py_CLEAR(state->recv_channel_type);
2797     }
2798 
2799     // Add and register the types.
2800     state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
2801     state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);
2802     if (ensure_xid_class(send, _channelend_shared) < 0) {
2803         Py_CLEAR(state->send_channel_type);
2804         Py_CLEAR(state->recv_channel_type);
2805         return -1;
2806     }
2807     if (ensure_xid_class(recv, _channelend_shared) < 0) {
2808         (void)clear_xid_class(state->send_channel_type);
2809         Py_CLEAR(state->send_channel_type);
2810         Py_CLEAR(state->recv_channel_type);
2811         return -1;
2812     }
2813 
2814     return 0;
2815 }
2816 
2817 
2818 /* module level code ********************************************************/
2819 
2820 /* globals is the process-global state for the module.  It holds all
2821    the data that we need to share between interpreters, so it cannot
2822    hold PyObject values. */
2823 static struct globals {
2824     PyMutex mutex;
2825     int module_count;
2826     _channels channels;
2827 } _globals = {0};
2828 
2829 static int
_globals_init(void)2830 _globals_init(void)
2831 {
2832     PyMutex_Lock(&_globals.mutex);
2833     assert(_globals.module_count >= 0);
2834     _globals.module_count++;
2835     if (_globals.module_count == 1) {
2836         // Called for the first time.
2837         PyThread_type_lock mutex = PyThread_allocate_lock();
2838         if (mutex == NULL) {
2839             _globals.module_count--;
2840             PyMutex_Unlock(&_globals.mutex);
2841             return ERR_CHANNELS_MUTEX_INIT;
2842         }
2843         _channels_init(&_globals.channels, mutex);
2844     }
2845     PyMutex_Unlock(&_globals.mutex);
2846     return 0;
2847 }
2848 
2849 static void
_globals_fini(void)2850 _globals_fini(void)
2851 {
2852     PyMutex_Lock(&_globals.mutex);
2853     assert(_globals.module_count > 0);
2854     _globals.module_count--;
2855     if (_globals.module_count == 0) {
2856         PyThread_type_lock mutex;
2857         _channels_fini(&_globals.channels, &mutex);
2858         assert(mutex != NULL);
2859         PyThread_free_lock(mutex);
2860     }
2861     PyMutex_Unlock(&_globals.mutex);
2862 }
2863 
2864 static _channels *
_global_channels(void)2865 _global_channels(void) {
2866     return &_globals.channels;
2867 }
2868 
2869 
2870 static void
clear_interpreter(void * data)2871 clear_interpreter(void *data)
2872 {
2873     if (_globals.module_count == 0) {
2874         return;
2875     }
2876     PyInterpreterState *interp = (PyInterpreterState *)data;
2877     assert(interp == _get_current_interp());
2878     int64_t interpid = PyInterpreterState_GetID(interp);
2879     _channels_clear_interpreter(&_globals.channels, interpid);
2880 }
2881 
2882 
2883 static PyObject *
channelsmod_create(PyObject * self,PyObject * args,PyObject * kwds)2884 channelsmod_create(PyObject *self, PyObject *args, PyObject *kwds)
2885 {
2886     static char *kwlist[] = {"unboundop", NULL};
2887     int unboundop;
2888     if (!PyArg_ParseTupleAndKeywords(args, kwds, "i:create", kwlist,
2889                                      &unboundop))
2890     {
2891         return NULL;
2892     }
2893     if (!check_unbound(unboundop)) {
2894         PyErr_Format(PyExc_ValueError,
2895                      "unsupported unboundop %d", unboundop);
2896         return NULL;
2897     }
2898 
2899     int64_t cid = channel_create(&_globals.channels, unboundop);
2900     if (cid < 0) {
2901         (void)handle_channel_error(-1, self, cid);
2902         return NULL;
2903     }
2904     module_state *state = get_module_state(self);
2905     if (state == NULL) {
2906         return NULL;
2907     }
2908     PyObject *cidobj = NULL;
2909     int err = newchannelid(state->ChannelIDType, cid, 0,
2910                            &_globals.channels, 0, 0,
2911                            (channelid **)&cidobj);
2912     if (handle_channel_error(err, self, cid)) {
2913         assert(cidobj == NULL);
2914         err = channel_destroy(&_globals.channels, cid);
2915         if (handle_channel_error(err, self, cid)) {
2916             // XXX issue a warning?
2917         }
2918         return NULL;
2919     }
2920     assert(cidobj != NULL);
2921     assert(((channelid *)cidobj)->channels != NULL);
2922     return cidobj;
2923 }
2924 
2925 PyDoc_STRVAR(channelsmod_create_doc,
2926 "channel_create(unboundop) -> cid\n\
2927 \n\
2928 Create a new cross-interpreter channel and return a unique generated ID.");
2929 
2930 static PyObject *
channelsmod_destroy(PyObject * self,PyObject * args,PyObject * kwds)2931 channelsmod_destroy(PyObject *self, PyObject *args, PyObject *kwds)
2932 {
2933     static char *kwlist[] = {"cid", NULL};
2934     int64_t cid;
2935     struct channel_id_converter_data cid_data = {
2936         .module = self,
2937     };
2938     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist,
2939                                      channel_id_converter, &cid_data)) {
2940         return NULL;
2941     }
2942     cid = cid_data.cid;
2943 
2944     int err = channel_destroy(&_globals.channels, cid);
2945     if (handle_channel_error(err, self, cid)) {
2946         return NULL;
2947     }
2948     Py_RETURN_NONE;
2949 }
2950 
2951 PyDoc_STRVAR(channelsmod_destroy_doc,
2952 "channel_destroy(cid)\n\
2953 \n\
2954 Close and finalize the channel.  Afterward attempts to use the channel\n\
2955 will behave as though it never existed.");
2956 
2957 static PyObject *
channelsmod_list_all(PyObject * self,PyObject * Py_UNUSED (ignored))2958 channelsmod_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
2959 {
2960     int64_t count = 0;
2961     struct channel_id_and_info *cids =
2962         _channels_list_all(&_globals.channels, &count);
2963     if (cids == NULL) {
2964         if (count == 0) {
2965             return PyList_New(0);
2966         }
2967         return NULL;
2968     }
2969     PyObject *ids = PyList_New((Py_ssize_t)count);
2970     if (ids == NULL) {
2971         goto finally;
2972     }
2973     module_state *state = get_module_state(self);
2974     if (state == NULL) {
2975         Py_DECREF(ids);
2976         ids = NULL;
2977         goto finally;
2978     }
2979     struct channel_id_and_info *cur = cids;
2980     for (int64_t i=0; i < count; cur++, i++) {
2981         PyObject *cidobj = NULL;
2982         int err = newchannelid(state->ChannelIDType, cur->id, 0,
2983                                &_globals.channels, 0, 0,
2984                                (channelid **)&cidobj);
2985         if (handle_channel_error(err, self, cur->id)) {
2986             assert(cidobj == NULL);
2987             Py_SETREF(ids, NULL);
2988             break;
2989         }
2990         assert(cidobj != NULL);
2991 
2992         PyObject *item = Py_BuildValue("Oi", cidobj, cur->unboundop);
2993         Py_DECREF(cidobj);
2994         if (item == NULL) {
2995             Py_SETREF(ids, NULL);
2996             break;
2997         }
2998         PyList_SET_ITEM(ids, (Py_ssize_t)i, item);
2999     }
3000 
3001 finally:
3002     PyMem_Free(cids);
3003     return ids;
3004 }
3005 
3006 PyDoc_STRVAR(channelsmod_list_all_doc,
3007 "channel_list_all() -> [cid]\n\
3008 \n\
3009 Return the list of all IDs for active channels.");
3010 
3011 static PyObject *
channelsmod_list_interpreters(PyObject * self,PyObject * args,PyObject * kwds)3012 channelsmod_list_interpreters(PyObject *self, PyObject *args, PyObject *kwds)
3013 {
3014     static char *kwlist[] = {"cid", "send", NULL};
3015     int64_t cid;            /* Channel ID */
3016     struct channel_id_converter_data cid_data = {
3017         .module = self,
3018     };
3019     int send = 0;           /* Send or receive end? */
3020     int64_t interpid;
3021     PyObject *ids, *interpid_obj;
3022     PyInterpreterState *interp;
3023 
3024     if (!PyArg_ParseTupleAndKeywords(
3025             args, kwds, "O&$p:channel_list_interpreters",
3026             kwlist, channel_id_converter, &cid_data, &send)) {
3027         return NULL;
3028     }
3029     cid = cid_data.cid;
3030 
3031     ids = PyList_New(0);
3032     if (ids == NULL) {
3033         goto except;
3034     }
3035 
3036     interp = PyInterpreterState_Head();
3037     while (interp != NULL) {
3038         interpid = PyInterpreterState_GetID(interp);
3039         assert(interpid >= 0);
3040         int res = channel_is_associated(&_globals.channels, cid, interpid, send);
3041         if (res < 0) {
3042             (void)handle_channel_error(res, self, cid);
3043             goto except;
3044         }
3045         if (res) {
3046             interpid_obj = _PyInterpreterState_GetIDObject(interp);
3047             if (interpid_obj == NULL) {
3048                 goto except;
3049             }
3050             res = PyList_Insert(ids, 0, interpid_obj);
3051             Py_DECREF(interpid_obj);
3052             if (res < 0) {
3053                 goto except;
3054             }
3055         }
3056         interp = PyInterpreterState_Next(interp);
3057     }
3058 
3059     goto finally;
3060 
3061 except:
3062     Py_CLEAR(ids);
3063 
3064 finally:
3065     return ids;
3066 }
3067 
3068 PyDoc_STRVAR(channelsmod_list_interpreters_doc,
3069 "channel_list_interpreters(cid, *, send) -> [id]\n\
3070 \n\
3071 Return the list of all interpreter IDs associated with an end of the channel.\n\
3072 \n\
3073 The 'send' argument should be a boolean indicating whether to use the send or\n\
3074 receive end.");
3075 
3076 
3077 static PyObject *
channelsmod_send(PyObject * self,PyObject * args,PyObject * kwds)3078 channelsmod_send(PyObject *self, PyObject *args, PyObject *kwds)
3079 {
3080     static char *kwlist[] = {"cid", "obj", "unboundop", "blocking", "timeout",
3081                              NULL};
3082     struct channel_id_converter_data cid_data = {
3083         .module = self,
3084     };
3085     PyObject *obj;
3086     int unboundop = UNBOUND_REPLACE;
3087     int blocking = 1;
3088     PyObject *timeout_obj = NULL;
3089     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|i$pO:channel_send", kwlist,
3090                                      channel_id_converter, &cid_data, &obj,
3091                                      &unboundop, &blocking, &timeout_obj))
3092     {
3093         return NULL;
3094     }
3095     if (!check_unbound(unboundop)) {
3096         PyErr_Format(PyExc_ValueError,
3097                      "unsupported unboundop %d", unboundop);
3098         return NULL;
3099     }
3100 
3101     int64_t cid = cid_data.cid;
3102     PY_TIMEOUT_T timeout;
3103     if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
3104         return NULL;
3105     }
3106 
3107     /* Queue up the object. */
3108     int err = 0;
3109     if (blocking) {
3110         err = channel_send_wait(&_globals.channels, cid, obj, unboundop, timeout);
3111     }
3112     else {
3113         err = channel_send(&_globals.channels, cid, obj, NULL, unboundop);
3114     }
3115     if (handle_channel_error(err, self, cid)) {
3116         return NULL;
3117     }
3118 
3119     Py_RETURN_NONE;
3120 }
3121 
3122 PyDoc_STRVAR(channelsmod_send_doc,
3123 "channel_send(cid, obj, *, blocking=True, timeout=None)\n\
3124 \n\
3125 Add the object's data to the channel's queue.\n\
3126 By default this waits for the object to be received.");
3127 
3128 static PyObject *
channelsmod_send_buffer(PyObject * self,PyObject * args,PyObject * kwds)3129 channelsmod_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
3130 {
3131     static char *kwlist[] = {"cid", "obj", "unboundop", "blocking", "timeout",
3132                              NULL};
3133     struct channel_id_converter_data cid_data = {
3134         .module = self,
3135     };
3136     PyObject *obj;
3137     int unboundop = UNBOUND_REPLACE;
3138     int blocking = 1;
3139     PyObject *timeout_obj = NULL;
3140     if (!PyArg_ParseTupleAndKeywords(args, kwds,
3141                                      "O&O|i$pO:channel_send_buffer", kwlist,
3142                                      channel_id_converter, &cid_data, &obj,
3143                                      &unboundop, &blocking, &timeout_obj)) {
3144         return NULL;
3145     }
3146     if (!check_unbound(unboundop)) {
3147         PyErr_Format(PyExc_ValueError,
3148                      "unsupported unboundop %d", unboundop);
3149         return NULL;
3150     }
3151 
3152     int64_t cid = cid_data.cid;
3153     PY_TIMEOUT_T timeout;
3154     if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
3155         return NULL;
3156     }
3157 
3158     PyObject *tempobj = PyMemoryView_FromObject(obj);
3159     if (tempobj == NULL) {
3160         return NULL;
3161     }
3162 
3163     /* Queue up the object. */
3164     int err = 0;
3165     if (blocking) {
3166         err = channel_send_wait(
3167                 &_globals.channels, cid, tempobj, unboundop, timeout);
3168     }
3169     else {
3170         err = channel_send(&_globals.channels, cid, tempobj, NULL, unboundop);
3171     }
3172     Py_DECREF(tempobj);
3173     if (handle_channel_error(err, self, cid)) {
3174         return NULL;
3175     }
3176 
3177     Py_RETURN_NONE;
3178 }
3179 
3180 PyDoc_STRVAR(channelsmod_send_buffer_doc,
3181 "channel_send_buffer(cid, obj, *, blocking=True, timeout=None)\n\
3182 \n\
3183 Add the object's buffer to the channel's queue.\n\
3184 By default this waits for the object to be received.");
3185 
3186 static PyObject *
channelsmod_recv(PyObject * self,PyObject * args,PyObject * kwds)3187 channelsmod_recv(PyObject *self, PyObject *args, PyObject *kwds)
3188 {
3189     static char *kwlist[] = {"cid", "default", NULL};
3190     int64_t cid;
3191     struct channel_id_converter_data cid_data = {
3192         .module = self,
3193     };
3194     PyObject *dflt = NULL;
3195     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|O:channel_recv", kwlist,
3196                                      channel_id_converter, &cid_data, &dflt)) {
3197         return NULL;
3198     }
3199     cid = cid_data.cid;
3200 
3201     PyObject *obj = NULL;
3202     int unboundop = 0;
3203     int err = channel_recv(&_globals.channels, cid, &obj, &unboundop);
3204     if (err == ERR_CHANNEL_EMPTY && dflt != NULL) {
3205         // Use the default.
3206         obj = Py_NewRef(dflt);
3207         err = 0;
3208     }
3209     else if (handle_channel_error(err, self, cid)) {
3210         return NULL;
3211     }
3212     else if (obj == NULL) {
3213         // The item was unbound.
3214         return Py_BuildValue("Oi", Py_None, unboundop);
3215     }
3216 
3217     PyObject *res = Py_BuildValue("OO", obj, Py_None);
3218     Py_DECREF(obj);
3219     return res;
3220 }
3221 
3222 PyDoc_STRVAR(channelsmod_recv_doc,
3223 "channel_recv(cid, [default]) -> (obj, unboundop)\n\
3224 \n\
3225 Return a new object from the data at the front of the channel's queue.\n\
3226 \n\
3227 If there is nothing to receive then raise ChannelEmptyError, unless\n\
3228 a default value is provided.  In that case return it.");
3229 
3230 static PyObject *
channelsmod_close(PyObject * self,PyObject * args,PyObject * kwds)3231 channelsmod_close(PyObject *self, PyObject *args, PyObject *kwds)
3232 {
3233     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
3234     int64_t cid;
3235     struct channel_id_converter_data cid_data = {
3236         .module = self,
3237     };
3238     int send = 0;
3239     int recv = 0;
3240     int force = 0;
3241     if (!PyArg_ParseTupleAndKeywords(args, kwds,
3242                                      "O&|$ppp:channel_close", kwlist,
3243                                      channel_id_converter, &cid_data,
3244                                      &send, &recv, &force)) {
3245         return NULL;
3246     }
3247     cid = cid_data.cid;
3248 
3249     int err = channel_close(&_globals.channels, cid, send-recv, force);
3250     if (handle_channel_error(err, self, cid)) {
3251         return NULL;
3252     }
3253     Py_RETURN_NONE;
3254 }
3255 
3256 PyDoc_STRVAR(channelsmod_close_doc,
3257 "channel_close(cid, *, send=None, recv=None, force=False)\n\
3258 \n\
3259 Close the channel for all interpreters.\n\
3260 \n\
3261 If the channel is empty then the keyword args are ignored and both\n\
3262 ends are immediately closed.  Otherwise, if 'force' is True then\n\
3263 all queued items are released and both ends are immediately\n\
3264 closed.\n\
3265 \n\
3266 If the channel is not empty *and* 'force' is False then following\n\
3267 happens:\n\
3268 \n\
3269  * recv is True (regardless of send):\n\
3270    - raise ChannelNotEmptyError\n\
3271  * recv is None and send is None:\n\
3272    - raise ChannelNotEmptyError\n\
3273  * send is True and recv is not True:\n\
3274    - fully close the 'send' end\n\
3275    - close the 'recv' end to interpreters not already receiving\n\
3276    - fully close it once empty\n\
3277 \n\
3278 Closing an already closed channel results in a ChannelClosedError.\n\
3279 \n\
3280 Once the channel's ID has no more ref counts in any interpreter\n\
3281 the channel will be destroyed.");
3282 
3283 static PyObject *
channelsmod_release(PyObject * self,PyObject * args,PyObject * kwds)3284 channelsmod_release(PyObject *self, PyObject *args, PyObject *kwds)
3285 {
3286     // Note that only the current interpreter is affected.
3287     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
3288     int64_t cid;
3289     struct channel_id_converter_data cid_data = {
3290         .module = self,
3291     };
3292     int send = 0;
3293     int recv = 0;
3294     int force = 0;
3295     if (!PyArg_ParseTupleAndKeywords(args, kwds,
3296                                      "O&|$ppp:channel_release", kwlist,
3297                                      channel_id_converter, &cid_data,
3298                                      &send, &recv, &force)) {
3299         return NULL;
3300     }
3301     cid = cid_data.cid;
3302     if (send == 0 && recv == 0) {
3303         send = 1;
3304         recv = 1;
3305     }
3306 
3307     // XXX Handle force is True.
3308     // XXX Fix implicit release.
3309 
3310     int err = channel_release(&_globals.channels, cid, send, recv);
3311     if (handle_channel_error(err, self, cid)) {
3312         return NULL;
3313     }
3314     Py_RETURN_NONE;
3315 }
3316 
3317 PyDoc_STRVAR(channelsmod_release_doc,
3318 "channel_release(cid, *, send=None, recv=None, force=True)\n\
3319 \n\
3320 Close the channel for the current interpreter.  'send' and 'recv'\n\
3321 (bool) may be used to indicate the ends to close.  By default both\n\
3322 ends are closed.  Closing an already closed end is a noop.");
3323 
3324 static PyObject *
channelsmod_get_count(PyObject * self,PyObject * args,PyObject * kwds)3325 channelsmod_get_count(PyObject *self, PyObject *args, PyObject *kwds)
3326 {
3327     static char *kwlist[] = {"cid", NULL};
3328     struct channel_id_converter_data cid_data = {
3329         .module = self,
3330     };
3331     if (!PyArg_ParseTupleAndKeywords(args, kwds,
3332                                      "O&:get_count", kwlist,
3333                                      channel_id_converter, &cid_data)) {
3334         return NULL;
3335     }
3336     int64_t cid = cid_data.cid;
3337 
3338     Py_ssize_t count = -1;
3339     int err = _channel_get_count(&_globals.channels, cid, &count);
3340     if (handle_channel_error(err, self, cid)) {
3341         return NULL;
3342     }
3343     assert(count >= 0);
3344     return PyLong_FromSsize_t(count);
3345 }
3346 
3347 PyDoc_STRVAR(channelsmod_get_count_doc,
3348 "get_count(cid)\n\
3349 \n\
3350 Return the number of items in the channel.");
3351 
3352 static PyObject *
channelsmod_get_info(PyObject * self,PyObject * args,PyObject * kwds)3353 channelsmod_get_info(PyObject *self, PyObject *args, PyObject *kwds)
3354 {
3355     static char *kwlist[] = {"cid", NULL};
3356     struct channel_id_converter_data cid_data = {
3357         .module = self,
3358     };
3359     if (!PyArg_ParseTupleAndKeywords(args, kwds,
3360                                      "O&:_get_info", kwlist,
3361                                      channel_id_converter, &cid_data)) {
3362         return NULL;
3363     }
3364     int64_t cid = cid_data.cid;
3365 
3366     struct channel_info info;
3367     int err = _channel_get_info(&_globals.channels, cid, &info);
3368     if (handle_channel_error(err, self, cid)) {
3369         return NULL;
3370     }
3371     return new_channel_info(self, &info);
3372 }
3373 
3374 PyDoc_STRVAR(channelsmod_get_info_doc,
3375 "get_info(cid)\n\
3376 \n\
3377 Return details about the channel.");
3378 
3379 static PyObject *
channelsmod_get_channel_defaults(PyObject * self,PyObject * args,PyObject * kwds)3380 channelsmod_get_channel_defaults(PyObject *self, PyObject *args, PyObject *kwds)
3381 {
3382     static char *kwlist[] = {"cid", NULL};
3383     struct channel_id_converter_data cid_data = {
3384         .module = self,
3385     };
3386     if (!PyArg_ParseTupleAndKeywords(args, kwds,
3387                                      "O&:get_channel_defaults", kwlist,
3388                                      channel_id_converter, &cid_data)) {
3389         return NULL;
3390     }
3391     int64_t cid = cid_data.cid;
3392 
3393     PyThread_type_lock mutex = NULL;
3394     _channel_state *channel = NULL;
3395     int err = _channels_lookup(&_globals.channels, cid, &mutex, &channel);
3396     if (handle_channel_error(err, self, cid)) {
3397         return NULL;
3398     }
3399     int unboundop = channel->defaults.unboundop;
3400     PyThread_release_lock(mutex);
3401 
3402     PyObject *defaults = Py_BuildValue("i", unboundop);
3403     return defaults;
3404 }
3405 
3406 PyDoc_STRVAR(channelsmod_get_channel_defaults_doc,
3407 "get_channel_defaults(cid)\n\
3408 \n\
3409 Return the channel's default values, set when it was created.");
3410 
3411 static PyObject *
channelsmod__channel_id(PyObject * self,PyObject * args,PyObject * kwds)3412 channelsmod__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
3413 {
3414     module_state *state = get_module_state(self);
3415     if (state == NULL) {
3416         return NULL;
3417     }
3418     PyTypeObject *cls = state->ChannelIDType;
3419 
3420     PyObject *mod = get_module_from_owned_type(cls);
3421     assert(mod == self);
3422     Py_DECREF(mod);
3423 
3424     return _channelid_new(self, cls, args, kwds);
3425 }
3426 
3427 static PyObject *
channelsmod__register_end_types(PyObject * self,PyObject * args,PyObject * kwds)3428 channelsmod__register_end_types(PyObject *self, PyObject *args, PyObject *kwds)
3429 {
3430     static char *kwlist[] = {"send", "recv", NULL};
3431     PyObject *send;
3432     PyObject *recv;
3433     if (!PyArg_ParseTupleAndKeywords(args, kwds,
3434                                      "OO:_register_end_types", kwlist,
3435                                      &send, &recv)) {
3436         return NULL;
3437     }
3438     if (!PyType_Check(send)) {
3439         PyErr_SetString(PyExc_TypeError, "expected a type for 'send'");
3440         return NULL;
3441     }
3442     if (!PyType_Check(recv)) {
3443         PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'");
3444         return NULL;
3445     }
3446     PyTypeObject *cls_send = (PyTypeObject *)send;
3447     PyTypeObject *cls_recv = (PyTypeObject *)recv;
3448 
3449     if (set_channelend_types(self, cls_send, cls_recv) < 0) {
3450         return NULL;
3451     }
3452 
3453     Py_RETURN_NONE;
3454 }
3455 
3456 static PyMethodDef module_functions[] = {
3457     {"create",                     _PyCFunction_CAST(channelsmod_create),
3458      METH_VARARGS | METH_KEYWORDS, channelsmod_create_doc},
3459     {"destroy",                    _PyCFunction_CAST(channelsmod_destroy),
3460      METH_VARARGS | METH_KEYWORDS, channelsmod_destroy_doc},
3461     {"list_all",                   channelsmod_list_all,
3462      METH_NOARGS,                  channelsmod_list_all_doc},
3463     {"list_interpreters",          _PyCFunction_CAST(channelsmod_list_interpreters),
3464      METH_VARARGS | METH_KEYWORDS, channelsmod_list_interpreters_doc},
3465     {"send",                       _PyCFunction_CAST(channelsmod_send),
3466      METH_VARARGS | METH_KEYWORDS, channelsmod_send_doc},
3467     {"send_buffer",                _PyCFunction_CAST(channelsmod_send_buffer),
3468      METH_VARARGS | METH_KEYWORDS, channelsmod_send_buffer_doc},
3469     {"recv",                       _PyCFunction_CAST(channelsmod_recv),
3470      METH_VARARGS | METH_KEYWORDS, channelsmod_recv_doc},
3471     {"close",                      _PyCFunction_CAST(channelsmod_close),
3472      METH_VARARGS | METH_KEYWORDS, channelsmod_close_doc},
3473     {"release",                    _PyCFunction_CAST(channelsmod_release),
3474      METH_VARARGS | METH_KEYWORDS, channelsmod_release_doc},
3475     {"get_count",                   _PyCFunction_CAST(channelsmod_get_count),
3476      METH_VARARGS | METH_KEYWORDS, channelsmod_get_count_doc},
3477     {"get_info",                   _PyCFunction_CAST(channelsmod_get_info),
3478      METH_VARARGS | METH_KEYWORDS, channelsmod_get_info_doc},
3479     {"get_channel_defaults",       _PyCFunction_CAST(channelsmod_get_channel_defaults),
3480      METH_VARARGS | METH_KEYWORDS, channelsmod_get_channel_defaults_doc},
3481     {"_channel_id",                _PyCFunction_CAST(channelsmod__channel_id),
3482      METH_VARARGS | METH_KEYWORDS, NULL},
3483     {"_register_end_types",        _PyCFunction_CAST(channelsmod__register_end_types),
3484      METH_VARARGS | METH_KEYWORDS, NULL},
3485 
3486     {NULL,                        NULL}           /* sentinel */
3487 };
3488 
3489 
3490 /* initialization function */
3491 
3492 PyDoc_STRVAR(module_doc,
3493 "This module provides primitive operations to manage Python interpreters.\n\
3494 The 'interpreters' module provides a more convenient interface.");
3495 
3496 static int
module_exec(PyObject * mod)3497 module_exec(PyObject *mod)
3498 {
3499     int err = _globals_init();
3500     if (handle_channel_error(err, mod, -1)) {
3501         return -1;
3502     }
3503 
3504     module_state *state = get_module_state(mod);
3505     if (state == NULL) {
3506         goto error;
3507     }
3508 
3509     /* Add exception types */
3510     if (exceptions_init(mod) != 0) {
3511         goto error;
3512     }
3513 
3514     /* Add other types */
3515 
3516     // ChannelInfo
3517     state->ChannelInfoType = PyStructSequence_NewType(&channel_info_desc);
3518     if (state->ChannelInfoType == NULL) {
3519         goto error;
3520     }
3521     if (PyModule_AddType(mod, state->ChannelInfoType) < 0) {
3522         goto error;
3523     }
3524 
3525     // ChannelID
3526     state->ChannelIDType = add_channelid_type(mod);
3527     if (state->ChannelIDType == NULL) {
3528         goto error;
3529     }
3530 
3531     /* Make sure chnnels drop objects owned by this interpreter. */
3532     PyInterpreterState *interp = _get_current_interp();
3533     PyUnstable_AtExit(interp, clear_interpreter, (void *)interp);
3534 
3535     return 0;
3536 
3537 error:
3538     if (state != NULL) {
3539         clear_xid_types(state);
3540     }
3541     _globals_fini();
3542     return -1;
3543 }
3544 
3545 static struct PyModuleDef_Slot module_slots[] = {
3546     {Py_mod_exec, module_exec},
3547     {Py_mod_multiple_interpreters, Py_MOD_PER_INTERPRETER_GIL_SUPPORTED},
3548     {Py_mod_gil, Py_MOD_GIL_NOT_USED},
3549     {0, NULL},
3550 };
3551 
3552 static int
module_traverse(PyObject * mod,visitproc visit,void * arg)3553 module_traverse(PyObject *mod, visitproc visit, void *arg)
3554 {
3555     module_state *state = get_module_state(mod);
3556     assert(state != NULL);
3557     traverse_module_state(state, visit, arg);
3558     return 0;
3559 }
3560 
3561 static int
module_clear(PyObject * mod)3562 module_clear(PyObject *mod)
3563 {
3564     module_state *state = get_module_state(mod);
3565     assert(state != NULL);
3566 
3567     // Now we clear the module state.
3568     clear_module_state(state);
3569     return 0;
3570 }
3571 
3572 static void
module_free(void * mod)3573 module_free(void *mod)
3574 {
3575     module_state *state = get_module_state(mod);
3576     assert(state != NULL);
3577 
3578     // Now we clear the module state.
3579     clear_module_state(state);
3580 
3581     _globals_fini();
3582 }
3583 
3584 static struct PyModuleDef moduledef = {
3585     .m_base = PyModuleDef_HEAD_INIT,
3586     .m_name = MODULE_NAME_STR,
3587     .m_doc = module_doc,
3588     .m_size = sizeof(module_state),
3589     .m_methods = module_functions,
3590     .m_slots = module_slots,
3591     .m_traverse = module_traverse,
3592     .m_clear = module_clear,
3593     .m_free = (freefunc)module_free,
3594 };
3595 
3596 PyMODINIT_FUNC
MODINIT_FUNC_NAME(void)3597 MODINIT_FUNC_NAME(void)
3598 {
3599     return PyModuleDef_Init(&moduledef);
3600 }
3601