• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api_internal.h"
16 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
17 #include "tensorflow/lite/kernels/internal/tensor.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/kernels/op_macros.h"
20 #include "tensorflow/lite/string_util.h"
21 
22 namespace tflite {
23 namespace ops {
24 namespace builtin {
25 namespace comparisons {
26 namespace {
27 
28 constexpr int kInputTensor1 = 0;
29 constexpr int kInputTensor2 = 1;
30 constexpr int kOutputTensor = 0;
31 
ComparisonPrepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
33   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
34   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35 
36   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
37   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
38   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
39 
40   // Don't support string and bool.
41   TF_LITE_ENSURE(context,
42                  input1->type != kTfLiteString || input1->type != kTfLiteBool);
43   // Currently only support tensors have the same type.
44   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
45   output->type = kTfLiteBool;
46 
47   bool requires_broadcast = !HaveSameShapes(input1, input2);
48 
49   TfLiteIntArray* output_size = nullptr;
50   if (requires_broadcast) {
51     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
52                                    context, input1, input2, &output_size));
53   } else {
54     output_size = TfLiteIntArrayCopy(input1->dims);
55   }
56 
57   return context->ResizeTensor(context, output, output_size);
58 }
59 
60 // TODO(ruic): optimize macros below to using template functions.
61 #define TF_LITE_QUANTIZE_COMPARISON(opname)                                    \
62   template <typename input_dtype>                                              \
63   void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node,         \
64                              const TfLiteTensor* input1,                       \
65                              const TfLiteTensor* input2, TfLiteTensor* output, \
66                              bool requires_broadcast) {                        \
67     if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {         \
68       auto input1_offset = -input1->params.zero_point;                         \
69       auto input2_offset = -input2->params.zero_point;                         \
70       const int left_shift = 8;                                                \
71                                                                                \
72       int32 input1_multiplier;                                                 \
73       int input1_shift;                                                        \
74       QuantizeMultiplierSmallerThanOneExp(input1->params.scale,                \
75                                           &input1_multiplier, &input1_shift);  \
76       int32 input2_multiplier;                                                 \
77       int input2_shift;                                                        \
78       QuantizeMultiplierSmallerThanOneExp(input2->params.scale,                \
79                                           &input2_multiplier, &input2_shift);  \
80                                                                                \
81       ComparisonParams op_params;                                              \
82       op_params.left_shift = left_shift;                                       \
83       op_params.input1_offset = input1_offset;                                 \
84       op_params.input1_multiplier = input1_multiplier;                         \
85       op_params.input1_shift = input1_shift;                                   \
86       op_params.input2_offset = input2_offset;                                 \
87       op_params.input2_multiplier = input2_multiplier;                         \
88       op_params.input2_shift = input2_shift;                                   \
89       if (requires_broadcast) {                                                \
90         reference_ops::Broadcast4DSlow##opname##WithScaling(                   \
91             op_params, GetTensorShape(input1),                                 \
92             GetTensorData<input_dtype>(input1), GetTensorShape(input2),        \
93             GetTensorData<input_dtype>(input2), GetTensorShape(output),        \
94             GetTensorData<bool>(output));                                      \
95       } else {                                                                 \
96         reference_ops::opname##WithScaling(                                    \
97             op_params, GetTensorShape(input1),                                 \
98             GetTensorData<input_dtype>(input1), GetTensorShape(input2),        \
99             GetTensorData<input_dtype>(input2), GetTensorShape(output),        \
100             GetTensorData<bool>(output));                                      \
101       }                                                                        \
102     }                                                                          \
103   }
104 TF_LITE_QUANTIZE_COMPARISON(Equal);
105 TF_LITE_QUANTIZE_COMPARISON(NotEqual);
106 TF_LITE_QUANTIZE_COMPARISON(Greater);
107 TF_LITE_QUANTIZE_COMPARISON(GreaterEqual);
108 TF_LITE_QUANTIZE_COMPARISON(Less);
109 TF_LITE_QUANTIZE_COMPARISON(LessEqual);
110 #undef TF_LITE_QUANTIZE_COMPARISON
111 
112 #define TF_LITE_COMPARISON(type, opname, requires_broadcast)                  \
113   {                                                                           \
114     ComparisonParams op_params;                                               \
115     requires_broadcast                                                        \
116         ? reference_ops::Broadcast4DSlow##opname##NoScaling(                  \
117               op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
118               GetTensorShape(input2), GetTensorData<type>(input2),            \
119               GetTensorShape(output), GetTensorData<bool>(output))            \
120         : reference_ops::opname##NoScaling(                                   \
121               op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
122               GetTensorShape(input2), GetTensorData<type>(input2),            \
123               GetTensorShape(output), GetTensorData<bool>(output));           \
124   }
125 
EqualEval(TfLiteContext * context,TfLiteNode * node)126 TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
127   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
128   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
129   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
130   bool requires_broadcast = !HaveSameShapes(input1, input2);
131   switch (input1->type) {
132     case kTfLiteBool:
133       TF_LITE_COMPARISON(bool, Equal, requires_broadcast);
134       break;
135     case kTfLiteFloat32:
136       TF_LITE_COMPARISON(float, Equal, requires_broadcast);
137       break;
138     case kTfLiteInt32:
139       TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
140       break;
141     case kTfLiteInt64:
142       TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
143       break;
144     case kTfLiteUInt8:
145       EvalQuantizedEqual<uint8_t>(context, node, input1, input2, output,
146                                   requires_broadcast);
147       break;
148     case kTfLiteInt8:
149       EvalQuantizedEqual<int8_t>(context, node, input1, input2, output,
150                                  requires_broadcast);
151       break;
152     default:
153       context->ReportError(
154           context, "Does not support type %d, requires bool|float|int|uint8",
155           input1->type);
156       return kTfLiteError;
157   }
158   return kTfLiteOk;
159 }
160 
161 // TODO(renjieliu): Refactor the logic to avoid duplications.
NotEqualEval(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
163   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
164   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
165   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
166   bool requires_broadcast = !HaveSameShapes(input1, input2);
167   switch (input1->type) {
168     case kTfLiteBool:
169       TF_LITE_COMPARISON(bool, NotEqual, requires_broadcast);
170       break;
171     case kTfLiteFloat32:
172       TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
173       break;
174     case kTfLiteInt32:
175       TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
176       break;
177     case kTfLiteInt64:
178       TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
179       break;
180     case kTfLiteUInt8:
181       EvalQuantizedNotEqual<uint8_t>(context, node, input1, input2, output,
182                                      requires_broadcast);
183       break;
184     case kTfLiteInt8:
185       EvalQuantizedNotEqual<int8_t>(context, node, input1, input2, output,
186                                     requires_broadcast);
187       break;
188     default:
189       context->ReportError(
190           context, "Does not support type %d, requires bool|float|int|uint8",
191           input1->type);
192       return kTfLiteError;
193   }
194   return kTfLiteOk;
195 }
196 
GreaterEval(TfLiteContext * context,TfLiteNode * node)197 TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
198   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
199   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
200   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
201   bool requires_broadcast = !HaveSameShapes(input1, input2);
202   switch (input1->type) {
203     case kTfLiteFloat32:
204       TF_LITE_COMPARISON(float, Greater, requires_broadcast);
205       break;
206     case kTfLiteInt32:
207       TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast);
208       break;
209     case kTfLiteInt64:
210       TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
211       break;
212     case kTfLiteUInt8:
213       EvalQuantizedGreater<uint8_t>(context, node, input1, input2, output,
214                                     requires_broadcast);
215       break;
216     case kTfLiteInt8:
217       EvalQuantizedGreater<int8_t>(context, node, input1, input2, output,
218                                    requires_broadcast);
219       break;
220     default:
221       context->ReportError(context,
222                            "Does not support type %d, requires float|int|uint8",
223                            input1->type);
224       return kTfLiteError;
225   }
226   return kTfLiteOk;
227 }
228 
GreaterEqualEval(TfLiteContext * context,TfLiteNode * node)229 TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
230   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
231   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
232   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
233   bool requires_broadcast = !HaveSameShapes(input1, input2);
234   switch (input1->type) {
235     case kTfLiteFloat32:
236       TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
237       break;
238     case kTfLiteInt32:
239       TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast);
240       break;
241     case kTfLiteInt64:
242       TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
243       break;
244     case kTfLiteUInt8:
245       EvalQuantizedGreaterEqual<uint8_t>(context, node, input1, input2, output,
246                                          requires_broadcast);
247       break;
248     case kTfLiteInt8:
249       EvalQuantizedGreaterEqual<int8_t>(context, node, input1, input2, output,
250                                         requires_broadcast);
251       break;
252     default:
253       context->ReportError(context,
254                            "Does not support type %d, requires float|int|uint8",
255                            input1->type);
256       return kTfLiteError;
257   }
258   return kTfLiteOk;
259 }
260 
LessEval(TfLiteContext * context,TfLiteNode * node)261 TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
262   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
263   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
264   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
265   bool requires_broadcast = !HaveSameShapes(input1, input2);
266   switch (input1->type) {
267     case kTfLiteFloat32:
268       TF_LITE_COMPARISON(float, Less, requires_broadcast);
269       break;
270     case kTfLiteInt32:
271       TF_LITE_COMPARISON(int32_t, Less, requires_broadcast);
272       break;
273     case kTfLiteInt64:
274       TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
275       break;
276     case kTfLiteUInt8:
277       EvalQuantizedLess<uint8_t>(context, node, input1, input2, output,
278                                  requires_broadcast);
279       break;
280     case kTfLiteInt8:
281       EvalQuantizedLess<int8_t>(context, node, input1, input2, output,
282                                 requires_broadcast);
283       break;
284     default:
285       context->ReportError(context,
286                            "Does not support type %d, requires float|int|uint8",
287                            input1->type);
288       return kTfLiteError;
289   }
290   return kTfLiteOk;
291 }
292 
LessEqualEval(TfLiteContext * context,TfLiteNode * node)293 TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
294   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
295   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
296   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
297   bool requires_broadcast = !HaveSameShapes(input1, input2);
298   switch (input1->type) {
299     case kTfLiteFloat32:
300       TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
301       break;
302     case kTfLiteInt32:
303       TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast);
304       break;
305     case kTfLiteInt64:
306       TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
307       break;
308     case kTfLiteUInt8:
309       EvalQuantizedLessEqual<uint8_t>(context, node, input1, input2, output,
310                                       requires_broadcast);
311       break;
312     case kTfLiteInt8:
313       EvalQuantizedLessEqual<int8_t>(context, node, input1, input2, output,
314                                      requires_broadcast);
315       break;
316     default:
317       context->ReportError(context,
318                            "Does not support type %d, requires float|int|uint8",
319                            input1->type);
320       return kTfLiteError;
321   }
322   return kTfLiteOk;
323 }
324 
325 }  // namespace
326 }  // namespace comparisons
327 
Register_EQUAL()328 TfLiteRegistration* Register_EQUAL() {
329   static TfLiteRegistration r = {
330       nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
331   return &r;
332 }
333 
Register_NOT_EQUAL()334 TfLiteRegistration* Register_NOT_EQUAL() {
335   static TfLiteRegistration r = {nullptr, nullptr,
336                                  comparisons::ComparisonPrepare,
337                                  comparisons::NotEqualEval};
338   return &r;
339 }
340 
Register_GREATER()341 TfLiteRegistration* Register_GREATER() {
342   static TfLiteRegistration r = {nullptr, nullptr,
343                                  comparisons::ComparisonPrepare,
344                                  comparisons::GreaterEval};
345   return &r;
346 }
347 
Register_GREATER_EQUAL()348 TfLiteRegistration* Register_GREATER_EQUAL() {
349   static TfLiteRegistration r = {nullptr, nullptr,
350                                  comparisons::ComparisonPrepare,
351                                  comparisons::GreaterEqualEval};
352   return &r;
353 }
354 
Register_LESS()355 TfLiteRegistration* Register_LESS() {
356   static TfLiteRegistration r = {
357       nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval};
358   return &r;
359 }
360 
Register_LESS_EQUAL()361 TfLiteRegistration* Register_LESS_EQUAL() {
362   static TfLiteRegistration r = {nullptr, nullptr,
363                                  comparisons::ComparisonPrepare,
364                                  comparisons::LessEqualEval};
365   return &r;
366 }
367 
368 }  // namespace builtin
369 }  // namespace ops
370 }  // namespace tflite
371