• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import operator
3import traceback
4import typing
5from contextlib import nullcontext
6from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
7
8import torch
9from functorch.experimental.control_flow import _unstack_pytree
10from torch import fx
11from torch._dispatch.python import enable_python_dispatcher
12from torch._export.pass_infra.node_metadata import NodeMetadata
13from torch._export.pass_infra.proxy_value import ProxyValue
14from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
15from torch._subclasses.fake_tensor import FakeTensorMode
16from torch.fx import traceback as fx_traceback
17from torch.fx.experimental.proxy_tensor import PythonKeyTracer
18from torch.fx.graph import CodeGen
19from torch.fx.passes.infra.pass_base import PassBase, PassResult
20from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
21from torch.utils import _pytree as pytree
22from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
23
24
25__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
26
27
28Argument = Any
29Value = Any
30Fn = Callable[..., Any]
31PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
32
33
34_TORCH_SYM_OPS: Set[Callable] = {
35    torch.sym_int,
36    torch.sym_float,
37    torch.sym_ite,
38    torch.sym_max,
39    torch.sym_min,
40    torch.sym_not,
41    torch.sym_sqrt,
42}
43
44
45class ExportPassBaseError(RuntimeError):
46    pass
47
48
49class _ExportPassBaseDeprecatedDoNotUse(PassBase):
50    """
51    Interpreter-based pass class to help users maintain the IR spec while writing
52    transformations.
53    """
54
55    @staticmethod
56    def _create_dummy_node_metadata():
57        return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
58
59
60    class ExportTracer(PythonKeyTracer):
61        def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
62            super().__init__()
63            self.callback = callback
64            self.root = torch.nn.Module()
65            self.graph = torch.fx.Graph()
66            self.graph.set_codegen(codegen)
67            self.tensor_attrs: Dict[str, torch.Tensor] = {}  # type: ignore[assignment]
68            self.fake_tensor_mode: Optional[FakeTensorMode] = None
69            self.submodules: Dict[torch.nn.Module, str] = {}
70
71        def trace(self) -> None:  # type: ignore[override]
72            raise ExportPassBaseError("ExportTracer doesn't support trace().")
73
74        def create_arg(self, a: Argument) -> torch.fx.Node:
75            if isinstance(a, torch.nn.Module):
76                if a not in self.submodules:
77                    name_submodule = f"submodule_{len(self.submodules)}"
78                    self.root.add_module(name_submodule, a)
79                    self.submodules[a] = name_submodule
80            elif isinstance(a, FakeTensor):
81                if not hasattr(a, "constant") or a.constant is None:
82                    raise ExportPassBaseError(f"Cannot add {a} to graph.")
83                a = a.constant
84            node = super().create_arg(a)
85            if (
86                isinstance(a, torch.Tensor)
87                and isinstance(node, torch.fx.Node)
88                and node.op == "get_attr"
89            ):
90                self.set_metadata(node, a)
91                self.callback.on_attr(ProxyValue(a, node))
92            return node
93
94        def set_metadata(
95            self, node: torch.fx.Node, value: Argument,
96        ) -> None:
97            # propagate the fake tensor or sym nodes
98            def make_val(
99                x: Argument,
100            ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
101                if isinstance(x, FakeTensor):
102                    return x
103                elif isinstance(x, torch.Tensor):
104                    if x.is_quantized:
105                        # TODO (tmanlaibaatar) properly support Quantized FakeTensor
106                        x = torch.dequantize(x)
107
108                    try:
109                        assert self.fake_tensor_mode is not None
110                        # TODO we should allocate static shapes
111                        # for param/buffer values
112                        if isinstance(x, torch.nn.Parameter):
113                            fake_tensor = self.fake_tensor_mode.from_tensor(
114                                x, static_shapes=True
115                            )
116                        else:
117                            fake_tensor = self.fake_tensor_mode.from_tensor(x)
118                    except UnsupportedFakeTensorException:
119                        # TODO: This is just a workaround to get over the
120                        # x.as_subclass error
121                        print(
122                            "Fakeifying a Tensor subclass is not supported \
123                            right now. Instead a TensorMetadata is used."
124                        )
125                        fake_tensor = None
126                    return fake_tensor
127                elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
128                    return x
129                else:
130                    return None
131
132            node.meta["val"] = pytree.tree_map(make_val, value)
133
134            # Set the tensor_metadata for values that do not have a corresponding FakeTensor
135            def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
136                if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
137                    if x.is_quantized:
138                        # TODO (tmanlaibaatar) properly support Quantized FakeTensor
139                        x = torch.dequantize(x)
140
141                    try:
142                        assert self.fake_tensor_mode is not None
143                        _ = self.fake_tensor_mode.from_tensor(x)
144                        tensor_meta = None
145                    except UnsupportedFakeTensorException:
146                        # TODO: This is just a workaround to get over the
147                        # x.as_subclass error
148                        tensor_meta = _extract_tensor_metadata(x)
149                    return tensor_meta
150                else:
151                    return None
152
153            node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
154
155    class ExportInterpreter(fx.Interpreter):
156        def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
157            super().__init__(gm)
158            self.callback = callback
159            self.node: torch.fx.Node = next(iter(gm.graph.nodes))
160
161        def placeholder(
162            self,
163            target: str,  # type: ignore[override]
164            args: Tuple[Argument, ...],
165            kwargs: Dict[str, Argument],
166        ) -> ProxyValue:
167            arg = super().placeholder(target, args, kwargs)
168            return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
169
170        def output(
171            self,
172            target: torch.fx.node.Target,
173            args: Tuple[Argument, ...],
174            kwargs: Dict[str, Argument],
175        ) -> ProxyValue:
176            return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
177
178        def call_function(
179            self,
180            target: torch.fx.node.Target,
181            args: Tuple[Argument, ...],
182            kwargs: Dict[str, Argument],
183        ) -> ProxyValue:
184            meta = NodeMetadata(self.node.meta)
185
186            if target == operator.getitem:
187                value, key = args
188                return self.callback.call_getitem(value, key, meta)
189            elif getattr(target, "__module__", None) in {"_operator", "math"}:
190                assert callable(target)
191                return self.callback.call_sym(target, args, meta)
192            elif target in _TORCH_SYM_OPS:
193                assert callable(target)
194                return self.callback.call_sym(target, args, meta)
195            elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
196                return self.callback.call_operator(
197                    target,
198                    args,
199                    kwargs,
200                    meta,
201                )
202            elif target == torch.ops.higher_order.cond:
203                pred, true_fn, false_fn, inputs = args
204                return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
205            elif target == torch.ops.higher_order.map_impl:
206                f, mapped_args, operands = args  # type: ignore[assignment]
207                return self.callback.call_map(f, mapped_args, operands, meta)
208            # For other unregistered HigherOrderOps, just interpret them blindly
209            elif isinstance(target, torch._ops.HigherOrderOperator):
210                return self.callback._fx(
211                    "call_function",
212                    target,
213                    args,
214                    kwargs,
215                    meta,
216                )
217            else:
218                raise ExportPassBaseError(f"Unsupported target type: {target}")
219
220        def get_attr(
221            self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]  # type: ignore[override]
222        ) -> Argument:
223            return super().get_attr(target, args, kwargs)
224
225        def call_module(
226            self,
227            target: torch.fx.node.Target,
228            args: Tuple[Argument, ...],
229            kwargs: Dict[str, Argument],
230        ) -> None:
231            raise ExportPassBaseError("call_module is not supported.")
232
233        def call_method(
234            self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]  # type: ignore[override]
235        ) -> None:
236            raise ExportPassBaseError("call_method is not supported.")
237
238        def run_node(self, n: torch.fx.Node) -> Argument:
239            self.node = n
240            self.callback.node_debug_str = n.format_node()
241            return super().run_node(n)
242
243    def __init__(self) -> None:
244        self.interpreter = PropagateUnbackedSymInts(
245            torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
246        )
247        self.tracer = self.ExportTracer(self, CodeGen())
248        self.fake_tensor_mode: Optional[FakeTensorMode] = None
249        self._initialized = True
250        self.node_debug_str: typing.Optional[str] = None
251
252    def _fx(
253        self,
254        kind: str,
255        target: torch.fx.node.Target,
256        args: Tuple[Argument, ...],
257        kwargs: Dict[str, Argument],
258        meta: NodeMetadata,
259    ) -> ProxyValue:
260        args_data, kwargs_data = pytree.tree_map_only(
261            ProxyValue, lambda x: x.data, (args, kwargs)
262        )
263        res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
264        args_proxy, kwargs_proxy = pytree.tree_map_only(
265            ProxyValue, lambda x: x.proxy, (args, kwargs)
266        )
267
268        name = None
269        if isinstance(target, torch._ops.OpOverload):
270            name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
271
272        res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
273        res_proxy.node.meta.update(meta.data)
274        if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
275            if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
276                res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
277        self.tracer.set_metadata(res_proxy.node, res_data)
278        return ProxyValue(res_data, res_proxy)
279
280    def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
281        # TODO(angelayi): Update this with what we decide to do for metadata in
282        # the exported graph module
283        if (args := graph_module.meta.get("args", None)) is not None:
284            return list(args)
285
286        def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
287            if "val" in node.meta:
288                fake = node.meta["val"]
289                if hasattr(fake, "constant") and fake.constant is not None:
290                    return fake.constant
291                return fake
292            elif tensor_meta := node.meta.get("tensor_meta"):
293                assert self.fake_tensor_mode is not None
294                return FakeTensor(
295                    self.fake_tensor_mode,
296                    torch.empty(
297                        tensor_meta.shape,
298                        dtype=tensor_meta.dtype,
299                        device="meta",
300                        requires_grad=tensor_meta.requires_grad,
301                        memory_format=tensor_meta.memory_format,
302                    ),
303                    torch.device("cpu"),
304                )
305            elif len(node.users) == 0:
306                return None
307            raise ExportPassBaseError(
308                f"Cannot construct an input for graph module: {graph_module}.",
309            )
310
311        return [
312            extract_input(node)
313            for node in graph_module.graph.nodes
314            if node.op == "placeholder"
315        ]
316
317    def on_attr(self, attr: ProxyValue) -> None:
318        pass
319
320    def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
321        arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
322        arg_proxy.node.meta = meta.data
323        self.tracer.set_metadata(arg_proxy.node, arg)
324        return ProxyValue(arg, arg_proxy)
325
326    def call_operator(
327        self,
328        op,
329        args: Tuple[Argument, ...],
330        kwargs: Dict[str, Argument],
331        meta: NodeMetadata,
332    ) -> ProxyValue:
333        return self._fx("call_function", op, args, kwargs, meta)
334
335    def call_sym(
336        self,
337        target: Fn,
338        args: Tuple[Argument, ...],
339        meta: NodeMetadata,
340    ) -> ProxyValue:
341        return self._fx("call_function", target, args, {}, meta)
342
343    def call_cond(
344        self,
345        pred: ProxyValue,
346        true_fn: torch.fx.GraphModule,
347        false_fn: torch.fx.GraphModule,
348        inputs: List[Argument],
349        meta: NodeMetadata,
350    ) -> ProxyValue:
351        true_branch = self.call_submodule(true_fn, tuple(inputs))
352        false_branch = self.call_submodule(false_fn, tuple(inputs))
353        assert true_branch is not None
354        assert false_branch is not None
355        return self._fx(
356            "call_function",
357            torch.ops.higher_order.cond,
358            (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
359            {},
360            meta,
361        )
362
363    def call_map(
364        self,
365        f: torch.fx.GraphModule,
366        mapped_args: List[ProxyValue],
367        operands: List[ProxyValue],
368        meta: NodeMetadata,
369    ) -> ProxyValue:
370        xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
371        f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
372        assert f_branch is not None
373        return self._fx(
374            "call_function",
375            torch.ops.higher_order.map_impl,
376            (f_branch.graph_module, mapped_args, operands),
377            {},
378            meta,
379        )
380
381    def call_getitem(
382        self, value: ProxyValue, key: int, meta: NodeMetadata
383    ) -> ProxyValue:
384        return self._fx("call_function", operator.getitem, (value, key), {}, meta)
385
386    def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
387        return self._fx("output", "output", (results,), {}, meta)
388
389    def call_submodule(
390        self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
391    ) -> PassResult:
392        prev_tracer, self.tracer = self.tracer, self.ExportTracer(
393            self, graph_module.graph._codegen
394        )
395        self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
396        interpreter = self.ExportInterpreter(self, graph_module)
397        prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(  # type: ignore[assignment]
398            torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
399        )
400        inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
401        with fx_traceback.preserve_node_meta():
402            interpreter.run(*inputs_data)
403
404        new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
405
406        self.tracer = prev_tracer
407        self.interpreter = prev_interpreter
408        return PassResult(
409            new_graph_module,
410            True,
411        )
412
413    def call(self, graph_module: fx.GraphModule) -> PassResult:
414        if not getattr(self, "_initialized", False):
415            raise ExportPassBaseError(
416                "ExportPass is not initialized with __init__().",
417            )
418
419        inputs = self.inputs(graph_module)
420
421        fake_tensor_mode = None
422        for i in inputs:
423            if isinstance(i, FakeTensor):
424                assert (
425                    fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
426                ), "Multiple fake tensor mode detected."
427                fake_tensor_mode = i.fake_mode
428        if fake_tensor_mode is None:
429            self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
430            fake_tensor_mode = nullcontext()  # type: ignore[assignment]
431            dispatcher_mode = nullcontext()  # type: ignore[assignment]
432        else:
433            fake_tensor_mode.allow_non_fake_inputs = True
434            self.tracer.fake_tensor_mode = fake_tensor_mode
435            dispatcher_mode = enable_python_dispatcher()  # type: ignore[assignment]
436        self.fake_tensor_mode = self.tracer.fake_tensor_mode
437
438        with fake_tensor_mode, dispatcher_mode:  # type: ignore[assignment, union-attr]
439            result = self.call_submodule(graph_module, tuple(inputs))
440
441        return result
442