1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14using executorch::aten::ScalarType; 15 16MPSGraphTensor* 17binaryOpTensor( 18 MPSGraphTensor* primaryTensor, 19 MPSGraphTensor* secondaryTensor, 20 MPSGraph* mpsGraph, 21 std::function<MPSGraphTensor*(MPSGraphTensor*, MPSGraphTensor*)> binaryOpFunction) { 22 MPSDataType mpsInputDataType = [primaryTensor dataType]; 23 MPSDataType mpsOtherDataType = [secondaryTensor dataType]; 24 25 ScalarType inputDataType = getScalarType(mpsInputDataType); 26 ScalarType otherDataType = getScalarType(mpsOtherDataType); 27 28 MPSGraphTensor* primaryCastTensor = primaryTensor; 29 MPSGraphTensor* secondaryCastTensor = secondaryTensor; 30 ScalarType commonDataType = executorch::runtime::promoteTypes(inputDataType, otherDataType); 31 if (inputDataType != commonDataType) { 32 primaryCastTensor = castMPSTensor(mpsGraph, primaryTensor, commonDataType); 33 } 34 if (otherDataType != commonDataType) { 35 secondaryCastTensor = castMPSTensor(mpsGraph, secondaryTensor, commonDataType); 36 } 37 38 return binaryOpFunction(primaryCastTensor, secondaryCastTensor); 39} 40 41/* 42Helper macro to create an MPSGraph node based on the serialized data from the FlatBuffer. 43It takes 2 inputs, an alpha parameter and returns one output. Couple operators from PyTorch, 44such as torch.sub, torch.add take an additional alpha param. 45More info at https://pytorch.org/docs/stable/generated/torch.sub.html. 46*/ 47#define REGISTER_BINARY_WITH_ALPHA_OP(aot_name, graph_op) \ 48Error \ 49MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ 50auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ 51 ET_LOG( \ 52 Debug, "%s: (%d, %d) -> %d", \ 53 __FUNCTION__, \ 54 graphNode->input1_id(), \ 55 graphNode->input2_id(), \ 56 graphNode->output_id() \ 57 ); \ 58 \ 59 _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \ 60 getMPSGraphTensor(graphNode->input1_id()), \ 61 getMPSGraphTensor(graphNode->input2_id()), \ 62 _mpsGraph, \ 63 [&](MPSGraphTensor* primaryCastTensor, \ 64 MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ 65 if (graphNode->alpha() != 1.0) { \ 66 MPSGraphTensor* alphaTensor = [_mpsGraph constantWithScalar:graphNode->alpha() \ 67 shape:@[@1] \ 68 dataType:primaryCastTensor.dataType]; \ 69 secondaryCastTensor = [_mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor \ 70 secondaryTensor:alphaTensor \ 71 name:nil]; \ 72 } \ 73 return [_mpsGraph graph_op##WithPrimaryTensor:primaryCastTensor \ 74 secondaryTensor:secondaryCastTensor \ 75 name:nil]; \ 76 } \ 77 ); \ 78 return Error::Ok; \ 79} 80 81/* 82Helper macro to create an MPSGraph node based on the serialized data from the FlatBuffer. 83It takes 2 inputs and returns one output. 84*/ 85#define REGISTER_BINARY_OP(aot_name, graph_op) \ 86Error \ 87MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ 88auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ 89 ET_LOG( \ 90 Debug, "%s: (%d, %d) -> %d", \ 91 __FUNCTION__, \ 92 graphNode->input1_id(), \ 93 graphNode->input2_id(), \ 94 graphNode->output_id() \ 95 ); \ 96 \ 97 _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \ 98 getMPSGraphTensor(graphNode->input1_id()), \ 99 getMPSGraphTensor(graphNode->input2_id()), \ 100 _mpsGraph, \ 101 [&](MPSGraphTensor* primaryCastTensor, \ 102 MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ 103 return [_mpsGraph graph_op##WithPrimaryTensor:primaryCastTensor \ 104 secondaryTensor:secondaryCastTensor \ 105 name:nil]; \ 106 } \ 107 ); \ 108 \ 109 return Error::Ok; \ 110} 111 112#define REGISTER_BITWISE_BINARY_OP(aot_name, graph_op) \ 113Error \ 114MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ 115auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ 116 ET_LOG( \ 117 Debug, "%s: (%d, %d) -> %d", \ 118 __FUNCTION__, \ 119 graphNode->input1_id(), \ 120 graphNode->input2_id(), \ 121 graphNode->output_id() \ 122 ); \ 123 ET_CHECK_OR_RETURN_ERROR( \ 124 is_macos_13_or_newer(), NotSupported, \ 125 "%s supported by MPS on MacOS13.0+/iOS16.1+", #aot_name); \ 126 \ 127 _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \ 128 getMPSGraphTensor(graphNode->input1_id()), \ 129 getMPSGraphTensor(graphNode->input2_id()), \ 130 _mpsGraph, \ 131 [&](MPSGraphTensor* primaryCastTensor, \ 132 MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ 133 MPSDataType mpsInputDataType = [primaryCastTensor dataType]; \ 134 if (getScalarType(mpsInputDataType) == ScalarType::Bool) { \ 135 return [_mpsGraph logical##graph_op##WithPrimaryTensor:primaryCastTensor \ 136 secondaryTensor:secondaryCastTensor \ 137 name:nil]; \ 138 } \ 139 return [_mpsGraph bitwise##graph_op##WithPrimaryTensor:primaryCastTensor \ 140 secondaryTensor:secondaryCastTensor \ 141 name:nil]; \ 142 } \ 143 ); \ 144 \ 145 return Error::Ok; \ 146} 147 148// Arithmetic Binary Ops 149REGISTER_BINARY_WITH_ALPHA_OP(Add, addition) 150REGISTER_BINARY_WITH_ALPHA_OP(Sub, subtraction) 151REGISTER_BINARY_OP(Mul, multiplication) 152REGISTER_BINARY_OP(Pow, power) 153REGISTER_BINARY_OP(Minimum, minimum) 154 155// Boolean Binary ops 156REGISTER_BINARY_OP(Eq, equal) 157REGISTER_BINARY_OP(Ne, notEqual) 158REGISTER_BINARY_OP(Ge, greaterThanOrEqualTo) 159REGISTER_BINARY_OP(Gt, greaterThan) 160REGISTER_BINARY_OP(Le, lessThanOrEqualTo) 161REGISTER_BINARY_OP(Lt, lessThan) 162 163// Bitwise Binary ops 164REGISTER_BITWISE_BINARY_OP(BitwiseAnd, AND) 165REGISTER_BITWISE_BINARY_OP(BitwiseOr, OR) 166REGISTER_BITWISE_BINARY_OP(BitwiseXor, XOR) 167 168#undef REGISTER_BINARY_WITH_ALPHA_OP 169#undef REGISTER_BINARY_OP 170 171static 172MPSGraphTensor* mpsTruncTensor(MPSGraphTensor* inputTensor, MPSGraph* mpsGraph) { 173 // Rounding is a no-op for integral types, and also a reasonable workaround 174 // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` 175 // See https://github.com/pytorch/pytorch/issues/84995 176 bool isFloatInput = ([inputTensor dataType] & MPSDataTypeFloatBit) != 0; 177 if (!isFloatInput) { 178 return inputTensor; 179 } 180 181 if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS)) { 182 MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; 183 MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor 184 secondaryTensor:zeroTensor 185 name:nil]; 186 return [mpsGraph selectWithPredicateTensor:predicateTensor 187 truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil] 188 falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil] 189 name:nil]; 190 } else { 191 return [mpsGraph truncateWithTensor:inputTensor 192 name:nil]; 193 } 194}; 195 196static 197MPSGraphTensor* divModeTemplate( 198 MPSGraphTensor* primaryTensor, 199 MPSGraphTensor* secondaryTensor, 200 std::optional<flatbuffers::string_view> rounding_mode, 201 MPSGraph* mpsGraph, 202 const std::string& op_name) { 203 MPSDataType mpsInputDataType = [primaryTensor dataType]; 204 ScalarType inputDataType = getScalarType(mpsInputDataType); 205 206 if(rounding_mode.has_value() && *rounding_mode == "trunc"){ 207 ET_CHECK_MSG(inputDataType != ScalarType::Half, 208 "MPS: does not support trunc_divide op with float16 input"); 209 } 210 211 auto divOpFunc = [&](MPSGraphTensor* primaryCastTensor, 212 MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { 213 bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0; 214 if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) { 215 primaryCastTensor = [mpsGraph castTensor:primaryCastTensor 216 toType:MPSDataTypeFloat32 217 name:@"primaryCastTensor"]; 218 secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor 219 toType:MPSDataTypeFloat32 220 name:@"secondaryCastTensor"]; 221 } 222 MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor 223 secondaryTensor:secondaryCastTensor 224 name:nil]; 225 226 // Rounding is a no-op for integral types, and also a reasonable workaround 227 // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` 228 // See https://github.com/pytorch/pytorch/issues/84995 229 bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0; 230 if (!rounding_mode.has_value() || !isFloatOutput) { 231 return divTensor; 232 } else if (*rounding_mode == "trunc") { 233 auto truncTensor = mpsTruncTensor(divTensor, mpsGraph); 234 if (op_name == "Fmod") { 235 auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor 236 secondaryTensor:secondaryCastTensor 237 name:nil]; 238 return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor 239 secondaryTensor:mulTensor 240 name:nil]; 241 } 242 return truncTensor; 243 } else if (*rounding_mode == "floor") { 244 MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; 245 if (op_name == "Remainder") { 246 auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor 247 secondaryTensor:secondaryCastTensor 248 name:nil]; 249 return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor 250 secondaryTensor:mulTensor 251 name:nil]; 252 } 253 return floorTensor; 254 } else { 255 assert(0 && "Invalid rounding mode\n"); 256 } 257 return nullptr; 258 }; 259 return binaryOpTensor(primaryTensor, secondaryTensor, mpsGraph, divOpFunc); 260} 261 262#define REGISTER_DIV_OP(aot_name, round_mode) \ 263Error \ 264MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ 265 auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ 266 ET_LOG( \ 267 Debug, "%s: (%d, %d) -> %d", \ 268 __FUNCTION__, \ 269 graphNode->input1_id(), \ 270 graphNode->input2_id(), \ 271 graphNode->output_id() \ 272 ); \ 273 \ 274 auto strView = graphNode->rounding_mode() != nullptr ? \ 275 std::make_optional(graphNode->rounding_mode()->string_view()) : round_mode; \ 276 \ 277 _idToMPSGraphTensor[graphNode->output_id()] = divModeTemplate( \ 278 getMPSGraphTensor(graphNode->input1_id()), \ 279 getMPSGraphTensor(graphNode->input2_id()), \ 280 strView, \ 281 _mpsGraph, \ 282 #aot_name \ 283 ); \ 284 \ 285 return Error::Ok; \ 286} 287 288REGISTER_DIV_OP(Div, std::nullopt) 289REGISTER_DIV_OP(Fmod, "trunc") 290REGISTER_DIV_OP(Remainder, "floor") 291 292#undef REGISTER_DIV_OP 293 294 295} // namespace delegate 296} // namespace mps 297} // namespace backends 298} // namespace executorch 299