• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
16 
17 #include <stdint.h>
18 
19 #include <functional>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24 #include "tensorflow/lite/kernels/internal/quantization_util.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace arg_min_max {
33 
34 constexpr int kInputTensor = 0;
35 constexpr int kAxis = 1;
36 constexpr int kOutputTensor = 0;
37 
ResizeOutput(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * axis,TfLiteTensor * output)38 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* input,
39                           const TfLiteTensor* axis, TfLiteTensor* output) {
40   int axis_value;
41   // Retrive all 8 bytes when axis type is kTfLiteInt64 to avoid data loss.
42   if (axis->type == kTfLiteInt64) {
43     axis_value = static_cast<int>(*GetTensorData<int64_t>(axis));
44   } else {
45     axis_value = *GetTensorData<int>(axis);
46   }
47   if (axis_value < 0) {
48     axis_value += NumDimensions(input);
49   }
50 
51   TF_LITE_ENSURE(context, axis_value >= 0);
52   TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
53 
54   // Copy the input dimensions to output except the axis dimension.
55   TfLiteIntArray* output_dims = TfLiteIntArrayCreate(NumDimensions(input) - 1);
56   int j = 0;
57   for (int i = 0; i < NumDimensions(input); ++i) {
58     if (i != axis_value) {
59       output_dims->data[j] = SizeOfDimension(input, i);
60       ++j;
61     }
62   }
63   return context->ResizeTensor(context, output, output_dims);
64 }
65 
Prepare(TfLiteContext * context,TfLiteNode * node)66 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
67   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
68   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
69 
70   const TfLiteTensor* input;
71   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
72   const TfLiteTensor* axis;
73   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
74   // Make sure the axis is only 1 dimension.
75   TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
76   // Make sure the axis is only either int32 or int64.
77   TF_LITE_ENSURE(context,
78                  axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64);
79 
80   TfLiteTensor* output;
81   TF_LITE_ENSURE_OK(context,
82                     GetOutputSafe(context, node, kOutputTensor, &output));
83 
84   auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
85   switch (params->output_type) {
86     case kTfLiteInt32:
87       output->type = kTfLiteInt32;
88       break;
89     case kTfLiteInt64:
90       output->type = kTfLiteInt64;
91       break;
92     default:
93       context->ReportError(context, "Unknown index output data type: %d",
94                            params->output_type);
95       return kTfLiteError;
96   }
97 
98   // Check conditions for different types.
99   switch (input->type) {
100     case kTfLiteFloat32:
101     case kTfLiteUInt8:
102     case kTfLiteInt8:
103     case kTfLiteInt32:
104       break;
105 
106     default:
107       context->ReportError(
108           context,
109           "Unknown input type: %d, only float32 and int types are supported",
110           input->type);
111       return kTfLiteError;
112   }
113 
114   TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
115 
116   if (IsConstantTensor(axis)) {
117     TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
118   } else {
119     SetTensorToDynamic(output);
120   }
121 
122   return kTfLiteOk;
123 }
124 
Eval(TfLiteContext * context,TfLiteNode * node,bool is_arg_max)125 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
126   const TfLiteTensor* input;
127   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
128   const TfLiteTensor* axis;
129   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
130   TfLiteTensor* output;
131   TF_LITE_ENSURE_OK(context,
132                     GetOutputSafe(context, node, kOutputTensor, &output));
133   if (IsDynamicTensor(output)) {
134     TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
135   }
136 
137 #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
138   optimized_ops::ArgMinMax(                                    \
139       GetTensorShape(input), GetTensorData<data_type>(input),  \
140       GetTensorData<axis_type>(axis), GetTensorShape(output),  \
141       GetTensorData<output_type>(output), is_arg_max)
142   if (axis->type == kTfLiteInt32) {
143     switch (output->type) {
144       case kTfLiteInt32: {
145         switch (input->type) {
146           case kTfLiteFloat32:
147             TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
148             break;
149           case kTfLiteUInt8:
150             TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
151             break;
152           case kTfLiteInt8:
153             TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
154             break;
155           case kTfLiteInt32:
156             TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
157             break;
158           default:
159             context->ReportError(context,
160                                  "Only float32, uint8, int8 and int32 are "
161                                  "supported currently, got %s.",
162                                  TfLiteTypeGetName(input->type));
163             return kTfLiteError;
164         }
165       } break;
166       case kTfLiteInt64: {
167         switch (input->type) {
168           case kTfLiteFloat32:
169             TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
170             break;
171           case kTfLiteUInt8:
172             TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
173             break;
174           case kTfLiteInt8:
175             TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t);
176             break;
177           case kTfLiteInt32:
178             TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
179             break;
180           default:
181             context->ReportError(context,
182                                  "Only float32, uint8, int8 and int32 are "
183                                  "supported currently, got %s.",
184                                  TfLiteTypeGetName(input->type));
185             return kTfLiteError;
186         }
187       } break;
188       default:
189         context->ReportError(
190             context, "Only int32 and int64 are supported currently, got %s.",
191             TfLiteTypeGetName(output->type));
192         return kTfLiteError;
193     }
194   } else {
195     switch (output->type) {
196       case kTfLiteInt32: {
197         switch (input->type) {
198           case kTfLiteFloat32:
199             TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t);
200             break;
201           case kTfLiteUInt8:
202             TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
203             break;
204           case kTfLiteInt8:
205             TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t);
206             break;
207           case kTfLiteInt32:
208             TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
209             break;
210           default:
211             context->ReportError(context,
212                                  "Only float32, uint8, int8 and int32 are "
213                                  "supported currently, got %s.",
214                                  TfLiteTypeGetName(input->type));
215             return kTfLiteError;
216         }
217       } break;
218       case kTfLiteInt64: {
219         switch (input->type) {
220           case kTfLiteFloat32:
221             TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t);
222             break;
223           case kTfLiteUInt8:
224             TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
225             break;
226           case kTfLiteInt8:
227             TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t);
228             break;
229           case kTfLiteInt32:
230             TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
231             break;
232           default:
233             context->ReportError(context,
234                                  "Only float32, uint8, int8 and int32 are "
235                                  "supported currently, got %s.",
236                                  TfLiteTypeGetName(input->type));
237             return kTfLiteError;
238         }
239       } break;
240       default:
241         context->ReportError(
242             context, "Only int32 and int64 are supported currently, got %s.",
243             TfLiteTypeGetName(output->type));
244         return kTfLiteError;
245     }
246   }
247 #undef TF_LITE_ARG_MIN_MAX
248 
249   return kTfLiteOk;
250 }
251 
ArgMinEval(TfLiteContext * context,TfLiteNode * node)252 TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
253   return Eval(context, node, false);
254 }
255 
ArgMaxEval(TfLiteContext * context,TfLiteNode * node)256 TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
257   return Eval(context, node, true);
258 }
259 
260 }  // namespace arg_min_max
261 
Register_ARG_MAX()262 TfLiteRegistration* Register_ARG_MAX() {
263   static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
264                                  arg_min_max::ArgMaxEval};
265   return &r;
266 }
267 
Register_ARG_MIN()268 TfLiteRegistration* Register_ARG_MIN() {
269   static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
270                                  arg_min_max::ArgMinEval};
271   return &r;
272 }
273 
274 }  // namespace builtin
275 }  // namespace ops
276 }  // namespace tflite
277