1# Owner(s): ["oncall: distributed"] 2 3import sys 4from typing import List 5from unittest.mock import patch 6 7import torch 8import torch.nn as nn 9from torch import distributed as dist 10from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP 11from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root 12from torch.distributed.fsdp._flat_param import HandleTrainingState 13from torch.distributed.fsdp._runtime_utils import ( 14 _get_handle_to_prefetch, 15 _get_training_state, 16) 17from torch.distributed.fsdp.wrap import ModuleWrapPolicy 18from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 19from torch.testing._internal.common_fsdp import FSDPTest 20from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 21 22 23NUM_ITERS = 2 24DECODER_PARAM_FQNS = [ 25 "decoder.layers.{index}.self_attn.in_proj_weight", 26 "decoder.layers.{index}.self_attn.in_proj_bias", 27 "decoder.layers.{index}.self_attn.out_proj.weight", 28 "decoder.layers.{index}.self_attn.out_proj.bias", 29 "decoder.layers.{index}.multihead_attn.in_proj_weight", 30 "decoder.layers.{index}.multihead_attn.in_proj_bias", 31 "decoder.layers.{index}.multihead_attn.out_proj.weight", 32 "decoder.layers.{index}.multihead_attn.out_proj.bias", 33 "decoder.layers.{index}.linear1.weight", 34 "decoder.layers.{index}.linear1.bias", 35 "decoder.layers.{index}.linear2.weight", 36 "decoder.layers.{index}.linear2.bias", 37 "decoder.layers.{index}.norm1.weight", 38 "decoder.layers.{index}.norm1.bias", 39 "decoder.layers.{index}.norm2.weight", 40 "decoder.layers.{index}.norm2.bias", 41 "decoder.layers.{index}.norm3.weight", 42 "decoder.layers.{index}.norm3.bias", 43] 44ENCODER_PARAM_FQNS = [ 45 "encoder.layers.{index}.self_attn.in_proj_weight", 46 "encoder.layers.{index}.self_attn.in_proj_bias", 47 "encoder.layers.{index}.self_attn.out_proj.weight", 48 "encoder.layers.{index}.self_attn.out_proj.bias", 49 "encoder.layers.{index}.linear1.weight", 50 "encoder.layers.{index}.linear1.bias", 51 "encoder.layers.{index}.linear2.weight", 52 "encoder.layers.{index}.linear2.bias", 53 "encoder.layers.{index}.norm1.weight", 54 "encoder.layers.{index}.norm1.bias", 55 "encoder.layers.{index}.norm2.weight", 56 "encoder.layers.{index}.norm2.bias", 57] 58TOTAL_NUM_PREFETCH_FOR_PRE = 12 59TOTAL_NUM_PREFETCH_FOR_POST = 11 60ENCODER_BEGIN_INDEX_FOR_PRE = 6 61ENCODER_BEGIN_INDEX_FOR_POST = 5 62ENCODER_PREFETCH_NUM = 5 63 64if not dist.is_available(): 65 print("Distributed not available, skipping tests", file=sys.stderr) 66 sys.exit(0) 67 68if TEST_WITH_DEV_DBG_ASAN: 69 print( 70 "Skip dev-asan as torch + multiprocessing spawn have known issues", 71 file=sys.stderr, 72 ) 73 sys.exit(0) 74 75 76class TestBackwardPrefetch(FSDPTest): 77 @property 78 def world_size(self): 79 return 2 80 81 def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE): 82 rank = self.rank 83 orig_get_handle_to_prefetch = _get_handle_to_prefetch 84 85 torch.manual_seed(0) 86 policy = ModuleWrapPolicy( 87 {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer} 88 ) 89 model = FSDP( 90 nn.Transformer(d_model=1024, nhead=8, device="cuda"), 91 device_id=torch.cuda.current_device(), 92 auto_wrap_policy=policy, 93 use_orig_params=True, 94 backward_prefetch=backward_prefetch, 95 ) 96 optim = torch.optim.SGD(model.parameters(), lr=1e-2) 97 98 # prepare input 99 torch.manual_seed(rank + 1) 100 src = torch.randn((10, 1, 1024), device="cuda") 101 tgt = torch.randn((20, 1, 1024), device="cuda") 102 103 # monkey patch 104 all_handle_fqns: List[List[str]] = [] 105 106 def patched_get_handle_to_prefetch(*args, **kwargs): 107 handle = orig_get_handle_to_prefetch(*args, **kwargs) 108 109 self.assertEqual( 110 len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)" 111 ) 112 state = args[0] 113 current_handle = args[1] 114 training_state = _get_training_state(current_handle) 115 if ( 116 training_state == HandleTrainingState.BACKWARD_PRE 117 and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE 118 ) or ( 119 training_state == HandleTrainingState.BACKWARD_POST 120 and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST 121 ): 122 nonlocal all_handle_fqns 123 # FQNs prefixed from the root module 124 # state._exec_order_data.param_to_fqn 125 fqns = _get_handle_fqns_from_root(state, handle) 126 all_handle_fqns.append(fqns) 127 return handle 128 129 # flat params from prefetch handle should match 130 # DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS 131 with patch( 132 "torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch", 133 patched_get_handle_to_prefetch, 134 ): 135 for _ in range(NUM_ITERS): 136 optim.zero_grad() 137 loss = model(src, tgt).sum() 138 loss.backward() 139 optim.step() 140 if backward_prefetch is None: 141 self.assertEqual(len(all_handle_fqns), 0) 142 continue 143 elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE: 144 # state._exec_order_data.handles_post_forward_order 145 # equals forward order 146 # encoder 0...5 -> decoder 0...5 -> root 147 # pre-backward hook order 148 # root -> decoder 5...0 -> encoder 5...0 149 # prefetch order 150 # decoder 5...0 -> encoder 5...0 -> None 151 # None: when current_handle=encoder 0, 152 # _get_handle_to_prefetch returns None 153 # +1 is for the above None 154 encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE 155 self.assertEqual( 156 len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1 157 ) 158 elif backward_prefetch == BackwardPrefetch.BACKWARD_POST: 159 # state._exec_order_data.handles_post_forward_order 160 # equals forward order (same as BACKWARD_PRE) 161 # encoder 0...5 -> decoder 0...5 -> root 162 # post-backward hook (AccumulateGrad) order 163 # decoder 5, 4...0 -> encoder 5...0 -> root 164 # prefetch order 165 # decoder 4...0 -> encoder 5...0 -> None -> None 166 # 1st None: when current_handle=encoder 0, 167 # _get_handle_to_prefetch returns None 168 # 2nd None: when current_handle=root, 169 # get decoder 5 inside _get_handle_to_prefetch 170 # but not needed since decoder 5 is computed already 171 # +2 is for the above Nones 172 encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST 173 self.assertEqual( 174 len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2 175 ) 176 177 # ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch 178 for ith_prefetch, fqns in enumerate(all_handle_fqns): 179 if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index: 180 layer_index = encoder_begin_index - 1 - ith_prefetch 181 self.assertEqual( 182 fqns, 183 [x.format(index=layer_index) for x in DECODER_PARAM_FQNS], 184 ) 185 elif ( 186 ith_prefetch >= encoder_begin_index 187 and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM 188 ): 189 layer_index = ( 190 encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch 191 ) 192 self.assertEqual( 193 fqns, 194 [x.format(index=layer_index) for x in ENCODER_PARAM_FQNS], 195 ) 196 else: 197 self.assertTrue(fqns is None) 198 199 all_handle_fqns = [] 200 201 @skip_if_lt_x_gpu(2) 202 def test_backward_prefetch(self): 203 # subtest reuse process group to shorten test time 204 self.run_subtests( 205 { 206 "backward_prefetch": [ 207 None, 208 BackwardPrefetch.BACKWARD_PRE, 209 BackwardPrefetch.BACKWARD_POST, 210 ], 211 }, 212 self._test_backward_prefetch, 213 ) 214 215 def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch): 216 self._dist_train(backward_prefetch) 217 218 219if __name__ == "__main__": 220 run_tests() 221