• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from contextlib import contextmanager, nullcontext
4from typing import Any, ContextManager, Dict, Optional, Tuple
5
6import torch
7import torch.nn as nn
8from torch.utils.checkpoint import (
9    _checkpoint_without_reentrant_generator,
10    _DEFAULT_DETERMINISM_MODE,
11)
12
13from .contract import contract
14
15
16@contextmanager
17def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None):
18    r"""
19    Disable hooks installed by checkpoint to avoid unintentional recursion
20    during backward recomputation.
21    """
22
23    with user_ctx if user_ctx else nullcontext():
24        orig_enable_hook = checkpoint.state(module).enable_hook
25        checkpoint.state(module).enable_hook = False
26        try:
27            yield
28        finally:
29            checkpoint.state(module).enable_hook = orig_enable_hook
30
31
32@contract()
33def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
34    r"""
35    This is a composable activation checkpointing API. Unlike functional
36    activation checkpointing APIs, this one does not require changing model
37    source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
38    this one does not modify model structure or fully-qualified names either.
39    Under the hood, it registers activation checkpointing logic as pre- and
40    post-forward hooks. Hence, this API can be easily applied to any model or
41    sub-modules in the model.
42
43    Args:
44        module (nn.Module): the target model or sub-module to apply activation
45            checkpointing.
46
47    Example::
48        >>> # xdoctest: +SKIP
49        >>> import torch.nn as nn
50        >>>
51        >>> class MyModel(nn.Module):
52        >>>     def __init__(self) -> None:
53        >>>         super().__init__()
54        >>>         self.l1 = nn.Linear(10, 10)
55        >>>         self.l2 = nn.Linear(10, 10)
56        >>>
57        >>>     def forward(self, x):
58        >>>         return self.l2(self.l1(x))
59        >>>
60        >>> model = MyModel()
61        >>> checkpoint(model.l1)  # apply activation checkpointing only to l1
62        >>> model(torch.zeros(2, 10)).sum().backward()
63
64    """
65    torch._C._log_api_usage_once("torch.distributed.checkpoint")
66
67    use_reentrant = kwargs.pop("use_reentrant", False)
68    if use_reentrant:
69        raise NotImplementedError(
70            "use_reentrant=True is not supported in composable checkpoint. "
71            "Please use torch.utils.checkpoint.checkpoint instead."
72        )
73    preserve_rng_state = kwargs.pop("preserve_rng_state", True)
74    user_context_fns = kwargs.pop("context_fn", None)
75    determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
76    debug = kwargs.pop("debug", False)
77
78    if kwargs:
79        raise ValueError(
80            "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
81        )
82
83    def forward_pre_hook(
84        module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
85    ) -> None:
86        if checkpoint.state(module).enable_hook:
87
88            def context_fns():
89                if user_context_fns is not None:
90                    ctx1, ctx2 = user_context_fns()
91                    return ctx1, _no_hook(module, ctx2)
92                else:
93                    return nullcontext(), _no_hook(module)
94
95            checkpoint.state(
96                module
97            )._ac_generator = _checkpoint_without_reentrant_generator(
98                module,
99                preserve_rng_state,
100                context_fns,
101                determinism_check,
102                debug,
103                *args,
104                **kwargs,
105            )
106            next(checkpoint.state(module)._ac_generator)
107
108    def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
109        if checkpoint.state(module).enable_hook:
110            try:
111                next(checkpoint.state(module)._ac_generator)
112            except StopIteration:
113                pass
114            else:
115                raise RuntimeError(
116                    "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
117                )
118
119        #  Ensure that we no longer hold on to the generator. always_call=True helps ensure we
120        # clear this even in the case of exception in fwd pass.
121        checkpoint.state(module)._ac_generator = None
122
123    checkpoint.state(module).enable_hook = True
124    module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
125    module.register_forward_hook(forward_hook, prepend=True, always_call=True)
126    return module
127