1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10from typing import cast 11 12from executorch.backends.arm._passes.arm_pass_utils import ( 13 create_node, 14 get_first_fake_tensor, 15) 16 17from executorch.exir.dialects._ops import ops as exir_ops 18 19from executorch.exir.pass_base import ExportPass, PassResult 20from torch.fx import GraphModule, Node 21 22 23class MatchArgRanksPass(ExportPass): 24 """ 25 For ops in 'targeted_ops', make sure that the inputs share the same rank. 26 New dimensions are inserted at from the beginning of the 27 """ 28 29 def __init__(self, exported_program): 30 super().__init__() 31 self.exported_program = exported_program 32 33 targeted_ops = [ 34 exir_ops.edge.aten.add.Tensor, 35 exir_ops.edge.aten.sub.Tensor, 36 exir_ops.edge.aten.mul.Tensor, 37 exir_ops.edge.aten.div.Tensor, 38 ] 39 40 def _match_op_rank(self, graph_module, node, arg, max_rank): 41 """ 42 In graph_module, insert a view between arg and node to make the 43 rank of arg match the other args to node. 44 """ 45 shape = get_first_fake_tensor(arg).shape 46 rank = len(shape) 47 new_shape = list([1] * (max_rank - rank) + list(shape)) 48 with graph_module.graph.inserting_before(node): 49 view = create_node( 50 graph_module.graph, 51 exir_ops.edge.aten.view_copy.default, 52 args=(arg, new_shape), 53 kwargs={}, 54 ) 55 node.replace_input_with(arg, view) 56 57 def _match_buffer_rank(self, arg, max_rank): 58 """ 59 Change arg's fake tensor meta to match max_rank if: 60 - arg is found in inputs_to_buffers or inputs_to_parameters. 61 """ 62 fake_tensor = get_first_fake_tensor(arg) 63 shape = fake_tensor.shape 64 rank = len(shape) 65 new_shape = list([1] * (max_rank - rank) + list(shape)) 66 67 buffer_name = None 68 if arg.name in self.exported_program.graph_signature.inputs_to_buffers: 69 buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ 70 arg.name 71 ] 72 elif arg.name in self.exported_program.graph_signature.inputs_to_parameters: 73 buffer_name = self.exported_program.graph_signature.inputs_to_parameters[ 74 arg.name 75 ] 76 if buffer_name: 77 new_tensor = self.exported_program.state_dict[buffer_name].reshape( 78 new_shape 79 ) 80 self.exported_program.state_dict[buffer_name] = new_tensor 81 arg.meta["val"] = fake_tensor.fake_mode.from_tensor( 82 new_tensor, static_shapes=True 83 ) 84 85 def call(self, graph_module: GraphModule) -> PassResult: 86 for node in graph_module.graph.nodes: 87 node = cast(Node, node) 88 89 if node.op != "call_function" or node.target not in self.targeted_ops: 90 continue 91 92 # Calculate max rank of all inputs to node 93 max_rank = 1 94 for arg in node.args: 95 if isinstance(arg, Node): 96 shape = get_first_fake_tensor(arg).shape 97 max_rank = max(max_rank, len(shape)) 98 99 # Adjust output shape of args if needed. 100 for arg in node.args: 101 if not isinstance(arg, Node): 102 continue 103 shape = get_first_fake_tensor(arg).shape 104 rank = len(shape) 105 if rank == max_rank: 106 continue 107 108 # If the argument is call_function, match shape by inserting view node. 109 if arg.op == "call_function": 110 self._match_op_rank(graph_module, node, arg, max_rank) 111 else: 112 # If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta. 113 self._match_buffer_rank(arg, max_rank) 114 115 graph_module.recompile() 116 graph_module = super().call(graph_module).graph_module 117 return PassResult(graph_module, True) 118