# Owner(s): ["oncall: distributed"] import torch from torch.distributed.fsdp._trace_utils import _ExecOrderTracer from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, run_tests, TestCase, ) class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight1 = torch.nn.Parameter(torch.randn(6, 6)) self.weight2 = torch.nn.Parameter(torch.randn(6, 6)) self.weight_unused = torch.nn.Parameter(torch.randn(2, 2)) self.layer0 = torch.nn.Linear(6, 6) self.layer1 = torch.nn.Linear(6, 6, bias=False) self.layer2 = torch.nn.Sequential( torch.nn.Linear(6, 3, bias=False), torch.nn.ReLU(), torch.nn.Linear(3, 6, bias=False), ) self.relu = torch.nn.ReLU() def forward(self, x: torch.Tensor, run_all_layers: bool) -> torch.Tensor: z = self.relu(self.layer0(x)) z = self.relu(self.layer2(z)) z = z @ self.weight1 if run_all_layers: z = self.relu(self.layer1(z)) z = z @ self.weight2 # Use `layer0` twice to check the handling of multiplicity in the # saved data structures z = self.relu(self.layer0(x)) return z class TestSymbolicTracing(TestCase): def test_symbolic_tracing_outputs(self): """ Tests running ``tracer.trace()`` inside ``patch_tracer()`` by checking the saved data structures. """ model = Model() tracer = torch.fx.Tracer() orig_call_module = tracer.call_module orig_create_proxy = tracer.create_proxy exec_order_tracer = _ExecOrderTracer() with exec_order_tracer.patch_tracer(tracer=tracer, root_module=model): concrete_args = {"run_all_layers": True} tracer.trace(model, concrete_args) # Check that the tracer methods are unchanged after exiting the context self.assertEqual(orig_call_module, tracer.call_module) self.assertEqual(orig_create_proxy, tracer.create_proxy) # Check `module_forward_order` correct_module_forward_order = [ model, model.layer0, model.relu, model.layer2, model.layer2[0], model.layer2[1], model.layer2[2], model.relu, model.layer1, model.relu, model.layer0, model.relu, ] exec_info = exec_order_tracer.exec_info self.assertEqual(exec_info.module_forward_order, correct_module_forward_order) # Check `module_to_param_usage_infos` self.assertEqual( exec_info.module_to_param_usage_infos[model], [ (model.layer0, list(model.layer0.named_parameters())), (model.layer2, list(model.layer2.named_parameters())), (model, [("weight1", model.weight1)]), (model.layer1, list(model.layer1.named_parameters())), (model, [("weight2", model.weight2)]), (model.layer0, list(model.layer0.named_parameters())), ], ) self.assertEqual( exec_info.module_to_param_usage_infos[model.layer0], [(model.layer0, list(model.layer0.named_parameters()))], ) self.assertEqual( exec_info.module_to_param_usage_infos[model.layer1], [(model.layer1, list(model.layer1.named_parameters()))], ) self.assertEqual( exec_info.module_to_param_usage_infos[model.layer2], [ (model.layer2[0], list(model.layer2[0].named_parameters())), (model.layer2[2], list(model.layer2[2].named_parameters())), ], ) self.assertEqual(exec_info.module_to_param_usage_infos[model.relu], []) # Check `param_forward_order` correct_param_order = [ model.layer0.weight, model.layer0.bias, model.layer2[0].weight, model.layer2[2].weight, model.weight1, model.layer1.weight, model.weight2, ] self.assertEqual(exec_info.param_forward_order, correct_param_order) # Check `visited_params` self.assertEqual( len(exec_info.visited_params), len(exec_info.param_forward_order) ) self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order)) instantiate_parametrized_tests(TestSymbolicTracing) if __name__ == "__main__": run_tests()