# mypy: allow-untyped-defs from __future__ import annotations from typing import TYPE_CHECKING import torch from torch.onnx._internal.fx import _pass if TYPE_CHECKING: import torch.fx class MovePlaceholderToFront(_pass.Transform): """This pass move all placeholder nodes to the front of the graph node list. In torch.fx.Graph, placeholder is a special assignment node. If it's not executed in the beginning, it could overwrite values computed by upstream nodes. """ def _run(self, *args, **kwargs) -> torch.fx.GraphModule: graph_module = self.module graph = graph_module.graph placeholders = [] first_not_placeholder = None for node in graph.nodes: if node.op == "placeholder": placeholders.append(node) if first_not_placeholder is None and node.op != "placeholder": first_not_placeholder = node if first_not_placeholder is None: return graph_module for placeholder in placeholders: first_not_placeholder.prepend(placeholder) return graph_module class ReplaceGetAttrWithPlaceholder(_pass.Transform): """Replace get_attr with placeholder. The parameters and buffers accessed by the original get_attr are returned; they are useful when creating random inputs for the modified graph_module. """ _replaced_attrs: tuple[torch.Tensor, ...] | None @property def replaced_attrs(self) -> tuple[torch.Tensor, ...]: """The list of replaced weight tensors.""" assert ( self._replaced_attrs is not None ), "Must run ReplaceGetAttrWithPlaceholder first" return self._replaced_attrs def _run(self, *args, **kwargs) -> torch.fx.GraphModule: graph_module = self.module graph = graph_module.graph replaced_attrs: list[torch.Tensor] = [] for node in graph.nodes: if node.op == "get_attr": replaced_attr: torch.Tensor | None = None # get_attr could retrieve either parameter or buffer, so # we need to try both. try: replaced_attr = graph_module.get_parameter(node.target) except AttributeError: # It's possible that model author use buffer instead of # parameter to store trainable weights. In this case, # 1. get_parameter will throw something like # AttributeError: `bias` is not an nn.Parameter. # 2. get_buffer should work. replaced_attr = graph_module.get_buffer(node.target) # Reassign op type so that get_attr node becomes placeholder node. node.op = "placeholder" # The target name in placeholder must be a valid Python identifier. # Thus, we replace, e.g., "module.submodule.weight" with # "module_submodule_weight". node.target = node.target.replace(".", "_") # Default value is None. This is needed as long as the "graph_module" # has optional inputs. Assume the original forward signature is # def forward(self, x, y=None) # and the replaced get_attr node has target "z". Then, the modified # signature should be # def forward(self, x, y=None, z=None) # Without the following line, the signature will be # def forward(self, x, y=None, z) # , which is not valid Python code. node.args = (None,) replaced_attrs.append(replaced_attr) self._replaced_attrs = tuple(replaced_attrs) return graph_module