1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant 8from executorch.exir.dialects._ops import ops as exir_ops 9from executorch.exir.pass_base import ExportPass, PassResult 10from torch._subclasses.fake_tensor import FakeTensor 11 12 13class I64toI32(ExportPass): 14 """ 15 Cast unsupported int64 datatype into int32. 16 """ 17 18 def __init__(self, edge_program: torch.export.ExportedProgram): 19 super(I64toI32, self).__init__() 20 self.edge_program = edge_program 21 # pyre-ignore[4] 22 self.copy_op = exir_ops.edge.aten._to_copy.default 23 24 def _update_meta(self, node: torch.fx.node) -> None: 25 meta_val = node.meta["val"] 26 if isinstance(meta_val, tuple): 27 node.meta["val"] = ( 28 ( 29 fake_tensor.to(torch.int32) 30 if fake_tensor.dtype == torch.int64 31 else fake_tensor 32 ) 33 for fake_tensor in meta_val 34 ) 35 else: 36 if meta_val.dtype == torch.int64: 37 node.meta["val"] = meta_val.to(torch.float) 38 39 # pyre-ignore[2] 40 def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: 41 return isinstance(node_val, FakeTensor) and node_val.dtype == dtype 42 43 def _cast_to_int32(self, graph_module: torch.fx.GraphModule): 44 for n in graph_module.graph.nodes: 45 if is_constant(n, self.edge_program): 46 param = get_parameter(n, self.edge_program) 47 if param.dtype == torch.int64: 48 # QNN does not support int64 49 self._update_meta(n) 50 elif n.op == "placeholder": 51 node_val = n.meta["val"] 52 if self._is_tensor_of_dtype(node_val, torch.int64): 53 with graph_module.graph.inserting_after(n): 54 args = (n,) 55 to_dst_node = graph_module.graph.create_node( 56 "call_function", 57 self.copy_op, 58 args, 59 {"dtype": torch.int32}, 60 ) 61 to_dst_node.meta["val"] = node_val.to(torch.int32) 62 63 # Replace usage of the src dtype result with the dst dtype result. 64 n.replace_all_uses_with(to_dst_node) 65 to_dst_node.args = (n,) 66 67 def call(self, graph_module: torch.fx.GraphModule): 68 self._cast_to_int32(graph_module) 69 graph_module.recompile() 70 graph_module = super().call(graph_module).graph_module 71 return PassResult(graph_module, True) 72