# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe from typing import cast, Union import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.exir.pass_base import ExportPass, PassResult from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch.fx import GraphModule, Node class ScalarsToAttributePass(ExportPass): """ For ops in 'targeted_ops', convert inputs that are scalar values to attribute Nodes that output the same value. """ targeted_ops = [ torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor, torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor, torch.ops.aten.rsub.Scalar, torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor, torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor, ] def call(self, graph_module: GraphModule) -> PassResult: for n in graph_module.graph.nodes: n = cast(Node, n) if n.op != "call_function" or n.target not in self.targeted_ops: continue biggest_rank = 1 for arg in n.args: if isinstance(arg, Node): shape = get_first_fake_tensor(arg).shape biggest_rank = max(biggest_rank, len(shape)) new_args = [] for arg in n.args: if isinstance(arg, Node): new_args.append(arg) continue prefix = "_tensor_constant_" get_new_attr_name = get_new_attr_name_with_prefix(prefix) tensor_constant_name = get_new_attr_name(graph_module) float_tensor = torch.tensor( float(cast(Union[int, float], arg)) ).reshape((1,) * biggest_rank) graph_module.register_buffer(tensor_constant_name, float_tensor) fake_mode = n.meta["val"].fake_mode with graph_module.graph.inserting_before(n): get_attr_node = graph_module.graph.create_node( "get_attr", tensor_constant_name, (), {} ) get_attr_node.meta["val"] = fake_mode.from_tensor( float_tensor, static_shapes=True ) new_args.append(get_attr_node) n.args = tuple(new_args) graph_module.recompile() return PassResult(graph_module, True)