• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4#
5# This source code is licensed under the BSD license found in the
6# LICENSE file in the root directory of this source tree.
7
8import copy
9import os
10import sys
11import unittest
12from contextlib import nullcontext
13from typing import Any, cast, List
14
15import numpy as np
16
17import torch
18import torch.distributed as dist
19
20
21if not dist.is_available():
22    print("Distributed not available, skipping tests", file=sys.stderr)
23    sys.exit(0)
24from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
25    hook_with_zero_step,
26    hook_with_zero_step_interleaved,
27)
28from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
29from torch.distributed.algorithms.join import Join, Joinable, JoinHook
30from torch.distributed.optim import ZeroRedundancyOptimizer
31from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
32from torch.nn.parallel import DistributedDataParallel as DDP
33from torch.optim import AdamW, SGD
34from torch.testing._internal import common_distributed
35from torch.testing._internal.common_utils import (
36    instantiate_parametrized_tests,
37    IS_WINDOWS,
38    parametrize,
39    run_tests,
40    TEST_WITH_ASAN,
41    TEST_WITH_DEV_DBG_ASAN,
42)
43
44
45try:
46    import torchvision
47
48    HAS_TORCHVISION = True
49except ImportError:
50    HAS_TORCHVISION = False
51
52
53# Use GLOO on GPU when running CUDA + Windows
54def _get_backend_for_tests():
55    return (
56        dist.Backend.NCCL
57        if not IS_WINDOWS and torch.cuda.is_available()
58        # Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
59        # no GPUs are available.
60        else dist.Backend.GLOO
61    )
62
63
64BACKEND = _get_backend_for_tests()
65
66
67@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.")
68class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
69    def setUp(self):
70        super().setUp()
71        os.environ["WORLD_SIZE"] = str(self.world_size)
72        self._spawn_processes()
73
74    @property
75    def device(self):
76        return (
77            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
78        )
79
80    @property
81    def world_size(self):
82        return 1
83
84    def tearDown(self):
85        try:
86            torch.distributed.destroy_process_group()
87        except AssertionError:
88            pass
89        try:
90            os.remove(self.file_name)
91        except OSError:
92            pass
93
94    def dist_init(self, rank, world_size=-1, backend=BACKEND):
95        if world_size < 1:
96            world_size = self.world_size
97        store = dist.FileStore(self.file_name, world_size)
98        return dist.init_process_group(
99            backend=backend,
100            store=store,
101            rank=rank,
102            world_size=world_size,
103        )
104
105
106# TODO: skip_but_pass_in_sandcastle_if does not work here.
107@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.")
108class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
109    def test_state_dict(self):
110        """Check that ZeroRedundancyOptimizer exposes the expected state dict
111        interface, irrespective of the sharding."""
112        self.dist_init(self.rank)
113        LR1 = 0.1
114        LR2 = 0.01
115        MOMENTUM = 0.9
116        RECIPIENT_RANK = 0  # rank 0 is the only rank since the world size is 1
117        x = torch.tensor([1.0], device=self.device, requires_grad=True)
118        o = ZeroRedundancyOptimizer(
119            [x],
120            optimizer_class=SGD,
121            lr=LR1,
122            momentum=MOMENTUM,
123        )
124        x.backward()
125        o.step()
126        self.assertEqual(x, torch.tensor([0.9], device=self.device))
127        self.assertEqual(
128            o.optim.state[x]["momentum_buffer"],
129            torch.tensor([1.0], device=self.device),
130        )
131
132        o.zero_grad()
133        o.consolidate_state_dict(to=RECIPIENT_RANK)
134        state_dict = o.state_dict()
135
136        # Check that the state dict has keys compliant with PyTorch
137        self.assertIn("param_groups", state_dict.keys())
138        self.assertIn("state", state_dict.keys())
139
140        # Check that the state has the expected keys
141        self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
142        self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
143        self.assertFalse(state_dict["param_groups"][0]["nesterov"])
144        self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
145        self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)
146
147        # Check that the state and the `param_groups` attribute are in sync
148        for k in state_dict["param_groups"][0]:
149            if k != "params":
150                self.assertEqual(
151                    state_dict["param_groups"][0][k],
152                    o.param_groups[0][k],
153                )
154
155        # Check that the state is reloaded with the correct values and device
156        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR2)
157        o.load_state_dict(state_dict)
158        self.assertEqual(
159            o.optim.state[x]["momentum_buffer"],
160            torch.tensor([1.0], device=self.device),
161        )
162
163        # We should we using `LR1` and not `LR2` after reloading, both within
164        # the optimizer and as exposed by the `param_groups` attribute
165        self.assertEqual(o.param_groups[0]["lr"], LR1)
166        x.backward()
167        o.step()
168        self.assertEqual(x, torch.tensor([0.71], device=self.device))
169        self.assertEqual(
170            o.optim.state[x]["momentum_buffer"],
171            torch.tensor([1.9], device=self.device),
172        )
173
174        # Check that the exposed `param_groups`` are on the proper device
175        self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
176
177    def test_lr_scheduler(self):
178        """Check that a normal PyTorch ``lr_scheduler`` is usable with
179        ZeroRedundancyOptimizer."""
180        self.dist_init(self.rank)
181        NUM_ITERS = 5
182        LR = 0.01
183        x = torch.tensor([1.0], device=self.device, requires_grad=True)
184        x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
185        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR)
186        o2 = torch.optim.SGD([x2], lr=LR)
187        s = torch.optim.lr_scheduler.StepLR(o, 1)
188        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
189        for _ in range(NUM_ITERS):
190            x.backward()
191            o.zero_grad()
192            o.step()
193            s.step()
194            x2.backward()
195            o2.zero_grad()
196            o2.step()
197            s2.step()
198            self.assertEqual(x, x2)
199
200    def test_step_with_kwargs(self):
201        """Check that the ``step(**kwargs)`` interface is properly exposed."""
202        self.dist_init(self.rank)
203        LR = 0.1
204
205        class SGDWithStepKWArg(torch.optim.SGD):
206            def step(self, closure=None, kwarg=None):
207                super().step()
208                kwarg.append(5)
209
210        kwarg: List[Any] = []
211        x = torch.tensor([1.0], device=self.device, requires_grad=True)
212        o = ZeroRedundancyOptimizer(
213            [x],
214            optimizer_class=SGDWithStepKWArg,
215            lr=LR,
216        )
217        x.backward()
218        o.step(0, kwarg=kwarg)
219        self.assertEqual(kwarg, [5])
220        self.assertEqual(x, torch.tensor([0.9], device=self.device))
221
222    def test_step_with_extra_inner_key(self):
223        """Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
224        extra keys to ``param_groups`` exposes those keys through ZeRO's own
225        ``param_groups``."""
226        self.dist_init(self.rank)
227        LR = 0.1
228
229        class SGDWithNewKey(torch.optim.SGD):
230            # Dummy optimizer which adds a new key to the param groups
231            def step(self, closure=None):
232                super().step()
233                self.param_groups[0]["new_key"] = 0.1
234
235        x = torch.tensor([1.0], device=self.device, requires_grad=True)
236        o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithNewKey, lr=LR)
237        x.backward()
238        o.step()
239        self.assertEqual(o.param_groups[0]["new_key"], 0.1)
240        self.assertEqual(x, torch.tensor([0.9], device=self.device))
241
242    def test_step_without_closure(self):
243        """Check that the ``step()`` method (without closure) is handled as
244        expected."""
245        self.dist_init(self.rank)
246        LR = 0.1
247
248        class SGDWithoutClosure(torch.optim.SGD):
249            def step(self):
250                return super().step()
251
252        x = torch.tensor([1.0], device=self.device, requires_grad=True)
253        o = ZeroRedundancyOptimizer(
254            [x],
255            optimizer_class=SGDWithoutClosure,
256            lr=LR,
257        )
258        x.backward()
259        o.step()
260        self.assertEqual(x, torch.tensor([0.9], device=self.device))
261
262    def test_zero_grad(self):
263        """Check that the ``zero_grad`` method is properly handled."""
264        self.dist_init(self.rank)
265        LR = 0.01
266        x = torch.rand(1)
267        m = torch.nn.Linear(1, 1)
268        o = ZeroRedundancyOptimizer(m.parameters(), optimizer_class=SGD, lr=LR)
269        y = m(x)
270        y.backward(x)
271        self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
272        self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
273        o.zero_grad()
274        self.assertIsNone(m.weight.grad)
275        self.assertIsNone(m.bias.grad)
276
277    def test_constructor(self):
278        """Check the robustness of the ZeroRedundancyOptimizer constructor by
279        passing different values for the ``params`` argument."""
280        self.dist_init(self.rank)
281        LR = 0.01
282        m = torch.nn.Sequential(
283            torch.nn.Linear(5, 10),
284            torch.nn.Linear(10, 10),
285            torch.nn.Linear(10, 10),
286        )
287        # Test various constructor inputs in the form: (input, expected error)
288        ctor_inputs = [
289            ([], ValueError),  # empty parameter list
290            (torch.randn(1), TypeError),  # non-iterable: `torch.Tensor`
291            (1.2, TypeError),  # non-iterable: `float`
292            (
293                [
294                    {"params": [l.weight for l in m]},
295                    {"params": [l.bias for l in m]},
296                ],
297                None,
298            ),  # iterable of dict
299            (
300                list(m.parameters()) + [42],
301                TypeError,
302            ),  # iterable containing invalid type
303            (m.parameters(), None),  # `params` as a generator
304            (list(m.parameters()), None),  # `params` as a list
305        ]
306        for ctor_input, error in ctor_inputs:
307            context = self.assertRaises(error) if error else nullcontext()
308            with context:
309                ZeroRedundancyOptimizer(
310                    ctor_input,
311                    optimizer_class=SGD,
312                    lr=LR,
313                )
314
315        # Test constructing with multiple parameter groups more thoroughly
316        WD = 0.01
317        BETAS = (0.9, 0.999)
318        EPS = 1e-8
319        params = [
320            {"params": [l.weight for l in m], "weight_decay": 0.0},
321            {"params": [l.bias for l in m], "weight_decay": WD},
322        ]
323        o = ZeroRedundancyOptimizer(
324            params,
325            optimizer_class=AdamW,
326            lr=LR,
327            betas=BETAS,
328            eps=EPS,
329        )
330        assert (
331            len(o.param_groups) == 2
332        ), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
333        assert len(o.optim.param_groups) == 2, (
334            "Expected 2 local optimizer param groups, but got "
335            f"{len(o.optim.param_groups)}"
336        )
337
338    def test_same_dense_param_type(self):
339        """Check that ZeroRedundancyOptimizer raises an exception if the input
340        parameters include sparse tensors or different dense types.
341
342        NOTE: This test should be removed once support for sparse parameters
343        and varying parameter types is added.
344        """
345        self.dist_init(self.rank)
346        LR = 0.01
347        inputs = [
348            [torch.sparse_coo_tensor(size=(2, 3))],
349            [torch.FloatTensor(1), torch.DoubleTensor(1)],
350            [
351                torch.FloatTensor(1),
352                torch.FloatTensor(1),
353                torch.sparse_coo_tensor(size=(2, 3)),
354            ],
355        ]
356        for input in inputs:
357            with self.assertRaises(ValueError):
358                ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=LR)
359
360
361class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
362    @property
363    def device(self):
364        return (
365            torch.device(self.rank)
366            if torch.cuda.is_available()
367            else torch.device("cpu")
368        )
369
370    @property
371    def world_size(self):
372        return min(4, max(2, torch.cuda.device_count()))
373
374    @property
375    def context(self):
376        return (
377            nullcontext()
378            if not torch.cuda.is_available()
379            else torch.cuda.device(self.rank)
380        )
381
382    def _check_same_model_params(
383        self,
384        model_a: torch.nn.Module,
385        model_b: torch.nn.Module,
386        message: str = "",
387    ) -> None:
388        # Check that model parameters match
389        for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
390            torch.testing.assert_close(
391                p_a,
392                p_b,
393                atol=1e-3,
394                rtol=1e-5,
395                msg=f"Model parameters differ:\n{p_a} {p_b}\n" + message,
396            )
397        # Check that model buffers match
398        for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
399            torch.testing.assert_close(
400                b_a,
401                b_b,
402                msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
403            )
404
405    @common_distributed.skip_if_no_gpu
406    @common_distributed.skip_if_rocm_multiprocess
407    def test_step(self):
408        """Check that ZeroRedundancyOptimizer properly exposes the ``step()``
409        interface."""
410        self.dist_init(self.rank, world_size=self.world_size)
411        LR = 0.01
412
413        with self.context:
414            x = torch.tensor([float(self.rank + 1)], device=self.device)
415            m = torch.nn.Linear(1, 1)
416            m.weight.data = torch.tensor([[1.0]])
417            m.bias.data = torch.tensor([2.0])
418            m = m.to(self.device)
419            m_zero = copy.deepcopy(m).to(self.device)
420
421            o = SGD(m.parameters(), lr=LR)
422            o_zero = ZeroRedundancyOptimizer(
423                m_zero.parameters(),
424                optimizer_class=SGD,
425                lr=LR,
426            )
427
428            y = m(x)
429            y.backward(x)
430            y_zero = m_zero(x)
431            y_zero.backward(x)
432
433            for p in m.parameters():
434                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
435                p.grad.data /= self.world_size
436            o.step()
437            for p in m_zero.parameters():
438                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
439                p.grad.data /= self.world_size
440            o_zero.step()
441
442            self.assertEqual(m.weight, m_zero.weight)
443            self.assertEqual(m.bias, m_zero.bias)
444
445    @common_distributed.skip_if_no_gpu
446    @common_distributed.skip_if_rocm_multiprocess
447    def test_step_with_closure(self):
448        """Check that ZeroRedundancyOptimizer properly exposes the
449        ``step(closure)`` interface."""
450        self.dist_init(self.rank, world_size=self.world_size)
451
452        with self.context:
453            for bucket_view in [False, True]:
454                x_val = self.rank + 1
455                weight = 1.0
456                bias = 2.0
457                error = 1.0
458                target = torch.tensor(
459                    [x_val * weight + bias + error],
460                    device=self.device,
461                )
462                loss_fn = torch.nn.L1Loss()
463
464                x = torch.tensor([float(x_val)], device=self.device)
465                m = torch.nn.Linear(1, 1)
466                m.weight.data = torch.tensor([[weight]])
467                m.bias.data = torch.tensor([bias])
468                m.to(self.device)
469
470                o = ZeroRedundancyOptimizer(
471                    m.parameters(),
472                    optimizer_class=SGD,
473                    parameters_as_bucket_view=bucket_view,
474                    lr=0.1,
475                )
476
477                y = m(x)
478                y.backward(x)
479                for p in m.parameters():
480                    dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
481                    p.grad.data /= self.world_size
482
483                def closure():
484                    o.zero_grad()
485                    output = m(x)
486                    loss = loss_fn(output, target)
487                    loss.backward()
488                    return loss
489
490                loss = o.step(closure=closure)
491
492                self.assertEqual(loss, torch.tensor(error))
493                self.assertEqual(m.weight, torch.tensor([[1.1]]))
494                self.assertEqual(m.bias, torch.tensor([2.1]))
495
496    @common_distributed.skip_if_no_gpu
497    def test_lr_scheduler(self):
498        """Check that a normal PyTorch ``lr_scheduler`` is usable with
499        ZeroRedundancyOptimizer."""
500        self.dist_init(self.rank)
501        x = torch.tensor([1.0], device=self.device, requires_grad=True)
502        x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
503        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
504        o2 = torch.optim.SGD([x2], lr=0.01)
505        s = torch.optim.lr_scheduler.StepLR(o, 1)
506        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
507        for _ in range(5):
508            x.backward()
509            o.zero_grad()
510            o.step()
511            s.step()
512            x2.backward()
513            o2.zero_grad()
514            o2.step()
515            s2.step()
516            self.assertEqual(x, x2)
517
518    def test_sharding(self):
519        """
520        Check ZeroRedundancyOptimizer's parameter sharding at construction
521        time.
522
523        NOTE: The correctness of this test depends on the ZeRO implementation
524        using the sorted-greedy partitioning algorithm. For details, see
525        ``ZeroRedundancyOptimizer._partition_parameters()`` in
526        zero_redundancy_optimizer.py.
527        """
528        self.dist_init(self.rank)
529        LR = 0.01
530        sizes = [9, 7, 5, 3]
531        params = []
532        for size in sizes * self.world_size:
533            params.append(torch.rand(size, 1))
534        o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
535        self.assertEqual(
536            sum(x.numel() for x in o.optim.param_groups[0]["params"]),
537            sum(sizes),
538        )
539
540    def test_add_param_group(self):
541        """Check that ZeroRedundancyOptimizer properly handles adding a new
542        parameter group a posteriori and that all ranks get a shard of the
543        contained parameters.
544
545        NOTE: The correctness of this test depends on the ZeRO implementation
546        using the sorted-greedy partitioning algorithm. For details, see
547        ``ZeroRedundancyOptimizer._partition_parameters()`` in
548        zero_redundancy_optimizer.py.
549        """
550        self.dist_init(self.rank)
551        LR = 0.01
552
553        # Test with all parameters trainable to begin with
554        def all_trainable():
555            params = []
556            sizes = [9, 7, 5, 3]
557            sizes_world = sizes * self.world_size
558            for size in sizes_world[:-1]:
559                params.append(torch.rand(size, 1))
560
561            # Make sure that the params are trainable so that they are factored
562            # into the size-based parameter partitioning
563            for p in params:
564                p.requires_grad = True
565
566            o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
567            self.assertEqual(len(o.param_groups), 1)
568            o.add_param_group({"params": [torch.rand(3, 1)]})
569            # Verify that new group is added to the correct partition, making
570            # all partitions have the same elements
571            self.assertEqual(len(o.param_groups), 2)
572            self.assertEqual(
573                sum(x.numel() for g in o.optim.param_groups for x in g["params"]),
574                sum(sizes),
575            )
576            self.assertEqual(len(o.optim.param_groups), 2)
577
578        # Test a pathological config with a first big non-trainable param
579        def some_trainable():
580            params = []
581            for size in [100, 3, 5, 2, 6, 4]:
582                params.append(torch.rand(size, 1))
583
584            # Make sure that all but the first param are trainable so that they
585            # are factored into the size-based parameter partitioning
586            for p in params[1:]:
587                p.requires_grad = True
588
589            o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
590            self.assertEqual(len(o.param_groups), 1)
591            o.add_param_group({"params": [torch.rand(3, 1)]})
592            self.assertEqual(len(o.param_groups), 2)
593            self.assertEqual(len(o.optim.param_groups), 2)
594
595        all_trainable()
596        some_trainable()
597
598    @common_distributed.skip_if_no_gpu
599    def test_multiple_param_groups(self):
600        """
601        Check parity between constructing ZeRO with multiple parameter groups
602        upfront versus adding parameter groups to ZeRO after construction
603        versus a non-sharded optimizer.
604        """
605        self.dist_init(self.rank)
606        BATCH_SIZE, NUM_ITERS = 8, 3
607        INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
608        WD, LR = 0.01, 0.01
609        model1 = torch.nn.Sequential(
610            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
611            torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
612            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
613        )
614        model2 = copy.deepcopy(model1)
615        model3 = copy.deepcopy(model1)
616        model1 = model1.to(self.device)
617        model2 = model2.to(self.device)
618        model3 = model3.to(self.device)
619        inputs = [
620            torch.randn(BATCH_SIZE, INPUT_DIM).to(self.device) for _ in range(NUM_ITERS)
621        ]
622        # Construct `optim1` with both parameter groups upfront
623        optim1 = ZeroRedundancyOptimizer(
624            [
625                {"params": [l.weight for l in model1], "weight_decay": 0.0},
626                {"params": [l.bias for l in model1], "weight_decay": WD},
627            ],
628            optimizer_class=AdamW,
629            lr=LR,
630        )
631        # Construct `optim2` by adding the second parameter after
632        optim2 = ZeroRedundancyOptimizer(
633            [l.weight for l in model2],
634            optimizer_class=AdamW,
635            lr=LR,
636            weight_decay=0.0,
637        )
638        optim2.add_param_group({"params": [l.bias for l in model2], "weight_decay": WD})
639        # Construct `optim3` as a non-sharded optimizer
640        optim3 = AdamW(
641            [
642                {"params": [l.weight for l in model3], "weight_decay": 0.0},
643                {"params": [l.bias for l in model3], "weight_decay": WD},
644            ],
645            lr=LR,
646        )
647        # Check parity over a few iterations
648        for input in inputs:
649            for model, optim in (
650                (model1, optim1),
651                (model2, optim2),
652                (model3, optim3),
653            ):
654                optim.zero_grad()
655                out = model(input)
656                loss = out.sum()
657                loss.backward()
658                optim.step()
659            for layer1, layer2, layer3 in zip(model1, model2, model3):
660                torch.testing.assert_close(layer1.weight, layer2.weight)
661                torch.testing.assert_close(layer1.weight, layer3.weight)
662                torch.testing.assert_close(layer1.bias, layer2.bias)
663                torch.testing.assert_close(layer1.bias, layer3.bias)
664
665    @common_distributed.skip_if_no_gpu
666    @common_distributed.skip_if_rocm_multiprocess
667    def test_collect_shards(self):
668        """Check the state consolidation mechanism and the state dict exposed
669        by ZeroRedundancyOptimizer."""
670        self.dist_init(self.rank)
671        LR = 1e-3
672        MOMENTUM = 0.99
673        BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
674        REFERENCE_RANK = 0
675        target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=self.device)
676        inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=self.device)
677        model = torch.nn.Sequential(
678            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
679            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
680        ).to(self.device)
681        loss_fn = torch.nn.L1Loss()
682        loss_fn.to(self.device)
683        optimizer = ZeroRedundancyOptimizer(
684            model.parameters(),
685            optimizer_class=SGD,
686            lr=LR,
687            momentum=MOMENTUM,  # ensure there exists state to shard
688        )
689
690        def closure():
691            optimizer.zero_grad()
692            output = model(inputs)
693            loss = loss_fn(output, target)
694            loss.backward()
695            return loss
696
697        # Run a dummy step so that the optimizer state dict exists
698        _ = optimizer.step(closure=closure)
699
700        # Get the optimizer state on the reference rank
701        optimizer.consolidate_state_dict(to=REFERENCE_RANK)
702        if self.rank == REFERENCE_RANK:
703            # Check that the state has the correct size
704            optimizer_state_dict = optimizer.state_dict()
705            self.assertEqual(
706                len(optimizer_state_dict["state"]),
707                len(list(model.parameters())),
708            )
709        else:
710            optimizer_state_dict = {}
711
712        # Load the optimizer state on all ranks without any exceptions
713        optimizer_state_dict = _broadcast_object(
714            optimizer_state_dict,
715            src_rank=REFERENCE_RANK,
716            group=dist.group.WORLD,
717            device=self.device,
718        )
719        optimizer.load_state_dict(optimizer_state_dict)
720
721    def test_nondefault_process_group(self):
722        """Check that ZeroRedundancyOptimizer works with a non-default process
723        group consisting only of even ranks."""
724        # Skip the test if below the minimum world size since then the test is
725        # trivial
726        MIN_WORLD_SIZE = 4
727        if self.world_size < MIN_WORLD_SIZE:
728            common_distributed.logger.info(
729                "Skipping `test_nondefault_process_group()` since world size "
730                "of %s is less than %s",
731                self.world_size,
732                MIN_WORLD_SIZE,
733            )
734            return
735        BACKEND = dist.Backend.GLOO
736        self.dist_init(self.rank, self.world_size, BACKEND)
737        # Use GPU if enough are available, or fall back to CPU otherwise, which
738        # is fine since Gloo backend supports both
739        if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
740            device = torch.device(self.rank)
741        else:
742            device = torch.device("cpu")
743        # Create a new process group consisting of the even ranks to exercise
744        # the case where the global and local ranks do not necessarily match
745        subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
746        process_group = dist.new_group(
747            ranks=subgroup_ranks,
748            backend=BACKEND,
749        )
750        # Ranks not participating in the new process group are no longer needed
751        if self.rank not in subgroup_ranks:
752            return
753
754        # Set different seeds across ranks so that each rank gets different
755        # training data and hence the model sync check is meaningful
756        torch.manual_seed(self.rank)
757        np.random.seed(self.rank)
758
759        EPOCHS, BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 3, 20, 10, 5
760        LR = 1e-3
761        MOMENTUM = 0.99
762        REFERENCE_RANK = 0
763        assert (
764            REFERENCE_RANK in subgroup_ranks
765        ), "Reference rank must be in the new process group"
766        loss_fn = torch.nn.L1Loss().to(device)
767
768        def check(optimizer):
769            for _ in range(EPOCHS):
770                target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=device)
771                inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=device)
772
773                def closure():
774                    optimizer.zero_grad()
775                    output = model(inputs)
776                    loss = loss_fn(output, target)
777                    loss /= self.world_size
778                    loss.backward()
779                    dist.all_reduce(loss, group=process_group)
780                    return loss
781
782                _ = optimizer.step(closure=closure)
783
784                # Check that the parameters match across ranks after a step
785                for pg in optimizer.param_groups:
786                    for p in pg["params"]:
787                        receptacle = (
788                            [p.clone() for _ in subgroup_ranks]
789                            if self.rank == REFERENCE_RANK
790                            else []
791                        )
792                        dist.gather(
793                            p,
794                            receptacle,
795                            dst=REFERENCE_RANK,
796                            group=process_group,
797                        )
798                        if self.rank == REFERENCE_RANK:
799                            reference_param = receptacle[0]
800                            for param in receptacle[1:]:
801                                torch.testing.assert_close(
802                                    reference_param,
803                                    param,
804                                    msg="Models differ between ranks",
805                                )
806
807        model = torch.nn.Sequential(
808            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
809            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
810        ).to(device)
811        optimizer = ZeroRedundancyOptimizer(
812            model.parameters(),
813            optimizer_class=SGD,
814            lr=LR,
815            momentum=MOMENTUM,  # ensure there exists state to shard
816            process_group=process_group,
817        )
818        check(optimizer)
819
820    @common_distributed.skip_if_no_gpu
821    @parametrize(
822        "optimizer_class_str",
823        ["Adam", "AdamW", "SGD"],
824        # Use string to appease the internal test name parser
825    )
826    @parametrize(
827        "maximize",
828        [False, True],
829    )
830    def test_local_optimizer_parity(
831        self,
832        optimizer_class_str: str,
833        maximize: bool,
834    ):
835        """When combined with DDP, check that a local optimizer gives the same
836        results as wrapping that optimizer with ZeroRedundancyOptimizer."""
837        self.dist_init(self.rank)
838        BATCHES = 20
839        BATCH_SIZE = 64
840        LR = 1e-3
841        INPUT_DIM = 2
842        HIDDEN_DIM = 3
843        OUTPUT_DIM = 3
844        torch.manual_seed(self.rank)
845        np.random.seed(self.rank)
846        if optimizer_class_str == "Adam":
847            optimizer_class = torch.optim.Adam
848        elif optimizer_class_str == "AdamW":
849            optimizer_class = torch.optim.AdamW
850        elif optimizer_class_str == "SGD":
851            optimizer_class = torch.optim.SGD
852        else:
853            assert 0, f"Unsupported optimizer class: {optimizer_class_str}"
854
855        with self.context:
856            # Define a base model with a different buffer for each rank
857            model = torch.nn.Sequential(
858                torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
859                torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
860                torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
861            ).to(self.device)
862            model.test_buffer = torch.nn.Buffer(
863                torch.ones((1), device=self.device) * self.rank,
864            )
865            # Define models/optimizers for DDP with ZeRO and DDP with local
866            # optimizer
867            defaults = {"maximize": True} if maximize else {}
868            sharded_optimizer = ZeroRedundancyOptimizer(
869                params=model.parameters(),
870                optimizer_class=optimizer_class,
871                lr=LR,
872                **defaults,
873            )
874            sharded_ddp_model = DDP(
875                module=model,
876                device_ids=[self.rank],
877                broadcast_buffers=True,
878                find_unused_parameters=True,
879            )
880            local_model = copy.deepcopy(model).to(self.device)
881            ddp_optimizer = optimizer_class(
882                local_model.parameters(),
883                lr=LR,
884                **defaults,
885            )
886            ddp_model = DDP(
887                local_model,
888                device_ids=[self.rank],
889                broadcast_buffers=True,
890                find_unused_parameters=True,
891            )
892            # Check that the model is properly synchronized between ranks
893            # at construction time
894            self._check_same_model_params(
895                sharded_ddp_model,
896                ddp_model,
897                "Models differ from the start",
898            )
899
900            def check_step():
901                input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM))
902
903                def closure_ddp(input_tensor=input_tensor):
904                    ddp_optimizer.zero_grad()
905                    ddp_loss = ddp_model(input_tensor).abs().sum()
906                    ddp_loss.backward()
907                    return ddp_loss
908
909                def closure_sharded(input_tensor=input_tensor):
910                    sharded_optimizer.zero_grad()
911                    sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
912                    sharded_loss.backward()
913                    return sharded_loss
914
915                loss_ddp = cast(
916                    torch.Tensor,
917                    ddp_optimizer.step(closure=closure_ddp),
918                )
919                loss_sharded_optim = cast(
920                    torch.Tensor,
921                    sharded_optimizer.step(closure=closure_sharded),
922                )
923                torch.testing.assert_close(
924                    loss_ddp,
925                    loss_sharded_optim,
926                    msg="Losses differ between local optimizer and ZeRO",
927                )
928                self._check_same_model_params(
929                    sharded_ddp_model,
930                    ddp_model,
931                    "Models differ after a step",
932                )
933
934            # Check that parity is maintained
935            for i in range(BATCHES):
936                check_step()
937                # For the second half of batches, change the parameter
938                # trainability to further test parity
939                if i > BATCHES // 2:
940                    next(ddp_model.parameters()).requires_grad = bool(i % 2)
941                    next(sharded_ddp_model.parameters()).requires_grad = bool(i % 2)
942
943            # Check that the `state_dict` checkpoints are compatible between
944            # the local optimizer and ZeRO
945            REFERENCE_RANK = 0
946            # - Get states
947            ddp_state_dict = ddp_optimizer.state_dict()
948            sharded_optimizer.consolidate_state_dict(to=REFERENCE_RANK)
949            sharded_optim_state_dict = [
950                sharded_optimizer.state_dict() if self.rank == REFERENCE_RANK else {}
951            ]
952            dist.broadcast_object_list(
953                sharded_optim_state_dict,
954                src=REFERENCE_RANK,
955                group=dist.group.WORLD,
956            )
957            sharded_optim_state_dict = sharded_optim_state_dict[0]
958
959            # - Cross-load the states
960            # Run one step and check that the models are still the same
961            ddp_state_dict_ref = copy.deepcopy(ddp_state_dict)
962            ddp_optimizer.load_state_dict(sharded_optim_state_dict)
963            sharded_optimizer.load_state_dict(ddp_state_dict)
964            check_step()
965
966            # - Reload their respective states
967            # Run one step and check that the models are still the same
968            ddp_optimizer.load_state_dict(ddp_state_dict_ref)
969            sharded_optimizer.load_state_dict(sharded_optim_state_dict)
970            check_step()
971
972    def _test_zero_join(self, device):
973        """Check that the ZeRO join hook allows training with uneven inputs
974        when using the given device."""
975        NUM_INPUTS = 3
976        NUM_EPOCHS = 2
977        LR = 0.01
978        torch.manual_seed(0)
979        torch.cuda.manual_seed(0)
980
981        rank = self.rank
982        world_size = self.world_size
983        is_gpu = device.type == "cuda"
984        backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
985        self.dist_init(rank, world_size, backend)
986
987        model = torch.nn.Sequential(
988            torch.nn.Linear(2, 3),
989            torch.nn.Linear(3, 3),
990            torch.nn.Linear(3, 3),
991        )
992        model.to(device)
993
994        # DDP ensures correct gradients in data parallel training, so DDP with
995        # local optimizers on uneven inputs should be equivalent to ZeRO on
996        # uneven inputs with gradients being manually set
997        ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
998        local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
999        zero_model = copy.deepcopy(model)
1000        zero_model.to(device)
1001        zero_optim = ZeroRedundancyOptimizer(
1002            zero_model.parameters(),
1003            torch.optim.Adam,
1004            lr=LR,
1005        )
1006        loss_fn = torch.nn.MSELoss()
1007
1008        # Use uneven inputs: rank i has i extra inputs
1009        inputs = [torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank)]
1010        labels = torch.randn(20, 3).to(device)
1011
1012        # Save the gradients and parameters from DDP as the ground truth; do
1013        # so on the last-joining rank (in this case, the largest rank)
1014        grads_at_each_iter = []
1015        params_at_each_iter = []
1016        with ddp_model.join():
1017            for _ in range(NUM_EPOCHS):
1018                for input in inputs:
1019                    output = ddp_model(input)
1020                    loss_fn(output, labels).backward()
1021                    if rank == world_size - 1:
1022                        grads = []
1023                        for p in ddp_model.parameters():
1024                            grads.append(p.grad.detach().clone().to(device))
1025                    local_optim.step()
1026                    if rank == world_size - 1:
1027                        params = []
1028                        for p in ddp_model.parameters():
1029                            params.append(p.detach().clone().to(device))
1030                        grads_at_each_iter.append(grads)
1031                        params_at_each_iter.append(params)
1032
1033        # Broadcast the saved gradients and parameters to all of the other
1034        # ranks (which joined early)
1035        grads_and_params = [grads_at_each_iter, params_at_each_iter]
1036        grads_and_params = _broadcast_object(
1037            grads_and_params,
1038            src_rank=world_size - 1,
1039            group=dist.group.WORLD,
1040            device=device,
1041        )
1042        grads_at_each_iter = grads_and_params[0]
1043        params_at_each_iter = grads_and_params[1]
1044        # TODO: Replace this `_broadcast_object` with `broadcast_object_list`
1045        # once the latter supports loading to the destination device instead
1046        # of the source device
1047
1048        # A process must still set the remaining gradients after joining, so we
1049        # define a join hook to do this before the ZeRO join hook
1050        class _JoinGradInfo:
1051            def __init__(self, grads):
1052                self.grads = grads  # remaining gradients to set (in order)
1053                self.index = 0
1054
1055        class _SetGradsJoinHook(JoinHook):
1056            def __init__(self, zero_optim, grads):
1057                zero_optim._join_grad_info = _JoinGradInfo(grads)
1058                self.zero = zero_optim
1059                super().__init__()
1060
1061            def main_hook(self):
1062                join_grad_info = self.zero._join_grad_info
1063                grads = self.zero._join_grad_info.grads[join_grad_info.index]
1064                join_grad_info.index += 1
1065                for p, grad in zip(self.zero._all_params, grads):
1066                    p.grad = grad.detach().clone().to(device)
1067
1068        class _GradientSetter(Joinable):
1069            def __init__(self) -> None:
1070                super().__init__()
1071
1072            def join_hook(self, **kwargs):
1073                assert "zero_optim" in kwargs
1074                assert "grads" in kwargs
1075                zero_optim = kwargs["zero_optim"]
1076                grads = kwargs["grads"]
1077                return _SetGradsJoinHook(zero_optim, grads)
1078
1079            @property
1080            def join_device(self):
1081                return device
1082
1083            @property
1084            def join_process_group(self):
1085                return dist.group.WORLD
1086
1087        num_grads_after_joining = NUM_EPOCHS * (world_size - rank - 1)
1088        grads = grads_at_each_iter[-num_grads_after_joining:]
1089        gradient_setter = _GradientSetter()
1090        iter = 0
1091        with Join(
1092            [gradient_setter, zero_optim],
1093            zero_optim=zero_optim,
1094            grads=grads,
1095        ):
1096            for _ in range(NUM_EPOCHS):
1097                for input in inputs:
1098                    # Notify join context that this process has not joined
1099                    Join.notify_join_context(gradient_setter)
1100                    # Set gradients manually
1101                    for p, grad in zip(
1102                        zero_model.parameters(),
1103                        grads_at_each_iter[iter],
1104                    ):
1105                        p.grad = grad.detach().clone().to(device)
1106                    # Perform optimizer step and check parity
1107                    zero_optim.step()
1108                    for p, ddp_p in zip(
1109                        zero_model.parameters(),
1110                        params_at_each_iter[iter],
1111                    ):
1112                        torch.testing.assert_close(
1113                            p,
1114                            ddp_p,
1115                            msg="Parameters differ between using ZeRO and "
1116                            "local optimizer",
1117                        )
1118                    iter += 1
1119
1120    @common_distributed.requires_nccl()
1121    @common_distributed.skip_if_no_gpu
1122    def test_zero_join_gpu(self):
1123        """Check that the ZeRO join hook allows training with uneven inputs
1124        on GPU."""
1125        self._test_zero_join(self.device)
1126
1127    @common_distributed.requires_gloo()
1128    def test_zero_join_cpu(self):
1129        """Check that the ZeRO join hook allows training with uneven inputs
1130        on CPU."""
1131        self._test_zero_join(torch.device("cpu"))
1132
1133    def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
1134        # Use two processes each with two GPUs
1135        assert self.rank < 2
1136        NUM_EPOCHS = 2
1137        NUM_INPUTS = 4
1138        LR = 0.01
1139        torch.manual_seed(0)
1140        torch.cuda.manual_seed(0)
1141
1142        class ModelParallelModel(torch.nn.Module):
1143            def __init__(self, dev0, dev1):
1144                super().__init__()
1145                self.dev0 = dev0
1146                self.dev1 = dev1
1147                self.net0 = torch.nn.Linear(10, 10).to(dev0)
1148                self.relu = torch.nn.ReLU()
1149                self.net1 = torch.nn.Linear(10, 5).to(dev1)
1150
1151            def forward(self, x):
1152                x = x.to(self.dev0)
1153                x = self.relu(self.net0(x))
1154                x = x.to(self.dev1)
1155                return self.net1(x)
1156
1157        class LocalModel(torch.nn.Module):
1158            def __init__(self) -> None:
1159                super().__init__()
1160                self.net0 = torch.nn.Linear(10, 10)
1161                self.relu = torch.nn.ReLU()
1162                self.net1 = torch.nn.Linear(10, 5)
1163
1164            def forward(self, x):
1165                return self.net1(self.relu(self.net0(x)))
1166
1167        dev0 = torch.device(2 * self.rank)
1168        dev1 = torch.device(2 * self.rank + 1)
1169        mp_model = ModelParallelModel(dev0, dev1)
1170        ddp_model = DDP(mp_model)
1171        local_model = LocalModel().to(dev0)
1172
1173        # Ensure the parameters are the same across the two models
1174        def copy_param(p):
1175            return torch.nn.Parameter(p.detach().clone().to(dev0))
1176
1177        local_model.net0.weight = copy_param(mp_model.net0.weight)
1178        local_model.net0.bias = copy_param(mp_model.net0.bias)
1179        local_model.net1.weight = copy_param(mp_model.net1.weight)
1180        local_model.net1.bias = copy_param(mp_model.net1.bias)
1181
1182        # Compare parity between DDP with model parallelism using ZeRO and
1183        # a local model using a local optimizer
1184        zero_optim = ZeroRedundancyOptimizer(
1185            ddp_model.parameters(),
1186            optimizer_class=torch.optim.Adam,
1187            parameters_as_bucket_view=parameters_as_bucket_view,
1188            lr=LR,
1189        )
1190        local_optim = torch.optim.Adam(local_model.parameters(), lr=LR)
1191        inputs = [torch.randn(20, 10).to(dev0) for _ in range(NUM_INPUTS)]
1192
1193        for _ in range(NUM_EPOCHS):
1194            for input in inputs:
1195
1196                def closure_local():
1197                    local_optim.zero_grad()
1198                    local_loss = local_model(input).abs().sum()
1199                    local_loss.backward()
1200                    return local_loss
1201
1202                def closure_ddp():
1203                    zero_optim.zero_grad()
1204                    ddp_loss = ddp_model(input).abs().sum()
1205                    ddp_loss.backward()
1206                    return ddp_loss
1207
1208                local_loss = cast(torch.Tensor, local_optim.step(closure=closure_local))
1209                ddp_loss = cast(torch.Tensor, zero_optim.step(closure=closure_ddp))
1210
1211                # Increased tolerances are needed to pass when using TF32
1212                # See: https://github.com/pytorch/pytorch/issues/67764
1213                torch.testing.assert_close(
1214                    local_loss.cpu(),
1215                    ddp_loss.cpu(),
1216                    rtol=1e-03,
1217                    atol=1e-08,
1218                ), "Losses differ between local optimizer and ZeRO"
1219
1220                for local_p, ddp_p in zip(
1221                    local_model.parameters(), ddp_model.parameters()
1222                ):
1223                    torch.testing.assert_close(
1224                        local_p.cpu(),
1225                        ddp_p.cpu(),
1226                        rtol=1e-03,
1227                        atol=1e-04,
1228                    ), "Models differ after a step"
1229
1230    @common_distributed.skip_if_lt_x_gpu(4)
1231    @parametrize(
1232        "parameters_as_bucket_view",
1233        [False, True],
1234    )
1235    def test_zero_model_parallel(
1236        self,
1237        parameters_as_bucket_view: bool,
1238    ):
1239        """Check that ZeRO works with model parallelism where the model's
1240        layers are assigned to different devices."""
1241        if self.rank >= 2:
1242            return
1243        self.dist_init(self.rank, world_size=2)
1244        self._test_zero_model_parallel(parameters_as_bucket_view)
1245
1246    def _test_ddp_zero_overlap(
1247        self,
1248        device,
1249        hook_constructor,
1250        gradient_as_bucket_view,
1251        static_graph,
1252        **kwargs,
1253    ):
1254        SGD_LR = 0.01
1255        SGD_MOMENTUM = 0.9
1256        SGD_WEIGHT_DECAY = 0.001
1257        NUM_INPUTS = 5
1258        torch.manual_seed(0)
1259        torch.cuda.manual_seed(0)
1260
1261        rank = self.rank
1262        is_gpu = device.type == "cuda"
1263        if is_gpu:
1264            torch.cuda.set_device(device)
1265        models_to_test = [
1266            (
1267                torch.nn.Sequential(
1268                    torch.nn.Linear(1000, 2000),
1269                    torch.nn.Linear(2000, 500),
1270                ),
1271                [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)],
1272            )
1273        ]
1274        if HAS_TORCHVISION:
1275            models_to_test.append(
1276                (
1277                    torchvision.models.resnet50(),
1278                    [torch.randn(1, 3, 3, 1000).to(device) for _ in range(NUM_INPUTS)],
1279                )
1280            )
1281        for model, inputs in models_to_test:
1282            # Enable determinism in cudnn operators
1283            with torch.backends.cudnn.flags(
1284                enabled=True, deterministic=True, benchmark=False
1285            ):
1286                device_ids = [rank] if is_gpu else None
1287                # Set up the DDP model overlapping with ZeRO
1288                ddp_model_overlap = DDP(
1289                    copy.deepcopy(model).to(device),
1290                    device_ids=device_ids,
1291                    gradient_as_bucket_view=gradient_as_bucket_view,
1292                )
1293                if static_graph:
1294                    ddp_model_overlap._set_static_graph()
1295                zero_optim = ZeroRedundancyOptimizer(
1296                    ddp_model_overlap.parameters(),
1297                    optimizer_class=torch.optim.SGD,
1298                    overlap_with_ddp=True,
1299                    lr=SGD_LR,
1300                    momentum=SGD_MOMENTUM,
1301                    weight_decay=SGD_WEIGHT_DECAY,
1302                )
1303                ddp_model_overlap.register_comm_hook(
1304                    None,
1305                    hook_constructor(
1306                        allreduce_hook,
1307                        ddp_model_overlap,
1308                        zero_optim,
1309                        **kwargs,
1310                    ),
1311                )
1312
1313                # Set up the DDP model with local optimizer
1314                ddp_model_local = DDP(
1315                    copy.deepcopy(model).to(device),
1316                    device_ids=device_ids,
1317                    gradient_as_bucket_view=gradient_as_bucket_view,
1318                )
1319                if static_graph:
1320                    ddp_model_local._set_static_graph()
1321                local_optim = torch.optim.SGD(
1322                    ddp_model_local.parameters(),
1323                    lr=SGD_LR,
1324                    momentum=SGD_MOMENTUM,
1325                    weight_decay=SGD_WEIGHT_DECAY,
1326                )
1327
1328                # Check that the parameters match initially
1329                for p1, p2 in zip(
1330                    ddp_model_overlap.parameters(), ddp_model_local.parameters()
1331                ):
1332                    self.assertEqual(p1, p2)
1333
1334                # Save the parameters to ensure they were updated
1335                init_params_overlap = copy.deepcopy(
1336                    list(ddp_model_overlap.parameters())
1337                )
1338
1339                # Ensure that this test runs independently
1340                dist.barrier()
1341
1342                # Run the DDP model overlapping with ZeRO
1343                # NOTE: Overlapping currently requires 2 or 3 warmup iterations
1344                # to ensure DDP buckets have been rebuilt (depending on the
1345                # value of `static_graph`)
1346                num_warmup_inputs = 2 if not static_graph else 3
1347                for input in inputs[:num_warmup_inputs]:
1348                    output = ddp_model_overlap(input)
1349                    loss = output.sum()
1350                    loss.backward()
1351                for input in inputs:
1352                    zero_optim.zero_grad()
1353                    output = ddp_model_overlap(input)
1354                    loss = output.sum()
1355                    loss.backward()
1356
1357                # Run the DDP model with local optimizer
1358                for input in inputs:
1359                    local_optim.zero_grad()
1360                    output = ddp_model_local(input)
1361                    loss = output.sum()
1362                    loss.backward()
1363                    local_optim.step()
1364                dist.barrier()
1365
1366                # Check that the parameters are equal
1367                for p1, p2 in zip(
1368                    ddp_model_overlap.parameters(), ddp_model_local.parameters()
1369                ):
1370                    self.assertEqual(p1, p2)
1371
1372                # Check that the parameters were updated
1373                self.assertNotEqual(
1374                    init_params_overlap,
1375                    list(ddp_model_overlap.parameters()),
1376                )
1377
1378                # Ensure that this test runs independently
1379                dist.barrier()
1380
1381    # NOTE: The test is skipped if using Windows since functional optimizers
1382    # are not currently supported.
1383    @common_distributed.skip_if_win32()
1384    @common_distributed.requires_nccl()
1385    @common_distributed.skip_if_no_gpu
1386    @common_distributed.skip_if_rocm_multiprocess
1387    @parametrize(
1388        "use_gpu",
1389        [True],
1390        # Add `False` once the Gloo sync issue causing hangs is fixed
1391        # See: https://github.com/pytorch/pytorch/issues/62300
1392    )
1393    @parametrize(
1394        "use_interleaved_hook",
1395        [False, True],
1396    )
1397    @parametrize(
1398        "gradient_as_bucket_view",
1399        [False, True],
1400    )
1401    @parametrize(
1402        "static_graph",
1403        [False, True],
1404    )
1405    @parametrize(
1406        "shard_buckets",
1407        [False, True],
1408    )
1409    def test_ddp_zero_overlap(
1410        self,
1411        use_gpu: bool,
1412        use_interleaved_hook: bool,
1413        gradient_as_bucket_view: bool,
1414        static_graph: bool,
1415        shard_buckets: bool,
1416    ):
1417        """
1418        Check that overlapping DDP with ZeRO using the given method determined
1419        by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
1420        and DDP arguments achieves parity with DDP using a local optimizer.
1421        """
1422        device = torch.device(self.rank) if use_gpu else torch.device("cpu")
1423        backend = _get_backend_for_tests()
1424        self.dist_init(self.rank, self.world_size, backend)
1425        hook_constructor = (
1426            hook_with_zero_step
1427            if not use_interleaved_hook
1428            else hook_with_zero_step_interleaved
1429        )
1430
1431        self._test_ddp_zero_overlap(
1432            device,
1433            hook_constructor,
1434            gradient_as_bucket_view,
1435            static_graph,
1436            shard_buckets=shard_buckets,
1437        )
1438
1439
1440instantiate_parametrized_tests(TestZeroRedundancyOptimizerSingleRank)
1441instantiate_parametrized_tests(TestZeroRedundancyOptimizerDistributed)
1442
1443if __name__ == "__main__":
1444    # ! unittest should not be used here, else the tests are not properly registered
1445    run_tests()
1446