1# Owner(s): ["module: unknown"] 2import functools 3import gc 4from typing import Union 5 6import torch 7import torch.nn as nn 8from torch.distributed._composable import checkpoint 9from torch.distributed._composable.fsdp import ( 10 CPUOffloadPolicy, 11 fully_shard, 12 MixedPrecisionPolicy, 13 OffloadPolicy, 14) 15from torch.distributed._tensor import init_device_mesh 16from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker 17from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 18 apply_activation_checkpointing, 19 CheckpointWrapper, 20) 21from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 22from torch.testing._internal.common_fsdp import FSDPTest, MLP 23from torch.testing._internal.common_utils import run_tests 24from torch.testing._internal.distributed._tensor.common_dtensor import ( 25 ModelArgs, 26 Transformer, 27 TransformerBlock, 28) 29 30 31def _init_cublas_workspace(dev: torch.device): 32 lin = torch.nn.Linear(768, 768, device=dev) 33 inp = torch.randn(1, 768, device=dev) 34 lin(inp).sum().backward() 35 del lin 36 del inp 37 38 39def _reset_mem_stats(dev: torch.device): 40 torch.cuda.empty_cache() 41 torch.cuda.reset_accumulated_memory_stats(dev) 42 torch.cuda.reset_peak_memory_stats(dev) 43 44 45class TestTrackerFullyShard1DTrainingCore(FSDPTest): 46 @property 47 def world_size(self) -> int: 48 return min(4, torch.cuda.device_count()) 49 50 @skip_if_lt_x_gpu(2) 51 def test_tracker_multi_group_eager(self): 52 """ 53 Tests tracker accuracy when using multiple parameter groups for 54 communication (for communication and computation overlap plus memory 55 reduction) and different mixed precision policies. 56 """ 57 self.run_subtests( 58 { 59 "reshard_after_forward": [True, False], 60 "offload_policy": [ 61 CPUOffloadPolicy(pin_memory=False), 62 OffloadPolicy(), 63 ], 64 "mp_policy": [ 65 MixedPrecisionPolicy( 66 param_dtype=torch.float16, reduce_dtype=torch.float32 67 ), 68 ], 69 }, 70 self._test_tracker_multi_group, 71 ) 72 73 def _test_tracker_multi_group( 74 self, 75 reshard_after_forward: Union[bool, int], 76 offload_policy: OffloadPolicy, 77 mp_policy: MixedPrecisionPolicy, 78 ): 79 debug = False 80 dev = torch.device(torch.cuda.current_device()) 81 _init_cublas_workspace(dev) 82 gc.collect() 83 _reset_mem_stats(dev) 84 mem_stats = torch.cuda.memory_stats(dev) 85 pre_cuda_active = mem_stats["active_bytes.all.current"] 86 torch.manual_seed(42) 87 lin_dim, bsz = 2048, 8192 88 with torch.device(dev): 89 model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)]) 90 mesh = init_device_mesh("cuda", (self.world_size,)) 91 fully_shard_fn = functools.partial( 92 fully_shard, 93 mesh=mesh, 94 reshard_after_forward=reshard_after_forward, 95 offload_policy=offload_policy, 96 mp_policy=mp_policy, 97 ) 98 for mlp in model: 99 fully_shard_fn(mlp) 100 fully_shard_fn(model) 101 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 102 inp = torch.randn((bsz, lin_dim), device=dev) 103 fmt = FSDPMemTracker(model, optim) 104 fmt.track_inputs((inp,)) 105 with fmt: 106 for iter_idx in range(2): 107 loss = model(inp).sum() 108 loss.backward() 109 optim.step() 110 optim.zero_grad() 111 if iter_idx == 0: 112 fmt.reset_mod_stats() 113 mem_stats = torch.cuda.memory_stats() 114 tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] 115 cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active 116 accuracy = tracker_max / cuda_max 117 if self.rank == 0 and debug: 118 print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") 119 self.assertAlmostEqual( 120 accuracy, 121 1.0, 122 delta=0.1, 123 msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", 124 ) 125 del model 126 del inp 127 del optim 128 129 @skip_if_lt_x_gpu(2) 130 def test_tracker_non_root_forward_backward(self): 131 """ 132 Tests tracker accracy when running forward/backward through a non-root. 133 """ 134 debug = False 135 dev = torch.device(torch.cuda.current_device()) 136 _init_cublas_workspace(dev) 137 gc.collect() 138 _reset_mem_stats(dev) 139 mem_stats = torch.cuda.memory_stats(dev) 140 pre_cuda_active = mem_stats["active_bytes.all.current"] 141 torch.manual_seed(42) 142 lin_dim, bsz = 2048, 8 143 model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)]) 144 for mlp in model: 145 fully_shard(mlp) 146 fully_shard(model) 147 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 148 torch.manual_seed(42 + self.rank) 149 inp = torch.randn((bsz, lin_dim), device=dev) 150 fmt = FSDPMemTracker(model, optim) 151 fmt.track_inputs((inp,)) 152 with fmt: 153 for iter_idx in range(2): 154 nonroot_loss = model[0](inp).sum() 155 nonroot_loss.backward() 156 optim.step() 157 optim.zero_grad() 158 if iter_idx == 0: 159 fmt.reset_mod_stats() 160 mem_stats = torch.cuda.memory_stats() 161 tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] 162 cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active 163 accuracy = tracker_max / cuda_max 164 if self.rank == 0 and debug: 165 print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") 166 self.assertAlmostEqual( 167 accuracy, 168 1.0, 169 delta=0.1, 170 msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", 171 ) 172 del inp 173 del model 174 del optim 175 176 177class TestTrackerFullyShard1DTrainingCompose(FSDPTest): 178 @property 179 def world_size(self) -> int: 180 return min(torch.cuda.device_count(), 4) 181 182 @skip_if_lt_x_gpu(2) 183 def test_tracker_with_activation_checkpointing(self): 184 """ 185 Tests tracker accuracy when composing with activation checkpointing. 186 """ 187 self.run_subtests( 188 { 189 "reshard_after_forward": [True, False], 190 "checkpoint_impl": ["composable", "wrapper"], 191 }, 192 self._test_tracker_with_activation_checkpointing, 193 ) 194 195 def _test_tracker_with_activation_checkpointing( 196 self, reshard_after_forward: Union[bool, int], checkpoint_impl: str 197 ): 198 assert checkpoint_impl in ("composable", "wrapper") 199 debug = False 200 dev = torch.device(torch.cuda.current_device()) 201 _init_cublas_workspace(dev) 202 gc.collect() 203 _reset_mem_stats(dev) 204 mem_stats = torch.cuda.memory_stats(dev) 205 pre_cuda_active = mem_stats["active_bytes.all.current"] 206 torch.manual_seed(42) 207 vocab_size = 8192 208 bsz, seq_len = 16, 512 209 with torch.device(dev): 210 model_args = ModelArgs( 211 n_layers=4, 212 n_heads=4, 213 vocab_size=vocab_size, 214 max_seq_len=seq_len, 215 dropout_p=0.1, 216 ) 217 model = Transformer(model_args) 218 foreach = False 219 fully_shard_fn = functools.partial( 220 fully_shard, 221 reshard_after_forward=reshard_after_forward, 222 ) 223 if checkpoint_impl == "wrapper": 224 apply_activation_checkpointing( 225 model, check_fn=lambda m: isinstance(m, TransformerBlock) 226 ) 227 for module in model.modules(): 228 # Apply to `CheckpointWrapper`, which wraps `TransformerBlock` 229 if isinstance(module, CheckpointWrapper): 230 fully_shard_fn(module) 231 else: 232 for module in model.modules(): 233 if isinstance(module, TransformerBlock): 234 if checkpoint_impl == "composable": 235 checkpoint(module) 236 fully_shard_fn(module) 237 fully_shard_fn(model) 238 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach) 239 240 torch.manual_seed(42 + self.rank) 241 inp = torch.randint(0, vocab_size, (bsz, seq_len), device=dev) 242 fmt = FSDPMemTracker(model, optim) 243 fmt.track_inputs((inp,)) 244 with fmt: 245 for iter_idx in range(2): 246 loss = model(inp).sum() 247 loss.backward() 248 optim.step() 249 optim.zero_grad() 250 if iter_idx == 0: 251 fmt.reset_mod_stats() 252 mem_stats = torch.cuda.memory_stats() 253 tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] 254 cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active 255 accuracy = tracker_max / cuda_max 256 if self.rank == 0 and debug: 257 print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") 258 self.assertAlmostEqual( 259 accuracy, 260 1.0, 261 delta=0.1, 262 msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", 263 ) 264 del inp 265 del model 266 del optim 267 268 269if __name__ == "__main__": 270 run_tests() 271