1# Owner(s): ["oncall: distributed"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 4# 5# This source code is licensed under the BSD license found in the 6# LICENSE file in the root directory of this source tree. 7 8import copy 9import os 10import sys 11import unittest 12from contextlib import nullcontext 13from typing import Any, cast, List 14 15import numpy as np 16 17import torch 18import torch.distributed as dist 19 20 21if not dist.is_available(): 22 print("Distributed not available, skipping tests", file=sys.stderr) 23 sys.exit(0) 24from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import ( 25 hook_with_zero_step, 26 hook_with_zero_step_interleaved, 27) 28from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook 29from torch.distributed.algorithms.join import Join, Joinable, JoinHook 30from torch.distributed.optim import ZeroRedundancyOptimizer 31from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object 32from torch.nn.parallel import DistributedDataParallel as DDP 33from torch.optim import AdamW, SGD 34from torch.testing._internal import common_distributed 35from torch.testing._internal.common_utils import ( 36 instantiate_parametrized_tests, 37 IS_WINDOWS, 38 parametrize, 39 run_tests, 40 TEST_WITH_ASAN, 41 TEST_WITH_DEV_DBG_ASAN, 42) 43 44 45try: 46 import torchvision 47 48 HAS_TORCHVISION = True 49except ImportError: 50 HAS_TORCHVISION = False 51 52 53# Use GLOO on GPU when running CUDA + Windows 54def _get_backend_for_tests(): 55 return ( 56 dist.Backend.NCCL 57 if not IS_WINDOWS and torch.cuda.is_available() 58 # Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when 59 # no GPUs are available. 60 else dist.Backend.GLOO 61 ) 62 63 64BACKEND = _get_backend_for_tests() 65 66 67@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.") 68class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase): 69 def setUp(self): 70 super().setUp() 71 os.environ["WORLD_SIZE"] = str(self.world_size) 72 self._spawn_processes() 73 74 @property 75 def device(self): 76 return ( 77 torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 78 ) 79 80 @property 81 def world_size(self): 82 return 1 83 84 def tearDown(self): 85 try: 86 torch.distributed.destroy_process_group() 87 except AssertionError: 88 pass 89 try: 90 os.remove(self.file_name) 91 except OSError: 92 pass 93 94 def dist_init(self, rank, world_size=-1, backend=BACKEND): 95 if world_size < 1: 96 world_size = self.world_size 97 store = dist.FileStore(self.file_name, world_size) 98 return dist.init_process_group( 99 backend=backend, 100 store=store, 101 rank=rank, 102 world_size=world_size, 103 ) 104 105 106# TODO: skip_but_pass_in_sandcastle_if does not work here. 107@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.") 108class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer): 109 def test_state_dict(self): 110 """Check that ZeroRedundancyOptimizer exposes the expected state dict 111 interface, irrespective of the sharding.""" 112 self.dist_init(self.rank) 113 LR1 = 0.1 114 LR2 = 0.01 115 MOMENTUM = 0.9 116 RECIPIENT_RANK = 0 # rank 0 is the only rank since the world size is 1 117 x = torch.tensor([1.0], device=self.device, requires_grad=True) 118 o = ZeroRedundancyOptimizer( 119 [x], 120 optimizer_class=SGD, 121 lr=LR1, 122 momentum=MOMENTUM, 123 ) 124 x.backward() 125 o.step() 126 self.assertEqual(x, torch.tensor([0.9], device=self.device)) 127 self.assertEqual( 128 o.optim.state[x]["momentum_buffer"], 129 torch.tensor([1.0], device=self.device), 130 ) 131 132 o.zero_grad() 133 o.consolidate_state_dict(to=RECIPIENT_RANK) 134 state_dict = o.state_dict() 135 136 # Check that the state dict has keys compliant with PyTorch 137 self.assertIn("param_groups", state_dict.keys()) 138 self.assertIn("state", state_dict.keys()) 139 140 # Check that the state has the expected keys 141 self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1) 142 self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9) 143 self.assertFalse(state_dict["param_groups"][0]["nesterov"]) 144 self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0) 145 self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0) 146 147 # Check that the state and the `param_groups` attribute are in sync 148 for k in state_dict["param_groups"][0]: 149 if k != "params": 150 self.assertEqual( 151 state_dict["param_groups"][0][k], 152 o.param_groups[0][k], 153 ) 154 155 # Check that the state is reloaded with the correct values and device 156 o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR2) 157 o.load_state_dict(state_dict) 158 self.assertEqual( 159 o.optim.state[x]["momentum_buffer"], 160 torch.tensor([1.0], device=self.device), 161 ) 162 163 # We should we using `LR1` and not `LR2` after reloading, both within 164 # the optimizer and as exposed by the `param_groups` attribute 165 self.assertEqual(o.param_groups[0]["lr"], LR1) 166 x.backward() 167 o.step() 168 self.assertEqual(x, torch.tensor([0.71], device=self.device)) 169 self.assertEqual( 170 o.optim.state[x]["momentum_buffer"], 171 torch.tensor([1.9], device=self.device), 172 ) 173 174 # Check that the exposed `param_groups`` are on the proper device 175 self.assertEqual(o.param_groups[0]["params"][0].device, x.device) 176 177 def test_lr_scheduler(self): 178 """Check that a normal PyTorch ``lr_scheduler`` is usable with 179 ZeroRedundancyOptimizer.""" 180 self.dist_init(self.rank) 181 NUM_ITERS = 5 182 LR = 0.01 183 x = torch.tensor([1.0], device=self.device, requires_grad=True) 184 x2 = torch.tensor([1.0], device=self.device, requires_grad=True) 185 o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR) 186 o2 = torch.optim.SGD([x2], lr=LR) 187 s = torch.optim.lr_scheduler.StepLR(o, 1) 188 s2 = torch.optim.lr_scheduler.StepLR(o2, 1) 189 for _ in range(NUM_ITERS): 190 x.backward() 191 o.zero_grad() 192 o.step() 193 s.step() 194 x2.backward() 195 o2.zero_grad() 196 o2.step() 197 s2.step() 198 self.assertEqual(x, x2) 199 200 def test_step_with_kwargs(self): 201 """Check that the ``step(**kwargs)`` interface is properly exposed.""" 202 self.dist_init(self.rank) 203 LR = 0.1 204 205 class SGDWithStepKWArg(torch.optim.SGD): 206 def step(self, closure=None, kwarg=None): 207 super().step() 208 kwarg.append(5) 209 210 kwarg: List[Any] = [] 211 x = torch.tensor([1.0], device=self.device, requires_grad=True) 212 o = ZeroRedundancyOptimizer( 213 [x], 214 optimizer_class=SGDWithStepKWArg, 215 lr=LR, 216 ) 217 x.backward() 218 o.step(0, kwarg=kwarg) 219 self.assertEqual(kwarg, [5]) 220 self.assertEqual(x, torch.tensor([0.9], device=self.device)) 221 222 def test_step_with_extra_inner_key(self): 223 """Check that ZeroRedundancyOptimizer wrapping an optimizer that adds 224 extra keys to ``param_groups`` exposes those keys through ZeRO's own 225 ``param_groups``.""" 226 self.dist_init(self.rank) 227 LR = 0.1 228 229 class SGDWithNewKey(torch.optim.SGD): 230 # Dummy optimizer which adds a new key to the param groups 231 def step(self, closure=None): 232 super().step() 233 self.param_groups[0]["new_key"] = 0.1 234 235 x = torch.tensor([1.0], device=self.device, requires_grad=True) 236 o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithNewKey, lr=LR) 237 x.backward() 238 o.step() 239 self.assertEqual(o.param_groups[0]["new_key"], 0.1) 240 self.assertEqual(x, torch.tensor([0.9], device=self.device)) 241 242 def test_step_without_closure(self): 243 """Check that the ``step()`` method (without closure) is handled as 244 expected.""" 245 self.dist_init(self.rank) 246 LR = 0.1 247 248 class SGDWithoutClosure(torch.optim.SGD): 249 def step(self): 250 return super().step() 251 252 x = torch.tensor([1.0], device=self.device, requires_grad=True) 253 o = ZeroRedundancyOptimizer( 254 [x], 255 optimizer_class=SGDWithoutClosure, 256 lr=LR, 257 ) 258 x.backward() 259 o.step() 260 self.assertEqual(x, torch.tensor([0.9], device=self.device)) 261 262 def test_zero_grad(self): 263 """Check that the ``zero_grad`` method is properly handled.""" 264 self.dist_init(self.rank) 265 LR = 0.01 266 x = torch.rand(1) 267 m = torch.nn.Linear(1, 1) 268 o = ZeroRedundancyOptimizer(m.parameters(), optimizer_class=SGD, lr=LR) 269 y = m(x) 270 y.backward(x) 271 self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight)) 272 self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight)) 273 o.zero_grad() 274 self.assertIsNone(m.weight.grad) 275 self.assertIsNone(m.bias.grad) 276 277 def test_constructor(self): 278 """Check the robustness of the ZeroRedundancyOptimizer constructor by 279 passing different values for the ``params`` argument.""" 280 self.dist_init(self.rank) 281 LR = 0.01 282 m = torch.nn.Sequential( 283 torch.nn.Linear(5, 10), 284 torch.nn.Linear(10, 10), 285 torch.nn.Linear(10, 10), 286 ) 287 # Test various constructor inputs in the form: (input, expected error) 288 ctor_inputs = [ 289 ([], ValueError), # empty parameter list 290 (torch.randn(1), TypeError), # non-iterable: `torch.Tensor` 291 (1.2, TypeError), # non-iterable: `float` 292 ( 293 [ 294 {"params": [l.weight for l in m]}, 295 {"params": [l.bias for l in m]}, 296 ], 297 None, 298 ), # iterable of dict 299 ( 300 list(m.parameters()) + [42], 301 TypeError, 302 ), # iterable containing invalid type 303 (m.parameters(), None), # `params` as a generator 304 (list(m.parameters()), None), # `params` as a list 305 ] 306 for ctor_input, error in ctor_inputs: 307 context = self.assertRaises(error) if error else nullcontext() 308 with context: 309 ZeroRedundancyOptimizer( 310 ctor_input, 311 optimizer_class=SGD, 312 lr=LR, 313 ) 314 315 # Test constructing with multiple parameter groups more thoroughly 316 WD = 0.01 317 BETAS = (0.9, 0.999) 318 EPS = 1e-8 319 params = [ 320 {"params": [l.weight for l in m], "weight_decay": 0.0}, 321 {"params": [l.bias for l in m], "weight_decay": WD}, 322 ] 323 o = ZeroRedundancyOptimizer( 324 params, 325 optimizer_class=AdamW, 326 lr=LR, 327 betas=BETAS, 328 eps=EPS, 329 ) 330 assert ( 331 len(o.param_groups) == 2 332 ), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" 333 assert len(o.optim.param_groups) == 2, ( 334 "Expected 2 local optimizer param groups, but got " 335 f"{len(o.optim.param_groups)}" 336 ) 337 338 def test_same_dense_param_type(self): 339 """Check that ZeroRedundancyOptimizer raises an exception if the input 340 parameters include sparse tensors or different dense types. 341 342 NOTE: This test should be removed once support for sparse parameters 343 and varying parameter types is added. 344 """ 345 self.dist_init(self.rank) 346 LR = 0.01 347 inputs = [ 348 [torch.sparse_coo_tensor(size=(2, 3))], 349 [torch.FloatTensor(1), torch.DoubleTensor(1)], 350 [ 351 torch.FloatTensor(1), 352 torch.FloatTensor(1), 353 torch.sparse_coo_tensor(size=(2, 3)), 354 ], 355 ] 356 for input in inputs: 357 with self.assertRaises(ValueError): 358 ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=LR) 359 360 361class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): 362 @property 363 def device(self): 364 return ( 365 torch.device(self.rank) 366 if torch.cuda.is_available() 367 else torch.device("cpu") 368 ) 369 370 @property 371 def world_size(self): 372 return min(4, max(2, torch.cuda.device_count())) 373 374 @property 375 def context(self): 376 return ( 377 nullcontext() 378 if not torch.cuda.is_available() 379 else torch.cuda.device(self.rank) 380 ) 381 382 def _check_same_model_params( 383 self, 384 model_a: torch.nn.Module, 385 model_b: torch.nn.Module, 386 message: str = "", 387 ) -> None: 388 # Check that model parameters match 389 for p_a, p_b in zip(model_a.parameters(), model_b.parameters()): 390 torch.testing.assert_close( 391 p_a, 392 p_b, 393 atol=1e-3, 394 rtol=1e-5, 395 msg=f"Model parameters differ:\n{p_a} {p_b}\n" + message, 396 ) 397 # Check that model buffers match 398 for b_a, b_b in zip(model_a.buffers(), model_b.buffers()): 399 torch.testing.assert_close( 400 b_a, 401 b_b, 402 msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message, 403 ) 404 405 @common_distributed.skip_if_no_gpu 406 @common_distributed.skip_if_rocm_multiprocess 407 def test_step(self): 408 """Check that ZeroRedundancyOptimizer properly exposes the ``step()`` 409 interface.""" 410 self.dist_init(self.rank, world_size=self.world_size) 411 LR = 0.01 412 413 with self.context: 414 x = torch.tensor([float(self.rank + 1)], device=self.device) 415 m = torch.nn.Linear(1, 1) 416 m.weight.data = torch.tensor([[1.0]]) 417 m.bias.data = torch.tensor([2.0]) 418 m = m.to(self.device) 419 m_zero = copy.deepcopy(m).to(self.device) 420 421 o = SGD(m.parameters(), lr=LR) 422 o_zero = ZeroRedundancyOptimizer( 423 m_zero.parameters(), 424 optimizer_class=SGD, 425 lr=LR, 426 ) 427 428 y = m(x) 429 y.backward(x) 430 y_zero = m_zero(x) 431 y_zero.backward(x) 432 433 for p in m.parameters(): 434 dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM) 435 p.grad.data /= self.world_size 436 o.step() 437 for p in m_zero.parameters(): 438 dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM) 439 p.grad.data /= self.world_size 440 o_zero.step() 441 442 self.assertEqual(m.weight, m_zero.weight) 443 self.assertEqual(m.bias, m_zero.bias) 444 445 @common_distributed.skip_if_no_gpu 446 @common_distributed.skip_if_rocm_multiprocess 447 def test_step_with_closure(self): 448 """Check that ZeroRedundancyOptimizer properly exposes the 449 ``step(closure)`` interface.""" 450 self.dist_init(self.rank, world_size=self.world_size) 451 452 with self.context: 453 for bucket_view in [False, True]: 454 x_val = self.rank + 1 455 weight = 1.0 456 bias = 2.0 457 error = 1.0 458 target = torch.tensor( 459 [x_val * weight + bias + error], 460 device=self.device, 461 ) 462 loss_fn = torch.nn.L1Loss() 463 464 x = torch.tensor([float(x_val)], device=self.device) 465 m = torch.nn.Linear(1, 1) 466 m.weight.data = torch.tensor([[weight]]) 467 m.bias.data = torch.tensor([bias]) 468 m.to(self.device) 469 470 o = ZeroRedundancyOptimizer( 471 m.parameters(), 472 optimizer_class=SGD, 473 parameters_as_bucket_view=bucket_view, 474 lr=0.1, 475 ) 476 477 y = m(x) 478 y.backward(x) 479 for p in m.parameters(): 480 dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM) 481 p.grad.data /= self.world_size 482 483 def closure(): 484 o.zero_grad() 485 output = m(x) 486 loss = loss_fn(output, target) 487 loss.backward() 488 return loss 489 490 loss = o.step(closure=closure) 491 492 self.assertEqual(loss, torch.tensor(error)) 493 self.assertEqual(m.weight, torch.tensor([[1.1]])) 494 self.assertEqual(m.bias, torch.tensor([2.1])) 495 496 @common_distributed.skip_if_no_gpu 497 def test_lr_scheduler(self): 498 """Check that a normal PyTorch ``lr_scheduler`` is usable with 499 ZeroRedundancyOptimizer.""" 500 self.dist_init(self.rank) 501 x = torch.tensor([1.0], device=self.device, requires_grad=True) 502 x2 = torch.tensor([1.0], device=self.device, requires_grad=True) 503 o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01) 504 o2 = torch.optim.SGD([x2], lr=0.01) 505 s = torch.optim.lr_scheduler.StepLR(o, 1) 506 s2 = torch.optim.lr_scheduler.StepLR(o2, 1) 507 for _ in range(5): 508 x.backward() 509 o.zero_grad() 510 o.step() 511 s.step() 512 x2.backward() 513 o2.zero_grad() 514 o2.step() 515 s2.step() 516 self.assertEqual(x, x2) 517 518 def test_sharding(self): 519 """ 520 Check ZeroRedundancyOptimizer's parameter sharding at construction 521 time. 522 523 NOTE: The correctness of this test depends on the ZeRO implementation 524 using the sorted-greedy partitioning algorithm. For details, see 525 ``ZeroRedundancyOptimizer._partition_parameters()`` in 526 zero_redundancy_optimizer.py. 527 """ 528 self.dist_init(self.rank) 529 LR = 0.01 530 sizes = [9, 7, 5, 3] 531 params = [] 532 for size in sizes * self.world_size: 533 params.append(torch.rand(size, 1)) 534 o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR) 535 self.assertEqual( 536 sum(x.numel() for x in o.optim.param_groups[0]["params"]), 537 sum(sizes), 538 ) 539 540 def test_add_param_group(self): 541 """Check that ZeroRedundancyOptimizer properly handles adding a new 542 parameter group a posteriori and that all ranks get a shard of the 543 contained parameters. 544 545 NOTE: The correctness of this test depends on the ZeRO implementation 546 using the sorted-greedy partitioning algorithm. For details, see 547 ``ZeroRedundancyOptimizer._partition_parameters()`` in 548 zero_redundancy_optimizer.py. 549 """ 550 self.dist_init(self.rank) 551 LR = 0.01 552 553 # Test with all parameters trainable to begin with 554 def all_trainable(): 555 params = [] 556 sizes = [9, 7, 5, 3] 557 sizes_world = sizes * self.world_size 558 for size in sizes_world[:-1]: 559 params.append(torch.rand(size, 1)) 560 561 # Make sure that the params are trainable so that they are factored 562 # into the size-based parameter partitioning 563 for p in params: 564 p.requires_grad = True 565 566 o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR) 567 self.assertEqual(len(o.param_groups), 1) 568 o.add_param_group({"params": [torch.rand(3, 1)]}) 569 # Verify that new group is added to the correct partition, making 570 # all partitions have the same elements 571 self.assertEqual(len(o.param_groups), 2) 572 self.assertEqual( 573 sum(x.numel() for g in o.optim.param_groups for x in g["params"]), 574 sum(sizes), 575 ) 576 self.assertEqual(len(o.optim.param_groups), 2) 577 578 # Test a pathological config with a first big non-trainable param 579 def some_trainable(): 580 params = [] 581 for size in [100, 3, 5, 2, 6, 4]: 582 params.append(torch.rand(size, 1)) 583 584 # Make sure that all but the first param are trainable so that they 585 # are factored into the size-based parameter partitioning 586 for p in params[1:]: 587 p.requires_grad = True 588 589 o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR) 590 self.assertEqual(len(o.param_groups), 1) 591 o.add_param_group({"params": [torch.rand(3, 1)]}) 592 self.assertEqual(len(o.param_groups), 2) 593 self.assertEqual(len(o.optim.param_groups), 2) 594 595 all_trainable() 596 some_trainable() 597 598 @common_distributed.skip_if_no_gpu 599 def test_multiple_param_groups(self): 600 """ 601 Check parity between constructing ZeRO with multiple parameter groups 602 upfront versus adding parameter groups to ZeRO after construction 603 versus a non-sharded optimizer. 604 """ 605 self.dist_init(self.rank) 606 BATCH_SIZE, NUM_ITERS = 8, 3 607 INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5 608 WD, LR = 0.01, 0.01 609 model1 = torch.nn.Sequential( 610 torch.nn.Linear(INPUT_DIM, HIDDEN_DIM), 611 torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), 612 torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM), 613 ) 614 model2 = copy.deepcopy(model1) 615 model3 = copy.deepcopy(model1) 616 model1 = model1.to(self.device) 617 model2 = model2.to(self.device) 618 model3 = model3.to(self.device) 619 inputs = [ 620 torch.randn(BATCH_SIZE, INPUT_DIM).to(self.device) for _ in range(NUM_ITERS) 621 ] 622 # Construct `optim1` with both parameter groups upfront 623 optim1 = ZeroRedundancyOptimizer( 624 [ 625 {"params": [l.weight for l in model1], "weight_decay": 0.0}, 626 {"params": [l.bias for l in model1], "weight_decay": WD}, 627 ], 628 optimizer_class=AdamW, 629 lr=LR, 630 ) 631 # Construct `optim2` by adding the second parameter after 632 optim2 = ZeroRedundancyOptimizer( 633 [l.weight for l in model2], 634 optimizer_class=AdamW, 635 lr=LR, 636 weight_decay=0.0, 637 ) 638 optim2.add_param_group({"params": [l.bias for l in model2], "weight_decay": WD}) 639 # Construct `optim3` as a non-sharded optimizer 640 optim3 = AdamW( 641 [ 642 {"params": [l.weight for l in model3], "weight_decay": 0.0}, 643 {"params": [l.bias for l in model3], "weight_decay": WD}, 644 ], 645 lr=LR, 646 ) 647 # Check parity over a few iterations 648 for input in inputs: 649 for model, optim in ( 650 (model1, optim1), 651 (model2, optim2), 652 (model3, optim3), 653 ): 654 optim.zero_grad() 655 out = model(input) 656 loss = out.sum() 657 loss.backward() 658 optim.step() 659 for layer1, layer2, layer3 in zip(model1, model2, model3): 660 torch.testing.assert_close(layer1.weight, layer2.weight) 661 torch.testing.assert_close(layer1.weight, layer3.weight) 662 torch.testing.assert_close(layer1.bias, layer2.bias) 663 torch.testing.assert_close(layer1.bias, layer3.bias) 664 665 @common_distributed.skip_if_no_gpu 666 @common_distributed.skip_if_rocm_multiprocess 667 def test_collect_shards(self): 668 """Check the state consolidation mechanism and the state dict exposed 669 by ZeroRedundancyOptimizer.""" 670 self.dist_init(self.rank) 671 LR = 1e-3 672 MOMENTUM = 0.99 673 BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5 674 REFERENCE_RANK = 0 675 target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=self.device) 676 inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=self.device) 677 model = torch.nn.Sequential( 678 torch.nn.Linear(INPUT_DIM, HIDDEN_DIM), 679 torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM), 680 ).to(self.device) 681 loss_fn = torch.nn.L1Loss() 682 loss_fn.to(self.device) 683 optimizer = ZeroRedundancyOptimizer( 684 model.parameters(), 685 optimizer_class=SGD, 686 lr=LR, 687 momentum=MOMENTUM, # ensure there exists state to shard 688 ) 689 690 def closure(): 691 optimizer.zero_grad() 692 output = model(inputs) 693 loss = loss_fn(output, target) 694 loss.backward() 695 return loss 696 697 # Run a dummy step so that the optimizer state dict exists 698 _ = optimizer.step(closure=closure) 699 700 # Get the optimizer state on the reference rank 701 optimizer.consolidate_state_dict(to=REFERENCE_RANK) 702 if self.rank == REFERENCE_RANK: 703 # Check that the state has the correct size 704 optimizer_state_dict = optimizer.state_dict() 705 self.assertEqual( 706 len(optimizer_state_dict["state"]), 707 len(list(model.parameters())), 708 ) 709 else: 710 optimizer_state_dict = {} 711 712 # Load the optimizer state on all ranks without any exceptions 713 optimizer_state_dict = _broadcast_object( 714 optimizer_state_dict, 715 src_rank=REFERENCE_RANK, 716 group=dist.group.WORLD, 717 device=self.device, 718 ) 719 optimizer.load_state_dict(optimizer_state_dict) 720 721 def test_nondefault_process_group(self): 722 """Check that ZeroRedundancyOptimizer works with a non-default process 723 group consisting only of even ranks.""" 724 # Skip the test if below the minimum world size since then the test is 725 # trivial 726 MIN_WORLD_SIZE = 4 727 if self.world_size < MIN_WORLD_SIZE: 728 common_distributed.logger.info( 729 "Skipping `test_nondefault_process_group()` since world size " 730 "of %s is less than %s", 731 self.world_size, 732 MIN_WORLD_SIZE, 733 ) 734 return 735 BACKEND = dist.Backend.GLOO 736 self.dist_init(self.rank, self.world_size, BACKEND) 737 # Use GPU if enough are available, or fall back to CPU otherwise, which 738 # is fine since Gloo backend supports both 739 if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size: 740 device = torch.device(self.rank) 741 else: 742 device = torch.device("cpu") 743 # Create a new process group consisting of the even ranks to exercise 744 # the case where the global and local ranks do not necessarily match 745 subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0] 746 process_group = dist.new_group( 747 ranks=subgroup_ranks, 748 backend=BACKEND, 749 ) 750 # Ranks not participating in the new process group are no longer needed 751 if self.rank not in subgroup_ranks: 752 return 753 754 # Set different seeds across ranks so that each rank gets different 755 # training data and hence the model sync check is meaningful 756 torch.manual_seed(self.rank) 757 np.random.seed(self.rank) 758 759 EPOCHS, BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 3, 20, 10, 5 760 LR = 1e-3 761 MOMENTUM = 0.99 762 REFERENCE_RANK = 0 763 assert ( 764 REFERENCE_RANK in subgroup_ranks 765 ), "Reference rank must be in the new process group" 766 loss_fn = torch.nn.L1Loss().to(device) 767 768 def check(optimizer): 769 for _ in range(EPOCHS): 770 target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=device) 771 inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=device) 772 773 def closure(): 774 optimizer.zero_grad() 775 output = model(inputs) 776 loss = loss_fn(output, target) 777 loss /= self.world_size 778 loss.backward() 779 dist.all_reduce(loss, group=process_group) 780 return loss 781 782 _ = optimizer.step(closure=closure) 783 784 # Check that the parameters match across ranks after a step 785 for pg in optimizer.param_groups: 786 for p in pg["params"]: 787 receptacle = ( 788 [p.clone() for _ in subgroup_ranks] 789 if self.rank == REFERENCE_RANK 790 else [] 791 ) 792 dist.gather( 793 p, 794 receptacle, 795 dst=REFERENCE_RANK, 796 group=process_group, 797 ) 798 if self.rank == REFERENCE_RANK: 799 reference_param = receptacle[0] 800 for param in receptacle[1:]: 801 torch.testing.assert_close( 802 reference_param, 803 param, 804 msg="Models differ between ranks", 805 ) 806 807 model = torch.nn.Sequential( 808 torch.nn.Linear(INPUT_DIM, HIDDEN_DIM), 809 torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM), 810 ).to(device) 811 optimizer = ZeroRedundancyOptimizer( 812 model.parameters(), 813 optimizer_class=SGD, 814 lr=LR, 815 momentum=MOMENTUM, # ensure there exists state to shard 816 process_group=process_group, 817 ) 818 check(optimizer) 819 820 @common_distributed.skip_if_no_gpu 821 @parametrize( 822 "optimizer_class_str", 823 ["Adam", "AdamW", "SGD"], 824 # Use string to appease the internal test name parser 825 ) 826 @parametrize( 827 "maximize", 828 [False, True], 829 ) 830 def test_local_optimizer_parity( 831 self, 832 optimizer_class_str: str, 833 maximize: bool, 834 ): 835 """When combined with DDP, check that a local optimizer gives the same 836 results as wrapping that optimizer with ZeroRedundancyOptimizer.""" 837 self.dist_init(self.rank) 838 BATCHES = 20 839 BATCH_SIZE = 64 840 LR = 1e-3 841 INPUT_DIM = 2 842 HIDDEN_DIM = 3 843 OUTPUT_DIM = 3 844 torch.manual_seed(self.rank) 845 np.random.seed(self.rank) 846 if optimizer_class_str == "Adam": 847 optimizer_class = torch.optim.Adam 848 elif optimizer_class_str == "AdamW": 849 optimizer_class = torch.optim.AdamW 850 elif optimizer_class_str == "SGD": 851 optimizer_class = torch.optim.SGD 852 else: 853 assert 0, f"Unsupported optimizer class: {optimizer_class_str}" 854 855 with self.context: 856 # Define a base model with a different buffer for each rank 857 model = torch.nn.Sequential( 858 torch.nn.Linear(INPUT_DIM, HIDDEN_DIM), 859 torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), 860 torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM), 861 ).to(self.device) 862 model.test_buffer = torch.nn.Buffer( 863 torch.ones((1), device=self.device) * self.rank, 864 ) 865 # Define models/optimizers for DDP with ZeRO and DDP with local 866 # optimizer 867 defaults = {"maximize": True} if maximize else {} 868 sharded_optimizer = ZeroRedundancyOptimizer( 869 params=model.parameters(), 870 optimizer_class=optimizer_class, 871 lr=LR, 872 **defaults, 873 ) 874 sharded_ddp_model = DDP( 875 module=model, 876 device_ids=[self.rank], 877 broadcast_buffers=True, 878 find_unused_parameters=True, 879 ) 880 local_model = copy.deepcopy(model).to(self.device) 881 ddp_optimizer = optimizer_class( 882 local_model.parameters(), 883 lr=LR, 884 **defaults, 885 ) 886 ddp_model = DDP( 887 local_model, 888 device_ids=[self.rank], 889 broadcast_buffers=True, 890 find_unused_parameters=True, 891 ) 892 # Check that the model is properly synchronized between ranks 893 # at construction time 894 self._check_same_model_params( 895 sharded_ddp_model, 896 ddp_model, 897 "Models differ from the start", 898 ) 899 900 def check_step(): 901 input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)) 902 903 def closure_ddp(input_tensor=input_tensor): 904 ddp_optimizer.zero_grad() 905 ddp_loss = ddp_model(input_tensor).abs().sum() 906 ddp_loss.backward() 907 return ddp_loss 908 909 def closure_sharded(input_tensor=input_tensor): 910 sharded_optimizer.zero_grad() 911 sharded_loss = sharded_ddp_model(input_tensor).abs().sum() 912 sharded_loss.backward() 913 return sharded_loss 914 915 loss_ddp = cast( 916 torch.Tensor, 917 ddp_optimizer.step(closure=closure_ddp), 918 ) 919 loss_sharded_optim = cast( 920 torch.Tensor, 921 sharded_optimizer.step(closure=closure_sharded), 922 ) 923 torch.testing.assert_close( 924 loss_ddp, 925 loss_sharded_optim, 926 msg="Losses differ between local optimizer and ZeRO", 927 ) 928 self._check_same_model_params( 929 sharded_ddp_model, 930 ddp_model, 931 "Models differ after a step", 932 ) 933 934 # Check that parity is maintained 935 for i in range(BATCHES): 936 check_step() 937 # For the second half of batches, change the parameter 938 # trainability to further test parity 939 if i > BATCHES // 2: 940 next(ddp_model.parameters()).requires_grad = bool(i % 2) 941 next(sharded_ddp_model.parameters()).requires_grad = bool(i % 2) 942 943 # Check that the `state_dict` checkpoints are compatible between 944 # the local optimizer and ZeRO 945 REFERENCE_RANK = 0 946 # - Get states 947 ddp_state_dict = ddp_optimizer.state_dict() 948 sharded_optimizer.consolidate_state_dict(to=REFERENCE_RANK) 949 sharded_optim_state_dict = [ 950 sharded_optimizer.state_dict() if self.rank == REFERENCE_RANK else {} 951 ] 952 dist.broadcast_object_list( 953 sharded_optim_state_dict, 954 src=REFERENCE_RANK, 955 group=dist.group.WORLD, 956 ) 957 sharded_optim_state_dict = sharded_optim_state_dict[0] 958 959 # - Cross-load the states 960 # Run one step and check that the models are still the same 961 ddp_state_dict_ref = copy.deepcopy(ddp_state_dict) 962 ddp_optimizer.load_state_dict(sharded_optim_state_dict) 963 sharded_optimizer.load_state_dict(ddp_state_dict) 964 check_step() 965 966 # - Reload their respective states 967 # Run one step and check that the models are still the same 968 ddp_optimizer.load_state_dict(ddp_state_dict_ref) 969 sharded_optimizer.load_state_dict(sharded_optim_state_dict) 970 check_step() 971 972 def _test_zero_join(self, device): 973 """Check that the ZeRO join hook allows training with uneven inputs 974 when using the given device.""" 975 NUM_INPUTS = 3 976 NUM_EPOCHS = 2 977 LR = 0.01 978 torch.manual_seed(0) 979 torch.cuda.manual_seed(0) 980 981 rank = self.rank 982 world_size = self.world_size 983 is_gpu = device.type == "cuda" 984 backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO 985 self.dist_init(rank, world_size, backend) 986 987 model = torch.nn.Sequential( 988 torch.nn.Linear(2, 3), 989 torch.nn.Linear(3, 3), 990 torch.nn.Linear(3, 3), 991 ) 992 model.to(device) 993 994 # DDP ensures correct gradients in data parallel training, so DDP with 995 # local optimizers on uneven inputs should be equivalent to ZeRO on 996 # uneven inputs with gradients being manually set 997 ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model) 998 local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) 999 zero_model = copy.deepcopy(model) 1000 zero_model.to(device) 1001 zero_optim = ZeroRedundancyOptimizer( 1002 zero_model.parameters(), 1003 torch.optim.Adam, 1004 lr=LR, 1005 ) 1006 loss_fn = torch.nn.MSELoss() 1007 1008 # Use uneven inputs: rank i has i extra inputs 1009 inputs = [torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank)] 1010 labels = torch.randn(20, 3).to(device) 1011 1012 # Save the gradients and parameters from DDP as the ground truth; do 1013 # so on the last-joining rank (in this case, the largest rank) 1014 grads_at_each_iter = [] 1015 params_at_each_iter = [] 1016 with ddp_model.join(): 1017 for _ in range(NUM_EPOCHS): 1018 for input in inputs: 1019 output = ddp_model(input) 1020 loss_fn(output, labels).backward() 1021 if rank == world_size - 1: 1022 grads = [] 1023 for p in ddp_model.parameters(): 1024 grads.append(p.grad.detach().clone().to(device)) 1025 local_optim.step() 1026 if rank == world_size - 1: 1027 params = [] 1028 for p in ddp_model.parameters(): 1029 params.append(p.detach().clone().to(device)) 1030 grads_at_each_iter.append(grads) 1031 params_at_each_iter.append(params) 1032 1033 # Broadcast the saved gradients and parameters to all of the other 1034 # ranks (which joined early) 1035 grads_and_params = [grads_at_each_iter, params_at_each_iter] 1036 grads_and_params = _broadcast_object( 1037 grads_and_params, 1038 src_rank=world_size - 1, 1039 group=dist.group.WORLD, 1040 device=device, 1041 ) 1042 grads_at_each_iter = grads_and_params[0] 1043 params_at_each_iter = grads_and_params[1] 1044 # TODO: Replace this `_broadcast_object` with `broadcast_object_list` 1045 # once the latter supports loading to the destination device instead 1046 # of the source device 1047 1048 # A process must still set the remaining gradients after joining, so we 1049 # define a join hook to do this before the ZeRO join hook 1050 class _JoinGradInfo: 1051 def __init__(self, grads): 1052 self.grads = grads # remaining gradients to set (in order) 1053 self.index = 0 1054 1055 class _SetGradsJoinHook(JoinHook): 1056 def __init__(self, zero_optim, grads): 1057 zero_optim._join_grad_info = _JoinGradInfo(grads) 1058 self.zero = zero_optim 1059 super().__init__() 1060 1061 def main_hook(self): 1062 join_grad_info = self.zero._join_grad_info 1063 grads = self.zero._join_grad_info.grads[join_grad_info.index] 1064 join_grad_info.index += 1 1065 for p, grad in zip(self.zero._all_params, grads): 1066 p.grad = grad.detach().clone().to(device) 1067 1068 class _GradientSetter(Joinable): 1069 def __init__(self) -> None: 1070 super().__init__() 1071 1072 def join_hook(self, **kwargs): 1073 assert "zero_optim" in kwargs 1074 assert "grads" in kwargs 1075 zero_optim = kwargs["zero_optim"] 1076 grads = kwargs["grads"] 1077 return _SetGradsJoinHook(zero_optim, grads) 1078 1079 @property 1080 def join_device(self): 1081 return device 1082 1083 @property 1084 def join_process_group(self): 1085 return dist.group.WORLD 1086 1087 num_grads_after_joining = NUM_EPOCHS * (world_size - rank - 1) 1088 grads = grads_at_each_iter[-num_grads_after_joining:] 1089 gradient_setter = _GradientSetter() 1090 iter = 0 1091 with Join( 1092 [gradient_setter, zero_optim], 1093 zero_optim=zero_optim, 1094 grads=grads, 1095 ): 1096 for _ in range(NUM_EPOCHS): 1097 for input in inputs: 1098 # Notify join context that this process has not joined 1099 Join.notify_join_context(gradient_setter) 1100 # Set gradients manually 1101 for p, grad in zip( 1102 zero_model.parameters(), 1103 grads_at_each_iter[iter], 1104 ): 1105 p.grad = grad.detach().clone().to(device) 1106 # Perform optimizer step and check parity 1107 zero_optim.step() 1108 for p, ddp_p in zip( 1109 zero_model.parameters(), 1110 params_at_each_iter[iter], 1111 ): 1112 torch.testing.assert_close( 1113 p, 1114 ddp_p, 1115 msg="Parameters differ between using ZeRO and " 1116 "local optimizer", 1117 ) 1118 iter += 1 1119 1120 @common_distributed.requires_nccl() 1121 @common_distributed.skip_if_no_gpu 1122 def test_zero_join_gpu(self): 1123 """Check that the ZeRO join hook allows training with uneven inputs 1124 on GPU.""" 1125 self._test_zero_join(self.device) 1126 1127 @common_distributed.requires_gloo() 1128 def test_zero_join_cpu(self): 1129 """Check that the ZeRO join hook allows training with uneven inputs 1130 on CPU.""" 1131 self._test_zero_join(torch.device("cpu")) 1132 1133 def _test_zero_model_parallel(self, parameters_as_bucket_view: bool): 1134 # Use two processes each with two GPUs 1135 assert self.rank < 2 1136 NUM_EPOCHS = 2 1137 NUM_INPUTS = 4 1138 LR = 0.01 1139 torch.manual_seed(0) 1140 torch.cuda.manual_seed(0) 1141 1142 class ModelParallelModel(torch.nn.Module): 1143 def __init__(self, dev0, dev1): 1144 super().__init__() 1145 self.dev0 = dev0 1146 self.dev1 = dev1 1147 self.net0 = torch.nn.Linear(10, 10).to(dev0) 1148 self.relu = torch.nn.ReLU() 1149 self.net1 = torch.nn.Linear(10, 5).to(dev1) 1150 1151 def forward(self, x): 1152 x = x.to(self.dev0) 1153 x = self.relu(self.net0(x)) 1154 x = x.to(self.dev1) 1155 return self.net1(x) 1156 1157 class LocalModel(torch.nn.Module): 1158 def __init__(self) -> None: 1159 super().__init__() 1160 self.net0 = torch.nn.Linear(10, 10) 1161 self.relu = torch.nn.ReLU() 1162 self.net1 = torch.nn.Linear(10, 5) 1163 1164 def forward(self, x): 1165 return self.net1(self.relu(self.net0(x))) 1166 1167 dev0 = torch.device(2 * self.rank) 1168 dev1 = torch.device(2 * self.rank + 1) 1169 mp_model = ModelParallelModel(dev0, dev1) 1170 ddp_model = DDP(mp_model) 1171 local_model = LocalModel().to(dev0) 1172 1173 # Ensure the parameters are the same across the two models 1174 def copy_param(p): 1175 return torch.nn.Parameter(p.detach().clone().to(dev0)) 1176 1177 local_model.net0.weight = copy_param(mp_model.net0.weight) 1178 local_model.net0.bias = copy_param(mp_model.net0.bias) 1179 local_model.net1.weight = copy_param(mp_model.net1.weight) 1180 local_model.net1.bias = copy_param(mp_model.net1.bias) 1181 1182 # Compare parity between DDP with model parallelism using ZeRO and 1183 # a local model using a local optimizer 1184 zero_optim = ZeroRedundancyOptimizer( 1185 ddp_model.parameters(), 1186 optimizer_class=torch.optim.Adam, 1187 parameters_as_bucket_view=parameters_as_bucket_view, 1188 lr=LR, 1189 ) 1190 local_optim = torch.optim.Adam(local_model.parameters(), lr=LR) 1191 inputs = [torch.randn(20, 10).to(dev0) for _ in range(NUM_INPUTS)] 1192 1193 for _ in range(NUM_EPOCHS): 1194 for input in inputs: 1195 1196 def closure_local(): 1197 local_optim.zero_grad() 1198 local_loss = local_model(input).abs().sum() 1199 local_loss.backward() 1200 return local_loss 1201 1202 def closure_ddp(): 1203 zero_optim.zero_grad() 1204 ddp_loss = ddp_model(input).abs().sum() 1205 ddp_loss.backward() 1206 return ddp_loss 1207 1208 local_loss = cast(torch.Tensor, local_optim.step(closure=closure_local)) 1209 ddp_loss = cast(torch.Tensor, zero_optim.step(closure=closure_ddp)) 1210 1211 # Increased tolerances are needed to pass when using TF32 1212 # See: https://github.com/pytorch/pytorch/issues/67764 1213 torch.testing.assert_close( 1214 local_loss.cpu(), 1215 ddp_loss.cpu(), 1216 rtol=1e-03, 1217 atol=1e-08, 1218 ), "Losses differ between local optimizer and ZeRO" 1219 1220 for local_p, ddp_p in zip( 1221 local_model.parameters(), ddp_model.parameters() 1222 ): 1223 torch.testing.assert_close( 1224 local_p.cpu(), 1225 ddp_p.cpu(), 1226 rtol=1e-03, 1227 atol=1e-04, 1228 ), "Models differ after a step" 1229 1230 @common_distributed.skip_if_lt_x_gpu(4) 1231 @parametrize( 1232 "parameters_as_bucket_view", 1233 [False, True], 1234 ) 1235 def test_zero_model_parallel( 1236 self, 1237 parameters_as_bucket_view: bool, 1238 ): 1239 """Check that ZeRO works with model parallelism where the model's 1240 layers are assigned to different devices.""" 1241 if self.rank >= 2: 1242 return 1243 self.dist_init(self.rank, world_size=2) 1244 self._test_zero_model_parallel(parameters_as_bucket_view) 1245 1246 def _test_ddp_zero_overlap( 1247 self, 1248 device, 1249 hook_constructor, 1250 gradient_as_bucket_view, 1251 static_graph, 1252 **kwargs, 1253 ): 1254 SGD_LR = 0.01 1255 SGD_MOMENTUM = 0.9 1256 SGD_WEIGHT_DECAY = 0.001 1257 NUM_INPUTS = 5 1258 torch.manual_seed(0) 1259 torch.cuda.manual_seed(0) 1260 1261 rank = self.rank 1262 is_gpu = device.type == "cuda" 1263 if is_gpu: 1264 torch.cuda.set_device(device) 1265 models_to_test = [ 1266 ( 1267 torch.nn.Sequential( 1268 torch.nn.Linear(1000, 2000), 1269 torch.nn.Linear(2000, 500), 1270 ), 1271 [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)], 1272 ) 1273 ] 1274 if HAS_TORCHVISION: 1275 models_to_test.append( 1276 ( 1277 torchvision.models.resnet50(), 1278 [torch.randn(1, 3, 3, 1000).to(device) for _ in range(NUM_INPUTS)], 1279 ) 1280 ) 1281 for model, inputs in models_to_test: 1282 # Enable determinism in cudnn operators 1283 with torch.backends.cudnn.flags( 1284 enabled=True, deterministic=True, benchmark=False 1285 ): 1286 device_ids = [rank] if is_gpu else None 1287 # Set up the DDP model overlapping with ZeRO 1288 ddp_model_overlap = DDP( 1289 copy.deepcopy(model).to(device), 1290 device_ids=device_ids, 1291 gradient_as_bucket_view=gradient_as_bucket_view, 1292 ) 1293 if static_graph: 1294 ddp_model_overlap._set_static_graph() 1295 zero_optim = ZeroRedundancyOptimizer( 1296 ddp_model_overlap.parameters(), 1297 optimizer_class=torch.optim.SGD, 1298 overlap_with_ddp=True, 1299 lr=SGD_LR, 1300 momentum=SGD_MOMENTUM, 1301 weight_decay=SGD_WEIGHT_DECAY, 1302 ) 1303 ddp_model_overlap.register_comm_hook( 1304 None, 1305 hook_constructor( 1306 allreduce_hook, 1307 ddp_model_overlap, 1308 zero_optim, 1309 **kwargs, 1310 ), 1311 ) 1312 1313 # Set up the DDP model with local optimizer 1314 ddp_model_local = DDP( 1315 copy.deepcopy(model).to(device), 1316 device_ids=device_ids, 1317 gradient_as_bucket_view=gradient_as_bucket_view, 1318 ) 1319 if static_graph: 1320 ddp_model_local._set_static_graph() 1321 local_optim = torch.optim.SGD( 1322 ddp_model_local.parameters(), 1323 lr=SGD_LR, 1324 momentum=SGD_MOMENTUM, 1325 weight_decay=SGD_WEIGHT_DECAY, 1326 ) 1327 1328 # Check that the parameters match initially 1329 for p1, p2 in zip( 1330 ddp_model_overlap.parameters(), ddp_model_local.parameters() 1331 ): 1332 self.assertEqual(p1, p2) 1333 1334 # Save the parameters to ensure they were updated 1335 init_params_overlap = copy.deepcopy( 1336 list(ddp_model_overlap.parameters()) 1337 ) 1338 1339 # Ensure that this test runs independently 1340 dist.barrier() 1341 1342 # Run the DDP model overlapping with ZeRO 1343 # NOTE: Overlapping currently requires 2 or 3 warmup iterations 1344 # to ensure DDP buckets have been rebuilt (depending on the 1345 # value of `static_graph`) 1346 num_warmup_inputs = 2 if not static_graph else 3 1347 for input in inputs[:num_warmup_inputs]: 1348 output = ddp_model_overlap(input) 1349 loss = output.sum() 1350 loss.backward() 1351 for input in inputs: 1352 zero_optim.zero_grad() 1353 output = ddp_model_overlap(input) 1354 loss = output.sum() 1355 loss.backward() 1356 1357 # Run the DDP model with local optimizer 1358 for input in inputs: 1359 local_optim.zero_grad() 1360 output = ddp_model_local(input) 1361 loss = output.sum() 1362 loss.backward() 1363 local_optim.step() 1364 dist.barrier() 1365 1366 # Check that the parameters are equal 1367 for p1, p2 in zip( 1368 ddp_model_overlap.parameters(), ddp_model_local.parameters() 1369 ): 1370 self.assertEqual(p1, p2) 1371 1372 # Check that the parameters were updated 1373 self.assertNotEqual( 1374 init_params_overlap, 1375 list(ddp_model_overlap.parameters()), 1376 ) 1377 1378 # Ensure that this test runs independently 1379 dist.barrier() 1380 1381 # NOTE: The test is skipped if using Windows since functional optimizers 1382 # are not currently supported. 1383 @common_distributed.skip_if_win32() 1384 @common_distributed.requires_nccl() 1385 @common_distributed.skip_if_no_gpu 1386 @common_distributed.skip_if_rocm_multiprocess 1387 @parametrize( 1388 "use_gpu", 1389 [True], 1390 # Add `False` once the Gloo sync issue causing hangs is fixed 1391 # See: https://github.com/pytorch/pytorch/issues/62300 1392 ) 1393 @parametrize( 1394 "use_interleaved_hook", 1395 [False, True], 1396 ) 1397 @parametrize( 1398 "gradient_as_bucket_view", 1399 [False, True], 1400 ) 1401 @parametrize( 1402 "static_graph", 1403 [False, True], 1404 ) 1405 @parametrize( 1406 "shard_buckets", 1407 [False, True], 1408 ) 1409 def test_ddp_zero_overlap( 1410 self, 1411 use_gpu: bool, 1412 use_interleaved_hook: bool, 1413 gradient_as_bucket_view: bool, 1414 static_graph: bool, 1415 shard_buckets: bool, 1416 ): 1417 """ 1418 Check that overlapping DDP with ZeRO using the given method determined 1419 by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO 1420 and DDP arguments achieves parity with DDP using a local optimizer. 1421 """ 1422 device = torch.device(self.rank) if use_gpu else torch.device("cpu") 1423 backend = _get_backend_for_tests() 1424 self.dist_init(self.rank, self.world_size, backend) 1425 hook_constructor = ( 1426 hook_with_zero_step 1427 if not use_interleaved_hook 1428 else hook_with_zero_step_interleaved 1429 ) 1430 1431 self._test_ddp_zero_overlap( 1432 device, 1433 hook_constructor, 1434 gradient_as_bucket_view, 1435 static_graph, 1436 shard_buckets=shard_buckets, 1437 ) 1438 1439 1440instantiate_parametrized_tests(TestZeroRedundancyOptimizerSingleRank) 1441instantiate_parametrized_tests(TestZeroRedundancyOptimizerDistributed) 1442 1443if __name__ == "__main__": 1444 # ! unittest should not be used here, else the tests are not properly registered 1445 run_tests() 1446