• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import cast, Union
10
11import torch
12from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
13
14from executorch.exir.pass_base import ExportPass, PassResult
15from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
16from torch.fx import GraphModule, Node
17
18
19class ScalarsToAttributePass(ExportPass):
20    """
21    For ops in 'targeted_ops', convert inputs that are scalar values
22    to attribute Nodes that output the same value.
23    """
24
25    targeted_ops = [
26        torch.ops.aten.add.Tensor,
27        torch.ops.aten.add_.Tensor,
28        torch.ops.aten.sub.Tensor,
29        torch.ops.aten.sub_.Tensor,
30        torch.ops.aten.rsub.Scalar,
31        torch.ops.aten.mul.Tensor,
32        torch.ops.aten.mul_.Tensor,
33        torch.ops.aten.div.Tensor,
34        torch.ops.aten.div_.Tensor,
35    ]
36
37    def call(self, graph_module: GraphModule) -> PassResult:
38        for n in graph_module.graph.nodes:
39            n = cast(Node, n)
40            if n.op != "call_function" or n.target not in self.targeted_ops:
41                continue
42
43            biggest_rank = 1
44            for arg in n.args:
45                if isinstance(arg, Node):
46                    shape = get_first_fake_tensor(arg).shape
47                    biggest_rank = max(biggest_rank, len(shape))
48
49            new_args = []
50            for arg in n.args:
51                if isinstance(arg, Node):
52                    new_args.append(arg)
53                    continue
54
55                prefix = "_tensor_constant_"
56                get_new_attr_name = get_new_attr_name_with_prefix(prefix)
57                tensor_constant_name = get_new_attr_name(graph_module)
58                float_tensor = torch.tensor(
59                    float(cast(Union[int, float], arg))
60                ).reshape((1,) * biggest_rank)
61                graph_module.register_buffer(tensor_constant_name, float_tensor)
62                fake_mode = n.meta["val"].fake_mode
63
64                with graph_module.graph.inserting_before(n):
65                    get_attr_node = graph_module.graph.create_node(
66                        "get_attr", tensor_constant_name, (), {}
67                    )
68                    get_attr_node.meta["val"] = fake_mode.from_tensor(
69                        float_tensor, static_shapes=True
70                    )
71                    new_args.append(get_attr_node)
72            n.args = tuple(new_args)
73
74        graph_module.recompile()
75        return PassResult(graph_module, True)
76