1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from contextlib import contextmanager, nullcontext 4from typing import Any, ContextManager, Dict, Optional, Tuple 5 6import torch 7import torch.nn as nn 8from torch.utils.checkpoint import ( 9 _checkpoint_without_reentrant_generator, 10 _DEFAULT_DETERMINISM_MODE, 11) 12 13from .contract import contract 14 15 16@contextmanager 17def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None): 18 r""" 19 Disable hooks installed by checkpoint to avoid unintentional recursion 20 during backward recomputation. 21 """ 22 23 with user_ctx if user_ctx else nullcontext(): 24 orig_enable_hook = checkpoint.state(module).enable_hook 25 checkpoint.state(module).enable_hook = False 26 try: 27 yield 28 finally: 29 checkpoint.state(module).enable_hook = orig_enable_hook 30 31 32@contract() 33def checkpoint(module: nn.Module, **kwargs) -> nn.Module: 34 r""" 35 This is a composable activation checkpointing API. Unlike functional 36 activation checkpointing APIs, this one does not require changing model 37 source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs, 38 this one does not modify model structure or fully-qualified names either. 39 Under the hood, it registers activation checkpointing logic as pre- and 40 post-forward hooks. Hence, this API can be easily applied to any model or 41 sub-modules in the model. 42 43 Args: 44 module (nn.Module): the target model or sub-module to apply activation 45 checkpointing. 46 47 Example:: 48 >>> # xdoctest: +SKIP 49 >>> import torch.nn as nn 50 >>> 51 >>> class MyModel(nn.Module): 52 >>> def __init__(self) -> None: 53 >>> super().__init__() 54 >>> self.l1 = nn.Linear(10, 10) 55 >>> self.l2 = nn.Linear(10, 10) 56 >>> 57 >>> def forward(self, x): 58 >>> return self.l2(self.l1(x)) 59 >>> 60 >>> model = MyModel() 61 >>> checkpoint(model.l1) # apply activation checkpointing only to l1 62 >>> model(torch.zeros(2, 10)).sum().backward() 63 64 """ 65 torch._C._log_api_usage_once("torch.distributed.checkpoint") 66 67 use_reentrant = kwargs.pop("use_reentrant", False) 68 if use_reentrant: 69 raise NotImplementedError( 70 "use_reentrant=True is not supported in composable checkpoint. " 71 "Please use torch.utils.checkpoint.checkpoint instead." 72 ) 73 preserve_rng_state = kwargs.pop("preserve_rng_state", True) 74 user_context_fns = kwargs.pop("context_fn", None) 75 determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) 76 debug = kwargs.pop("debug", False) 77 78 if kwargs: 79 raise ValueError( 80 "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) 81 ) 82 83 def forward_pre_hook( 84 module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] 85 ) -> None: 86 if checkpoint.state(module).enable_hook: 87 88 def context_fns(): 89 if user_context_fns is not None: 90 ctx1, ctx2 = user_context_fns() 91 return ctx1, _no_hook(module, ctx2) 92 else: 93 return nullcontext(), _no_hook(module) 94 95 checkpoint.state( 96 module 97 )._ac_generator = _checkpoint_without_reentrant_generator( 98 module, 99 preserve_rng_state, 100 context_fns, 101 determinism_check, 102 debug, 103 *args, 104 **kwargs, 105 ) 106 next(checkpoint.state(module)._ac_generator) 107 108 def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any: 109 if checkpoint.state(module).enable_hook: 110 try: 111 next(checkpoint.state(module)._ac_generator) 112 except StopIteration: 113 pass 114 else: 115 raise RuntimeError( 116 "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!" 117 ) 118 119 # Ensure that we no longer hold on to the generator. always_call=True helps ensure we 120 # clear this even in the case of exception in fwd pass. 121 checkpoint.state(module)._ac_generator = None 122 123 checkpoint.state(module).enable_hook = True 124 module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) 125 module.register_forward_hook(forward_hook, prepend=True, always_call=True) 126 return module 127