1# mypy: allow-untyped-defs 2import operator 3import traceback 4import typing 5from contextlib import nullcontext 6from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 7 8import torch 9from functorch.experimental.control_flow import _unstack_pytree 10from torch import fx 11from torch._dispatch.python import enable_python_dispatcher 12from torch._export.pass_infra.node_metadata import NodeMetadata 13from torch._export.pass_infra.proxy_value import ProxyValue 14from torch._subclasses import FakeTensor, UnsupportedFakeTensorException 15from torch._subclasses.fake_tensor import FakeTensorMode 16from torch.fx import traceback as fx_traceback 17from torch.fx.experimental.proxy_tensor import PythonKeyTracer 18from torch.fx.graph import CodeGen 19from torch.fx.passes.infra.pass_base import PassBase, PassResult 20from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata 21from torch.utils import _pytree as pytree 22from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings 23 24 25__all__ = ["_ExportPassBaseDeprecatedDoNotUse"] 26 27 28Argument = Any 29Value = Any 30Fn = Callable[..., Any] 31PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] 32 33 34_TORCH_SYM_OPS: Set[Callable] = { 35 torch.sym_int, 36 torch.sym_float, 37 torch.sym_ite, 38 torch.sym_max, 39 torch.sym_min, 40 torch.sym_not, 41 torch.sym_sqrt, 42} 43 44 45class ExportPassBaseError(RuntimeError): 46 pass 47 48 49class _ExportPassBaseDeprecatedDoNotUse(PassBase): 50 """ 51 Interpreter-based pass class to help users maintain the IR spec while writing 52 transformations. 53 """ 54 55 @staticmethod 56 def _create_dummy_node_metadata(): 57 return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) 58 59 60 class ExportTracer(PythonKeyTracer): 61 def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None: 62 super().__init__() 63 self.callback = callback 64 self.root = torch.nn.Module() 65 self.graph = torch.fx.Graph() 66 self.graph.set_codegen(codegen) 67 self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] 68 self.fake_tensor_mode: Optional[FakeTensorMode] = None 69 self.submodules: Dict[torch.nn.Module, str] = {} 70 71 def trace(self) -> None: # type: ignore[override] 72 raise ExportPassBaseError("ExportTracer doesn't support trace().") 73 74 def create_arg(self, a: Argument) -> torch.fx.Node: 75 if isinstance(a, torch.nn.Module): 76 if a not in self.submodules: 77 name_submodule = f"submodule_{len(self.submodules)}" 78 self.root.add_module(name_submodule, a) 79 self.submodules[a] = name_submodule 80 elif isinstance(a, FakeTensor): 81 if not hasattr(a, "constant") or a.constant is None: 82 raise ExportPassBaseError(f"Cannot add {a} to graph.") 83 a = a.constant 84 node = super().create_arg(a) 85 if ( 86 isinstance(a, torch.Tensor) 87 and isinstance(node, torch.fx.Node) 88 and node.op == "get_attr" 89 ): 90 self.set_metadata(node, a) 91 self.callback.on_attr(ProxyValue(a, node)) 92 return node 93 94 def set_metadata( 95 self, node: torch.fx.Node, value: Argument, 96 ) -> None: 97 # propagate the fake tensor or sym nodes 98 def make_val( 99 x: Argument, 100 ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]: 101 if isinstance(x, FakeTensor): 102 return x 103 elif isinstance(x, torch.Tensor): 104 if x.is_quantized: 105 # TODO (tmanlaibaatar) properly support Quantized FakeTensor 106 x = torch.dequantize(x) 107 108 try: 109 assert self.fake_tensor_mode is not None 110 # TODO we should allocate static shapes 111 # for param/buffer values 112 if isinstance(x, torch.nn.Parameter): 113 fake_tensor = self.fake_tensor_mode.from_tensor( 114 x, static_shapes=True 115 ) 116 else: 117 fake_tensor = self.fake_tensor_mode.from_tensor(x) 118 except UnsupportedFakeTensorException: 119 # TODO: This is just a workaround to get over the 120 # x.as_subclass error 121 print( 122 "Fakeifying a Tensor subclass is not supported \ 123 right now. Instead a TensorMetadata is used." 124 ) 125 fake_tensor = None 126 return fake_tensor 127 elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)): 128 return x 129 else: 130 return None 131 132 node.meta["val"] = pytree.tree_map(make_val, value) 133 134 # Set the tensor_metadata for values that do not have a corresponding FakeTensor 135 def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: 136 if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): 137 if x.is_quantized: 138 # TODO (tmanlaibaatar) properly support Quantized FakeTensor 139 x = torch.dequantize(x) 140 141 try: 142 assert self.fake_tensor_mode is not None 143 _ = self.fake_tensor_mode.from_tensor(x) 144 tensor_meta = None 145 except UnsupportedFakeTensorException: 146 # TODO: This is just a workaround to get over the 147 # x.as_subclass error 148 tensor_meta = _extract_tensor_metadata(x) 149 return tensor_meta 150 else: 151 return None 152 153 node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) 154 155 class ExportInterpreter(fx.Interpreter): 156 def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None: 157 super().__init__(gm) 158 self.callback = callback 159 self.node: torch.fx.Node = next(iter(gm.graph.nodes)) 160 161 def placeholder( 162 self, 163 target: str, # type: ignore[override] 164 args: Tuple[Argument, ...], 165 kwargs: Dict[str, Argument], 166 ) -> ProxyValue: 167 arg = super().placeholder(target, args, kwargs) 168 return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) 169 170 def output( 171 self, 172 target: torch.fx.node.Target, 173 args: Tuple[Argument, ...], 174 kwargs: Dict[str, Argument], 175 ) -> ProxyValue: 176 return self.callback.output(args[0], NodeMetadata(self.node.meta)).data 177 178 def call_function( 179 self, 180 target: torch.fx.node.Target, 181 args: Tuple[Argument, ...], 182 kwargs: Dict[str, Argument], 183 ) -> ProxyValue: 184 meta = NodeMetadata(self.node.meta) 185 186 if target == operator.getitem: 187 value, key = args 188 return self.callback.call_getitem(value, key, meta) 189 elif getattr(target, "__module__", None) in {"_operator", "math"}: 190 assert callable(target) 191 return self.callback.call_sym(target, args, meta) 192 elif target in _TORCH_SYM_OPS: 193 assert callable(target) 194 return self.callback.call_sym(target, args, meta) 195 elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): 196 return self.callback.call_operator( 197 target, 198 args, 199 kwargs, 200 meta, 201 ) 202 elif target == torch.ops.higher_order.cond: 203 pred, true_fn, false_fn, inputs = args 204 return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) 205 elif target == torch.ops.higher_order.map_impl: 206 f, mapped_args, operands = args # type: ignore[assignment] 207 return self.callback.call_map(f, mapped_args, operands, meta) 208 # For other unregistered HigherOrderOps, just interpret them blindly 209 elif isinstance(target, torch._ops.HigherOrderOperator): 210 return self.callback._fx( 211 "call_function", 212 target, 213 args, 214 kwargs, 215 meta, 216 ) 217 else: 218 raise ExportPassBaseError(f"Unsupported target type: {target}") 219 220 def get_attr( 221 self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] 222 ) -> Argument: 223 return super().get_attr(target, args, kwargs) 224 225 def call_module( 226 self, 227 target: torch.fx.node.Target, 228 args: Tuple[Argument, ...], 229 kwargs: Dict[str, Argument], 230 ) -> None: 231 raise ExportPassBaseError("call_module is not supported.") 232 233 def call_method( 234 self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] 235 ) -> None: 236 raise ExportPassBaseError("call_method is not supported.") 237 238 def run_node(self, n: torch.fx.Node) -> Argument: 239 self.node = n 240 self.callback.node_debug_str = n.format_node() 241 return super().run_node(n) 242 243 def __init__(self) -> None: 244 self.interpreter = PropagateUnbackedSymInts( 245 torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 246 ) 247 self.tracer = self.ExportTracer(self, CodeGen()) 248 self.fake_tensor_mode: Optional[FakeTensorMode] = None 249 self._initialized = True 250 self.node_debug_str: typing.Optional[str] = None 251 252 def _fx( 253 self, 254 kind: str, 255 target: torch.fx.node.Target, 256 args: Tuple[Argument, ...], 257 kwargs: Dict[str, Argument], 258 meta: NodeMetadata, 259 ) -> ProxyValue: 260 args_data, kwargs_data = pytree.tree_map_only( 261 ProxyValue, lambda x: x.data, (args, kwargs) 262 ) 263 res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) 264 args_proxy, kwargs_proxy = pytree.tree_map_only( 265 ProxyValue, lambda x: x.proxy, (args, kwargs) 266 ) 267 268 name = None 269 if isinstance(target, torch._ops.OpOverload): 270 name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) 271 272 res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name) 273 res_proxy.node.meta.update(meta.data) 274 if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): 275 if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): 276 res_proxy.node.meta["unbacked_bindings"] = symbol_to_path 277 self.tracer.set_metadata(res_proxy.node, res_data) 278 return ProxyValue(res_data, res_proxy) 279 280 def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: 281 # TODO(angelayi): Update this with what we decide to do for metadata in 282 # the exported graph module 283 if (args := graph_module.meta.get("args", None)) is not None: 284 return list(args) 285 286 def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: 287 if "val" in node.meta: 288 fake = node.meta["val"] 289 if hasattr(fake, "constant") and fake.constant is not None: 290 return fake.constant 291 return fake 292 elif tensor_meta := node.meta.get("tensor_meta"): 293 assert self.fake_tensor_mode is not None 294 return FakeTensor( 295 self.fake_tensor_mode, 296 torch.empty( 297 tensor_meta.shape, 298 dtype=tensor_meta.dtype, 299 device="meta", 300 requires_grad=tensor_meta.requires_grad, 301 memory_format=tensor_meta.memory_format, 302 ), 303 torch.device("cpu"), 304 ) 305 elif len(node.users) == 0: 306 return None 307 raise ExportPassBaseError( 308 f"Cannot construct an input for graph module: {graph_module}.", 309 ) 310 311 return [ 312 extract_input(node) 313 for node in graph_module.graph.nodes 314 if node.op == "placeholder" 315 ] 316 317 def on_attr(self, attr: ProxyValue) -> None: 318 pass 319 320 def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: 321 arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) 322 arg_proxy.node.meta = meta.data 323 self.tracer.set_metadata(arg_proxy.node, arg) 324 return ProxyValue(arg, arg_proxy) 325 326 def call_operator( 327 self, 328 op, 329 args: Tuple[Argument, ...], 330 kwargs: Dict[str, Argument], 331 meta: NodeMetadata, 332 ) -> ProxyValue: 333 return self._fx("call_function", op, args, kwargs, meta) 334 335 def call_sym( 336 self, 337 target: Fn, 338 args: Tuple[Argument, ...], 339 meta: NodeMetadata, 340 ) -> ProxyValue: 341 return self._fx("call_function", target, args, {}, meta) 342 343 def call_cond( 344 self, 345 pred: ProxyValue, 346 true_fn: torch.fx.GraphModule, 347 false_fn: torch.fx.GraphModule, 348 inputs: List[Argument], 349 meta: NodeMetadata, 350 ) -> ProxyValue: 351 true_branch = self.call_submodule(true_fn, tuple(inputs)) 352 false_branch = self.call_submodule(false_fn, tuple(inputs)) 353 assert true_branch is not None 354 assert false_branch is not None 355 return self._fx( 356 "call_function", 357 torch.ops.higher_order.cond, 358 (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), 359 {}, 360 meta, 361 ) 362 363 def call_map( 364 self, 365 f: torch.fx.GraphModule, 366 mapped_args: List[ProxyValue], 367 operands: List[ProxyValue], 368 meta: NodeMetadata, 369 ) -> ProxyValue: 370 xs = _unstack_pytree([arg.data for arg in mapped_args])[0] 371 f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) 372 assert f_branch is not None 373 return self._fx( 374 "call_function", 375 torch.ops.higher_order.map_impl, 376 (f_branch.graph_module, mapped_args, operands), 377 {}, 378 meta, 379 ) 380 381 def call_getitem( 382 self, value: ProxyValue, key: int, meta: NodeMetadata 383 ) -> ProxyValue: 384 return self._fx("call_function", operator.getitem, (value, key), {}, meta) 385 386 def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: 387 return self._fx("output", "output", (results,), {}, meta) 388 389 def call_submodule( 390 self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] 391 ) -> PassResult: 392 prev_tracer, self.tracer = self.tracer, self.ExportTracer( 393 self, graph_module.graph._codegen 394 ) 395 self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode 396 interpreter = self.ExportInterpreter(self, graph_module) 397 prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment] 398 torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 399 ) 400 inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) 401 with fx_traceback.preserve_node_meta(): 402 interpreter.run(*inputs_data) 403 404 new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) 405 406 self.tracer = prev_tracer 407 self.interpreter = prev_interpreter 408 return PassResult( 409 new_graph_module, 410 True, 411 ) 412 413 def call(self, graph_module: fx.GraphModule) -> PassResult: 414 if not getattr(self, "_initialized", False): 415 raise ExportPassBaseError( 416 "ExportPass is not initialized with __init__().", 417 ) 418 419 inputs = self.inputs(graph_module) 420 421 fake_tensor_mode = None 422 for i in inputs: 423 if isinstance(i, FakeTensor): 424 assert ( 425 fake_tensor_mode is None or fake_tensor_mode is i.fake_mode 426 ), "Multiple fake tensor mode detected." 427 fake_tensor_mode = i.fake_mode 428 if fake_tensor_mode is None: 429 self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) 430 fake_tensor_mode = nullcontext() # type: ignore[assignment] 431 dispatcher_mode = nullcontext() # type: ignore[assignment] 432 else: 433 fake_tensor_mode.allow_non_fake_inputs = True 434 self.tracer.fake_tensor_mode = fake_tensor_mode 435 dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] 436 self.fake_tensor_mode = self.tracer.fake_tensor_mode 437 438 with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] 439 result = self.call_submodule(graph_module, tuple(inputs)) 440 441 return result 442