# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # pyre-strict # This file contains functions to remove operators from the graph. The removed # ops should belong to either of the following categories: # 1. The op should be redundant for inference (e.g., dropout). Such ops are grouped # together in 'RemoveRedundantOps'. Anyone running inference can add this class # in their pass list, and it should semantic-preserving transformation. # 2. The op should be redundant for Jarvis (e.g., contiguous). Such ops are grouped # together in 'CadenceRemoveNops'. The ops removed in this class might not be nop # in a context outside of Jarvis', so exercise caution while invoking this in a # pass list outside of Jarvis. import itertools import logging from dataclasses import dataclass, field from typing import Callable, cast, Dict, List, Optional, Sequence import torch import torch.fx from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, register_cadence_pass, ) from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue from executorch.exir.pass_manager import PassManager, PassType from executorch.exir.passes import dead_code_elimination_pass from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch.fx.node import Argument @register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveCloneOpsTransformImported(ExportPass): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: finalize_passes: List[PassType] = [ RemoveCloneOpsTransform(), ] result = PassManager(passes=finalize_passes)(graph_module) dead_code_elimination_pass(result.graph_module) return result @register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveDetachCopyPass(ExportPass): def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.detach_copy.default: return super().call_operator(op, args, kwargs, meta) assert len(args) == 1 return cast(ProxyValue, args[0]) # The following class consolidates passes to remove ops that are redundant: # either by the virtue of the operation they perform, or redundant in the # context of inference. class RemoveRedundantOps: passes = [ RemoveDetachCopyPass, ] @register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveZeroSizedCatArgsPass(ExportPass): def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.cat.default: return super().call_operator(op, args, kwargs, meta) # Remove any zero-sized tensor arg to form a new args list. cat_inputs: list[ProxyValue] = [] for arg in cast(Sequence[ProxyValue], args[0]): if arg.to_tensor().numel() > 0: cat_inputs.append(arg) # If all the tensors were empty, we just return an empty tensor with # the right shape. if not cat_inputs: empty_shape = meta["val"].shape dtype = meta["val"].dtype return super().call_operator( exir_ops.edge.aten.full.default, (tuple(empty_shape), 0), {"dtype": dtype}, meta, ) # If there was only one tensor in the cat_inputs list, # we can safely erase this cat op. if len(cat_inputs) == 1: return cat_inputs[0] # Otherwise, we replace args[0] with cat_inputs. new_args = list(args) new_args[0] = cat_inputs return super().call_operator(op, tuple(new_args), kwargs, meta) @register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveNopExpandOpPass(ExportPass): """ For an expand op, if the operator shape matches the expand shape, then the expand is a nop. """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if get_edge_overload_packet(op) not in { exir_ops.edge.aten.expand_copy, exir_ops.edge.aten.expand, }: return super().call_operator(op, args, kwargs, meta) # Parse the args, and check for nop condition arg0 = cast(ProxyValue, args[0]) arg1 = cast(Sequence[int], args[1]) in_tensor = arg0.to_tensor() if list(in_tensor.shape) == list(arg1): return arg0 return super().call_operator(op, args, kwargs, meta) @register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveToOpsPass(ExportPass): # aten.to.* as of now are all nops for Jarvis def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op not in ( exir_ops.edge.aten.to.dtype, exir_ops.edge.aten.to.dtype_layout, ): return super().call_operator(op, args, kwargs, meta) logging.debug(f"Erasing to.dtype node (target = {op})") return cast(ProxyValue, args[0]) @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveZeroSizedConstantPadNd(ExportPass): def call_operator( self, op, # pyre-ignore args: tuple[ProxyValue, tuple[int, ...], Argument], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.constant_pad_nd.default: return super().call_operator(op, args, kwargs, meta) input_tensor = args[0] padding = args[1] if any(x != 0 for x in padding): return super().call_operator(op, args, kwargs, meta) logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}") return input_tensor @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveNopSliceOrViewOpPass(ExportPass): """ Remove slice ops that are more like views, and view ops that do not change the shape """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op not in { exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.view_copy.default, }: return super().call_operator(op, args, kwargs, meta) arg0 = cast(ProxyValue, args[0]) out_shape = meta["val"].shape # If both arg_shape and out_shape are the same, this slice is a nop return ( arg0 if arg0.to_tensor().shape == out_shape else super().call_operator(op, args, kwargs, meta) ) @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveNopLinalgVectorNormOpPass(ExportPass): """ If the norm is applied over a dimension that is size 1, it can be eliminated. """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op not in { exir_ops.edge.aten.linalg_vector_norm.default, exir_ops.edge.cadence.linalg_vector_norm.default, }: return super().call_operator(op, args, kwargs, meta) # If the op has three args or less, it can't be a nop if len(args) <= 3: return super().call_operator(op, args, kwargs, meta) # If dim is None, or keepdim is False, it is not a nop dim = cast(Optional[tuple[int, ...]], args[2]) keepdim = cast(bool, args[3]) if dim is None or not keepdim: return super().call_operator(op, args, kwargs, meta) # If the norm has 4 args and keepdim is True, check if dim is not None # and if the dimensions in dim are size 1. If not, the norm is not a nop. t = cast(ProxyValue, args[0]) shape = t.to_tensor().shape if len(args) < 4: for d in dim: if shape[d] != 1: return super().call_operator(op, args, kwargs, meta) return t @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveNopSelectOpPass(ExportPass): """ A select op that selects from a dimension that is size 1 can be eliminated in a few cases. For example, ``` x = view (x, [1, 3, 16]) y = select(x, 0, 0) z = add(m, y) ``` The special thing about this pattern is the add op, which allows broadcasting. So adding an operand with shape [3, 16] is the same as adding an operand with shape [1, 3, 16]. Therefore, if m has the same shape as x, then this select op is a nop, and can be eliminated: ``` x = view (x, [1, 3, 16]) z = add(x, m) ``` """ # A set of binary operators that could require broadcasting, and are # critical to this transformation if their operand is select op. binary_broadcast_ops: set[EdgeOpOverload] = { exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.div.Tensor, } def __init__(self) -> None: super().__init__() self.op_sizes: dict[str, tuple[torch.Size, torch.Size]] = {} # For select, view, or any op in binary_broadcast_ops, record the shapes of # input and output tensors. def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: res = super().call_operator(op, args, kwargs, meta) # Unary ops: input and output if op in { exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.view_copy.default, }: arg0 = cast(ProxyValue, args[0]) self.op_sizes[res.node.name] = (arg0.to_tensor().shape, meta["val"].shape) # Binary ops: two inputs, output shape can be inferred elif op in self.binary_broadcast_ops: arg0 = cast(ProxyValue, args[0]) arg1 = cast(ProxyValue, args[1]) self.op_sizes[res.node.name] = ( arg0.to_tensor().shape, arg1.to_tensor().shape, ) return res # Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops, # and check if their arg is a select op. def eliminate_nop_select_op(self, graph_module: torch.fx.GraphModule) -> None: for sel_node in graph_module.graph.nodes: # We are only interested in select ops if sel_node.target != exir_ops.edge.aten.select_copy.int: continue # The shape of the input/output operands for this select op should # have been precomputed. assert sel_node.name in self.op_sizes (sel_in_shape, sel_out_shape) = self.op_sizes[sel_node.name] # Get the select dimension sel_dim = ( sel_node.args[1] if sel_node.args[1] >= 0 else sel_node.args[1] + len(sel_in_shape) ) # If the input size along select dimension is not 1, bail. if sel_in_shape[sel_dim] != 1: continue # Get all the users of the select op that are either view, or # binary_broadcast_ops. users = [x for x in list(sel_node.users.keys()) if x.name in self.op_sizes] sel_in = sel_node.args[0] # Iterate over the users of select op, and remove the use of the # select op in the user if feasible. for node in users: args = list(node.args) for idx, sel_arg in enumerate(args): # Check if the arg is the select op if sel_arg != sel_node: continue # If the input of select has the same shape as the other arg # of the binary op, the select op can be bypassed. if sel_in_shape == self.op_sizes[node.name][(idx + 1) % 2]: args[idx] = sel_in # update the node's args node.args = tuple(args) graph_module.recompile() graph_module.graph.eliminate_dead_code() def call(self, graph_module: torch.fx.GraphModule) -> PassResult: result = SpecPropPass()(graph_module) assert result is not None result = super().call(result.graph_module) self.eliminate_nop_select_op(result.graph_module) return result @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveCloneOpPass(ExportPass): # If the op is a clone op, return the input and eliminate the op def call_operator( self, op, # pyre-ignore args: tuple[ProxyValue], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.clone.default: return super().call_operator(op, args, kwargs, meta) return args[0] @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveContiguousOpPass(ExportPass): """ This is based on the assumption that all tensors are contiguous in ExecuTorch and after cadence passes, and we should revisit this if that assumption is no longer true. This causes the model to not be runnable with the arguments given to the original graph module. """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.contiguous.default: return super().call_operator(op, args, kwargs, meta) assert len(args) == 1 return cast(ProxyValue, args[0]) @register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveAliasCopyOpPass(ExportPass): """ alias_copy is a no-op for Jarvis and can be removed. """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.alias_copy.default: return super().call_operator(op, args, kwargs, meta) assert len(args) == 1 return cast(ProxyValue, args[0]) @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveNopRequantizeOpPass(ExportPass): """ For a requantize op, if the following three conditions are satisfied: 1. the in_scale matches the out_scale 2. the in_zero_point matches the out_zero_point 3. the dtypes of the input and output tensors are the same then the requantize op is redundant, and can be eliminated """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.cadence.requantize.default: return super().call_operator(op, args, kwargs, meta) # Parse the args (X, in_scale, in_zero_point, out_scale, out_zero_point, out_dtype) = cast( tuple[ProxyValue, int, float, int, float, torch.dtype], args ) in_dtype = X.to_tensor().dtype # Check the three conditions if ( in_scale == out_scale and in_zero_point == out_zero_point and in_dtype == out_dtype ): return cast(ProxyValue, args[0]) return super().call_operator(op, args, kwargs, meta) @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveNopMulOpPass(ExportPass): """ If a mul op is multiplying two tensors with the same shape and one of those tensors is all zeros, return the zero tensor instead. """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.mul.Tensor: return super().call_operator(op, args, kwargs, meta) # Parse the args (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) # Check if both inputs have the same shape if input1.to_tensor().shape != input2.to_tensor().shape: return super().call_operator(op, args, kwargs, meta) # Check if one of the inputs is a zero tensor if input1.node.target == exir_ops.edge.aten.full.default: if input1.node.args[1] == 0: return input1 elif input2.node.target == exir_ops.edge.aten.full.default: if input2.node.args[1] == 0: return input2 return super().call_operator(op, args, kwargs, meta) @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveNopAddOpPass(ExportPass): """ If an add op is adding two tensors with the same shape and one of those tensors is all zeros, return the other tensor instead. """ def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op != exir_ops.edge.aten.add.Tensor: return super().call_operator(op, args, kwargs, meta) # Parse the args (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) # Check if both inputs have the same shape if input1.to_tensor().shape != input2.to_tensor().shape: return super().call_operator(op, args, kwargs, meta) # Check if one of the inputs is a zero tensor if input1.node.target == exir_ops.edge.aten.full.default: if input1.node.args[1] == 0: return input2 elif input2.node.target == exir_ops.edge.aten.full.default: if input2.node.args[1] == 0: return input1 return super().call_operator(op, args, kwargs, meta) @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemovePermutesAroundElementwiseOps(ExportPass): """ Looks for subgraphs of elementwise ops sandwiched between permutes and removes those permutes if possible. This pass is targeted at models where delegated subgraphs must be in NHWC format, so there's usually a to_NHWC permute before each delegate and a to_NCHW permute after it. If all the ops between two delegates are elementwise ops then these permutes can be safely removed. Allows special handling for certain non-elementwise ops that can be easily updated based on the permute's parameter, such as mean and cat """ @dataclass() class Subgraph: """ Keeps track of nodes grouped as a subgraph between two sets of permutes """ start_permutes: set[torch.fx.Node] = field(default_factory=set) end_permutes: set[torch.fx.Node] = field(default_factory=set) intermediate_nodes: set[torch.fx.Node] = field(default_factory=set) is_valid: bool = True elementwise_ops: set[EdgeOpOverload] = { exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.cat.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, } # must be initialized in the constructor special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {} to_NCHW = [0, 3, 1, 2] to_NHWC = [0, 2, 3, 1] def __init__(self) -> None: super().__init__() self.visited: set[object] = set() self.special_handling = { exir_ops.edge.aten.mean.dim: self.handle_mean_dim, exir_ops.edge.aten.cat.default: self.handle_cat, } def call(self, graph_module: torch.fx.GraphModule) -> PassResult: self.visited = set() for node in graph_module.graph.nodes: sg = self.Subgraph() self.start_search(node, sg) if self.is_valid_subgraph(sg): logging.debug(f"Found valid subgraph: {sg}") self.handle_subgraph(graph_module, sg) result = super().call(graph_module) return result def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None: assert mean_dim.target == exir_ops.edge.aten.mean.dim args = list(mean_dim.args) args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])] mean_dim.args = tuple(args) def handle_cat(self, cat: torch.fx.Node) -> None: assert cat.target == exir_ops.edge.aten.cat.default args = list(cat.args) args[1] = self.to_NCHW[cast(int, args[1])] cat.args = tuple(args) def is_valid_subgraph(self, sg: Subgraph) -> bool: return ( sg.is_valid and len(sg.start_permutes) > 0 and len(sg.end_permutes) > 0 and len(sg.intermediate_nodes) > 0 ) def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None: for permute in itertools.chain(sg.start_permutes, sg.end_permutes): permute.replace_all_uses_with(permute.args[0]) # pyre-fixme[6] for node in sg.intermediate_nodes: if node.target in self.special_handling: self.special_handling[node.target](node) graph_module.recompile() graph_module.graph.eliminate_dead_code() def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None: if node in self.visited: return if self.is_starting_permute(node): sg.start_permutes.add(node) self.visited.add(node) for user in node.users: self.search_down(user, sg) def search_up(self, node: object, sg: Subgraph) -> None: # non-nodes can be ignored. These would be arguments like integers or lists # of integers, which don't affect the subgraph validity or inclusion set. if not isinstance(node, torch.fx.Node): return if node.op == "placeholder": # If we reach a placeholder or other terminal node without encountering # a start permute, then the subgraph is invalid. # This could be because in the add(x, y) case where x is permuted and # y is a graph input, we can't remove the permute on x because it might # become two different shapes that don't broadcast together. # TODO: Adding a permute on y could be the more optimal solution, # but perhaps not in all cases, say if x is small and y is very large. # This transform prefers to be safe over optimal for now. sg.is_valid = False return if node in self.visited: return self.visited.add(node) if self.is_starting_permute(node): sg.start_permutes.add(node) for user in node.users: self.search_down(user, sg) else: self.traverse_intermediate_node(node, sg) def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None: if node in self.visited or self.is_starting_permute(node): return self.visited.add(node) if self.is_ending_permute(node): sg.end_permutes.add(node) for arg in node.args: if isinstance(arg, list): for elem in arg: self.search_up(elem, sg) else: self.search_up(arg, sg) else: self.traverse_intermediate_node(node, sg) def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None: if node.target in self.elementwise_ops: sg.intermediate_nodes.add(node) for arg in node.args: if isinstance(arg, list): for elem in arg: self.search_up(elem, sg) else: self.search_up(arg, sg) for user in node.users: self.search_down(user, sg) else: sg.is_valid = False def is_starting_permute(self, node: torch.fx.Node) -> bool: return ( node.target == exir_ops.edge.aten.permute_copy.default and cast(list[int], node.args[1]) == self.to_NCHW ) def is_ending_permute(self, node: torch.fx.Node) -> bool: return ( node.target == exir_ops.edge.aten.permute_copy.default and cast(list[int], node.args[1]) == self.to_NHWC ) # The following class consolidates functions to remove ops that are redundant # in Jarvis. Currently, each function in this class iterates over each node of # the graph module once. In future, we could consolidate them into a monolithic # function. class CadenceRemoveNops: passes = [ SimplifySliceOpPass, RemoveCloneOpsTransformImported, RemoveToOpsPass, RemoveNopRequantizeOpPass, RemoveZeroSizedCatArgsPass, RemoveNopSliceOrViewOpPass, RemoveNopExpandOpPass, RemoveZeroSizedConstantPadNd, RemoveCloneOpPass, RemoveContiguousOpPass, RemoveAliasCopyOpPass, RemoveNopMulOpPass, RemoveNopAddOpPass, RemoveNopLinalgVectorNormOpPass, ]