1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/quantization_util.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/kernels/kernel_util.h"
22
23 namespace tflite {
24 namespace ops {
25 namespace micro {
26 namespace comparisons {
27 namespace {
28
29 struct OpData {
30 ComparisonParams params;
31 };
32
33 constexpr int kInputTensor1 = 0;
34 constexpr int kInputTensor2 = 1;
35 constexpr int kOutputTensor = 0;
36
EqualEval(TfLiteContext * context,TfLiteNode * node)37 TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
38 TFLITE_DCHECK(node->user_data != nullptr);
39 const OpData* data = static_cast<const OpData*>(node->user_data);
40
41 const TfLiteEvalTensor* input1 =
42 tflite::micro::GetEvalInput(context, node, kInputTensor1);
43 const TfLiteEvalTensor* input2 =
44 tflite::micro::GetEvalInput(context, node, kInputTensor2);
45 TfLiteEvalTensor* output =
46 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
47
48 RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
49 RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
50 RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
51 bool* output_data = tflite::micro::GetTensorData<bool>(output);
52
53 bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
54 switch (input1->type) {
55 case kTfLiteBool:
56 requires_broadcast
57 ? reference_ops::Broadcast4DSlowEqualNoScaling(
58 data->params, input1_shape,
59 tflite::micro::GetTensorData<bool>(input1), input2_shape,
60 tflite::micro::GetTensorData<bool>(input2), output_shape,
61 output_data)
62 : reference_ops::EqualNoScaling(
63 data->params, input1_shape,
64 tflite::micro::GetTensorData<bool>(input1), input2_shape,
65 tflite::micro::GetTensorData<bool>(input2), output_shape,
66 output_data);
67 break;
68 case kTfLiteFloat32:
69 requires_broadcast
70 ? reference_ops::Broadcast4DSlowEqualNoScaling(
71 data->params, input1_shape,
72 tflite::micro::GetTensorData<float>(input1), input2_shape,
73 tflite::micro::GetTensorData<float>(input2), output_shape,
74 output_data)
75 : reference_ops::EqualNoScaling(
76 data->params, input1_shape,
77 tflite::micro::GetTensorData<float>(input1), input2_shape,
78 tflite::micro::GetTensorData<float>(input2), output_shape,
79 output_data);
80 break;
81 case kTfLiteInt32:
82 requires_broadcast
83 ? reference_ops::Broadcast4DSlowEqualNoScaling(
84 data->params, input1_shape,
85 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
86 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
87 output_data)
88 : reference_ops::EqualNoScaling(
89 data->params, input1_shape,
90 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
91 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
92 output_data);
93 break;
94 case kTfLiteInt64:
95 requires_broadcast
96 ? reference_ops::Broadcast4DSlowEqualNoScaling(
97 data->params, input1_shape,
98 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
99 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
100 output_data)
101 : reference_ops::EqualNoScaling(
102 data->params, input1_shape,
103 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
104 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
105 output_data);
106 break;
107 case kTfLiteUInt8:
108 requires_broadcast
109 ? reference_ops::Broadcast4DSlowEqualWithScaling(
110 data->params, input1_shape,
111 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
112 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
113 output_data)
114 : reference_ops::EqualWithScaling(
115 data->params, input1_shape,
116 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
117 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
118 output_data);
119 break;
120 case kTfLiteInt8:
121 requires_broadcast
122 ? reference_ops::Broadcast4DSlowEqualWithScaling(
123 data->params, input1_shape,
124 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
125 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
126 output_data)
127 : reference_ops::EqualWithScaling(
128 data->params, input1_shape,
129 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
130 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
131 output_data);
132 break;
133 default:
134 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
135 TfLiteTypeGetName(input1->type), input1->type);
136 return kTfLiteError;
137 }
138 return kTfLiteOk;
139 }
140
141 // TODO(renjieliu): Refactor the logic to avoid duplications.
NotEqualEval(TfLiteContext * context,TfLiteNode * node)142 TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
143 TFLITE_DCHECK(node->user_data != nullptr);
144 const OpData* data = static_cast<const OpData*>(node->user_data);
145
146 const TfLiteEvalTensor* input1 =
147 tflite::micro::GetEvalInput(context, node, kInputTensor1);
148 const TfLiteEvalTensor* input2 =
149 tflite::micro::GetEvalInput(context, node, kInputTensor2);
150 TfLiteEvalTensor* output =
151 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
152
153 RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
154 RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
155 RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
156 bool* output_data = tflite::micro::GetTensorData<bool>(output);
157
158 bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
159 switch (input1->type) {
160 case kTfLiteBool:
161 requires_broadcast
162 ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
163 data->params, input1_shape,
164 tflite::micro::GetTensorData<bool>(input1), input2_shape,
165 tflite::micro::GetTensorData<bool>(input2), output_shape,
166 output_data)
167 : reference_ops::NotEqualNoScaling(
168 data->params, input1_shape,
169 tflite::micro::GetTensorData<bool>(input1), input2_shape,
170 tflite::micro::GetTensorData<bool>(input2), output_shape,
171 output_data);
172 break;
173 case kTfLiteFloat32:
174 requires_broadcast
175 ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
176 data->params, input1_shape,
177 tflite::micro::GetTensorData<float>(input1), input2_shape,
178 tflite::micro::GetTensorData<float>(input2), output_shape,
179 output_data)
180 : reference_ops::NotEqualNoScaling(
181 data->params, input1_shape,
182 tflite::micro::GetTensorData<float>(input1), input2_shape,
183 tflite::micro::GetTensorData<float>(input2), output_shape,
184 output_data);
185 break;
186 case kTfLiteInt32:
187 requires_broadcast
188 ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
189 data->params, input1_shape,
190 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
191 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
192 output_data)
193 : reference_ops::NotEqualNoScaling(
194 data->params, input1_shape,
195 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
196 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
197 output_data);
198 break;
199 case kTfLiteInt64:
200 requires_broadcast
201 ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
202 data->params, input1_shape,
203 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
204 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
205 output_data)
206 : reference_ops::NotEqualNoScaling(
207 data->params, input1_shape,
208 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
209 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
210 output_data);
211 break;
212 case kTfLiteUInt8:
213 requires_broadcast
214 ? reference_ops::Broadcast4DSlowNotEqualWithScaling(
215 data->params, input1_shape,
216 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
217 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
218 output_data)
219 : reference_ops::NotEqualWithScaling(
220 data->params, input1_shape,
221 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
222 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
223 output_data);
224 break;
225 case kTfLiteInt8:
226 requires_broadcast
227 ? reference_ops::Broadcast4DSlowNotEqualWithScaling(
228 data->params, input1_shape,
229 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
230 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
231 output_data)
232 : reference_ops::NotEqualWithScaling(
233 data->params, input1_shape,
234 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
235 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
236 output_data);
237 break;
238 default:
239 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
240 TfLiteTypeGetName(input1->type), input1->type);
241 return kTfLiteError;
242 }
243 return kTfLiteOk;
244 }
245
GreaterEval(TfLiteContext * context,TfLiteNode * node)246 TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
247 TFLITE_DCHECK(node->user_data != nullptr);
248 const OpData* data = static_cast<const OpData*>(node->user_data);
249
250 const TfLiteEvalTensor* input1 =
251 tflite::micro::GetEvalInput(context, node, kInputTensor1);
252 const TfLiteEvalTensor* input2 =
253 tflite::micro::GetEvalInput(context, node, kInputTensor2);
254 TfLiteEvalTensor* output =
255 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
256
257 RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
258 RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
259 RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
260 bool* output_data = tflite::micro::GetTensorData<bool>(output);
261
262 bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
263 switch (input1->type) {
264 case kTfLiteFloat32:
265 requires_broadcast
266 ? reference_ops::Broadcast4DSlowGreaterNoScaling(
267 data->params, input1_shape,
268 tflite::micro::GetTensorData<float>(input1), input2_shape,
269 tflite::micro::GetTensorData<float>(input2), output_shape,
270 output_data)
271 : reference_ops::GreaterNoScaling(
272 data->params, input1_shape,
273 tflite::micro::GetTensorData<float>(input1), input2_shape,
274 tflite::micro::GetTensorData<float>(input2), output_shape,
275 output_data);
276 break;
277 case kTfLiteInt32:
278 requires_broadcast
279 ? reference_ops::Broadcast4DSlowGreaterNoScaling(
280 data->params, input1_shape,
281 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
282 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
283 output_data)
284 : reference_ops::GreaterNoScaling(
285 data->params, input1_shape,
286 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
287 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
288 output_data);
289 break;
290 case kTfLiteInt64:
291 requires_broadcast
292 ? reference_ops::Broadcast4DSlowGreaterNoScaling(
293 data->params, input1_shape,
294 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
295 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
296 output_data)
297 : reference_ops::GreaterNoScaling(
298 data->params, input1_shape,
299 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
300 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
301 output_data);
302 break;
303 case kTfLiteUInt8:
304 requires_broadcast
305 ? reference_ops::Broadcast4DSlowGreaterWithScaling(
306 data->params, input1_shape,
307 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
308 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
309 output_data)
310 : reference_ops::GreaterWithScaling(
311 data->params, input1_shape,
312 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
313 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
314 output_data);
315 break;
316 case kTfLiteInt8:
317 requires_broadcast
318 ? reference_ops::Broadcast4DSlowGreaterWithScaling(
319 data->params, input1_shape,
320 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
321 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
322 output_data)
323 : reference_ops::GreaterWithScaling(
324 data->params, input1_shape,
325 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
326 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
327 output_data);
328 break;
329 default:
330 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
331 TfLiteTypeGetName(input1->type), input1->type);
332 return kTfLiteError;
333 }
334 return kTfLiteOk;
335 }
336
GreaterEqualEval(TfLiteContext * context,TfLiteNode * node)337 TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
338 TFLITE_DCHECK(node->user_data != nullptr);
339 const OpData* data = static_cast<const OpData*>(node->user_data);
340
341 const TfLiteEvalTensor* input1 =
342 tflite::micro::GetEvalInput(context, node, kInputTensor1);
343 const TfLiteEvalTensor* input2 =
344 tflite::micro::GetEvalInput(context, node, kInputTensor2);
345 TfLiteEvalTensor* output =
346 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
347
348 RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
349 RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
350 RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
351 bool* output_data = tflite::micro::GetTensorData<bool>(output);
352
353 bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
354 switch (input1->type) {
355 case kTfLiteFloat32:
356 requires_broadcast
357 ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
358 data->params, input1_shape,
359 tflite::micro::GetTensorData<float>(input1), input2_shape,
360 tflite::micro::GetTensorData<float>(input2), output_shape,
361 output_data)
362 : reference_ops::GreaterEqualNoScaling(
363 data->params, input1_shape,
364 tflite::micro::GetTensorData<float>(input1), input2_shape,
365 tflite::micro::GetTensorData<float>(input2), output_shape,
366 output_data);
367 break;
368 case kTfLiteInt32:
369 requires_broadcast
370 ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
371 data->params, input1_shape,
372 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
373 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
374 output_data)
375 : reference_ops::GreaterEqualNoScaling(
376 data->params, input1_shape,
377 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
378 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
379 output_data);
380 break;
381 case kTfLiteInt64:
382 requires_broadcast
383 ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
384 data->params, input1_shape,
385 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
386 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
387 output_data)
388 : reference_ops::GreaterEqualNoScaling(
389 data->params, input1_shape,
390 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
391 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
392 output_data);
393 break;
394 case kTfLiteUInt8:
395 requires_broadcast
396 ? reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
397 data->params, input1_shape,
398 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
399 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
400 output_data)
401 : reference_ops::GreaterEqualWithScaling(
402 data->params, input1_shape,
403 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
404 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
405 output_data);
406 break;
407 case kTfLiteInt8:
408 requires_broadcast
409 ? reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
410 data->params, input1_shape,
411 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
412 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
413 output_data)
414 : reference_ops::GreaterEqualWithScaling(
415 data->params, input1_shape,
416 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
417 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
418 output_data);
419 break;
420 default:
421 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
422 TfLiteTypeGetName(input1->type), input1->type);
423 return kTfLiteError;
424 }
425 return kTfLiteOk;
426 }
427
LessEval(TfLiteContext * context,TfLiteNode * node)428 TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
429 TFLITE_DCHECK(node->user_data != nullptr);
430 const OpData* data = static_cast<const OpData*>(node->user_data);
431
432 const TfLiteEvalTensor* input1 =
433 tflite::micro::GetEvalInput(context, node, kInputTensor1);
434 const TfLiteEvalTensor* input2 =
435 tflite::micro::GetEvalInput(context, node, kInputTensor2);
436 TfLiteEvalTensor* output =
437 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
438
439 RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
440 RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
441 RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
442 bool* output_data = tflite::micro::GetTensorData<bool>(output);
443
444 bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
445 switch (input1->type) {
446 case kTfLiteFloat32:
447 requires_broadcast
448 ? reference_ops::Broadcast4DSlowLessNoScaling(
449 data->params, input1_shape,
450 tflite::micro::GetTensorData<float>(input1), input2_shape,
451 tflite::micro::GetTensorData<float>(input2), output_shape,
452 output_data)
453 : reference_ops::LessNoScaling(
454 data->params, input1_shape,
455 tflite::micro::GetTensorData<float>(input1), input2_shape,
456 tflite::micro::GetTensorData<float>(input2), output_shape,
457 output_data);
458 break;
459 case kTfLiteInt32:
460 requires_broadcast
461 ? reference_ops::Broadcast4DSlowLessNoScaling(
462 data->params, input1_shape,
463 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
464 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
465 output_data)
466 : reference_ops::LessNoScaling(
467 data->params, input1_shape,
468 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
469 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
470 output_data);
471 break;
472 case kTfLiteInt64:
473 requires_broadcast
474 ? reference_ops::Broadcast4DSlowLessNoScaling(
475 data->params, input1_shape,
476 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
477 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
478 output_data)
479 : reference_ops::LessNoScaling(
480 data->params, input1_shape,
481 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
482 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
483 output_data);
484 break;
485 case kTfLiteUInt8:
486 requires_broadcast
487 ? reference_ops::Broadcast4DSlowLessWithScaling(
488 data->params, input1_shape,
489 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
490 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
491 output_data)
492 : reference_ops::LessWithScaling(
493 data->params, input1_shape,
494 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
495 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
496 output_data);
497 break;
498 case kTfLiteInt8:
499 requires_broadcast
500 ? reference_ops::Broadcast4DSlowLessWithScaling(
501 data->params, input1_shape,
502 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
503 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
504 output_data)
505 : reference_ops::LessWithScaling(
506 data->params, input1_shape,
507 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
508 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
509 output_data);
510 break;
511 default:
512 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
513 TfLiteTypeGetName(input1->type), input1->type);
514 return kTfLiteError;
515 }
516 return kTfLiteOk;
517 }
518
LessEqualEval(TfLiteContext * context,TfLiteNode * node)519 TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
520 TFLITE_DCHECK(node->user_data != nullptr);
521 const OpData* data = static_cast<const OpData*>(node->user_data);
522
523 const TfLiteEvalTensor* input1 =
524 tflite::micro::GetEvalInput(context, node, kInputTensor1);
525 const TfLiteEvalTensor* input2 =
526 tflite::micro::GetEvalInput(context, node, kInputTensor2);
527 TfLiteEvalTensor* output =
528 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
529
530 RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
531 RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
532 RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
533 bool* output_data = tflite::micro::GetTensorData<bool>(output);
534
535 bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
536 switch (input1->type) {
537 case kTfLiteFloat32:
538 requires_broadcast
539 ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
540 data->params, input1_shape,
541 tflite::micro::GetTensorData<float>(input1), input2_shape,
542 tflite::micro::GetTensorData<float>(input2), output_shape,
543 output_data)
544 : reference_ops::LessEqualNoScaling(
545 data->params, input1_shape,
546 tflite::micro::GetTensorData<float>(input1), input2_shape,
547 tflite::micro::GetTensorData<float>(input2), output_shape,
548 output_data);
549 break;
550 case kTfLiteInt32:
551 requires_broadcast
552 ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
553 data->params, input1_shape,
554 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
555 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
556 output_data)
557 : reference_ops::LessEqualNoScaling(
558 data->params, input1_shape,
559 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
560 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
561 output_data);
562 break;
563 case kTfLiteInt64:
564 requires_broadcast
565 ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
566 data->params, input1_shape,
567 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
568 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
569 output_data)
570 : reference_ops::LessEqualNoScaling(
571 data->params, input1_shape,
572 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
573 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
574 output_data);
575 break;
576 case kTfLiteUInt8:
577 requires_broadcast
578 ? reference_ops::Broadcast4DSlowLessEqualWithScaling(
579 data->params, input1_shape,
580 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
581 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
582 output_data)
583 : reference_ops::LessEqualWithScaling(
584 data->params, input1_shape,
585 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
586 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
587 output_data);
588 break;
589 case kTfLiteInt8:
590 requires_broadcast
591 ? reference_ops::Broadcast4DSlowLessEqualWithScaling(
592 data->params, input1_shape,
593 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
594 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
595 output_data)
596 : reference_ops::LessEqualWithScaling(
597 data->params, input1_shape,
598 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
599 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
600 output_data);
601 break;
602 default:
603 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
604 TfLiteTypeGetName(input1->type), input1->type);
605 return kTfLiteError;
606 }
607 return kTfLiteOk;
608 }
609
610 } // namespace
611
Init(TfLiteContext * context,const char * buffer,size_t length)612 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
613 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
614 return context->AllocatePersistentBuffer(context, sizeof(OpData));
615 }
616
Prepare(TfLiteContext * context,TfLiteNode * node)617 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
618 TFLITE_DCHECK(node->user_data != nullptr);
619 OpData* data = static_cast<OpData*>(node->user_data);
620
621 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
622 TF_LITE_ENSURE(context, input1 != nullptr);
623 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
624 TF_LITE_ENSURE(context, input2 != nullptr);
625
626 if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
627 auto input1_offset = -input1->params.zero_point;
628 auto input2_offset = -input2->params.zero_point;
629 const int kLeftShift = 8;
630
631 int32_t input1_multiplier;
632 int input1_shift;
633 QuantizeMultiplierSmallerThanOneExp(
634 static_cast<double>(input1->params.scale), &input1_multiplier,
635 &input1_shift);
636 int32_t input2_multiplier;
637 int input2_shift;
638 QuantizeMultiplierSmallerThanOneExp(
639 static_cast<double>(input2->params.scale), &input2_multiplier,
640 &input2_shift);
641
642 data->params.left_shift = kLeftShift;
643 data->params.input1_offset = input1_offset;
644 data->params.input1_multiplier = input1_multiplier;
645 data->params.input1_shift = input1_shift;
646 data->params.input2_offset = input2_offset;
647 data->params.input2_multiplier = input2_multiplier;
648 data->params.input2_shift = input2_shift;
649 }
650
651 return kTfLiteOk;
652 }
653
654 } // namespace comparisons
655
Register_EQUAL()656 TfLiteRegistration Register_EQUAL() {
657 return {/*init=*/comparisons::Init,
658 /*free=*/nullptr,
659 /*prepare=*/comparisons::Prepare,
660 /*invoke=*/comparisons::EqualEval,
661 /*profiling_string=*/nullptr,
662 /*builtin_code=*/0,
663 /*custom_name=*/nullptr,
664 /*version=*/0};
665 }
666
Register_NOT_EQUAL()667 TfLiteRegistration Register_NOT_EQUAL() {
668 return {/*init=*/comparisons::Init,
669 /*free=*/nullptr,
670 /*prepare=*/comparisons::Prepare,
671 /*invoke=*/comparisons::NotEqualEval,
672 /*profiling_string=*/nullptr,
673 /*builtin_code=*/0,
674 /*custom_name=*/nullptr,
675 /*version=*/0};
676 }
677
Register_GREATER()678 TfLiteRegistration Register_GREATER() {
679 return {/*init=*/comparisons::Init,
680 /*free=*/nullptr,
681 /*prepare=*/comparisons::Prepare,
682 /*invoke=*/comparisons::GreaterEval,
683 /*profiling_string=*/nullptr,
684 /*builtin_code=*/0,
685 /*custom_name=*/nullptr,
686 /*version=*/0};
687 }
688
Register_GREATER_EQUAL()689 TfLiteRegistration Register_GREATER_EQUAL() {
690 return {/*init=*/comparisons::Init,
691 /*free=*/nullptr,
692 /*prepare=*/comparisons::Prepare,
693 /*invoke=*/comparisons::GreaterEqualEval,
694 /*profiling_string=*/nullptr,
695 /*builtin_code=*/0,
696 /*custom_name=*/nullptr,
697 /*version=*/0};
698 }
699
Register_LESS()700 TfLiteRegistration Register_LESS() {
701 return {/*init=*/comparisons::Init,
702 /*free=*/nullptr,
703 /*prepare=*/comparisons::Prepare,
704 /*invoke=*/comparisons::LessEval,
705 /*profiling_string=*/nullptr,
706 /*builtin_code=*/0,
707 /*custom_name=*/nullptr,
708 /*version=*/0};
709 }
710
Register_LESS_EQUAL()711 TfLiteRegistration Register_LESS_EQUAL() {
712 return {/*init=*/comparisons::Init,
713 /*free=*/nullptr,
714 /*prepare=*/comparisons::Prepare,
715 /*invoke=*/comparisons::LessEqualEval,
716 /*profiling_string=*/nullptr,
717 /*builtin_code=*/0,
718 /*custom_name=*/nullptr,
719 /*version=*/0};
720 }
721
722 } // namespace micro
723 } // namespace ops
724 } // namespace tflite
725