1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17 #include <stdlib.h>
18
19 #include <cmath>
20 #include <limits>
21
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/kernels/op_macros.h"
29
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace elementwise {
34 namespace {
35
36 const char kAbsName[] = "Abs";
37 const char kRsqrtName[] = "Rsqrt";
38
39 struct OpData {
40 int32_t multiplier;
41 int32_t shift;
42 int input_offset;
43 int output_offset;
44 bool needs_rescale;
45 };
46
IsNumericSupportedType(const TfLiteType type)47 bool IsNumericSupportedType(const TfLiteType type) {
48 return type == kTfLiteFloat32;
49 }
50
IsLogicalSupportedType(const TfLiteType type)51 bool IsLogicalSupportedType(const TfLiteType type) {
52 return type == kTfLiteBool;
53 }
54
IsAbsSupportedType(const TfLiteType type)55 bool IsAbsSupportedType(const TfLiteType type) {
56 return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
57 }
58
IsRsqrtSupportedType(const TfLiteType type)59 bool IsRsqrtSupportedType(const TfLiteType type) {
60 return type == kTfLiteFloat32 || type == kTfLiteInt8;
61 }
62
SetAbsOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)63 inline void SetAbsOutputMultiplier(const float input_scale,
64 const float output_scale,
65 int32_t* multiplier, int32_t* shift) {
66 QuantizeMultiplier(input_scale / output_scale, multiplier, shift);
67 }
68
SetRsqrtOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)69 inline void SetRsqrtOutputMultiplier(const float input_scale,
70 const float output_scale,
71 int32_t* multiplier, int32_t* shift) {
72 const double scale = 1. / (std::sqrt(input_scale) * output_scale);
73 QuantizeMultiplier(scale, multiplier, shift);
74 }
75
76 typedef bool (*IsSupportedType)(TfLiteType);
GenericPrepare(TfLiteContext * context,TfLiteNode * node,IsSupportedType is_supported_type,const char * op_name)77 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node,
78 IsSupportedType is_supported_type,
79 const char* op_name) {
80 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
81 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
82 const TfLiteTensor* input;
83 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
84 TfLiteTensor* output;
85 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
86 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
87 if (!is_supported_type(input->type)) {
88 TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
89 }
90 // For int16 type input, we support both quantized and non-quantized
91 // evaluation.
92 if (input->type == kTfLiteInt8 ||
93 (input->type == kTfLiteInt16 &&
94 input->quantization.type != kTfLiteNoQuantization)) {
95 TfLiteTensor* output = GetOutput(context, node, 0);
96 auto* op_data = static_cast<OpData*>(node->user_data);
97 TF_LITE_ENSURE_EQ(context, input->quantization.type,
98 kTfLiteAffineQuantization);
99 TF_LITE_ENSURE_EQ(context, output->quantization.type,
100 kTfLiteAffineQuantization);
101 const auto* input_params =
102 reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
103 const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
104 output->quantization.params);
105 TF_LITE_ENSURE(context, input_params != nullptr);
106 TF_LITE_ENSURE(context, input_params->scale != nullptr);
107 TF_LITE_ENSURE(context, input_params->scale->size > 0);
108 TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
109 TF_LITE_ENSURE(context, output_params != nullptr);
110 TF_LITE_ENSURE(context, output_params->scale != nullptr);
111 TF_LITE_ENSURE(context, output_params->scale->size > 0);
112 TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
113 op_data->input_offset = input_params->zero_point->data[0];
114 op_data->output_offset = output_params->zero_point->data[0];
115 if (input->type == kTfLiteInt16) {
116 TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
117 TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
118 }
119 const float input_scale = input_params->scale->data[0];
120 const float output_scale = output_params->scale->data[0];
121 op_data->needs_rescale = input_scale != output_scale;
122 if (op_name == kAbsName && op_data->needs_rescale) {
123 SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
124 &op_data->shift);
125 } else if (op_name == kRsqrtName) {
126 SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
127 &op_data->shift);
128 }
129 }
130 return context->ResizeTensor(context, output,
131 TfLiteIntArrayCopy(input->dims));
132 }
133
134 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,std::function<TfLiteStatus (T)> validate_input_func,TfLiteType expected_type)135 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
136 std::function<T(T)> func,
137 std::function<TfLiteStatus(T)> validate_input_func,
138 TfLiteType expected_type) {
139 const TfLiteTensor* input;
140 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
141 TfLiteTensor* output;
142 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
143 TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
144 const int64_t num_elements = NumElements(input);
145 const T* in_data = GetTensorData<T>(input);
146 T* out_data = GetTensorData<T>(output);
147 for (int64_t i = 0; i < num_elements; ++i) {
148 if (validate_input_func) {
149 TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
150 }
151 out_data[i] = func(in_data[i]);
152 }
153 return kTfLiteOk;
154 }
155
156 // Non-quantized evaluation of Abs op when input is int16.
AbsInt16EvalImpl(TfLiteContext * context,TfLiteNode * node,TfLiteType expected_type)157 inline TfLiteStatus AbsInt16EvalImpl(TfLiteContext* context, TfLiteNode* node,
158 TfLiteType expected_type) {
159 const TfLiteTensor* input;
160 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
161 TfLiteTensor* output;
162 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
163 TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
164 const int64_t num_elements = NumElements(input);
165 const int16_t* in_data = GetTensorData<int16_t>(input);
166 int16_t* out_data = GetTensorData<int16_t>(output);
167 for (int64_t i = 0; i < num_elements; ++i) {
168 out_data[i] = static_cast<int16_t>(
169 std::abs<int32_t>(static_cast<int32_t>(in_data[i])));
170 }
171 return kTfLiteOk;
172 }
173
174 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,TfLiteType expected_type)175 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
176 std::function<T(T)> func,
177 TfLiteType expected_type) {
178 return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr,
179 expected_type);
180 }
181
EvalNumeric(TfLiteContext * context,TfLiteNode * node,float float_func (float))182 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
183 float float_func(float)) {
184 return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
185 }
186
EvalLogical(TfLiteContext * context,TfLiteNode * node,bool bool_func (bool))187 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
188 bool bool_func(bool)) {
189 return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
190 }
191
ElementWiseQuantizedInit(TfLiteContext * context,const char * buffer,size_t length)192 void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer,
193 size_t length) {
194 return new OpData();
195 }
196
ElementWiseQuantizedFree(TfLiteContext * context,void * buffer)197 void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) {
198 delete static_cast<OpData*>(buffer);
199 }
200
201 template <typename T>
AbsEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)202 TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node,
203 TfLiteType type) {
204 const auto* op_data = static_cast<const OpData*>(node->user_data);
205 const int kMin = std::numeric_limits<T>::min();
206 const int kMax = std::numeric_limits<T>::max();
207
208 std::function<T(T)> func = [&](T i) {
209 const int32_t value = std::abs(i - op_data->input_offset);
210 if (!op_data->needs_rescale) {
211 return static_cast<T>(
212 std::min(std::max(value + op_data->output_offset, kMin), kMax));
213 }
214 const int32_t output = MultiplyByQuantizedMultiplier(
215 value, op_data->multiplier, op_data->shift) +
216 op_data->output_offset;
217 return static_cast<T>(std::min(std::max(output, kMin), kMax));
218 };
219
220 return EvalImpl<T>(context, node, func, type);
221 }
222
AbsEval(TfLiteContext * context,TfLiteNode * node)223 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
224 const TfLiteTensor* input = GetInput(context, node, 0);
225 const TfLiteType type = input->type;
226 switch (type) {
227 case kTfLiteFloat32:
228 return EvalImpl<float>(context, node, std::abs<float>, type);
229 case kTfLiteInt8:
230 return AbsEvalQuantized<int8_t>(context, node, type);
231 case kTfLiteInt16:
232 return input->quantization.type == kTfLiteNoQuantization
233 ? AbsInt16EvalImpl(context, node, type)
234 : AbsEvalQuantized<int16_t>(context, node, type);
235 default:
236 TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
237 TfLiteTypeGetName(type));
238 return kTfLiteError;
239 }
240 }
241
SinEval(TfLiteContext * context,TfLiteNode * node)242 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
243 return EvalNumeric(context, node, std::sin);
244 }
245
CosEval(TfLiteContext * context,TfLiteNode * node)246 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
247 return EvalNumeric(context, node, std::cos);
248 }
249
LogEval(TfLiteContext * context,TfLiteNode * node)250 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
251 return EvalNumeric(context, node, std::log);
252 }
253
SqrtEval(TfLiteContext * context,TfLiteNode * node)254 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
255 return EvalNumeric(context, node, std::sqrt);
256 }
257
RsqrtEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)258 TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node,
259 TfLiteType type) {
260 const auto* op_data = static_cast<const OpData*>(node->user_data);
261 const int kMin = std::numeric_limits<int8_t>::min();
262 const int kMax = std::numeric_limits<int8_t>::max();
263 std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) {
264 TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
265 "Rsqrt is only defined for positive values");
266 return kTfLiteOk;
267 };
268
269 std::function<int8_t(int8_t)> func = [&](int8_t i) {
270 const int32_t value = (i - op_data->input_offset);
271 const int32_t kShift = 20; // Shift to keep value integer.
272 if (value == 0) {
273 // Assume that any value close to 0 represents the max output value.
274 return static_cast<int8_t>(kMax);
275 }
276 int32_t inv_sqrt_multiplier;
277 int inv_sqrt_shift;
278 GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
279 &inv_sqrt_shift);
280 const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier,
281 inv_sqrt_shift + kShift);
282 const int32_t output =
283 MultiplyByQuantizedMultiplier(data, op_data->multiplier,
284 op_data->shift - kShift) +
285 op_data->output_offset;
286 return static_cast<int8_t>(std::min(std::max(output, kMin), kMax));
287 };
288
289 return EvalImpl<int8_t>(context, node, func, validate_input_func, type);
290 }
291
RsqrtEval(TfLiteContext * context,TfLiteNode * node)292 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
293 const TfLiteType type = GetInput(context, node, 0)->type;
294 switch (type) {
295 case kTfLiteFloat32:
296 return EvalImpl<float>(
297 context, node, [](float f) { return 1.f / std::sqrt(f); }, type);
298 case kTfLiteInt8:
299 return RsqrtEvalQuantized(context, node, type);
300 default:
301 TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
302 TfLiteTypeGetName(type));
303 return kTfLiteError;
304 }
305 }
306
SquareEval(TfLiteContext * context,TfLiteNode * node)307 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
308 return EvalNumeric(context, node, [](float f) { return f * f; });
309 }
310
LogicalNotEval(TfLiteContext * context,TfLiteNode * node)311 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
312 return EvalLogical(context, node, [](bool v) { return !v; });
313 }
314
315 } // namespace
316 } // namespace elementwise
317
318 // Given a function...
319 // template<int T>
320 // int Foo(int b)
321 //
322 // typedef int(*Bar)(int);
323 //
324 // MSVC2015 will not see Foo<10> as the same type as Bar.
325 //
326 // This works around the issue by instantiating wrapper methods around
327 // elementwise::GenericPrepare() rather than using a templated
328 // elementwise::GenericPrepare method.
329 #define GENERIC_PREPARE(function_name, is_supported_type_function, type_name) \
330 static TfLiteStatus function_name(TfLiteContext* context, \
331 TfLiteNode* node) { \
332 return elementwise::GenericPrepare(context, node, \
333 is_supported_type_function, type_name); \
334 }
335
GENERIC_PREPARE(PrepareAbs,elementwise::IsAbsSupportedType,elementwise::kAbsName)336 GENERIC_PREPARE(PrepareAbs, elementwise::IsAbsSupportedType,
337 elementwise::kAbsName)
338
339 TfLiteRegistration* Register_ABS() {
340 static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
341 elementwise::ElementWiseQuantizedFree,
342 PrepareAbs, elementwise::AbsEval};
343 return &r;
344 }
345
346 GENERIC_PREPARE(PrepareSin, elementwise::IsNumericSupportedType, "Sin")
347
Register_SIN()348 TfLiteRegistration* Register_SIN() {
349 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareSin,
350 elementwise::SinEval};
351 return &r;
352 }
353
354 GENERIC_PREPARE(PrepareCos, elementwise::IsNumericSupportedType, "Cos")
355
Register_COS()356 TfLiteRegistration* Register_COS() {
357 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareCos,
358 elementwise::CosEval};
359 return &r;
360 }
361
362 GENERIC_PREPARE(PrepareLog, elementwise::IsNumericSupportedType, "Log")
363
Register_LOG()364 TfLiteRegistration* Register_LOG() {
365 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareLog,
366 elementwise::LogEval};
367 return &r;
368 }
369
370 GENERIC_PREPARE(PrepareSqrt, elementwise::IsNumericSupportedType, "Sqrt")
371
Register_SQRT()372 TfLiteRegistration* Register_SQRT() {
373 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
374 PrepareSqrt, elementwise::SqrtEval};
375 return &r;
376 }
377
GENERIC_PREPARE(PrepareRsqrt,elementwise::IsRsqrtSupportedType,elementwise::kRsqrtName)378 GENERIC_PREPARE(PrepareRsqrt, elementwise::IsRsqrtSupportedType,
379 elementwise::kRsqrtName)
380
381 TfLiteRegistration* Register_RSQRT() {
382 static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
383 elementwise::ElementWiseQuantizedFree,
384 PrepareRsqrt, elementwise::RsqrtEval};
385 return &r;
386 }
387
388 GENERIC_PREPARE(PrepareSquare, elementwise::IsNumericSupportedType, "Square")
389
Register_SQUARE()390 TfLiteRegistration* Register_SQUARE() {
391 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
392 PrepareSquare, elementwise::SquareEval};
393 return &r;
394 }
395
396 GENERIC_PREPARE(PrepareNot, elementwise::IsLogicalSupportedType, "Not")
397
Register_LOGICAL_NOT()398 TfLiteRegistration* Register_LOGICAL_NOT() {
399 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareNot,
400 elementwise::LogicalNotEval};
401 return &r;
402 }
403
404 } // namespace builtin
405 } // namespace ops
406 } // namespace tflite
407