# Owner(s): ["module: dynamo"] import unittest import torch from functorch import make_fx from torch._dynamo import debug_utils from torch._dynamo.debug_utils import aot_graph_input_parser from torch._dynamo.test_case import TestCase from torch.testing._internal.inductor_utils import HAS_CUDA requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") f32 = torch.float32 i64 = torch.int64 i32 = torch.int32 class TestDebugUtils(TestCase): def test_cast_model_to_fp64_dtype_args(self): # Test that dtype arguments are converted to fp64 def fn(x): return ( torch.ops.prims.convert_element_type(x, torch.float16), x.to(torch.float16), torch.full(x.shape, 2, dtype=torch.float32, device=x.device), x.new_empty(x.shape), ) x = torch.randn(32, device="cpu") decomps = torch._decomp.core_aten_decompositions() fx = make_fx(fn, decomposition_table=decomps)(x) self.assertExpectedInline( fx.code.lstrip(), """\ def forward(self, x_1): convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16) _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) return (convert_element_type, _to_copy, full, empty) """, # NOQA: B950 ) fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,)) self.assertEqual(fp64_examples, (x.to(torch.float64),)) self.assertExpectedInline( fx.code.lstrip(), """\ def forward(self, x_1): convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64) _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) return (convert_element_type, _to_copy, full, empty) """, # NOQA: B950 ) @requires_cuda def test_aot_graph_parser(self): from torch import device def forward( self, primals_1: "f32[1001, 6]", primals_2: "f32[1001]", primals_3: "f32[1001, 64]", primals_4: "f32[4190]", primals_5: "f32[4190]", primals_6: "f32[1739, 4190]", primals_48: "f32[6144, 4191]", ): _tensor_constant0: "i64[4190]" = self._tensor_constant0 lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default( _tensor_constant0 ) _tensor_constant0 = None index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( primals_48, [None, lift_fresh_copy] ) lift_fresh_copy = None _tensor_constant1: "i64[6]" = self._tensor_constant1 lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default( _tensor_constant1 ) _tensor_constant1 = None index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor( primals_48, [None, lift_fresh_copy_1] ) primals_48 = lift_fresh_copy_1 = None permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0]) primals_1 = None addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default( primals_2, index_1, permute ) primals_2 = permute = None amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True) sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax) exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub) sub = None sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1) exp = None full_default: "i32[6144, 1001]" = torch.ops.aten.full.default( [6144, 1001], 1, dtype=torch.int32, layout=torch.strided, device=device(type="cuda", index=0), pin_memory=False, ) iota: "i32[1001]" = torch.ops.prims.iota.default( 1001, start=0, step=1, dtype=torch.int32, device=device(type="cuda"), requires_grad=False, ) mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota) full_default = iota = None iota_1: "i32[6144]" = torch.ops.prims.iota.default( 6144, start=0, step=1001, dtype=torch.int32, device=device(type="cuda", index=0), requires_grad=False, ) view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1]) mul = None view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1]) div = None _embedding_bag = torch.ops.aten._embedding_bag.default( primals_3, view, iota_1, False, 0, False, view_1 ) return _embedding_bag kwargs = aot_graph_input_parser(forward, device="cuda") # runs successfully forward(**kwargs) @requires_cuda def test_sym_aot_graph_parser(self): def forward( self, primals_1: "f32[1001, 6]", # noqa: F821 primals_2: "f32[s0]", # noqa: F821 primals_3: "Sym(s0)", # noqa: F821, primals_4: "f32[s1]", # noqa: F821, primals_5: "Sym(s1)", # noqa: F821, ): _tensor_constant0: "i64[4190]" = self._tensor_constant0 kwargs = aot_graph_input_parser( forward, device="cuda", sym_shapes={"s0": 10}, default_sym_shape=5 ) self.assertEqual(list(kwargs["primals_2"].shape), [10]) self.assertEqual(kwargs["primals_3"], 10) self.assertEqual(list(kwargs["primals_4"].shape), [5]) self.assertEqual(kwargs["primals_5"], 5) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()