• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
16 
17 #include <stdint.h>
18 
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 #include "tensorflow/lite/string_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace comparisons {
33 namespace {
34 
35 constexpr int kInputTensor1 = 0;
36 constexpr int kInputTensor2 = 1;
37 constexpr int kOutputTensor = 0;
38 
ComparisonPrepareCommon(TfLiteContext * context,TfLiteNode * node,bool is_string_allowed)39 TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node,
40                                      bool is_string_allowed) {
41   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
42   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
43 
44   const TfLiteTensor* input1;
45   TF_LITE_ENSURE_OK(context,
46                     GetInputSafe(context, node, kInputTensor1, &input1));
47   const TfLiteTensor* input2;
48   TF_LITE_ENSURE_OK(context,
49                     GetInputSafe(context, node, kInputTensor2, &input2));
50   TfLiteTensor* output;
51   TF_LITE_ENSURE_OK(context,
52                     GetOutputSafe(context, node, kOutputTensor, &output));
53 
54   // Don't support string.
55   if (!is_string_allowed) {
56     TF_LITE_ENSURE(context, input1->type != kTfLiteString);
57   }
58   // Currently only support tensors have the same type.
59   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
60   output->type = kTfLiteBool;
61 
62   bool requires_broadcast = !HaveSameShapes(input1, input2);
63 
64   TfLiteIntArray* output_size = nullptr;
65   if (requires_broadcast) {
66     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
67                                    context, input1, input2, &output_size));
68   } else {
69     output_size = TfLiteIntArrayCopy(input1->dims);
70   }
71 
72   return context->ResizeTensor(context, output, output_size);
73 }
74 
ComparisonPrepare(TfLiteContext * context,TfLiteNode * node)75 TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
76   return ComparisonPrepareCommon(context, node, false);
77 }
78 
ComparisonPrepareStringAllowed(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus ComparisonPrepareStringAllowed(TfLiteContext* context,
80                                             TfLiteNode* node) {
81   return ComparisonPrepareCommon(context, node, true);
82 }
83 
QuantizeMultiplier(double double_multiplier,int32_t * quantized_multiplier,int * left_shift)84 void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
85                         int* left_shift) {
86   if (double_multiplier < 1.0) {
87     QuantizeMultiplierSmallerThanOneExp(double_multiplier, quantized_multiplier,
88                                         left_shift);
89   } else {
90     QuantizeMultiplierGreaterThanOne(double_multiplier, quantized_multiplier,
91                                      left_shift);
92   }
93 }
94 
95 template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
ComparisonQuantized(const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,bool requires_broadcast)96 void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
97                          TfLiteTensor* output, bool requires_broadcast) {
98   if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
99     auto input1_offset = -input1->params.zero_point;
100     auto input2_offset = -input2->params.zero_point;
101     const int left_shift = 8;
102 
103     int32 input1_multiplier;
104     int32 input2_multiplier;
105     int input1_shift;
106     int input2_shift;
107     QuantizeMultiplier(input1->params.scale, &input1_multiplier, &input1_shift);
108     QuantizeMultiplier(input2->params.scale, &input2_multiplier, &input2_shift);
109 
110     ComparisonParams op_params;
111     op_params.left_shift = left_shift;
112     op_params.input1_offset = input1_offset;
113     op_params.input1_multiplier = input1_multiplier;
114     op_params.input1_shift = input1_shift;
115     op_params.input2_offset = input2_offset;
116     op_params.input2_multiplier = input2_multiplier;
117     op_params.input2_shift = input2_shift;
118     if (requires_broadcast) {
119       reference_ops::BroadcastComparison4DSlowWithScaling<input_dtype, opname>(
120           op_params, GetTensorShape(input1), GetTensorData<input_dtype>(input1),
121           GetTensorShape(input2), GetTensorData<input_dtype>(input2),
122           GetTensorShape(output), GetTensorData<bool>(output));
123     } else {
124       reference_ops::ComparisonWithScaling<input_dtype, opname>(
125           op_params, GetTensorShape(input1), GetTensorData<input_dtype>(input1),
126           GetTensorShape(input2), GetTensorData<input_dtype>(input2),
127           GetTensorShape(output), GetTensorData<bool>(output));
128     }
129   }
130 }
131 
132 template <typename T, reference_ops::ComparisonFn<T> opname>
Comparison(const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,bool requires_broadcast)133 void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2,
134                 TfLiteTensor* output, bool requires_broadcast) {
135   ComparisonParams op_params;
136   requires_broadcast
137       ? reference_ops::BroadcastComparison4DSlowImpl<T, opname>(
138             op_params, GetTensorShape(input1), GetTensorData<T>(input1),
139             GetTensorShape(input2), GetTensorData<T>(input2),
140             GetTensorShape(output), GetTensorData<bool>(output))
141       : reference_ops::ComparisonImpl<T, opname>(
142             op_params, GetTensorShape(input1), GetTensorData<T>(input1),
143             GetTensorShape(input2), GetTensorData<T>(input2),
144             GetTensorShape(output), GetTensorData<bool>(output));
145 }
146 
ComparisonString(bool (* opname)(const StringRef &,const StringRef &),const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,bool requires_broadcast)147 void ComparisonString(bool (*opname)(const StringRef&, const StringRef&),
148                       const TfLiteTensor* input1, const TfLiteTensor* input2,
149                       TfLiteTensor* output, bool requires_broadcast) {
150   bool* output_data = GetTensorData<bool>(output);
151   if (requires_broadcast) {
152     reference_ops::BroadcastComparison4DSlowStringImpl(
153         opname, GetTensorShape(input1), input1, GetTensorShape(input2), input2,
154         GetTensorShape(output), output_data);
155   } else {
156     reference_ops::ComparisonStringImpl(opname, GetTensorShape(input1), input1,
157                                         GetTensorShape(input2), input2,
158                                         GetTensorShape(output), output_data);
159   }
160 }
161 
EqualEval(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
163   const TfLiteTensor* input1;
164   TF_LITE_ENSURE_OK(context,
165                     GetInputSafe(context, node, kInputTensor1, &input1));
166   const TfLiteTensor* input2;
167   TF_LITE_ENSURE_OK(context,
168                     GetInputSafe(context, node, kInputTensor2, &input2));
169   TfLiteTensor* output;
170   TF_LITE_ENSURE_OK(context,
171                     GetOutputSafe(context, node, kOutputTensor, &output));
172   bool requires_broadcast = !HaveSameShapes(input1, input2);
173   switch (input1->type) {
174     case kTfLiteBool:
175       Comparison<bool, reference_ops::EqualFn>(input1, input2, output,
176                                                requires_broadcast);
177       break;
178     case kTfLiteFloat32:
179       Comparison<float, reference_ops::EqualFn>(input1, input2, output,
180                                                 requires_broadcast);
181       break;
182     case kTfLiteInt32:
183       Comparison<int32_t, reference_ops::EqualFn>(input1, input2, output,
184                                                   requires_broadcast);
185       break;
186     case kTfLiteInt64:
187       Comparison<int64_t, reference_ops::EqualFn>(input1, input2, output,
188                                                   requires_broadcast);
189       break;
190     case kTfLiteUInt8:
191       ComparisonQuantized<uint8_t, reference_ops::EqualFn>(
192           input1, input2, output, requires_broadcast);
193       break;
194     case kTfLiteInt8:
195       ComparisonQuantized<int8_t, reference_ops::EqualFn>(
196           input1, input2, output, requires_broadcast);
197       break;
198     case kTfLiteString:
199       ComparisonString(reference_ops::StringRefEqualFn, input1, input2, output,
200                        requires_broadcast);
201       break;
202     default:
203       TF_LITE_KERNEL_LOG(
204           context,
205           "Does not support type %d, requires bool|float|int|uint8|string",
206           input1->type);
207       return kTfLiteError;
208   }
209   return kTfLiteOk;
210 }
211 
NotEqualEval(TfLiteContext * context,TfLiteNode * node)212 TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
213   const TfLiteTensor* input1;
214   TF_LITE_ENSURE_OK(context,
215                     GetInputSafe(context, node, kInputTensor1, &input1));
216   const TfLiteTensor* input2;
217   TF_LITE_ENSURE_OK(context,
218                     GetInputSafe(context, node, kInputTensor2, &input2));
219   TfLiteTensor* output;
220   TF_LITE_ENSURE_OK(context,
221                     GetOutputSafe(context, node, kOutputTensor, &output));
222   bool requires_broadcast = !HaveSameShapes(input1, input2);
223   switch (input1->type) {
224     case kTfLiteBool:
225       Comparison<bool, reference_ops::NotEqualFn>(input1, input2, output,
226                                                   requires_broadcast);
227       break;
228     case kTfLiteFloat32:
229       Comparison<float, reference_ops::NotEqualFn>(input1, input2, output,
230                                                    requires_broadcast);
231       break;
232     case kTfLiteInt32:
233       Comparison<int32_t, reference_ops::NotEqualFn>(input1, input2, output,
234                                                      requires_broadcast);
235       break;
236     case kTfLiteInt64:
237       Comparison<int64_t, reference_ops::NotEqualFn>(input1, input2, output,
238                                                      requires_broadcast);
239       break;
240     case kTfLiteUInt8:
241       ComparisonQuantized<uint8_t, reference_ops::NotEqualFn>(
242           input1, input2, output, requires_broadcast);
243       break;
244     case kTfLiteInt8:
245       ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
246           input1, input2, output, requires_broadcast);
247       break;
248     case kTfLiteString:
249       ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2,
250                        output, requires_broadcast);
251       break;
252     default:
253       TF_LITE_KERNEL_LOG(
254           context,
255           "Does not support type %d, requires bool|float|int|uint8|string",
256           input1->type);
257       return kTfLiteError;
258   }
259   return kTfLiteOk;
260 }
261 
GreaterEval(TfLiteContext * context,TfLiteNode * node)262 TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
263   const TfLiteTensor* input1;
264   TF_LITE_ENSURE_OK(context,
265                     GetInputSafe(context, node, kInputTensor1, &input1));
266   const TfLiteTensor* input2;
267   TF_LITE_ENSURE_OK(context,
268                     GetInputSafe(context, node, kInputTensor2, &input2));
269   TfLiteTensor* output;
270   TF_LITE_ENSURE_OK(context,
271                     GetOutputSafe(context, node, kOutputTensor, &output));
272   bool requires_broadcast = !HaveSameShapes(input1, input2);
273   switch (input1->type) {
274     case kTfLiteFloat32:
275       Comparison<float, reference_ops::GreaterFn>(input1, input2, output,
276                                                   requires_broadcast);
277       break;
278     case kTfLiteInt32:
279       Comparison<int32_t, reference_ops::GreaterFn>(input1, input2, output,
280                                                     requires_broadcast);
281       break;
282     case kTfLiteInt64:
283       Comparison<int64_t, reference_ops::GreaterFn>(input1, input2, output,
284                                                     requires_broadcast);
285       break;
286     case kTfLiteUInt8:
287       ComparisonQuantized<uint8_t, reference_ops::GreaterFn>(
288           input1, input2, output, requires_broadcast);
289       break;
290     case kTfLiteInt8:
291       ComparisonQuantized<int8_t, reference_ops::GreaterFn>(
292           input1, input2, output, requires_broadcast);
293       break;
294     default:
295       TF_LITE_KERNEL_LOG(context,
296                          "Does not support type %d, requires float|int|uint8",
297                          input1->type);
298       return kTfLiteError;
299   }
300   return kTfLiteOk;
301 }
302 
GreaterEqualEval(TfLiteContext * context,TfLiteNode * node)303 TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
304   const TfLiteTensor* input1;
305   TF_LITE_ENSURE_OK(context,
306                     GetInputSafe(context, node, kInputTensor1, &input1));
307   const TfLiteTensor* input2;
308   TF_LITE_ENSURE_OK(context,
309                     GetInputSafe(context, node, kInputTensor2, &input2));
310   TfLiteTensor* output;
311   TF_LITE_ENSURE_OK(context,
312                     GetOutputSafe(context, node, kOutputTensor, &output));
313   bool requires_broadcast = !HaveSameShapes(input1, input2);
314   switch (input1->type) {
315     case kTfLiteFloat32:
316       Comparison<float, reference_ops::GreaterEqualFn>(input1, input2, output,
317                                                        requires_broadcast);
318       break;
319     case kTfLiteInt32:
320       Comparison<int32_t, reference_ops::GreaterEqualFn>(input1, input2, output,
321                                                          requires_broadcast);
322       break;
323     case kTfLiteInt64:
324       Comparison<int64_t, reference_ops::GreaterEqualFn>(input1, input2, output,
325                                                          requires_broadcast);
326       break;
327     case kTfLiteUInt8:
328       ComparisonQuantized<uint8_t, reference_ops::GreaterEqualFn>(
329           input1, input2, output, requires_broadcast);
330       break;
331     case kTfLiteInt8:
332       ComparisonQuantized<int8_t, reference_ops::GreaterEqualFn>(
333           input1, input2, output, requires_broadcast);
334       break;
335     default:
336       TF_LITE_KERNEL_LOG(context,
337                          "Does not support type %d, requires float|int|uint8",
338                          input1->type);
339       return kTfLiteError;
340   }
341   return kTfLiteOk;
342 }
343 
LessEval(TfLiteContext * context,TfLiteNode * node)344 TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
345   const TfLiteTensor* input1;
346   TF_LITE_ENSURE_OK(context,
347                     GetInputSafe(context, node, kInputTensor1, &input1));
348   const TfLiteTensor* input2;
349   TF_LITE_ENSURE_OK(context,
350                     GetInputSafe(context, node, kInputTensor2, &input2));
351   TfLiteTensor* output;
352   TF_LITE_ENSURE_OK(context,
353                     GetOutputSafe(context, node, kOutputTensor, &output));
354   bool requires_broadcast = !HaveSameShapes(input1, input2);
355   switch (input1->type) {
356     case kTfLiteFloat32:
357       Comparison<float, reference_ops::LessFn>(input1, input2, output,
358                                                requires_broadcast);
359       break;
360     case kTfLiteInt32:
361       Comparison<int32_t, reference_ops::LessFn>(input1, input2, output,
362                                                  requires_broadcast);
363       break;
364     case kTfLiteInt64:
365       Comparison<int64_t, reference_ops::LessFn>(input1, input2, output,
366                                                  requires_broadcast);
367       break;
368     case kTfLiteUInt8:
369       ComparisonQuantized<uint8_t, reference_ops::LessFn>(
370           input1, input2, output, requires_broadcast);
371       break;
372     case kTfLiteInt8:
373       ComparisonQuantized<int8_t, reference_ops::LessFn>(input1, input2, output,
374                                                          requires_broadcast);
375       break;
376     default:
377       TF_LITE_KERNEL_LOG(context,
378                          "Does not support type %d, requires float|int|uint8",
379                          input1->type);
380       return kTfLiteError;
381   }
382   return kTfLiteOk;
383 }
384 
LessEqualEval(TfLiteContext * context,TfLiteNode * node)385 TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
386   const TfLiteTensor* input1;
387   TF_LITE_ENSURE_OK(context,
388                     GetInputSafe(context, node, kInputTensor1, &input1));
389   const TfLiteTensor* input2;
390   TF_LITE_ENSURE_OK(context,
391                     GetInputSafe(context, node, kInputTensor2, &input2));
392   TfLiteTensor* output;
393   TF_LITE_ENSURE_OK(context,
394                     GetOutputSafe(context, node, kOutputTensor, &output));
395   bool requires_broadcast = !HaveSameShapes(input1, input2);
396   switch (input1->type) {
397     case kTfLiteFloat32:
398       Comparison<float, reference_ops::LessEqualFn>(input1, input2, output,
399                                                     requires_broadcast);
400       break;
401     case kTfLiteInt32:
402       Comparison<int32_t, reference_ops::LessEqualFn>(input1, input2, output,
403                                                       requires_broadcast);
404       break;
405     case kTfLiteInt64:
406       Comparison<int64_t, reference_ops::LessEqualFn>(input1, input2, output,
407                                                       requires_broadcast);
408       break;
409     case kTfLiteUInt8:
410       ComparisonQuantized<uint8_t, reference_ops::LessEqualFn>(
411           input1, input2, output, requires_broadcast);
412       break;
413     case kTfLiteInt8:
414       ComparisonQuantized<int8_t, reference_ops::LessEqualFn>(
415           input1, input2, output, requires_broadcast);
416       break;
417     default:
418       TF_LITE_KERNEL_LOG(context,
419                          "Does not support type %d, requires float|int|uint8",
420                          input1->type);
421       return kTfLiteError;
422   }
423   return kTfLiteOk;
424 }
425 
426 }  // namespace
427 }  // namespace comparisons
428 
Register_EQUAL()429 TfLiteRegistration* Register_EQUAL() {
430   static TfLiteRegistration r = {nullptr, nullptr,
431                                  comparisons::ComparisonPrepareStringAllowed,
432                                  comparisons::EqualEval};
433   return &r;
434 }
435 
Register_NOT_EQUAL()436 TfLiteRegistration* Register_NOT_EQUAL() {
437   static TfLiteRegistration r = {nullptr, nullptr,
438                                  comparisons::ComparisonPrepareStringAllowed,
439                                  comparisons::NotEqualEval};
440   return &r;
441 }
442 
Register_GREATER()443 TfLiteRegistration* Register_GREATER() {
444   static TfLiteRegistration r = {nullptr, nullptr,
445                                  comparisons::ComparisonPrepare,
446                                  comparisons::GreaterEval};
447   return &r;
448 }
449 
Register_GREATER_EQUAL()450 TfLiteRegistration* Register_GREATER_EQUAL() {
451   static TfLiteRegistration r = {nullptr, nullptr,
452                                  comparisons::ComparisonPrepare,
453                                  comparisons::GreaterEqualEval};
454   return &r;
455 }
456 
Register_LESS()457 TfLiteRegistration* Register_LESS() {
458   static TfLiteRegistration r = {
459       nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval};
460   return &r;
461 }
462 
Register_LESS_EQUAL()463 TfLiteRegistration* Register_LESS_EQUAL() {
464   static TfLiteRegistration r = {nullptr, nullptr,
465                                  comparisons::ComparisonPrepare,
466                                  comparisons::LessEqualEval};
467   return &r;
468 }
469 
470 }  // namespace builtin
471 }  // namespace ops
472 }  // namespace tflite
473