# Owner(s): ["oncall: distributed"] import sys from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import distributed as dist from torch.distributed.algorithms._comm_hooks import default_hooks from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.common_distributed import ( requires_nccl, requires_nccl_version, skip_but_pass_in_sandcastle_if, skip_if_lt_x_gpu, ) from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, ) if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) # bfloat16 is only supported by CUDA 11+ BFLOAT16_AVAILABLE = torch.cuda.is_available() and ( torch.version.cuda is not None or torch.version.hip is not None ) class Net(nn.Module): def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): # to ensure determinism torch.manual_seed(0) torch.cuda.manual_seed(0) super().__init__() if has_wrapping: self.net = FSDP( nn.Sequential( nn.Linear(8, 16), nn.ReLU(), FSDP( nn.Linear(16, 8), device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ), ), device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ) else: self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8)) self.out = nn.Linear(8, 4) def forward(self, x): return self.out(F.relu(self.net(x))) class DummyState: __slots__ = ["process_group", "noise"] def __init__(self, process_group: dist.ProcessGroup, noise: int): self.process_group = process_group self.noise = noise class DummyHook: def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor): """ This communication hook is for illustration and testing purpose only. This communication hook is used during FSDP ``NO_SHARD`` training. It adds some noise to the provided ``grad`` parameter and uses ``all_reduce`` to communicate full, flattened, unsharded gradient. """ grad.add_(state.noise) dist.all_reduce(grad, group=state.process_group) def custom_reduce_scatter(self, output, input, group=None): """ This function is for illustrative purpose only. It is meant to implement a custom reduce-scatter of a flattened tensor to all processes in a group. Currently a no-op. """ def dummy_hook_for_sharded_fsdp( self, state: DummyState, grad: torch.Tensor, output: torch.Tensor ): """ This communication hook is for illustration and testing purposes only. This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training. It adds some noise to the provided ``grad`` parameter, uses ``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``. """ grad.add_(state.noise) self.custom_reduce_scatter(output, grad, group=state.process_group) class TestCommunicationHooks(FSDPTest): @skip_if_lt_x_gpu(2) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ], ) def test_default_communication_hook_behavior( self, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's default communication hook's behavior and correctness. This test creates a simple linear net with weight shape ``1 X N``, where ``N`` is the number of workers. For sharded cases, each worker gets 1 element of the weight parameter. This test checks that after backward, each worker has a proper value in its chunk of the gradient, or the whole gradient on every worker is equal to an expected value. Arguments: sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. """ out_dim = self.world_size net = torch.nn.Linear(1, out_dim, bias=False) inpt = torch.tensor([self.rank]).float().cuda(self.rank) net_default_hook = FSDP( net, device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, ).to(self.rank) # Check that by default, `_comm_hook` is None for entry in FSDP.fsdp_modules(net_default_hook): self.assertEqual(entry._comm_hook, None) for _ in range(4): # Clear gradients net_default_hook.zero_grad() loss = net_default_hook(inpt).sum() loss.backward() # For each worker, the gradient on the weight should be worker_rank. grad = net_default_hook.params[0].grad expected_grad = ( sum(i for i in range(dist.get_world_size())) / dist.get_world_size() ) # Verify default hook produces expected gradients self.assertEqual( grad[0].item(), expected_grad, msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}", ) def _get_submodules(self, fsdp_net): return [ submodule for submodule in FSDP.fsdp_modules(fsdp_net) if not submodule.check_is_root() ] def _init_model(self, core, sharding_strategy, mixed_precision=None): device = torch.device("cuda") return FSDP( core, device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ).to(device) @skip_if_lt_x_gpu(2) @parametrize("has_wrapping", [True, False]) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ], ) def test_default_communication_hook_initialization( self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's communication hook interface behavior. Arguments: has_wrapping (bool): Configures wrapping of a module. sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. """ # Initialize a model fsdp_model_with_hook = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), sharding_strategy=sharding_strategy, ) # Check that by default, `_comm_hook` is None for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook): self.assertEqual(fsdp_module._comm_hook, None) dummy_state = DummyState(process_group=None, noise=1234) dummy_hook = ( DummyHook.dummy_hook_for_no_shard_fsdp if sharding_strategy != ShardingStrategy.NO_SHARD else DummyHook.dummy_hook_for_sharded_fsdp ) fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) # Check that we can't register comm hook twice with self.assertRaisesRegex( AssertionError, "^A communication hook is already registered$" ): fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) # Check dummy hook was registered for the root and all submodules if any for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook): self.assertEqual(fsdp_module._comm_hook, dummy_hook) self.assertEqual(fsdp_module._comm_hook_state, dummy_state) @skip_if_lt_x_gpu(2) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ], ) def test_registering_hook_non_root( self, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's communication hook registering for submodules. Make sure it can't be registered for non-root submodules. Currently tests only ``NO_SHARD`` strategy. Arguments: sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. """ fsdp_model_with_hook = self._init_model( Net(has_wrapping=True, sharding_strategy=sharding_strategy), sharding_strategy=sharding_strategy, ) dummy_state = DummyState(process_group=None, noise=1234) dummy_hook = ( DummyHook.dummy_hook_for_no_shard_fsdp if sharding_strategy != ShardingStrategy.NO_SHARD else DummyHook.dummy_hook_for_sharded_fsdp ) # Creating a list of non-root submodules to test submodules = self._get_submodules(fsdp_model_with_hook) # Check that assertion is raised for registering a comm hook on a non-root with self.assertRaisesRegex( AssertionError, "^register_comm_hook can only be called on a root instance.$", ): submodules[1].register_comm_hook(dummy_state, dummy_hook) @skip_if_lt_x_gpu(2) def test_registering_hook_hybrid_strategy(self): for sharding_strategy in ( ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2, ): model = Net(False, None, None).cuda() fsdp_model = FSDP( model, auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), sharding_strategy=sharding_strategy, ) dummy_state = DummyState(process_group=None, noise=1234) dummy_hook = DummyHook.dummy_hook_for_sharded_fsdp with self.assertRaisesRegex( AssertionError, "Communication hook is not supported for hybrid strategies", ): fsdp_model.register_comm_hook(dummy_state, dummy_hook) @skip_if_lt_x_gpu(2) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ], ) def test_registering_hook_submodules( self, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's communication hook registering for submodules. Checks behavior if a hook was registered for a non-root submodule Currently tests only ``NO_SHARD`` strategy. Arguments: sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. """ fsdp_model_with_hook = self._init_model( Net(has_wrapping=True, sharding_strategy=sharding_strategy), sharding_strategy=sharding_strategy, ) dummy_state = DummyState(process_group=None, noise=1234) dummy_hook = ( DummyHook.dummy_hook_for_no_shard_fsdp if sharding_strategy != ShardingStrategy.NO_SHARD else DummyHook.dummy_hook_for_sharded_fsdp ) submodules = self._get_submodules(fsdp_model_with_hook) # Simulate a registration of a hook on a submodule submodules[1]._comm_hook = dummy_hook # Check that an error is raised when some of submodules have a non-default hook assigned with self.assertRaisesRegex( AssertionError, "^A communication hook is already registered$" ): fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) def _check_low_precision_hook( self, state, hook, sharding_strategy, dtype, has_wrapping ): # keep everything deterministic for input data torch.manual_seed(0) torch.cuda.manual_seed(0) fsdp_with_hook = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), sharding_strategy=sharding_strategy, ) fsdp_with_hook.register_comm_hook(state, hook) mp_only_grad = MixedPrecision(reduce_dtype=dtype) fsdp_with_mp = self._init_model( Net( has_wrapping=has_wrapping, sharding_strategy=sharding_strategy, mixed_precision=mp_only_grad, ), sharding_strategy=sharding_strategy, mixed_precision=mp_only_grad, ) optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1) in_data = torch.rand(16, 8).cuda() fsdp_with_hook.train() fsdp_with_mp.train() loss_hook = fsdp_with_hook(in_data).sum() loss_mp = fsdp_with_mp(in_data).sum() loss_hook.backward() # Make sure grads were cast to the parameter's precision self.assertEqual(fsdp_with_hook.params[0].grad.dtype, state.parameter_type) loss_mp.backward() optim_hook.step() optim_mp.step() dist.barrier() for hook_param, mp_param in zip( fsdp_with_hook.parameters(), fsdp_with_mp.parameters() ): self.assertEqual(hook_param.grad, mp_param.grad) @requires_nccl() @skip_if_lt_x_gpu(2) @parametrize("has_wrapping", [True, False]) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ], ) def test_fp16_hook( self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] ): state = default_hooks.LowPrecisionState(process_group=_get_default_group()) hook = default_hooks.fp16_compress_hook self._check_low_precision_hook( state, hook, sharding_strategy, torch.float16, has_wrapping ) @requires_nccl() @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS") @skip_but_pass_in_sandcastle_if( not BFLOAT16_AVAILABLE, "BFloat16 is only supported by CUDA 11+", ) @skip_if_lt_x_gpu(2) @parametrize("has_wrapping", [True, False]) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ], ) def test_bf16_hook( self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] ): state = default_hooks.LowPrecisionState(process_group=_get_default_group()) hook = default_hooks.bf16_compress_hook self._check_low_precision_hook( state, hook, sharding_strategy, torch.bfloat16, has_wrapping ) instantiate_parametrized_tests(TestCommunicationHooks) if __name__ == "__main__": run_tests()