1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
16
17 #include <stdint.h>
18
19 #include <functional>
20
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24 #include "tensorflow/lite/kernels/internal/quantization_util.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace arg_min_max {
33
34 constexpr int kInputTensor = 0;
35 constexpr int kAxis = 1;
36 constexpr int kOutputTensor = 0;
37
ResizeOutput(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * axis,TfLiteTensor * output)38 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* input,
39 const TfLiteTensor* axis, TfLiteTensor* output) {
40 int axis_value;
41 // Retrive all 8 bytes when axis type is kTfLiteInt64 to avoid data loss.
42 if (axis->type == kTfLiteInt64) {
43 axis_value = static_cast<int>(*GetTensorData<int64_t>(axis));
44 } else {
45 axis_value = *GetTensorData<int>(axis);
46 }
47 if (axis_value < 0) {
48 axis_value += NumDimensions(input);
49 }
50
51 TF_LITE_ENSURE(context, axis_value >= 0);
52 TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
53
54 // Copy the input dimensions to output except the axis dimension.
55 TfLiteIntArray* output_dims = TfLiteIntArrayCreate(NumDimensions(input) - 1);
56 int j = 0;
57 for (int i = 0; i < NumDimensions(input); ++i) {
58 if (i != axis_value) {
59 output_dims->data[j] = SizeOfDimension(input, i);
60 ++j;
61 }
62 }
63 return context->ResizeTensor(context, output, output_dims);
64 }
65
Prepare(TfLiteContext * context,TfLiteNode * node)66 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
67 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
68 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
69
70 const TfLiteTensor* input;
71 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
72 const TfLiteTensor* axis;
73 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
74 // Make sure the axis is only 1 dimension.
75 TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
76 // Make sure the axis is only either int32 or int64.
77 TF_LITE_ENSURE(context,
78 axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64);
79
80 TfLiteTensor* output;
81 TF_LITE_ENSURE_OK(context,
82 GetOutputSafe(context, node, kOutputTensor, &output));
83
84 auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
85 switch (params->output_type) {
86 case kTfLiteInt32:
87 output->type = kTfLiteInt32;
88 break;
89 case kTfLiteInt64:
90 output->type = kTfLiteInt64;
91 break;
92 default:
93 context->ReportError(context, "Unknown index output data type: %d",
94 params->output_type);
95 return kTfLiteError;
96 }
97
98 // Check conditions for different types.
99 switch (input->type) {
100 case kTfLiteFloat32:
101 case kTfLiteUInt8:
102 case kTfLiteInt8:
103 case kTfLiteInt32:
104 break;
105
106 default:
107 context->ReportError(
108 context,
109 "Unknown input type: %d, only float32 and int types are supported",
110 input->type);
111 return kTfLiteError;
112 }
113
114 TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
115
116 if (IsConstantTensor(axis)) {
117 TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
118 } else {
119 SetTensorToDynamic(output);
120 }
121
122 return kTfLiteOk;
123 }
124
Eval(TfLiteContext * context,TfLiteNode * node,bool is_arg_max)125 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
126 const TfLiteTensor* input;
127 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
128 const TfLiteTensor* axis;
129 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
130 TfLiteTensor* output;
131 TF_LITE_ENSURE_OK(context,
132 GetOutputSafe(context, node, kOutputTensor, &output));
133 if (IsDynamicTensor(output)) {
134 TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
135 }
136
137 #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
138 optimized_ops::ArgMinMax( \
139 GetTensorShape(input), GetTensorData<data_type>(input), \
140 GetTensorData<axis_type>(axis), GetTensorShape(output), \
141 GetTensorData<output_type>(output), is_arg_max)
142 if (axis->type == kTfLiteInt32) {
143 switch (output->type) {
144 case kTfLiteInt32: {
145 switch (input->type) {
146 case kTfLiteFloat32:
147 TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
148 break;
149 case kTfLiteUInt8:
150 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
151 break;
152 case kTfLiteInt8:
153 TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
154 break;
155 case kTfLiteInt32:
156 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
157 break;
158 default:
159 context->ReportError(context,
160 "Only float32, uint8, int8 and int32 are "
161 "supported currently, got %s.",
162 TfLiteTypeGetName(input->type));
163 return kTfLiteError;
164 }
165 } break;
166 case kTfLiteInt64: {
167 switch (input->type) {
168 case kTfLiteFloat32:
169 TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
170 break;
171 case kTfLiteUInt8:
172 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
173 break;
174 case kTfLiteInt8:
175 TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t);
176 break;
177 case kTfLiteInt32:
178 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
179 break;
180 default:
181 context->ReportError(context,
182 "Only float32, uint8, int8 and int32 are "
183 "supported currently, got %s.",
184 TfLiteTypeGetName(input->type));
185 return kTfLiteError;
186 }
187 } break;
188 default:
189 context->ReportError(
190 context, "Only int32 and int64 are supported currently, got %s.",
191 TfLiteTypeGetName(output->type));
192 return kTfLiteError;
193 }
194 } else {
195 switch (output->type) {
196 case kTfLiteInt32: {
197 switch (input->type) {
198 case kTfLiteFloat32:
199 TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t);
200 break;
201 case kTfLiteUInt8:
202 TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
203 break;
204 case kTfLiteInt8:
205 TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t);
206 break;
207 case kTfLiteInt32:
208 TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
209 break;
210 default:
211 context->ReportError(context,
212 "Only float32, uint8, int8 and int32 are "
213 "supported currently, got %s.",
214 TfLiteTypeGetName(input->type));
215 return kTfLiteError;
216 }
217 } break;
218 case kTfLiteInt64: {
219 switch (input->type) {
220 case kTfLiteFloat32:
221 TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t);
222 break;
223 case kTfLiteUInt8:
224 TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
225 break;
226 case kTfLiteInt8:
227 TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t);
228 break;
229 case kTfLiteInt32:
230 TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
231 break;
232 default:
233 context->ReportError(context,
234 "Only float32, uint8, int8 and int32 are "
235 "supported currently, got %s.",
236 TfLiteTypeGetName(input->type));
237 return kTfLiteError;
238 }
239 } break;
240 default:
241 context->ReportError(
242 context, "Only int32 and int64 are supported currently, got %s.",
243 TfLiteTypeGetName(output->type));
244 return kTfLiteError;
245 }
246 }
247 #undef TF_LITE_ARG_MIN_MAX
248
249 return kTfLiteOk;
250 }
251
ArgMinEval(TfLiteContext * context,TfLiteNode * node)252 TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
253 return Eval(context, node, false);
254 }
255
ArgMaxEval(TfLiteContext * context,TfLiteNode * node)256 TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
257 return Eval(context, node, true);
258 }
259
260 } // namespace arg_min_max
261
Register_ARG_MAX()262 TfLiteRegistration* Register_ARG_MAX() {
263 static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
264 arg_min_max::ArgMaxEval};
265 return &r;
266 }
267
Register_ARG_MIN()268 TfLiteRegistration* Register_ARG_MIN() {
269 static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
270 arg_min_max::ArgMinEval};
271 return &r;
272 }
273
274 } // namespace builtin
275 } // namespace ops
276 } // namespace tflite
277