• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import sys
4from typing import Optional
5
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from torch import distributed as dist
10from torch.distributed.algorithms._comm_hooks import default_hooks
11from torch.distributed.distributed_c10d import _get_default_group
12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
13from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
14from torch.distributed.fsdp.wrap import ModuleWrapPolicy
15from torch.testing._internal.common_distributed import (
16    requires_nccl,
17    requires_nccl_version,
18    skip_but_pass_in_sandcastle_if,
19    skip_if_lt_x_gpu,
20)
21from torch.testing._internal.common_fsdp import FSDPTest
22from torch.testing._internal.common_utils import (
23    instantiate_parametrized_tests,
24    parametrize,
25    run_tests,
26)
27
28
29if not dist.is_available():
30    print("Distributed not available, skipping tests", file=sys.stderr)
31    sys.exit(0)
32
33# bfloat16 is only supported by CUDA 11+
34BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
35    torch.version.cuda is not None or torch.version.hip is not None
36)
37
38
39class Net(nn.Module):
40    def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
41        # to ensure determinism
42        torch.manual_seed(0)
43        torch.cuda.manual_seed(0)
44        super().__init__()
45
46        if has_wrapping:
47            self.net = FSDP(
48                nn.Sequential(
49                    nn.Linear(8, 16),
50                    nn.ReLU(),
51                    FSDP(
52                        nn.Linear(16, 8),
53                        device_id=torch.cuda.current_device(),
54                        sharding_strategy=sharding_strategy,
55                        mixed_precision=mixed_precision,
56                    ),
57                ),
58                device_id=torch.cuda.current_device(),
59                sharding_strategy=sharding_strategy,
60                mixed_precision=mixed_precision,
61            )
62        else:
63            self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8))
64
65        self.out = nn.Linear(8, 4)
66
67    def forward(self, x):
68        return self.out(F.relu(self.net(x)))
69
70
71class DummyState:
72    __slots__ = ["process_group", "noise"]
73
74    def __init__(self, process_group: dist.ProcessGroup, noise: int):
75        self.process_group = process_group
76        self.noise = noise
77
78
79class DummyHook:
80    def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor):
81        """
82        This communication hook is for illustration and testing purpose only.
83        This communication hook is used during FSDP ``NO_SHARD`` training. It adds some noise to
84        the provided ``grad`` parameter and uses ``all_reduce`` to communicate full, flattened,
85        unsharded gradient.
86        """
87        grad.add_(state.noise)
88        dist.all_reduce(grad, group=state.process_group)
89
90    def custom_reduce_scatter(self, output, input, group=None):
91        """
92        This function is for illustrative purpose only.
93        It is meant to implement a custom reduce-scatter
94        of a flattened tensor to all processes in a group.
95        Currently a no-op.
96        """
97
98    def dummy_hook_for_sharded_fsdp(
99        self, state: DummyState, grad: torch.Tensor, output: torch.Tensor
100    ):
101        """
102        This communication hook is for illustration and testing purposes only.
103        This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training.
104        It adds some noise to the provided ``grad`` parameter, uses
105        ``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``.
106        """
107        grad.add_(state.noise)
108        self.custom_reduce_scatter(output, grad, group=state.process_group)
109
110
111class TestCommunicationHooks(FSDPTest):
112    @skip_if_lt_x_gpu(2)
113    @parametrize(
114        "sharding_strategy",
115        [
116            ShardingStrategy.NO_SHARD,
117            ShardingStrategy.FULL_SHARD,
118            ShardingStrategy.SHARD_GRAD_OP,
119        ],
120    )
121    def test_default_communication_hook_behavior(
122        self, sharding_strategy: Optional[ShardingStrategy]
123    ):
124        """
125        Tests FSDP's default communication hook's behavior and correctness.
126        This test creates a simple linear net with weight shape  ``1 X N``,
127        where ``N`` is the number of workers.
128        For sharded cases, each worker gets 1 element of the weight parameter. This test
129        checks that after backward, each worker has a proper value in its chunk of
130        the gradient, or the whole gradient on every worker is equal to an expected value.
131
132        Arguments:
133            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
134        """
135        out_dim = self.world_size
136        net = torch.nn.Linear(1, out_dim, bias=False)
137        inpt = torch.tensor([self.rank]).float().cuda(self.rank)
138
139        net_default_hook = FSDP(
140            net,
141            device_id=torch.cuda.current_device(),
142            sharding_strategy=sharding_strategy,
143        ).to(self.rank)
144
145        # Check that by default, `_comm_hook` is None
146        for entry in FSDP.fsdp_modules(net_default_hook):
147            self.assertEqual(entry._comm_hook, None)
148
149        for _ in range(4):
150            # Clear gradients
151            net_default_hook.zero_grad()
152            loss = net_default_hook(inpt).sum()
153            loss.backward()
154
155            # For each worker, the gradient on the weight should be worker_rank.
156            grad = net_default_hook.params[0].grad
157            expected_grad = (
158                sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
159            )
160            # Verify default hook produces expected gradients
161            self.assertEqual(
162                grad[0].item(),
163                expected_grad,
164                msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}",
165            )
166
167    def _get_submodules(self, fsdp_net):
168        return [
169            submodule
170            for submodule in FSDP.fsdp_modules(fsdp_net)
171            if not submodule.check_is_root()
172        ]
173
174    def _init_model(self, core, sharding_strategy, mixed_precision=None):
175        device = torch.device("cuda")
176        return FSDP(
177            core,
178            device_id=torch.cuda.current_device(),
179            sharding_strategy=sharding_strategy,
180            mixed_precision=mixed_precision,
181        ).to(device)
182
183    @skip_if_lt_x_gpu(2)
184    @parametrize("has_wrapping", [True, False])
185    @parametrize(
186        "sharding_strategy",
187        [
188            ShardingStrategy.NO_SHARD,
189            ShardingStrategy.FULL_SHARD,
190            ShardingStrategy.SHARD_GRAD_OP,
191        ],
192    )
193    def test_default_communication_hook_initialization(
194        self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
195    ):
196        """
197        Tests FSDP's communication hook interface behavior.
198
199        Arguments:
200            has_wrapping (bool): Configures wrapping of a module.
201            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
202        """
203
204        # Initialize a model
205        fsdp_model_with_hook = self._init_model(
206            Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
207            sharding_strategy=sharding_strategy,
208        )
209
210        # Check that by default, `_comm_hook` is None
211        for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
212            self.assertEqual(fsdp_module._comm_hook, None)
213
214        dummy_state = DummyState(process_group=None, noise=1234)
215        dummy_hook = (
216            DummyHook.dummy_hook_for_no_shard_fsdp
217            if sharding_strategy != ShardingStrategy.NO_SHARD
218            else DummyHook.dummy_hook_for_sharded_fsdp
219        )
220
221        fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
222
223        # Check that we can't register comm hook twice
224        with self.assertRaisesRegex(
225            AssertionError, "^A communication hook is already registered$"
226        ):
227            fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
228
229        # Check dummy hook was registered for the root and all submodules if any
230        for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
231            self.assertEqual(fsdp_module._comm_hook, dummy_hook)
232            self.assertEqual(fsdp_module._comm_hook_state, dummy_state)
233
234    @skip_if_lt_x_gpu(2)
235    @parametrize(
236        "sharding_strategy",
237        [
238            ShardingStrategy.NO_SHARD,
239            ShardingStrategy.FULL_SHARD,
240            ShardingStrategy.SHARD_GRAD_OP,
241        ],
242    )
243    def test_registering_hook_non_root(
244        self, sharding_strategy: Optional[ShardingStrategy]
245    ):
246        """
247        Tests FSDP's communication hook registering for submodules.
248        Make sure it can't be registered for non-root submodules.
249        Currently tests only ``NO_SHARD`` strategy.
250
251        Arguments:
252            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
253        """
254
255        fsdp_model_with_hook = self._init_model(
256            Net(has_wrapping=True, sharding_strategy=sharding_strategy),
257            sharding_strategy=sharding_strategy,
258        )
259        dummy_state = DummyState(process_group=None, noise=1234)
260        dummy_hook = (
261            DummyHook.dummy_hook_for_no_shard_fsdp
262            if sharding_strategy != ShardingStrategy.NO_SHARD
263            else DummyHook.dummy_hook_for_sharded_fsdp
264        )
265        # Creating a list of non-root submodules to test
266        submodules = self._get_submodules(fsdp_model_with_hook)
267        # Check that assertion is raised for registering a comm hook on a non-root
268        with self.assertRaisesRegex(
269            AssertionError,
270            "^register_comm_hook can only be called on a root instance.$",
271        ):
272            submodules[1].register_comm_hook(dummy_state, dummy_hook)
273
274    @skip_if_lt_x_gpu(2)
275    def test_registering_hook_hybrid_strategy(self):
276        for sharding_strategy in (
277            ShardingStrategy.HYBRID_SHARD,
278            ShardingStrategy._HYBRID_SHARD_ZERO2,
279        ):
280            model = Net(False, None, None).cuda()
281            fsdp_model = FSDP(
282                model,
283                auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
284                sharding_strategy=sharding_strategy,
285            )
286            dummy_state = DummyState(process_group=None, noise=1234)
287            dummy_hook = DummyHook.dummy_hook_for_sharded_fsdp
288            with self.assertRaisesRegex(
289                AssertionError,
290                "Communication hook is not supported for hybrid strategies",
291            ):
292                fsdp_model.register_comm_hook(dummy_state, dummy_hook)
293
294    @skip_if_lt_x_gpu(2)
295    @parametrize(
296        "sharding_strategy",
297        [
298            ShardingStrategy.NO_SHARD,
299            ShardingStrategy.FULL_SHARD,
300            ShardingStrategy.SHARD_GRAD_OP,
301        ],
302    )
303    def test_registering_hook_submodules(
304        self, sharding_strategy: Optional[ShardingStrategy]
305    ):
306        """
307        Tests FSDP's communication hook registering for submodules.
308        Checks behavior if a hook was registered for a non-root submodule
309        Currently tests only ``NO_SHARD`` strategy.
310
311        Arguments:
312            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
313        """
314
315        fsdp_model_with_hook = self._init_model(
316            Net(has_wrapping=True, sharding_strategy=sharding_strategy),
317            sharding_strategy=sharding_strategy,
318        )
319        dummy_state = DummyState(process_group=None, noise=1234)
320        dummy_hook = (
321            DummyHook.dummy_hook_for_no_shard_fsdp
322            if sharding_strategy != ShardingStrategy.NO_SHARD
323            else DummyHook.dummy_hook_for_sharded_fsdp
324        )
325        submodules = self._get_submodules(fsdp_model_with_hook)
326
327        # Simulate a registration of a hook on a submodule
328        submodules[1]._comm_hook = dummy_hook
329        # Check that an error is raised when some of submodules have a non-default hook assigned
330        with self.assertRaisesRegex(
331            AssertionError, "^A communication hook is already registered$"
332        ):
333            fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
334
335    def _check_low_precision_hook(
336        self, state, hook, sharding_strategy, dtype, has_wrapping
337    ):
338        # keep everything deterministic for input data
339        torch.manual_seed(0)
340        torch.cuda.manual_seed(0)
341
342        fsdp_with_hook = self._init_model(
343            Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
344            sharding_strategy=sharding_strategy,
345        )
346        fsdp_with_hook.register_comm_hook(state, hook)
347
348        mp_only_grad = MixedPrecision(reduce_dtype=dtype)
349        fsdp_with_mp = self._init_model(
350            Net(
351                has_wrapping=has_wrapping,
352                sharding_strategy=sharding_strategy,
353                mixed_precision=mp_only_grad,
354            ),
355            sharding_strategy=sharding_strategy,
356            mixed_precision=mp_only_grad,
357        )
358
359        optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
360        optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
361
362        in_data = torch.rand(16, 8).cuda()
363        fsdp_with_hook.train()
364        fsdp_with_mp.train()
365        loss_hook = fsdp_with_hook(in_data).sum()
366        loss_mp = fsdp_with_mp(in_data).sum()
367        loss_hook.backward()
368        # Make sure grads were cast to the parameter's precision
369        self.assertEqual(fsdp_with_hook.params[0].grad.dtype, state.parameter_type)
370        loss_mp.backward()
371        optim_hook.step()
372        optim_mp.step()
373
374        dist.barrier()
375
376        for hook_param, mp_param in zip(
377            fsdp_with_hook.parameters(), fsdp_with_mp.parameters()
378        ):
379            self.assertEqual(hook_param.grad, mp_param.grad)
380
381    @requires_nccl()
382    @skip_if_lt_x_gpu(2)
383    @parametrize("has_wrapping", [True, False])
384    @parametrize(
385        "sharding_strategy",
386        [
387            ShardingStrategy.NO_SHARD,
388            ShardingStrategy.FULL_SHARD,
389            ShardingStrategy.SHARD_GRAD_OP,
390        ],
391    )
392    def test_fp16_hook(
393        self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
394    ):
395        state = default_hooks.LowPrecisionState(process_group=_get_default_group())
396        hook = default_hooks.fp16_compress_hook
397
398        self._check_low_precision_hook(
399            state, hook, sharding_strategy, torch.float16, has_wrapping
400        )
401
402    @requires_nccl()
403    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
404    @skip_but_pass_in_sandcastle_if(
405        not BFLOAT16_AVAILABLE,
406        "BFloat16 is only supported by CUDA 11+",
407    )
408    @skip_if_lt_x_gpu(2)
409    @parametrize("has_wrapping", [True, False])
410    @parametrize(
411        "sharding_strategy",
412        [
413            ShardingStrategy.NO_SHARD,
414            ShardingStrategy.FULL_SHARD,
415            ShardingStrategy.SHARD_GRAD_OP,
416        ],
417    )
418    def test_bf16_hook(
419        self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
420    ):
421        state = default_hooks.LowPrecisionState(process_group=_get_default_group())
422        hook = default_hooks.bf16_compress_hook
423
424        self._check_low_precision_hook(
425            state, hook, sharding_strategy, torch.bfloat16, has_wrapping
426        )
427
428
429instantiate_parametrized_tests(TestCommunicationHooks)
430
431if __name__ == "__main__":
432    run_tests()
433