• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import cast
10
11import torch
12import torch.fx
13from executorch.backends.arm._passes.arm_pass_utils import create_node
14from executorch.exir.dialects._ops import ops as exir_ops
15from executorch.exir.pass_base import ExportPass, PassResult
16
17
18class InsertSqueezeAfterSumPass(ExportPass):
19    """
20    In Pytorch, the default behaviour of Tensor.sum is to squeeze
21    the dimension that is summed (keep_dim = False).
22    However, in TOSA, REDUCE_SUM always preserves the
23    rank of the input (keep_dim = True).
24    To get a 1-1 mapping in the sum lowering, normalize the
25    keep_dim = False case to keep_dim = True and add squeeze ops.
26
27    Original:
28        sum(dims, keep_dim = False)
29    After pass:
30        sum(dims, keep_dim = True)
31        squeeze(dim = dims)
32    """
33
34    def call(self, graph_module: torch.fx.GraphModule):
35        for node in graph_module.graph.nodes:
36            if node.op != "call_function":
37                continue
38            if node.target != exir_ops.edge.aten.sum.dim_IntList:
39                continue
40            sum_node = cast(torch.fx.Node, node)
41            keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
42            if keep_dim:
43                continue
44
45            dim_list = cast(list[int], sum_node.args[1])
46
47            # Add keep_dim = True arg to sum node.
48            sum_node.args = sum_node.args[0:2] + (True,)
49
50            with graph_module.graph.inserting_after(sum_node):
51                squeeze_node = create_node(
52                    graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
53                )
54                sum_node.replace_all_uses_with(squeeze_node)
55                squeeze_node.args = (sum_node, dim_list)
56        graph_module.graph.eliminate_dead_code()
57        graph_module.recompile()
58        graph_module = super().call(graph_module).graph_module
59        return PassResult(graph_module, True)
60