1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch 10from executorch.exir.pass_base import ExportPass, PassResult 11 12 13class UnsqueezeScalarPlaceholdersPass(ExportPass): 14 """ 15 Placeholders that have node.meta["val"].shape = () cause issues later in the lowering. 16 This pass unsqueezes the placeholders to make sure shape is at least (1,). 17 """ 18 19 def __init__(self, exported_program): 20 self.exported_program = exported_program 21 super().__init__() 22 23 def call(self, graph_module: torch.fx.GraphModule): 24 for node in graph_module.graph.nodes: 25 if node.op != "placeholder": 26 continue 27 rank = node.meta["val"].dim() 28 if rank == 0: 29 if not ( 30 node.name in self.exported_program.graph_signature.inputs_to_buffers 31 or node.name 32 in self.exported_program.graph_signature.inputs_to_parameters 33 ): 34 continue 35 tensor = self.exported_program.state_dict[node.name] 36 if tensor.dim() == 0: 37 self.exported_program.state_dict[node.name] = tensor.unsqueeze(0) 38 node.meta["val"] = node.meta["val"].fake_mode.from_tensor( 39 tensor.unsqueeze(0), static_shapes=True 40 ) 41 else: 42 node.meta["val"] = node.meta["val"].fake_mode.from_tensor( 43 tensor, static_shapes=True 44 ) 45 46 graph_module.recompile() 47 graph_module = super().call(graph_module).graph_module 48 return PassResult(graph_module, True) 49 50 def ensures(self, graph_module: torch.fx.GraphModule): 51 for node in graph_module.graph.nodes: 52 if node.op == "placeholder": 53 rank = node.meta["val"].dim() 54 if rank == 0: 55 raise ValueError("Placeholders of rank 0 are not supported!") 56