• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ActivationFunctor.h"
18 #include "CpuOperationUtils.h"
19 #include "OperationResolver.h"
20 
21 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
23 
24 #include "Tracing.h"
25 
26 namespace android {
27 namespace nn {
28 
29 namespace activation {
30 
31 constexpr uint32_t kNumInputs = 1;
32 constexpr uint32_t kInputTensor = 0;
33 
34 constexpr uint32_t kNumOutputs = 1;
35 constexpr uint32_t kOutputTensor = 0;
36 
37 namespace {
38 
39 template <typename T>
reluFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape,float reluMin=0.f,float reluMax=std::numeric_limits<float>::max ())40 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape,
41                float reluMin = 0.f, float reluMax = std::numeric_limits<float>::max()) {
42     NNTRACE_COMP("reluX");
43     int numElements = getNumberOfElements(inputShape);
44     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
45         *outputData = static_cast<T>(
46                 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
47     }
48     return true;
49 }
50 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
51                                const Shape& outputShape, float reluMin, float reluMax);
52 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
53                                   _Float16* outputData, const Shape& outputShape, float reluMin,
54                                   float reluMax);
55 
56 template <typename T>
relu1Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)57 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
58                 const Shape& outputShape) {
59     return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
60 }
61 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
62                                 const Shape& outputShape);
63 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
64                                    _Float16* outputData, const Shape& outputShape);
65 
66 template <typename T>
relu6Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)67 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
68                 const Shape& outputShape) {
69     return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
70 }
71 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
72                                 const Shape& outputShape);
73 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
74                                    _Float16* outputData, const Shape& outputShape);
75 
tanhFloat16(const _Float16 * inputData,const Shape & inputShape,_Float16 * outputData,const Shape & outputShape)76 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
77                  const Shape& outputShape) {
78     NNTRACE_COMP("tanhFloat16");
79     int numElements = getNumberOfElements(inputShape);
80     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
81         *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
82     }
83     return true;
84 }
85 
tanhFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape & outputShape)86 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
87                  const Shape& outputShape) {
88     NNTRACE_COMP("tanhFloat32");
89     int numElements = getNumberOfElements(inputShape);
90     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
91         *outputData = std::tanh(*inputData);
92     }
93     return true;
94 }
95 
96 template <typename T>
logisticFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)97 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
98                    const Shape& outputShape) {
99     NNTRACE_COMP("logisticFloat");
100     int numElements = getNumberOfElements(inputShape);
101     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
102         *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
103     }
104     return true;
105 }
106 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
107                                    float* outputData, const Shape& outputShape);
108 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
109                                       _Float16* outputData, const Shape& outputShape);
110 
111 #define ANDROID_NN_RELUX_QUANT8(activation)                                           \
112     int numElements = getNumberOfElements(inputShape);                                \
113     int32_t output_activation_min = 0;                                                \
114     int32_t output_activation_max = 0;                                                \
115                                                                                       \
116     CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,     \
117                                   &output_activation_max);                            \
118                                                                                       \
119     for (int i = 0; i < numElements; i++, inputData++, outputData++) {                \
120         *outputData = std::min((uint8_t)output_activation_max,                        \
121                                std::max((uint8_t)output_activation_min, *inputData)); \
122     }
123 
reluQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)124 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
125                 const Shape& outputShape) {
126     NNTRACE_COMP("reluQuant8");
127     ANDROID_NN_RELUX_QUANT8(kActivationRelu)
128     return true;
129 }
130 
relu1Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)131 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
132                  const Shape& outputShape) {
133     NNTRACE_COMP("relu1Quant8");
134     ANDROID_NN_RELUX_QUANT8(kActivationRelu1)
135     return true;
136 }
137 
relu6Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)138 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
139                  const Shape& outputShape) {
140     NNTRACE_COMP("relu6Quant8");
141     ANDROID_NN_RELUX_QUANT8(kActivationRelu6)
142     return true;
143 }
144 
145 #undef ANDROID_NN_RELUX_QUANT8
146 
tanhQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)147 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
148                 const Shape& outputShape) {
149     NNTRACE_TRANS("tanhQuant8");
150     if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
151         LOG(ERROR) << "incorrect scale or offset for TANH output";
152         return false;
153     }
154 
155     int numElements = getNumberOfElements(inputShape);
156     static constexpr int kInputIntegerBits = 4;
157 
158     const double input_real_multiplier =
159             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
160 
161     int32_t input_multiplier = 0;
162     int32_t input_left_shift = 0;
163     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
164                                           &input_left_shift)) {
165         return false;
166     }
167     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
168 
169     NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
170     tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
171                                 input_range_radius, input_multiplier, input_left_shift, outputData,
172                                 convertShapeToTflshape(outputShape));
173 
174     return true;
175 }
176 
logisticQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)177 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
178                     const Shape& outputShape) {
179     NNTRACE_TRANS("logisticQuant8");
180     if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
181         LOG(ERROR) << "incorrect scale / offset for output";
182         return false;
183     }
184 
185     int numElements = getNumberOfElements(inputShape);
186     static constexpr int kInputIntegerBits = 4;
187 
188     const double input_real_multiplier =
189             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
190 
191     int32_t input_multiplier = 0;
192     int32_t input_left_shift = 0;
193     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
194                                           &input_left_shift)) {
195         return false;
196     }
197     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
198 
199     NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
200     tflite::optimized_ops::Logistic(
201             inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
202             input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
203 
204     return true;
205 }
206 
207 }  // namespace
208 
validate(OperationType opType,const IOperationValidationContext * context)209 bool validate(OperationType opType, const IOperationValidationContext* context) {
210     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
211     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
212     auto inputType = context->getInputType(kInputTensor);
213     if (inputType == OperandType::TENSOR_FLOAT32) {
214         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
215     } else if (inputType == OperandType::TENSOR_FLOAT16) {
216         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
217     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
218         if (opType == OperationType::TANH) {
219             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
220         } else {
221             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
222         }
223     } else {
224         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
225     }
226     return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
227 }
228 
prepare(OperationType opType,IOperationExecutionContext * context)229 bool prepare(OperationType opType, IOperationExecutionContext* context) {
230     Shape input = context->getInputShape(kInputTensor);
231     NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
232     Shape output = input;
233     if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
234         switch (opType) {
235             case OperationType::RELU:
236             case OperationType::RELU1:
237             case OperationType::RELU6:
238                 break;
239             case OperationType::LOGISTIC:
240                 output.scale = 1.f / 256;
241                 output.offset = 0;
242                 break;
243             case OperationType::TANH:
244                 output.scale = 1.f / 128;
245                 output.offset = 128;
246                 break;
247             default:
248                 NN_RET_CHECK_FAIL() << "Unsupported operation type";
249         }
250     }
251     return context->setOutputShape(kOutputTensor, output);
252 }
253 
executeRelu(IOperationExecutionContext * context)254 bool executeRelu(IOperationExecutionContext* context) {
255     // Bypass execution in the case of zero-sized input.
256     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
257     switch (context->getInputType(kInputTensor)) {
258         case OperandType::TENSOR_FLOAT16:
259             return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
260                              context->getInputShape(kInputTensor),
261                              context->getOutputBuffer<_Float16>(kOutputTensor),
262                              context->getOutputShape(kOutputTensor));
263         case OperandType::TENSOR_FLOAT32:
264             return reluFloat(context->getInputBuffer<float>(kInputTensor),
265                              context->getInputShape(kInputTensor),
266                              context->getOutputBuffer<float>(kOutputTensor),
267                              context->getOutputShape(kOutputTensor));
268         case OperandType::TENSOR_QUANT8_ASYMM:
269             return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
270                               context->getInputShape(kInputTensor),
271                               context->getOutputBuffer<uint8_t>(kOutputTensor),
272                               context->getOutputShape(kOutputTensor));
273         default:
274             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
275     }
276 }
277 
executeRelu1(IOperationExecutionContext * context)278 bool executeRelu1(IOperationExecutionContext* context) {
279     // Bypass execution in the case of zero-sized input.
280     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
281     switch (context->getInputType(kInputTensor)) {
282         case OperandType::TENSOR_FLOAT16:
283             return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
284                               context->getInputShape(kInputTensor),
285                               context->getOutputBuffer<_Float16>(kOutputTensor),
286                               context->getOutputShape(kOutputTensor));
287         case OperandType::TENSOR_FLOAT32:
288             return relu1Float(context->getInputBuffer<float>(kInputTensor),
289                               context->getInputShape(kInputTensor),
290                               context->getOutputBuffer<float>(kOutputTensor),
291                               context->getOutputShape(kOutputTensor));
292         case OperandType::TENSOR_QUANT8_ASYMM:
293             return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
294                                context->getInputShape(kInputTensor),
295                                context->getOutputBuffer<uint8_t>(kOutputTensor),
296                                context->getOutputShape(kOutputTensor));
297         default:
298             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
299     }
300 }
301 
executeRelu6(IOperationExecutionContext * context)302 bool executeRelu6(IOperationExecutionContext* context) {
303     // Bypass execution in the case of zero-sized input.
304     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
305     switch (context->getInputType(kInputTensor)) {
306         case OperandType::TENSOR_FLOAT16:
307             return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
308                               context->getInputShape(kInputTensor),
309                               context->getOutputBuffer<_Float16>(kOutputTensor),
310                               context->getOutputShape(kOutputTensor));
311         case OperandType::TENSOR_FLOAT32:
312             return relu6Float(context->getInputBuffer<float>(kInputTensor),
313                               context->getInputShape(kInputTensor),
314                               context->getOutputBuffer<float>(kOutputTensor),
315                               context->getOutputShape(kOutputTensor));
316         case OperandType::TENSOR_QUANT8_ASYMM:
317             return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
318                                context->getInputShape(kInputTensor),
319                                context->getOutputBuffer<uint8_t>(kOutputTensor),
320                                context->getOutputShape(kOutputTensor));
321         default:
322             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
323     }
324 }
325 
executeLogistic(IOperationExecutionContext * context)326 bool executeLogistic(IOperationExecutionContext* context) {
327     // Bypass execution in the case of zero-sized input.
328     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
329     switch (context->getInputType(kInputTensor)) {
330         case OperandType::TENSOR_FLOAT16:
331             return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
332                                  context->getInputShape(kInputTensor),
333                                  context->getOutputBuffer<_Float16>(kOutputTensor),
334                                  context->getOutputShape(kOutputTensor));
335         case OperandType::TENSOR_FLOAT32:
336             return logisticFloat(context->getInputBuffer<float>(kInputTensor),
337                                  context->getInputShape(kInputTensor),
338                                  context->getOutputBuffer<float>(kOutputTensor),
339                                  context->getOutputShape(kOutputTensor));
340         case OperandType::TENSOR_QUANT8_ASYMM:
341             return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
342                                   context->getInputShape(kInputTensor),
343                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
344                                   context->getOutputShape(kOutputTensor));
345         default:
346             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
347     }
348 }
349 
executeTanh(IOperationExecutionContext * context)350 bool executeTanh(IOperationExecutionContext* context) {
351     // Bypass execution in the case of zero-sized input.
352     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
353     switch (context->getInputType(kInputTensor)) {
354         case OperandType::TENSOR_FLOAT16:
355             return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
356                                context->getInputShape(kInputTensor),
357                                context->getOutputBuffer<_Float16>(kOutputTensor),
358                                context->getOutputShape(kOutputTensor));
359         case OperandType::TENSOR_FLOAT32:
360             return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
361                                context->getInputShape(kInputTensor),
362                                context->getOutputBuffer<float>(kOutputTensor),
363                                context->getOutputShape(kOutputTensor));
364         case OperandType::TENSOR_QUANT8_ASYMM:
365             return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
366                               context->getInputShape(kInputTensor),
367                               context->getOutputBuffer<uint8_t>(kOutputTensor),
368                               context->getOutputShape(kOutputTensor));
369         default:
370             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
371     }
372 }
373 
374 }  // namespace activation
375 
376 using std::placeholders::_1;
377 NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
378                       std::bind(activation::prepare, OperationType::RELU, _1),
379                       activation::executeRelu, .allowZeroSizedInput = true);
380 NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
381                       std::bind(activation::prepare, OperationType::RELU1, _1),
382                       activation::executeRelu1, .allowZeroSizedInput = true);
383 NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
384                       std::bind(activation::prepare, OperationType::RELU6, _1),
385                       activation::executeRelu6, .allowZeroSizedInput = true);
386 NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
387                       std::bind(activation::validate, OperationType::LOGISTIC, _1),
388                       std::bind(activation::prepare, OperationType::LOGISTIC, _1),
389                       activation::executeLogistic, .allowZeroSizedInput = true);
390 NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
391                       std::bind(activation::prepare, OperationType::TANH, _1),
392                       activation::executeTanh, .allowZeroSizedInput = true);
393 
394 }  // namespace nn
395 }  // namespace android
396