/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "Operations" #include #include #include #include "ActivationFunctor.h" #include "OperationResolver.h" #include "OperationsUtils.h" #include "Tracing.h" #ifdef NN_INCLUDE_CPU_IMPLEMENTATION #include #include #include #include #include #include "CpuOperationUtils.h" #endif // NN_INCLUDE_CPU_IMPLEMENTATION namespace android { namespace nn { namespace activation { constexpr uint32_t kNumInputs = 1; constexpr uint32_t kInputTensor = 0; constexpr uint32_t kNumOutputs = 1; constexpr uint32_t kOutputTensor = 0; #ifdef NN_INCLUDE_CPU_IMPLEMENTATION namespace { template bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape, float reluMin = 0.f, float reluMax = std::numeric_limits::max()) { NNTRACE_COMP("reluX"); int numElements = getNumberOfElements(inputShape); for (int i = 0; i < numElements; i++, inputData++, outputData++) { *outputData = static_cast( std::min(std::max(reluMin, static_cast(*inputData)), reluMax)); } return true; } template bool reluFloat(const float* inputData, const Shape& inputShape, float* outputData, const Shape& outputShape, float reluMin, float reluMax); template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape, _Float16* outputData, const Shape& outputShape, float reluMin, float reluMax); template bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape) { return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f); } template bool relu1Float(const float* inputData, const Shape& inputShape, float* outputData, const Shape& outputShape); template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape, _Float16* outputData, const Shape& outputShape); template bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape) { return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f); } template bool relu6Float(const float* inputData, const Shape& inputShape, float* outputData, const Shape& outputShape); template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape, _Float16* outputData, const Shape& outputShape); bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData, const Shape& outputShape) { NNTRACE_COMP("tanhFloat16"); int numElements = getNumberOfElements(inputShape); for (int i = 0; i < numElements; i++, inputData++, outputData++) { *outputData = static_cast<_Float16>(std::tanh(static_cast(*inputData))); } return true; } bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData, const Shape& outputShape) { NNTRACE_COMP("tanhFloat32"); int numElements = getNumberOfElements(inputShape); for (int i = 0; i < numElements; i++, inputData++, outputData++) { *outputData = std::tanh(*inputData); } return true; } template bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape) { NNTRACE_COMP("logisticFloat"); int numElements = getNumberOfElements(inputShape); for (int i = 0; i < numElements; i++, inputData++, outputData++) { *outputData = static_cast(1.f / (1.f + std::exp(static_cast(-*inputData)))); } return true; } template bool logisticFloat(const float* inputData, const Shape& inputShape, float* outputData, const Shape& outputShape); template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape, _Float16* outputData, const Shape& outputShape); template inline bool reluXQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, const Shape& outputShape) { int numElements = getNumberOfElements(inputShape); int32_t output_activation_min = 0; int32_t output_activation_max = 0; CalculateActivationRangeUint8(activation, inputShape, &output_activation_min, &output_activation_max); for (int i = 0; i < numElements; i++, inputData++, outputData++) { *outputData = std::min((uint8_t)output_activation_max, std::max((uint8_t)output_activation_min, *inputData)); } return true; } bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, const Shape& outputShape) { NNTRACE_COMP("reluQuant8"); return reluXQuant8(inputData, inputShape, outputData, outputShape); } bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, const Shape& outputShape) { NNTRACE_COMP("relu1Quant8"); return reluXQuant8(inputData, inputShape, outputData, outputShape); } bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, const Shape& outputShape) { NNTRACE_COMP("relu6Quant8"); return reluXQuant8(inputData, inputShape, outputData, outputShape); } bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, const Shape& outputShape) { NNTRACE_TRANS("tanhQuant8"); if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) { LOG(ERROR) << "incorrect scale or offset for TANH output"; return false; } int numElements = getNumberOfElements(inputShape); static constexpr int kInputIntegerBits = 4; const double input_real_multiplier = inputShape.scale * static_cast(1 << (31 - kInputIntegerBits)); int32_t input_multiplier = 0; int32_t input_left_shift = 0; if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier, &input_left_shift)) { return false; } int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift); NNTRACE_COMP_SWITCH("optimized_ops::Tanh"); tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius, input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape)); return true; } bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, const Shape& outputShape) { NNTRACE_TRANS("logisticQuant8"); if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) { LOG(ERROR) << "incorrect scale / offset for output"; return false; } int numElements = getNumberOfElements(inputShape); static constexpr int kInputIntegerBits = 4; const double input_real_multiplier = inputShape.scale * static_cast(1 << (31 - kInputIntegerBits)); int32_t input_multiplier = 0; int32_t input_left_shift = 0; if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier, &input_left_shift)) { return false; } int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift); NNTRACE_COMP_SWITCH("optimized_ops::Logistic"); tflite::optimized_ops::Logistic( inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius, input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape)); return true; } template inline bool reluXQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData, const Shape& outputShape) { int numElements = getNumberOfElements(inputShape); int32_t output_activation_min = 0; int32_t output_activation_max = 0; CalculateActivationRangeInt8(activation, inputShape, &output_activation_min, &output_activation_max); for (int i = 0; i < numElements; i++, inputData++, outputData++) { *outputData = std::min((int8_t)output_activation_max, std::max((int8_t)output_activation_min, *inputData)); } return true; } bool reluQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData, const Shape& outputShape) { NNTRACE_COMP("reluQuant8"); return reluXQuant8Signed(inputData, inputShape, outputData, outputShape); } bool relu1Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData, const Shape& outputShape) { NNTRACE_COMP("relu1Quant8"); return reluXQuant8Signed(inputData, inputShape, outputData, outputShape); } bool relu6Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData, const Shape& outputShape) { NNTRACE_COMP("relu6Quant8"); return reluXQuant8Signed(inputData, inputShape, outputData, outputShape); } bool tanhQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData, const Shape& outputShape) { NNTRACE_TRANS("tanhQuant8Signed"); if (outputShape.offset != 0 || outputShape.scale != 1.f / 128) { LOG(ERROR) << "incorrect scale or offset for TANH output"; return false; } int numElements = getNumberOfElements(inputShape); static constexpr int kInputIntegerBits = 4; const double input_real_multiplier = inputShape.scale * static_cast(1 << (31 - kInputIntegerBits)); int32_t input_multiplier = 0; int32_t input_left_shift = 0; if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier, &input_left_shift)) { return false; } int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift); NNTRACE_COMP_SWITCH("reference_integer_ops::Tanh"); tflite::reference_integer_ops::Tanh(inputShape.offset, input_range_radius, input_multiplier, input_left_shift, convertShapeToTflshape(inputShape), inputData, convertShapeToTflshape(outputShape), outputData); return true; } bool logisticQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData, const Shape& outputShape) { NNTRACE_TRANS("logisticQuant8Signed"); if (outputShape.offset != -128 || outputShape.scale != 1.f / 256) { LOG(ERROR) << "incorrect scale / offset for output"; return false; } int numElements = getNumberOfElements(inputShape); static constexpr int kInputIntegerBits = 4; const double input_real_multiplier = inputShape.scale * static_cast(1 << (31 - kInputIntegerBits)); int32_t input_multiplier = 0; int32_t input_left_shift = 0; if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier, &input_left_shift)) { return false; } int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift); NNTRACE_COMP_SWITCH("reference_integer_ops::Logistic"); tflite::reference_integer_ops::Logistic(inputShape.offset, input_range_radius, input_multiplier, input_left_shift, numElements, inputData, outputData); return true; } void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32, int16_t* multiplier_int16) { TFLITE_DCHECK_GE(multiplier_int32, 0); static constexpr int32_t kRoundingOffset = 1 << 15; if (multiplier_int32 >= std::numeric_limits::max() - kRoundingOffset) { *multiplier_int16 = std::numeric_limits::max(); return; } const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16; TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset); TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset); *multiplier_int16 = result; TFLITE_DCHECK_EQ(*multiplier_int16, result); } template bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape) { tflite::HardSwishParams params; params.input_zero_point = inputShape.offset; params.output_zero_point = outputShape.offset; const float input_scale = inputShape.scale; const float hires_input_scale = (1.0f / 128.0f) * input_scale; const float reluish_scale = 3.0f / 32768.0f; const float output_scale = outputShape.scale; const float output_multiplier = hires_input_scale / output_scale; int32_t output_multiplier_fixedpoint_int32; NN_RET_CHECK(QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32, ¶ms.output_multiplier_exponent)); DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32, ¶ms.output_multiplier_fixedpoint_int16); NN_RET_CHECK(params.output_multiplier_exponent <= 0); const float reluish_multiplier = hires_input_scale / reluish_scale; int32_t reluish_multiplier_fixedpoint_int32; NN_RET_CHECK(QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32, ¶ms.reluish_multiplier_exponent)); DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32, ¶ms.reluish_multiplier_fixedpoint_int16); tflite::reference_ops::HardSwish(params, convertShapeToTflshape(inputShape), inputData, convertShapeToTflshape(outputShape), outputData); return true; } } // namespace #endif // NN_INCLUDE_CPU_IMPLEMENTATION Result validate(OperationType opType, const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_FLOAT16) { minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (opType == OperationType::TANH) { minSupportedVersion = Version::ANDROID_Q; } else { minSupportedVersion = Version::ANDROID_OC_MR1; } } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType; } const Shape& input = context->getInputShape(kInputTensor); if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } NN_RET_CHECK(validateInputTypes(context, {inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); return minSupportedVersion; } Result validateHardSwish(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU"; } NN_RET_CHECK(validateInputTypes(context, {inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); return minSupportedVersion; } #ifdef NN_INCLUDE_CPU_IMPLEMENTATION bool prepare(OperationType opType, IOperationExecutionContext* context) { Shape input = context->getInputShape(kInputTensor); if (opType != OperationType::HARD_SWISH) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } Shape output = input; if (input.type == OperandType::TENSOR_QUANT8_ASYMM || input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { bool isSigned = input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED; switch (opType) { case OperationType::HARD_SWISH: { auto outputShape = context->getOutputShape(kOutputTensor); output.scale = outputShape.scale; output.offset = outputShape.offset; } break; case OperationType::RELU: case OperationType::RELU1: case OperationType::RELU6: break; case OperationType::LOGISTIC: output.scale = 1.f / 256; output.offset = isSigned ? -128 : 0; break; case OperationType::TANH: output.scale = 1.f / 128; output.offset = isSigned ? 0 : 128; break; default: NN_RET_CHECK_FAIL() << "Unsupported operation type"; } } return context->setOutputShape(kOutputTensor, output); } bool executeRelu(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor)) { case OperandType::TENSOR_FLOAT16: return reluFloat(context->getInputBuffer<_Float16>(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return reluFloat(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return reluQuant8(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return reluQuant8Signed(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU"; } } bool executeRelu1(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor)) { case OperandType::TENSOR_FLOAT16: return relu1Float(context->getInputBuffer<_Float16>(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return relu1Float(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return relu1Quant8(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return relu1Quant8Signed(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1"; } } bool executeRelu6(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor)) { case OperandType::TENSOR_FLOAT16: return relu6Float(context->getInputBuffer<_Float16>(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return relu6Float(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return relu6Quant8(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return relu6Quant8Signed(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6"; } } bool executeLogistic(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor)) { case OperandType::TENSOR_FLOAT16: return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return logisticFloat(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return logisticQuant8(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return logisticQuant8Signed(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC"; } } bool executeTanh(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor)) { case OperandType::TENSOR_FLOAT16: return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return tanhFloat32(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return tanhQuant8(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return tanhQuant8Signed(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH"; } } bool executeHardSwish(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor)) { case OperandType::TENSOR_FLOAT16: { const Shape& inputShape = context->getInputShape(kInputTensor); const Shape& outputShape = context->getOutputShape(kOutputTensor); std::vector inputFloat(getNumberOfElements(inputShape)); std::vector outputFloat(getNumberOfElements(outputShape)); convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat); tflite::reference_ops::HardSwish(convertShapeToTflshape(inputShape), inputFloat.data(), convertShapeToTflshape(outputShape), outputFloat.data()); convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor)); return true; } case OperandType::TENSOR_FLOAT32: { tflite::reference_ops::HardSwish( convertShapeToTflshape(context->getInputShape(kInputTensor)), context->getInputBuffer(kInputTensor), convertShapeToTflshape(context->getOutputShape(kOutputTensor)), context->getOutputBuffer(kOutputTensor)); return true; } case OperandType::TENSOR_QUANT8_ASYMM: return hardSwishQuant(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return hardSwishQuant(context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH"; } } #endif // NN_INCLUDE_CPU_IMPLEMENTATION } // namespace activation using std::placeholders::_1; NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1), std::bind(activation::prepare, OperationType::RELU, _1), activation::executeRelu, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1), std::bind(activation::prepare, OperationType::RELU1, _1), activation::executeRelu1, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1), std::bind(activation::prepare, OperationType::RELU6, _1), activation::executeRelu6, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC", std::bind(activation::validate, OperationType::LOGISTIC, _1), std::bind(activation::prepare, OperationType::LOGISTIC, _1), activation::executeLogistic, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1), std::bind(activation::prepare, OperationType::TANH, _1), activation::executeTanh, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(HARD_SWISH, "HARD_SWISH", activation::validateHardSwish, std::bind(activation::prepare, OperationType::HARD_SWISH, _1), activation::executeHardSwish, .allowZeroSizedInput = true); } // namespace nn } // namespace android