# Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass class RemoveRedundancy(ExportPass): """ Trim the 'identity' operators to reduce the unnecessary copy overhead. """ redundant_ops = { torch.clone, torch.ops.aten.clone.default, exir_ops.edge.aten.clone.default, torch.ops.aten.alias.default, exir_ops.edge.aten.alias.default, exir_ops.edge.aten.lift_fresh_copy.default, } def __init__(self): super(RemoveRedundancy, self).__init__() def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: if n.target not in self.redundant_ops: continue to_be_remove = n for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) graph_module.graph.erase_node(to_be_remove) def call(self, graph_module: torch.fx.GraphModule): self._remove(graph_module) graph_module.recompile() dead_code_elimination_pass(graph_module) return PassResult(graph_module, True)