• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "PRelu.h"
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include "IndexedShapeWrapper.h"
25 #include "OperationResolver.h"
26 #include "OperationsExecutionUtils.h"
27 #include "Tracing.h"
28 
29 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
30 #pragma clang diagnostic push
31 #pragma clang diagnostic ignored "-Wunused-parameter"
32 #pragma clang diagnostic ignored "-Wsign-compare"
33 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
34 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
35 #pragma clang diagnostic pop
36 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
37 
38 namespace android {
39 namespace nn {
40 namespace prelu {
41 
42 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
43 template <typename T>
eval(const std::function<T (const T &,const T &)> & func,const T * aData,const Shape & aShape,const T * bData,const Shape & bShape,T * outputData,const Shape & outputShape)44 inline bool eval(const std::function<T(const T&, const T&)>& func, const T* aData,
45                  const Shape& aShape, const T* bData, const Shape& bShape, T* outputData,
46                  const Shape& outputShape) {
47     IndexedShapeWrapper aShapeIndexed(aShape);
48     IndexedShapeWrapper bShapeIndexed(bShape);
49     IndexedShapeWrapper outputShapeIndexed(outputShape);
50     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
51     bool lastIndex = false;
52     do {
53         uint32_t outputFlatIndex;
54         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
55         uint32_t aFlatIndex;
56         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
57         uint32_t bFlatIndex;
58         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
59 
60         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
61 
62         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
63     } while (!lastIndex);
64     return true;
65 }
66 
67 template <typename T>
evalQuant8(const T * aData,const Shape & aShape,const T * bData,const Shape & bShape,T * outputData,const Shape & outputShape)68 bool evalQuant8(const T* aData, const Shape& aShape, const T* bData, const Shape& bShape,
69                 T* outputData, const Shape& outputShape) {
70     const int32_t input_offset = -aShape.offset;
71     const int32_t alpha_offset = -bShape.offset;
72     const int32_t output_offset = outputShape.offset;
73     const double input_product_scale = aShape.scale * bShape.scale;
74     const double real_multiplier_pos = aShape.scale / outputShape.scale;
75     const double real_multiplier_neg = input_product_scale / outputShape.scale;
76     int32_t output_multiplier_pos, output_shift_pos;
77     int32_t output_multiplier_neg, output_shift_neg;
78     tflite::QuantizeMultiplier(real_multiplier_pos, &output_multiplier_pos, &output_shift_pos);
79     tflite::QuantizeMultiplier(real_multiplier_neg, &output_multiplier_neg, &output_shift_neg);
80     return eval<T>(
81             [&](const T& val1, const T& val2) -> uint8_t {
82                 const int32_t input = input_offset + static_cast<int32_t>(val1);
83                 int32_t output_val;
84                 if (input >= 0) {
85                     output_val =
86                             output_offset + tflite::MultiplyByQuantizedMultiplier(
87                                                     input, output_multiplier_pos, output_shift_pos);
88                 } else {
89                     const int32_t alpha = alpha_offset + static_cast<int32_t>(val2);
90                     output_val = output_offset +
91                                  tflite::MultiplyByQuantizedMultiplier(
92                                          input * alpha, output_multiplier_neg, output_shift_neg);
93                 }
94                 return saturateCast<T>(output_val);
95             },
96             aData, aShape, bData, bShape, outputData, outputShape);
97 }
98 
prepare(IOperationExecutionContext * context)99 bool prepare(IOperationExecutionContext* context) {
100     Shape input = context->getInputShape(kInputTensor);
101     Shape alpha = context->getInputShape(kAlphaTensor);
102     NN_RET_CHECK(input.type == alpha.type);
103     Shape output = context->getOutputShape(kOutputTensor);
104     NN_RET_CHECK(calculateBroadcastedShape(input, alpha, &output));
105     return context->setOutputShape(kOutputTensor, output);
106 }
107 
execute(IOperationExecutionContext * context)108 bool execute(IOperationExecutionContext* context) {
109     switch (context->getInputType(kInputTensor)) {
110         case OperandType::TENSOR_FLOAT16:
111             return eval<_Float16>(
112                     [](const _Float16& val1, const _Float16& val2) -> _Float16 {
113                         return val1 >= 0.0f ? val1 : val1 * val2;
114                     },
115                     context->getInputBuffer<_Float16>(kInputTensor),
116                     context->getInputShape(kInputTensor),
117                     context->getInputBuffer<_Float16>(kAlphaTensor),
118                     context->getInputShape(kAlphaTensor),
119                     context->getOutputBuffer<_Float16>(kOutputTensor),
120                     context->getOutputShape(kOutputTensor));
121         case OperandType::TENSOR_FLOAT32:
122             return eval<float>(
123                     [](const float& val1, const float& val2) -> float {
124                         return val1 >= 0.0f ? val1 : val1 * val2;
125                     },
126                     context->getInputBuffer<float>(kInputTensor),
127                     context->getInputShape(kInputTensor),
128                     context->getInputBuffer<float>(kAlphaTensor),
129                     context->getInputShape(kAlphaTensor),
130                     context->getOutputBuffer<float>(kOutputTensor),
131                     context->getOutputShape(kOutputTensor));
132         case OperandType::TENSOR_QUANT8_ASYMM: {
133             return evalQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
134                               context->getInputShape(kInputTensor),
135                               context->getInputBuffer<uint8_t>(kAlphaTensor),
136                               context->getInputShape(kAlphaTensor),
137                               context->getOutputBuffer<uint8_t>(kOutputTensor),
138                               context->getOutputShape(kOutputTensor));
139         }
140         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
141             return evalQuant8(context->getInputBuffer<int8_t>(kInputTensor),
142                               context->getInputShape(kInputTensor),
143                               context->getInputBuffer<int8_t>(kAlphaTensor),
144                               context->getInputShape(kAlphaTensor),
145                               context->getOutputBuffer<int8_t>(kOutputTensor),
146                               context->getOutputShape(kOutputTensor));
147         }
148         default:
149             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
150     }
151 }
152 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
153 
154 }  // namespace prelu
155 
156 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(PRELU, prelu::prepare, prelu::execute);
157 
158 }  // namespace nn
159 }  // namespace android
160