# Owner(s): ["module: dynamo"] import collections import dis import sys import unittest import torch import torch._dynamo.test_case from torch._dynamo import bytecode_analysis, bytecode_transformation from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312 class BytecodeTests(torch._dynamo.test_case.TestCase): @skipIfNotPy311 def test_linetable_311_writer1(self): def fn(): a = 10 b = 20 # prevent LOAD_FAST_LOAD_FAST in 3.13 by wrapping b with g() c = a + g(b) f = "linetable_writer" return f"Test if {f} generates correct co_linetable: {c}" keys = bytecode_transformation.get_code_keys() code_options = {k: getattr(fn.__code__, k) for k in keys} result = bytecode_transformation.clean_and_assemble_instructions( bytecode_transformation.cleaned_instructions(fn.__code__), keys, code_options, ) l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions()) self.assertEqual(len(l1), len(l2)) for p1, p2 in zip(l1, l2): self.assertEqual(p1, p2) # TODO co_lnotab is deprecated in 3.12 and will be removed in 3.14 # In 3.11+,. it is computed lazily from other linetable attributes (e.g. co_linetable), # so we do not set this attribute ourselves. self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab) @skipIfNotPy311 def test_linetable_311_writer2(self): """ test large ops (LOAD_METHOD) and EXTENDED_ARGS fn_str is in the form: def fn(): ... x0 = 1 x1 = 1 ... l = [x0, x1, ...] """ fn_str = f"""\ def fn(): foo.bar(1, 2, 3) {str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] """ locals = {} exec(fn_str, {}, locals) fn = locals["fn"] orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn)))) self.assertIn("EXTENDED_ARG", orig_inst_str) load_method_str = "LOAD_ATTR" if sys.version_info >= (3, 12) else "LOAD_METHOD" self.assertIn(load_method_str, orig_inst_str) keys = bytecode_transformation.get_code_keys() code_options = {k: getattr(fn.__code__, k) for k in keys} result = bytecode_transformation.clean_and_assemble_instructions( bytecode_transformation.cleaned_instructions(fn.__code__), keys, code_options, ) new_inst_str = "\n".join(list(map(str, result[0]))) self.assertIn("EXTENDED_ARG", new_inst_str) self.assertIn(load_method_str, new_inst_str) l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions()) self.assertEqual(len(l1), len(l2)) for p1, p2 in zip(l1, l2): self.assertEqual(p1, p2) self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab) @unittest.skipIf( sys.version_info < (3, 10) or sys.version_info >= (3, 11), "linetable test for Python 3.10", ) def test_linetable_310_writer(self): def fn(): a = 10 b = 20 c = a + b f = "linetable_writer" return f"Test if {f} generates correct co_linetable: {c}" inst = dis.get_instructions(fn) result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) self.assertTrue(result[1] == fn.__code__.co_linetable) @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10") def test_lnotab_writer(self): def fn(): a = 10 b = 20 c = a + b f = "lnotab_writer" return f"Test if {f} generates correct co_lnotab: {c}" inst = dis.get_instructions(fn) result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) self.assertTrue(result[1] == fn.__code__.co_lnotab) def test_if_tensor_is_none(self): """ Python 3.11 adds new jump instructions that check if TOS is None. We do not support these instructions. """ def f(x, y): z = 1 if x is None: z *= 2 if y is not None: z *= 3 return z opt_f = torch._dynamo.optimize("eager", nopython=True)(f) self.assertEqual(opt_f(None, torch.ones(2)), 6) if sys.version_info >= (3, 11): insts = bytecode_transformation.cleaned_instructions(f.__code__) for inst in insts: self.assertNotIn("_NONE", inst.opname) @skipIfNotPy311 def test_py311_jump_offset(self): new_inst = bytecode_transformation.create_instruction consts = (None, 1, 2, 3, 4) def create_test_code(jump_opname, target_idx): targets = [ new_inst("LOAD_CONST", argval=1), new_inst("LOAD_CONST", argval=3), ] jump_to_target_inst = new_inst(jump_opname, target=targets[target_idx]) """ pseudocode of generated bytecode: def test_py311_fn(): goto target1 target0: return 1 target1: goto [target0/target2] (via fwd or bwd jump) return 2 target2: return 3 return 4 """ # test with LOAD_GLOBAL since it has a different instruction size insts = [ new_inst("RESUME", arg=0), new_inst("JUMP_FORWARD", target=jump_to_target_inst), targets[0], new_inst("LOAD_GLOBAL", arg=0, argval="print"), new_inst("POP_TOP"), new_inst("RETURN_VALUE"), jump_to_target_inst, new_inst("LOAD_CONST", argval=2), new_inst("LOAD_GLOBAL", arg=0, argval="print"), new_inst("POP_TOP"), new_inst("RETURN_VALUE"), targets[1], new_inst("RETURN_VALUE"), new_inst("LOAD_CONST", argval=4), new_inst("RETURN_VALUE"), ] code_options = collections.OrderedDict( [ ("co_argcount", 0), ("co_posonlyargcount", 0), ("co_kwonlyargcount", 0), ("co_nlocals", 0), ("co_stacksize", 2), ("co_flags", 3), ("co_code", b""), ("co_consts", consts), ("co_names", ("print",)), ("co_varnames", ()), ("co_filename", __file__), ("co_name", "test_py311_fn"), ("co_qualname", "test_py311_fn"), ("co_firstlineno", 1), ("co_linetable", b""), ("co_exceptiontable", b""), ("co_freevars", ()), ("co_cellvars", ()), ] ) return bytecode_transformation.clean_and_assemble_instructions( insts, list(code_options.keys()), code_options, ) # format: jump_opname, target_idx, expected forward jump, expected return value test_args = ( ("JUMP_FORWARD", 0, False, 1), ("JUMP_FORWARD", 1, True, 3), ("JUMP_BACKWARD", 0, False, 1), ("JUMP_BACKWARD", 1, True, 3), ) for test in test_args: insts, code = create_test_code(test[0], test[1]) # check if offset of latest jump instruction is forward/backward for inst in reversed(insts): if inst.opname.startswith("JUMP"): if test[2]: self.assertIn("FORWARD", inst.opname) else: self.assertIn("BACKWARD", inst.opname) break # run the code and check result def dummy_fn(): pass dummy_fn.__code__ = code self.assertEqual(dummy_fn(), test[3]) dummy_opt = torch._dynamo.optimize("eager")(dummy_fn) self.assertEqual(dummy_opt(), test[3]) def test_exception_table_encode_varint(self): # these numbers have no real meaning to them nums = [ 0b111_101010_000000, 0b1100_111000_010101_101010, ] b = bytecode_transformation.encode_exception_table_varint( nums[0] ) + bytecode_transformation.encode_exception_table_varint(nums[1]) nums_new = [] b_iter = iter(bytes(b)) while True: try: nums_new.append( bytecode_transformation.decode_exception_table_varint(b_iter) ) except StopIteration: break self.assertEqual(nums, nums_new) @skipIfNotPy311 def test_exception_table_parsing(self): def fn(): try: with a(): b() c() except Exception: d() finally: e() f() tab = bytecode_transformation.parse_exception_table( fn.__code__.co_exceptiontable ) b = bytecode_transformation.assemble_exception_table(tab) self.assertEqual(b, fn.__code__.co_exceptiontable) @skipIfNotPy311 def test_exception_table_e2e(self): def fn(): try: with a(): b() c() except Exception: d() finally: e() f() def nothing(*args): pass code = bytecode_transformation.transform_code_object(fn.__code__, nothing) self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) @skipIfNotPy311 def test_exception_table_e2e_2(self): # last instructions of an exn_table entry is a large instruction # i.e., LOAD_GLOBAL a def fn(): try: return a except Exception: pass def nothing(*args): pass code = bytecode_transformation.transform_code_object(fn.__code__, nothing) self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) @skipIfNotPy311 def test_exception_table_entry_propagation(self): insts = [] for _ in range(10): insts.append(bytecode_transformation.create_instruction("NOP")) insts[8].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[0], insts[9], insts[0], 0, True ) insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[0], insts[0], insts[1], 0, True ) insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[0], insts[2], insts[2], 0, True ) insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[4], insts[6], insts[3], 0, True ) insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[9], insts[9], insts[4], 0, True ) insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[7], insts[9], insts[5], 0, True ) bytecode_transformation.propagate_inst_exn_table_entries(insts) expected = [1, 2, 2, 0, 3, 3, 3, 5, 5, 4] for inst, exp in zip(insts, expected): self.assertIsNotNone(inst.exn_tab_entry) self.assertIs(inst.exn_tab_entry.target, insts[exp]) @skipIfNotPy311 def test_compute_exception_table_nested(self): insts = [] for _ in range(20): insts.append(bytecode_transformation.create_instruction("NOP")) insts[10].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[1], insts[10], insts[0], 0, True ) insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[1], insts[1], insts[1], 0, True ) insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[1], insts[3], insts[2], 0, True ) insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[5], insts[7], insts[3], 0, True ) insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[10], insts[10], insts[4], 0, True ) insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[8], insts[10], insts[5], 0, True ) insts[14].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[13], insts[17], insts[6], 0, True ) insts[16].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( insts[15], insts[16], insts[7], 0, True ) bytecode_transformation.update_offsets(insts) tab = bytecode_transformation.compute_exception_table(insts) expected = [ (1, 1, 1), (2, 3, 2), (4, 4, 0), (5, 7, 3), (8, 9, 5), (10, 10, 4), (13, 14, 6), (15, 16, 7), (17, 17, 6), ] self.assertEqual(len(tab), len(expected)) for entry, exp in zip(tab, expected): self.assertEqual(entry.start, exp[0] * 2) self.assertEqual(entry.end, exp[1] * 2) self.assertEqual(entry.target, exp[2] * 2) @skipIfNotPy311 def test_remove_dead_code_with_exn_table_entries(self): create_instruction = bytecode_transformation.create_instruction target1 = create_instruction("NOP") target2 = create_instruction("NOP") target3 = create_instruction("NOP") exn_start = create_instruction("NOP") exn_end = create_instruction("NOP") insts = [ create_instruction("JUMP_FORWARD", target=target1), exn_start, # dead target1, create_instruction("JUMP_FORWARD", target=target3), exn_end, # dead target2, target3, ] exn_start.exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( exn_start, exn_end, target2, 0, True ) bytecode_transformation.propagate_inst_exn_table_entries(insts) insts = bytecode_analysis.remove_dead_code(insts) self.assertEqual(len(insts), 5) self.assertNotIn(exn_start, insts) self.assertNotIn(exn_end, insts) self.assertIn(target2, insts) self.assertIn(target3, insts) bytecode_transformation.update_offsets(insts) tab = bytecode_transformation.compute_exception_table(insts) self.assertEqual(len(tab), 1) self.assertEqual(tab[0].start, 2) self.assertEqual(tab[0].end, 4) self.assertEqual(tab[0].target, 6) def test_bytecode_from_template(self): def fn(d1): for k, v in d1.items(): d2[k] = v varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"} insts = bytecode_transformation.bytecode_from_template(fn, varname_map) for inst in insts: self.assertIsNone(inst.starts_line) if inst.opname.startswith("LOAD"): self.assertNotIn(inst.argval, varname_map) if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"): self.assertIsNone(inst.arg) self.assertFalse(inst.opname.startswith("RETURN")) @skipIfNotPy311 def test_bytecode_from_template_noprefix(self): # Test that 3.11+ prefix instructions are removed def gen_fn(): cl = None def fn(): return cl return fn fn = gen_fn() dis_insts = list(dis.get_instructions(fn)) names = {inst.opname for inst in dis_insts} self.assertIn("RESUME", names) self.assertIn("COPY_FREE_VARS", names) insts = bytecode_transformation.bytecode_from_template(fn) names = {inst.opname for inst in insts} self.assertNotIn("RESUME", names) self.assertNotIn("COPY_FREE_VARS", names) def test_bytecode_from_template_noreturn1(self): # Test that functions with multiple returns will have their # returns replaced with jumps to the end def fn(): if x: return y z = 3 return z dis_insts = list(dis.get_instructions(fn)) dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts)) self.assertGreater(len(dis_returns), 1) self.assertTrue(dis_insts[-1].opname.startswith("RETURN")) insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) self.assertEqual(insts[-1].opname, "NOP") self.assertEqual(len(dis_insts), len(insts)) for i0, i1 in zip(dis_insts, insts): if i0.opname.startswith("RETURN"): if i1 is insts[-1]: continue self.assertIn("JUMP", i1.opname) self.assertIs(i1.target, insts[-1]) # Should work with 3.10, but testing with 3.11+ is sufficient. # In 3.8, `fn` ends with a RETURN_VALUE. @skipIfNotPy311 def test_bytecode_from_template_noreturn2(self): # Test function that doesn't end with RETURN_VALUE def fn(): if x: return x if x: return x raise RuntimeError dis_insts = list(dis.get_instructions(fn)) self.assertFalse(dis_insts[-1].opname.startswith("RETURN")) insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) self.assertEqual(insts[-1].opname, "NOP") self.assertEqual(insts[-2].opname, dis_insts[-1].opname) self.assertEqual(len(dis_insts) + 1, len(insts)) for i0, i1 in zip(dis_insts, insts): if i0.opname.startswith("RETURN"): self.assertIn("JUMP", i1.opname) self.assertIs(i1.target, insts[-1]) @skipIfNotPy312 def test_bytecode_from_template_noreturn_const(self): # Test 3.12+ RETURN_CONST def fn(): if x: return 1 return 0 dis_insts = list(dis.get_instructions(fn)) dis_return_consts = list( filter(lambda x: x.opname == "RETURN_CONST", dis_insts) ) self.assertGreater(len(dis_return_consts), 1) self.assertTrue(dis_insts[-1].opname == "RETURN_CONST") insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) self.assertEqual(insts[-1].opname, "NOP") insts_i = 0 for i, inst in enumerate(dis_insts): if inst.opname == "RETURN_CONST": self.assertEqual(insts[insts_i].opname, "LOAD_CONST") insts_i += 1 if insts_i != len(insts) - 1: self.assertIn("JUMP", insts[insts_i].opname) self.assertIs(insts[insts_i].target, insts[-1]) insts_i += 1 class BytecodeHookTests(torch._dynamo.test_case.TestCase): def test_bytecode_hook(self): def fn(a, b): return a - b * 10 def hook(code, out_code): print(code) print(out_code) return code torch._dynamo.reset() handle = torch._dynamo.convert_frame.register_bytecode_hook(hook) try: opt_fn = torch.compile(fn) for i in range(2, 12): opt_fn(torch.randn(i), torch.randn(i)) finally: handle.remove() if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()