• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
16 
17 #include <stdint.h>
18 
19 #include <functional>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/c_api_types.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
25 #include "tensorflow/lite/kernels/internal/quantization_util.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace arg_min_max {
34 
35 constexpr int kInputTensor = 0;
36 constexpr int kAxis = 1;
37 constexpr int kOutputTensor = 0;
38 
ResizeOutput(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * axis,TfLiteTensor * output)39 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* input,
40                           const TfLiteTensor* axis, TfLiteTensor* output) {
41   int axis_value;
42   // Retrieve all 8 bytes when axis type is kTfLiteInt64 to avoid data loss.
43   if (axis->type == kTfLiteInt64) {
44     axis_value = static_cast<int>(*GetTensorData<int64_t>(axis));
45   } else {
46     axis_value = *GetTensorData<int>(axis);
47   }
48   if (axis_value < 0) {
49     axis_value += NumDimensions(input);
50   }
51 
52   TF_LITE_ENSURE(context, axis_value >= 0);
53   TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
54 
55   // Copy the input dimensions to output except the axis dimension.
56   TfLiteIntArray* output_dims = TfLiteIntArrayCreate(NumDimensions(input) - 1);
57   int j = 0;
58   for (int i = 0; i < NumDimensions(input); ++i) {
59     if (i != axis_value) {
60       output_dims->data[j] = SizeOfDimension(input, i);
61       ++j;
62     }
63   }
64   return context->ResizeTensor(context, output, output_dims);
65 }
66 
Prepare(TfLiteContext * context,TfLiteNode * node)67 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
68   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
69   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
70 
71   const TfLiteTensor* input;
72   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
73   const TfLiteTensor* axis;
74   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
75   // Make sure the axis is only 1 dimension.
76   TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
77   // Make sure the axis is only either int32 or int64.
78   TF_LITE_ENSURE(context,
79                  axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64);
80 
81   TfLiteTensor* output;
82   TF_LITE_ENSURE_OK(context,
83                     GetOutputSafe(context, node, kOutputTensor, &output));
84 
85   auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
86   switch (params->output_type) {
87     case kTfLiteInt32:
88       output->type = kTfLiteInt32;
89       break;
90     case kTfLiteInt64:
91       output->type = kTfLiteInt64;
92       break;
93     default:
94       TF_LITE_KERNEL_LOG(context, "Unknown index output data type: %d",
95                          params->output_type);
96       return kTfLiteError;
97   }
98 
99   // Check conditions for different types.
100   switch (input->type) {
101     case kTfLiteFloat32:
102     case kTfLiteUInt8:
103     case kTfLiteInt8:
104     case kTfLiteInt32:
105     case kTfLiteBool:
106       break;
107 
108     default:
109       TF_LITE_KERNEL_LOG(context,
110                          "Unknown input type: %d, only float32, int types "
111                          "and bool are supported",
112                          input->type);
113       return kTfLiteError;
114   }
115 
116   TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
117 
118   if (IsConstantTensor(axis)) {
119     TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
120   } else {
121     SetTensorToDynamic(output);
122   }
123 
124   return kTfLiteOk;
125 }
126 
Eval(TfLiteContext * context,TfLiteNode * node,bool is_arg_max)127 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
128   const TfLiteTensor* input;
129   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
130   const TfLiteTensor* axis;
131   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
132   TfLiteTensor* output;
133   TF_LITE_ENSURE_OK(context,
134                     GetOutputSafe(context, node, kOutputTensor, &output));
135   if (IsDynamicTensor(output)) {
136     TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
137   }
138 
139 #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
140   optimized_ops::ArgMinMax(                                    \
141       GetTensorShape(input), GetTensorData<data_type>(input),  \
142       GetTensorData<axis_type>(axis), GetTensorShape(output),  \
143       GetTensorData<output_type>(output), is_arg_max)
144   if (axis->type == kTfLiteInt32) {
145     switch (output->type) {
146       case kTfLiteInt32: {
147         switch (input->type) {
148           case kTfLiteFloat32:
149             TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
150             break;
151           case kTfLiteUInt8:
152             TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
153             break;
154           case kTfLiteInt8:
155             TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
156             break;
157           case kTfLiteInt32:
158             TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
159             break;
160           case kTfLiteBool:
161             TF_LITE_ARG_MIN_MAX(bool, int32_t, int32_t);
162             break;
163           default:
164             TF_LITE_KERNEL_LOG(context,
165                                "Only float32, uint8, int8, int32 and bool are "
166                                "supported currently, got %s.",
167                                TfLiteTypeGetName(input->type));
168             return kTfLiteError;
169         }
170       } break;
171       case kTfLiteInt64: {
172         switch (input->type) {
173           case kTfLiteFloat32:
174             TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
175             break;
176           case kTfLiteUInt8:
177             TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
178             break;
179           case kTfLiteInt8:
180             TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t);
181             break;
182           case kTfLiteInt32:
183             TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
184             break;
185           case kTfLiteBool:
186             TF_LITE_ARG_MIN_MAX(bool, int32_t, int64_t);
187             break;
188           default:
189             TF_LITE_KERNEL_LOG(context,
190                                "Only float32, uint8, int8, int32 and bool are "
191                                "supported currently, got %s.",
192                                TfLiteTypeGetName(input->type));
193             return kTfLiteError;
194         }
195       } break;
196       default:
197         TF_LITE_KERNEL_LOG(
198             context, "Only int32 and int64 are supported currently, got %s.",
199             TfLiteTypeGetName(output->type));
200         return kTfLiteError;
201     }
202   } else {
203     switch (output->type) {
204       case kTfLiteInt32: {
205         switch (input->type) {
206           case kTfLiteFloat32:
207             TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t);
208             break;
209           case kTfLiteUInt8:
210             TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
211             break;
212           case kTfLiteInt8:
213             TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t);
214             break;
215           case kTfLiteInt32:
216             TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
217             break;
218           case kTfLiteBool:
219             TF_LITE_ARG_MIN_MAX(bool, int64_t, int32_t);
220             break;
221           default:
222             TF_LITE_KERNEL_LOG(context,
223                                "Only float32, uint8, int8, int32 and bool are "
224                                "supported currently, got %s.",
225                                TfLiteTypeGetName(input->type));
226             return kTfLiteError;
227         }
228       } break;
229       case kTfLiteInt64: {
230         switch (input->type) {
231           case kTfLiteFloat32:
232             TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t);
233             break;
234           case kTfLiteUInt8:
235             TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
236             break;
237           case kTfLiteInt8:
238             TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t);
239             break;
240           case kTfLiteInt32:
241             TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
242             break;
243           case kTfLiteBool:
244             TF_LITE_ARG_MIN_MAX(bool, int64_t, int64_t);
245             break;
246           default:
247             TF_LITE_KERNEL_LOG(context,
248                                "Only float32, uint8, int8, int32 and bool are "
249                                "supported currently, got %s.",
250                                TfLiteTypeGetName(input->type));
251             return kTfLiteError;
252         }
253       } break;
254       default:
255         TF_LITE_KERNEL_LOG(
256             context, "Only int32 and int64 are supported currently, got %s.",
257             TfLiteTypeGetName(output->type));
258         return kTfLiteError;
259     }
260   }
261 #undef TF_LITE_ARG_MIN_MAX
262 
263   return kTfLiteOk;
264 }
265 
ArgMinEval(TfLiteContext * context,TfLiteNode * node)266 TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
267   return Eval(context, node, false);
268 }
269 
ArgMaxEval(TfLiteContext * context,TfLiteNode * node)270 TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
271   return Eval(context, node, true);
272 }
273 
274 }  // namespace arg_min_max
275 
Register_ARG_MAX()276 TfLiteRegistration* Register_ARG_MAX() {
277   static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
278                                  arg_min_max::ArgMaxEval};
279   return &r;
280 }
281 
Register_ARG_MIN()282 TfLiteRegistration* Register_ARG_MIN() {
283   static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
284                                  arg_min_max::ArgMinEval};
285   return &r;
286 }
287 
288 }  // namespace builtin
289 }  // namespace ops
290 }  // namespace tflite
291