1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Contains the implementation of the operations.
18
19 #define LOG_TAG "Operations"
20
21 #include "Broadcast.h"
22
23 #include <algorithm>
24 #include <vector>
25
26 #include "IndexedShapeWrapper.h"
27 #include "OperationResolver.h"
28 #include "Tracing.h"
29 #include "nnapi/Types.h"
30 #include "nnapi/Validation.h"
31
32 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
33 #pragma clang diagnostic push
34 #pragma clang diagnostic ignored "-Wunused-parameter"
35 #pragma clang diagnostic ignored "-Wsign-compare"
36 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
37 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/add.h>
38 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h>
39 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
40 #include <tensorflow/lite/kernels/internal/reference/integer_ops/add.h>
41 #include <tensorflow/lite/kernels/internal/reference/integer_ops/mul.h>
42 #include <tensorflow/lite/kernels/internal/types.h>
43 #pragma clang diagnostic pop
44
45 #include "CpuOperationUtils.h"
46 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
47
48 namespace android {
49 namespace nn {
50
51 namespace broadcast {
52
53 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
54 namespace {
55
56 #define ANDROID_NN_MACRO_DISPATCH(macro) \
57 switch (activation) { \
58 case static_cast<int32_t>(FusedActivationFunc::NONE): \
59 macro(kNone); \
60 break; \
61 case static_cast<int32_t>(FusedActivationFunc::RELU): \
62 macro(kRelu); \
63 break; \
64 case static_cast<int32_t>(FusedActivationFunc::RELU1): \
65 macro(kRelu1); \
66 break; \
67 case static_cast<int32_t>(FusedActivationFunc::RELU6): \
68 macro(kRelu6); \
69 break; \
70 default: \
71 LOG(ERROR) << "Unsupported fused activation function type"; \
72 return false; \
73 }
74
75 using binaryFunctionFloat32 = std::function<bool(
76 const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
77 int32_t activation, float* out, const Shape& shapeOut)>;
78
binaryOperationFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut,binaryFunctionFloat32 operationFloat32)79 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
80 const Shape& shape2, int32_t activation, _Float16* out,
81 const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
82 std::vector<float> in1_float32(getNumberOfElements(shape1));
83 convertFloat16ToFloat32(in1, &in1_float32);
84 std::vector<float> in2_float32(getNumberOfElements(shape2));
85 convertFloat16ToFloat32(in2, &in2_float32);
86 std::vector<float> out_float32(getNumberOfElements(shapeOut));
87
88 operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
89 out_float32.data(), shapeOut);
90 convertFloat32ToFloat16(out_float32, out);
91
92 return true;
93 }
94
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)95 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
96 int32_t activation, float* out, const Shape& shapeOut) {
97 NNTRACE_TRANS("addFloat32");
98 bool needBroadcast = !SameShape(shape1, shape2);
99 if (needBroadcast) {
100 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
101 #define ANDROID_NN_BROADCAST_ADD(activation) \
102 tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
103 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
104 convertShapeToDims(shapeOut))
105
106 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
107 #undef ANDROID_NN_BROADCAST_ADD
108 } else {
109 NNTRACE_COMP_SWITCH("optimized_ops::Add");
110 #define ANDROID_NN_ADD(activation) \
111 tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
112 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
113 convertShapeToDims(shapeOut))
114
115 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
116 #undef ANDROID_NN_ADD
117 }
118
119 return true;
120 }
121
addFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)122 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
123 int32_t activation, _Float16* out, const Shape& shapeOut) {
124 NNTRACE_TRANS("addFloat16");
125 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
126 }
127
128 template <typename T>
addQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)129 bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
130 int32_t activation, T* out, const Shape& shapeOut) {
131 NNTRACE_TRANS("addQuant8");
132 const bool needBroadcast = !SameShape(shape1, shape2);
133
134 const int32_t input1_offset = -shape1.offset;
135 const int32_t input2_offset = -shape2.offset;
136 const int32_t output_offset = shapeOut.offset;
137 const int left_shift = 20;
138 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
139 const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
140 const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
141 const double real_output_multiplier =
142 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
143
144 int32_t input1_multiplier;
145 int32_t input1_shift;
146 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
147 &input1_shift));
148 int32_t input2_multiplier;
149 int32_t input2_shift;
150 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
151 &input2_shift));
152 int32_t output_multiplier;
153 int32_t output_shift;
154 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
155 &output_shift));
156
157 int32_t output_activation_min;
158 int32_t output_activation_max;
159 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
160 if constexpr (isSignedOp) {
161 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
162 &output_activation_max);
163 } else {
164 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
165 &output_activation_max);
166 }
167
168 tflite::ArithmeticParams op_params;
169 op_params.left_shift = left_shift;
170 op_params.input1_offset = input1_offset;
171 op_params.input1_multiplier = input1_multiplier;
172 op_params.input1_shift = input1_shift;
173 op_params.input2_offset = input2_offset;
174 op_params.input2_multiplier = input2_multiplier;
175 op_params.input2_shift = input2_shift;
176 op_params.output_offset = output_offset;
177 op_params.output_multiplier = output_multiplier;
178 op_params.output_shift = output_shift;
179 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
180
181 if (needBroadcast) {
182 if constexpr (isSignedOp) {
183 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
184 tflite::reference_integer_ops::BroadcastAdd4DSlow(
185 op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2),
186 in2, convertShapeToTflshape(shapeOut), out);
187 } else {
188 NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
189 tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
190 in1, convertShapeToTflshape(shape2), in2,
191 convertShapeToTflshape(shapeOut), out);
192 }
193 } else {
194 if constexpr (isSignedOp) {
195 NNTRACE_COMP_SWITCH("optimized_integer_ops::Add");
196 tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
197 convertShapeToTflshape(shape2), in2,
198 convertShapeToTflshape(shapeOut), out);
199 } else {
200 NNTRACE_COMP_SWITCH("optimized_ops::Add");
201 tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
202 convertShapeToTflshape(shape2), in2,
203 convertShapeToTflshape(shapeOut), out);
204 }
205 }
206
207 return true;
208 }
209
executeInt32(const int32_t * aData,const Shape & aShape,const int32_t * bData,const Shape & bShape,int32_t activation,int32_t * outputData,const Shape & outputShape,int32_t func (int32_t,int32_t))210 bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
211 const Shape& bShape, int32_t activation, int32_t* outputData,
212 const Shape& outputShape, int32_t func(int32_t, int32_t)) {
213 NN_RET_CHECK_EQ(static_cast<FusedActivationFunc>(activation), FusedActivationFunc::NONE);
214 IndexedShapeWrapper aShapeIndexed(aShape);
215 IndexedShapeWrapper bShapeIndexed(bShape);
216 IndexedShapeWrapper outputShapeIndexed(outputShape);
217 std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
218 bool lastIndex = false;
219 do {
220 uint32_t outputFlatIndex;
221 NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
222 uint32_t aFlatIndex;
223 NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
224 uint32_t bFlatIndex;
225 NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
226
227 outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
228
229 NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
230 } while (!lastIndex);
231 return true;
232 }
233
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)234 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
235 int32_t activation, float* out, const Shape& shapeOut) {
236 NNTRACE_TRANS("mulFloat32");
237 bool needBroadcast = !SameShape(shape1, shape2);
238
239 if (needBroadcast) {
240 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
241 #define ANDROID_NN_BROADCAST_MUL(activation) \
242 tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
243 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
244 convertShapeToDims(shapeOut))
245
246 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
247 #undef ANDROID_NN_BROADCAST_MUL
248 } else {
249 float output_activation_min, output_activation_max;
250 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
251
252 NNTRACE_COMP_SWITCH("optimized_ops::Mul");
253 tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
254 output_activation_min, output_activation_max, out,
255 convertShapeToDims(shapeOut));
256 }
257
258 return true;
259 }
260
mulFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)261 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
262 int32_t activation, _Float16* out, const Shape& shapeOut) {
263 NNTRACE_TRANS("mulFloat16");
264 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
265 }
266
267 template <typename T>
mulQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)268 bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
269 int32_t activation, T* out, const Shape& shapeOut) {
270 NNTRACE_TRANS("mulQuant8");
271 const int32_t input1_offset = -shape1.offset;
272 const int32_t input2_offset = -shape2.offset;
273 const int32_t output_offset = shapeOut.offset;
274 const double input_product_scale = shape1.scale * shape2.scale;
275 const double real_multiplier = input_product_scale / shapeOut.scale;
276 int32 output_multiplier;
277 int output_shift;
278 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
279 &output_shift));
280
281 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
282 int32_t output_activation_min;
283 int32_t output_activation_max;
284 if constexpr (isSignedOp) {
285 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
286 &output_activation_max);
287 } else {
288 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
289 &output_activation_max);
290 }
291
292 tflite::ArithmeticParams op_params;
293 op_params.input1_offset = input1_offset;
294 op_params.input2_offset = input2_offset;
295 op_params.output_offset = output_offset;
296 op_params.output_multiplier = output_multiplier;
297 op_params.output_shift = output_shift;
298 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
299
300 if constexpr (isSignedOp) {
301 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow");
302 tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1),
303 in1, convertShapeToTflshape(shape2), in2,
304 convertShapeToTflshape(shapeOut), out);
305 } else {
306 NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow");
307 tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1,
308 convertShapeToTflshape(shape2), in2,
309 convertShapeToTflshape(shapeOut), out);
310 }
311
312 return true;
313 }
314
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)315 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
316 int32_t activation, float* out, const Shape& shapeOut) {
317 NNTRACE_TRANS("subFloat32");
318 NNTRACE_COMP_SWITCH("optimized_ops::Sub");
319 tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
320 out, convertShapeToDims(shapeOut));
321
322 // TFLite does not apply activation to broadcast sub.
323 float output_activation_min, output_activation_max;
324 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
325 uint32_t numOutputElements = getNumberOfElements(shapeOut);
326 for (uint32_t i = 0; i < numOutputElements; i++) {
327 out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
328 }
329 return true;
330 }
331
subFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)332 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
333 int32_t activation, _Float16* out, const Shape& shapeOut) {
334 NNTRACE_TRANS("subFloat16");
335 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
336 }
337
338 template <typename T>
subQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)339 bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
340 int32_t activation, T* out, const Shape& shapeOut) {
341 NNTRACE_TRANS("subQuant8");
342
343 const int32_t input1_offset = -shape1.offset;
344 const int32_t input2_offset = -shape2.offset;
345 const int32_t output_offset = shapeOut.offset;
346 const int left_shift = 20;
347 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
348 const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
349 const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
350 const double real_output_multiplier =
351 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
352
353 int32_t input1_multiplier;
354 int32_t input1_shift;
355 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
356 &input1_shift));
357 int32_t input2_multiplier;
358 int32_t input2_shift;
359 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
360 &input2_shift));
361 // Negate multiplier of the second input, so that we can use Add kernels.
362 input2_multiplier *= -1;
363
364 int32_t output_multiplier;
365 int32_t output_shift;
366 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
367 &output_shift));
368
369 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
370 int32_t output_activation_min;
371 int32_t output_activation_max;
372 if constexpr (isSignedOp) {
373 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
374 &output_activation_max);
375 } else {
376 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
377 &output_activation_max);
378 }
379
380 tflite::ArithmeticParams op_params;
381 op_params.left_shift = left_shift;
382 op_params.input1_offset = input1_offset;
383 op_params.input1_multiplier = input1_multiplier;
384 op_params.input1_shift = input1_shift;
385 op_params.input2_offset = input2_offset;
386 op_params.input2_multiplier = input2_multiplier;
387 op_params.input2_shift = input2_shift;
388 op_params.output_offset = output_offset;
389 op_params.output_multiplier = output_multiplier;
390 op_params.output_shift = output_shift;
391 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
392
393 // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
394 // because tflite::optimized_ops::Add fails to pass some of the
395 // sub_quantized_different_scales tests.
396 if constexpr (isSignedOp) {
397 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
398 tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
399 in1, convertShapeToTflshape(shape2), in2,
400 convertShapeToTflshape(shapeOut), out);
401 } else {
402 NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
403 tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1,
404 convertShapeToTflshape(shape2), in2,
405 convertShapeToTflshape(shapeOut), out);
406 }
407
408 return true;
409 }
410
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)411 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
412 int32_t activation, float* out, const Shape& shapeOut) {
413 NNTRACE_TRANS("divFloat32");
414 float output_activation_min, output_activation_max;
415 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
416
417 bool needBroadcast = !SameShape(shape1, shape2);
418 if (needBroadcast) {
419 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
420 tflite::optimized_ops::BroadcastDiv(
421 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
422 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
423 } else {
424 NNTRACE_COMP_SWITCH("optimized_ops::Div");
425 tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
426 output_activation_min, output_activation_max, out,
427 convertShapeToDims(shapeOut));
428 }
429 return true;
430 }
431
divFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)432 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
433 int32_t activation, _Float16* out, const Shape& shapeOut) {
434 NNTRACE_TRANS("divFloat16");
435 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
436 }
437
438 } // namespace
439
prepare(IOperationExecutionContext * context)440 bool prepare(IOperationExecutionContext* context) {
441 Shape input1 = context->getInputShape(kInputTensor1);
442 Shape input2 = context->getInputShape(kInputTensor2);
443 Shape output = context->getOutputShape(kOutputTensor);
444 NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4u);
445 NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4u);
446 NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
447 return context->setOutputShape(kOutputTensor, output);
448 }
449
executeAdd(IOperationExecutionContext * context)450 bool executeAdd(IOperationExecutionContext* context) {
451 // Bypass execution in the case of zero-sized input.
452 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
453 switch (context->getInputType(kInputTensor1)) {
454 case OperandType::TENSOR_FLOAT16:
455 return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
456 context->getInputShape(kInputTensor1),
457 context->getInputBuffer<_Float16>(kInputTensor2),
458 context->getInputShape(kInputTensor2),
459 context->getInputValue<int32_t>(kActivationScalar),
460 context->getOutputBuffer<_Float16>(kOutputTensor),
461 context->getOutputShape(kOutputTensor));
462 case OperandType::TENSOR_FLOAT32:
463 return addFloat32(context->getInputBuffer<float>(kInputTensor1),
464 context->getInputShape(kInputTensor1),
465 context->getInputBuffer<float>(kInputTensor2),
466 context->getInputShape(kInputTensor2),
467 context->getInputValue<int32_t>(kActivationScalar),
468 context->getOutputBuffer<float>(kOutputTensor),
469 context->getOutputShape(kOutputTensor));
470 case OperandType::TENSOR_QUANT8_ASYMM:
471 return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
472 context->getInputShape(kInputTensor1),
473 context->getInputBuffer<uint8_t>(kInputTensor2),
474 context->getInputShape(kInputTensor2),
475 context->getInputValue<int32_t>(kActivationScalar),
476 context->getOutputBuffer<uint8_t>(kOutputTensor),
477 context->getOutputShape(kOutputTensor));
478 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
479 return addQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
480 context->getInputShape(kInputTensor1),
481 context->getInputBuffer<int8_t>(kInputTensor2),
482 context->getInputShape(kInputTensor2),
483 context->getInputValue<int32_t>(kActivationScalar),
484 context->getOutputBuffer<int8_t>(kOutputTensor),
485 context->getOutputShape(kOutputTensor));
486 case OperandType::TENSOR_INT32:
487 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
488 context->getInputShape(kInputTensor1),
489 context->getInputBuffer<int32_t>(kInputTensor2),
490 context->getInputShape(kInputTensor2),
491 context->getInputValue<int32_t>(kActivationScalar),
492 context->getOutputBuffer<int32_t>(kOutputTensor),
493 context->getOutputShape(kOutputTensor),
494 [](int32_t a, int32_t b) { return a + b; });
495 default:
496 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
497 }
498 }
499
executeMul(IOperationExecutionContext * context)500 bool executeMul(IOperationExecutionContext* context) {
501 // Bypass execution in the case of zero-sized input.
502 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
503 switch (context->getInputType(kInputTensor1)) {
504 case OperandType::TENSOR_FLOAT16:
505 return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
506 context->getInputShape(kInputTensor1),
507 context->getInputBuffer<_Float16>(kInputTensor2),
508 context->getInputShape(kInputTensor2),
509 context->getInputValue<int32_t>(kActivationScalar),
510 context->getOutputBuffer<_Float16>(kOutputTensor),
511 context->getOutputShape(kOutputTensor));
512 case OperandType::TENSOR_FLOAT32:
513 return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
514 context->getInputShape(kInputTensor1),
515 context->getInputBuffer<float>(kInputTensor2),
516 context->getInputShape(kInputTensor2),
517 context->getInputValue<int32_t>(kActivationScalar),
518 context->getOutputBuffer<float>(kOutputTensor),
519 context->getOutputShape(kOutputTensor));
520 case OperandType::TENSOR_QUANT8_ASYMM:
521 return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
522 context->getInputShape(kInputTensor1),
523 context->getInputBuffer<uint8_t>(kInputTensor2),
524 context->getInputShape(kInputTensor2),
525 context->getInputValue<int32_t>(kActivationScalar),
526 context->getOutputBuffer<uint8_t>(kOutputTensor),
527 context->getOutputShape(kOutputTensor));
528 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
529 return mulQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
530 context->getInputShape(kInputTensor1),
531 context->getInputBuffer<int8_t>(kInputTensor2),
532 context->getInputShape(kInputTensor2),
533 context->getInputValue<int32_t>(kActivationScalar),
534 context->getOutputBuffer<int8_t>(kOutputTensor),
535 context->getOutputShape(kOutputTensor));
536 case OperandType::TENSOR_INT32:
537 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
538 context->getInputShape(kInputTensor1),
539 context->getInputBuffer<int32_t>(kInputTensor2),
540 context->getInputShape(kInputTensor2),
541 context->getInputValue<int32_t>(kActivationScalar),
542 context->getOutputBuffer<int32_t>(kOutputTensor),
543 context->getOutputShape(kOutputTensor),
544 [](int32_t a, int32_t b) { return a * b; });
545 default:
546 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
547 }
548 }
549
executeSub(IOperationExecutionContext * context)550 bool executeSub(IOperationExecutionContext* context) {
551 // Bypass execution in the case of zero-sized input.
552 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
553 switch (context->getInputType(kInputTensor1)) {
554 case OperandType::TENSOR_FLOAT16:
555 return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
556 context->getInputShape(kInputTensor1),
557 context->getInputBuffer<_Float16>(kInputTensor2),
558 context->getInputShape(kInputTensor2),
559 context->getInputValue<int32_t>(kActivationScalar),
560 context->getOutputBuffer<_Float16>(kOutputTensor),
561 context->getOutputShape(kOutputTensor));
562 case OperandType::TENSOR_FLOAT32:
563 return subFloat32(context->getInputBuffer<float>(kInputTensor1),
564 context->getInputShape(kInputTensor1),
565 context->getInputBuffer<float>(kInputTensor2),
566 context->getInputShape(kInputTensor2),
567 context->getInputValue<int32_t>(kActivationScalar),
568 context->getOutputBuffer<float>(kOutputTensor),
569 context->getOutputShape(kOutputTensor));
570 case OperandType::TENSOR_QUANT8_ASYMM:
571 return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
572 context->getInputShape(kInputTensor1),
573 context->getInputBuffer<uint8_t>(kInputTensor2),
574 context->getInputShape(kInputTensor2),
575 context->getInputValue<int32_t>(kActivationScalar),
576 context->getOutputBuffer<uint8_t>(kOutputTensor),
577 context->getOutputShape(kOutputTensor));
578 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
579 return subQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
580 context->getInputShape(kInputTensor1),
581 context->getInputBuffer<int8_t>(kInputTensor2),
582 context->getInputShape(kInputTensor2),
583 context->getInputValue<int32_t>(kActivationScalar),
584 context->getOutputBuffer<int8_t>(kOutputTensor),
585 context->getOutputShape(kOutputTensor));
586 case OperandType::TENSOR_INT32:
587 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
588 context->getInputShape(kInputTensor1),
589 context->getInputBuffer<int32_t>(kInputTensor2),
590 context->getInputShape(kInputTensor2),
591 context->getInputValue<int32_t>(kActivationScalar),
592 context->getOutputBuffer<int32_t>(kOutputTensor),
593 context->getOutputShape(kOutputTensor),
594 [](int32_t a, int32_t b) { return a - b; });
595 default:
596 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
597 }
598 }
599
executeDiv(IOperationExecutionContext * context)600 bool executeDiv(IOperationExecutionContext* context) {
601 // Bypass execution in the case of zero-sized input.
602 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
603 switch (context->getInputType(kInputTensor1)) {
604 case OperandType::TENSOR_FLOAT16:
605 return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
606 context->getInputShape(kInputTensor1),
607 context->getInputBuffer<_Float16>(kInputTensor2),
608 context->getInputShape(kInputTensor2),
609 context->getInputValue<int32_t>(kActivationScalar),
610 context->getOutputBuffer<_Float16>(kOutputTensor),
611 context->getOutputShape(kOutputTensor));
612 case OperandType::TENSOR_FLOAT32:
613 return divFloat32(context->getInputBuffer<float>(kInputTensor1),
614 context->getInputShape(kInputTensor1),
615 context->getInputBuffer<float>(kInputTensor2),
616 context->getInputShape(kInputTensor2),
617 context->getInputValue<int32_t>(kActivationScalar),
618 context->getOutputBuffer<float>(kOutputTensor),
619 context->getOutputShape(kOutputTensor));
620 case OperandType::TENSOR_INT32:
621 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
622 context->getInputShape(kInputTensor1),
623 context->getInputBuffer<int32_t>(kInputTensor2),
624 context->getInputShape(kInputTensor2),
625 context->getInputValue<int32_t>(kActivationScalar),
626 context->getOutputBuffer<int32_t>(kOutputTensor),
627 context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) {
628 // In NNAPI, DIV by zero is undefined, but should not crash.
629 if (b == 0) return 0;
630 int32_t result = a / b;
631 if (a % b != 0 && ((a < 0) != (b < 0))) {
632 // Implement "floor division".
633 --result;
634 }
635 return result;
636 });
637 default:
638 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
639 }
640 }
641 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
642
643 } // namespace broadcast
644
645 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(ADD, broadcast::prepare, broadcast::executeAdd,
646 .allowZeroSizedInput = true);
647 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(MUL, broadcast::prepare, broadcast::executeMul,
648 .allowZeroSizedInput = true);
649 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(SUB, broadcast::prepare, broadcast::executeSub,
650 .allowZeroSizedInput = true);
651 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(DIV, broadcast::prepare, broadcast::executeDiv,
652 .allowZeroSizedInput = true);
653
654 } // namespace nn
655 } // namespace android
656