# Owner(s): ["oncall: distributed"] import sys from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( CUDAInitMode, FSDPInitMode, FSDPTest, NestedWrappedModule, ) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) class TestTraversal(FSDPTest): @property def world_size(self): return 2 @skip_if_lt_x_gpu(2) def test_fsdp_modules(self): nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, ) modules = FSDP.fsdp_modules(nested_wrapped_module) self.assertEqual( modules, [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("1").get_submodule("0"), nested_wrapped_module.module.get_submodule("2"), ], ) modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True) self.assertEqual( modules, [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("2"), ], ) if __name__ == "__main__": run_tests()