1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSDataType, 15 MPSFull, 16 MPSFullLike, 17 MPSGraph, 18 MPSNode, 19) 20from executorch.backends.apple.mps.utils.mps_utils import ( 21 edge_dtype_to_mps_dtype, 22 get_input_node, 23) 24 25from executorch.exir.dialects._ops import ops as exir_ops 26from executorch.exir.sym_util import eval_shape 27 28 29@register_node_visitor 30class ConstantOpVisitor(NodeVisitor): 31 target = [ 32 "aten.full.default", 33 "aten.empty.memory_format", 34 "aten.scalar_tensor.default", 35 ] 36 37 def __init__(self, *args) -> None: 38 super().__init__(*args) 39 40 def define_node( 41 self, 42 node: torch.fx.Node, 43 mps_graph: MPSGraph, 44 ) -> None: 45 if len(node.args) >= 3: 46 raise AssertionError("Unexpected number of input parameters") 47 48 if node.target == exir_ops.edge.aten.scalar_tensor.default: 49 shape = [1] 50 else: 51 shape = eval_shape(node.args[0]) 52 53 if node.target == exir_ops.edge.aten.full.default: 54 fill_value = cast(float, node.args[1]) 55 elif node.target == exir_ops.edge.aten.empty.memory_format: 56 fill_value = 0 57 elif node.target == exir_ops.edge.aten.scalar_tensor.default: 58 fill_value = cast(float, node.args[0]) 59 60 if fill_value == float("-inf"): 61 fill_value = "-inf" 62 elif fill_value == float("inf"): 63 fill_value = "inf" 64 65 dtype = MPSDataType.mps_data_type_float32 66 if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: 67 dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"]) 68 69 output_id = self.define_tensor(node, mps_graph) 70 mps_graph.mps_nodes.append( 71 MPSNode( 72 mpsnode_union=MPSFull( 73 output_id=output_id, 74 shape=shape, 75 fill_value=fill_value, 76 dtype=dtype, 77 ) 78 ) 79 ) 80 81 82@register_node_visitor 83class FullLikeVisitor(NodeVisitor): 84 target = "aten.full_like.default" 85 86 def __init__(self, *args) -> None: 87 super().__init__(*args) 88 89 def define_node( 90 self, 91 node: torch.fx.Node, 92 mps_graph: MPSGraph, 93 ) -> None: 94 95 if len(node.args) < 2: 96 raise AssertionError("Full op requires at least size & fill_value args") 97 98 mps_node = self.create_unary_node(node, mps_graph, MPSFullLike) 99 100 mps_node.mpsnode_union.fill_value = cast(float, node.args[1]) 101 mps_node.mpsnode_union.dtype = self.get_serialized_dtype( 102 get_input_node(node, 0) 103 ) 104 if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: 105 mps_node.mpsnode_union.dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"]) 106 if len(node.args) >= 3: 107 raise AssertionError("Unexpected number of input parameters") 108 109 mps_graph.mps_nodes.append(mps_node) 110