• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14using executorch::aten::ScalarType;
15
16MPSGraphTensor*
17binaryOpTensor(
18  MPSGraphTensor* primaryTensor,
19  MPSGraphTensor* secondaryTensor,
20  MPSGraph* mpsGraph,
21  std::function<MPSGraphTensor*(MPSGraphTensor*, MPSGraphTensor*)> binaryOpFunction) {
22  MPSDataType mpsInputDataType = [primaryTensor dataType];
23  MPSDataType mpsOtherDataType = [secondaryTensor dataType];
24
25  ScalarType inputDataType = getScalarType(mpsInputDataType);
26  ScalarType otherDataType = getScalarType(mpsOtherDataType);
27
28  MPSGraphTensor* primaryCastTensor = primaryTensor;
29  MPSGraphTensor* secondaryCastTensor = secondaryTensor;
30  ScalarType commonDataType = executorch::runtime::promoteTypes(inputDataType, otherDataType);
31  if (inputDataType != commonDataType) {
32    primaryCastTensor = castMPSTensor(mpsGraph, primaryTensor, commonDataType);
33  }
34  if (otherDataType != commonDataType) {
35    secondaryCastTensor = castMPSTensor(mpsGraph, secondaryTensor, commonDataType);
36  }
37
38  return binaryOpFunction(primaryCastTensor, secondaryCastTensor);
39}
40
41/*
42Helper macro to create an MPSGraph node based on the serialized data from the FlatBuffer.
43It takes 2 inputs, an alpha parameter and returns one output. Couple operators from PyTorch,
44such as torch.sub, torch.add take an additional alpha param.
45More info at https://pytorch.org/docs/stable/generated/torch.sub.html.
46*/
47#define REGISTER_BINARY_WITH_ALPHA_OP(aot_name, graph_op)                                       \
48Error                                                                                           \
49MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr)  {                                          \
50auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name();                                     \
51  ET_LOG(                                                                                       \
52    Debug, "%s: (%d, %d) -> %d",                                                                \
53    __FUNCTION__,                                                                               \
54    graphNode->input1_id(),                                                                     \
55    graphNode->input2_id(),                                                                     \
56    graphNode->output_id()                                                                      \
57  );                                                                                            \
58                                                                                                \
59  _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor(                                 \
60    getMPSGraphTensor(graphNode->input1_id()),                                                  \
61    getMPSGraphTensor(graphNode->input2_id()),                                                  \
62    _mpsGraph,                                                                                  \
63    [&](MPSGraphTensor* primaryCastTensor,                                                      \
64        MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* {                               \
65      if (graphNode->alpha() != 1.0) {                                                          \
66        MPSGraphTensor* alphaTensor = [_mpsGraph constantWithScalar:graphNode->alpha()          \
67                                                             shape:@[@1]                        \
68                                                          dataType:primaryCastTensor.dataType]; \
69        secondaryCastTensor = [_mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor    \
70                                                        secondaryTensor:alphaTensor             \
71                                                                   name:nil];                   \
72      }                                                                                         \
73      return [_mpsGraph graph_op##WithPrimaryTensor:primaryCastTensor                           \
74                                  secondaryTensor:secondaryCastTensor                           \
75                                             name:nil];                                         \
76    }                                                                                           \
77  );                                                                                            \
78  return Error::Ok;                                                                             \
79}
80
81/*
82Helper macro to create an MPSGraph node based on the serialized data from the FlatBuffer.
83It takes 2 inputs and returns one output.
84*/
85#define REGISTER_BINARY_OP(aot_name, graph_op)                                  \
86Error                                                                           \
87MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr)  {                          \
88auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name();                     \
89  ET_LOG(                                                                       \
90    Debug, "%s: (%d, %d) -> %d",                                                \
91    __FUNCTION__,                                                               \
92    graphNode->input1_id(),                                                     \
93    graphNode->input2_id(),                                                     \
94    graphNode->output_id()                                                      \
95  );                                                                            \
96                                                                                \
97  _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor(                 \
98    getMPSGraphTensor(graphNode->input1_id()),                                  \
99    getMPSGraphTensor(graphNode->input2_id()),                                  \
100    _mpsGraph,                                                                  \
101    [&](MPSGraphTensor* primaryCastTensor,                                      \
102        MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* {               \
103      return [_mpsGraph graph_op##WithPrimaryTensor:primaryCastTensor           \
104                                  secondaryTensor:secondaryCastTensor           \
105                                             name:nil];                         \
106    }                                                                           \
107  );                                                                            \
108                                                                                \
109  return Error::Ok;                                                             \
110}
111
112#define REGISTER_BITWISE_BINARY_OP(aot_name, graph_op)                             \
113Error                                                                              \
114MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr)  {                             \
115auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name();                        \
116  ET_LOG(                                                                          \
117    Debug, "%s: (%d, %d) -> %d",                                                   \
118    __FUNCTION__,                                                                  \
119    graphNode->input1_id(),                                                        \
120    graphNode->input2_id(),                                                        \
121    graphNode->output_id()                                                         \
122  );                                                                               \
123  ET_CHECK_OR_RETURN_ERROR(                                                        \
124    is_macos_13_or_newer(), NotSupported,                                          \
125    "%s supported by MPS on MacOS13.0+/iOS16.1+", #aot_name);                      \
126                                                                                   \
127  _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor(                    \
128    getMPSGraphTensor(graphNode->input1_id()),                                     \
129    getMPSGraphTensor(graphNode->input2_id()),                                     \
130    _mpsGraph,                                                                     \
131    [&](MPSGraphTensor* primaryCastTensor,                                         \
132        MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* {                  \
133      MPSDataType mpsInputDataType = [primaryCastTensor dataType];                 \
134      if (getScalarType(mpsInputDataType) == ScalarType::Bool) {                   \
135        return [_mpsGraph logical##graph_op##WithPrimaryTensor:primaryCastTensor   \
136                                             secondaryTensor:secondaryCastTensor   \
137                                                        name:nil];                 \
138      }                                                                            \
139      return [_mpsGraph bitwise##graph_op##WithPrimaryTensor:primaryCastTensor     \
140                                             secondaryTensor:secondaryCastTensor   \
141                                                        name:nil];                 \
142    }                                                                              \
143  );                                                                               \
144                                                                                   \
145  return Error::Ok;                                                                \
146}
147
148// Arithmetic Binary Ops
149REGISTER_BINARY_WITH_ALPHA_OP(Add, addition)
150REGISTER_BINARY_WITH_ALPHA_OP(Sub, subtraction)
151REGISTER_BINARY_OP(Mul, multiplication)
152REGISTER_BINARY_OP(Pow, power)
153REGISTER_BINARY_OP(Minimum, minimum)
154
155// Boolean Binary ops
156REGISTER_BINARY_OP(Eq, equal)
157REGISTER_BINARY_OP(Ne, notEqual)
158REGISTER_BINARY_OP(Ge, greaterThanOrEqualTo)
159REGISTER_BINARY_OP(Gt, greaterThan)
160REGISTER_BINARY_OP(Le, lessThanOrEqualTo)
161REGISTER_BINARY_OP(Lt, lessThan)
162
163// Bitwise Binary ops
164REGISTER_BITWISE_BINARY_OP(BitwiseAnd, AND)
165REGISTER_BITWISE_BINARY_OP(BitwiseOr, OR)
166REGISTER_BITWISE_BINARY_OP(BitwiseXor, XOR)
167
168#undef REGISTER_BINARY_WITH_ALPHA_OP
169#undef REGISTER_BINARY_OP
170
171static
172MPSGraphTensor* mpsTruncTensor(MPSGraphTensor* inputTensor, MPSGraph* mpsGraph) {
173  // Rounding is a no-op for integral types, and also a reasonable workaround
174  // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
175  // See https://github.com/pytorch/pytorch/issues/84995
176  bool isFloatInput = ([inputTensor dataType] & MPSDataTypeFloatBit) != 0;
177  if (!isFloatInput) {
178    return inputTensor;
179  }
180
181  if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS)) {
182    MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
183    MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
184                                                          secondaryTensor:zeroTensor
185                                                                     name:nil];
186    return [mpsGraph selectWithPredicateTensor:predicateTensor
187                           truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil]
188                          falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
189                                          name:nil];
190  } else {
191    return [mpsGraph truncateWithTensor:inputTensor
192                                    name:nil];
193  }
194};
195
196static
197MPSGraphTensor* divModeTemplate(
198  MPSGraphTensor* primaryTensor,
199  MPSGraphTensor* secondaryTensor,
200  std::optional<flatbuffers::string_view> rounding_mode,
201  MPSGraph* mpsGraph,
202  const std::string& op_name) {
203  MPSDataType mpsInputDataType = [primaryTensor dataType];
204  ScalarType inputDataType = getScalarType(mpsInputDataType);
205
206  if(rounding_mode.has_value() && *rounding_mode == "trunc"){
207    ET_CHECK_MSG(inputDataType != ScalarType::Half,
208                "MPS: does not support trunc_divide op with float16 input");
209  }
210
211  auto divOpFunc = [&](MPSGraphTensor* primaryCastTensor,
212                      MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* {
213    bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;
214    if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
215      primaryCastTensor = [mpsGraph castTensor:primaryCastTensor
216                                        toType:MPSDataTypeFloat32
217                                          name:@"primaryCastTensor"];
218      secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor
219                                          toType:MPSDataTypeFloat32
220                                            name:@"secondaryCastTensor"];
221    }
222    MPSGraphTensor* divTensor =  [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
223                                                      secondaryTensor:secondaryCastTensor
224                                                                name:nil];
225
226    // Rounding is a no-op for integral types, and also a reasonable workaround
227    // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
228    // See https://github.com/pytorch/pytorch/issues/84995
229    bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0;
230    if (!rounding_mode.has_value() || !isFloatOutput) {
231      return divTensor;
232    } else if (*rounding_mode == "trunc") {
233      auto truncTensor =  mpsTruncTensor(divTensor, mpsGraph);
234      if (op_name == "Fmod") {
235        auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor
236                                                    secondaryTensor:secondaryCastTensor
237                                                              name:nil];
238        return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
239                                      secondaryTensor:mulTensor
240                                                  name:nil];
241      }
242      return truncTensor;
243    } else if (*rounding_mode == "floor") {
244      MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil];
245      if (op_name == "Remainder") {
246        auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
247                                                    secondaryTensor:secondaryCastTensor
248                                                              name:nil];
249        return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
250                                      secondaryTensor:mulTensor
251                                                  name:nil];
252      }
253      return floorTensor;
254    } else {
255      assert(0 && "Invalid rounding mode\n");
256    }
257    return nullptr;
258  };
259  return binaryOpTensor(primaryTensor, secondaryTensor, mpsGraph, divOpFunc);
260}
261
262#define REGISTER_DIV_OP(aot_name, round_mode)                                         \
263Error                                                                                 \
264MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr)  {                                \
265  auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name();                         \
266  ET_LOG(                                                                             \
267    Debug, "%s: (%d, %d) -> %d",                                                      \
268    __FUNCTION__,                                                                     \
269    graphNode->input1_id(),                                                           \
270    graphNode->input2_id(),                                                           \
271    graphNode->output_id()                                                            \
272  );                                                                                  \
273                                                                                      \
274  auto strView = graphNode->rounding_mode() != nullptr ?                              \
275    std::make_optional(graphNode->rounding_mode()->string_view()) : round_mode;       \
276                                                                                      \
277  _idToMPSGraphTensor[graphNode->output_id()] = divModeTemplate(                      \
278    getMPSGraphTensor(graphNode->input1_id()),                                        \
279    getMPSGraphTensor(graphNode->input2_id()),                                        \
280    strView,                                                                          \
281    _mpsGraph,                                                                        \
282    #aot_name                                                                         \
283  );                                                                                  \
284                                                                                      \
285  return Error::Ok;                                                                   \
286}
287
288REGISTER_DIV_OP(Div, std::nullopt)
289REGISTER_DIV_OP(Fmod, "trunc")
290REGISTER_DIV_OP(Remainder, "floor")
291
292#undef REGISTER_DIV_OP
293
294
295} // namespace delegate
296} // namespace mps
297} // namespace backends
298} // namespace executorch
299