# Owner(s): ["oncall: distributed"] import sys import torch from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import Linear, Module from torch.optim import SGD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, subtest, TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) class TestInput(FSDPTest): @property def world_size(self): return 1 @skip_if_lt_x_gpu(1) @parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")]) def test_input_type(self, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" class Model(Module): def __init__(self) -> None: super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model()).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: self.assertTrue(input_cls is dict) in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() instantiate_parametrized_tests(TestInput) if __name__ == "__main__": run_tests()