# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import functools from typing import Dict import torch from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported from .base import VariableTracker from .user_defined import UserDefinedObjectVariable def _raise_hard_error_if_graph_break(reason): def deco(fn): @functools.wraps(fn) def graph_break_as_hard_error(*args, **kwargs): try: return fn(*args, **kwargs) except Unsupported as e: raise UnsafeScriptObjectError(e.msg) from e return graph_break_as_hard_error return deco class TorchScriptObjectVariable(UserDefinedObjectVariable): _fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {} @classmethod def is_matching_cls(cls, user_cls: type): return issubclass(user_cls, torch.ScriptObject) @staticmethod def create(proxy, value, **options): return TorchScriptObjectVariable(proxy, value, **options) def __init__(self, proxy, value, source, **kwargs) -> None: super().__init__(value, **kwargs) self.proxy = proxy self.proxy.node.meta["example_value"] = value self.source = source def as_proxy(self): return self.proxy @_raise_hard_error_if_graph_break( "Dynamo cannot safely trace script object due to graph break." ) def var_getattr(self, tx, name: str) -> VariableTracker: from torch._higher_order_ops.torchbind import call_torchbind from ..source import AttrSource from .higher_order_ops import TorchHigherOrderOperatorVariable method = getattr(self.value, name, None) if method is None: unimplemented( f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?" ) if not callable(method): unimplemented( "Only method calls on TorchScript objects can be supported safely." " Please use method calls instead of attribute access." ) return TorchHigherOrderOperatorVariable.make( call_torchbind, source=AttrSource(self.source, name), script_obj_var=self, method_name=name, ) # We only support method calls on script objects. Interpreting the bytecodes # should go through var_getattr then call_function instead of call_method. # # However, it's possible for call_method to be used directly e.g. for __setattr__. @_raise_hard_error_if_graph_break( "Dynamo cannot safely trace script object due to graph break." ) def call_method(self, tx, name, args, kwargs): unimplemented(f"call method {name} on script object is not safe.")