1# mypy: allow-untyped-defs 2# NOTE: This file is referenced by name at 3# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. 4# introduced by https://github.com/pytorch/pytorch/pull/98894. 5# If this file is renamed, moved, etc please update the reference there! 6 7from __future__ import annotations 8 9import contextlib 10import functools 11import inspect 12from typing import Any, Callable, Mapping, Sequence 13 14import torch._dynamo 15import torch.export as torch_export 16import torch.fx 17import torch.onnx 18from torch.onnx._internal import _exporter_legacy, io_adapter 19from torch.utils import _pytree as pytree 20 21 22class _PyTreeExtensionContext: 23 """Context manager to register PyTree extension.""" 24 25 _extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]] 26 27 def __init__(self) -> None: 28 self._extensions = {} 29 # Register PyTree extension for HuggingFace model output. 30 self._register_huggingface_model_output_extension() 31 32 def __enter__(self): 33 for class_type, (flatten_func, unflatten_func) in self._extensions.items(): 34 pytree._private_register_pytree_node( 35 class_type, 36 flatten_func, 37 unflatten_func, 38 ) 39 return self 40 41 def __exit__(self, exc_type, exc_val, exc_tb): 42 for class_type in self._extensions: 43 pytree.SUPPORTED_NODES.pop(class_type) 44 45 def register_pytree_node( 46 self, 47 class_type: type, 48 flatten_func: pytree.FlattenFunc, 49 unflatten_func: pytree.UnflattenFunc, 50 ): 51 """Register PyTree extension for a custom python type. 52 53 Args: 54 class_type: The custom python type. 55 flatten_func: The flatten function. 56 unflatten_func: The unflatten function. 57 58 Raises: 59 AssertionError: If the custom python type is already registered. 60 """ 61 if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions: 62 # PyTree node already registered. 63 # E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after 64 # https://github.com/huggingface/transformers/pull/25358. 65 return 66 self._extensions[class_type] = (flatten_func, unflatten_func) 67 68 def _register_huggingface_model_output_extension(self): 69 try: 70 from transformers import modeling_outputs # type: ignore[import] 71 except ImportError as e: 72 return 73 74 def model_output_flatten( 75 output: modeling_outputs.ModelOutput, 76 ) -> tuple[list[Any], pytree.Context]: 77 return list(output.values()), (type(output), list(output.keys())) 78 79 def model_output_unflatten( 80 values: list[Any], context: pytree.Context 81 ) -> modeling_outputs.ModelOutput: 82 output_type, keys = context 83 return output_type(**dict(zip(keys, values))) 84 85 # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. 86 named_model_output_classes = inspect.getmembers( 87 modeling_outputs, 88 lambda x: ( 89 inspect.isclass(x) 90 and issubclass(x, modeling_outputs.ModelOutput) 91 and x is not modeling_outputs.ModelOutput 92 ), 93 ) 94 95 for _, class_type in named_model_output_classes: 96 self.register_pytree_node( 97 class_type, 98 model_output_flatten, 99 model_output_unflatten, # type: ignore[arg-type ] 100 ) 101 102 103class DynamoFlattenOutputStep(io_adapter.FlattenOutputStep): 104 """Flatten nested collection and custom python types and return a flat list of elements. 105 106 Extended from :class:`io_adapter.FlattenOutputStep` to support flattening arbitrary 107 types via pytree extension. By default this supports many common user defined python 108 types such as :class:`ModelOutput` from HuggingFace transformers. 109 110 The pytree extension can be customized by passing in a ``_PyTreeExtensionContext`` 111 object. See :meth:`_PyTreeExtensionContext.register_pytree_node`. 112 """ 113 114 def __init__(self, pytree_extension_context: _PyTreeExtensionContext | None = None): 115 super().__init__() 116 self._pytree_extension_context = ( 117 pytree_extension_context or _PyTreeExtensionContext() 118 ) 119 120 def apply( 121 self, 122 model_outputs: Any, 123 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 124 ) -> Sequence[Any]: 125 """Flatten the model outputs, under the context of pytree extension.""" 126 with self._pytree_extension_context: 127 return super().apply(model_outputs, model=model) 128 129 130def _wrap_model_with_output_adapter( 131 model: torch.nn.Module | Callable, 132 output_adapter: DynamoFlattenOutputStep, 133) -> Callable: 134 """Wrap model with output adapter. 135 136 This is a helper function to enable :func:`dynamo.export` on models that produce 137 custom user defined types outputs. It wraps the model with an output adapter to 138 convert the outputs to :func:`dynamo.export` compatible types, i.e. :class:`torch.Tensor`. 139 140 The adapting logic is controlled by ``output_adapter``. 141 142 Args: 143 model: PyTorch model or function. 144 output_adapter: Output adapter to apply to model output. 145 Returns: 146 Wrapped model. 147 """ 148 model_func = model.forward if isinstance(model, torch.nn.Module) else model 149 150 # Preserve original function signature. 151 @functools.wraps(model_func) 152 def wrapped(*args, **kwargs): 153 return output_adapter.apply(model_func(*args, **kwargs), model=model) 154 155 return wrapped 156 157 158class DynamoExport(_exporter_legacy.FXGraphExtractor): 159 """Generates a FX GraphModule using torch.dynamo.export API 160 Args: 161 aten_graph: If True, exports a graph with ATen operators. 162 If False, exports a graph with Python operators. 163 """ 164 165 def __init__( 166 self, 167 aten_graph: bool | None = None, 168 ): 169 super().__init__() 170 self.aten_graph = aten_graph or True 171 172 def generate_fx( 173 self, 174 options: _exporter_legacy.ResolvedExportOptions, 175 model: torch.nn.Module | Callable, 176 model_args: Sequence[Any], 177 model_kwargs: Mapping[str, Any], 178 ) -> torch.fx.GraphModule: 179 # `dynamo.export` does not recognize custom user defined classes as output type. 180 # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types, 181 # i.e. :class:`torch.Tensor`. 182 dynamo_flatten_output_step = DynamoFlattenOutputStep() 183 wrapped_model = _wrap_model_with_output_adapter( 184 model, dynamo_flatten_output_step 185 ) 186 # Record the output adapter step. 187 self.output_adapter.append_step(dynamo_flatten_output_step) 188 189 # Translate callable to FX graph. 190 # 191 fake_mode = ( 192 options.fake_context.fake_mode 193 if options.fake_context 194 else contextlib.nullcontext() 195 ) 196 fx_mode = "symbolic" if options.dynamic_shapes else "fake" 197 with fake_mode: # type: ignore[attr-defined] 198 graph_module, graph_guard = torch._dynamo.export( 199 wrapped_model, 200 tracing_mode=fx_mode, 201 )( 202 *model_args, 203 **model_kwargs, 204 ) 205 del graph_guard # Unused 206 torch._dynamo.reset() 207 208 # Export FX graph to ONNX ModelProto. 209 self.input_adapter.append_step( 210 io_adapter.FlattenInputWithTreeSpecValidationInputStep() 211 ) 212 213 updated_model_args = self.input_adapter.apply( 214 *model_args, model=model, **model_kwargs 215 ) 216 217 return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value] 218 219 def pre_export_passes( 220 self, 221 options: _exporter_legacy.ResolvedExportOptions, 222 original_model: torch.nn.Module | Callable, 223 fx_module: torch.fx.GraphModule, 224 fx_module_args: Sequence[Any], 225 ): 226 return _exporter_legacy.common_pre_export_passes( 227 options, original_model, fx_module, fx_module_args 228 ) 229