1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast, List 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSCat, 15 MPSExpand, 16 MPSGraph, 17 MPSNode, 18 MPSPermute, 19 MPSPixelShuffle, 20 MPSSelect, 21 MPSSlice, 22 MPSSplitWithSizes, 23 MPSSqueeze, 24 MPSUnsqueeze, 25 MPSView, 26) 27from executorch.backends.apple.mps.utils.mps_utils import get_input_node 28from executorch.backends.transforms import get_shape 29from executorch.exir.dialects._ops import ops as exir_ops 30 31from executorch.exir.sym_util import eval_expr, eval_shape 32 33 34@register_node_visitor 35class PermuteVisitor(NodeVisitor): 36 target = "aten.permute_copy.default" 37 38 def __init__(self, *args) -> None: 39 super().__init__(*args) 40 41 def define_node( 42 self, 43 node: torch.fx.Node, 44 mps_graph: MPSGraph, 45 ) -> None: 46 mps_node = self.create_unary_node(node, mps_graph, MPSPermute) 47 48 permute_order = cast(List[int], node.args[1]) 49 mps_node.mpsnode_union.num_dims = len(permute_order) 50 mps_node.mpsnode_union.perm = permute_order 51 52 mps_graph.mps_nodes.append(mps_node) 53 54 55@register_node_visitor 56class ViewExpandVisitor(NodeVisitor): 57 target = ["aten.view_copy.default", "aten.expand_copy.default"] 58 59 def __init__(self, *args) -> None: 60 super().__init__(*args) 61 62 def define_node( 63 self, 64 node: torch.fx.Node, 65 mps_graph: MPSGraph, 66 ) -> None: 67 node_type = ( 68 MPSView 69 if node.target is exir_ops.edge.aten.view_copy.default 70 else MPSExpand 71 ) 72 mps_node = self.create_unary_node(node, mps_graph, node_type) 73 74 view_shape = cast(List[int], node.args[1]) 75 mps_node.mpsnode_union.num_dims = len(view_shape) 76 mps_node.mpsnode_union.shape = view_shape 77 78 mps_graph.mps_nodes.append(mps_node) 79 80 81@register_node_visitor 82class CatVisitor(NodeVisitor): 83 target = "aten.cat.default" 84 85 def __init__(self, *args) -> None: 86 super().__init__(*args) 87 88 def define_node( 89 self, 90 node: torch.fx.Node, 91 mps_graph: MPSGraph, 92 ) -> None: 93 tensors = cast(List[torch.fx.Node], node.args[0]) 94 output_id = self.define_tensor(node, mps_graph) 95 input_ids: List[int] = [] 96 97 for tensor in tensors: 98 input_ids.append(self.define_tensor(tensor, mps_graph)) 99 100 dim = 0 101 if len(node.args) > 1: 102 dim = cast(int, node.args[1]) 103 if dim < 0 and len(tensors) > 0: 104 dim += len(get_shape(tensors[0])) 105 106 mps_graph.mps_nodes.append( 107 MPSNode( 108 mpsnode_union=MPSCat(input_ids=input_ids, output_id=output_id, dim=dim), 109 ), 110 ) 111 112 113@register_node_visitor 114class SqueezeUnsqueezeVisitor(NodeVisitor): 115 target = ["aten.unsqueeze_copy.default", "aten.squeeze_copy.dims"] 116 117 def __init__(self, *args) -> None: 118 super().__init__(*args) 119 120 def define_node( 121 self, 122 node: torch.fx.Node, 123 mps_graph: MPSGraph, 124 ) -> None: 125 node_type = ( 126 MPSUnsqueeze 127 if node.target is exir_ops.edge.aten.unsqueeze_copy.default 128 else MPSSqueeze 129 ) 130 131 mps_node = self.create_unary_node(node, mps_graph, node_type) 132 133 if node_type is MPSUnsqueeze: 134 mps_node.mpsnode_union.dim = cast(int, node.args[1]) 135 else: 136 dims = cast(List[int], node.args[1]) 137 input_shape = get_shape(get_input_node(node, 0)) 138 new_dims = [] 139 for dim in dims: 140 if input_shape[dim] == 1: 141 new_dims.append(dim) 142 mps_node.mpsnode_union.dims = new_dims 143 144 mps_graph.mps_nodes.append(mps_node) 145 146 147@register_node_visitor 148class SelectVisitor(NodeVisitor): 149 target = "aten.select_copy.int" 150 151 def __init__(self, *args) -> None: 152 super().__init__(*args) 153 154 def define_node( 155 self, 156 node: torch.fx.Node, 157 mps_graph: MPSGraph, 158 ) -> None: 159 mps_node = self.create_unary_node(node, mps_graph, MPSSelect) 160 mps_node.mpsnode_union.dim = cast(int, node.args[1]) 161 mps_node.mpsnode_union.index = eval_expr(cast(torch.SymInt, node.args[2])) 162 mps_graph.mps_nodes.append(mps_node) 163 164 165@register_node_visitor 166class PixelShuffleVisitor(NodeVisitor): 167 target = "aten.pixel_shuffle.default" 168 169 def __init__(self, *args) -> None: 170 super().__init__(*args) 171 172 def define_node( 173 self, 174 node: torch.fx.Node, 175 mps_graph: MPSGraph, 176 ) -> None: 177 mps_node = self.create_unary_node(node, mps_graph, MPSPixelShuffle) 178 mps_node.mpsnode_union.upscale_factor = cast(int, node.args[1]) 179 mps_graph.mps_nodes.append(mps_node) 180 181 182@register_node_visitor 183class SliceVisitor(NodeVisitor): 184 target = "aten.slice_copy.Tensor" 185 186 def __init__(self, *args) -> None: 187 super().__init__(*args) 188 189 def define_node( 190 self, 191 node: torch.fx.Node, 192 mps_graph: MPSGraph, 193 ) -> None: 194 mps_node = self.create_unary_node(node, mps_graph, MPSSlice) 195 196 def maybe_wrap_dim(dim: int, n: int) -> List[int]: 197 if dim < 0: 198 wrapped_dim = dim + n 199 if wrapped_dim < 0: 200 wrapped_dim = 0 201 return wrapped_dim 202 elif dim > n: 203 return n 204 return dim 205 206 start = None 207 end = None 208 if len(node.args) >= 2: 209 mps_node.mpsnode_union.dim = cast(int, node.args[1]) 210 if len(node.args) >= 4: 211 end = cast(int, node.args[3]) 212 start = cast(int, node.args[2]) 213 if len(node.args) >= 5: 214 mps_node.mpsnode_union.step = cast(int, node.args[4]) 215 216 input_shape = get_shape(get_input_node(node, 0)) 217 dim_len = input_shape[ 218 maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape)) 219 ] 220 221 start_val = start if start is not None else 0 222 end_val = end if end is not None else dim_len 223 224 mps_node.mpsnode_union.start = maybe_wrap_dim(start_val, dim_len) 225 mps_node.mpsnode_union.end = maybe_wrap_dim(end_val, dim_len) 226 mps_graph.mps_nodes.append(mps_node) 227 228 229@register_node_visitor 230class SplitWithSizesVisitor(NodeVisitor): 231 target = "aten.split_with_sizes_copy.default" 232 233 def __init__(self, *args) -> None: 234 super().__init__(*args) 235 236 def define_node( 237 self, 238 node: torch.fx.Node, 239 mps_graph: MPSGraph, 240 ) -> None: 241 input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) 242 output_ids = self.define_tensor_list(node, mps_graph) 243 split_sizes = eval_shape(cast(torch.SymInt, node.args[1])) 244 dim = cast(int, node.args[2]) 245 input_shape = get_shape(get_input_node(node, 0)) 246 247 if dim < 0 or dim >= len(input_shape): 248 raise RuntimeError( 249 f"split_copy: dim {dim} out of range for input tensor with {len(input_shape)} dimensions" 250 ) 251 252 mps_node = MPSNode( 253 mpsnode_union=MPSSplitWithSizes( 254 input1_id=input1_id, 255 output_ids=output_ids, 256 split_sizes=split_sizes, 257 dim=dim, 258 ) 259 ) 260 mps_graph.mps_nodes.append(mps_node) 261