1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/common.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 #include "tensorflow/lite/micro/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/micro_utils.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace activations {
31 namespace {
32
33 struct ReluOpData {
34 ReluParams params;
35 };
36
37 struct Relu6OpData {
38 int8_t six_int8;
39 int8_t zero_int8;
40 uint8_t six_uint8;
41 uint8_t zero_uint8;
42 };
43
44 } // namespace
45
46 constexpr int kInputTensor = 0;
47 constexpr int kOutputTensor = 0;
48
49 template <typename T>
ReluQuantized(const ReluOpData & data,const RuntimeShape & input_shape,const RuntimeShape & output_shape,const T * input_data,T * output_data)50 inline void ReluQuantized(const ReluOpData& data,
51 const RuntimeShape& input_shape,
52 const RuntimeShape& output_shape, const T* input_data,
53 T* output_data) {
54 const int flat_size = MatchingFlatSize(input_shape, output_shape);
55 for (int i = 0; i < flat_size; ++i) {
56 const int32_t val = static_cast<int32_t>(input_data[i]);
57 int32_t clamped =
58 data.params.output_offset +
59 MultiplyByQuantizedMultiplier(val - data.params.input_offset,
60 data.params.output_multiplier,
61 data.params.output_shift);
62 clamped = std::max(data.params.quantized_activation_min, clamped);
63 clamped = std::min(data.params.quantized_activation_max, clamped);
64 output_data[i] = static_cast<T>(clamped);
65 }
66 }
67
68 template <typename T>
CalculateReluOpData(const TfLiteTensor * input,TfLiteTensor * output,ReluOpData * data)69 inline void CalculateReluOpData(const TfLiteTensor* input, TfLiteTensor* output,
70 ReluOpData* data) {
71 float act_min = 0.0;
72 float act_max = std::numeric_limits<float>::infinity();
73 double real_multiplier =
74 static_cast<double>(input->params.scale / output->params.scale);
75
76 const RuntimeShape input_shape = GetTensorShape(input);
77 const RuntimeShape output_shape = GetTensorShape(output);
78
79 QuantizeMultiplier(real_multiplier, &data->params.output_multiplier,
80 &data->params.output_shift);
81
82 data->params.quantized_activation_min = std::max(
83 static_cast<int32_t>(std::numeric_limits<T>::min()),
84 output->params.zero_point +
85 static_cast<int32_t>(roundf(act_min / output->params.scale)));
86 data->params.quantized_activation_max =
87 act_max == std::numeric_limits<float>::infinity()
88 ? static_cast<int32_t>(std::numeric_limits<T>::max())
89 : std::min(static_cast<int32_t>(std::numeric_limits<T>::max()),
90 output->params.zero_point +
91 static_cast<int32_t>(
92 roundf(act_max / output->params.scale)));
93 data->params.input_offset = input->params.zero_point;
94 data->params.output_offset = output->params.zero_point;
95 }
96
ReluFloat(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)97 inline void ReluFloat(const RuntimeShape& input_shape, const float* input_data,
98 const RuntimeShape& output_shape, float* output_data) {
99 const int flat_size = MatchingFlatSize(input_shape, output_shape);
100 for (int i = 0; i < flat_size; ++i) {
101 const float val = input_data[i];
102 const float lower = 0.0f;
103 const float clamped = val < lower ? lower : val;
104 output_data[i] = clamped;
105 }
106 }
107
Relu6Float(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)108 inline void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
109 const RuntimeShape& output_shape, float* output_data) {
110 const int flat_size = MatchingFlatSize(input_shape, output_shape);
111 for (int i = 0; i < flat_size; ++i) {
112 const float val = input_data[i];
113 const float upper = 6.0f;
114 const float lower = 0.0f;
115 const float clamped = val > upper ? upper : val < lower ? lower : val;
116 output_data[i] = clamped;
117 }
118 }
119
120 template <typename Q>
Relu6Quantized(Q lower,Q upper,const RuntimeShape & input_shape,const Q * input_data,const RuntimeShape & output_shape,Q * output_data)121 inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape,
122 const Q* input_data,
123 const RuntimeShape& output_shape, Q* output_data) {
124 const int flat_size = MatchingFlatSize(input_shape, output_shape);
125 for (int i = 0; i < flat_size; ++i) {
126 const Q val = input_data[i];
127 const Q clamped = val > upper ? upper : val < lower ? lower : val;
128 output_data[i] = clamped;
129 }
130 }
131
ReluInit(TfLiteContext * context,const char * buffer,size_t length)132 void* ReluInit(TfLiteContext* context, const char* buffer, size_t length) {
133 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
134 return context->AllocatePersistentBuffer(context, sizeof(ReluOpData));
135 }
136
ReluPrepare(TfLiteContext * context,TfLiteNode * node)137 TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
138 TFLITE_DCHECK(node->user_data != nullptr);
139 ReluOpData* data = static_cast<ReluOpData*>(node->user_data);
140
141 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
142 TF_LITE_ENSURE(context, input != nullptr);
143 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
144 TF_LITE_ENSURE(context, output != nullptr);
145
146 if (input->type == kTfLiteInt8) {
147 CalculateReluOpData<int8_t>(input, output, data);
148 } else if (input->type == kTfLiteUInt8) {
149 CalculateReluOpData<uint8_t>(input, output, data);
150 }
151
152 return kTfLiteOk;
153 }
154
ReluEval(TfLiteContext * context,TfLiteNode * node)155 TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
156 TFLITE_DCHECK(node->user_data != nullptr);
157 const ReluOpData& data = *(static_cast<const ReluOpData*>(node->user_data));
158
159 const TfLiteEvalTensor* input =
160 tflite::micro::GetEvalInput(context, node, kInputTensor);
161 TfLiteEvalTensor* output =
162 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
163
164 switch (input->type) {
165 case kTfLiteFloat32: {
166 ReluFloat(tflite::micro::GetTensorShape(input),
167 tflite::micro::GetTensorData<float>(input),
168 tflite::micro::GetTensorShape(output),
169 tflite::micro::GetTensorData<float>(output));
170
171 return kTfLiteOk;
172 }
173 case kTfLiteInt8: {
174 ReluQuantized<int8_t>(data, tflite::micro::GetTensorShape(input),
175 tflite::micro::GetTensorShape(output),
176 tflite::micro::GetTensorData<int8_t>(input),
177 tflite::micro::GetTensorData<int8_t>(output));
178 return kTfLiteOk;
179 }
180 case kTfLiteUInt8: {
181 ReluQuantized<uint8_t>(data, tflite::micro::GetTensorShape(input),
182 tflite::micro::GetTensorShape(output),
183 tflite::micro::GetTensorData<uint8_t>(input),
184 tflite::micro::GetTensorData<uint8_t>(output));
185 return kTfLiteOk;
186 }
187 default: {
188 TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
189 TfLiteTypeGetName(input->type));
190 return kTfLiteError;
191 }
192 }
193 }
194
Relu6Init(TfLiteContext * context,const char * buffer,size_t length)195 void* Relu6Init(TfLiteContext* context, const char* buffer, size_t length) {
196 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
197 return context->AllocatePersistentBuffer(context, sizeof(Relu6OpData));
198 }
199
Relu6Prepare(TfLiteContext * context,TfLiteNode * node)200 TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
201 TFLITE_DCHECK(node->user_data != nullptr);
202 Relu6OpData* data = static_cast<Relu6OpData*>(node->user_data);
203
204 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
205 TF_LITE_ENSURE(context, input != nullptr);
206
207 if (input->type == kTfLiteInt8) {
208 data->six_int8 = FloatToQuantizedType<int8_t>(6.0f, input->params.scale,
209 input->params.zero_point);
210 data->zero_int8 = input->params.zero_point;
211 } else if (input->type == kTfLiteUInt8) {
212 data->six_uint8 = FloatToQuantizedType<uint8_t>(6.0f, input->params.scale,
213 input->params.zero_point);
214 data->zero_uint8 = input->params.zero_point;
215 }
216
217 return kTfLiteOk;
218 }
219
Relu6Eval(TfLiteContext * context,TfLiteNode * node)220 TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
221 TFLITE_DCHECK(node->user_data != nullptr);
222 const Relu6OpData& data = *(static_cast<const Relu6OpData*>(node->user_data));
223
224 const TfLiteEvalTensor* input =
225 tflite::micro::GetEvalInput(context, node, kInputTensor);
226 TfLiteEvalTensor* output =
227 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
228
229 switch (input->type) {
230 case kTfLiteFloat32: {
231 Relu6Float(tflite::micro::GetTensorShape(input),
232 tflite::micro::GetTensorData<float>(input),
233 tflite::micro::GetTensorShape(output),
234 tflite::micro::GetTensorData<float>(output));
235
236 return kTfLiteOk;
237 }
238 case kTfLiteInt8: {
239 Relu6Quantized<int8_t>(data.zero_int8, data.six_int8,
240 tflite::micro::GetTensorShape(input),
241 tflite::micro::GetTensorData<int8_t>(input),
242 tflite::micro::GetTensorShape(output),
243 tflite::micro::GetTensorData<int8_t>(output));
244 return kTfLiteOk;
245 }
246 case kTfLiteUInt8: {
247 Relu6Quantized<uint8_t>(data.zero_uint8, data.six_uint8,
248 tflite::micro::GetTensorShape(input),
249 tflite::micro::GetTensorData<uint8_t>(input),
250 tflite::micro::GetTensorShape(output),
251 tflite::micro::GetTensorData<uint8_t>(output));
252 return kTfLiteOk;
253 }
254 default: {
255 TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
256 TfLiteTypeGetName(input->type));
257 return kTfLiteError;
258 }
259 }
260 }
261
262 } // namespace activations
263
Register_RELU()264 TfLiteRegistration Register_RELU() {
265 return {/*init=*/activations::ReluInit,
266 /*free=*/nullptr,
267 /*prepare=*/activations::ReluPrepare,
268 /*invoke=*/activations::ReluEval,
269 /*profiling_string=*/nullptr,
270 /*builtin_code=*/0,
271 /*custom_name=*/nullptr,
272 /*version=*/0};
273 }
274
Register_RELU6()275 TfLiteRegistration Register_RELU6() {
276 return {/*init=*/activations::Relu6Init,
277 /*free=*/nullptr,
278 /*prepare=*/activations::Relu6Prepare,
279 /*invoke=*/activations::Relu6Eval,
280 /*profiling_string=*/nullptr,
281 /*builtin_code=*/0,
282 /*custom_name=*/nullptr,
283 /*version=*/0};
284 }
285
286 } // namespace micro
287 } // namespace ops
288 } // namespace tflite
289