• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""This module includes tests of the code object representation.
2
3>>> def f(x):
4...     def g(y):
5...         return x + y
6...     return g
7...
8
9>>> dump(f.__code__)
10name: f
11argcount: 1
12posonlyargcount: 0
13kwonlyargcount: 0
14names: ()
15varnames: ('x', 'g')
16cellvars: ('x',)
17freevars: ()
18nlocals: 2
19flags: 3
20consts: ('None', '<code object g>')
21
22>>> dump(f(4).__code__)
23name: g
24argcount: 1
25posonlyargcount: 0
26kwonlyargcount: 0
27names: ()
28varnames: ('y',)
29cellvars: ()
30freevars: ('x',)
31nlocals: 1
32flags: 19
33consts: ('None',)
34
35>>> def h(x, y):
36...     a = x + y
37...     b = x - y
38...     c = a * b
39...     return c
40...
41
42>>> dump(h.__code__)
43name: h
44argcount: 2
45posonlyargcount: 0
46kwonlyargcount: 0
47names: ()
48varnames: ('x', 'y', 'a', 'b', 'c')
49cellvars: ()
50freevars: ()
51nlocals: 5
52flags: 3
53consts: ('None',)
54
55>>> def attrs(obj):
56...     print(obj.attr1)
57...     print(obj.attr2)
58...     print(obj.attr3)
59
60>>> dump(attrs.__code__)
61name: attrs
62argcount: 1
63posonlyargcount: 0
64kwonlyargcount: 0
65names: ('print', 'attr1', 'attr2', 'attr3')
66varnames: ('obj',)
67cellvars: ()
68freevars: ()
69nlocals: 1
70flags: 3
71consts: ('None',)
72
73>>> def optimize_away():
74...     'doc string'
75...     'not a docstring'
76...     53
77...     0x53
78
79>>> dump(optimize_away.__code__)
80name: optimize_away
81argcount: 0
82posonlyargcount: 0
83kwonlyargcount: 0
84names: ()
85varnames: ()
86cellvars: ()
87freevars: ()
88nlocals: 0
89flags: 3
90consts: ("'doc string'", 'None')
91
92>>> def keywordonly_args(a,b,*,k1):
93...     return a,b,k1
94...
95
96>>> dump(keywordonly_args.__code__)
97name: keywordonly_args
98argcount: 2
99posonlyargcount: 0
100kwonlyargcount: 1
101names: ()
102varnames: ('a', 'b', 'k1')
103cellvars: ()
104freevars: ()
105nlocals: 3
106flags: 3
107consts: ('None',)
108
109>>> def posonly_args(a,b,/,c):
110...     return a,b,c
111...
112
113>>> dump(posonly_args.__code__)
114name: posonly_args
115argcount: 3
116posonlyargcount: 2
117kwonlyargcount: 0
118names: ()
119varnames: ('a', 'b', 'c')
120cellvars: ()
121freevars: ()
122nlocals: 3
123flags: 3
124consts: ('None',)
125
126"""
127
128import copy
129import inspect
130import sys
131import threading
132import doctest
133import unittest
134import textwrap
135import weakref
136import dis
137
138try:
139    import ctypes
140except ImportError:
141    ctypes = None
142from test.support import (cpython_only,
143                          check_impl_detail, requires_debug_ranges,
144                          gc_collect, Py_GIL_DISABLED,
145                          suppress_immortalization,
146                          skip_if_suppress_immortalization)
147from test.support.script_helper import assert_python_ok
148from test.support import threading_helper, import_helper
149from test.support.bytecode_helper import instructions_with_positions
150from opcode import opmap, opname
151COPY_FREE_VARS = opmap['COPY_FREE_VARS']
152
153
154def consts(t):
155    """Yield a doctest-safe sequence of object reprs."""
156    for elt in t:
157        r = repr(elt)
158        if r.startswith("<code object"):
159            yield "<code object %s>" % elt.co_name
160        else:
161            yield r
162
163def dump(co):
164    """Print out a text representation of a code object."""
165    for attr in ["name", "argcount", "posonlyargcount",
166                 "kwonlyargcount", "names", "varnames",
167                 "cellvars", "freevars", "nlocals", "flags"]:
168        print("%s: %s" % (attr, getattr(co, "co_" + attr)))
169    print("consts:", tuple(consts(co.co_consts)))
170
171# Needed for test_closure_injection below
172# Defined at global scope to avoid implicitly closing over __class__
173def external_getitem(self, i):
174    return f"Foreign getitem: {super().__getitem__(i)}"
175
176class CodeTest(unittest.TestCase):
177
178    @cpython_only
179    def test_newempty(self):
180        _testcapi = import_helper.import_module("_testcapi")
181        co = _testcapi.code_newempty("filename", "funcname", 15)
182        self.assertEqual(co.co_filename, "filename")
183        self.assertEqual(co.co_name, "funcname")
184        self.assertEqual(co.co_firstlineno, 15)
185        #Empty code object should raise, but not crash the VM
186        with self.assertRaises(Exception):
187            exec(co)
188
189    @cpython_only
190    def test_closure_injection(self):
191        # From https://bugs.python.org/issue32176
192        from types import FunctionType
193
194        def create_closure(__class__):
195            return (lambda: __class__).__closure__
196
197        def new_code(c):
198            '''A new code object with a __class__ cell added to freevars'''
199            return c.replace(co_freevars=c.co_freevars + ('__class__',), co_code=bytes([COPY_FREE_VARS, 1])+c.co_code)
200
201        def add_foreign_method(cls, name, f):
202            code = new_code(f.__code__)
203            assert not f.__closure__
204            closure = create_closure(cls)
205            defaults = f.__defaults__
206            setattr(cls, name, FunctionType(code, globals(), name, defaults, closure))
207
208        class List(list):
209            pass
210
211        add_foreign_method(List, "__getitem__", external_getitem)
212
213        # Ensure the closure injection actually worked
214        function = List.__getitem__
215        class_ref = function.__closure__[0].cell_contents
216        self.assertIs(class_ref, List)
217
218        # Ensure the zero-arg super() call in the injected method works
219        obj = List([1, 2, 3])
220        self.assertEqual(obj[0], "Foreign getitem: 1")
221
222    def test_constructor(self):
223        def func(): pass
224        co = func.__code__
225        CodeType = type(co)
226
227        # test code constructor
228        CodeType(co.co_argcount,
229                        co.co_posonlyargcount,
230                        co.co_kwonlyargcount,
231                        co.co_nlocals,
232                        co.co_stacksize,
233                        co.co_flags,
234                        co.co_code,
235                        co.co_consts,
236                        co.co_names,
237                        co.co_varnames,
238                        co.co_filename,
239                        co.co_name,
240                        co.co_qualname,
241                        co.co_firstlineno,
242                        co.co_linetable,
243                        co.co_exceptiontable,
244                        co.co_freevars,
245                        co.co_cellvars)
246
247    def test_qualname(self):
248        self.assertEqual(
249            CodeTest.test_qualname.__code__.co_qualname,
250            CodeTest.test_qualname.__qualname__
251        )
252
253    def test_replace(self):
254        def func():
255            x = 1
256            return x
257        code = func.__code__
258
259        # different co_name, co_varnames, co_consts
260        def func2():
261            y = 2
262            z = 3
263            return y
264        code2 = func2.__code__
265
266        for attr, value in (
267            ("co_argcount", 0),
268            ("co_posonlyargcount", 0),
269            ("co_kwonlyargcount", 0),
270            ("co_nlocals", 1),
271            ("co_stacksize", 1),
272            ("co_flags", code.co_flags | inspect.CO_COROUTINE),
273            ("co_firstlineno", 100),
274            ("co_code", code2.co_code),
275            ("co_consts", code2.co_consts),
276            ("co_names", ("myname",)),
277            ("co_varnames", ('spam',)),
278            ("co_freevars", ("freevar",)),
279            ("co_cellvars", ("cellvar",)),
280            ("co_filename", "newfilename"),
281            ("co_name", "newname"),
282            ("co_linetable", code2.co_linetable),
283        ):
284            with self.subTest(attr=attr, value=value):
285                new_code = code.replace(**{attr: value})
286                self.assertEqual(getattr(new_code, attr), value)
287                new_code = copy.replace(code, **{attr: value})
288                self.assertEqual(getattr(new_code, attr), value)
289
290        new_code = code.replace(co_varnames=code2.co_varnames,
291                                co_nlocals=code2.co_nlocals)
292        self.assertEqual(new_code.co_varnames, code2.co_varnames)
293        self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
294        new_code = copy.replace(code, co_varnames=code2.co_varnames,
295                                co_nlocals=code2.co_nlocals)
296        self.assertEqual(new_code.co_varnames, code2.co_varnames)
297        self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
298
299    def test_nlocals_mismatch(self):
300        def func():
301            x = 1
302            return x
303        co = func.__code__
304        assert co.co_nlocals > 0;
305
306        # First we try the constructor.
307        CodeType = type(co)
308        for diff in (-1, 1):
309            with self.assertRaises(ValueError):
310                CodeType(co.co_argcount,
311                         co.co_posonlyargcount,
312                         co.co_kwonlyargcount,
313                         # This is the only change.
314                         co.co_nlocals + diff,
315                         co.co_stacksize,
316                         co.co_flags,
317                         co.co_code,
318                         co.co_consts,
319                         co.co_names,
320                         co.co_varnames,
321                         co.co_filename,
322                         co.co_name,
323                         co.co_qualname,
324                         co.co_firstlineno,
325                         co.co_linetable,
326                         co.co_exceptiontable,
327                         co.co_freevars,
328                         co.co_cellvars,
329                         )
330        # Then we try the replace method.
331        with self.assertRaises(ValueError):
332            co.replace(co_nlocals=co.co_nlocals - 1)
333        with self.assertRaises(ValueError):
334            co.replace(co_nlocals=co.co_nlocals + 1)
335
336    def test_shrinking_localsplus(self):
337        # Check that PyCode_NewWithPosOnlyArgs resizes both
338        # localsplusnames and localspluskinds, if an argument is a cell.
339        def func(arg):
340            return lambda: arg
341        code = func.__code__
342        newcode = code.replace(co_name="func")  # Should not raise SystemError
343        self.assertEqual(code, newcode)
344
345    def test_empty_linetable(self):
346        def func():
347            pass
348        new_code = code = func.__code__.replace(co_linetable=b'')
349        self.assertEqual(list(new_code.co_lines()), [])
350
351    def test_co_lnotab_is_deprecated(self):  # TODO: remove in 3.14
352        def func():
353            pass
354
355        with self.assertWarns(DeprecationWarning):
356            func.__code__.co_lnotab
357
358    def test_invalid_bytecode(self):
359        def foo():
360            pass
361
362        # assert that opcode 229 is invalid
363        self.assertEqual(opname[229], '<229>')
364
365        # change first opcode to 0xeb (=229)
366        foo.__code__ = foo.__code__.replace(
367            co_code=b'\xe5' + foo.__code__.co_code[1:])
368
369        msg = "unknown opcode 229"
370        with self.assertRaisesRegex(SystemError, msg):
371            foo()
372
373    @requires_debug_ranges()
374    def test_co_positions_artificial_instructions(self):
375        import dis
376
377        namespace = {}
378        exec(textwrap.dedent("""\
379        try:
380            1/0
381        except Exception as e:
382            exc = e
383        """), namespace)
384
385        exc = namespace['exc']
386        traceback = exc.__traceback__
387        code = traceback.tb_frame.f_code
388
389        artificial_instructions = []
390        for instr, positions in instructions_with_positions(
391            dis.get_instructions(code), code.co_positions()
392        ):
393            # If any of the positions is None, then all have to
394            # be None as well for the case above. There are still
395            # some places in the compiler, where the artificial instructions
396            # get assigned the first_lineno but they don't have other positions.
397            # There is no easy way of inferring them at that stage, so for now
398            # we don't support it.
399            self.assertIn(positions.count(None), [0, 3, 4])
400
401            if not any(positions):
402                artificial_instructions.append(instr)
403
404        self.assertEqual(
405            [
406                (instruction.opname, instruction.argval)
407                for instruction in artificial_instructions
408            ],
409            [
410                ("PUSH_EXC_INFO", None),
411                ("LOAD_CONST", None), # artificial 'None'
412                ("STORE_NAME", "e"),  # XX: we know the location for this
413                ("DELETE_NAME", "e"),
414                ("RERAISE", 1),
415                ("COPY", 3),
416                ("POP_EXCEPT", None),
417                ("RERAISE", 1)
418            ]
419        )
420
421    def test_endline_and_columntable_none_when_no_debug_ranges(self):
422        # Make sure that if `-X no_debug_ranges` is used, there is
423        # minimal debug info
424        code = textwrap.dedent("""
425            def f():
426                pass
427
428            positions = f.__code__.co_positions()
429            for line, end_line, column, end_column in positions:
430                assert line == end_line
431                assert column is None
432                assert end_column is None
433            """)
434        assert_python_ok('-X', 'no_debug_ranges', '-c', code)
435
436    def test_endline_and_columntable_none_when_no_debug_ranges_env(self):
437        # Same as above but using the environment variable opt out.
438        code = textwrap.dedent("""
439            def f():
440                pass
441
442            positions = f.__code__.co_positions()
443            for line, end_line, column, end_column in positions:
444                assert line == end_line
445                assert column is None
446                assert end_column is None
447            """)
448        assert_python_ok('-c', code, PYTHONNODEBUGRANGES='1')
449
450    # co_positions behavior when info is missing.
451
452    @requires_debug_ranges()
453    def test_co_positions_empty_linetable(self):
454        def func():
455            x = 1
456        new_code = func.__code__.replace(co_linetable=b'')
457        positions = new_code.co_positions()
458        for line, end_line, column, end_column in positions:
459            self.assertIsNone(line)
460            self.assertEqual(end_line, new_code.co_firstlineno + 1)
461
462    def test_code_equality(self):
463        def f():
464            try:
465                a()
466            except:
467                b()
468            else:
469                c()
470            finally:
471                d()
472        code_a = f.__code__
473        code_b = code_a.replace(co_linetable=b"")
474        code_c = code_a.replace(co_exceptiontable=b"")
475        code_d = code_b.replace(co_exceptiontable=b"")
476        self.assertNotEqual(code_a, code_b)
477        self.assertNotEqual(code_a, code_c)
478        self.assertNotEqual(code_a, code_d)
479        self.assertNotEqual(code_b, code_c)
480        self.assertNotEqual(code_b, code_d)
481        self.assertNotEqual(code_c, code_d)
482
483    def test_code_hash_uses_firstlineno(self):
484        c1 = (lambda: 1).__code__
485        c2 = (lambda: 1).__code__
486        self.assertNotEqual(c1, c2)
487        self.assertNotEqual(hash(c1), hash(c2))
488        c3 = c1.replace(co_firstlineno=17)
489        self.assertNotEqual(c1, c3)
490        self.assertNotEqual(hash(c1), hash(c3))
491
492    def test_code_hash_uses_order(self):
493        # Swapping posonlyargcount and kwonlyargcount should change the hash.
494        c = (lambda x, y, *, z=1, w=1: 1).__code__
495        self.assertEqual(c.co_argcount, 2)
496        self.assertEqual(c.co_posonlyargcount, 0)
497        self.assertEqual(c.co_kwonlyargcount, 2)
498        swapped = c.replace(co_posonlyargcount=2, co_kwonlyargcount=0)
499        self.assertNotEqual(c, swapped)
500        self.assertNotEqual(hash(c), hash(swapped))
501
502    def test_code_hash_uses_bytecode(self):
503        c = (lambda x, y: x + y).__code__
504        d = (lambda x, y: x * y).__code__
505        c1 = c.replace(co_code=d.co_code)
506        self.assertNotEqual(c, c1)
507        self.assertNotEqual(hash(c), hash(c1))
508
509    @cpython_only
510    def test_code_equal_with_instrumentation(self):
511        """ GH-109052
512
513        Make sure the instrumentation doesn't affect the code equality
514        The validity of this test relies on the fact that "x is x" and
515        "x in x" have only one different instruction and the instructions
516        have the same argument.
517
518        """
519        code1 = compile("x is x", "example.py", "eval")
520        code2 = compile("x in x", "example.py", "eval")
521        sys._getframe().f_trace_opcodes = True
522        sys.settrace(lambda *args: None)
523        exec(code1, {'x': []})
524        exec(code2, {'x': []})
525        self.assertNotEqual(code1, code2)
526        sys.settrace(None)
527
528
529def isinterned(s):
530    return s is sys.intern(('_' + s + '_')[1:-1])
531
532class CodeConstsTest(unittest.TestCase):
533
534    def find_const(self, consts, value):
535        for v in consts:
536            if v == value:
537                return v
538        self.assertIn(value, consts)  # raises an exception
539        self.fail('Should never be reached')
540
541    def assertIsInterned(self, s):
542        if not isinterned(s):
543            self.fail('String %r is not interned' % (s,))
544
545    def assertIsNotInterned(self, s):
546        if isinterned(s):
547            self.fail('String %r is interned' % (s,))
548
549    @cpython_only
550    def test_interned_string(self):
551        co = compile('res = "str_value"', '?', 'exec')
552        v = self.find_const(co.co_consts, 'str_value')
553        self.assertIsInterned(v)
554
555    @cpython_only
556    def test_interned_string_in_tuple(self):
557        co = compile('res = ("str_value",)', '?', 'exec')
558        v = self.find_const(co.co_consts, ('str_value',))
559        self.assertIsInterned(v[0])
560
561    @cpython_only
562    def test_interned_string_in_frozenset(self):
563        co = compile('res = a in {"str_value"}', '?', 'exec')
564        v = self.find_const(co.co_consts, frozenset(('str_value',)))
565        self.assertIsInterned(tuple(v)[0])
566
567    @cpython_only
568    def test_interned_string_default(self):
569        def f(a='str_value'):
570            return a
571        self.assertIsInterned(f())
572
573    @cpython_only
574    @unittest.skipIf(Py_GIL_DISABLED, "free-threaded build interns all string constants")
575    def test_interned_string_with_null(self):
576        co = compile(r'res = "str\0value!"', '?', 'exec')
577        v = self.find_const(co.co_consts, 'str\0value!')
578        self.assertIsNotInterned(v)
579
580    @cpython_only
581    @unittest.skipUnless(Py_GIL_DISABLED, "does not intern all constants")
582    @skip_if_suppress_immortalization()
583    def test_interned_constants(self):
584        # compile separately to avoid compile time de-duping
585
586        globals = {}
587        exec(textwrap.dedent("""
588            def func1():
589                return (0.0, (1, 2, "hello"))
590        """), globals)
591
592        exec(textwrap.dedent("""
593            def func2():
594                return (0.0, (1, 2, "hello"))
595        """), globals)
596
597        self.assertTrue(globals["func1"]() is globals["func2"]())
598
599
600class CodeWeakRefTest(unittest.TestCase):
601
602    @suppress_immortalization()
603    def test_basic(self):
604        # Create a code object in a clean environment so that we know we have
605        # the only reference to it left.
606        namespace = {}
607        exec("def f(): pass", globals(), namespace)
608        f = namespace["f"]
609        del namespace
610
611        self.called = False
612        def callback(code):
613            self.called = True
614
615        # f is now the last reference to the function, and through it, the code
616        # object.  While we hold it, check that we can create a weakref and
617        # deref it.  Then delete it, and check that the callback gets called and
618        # the reference dies.
619        coderef = weakref.ref(f.__code__, callback)
620        self.assertTrue(bool(coderef()))
621        del f
622        gc_collect()  # For PyPy or other GCs.
623        self.assertFalse(bool(coderef()))
624        self.assertTrue(self.called)
625
626# Python implementation of location table parsing algorithm
627def read(it):
628    return next(it)
629
630def read_varint(it):
631    b = read(it)
632    val = b & 63;
633    shift = 0;
634    while b & 64:
635        b = read(it)
636        shift += 6
637        val |= (b&63) << shift
638    return val
639
640def read_signed_varint(it):
641    uval = read_varint(it)
642    if uval & 1:
643        return -(uval >> 1)
644    else:
645        return uval >> 1
646
647def parse_location_table(code):
648    line = code.co_firstlineno
649    it = iter(code.co_linetable)
650    while True:
651        try:
652            first_byte = read(it)
653        except StopIteration:
654            return
655        code = (first_byte >> 3) & 15
656        length = (first_byte & 7) + 1
657        if code == 15:
658            yield (code, length, None, None, None, None)
659        elif code == 14:
660            line_delta = read_signed_varint(it)
661            line += line_delta
662            end_line = line + read_varint(it)
663            col = read_varint(it)
664            if col == 0:
665                col = None
666            else:
667                col -= 1
668            end_col = read_varint(it)
669            if end_col == 0:
670                end_col = None
671            else:
672                end_col -= 1
673            yield (code, length, line, end_line, col, end_col)
674        elif code == 13: # No column
675            line_delta = read_signed_varint(it)
676            line += line_delta
677            yield (code, length, line, line, None, None)
678        elif code in (10, 11, 12): # new line
679            line_delta = code - 10
680            line += line_delta
681            column = read(it)
682            end_column = read(it)
683            yield (code, length, line, line, column, end_column)
684        else:
685            assert (0 <= code < 10)
686            second_byte = read(it)
687            column = code << 3 | (second_byte >> 4)
688            yield (code, length, line, line, column, column + (second_byte & 15))
689
690def positions_from_location_table(code):
691    for _, length, line, end_line, col, end_col in parse_location_table(code):
692        for _ in range(length):
693            yield (line, end_line, col, end_col)
694
695def dedup(lst, prev=object()):
696    for item in lst:
697        if item != prev:
698            yield item
699            prev = item
700
701def lines_from_postions(positions):
702    return dedup(l for (l, _, _, _) in positions)
703
704def misshappen():
705    """
706
707
708
709
710
711    """
712    x = (
713
714
715        4
716
717        +
718
719        y
720
721    )
722    y = (
723        a
724        +
725            b
726                +
727
728                d
729        )
730    return q if (
731
732        x
733
734        ) else p
735
736def bug93662():
737    example_report_generation_message= (
738            """
739            """
740    ).strip()
741    raise ValueError()
742
743
744class CodeLocationTest(unittest.TestCase):
745
746    def check_positions(self, func):
747        pos1 = list(func.__code__.co_positions())
748        pos2 = list(positions_from_location_table(func.__code__))
749        for l1, l2 in zip(pos1, pos2):
750            self.assertEqual(l1, l2)
751        self.assertEqual(len(pos1), len(pos2))
752
753    def test_positions(self):
754        self.check_positions(parse_location_table)
755        self.check_positions(misshappen)
756        self.check_positions(bug93662)
757
758    def check_lines(self, func):
759        co = func.__code__
760        lines1 = [line for _, _, line in co.co_lines()]
761        self.assertEqual(lines1, list(dedup(lines1)))
762        lines2 = list(lines_from_postions(positions_from_location_table(co)))
763        for l1, l2 in zip(lines1, lines2):
764            self.assertEqual(l1, l2)
765        self.assertEqual(len(lines1), len(lines2))
766
767    def test_lines(self):
768        self.check_lines(parse_location_table)
769        self.check_lines(misshappen)
770        self.check_lines(bug93662)
771
772    @cpython_only
773    def test_code_new_empty(self):
774        # If this test fails, it means that the construction of PyCode_NewEmpty
775        # needs to be modified! Please update this test *and* PyCode_NewEmpty,
776        # so that they both stay in sync.
777        def f():
778            pass
779        PY_CODE_LOCATION_INFO_NO_COLUMNS = 13
780        f.__code__ = f.__code__.replace(
781            co_stacksize=1,
782            co_firstlineno=42,
783            co_code=bytes(
784                [
785                    dis.opmap["RESUME"], 0,
786                    dis.opmap["LOAD_ASSERTION_ERROR"], 0,
787                    dis.opmap["RAISE_VARARGS"], 1,
788                ]
789            ),
790            co_linetable=bytes(
791                [
792                    (1 << 7)
793                    | (PY_CODE_LOCATION_INFO_NO_COLUMNS << 3)
794                    | (3 - 1),
795                    0,
796                ]
797            ),
798        )
799        self.assertRaises(AssertionError, f)
800        self.assertEqual(
801            list(f.__code__.co_positions()),
802            3 * [(42, 42, None, None)],
803        )
804
805
806if check_impl_detail(cpython=True) and ctypes is not None:
807    py = ctypes.pythonapi
808    freefunc = ctypes.CFUNCTYPE(None,ctypes.c_voidp)
809
810    RequestCodeExtraIndex = py.PyUnstable_Eval_RequestCodeExtraIndex
811    RequestCodeExtraIndex.argtypes = (freefunc,)
812    RequestCodeExtraIndex.restype = ctypes.c_ssize_t
813
814    SetExtra = py.PyUnstable_Code_SetExtra
815    SetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t, ctypes.c_voidp)
816    SetExtra.restype = ctypes.c_int
817
818    GetExtra = py.PyUnstable_Code_GetExtra
819    GetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t,
820                         ctypes.POINTER(ctypes.c_voidp))
821    GetExtra.restype = ctypes.c_int
822
823    LAST_FREED = None
824    def myfree(ptr):
825        global LAST_FREED
826        LAST_FREED = ptr
827
828    FREE_FUNC = freefunc(myfree)
829    FREE_INDEX = RequestCodeExtraIndex(FREE_FUNC)
830
831    class CoExtra(unittest.TestCase):
832        def get_func(self):
833            # Defining a function causes the containing function to have a
834            # reference to the code object.  We need the code objects to go
835            # away, so we eval a lambda.
836            return eval('lambda:42')
837
838        def test_get_non_code(self):
839            f = self.get_func()
840
841            self.assertRaises(SystemError, SetExtra, 42, FREE_INDEX,
842                              ctypes.c_voidp(100))
843            self.assertRaises(SystemError, GetExtra, 42, FREE_INDEX,
844                              ctypes.c_voidp(100))
845
846        def test_bad_index(self):
847            f = self.get_func()
848            self.assertRaises(SystemError, SetExtra, f.__code__,
849                              FREE_INDEX+100, ctypes.c_voidp(100))
850            self.assertEqual(GetExtra(f.__code__, FREE_INDEX+100,
851                              ctypes.c_voidp(100)), 0)
852
853        @suppress_immortalization()
854        def test_free_called(self):
855            # Verify that the provided free function gets invoked
856            # when the code object is cleaned up.
857            f = self.get_func()
858
859            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(100))
860            del f
861            gc_collect()  # For free-threaded build
862            self.assertEqual(LAST_FREED, 100)
863
864        def test_get_set(self):
865            # Test basic get/set round tripping.
866            f = self.get_func()
867
868            extra = ctypes.c_voidp()
869
870            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(200))
871            # reset should free...
872            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(300))
873            self.assertEqual(LAST_FREED, 200)
874
875            extra = ctypes.c_voidp()
876            GetExtra(f.__code__, FREE_INDEX, extra)
877            self.assertEqual(extra.value, 300)
878            del f
879
880        @threading_helper.requires_working_threading()
881        @suppress_immortalization()
882        def test_free_different_thread(self):
883            # Freeing a code object on a different thread then
884            # where the co_extra was set should be safe.
885            f = self.get_func()
886            class ThreadTest(threading.Thread):
887                def __init__(self, f, test):
888                    super().__init__()
889                    self.f = f
890                    self.test = test
891                def run(self):
892                    del self.f
893                    gc_collect()
894                    # gh-117683: In the free-threaded build, the code object's
895                    # destructor may still be running concurrently in the main
896                    # thread.
897                    if not Py_GIL_DISABLED:
898                        self.test.assertEqual(LAST_FREED, 500)
899
900            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(500))
901            tt = ThreadTest(f, self)
902            del f
903            tt.start()
904            tt.join()
905            gc_collect()  # For free-threaded build
906            self.assertEqual(LAST_FREED, 500)
907
908
909def load_tests(loader, tests, pattern):
910    tests.addTest(doctest.DocTestSuite())
911    return tests
912
913
914if __name__ == "__main__":
915    unittest.main()
916