• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch
10from executorch.exir.pass_base import ExportPass, PassResult
11
12
13class UnsqueezeScalarPlaceholdersPass(ExportPass):
14    """
15    Placeholders that have node.meta["val"].shape = () cause issues later in the lowering.
16    This pass unsqueezes the placeholders to make sure shape is at least (1,).
17    """
18
19    def __init__(self, exported_program):
20        self.exported_program = exported_program
21        super().__init__()
22
23    def call(self, graph_module: torch.fx.GraphModule):
24        for node in graph_module.graph.nodes:
25            if node.op != "placeholder":
26                continue
27            rank = node.meta["val"].dim()
28            if rank == 0:
29                if not (
30                    node.name in self.exported_program.graph_signature.inputs_to_buffers
31                    or node.name
32                    in self.exported_program.graph_signature.inputs_to_parameters
33                ):
34                    continue
35                tensor = self.exported_program.state_dict[node.name]
36                if tensor.dim() == 0:
37                    self.exported_program.state_dict[node.name] = tensor.unsqueeze(0)
38                    node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
39                        tensor.unsqueeze(0), static_shapes=True
40                    )
41                else:
42                    node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
43                        tensor, static_shapes=True
44                    )
45
46        graph_module.recompile()
47        graph_module = super().call(graph_module).graph_module
48        return PassResult(graph_module, True)
49
50    def ensures(self, graph_module: torch.fx.GraphModule):
51        for node in graph_module.graph.nodes:
52            if node.op == "placeholder":
53                rank = node.meta["val"].dim()
54                if rank == 0:
55                    raise ValueError("Placeholders of rank 0 are not supported!")
56