1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast, List 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSEmbedding, 15 MPSGraph, 16 MPSIndexPut, 17 MPSIndexSelect, 18 MPSIndexTensor, 19 MPSScatter, 20) 21from executorch.backends.apple.mps.utils.mps_utils import get_input_node 22from executorch.backends.transforms import get_shape 23from executorch.exir.sym_util import eval_expr 24 25 26@register_node_visitor 27class IndexSelectVisitor(NodeVisitor): 28 target = "aten.index_select.default" 29 30 def __init__(self, *args) -> None: 31 super().__init__(*args) 32 33 def define_node( 34 self, 35 node: torch.fx.Node, 36 mps_graph: MPSGraph, 37 ) -> None: 38 mps_node = self.create_unary_node(node, mps_graph, MPSIndexSelect) 39 mps_node.mpsnode_union.dim = cast(int, node.args[1]) 40 mps_node.mpsnode_union.index_id = self.define_tensor( 41 get_input_node(node, 2), mps_graph 42 ) 43 44 mps_graph.mps_nodes.append(mps_node) 45 46 47@register_node_visitor 48class IndexTensorVisitor(NodeVisitor): 49 target = "aten.index.Tensor" 50 51 def __init__(self, *args) -> None: 52 super().__init__(*args) 53 54 def define_node( 55 self, 56 node: torch.fx.Node, 57 mps_graph: MPSGraph, 58 ) -> None: 59 mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor) 60 tensors = cast(List[torch.fx.Node], node.args[1]) 61 for tensor in tensors: 62 mps_node.mpsnode_union.indices_id.append( 63 self.define_tensor(tensor, mps_graph) 64 ) 65 66 mps_graph.mps_nodes.append(mps_node) 67 68 69@register_node_visitor 70class IndexPutVisitor(NodeVisitor): 71 target = "aten.index_put.default" 72 73 def __init__(self, *args) -> None: 74 super().__init__(*args) 75 76 def infer_sizes(self, a: List[int], b: List[int]): 77 dimsA = len(a) 78 dimsB = len(b) 79 ndim = dimsA if dimsA > dimsB else dimsB 80 expandedSizes = [0] * ndim 81 for i in range(ndim - 1, -1, -1): 82 offset = ndim - 1 - i 83 dimA = dimsA - 1 - offset 84 dimB = dimsB - 1 - offset 85 sizeA = a[dimA] if dimA >= 0 else -1 86 sizeB = b[dimB] if dimB >= 0 else -1 87 expandedSizes[i] = sizeA if sizeB == -1 else sizeB 88 89 return expandedSizes 90 91 def define_node( 92 self, 93 node: torch.fx.Node, 94 mps_graph: MPSGraph, 95 ) -> None: 96 mps_node = self.create_unary_node(node, mps_graph, MPSIndexPut) 97 updates_shape = get_shape(node.args[2]) 98 input_shape = get_shape(node.args[0]) 99 new_shape = [] 100 if len(updates_shape) != 1 and len(updates_shape) != len(input_shape): 101 new_shape = self.infer_sizes(input_shape, updates_shape) 102 mps_node.mpsnode_union.values_shape = new_shape 103 104 tensors = cast(List[torch.fx.Node], node.args[1]) 105 for tensor in tensors: 106 mps_node.mpsnode_union.indices_id.append( 107 self.define_tensor(tensor, mps_graph) 108 ) 109 110 mps_node.mpsnode_union.values_id = self.define_tensor( 111 get_input_node(node, 2), mps_graph 112 ) 113 mps_graph.mps_nodes.append(mps_node) 114 115 116@register_node_visitor 117class SliceScatterVisitor(NodeVisitor): 118 target = "aten.slice_scatter.default" 119 120 def __init__(self, *args) -> None: 121 super().__init__(*args) 122 self.invalid_val = 2**63 - 1 123 124 def maybe_wrap_dim(self, dim: int, n: int) -> List[int]: 125 if dim < 0: 126 wrapped_dim = dim + n 127 if wrapped_dim < 0: 128 wrapped_dim = 0 129 return wrapped_dim 130 elif dim > n: 131 return n 132 return dim 133 134 def get_exapnded_index(self, idx, shape, dim): 135 if idx.dim() == 0: 136 return idx.expand(shape) 137 138 dim = self.maybe_wrap_dim(dim, len(shape)) 139 140 # setup new_index_shape as [BS, 1, ..., idx_size, ..., 1] 141 # to reshape index_ 142 idx_size = idx.size(0) 143 new_index_shape = [1] * len(shape) 144 new_index_shape[dim] = idx_size 145 146 # Now apply expand to index_ 147 index = idx.view(new_index_shape) 148 new_index_shape = list(shape) 149 new_index_shape[dim] = idx_size 150 index = index.expand(new_index_shape) 151 152 return index 153 154 def get_slice_scatter_indices( 155 self, dim, start, end, step, input_shape, dtype=torch.int64 156 ): 157 idx = torch.arange(start, end, step, dtype=dtype) 158 return self.get_exapnded_index(idx, input_shape, dim) 159 160 def define_node( 161 self, 162 node: torch.fx.Node, 163 mps_graph: MPSGraph, 164 ) -> None: 165 mps_node = self.create_unary_node(node, mps_graph, MPSScatter) 166 167 start = None 168 end = None 169 step = 1 170 171 mps_node.mpsnode_union.src_id = self.define_tensor( 172 get_input_node(node, 1), mps_graph 173 ) 174 if len(node.args) >= 3: 175 mps_node.mpsnode_union.dim = cast(int, node.args[2]) 176 if len(node.args) >= 4: 177 start = cast(int, node.args[3]) 178 if len(node.args) >= 5 and node.args[4] != self.invalid_val: 179 end = cast(int, node.args[4]) 180 if len(node.args) >= 6: 181 step = cast(int, node.args[5]) 182 183 input_shape = get_shape(get_input_node(node, 0)) 184 dim_len = input_shape[ 185 self.maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape)) 186 ] 187 188 start_val = start if start is not None else 0 189 end_val = end if end is not None else dim_len 190 191 scatter_indices = self.get_slice_scatter_indices( 192 mps_node.mpsnode_union.dim, start_val, end_val, step, input_shape 193 ) 194 mps_node.mpsnode_union.idx_id = self.define_constant(scatter_indices, mps_graph) 195 mps_graph.mps_nodes.append(mps_node) 196 197 198@register_node_visitor 199class EmbeddingVisitor(NodeVisitor): 200 target = "aten.embedding.default" 201 202 def __init__(self, *args) -> None: 203 super().__init__(*args) 204 205 def define_node( 206 self, 207 node: torch.fx.Node, 208 mps_graph: MPSGraph, 209 ) -> None: 210 n_args = len(node.args) 211 mps_node = self.create_binary_node( 212 node, 213 mps_graph, 214 MPSEmbedding, 215 ) 216 217 if n_args >= 3: 218 mps_node.mpsnode_union.padding_idx = eval_expr( 219 cast(torch.SymInt, node.args[2]) 220 ) 221 if n_args >= 4: 222 mps_node.mpsnode_union.scale_grad_by_freq = cast(bool, node.args[3]) 223 if n_args >= 5: 224 mps_node.mpsnode_union.sparse = cast(bool, node.args[4]) 225 mps_graph.mps_nodes.append(mps_node) 226