# Owner(s): ["module: unknown"] import functools import gc from typing import Union import torch import torch.nn as nn from torch.distributed._composable import checkpoint from torch.distributed._composable.fsdp import ( CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy, OffloadPolicy, ) from torch.distributed._tensor import init_device_mesh from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, CheckpointWrapper, ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest, MLP from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, TransformerBlock, ) def _init_cublas_workspace(dev: torch.device): lin = torch.nn.Linear(768, 768, device=dev) inp = torch.randn(1, 768, device=dev) lin(inp).sum().backward() del lin del inp def _reset_mem_stats(dev: torch.device): torch.cuda.empty_cache() torch.cuda.reset_accumulated_memory_stats(dev) torch.cuda.reset_peak_memory_stats(dev) class TestTrackerFullyShard1DTrainingCore(FSDPTest): @property def world_size(self) -> int: return min(4, torch.cuda.device_count()) @skip_if_lt_x_gpu(2) def test_tracker_multi_group_eager(self): """ Tests tracker accuracy when using multiple parameter groups for communication (for communication and computation overlap plus memory reduction) and different mixed precision policies. """ self.run_subtests( { "reshard_after_forward": [True, False], "offload_policy": [ CPUOffloadPolicy(pin_memory=False), OffloadPolicy(), ], "mp_policy": [ MixedPrecisionPolicy( param_dtype=torch.float16, reduce_dtype=torch.float32 ), ], }, self._test_tracker_multi_group, ) def _test_tracker_multi_group( self, reshard_after_forward: Union[bool, int], offload_policy: OffloadPolicy, mp_policy: MixedPrecisionPolicy, ): debug = False dev = torch.device(torch.cuda.current_device()) _init_cublas_workspace(dev) gc.collect() _reset_mem_stats(dev) mem_stats = torch.cuda.memory_stats(dev) pre_cuda_active = mem_stats["active_bytes.all.current"] torch.manual_seed(42) lin_dim, bsz = 2048, 8192 with torch.device(dev): model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)]) mesh = init_device_mesh("cuda", (self.world_size,)) fully_shard_fn = functools.partial( fully_shard, mesh=mesh, reshard_after_forward=reshard_after_forward, offload_policy=offload_policy, mp_policy=mp_policy, ) for mlp in model: fully_shard_fn(mlp) fully_shard_fn(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) inp = torch.randn((bsz, lin_dim), device=dev) fmt = FSDPMemTracker(model, optim) fmt.track_inputs((inp,)) with fmt: for iter_idx in range(2): loss = model(inp).sum() loss.backward() optim.step() optim.zero_grad() if iter_idx == 0: fmt.reset_mod_stats() mem_stats = torch.cuda.memory_stats() tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active accuracy = tracker_max / cuda_max if self.rank == 0 and debug: print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") self.assertAlmostEqual( accuracy, 1.0, delta=0.1, msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", ) del model del inp del optim @skip_if_lt_x_gpu(2) def test_tracker_non_root_forward_backward(self): """ Tests tracker accracy when running forward/backward through a non-root. """ debug = False dev = torch.device(torch.cuda.current_device()) _init_cublas_workspace(dev) gc.collect() _reset_mem_stats(dev) mem_stats = torch.cuda.memory_stats(dev) pre_cuda_active = mem_stats["active_bytes.all.current"] torch.manual_seed(42) lin_dim, bsz = 2048, 8 model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)]) for mlp in model: fully_shard(mlp) fully_shard(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) torch.manual_seed(42 + self.rank) inp = torch.randn((bsz, lin_dim), device=dev) fmt = FSDPMemTracker(model, optim) fmt.track_inputs((inp,)) with fmt: for iter_idx in range(2): nonroot_loss = model[0](inp).sum() nonroot_loss.backward() optim.step() optim.zero_grad() if iter_idx == 0: fmt.reset_mod_stats() mem_stats = torch.cuda.memory_stats() tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active accuracy = tracker_max / cuda_max if self.rank == 0 and debug: print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") self.assertAlmostEqual( accuracy, 1.0, delta=0.1, msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", ) del inp del model del optim class TestTrackerFullyShard1DTrainingCompose(FSDPTest): @property def world_size(self) -> int: return min(torch.cuda.device_count(), 4) @skip_if_lt_x_gpu(2) def test_tracker_with_activation_checkpointing(self): """ Tests tracker accuracy when composing with activation checkpointing. """ self.run_subtests( { "reshard_after_forward": [True, False], "checkpoint_impl": ["composable", "wrapper"], }, self._test_tracker_with_activation_checkpointing, ) def _test_tracker_with_activation_checkpointing( self, reshard_after_forward: Union[bool, int], checkpoint_impl: str ): assert checkpoint_impl in ("composable", "wrapper") debug = False dev = torch.device(torch.cuda.current_device()) _init_cublas_workspace(dev) gc.collect() _reset_mem_stats(dev) mem_stats = torch.cuda.memory_stats(dev) pre_cuda_active = mem_stats["active_bytes.all.current"] torch.manual_seed(42) vocab_size = 8192 bsz, seq_len = 16, 512 with torch.device(dev): model_args = ModelArgs( n_layers=4, n_heads=4, vocab_size=vocab_size, max_seq_len=seq_len, dropout_p=0.1, ) model = Transformer(model_args) foreach = False fully_shard_fn = functools.partial( fully_shard, reshard_after_forward=reshard_after_forward, ) if checkpoint_impl == "wrapper": apply_activation_checkpointing( model, check_fn=lambda m: isinstance(m, TransformerBlock) ) for module in model.modules(): # Apply to `CheckpointWrapper`, which wraps `TransformerBlock` if isinstance(module, CheckpointWrapper): fully_shard_fn(module) else: for module in model.modules(): if isinstance(module, TransformerBlock): if checkpoint_impl == "composable": checkpoint(module) fully_shard_fn(module) fully_shard_fn(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach) torch.manual_seed(42 + self.rank) inp = torch.randint(0, vocab_size, (bsz, seq_len), device=dev) fmt = FSDPMemTracker(model, optim) fmt.track_inputs((inp,)) with fmt: for iter_idx in range(2): loss = model(inp).sum() loss.backward() optim.step() optim.zero_grad() if iter_idx == 0: fmt.reset_mod_stats() mem_stats = torch.cuda.memory_stats() tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active accuracy = tracker_max / cuda_max if self.rank == 0 and debug: print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") self.assertAlmostEqual( accuracy, 1.0, delta=0.1, msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", ) del inp del model del optim if __name__ == "__main__": run_tests()