# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import copy import json import traceback from contextlib import contextmanager from dataclasses import asdict, dataclass from typing import ( Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, ) import executorch.extension.pytree as ex_pytree import torch import torch._dynamo as torchdynamo import torch.fx as fx import torch.fx._pytree as fx_pytree import torch.utils._pytree as pytree from executorch.exir.common import ( extract_out_arguments, format_schema_name, no_dispatch, setting_python_recursive_limit, ) from executorch.exir.error import ExportError, ExportErrorType, InternalError from executorch.exir.graph_module import LeafValue from executorch.exir.operator.convert import is_out_variant from executorch.exir.types import ValueSpec from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual from torch._decomp import get_decompositions from torch._dynamo.guards import Guard from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor from torch.export import default_decompositions from torch.func import functionalize from torch.fx.operator_schemas import normalize_function from torch.utils._pytree import TreeSpec from typing_extensions import TypeAlias Value: TypeAlias = Union[ LeafValue, Tuple["Value", ...], List["Value"], Dict[str, "Value"], ] torchdynamo_enabled = False def get_stacktrace() -> List[Dict[str, str]]: """ Get the current stacktrace (between trace() and __torch_dispatch__()) Include the filename, function name, line number, and source code from the start of the function to the given instruction. Return: A list of stacktraces for each instruction along with the source code context surrounding each instruction """ stacktrace = traceback.extract_stack() # The stacktrace typically looks like this: # # 1. I stack frames from the top level runner (e.g., the # test suite runner) # 2. J frames in executorch/exir/tracer.py setting up tracing # (call this INIT_EXIR) # 3. K frames in user model code (this is what we want to save!) # 4. 1 frame in executorch/exir/tracer.py __torch_function__ # returning to tracer (call this TRACE_EXIR) # 5. H frames in executorch/exir/tracer.py AND torch/_tensor.py # doing all of the internal tracer handling # # The PyE tests assert that executorch/exir/tracer.py never shows # up in the user provided stack traces, so we must oblige them. # # Assumptions: # - Reentrant tracing is not a thing. Thus, the first time # executorch/exir/tracer.py shows up in the trace, we know # THAT is the point at which we start tracing. (An alternative # is that the tracer entry point could record the stack trace # at this time, but I didn't do this.) # # Our plan is to do a miniature stack machine traversing these # stack machines. # Remove parts before the trace function and parts after entering # __torch_dispatch__. Defaults to returning the entire stack trace. init_exir_end = 0 trace_exir_start = None # A miniature state machine, referring to the frame segments described # above. The locations are closed-open interval. FIND_INIT_EXIR_START, FIND_INIT_EXIR_END, FIND_TRACE_EXIR_START = range(3) state = FIND_INIT_EXIR_START for i, frame in enumerate(stacktrace): if state == FIND_INIT_EXIR_START: if "executorch/exir/tracer.py" in frame.filename: state = FIND_INIT_EXIR_END elif state == FIND_INIT_EXIR_END: if "executorch/exir/tracer.py" not in frame.filename: init_exir_end = i state = FIND_TRACE_EXIR_START elif state == FIND_TRACE_EXIR_START: if "executorch/exir/tracer.py" in frame.filename: trace_exir_start = i break stacktrace = stacktrace[init_exir_end:trace_exir_start] # Get the source code from the errored line to it contexts: List[str] = [] for s in stacktrace: try: with open(s.filename) as file: # pyre-fixme[6]: For 1st param expected `Union[SupportsTrunc, bytes, # str, SupportsInt, SupportsIndex]` but got `Optional[int]`. lineno = int(s.lineno) # Get the source code 5 lines above/below the current instruction file_contents = [ str(index + 1) + line for index, line in enumerate(file.readlines()) ] file_contents_above = "".join( file_contents[max(lineno - 5, 0) : lineno] ) file_contents_below = "".join( file_contents[lineno : min(lineno + 5, len(file_contents))] ) context = ( file_contents_above + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + file_contents_below ) contexts.append(context) except FileNotFoundError: contexts.append("") # torch.fx stack preservation logic expects strings to # be passed around. Working with dictionary is lot easier # to convert to string and vice versa. frames: List[Dict[str, str]] = [] for i, frame in enumerate(stacktrace): frames.append( { "filename": str(frame.filename), "lineno": str(frame.lineno), "name": str(frame.name), "line": str(frame.line), "context": contexts[i], } ) return frames def unwrap_functional(t: torch.Tensor) -> torch.Tensor: assert isinstance(t, torch.Tensor) return _maybe_unwrap_functional_tensor(t, reapply_views=False) def unwrap_proxy(t: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: if not isinstance(t, torch.Tensor): return t t = unwrap_functional(t) return t.proxy if isinstance(t, PythonTensor) else t def single_return( output: LeafValue, proxy: torch.fx.Proxy, wrapper: Callable[..., LeafValue], ) -> LeafValue: if isinstance(output, torch.Tensor): return wrapper(output, proxy) return output def tree_return( outputs: Value, proxy: torch.fx.Proxy, wrapper: Callable[..., LeafValue], meta_type: Callable[..., Iterable[ValueSpec]] = tuple, ) -> Value: i: int = 0 def wrap(o: LeafValue) -> LeafValue: nonlocal i ret = single_return(o, proxy[i], wrapper) i += 1 return ret return pytree.tree_map(wrap, outputs) class DummyProxy: def __init__(self) -> None: class DummyNode: def __init__(self): self.meta = {} self.node = DummyNode() def __getitem__(self, key: str) -> "DummyProxy": return DummyProxy() class PythonTensor(torch.Tensor): """ A wrapper tensor subclass used in the DispatchTracer to keep track of proxies to construct the FX graph. Wrapping something in PythonTensor implicitly detaches gradients. If something required grad, we will collect it as if it were a leaf. A consequence of detaching in this way is you need to maintain a parameter cache when translating tensors into PythonTensor, so you don't create multiple copies of a gradient (they are aliased, but they would count as independent leaves). An alternate strategy would be to avoid implicitly detaching and instead "catch" gradients as they exit the PythonTensor boundary. """ __slots__ = ["proxy", "is_immutable"] @staticmethod def __new__( cls, elem: torch.Tensor, proxy: torch.fx.Proxy, is_immutable: bool = False ) -> torch.Tensor: # assert not elem.requires_grad or not torch.is_grad_enabled() r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) assert isinstance(r, PythonTensor) r.is_immutable: bool = is_immutable r.update_proxy(proxy) return r def update_proxy(self, proxy: torch.fx.Proxy) -> None: self.proxy = proxy def __repr__(self, *, tensor_contents: None = None) -> str: with no_dispatch(): return f"PythonTensor({self.as_subclass(torch.Tensor)})" @classmethod def __torch_function__( cls, # pyre-ignore: Missing parameter annotation [2] func, # pyre-ignore: Missing parameter annotation [2] types, args: Tuple[Value, ...] = (), kwargs: Optional[Dict[str, Value]] = None, ) -> Value: if kwargs is None: kwargs = {} if torch.is_inference_mode_enabled(): if func is torch.nn.functional.layer_norm: args, kwargs = normalize_function(func, args, kwargs) # pyre-fixme[23] input, normalized_shape = args normalized_shape = list(normalized_shape) return cls.__torch_dispatch__( torch.ops.aten.layer_norm.default, types, (input, normalized_shape), kwargs, ) elif func is torch.nn.functional.linear: return cls.__torch_dispatch__( torch.ops.aten.linear.default, types, args, kwargs ) with DisableTorchFunctionSubclass(): return func(*args, **kwargs) @classmethod def __torch_dispatch__( # noqa: C901 cls, func_overload: torch._ops.OpOverload, # pyre-ignore: Missing parameter annotation [2] types, args: Tuple[Value, ...] = (), kwargs: Optional[Dict[str, Value]] = None, ) -> Value: """ This function is invoked every time an aten operation is called. Args: func_overload: The function that was called that invoked this torch_dispatch call types: args: Arguments that were passed into the function. Each argument has type PythonTensor. kwargs: Keyword arguments that were passed into the function. Each argument has type PythonTensor. """ func = func_overload.overloadpacket kwargs = kwargs or {} if is_out_variant(func._qualified_op_name, func_overload._overloadname): out_args = extract_out_arguments(func_overload._schema, kwargs) out_args_iter = [out_args] if not isinstance(out_args, list) else out_args for out_arg_name, out_arg_val in out_args_iter: if isinstance(out_arg_val, PythonTensor) and out_arg_val.is_immutable: raise RuntimeError( "Immutable tensor `{}` is potentially getting modified by {}".format( out_arg_name, format_schema_name(func_overload._schema) ) ) # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. proxy_args = ex_pytree.tree_map(unwrap_proxy, args) # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. proxy_kwargs = ex_pytree.tree_map(unwrap_proxy, kwargs) # Get the output of the function g = _EnableTorchFunction() try: proxy_out = ( func_overload(*proxy_args, **proxy_kwargs) if DispatchTracer.get() or torchdynamo_enabled # Disable node creation when no tracer is active. else DummyProxy() ) finally: del g with no_dispatch(): real_out = func_overload(*args, **kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": if isinstance(args[0], PythonTensor): args[0].proxy = proxy_out if not torch.fx.traceback.has_preserved_node_meta(): proxy_out.node.meta["stack_trace"] = json.dumps(get_stacktrace()) # Wrap the output tensors with the PythonTensor subclass to propagate to # future tracing def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue: # Some ops (like native_batch_norm_backward) return undefined tensors that get # converted into None in python. # As the function signature expects tensors, if we directly return these None # tensors back to C++, we'll error. if e is None: e = torch.empty(()) if isinstance(e, torch.Tensor): return PythonTensor(e, proxy) # Inplace and out-variant ops may return one of their arguments, which is already # a PythonTensor. In this case, we need to update the PythonTensor's associated # proxy to the newly created proxy. if isinstance(e, PythonTensor): e.update_proxy(proxy) return e return e retval = None if not isinstance(real_out, (list, tuple)): retval = single_return(real_out, proxy_out, wrap_with_proxy) else: retval = tree_return(real_out, proxy_out, wrap_with_proxy, type(real_out)) return retval @contextmanager def using_tracer(tracer: Optional["DispatchTracer"]) -> Generator[None, None, None]: """ Set the "current" global tracer within the scope of using_tracer context manager. Since various things we want to capture today with torch_dispatch does not "trap" into dispatcher really (for example, cond() and shape()), we need a separate singleton tracer exposed to user space in addition to Dispatcher to trigger graph capturing. """ global TRACER TRACER, prev = tracer, TRACER try: yield finally: TRACER = prev class DispatchTracer(fx.Tracer): def __init__(self) -> None: super().__init__() self.root: torch.nn.Module = torch.nn.Module() self.tensor_attrs: Dict[torch.Tensor, str] = {} self.submodules: Dict[fx.GraphModule, str] = {} def call_module( self, m: torch.nn.Module, forward: Callable[..., Value], args: Tuple[Value, ...], kwargs: Dict[str, Value], ) -> Value: return forward(*args, **kwargs) def _module_getattr( self, attr: str, attr_val: Value, parameter_proxy_cache: Dict[str, torch.Tensor] ) -> Value: if isinstance(attr_val, torch.nn.Parameter): for n, p in self.root.named_parameters(): if attr_val is p: if n not in parameter_proxy_cache: proxy = self.create_proxy("get_attr", n, (), {}) parameter_proxy_cache[n] = PythonTensor(attr_val, proxy) return parameter_proxy_cache[n] return attr_val return attr_val def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901 if isinstance(a, torch.nn.Parameter): for n, p in self.root.named_parameters(): if a is p: return self.create_node("get_attr", n, (), {}) qualname: Optional[str] = None if not qualname: i = 0 while True: qualname = f"_param_constant{i}" if not hasattr(self.root, qualname): break i += 1 setattr(self.root, qualname, a) return self.create_node("get_attr", qualname, (), {}) if isinstance(a, torch.Tensor): qualname: Optional[str] = self.tensor_attrs.get(a) if not qualname: i = 0 while True: qualname = f"_tensor_constant{i}" if not hasattr(self.root, qualname): break i += 1 self.tensor_attrs[a] = qualname self.root.register_buffer(qualname, a) return self.create_node("get_attr", qualname, (), {}) # higher-order operator if isinstance(a, fx.GraphModule): if a not in self.submodules: name_submodule = f"submodule_{len(self.submodules)}" self.root.add_module(name_submodule, a) self.submodules[a] = name_submodule return self.create_node("get_attr", self.submodules[a], (), {}) return super().create_arg(a) # pyre-fixme[7] @staticmethod def get() -> "DispatchTracer": return TRACER def trace( # pyre-fixme[14,15] self, root: Callable[..., Value], concrete_args: Tuple[Value, ...] = (), in_spec: Optional[TreeSpec] = None, ) -> Value: """ Traces the given graph module. """ with using_tracer(self): return self._trace(root, concrete_args=concrete_args, in_spec=in_spec) def _trace( self, root: Callable[..., Value], concrete_args: Tuple[Value, ...], in_spec: Optional[TreeSpec], ) -> Value: self.root = torch.nn.Module() root_fn = root tracer_cls = getattr(self, "__class__", None) self.graph = fx.Graph(tracer_cls=tracer_cls) # Don't support module, so tensor_attrs is always empty self.tensor_attrs = {} # Wrap all inputs as a PythonTensor subclass and insert them into the FX # graph as placeholder nodes def wrap(arg: Value, i: int) -> Value: placeholder = self.create_proxy("placeholder", f"ph_{i}", (), {}) if isinstance(arg, torch.Tensor): return PythonTensor(arg, placeholder, is_immutable=True) else: # torch._assert( # placeholder == arg, # f"ph_{i} has been specialized to have value {arg}", # ) return arg tree_args = [wrap(arg, i) for i, arg in enumerate(concrete_args)] if in_spec: tree_args = pytree.tree_unflatten(tree_args, in_spec) tree_out = root_fn(*tree_args) out_args, _ = pytree.tree_flatten(tree_out) def unwrap(out: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: # it's legit for a model to return a list of items some of which # are None if out is None: return None if not isinstance(out, torch.Tensor): raise TypeError( f"Expect model to return torch.Tensor, got type: '{type(out)}' (value: {out})." ) return unwrap_proxy(out) returns = [unwrap(out) for out in out_args] return_annotation = None # some ops like torch.sub doesn't have annotations if hasattr(root_fn, "__annotations__"): return_annotation = root_fn.__annotations__.get("return", None) self.create_proxy( "output", "output", (returns,), {}, type_expr=return_annotation, ) self.submodule_paths = None return tree_out TRACER: Optional[DispatchTracer] = None TORCHDYNAMO_ENABLED: bool = False @contextmanager def using_dynamo(val: bool) -> Generator[None, None, None]: global TORCHDYNAMO_ENABLED TORCHDYNAMO_ENABLED, prev = val, TORCHDYNAMO_ENABLED try: yield finally: TORCHDYNAMO_ENABLED = prev def flattened_dispatch_trace( f: Callable[..., Value], args: Tuple[LeafValue, ...], guards: Set[Guard], in_spec: Optional[TreeSpec] = None, enable_functionalization: bool = True, ) -> Tuple[torch.fx.GraphModule, Value]: if not isinstance(args, tuple): raise TypeError(f"Expecting 'args' to be a tuple, got: {type(args)}") tracer = DispatchTracer() if enable_functionalization: f = functionalize(f, remove="mutations_and_views") tree_out = tracer.trace(f, concrete_args=args, in_spec=in_spec) name = type(f).__name__ if isinstance(f, torch.nn.Module) else f.__name__ gm = torch.fx.GraphModule(tracer.root, tracer.graph, name) return (gm, tree_out) @dataclass class ExirDynamoConfig: """ Manage Exir-specific configurations of Dynamo. """ allow_rnn: bool = True verbose: bool = True assume_static_by_default: bool = False def flatten_output(gm: torch.fx.GraphModule) -> None: """ Modifies the output nodes in the submodules to return the result as a flattened list. This keeps it consistent with the result of EXIR's tracer """ for node in reversed(gm.graph.nodes): if node.op == "output": assert len(node.args) == 1 outputs = node.args[0] returns, _ = pytree.tree_flatten(outputs) node.args = (returns,) return raise RuntimeError(f"Could not find an output node in {gm.graph}") def _default_decomposition_table( _use_old_decomp_table=False, ) -> Dict[torch._ops.OpOverload, Callable[..., Value]]: if _use_old_decomp_table: decomp_opset = [ torch.ops.aten.log_sigmoid_forward, torch.ops.aten.ones, torch.ops.aten.arange.default, torch.ops.aten.arange.start, torch.ops.aten.transpose, ] # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... return get_decompositions(decomp_opset) # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... return default_decompositions() def dynamo_trace( f: Callable[..., Value], # pyre-ignore args: Tuple[Any, ...], aten_graph: bool, tracing_mode: str = "real", dynamo_config: Optional[ExirDynamoConfig] = None, # pyre-ignore dynamic_shapes: Optional[List[Any]] = None, _use_old_decomp_table: bool = False, ) -> Tuple[torch.fx.GraphModule, Set[Guard]]: """ TODO: Once we fully migrate to torchdynamo frontend, we will remove this config option alltogether. For now, it helps with quick experiments with playing around with TorchDynamo """ if dynamo_config is None: dynamo_config = ExirDynamoConfig() with torchdynamo.config.patch( asdict(dynamo_config) ), setting_python_recursive_limit(2000): torchdynamo.reset() try: # TODO merge executorch functionalization with official # functionalization # pyre-fixme[7]: Expected `Tuple[GraphModule, Set[Guard]]` but got # `ExportResult`. return torchdynamo.export( f, aten_graph=aten_graph, tracing_mode=tracing_mode, assume_static_by_default=dynamo_config.assume_static_by_default, decomposition_table=( _default_decomposition_table(_use_old_decomp_table) if aten_graph else None ), dynamic_shapes=dynamic_shapes, )( *copy.deepcopy(args), ) except torchdynamo.exc.Unsupported as exc: raise ExportError( ExportErrorType.NOT_SUPPORTED, "The user code is using a feature we don't support. " "Please try torchdynamo.explain() to get possible the reasons", ) from exc except Exception as exc: raise InternalError( "torchdynamo internal error occured. Please see above stacktrace" ) from exc def dispatch_trace( f: Callable[..., Value], args: Tuple[Value, ...], ) -> torch.fx.GraphModule: """ Executes a given callable `f` with a given tuple of arguments. During execution, Tensor operations are recorded in a fx.GraphModule, which is then returned. Args: f: A `nn.Module` or a Python function that implements an ML program. args: A tuple of arguments of any type to be used as inputs for the tracing run. Returns: EXIR contained in a fx.GraphModule """ trace_func = f guards = set() if TORCHDYNAMO_ENABLED: # Copying args is safer in case downstream implementations of trace_func mutate them trace_func, guards = dynamo_trace(trace_func, args, False) # Copying args is safer in case downstream implementations of trace_func mutate them trace_args, in_spec = pytree.tree_flatten(args) in_args = copy.deepcopy(tuple(trace_args)) gm, tree_out = flattened_dispatch_trace( trace_func, in_args, guards, in_spec, enable_functionalization=False, ) _, out_spec = pytree.tree_flatten(tree_out) # pyre-fixme[16]: `GraphModule` has no attribute `in_spec`. gm.in_spec = in_spec # pyre-fixme[16]: `GraphModule` has no attribute `out_spec`. gm.out_spec = out_spec # TODO (tmanlaibaatar) This is bit clowny, but our # dispatch_trace sometimes creates unused node that # breaks functionalization. it seems too much trouble # to fix it properly since dispatch_trace will be deprecated soon. # Basically dispatch_trace struggles on: # def f(x: torch.Tensor) -> torch.Tensor: # return torch.ones(6, dtype=x.dtype) changed = gm.graph.eliminate_dead_code() if changed: gm.recompile() in_args = copy.deepcopy(tuple(trace_args)) assert callable(gm) # This wrapper is used for preserving the stacktrace # during second round of tracing. # pyre-ignore def graph_with_interpreter(*args): try: args = fx_pytree.tree_flatten_spec(args, gm.in_spec) # type: ignore[assignment] except Exception: _, received_spec = pytree.tree_flatten(args) raise RuntimeError( "Trying to flatten user inputs with exported input tree spec: \n" f"{gm.in_spec}\n" "but actually got inputs with tree spec of: \n" f"{received_spec}" ) with torch.fx.traceback.preserve_node_meta(): res = gm(*args) if gm.out_spec is not None: try: res = pytree.tree_unflatten(res, gm.out_spec) except Exception: _, received_spec = pytree.tree_flatten(res) raise RuntimeError( "Trying to flatten user outputs with exported output tree spec: \n" f"{gm.out_spec}\n" "but actually got outputs with tree spec of: \n" f"{received_spec}" ) return res gm, tree_out = flattened_dispatch_trace( graph_with_interpreter, in_args, guards, in_spec, enable_functionalization=True, ) gm.in_spec = in_spec gm.out_spec = out_spec return gm