• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from collections import namedtuple
2import contextlib
3import itertools
4import os
5import pickle
6import sys
7from textwrap import dedent
8import threading
9import time
10import unittest
11
12from test import support
13from test.support import script_helper
14
15
16interpreters = support.import_module('_xxsubinterpreters')
17
18
19##################################
20# helpers
21
22def _captured_script(script):
23    r, w = os.pipe()
24    indented = script.replace('\n', '\n                ')
25    wrapped = dedent(f"""
26        import contextlib
27        with open({w}, 'w') as spipe:
28            with contextlib.redirect_stdout(spipe):
29                {indented}
30        """)
31    return wrapped, open(r)
32
33
34def _run_output(interp, request, shared=None):
35    script, rpipe = _captured_script(request)
36    with rpipe:
37        interpreters.run_string(interp, script, shared)
38        return rpipe.read()
39
40
41@contextlib.contextmanager
42def _running(interp):
43    r, w = os.pipe()
44    def run():
45        interpreters.run_string(interp, dedent(f"""
46            # wait for "signal"
47            with open({r}) as rpipe:
48                rpipe.read()
49            """))
50
51    t = threading.Thread(target=run)
52    t.start()
53
54    yield
55
56    with open(w, 'w') as spipe:
57        spipe.write('done')
58    t.join()
59
60
61#@contextmanager
62#def run_threaded(id, source, **shared):
63#    def run():
64#        run_interp(id, source, **shared)
65#    t = threading.Thread(target=run)
66#    t.start()
67#    yield
68#    t.join()
69
70
71def run_interp(id, source, **shared):
72    _run_interp(id, source, shared)
73
74
75def _run_interp(id, source, shared, _mainns={}):
76    source = dedent(source)
77    main = interpreters.get_main()
78    if main == id:
79        if interpreters.get_current() != main:
80            raise RuntimeError
81        # XXX Run a func?
82        exec(source, _mainns)
83    else:
84        interpreters.run_string(id, source, shared)
85
86
87class Interpreter(namedtuple('Interpreter', 'name id')):
88
89    @classmethod
90    def from_raw(cls, raw):
91        if isinstance(raw, cls):
92            return raw
93        elif isinstance(raw, str):
94            return cls(raw)
95        else:
96            raise NotImplementedError
97
98    def __new__(cls, name=None, id=None):
99        main = interpreters.get_main()
100        if id == main:
101            if not name:
102                name = 'main'
103            elif name != 'main':
104                raise ValueError(
105                    'name mismatch (expected "main", got "{}")'.format(name))
106            id = main
107        elif id is not None:
108            if not name:
109                name = 'interp'
110            elif name == 'main':
111                raise ValueError('name mismatch (unexpected "main")')
112            if not isinstance(id, interpreters.InterpreterID):
113                id = interpreters.InterpreterID(id)
114        elif not name or name == 'main':
115            name = 'main'
116            id = main
117        else:
118            id = interpreters.create()
119        self = super().__new__(cls, name, id)
120        return self
121
122
123# XXX expect_channel_closed() is unnecessary once we improve exc propagation.
124
125@contextlib.contextmanager
126def expect_channel_closed():
127    try:
128        yield
129    except interpreters.ChannelClosedError:
130        pass
131    else:
132        assert False, 'channel not closed'
133
134
135class ChannelAction(namedtuple('ChannelAction', 'action end interp')):
136
137    def __new__(cls, action, end=None, interp=None):
138        if not end:
139            end = 'both'
140        if not interp:
141            interp = 'main'
142        self = super().__new__(cls, action, end, interp)
143        return self
144
145    def __init__(self, *args, **kwargs):
146        if self.action == 'use':
147            if self.end not in ('same', 'opposite', 'send', 'recv'):
148                raise ValueError(self.end)
149        elif self.action in ('close', 'force-close'):
150            if self.end not in ('both', 'same', 'opposite', 'send', 'recv'):
151                raise ValueError(self.end)
152        else:
153            raise ValueError(self.action)
154        if self.interp not in ('main', 'same', 'other', 'extra'):
155            raise ValueError(self.interp)
156
157    def resolve_end(self, end):
158        if self.end == 'same':
159            return end
160        elif self.end == 'opposite':
161            return 'recv' if end == 'send' else 'send'
162        else:
163            return self.end
164
165    def resolve_interp(self, interp, other, extra):
166        if self.interp == 'same':
167            return interp
168        elif self.interp == 'other':
169            if other is None:
170                raise RuntimeError
171            return other
172        elif self.interp == 'extra':
173            if extra is None:
174                raise RuntimeError
175            return extra
176        elif self.interp == 'main':
177            if interp.name == 'main':
178                return interp
179            elif other and other.name == 'main':
180                return other
181            else:
182                raise RuntimeError
183        # Per __init__(), there aren't any others.
184
185
186class ChannelState(namedtuple('ChannelState', 'pending closed')):
187
188    def __new__(cls, pending=0, *, closed=False):
189        self = super().__new__(cls, pending, closed)
190        return self
191
192    def incr(self):
193        return type(self)(self.pending + 1, closed=self.closed)
194
195    def decr(self):
196        return type(self)(self.pending - 1, closed=self.closed)
197
198    def close(self, *, force=True):
199        if self.closed:
200            if not force or self.pending == 0:
201                return self
202        return type(self)(0 if force else self.pending, closed=True)
203
204
205def run_action(cid, action, end, state, *, hideclosed=True):
206    if state.closed:
207        if action == 'use' and end == 'recv' and state.pending:
208            expectfail = False
209        else:
210            expectfail = True
211    else:
212        expectfail = False
213
214    try:
215        result = _run_action(cid, action, end, state)
216    except interpreters.ChannelClosedError:
217        if not hideclosed and not expectfail:
218            raise
219        result = state.close()
220    else:
221        if expectfail:
222            raise ...  # XXX
223    return result
224
225
226def _run_action(cid, action, end, state):
227    if action == 'use':
228        if end == 'send':
229            interpreters.channel_send(cid, b'spam')
230            return state.incr()
231        elif end == 'recv':
232            if not state.pending:
233                try:
234                    interpreters.channel_recv(cid)
235                except interpreters.ChannelEmptyError:
236                    return state
237                else:
238                    raise Exception('expected ChannelEmptyError')
239            else:
240                interpreters.channel_recv(cid)
241                return state.decr()
242        else:
243            raise ValueError(end)
244    elif action == 'close':
245        kwargs = {}
246        if end in ('recv', 'send'):
247            kwargs[end] = True
248        interpreters.channel_close(cid, **kwargs)
249        return state.close()
250    elif action == 'force-close':
251        kwargs = {
252            'force': True,
253            }
254        if end in ('recv', 'send'):
255            kwargs[end] = True
256        interpreters.channel_close(cid, **kwargs)
257        return state.close(force=True)
258    else:
259        raise ValueError(action)
260
261
262def clean_up_interpreters():
263    for id in interpreters.list_all():
264        if id == 0:  # main
265            continue
266        try:
267            interpreters.destroy(id)
268        except RuntimeError:
269            pass  # already destroyed
270
271
272def clean_up_channels():
273    for cid in interpreters.channel_list_all():
274        try:
275            interpreters.channel_destroy(cid)
276        except interpreters.ChannelNotFoundError:
277            pass  # already destroyed
278
279
280class TestBase(unittest.TestCase):
281
282    def tearDown(self):
283        clean_up_interpreters()
284        clean_up_channels()
285
286
287##################################
288# misc. tests
289
290class IsShareableTests(unittest.TestCase):
291
292    def test_default_shareables(self):
293        shareables = [
294                # singletons
295                None,
296                # builtin objects
297                b'spam',
298                'spam',
299                10,
300                -10,
301                ]
302        for obj in shareables:
303            with self.subTest(obj):
304                self.assertTrue(
305                    interpreters.is_shareable(obj))
306
307    def test_not_shareable(self):
308        class Cheese:
309            def __init__(self, name):
310                self.name = name
311            def __str__(self):
312                return self.name
313
314        class SubBytes(bytes):
315            """A subclass of a shareable type."""
316
317        not_shareables = [
318                # singletons
319                True,
320                False,
321                NotImplemented,
322                ...,
323                # builtin types and objects
324                type,
325                object,
326                object(),
327                Exception(),
328                100.0,
329                # user-defined types and objects
330                Cheese,
331                Cheese('Wensleydale'),
332                SubBytes(b'spam'),
333                ]
334        for obj in not_shareables:
335            with self.subTest(repr(obj)):
336                self.assertFalse(
337                    interpreters.is_shareable(obj))
338
339
340class ShareableTypeTests(unittest.TestCase):
341
342    def setUp(self):
343        super().setUp()
344        self.cid = interpreters.channel_create()
345
346    def tearDown(self):
347        interpreters.channel_destroy(self.cid)
348        super().tearDown()
349
350    def _assert_values(self, values):
351        for obj in values:
352            with self.subTest(obj):
353                interpreters.channel_send(self.cid, obj)
354                got = interpreters.channel_recv(self.cid)
355
356                self.assertEqual(got, obj)
357                self.assertIs(type(got), type(obj))
358                # XXX Check the following in the channel tests?
359                #self.assertIsNot(got, obj)
360
361    def test_singletons(self):
362        for obj in [None]:
363            with self.subTest(obj):
364                interpreters.channel_send(self.cid, obj)
365                got = interpreters.channel_recv(self.cid)
366
367                # XXX What about between interpreters?
368                self.assertIs(got, obj)
369
370    def test_types(self):
371        self._assert_values([
372            b'spam',
373            9999,
374            self.cid,
375            ])
376
377    def test_bytes(self):
378        self._assert_values(i.to_bytes(2, 'little', signed=True)
379                            for i in range(-1, 258))
380
381    def test_strs(self):
382        self._assert_values(['hello world', '你好世界', ''])
383
384    def test_int(self):
385        self._assert_values(itertools.chain(range(-1, 258),
386                                            [sys.maxsize, -sys.maxsize - 1]))
387
388    def test_non_shareable_int(self):
389        ints = [
390            sys.maxsize + 1,
391            -sys.maxsize - 2,
392            2**1000,
393        ]
394        for i in ints:
395            with self.subTest(i):
396                with self.assertRaises(OverflowError):
397                    interpreters.channel_send(self.cid, i)
398
399
400##################################
401# interpreter tests
402
403class ListAllTests(TestBase):
404
405    def test_initial(self):
406        main = interpreters.get_main()
407        ids = interpreters.list_all()
408        self.assertEqual(ids, [main])
409
410    def test_after_creating(self):
411        main = interpreters.get_main()
412        first = interpreters.create()
413        second = interpreters.create()
414        ids = interpreters.list_all()
415        self.assertEqual(ids, [main, first, second])
416
417    def test_after_destroying(self):
418        main = interpreters.get_main()
419        first = interpreters.create()
420        second = interpreters.create()
421        interpreters.destroy(first)
422        ids = interpreters.list_all()
423        self.assertEqual(ids, [main, second])
424
425
426class GetCurrentTests(TestBase):
427
428    def test_main(self):
429        main = interpreters.get_main()
430        cur = interpreters.get_current()
431        self.assertEqual(cur, main)
432        self.assertIsInstance(cur, interpreters.InterpreterID)
433
434    def test_subinterpreter(self):
435        main = interpreters.get_main()
436        interp = interpreters.create()
437        out = _run_output(interp, dedent("""
438            import _xxsubinterpreters as _interpreters
439            cur = _interpreters.get_current()
440            print(cur)
441            assert isinstance(cur, _interpreters.InterpreterID)
442            """))
443        cur = int(out.strip())
444        _, expected = interpreters.list_all()
445        self.assertEqual(cur, expected)
446        self.assertNotEqual(cur, main)
447
448
449class GetMainTests(TestBase):
450
451    def test_from_main(self):
452        [expected] = interpreters.list_all()
453        main = interpreters.get_main()
454        self.assertEqual(main, expected)
455        self.assertIsInstance(main, interpreters.InterpreterID)
456
457    def test_from_subinterpreter(self):
458        [expected] = interpreters.list_all()
459        interp = interpreters.create()
460        out = _run_output(interp, dedent("""
461            import _xxsubinterpreters as _interpreters
462            main = _interpreters.get_main()
463            print(main)
464            assert isinstance(main, _interpreters.InterpreterID)
465            """))
466        main = int(out.strip())
467        self.assertEqual(main, expected)
468
469
470class IsRunningTests(TestBase):
471
472    def test_main(self):
473        main = interpreters.get_main()
474        self.assertTrue(interpreters.is_running(main))
475
476    def test_subinterpreter(self):
477        interp = interpreters.create()
478        self.assertFalse(interpreters.is_running(interp))
479
480        with _running(interp):
481            self.assertTrue(interpreters.is_running(interp))
482        self.assertFalse(interpreters.is_running(interp))
483
484    def test_from_subinterpreter(self):
485        interp = interpreters.create()
486        out = _run_output(interp, dedent(f"""
487            import _xxsubinterpreters as _interpreters
488            if _interpreters.is_running({interp}):
489                print(True)
490            else:
491                print(False)
492            """))
493        self.assertEqual(out.strip(), 'True')
494
495    def test_already_destroyed(self):
496        interp = interpreters.create()
497        interpreters.destroy(interp)
498        with self.assertRaises(RuntimeError):
499            interpreters.is_running(interp)
500
501    def test_does_not_exist(self):
502        with self.assertRaises(RuntimeError):
503            interpreters.is_running(1_000_000)
504
505    def test_bad_id(self):
506        with self.assertRaises(ValueError):
507            interpreters.is_running(-1)
508
509
510class InterpreterIDTests(TestBase):
511
512    def test_with_int(self):
513        id = interpreters.InterpreterID(10, force=True)
514
515        self.assertEqual(int(id), 10)
516
517    def test_coerce_id(self):
518        class Int(str):
519            def __index__(self):
520                return 10
521
522        id = interpreters.InterpreterID(Int(), force=True)
523        self.assertEqual(int(id), 10)
524
525    def test_bad_id(self):
526        self.assertRaises(TypeError, interpreters.InterpreterID, object())
527        self.assertRaises(TypeError, interpreters.InterpreterID, 10.0)
528        self.assertRaises(TypeError, interpreters.InterpreterID, '10')
529        self.assertRaises(TypeError, interpreters.InterpreterID, b'10')
530        self.assertRaises(ValueError, interpreters.InterpreterID, -1)
531        self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64)
532
533    def test_does_not_exist(self):
534        id = interpreters.channel_create()
535        with self.assertRaises(RuntimeError):
536            interpreters.InterpreterID(int(id) + 1)  # unforced
537
538    def test_str(self):
539        id = interpreters.InterpreterID(10, force=True)
540        self.assertEqual(str(id), '10')
541
542    def test_repr(self):
543        id = interpreters.InterpreterID(10, force=True)
544        self.assertEqual(repr(id), 'InterpreterID(10)')
545
546    def test_equality(self):
547        id1 = interpreters.create()
548        id2 = interpreters.InterpreterID(int(id1))
549        id3 = interpreters.create()
550
551        self.assertTrue(id1 == id1)
552        self.assertTrue(id1 == id2)
553        self.assertTrue(id1 == int(id1))
554        self.assertTrue(int(id1) == id1)
555        self.assertTrue(id1 == float(int(id1)))
556        self.assertTrue(float(int(id1)) == id1)
557        self.assertFalse(id1 == float(int(id1)) + 0.1)
558        self.assertFalse(id1 == str(int(id1)))
559        self.assertFalse(id1 == 2**1000)
560        self.assertFalse(id1 == float('inf'))
561        self.assertFalse(id1 == 'spam')
562        self.assertFalse(id1 == id3)
563
564        self.assertFalse(id1 != id1)
565        self.assertFalse(id1 != id2)
566        self.assertTrue(id1 != id3)
567
568
569class CreateTests(TestBase):
570
571    def test_in_main(self):
572        id = interpreters.create()
573        self.assertIsInstance(id, interpreters.InterpreterID)
574
575        self.assertIn(id, interpreters.list_all())
576
577    @unittest.skip('enable this test when working on pystate.c')
578    def test_unique_id(self):
579        seen = set()
580        for _ in range(100):
581            id = interpreters.create()
582            interpreters.destroy(id)
583            seen.add(id)
584
585        self.assertEqual(len(seen), 100)
586
587    def test_in_thread(self):
588        lock = threading.Lock()
589        id = None
590        def f():
591            nonlocal id
592            id = interpreters.create()
593            lock.acquire()
594            lock.release()
595
596        t = threading.Thread(target=f)
597        with lock:
598            t.start()
599        t.join()
600        self.assertIn(id, interpreters.list_all())
601
602    def test_in_subinterpreter(self):
603        main, = interpreters.list_all()
604        id1 = interpreters.create()
605        out = _run_output(id1, dedent("""
606            import _xxsubinterpreters as _interpreters
607            id = _interpreters.create()
608            print(id)
609            assert isinstance(id, _interpreters.InterpreterID)
610            """))
611        id2 = int(out.strip())
612
613        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
614
615    def test_in_threaded_subinterpreter(self):
616        main, = interpreters.list_all()
617        id1 = interpreters.create()
618        id2 = None
619        def f():
620            nonlocal id2
621            out = _run_output(id1, dedent("""
622                import _xxsubinterpreters as _interpreters
623                id = _interpreters.create()
624                print(id)
625                """))
626            id2 = int(out.strip())
627
628        t = threading.Thread(target=f)
629        t.start()
630        t.join()
631
632        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
633
634    def test_after_destroy_all(self):
635        before = set(interpreters.list_all())
636        # Create 3 subinterpreters.
637        ids = []
638        for _ in range(3):
639            id = interpreters.create()
640            ids.append(id)
641        # Now destroy them.
642        for id in ids:
643            interpreters.destroy(id)
644        # Finally, create another.
645        id = interpreters.create()
646        self.assertEqual(set(interpreters.list_all()), before | {id})
647
648    def test_after_destroy_some(self):
649        before = set(interpreters.list_all())
650        # Create 3 subinterpreters.
651        id1 = interpreters.create()
652        id2 = interpreters.create()
653        id3 = interpreters.create()
654        # Now destroy 2 of them.
655        interpreters.destroy(id1)
656        interpreters.destroy(id3)
657        # Finally, create another.
658        id = interpreters.create()
659        self.assertEqual(set(interpreters.list_all()), before | {id, id2})
660
661
662class DestroyTests(TestBase):
663
664    def test_one(self):
665        id1 = interpreters.create()
666        id2 = interpreters.create()
667        id3 = interpreters.create()
668        self.assertIn(id2, interpreters.list_all())
669        interpreters.destroy(id2)
670        self.assertNotIn(id2, interpreters.list_all())
671        self.assertIn(id1, interpreters.list_all())
672        self.assertIn(id3, interpreters.list_all())
673
674    def test_all(self):
675        before = set(interpreters.list_all())
676        ids = set()
677        for _ in range(3):
678            id = interpreters.create()
679            ids.add(id)
680        self.assertEqual(set(interpreters.list_all()), before | ids)
681        for id in ids:
682            interpreters.destroy(id)
683        self.assertEqual(set(interpreters.list_all()), before)
684
685    def test_main(self):
686        main, = interpreters.list_all()
687        with self.assertRaises(RuntimeError):
688            interpreters.destroy(main)
689
690        def f():
691            with self.assertRaises(RuntimeError):
692                interpreters.destroy(main)
693
694        t = threading.Thread(target=f)
695        t.start()
696        t.join()
697
698    def test_already_destroyed(self):
699        id = interpreters.create()
700        interpreters.destroy(id)
701        with self.assertRaises(RuntimeError):
702            interpreters.destroy(id)
703
704    def test_does_not_exist(self):
705        with self.assertRaises(RuntimeError):
706            interpreters.destroy(1_000_000)
707
708    def test_bad_id(self):
709        with self.assertRaises(ValueError):
710            interpreters.destroy(-1)
711
712    def test_from_current(self):
713        main, = interpreters.list_all()
714        id = interpreters.create()
715        script = dedent(f"""
716            import _xxsubinterpreters as _interpreters
717            try:
718                _interpreters.destroy({id})
719            except RuntimeError:
720                pass
721            """)
722
723        interpreters.run_string(id, script)
724        self.assertEqual(set(interpreters.list_all()), {main, id})
725
726    def test_from_sibling(self):
727        main, = interpreters.list_all()
728        id1 = interpreters.create()
729        id2 = interpreters.create()
730        script = dedent(f"""
731            import _xxsubinterpreters as _interpreters
732            _interpreters.destroy({id2})
733            """)
734        interpreters.run_string(id1, script)
735
736        self.assertEqual(set(interpreters.list_all()), {main, id1})
737
738    def test_from_other_thread(self):
739        id = interpreters.create()
740        def f():
741            interpreters.destroy(id)
742
743        t = threading.Thread(target=f)
744        t.start()
745        t.join()
746
747    def test_still_running(self):
748        main, = interpreters.list_all()
749        interp = interpreters.create()
750        with _running(interp):
751            self.assertTrue(interpreters.is_running(interp),
752                            msg=f"Interp {interp} should be running before destruction.")
753
754            with self.assertRaises(RuntimeError,
755                                   msg=f"Should not be able to destroy interp {interp} while it's still running."):
756                interpreters.destroy(interp)
757            self.assertTrue(interpreters.is_running(interp))
758
759
760class RunStringTests(TestBase):
761
762    SCRIPT = dedent("""
763        with open('{}', 'w') as out:
764            out.write('{}')
765        """)
766    FILENAME = 'spam'
767
768    def setUp(self):
769        super().setUp()
770        self.id = interpreters.create()
771        self._fs = None
772
773    def tearDown(self):
774        if self._fs is not None:
775            self._fs.close()
776        super().tearDown()
777
778    def test_success(self):
779        script, file = _captured_script('print("it worked!", end="")')
780        with file:
781            interpreters.run_string(self.id, script)
782            out = file.read()
783
784        self.assertEqual(out, 'it worked!')
785
786    def test_in_thread(self):
787        script, file = _captured_script('print("it worked!", end="")')
788        with file:
789            def f():
790                interpreters.run_string(self.id, script)
791
792            t = threading.Thread(target=f)
793            t.start()
794            t.join()
795            out = file.read()
796
797        self.assertEqual(out, 'it worked!')
798
799    def test_create_thread(self):
800        subinterp = interpreters.create(isolated=False)
801        script, file = _captured_script("""
802            import threading
803            def f():
804                print('it worked!', end='')
805
806            t = threading.Thread(target=f)
807            t.start()
808            t.join()
809            """)
810        with file:
811            interpreters.run_string(subinterp, script)
812            out = file.read()
813
814        self.assertEqual(out, 'it worked!')
815
816    @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
817    def test_fork(self):
818        import tempfile
819        with tempfile.NamedTemporaryFile('w+') as file:
820            file.write('')
821            file.flush()
822
823            expected = 'spam spam spam spam spam'
824            script = dedent(f"""
825                import os
826                try:
827                    os.fork()
828                except RuntimeError:
829                    with open('{file.name}', 'w') as out:
830                        out.write('{expected}')
831                """)
832            interpreters.run_string(self.id, script)
833
834            file.seek(0)
835            content = file.read()
836            self.assertEqual(content, expected)
837
838    def test_already_running(self):
839        with _running(self.id):
840            with self.assertRaises(RuntimeError):
841                interpreters.run_string(self.id, 'print("spam")')
842
843    def test_does_not_exist(self):
844        id = 0
845        while id in interpreters.list_all():
846            id += 1
847        with self.assertRaises(RuntimeError):
848            interpreters.run_string(id, 'print("spam")')
849
850    def test_error_id(self):
851        with self.assertRaises(ValueError):
852            interpreters.run_string(-1, 'print("spam")')
853
854    def test_bad_id(self):
855        with self.assertRaises(TypeError):
856            interpreters.run_string('spam', 'print("spam")')
857
858    def test_bad_script(self):
859        with self.assertRaises(TypeError):
860            interpreters.run_string(self.id, 10)
861
862    def test_bytes_for_script(self):
863        with self.assertRaises(TypeError):
864            interpreters.run_string(self.id, b'print("spam")')
865
866    @contextlib.contextmanager
867    def assert_run_failed(self, exctype, msg=None):
868        with self.assertRaises(interpreters.RunFailedError) as caught:
869            yield
870        if msg is None:
871            self.assertEqual(str(caught.exception).split(':')[0],
872                             str(exctype))
873        else:
874            self.assertEqual(str(caught.exception),
875                             "{}: {}".format(exctype, msg))
876
877    def test_invalid_syntax(self):
878        with self.assert_run_failed(SyntaxError):
879            # missing close paren
880            interpreters.run_string(self.id, 'print("spam"')
881
882    def test_failure(self):
883        with self.assert_run_failed(Exception, 'spam'):
884            interpreters.run_string(self.id, 'raise Exception("spam")')
885
886    def test_SystemExit(self):
887        with self.assert_run_failed(SystemExit, '42'):
888            interpreters.run_string(self.id, 'raise SystemExit(42)')
889
890    def test_sys_exit(self):
891        with self.assert_run_failed(SystemExit):
892            interpreters.run_string(self.id, dedent("""
893                import sys
894                sys.exit()
895                """))
896
897        with self.assert_run_failed(SystemExit, '42'):
898            interpreters.run_string(self.id, dedent("""
899                import sys
900                sys.exit(42)
901                """))
902
903    def test_with_shared(self):
904        r, w = os.pipe()
905
906        shared = {
907                'spam': b'ham',
908                'eggs': b'-1',
909                'cheddar': None,
910                }
911        script = dedent(f"""
912            eggs = int(eggs)
913            spam = 42
914            result = spam + eggs
915
916            ns = dict(vars())
917            del ns['__builtins__']
918            import pickle
919            with open({w}, 'wb') as chan:
920                pickle.dump(ns, chan)
921            """)
922        interpreters.run_string(self.id, script, shared)
923        with open(r, 'rb') as chan:
924            ns = pickle.load(chan)
925
926        self.assertEqual(ns['spam'], 42)
927        self.assertEqual(ns['eggs'], -1)
928        self.assertEqual(ns['result'], 41)
929        self.assertIsNone(ns['cheddar'])
930
931    def test_shared_overwrites(self):
932        interpreters.run_string(self.id, dedent("""
933            spam = 'eggs'
934            ns1 = dict(vars())
935            del ns1['__builtins__']
936            """))
937
938        shared = {'spam': b'ham'}
939        script = dedent(f"""
940            ns2 = dict(vars())
941            del ns2['__builtins__']
942        """)
943        interpreters.run_string(self.id, script, shared)
944
945        r, w = os.pipe()
946        script = dedent(f"""
947            ns = dict(vars())
948            del ns['__builtins__']
949            import pickle
950            with open({w}, 'wb') as chan:
951                pickle.dump(ns, chan)
952            """)
953        interpreters.run_string(self.id, script)
954        with open(r, 'rb') as chan:
955            ns = pickle.load(chan)
956
957        self.assertEqual(ns['ns1']['spam'], 'eggs')
958        self.assertEqual(ns['ns2']['spam'], b'ham')
959        self.assertEqual(ns['spam'], b'ham')
960
961    def test_shared_overwrites_default_vars(self):
962        r, w = os.pipe()
963
964        shared = {'__name__': b'not __main__'}
965        script = dedent(f"""
966            spam = 42
967
968            ns = dict(vars())
969            del ns['__builtins__']
970            import pickle
971            with open({w}, 'wb') as chan:
972                pickle.dump(ns, chan)
973            """)
974        interpreters.run_string(self.id, script, shared)
975        with open(r, 'rb') as chan:
976            ns = pickle.load(chan)
977
978        self.assertEqual(ns['__name__'], b'not __main__')
979
980    def test_main_reused(self):
981        r, w = os.pipe()
982        interpreters.run_string(self.id, dedent(f"""
983            spam = True
984
985            ns = dict(vars())
986            del ns['__builtins__']
987            import pickle
988            with open({w}, 'wb') as chan:
989                pickle.dump(ns, chan)
990            del ns, pickle, chan
991            """))
992        with open(r, 'rb') as chan:
993            ns1 = pickle.load(chan)
994
995        r, w = os.pipe()
996        interpreters.run_string(self.id, dedent(f"""
997            eggs = False
998
999            ns = dict(vars())
1000            del ns['__builtins__']
1001            import pickle
1002            with open({w}, 'wb') as chan:
1003                pickle.dump(ns, chan)
1004            """))
1005        with open(r, 'rb') as chan:
1006            ns2 = pickle.load(chan)
1007
1008        self.assertIn('spam', ns1)
1009        self.assertNotIn('eggs', ns1)
1010        self.assertIn('eggs', ns2)
1011        self.assertIn('spam', ns2)
1012
1013    def test_execution_namespace_is_main(self):
1014        r, w = os.pipe()
1015
1016        script = dedent(f"""
1017            spam = 42
1018
1019            ns = dict(vars())
1020            ns['__builtins__'] = str(ns['__builtins__'])
1021            import pickle
1022            with open({w}, 'wb') as chan:
1023                pickle.dump(ns, chan)
1024            """)
1025        interpreters.run_string(self.id, script)
1026        with open(r, 'rb') as chan:
1027            ns = pickle.load(chan)
1028
1029        ns.pop('__builtins__')
1030        ns.pop('__loader__')
1031        self.assertEqual(ns, {
1032            '__name__': '__main__',
1033            '__annotations__': {},
1034            '__doc__': None,
1035            '__package__': None,
1036            '__spec__': None,
1037            'spam': 42,
1038            })
1039
1040    # XXX Fix this test!
1041    @unittest.skip('blocking forever')
1042    def test_still_running_at_exit(self):
1043        script = dedent(f"""
1044        from textwrap import dedent
1045        import threading
1046        import _xxsubinterpreters as _interpreters
1047        id = _interpreters.create()
1048        def f():
1049            _interpreters.run_string(id, dedent('''
1050                import time
1051                # Give plenty of time for the main interpreter to finish.
1052                time.sleep(1_000_000)
1053                '''))
1054
1055        t = threading.Thread(target=f)
1056        t.start()
1057        """)
1058        with support.temp_dir() as dirname:
1059            filename = script_helper.make_script(dirname, 'interp', script)
1060            with script_helper.spawn_python(filename) as proc:
1061                retcode = proc.wait()
1062
1063        self.assertEqual(retcode, 0)
1064
1065
1066##################################
1067# channel tests
1068
1069class ChannelIDTests(TestBase):
1070
1071    def test_default_kwargs(self):
1072        cid = interpreters._channel_id(10, force=True)
1073
1074        self.assertEqual(int(cid), 10)
1075        self.assertEqual(cid.end, 'both')
1076
1077    def test_with_kwargs(self):
1078        cid = interpreters._channel_id(10, send=True, force=True)
1079        self.assertEqual(cid.end, 'send')
1080
1081        cid = interpreters._channel_id(10, send=True, recv=False, force=True)
1082        self.assertEqual(cid.end, 'send')
1083
1084        cid = interpreters._channel_id(10, recv=True, force=True)
1085        self.assertEqual(cid.end, 'recv')
1086
1087        cid = interpreters._channel_id(10, recv=True, send=False, force=True)
1088        self.assertEqual(cid.end, 'recv')
1089
1090        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
1091        self.assertEqual(cid.end, 'both')
1092
1093    def test_coerce_id(self):
1094        class Int(str):
1095            def __index__(self):
1096                return 10
1097
1098        cid = interpreters._channel_id(Int(), force=True)
1099        self.assertEqual(int(cid), 10)
1100
1101    def test_bad_id(self):
1102        self.assertRaises(TypeError, interpreters._channel_id, object())
1103        self.assertRaises(TypeError, interpreters._channel_id, 10.0)
1104        self.assertRaises(TypeError, interpreters._channel_id, '10')
1105        self.assertRaises(TypeError, interpreters._channel_id, b'10')
1106        self.assertRaises(ValueError, interpreters._channel_id, -1)
1107        self.assertRaises(OverflowError, interpreters._channel_id, 2**64)
1108
1109    def test_bad_kwargs(self):
1110        with self.assertRaises(ValueError):
1111            interpreters._channel_id(10, send=False, recv=False)
1112
1113    def test_does_not_exist(self):
1114        cid = interpreters.channel_create()
1115        with self.assertRaises(interpreters.ChannelNotFoundError):
1116            interpreters._channel_id(int(cid) + 1)  # unforced
1117
1118    def test_str(self):
1119        cid = interpreters._channel_id(10, force=True)
1120        self.assertEqual(str(cid), '10')
1121
1122    def test_repr(self):
1123        cid = interpreters._channel_id(10, force=True)
1124        self.assertEqual(repr(cid), 'ChannelID(10)')
1125
1126        cid = interpreters._channel_id(10, send=True, force=True)
1127        self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
1128
1129        cid = interpreters._channel_id(10, recv=True, force=True)
1130        self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
1131
1132        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
1133        self.assertEqual(repr(cid), 'ChannelID(10)')
1134
1135    def test_equality(self):
1136        cid1 = interpreters.channel_create()
1137        cid2 = interpreters._channel_id(int(cid1))
1138        cid3 = interpreters.channel_create()
1139
1140        self.assertTrue(cid1 == cid1)
1141        self.assertTrue(cid1 == cid2)
1142        self.assertTrue(cid1 == int(cid1))
1143        self.assertTrue(int(cid1) == cid1)
1144        self.assertTrue(cid1 == float(int(cid1)))
1145        self.assertTrue(float(int(cid1)) == cid1)
1146        self.assertFalse(cid1 == float(int(cid1)) + 0.1)
1147        self.assertFalse(cid1 == str(int(cid1)))
1148        self.assertFalse(cid1 == 2**1000)
1149        self.assertFalse(cid1 == float('inf'))
1150        self.assertFalse(cid1 == 'spam')
1151        self.assertFalse(cid1 == cid3)
1152
1153        self.assertFalse(cid1 != cid1)
1154        self.assertFalse(cid1 != cid2)
1155        self.assertTrue(cid1 != cid3)
1156
1157
1158class ChannelTests(TestBase):
1159
1160    def test_create_cid(self):
1161        cid = interpreters.channel_create()
1162        self.assertIsInstance(cid, interpreters.ChannelID)
1163
1164    def test_sequential_ids(self):
1165        before = interpreters.channel_list_all()
1166        id1 = interpreters.channel_create()
1167        id2 = interpreters.channel_create()
1168        id3 = interpreters.channel_create()
1169        after = interpreters.channel_list_all()
1170
1171        self.assertEqual(id2, int(id1) + 1)
1172        self.assertEqual(id3, int(id2) + 1)
1173        self.assertEqual(set(after) - set(before), {id1, id2, id3})
1174
1175    def test_ids_global(self):
1176        id1 = interpreters.create()
1177        out = _run_output(id1, dedent("""
1178            import _xxsubinterpreters as _interpreters
1179            cid = _interpreters.channel_create()
1180            print(cid)
1181            """))
1182        cid1 = int(out.strip())
1183
1184        id2 = interpreters.create()
1185        out = _run_output(id2, dedent("""
1186            import _xxsubinterpreters as _interpreters
1187            cid = _interpreters.channel_create()
1188            print(cid)
1189            """))
1190        cid2 = int(out.strip())
1191
1192        self.assertEqual(cid2, int(cid1) + 1)
1193
1194    def test_channel_list_interpreters_none(self):
1195        """Test listing interpreters for a channel with no associations."""
1196        # Test for channel with no associated interpreters.
1197        cid = interpreters.channel_create()
1198        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1199        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1200        self.assertEqual(send_interps, [])
1201        self.assertEqual(recv_interps, [])
1202
1203    def test_channel_list_interpreters_basic(self):
1204        """Test basic listing channel interpreters."""
1205        interp0 = interpreters.get_main()
1206        cid = interpreters.channel_create()
1207        interpreters.channel_send(cid, "send")
1208        # Test for a channel that has one end associated to an interpreter.
1209        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1210        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1211        self.assertEqual(send_interps, [interp0])
1212        self.assertEqual(recv_interps, [])
1213
1214        interp1 = interpreters.create()
1215        _run_output(interp1, dedent(f"""
1216            import _xxsubinterpreters as _interpreters
1217            obj = _interpreters.channel_recv({cid})
1218            """))
1219        # Test for channel that has boths ends associated to an interpreter.
1220        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1221        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1222        self.assertEqual(send_interps, [interp0])
1223        self.assertEqual(recv_interps, [interp1])
1224
1225    def test_channel_list_interpreters_multiple(self):
1226        """Test listing interpreters for a channel with many associations."""
1227        interp0 = interpreters.get_main()
1228        interp1 = interpreters.create()
1229        interp2 = interpreters.create()
1230        interp3 = interpreters.create()
1231        cid = interpreters.channel_create()
1232
1233        interpreters.channel_send(cid, "send")
1234        _run_output(interp1, dedent(f"""
1235            import _xxsubinterpreters as _interpreters
1236            _interpreters.channel_send({cid}, "send")
1237            """))
1238        _run_output(interp2, dedent(f"""
1239            import _xxsubinterpreters as _interpreters
1240            obj = _interpreters.channel_recv({cid})
1241            """))
1242        _run_output(interp3, dedent(f"""
1243            import _xxsubinterpreters as _interpreters
1244            obj = _interpreters.channel_recv({cid})
1245            """))
1246        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1247        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1248        self.assertEqual(set(send_interps), {interp0, interp1})
1249        self.assertEqual(set(recv_interps), {interp2, interp3})
1250
1251    def test_channel_list_interpreters_destroyed(self):
1252        """Test listing channel interpreters with a destroyed interpreter."""
1253        interp0 = interpreters.get_main()
1254        interp1 = interpreters.create()
1255        cid = interpreters.channel_create()
1256        interpreters.channel_send(cid, "send")
1257        _run_output(interp1, dedent(f"""
1258            import _xxsubinterpreters as _interpreters
1259            obj = _interpreters.channel_recv({cid})
1260            """))
1261        # Should be one interpreter associated with each end.
1262        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1263        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1264        self.assertEqual(send_interps, [interp0])
1265        self.assertEqual(recv_interps, [interp1])
1266
1267        interpreters.destroy(interp1)
1268        # Destroyed interpreter should not be listed.
1269        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1270        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1271        self.assertEqual(send_interps, [interp0])
1272        self.assertEqual(recv_interps, [])
1273
1274    def test_channel_list_interpreters_released(self):
1275        """Test listing channel interpreters with a released channel."""
1276        # Set up one channel with main interpreter on the send end and two
1277        # subinterpreters on the receive end.
1278        interp0 = interpreters.get_main()
1279        interp1 = interpreters.create()
1280        interp2 = interpreters.create()
1281        cid = interpreters.channel_create()
1282        interpreters.channel_send(cid, "data")
1283        _run_output(interp1, dedent(f"""
1284            import _xxsubinterpreters as _interpreters
1285            obj = _interpreters.channel_recv({cid})
1286            """))
1287        interpreters.channel_send(cid, "data")
1288        _run_output(interp2, dedent(f"""
1289            import _xxsubinterpreters as _interpreters
1290            obj = _interpreters.channel_recv({cid})
1291            """))
1292        # Check the setup.
1293        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1294        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1295        self.assertEqual(len(send_interps), 1)
1296        self.assertEqual(len(recv_interps), 2)
1297
1298        # Release the main interpreter from the send end.
1299        interpreters.channel_release(cid, send=True)
1300        # Send end should have no associated interpreters.
1301        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1302        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1303        self.assertEqual(len(send_interps), 0)
1304        self.assertEqual(len(recv_interps), 2)
1305
1306        # Release one of the subinterpreters from the receive end.
1307        _run_output(interp2, dedent(f"""
1308            import _xxsubinterpreters as _interpreters
1309            _interpreters.channel_release({cid})
1310            """))
1311        # Receive end should have the released interpreter removed.
1312        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1313        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1314        self.assertEqual(len(send_interps), 0)
1315        self.assertEqual(recv_interps, [interp1])
1316
1317    def test_channel_list_interpreters_closed(self):
1318        """Test listing channel interpreters with a closed channel."""
1319        interp0 = interpreters.get_main()
1320        interp1 = interpreters.create()
1321        cid = interpreters.channel_create()
1322        # Put something in the channel so that it's not empty.
1323        interpreters.channel_send(cid, "send")
1324
1325        # Check initial state.
1326        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1327        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1328        self.assertEqual(len(send_interps), 1)
1329        self.assertEqual(len(recv_interps), 0)
1330
1331        # Force close the channel.
1332        interpreters.channel_close(cid, force=True)
1333        # Both ends should raise an error.
1334        with self.assertRaises(interpreters.ChannelClosedError):
1335            interpreters.channel_list_interpreters(cid, send=True)
1336        with self.assertRaises(interpreters.ChannelClosedError):
1337            interpreters.channel_list_interpreters(cid, send=False)
1338
1339    def test_channel_list_interpreters_closed_send_end(self):
1340        """Test listing channel interpreters with a channel's send end closed."""
1341        interp0 = interpreters.get_main()
1342        interp1 = interpreters.create()
1343        cid = interpreters.channel_create()
1344        # Put something in the channel so that it's not empty.
1345        interpreters.channel_send(cid, "send")
1346
1347        # Check initial state.
1348        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1349        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1350        self.assertEqual(len(send_interps), 1)
1351        self.assertEqual(len(recv_interps), 0)
1352
1353        # Close the send end of the channel.
1354        interpreters.channel_close(cid, send=True)
1355        # Send end should raise an error.
1356        with self.assertRaises(interpreters.ChannelClosedError):
1357            interpreters.channel_list_interpreters(cid, send=True)
1358        # Receive end should not be closed (since channel is not empty).
1359        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1360        self.assertEqual(len(recv_interps), 0)
1361
1362        # Close the receive end of the channel from a subinterpreter.
1363        _run_output(interp1, dedent(f"""
1364            import _xxsubinterpreters as _interpreters
1365            _interpreters.channel_close({cid}, force=True)
1366            """))
1367        # Both ends should raise an error.
1368        with self.assertRaises(interpreters.ChannelClosedError):
1369            interpreters.channel_list_interpreters(cid, send=True)
1370        with self.assertRaises(interpreters.ChannelClosedError):
1371            interpreters.channel_list_interpreters(cid, send=False)
1372
1373    ####################
1374
1375    def test_send_recv_main(self):
1376        cid = interpreters.channel_create()
1377        orig = b'spam'
1378        interpreters.channel_send(cid, orig)
1379        obj = interpreters.channel_recv(cid)
1380
1381        self.assertEqual(obj, orig)
1382        self.assertIsNot(obj, orig)
1383
1384    def test_send_recv_same_interpreter(self):
1385        id1 = interpreters.create()
1386        out = _run_output(id1, dedent("""
1387            import _xxsubinterpreters as _interpreters
1388            cid = _interpreters.channel_create()
1389            orig = b'spam'
1390            _interpreters.channel_send(cid, orig)
1391            obj = _interpreters.channel_recv(cid)
1392            assert obj is not orig
1393            assert obj == orig
1394            """))
1395
1396    def test_send_recv_different_interpreters(self):
1397        cid = interpreters.channel_create()
1398        id1 = interpreters.create()
1399        out = _run_output(id1, dedent(f"""
1400            import _xxsubinterpreters as _interpreters
1401            _interpreters.channel_send({cid}, b'spam')
1402            """))
1403        obj = interpreters.channel_recv(cid)
1404
1405        self.assertEqual(obj, b'spam')
1406
1407    def test_send_recv_different_threads(self):
1408        cid = interpreters.channel_create()
1409
1410        def f():
1411            while True:
1412                try:
1413                    obj = interpreters.channel_recv(cid)
1414                    break
1415                except interpreters.ChannelEmptyError:
1416                    time.sleep(0.1)
1417            interpreters.channel_send(cid, obj)
1418        t = threading.Thread(target=f)
1419        t.start()
1420
1421        interpreters.channel_send(cid, b'spam')
1422        t.join()
1423        obj = interpreters.channel_recv(cid)
1424
1425        self.assertEqual(obj, b'spam')
1426
1427    def test_send_recv_different_interpreters_and_threads(self):
1428        cid = interpreters.channel_create()
1429        id1 = interpreters.create()
1430        out = None
1431
1432        def f():
1433            nonlocal out
1434            out = _run_output(id1, dedent(f"""
1435                import time
1436                import _xxsubinterpreters as _interpreters
1437                while True:
1438                    try:
1439                        obj = _interpreters.channel_recv({cid})
1440                        break
1441                    except _interpreters.ChannelEmptyError:
1442                        time.sleep(0.1)
1443                assert(obj == b'spam')
1444                _interpreters.channel_send({cid}, b'eggs')
1445                """))
1446        t = threading.Thread(target=f)
1447        t.start()
1448
1449        interpreters.channel_send(cid, b'spam')
1450        t.join()
1451        obj = interpreters.channel_recv(cid)
1452
1453        self.assertEqual(obj, b'eggs')
1454
1455    def test_send_not_found(self):
1456        with self.assertRaises(interpreters.ChannelNotFoundError):
1457            interpreters.channel_send(10, b'spam')
1458
1459    def test_recv_not_found(self):
1460        with self.assertRaises(interpreters.ChannelNotFoundError):
1461            interpreters.channel_recv(10)
1462
1463    def test_recv_empty(self):
1464        cid = interpreters.channel_create()
1465        with self.assertRaises(interpreters.ChannelEmptyError):
1466            interpreters.channel_recv(cid)
1467
1468    def test_recv_default(self):
1469        default = object()
1470        cid = interpreters.channel_create()
1471        obj1 = interpreters.channel_recv(cid, default)
1472        interpreters.channel_send(cid, None)
1473        interpreters.channel_send(cid, 1)
1474        interpreters.channel_send(cid, b'spam')
1475        interpreters.channel_send(cid, b'eggs')
1476        obj2 = interpreters.channel_recv(cid, default)
1477        obj3 = interpreters.channel_recv(cid, default)
1478        obj4 = interpreters.channel_recv(cid)
1479        obj5 = interpreters.channel_recv(cid, default)
1480        obj6 = interpreters.channel_recv(cid, default)
1481
1482        self.assertIs(obj1, default)
1483        self.assertIs(obj2, None)
1484        self.assertEqual(obj3, 1)
1485        self.assertEqual(obj4, b'spam')
1486        self.assertEqual(obj5, b'eggs')
1487        self.assertIs(obj6, default)
1488
1489    def test_run_string_arg_unresolved(self):
1490        cid = interpreters.channel_create()
1491        interp = interpreters.create()
1492
1493        out = _run_output(interp, dedent("""
1494            import _xxsubinterpreters as _interpreters
1495            print(cid.end)
1496            _interpreters.channel_send(cid, b'spam')
1497            """),
1498            dict(cid=cid.send))
1499        obj = interpreters.channel_recv(cid)
1500
1501        self.assertEqual(obj, b'spam')
1502        self.assertEqual(out.strip(), 'send')
1503
1504    # XXX For now there is no high-level channel into which the
1505    # sent channel ID can be converted...
1506    # Note: this test caused crashes on some buildbots (bpo-33615).
1507    @unittest.skip('disabled until high-level channels exist')
1508    def test_run_string_arg_resolved(self):
1509        cid = interpreters.channel_create()
1510        cid = interpreters._channel_id(cid, _resolve=True)
1511        interp = interpreters.create()
1512
1513        out = _run_output(interp, dedent("""
1514            import _xxsubinterpreters as _interpreters
1515            print(chan.id.end)
1516            _interpreters.channel_send(chan.id, b'spam')
1517            """),
1518            dict(chan=cid.send))
1519        obj = interpreters.channel_recv(cid)
1520
1521        self.assertEqual(obj, b'spam')
1522        self.assertEqual(out.strip(), 'send')
1523
1524    # close
1525
1526    def test_close_single_user(self):
1527        cid = interpreters.channel_create()
1528        interpreters.channel_send(cid, b'spam')
1529        interpreters.channel_recv(cid)
1530        interpreters.channel_close(cid)
1531
1532        with self.assertRaises(interpreters.ChannelClosedError):
1533            interpreters.channel_send(cid, b'eggs')
1534        with self.assertRaises(interpreters.ChannelClosedError):
1535            interpreters.channel_recv(cid)
1536
1537    def test_close_multiple_users(self):
1538        cid = interpreters.channel_create()
1539        id1 = interpreters.create()
1540        id2 = interpreters.create()
1541        interpreters.run_string(id1, dedent(f"""
1542            import _xxsubinterpreters as _interpreters
1543            _interpreters.channel_send({cid}, b'spam')
1544            """))
1545        interpreters.run_string(id2, dedent(f"""
1546            import _xxsubinterpreters as _interpreters
1547            _interpreters.channel_recv({cid})
1548            """))
1549        interpreters.channel_close(cid)
1550        with self.assertRaises(interpreters.RunFailedError) as cm:
1551            interpreters.run_string(id1, dedent(f"""
1552                _interpreters.channel_send({cid}, b'spam')
1553                """))
1554        self.assertIn('ChannelClosedError', str(cm.exception))
1555        with self.assertRaises(interpreters.RunFailedError) as cm:
1556            interpreters.run_string(id2, dedent(f"""
1557                _interpreters.channel_send({cid}, b'spam')
1558                """))
1559        self.assertIn('ChannelClosedError', str(cm.exception))
1560
1561    def test_close_multiple_times(self):
1562        cid = interpreters.channel_create()
1563        interpreters.channel_send(cid, b'spam')
1564        interpreters.channel_recv(cid)
1565        interpreters.channel_close(cid)
1566
1567        with self.assertRaises(interpreters.ChannelClosedError):
1568            interpreters.channel_close(cid)
1569
1570    def test_close_empty(self):
1571        tests = [
1572            (False, False),
1573            (True, False),
1574            (False, True),
1575            (True, True),
1576            ]
1577        for send, recv in tests:
1578            with self.subTest((send, recv)):
1579                cid = interpreters.channel_create()
1580                interpreters.channel_send(cid, b'spam')
1581                interpreters.channel_recv(cid)
1582                interpreters.channel_close(cid, send=send, recv=recv)
1583
1584                with self.assertRaises(interpreters.ChannelClosedError):
1585                    interpreters.channel_send(cid, b'eggs')
1586                with self.assertRaises(interpreters.ChannelClosedError):
1587                    interpreters.channel_recv(cid)
1588
1589    def test_close_defaults_with_unused_items(self):
1590        cid = interpreters.channel_create()
1591        interpreters.channel_send(cid, b'spam')
1592        interpreters.channel_send(cid, b'ham')
1593
1594        with self.assertRaises(interpreters.ChannelNotEmptyError):
1595            interpreters.channel_close(cid)
1596        interpreters.channel_recv(cid)
1597        interpreters.channel_send(cid, b'eggs')
1598
1599    def test_close_recv_with_unused_items_unforced(self):
1600        cid = interpreters.channel_create()
1601        interpreters.channel_send(cid, b'spam')
1602        interpreters.channel_send(cid, b'ham')
1603
1604        with self.assertRaises(interpreters.ChannelNotEmptyError):
1605            interpreters.channel_close(cid, recv=True)
1606        interpreters.channel_recv(cid)
1607        interpreters.channel_send(cid, b'eggs')
1608        interpreters.channel_recv(cid)
1609        interpreters.channel_recv(cid)
1610        interpreters.channel_close(cid, recv=True)
1611
1612    def test_close_send_with_unused_items_unforced(self):
1613        cid = interpreters.channel_create()
1614        interpreters.channel_send(cid, b'spam')
1615        interpreters.channel_send(cid, b'ham')
1616        interpreters.channel_close(cid, send=True)
1617
1618        with self.assertRaises(interpreters.ChannelClosedError):
1619            interpreters.channel_send(cid, b'eggs')
1620        interpreters.channel_recv(cid)
1621        interpreters.channel_recv(cid)
1622        with self.assertRaises(interpreters.ChannelClosedError):
1623            interpreters.channel_recv(cid)
1624
1625    def test_close_both_with_unused_items_unforced(self):
1626        cid = interpreters.channel_create()
1627        interpreters.channel_send(cid, b'spam')
1628        interpreters.channel_send(cid, b'ham')
1629
1630        with self.assertRaises(interpreters.ChannelNotEmptyError):
1631            interpreters.channel_close(cid, recv=True, send=True)
1632        interpreters.channel_recv(cid)
1633        interpreters.channel_send(cid, b'eggs')
1634        interpreters.channel_recv(cid)
1635        interpreters.channel_recv(cid)
1636        interpreters.channel_close(cid, recv=True)
1637
1638    def test_close_recv_with_unused_items_forced(self):
1639        cid = interpreters.channel_create()
1640        interpreters.channel_send(cid, b'spam')
1641        interpreters.channel_send(cid, b'ham')
1642        interpreters.channel_close(cid, recv=True, force=True)
1643
1644        with self.assertRaises(interpreters.ChannelClosedError):
1645            interpreters.channel_send(cid, b'eggs')
1646        with self.assertRaises(interpreters.ChannelClosedError):
1647            interpreters.channel_recv(cid)
1648
1649    def test_close_send_with_unused_items_forced(self):
1650        cid = interpreters.channel_create()
1651        interpreters.channel_send(cid, b'spam')
1652        interpreters.channel_send(cid, b'ham')
1653        interpreters.channel_close(cid, send=True, force=True)
1654
1655        with self.assertRaises(interpreters.ChannelClosedError):
1656            interpreters.channel_send(cid, b'eggs')
1657        with self.assertRaises(interpreters.ChannelClosedError):
1658            interpreters.channel_recv(cid)
1659
1660    def test_close_both_with_unused_items_forced(self):
1661        cid = interpreters.channel_create()
1662        interpreters.channel_send(cid, b'spam')
1663        interpreters.channel_send(cid, b'ham')
1664        interpreters.channel_close(cid, send=True, recv=True, force=True)
1665
1666        with self.assertRaises(interpreters.ChannelClosedError):
1667            interpreters.channel_send(cid, b'eggs')
1668        with self.assertRaises(interpreters.ChannelClosedError):
1669            interpreters.channel_recv(cid)
1670
1671    def test_close_never_used(self):
1672        cid = interpreters.channel_create()
1673        interpreters.channel_close(cid)
1674
1675        with self.assertRaises(interpreters.ChannelClosedError):
1676            interpreters.channel_send(cid, b'spam')
1677        with self.assertRaises(interpreters.ChannelClosedError):
1678            interpreters.channel_recv(cid)
1679
1680    def test_close_by_unassociated_interp(self):
1681        cid = interpreters.channel_create()
1682        interpreters.channel_send(cid, b'spam')
1683        interp = interpreters.create()
1684        interpreters.run_string(interp, dedent(f"""
1685            import _xxsubinterpreters as _interpreters
1686            _interpreters.channel_close({cid}, force=True)
1687            """))
1688        with self.assertRaises(interpreters.ChannelClosedError):
1689            interpreters.channel_recv(cid)
1690        with self.assertRaises(interpreters.ChannelClosedError):
1691            interpreters.channel_close(cid)
1692
1693    def test_close_used_multiple_times_by_single_user(self):
1694        cid = interpreters.channel_create()
1695        interpreters.channel_send(cid, b'spam')
1696        interpreters.channel_send(cid, b'spam')
1697        interpreters.channel_send(cid, b'spam')
1698        interpreters.channel_recv(cid)
1699        interpreters.channel_close(cid, force=True)
1700
1701        with self.assertRaises(interpreters.ChannelClosedError):
1702            interpreters.channel_send(cid, b'eggs')
1703        with self.assertRaises(interpreters.ChannelClosedError):
1704            interpreters.channel_recv(cid)
1705
1706    def test_channel_list_interpreters_invalid_channel(self):
1707        cid = interpreters.channel_create()
1708        # Test for invalid channel ID.
1709        with self.assertRaises(interpreters.ChannelNotFoundError):
1710            interpreters.channel_list_interpreters(1000, send=True)
1711
1712        interpreters.channel_close(cid)
1713        # Test for a channel that has been closed.
1714        with self.assertRaises(interpreters.ChannelClosedError):
1715            interpreters.channel_list_interpreters(cid, send=True)
1716
1717    def test_channel_list_interpreters_invalid_args(self):
1718        # Tests for invalid arguments passed to the API.
1719        cid = interpreters.channel_create()
1720        with self.assertRaises(TypeError):
1721            interpreters.channel_list_interpreters(cid)
1722
1723
1724class ChannelReleaseTests(TestBase):
1725
1726    # XXX Add more test coverage a la the tests for close().
1727
1728    """
1729    - main / interp / other
1730    - run in: current thread / new thread / other thread / different threads
1731    - end / opposite
1732    - force / no force
1733    - used / not used  (associated / not associated)
1734    - empty / emptied / never emptied / partly emptied
1735    - closed / not closed
1736    - released / not released
1737    - creator (interp) / other
1738    - associated interpreter not running
1739    - associated interpreter destroyed
1740    """
1741
1742    """
1743    use
1744    pre-release
1745    release
1746    after
1747    check
1748    """
1749
1750    """
1751    release in:         main, interp1
1752    creator:            same, other (incl. interp2)
1753
1754    use:                None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1755    pre-release:        None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all
1756    pre-release forced: None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all
1757
1758    release:            same
1759    release forced:     same
1760
1761    use after:          None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1762    release after:      None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1763    check released:     send/recv for same/other(incl. interp2)
1764    check closed:       send/recv for same/other(incl. interp2)
1765    """
1766
1767    def test_single_user(self):
1768        cid = interpreters.channel_create()
1769        interpreters.channel_send(cid, b'spam')
1770        interpreters.channel_recv(cid)
1771        interpreters.channel_release(cid, send=True, recv=True)
1772
1773        with self.assertRaises(interpreters.ChannelClosedError):
1774            interpreters.channel_send(cid, b'eggs')
1775        with self.assertRaises(interpreters.ChannelClosedError):
1776            interpreters.channel_recv(cid)
1777
1778    def test_multiple_users(self):
1779        cid = interpreters.channel_create()
1780        id1 = interpreters.create()
1781        id2 = interpreters.create()
1782        interpreters.run_string(id1, dedent(f"""
1783            import _xxsubinterpreters as _interpreters
1784            _interpreters.channel_send({cid}, b'spam')
1785            """))
1786        out = _run_output(id2, dedent(f"""
1787            import _xxsubinterpreters as _interpreters
1788            obj = _interpreters.channel_recv({cid})
1789            _interpreters.channel_release({cid})
1790            print(repr(obj))
1791            """))
1792        interpreters.run_string(id1, dedent(f"""
1793            _interpreters.channel_release({cid})
1794            """))
1795
1796        self.assertEqual(out.strip(), "b'spam'")
1797
1798    def test_no_kwargs(self):
1799        cid = interpreters.channel_create()
1800        interpreters.channel_send(cid, b'spam')
1801        interpreters.channel_recv(cid)
1802        interpreters.channel_release(cid)
1803
1804        with self.assertRaises(interpreters.ChannelClosedError):
1805            interpreters.channel_send(cid, b'eggs')
1806        with self.assertRaises(interpreters.ChannelClosedError):
1807            interpreters.channel_recv(cid)
1808
1809    def test_multiple_times(self):
1810        cid = interpreters.channel_create()
1811        interpreters.channel_send(cid, b'spam')
1812        interpreters.channel_recv(cid)
1813        interpreters.channel_release(cid, send=True, recv=True)
1814
1815        with self.assertRaises(interpreters.ChannelClosedError):
1816            interpreters.channel_release(cid, send=True, recv=True)
1817
1818    def test_with_unused_items(self):
1819        cid = interpreters.channel_create()
1820        interpreters.channel_send(cid, b'spam')
1821        interpreters.channel_send(cid, b'ham')
1822        interpreters.channel_release(cid, send=True, recv=True)
1823
1824        with self.assertRaises(interpreters.ChannelClosedError):
1825            interpreters.channel_recv(cid)
1826
1827    def test_never_used(self):
1828        cid = interpreters.channel_create()
1829        interpreters.channel_release(cid)
1830
1831        with self.assertRaises(interpreters.ChannelClosedError):
1832            interpreters.channel_send(cid, b'spam')
1833        with self.assertRaises(interpreters.ChannelClosedError):
1834            interpreters.channel_recv(cid)
1835
1836    def test_by_unassociated_interp(self):
1837        cid = interpreters.channel_create()
1838        interpreters.channel_send(cid, b'spam')
1839        interp = interpreters.create()
1840        interpreters.run_string(interp, dedent(f"""
1841            import _xxsubinterpreters as _interpreters
1842            _interpreters.channel_release({cid})
1843            """))
1844        obj = interpreters.channel_recv(cid)
1845        interpreters.channel_release(cid)
1846
1847        with self.assertRaises(interpreters.ChannelClosedError):
1848            interpreters.channel_send(cid, b'eggs')
1849        self.assertEqual(obj, b'spam')
1850
1851    def test_close_if_unassociated(self):
1852        # XXX Something's not right with this test...
1853        cid = interpreters.channel_create()
1854        interp = interpreters.create()
1855        interpreters.run_string(interp, dedent(f"""
1856            import _xxsubinterpreters as _interpreters
1857            obj = _interpreters.channel_send({cid}, b'spam')
1858            _interpreters.channel_release({cid})
1859            """))
1860
1861        with self.assertRaises(interpreters.ChannelClosedError):
1862            interpreters.channel_recv(cid)
1863
1864    def test_partially(self):
1865        # XXX Is partial close too weird/confusing?
1866        cid = interpreters.channel_create()
1867        interpreters.channel_send(cid, None)
1868        interpreters.channel_recv(cid)
1869        interpreters.channel_send(cid, b'spam')
1870        interpreters.channel_release(cid, send=True)
1871        obj = interpreters.channel_recv(cid)
1872
1873        self.assertEqual(obj, b'spam')
1874
1875    def test_used_multiple_times_by_single_user(self):
1876        cid = interpreters.channel_create()
1877        interpreters.channel_send(cid, b'spam')
1878        interpreters.channel_send(cid, b'spam')
1879        interpreters.channel_send(cid, b'spam')
1880        interpreters.channel_recv(cid)
1881        interpreters.channel_release(cid, send=True, recv=True)
1882
1883        with self.assertRaises(interpreters.ChannelClosedError):
1884            interpreters.channel_send(cid, b'eggs')
1885        with self.assertRaises(interpreters.ChannelClosedError):
1886            interpreters.channel_recv(cid)
1887
1888
1889class ChannelCloseFixture(namedtuple('ChannelCloseFixture',
1890                                     'end interp other extra creator')):
1891
1892    # Set this to True to avoid creating interpreters, e.g. when
1893    # scanning through test permutations without running them.
1894    QUICK = False
1895
1896    def __new__(cls, end, interp, other, extra, creator):
1897        assert end in ('send', 'recv')
1898        if cls.QUICK:
1899            known = {}
1900        else:
1901            interp = Interpreter.from_raw(interp)
1902            other = Interpreter.from_raw(other)
1903            extra = Interpreter.from_raw(extra)
1904            known = {
1905                interp.name: interp,
1906                other.name: other,
1907                extra.name: extra,
1908                }
1909        if not creator:
1910            creator = 'same'
1911        self = super().__new__(cls, end, interp, other, extra, creator)
1912        self._prepped = set()
1913        self._state = ChannelState()
1914        self._known = known
1915        return self
1916
1917    @property
1918    def state(self):
1919        return self._state
1920
1921    @property
1922    def cid(self):
1923        try:
1924            return self._cid
1925        except AttributeError:
1926            creator = self._get_interpreter(self.creator)
1927            self._cid = self._new_channel(creator)
1928            return self._cid
1929
1930    def get_interpreter(self, interp):
1931        interp = self._get_interpreter(interp)
1932        self._prep_interpreter(interp)
1933        return interp
1934
1935    def expect_closed_error(self, end=None):
1936        if end is None:
1937            end = self.end
1938        if end == 'recv' and self.state.closed == 'send':
1939            return False
1940        return bool(self.state.closed)
1941
1942    def prep_interpreter(self, interp):
1943        self._prep_interpreter(interp)
1944
1945    def record_action(self, action, result):
1946        self._state = result
1947
1948    def clean_up(self):
1949        clean_up_interpreters()
1950        clean_up_channels()
1951
1952    # internal methods
1953
1954    def _new_channel(self, creator):
1955        if creator.name == 'main':
1956            return interpreters.channel_create()
1957        else:
1958            ch = interpreters.channel_create()
1959            run_interp(creator.id, f"""
1960                import _xxsubinterpreters
1961                cid = _xxsubinterpreters.channel_create()
1962                # We purposefully send back an int to avoid tying the
1963                # channel to the other interpreter.
1964                _xxsubinterpreters.channel_send({ch}, int(cid))
1965                del _xxsubinterpreters
1966                """)
1967            self._cid = interpreters.channel_recv(ch)
1968        return self._cid
1969
1970    def _get_interpreter(self, interp):
1971        if interp in ('same', 'interp'):
1972            return self.interp
1973        elif interp == 'other':
1974            return self.other
1975        elif interp == 'extra':
1976            return self.extra
1977        else:
1978            name = interp
1979            try:
1980                interp = self._known[name]
1981            except KeyError:
1982                interp = self._known[name] = Interpreter(name)
1983            return interp
1984
1985    def _prep_interpreter(self, interp):
1986        if interp.id in self._prepped:
1987            return
1988        self._prepped.add(interp.id)
1989        if interp.name == 'main':
1990            return
1991        run_interp(interp.id, f"""
1992            import _xxsubinterpreters as interpreters
1993            import test.test__xxsubinterpreters as helpers
1994            ChannelState = helpers.ChannelState
1995            try:
1996                cid
1997            except NameError:
1998                cid = interpreters._channel_id({self.cid})
1999            """)
2000
2001
2002@unittest.skip('these tests take several hours to run')
2003class ExhaustiveChannelTests(TestBase):
2004
2005    """
2006    - main / interp / other
2007    - run in: current thread / new thread / other thread / different threads
2008    - end / opposite
2009    - force / no force
2010    - used / not used  (associated / not associated)
2011    - empty / emptied / never emptied / partly emptied
2012    - closed / not closed
2013    - released / not released
2014    - creator (interp) / other
2015    - associated interpreter not running
2016    - associated interpreter destroyed
2017
2018    - close after unbound
2019    """
2020
2021    """
2022    use
2023    pre-close
2024    close
2025    after
2026    check
2027    """
2028
2029    """
2030    close in:         main, interp1
2031    creator:          same, other, extra
2032
2033    use:              None,send,recv,send/recv in None,same,other,same+other,all
2034    pre-close:        None,send,recv in None,same,other,same+other,all
2035    pre-close forced: None,send,recv in None,same,other,same+other,all
2036
2037    close:            same
2038    close forced:     same
2039
2040    use after:        None,send,recv,send/recv in None,same,other,extra,same+other,all
2041    close after:      None,send,recv,send/recv in None,same,other,extra,same+other,all
2042    check closed:     send/recv for same/other(incl. interp2)
2043    """
2044
2045    def iter_action_sets(self):
2046        # - used / not used  (associated / not associated)
2047        # - empty / emptied / never emptied / partly emptied
2048        # - closed / not closed
2049        # - released / not released
2050
2051        # never used
2052        yield []
2053
2054        # only pre-closed (and possible used after)
2055        for closeactions in self._iter_close_action_sets('same', 'other'):
2056            yield closeactions
2057            for postactions in self._iter_post_close_action_sets():
2058                yield closeactions + postactions
2059        for closeactions in self._iter_close_action_sets('other', 'extra'):
2060            yield closeactions
2061            for postactions in self._iter_post_close_action_sets():
2062                yield closeactions + postactions
2063
2064        # used
2065        for useactions in self._iter_use_action_sets('same', 'other'):
2066            yield useactions
2067            for closeactions in self._iter_close_action_sets('same', 'other'):
2068                actions = useactions + closeactions
2069                yield actions
2070                for postactions in self._iter_post_close_action_sets():
2071                    yield actions + postactions
2072            for closeactions in self._iter_close_action_sets('other', 'extra'):
2073                actions = useactions + closeactions
2074                yield actions
2075                for postactions in self._iter_post_close_action_sets():
2076                    yield actions + postactions
2077        for useactions in self._iter_use_action_sets('other', 'extra'):
2078            yield useactions
2079            for closeactions in self._iter_close_action_sets('same', 'other'):
2080                actions = useactions + closeactions
2081                yield actions
2082                for postactions in self._iter_post_close_action_sets():
2083                    yield actions + postactions
2084            for closeactions in self._iter_close_action_sets('other', 'extra'):
2085                actions = useactions + closeactions
2086                yield actions
2087                for postactions in self._iter_post_close_action_sets():
2088                    yield actions + postactions
2089
2090    def _iter_use_action_sets(self, interp1, interp2):
2091        interps = (interp1, interp2)
2092
2093        # only recv end used
2094        yield [
2095            ChannelAction('use', 'recv', interp1),
2096            ]
2097        yield [
2098            ChannelAction('use', 'recv', interp2),
2099            ]
2100        yield [
2101            ChannelAction('use', 'recv', interp1),
2102            ChannelAction('use', 'recv', interp2),
2103            ]
2104
2105        # never emptied
2106        yield [
2107            ChannelAction('use', 'send', interp1),
2108            ]
2109        yield [
2110            ChannelAction('use', 'send', interp2),
2111            ]
2112        yield [
2113            ChannelAction('use', 'send', interp1),
2114            ChannelAction('use', 'send', interp2),
2115            ]
2116
2117        # partially emptied
2118        for interp1 in interps:
2119            for interp2 in interps:
2120                for interp3 in interps:
2121                    yield [
2122                        ChannelAction('use', 'send', interp1),
2123                        ChannelAction('use', 'send', interp2),
2124                        ChannelAction('use', 'recv', interp3),
2125                        ]
2126
2127        # fully emptied
2128        for interp1 in interps:
2129            for interp2 in interps:
2130                for interp3 in interps:
2131                    for interp4 in interps:
2132                        yield [
2133                            ChannelAction('use', 'send', interp1),
2134                            ChannelAction('use', 'send', interp2),
2135                            ChannelAction('use', 'recv', interp3),
2136                            ChannelAction('use', 'recv', interp4),
2137                            ]
2138
2139    def _iter_close_action_sets(self, interp1, interp2):
2140        ends = ('recv', 'send')
2141        interps = (interp1, interp2)
2142        for force in (True, False):
2143            op = 'force-close' if force else 'close'
2144            for interp in interps:
2145                for end in ends:
2146                    yield [
2147                        ChannelAction(op, end, interp),
2148                        ]
2149        for recvop in ('close', 'force-close'):
2150            for sendop in ('close', 'force-close'):
2151                for recv in interps:
2152                    for send in interps:
2153                        yield [
2154                            ChannelAction(recvop, 'recv', recv),
2155                            ChannelAction(sendop, 'send', send),
2156                            ]
2157
2158    def _iter_post_close_action_sets(self):
2159        for interp in ('same', 'extra', 'other'):
2160            yield [
2161                ChannelAction('use', 'recv', interp),
2162                ]
2163            yield [
2164                ChannelAction('use', 'send', interp),
2165                ]
2166
2167    def run_actions(self, fix, actions):
2168        for action in actions:
2169            self.run_action(fix, action)
2170
2171    def run_action(self, fix, action, *, hideclosed=True):
2172        end = action.resolve_end(fix.end)
2173        interp = action.resolve_interp(fix.interp, fix.other, fix.extra)
2174        fix.prep_interpreter(interp)
2175        if interp.name == 'main':
2176            result = run_action(
2177                fix.cid,
2178                action.action,
2179                end,
2180                fix.state,
2181                hideclosed=hideclosed,
2182                )
2183            fix.record_action(action, result)
2184        else:
2185            _cid = interpreters.channel_create()
2186            run_interp(interp.id, f"""
2187                result = helpers.run_action(
2188                    {fix.cid},
2189                    {repr(action.action)},
2190                    {repr(end)},
2191                    {repr(fix.state)},
2192                    hideclosed={hideclosed},
2193                    )
2194                interpreters.channel_send({_cid}, result.pending.to_bytes(1, 'little'))
2195                interpreters.channel_send({_cid}, b'X' if result.closed else b'')
2196                """)
2197            result = ChannelState(
2198                pending=int.from_bytes(interpreters.channel_recv(_cid), 'little'),
2199                closed=bool(interpreters.channel_recv(_cid)),
2200                )
2201            fix.record_action(action, result)
2202
2203    def iter_fixtures(self):
2204        # XXX threads?
2205        interpreters = [
2206            ('main', 'interp', 'extra'),
2207            ('interp', 'main', 'extra'),
2208            ('interp1', 'interp2', 'extra'),
2209            ('interp1', 'interp2', 'main'),
2210        ]
2211        for interp, other, extra in interpreters:
2212            for creator in ('same', 'other', 'creator'):
2213                for end in ('send', 'recv'):
2214                    yield ChannelCloseFixture(end, interp, other, extra, creator)
2215
2216    def _close(self, fix, *, force):
2217        op = 'force-close' if force else 'close'
2218        close = ChannelAction(op, fix.end, 'same')
2219        if not fix.expect_closed_error():
2220            self.run_action(fix, close, hideclosed=False)
2221        else:
2222            with self.assertRaises(interpreters.ChannelClosedError):
2223                self.run_action(fix, close, hideclosed=False)
2224
2225    def _assert_closed_in_interp(self, fix, interp=None):
2226        if interp is None or interp.name == 'main':
2227            with self.assertRaises(interpreters.ChannelClosedError):
2228                interpreters.channel_recv(fix.cid)
2229            with self.assertRaises(interpreters.ChannelClosedError):
2230                interpreters.channel_send(fix.cid, b'spam')
2231            with self.assertRaises(interpreters.ChannelClosedError):
2232                interpreters.channel_close(fix.cid)
2233            with self.assertRaises(interpreters.ChannelClosedError):
2234                interpreters.channel_close(fix.cid, force=True)
2235        else:
2236            run_interp(interp.id, f"""
2237                with helpers.expect_channel_closed():
2238                    interpreters.channel_recv(cid)
2239                """)
2240            run_interp(interp.id, f"""
2241                with helpers.expect_channel_closed():
2242                    interpreters.channel_send(cid, b'spam')
2243                """)
2244            run_interp(interp.id, f"""
2245                with helpers.expect_channel_closed():
2246                    interpreters.channel_close(cid)
2247                """)
2248            run_interp(interp.id, f"""
2249                with helpers.expect_channel_closed():
2250                    interpreters.channel_close(cid, force=True)
2251                """)
2252
2253    def _assert_closed(self, fix):
2254        self.assertTrue(fix.state.closed)
2255
2256        for _ in range(fix.state.pending):
2257            interpreters.channel_recv(fix.cid)
2258        self._assert_closed_in_interp(fix)
2259
2260        for interp in ('same', 'other'):
2261            interp = fix.get_interpreter(interp)
2262            if interp.name == 'main':
2263                continue
2264            self._assert_closed_in_interp(fix, interp)
2265
2266        interp = fix.get_interpreter('fresh')
2267        self._assert_closed_in_interp(fix, interp)
2268
2269    def _iter_close_tests(self, verbose=False):
2270        i = 0
2271        for actions in self.iter_action_sets():
2272            print()
2273            for fix in self.iter_fixtures():
2274                i += 1
2275                if i > 1000:
2276                    return
2277                if verbose:
2278                    if (i - 1) % 6 == 0:
2279                        print()
2280                    print(i, fix, '({} actions)'.format(len(actions)))
2281                else:
2282                    if (i - 1) % 6 == 0:
2283                        print(' ', end='')
2284                    print('.', end=''); sys.stdout.flush()
2285                yield i, fix, actions
2286            if verbose:
2287                print('---')
2288        print()
2289
2290    # This is useful for scanning through the possible tests.
2291    def _skim_close_tests(self):
2292        ChannelCloseFixture.QUICK = True
2293        for i, fix, actions in self._iter_close_tests():
2294            pass
2295
2296    def test_close(self):
2297        for i, fix, actions in self._iter_close_tests():
2298            with self.subTest('{} {}  {}'.format(i, fix, actions)):
2299                fix.prep_interpreter(fix.interp)
2300                self.run_actions(fix, actions)
2301
2302                self._close(fix, force=False)
2303
2304                self._assert_closed(fix)
2305            # XXX Things slow down if we have too many interpreters.
2306            fix.clean_up()
2307
2308    def test_force_close(self):
2309        for i, fix, actions in self._iter_close_tests():
2310            with self.subTest('{} {}  {}'.format(i, fix, actions)):
2311                fix.prep_interpreter(fix.interp)
2312                self.run_actions(fix, actions)
2313
2314                self._close(fix, force=True)
2315
2316                self._assert_closed(fix)
2317            # XXX Things slow down if we have too many interpreters.
2318            fix.clean_up()
2319
2320
2321if __name__ == '__main__':
2322    unittest.main()
2323