• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import torch
3from torch._C import DispatchKey
4from torch._higher_order_ops.utils import autograd_not_implemented
5from torch._ops import HigherOrderOperator
6from torch._subclasses import FakeTensorMode
7from torch.fx.experimental._backward_state import BackwardState
8from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
9from torch.utils._python_dispatch import _get_current_dispatch_mode
10from torch.utils._pytree import tree_map_only
11
12
13__all__ = ["trace_wrapped"]
14
15
16# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist:
17# if you make_fx trace through this call, we will not actually trace into fn; instead,
18# we will directly insert it as a call_function to fn in the graph.
19# (Unlike make_fx, Dynamo WILL inline into fn.)
20# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing.
21#
22# Because proxy tensor tracing does not actually run the function, there are
23# requirements on the behavior of fn. We are still figuring it out, but here is the current state:
24#
25# 1) fn SHOULD only take a single argument, which must be a tensor
26# 2) fn MUST return a new tensor with the same metadata as the original tensor
27#    (e.g., zeros_like(input) is a permissible implementation of fn).
28#    This is verified via an extra assert that is inserted into the traced graph.
29# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors
30#    participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state)
31# These requirements stem from the requirement that we need to continue performing proxy tensor tracing,
32# which assumes accurate fake tensor metadata, without actually running fn.
33# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.
34#
35# Note that tensors / Python state are allowed to be mutated.
36# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake
37# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete
38# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).
39#
40# The intended use case for this function is to allow AOTAutograd to defer complex
41# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves
42# the function call as is in the graph, and only when we Dynamo through the backward graph in
43# compiled autograd do we inline into the function.
44
45
46def trace_wrapped(*args, **kwargs):
47    with torch.no_grad():
48        return _trace_wrapped_op(*args, **kwargs)
49
50
51class TraceWrapped(HigherOrderOperator):
52    def __init__(self):
53        super().__init__("trace_wrapped")
54
55    def __call__(self, *args, **kwargs):
56        return super().__call__(*args, **kwargs)
57
58
59# TODO(jansel): need to ensure this does not get DCEed
60_trace_wrapped_op = TraceWrapped()
61
62
63def _assert_meta(grad, size, stride, dtype):
64    assert grad.size() == size, "size mismatch"
65    assert grad.stride() == stride, "stride mismatch"
66    assert grad.dtype == dtype, "dtype mismatch"
67    return grad
68
69
70@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode)
71def inner_trace(mode, *args, bw_state=None, **kwargs):
72    def self_invoke(*args, **dyn_kwargs):
73        with torch.no_grad():
74            return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs)
75
76    def unwrap_proxies(x):
77        if isinstance(x, torch.Tensor):
78            return mode.tracer.unwrap_proxy(x)
79        if isinstance(x, (list, tuple)):
80            return type(x)(map(unwrap_proxies, x))
81        if x is None:
82            return None
83        raise AssertionError(f"unhandled type: {type(x)}")
84
85    proxy_kwargs = {}
86    if bw_state is not None:
87        assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None
88        proxy_kwargs["bw_state"] = bw_state.proxy
89    out_proxy = mode.tracer.create_proxy(
90        "call_function",
91        self_invoke,
92        unwrap_proxies(args),
93        proxy_kwargs,
94        name="trace_wrapped",
95    )
96
97    if args[0] is None:
98        grad = args[1]  # module backward hooks
99    else:
100        grad = args[0]  # other backward hooks
101    grad = tree_map_only(torch.Tensor, torch.empty_like, grad)
102    track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
103    return grad
104
105
106@_trace_wrapped_op.py_impl(FakeTensorMode)
107def inner_fake(*args, **kwargs):
108    raise RuntimeError("This op should never be invoked here")
109
110
111@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd)
112def _trace_wrapped_op_dense(*args, fn, **kwargs):
113    mode = _get_current_dispatch_mode()
114    assert mode is None, "Mode should never be enabled for CPU/CUDA key"
115    return fn(*args, **kwargs)
116
117
118_trace_wrapped_op.py_impl(DispatchKey.Autograd)(
119    autograd_not_implemented(_trace_wrapped_op, deferred_error=True)
120)
121
122
123@_trace_wrapped_op.py_functionalize_impl
124def _trace_wrapped_functionalized(ctx, *args, **kwargs):
125    unwrapped_args = ctx.unwrap_tensors(args)
126    with ctx.redispatch_to_next():
127        return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs))
128