1# Owner(s): ["oncall: distributed"] 2 3import sys 4from typing import Optional 5 6import torch 7import torch.nn as nn 8import torch.nn.functional as F 9from torch import distributed as dist 10from torch.distributed.algorithms._comm_hooks import default_hooks 11from torch.distributed.distributed_c10d import _get_default_group 12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision 13from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy 14from torch.distributed.fsdp.wrap import ModuleWrapPolicy 15from torch.testing._internal.common_distributed import ( 16 requires_nccl, 17 requires_nccl_version, 18 skip_but_pass_in_sandcastle_if, 19 skip_if_lt_x_gpu, 20) 21from torch.testing._internal.common_fsdp import FSDPTest 22from torch.testing._internal.common_utils import ( 23 instantiate_parametrized_tests, 24 parametrize, 25 run_tests, 26) 27 28 29if not dist.is_available(): 30 print("Distributed not available, skipping tests", file=sys.stderr) 31 sys.exit(0) 32 33# bfloat16 is only supported by CUDA 11+ 34BFLOAT16_AVAILABLE = torch.cuda.is_available() and ( 35 torch.version.cuda is not None or torch.version.hip is not None 36) 37 38 39class Net(nn.Module): 40 def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): 41 # to ensure determinism 42 torch.manual_seed(0) 43 torch.cuda.manual_seed(0) 44 super().__init__() 45 46 if has_wrapping: 47 self.net = FSDP( 48 nn.Sequential( 49 nn.Linear(8, 16), 50 nn.ReLU(), 51 FSDP( 52 nn.Linear(16, 8), 53 device_id=torch.cuda.current_device(), 54 sharding_strategy=sharding_strategy, 55 mixed_precision=mixed_precision, 56 ), 57 ), 58 device_id=torch.cuda.current_device(), 59 sharding_strategy=sharding_strategy, 60 mixed_precision=mixed_precision, 61 ) 62 else: 63 self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8)) 64 65 self.out = nn.Linear(8, 4) 66 67 def forward(self, x): 68 return self.out(F.relu(self.net(x))) 69 70 71class DummyState: 72 __slots__ = ["process_group", "noise"] 73 74 def __init__(self, process_group: dist.ProcessGroup, noise: int): 75 self.process_group = process_group 76 self.noise = noise 77 78 79class DummyHook: 80 def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor): 81 """ 82 This communication hook is for illustration and testing purpose only. 83 This communication hook is used during FSDP ``NO_SHARD`` training. It adds some noise to 84 the provided ``grad`` parameter and uses ``all_reduce`` to communicate full, flattened, 85 unsharded gradient. 86 """ 87 grad.add_(state.noise) 88 dist.all_reduce(grad, group=state.process_group) 89 90 def custom_reduce_scatter(self, output, input, group=None): 91 """ 92 This function is for illustrative purpose only. 93 It is meant to implement a custom reduce-scatter 94 of a flattened tensor to all processes in a group. 95 Currently a no-op. 96 """ 97 98 def dummy_hook_for_sharded_fsdp( 99 self, state: DummyState, grad: torch.Tensor, output: torch.Tensor 100 ): 101 """ 102 This communication hook is for illustration and testing purposes only. 103 This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training. 104 It adds some noise to the provided ``grad`` parameter, uses 105 ``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``. 106 """ 107 grad.add_(state.noise) 108 self.custom_reduce_scatter(output, grad, group=state.process_group) 109 110 111class TestCommunicationHooks(FSDPTest): 112 @skip_if_lt_x_gpu(2) 113 @parametrize( 114 "sharding_strategy", 115 [ 116 ShardingStrategy.NO_SHARD, 117 ShardingStrategy.FULL_SHARD, 118 ShardingStrategy.SHARD_GRAD_OP, 119 ], 120 ) 121 def test_default_communication_hook_behavior( 122 self, sharding_strategy: Optional[ShardingStrategy] 123 ): 124 """ 125 Tests FSDP's default communication hook's behavior and correctness. 126 This test creates a simple linear net with weight shape ``1 X N``, 127 where ``N`` is the number of workers. 128 For sharded cases, each worker gets 1 element of the weight parameter. This test 129 checks that after backward, each worker has a proper value in its chunk of 130 the gradient, or the whole gradient on every worker is equal to an expected value. 131 132 Arguments: 133 sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. 134 """ 135 out_dim = self.world_size 136 net = torch.nn.Linear(1, out_dim, bias=False) 137 inpt = torch.tensor([self.rank]).float().cuda(self.rank) 138 139 net_default_hook = FSDP( 140 net, 141 device_id=torch.cuda.current_device(), 142 sharding_strategy=sharding_strategy, 143 ).to(self.rank) 144 145 # Check that by default, `_comm_hook` is None 146 for entry in FSDP.fsdp_modules(net_default_hook): 147 self.assertEqual(entry._comm_hook, None) 148 149 for _ in range(4): 150 # Clear gradients 151 net_default_hook.zero_grad() 152 loss = net_default_hook(inpt).sum() 153 loss.backward() 154 155 # For each worker, the gradient on the weight should be worker_rank. 156 grad = net_default_hook.params[0].grad 157 expected_grad = ( 158 sum(i for i in range(dist.get_world_size())) / dist.get_world_size() 159 ) 160 # Verify default hook produces expected gradients 161 self.assertEqual( 162 grad[0].item(), 163 expected_grad, 164 msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}", 165 ) 166 167 def _get_submodules(self, fsdp_net): 168 return [ 169 submodule 170 for submodule in FSDP.fsdp_modules(fsdp_net) 171 if not submodule.check_is_root() 172 ] 173 174 def _init_model(self, core, sharding_strategy, mixed_precision=None): 175 device = torch.device("cuda") 176 return FSDP( 177 core, 178 device_id=torch.cuda.current_device(), 179 sharding_strategy=sharding_strategy, 180 mixed_precision=mixed_precision, 181 ).to(device) 182 183 @skip_if_lt_x_gpu(2) 184 @parametrize("has_wrapping", [True, False]) 185 @parametrize( 186 "sharding_strategy", 187 [ 188 ShardingStrategy.NO_SHARD, 189 ShardingStrategy.FULL_SHARD, 190 ShardingStrategy.SHARD_GRAD_OP, 191 ], 192 ) 193 def test_default_communication_hook_initialization( 194 self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] 195 ): 196 """ 197 Tests FSDP's communication hook interface behavior. 198 199 Arguments: 200 has_wrapping (bool): Configures wrapping of a module. 201 sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. 202 """ 203 204 # Initialize a model 205 fsdp_model_with_hook = self._init_model( 206 Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), 207 sharding_strategy=sharding_strategy, 208 ) 209 210 # Check that by default, `_comm_hook` is None 211 for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook): 212 self.assertEqual(fsdp_module._comm_hook, None) 213 214 dummy_state = DummyState(process_group=None, noise=1234) 215 dummy_hook = ( 216 DummyHook.dummy_hook_for_no_shard_fsdp 217 if sharding_strategy != ShardingStrategy.NO_SHARD 218 else DummyHook.dummy_hook_for_sharded_fsdp 219 ) 220 221 fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) 222 223 # Check that we can't register comm hook twice 224 with self.assertRaisesRegex( 225 AssertionError, "^A communication hook is already registered$" 226 ): 227 fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) 228 229 # Check dummy hook was registered for the root and all submodules if any 230 for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook): 231 self.assertEqual(fsdp_module._comm_hook, dummy_hook) 232 self.assertEqual(fsdp_module._comm_hook_state, dummy_state) 233 234 @skip_if_lt_x_gpu(2) 235 @parametrize( 236 "sharding_strategy", 237 [ 238 ShardingStrategy.NO_SHARD, 239 ShardingStrategy.FULL_SHARD, 240 ShardingStrategy.SHARD_GRAD_OP, 241 ], 242 ) 243 def test_registering_hook_non_root( 244 self, sharding_strategy: Optional[ShardingStrategy] 245 ): 246 """ 247 Tests FSDP's communication hook registering for submodules. 248 Make sure it can't be registered for non-root submodules. 249 Currently tests only ``NO_SHARD`` strategy. 250 251 Arguments: 252 sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. 253 """ 254 255 fsdp_model_with_hook = self._init_model( 256 Net(has_wrapping=True, sharding_strategy=sharding_strategy), 257 sharding_strategy=sharding_strategy, 258 ) 259 dummy_state = DummyState(process_group=None, noise=1234) 260 dummy_hook = ( 261 DummyHook.dummy_hook_for_no_shard_fsdp 262 if sharding_strategy != ShardingStrategy.NO_SHARD 263 else DummyHook.dummy_hook_for_sharded_fsdp 264 ) 265 # Creating a list of non-root submodules to test 266 submodules = self._get_submodules(fsdp_model_with_hook) 267 # Check that assertion is raised for registering a comm hook on a non-root 268 with self.assertRaisesRegex( 269 AssertionError, 270 "^register_comm_hook can only be called on a root instance.$", 271 ): 272 submodules[1].register_comm_hook(dummy_state, dummy_hook) 273 274 @skip_if_lt_x_gpu(2) 275 def test_registering_hook_hybrid_strategy(self): 276 for sharding_strategy in ( 277 ShardingStrategy.HYBRID_SHARD, 278 ShardingStrategy._HYBRID_SHARD_ZERO2, 279 ): 280 model = Net(False, None, None).cuda() 281 fsdp_model = FSDP( 282 model, 283 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), 284 sharding_strategy=sharding_strategy, 285 ) 286 dummy_state = DummyState(process_group=None, noise=1234) 287 dummy_hook = DummyHook.dummy_hook_for_sharded_fsdp 288 with self.assertRaisesRegex( 289 AssertionError, 290 "Communication hook is not supported for hybrid strategies", 291 ): 292 fsdp_model.register_comm_hook(dummy_state, dummy_hook) 293 294 @skip_if_lt_x_gpu(2) 295 @parametrize( 296 "sharding_strategy", 297 [ 298 ShardingStrategy.NO_SHARD, 299 ShardingStrategy.FULL_SHARD, 300 ShardingStrategy.SHARD_GRAD_OP, 301 ], 302 ) 303 def test_registering_hook_submodules( 304 self, sharding_strategy: Optional[ShardingStrategy] 305 ): 306 """ 307 Tests FSDP's communication hook registering for submodules. 308 Checks behavior if a hook was registered for a non-root submodule 309 Currently tests only ``NO_SHARD`` strategy. 310 311 Arguments: 312 sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. 313 """ 314 315 fsdp_model_with_hook = self._init_model( 316 Net(has_wrapping=True, sharding_strategy=sharding_strategy), 317 sharding_strategy=sharding_strategy, 318 ) 319 dummy_state = DummyState(process_group=None, noise=1234) 320 dummy_hook = ( 321 DummyHook.dummy_hook_for_no_shard_fsdp 322 if sharding_strategy != ShardingStrategy.NO_SHARD 323 else DummyHook.dummy_hook_for_sharded_fsdp 324 ) 325 submodules = self._get_submodules(fsdp_model_with_hook) 326 327 # Simulate a registration of a hook on a submodule 328 submodules[1]._comm_hook = dummy_hook 329 # Check that an error is raised when some of submodules have a non-default hook assigned 330 with self.assertRaisesRegex( 331 AssertionError, "^A communication hook is already registered$" 332 ): 333 fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) 334 335 def _check_low_precision_hook( 336 self, state, hook, sharding_strategy, dtype, has_wrapping 337 ): 338 # keep everything deterministic for input data 339 torch.manual_seed(0) 340 torch.cuda.manual_seed(0) 341 342 fsdp_with_hook = self._init_model( 343 Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), 344 sharding_strategy=sharding_strategy, 345 ) 346 fsdp_with_hook.register_comm_hook(state, hook) 347 348 mp_only_grad = MixedPrecision(reduce_dtype=dtype) 349 fsdp_with_mp = self._init_model( 350 Net( 351 has_wrapping=has_wrapping, 352 sharding_strategy=sharding_strategy, 353 mixed_precision=mp_only_grad, 354 ), 355 sharding_strategy=sharding_strategy, 356 mixed_precision=mp_only_grad, 357 ) 358 359 optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) 360 optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1) 361 362 in_data = torch.rand(16, 8).cuda() 363 fsdp_with_hook.train() 364 fsdp_with_mp.train() 365 loss_hook = fsdp_with_hook(in_data).sum() 366 loss_mp = fsdp_with_mp(in_data).sum() 367 loss_hook.backward() 368 # Make sure grads were cast to the parameter's precision 369 self.assertEqual(fsdp_with_hook.params[0].grad.dtype, state.parameter_type) 370 loss_mp.backward() 371 optim_hook.step() 372 optim_mp.step() 373 374 dist.barrier() 375 376 for hook_param, mp_param in zip( 377 fsdp_with_hook.parameters(), fsdp_with_mp.parameters() 378 ): 379 self.assertEqual(hook_param.grad, mp_param.grad) 380 381 @requires_nccl() 382 @skip_if_lt_x_gpu(2) 383 @parametrize("has_wrapping", [True, False]) 384 @parametrize( 385 "sharding_strategy", 386 [ 387 ShardingStrategy.NO_SHARD, 388 ShardingStrategy.FULL_SHARD, 389 ShardingStrategy.SHARD_GRAD_OP, 390 ], 391 ) 392 def test_fp16_hook( 393 self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] 394 ): 395 state = default_hooks.LowPrecisionState(process_group=_get_default_group()) 396 hook = default_hooks.fp16_compress_hook 397 398 self._check_low_precision_hook( 399 state, hook, sharding_strategy, torch.float16, has_wrapping 400 ) 401 402 @requires_nccl() 403 @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS") 404 @skip_but_pass_in_sandcastle_if( 405 not BFLOAT16_AVAILABLE, 406 "BFloat16 is only supported by CUDA 11+", 407 ) 408 @skip_if_lt_x_gpu(2) 409 @parametrize("has_wrapping", [True, False]) 410 @parametrize( 411 "sharding_strategy", 412 [ 413 ShardingStrategy.NO_SHARD, 414 ShardingStrategy.FULL_SHARD, 415 ShardingStrategy.SHARD_GRAD_OP, 416 ], 417 ) 418 def test_bf16_hook( 419 self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] 420 ): 421 state = default_hooks.LowPrecisionState(process_group=_get_default_group()) 422 hook = default_hooks.bf16_compress_hook 423 424 self._check_low_precision_hook( 425 state, hook, sharding_strategy, torch.bfloat16, has_wrapping 426 ) 427 428 429instantiate_parametrized_tests(TestCommunicationHooks) 430 431if __name__ == "__main__": 432 run_tests() 433