1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api_internal.h"
16 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
17 #include "tensorflow/lite/kernels/internal/tensor.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/kernels/op_macros.h"
20 #include "tensorflow/lite/string_util.h"
21
22 namespace tflite {
23 namespace ops {
24 namespace builtin {
25 namespace comparisons {
26 namespace {
27
28 constexpr int kInputTensor1 = 0;
29 constexpr int kInputTensor2 = 1;
30 constexpr int kOutputTensor = 0;
31
ComparisonPrepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
33 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
34 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35
36 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
37 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
38 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
39
40 // Don't support string and bool.
41 TF_LITE_ENSURE(context,
42 input1->type != kTfLiteString || input1->type != kTfLiteBool);
43 // Currently only support tensors have the same type.
44 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
45 output->type = kTfLiteBool;
46
47 bool requires_broadcast = !HaveSameShapes(input1, input2);
48
49 TfLiteIntArray* output_size = nullptr;
50 if (requires_broadcast) {
51 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
52 context, input1, input2, &output_size));
53 } else {
54 output_size = TfLiteIntArrayCopy(input1->dims);
55 }
56
57 return context->ResizeTensor(context, output, output_size);
58 }
59
60 // TODO(ruic): optimize macros below to using template functions.
61 #define TF_LITE_QUANTIZE_COMPARISON(opname) \
62 template <typename input_dtype> \
63 void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \
64 const TfLiteTensor* input1, \
65 const TfLiteTensor* input2, TfLiteTensor* output, \
66 bool requires_broadcast) { \
67 if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) { \
68 auto input1_offset = -input1->params.zero_point; \
69 auto input2_offset = -input2->params.zero_point; \
70 const int left_shift = 8; \
71 \
72 int32 input1_multiplier; \
73 int input1_shift; \
74 QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
75 &input1_multiplier, &input1_shift); \
76 int32 input2_multiplier; \
77 int input2_shift; \
78 QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
79 &input2_multiplier, &input2_shift); \
80 \
81 ComparisonParams op_params; \
82 op_params.left_shift = left_shift; \
83 op_params.input1_offset = input1_offset; \
84 op_params.input1_multiplier = input1_multiplier; \
85 op_params.input1_shift = input1_shift; \
86 op_params.input2_offset = input2_offset; \
87 op_params.input2_multiplier = input2_multiplier; \
88 op_params.input2_shift = input2_shift; \
89 if (requires_broadcast) { \
90 reference_ops::Broadcast4DSlow##opname##WithScaling( \
91 op_params, GetTensorShape(input1), \
92 GetTensorData<input_dtype>(input1), GetTensorShape(input2), \
93 GetTensorData<input_dtype>(input2), GetTensorShape(output), \
94 GetTensorData<bool>(output)); \
95 } else { \
96 reference_ops::opname##WithScaling( \
97 op_params, GetTensorShape(input1), \
98 GetTensorData<input_dtype>(input1), GetTensorShape(input2), \
99 GetTensorData<input_dtype>(input2), GetTensorShape(output), \
100 GetTensorData<bool>(output)); \
101 } \
102 } \
103 }
104 TF_LITE_QUANTIZE_COMPARISON(Equal);
105 TF_LITE_QUANTIZE_COMPARISON(NotEqual);
106 TF_LITE_QUANTIZE_COMPARISON(Greater);
107 TF_LITE_QUANTIZE_COMPARISON(GreaterEqual);
108 TF_LITE_QUANTIZE_COMPARISON(Less);
109 TF_LITE_QUANTIZE_COMPARISON(LessEqual);
110 #undef TF_LITE_QUANTIZE_COMPARISON
111
112 #define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
113 { \
114 ComparisonParams op_params; \
115 requires_broadcast \
116 ? reference_ops::Broadcast4DSlow##opname##NoScaling( \
117 op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
118 GetTensorShape(input2), GetTensorData<type>(input2), \
119 GetTensorShape(output), GetTensorData<bool>(output)) \
120 : reference_ops::opname##NoScaling( \
121 op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
122 GetTensorShape(input2), GetTensorData<type>(input2), \
123 GetTensorShape(output), GetTensorData<bool>(output)); \
124 }
125
EqualEval(TfLiteContext * context,TfLiteNode * node)126 TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
127 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
128 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
129 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
130 bool requires_broadcast = !HaveSameShapes(input1, input2);
131 switch (input1->type) {
132 case kTfLiteBool:
133 TF_LITE_COMPARISON(bool, Equal, requires_broadcast);
134 break;
135 case kTfLiteFloat32:
136 TF_LITE_COMPARISON(float, Equal, requires_broadcast);
137 break;
138 case kTfLiteInt32:
139 TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
140 break;
141 case kTfLiteInt64:
142 TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
143 break;
144 case kTfLiteUInt8:
145 EvalQuantizedEqual<uint8_t>(context, node, input1, input2, output,
146 requires_broadcast);
147 break;
148 case kTfLiteInt8:
149 EvalQuantizedEqual<int8_t>(context, node, input1, input2, output,
150 requires_broadcast);
151 break;
152 default:
153 context->ReportError(
154 context, "Does not support type %d, requires bool|float|int|uint8",
155 input1->type);
156 return kTfLiteError;
157 }
158 return kTfLiteOk;
159 }
160
161 // TODO(renjieliu): Refactor the logic to avoid duplications.
NotEqualEval(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
163 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
164 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
165 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
166 bool requires_broadcast = !HaveSameShapes(input1, input2);
167 switch (input1->type) {
168 case kTfLiteBool:
169 TF_LITE_COMPARISON(bool, NotEqual, requires_broadcast);
170 break;
171 case kTfLiteFloat32:
172 TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
173 break;
174 case kTfLiteInt32:
175 TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
176 break;
177 case kTfLiteInt64:
178 TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
179 break;
180 case kTfLiteUInt8:
181 EvalQuantizedNotEqual<uint8_t>(context, node, input1, input2, output,
182 requires_broadcast);
183 break;
184 case kTfLiteInt8:
185 EvalQuantizedNotEqual<int8_t>(context, node, input1, input2, output,
186 requires_broadcast);
187 break;
188 default:
189 context->ReportError(
190 context, "Does not support type %d, requires bool|float|int|uint8",
191 input1->type);
192 return kTfLiteError;
193 }
194 return kTfLiteOk;
195 }
196
GreaterEval(TfLiteContext * context,TfLiteNode * node)197 TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
198 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
199 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
200 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
201 bool requires_broadcast = !HaveSameShapes(input1, input2);
202 switch (input1->type) {
203 case kTfLiteFloat32:
204 TF_LITE_COMPARISON(float, Greater, requires_broadcast);
205 break;
206 case kTfLiteInt32:
207 TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast);
208 break;
209 case kTfLiteInt64:
210 TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
211 break;
212 case kTfLiteUInt8:
213 EvalQuantizedGreater<uint8_t>(context, node, input1, input2, output,
214 requires_broadcast);
215 break;
216 case kTfLiteInt8:
217 EvalQuantizedGreater<int8_t>(context, node, input1, input2, output,
218 requires_broadcast);
219 break;
220 default:
221 context->ReportError(context,
222 "Does not support type %d, requires float|int|uint8",
223 input1->type);
224 return kTfLiteError;
225 }
226 return kTfLiteOk;
227 }
228
GreaterEqualEval(TfLiteContext * context,TfLiteNode * node)229 TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
230 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
231 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
232 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
233 bool requires_broadcast = !HaveSameShapes(input1, input2);
234 switch (input1->type) {
235 case kTfLiteFloat32:
236 TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
237 break;
238 case kTfLiteInt32:
239 TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast);
240 break;
241 case kTfLiteInt64:
242 TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
243 break;
244 case kTfLiteUInt8:
245 EvalQuantizedGreaterEqual<uint8_t>(context, node, input1, input2, output,
246 requires_broadcast);
247 break;
248 case kTfLiteInt8:
249 EvalQuantizedGreaterEqual<int8_t>(context, node, input1, input2, output,
250 requires_broadcast);
251 break;
252 default:
253 context->ReportError(context,
254 "Does not support type %d, requires float|int|uint8",
255 input1->type);
256 return kTfLiteError;
257 }
258 return kTfLiteOk;
259 }
260
LessEval(TfLiteContext * context,TfLiteNode * node)261 TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
262 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
263 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
264 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
265 bool requires_broadcast = !HaveSameShapes(input1, input2);
266 switch (input1->type) {
267 case kTfLiteFloat32:
268 TF_LITE_COMPARISON(float, Less, requires_broadcast);
269 break;
270 case kTfLiteInt32:
271 TF_LITE_COMPARISON(int32_t, Less, requires_broadcast);
272 break;
273 case kTfLiteInt64:
274 TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
275 break;
276 case kTfLiteUInt8:
277 EvalQuantizedLess<uint8_t>(context, node, input1, input2, output,
278 requires_broadcast);
279 break;
280 case kTfLiteInt8:
281 EvalQuantizedLess<int8_t>(context, node, input1, input2, output,
282 requires_broadcast);
283 break;
284 default:
285 context->ReportError(context,
286 "Does not support type %d, requires float|int|uint8",
287 input1->type);
288 return kTfLiteError;
289 }
290 return kTfLiteOk;
291 }
292
LessEqualEval(TfLiteContext * context,TfLiteNode * node)293 TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
294 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
295 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
296 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
297 bool requires_broadcast = !HaveSameShapes(input1, input2);
298 switch (input1->type) {
299 case kTfLiteFloat32:
300 TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
301 break;
302 case kTfLiteInt32:
303 TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast);
304 break;
305 case kTfLiteInt64:
306 TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
307 break;
308 case kTfLiteUInt8:
309 EvalQuantizedLessEqual<uint8_t>(context, node, input1, input2, output,
310 requires_broadcast);
311 break;
312 case kTfLiteInt8:
313 EvalQuantizedLessEqual<int8_t>(context, node, input1, input2, output,
314 requires_broadcast);
315 break;
316 default:
317 context->ReportError(context,
318 "Does not support type %d, requires float|int|uint8",
319 input1->type);
320 return kTfLiteError;
321 }
322 return kTfLiteOk;
323 }
324
325 } // namespace
326 } // namespace comparisons
327
Register_EQUAL()328 TfLiteRegistration* Register_EQUAL() {
329 static TfLiteRegistration r = {
330 nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
331 return &r;
332 }
333
Register_NOT_EQUAL()334 TfLiteRegistration* Register_NOT_EQUAL() {
335 static TfLiteRegistration r = {nullptr, nullptr,
336 comparisons::ComparisonPrepare,
337 comparisons::NotEqualEval};
338 return &r;
339 }
340
Register_GREATER()341 TfLiteRegistration* Register_GREATER() {
342 static TfLiteRegistration r = {nullptr, nullptr,
343 comparisons::ComparisonPrepare,
344 comparisons::GreaterEval};
345 return &r;
346 }
347
Register_GREATER_EQUAL()348 TfLiteRegistration* Register_GREATER_EQUAL() {
349 static TfLiteRegistration r = {nullptr, nullptr,
350 comparisons::ComparisonPrepare,
351 comparisons::GreaterEqualEval};
352 return &r;
353 }
354
Register_LESS()355 TfLiteRegistration* Register_LESS() {
356 static TfLiteRegistration r = {
357 nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval};
358 return &r;
359 }
360
Register_LESS_EQUAL()361 TfLiteRegistration* Register_LESS_EQUAL() {
362 static TfLiteRegistration r = {nullptr, nullptr,
363 comparisons::ComparisonPrepare,
364 comparisons::LessEqualEval};
365 return &r;
366 }
367
368 } // namespace builtin
369 } // namespace ops
370 } // namespace tflite
371