1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
17
18 #include <cstdint>
19
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/micro/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace micro {
29 namespace activations {
30 namespace {
31
CalculatePreluParams(const TfLiteTensor * input,const TfLiteTensor * alpha,TfLiteTensor * output,PreluParams * params)32 TfLiteStatus CalculatePreluParams(const TfLiteTensor* input,
33 const TfLiteTensor* alpha,
34 TfLiteTensor* output, PreluParams* params) {
35 if (output->type == kTfLiteInt8 || output->type == kTfLiteUInt8 ||
36 output->type == kTfLiteInt16) {
37 double real_multiplier_1 = static_cast<double>(input->params.scale) /
38 static_cast<double>(output->params.scale);
39 double real_multiplier_2 = static_cast<double>(input->params.scale) *
40 static_cast<double>(alpha->params.scale) /
41 static_cast<double>(output->params.scale);
42 QuantizeMultiplier(real_multiplier_1, ¶ms->output_multiplier_1,
43 ¶ms->output_shift_1);
44 QuantizeMultiplier(real_multiplier_2, ¶ms->output_multiplier_2,
45 ¶ms->output_shift_2);
46
47 params->input_offset = -input->params.zero_point;
48 params->alpha_offset = -alpha->params.zero_point;
49 params->output_offset = output->params.zero_point;
50 }
51
52 return kTfLiteOk;
53 }
54
55 } // namespace
56
BroadcastPrelu4DSlowFloat(const RuntimeShape & unextended_input1_shape,const float * input1_data,const RuntimeShape & unextended_input2_shape,const float * input2_data,const RuntimeShape & unextended_output_shape,float * output_data)57 inline void BroadcastPrelu4DSlowFloat(
58 const RuntimeShape& unextended_input1_shape, const float* input1_data,
59 const RuntimeShape& unextended_input2_shape, const float* input2_data,
60 const RuntimeShape& unextended_output_shape, float* output_data) {
61 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
62 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
63 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
64 const RuntimeShape output_shape =
65 RuntimeShape::ExtendedShape(4, unextended_output_shape);
66
67 NdArrayDesc<4> desc1;
68 NdArrayDesc<4> desc2;
69 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
70 unextended_input2_shape, &desc1, &desc2);
71
72 for (int b = 0; b < output_shape.Dims(0); ++b) {
73 for (int y = 0; y < output_shape.Dims(1); ++y) {
74 for (int x = 0; x < output_shape.Dims(2); ++x) {
75 for (int c = 0; c < output_shape.Dims(3); ++c) {
76 auto out_idx = Offset(output_shape, b, y, x, c);
77 auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
78 auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
79 auto in1_val = input1_data[in1_idx];
80 auto in2_val = input2_data[in2_idx];
81 output_data[out_idx] = in1_val >= 0.0f ? in1_val : in1_val * in2_val;
82 }
83 }
84 }
85 }
86 }
87
PreluInit(TfLiteContext * context,const char * buffer,size_t length)88 void* PreluInit(TfLiteContext* context, const char* buffer, size_t length) {
89 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
90 return context->AllocatePersistentBuffer(context, sizeof(PreluParams));
91 }
92
PreluPrepare(TfLiteContext * context,TfLiteNode * node)93 TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
94 TFLITE_DCHECK(node->user_data != nullptr);
95 PreluParams* params = static_cast<PreluParams*>(node->user_data);
96
97 const TfLiteTensor* input = GetInput(context, node, 0);
98 TF_LITE_ENSURE(context, input != nullptr);
99 const TfLiteTensor* alpha = GetInput(context, node, 1);
100 TF_LITE_ENSURE(context, alpha != nullptr);
101 TfLiteTensor* output = GetOutput(context, node, 0);
102 TF_LITE_ENSURE(context, output != nullptr);
103
104 return CalculatePreluParams(input, alpha, output, params);
105 }
106
PreluEval(TfLiteContext * context,TfLiteNode * node)107 TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
108 TFLITE_DCHECK(node->user_data != nullptr);
109 const PreluParams& params =
110 *(static_cast<const PreluParams*>(node->user_data));
111
112 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
113 const TfLiteEvalTensor* alpha = tflite::micro::GetEvalInput(context, node, 1);
114 TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
115
116 switch (input->type) {
117 case kTfLiteFloat32: {
118 BroadcastPrelu4DSlowFloat(tflite::micro::GetTensorShape(input),
119 tflite::micro::GetTensorData<float>(input),
120 tflite::micro::GetTensorShape(alpha),
121 tflite::micro::GetTensorData<float>(alpha),
122 tflite::micro::GetTensorShape(output),
123 tflite::micro::GetTensorData<float>(output));
124 return kTfLiteOk;
125 } break;
126 case kTfLiteUInt8: {
127 reference_ops::BroadcastPrelu4DSlow(
128 params, tflite::micro::GetTensorShape(input),
129 tflite::micro::GetTensorData<uint8_t>(input),
130 tflite::micro::GetTensorShape(alpha),
131 tflite::micro::GetTensorData<uint8_t>(alpha),
132 tflite::micro::GetTensorShape(output),
133 tflite::micro::GetTensorData<uint8_t>(output));
134 return kTfLiteOk;
135 } break;
136 case kTfLiteInt8: {
137 reference_ops::BroadcastPrelu4DSlow(
138 params, tflite::micro::GetTensorShape(input),
139 tflite::micro::GetTensorData<int8_t>(input),
140 tflite::micro::GetTensorShape(alpha),
141 tflite::micro::GetTensorData<int8_t>(alpha),
142 tflite::micro::GetTensorShape(output),
143 tflite::micro::GetTensorData<int8_t>(output));
144 return kTfLiteOk;
145 } break;
146 default:
147 TF_LITE_KERNEL_LOG(
148 context, "Only float32 and uint8_t are supported currently, got %d.",
149 TfLiteTypeGetName(input->type));
150 return kTfLiteError;
151 }
152 }
153
154 } // namespace activations
155
Register_PRELU()156 TfLiteRegistration Register_PRELU() {
157 return {/*init=*/activations::PreluInit,
158 /*free=*/nullptr,
159 /*prepare=*/activations::PreluPrepare,
160 /*invoke=*/activations::PreluEval,
161 /*profiling_string=*/nullptr,
162 /*builtin_code=*/0,
163 /*custom_name=*/nullptr,
164 /*version=*/0};
165 }
166
167 } // namespace micro
168 } // namespace ops
169 } // namespace tflite
170