• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import sys
4from typing import List
5from unittest.mock import patch
6
7import torch
8import torch.nn as nn
9from torch import distributed as dist
10from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
11from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
12from torch.distributed.fsdp._flat_param import HandleTrainingState
13from torch.distributed.fsdp._runtime_utils import (
14    _get_handle_to_prefetch,
15    _get_training_state,
16)
17from torch.distributed.fsdp.wrap import ModuleWrapPolicy
18from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
19from torch.testing._internal.common_fsdp import FSDPTest
20from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
21
22
23NUM_ITERS = 2
24DECODER_PARAM_FQNS = [
25    "decoder.layers.{index}.self_attn.in_proj_weight",
26    "decoder.layers.{index}.self_attn.in_proj_bias",
27    "decoder.layers.{index}.self_attn.out_proj.weight",
28    "decoder.layers.{index}.self_attn.out_proj.bias",
29    "decoder.layers.{index}.multihead_attn.in_proj_weight",
30    "decoder.layers.{index}.multihead_attn.in_proj_bias",
31    "decoder.layers.{index}.multihead_attn.out_proj.weight",
32    "decoder.layers.{index}.multihead_attn.out_proj.bias",
33    "decoder.layers.{index}.linear1.weight",
34    "decoder.layers.{index}.linear1.bias",
35    "decoder.layers.{index}.linear2.weight",
36    "decoder.layers.{index}.linear2.bias",
37    "decoder.layers.{index}.norm1.weight",
38    "decoder.layers.{index}.norm1.bias",
39    "decoder.layers.{index}.norm2.weight",
40    "decoder.layers.{index}.norm2.bias",
41    "decoder.layers.{index}.norm3.weight",
42    "decoder.layers.{index}.norm3.bias",
43]
44ENCODER_PARAM_FQNS = [
45    "encoder.layers.{index}.self_attn.in_proj_weight",
46    "encoder.layers.{index}.self_attn.in_proj_bias",
47    "encoder.layers.{index}.self_attn.out_proj.weight",
48    "encoder.layers.{index}.self_attn.out_proj.bias",
49    "encoder.layers.{index}.linear1.weight",
50    "encoder.layers.{index}.linear1.bias",
51    "encoder.layers.{index}.linear2.weight",
52    "encoder.layers.{index}.linear2.bias",
53    "encoder.layers.{index}.norm1.weight",
54    "encoder.layers.{index}.norm1.bias",
55    "encoder.layers.{index}.norm2.weight",
56    "encoder.layers.{index}.norm2.bias",
57]
58TOTAL_NUM_PREFETCH_FOR_PRE = 12
59TOTAL_NUM_PREFETCH_FOR_POST = 11
60ENCODER_BEGIN_INDEX_FOR_PRE = 6
61ENCODER_BEGIN_INDEX_FOR_POST = 5
62ENCODER_PREFETCH_NUM = 5
63
64if not dist.is_available():
65    print("Distributed not available, skipping tests", file=sys.stderr)
66    sys.exit(0)
67
68if TEST_WITH_DEV_DBG_ASAN:
69    print(
70        "Skip dev-asan as torch + multiprocessing spawn have known issues",
71        file=sys.stderr,
72    )
73    sys.exit(0)
74
75
76class TestBackwardPrefetch(FSDPTest):
77    @property
78    def world_size(self):
79        return 2
80
81    def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
82        rank = self.rank
83        orig_get_handle_to_prefetch = _get_handle_to_prefetch
84
85        torch.manual_seed(0)
86        policy = ModuleWrapPolicy(
87            {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
88        )
89        model = FSDP(
90            nn.Transformer(d_model=1024, nhead=8, device="cuda"),
91            device_id=torch.cuda.current_device(),
92            auto_wrap_policy=policy,
93            use_orig_params=True,
94            backward_prefetch=backward_prefetch,
95        )
96        optim = torch.optim.SGD(model.parameters(), lr=1e-2)
97
98        # prepare input
99        torch.manual_seed(rank + 1)
100        src = torch.randn((10, 1, 1024), device="cuda")
101        tgt = torch.randn((20, 1, 1024), device="cuda")
102
103        # monkey patch
104        all_handle_fqns: List[List[str]] = []
105
106        def patched_get_handle_to_prefetch(*args, **kwargs):
107            handle = orig_get_handle_to_prefetch(*args, **kwargs)
108
109            self.assertEqual(
110                len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
111            )
112            state = args[0]
113            current_handle = args[1]
114            training_state = _get_training_state(current_handle)
115            if (
116                training_state == HandleTrainingState.BACKWARD_PRE
117                and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
118            ) or (
119                training_state == HandleTrainingState.BACKWARD_POST
120                and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
121            ):
122                nonlocal all_handle_fqns
123                # FQNs prefixed from the root module
124                # state._exec_order_data.param_to_fqn
125                fqns = _get_handle_fqns_from_root(state, handle)
126                all_handle_fqns.append(fqns)
127            return handle
128
129        # flat params from prefetch handle should match
130        # DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
131        with patch(
132            "torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
133            patched_get_handle_to_prefetch,
134        ):
135            for _ in range(NUM_ITERS):
136                optim.zero_grad()
137                loss = model(src, tgt).sum()
138                loss.backward()
139                optim.step()
140                if backward_prefetch is None:
141                    self.assertEqual(len(all_handle_fqns), 0)
142                    continue
143                elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
144                    # state._exec_order_data.handles_post_forward_order
145                    # equals forward order
146                    #     encoder 0...5 -> decoder 0...5 -> root
147                    # pre-backward hook order
148                    #     root -> decoder 5...0 -> encoder 5...0
149                    # prefetch order
150                    #     decoder 5...0 -> encoder 5...0 -> None
151                    # None: when current_handle=encoder 0,
152                    #       _get_handle_to_prefetch returns None
153                    # +1 is for the above None
154                    encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
155                    self.assertEqual(
156                        len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
157                    )
158                elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
159                    # state._exec_order_data.handles_post_forward_order
160                    # equals forward order (same as BACKWARD_PRE)
161                    #     encoder 0...5 -> decoder 0...5 -> root
162                    # post-backward hook (AccumulateGrad) order
163                    #     decoder 5, 4...0 -> encoder 5...0 -> root
164                    # prefetch order
165                    #     decoder 4...0 -> encoder 5...0 -> None -> None
166                    # 1st None: when current_handle=encoder 0,
167                    #           _get_handle_to_prefetch returns None
168                    # 2nd None: when current_handle=root,
169                    #           get decoder 5 inside _get_handle_to_prefetch
170                    #           but not needed since decoder 5 is computed already
171                    # +2 is for the above Nones
172                    encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
173                    self.assertEqual(
174                        len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
175                    )
176
177                # ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
178                for ith_prefetch, fqns in enumerate(all_handle_fqns):
179                    if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
180                        layer_index = encoder_begin_index - 1 - ith_prefetch
181                        self.assertEqual(
182                            fqns,
183                            [x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
184                        )
185                    elif (
186                        ith_prefetch >= encoder_begin_index
187                        and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
188                    ):
189                        layer_index = (
190                            encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
191                        )
192                        self.assertEqual(
193                            fqns,
194                            [x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
195                        )
196                    else:
197                        self.assertTrue(fqns is None)
198
199                all_handle_fqns = []
200
201    @skip_if_lt_x_gpu(2)
202    def test_backward_prefetch(self):
203        # subtest reuse process group to shorten test time
204        self.run_subtests(
205            {
206                "backward_prefetch": [
207                    None,
208                    BackwardPrefetch.BACKWARD_PRE,
209                    BackwardPrefetch.BACKWARD_POST,
210                ],
211            },
212            self._test_backward_prefetch,
213        )
214
215    def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
216        self._dist_train(backward_prefetch)
217
218
219if __name__ == "__main__":
220    run_tests()
221