1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import cast 10 11import torch 12import torch.fx 13from executorch.backends.arm._passes.arm_pass_utils import create_node 14from executorch.exir.dialects._ops import ops as exir_ops 15from executorch.exir.pass_base import ExportPass, PassResult 16 17 18class InsertSqueezeAfterSumPass(ExportPass): 19 """ 20 In Pytorch, the default behaviour of Tensor.sum is to squeeze 21 the dimension that is summed (keep_dim = False). 22 However, in TOSA, REDUCE_SUM always preserves the 23 rank of the input (keep_dim = True). 24 To get a 1-1 mapping in the sum lowering, normalize the 25 keep_dim = False case to keep_dim = True and add squeeze ops. 26 27 Original: 28 sum(dims, keep_dim = False) 29 After pass: 30 sum(dims, keep_dim = True) 31 squeeze(dim = dims) 32 """ 33 34 def call(self, graph_module: torch.fx.GraphModule): 35 for node in graph_module.graph.nodes: 36 if node.op != "call_function": 37 continue 38 if node.target != exir_ops.edge.aten.sum.dim_IntList: 39 continue 40 sum_node = cast(torch.fx.Node, node) 41 keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False) 42 if keep_dim: 43 continue 44 45 dim_list = cast(list[int], sum_node.args[1]) 46 47 # Add keep_dim = True arg to sum node. 48 sum_node.args = sum_node.args[0:2] + (True,) 49 50 with graph_module.graph.inserting_after(sum_node): 51 squeeze_node = create_node( 52 graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, () 53 ) 54 sum_node.replace_all_uses_with(squeeze_node) 55 squeeze_node.args = (sum_node, dim_list) 56 graph_module.graph.eliminate_dead_code() 57 graph_module.recompile() 58 graph_module = super().call(graph_module).graph_module 59 return PassResult(graph_module, True) 60