# Owner(s): ["oncall: distributed"] import sys import torch import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( CUDAInitMode, FSDPInitMode, FSDPTest, NestedWrappedModule, TransformerWithSharedParams, ) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) class TestApply(FSDPTest): @property def world_size(self): return 2 @torch.no_grad() def _init_linear_weights(self, m): if type(m) == nn.Linear: m.weight.fill_(1.0) m.bias.fill_(1.0) def check_weights(self, fsdp, expected_tensor_fn, check): with FSDP.summon_full_params(fsdp, recurse=True): linear_modules = [ module for module in fsdp.modules() if type(module) == nn.Linear ] for module in linear_modules: for param in module.parameters(): expected = expected_tensor_fn(param) check(param, expected, f"Got {param} but expected {expected}") def _check_apply(self, fsdp): # Assert linear weights are not all 1.0 self.check_weights( fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertNotEqual ) fsdp.apply(self._init_linear_weights) # Ensure all weights are 1.0 self.check_weights( fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertEqual ) @skip_if_lt_x_gpu(2) def test_nested_module_apply(self): """Tests that ``apply()`` modifies parameter values in-place on a non-FSDP-root nested FSDP-wrapped model.""" nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, ) self._check_apply(nested_wrapped_module) @skip_if_lt_x_gpu(2) def test_transformer_module_apply(self): """Tests that ``apply()`` modifies parameter values in-place on an FSDP-wrapped transformer model with shared parameters.""" transformer = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, ) self._check_apply(transformer) @skip_if_lt_x_gpu(2) def test_apply_in_summon_raises_error(self): """Tests that calling ``apply()`` on an FSDP instance inside the ``summon_full_params()`` context raises an error.""" transformer = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, ) with transformer.summon_full_params(transformer): with self.assertRaisesRegex(ValueError, "expected to be in states"): transformer.apply(self._init_linear_weights) if __name__ == "__main__": run_tests()