# mypy: ignore-errors import functools import inspect import logging import operator import textwrap import traceback import types import unittest from typing import Dict, List, TYPE_CHECKING import sympy import torch._numpy as tnp import torch.fx import torch.random from torch._dynamo import compiled_autograd from torch._subclasses.meta_utils import is_sparse_any from torch.fx.experimental.symbolic_shapes import ( guard_scalar, GuardOnDataDependentSymNode, has_free_symbols, is_symbolic, SymTypes, ) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config, variables from .._trace_wrapped_higher_order_op import trace_wrapped from ..exc import unimplemented, UserError, UserErrorType from ..external_utils import call_hook_from_backward_state from ..guards import GuardBuilder, install_guard from ..source import AttrSource from ..utils import ( fqn, get_custom_getattr, get_fake_value, get_real_value, guard_if_dyn, object_has_getattribute, product, proxy_args_kwargs, set_example_value, tensortype_to_dtype, ) from .base import VariableTracker from .constant import ConstantVariable from .lists import SizeVariable try: import numpy as np except ModuleNotFoundError: np = None if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator log = logging.getLogger(__name__) # Ops that allow tensor tensor supported_tensor_comparison_ops = { ">": operator.gt, "<": operator.lt, ">=": operator.ge, "<=": operator.le, "==": operator.eq, "!=": operator.ne, } # Ops that allow tensor None supported_const_comparison_ops = { "is": operator.is_, "is not": operator.is_not, "==": operator.eq, "!=": operator.ne, } supported_comparison_ops = { **supported_tensor_comparison_ops, **supported_const_comparison_ops, } supported_tensor_comparison_op_values = dict.fromkeys( supported_tensor_comparison_ops.values() ) supported_const_comparison_op_values = dict.fromkeys( supported_const_comparison_ops.values() ) class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" _nonvar_fields = { "proxy", "dtype", "device", "layout", "ndim", "size", "stride", "requires_grad", "is_quantized", "is_contiguous", "is_sparse", "class_type", "specialized_value", "_is_name_set", *VariableTracker._nonvar_fields, } def get_real_value(self): """ Get the actual value represented by this variable if computation is run using the user-provided inputs. NOTE: this runs actual tensor computation and may be slow and memory-intensive. """ return get_real_value(self.proxy.node, self.proxy.tracer) def __init__( self, proxy: torch.fx.Proxy, *, dtype, device, layout, ndim, requires_grad, is_quantized, is_sparse, class_type, has_grad_fn, size=None, stride=None, is_contiguous=None, _is_name_set=None, **kwargs, ) -> None: super().__init__(**kwargs) self.proxy = proxy self.dtype = dtype self.device = device self.layout = layout self.ndim = ndim self.size = size self.stride = stride self.requires_grad = requires_grad self.is_quantized = is_quantized self.is_contiguous = is_contiguous self.is_sparse = is_sparse self.class_type = class_type self.has_grad_fn = has_grad_fn if _is_name_set is None: # no need to rename inputs _is_name_set = self.proxy.node.op == "placeholder" self._is_name_set: bool = _is_name_set def debug_repr(self): # TODO: strip off fake tensor from repr here return repr(self.proxy.node.meta["example_value"]) def as_proxy(self): return self.proxy def python_type(self): return self.class_type @staticmethod def specialize(value: torch.Tensor): props = { "dtype": value.dtype, "device": value.device, "layout": value.layout, "ndim": int(value.ndim), "requires_grad": value.requires_grad, "is_quantized": value.is_quantized, "is_sparse": value.is_sparse, "class_type": type(value), } try: props["has_grad_fn"] = value.grad_fn is not None except Exception: # Workaround for issues with create_parameter_op in Dynamo. Reading # grad_fn should never cause an issue. props["has_grad_fn"] = False if is_sparse_any(value) and not has_free_symbols(value): props["size"] = tuple( [int(s) if is_symbolic(s) else s for s in value.size()] ) elif not has_free_symbols(value): # this is a fully static shape, and the keys on props here inform specialization. # We have to cast to int here, because these might get accessed as ConstantVariable, which has # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and # I'd like to keep it around for now. props["size"] = tuple( # the non is_symbolic case applies to the jagged layout # NestedTensor case as singleton ints are not symbolic [int(s) if is_symbolic(s) else s for s in value.size()] ) props["stride"] = tuple(value.stride()) if torch._C._functorch.is_batchedtensor(value): # Batched tensors does not support contiguity patterns, so # we refrain from computing the `is_contiguous` property props["is_contiguous"] = None else: props["is_contiguous"] = tuple( [ x for x in torch._prims_common._memory_formats if value.is_contiguous(memory_format=x) ] ) return props def dynamic_getattr(self, tx: "InstructionTranslator", name): fake_val = self.proxy.node.meta["example_value"] # For getattrs on tensors without sources, # we can do better than the default (creating a GetAttrVariable) # if: # (1) the tensor is a traceable tensor subclass # (2) We are getattr'ing an inner tensor from that subclass if not self.source and is_traceable_wrapper_subclass(fake_val): fake_val = self.proxy.node.meta["example_value"] attrs, ctx = fake_val.__tensor_flatten__() proxy = getattr(self.as_proxy(), name) example_value = getattr(fake_val, name) if name in attrs: # attrs returned from tensor_flatten are always tensors assert isinstance(example_value, torch.Tensor) from .builder import wrap_fx_proxy return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value) # any other attributes on the subclass (that are not methods) # are assumed to be constant metadata. elif not callable(example_value): from .builder import SourcelessBuilder return SourcelessBuilder.create(tx, example_value) if not (self.source and self.source.subguards_allowed()): raise NotImplementedError # For local source, we associate the real value. We use this real value # for implementing getattr fallthrough on the variable tracker base class. # Note - this scope construction is mirrored in guards # A subsequent PR will introduce a util. scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} try: # We raise in case we get a typerror bug w/ SuperSource. # SuperSource has bugs in it atm, and can produce code like # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, # L['mod'].model.model.encoder.embed_positions)", scope) # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. _input_associated_real_value = eval(self.source.name(), scope) except Exception as exc: raise NotImplementedError from exc if _input_associated_real_value is None: raise NotImplementedError if object_has_getattribute(_input_associated_real_value): raise NotImplementedError if get_custom_getattr(_input_associated_real_value): raise NotImplementedError real_value = getattr(_input_associated_real_value, name) if callable(real_value): # Callables have more nuanced handling, and we should let the existing system delegate here. # Raising was past behavior and so should always be sound to fall back. # Note - at a certain point we may want to handle raise NotImplementedError from ..guards import GuardBuilder from .builder import VariableBuilder attr_source = AttrSource(self.source, name) install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) return VariableBuilder(tx, attr_source)(real_value) def method_attr_ndim(self, tx): if self.ndim is not None: return ConstantVariable.create(self.ndim) else: return self.call_method(tx, "dim", [], {}) def method_attr_dtype(self, tx): if self.dtype is not None: return ConstantVariable.create(self.dtype) def method_attr_device(self, tx): if self.device is not None: return ConstantVariable.create(self.device) def method_attr_layout(self, tx): if self.layout is not None: return ConstantVariable.create(self.layout) def method_attr_is_cuda(self, tx): if self.device is not None: return ConstantVariable.create(self.device.type == "cuda") def method_attr_shape(self, tx): if self.size is not None: sizes = [variables.ConstantVariable.create(x) for x in self.size] return SizeVariable(sizes) else: return self.call_method(tx, "size", [], {}) def method_attr_requires_grad(self, tx): if self.requires_grad is not None: return ConstantVariable.create(self.requires_grad) def method_attr_is_quantized(self, tx): if self.is_quantized is not None: return ConstantVariable.create(self.is_quantized) def method_attr_is_sparse(self, tx): if self.is_sparse is not None: return ConstantVariable.create(self.is_sparse) def method_attr_data(self, tx): return variables.TorchInGraphFunctionVariable( torch._C._autograd._get_data_attr ).call_function(tx, [self], {}) def method_attr_grad_fn(self, tx): if self.has_grad_fn: unimplemented("TensorVariable has a grad_fn") else: return variables.ConstantVariable(None) def method_attr__version(self, tx): from ..tensor_version_op import _tensor_version return variables.TorchInGraphFunctionVariable(_tensor_version).call_function( tx, [self], {} ) def call_hasattr(self, tx: "InstructionTranslator", name): from . import GetAttrVariable from .builtin import BuiltinVariable try: var = BuiltinVariable(getattr).call_function( tx, [self, ConstantVariable(name)], {} ) # in the event that TensorVariable returns NotImplemented # BuiltinVariable.call_getattr returns GetAttrVariable ret_val = not isinstance(var, GetAttrVariable) except AttributeError: ret_val = False if self.source: install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) return ConstantVariable(ret_val) def var_getattr(self, tx: "InstructionTranslator", name): from . import UserDefinedClassVariable if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): unimplemented(f"Illegal getattr invocation {name} in strict mode") if name == "__class__": return UserDefinedClassVariable(self.python_type()) handler = getattr(self, f"method_attr_{name}", None) result = handler(tx) if handler is not None else None # Add a guard for type matching, these guards are checked before tensor guards # In some cases, a . guard can be evaluated first, and break if # is later changed to another type if ( result is not None and self.source and self.source.subguards_allowed() and not ( name not in ("grad", "requires_grad") and result.is_python_constant() ) ): install_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) result.source = AttrSource(self.source, name) # It's hard to get inplace view (metadata mutation) on graph input work properly across # dynamo/aot/inductor, just fall back. if self.source is not None and hasattr(torch.ops.aten, name): fn = getattr(torch.ops.aten, name) if ( hasattr(fn, "overloads") and hasattr(fn, fn.overloads()[0]) and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags ): # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. return variables.misc.DelayGraphBreakVariable( source=AttrSource(self.source, name) ) # For attributes (not methods) that were not caught in the special handling above, # (e.g. tensor.real), we handle these generically, assuming that the output type is # a tensor. if result is None and name != "grad": def try_generic_attr_handling(): from .builder import wrap_fx_proxy from .misc import GetAttrVariable try: static_attr = inspect.getattr_static(torch.Tensor, name) except AttributeError: return None # Make sure this is an attribute, not a method. # type(torch.Tensor.H) should be "getset_descriptor" # This is a because of CPython implementation, see THPVariableType: # these attributes are implemented under tp_getset, which appear # as `getset_descriptor`s, (compared to, say, methods which appear # as `method_descriptor`s) if type(static_attr) != types.GetSetDescriptorType: return None proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) if self.source is not None: return wrap_fx_proxy( tx=tx, proxy=proxy, source=AttrSource(self.source, name) ) else: return wrap_fx_proxy(tx=tx, proxy=proxy) result = try_generic_attr_handling() if result is None: result = self.dynamic_getattr(tx, name) if result is None: raise NotImplementedError return result def call_id(self, tx): if not self.source: unimplemented("call_id not supported for sourceless TensorVariable") # For local source, we associate the real value. We use this real value scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} try: _input_associated_real_value = eval(self.source.name(), scope) except Exception as exc: unimplemented(f"error getting associated real value: {exc}") if _input_associated_real_value is None: unimplemented("call_id without associated real value") install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) id_value = id(_input_associated_real_value) return ConstantVariable.create(id_value) def has_unpack_var_sequence(self, tx): return self.ndim > 0 def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): from .builder import wrap_fx_proxy_cls if self.size: size_len = len(self.size) else: size_var = self.call_method(tx, "size", [], {}) assert isinstance(size_var, SizeVariable) size_len = len(size_var.items) # Ensure we don't unpack a scalar tensor. assert size_len != 0, "Can't unpack scalar tensors." if self.size: length = self.size[0] else: dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {}) # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through # symbolic_shapes, but that end up as int/sympy.Integer assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) if isinstance(dyn_length, SymNodeVariable): length = dyn_length.evaluate_expr(tx.output) else: length = dyn_length.value if idxes is None: idxes = range(length) else: assert ( len(idxes) == length ), f"Can't unpack a tensor of {length} rows into a tuple of {len(idxes)} elements." return [ wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i]) for i in idxes ] def _strict_mode_banned_ops(self): return torch._dynamo.config._autograd_backward_strict_mode_banned_ops def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): unimplemented(f"Illegal method invocation {name} in strict mode") """ Dispatch to a method-specific handler defined below. If the handler returns None (or doesn't exist) we put the method call in the graph. """ try: handler_method = getattr(self, f"method_{name}") except AttributeError: pass else: try: result = handler_method(*args, **kwargs) if result: return result except TypeError as e: unimplemented(f"unhandled args for {name}: {e}") from .builder import wrap_fx_proxy return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self, *args], kwargs), ), ) def method_size(self, *args, **kwargs): return self._method_size_stride("size", *args, **kwargs) def method_stride(self, *args, **kwargs): return self._method_size_stride("stride", *args, **kwargs) def _method_size_stride(self, name, dim=None): dim = guard_if_dyn(dim) def make_const_size_variable(x, **options): return SizeVariable( [ConstantVariable.create(y, **options) for y in x], **options ) RetVariable = ( make_const_size_variable if name == "size" else ConstantVariable.create ) # Technically, this should not be necessary, but I'm including it # for enhanced BC, in case example_value is sometimes not set # (it really should always be set though!) if (r := getattr(self, name)) is not None: if dim is None: return RetVariable(r) else: return ConstantVariable.create(r[dim]) # It might still be constant! Consult the fake tensor and see if (fake := self.proxy.node.meta.get("example_value")) is not None: if dim is None: fake_r = getattr(fake, name)() if not has_free_symbols(fake_r): # int conversion for safety, in case a SymInt refined # to constant return RetVariable(tuple(int(r) for r in fake_r)) else: fake_r = getattr(fake, name)(dim) if not has_free_symbols(fake_r): return ConstantVariable.create(int(fake_r)) def method_numel(self): if self.size is not None: return ConstantVariable.create(product(self.size)) # It might still be constant! Consult the fake tensor and see if (fake := self.proxy.node.meta.get("example_value")) is not None: fake_r = fake.numel() if not has_free_symbols(fake_r): return ConstantVariable.create(int(fake_r)) method_nelement = method_numel def method_dim(self): if self.ndim is not None: return ConstantVariable.create(self.ndim) method_ndimension = method_dim def method_is_floating_point(self): if self.dtype is not None: return ConstantVariable.create(self.dtype.is_floating_point) def method_is_complex(self): if self.dtype is not None: return ConstantVariable.create(self.dtype.is_complex) def method_is_contiguous(self, memory_format=None): memory_format = ( memory_format.as_python_constant() if memory_format is not None else torch.contiguous_format ) if self.is_contiguous is not None: return ConstantVariable.create(memory_format in self.is_contiguous) elif (fake := self.proxy.node.meta.get("example_value")) is not None: return ConstantVariable.create( fake.is_contiguous(memory_format=memory_format) ) def method_type(self, dtype=None, non_blocking=False, **kwargs): if ( dtype is None and self.dtype is not None and isinstance(self.device, torch.device) ): tensortype = next( k for k, v in tensortype_to_dtype.items() if self.dtype in v ) if self.device.type == "cuda": return ConstantVariable.create(f"torch.cuda.{tensortype.__name__}") else: return ConstantVariable.create(f"torch.{tensortype.__name__}") elif ( dtype is not None and fqn(type(dtype.as_python_constant())) == "torch.tensortype" ): # torch.FloatTensor, etc. are all of type "torch.tensortype". # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type. # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args) tensor_type = dtype.as_python_constant() tensor_type_const = ConstantVariable.create(fqn(tensor_type)) from ..symbolic_convert import InstructionTranslator from .builder import wrap_fx_proxy tx = InstructionTranslator.current_tx() if non_blocking: kwargs = {"non_blocking": non_blocking, **kwargs} return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", "type", *proxy_args_kwargs([self, tensor_type_const], kwargs), ), ) def method_as_subclass(self, cls): if isinstance(cls, TensorSubclassVariable) and cls.source: from ..symbolic_convert import InstructionTranslator from .builder import VariableBuilder from .torch_function import TensorWithTFOverrideVariable tx = InstructionTranslator.current_tx() # [Note: __torch_function__] coerce this tensor variable into a TensorWithTFOverrideVariable # in eager, this is just a type change. This isn't sound if a __torch_function__ tensor subclass # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call. # It is up to the user whether this is correct behavior or not. py_cls = cls.as_python_constant() torch_fn = VariableBuilder( tx, AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"), )(py_cls.__torch_function__.__func__) return TensorWithTFOverrideVariable.from_tensor_var( tx, self, py_cls, torch_fn ) def method_get_device(self): if isinstance(self.device, torch.device): index = self.device.index if self.device.type != "cpu" else -1 return ConstantVariable.create(index) def method_element_size(self): return ConstantVariable.create(self.dtype.itemsize) def method_numpy(self, *, force=False): if not config.trace_numpy: unimplemented("Tensor.numpy(). config.trace_numpy is False") if not np: unimplemented("Tensor.numpy(). NumPy is not available") if self.layout != torch.strided: raise TypeError( f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first" ) from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() # We don't check that the tensor is on CPU when force is False, as this # allows us to execute NumPy code on CUDA. Same for requires_grad=True if force and force.as_python_constant(): # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...) t = self.call_method(tx, "detach", [], {}) proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {}) else: # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable proxy = tx.output.create_proxy( "call_method", "view_as", *proxy_args_kwargs([self, self], {}) ) return NumpyNdarrayVariable.create(tx, proxy) def method_tolist(self): from ..symbolic_convert import InstructionTranslator from .builder import SourcelessBuilder tx = InstructionTranslator.current_tx() def tolist(tensor, sub_proxy): def wrap(i, sub_proxy): # Sigh, we forgot to gate this, so this data dependent is on # by default and is load bearing in CI with unittest.mock.patch.object( tx.fake_mode, "allow_scalar_outputs", True ): return SymNodeVariable.create( tx, sub_proxy.item(), ) if tensor.dtype not in [ torch.int8, torch.int16, torch.int32, torch.int64, ]: unimplemented("Input tensor for tolist must be an integer tensor") if tensor.dim() == 0: return wrap(tensor, sub_proxy) if tensor.dim() == 1: return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)] return [ tolist(sub_tensor, sub_proxy=sub_proxy[i]) for i, sub_tensor in enumerate(tensor) ] tensor = self.as_proxy().node.meta["example_value"] out = tolist(tensor, self.as_proxy()) return SourcelessBuilder.create(tx, out) def method_backward(self, *args, **kwargs): unimplemented("Tensor.backward") def method_data_ptr(self, *args, **kwargs): unimplemented("Tensor.data_ptr") def method_item(self, *args, **kwargs): if not config.capture_scalar_outputs: self._warn_capture_scalar_outputs() unimplemented("Tensor.item") @staticmethod @functools.lru_cache(None) def _warn_capture_scalar_outputs(): user_stack = torch._guards.TracingContext.extract_stack() user_stack_formatted = "".join(traceback.format_list(user_stack)) log.warning( textwrap.dedent( """\ Graph break from `Tensor.item()`, consider setting: torch._dynamo.config.capture_scalar_outputs = True or: env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 to include these operations in the captured graph. Graph break: from user code at: %s """ ), user_stack_formatted, ) def method___len__(self): from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) def method_addcmul_(self, tensor1, tensor2, *, value=None): from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() if value is not None: from .. import polyfills from .builder import SourcelessBuilder return tx.inline_user_function_return( SourcelessBuilder.create(tx, polyfills.addcmul_inplace), [self, tensor1, tensor2, value], {}, ) def method___setitem__(self, key, value): def has_bool_key(v): if isinstance(v, TensorVariable): return v.dtype in (torch.bool, torch.int8) elif isinstance(v, variables.TupleVariable): return any(has_bool_key(item) for item in v.items) else: return False if ( has_bool_key(key) and isinstance(value, TensorVariable) and value.requires_grad and torch.is_grad_enabled() ): unimplemented( "boolean masking setitem backwards, see https://github.com/pytorch/pytorch/issues/114123" ) from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() tx.output.create_proxy( "call_function", operator.setitem, *proxy_args_kwargs([self, key, value], {}), ) return ConstantVariable.create(None) def method_resize_(self, *args, **kwargs): unimplemented("Tensor.resize_") def method_resize_as_(self, *args, **kwargs): unimplemented("Tensor.resize_as_") def method_sparse_resize_(self, *args, **kwargs): unimplemented("Tensor.sparse_resize_") def method_sparse_resize_and_clear_(self, *args, **kwargs): unimplemented("Tensor.sparse_resize_and_clear_") def method_set_(self, *args, **kwargs): if len(args) > 1: # torch.Tensor.set_() has several overloads. # aten::set_.source_Tensor(Tensor) gets special handling # in AOTAutograd and functionalization, because it is the most common # overload and is used by FSDP. # graph-breaking on aten::set_source_Tensor_storage_offset for now, # unless we find that we need to make it work. unimplemented("Tensor.set_.source_Tensor_storage_offset") def method_add_(self, other, *, alpha=None): if alpha is not None: from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( tx, [other, alpha], {} ) return self.call_method(tx, "add_", [result], {}) def method_addcdiv_(self, tensor1, tensor2, *, value=None): from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() if value is not None: result = variables.TorchInGraphFunctionVariable(torch.div).call_function( tx, [tensor1, tensor2], {} ) result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( tx, [result, value], {} ) return self.call_method(tx, "add_", [result], {}) def method___contains__(self, arg): from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() # Rewrite __contains__ here so that downstream passes can trace through # without dealing with unbacked symbool. Roughly the code we translate is: # def __contains__(self, x): # return (x == self).any().item() result = variables.TorchInGraphFunctionVariable(torch.eq).call_function( tx, [self, arg], {} ) result = variables.TorchInGraphFunctionVariable(torch.any).call_function( tx, [result], {} ) return result.call_method(tx, "item", [], {}) def method_redistribute(self, *args, **kwargs): from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function args_as_value = [x.as_python_constant() for x in args] kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} def redistribute_fn_with_prim_types(x): return x.redistribute(*args_as_value, **kwargs_as_value) # attach the same function name for better debugging redistribute_fn_with_prim_types.__name__ = "prim_redistribute" from .builder import wrap_fx_proxy return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", redistribute_fn_with_prim_types, *proxy_args_kwargs([self], {}), ), ) def method_to_local(self, *args, **kwargs): from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function args_as_value = [x.as_python_constant() for x in args] kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} def to_local_fn_with_prim_types(x): return x.to_local(*args_as_value, **kwargs_as_value) # attach the same function name for better debugging to_local_fn_with_prim_types.__name__ = "prim_to_local" from .builder import wrap_fx_proxy return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", to_local_fn_with_prim_types, *proxy_args_kwargs([self], {}), ), ) def method_register_hook(self, *args, **kwargs): return self._method_register_hook("register_hook", *args, **kwargs) def method_register_post_accumulate_grad_hook(self, *args, **kwargs): return self._method_register_hook( "register_post_accumulate_grad_hook", *args, **kwargs ) def _method_register_hook(self, name: str, hook: VariableTracker): # Note - do not arbitrarily add hooks here - make sure they match the same contract # see [On tensor.register_hook] from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() if not self.source: if not compiled_autograd.compiled_autograd_enabled: # TODO(voz): # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary # python state. # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run # them in a compiled bwd without re-entering dynamo as compiled_autograd does. # # Discussion point 1 - Should we bypass this if nopython/fullgraph = True? # No. Because this was going to be a graph break anyway - this check does not # introduce new graph breaks where there were none. # # Discussion point 2 - Should we defer this check to backwards? # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user # would have no recourse - their forward traces just fine, but will fail at backwards unless # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) # then they have nothing they can do except disable compile. unimplemented( "Compilation of intermediate hooks requires compiled autograd" ) hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook) def _register_hook_trampoline(tensor, bw_state): register_hook = getattr(tensor, name) register_hook( functools.partial( trace_wrapped, fn=call_hook_from_backward_state, bw_state=bw_state, hook_name=hook_name, ) ) # TODO(jansel): returning None here is wrong, it should be # RemovableHandle, but we need some extra work to support # this properly. return None from .builder import wrap_fx_proxy return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", _register_hook_trampoline, (self.as_proxy(), bw_state_proxy), {}, ), ) handle_variable = variables.RemovableHandleVariable( mutable_local=variables.base.MutableLocal(), ) tx.output.side_effects.register_hook(self, hook, handle_variable, name) return handle_variable def method_requires_grad_(self, requires_grad=True): if requires_grad is not True: requires_grad = requires_grad.as_python_constant() if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: unimplemented("Tensor.requires_grad_") else: return self def method_new(self, *args, **kwargs): # Convert x.new(torch.Size) into x.new_empty(torch.Size), # as Tensor.new acts differently with a Size input versus a tuple input. if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( len(args) >= 1 and all( isinstance(a, ConstantVariable) and a.python_type() == int for a in args ) ): from ..symbolic_convert import InstructionTranslator return self.call_method( InstructionTranslator.current_tx(), "new_empty", args, kwargs ) def method_untyped_storage(self): return UntypedStorageVariable( self, self.as_proxy().node.meta["example_value"].untyped_storage() ) def set_name_hint(self, name: str): if not self._is_name_set: self.proxy.node._rename(name) self._is_name_set = True class SymNodeVariable(VariableTracker): """ Represents a symbolic scalar, either int, float or bool. This is most commonly used to handle symbolic size computation, e.g., tensor.size(0), but it is also used to handle logic like float_tensor.item() or unspecialized float inputs. """ _nonvar_fields = { "proxy", "sym_num", *VariableTracker._nonvar_fields, } def debug_repr(self): return repr(self.sym_num) @classmethod def create(cls, tx, proxy, sym_num=None, **options): if sym_num is None: sym_num = get_fake_value(proxy.node, tx) if "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == sym_num set_example_value(proxy.node, sym_num) if isinstance(sym_num, (sympy.Integer, int, bool)): sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num return ConstantVariable.create(sym_num) return SymNodeVariable(proxy, sym_num, **options) def __init__(self, proxy, sym_num, **kwargs) -> None: super().__init__(**kwargs) self.proxy = proxy # TODO: Should we allow non SymTypes here? Today it is allowed self.sym_num = sym_num self._tensor_var = None def python_type(self): if isinstance(self.sym_num, SymTypes): return self.sym_num.node.pytype else: return type(self.sym_num) def as_proxy(self): return self.proxy def as_tensor(self, tx): if self._tensor_var is None: from .builder import SourcelessBuilder self._tensor_var = SourcelessBuilder.create( tx, torch.scalar_tensor ).call_function(tx, [self], {}) return self._tensor_var def evaluate_expr(self, output_graph=None): try: return guard_scalar(self.sym_num) except GuardOnDataDependentSymNode as e: raise UserError( # noqa: B904 UserErrorType.ANTI_PATTERN, f"Consider annotating your code using torch._check*(). {str(e)}", case_name="constrain_as_size_example", ) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from .builder import wrap_fx_proxy return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self, *args], kwargs), ), ) class NumpyNdarrayVariable(TensorVariable): """ Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray. Use this for Tensor.numpy() call. """ @staticmethod def create(tx: "InstructionTranslator", proxy, **options): from .builder import wrap_fx_proxy_cls return wrap_fx_proxy_cls( target_cls=NumpyNdarrayVariable, tx=tx, proxy=proxy, **options, ) def var_getattr(self, tx: "InstructionTranslator", name): # NB: This INTENTIONALLY does not call super(), because there is # no intrinsic reason ndarray properties are related to Tensor # properties. The inheritance here is for implementation sharing. from ..utils import numpy_attr_wrapper from .builder import wrap_fx_proxy result = None example_value = self.as_proxy().node.meta["example_value"] example_ndarray = tnp.ndarray(example_value) def insert_into_graph(): return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {} ), ) if name in ["T", "real", "imag"]: proxy = tx.output.create_proxy( "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}, ) result = NumpyNdarrayVariable.create(tx, proxy) # These are awkward to implement. The standard playbook for torch._numpy # interop is to trace a call into the torch._numpy wrapper which works for # Tensor operations. However, we don't want to do this for calls # that don't return Tensors, because in those cases we may not want # to trace the attribute access into the graph at all (it is sort # of harmless to do so, because AOTAutograd will eliminate them, # but it's best not to trace them in to begin with.) But in any # case, tracing these into the graph is like trying to fit a square # peg into a round hole; best not to do it. So instead we # painstakingly implement these by hand # # NB: only ALWAYS specialized attributes can go here; notably, # size/shape not allowed! elif name in ("ndim", "itemsize"): return ConstantVariable.create(getattr(example_ndarray, name)) elif name in ("shape", "stride"): if not has_free_symbols(r := getattr(example_ndarray, name)): return ConstantVariable.create(tuple(int(r) for r in r)) return insert_into_graph() elif name == "size": if not has_free_symbols(r := example_ndarray.size): return ConstantVariable.create(int(r)) return insert_into_graph() elif name in ["base", "flags", "dtype"]: unimplemented(f"TODO: add support for ndarray.{name}") elif name in ["__version__"]: unimplemented("delegate np.__version__ to NumPy") if result is None: raise NotImplementedError return result @staticmethod def patch_args(name, args, kwargs): if name == "clip": kwargs_rename = {"a_min": "min", "a_max": "max"} kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()} return args, kwargs def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from ..utils import numpy_method_wrapper args, kwargs = self.patch_args(name, args, kwargs) if name in ["__len__", "size", "tolist"]: # delegate back to TensorVariable return super().call_method(tx, name, args, kwargs) if name in ("tostring", "tobytes"): unimplemented(f"{name} is not modelled in torch._numpy") proxy = tx.output.create_proxy( "call_function", numpy_method_wrapper(name), *proxy_args_kwargs([self] + list(args), kwargs), ) return NumpyNdarrayVariable.create(tx, proxy) def python_type(self): return np.ndarray class UnspecializedPythonVariable(TensorVariable): """ This is a 1-element tensor represents unspecialized python float/int. """ _nonvar_fields = { "raw_value", "need_unwrap", *TensorVariable._nonvar_fields, } def __init__( self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs ) -> None: super().__init__(proxy, **kwargs) self.raw_value = raw_value self.need_unwrap = need_unwrap @classmethod def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True): # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance. return UnspecializedPythonVariable( **dict(tensor_variable.__dict__), raw_value=raw_value, need_unwrap=need_unwrap, ) class FakeItemVariable(TensorVariable): """An unspecialized python variable which prevents access to the underlying raw value. This is needed if item is called on a FakeTensor.""" _nonvar_fields = { "need_unwrap", *TensorVariable._nonvar_fields, } def __init__(self, proxy: torch.fx.Proxy, **kwargs) -> None: need_unwrap = kwargs.pop("need_unwrap", False) super().__init__(proxy, **kwargs) self.need_unwrap = need_unwrap @classmethod def from_tensor_variable(cls, tensor_variable): return FakeItemVariable(**dict(tensor_variable.__dict__)) class TensorSubclassVariable(VariableTracker): def __init__(self, value, *args, **kwargs) -> None: self.value = value super().__init__(*args, **kwargs) def call_function( self, tx: "InstructionTranslator", args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> VariableTracker: if len(args) == 1 and isinstance(args[0], TensorVariable): from .builder import VariableBuilder from .torch_function import TensorWithTFOverrideVariable torch_fn = VariableBuilder( tx, AttrSource(self.source, "__torch_function__") )(self.value.__torch_function__) return TensorWithTFOverrideVariable.from_tensor_var( tx, args[0], self.value, torch_fn ) return super().call_function(tx, args, kwargs) def as_python_constant(self): return self.value class UntypedStorageVariable(VariableTracker): _nonvar_fields = { "example_value", *VariableTracker._nonvar_fields, } def __init__( self, from_tensor: TensorVariable, example_value: torch.UntypedStorage, **kwargs, ) -> None: super().__init__(**kwargs), self.from_tensor = from_tensor # Example_value will always have device="meta" self.example_value = example_value def call_method( self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> VariableTracker: if name == "size": assert not args assert not kwargs result = self.example_value.size() if not has_free_symbols(result): # avoid creating a node in the graph return ConstantVariable.create(int(result)) else: from ..external_utils import untyped_storage_size from .builder import wrap_fx_proxy return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", untyped_storage_size, (self.from_tensor.as_proxy(),), {}, ), ) if name == "resize_" and len(args) == 1: assert not kwargs tx.output.create_proxy( "call_function", torch.ops.inductor.resize_storage_bytes_, (self.from_tensor.as_proxy(), args[0].as_proxy()), {}, ) return self return super().call_method(tx, name, args, kwargs) def reconstruct(self, codegen): codegen(self.from_tensor) codegen.load_method("untyped_storage") codegen.call_method(0)