• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2"""
3The APIs in this file are exposed as `functorch.*`. They are thin wrappers
4around the torch.func.* APIs that have deprecation warnings -- we're trying
5to move people to the torch.func.* equivalents.
6
7NB: We don't use *args, **kwargs in the signatures because that changes the
8documentation.
9"""
10
11import textwrap
12import warnings
13from typing import Any, Callable, Optional, Tuple, Union
14
15import torch._functorch.apis as apis
16import torch._functorch.eager_transforms as _impl
17import torch._functorch.make_functional as _nn_impl
18import torch.nn as nn
19from torch._functorch.eager_transforms import argnums_t
20from torch._functorch.vmap import in_dims_t, out_dims_t
21
22
23def get_warning(api, new_api=None, replace_newlines=False):
24    if new_api is None:
25        new_api = f"torch.func.{api}"
26    warning = (
27        f"We've integrated functorch into PyTorch. As the final step of the \n"
28        f"integration, `functorch.{api}` is deprecated as of PyTorch \n"
29        f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
30        f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n"
31        f"and/or the `torch.func` migration guide for more details \n"
32        f"https://pytorch.org/docs/main/func.migrating.html"
33    )
34    if replace_newlines:
35        warning = warning.replace("\n", "")
36    return warning
37
38
39def warn_deprecated(api, new_api=None):
40    warning = get_warning(api, new_api, replace_newlines=True)
41    warnings.warn(warning, FutureWarning, stacklevel=3)
42
43
44def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
45    api_name = functorch_api.__name__
46    if torch_func_api is None:
47        torch_func_api = getattr(_impl, api_name)
48    # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO
49    if torch_func_api.__doc__ is None:
50        return
51
52    warning = get_warning(api_name, new_api_name)
53    warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, "    ")
54    warning_note = textwrap.indent(warning_note, "    ")
55    functorch_api.__doc__ = torch_func_api.__doc__ + warning_note
56
57
58def vmap(
59    func: Callable,
60    in_dims: in_dims_t = 0,
61    out_dims: out_dims_t = 0,
62    randomness: str = "error",
63    *,
64    chunk_size=None,
65) -> Callable:
66    warn_deprecated("vmap", "torch.vmap")
67    return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
68
69
70def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
71    warn_deprecated("grad")
72    return apis.grad(func, argnums, has_aux)
73
74
75def grad_and_value(
76    func: Callable, argnums: argnums_t = 0, has_aux: bool = False
77) -> Callable:
78    warn_deprecated("grad_and_value")
79    return apis.grad_and_value(func, argnums, has_aux)
80
81
82def vjp(func: Callable, *primals, has_aux: bool = False):
83    warn_deprecated("vjp")
84    return _impl.vjp(func, *primals, has_aux=has_aux)
85
86
87def jvp(
88    func: Callable,
89    primals: Any,
90    tangents: Any,
91    *,
92    strict: bool = False,
93    has_aux: bool = False,
94):
95    warn_deprecated("jvp")
96    return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
97
98
99def jacrev(
100    func: Callable,
101    argnums: Union[int, Tuple[int]] = 0,
102    *,
103    has_aux=False,
104    chunk_size: Optional[int] = None,
105    _preallocate_and_copy=False,
106):
107    warn_deprecated("jacrev")
108    return _impl.jacrev(
109        func,
110        argnums,
111        has_aux=has_aux,
112        chunk_size=chunk_size,
113        _preallocate_and_copy=_preallocate_and_copy,
114    )
115
116
117def jacfwd(
118    func: Callable,
119    argnums: argnums_t = 0,
120    has_aux: bool = False,
121    *,
122    randomness: str = "error",
123):
124    warn_deprecated("jacfwd")
125    return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
126
127
128def hessian(func, argnums=0):
129    warn_deprecated("hessian")
130    return _impl.hessian(func, argnums=argnums)
131
132
133def functionalize(func: Callable, *, remove: str = "mutations") -> Callable:
134    warn_deprecated("functionalize")
135    return _impl.functionalize(func, remove=remove)
136
137
138def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
139    warn_deprecated("make_functional", "torch.func.functional_call")
140    return _nn_impl.make_functional(model, disable_autograd_tracking)
141
142
143def make_functional_with_buffers(
144    model: nn.Module, disable_autograd_tracking: bool = False
145):
146    warn_deprecated("make_functional_with_buffers", "torch.func.functional_call")
147    return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
148
149
150def combine_state_for_ensemble(models):
151    warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state")
152    return _nn_impl.combine_state_for_ensemble(models)
153
154
155setup_docs(vmap, apis.vmap, "torch.vmap")
156setup_docs(grad, apis.grad)
157setup_docs(grad_and_value, apis.grad_and_value)
158setup_docs(vjp)
159setup_docs(jvp)
160setup_docs(jacrev)
161setup_docs(jacfwd)
162setup_docs(hessian)
163setup_docs(functionalize)
164setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call")
165setup_docs(
166    make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call"
167)
168setup_docs(
169    combine_state_for_ensemble,
170    _nn_impl.combine_state_for_ensemble,
171    "torch.func.stack_module_state",
172)
173