1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import cast, Union 10 11import torch 12from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor 13 14from executorch.exir.pass_base import ExportPass, PassResult 15from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix 16from torch.fx import GraphModule, Node 17 18 19class ScalarsToAttributePass(ExportPass): 20 """ 21 For ops in 'targeted_ops', convert inputs that are scalar values 22 to attribute Nodes that output the same value. 23 """ 24 25 targeted_ops = [ 26 torch.ops.aten.add.Tensor, 27 torch.ops.aten.add_.Tensor, 28 torch.ops.aten.sub.Tensor, 29 torch.ops.aten.sub_.Tensor, 30 torch.ops.aten.rsub.Scalar, 31 torch.ops.aten.mul.Tensor, 32 torch.ops.aten.mul_.Tensor, 33 torch.ops.aten.div.Tensor, 34 torch.ops.aten.div_.Tensor, 35 ] 36 37 def call(self, graph_module: GraphModule) -> PassResult: 38 for n in graph_module.graph.nodes: 39 n = cast(Node, n) 40 if n.op != "call_function" or n.target not in self.targeted_ops: 41 continue 42 43 biggest_rank = 1 44 for arg in n.args: 45 if isinstance(arg, Node): 46 shape = get_first_fake_tensor(arg).shape 47 biggest_rank = max(biggest_rank, len(shape)) 48 49 new_args = [] 50 for arg in n.args: 51 if isinstance(arg, Node): 52 new_args.append(arg) 53 continue 54 55 prefix = "_tensor_constant_" 56 get_new_attr_name = get_new_attr_name_with_prefix(prefix) 57 tensor_constant_name = get_new_attr_name(graph_module) 58 float_tensor = torch.tensor( 59 float(cast(Union[int, float], arg)) 60 ).reshape((1,) * biggest_rank) 61 graph_module.register_buffer(tensor_constant_name, float_tensor) 62 fake_mode = n.meta["val"].fake_mode 63 64 with graph_module.graph.inserting_before(n): 65 get_attr_node = graph_module.graph.create_node( 66 "get_attr", tensor_constant_name, (), {} 67 ) 68 get_attr_node.meta["val"] = fake_mode.from_tensor( 69 float_tensor, static_shapes=True 70 ) 71 new_args.append(get_attr_node) 72 n.args = tuple(new_args) 73 74 graph_module.recompile() 75 return PassResult(graph_module, True) 76