# # Copyright (c) 2023 Apple Inc. All rights reserved. # Provided subject to the LICENSE file in the top level directory. # from typing import cast, Optional, Union import torch from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSDataType from executorch.exir import ExportedProgram from torch._export.utils import get_buffer, get_param, is_buffer, is_param def get_input_node(node: torch.fx.Node, input_index: int) -> Union[torch.fx.Node, None]: return None if node is None else cast(torch.fx.Node, node.args[input_index]) def get_scalar_val(node: torch.fx.Node, input_index: int) -> Union[float, int]: return node.args[input_index] def edge_dtype_to_mps_dtype(dtype: torch.dtype): if not hasattr(edge_dtype_to_mps_dtype, "map"): edge_dtype_to_mps_dtype.map = { torch.float16: MPSDataType.mps_data_type_float16, torch.float32: MPSDataType.mps_data_type_float32, torch.float64: MPSDataType.mps_data_type_float32, torch.bfloat16: MPSDataType.mps_data_type_bfloat16, torch.int8: MPSDataType.mps_data_type_int8, torch.int16: MPSDataType.mps_data_type_int16, torch.int32: MPSDataType.mps_data_type_int32, torch.int64: MPSDataType.mps_data_type_int64, torch.uint8: MPSDataType.mps_data_type_uint8, torch.bool: MPSDataType.mps_data_type_bool, torch.cfloat: MPSDataType.mps_data_type_complex_float32, torch.chalf: MPSDataType.mps_data_type_complex_float16, } try: return edge_dtype_to_mps_dtype.map[dtype] except KeyError: raise RuntimeError(f"Invalid data type: {dtype}") def get_param_tensor( exp_prog: ExportedProgram, node: torch.fx.Node ) -> Optional[torch.Tensor]: if node is None: return None elif is_param(exp_prog, node): return get_param(exp_prog, node) elif is_buffer(exp_prog, node): return get_buffer(exp_prog, node) elif is_get_attr(node): # Support both lifted and unlifted graph try: # Unlifted graph (coming from old exir.capture API) return getattr(node.graph.owning_module, node.target) except AttributeError: return getattr(exp_prog.graph_module, node.target) raise RuntimeError(f"unsupported param type, {node.op}.") def is_get_attr(node: torch.fx.Node): """ Returns true if the given node is a get attr node for a tensor of the model """ return isinstance(node, torch.fx.Node) and node.op == "get_attr" def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool: """ Check if a node is a lifted parameter (static data like weights and bias are are supplied as inputs to the graph. Args: exp_prog (torch.export.ExportedProgram): _description_ node (torch.fx.Node): _description_ Returns: bool: _description_ """ return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node)