# Owner(s): ["module: dynamo"] import torch import torch._dynamo.test_case from torch._C._dynamo.eval_frame import set_eval_frame from torch._guards import CompileId def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): local = 1 return { "local": local, "arg1": arg1, "positional_only_arg": positional_only_arg, "keyword_only_arg": keyword_only_arg, "kwargs": kwargs, } def varkwargs_code1(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): # remove a local variable: local = 1 return { "local": 1, "arg1": arg1, "positional_only_arg": positional_only_arg, "keyword_only_arg": keyword_only_arg, "kwargs": kwargs, } def varkwargs_code2(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): # introduce a local variable local1 = 0 local2 = 1 return { "local": local1 + local2, "arg1": arg1, "positional_only_arg": positional_only_arg, "keyword_only_arg": keyword_only_arg, "kwargs": kwargs, } def target_with_varargs(arg1, /, positional_only_arg, *varargs, **kwargs): local = 1 return { "local": local, "arg1": arg1, "positional_only_arg": positional_only_arg, "varargs": varargs, "kwargs": kwargs, } def varargs_code1(arg1, /, positional_only_arg, *varargs, **kwargs): # remove a local variable: local = 1 return { "local": 1, "arg1": arg1, "positional_only_arg": positional_only_arg, "varargs": varargs, "kwargs": kwargs, } def varargs_code2(arg1, /, positional_only_arg, *varargs, **kwargs): # introduce a local variable local1 = 0 local2 = 1 return { "local": local1 + local2, "arg1": arg1, "positional_only_arg": positional_only_arg, "varargs": varargs, "kwargs": kwargs, } class FrameInitTests(torch._dynamo.test_case.TestCase): def test_frame_init(self): code_map1 = { target_with_varargs.__code__: varargs_code1.__code__, target_with_varkwargs.__code__: varkwargs_code1.__code__, } code_map2 = { target_with_varargs.__code__: varargs_code2.__code__, target_with_varkwargs.__code__: varkwargs_code2.__code__, } def callback1(frame, cache_entry, frame_state): if frame.f_code in code_map1: transformed_code = code_map1[frame.f_code] return torch._dynamo.types.GuardedCode( transformed_code, lambda f_locals: True, CompileId(0, 0) ) return None def callback2(frame, cache_entry, frame_state): if frame.f_code in code_map2: transformed_code = code_map2[frame.f_code] return torch._dynamo.types.GuardedCode( transformed_code, lambda f_locals: True, CompileId(0, 0) ) return None for callback in [callback1, callback2]: torch._dynamo.reset() expected_varargs_output = target_with_varargs( 1, 2, 3, 4, name1=1, name2=2, name3=3 ) expected_kwargs_output = target_with_varkwargs( 1, 2, keyword_only_arg=1, name2=2, name3=3 ) original = set_eval_frame(callback1) real_varargs_output = target_with_varargs( 1, 2, 3, 4, name1=1, name2=2, name3=3 ) real_kwargs_output = target_with_varkwargs( 1, 2, keyword_only_arg=1, name2=2, name3=3 ) self.assertEqual(real_varargs_output, expected_varargs_output) self.assertEqual(real_kwargs_output, expected_kwargs_output) set_eval_frame(original) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()