• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 #include <stdlib.h>
18 
19 #include <cmath>
20 #include <limits>
21 
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/kernels/op_macros.h"
29 
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace elementwise {
34 namespace {
35 
36 const char kAbsName[] = "Abs";
37 const char kRsqrtName[] = "Rsqrt";
38 
39 struct OpData {
40   int32_t multiplier;
41   int32_t shift;
42   int input_offset;
43   int output_offset;
44   bool needs_rescale;
45 };
46 
IsNumericSupportedType(const TfLiteType type)47 bool IsNumericSupportedType(const TfLiteType type) {
48   return type == kTfLiteFloat32;
49 }
50 
IsLogicalSupportedType(const TfLiteType type)51 bool IsLogicalSupportedType(const TfLiteType type) {
52   return type == kTfLiteBool;
53 }
54 
IsAbsSupportedType(const TfLiteType type)55 bool IsAbsSupportedType(const TfLiteType type) {
56   return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
57 }
58 
IsRsqrtSupportedType(const TfLiteType type)59 bool IsRsqrtSupportedType(const TfLiteType type) {
60   return type == kTfLiteFloat32 || type == kTfLiteInt8;
61 }
62 
SetAbsOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)63 inline void SetAbsOutputMultiplier(const float input_scale,
64                                    const float output_scale,
65                                    int32_t* multiplier, int32_t* shift) {
66   QuantizeMultiplier(input_scale / output_scale, multiplier, shift);
67 }
68 
SetRsqrtOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)69 inline void SetRsqrtOutputMultiplier(const float input_scale,
70                                      const float output_scale,
71                                      int32_t* multiplier, int32_t* shift) {
72   const double scale = 1. / (std::sqrt(input_scale) * output_scale);
73   QuantizeMultiplier(scale, multiplier, shift);
74 }
75 
76 typedef bool (*IsSupportedType)(TfLiteType);
GenericPrepare(TfLiteContext * context,TfLiteNode * node,IsSupportedType is_supported_type,const char * op_name)77 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node,
78                             IsSupportedType is_supported_type,
79                             const char* op_name) {
80   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
81   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
82   const TfLiteTensor* input;
83   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
84   TfLiteTensor* output;
85   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
86   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
87   if (!is_supported_type(input->type)) {
88     TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
89   }
90   // For int16 type input, we support both quantized and non-quantized
91   // evaluation.
92   if (input->type == kTfLiteInt8 ||
93       (input->type == kTfLiteInt16 &&
94        input->quantization.type != kTfLiteNoQuantization)) {
95     TfLiteTensor* output = GetOutput(context, node, 0);
96     auto* op_data = static_cast<OpData*>(node->user_data);
97     TF_LITE_ENSURE_EQ(context, input->quantization.type,
98                       kTfLiteAffineQuantization);
99     TF_LITE_ENSURE_EQ(context, output->quantization.type,
100                       kTfLiteAffineQuantization);
101     const auto* input_params =
102         reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
103     const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
104         output->quantization.params);
105     TF_LITE_ENSURE(context, input_params != nullptr);
106     TF_LITE_ENSURE(context, input_params->scale != nullptr);
107     TF_LITE_ENSURE(context, input_params->scale->size > 0);
108     TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
109     TF_LITE_ENSURE(context, output_params != nullptr);
110     TF_LITE_ENSURE(context, output_params->scale != nullptr);
111     TF_LITE_ENSURE(context, output_params->scale->size > 0);
112     TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
113     op_data->input_offset = input_params->zero_point->data[0];
114     op_data->output_offset = output_params->zero_point->data[0];
115     if (input->type == kTfLiteInt16) {
116       TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
117       TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
118     }
119     const float input_scale = input_params->scale->data[0];
120     const float output_scale = output_params->scale->data[0];
121     op_data->needs_rescale = input_scale != output_scale;
122     if (op_name == kAbsName && op_data->needs_rescale) {
123       SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
124                              &op_data->shift);
125     } else if (op_name == kRsqrtName) {
126       SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
127                                &op_data->shift);
128     }
129   }
130   return context->ResizeTensor(context, output,
131                                TfLiteIntArrayCopy(input->dims));
132 }
133 
134 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,std::function<TfLiteStatus (T)> validate_input_func,TfLiteType expected_type)135 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
136                              std::function<T(T)> func,
137                              std::function<TfLiteStatus(T)> validate_input_func,
138                              TfLiteType expected_type) {
139   const TfLiteTensor* input;
140   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
141   TfLiteTensor* output;
142   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
143   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
144   const int64_t num_elements = NumElements(input);
145   const T* in_data = GetTensorData<T>(input);
146   T* out_data = GetTensorData<T>(output);
147   for (int64_t i = 0; i < num_elements; ++i) {
148     if (validate_input_func) {
149       TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
150     }
151     out_data[i] = func(in_data[i]);
152   }
153   return kTfLiteOk;
154 }
155 
156 // Non-quantized evaluation of Abs op when input is int16.
AbsInt16EvalImpl(TfLiteContext * context,TfLiteNode * node,TfLiteType expected_type)157 inline TfLiteStatus AbsInt16EvalImpl(TfLiteContext* context, TfLiteNode* node,
158                                      TfLiteType expected_type) {
159   const TfLiteTensor* input;
160   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
161   TfLiteTensor* output;
162   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
163   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
164   const int64_t num_elements = NumElements(input);
165   const int16_t* in_data = GetTensorData<int16_t>(input);
166   int16_t* out_data = GetTensorData<int16_t>(output);
167   for (int64_t i = 0; i < num_elements; ++i) {
168     out_data[i] = static_cast<int16_t>(
169         std::abs<int32_t>(static_cast<int32_t>(in_data[i])));
170   }
171   return kTfLiteOk;
172 }
173 
174 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,TfLiteType expected_type)175 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
176                              std::function<T(T)> func,
177                              TfLiteType expected_type) {
178   return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr,
179                      expected_type);
180 }
181 
EvalNumeric(TfLiteContext * context,TfLiteNode * node,float float_func (float))182 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
183                                 float float_func(float)) {
184   return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
185 }
186 
EvalLogical(TfLiteContext * context,TfLiteNode * node,bool bool_func (bool))187 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
188                                 bool bool_func(bool)) {
189   return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
190 }
191 
ElementWiseQuantizedInit(TfLiteContext * context,const char * buffer,size_t length)192 void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer,
193                                size_t length) {
194   return new OpData();
195 }
196 
ElementWiseQuantizedFree(TfLiteContext * context,void * buffer)197 void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) {
198   delete static_cast<OpData*>(buffer);
199 }
200 
201 template <typename T>
AbsEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)202 TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node,
203                               TfLiteType type) {
204   const auto* op_data = static_cast<const OpData*>(node->user_data);
205   const int kMin = std::numeric_limits<T>::min();
206   const int kMax = std::numeric_limits<T>::max();
207 
208   std::function<T(T)> func = [&](T i) {
209     const int32_t value = std::abs(i - op_data->input_offset);
210     if (!op_data->needs_rescale) {
211       return static_cast<T>(
212           std::min(std::max(value + op_data->output_offset, kMin), kMax));
213     }
214     const int32_t output = MultiplyByQuantizedMultiplier(
215                                value, op_data->multiplier, op_data->shift) +
216                            op_data->output_offset;
217     return static_cast<T>(std::min(std::max(output, kMin), kMax));
218   };
219 
220   return EvalImpl<T>(context, node, func, type);
221 }
222 
AbsEval(TfLiteContext * context,TfLiteNode * node)223 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
224   const TfLiteTensor* input = GetInput(context, node, 0);
225   const TfLiteType type = input->type;
226   switch (type) {
227     case kTfLiteFloat32:
228       return EvalImpl<float>(context, node, std::abs<float>, type);
229     case kTfLiteInt8:
230       return AbsEvalQuantized<int8_t>(context, node, type);
231     case kTfLiteInt16:
232       return input->quantization.type == kTfLiteNoQuantization
233                  ? AbsInt16EvalImpl(context, node, type)
234                  : AbsEvalQuantized<int16_t>(context, node, type);
235     default:
236       TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
237                          TfLiteTypeGetName(type));
238       return kTfLiteError;
239   }
240 }
241 
SinEval(TfLiteContext * context,TfLiteNode * node)242 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
243   return EvalNumeric(context, node, std::sin);
244 }
245 
CosEval(TfLiteContext * context,TfLiteNode * node)246 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
247   return EvalNumeric(context, node, std::cos);
248 }
249 
LogEval(TfLiteContext * context,TfLiteNode * node)250 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
251   return EvalNumeric(context, node, std::log);
252 }
253 
SqrtEval(TfLiteContext * context,TfLiteNode * node)254 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
255   return EvalNumeric(context, node, std::sqrt);
256 }
257 
RsqrtEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)258 TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node,
259                                 TfLiteType type) {
260   const auto* op_data = static_cast<const OpData*>(node->user_data);
261   const int kMin = std::numeric_limits<int8_t>::min();
262   const int kMax = std::numeric_limits<int8_t>::max();
263   std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) {
264     TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
265                        "Rsqrt is only defined for positive values");
266     return kTfLiteOk;
267   };
268 
269   std::function<int8_t(int8_t)> func = [&](int8_t i) {
270     const int32_t value = (i - op_data->input_offset);
271     const int32_t kShift = 20;  // Shift to keep value integer.
272     if (value == 0) {
273       // Assume that any value close to 0 represents the max output value.
274       return static_cast<int8_t>(kMax);
275     }
276     int32_t inv_sqrt_multiplier;
277     int inv_sqrt_shift;
278     GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
279                                      &inv_sqrt_shift);
280     const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier,
281                                                        inv_sqrt_shift + kShift);
282     const int32_t output =
283         MultiplyByQuantizedMultiplier(data, op_data->multiplier,
284                                       op_data->shift - kShift) +
285         op_data->output_offset;
286     return static_cast<int8_t>(std::min(std::max(output, kMin), kMax));
287   };
288 
289   return EvalImpl<int8_t>(context, node, func, validate_input_func, type);
290 }
291 
RsqrtEval(TfLiteContext * context,TfLiteNode * node)292 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
293   const TfLiteType type = GetInput(context, node, 0)->type;
294   switch (type) {
295     case kTfLiteFloat32:
296       return EvalImpl<float>(
297           context, node, [](float f) { return 1.f / std::sqrt(f); }, type);
298     case kTfLiteInt8:
299       return RsqrtEvalQuantized(context, node, type);
300     default:
301       TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
302                          TfLiteTypeGetName(type));
303       return kTfLiteError;
304   }
305 }
306 
SquareEval(TfLiteContext * context,TfLiteNode * node)307 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
308   return EvalNumeric(context, node, [](float f) { return f * f; });
309 }
310 
LogicalNotEval(TfLiteContext * context,TfLiteNode * node)311 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
312   return EvalLogical(context, node, [](bool v) { return !v; });
313 }
314 
315 }  // namespace
316 }  // namespace elementwise
317 
318 // Given a function...
319 // template<int T>
320 // int Foo(int b)
321 //
322 // typedef int(*Bar)(int);
323 //
324 // MSVC2015 will not see Foo<10> as the same type as Bar.
325 //
326 // This works around the issue by instantiating wrapper methods around
327 // elementwise::GenericPrepare() rather than using a templated
328 // elementwise::GenericPrepare method.
329 #define GENERIC_PREPARE(function_name, is_supported_type_function, type_name)  \
330   static TfLiteStatus function_name(TfLiteContext* context,                    \
331                                     TfLiteNode* node) {                        \
332     return elementwise::GenericPrepare(context, node,                          \
333                                        is_supported_type_function, type_name); \
334   }
335 
GENERIC_PREPARE(PrepareAbs,elementwise::IsAbsSupportedType,elementwise::kAbsName)336 GENERIC_PREPARE(PrepareAbs, elementwise::IsAbsSupportedType,
337                 elementwise::kAbsName)
338 
339 TfLiteRegistration* Register_ABS() {
340   static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
341                                  elementwise::ElementWiseQuantizedFree,
342                                  PrepareAbs, elementwise::AbsEval};
343   return &r;
344 }
345 
346 GENERIC_PREPARE(PrepareSin, elementwise::IsNumericSupportedType, "Sin")
347 
Register_SIN()348 TfLiteRegistration* Register_SIN() {
349   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareSin,
350                                  elementwise::SinEval};
351   return &r;
352 }
353 
354 GENERIC_PREPARE(PrepareCos, elementwise::IsNumericSupportedType, "Cos")
355 
Register_COS()356 TfLiteRegistration* Register_COS() {
357   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareCos,
358                                  elementwise::CosEval};
359   return &r;
360 }
361 
362 GENERIC_PREPARE(PrepareLog, elementwise::IsNumericSupportedType, "Log")
363 
Register_LOG()364 TfLiteRegistration* Register_LOG() {
365   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareLog,
366                                  elementwise::LogEval};
367   return &r;
368 }
369 
370 GENERIC_PREPARE(PrepareSqrt, elementwise::IsNumericSupportedType, "Sqrt")
371 
Register_SQRT()372 TfLiteRegistration* Register_SQRT() {
373   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
374                                  PrepareSqrt, elementwise::SqrtEval};
375   return &r;
376 }
377 
GENERIC_PREPARE(PrepareRsqrt,elementwise::IsRsqrtSupportedType,elementwise::kRsqrtName)378 GENERIC_PREPARE(PrepareRsqrt, elementwise::IsRsqrtSupportedType,
379                 elementwise::kRsqrtName)
380 
381 TfLiteRegistration* Register_RSQRT() {
382   static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
383                                  elementwise::ElementWiseQuantizedFree,
384                                  PrepareRsqrt, elementwise::RsqrtEval};
385   return &r;
386 }
387 
388 GENERIC_PREPARE(PrepareSquare, elementwise::IsNumericSupportedType, "Square")
389 
Register_SQUARE()390 TfLiteRegistration* Register_SQUARE() {
391   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
392                                  PrepareSquare, elementwise::SquareEval};
393   return &r;
394 }
395 
396 GENERIC_PREPARE(PrepareNot, elementwise::IsLogicalSupportedType, "Not")
397 
Register_LOGICAL_NOT()398 TfLiteRegistration* Register_LOGICAL_NOT() {
399   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareNot,
400                                  elementwise::LogicalNotEval};
401   return &r;
402 }
403 
404 }  // namespace builtin
405 }  // namespace ops
406 }  // namespace tflite
407