• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include "Broadcast.h"
22 
23 #include <algorithm>
24 #include <vector>
25 
26 #include "IndexedShapeWrapper.h"
27 #include "OperationResolver.h"
28 #include "Tracing.h"
29 #include "nnapi/Types.h"
30 #include "nnapi/Validation.h"
31 
32 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
33 #pragma clang diagnostic push
34 #pragma clang diagnostic ignored "-Wunused-parameter"
35 #pragma clang diagnostic ignored "-Wsign-compare"
36 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
37 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/add.h>
38 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h>
39 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
40 #include <tensorflow/lite/kernels/internal/reference/integer_ops/add.h>
41 #include <tensorflow/lite/kernels/internal/reference/integer_ops/mul.h>
42 #include <tensorflow/lite/kernels/internal/types.h>
43 #pragma clang diagnostic pop
44 
45 #include "CpuOperationUtils.h"
46 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
47 
48 namespace android {
49 namespace nn {
50 
51 namespace broadcast {
52 
53 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
54 namespace {
55 
56 #define ANDROID_NN_MACRO_DISPATCH(macro)                                \
57     switch (activation) {                                               \
58         case static_cast<int32_t>(FusedActivationFunc::NONE):           \
59             macro(kNone);                                               \
60             break;                                                      \
61         case static_cast<int32_t>(FusedActivationFunc::RELU):           \
62             macro(kRelu);                                               \
63             break;                                                      \
64         case static_cast<int32_t>(FusedActivationFunc::RELU1):          \
65             macro(kRelu1);                                              \
66             break;                                                      \
67         case static_cast<int32_t>(FusedActivationFunc::RELU6):          \
68             macro(kRelu6);                                              \
69             break;                                                      \
70         default:                                                        \
71             LOG(ERROR) << "Unsupported fused activation function type"; \
72             return false;                                               \
73     }
74 
75 using binaryFunctionFloat32 = std::function<bool(
76         const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
77         int32_t activation, float* out, const Shape& shapeOut)>;
78 
binaryOperationFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut,binaryFunctionFloat32 operationFloat32)79 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
80                             const Shape& shape2, int32_t activation, _Float16* out,
81                             const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
82     std::vector<float> in1_float32(getNumberOfElements(shape1));
83     convertFloat16ToFloat32(in1, &in1_float32);
84     std::vector<float> in2_float32(getNumberOfElements(shape2));
85     convertFloat16ToFloat32(in2, &in2_float32);
86     std::vector<float> out_float32(getNumberOfElements(shapeOut));
87 
88     operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
89                      out_float32.data(), shapeOut);
90     convertFloat32ToFloat16(out_float32, out);
91 
92     return true;
93 }
94 
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)95 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
96                 int32_t activation, float* out, const Shape& shapeOut) {
97     NNTRACE_TRANS("addFloat32");
98     bool needBroadcast = !SameShape(shape1, shape2);
99     if (needBroadcast) {
100         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
101 #define ANDROID_NN_BROADCAST_ADD(activation)                                              \
102     tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
103             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
104             convertShapeToDims(shapeOut))
105 
106         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
107 #undef ANDROID_NN_BROADCAST_ADD
108     } else {
109         NNTRACE_COMP_SWITCH("optimized_ops::Add");
110 #define ANDROID_NN_ADD(activation)                                                 \
111     tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(   \
112             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
113             convertShapeToDims(shapeOut))
114 
115         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
116 #undef ANDROID_NN_ADD
117     }
118 
119     return true;
120 }
121 
addFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)122 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
123                 int32_t activation, _Float16* out, const Shape& shapeOut) {
124     NNTRACE_TRANS("addFloat16");
125     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
126 }
127 
128 template <typename T>
addQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)129 bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
130                int32_t activation, T* out, const Shape& shapeOut) {
131     NNTRACE_TRANS("addQuant8");
132     const bool needBroadcast = !SameShape(shape1, shape2);
133 
134     const int32_t input1_offset = -shape1.offset;
135     const int32_t input2_offset = -shape2.offset;
136     const int32_t output_offset = shapeOut.offset;
137     const int left_shift = 20;
138     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
139     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
140     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
141     const double real_output_multiplier =
142             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
143 
144     int32_t input1_multiplier;
145     int32_t input1_shift;
146     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
147                                                      &input1_shift));
148     int32_t input2_multiplier;
149     int32_t input2_shift;
150     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
151                                                      &input2_shift));
152     int32_t output_multiplier;
153     int32_t output_shift;
154     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
155                                                      &output_shift));
156 
157     int32_t output_activation_min;
158     int32_t output_activation_max;
159     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
160     if constexpr (isSignedOp) {
161         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
162                                      &output_activation_max);
163     } else {
164         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
165                                       &output_activation_max);
166     }
167 
168     tflite::ArithmeticParams op_params;
169     op_params.left_shift = left_shift;
170     op_params.input1_offset = input1_offset;
171     op_params.input1_multiplier = input1_multiplier;
172     op_params.input1_shift = input1_shift;
173     op_params.input2_offset = input2_offset;
174     op_params.input2_multiplier = input2_multiplier;
175     op_params.input2_shift = input2_shift;
176     op_params.output_offset = output_offset;
177     op_params.output_multiplier = output_multiplier;
178     op_params.output_shift = output_shift;
179     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
180 
181     if (needBroadcast) {
182         if constexpr (isSignedOp) {
183             NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
184             tflite::reference_integer_ops::BroadcastAdd4DSlow(
185                     op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2),
186                     in2, convertShapeToTflshape(shapeOut), out);
187         } else {
188             NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
189             tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
190                                                       in1, convertShapeToTflshape(shape2), in2,
191                                                       convertShapeToTflshape(shapeOut), out);
192         }
193     } else {
194         if constexpr (isSignedOp) {
195             NNTRACE_COMP_SWITCH("optimized_integer_ops::Add");
196             tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
197                                                convertShapeToTflshape(shape2), in2,
198                                                convertShapeToTflshape(shapeOut), out);
199         } else {
200             NNTRACE_COMP_SWITCH("optimized_ops::Add");
201             tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
202                                        convertShapeToTflshape(shape2), in2,
203                                        convertShapeToTflshape(shapeOut), out);
204         }
205     }
206 
207     return true;
208 }
209 
executeInt32(const int32_t * aData,const Shape & aShape,const int32_t * bData,const Shape & bShape,int32_t activation,int32_t * outputData,const Shape & outputShape,int32_t func (int32_t,int32_t))210 bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
211                   const Shape& bShape, int32_t activation, int32_t* outputData,
212                   const Shape& outputShape, int32_t func(int32_t, int32_t)) {
213     NN_RET_CHECK_EQ(static_cast<FusedActivationFunc>(activation), FusedActivationFunc::NONE);
214     IndexedShapeWrapper aShapeIndexed(aShape);
215     IndexedShapeWrapper bShapeIndexed(bShape);
216     IndexedShapeWrapper outputShapeIndexed(outputShape);
217     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
218     bool lastIndex = false;
219     do {
220         uint32_t outputFlatIndex;
221         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
222         uint32_t aFlatIndex;
223         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
224         uint32_t bFlatIndex;
225         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
226 
227         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
228 
229         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
230     } while (!lastIndex);
231     return true;
232 }
233 
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)234 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
235                 int32_t activation, float* out, const Shape& shapeOut) {
236     NNTRACE_TRANS("mulFloat32");
237     bool needBroadcast = !SameShape(shape1, shape2);
238 
239     if (needBroadcast) {
240         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
241 #define ANDROID_NN_BROADCAST_MUL(activation)                                              \
242     tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
243             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
244             convertShapeToDims(shapeOut))
245 
246         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
247 #undef ANDROID_NN_BROADCAST_MUL
248     } else {
249         float output_activation_min, output_activation_max;
250         CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
251 
252         NNTRACE_COMP_SWITCH("optimized_ops::Mul");
253         tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
254                                    output_activation_min, output_activation_max, out,
255                                    convertShapeToDims(shapeOut));
256     }
257 
258     return true;
259 }
260 
mulFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)261 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
262                 int32_t activation, _Float16* out, const Shape& shapeOut) {
263     NNTRACE_TRANS("mulFloat16");
264     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
265 }
266 
267 template <typename T>
mulQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)268 bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
269                int32_t activation, T* out, const Shape& shapeOut) {
270     NNTRACE_TRANS("mulQuant8");
271     const int32_t input1_offset = -shape1.offset;
272     const int32_t input2_offset = -shape2.offset;
273     const int32_t output_offset = shapeOut.offset;
274     const double input_product_scale = shape1.scale * shape2.scale;
275     const double real_multiplier = input_product_scale / shapeOut.scale;
276     int32 output_multiplier;
277     int output_shift;
278     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
279                                                      &output_shift));
280 
281     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
282     int32_t output_activation_min;
283     int32_t output_activation_max;
284     if constexpr (isSignedOp) {
285         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
286                                      &output_activation_max);
287     } else {
288         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
289                                       &output_activation_max);
290     }
291 
292     tflite::ArithmeticParams op_params;
293     op_params.input1_offset = input1_offset;
294     op_params.input2_offset = input2_offset;
295     op_params.output_offset = output_offset;
296     op_params.output_multiplier = output_multiplier;
297     op_params.output_shift = output_shift;
298     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
299 
300     if constexpr (isSignedOp) {
301         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow");
302         tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1),
303                                                           in1, convertShapeToTflshape(shape2), in2,
304                                                           convertShapeToTflshape(shapeOut), out);
305     } else {
306         NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow");
307         tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1,
308                                                   convertShapeToTflshape(shape2), in2,
309                                                   convertShapeToTflshape(shapeOut), out);
310     }
311 
312     return true;
313 }
314 
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)315 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
316                 int32_t activation, float* out, const Shape& shapeOut) {
317     NNTRACE_TRANS("subFloat32");
318     NNTRACE_COMP_SWITCH("optimized_ops::Sub");
319     tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
320                                out, convertShapeToDims(shapeOut));
321 
322     // TFLite does not apply activation to broadcast sub.
323     float output_activation_min, output_activation_max;
324     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
325     uint32_t numOutputElements = getNumberOfElements(shapeOut);
326     for (uint32_t i = 0; i < numOutputElements; i++) {
327         out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
328     }
329     return true;
330 }
331 
subFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)332 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
333                 int32_t activation, _Float16* out, const Shape& shapeOut) {
334     NNTRACE_TRANS("subFloat16");
335     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
336 }
337 
338 template <typename T>
subQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)339 bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
340                int32_t activation, T* out, const Shape& shapeOut) {
341     NNTRACE_TRANS("subQuant8");
342 
343     const int32_t input1_offset = -shape1.offset;
344     const int32_t input2_offset = -shape2.offset;
345     const int32_t output_offset = shapeOut.offset;
346     const int left_shift = 20;
347     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
348     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
349     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
350     const double real_output_multiplier =
351             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
352 
353     int32_t input1_multiplier;
354     int32_t input1_shift;
355     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
356                                                      &input1_shift));
357     int32_t input2_multiplier;
358     int32_t input2_shift;
359     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
360                                                      &input2_shift));
361     // Negate multiplier of the second input, so that we can use Add kernels.
362     input2_multiplier *= -1;
363 
364     int32_t output_multiplier;
365     int32_t output_shift;
366     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
367                                                      &output_shift));
368 
369     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
370     int32_t output_activation_min;
371     int32_t output_activation_max;
372     if constexpr (isSignedOp) {
373         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
374                                      &output_activation_max);
375     } else {
376         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
377                                       &output_activation_max);
378     }
379 
380     tflite::ArithmeticParams op_params;
381     op_params.left_shift = left_shift;
382     op_params.input1_offset = input1_offset;
383     op_params.input1_multiplier = input1_multiplier;
384     op_params.input1_shift = input1_shift;
385     op_params.input2_offset = input2_offset;
386     op_params.input2_multiplier = input2_multiplier;
387     op_params.input2_shift = input2_shift;
388     op_params.output_offset = output_offset;
389     op_params.output_multiplier = output_multiplier;
390     op_params.output_shift = output_shift;
391     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
392 
393     // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
394     // because tflite::optimized_ops::Add fails to pass some of the
395     // sub_quantized_different_scales tests.
396     if constexpr (isSignedOp) {
397         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
398         tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
399                                                           in1, convertShapeToTflshape(shape2), in2,
400                                                           convertShapeToTflshape(shapeOut), out);
401     } else {
402         NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
403         tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1,
404                                                   convertShapeToTflshape(shape2), in2,
405                                                   convertShapeToTflshape(shapeOut), out);
406     }
407 
408     return true;
409 }
410 
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)411 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
412                 int32_t activation, float* out, const Shape& shapeOut) {
413     NNTRACE_TRANS("divFloat32");
414     float output_activation_min, output_activation_max;
415     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
416 
417     bool needBroadcast = !SameShape(shape1, shape2);
418     if (needBroadcast) {
419         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
420         tflite::optimized_ops::BroadcastDiv(
421                 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
422                 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
423     } else {
424         NNTRACE_COMP_SWITCH("optimized_ops::Div");
425         tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
426                                    output_activation_min, output_activation_max, out,
427                                    convertShapeToDims(shapeOut));
428     }
429     return true;
430 }
431 
divFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)432 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
433                 int32_t activation, _Float16* out, const Shape& shapeOut) {
434     NNTRACE_TRANS("divFloat16");
435     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
436 }
437 
438 }  // namespace
439 
prepare(IOperationExecutionContext * context)440 bool prepare(IOperationExecutionContext* context) {
441     Shape input1 = context->getInputShape(kInputTensor1);
442     Shape input2 = context->getInputShape(kInputTensor2);
443     Shape output = context->getOutputShape(kOutputTensor);
444     NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4u);
445     NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4u);
446     NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
447     return context->setOutputShape(kOutputTensor, output);
448 }
449 
executeAdd(IOperationExecutionContext * context)450 bool executeAdd(IOperationExecutionContext* context) {
451     // Bypass execution in the case of zero-sized input.
452     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
453     switch (context->getInputType(kInputTensor1)) {
454         case OperandType::TENSOR_FLOAT16:
455             return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
456                               context->getInputShape(kInputTensor1),
457                               context->getInputBuffer<_Float16>(kInputTensor2),
458                               context->getInputShape(kInputTensor2),
459                               context->getInputValue<int32_t>(kActivationScalar),
460                               context->getOutputBuffer<_Float16>(kOutputTensor),
461                               context->getOutputShape(kOutputTensor));
462         case OperandType::TENSOR_FLOAT32:
463             return addFloat32(context->getInputBuffer<float>(kInputTensor1),
464                               context->getInputShape(kInputTensor1),
465                               context->getInputBuffer<float>(kInputTensor2),
466                               context->getInputShape(kInputTensor2),
467                               context->getInputValue<int32_t>(kActivationScalar),
468                               context->getOutputBuffer<float>(kOutputTensor),
469                               context->getOutputShape(kOutputTensor));
470         case OperandType::TENSOR_QUANT8_ASYMM:
471             return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
472                              context->getInputShape(kInputTensor1),
473                              context->getInputBuffer<uint8_t>(kInputTensor2),
474                              context->getInputShape(kInputTensor2),
475                              context->getInputValue<int32_t>(kActivationScalar),
476                              context->getOutputBuffer<uint8_t>(kOutputTensor),
477                              context->getOutputShape(kOutputTensor));
478         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
479             return addQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
480                              context->getInputShape(kInputTensor1),
481                              context->getInputBuffer<int8_t>(kInputTensor2),
482                              context->getInputShape(kInputTensor2),
483                              context->getInputValue<int32_t>(kActivationScalar),
484                              context->getOutputBuffer<int8_t>(kOutputTensor),
485                              context->getOutputShape(kOutputTensor));
486         case OperandType::TENSOR_INT32:
487             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
488                                 context->getInputShape(kInputTensor1),
489                                 context->getInputBuffer<int32_t>(kInputTensor2),
490                                 context->getInputShape(kInputTensor2),
491                                 context->getInputValue<int32_t>(kActivationScalar),
492                                 context->getOutputBuffer<int32_t>(kOutputTensor),
493                                 context->getOutputShape(kOutputTensor),
494                                 [](int32_t a, int32_t b) { return a + b; });
495         default:
496             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
497     }
498 }
499 
executeMul(IOperationExecutionContext * context)500 bool executeMul(IOperationExecutionContext* context) {
501     // Bypass execution in the case of zero-sized input.
502     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
503     switch (context->getInputType(kInputTensor1)) {
504         case OperandType::TENSOR_FLOAT16:
505             return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
506                               context->getInputShape(kInputTensor1),
507                               context->getInputBuffer<_Float16>(kInputTensor2),
508                               context->getInputShape(kInputTensor2),
509                               context->getInputValue<int32_t>(kActivationScalar),
510                               context->getOutputBuffer<_Float16>(kOutputTensor),
511                               context->getOutputShape(kOutputTensor));
512         case OperandType::TENSOR_FLOAT32:
513             return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
514                               context->getInputShape(kInputTensor1),
515                               context->getInputBuffer<float>(kInputTensor2),
516                               context->getInputShape(kInputTensor2),
517                               context->getInputValue<int32_t>(kActivationScalar),
518                               context->getOutputBuffer<float>(kOutputTensor),
519                               context->getOutputShape(kOutputTensor));
520         case OperandType::TENSOR_QUANT8_ASYMM:
521             return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
522                              context->getInputShape(kInputTensor1),
523                              context->getInputBuffer<uint8_t>(kInputTensor2),
524                              context->getInputShape(kInputTensor2),
525                              context->getInputValue<int32_t>(kActivationScalar),
526                              context->getOutputBuffer<uint8_t>(kOutputTensor),
527                              context->getOutputShape(kOutputTensor));
528         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
529             return mulQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
530                              context->getInputShape(kInputTensor1),
531                              context->getInputBuffer<int8_t>(kInputTensor2),
532                              context->getInputShape(kInputTensor2),
533                              context->getInputValue<int32_t>(kActivationScalar),
534                              context->getOutputBuffer<int8_t>(kOutputTensor),
535                              context->getOutputShape(kOutputTensor));
536         case OperandType::TENSOR_INT32:
537             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
538                                 context->getInputShape(kInputTensor1),
539                                 context->getInputBuffer<int32_t>(kInputTensor2),
540                                 context->getInputShape(kInputTensor2),
541                                 context->getInputValue<int32_t>(kActivationScalar),
542                                 context->getOutputBuffer<int32_t>(kOutputTensor),
543                                 context->getOutputShape(kOutputTensor),
544                                 [](int32_t a, int32_t b) { return a * b; });
545         default:
546             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
547     }
548 }
549 
executeSub(IOperationExecutionContext * context)550 bool executeSub(IOperationExecutionContext* context) {
551     // Bypass execution in the case of zero-sized input.
552     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
553     switch (context->getInputType(kInputTensor1)) {
554         case OperandType::TENSOR_FLOAT16:
555             return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
556                               context->getInputShape(kInputTensor1),
557                               context->getInputBuffer<_Float16>(kInputTensor2),
558                               context->getInputShape(kInputTensor2),
559                               context->getInputValue<int32_t>(kActivationScalar),
560                               context->getOutputBuffer<_Float16>(kOutputTensor),
561                               context->getOutputShape(kOutputTensor));
562         case OperandType::TENSOR_FLOAT32:
563             return subFloat32(context->getInputBuffer<float>(kInputTensor1),
564                               context->getInputShape(kInputTensor1),
565                               context->getInputBuffer<float>(kInputTensor2),
566                               context->getInputShape(kInputTensor2),
567                               context->getInputValue<int32_t>(kActivationScalar),
568                               context->getOutputBuffer<float>(kOutputTensor),
569                               context->getOutputShape(kOutputTensor));
570         case OperandType::TENSOR_QUANT8_ASYMM:
571             return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
572                              context->getInputShape(kInputTensor1),
573                              context->getInputBuffer<uint8_t>(kInputTensor2),
574                              context->getInputShape(kInputTensor2),
575                              context->getInputValue<int32_t>(kActivationScalar),
576                              context->getOutputBuffer<uint8_t>(kOutputTensor),
577                              context->getOutputShape(kOutputTensor));
578         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
579             return subQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
580                              context->getInputShape(kInputTensor1),
581                              context->getInputBuffer<int8_t>(kInputTensor2),
582                              context->getInputShape(kInputTensor2),
583                              context->getInputValue<int32_t>(kActivationScalar),
584                              context->getOutputBuffer<int8_t>(kOutputTensor),
585                              context->getOutputShape(kOutputTensor));
586         case OperandType::TENSOR_INT32:
587             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
588                                 context->getInputShape(kInputTensor1),
589                                 context->getInputBuffer<int32_t>(kInputTensor2),
590                                 context->getInputShape(kInputTensor2),
591                                 context->getInputValue<int32_t>(kActivationScalar),
592                                 context->getOutputBuffer<int32_t>(kOutputTensor),
593                                 context->getOutputShape(kOutputTensor),
594                                 [](int32_t a, int32_t b) { return a - b; });
595         default:
596             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
597     }
598 }
599 
executeDiv(IOperationExecutionContext * context)600 bool executeDiv(IOperationExecutionContext* context) {
601     // Bypass execution in the case of zero-sized input.
602     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
603     switch (context->getInputType(kInputTensor1)) {
604         case OperandType::TENSOR_FLOAT16:
605             return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
606                               context->getInputShape(kInputTensor1),
607                               context->getInputBuffer<_Float16>(kInputTensor2),
608                               context->getInputShape(kInputTensor2),
609                               context->getInputValue<int32_t>(kActivationScalar),
610                               context->getOutputBuffer<_Float16>(kOutputTensor),
611                               context->getOutputShape(kOutputTensor));
612         case OperandType::TENSOR_FLOAT32:
613             return divFloat32(context->getInputBuffer<float>(kInputTensor1),
614                               context->getInputShape(kInputTensor1),
615                               context->getInputBuffer<float>(kInputTensor2),
616                               context->getInputShape(kInputTensor2),
617                               context->getInputValue<int32_t>(kActivationScalar),
618                               context->getOutputBuffer<float>(kOutputTensor),
619                               context->getOutputShape(kOutputTensor));
620         case OperandType::TENSOR_INT32:
621             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
622                                 context->getInputShape(kInputTensor1),
623                                 context->getInputBuffer<int32_t>(kInputTensor2),
624                                 context->getInputShape(kInputTensor2),
625                                 context->getInputValue<int32_t>(kActivationScalar),
626                                 context->getOutputBuffer<int32_t>(kOutputTensor),
627                                 context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) {
628                                     // In NNAPI, DIV by zero is undefined, but should not crash.
629                                     if (b == 0) return 0;
630                                     int32_t result = a / b;
631                                     if (a % b != 0 && ((a < 0) != (b < 0))) {
632                                         // Implement "floor division".
633                                         --result;
634                                     }
635                                     return result;
636                                 });
637         default:
638             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
639     }
640 }
641 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
642 
643 }  // namespace broadcast
644 
645 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(ADD, broadcast::prepare, broadcast::executeAdd,
646                                          .allowZeroSizedInput = true);
647 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(MUL, broadcast::prepare, broadcast::executeMul,
648                                          .allowZeroSizedInput = true);
649 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(SUB, broadcast::prepare, broadcast::executeSub,
650                                          .allowZeroSizedInput = true);
651 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(DIV, broadcast::prepare, broadcast::executeDiv,
652                                          .allowZeroSizedInput = true);
653 
654 }  // namespace nn
655 }  // namespace android
656