1# mypy: allow-untyped-defs 2import torch 3from torch._C import DispatchKey 4from torch._higher_order_ops.utils import autograd_not_implemented 5from torch._ops import HigherOrderOperator 6from torch._subclasses import FakeTensorMode 7from torch.fx.experimental._backward_state import BackwardState 8from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree 9from torch.utils._python_dispatch import _get_current_dispatch_mode 10from torch.utils._pytree import tree_map_only 11 12 13__all__ = ["trace_wrapped"] 14 15 16# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist: 17# if you make_fx trace through this call, we will not actually trace into fn; instead, 18# we will directly insert it as a call_function to fn in the graph. 19# (Unlike make_fx, Dynamo WILL inline into fn.) 20# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing. 21# 22# Because proxy tensor tracing does not actually run the function, there are 23# requirements on the behavior of fn. We are still figuring it out, but here is the current state: 24# 25# 1) fn SHOULD only take a single argument, which must be a tensor 26# 2) fn MUST return a new tensor with the same metadata as the original tensor 27# (e.g., zeros_like(input) is a permissible implementation of fn). 28# This is verified via an extra assert that is inserted into the traced graph. 29# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors 30# participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state) 31# These requirements stem from the requirement that we need to continue performing proxy tensor tracing, 32# which assumes accurate fake tensor metadata, without actually running fn. 33# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns. 34# 35# Note that tensors / Python state are allowed to be mutated. 36# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake 37# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete 38# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python). 39# 40# The intended use case for this function is to allow AOTAutograd to defer complex 41# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves 42# the function call as is in the graph, and only when we Dynamo through the backward graph in 43# compiled autograd do we inline into the function. 44 45 46def trace_wrapped(*args, **kwargs): 47 with torch.no_grad(): 48 return _trace_wrapped_op(*args, **kwargs) 49 50 51class TraceWrapped(HigherOrderOperator): 52 def __init__(self): 53 super().__init__("trace_wrapped") 54 55 def __call__(self, *args, **kwargs): 56 return super().__call__(*args, **kwargs) 57 58 59# TODO(jansel): need to ensure this does not get DCEed 60_trace_wrapped_op = TraceWrapped() 61 62 63def _assert_meta(grad, size, stride, dtype): 64 assert grad.size() == size, "size mismatch" 65 assert grad.stride() == stride, "stride mismatch" 66 assert grad.dtype == dtype, "dtype mismatch" 67 return grad 68 69 70@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode) 71def inner_trace(mode, *args, bw_state=None, **kwargs): 72 def self_invoke(*args, **dyn_kwargs): 73 with torch.no_grad(): 74 return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs) 75 76 def unwrap_proxies(x): 77 if isinstance(x, torch.Tensor): 78 return mode.tracer.unwrap_proxy(x) 79 if isinstance(x, (list, tuple)): 80 return type(x)(map(unwrap_proxies, x)) 81 if x is None: 82 return None 83 raise AssertionError(f"unhandled type: {type(x)}") 84 85 proxy_kwargs = {} 86 if bw_state is not None: 87 assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None 88 proxy_kwargs["bw_state"] = bw_state.proxy 89 out_proxy = mode.tracer.create_proxy( 90 "call_function", 91 self_invoke, 92 unwrap_proxies(args), 93 proxy_kwargs, 94 name="trace_wrapped", 95 ) 96 97 if args[0] is None: 98 grad = args[1] # module backward hooks 99 else: 100 grad = args[0] # other backward hooks 101 grad = tree_map_only(torch.Tensor, torch.empty_like, grad) 102 track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer) 103 return grad 104 105 106@_trace_wrapped_op.py_impl(FakeTensorMode) 107def inner_fake(*args, **kwargs): 108 raise RuntimeError("This op should never be invoked here") 109 110 111@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd) 112def _trace_wrapped_op_dense(*args, fn, **kwargs): 113 mode = _get_current_dispatch_mode() 114 assert mode is None, "Mode should never be enabled for CPU/CUDA key" 115 return fn(*args, **kwargs) 116 117 118_trace_wrapped_op.py_impl(DispatchKey.Autograd)( 119 autograd_not_implemented(_trace_wrapped_op, deferred_error=True) 120) 121 122 123@_trace_wrapped_op.py_functionalize_impl 124def _trace_wrapped_functionalized(ctx, *args, **kwargs): 125 unwrapped_args = ctx.unwrap_tensors(args) 126 with ctx.redispatch_to_next(): 127 return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs)) 128