# Owner(s): ["oncall: distributed"] import contextlib import sys from copy import deepcopy from functools import partial import torch import torch.distributed as dist import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, offload_wrapper, ) from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, FullyShardedDataParallel as FSDP, ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, TEST_WITH_DEV_DBG_ASAN, ) from torch.utils.checkpoint import checkpoint if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) _save_on_cpu_called = False def get_patched_save_on_cpu(): orig_save_on_cpu = ( torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu ) def patched_save_on_cpu(*args, **kwargs): global _save_on_cpu_called _save_on_cpu_called = True return orig_save_on_cpu(*args, **kwargs) return patched_save_on_cpu @contextlib.contextmanager def patch_save_on_cpu(new_save_on_cpu): orig_save_on_cpu = ( torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu ) torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = ( new_save_on_cpu ) try: yield finally: torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = ( orig_save_on_cpu ) class TestFSDPCheckpoint(FSDPTest): class SequentialModule(nn.Module): def __init__( self, checkpoint_layer=False, offload_activations=False, wrap_fsdp=False, *fsdp_args, **fsdp_kwargs, ): torch.manual_seed(0) torch.cuda.manual_seed(0) super().__init__() l1 = nn.Linear(3, 3).cuda() l2 = nn.Linear(3, 3).cuda() l3 = nn.Linear(3, 3).cuda() if checkpoint_layer: if offload_activations: ckpt_wrapper = offload_wrapper else: ckpt_wrapper = checkpoint_wrapper l1 = ckpt_wrapper(l1) l2 = ckpt_wrapper(l2) l3 = ckpt_wrapper(l3) fsdp_wrapper = partial( _maybe_wrap_fsdp, *fsdp_args, wrap_fsdp=wrap_fsdp, **fsdp_kwargs ) self.ffn = nn.Sequential( fsdp_wrapper(l1), fsdp_wrapper(l2), fsdp_wrapper(l3), ) def forward(self, x): return self.ffn(x) def _verify_parity(self, losses, outputs, models): assert losses assert outputs assert models for l, o in zip(losses[1:], outputs[1:]): self.assertEqual(losses[0], l) self.assertEqual(outputs[0], o) # Verify grads ref_model = models[0] ref_grads = [p.grad for p in ref_model.parameters()] for m in models[1:]: grads = [p.grad for p in m.parameters()] for ref_g, g in zip(ref_grads, grads): self.assertEqual(ref_g, g) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("offload_activations", [True, False]) @parametrize("use_orig_params", [False, True]) def test_checkpoint_fsdp_wrapping( self, cpu_offload: CPUOffload, offload_activations: bool, use_orig_params: bool, ): # Test checkpoint(FSDP(layer1), FSDP(layer2), ....) if offload_activations: wrapper_to_use = offload_wrapper else: wrapper_to_use = checkpoint_wrapper fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params} ckpt_sequential_wrapped_fsdp = wrapper_to_use( TestFSDPCheckpoint.SequentialModule( wrap_fsdp=True, **fsdp_kwargs, ), ) # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), .... inner_ckpt = TestFSDPCheckpoint.SequentialModule( checkpoint_layer=True, offload_activations=offload_activations, wrap_fsdp=True, **fsdp_kwargs, ) baseline = TestFSDPCheckpoint.SequentialModule( wrap_fsdp=True, **fsdp_kwargs, ) # note that reentrant-based checkpointing requires inputs to have grad # flag set. inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True) global _save_on_cpu_called models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline] with patch_save_on_cpu(get_patched_save_on_cpu()): for i in range(2): losses = [] outputs = [] for m in models: check_offload = m != baseline and i == 0 and offload_activations if check_offload: self.assertFalse(_save_on_cpu_called) out = m(inp) if check_offload: self.assertTrue(_save_on_cpu_called) _save_on_cpu_called = False loss = out.sum() loss.backward() losses.append(loss) outputs.append(out) self._verify_parity(losses, outputs, models) dist.barrier() @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("offload_activations", [True, False]) @parametrize("use_orig_params", [False, True]) def test_basic_checkpoint_end_to_end( self, cpu_offload: CPUOffload, offload_activations: bool, use_orig_params: bool, ): fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params} global _save_on_cpu_called with patch_save_on_cpu(get_patched_save_on_cpu()): seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device()) # Runs FSDP with no checkpointing fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs) # Runs checkpoint-wrapped FSDP if offload_activations: wrapper_to_use = offload_wrapper else: wrapper_to_use = checkpoint_wrapper checkpointed_fsdp = wrapper_to_use( FSDP(deepcopy(seq), **fsdp_kwargs), ) # Runs FSDP-wrapped checkpointed module fsdp_wrapped_checkpoint = FSDP( wrapper_to_use(deepcopy(seq)), **fsdp_kwargs, ) # Runs FSDP with manual calls to checkpoint. fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs) # note that reentrant-based checkpointing requires inputs to have grad # flag set. inp = torch.randn( 10, 3, device=torch.cuda.current_device(), requires_grad=True ) models = [ fsdp_only_seq, checkpointed_fsdp, fsdp_wrapped_checkpoint, fsdp_call_checkpoint, ] # Ensure _save_on_cpu is not yet called self.assertFalse(_save_on_cpu_called) for i in range(6): losses = [] outputs = [] for m in models: check_offload = ( m != fsdp_only_seq and i == 0 and offload_activations ) if m == fsdp_call_checkpoint: # _save_on_cpu should not be called yet self.assertFalse(_save_on_cpu_called) offload_ctx = ( get_patched_save_on_cpu()(pin_memory=True) if offload_activations else contextlib.nullcontext() ) with offload_ctx: out = checkpoint(m, inp, use_reentrant=True) else: # _save_on_cpu should not be called yet self.assertFalse(_save_on_cpu_called) out = m(inp) if check_offload: self.assertTrue(_save_on_cpu_called) loss = out.sum() loss.backward() losses.append(loss) outputs.append(out) _save_on_cpu_called = False self._verify_parity(losses, outputs, models) dist.barrier() instantiate_parametrized_tests(TestFSDPCheckpoint) class CheckpointModule(nn.Module): def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): super().__init__() self.seq = nn.Sequential(*[nn.Linear(100, 100) for _ in range(4)]) self.checkpoint = checkpoint self.use_reentrant = use_reentrant def forward(self, x): return ( checkpoint(self.seq, x, use_reentrant=self.use_reentrant) if self.checkpoint else self.seq(x) ) class ModelWithCheckpointSubmodule(nn.Module): def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): super().__init__() self.l1 = nn.Linear(100, 100) self.s1 = CheckpointModule(checkpoint, use_reentrant) self.s2 = CheckpointModule(checkpoint, use_reentrant) self.relu = nn.ReLU() self.l2 = nn.Linear(100, 100) def forward(self, x): return self.l2(self.relu(self.s2(self.s1(self.l1(x))))) class TestModel(nn.Module): def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): super().__init__() self.l1 = nn.Linear(100, 100) self.relu = nn.ReLU() self.checkpoint1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant) self.checkpoint2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant) self.l2 = nn.Linear(100, 100) def forward(self, x): return self.l2(self.relu(self.checkpoint2(self.checkpoint1(self.l1(x))))) class TestFSDPCheckpointSubmodule(FSDPTest): # TODO: grad value checks occasionally fails when use_reentrant = True @skip_if_lt_x_gpu(2) @parametrize("use_reentrant", [False]) def test_checkpoint_submodule(self, use_reentrant: bool): model = TestModel(use_reentrant=use_reentrant).cuda() model_ac = deepcopy(model) for _, m in model_ac.named_modules(): if isinstance(m, CheckpointModule): m.checkpoint = True self.assertTrue(model_ac.checkpoint1.s1.checkpoint) self.assertTrue(model_ac.checkpoint2.s2.checkpoint) fsdp_kwargs = { "device_id": torch.cuda.current_device(), "sharding_strategy": ShardingStrategy.NO_SHARD, } # Wrap no checkpointing model submodules with FSDP model.checkpoint1 = FSDP(module=model.checkpoint1, **fsdp_kwargs) model.checkpoint2 = FSDP(module=model.checkpoint2, **fsdp_kwargs) # Wrap checkpointing model submodules with FSDP model_ac.checkpoint1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs) model_ac.checkpoint2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs) x = torch.randn(2, 100, device="cuda") model(x).sum().backward() model_ac(x).sum().backward() for (n1, p1), (n2, p2) in zip( model.named_parameters(), model_ac.named_parameters() ): self.assertEqual(n1, n2) self.assertTrue(p1.grad.allclose(p2.grad)) instantiate_parametrized_tests(TestFSDPCheckpointSubmodule) if __name__ == "__main__": run_tests()