• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant
8from executorch.exir.dialects._ops import ops as exir_ops
9from executorch.exir.pass_base import ExportPass, PassResult
10from torch._subclasses.fake_tensor import FakeTensor
11
12
13class I64toI32(ExportPass):
14    """
15    Cast unsupported int64 datatype into int32.
16    """
17
18    def __init__(self, edge_program: torch.export.ExportedProgram):
19        super(I64toI32, self).__init__()
20        self.edge_program = edge_program
21        # pyre-ignore[4]
22        self.copy_op = exir_ops.edge.aten._to_copy.default
23
24    def _update_meta(self, node: torch.fx.node) -> None:
25        meta_val = node.meta["val"]
26        if isinstance(meta_val, tuple):
27            node.meta["val"] = (
28                (
29                    fake_tensor.to(torch.int32)
30                    if fake_tensor.dtype == torch.int64
31                    else fake_tensor
32                )
33                for fake_tensor in meta_val
34            )
35        else:
36            if meta_val.dtype == torch.int64:
37                node.meta["val"] = meta_val.to(torch.float)
38
39    # pyre-ignore[2]
40    def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
41        return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
42
43    def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
44        for n in graph_module.graph.nodes:
45            if is_constant(n, self.edge_program):
46                param = get_parameter(n, self.edge_program)
47                if param.dtype == torch.int64:
48                    # QNN does not support int64
49                    self._update_meta(n)
50            elif n.op == "placeholder":
51                node_val = n.meta["val"]
52                if self._is_tensor_of_dtype(node_val, torch.int64):
53                    with graph_module.graph.inserting_after(n):
54                        args = (n,)
55                        to_dst_node = graph_module.graph.create_node(
56                            "call_function",
57                            self.copy_op,
58                            args,
59                            {"dtype": torch.int32},
60                        )
61                        to_dst_node.meta["val"] = node_val.to(torch.int32)
62
63                        # Replace usage of the src dtype result with the dst dtype result.
64                        n.replace_all_uses_with(to_dst_node)
65                        to_dst_node.args = (n,)
66
67    def call(self, graph_module: torch.fx.GraphModule):
68        self._cast_to_int32(graph_module)
69        graph_module.recompile()
70        graph_module = super().call(graph_module).graph_module
71        return PassResult(graph_module, True)
72