• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/common.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 #include "tensorflow/lite/micro/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/micro_utils.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace activations {
31 namespace {
32 
33 struct ReluOpData {
34   ReluParams params;
35 };
36 
37 struct Relu6OpData {
38   int8_t six_int8;
39   int8_t zero_int8;
40   uint8_t six_uint8;
41   uint8_t zero_uint8;
42 };
43 
44 }  // namespace
45 
46 constexpr int kInputTensor = 0;
47 constexpr int kOutputTensor = 0;
48 
49 template <typename T>
ReluQuantized(const ReluOpData & data,const RuntimeShape & input_shape,const RuntimeShape & output_shape,const T * input_data,T * output_data)50 inline void ReluQuantized(const ReluOpData& data,
51                           const RuntimeShape& input_shape,
52                           const RuntimeShape& output_shape, const T* input_data,
53                           T* output_data) {
54   const int flat_size = MatchingFlatSize(input_shape, output_shape);
55   for (int i = 0; i < flat_size; ++i) {
56     const int32_t val = static_cast<int32_t>(input_data[i]);
57     int32_t clamped =
58         data.params.output_offset +
59         MultiplyByQuantizedMultiplier(val - data.params.input_offset,
60                                       data.params.output_multiplier,
61                                       data.params.output_shift);
62     clamped = std::max(data.params.quantized_activation_min, clamped);
63     clamped = std::min(data.params.quantized_activation_max, clamped);
64     output_data[i] = static_cast<T>(clamped);
65   }
66 }
67 
68 template <typename T>
CalculateReluOpData(const TfLiteTensor * input,TfLiteTensor * output,ReluOpData * data)69 inline void CalculateReluOpData(const TfLiteTensor* input, TfLiteTensor* output,
70                                 ReluOpData* data) {
71   float act_min = 0.0;
72   float act_max = std::numeric_limits<float>::infinity();
73   double real_multiplier =
74       static_cast<double>(input->params.scale / output->params.scale);
75 
76   const RuntimeShape input_shape = GetTensorShape(input);
77   const RuntimeShape output_shape = GetTensorShape(output);
78 
79   QuantizeMultiplier(real_multiplier, &data->params.output_multiplier,
80                      &data->params.output_shift);
81 
82   data->params.quantized_activation_min = std::max(
83       static_cast<int32_t>(std::numeric_limits<T>::min()),
84       output->params.zero_point +
85           static_cast<int32_t>(roundf(act_min / output->params.scale)));
86   data->params.quantized_activation_max =
87       act_max == std::numeric_limits<float>::infinity()
88           ? static_cast<int32_t>(std::numeric_limits<T>::max())
89           : std::min(static_cast<int32_t>(std::numeric_limits<T>::max()),
90                      output->params.zero_point +
91                          static_cast<int32_t>(
92                              roundf(act_max / output->params.scale)));
93   data->params.input_offset = input->params.zero_point;
94   data->params.output_offset = output->params.zero_point;
95 }
96 
ReluFloat(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)97 inline void ReluFloat(const RuntimeShape& input_shape, const float* input_data,
98                       const RuntimeShape& output_shape, float* output_data) {
99   const int flat_size = MatchingFlatSize(input_shape, output_shape);
100   for (int i = 0; i < flat_size; ++i) {
101     const float val = input_data[i];
102     const float lower = 0.0f;
103     const float clamped = val < lower ? lower : val;
104     output_data[i] = clamped;
105   }
106 }
107 
Relu6Float(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)108 inline void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
109                        const RuntimeShape& output_shape, float* output_data) {
110   const int flat_size = MatchingFlatSize(input_shape, output_shape);
111   for (int i = 0; i < flat_size; ++i) {
112     const float val = input_data[i];
113     const float upper = 6.0f;
114     const float lower = 0.0f;
115     const float clamped = val > upper ? upper : val < lower ? lower : val;
116     output_data[i] = clamped;
117   }
118 }
119 
120 template <typename Q>
Relu6Quantized(Q lower,Q upper,const RuntimeShape & input_shape,const Q * input_data,const RuntimeShape & output_shape,Q * output_data)121 inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape,
122                            const Q* input_data,
123                            const RuntimeShape& output_shape, Q* output_data) {
124   const int flat_size = MatchingFlatSize(input_shape, output_shape);
125   for (int i = 0; i < flat_size; ++i) {
126     const Q val = input_data[i];
127     const Q clamped = val > upper ? upper : val < lower ? lower : val;
128     output_data[i] = clamped;
129   }
130 }
131 
ReluInit(TfLiteContext * context,const char * buffer,size_t length)132 void* ReluInit(TfLiteContext* context, const char* buffer, size_t length) {
133   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
134   return context->AllocatePersistentBuffer(context, sizeof(ReluOpData));
135 }
136 
ReluPrepare(TfLiteContext * context,TfLiteNode * node)137 TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
138   TFLITE_DCHECK(node->user_data != nullptr);
139   ReluOpData* data = static_cast<ReluOpData*>(node->user_data);
140 
141   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
142   TF_LITE_ENSURE(context, input != nullptr);
143   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
144   TF_LITE_ENSURE(context, output != nullptr);
145 
146   if (input->type == kTfLiteInt8) {
147     CalculateReluOpData<int8_t>(input, output, data);
148   } else if (input->type == kTfLiteUInt8) {
149     CalculateReluOpData<uint8_t>(input, output, data);
150   }
151 
152   return kTfLiteOk;
153 }
154 
ReluEval(TfLiteContext * context,TfLiteNode * node)155 TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
156   TFLITE_DCHECK(node->user_data != nullptr);
157   const ReluOpData& data = *(static_cast<const ReluOpData*>(node->user_data));
158 
159   const TfLiteEvalTensor* input =
160       tflite::micro::GetEvalInput(context, node, kInputTensor);
161   TfLiteEvalTensor* output =
162       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
163 
164   switch (input->type) {
165     case kTfLiteFloat32: {
166       ReluFloat(tflite::micro::GetTensorShape(input),
167                 tflite::micro::GetTensorData<float>(input),
168                 tflite::micro::GetTensorShape(output),
169                 tflite::micro::GetTensorData<float>(output));
170 
171       return kTfLiteOk;
172     }
173     case kTfLiteInt8: {
174       ReluQuantized<int8_t>(data, tflite::micro::GetTensorShape(input),
175                             tflite::micro::GetTensorShape(output),
176                             tflite::micro::GetTensorData<int8_t>(input),
177                             tflite::micro::GetTensorData<int8_t>(output));
178       return kTfLiteOk;
179     }
180     case kTfLiteUInt8: {
181       ReluQuantized<uint8_t>(data, tflite::micro::GetTensorShape(input),
182                              tflite::micro::GetTensorShape(output),
183                              tflite::micro::GetTensorData<uint8_t>(input),
184                              tflite::micro::GetTensorData<uint8_t>(output));
185       return kTfLiteOk;
186     }
187     default: {
188       TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
189                          TfLiteTypeGetName(input->type));
190       return kTfLiteError;
191     }
192   }
193 }
194 
Relu6Init(TfLiteContext * context,const char * buffer,size_t length)195 void* Relu6Init(TfLiteContext* context, const char* buffer, size_t length) {
196   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
197   return context->AllocatePersistentBuffer(context, sizeof(Relu6OpData));
198 }
199 
Relu6Prepare(TfLiteContext * context,TfLiteNode * node)200 TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
201   TFLITE_DCHECK(node->user_data != nullptr);
202   Relu6OpData* data = static_cast<Relu6OpData*>(node->user_data);
203 
204   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
205   TF_LITE_ENSURE(context, input != nullptr);
206 
207   if (input->type == kTfLiteInt8) {
208     data->six_int8 = FloatToQuantizedType<int8_t>(6.0f, input->params.scale,
209                                                   input->params.zero_point);
210     data->zero_int8 = input->params.zero_point;
211   } else if (input->type == kTfLiteUInt8) {
212     data->six_uint8 = FloatToQuantizedType<uint8_t>(6.0f, input->params.scale,
213                                                     input->params.zero_point);
214     data->zero_uint8 = input->params.zero_point;
215   }
216 
217   return kTfLiteOk;
218 }
219 
Relu6Eval(TfLiteContext * context,TfLiteNode * node)220 TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
221   TFLITE_DCHECK(node->user_data != nullptr);
222   const Relu6OpData& data = *(static_cast<const Relu6OpData*>(node->user_data));
223 
224   const TfLiteEvalTensor* input =
225       tflite::micro::GetEvalInput(context, node, kInputTensor);
226   TfLiteEvalTensor* output =
227       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
228 
229   switch (input->type) {
230     case kTfLiteFloat32: {
231       Relu6Float(tflite::micro::GetTensorShape(input),
232                  tflite::micro::GetTensorData<float>(input),
233                  tflite::micro::GetTensorShape(output),
234                  tflite::micro::GetTensorData<float>(output));
235 
236       return kTfLiteOk;
237     }
238     case kTfLiteInt8: {
239       Relu6Quantized<int8_t>(data.zero_int8, data.six_int8,
240                              tflite::micro::GetTensorShape(input),
241                              tflite::micro::GetTensorData<int8_t>(input),
242                              tflite::micro::GetTensorShape(output),
243                              tflite::micro::GetTensorData<int8_t>(output));
244       return kTfLiteOk;
245     }
246     case kTfLiteUInt8: {
247       Relu6Quantized<uint8_t>(data.zero_uint8, data.six_uint8,
248                               tflite::micro::GetTensorShape(input),
249                               tflite::micro::GetTensorData<uint8_t>(input),
250                               tflite::micro::GetTensorShape(output),
251                               tflite::micro::GetTensorData<uint8_t>(output));
252       return kTfLiteOk;
253     }
254     default: {
255       TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
256                          TfLiteTypeGetName(input->type));
257       return kTfLiteError;
258     }
259   }
260 }
261 
262 }  // namespace activations
263 
Register_RELU()264 TfLiteRegistration Register_RELU() {
265   return {/*init=*/activations::ReluInit,
266           /*free=*/nullptr,
267           /*prepare=*/activations::ReluPrepare,
268           /*invoke=*/activations::ReluEval,
269           /*profiling_string=*/nullptr,
270           /*builtin_code=*/0,
271           /*custom_name=*/nullptr,
272           /*version=*/0};
273 }
274 
Register_RELU6()275 TfLiteRegistration Register_RELU6() {
276   return {/*init=*/activations::Relu6Init,
277           /*free=*/nullptr,
278           /*prepare=*/activations::Relu6Prepare,
279           /*invoke=*/activations::Relu6Eval,
280           /*profiling_string=*/nullptr,
281           /*builtin_code=*/0,
282           /*custom_name=*/nullptr,
283           /*version=*/0};
284 }
285 
286 }  // namespace micro
287 }  // namespace ops
288 }  // namespace tflite
289