• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/quantization_util.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/kernels/kernel_util.h"
22 
23 namespace tflite {
24 namespace ops {
25 namespace micro {
26 namespace comparisons {
27 namespace {
28 
29 struct OpData {
30   ComparisonParams params;
31 };
32 
33 constexpr int kInputTensor1 = 0;
34 constexpr int kInputTensor2 = 1;
35 constexpr int kOutputTensor = 0;
36 
EqualEval(TfLiteContext * context,TfLiteNode * node)37 TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
38   TFLITE_DCHECK(node->user_data != nullptr);
39   const OpData* data = static_cast<const OpData*>(node->user_data);
40 
41   const TfLiteEvalTensor* input1 =
42       tflite::micro::GetEvalInput(context, node, kInputTensor1);
43   const TfLiteEvalTensor* input2 =
44       tflite::micro::GetEvalInput(context, node, kInputTensor2);
45   TfLiteEvalTensor* output =
46       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
47 
48   RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
49   RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
50   RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
51   bool* output_data = tflite::micro::GetTensorData<bool>(output);
52 
53   bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
54   switch (input1->type) {
55     case kTfLiteBool:
56       requires_broadcast
57           ? reference_ops::Broadcast4DSlowEqualNoScaling(
58                 data->params, input1_shape,
59                 tflite::micro::GetTensorData<bool>(input1), input2_shape,
60                 tflite::micro::GetTensorData<bool>(input2), output_shape,
61                 output_data)
62           : reference_ops::EqualNoScaling(
63                 data->params, input1_shape,
64                 tflite::micro::GetTensorData<bool>(input1), input2_shape,
65                 tflite::micro::GetTensorData<bool>(input2), output_shape,
66                 output_data);
67       break;
68     case kTfLiteFloat32:
69       requires_broadcast
70           ? reference_ops::Broadcast4DSlowEqualNoScaling(
71                 data->params, input1_shape,
72                 tflite::micro::GetTensorData<float>(input1), input2_shape,
73                 tflite::micro::GetTensorData<float>(input2), output_shape,
74                 output_data)
75           : reference_ops::EqualNoScaling(
76                 data->params, input1_shape,
77                 tflite::micro::GetTensorData<float>(input1), input2_shape,
78                 tflite::micro::GetTensorData<float>(input2), output_shape,
79                 output_data);
80       break;
81     case kTfLiteInt32:
82       requires_broadcast
83           ? reference_ops::Broadcast4DSlowEqualNoScaling(
84                 data->params, input1_shape,
85                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
86                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
87                 output_data)
88           : reference_ops::EqualNoScaling(
89                 data->params, input1_shape,
90                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
91                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
92                 output_data);
93       break;
94     case kTfLiteInt64:
95       requires_broadcast
96           ? reference_ops::Broadcast4DSlowEqualNoScaling(
97                 data->params, input1_shape,
98                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
99                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
100                 output_data)
101           : reference_ops::EqualNoScaling(
102                 data->params, input1_shape,
103                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
104                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
105                 output_data);
106       break;
107     case kTfLiteUInt8:
108       requires_broadcast
109           ? reference_ops::Broadcast4DSlowEqualWithScaling(
110                 data->params, input1_shape,
111                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
112                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
113                 output_data)
114           : reference_ops::EqualWithScaling(
115                 data->params, input1_shape,
116                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
117                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
118                 output_data);
119       break;
120     case kTfLiteInt8:
121       requires_broadcast
122           ? reference_ops::Broadcast4DSlowEqualWithScaling(
123                 data->params, input1_shape,
124                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
125                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
126                 output_data)
127           : reference_ops::EqualWithScaling(
128                 data->params, input1_shape,
129                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
130                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
131                 output_data);
132       break;
133     default:
134       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
135                          TfLiteTypeGetName(input1->type), input1->type);
136       return kTfLiteError;
137   }
138   return kTfLiteOk;
139 }
140 
141 // TODO(renjieliu): Refactor the logic to avoid duplications.
NotEqualEval(TfLiteContext * context,TfLiteNode * node)142 TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
143   TFLITE_DCHECK(node->user_data != nullptr);
144   const OpData* data = static_cast<const OpData*>(node->user_data);
145 
146   const TfLiteEvalTensor* input1 =
147       tflite::micro::GetEvalInput(context, node, kInputTensor1);
148   const TfLiteEvalTensor* input2 =
149       tflite::micro::GetEvalInput(context, node, kInputTensor2);
150   TfLiteEvalTensor* output =
151       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
152 
153   RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
154   RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
155   RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
156   bool* output_data = tflite::micro::GetTensorData<bool>(output);
157 
158   bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
159   switch (input1->type) {
160     case kTfLiteBool:
161       requires_broadcast
162           ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
163                 data->params, input1_shape,
164                 tflite::micro::GetTensorData<bool>(input1), input2_shape,
165                 tflite::micro::GetTensorData<bool>(input2), output_shape,
166                 output_data)
167           : reference_ops::NotEqualNoScaling(
168                 data->params, input1_shape,
169                 tflite::micro::GetTensorData<bool>(input1), input2_shape,
170                 tflite::micro::GetTensorData<bool>(input2), output_shape,
171                 output_data);
172       break;
173     case kTfLiteFloat32:
174       requires_broadcast
175           ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
176                 data->params, input1_shape,
177                 tflite::micro::GetTensorData<float>(input1), input2_shape,
178                 tflite::micro::GetTensorData<float>(input2), output_shape,
179                 output_data)
180           : reference_ops::NotEqualNoScaling(
181                 data->params, input1_shape,
182                 tflite::micro::GetTensorData<float>(input1), input2_shape,
183                 tflite::micro::GetTensorData<float>(input2), output_shape,
184                 output_data);
185       break;
186     case kTfLiteInt32:
187       requires_broadcast
188           ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
189                 data->params, input1_shape,
190                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
191                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
192                 output_data)
193           : reference_ops::NotEqualNoScaling(
194                 data->params, input1_shape,
195                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
196                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
197                 output_data);
198       break;
199     case kTfLiteInt64:
200       requires_broadcast
201           ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
202                 data->params, input1_shape,
203                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
204                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
205                 output_data)
206           : reference_ops::NotEqualNoScaling(
207                 data->params, input1_shape,
208                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
209                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
210                 output_data);
211       break;
212     case kTfLiteUInt8:
213       requires_broadcast
214           ? reference_ops::Broadcast4DSlowNotEqualWithScaling(
215                 data->params, input1_shape,
216                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
217                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
218                 output_data)
219           : reference_ops::NotEqualWithScaling(
220                 data->params, input1_shape,
221                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
222                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
223                 output_data);
224       break;
225     case kTfLiteInt8:
226       requires_broadcast
227           ? reference_ops::Broadcast4DSlowNotEqualWithScaling(
228                 data->params, input1_shape,
229                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
230                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
231                 output_data)
232           : reference_ops::NotEqualWithScaling(
233                 data->params, input1_shape,
234                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
235                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
236                 output_data);
237       break;
238     default:
239       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
240                          TfLiteTypeGetName(input1->type), input1->type);
241       return kTfLiteError;
242   }
243   return kTfLiteOk;
244 }
245 
GreaterEval(TfLiteContext * context,TfLiteNode * node)246 TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
247   TFLITE_DCHECK(node->user_data != nullptr);
248   const OpData* data = static_cast<const OpData*>(node->user_data);
249 
250   const TfLiteEvalTensor* input1 =
251       tflite::micro::GetEvalInput(context, node, kInputTensor1);
252   const TfLiteEvalTensor* input2 =
253       tflite::micro::GetEvalInput(context, node, kInputTensor2);
254   TfLiteEvalTensor* output =
255       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
256 
257   RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
258   RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
259   RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
260   bool* output_data = tflite::micro::GetTensorData<bool>(output);
261 
262   bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
263   switch (input1->type) {
264     case kTfLiteFloat32:
265       requires_broadcast
266           ? reference_ops::Broadcast4DSlowGreaterNoScaling(
267                 data->params, input1_shape,
268                 tflite::micro::GetTensorData<float>(input1), input2_shape,
269                 tflite::micro::GetTensorData<float>(input2), output_shape,
270                 output_data)
271           : reference_ops::GreaterNoScaling(
272                 data->params, input1_shape,
273                 tflite::micro::GetTensorData<float>(input1), input2_shape,
274                 tflite::micro::GetTensorData<float>(input2), output_shape,
275                 output_data);
276       break;
277     case kTfLiteInt32:
278       requires_broadcast
279           ? reference_ops::Broadcast4DSlowGreaterNoScaling(
280                 data->params, input1_shape,
281                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
282                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
283                 output_data)
284           : reference_ops::GreaterNoScaling(
285                 data->params, input1_shape,
286                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
287                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
288                 output_data);
289       break;
290     case kTfLiteInt64:
291       requires_broadcast
292           ? reference_ops::Broadcast4DSlowGreaterNoScaling(
293                 data->params, input1_shape,
294                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
295                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
296                 output_data)
297           : reference_ops::GreaterNoScaling(
298                 data->params, input1_shape,
299                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
300                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
301                 output_data);
302       break;
303     case kTfLiteUInt8:
304       requires_broadcast
305           ? reference_ops::Broadcast4DSlowGreaterWithScaling(
306                 data->params, input1_shape,
307                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
308                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
309                 output_data)
310           : reference_ops::GreaterWithScaling(
311                 data->params, input1_shape,
312                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
313                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
314                 output_data);
315       break;
316     case kTfLiteInt8:
317       requires_broadcast
318           ? reference_ops::Broadcast4DSlowGreaterWithScaling(
319                 data->params, input1_shape,
320                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
321                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
322                 output_data)
323           : reference_ops::GreaterWithScaling(
324                 data->params, input1_shape,
325                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
326                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
327                 output_data);
328       break;
329     default:
330       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
331                          TfLiteTypeGetName(input1->type), input1->type);
332       return kTfLiteError;
333   }
334   return kTfLiteOk;
335 }
336 
GreaterEqualEval(TfLiteContext * context,TfLiteNode * node)337 TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
338   TFLITE_DCHECK(node->user_data != nullptr);
339   const OpData* data = static_cast<const OpData*>(node->user_data);
340 
341   const TfLiteEvalTensor* input1 =
342       tflite::micro::GetEvalInput(context, node, kInputTensor1);
343   const TfLiteEvalTensor* input2 =
344       tflite::micro::GetEvalInput(context, node, kInputTensor2);
345   TfLiteEvalTensor* output =
346       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
347 
348   RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
349   RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
350   RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
351   bool* output_data = tflite::micro::GetTensorData<bool>(output);
352 
353   bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
354   switch (input1->type) {
355     case kTfLiteFloat32:
356       requires_broadcast
357           ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
358                 data->params, input1_shape,
359                 tflite::micro::GetTensorData<float>(input1), input2_shape,
360                 tflite::micro::GetTensorData<float>(input2), output_shape,
361                 output_data)
362           : reference_ops::GreaterEqualNoScaling(
363                 data->params, input1_shape,
364                 tflite::micro::GetTensorData<float>(input1), input2_shape,
365                 tflite::micro::GetTensorData<float>(input2), output_shape,
366                 output_data);
367       break;
368     case kTfLiteInt32:
369       requires_broadcast
370           ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
371                 data->params, input1_shape,
372                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
373                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
374                 output_data)
375           : reference_ops::GreaterEqualNoScaling(
376                 data->params, input1_shape,
377                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
378                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
379                 output_data);
380       break;
381     case kTfLiteInt64:
382       requires_broadcast
383           ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
384                 data->params, input1_shape,
385                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
386                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
387                 output_data)
388           : reference_ops::GreaterEqualNoScaling(
389                 data->params, input1_shape,
390                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
391                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
392                 output_data);
393       break;
394     case kTfLiteUInt8:
395       requires_broadcast
396           ? reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
397                 data->params, input1_shape,
398                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
399                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
400                 output_data)
401           : reference_ops::GreaterEqualWithScaling(
402                 data->params, input1_shape,
403                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
404                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
405                 output_data);
406       break;
407     case kTfLiteInt8:
408       requires_broadcast
409           ? reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
410                 data->params, input1_shape,
411                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
412                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
413                 output_data)
414           : reference_ops::GreaterEqualWithScaling(
415                 data->params, input1_shape,
416                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
417                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
418                 output_data);
419       break;
420     default:
421       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
422                          TfLiteTypeGetName(input1->type), input1->type);
423       return kTfLiteError;
424   }
425   return kTfLiteOk;
426 }
427 
LessEval(TfLiteContext * context,TfLiteNode * node)428 TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
429   TFLITE_DCHECK(node->user_data != nullptr);
430   const OpData* data = static_cast<const OpData*>(node->user_data);
431 
432   const TfLiteEvalTensor* input1 =
433       tflite::micro::GetEvalInput(context, node, kInputTensor1);
434   const TfLiteEvalTensor* input2 =
435       tflite::micro::GetEvalInput(context, node, kInputTensor2);
436   TfLiteEvalTensor* output =
437       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
438 
439   RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
440   RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
441   RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
442   bool* output_data = tflite::micro::GetTensorData<bool>(output);
443 
444   bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
445   switch (input1->type) {
446     case kTfLiteFloat32:
447       requires_broadcast
448           ? reference_ops::Broadcast4DSlowLessNoScaling(
449                 data->params, input1_shape,
450                 tflite::micro::GetTensorData<float>(input1), input2_shape,
451                 tflite::micro::GetTensorData<float>(input2), output_shape,
452                 output_data)
453           : reference_ops::LessNoScaling(
454                 data->params, input1_shape,
455                 tflite::micro::GetTensorData<float>(input1), input2_shape,
456                 tflite::micro::GetTensorData<float>(input2), output_shape,
457                 output_data);
458       break;
459     case kTfLiteInt32:
460       requires_broadcast
461           ? reference_ops::Broadcast4DSlowLessNoScaling(
462                 data->params, input1_shape,
463                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
464                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
465                 output_data)
466           : reference_ops::LessNoScaling(
467                 data->params, input1_shape,
468                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
469                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
470                 output_data);
471       break;
472     case kTfLiteInt64:
473       requires_broadcast
474           ? reference_ops::Broadcast4DSlowLessNoScaling(
475                 data->params, input1_shape,
476                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
477                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
478                 output_data)
479           : reference_ops::LessNoScaling(
480                 data->params, input1_shape,
481                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
482                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
483                 output_data);
484       break;
485     case kTfLiteUInt8:
486       requires_broadcast
487           ? reference_ops::Broadcast4DSlowLessWithScaling(
488                 data->params, input1_shape,
489                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
490                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
491                 output_data)
492           : reference_ops::LessWithScaling(
493                 data->params, input1_shape,
494                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
495                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
496                 output_data);
497       break;
498     case kTfLiteInt8:
499       requires_broadcast
500           ? reference_ops::Broadcast4DSlowLessWithScaling(
501                 data->params, input1_shape,
502                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
503                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
504                 output_data)
505           : reference_ops::LessWithScaling(
506                 data->params, input1_shape,
507                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
508                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
509                 output_data);
510       break;
511     default:
512       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
513                          TfLiteTypeGetName(input1->type), input1->type);
514       return kTfLiteError;
515   }
516   return kTfLiteOk;
517 }
518 
LessEqualEval(TfLiteContext * context,TfLiteNode * node)519 TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
520   TFLITE_DCHECK(node->user_data != nullptr);
521   const OpData* data = static_cast<const OpData*>(node->user_data);
522 
523   const TfLiteEvalTensor* input1 =
524       tflite::micro::GetEvalInput(context, node, kInputTensor1);
525   const TfLiteEvalTensor* input2 =
526       tflite::micro::GetEvalInput(context, node, kInputTensor2);
527   TfLiteEvalTensor* output =
528       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
529 
530   RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
531   RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
532   RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
533   bool* output_data = tflite::micro::GetTensorData<bool>(output);
534 
535   bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
536   switch (input1->type) {
537     case kTfLiteFloat32:
538       requires_broadcast
539           ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
540                 data->params, input1_shape,
541                 tflite::micro::GetTensorData<float>(input1), input2_shape,
542                 tflite::micro::GetTensorData<float>(input2), output_shape,
543                 output_data)
544           : reference_ops::LessEqualNoScaling(
545                 data->params, input1_shape,
546                 tflite::micro::GetTensorData<float>(input1), input2_shape,
547                 tflite::micro::GetTensorData<float>(input2), output_shape,
548                 output_data);
549       break;
550     case kTfLiteInt32:
551       requires_broadcast
552           ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
553                 data->params, input1_shape,
554                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
555                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
556                 output_data)
557           : reference_ops::LessEqualNoScaling(
558                 data->params, input1_shape,
559                 tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
560                 tflite::micro::GetTensorData<int32_t>(input2), output_shape,
561                 output_data);
562       break;
563     case kTfLiteInt64:
564       requires_broadcast
565           ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
566                 data->params, input1_shape,
567                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
568                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
569                 output_data)
570           : reference_ops::LessEqualNoScaling(
571                 data->params, input1_shape,
572                 tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
573                 tflite::micro::GetTensorData<int64_t>(input2), output_shape,
574                 output_data);
575       break;
576     case kTfLiteUInt8:
577       requires_broadcast
578           ? reference_ops::Broadcast4DSlowLessEqualWithScaling(
579                 data->params, input1_shape,
580                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
581                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
582                 output_data)
583           : reference_ops::LessEqualWithScaling(
584                 data->params, input1_shape,
585                 tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
586                 tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
587                 output_data);
588       break;
589     case kTfLiteInt8:
590       requires_broadcast
591           ? reference_ops::Broadcast4DSlowLessEqualWithScaling(
592                 data->params, input1_shape,
593                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
594                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
595                 output_data)
596           : reference_ops::LessEqualWithScaling(
597                 data->params, input1_shape,
598                 tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
599                 tflite::micro::GetTensorData<int8_t>(input2), output_shape,
600                 output_data);
601       break;
602     default:
603       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
604                          TfLiteTypeGetName(input1->type), input1->type);
605       return kTfLiteError;
606   }
607   return kTfLiteOk;
608 }
609 
610 }  // namespace
611 
Init(TfLiteContext * context,const char * buffer,size_t length)612 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
613   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
614   return context->AllocatePersistentBuffer(context, sizeof(OpData));
615 }
616 
Prepare(TfLiteContext * context,TfLiteNode * node)617 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
618   TFLITE_DCHECK(node->user_data != nullptr);
619   OpData* data = static_cast<OpData*>(node->user_data);
620 
621   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
622   TF_LITE_ENSURE(context, input1 != nullptr);
623   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
624   TF_LITE_ENSURE(context, input2 != nullptr);
625 
626   if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
627     auto input1_offset = -input1->params.zero_point;
628     auto input2_offset = -input2->params.zero_point;
629     const int kLeftShift = 8;
630 
631     int32_t input1_multiplier;
632     int input1_shift;
633     QuantizeMultiplierSmallerThanOneExp(
634         static_cast<double>(input1->params.scale), &input1_multiplier,
635         &input1_shift);
636     int32_t input2_multiplier;
637     int input2_shift;
638     QuantizeMultiplierSmallerThanOneExp(
639         static_cast<double>(input2->params.scale), &input2_multiplier,
640         &input2_shift);
641 
642     data->params.left_shift = kLeftShift;
643     data->params.input1_offset = input1_offset;
644     data->params.input1_multiplier = input1_multiplier;
645     data->params.input1_shift = input1_shift;
646     data->params.input2_offset = input2_offset;
647     data->params.input2_multiplier = input2_multiplier;
648     data->params.input2_shift = input2_shift;
649   }
650 
651   return kTfLiteOk;
652 }
653 
654 }  // namespace comparisons
655 
Register_EQUAL()656 TfLiteRegistration Register_EQUAL() {
657   return {/*init=*/comparisons::Init,
658           /*free=*/nullptr,
659           /*prepare=*/comparisons::Prepare,
660           /*invoke=*/comparisons::EqualEval,
661           /*profiling_string=*/nullptr,
662           /*builtin_code=*/0,
663           /*custom_name=*/nullptr,
664           /*version=*/0};
665 }
666 
Register_NOT_EQUAL()667 TfLiteRegistration Register_NOT_EQUAL() {
668   return {/*init=*/comparisons::Init,
669           /*free=*/nullptr,
670           /*prepare=*/comparisons::Prepare,
671           /*invoke=*/comparisons::NotEqualEval,
672           /*profiling_string=*/nullptr,
673           /*builtin_code=*/0,
674           /*custom_name=*/nullptr,
675           /*version=*/0};
676 }
677 
Register_GREATER()678 TfLiteRegistration Register_GREATER() {
679   return {/*init=*/comparisons::Init,
680           /*free=*/nullptr,
681           /*prepare=*/comparisons::Prepare,
682           /*invoke=*/comparisons::GreaterEval,
683           /*profiling_string=*/nullptr,
684           /*builtin_code=*/0,
685           /*custom_name=*/nullptr,
686           /*version=*/0};
687 }
688 
Register_GREATER_EQUAL()689 TfLiteRegistration Register_GREATER_EQUAL() {
690   return {/*init=*/comparisons::Init,
691           /*free=*/nullptr,
692           /*prepare=*/comparisons::Prepare,
693           /*invoke=*/comparisons::GreaterEqualEval,
694           /*profiling_string=*/nullptr,
695           /*builtin_code=*/0,
696           /*custom_name=*/nullptr,
697           /*version=*/0};
698 }
699 
Register_LESS()700 TfLiteRegistration Register_LESS() {
701   return {/*init=*/comparisons::Init,
702           /*free=*/nullptr,
703           /*prepare=*/comparisons::Prepare,
704           /*invoke=*/comparisons::LessEval,
705           /*profiling_string=*/nullptr,
706           /*builtin_code=*/0,
707           /*custom_name=*/nullptr,
708           /*version=*/0};
709 }
710 
Register_LESS_EQUAL()711 TfLiteRegistration Register_LESS_EQUAL() {
712   return {/*init=*/comparisons::Init,
713           /*free=*/nullptr,
714           /*prepare=*/comparisons::Prepare,
715           /*invoke=*/comparisons::LessEqualEval,
716           /*profiling_string=*/nullptr,
717           /*builtin_code=*/0,
718           /*custom_name=*/nullptr,
719           /*version=*/0};
720 }
721 
722 }  // namespace micro
723 }  // namespace ops
724 }  // namespace tflite
725