• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# NOTE: This file is referenced by name at
3#       /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES.
4#       introduced by https://github.com/pytorch/pytorch/pull/98894.
5#       If this file is renamed, moved, etc please update the reference there!
6
7from __future__ import annotations
8
9import contextlib
10import functools
11import inspect
12from typing import Any, Callable, Mapping, Sequence
13
14import torch._dynamo
15import torch.export as torch_export
16import torch.fx
17import torch.onnx
18from torch.onnx._internal import _exporter_legacy, io_adapter
19from torch.utils import _pytree as pytree
20
21
22class _PyTreeExtensionContext:
23    """Context manager to register PyTree extension."""
24
25    _extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]]
26
27    def __init__(self) -> None:
28        self._extensions = {}
29        # Register PyTree extension for HuggingFace model output.
30        self._register_huggingface_model_output_extension()
31
32    def __enter__(self):
33        for class_type, (flatten_func, unflatten_func) in self._extensions.items():
34            pytree._private_register_pytree_node(
35                class_type,
36                flatten_func,
37                unflatten_func,
38            )
39        return self
40
41    def __exit__(self, exc_type, exc_val, exc_tb):
42        for class_type in self._extensions:
43            pytree.SUPPORTED_NODES.pop(class_type)
44
45    def register_pytree_node(
46        self,
47        class_type: type,
48        flatten_func: pytree.FlattenFunc,
49        unflatten_func: pytree.UnflattenFunc,
50    ):
51        """Register PyTree extension for a custom python type.
52
53        Args:
54            class_type: The custom python type.
55            flatten_func: The flatten function.
56            unflatten_func: The unflatten function.
57
58        Raises:
59            AssertionError: If the custom python type is already registered.
60        """
61        if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions:
62            # PyTree node already registered.
63            # E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
64            # https://github.com/huggingface/transformers/pull/25358.
65            return
66        self._extensions[class_type] = (flatten_func, unflatten_func)
67
68    def _register_huggingface_model_output_extension(self):
69        try:
70            from transformers import modeling_outputs  # type: ignore[import]
71        except ImportError as e:
72            return
73
74        def model_output_flatten(
75            output: modeling_outputs.ModelOutput,
76        ) -> tuple[list[Any], pytree.Context]:
77            return list(output.values()), (type(output), list(output.keys()))
78
79        def model_output_unflatten(
80            values: list[Any], context: pytree.Context
81        ) -> modeling_outputs.ModelOutput:
82            output_type, keys = context
83            return output_type(**dict(zip(keys, values)))
84
85        # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
86        named_model_output_classes = inspect.getmembers(
87            modeling_outputs,
88            lambda x: (
89                inspect.isclass(x)
90                and issubclass(x, modeling_outputs.ModelOutput)
91                and x is not modeling_outputs.ModelOutput
92            ),
93        )
94
95        for _, class_type in named_model_output_classes:
96            self.register_pytree_node(
97                class_type,
98                model_output_flatten,
99                model_output_unflatten,  # type: ignore[arg-type ]
100            )
101
102
103class DynamoFlattenOutputStep(io_adapter.FlattenOutputStep):
104    """Flatten nested collection and custom python types and return a flat list of elements.
105
106    Extended from :class:`io_adapter.FlattenOutputStep` to support flattening arbitrary
107    types via pytree extension. By default this supports many common user defined python
108    types such as :class:`ModelOutput` from HuggingFace transformers.
109
110    The pytree extension can be customized by passing in a ``_PyTreeExtensionContext``
111    object. See :meth:`_PyTreeExtensionContext.register_pytree_node`.
112    """
113
114    def __init__(self, pytree_extension_context: _PyTreeExtensionContext | None = None):
115        super().__init__()
116        self._pytree_extension_context = (
117            pytree_extension_context or _PyTreeExtensionContext()
118        )
119
120    def apply(
121        self,
122        model_outputs: Any,
123        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
124    ) -> Sequence[Any]:
125        """Flatten the model outputs, under the context of pytree extension."""
126        with self._pytree_extension_context:
127            return super().apply(model_outputs, model=model)
128
129
130def _wrap_model_with_output_adapter(
131    model: torch.nn.Module | Callable,
132    output_adapter: DynamoFlattenOutputStep,
133) -> Callable:
134    """Wrap model with output adapter.
135
136    This is a helper function to enable :func:`dynamo.export` on models that produce
137    custom user defined types outputs. It wraps the model with an output adapter to
138    convert the outputs to :func:`dynamo.export` compatible types, i.e. :class:`torch.Tensor`.
139
140    The adapting logic is controlled by ``output_adapter``.
141
142    Args:
143        model: PyTorch model or function.
144        output_adapter: Output adapter to apply to model output.
145    Returns:
146        Wrapped model.
147    """
148    model_func = model.forward if isinstance(model, torch.nn.Module) else model
149
150    # Preserve original function signature.
151    @functools.wraps(model_func)
152    def wrapped(*args, **kwargs):
153        return output_adapter.apply(model_func(*args, **kwargs), model=model)
154
155    return wrapped
156
157
158class DynamoExport(_exporter_legacy.FXGraphExtractor):
159    """Generates a FX GraphModule using torch.dynamo.export API
160    Args:
161        aten_graph: If True, exports a graph with ATen operators.
162                    If False, exports a graph with Python operators.
163    """
164
165    def __init__(
166        self,
167        aten_graph: bool | None = None,
168    ):
169        super().__init__()
170        self.aten_graph = aten_graph or True
171
172    def generate_fx(
173        self,
174        options: _exporter_legacy.ResolvedExportOptions,
175        model: torch.nn.Module | Callable,
176        model_args: Sequence[Any],
177        model_kwargs: Mapping[str, Any],
178    ) -> torch.fx.GraphModule:
179        # `dynamo.export` does not recognize custom user defined classes as output type.
180        # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types,
181        # i.e. :class:`torch.Tensor`.
182        dynamo_flatten_output_step = DynamoFlattenOutputStep()
183        wrapped_model = _wrap_model_with_output_adapter(
184            model, dynamo_flatten_output_step
185        )
186        # Record the output adapter step.
187        self.output_adapter.append_step(dynamo_flatten_output_step)
188
189        # Translate callable to FX graph.
190        #
191        fake_mode = (
192            options.fake_context.fake_mode
193            if options.fake_context
194            else contextlib.nullcontext()
195        )
196        fx_mode = "symbolic" if options.dynamic_shapes else "fake"
197        with fake_mode:  # type: ignore[attr-defined]
198            graph_module, graph_guard = torch._dynamo.export(
199                wrapped_model,
200                tracing_mode=fx_mode,
201            )(
202                *model_args,
203                **model_kwargs,
204            )
205        del graph_guard  # Unused
206        torch._dynamo.reset()
207
208        # Export FX graph to ONNX ModelProto.
209        self.input_adapter.append_step(
210            io_adapter.FlattenInputWithTreeSpecValidationInputStep()
211        )
212
213        updated_model_args = self.input_adapter.apply(
214            *model_args, model=model, **model_kwargs
215        )
216
217        return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
218
219    def pre_export_passes(
220        self,
221        options: _exporter_legacy.ResolvedExportOptions,
222        original_model: torch.nn.Module | Callable,
223        fx_module: torch.fx.GraphModule,
224        fx_module_args: Sequence[Any],
225    ):
226        return _exporter_legacy.common_pre_export_passes(
227            options, original_model, fx_module, fx_module_args
228        )
229