1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6from torch import distributed as dist 7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 8from torch.nn import Linear, Module 9from torch.nn.parallel import DistributedDataParallel 10from torch.optim import SGD 11from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 12from torch.testing._internal.common_fsdp import FSDPTest, get_full_params 13from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 14 15 16if not dist.is_available(): 17 print("Distributed not available, skipping tests", file=sys.stderr) 18 sys.exit(0) 19 20if TEST_WITH_DEV_DBG_ASAN: 21 print( 22 "Skip dev-asan as torch + multiprocessing spawn have known issues", 23 file=sys.stderr, 24 ) 25 sys.exit(0) 26 27 28class Model(Module): 29 def __init__(self, wrap_fsdp): 30 super().__init__() 31 # keep everything deterministic for model initialization 32 torch.manual_seed(0) 33 self.inner = Linear(4, 4) 34 if wrap_fsdp: 35 self.inner = FSDP(self.inner) 36 self.outer = Linear(4, 5) 37 38 def forward(self, x): 39 # Forward twice. 40 i = self.inner(x) 41 j = self.inner(x) 42 return self.outer(i + j) 43 44 45class TestMultiForward(FSDPTest): 46 def _dist_train(self, wrap_fsdp): 47 # keep everything deterministic for input data 48 torch.manual_seed(0) 49 50 model = Model(wrap_fsdp).cuda() 51 if wrap_fsdp: 52 model = FSDP(model) 53 else: 54 model = DistributedDataParallel(model, device_ids=[self.rank]) 55 optim = SGD(model.parameters(), lr=0.1) 56 57 in_data = torch.rand(64, 4).cuda() 58 in_data.requires_grad = True 59 for _ in range(3): 60 out = model(in_data) 61 out.sum().backward() 62 optim.step() 63 optim.zero_grad() 64 65 if wrap_fsdp: 66 return get_full_params(model) 67 68 return list(model.parameters()) 69 70 @skip_if_lt_x_gpu(2) 71 def test_multi_forward(self): 72 # DDP 73 ddp_state = self._dist_train(wrap_fsdp=False) 74 75 # FSDP 76 fsdp_state = self._dist_train(wrap_fsdp=True) 77 78 self.assertEqual(ddp_state, fsdp_state) 79 80 81if __name__ == "__main__": 82 run_tests() 83