1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6from torch import distributed as dist 7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 8from torch.nn import Linear, Module, Sequential 9from torch.optim import SGD 10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 11from torch.testing._internal.common_fsdp import FSDPTest 12from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 13 14 15if not dist.is_available(): 16 print("Distributed not available, skipping tests", file=sys.stderr) 17 sys.exit(0) 18 19if TEST_WITH_DEV_DBG_ASAN: 20 print( 21 "Skip dev-asan as torch + multiprocessing spawn have known issues", 22 file=sys.stderr, 23 ) 24 sys.exit(0) 25 26 27class InnerModel(Module): 28 def __init__(self) -> None: 29 super().__init__() 30 self.layers = Sequential(FSDP(Linear(5, 5))) 31 32 def forward(self, x): 33 return self.layers(x) 34 35 36class TestMultipleWrapping(FSDPTest): 37 @skip_if_lt_x_gpu(2) 38 def test_multiple_wrapping(self): 39 """ 40 This test simulates wrapping the module after training to run inference. 41 This is required in cases where later in a session, the model is wrapped again in FSDP but 42 contains nested FSDP wrappers within the module. 43 """ 44 inner_model = InnerModel() 45 model = FSDP(inner_model).cuda() 46 optim = SGD(model.parameters(), lr=0.1) 47 48 for i in range(3): 49 input = torch.rand((1, 5), dtype=torch.float).cuda() 50 input.requires_grad = True 51 output = model(input) 52 output.sum().backward() 53 optim.step() 54 optim.zero_grad() 55 input = torch.rand((1, 5), dtype=torch.float).cuda() 56 output = model(input) 57 58 # second time to rewrap the inner model 59 rewrapped_model = FSDP(inner_model).cuda() 60 rewrapped_output = rewrapped_model(input) 61 62 self.assertEqual(output, rewrapped_output) 63 64 65if __name__ == "__main__": 66 run_tests() 67