1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <tuple>
19 #include <vector>
20
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25
26 namespace tflite {
27 namespace {
28
29 using ::testing::ElementsAreArray;
30
31 class ArgBaseOpModel : public SingleOpModel {
32 public:
ArgBaseOpModel(TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)33 ArgBaseOpModel(TensorType input_type, int axis_value, TensorType axis_type,
34 bool constant_axis, TensorType output_type)
35 : axis_value_(axis_value),
36 axis_type_(axis_type),
37 constant_axis_(constant_axis) {
38 input_ = AddInput(input_type);
39 if (constant_axis) {
40 if (axis_type == TensorType_INT64) {
41 axis_ =
42 AddConstInput(axis_type, {static_cast<int64_t>(axis_value)}, {1});
43 } else {
44 axis_ = AddConstInput(axis_type, {axis_value}, {1});
45 }
46 } else {
47 axis_ = AddInput(axis_type);
48 }
49 output_ = AddOutput(output_type);
50 }
51
input() const52 int input() const { return input_; }
axis() const53 int axis() const { return axis_; }
54
GetInt32Output() const55 std::vector<int32_t> GetInt32Output() const {
56 return ExtractVector<int32_t>(output_);
57 }
GetInt64Output() const58 std::vector<int64_t> GetInt64Output() const {
59 return ExtractVector<int64_t>(output_);
60 }
GetOutputShape()61 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
62
63 protected:
PopulateAxisIfNeeded()64 void PopulateAxisIfNeeded() {
65 if (constant_axis_) return;
66 if (axis_type_ == TensorType_INT32) {
67 PopulateTensor<int32_t>(axis(), {axis_value_});
68 } else {
69 PopulateTensor<int64_t>(axis(), {axis_value_});
70 }
71 }
72
73 const int axis_value_;
74 const TensorType axis_type_;
75 const bool constant_axis_;
76
77 int input_;
78 int axis_;
79 int output_;
80 };
81
82 class ArgMaxOpModel : public ArgBaseOpModel {
83 public:
ArgMaxOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)84 ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
85 int axis_value, TensorType axis_type, bool constant_axis,
86 TensorType output_type)
87 : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
88 output_type) {
89 ArgBaseOpModel::SetBuiltinOp(
90 BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
91 CreateArgMaxOptions(ArgBaseOpModel::builder_, output_type).Union());
92 ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
93 PopulateAxisIfNeeded();
94 }
95 };
96
97 class ArgMinOpModel : public ArgBaseOpModel {
98 public:
ArgMinOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)99 ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
100 int axis_value, TensorType axis_type, bool constant_axis,
101 TensorType output_type)
102 : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
103 output_type) {
104 ArgBaseOpModel::SetBuiltinOp(
105 BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
106 CreateArgMinOptions(ArgBaseOpModel::builder_, output_type).Union());
107 ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
108 PopulateAxisIfNeeded();
109 }
110 };
111
112 // Declare ArgMinMaxOpTest as a parameterized test, where the parameter is a
113 // tuple with:
114 // - boolean indicating whether to use a constant axis or not.
115 // - axis type (TensorType_INT32 or TensorType_INT64)
116 // - output type (TensorType_INT32 or TensorType_INT64)
117 class ArgMinMaxOpTest : public ::testing::TestWithParam<
118 std::tuple<bool, TensorType, TensorType>> {
119 public:
ConstantAxis() const120 bool ConstantAxis() const { return std::get<0>(GetParam()); }
121
AxisType() const122 TensorType AxisType() const { return std::get<1>(GetParam()); }
123
OutputType() const124 TensorType OutputType() const { return std::get<2>(GetParam()); }
125
ValidateOutput(const ArgBaseOpModel & model,const std::vector<int> & expected_output)126 void ValidateOutput(const ArgBaseOpModel& model,
127 const std::vector<int>& expected_output) {
128 if (OutputType() == TensorType_INT32) {
129 EXPECT_THAT(model.GetInt32Output(), ElementsAreArray(expected_output));
130 } else {
131 EXPECT_THAT(model.GetInt64Output(), ElementsAreArray(expected_output));
132 }
133 }
134 };
135
136 INSTANTIATE_TEST_SUITE_P(
137 ArgMinMaxOpTest, ArgMinMaxOpTest,
138 ::testing::Combine(::testing::Bool(),
139 ::testing::Values(TensorType_INT32, TensorType_INT64),
140 ::testing::Values(TensorType_INT32, TensorType_INT64)));
141
TEST_P(ArgMinMaxOpTest,GetMaxArgFloat)142 TEST_P(ArgMinMaxOpTest, GetMaxArgFloat) {
143 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
144 ConstantAxis(), OutputType());
145 model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
146 ASSERT_EQ(model.Invoke(), kTfLiteOk);
147
148 ValidateOutput(model, {1});
149 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
150 }
151
TEST_P(ArgMinMaxOpTest,GetMaxArgUInt8)152 TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8) {
153 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_UINT8, 3, AxisType(),
154 ConstantAxis(), OutputType());
155 model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
156 ASSERT_EQ(model.Invoke(), kTfLiteOk);
157
158 ValidateOutput(model, {1});
159 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
160 }
161
TEST_P(ArgMinMaxOpTest,GetMaxArgInt8)162 TEST_P(ArgMinMaxOpTest, GetMaxArgInt8) {
163 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT8, 3, AxisType(),
164 ConstantAxis(), OutputType());
165 model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
166 ASSERT_EQ(model.Invoke(), kTfLiteOk);
167
168 ValidateOutput(model, {2});
169 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
170 }
171
TEST_P(ArgMinMaxOpTest,GetMaxArgInt)172 TEST_P(ArgMinMaxOpTest, GetMaxArgInt) {
173 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
174 ConstantAxis(), OutputType());
175 model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
176 ASSERT_EQ(model.Invoke(), kTfLiteOk);
177
178 ValidateOutput(model, {1});
179 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
180 }
181
TEST_P(ArgMinMaxOpTest,GetMaxArgBool)182 TEST_P(ArgMinMaxOpTest, GetMaxArgBool) {
183 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_BOOL, 3, AxisType(),
184 ConstantAxis(), OutputType());
185 model.PopulateTensor<bool>(model.input(), {true, false, false, false});
186 ASSERT_EQ(model.Invoke(), kTfLiteOk);
187
188 ValidateOutput(model, {0});
189 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
190 }
191
TEST_P(ArgMinMaxOpTest,GetMaxArgMulDimensions)192 TEST_P(ArgMinMaxOpTest, GetMaxArgMulDimensions) {
193 ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
194 ConstantAxis(), OutputType());
195 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
196 ASSERT_EQ(model.Invoke(), kTfLiteOk);
197
198 ValidateOutput(model, {3, 1});
199 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
200 }
201
TEST_P(ArgMinMaxOpTest,GetMaxArgNegativeAxis)202 TEST_P(ArgMinMaxOpTest, GetMaxArgNegativeAxis) {
203 ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
204 ConstantAxis(), OutputType());
205 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
206 ASSERT_EQ(model.Invoke(), kTfLiteOk);
207
208 ValidateOutput(model, {0, 1, 0, 0});
209 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
210 }
211
TEST_P(ArgMinMaxOpTest,GetMaxArgOutput64)212 TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
213 ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
214 ConstantAxis(), OutputType());
215 model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
216 ASSERT_EQ(model.Invoke(), kTfLiteOk);
217
218 ValidateOutput(model, {0, 1});
219 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
220 }
221
TEST_P(ArgMinMaxOpTest,GetMaxArgFloatLastAxis)222 TEST_P(ArgMinMaxOpTest, GetMaxArgFloatLastAxis) {
223 std::vector<float> input{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0};
224 for (int i = 1; i < 10; ++i) {
225 ArgMaxOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
226 OutputType());
227 model.PopulateTensor<float>(
228 model.input(), std::vector<float>(input.begin(), input.begin() + i));
229 ASSERT_EQ(model.Invoke(), kTfLiteOk);
230
231 ValidateOutput(model, {i - 1});
232 }
233 }
234
TEST_P(ArgMinMaxOpTest,GetMinArgFloat)235 TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
236 ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
237 ConstantAxis(), OutputType());
238 model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
239 ASSERT_EQ(model.Invoke(), kTfLiteOk);
240
241 ValidateOutput(model, {0});
242 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
243 }
244
TEST_P(ArgMinMaxOpTest,GetMinArgInt)245 TEST_P(ArgMinMaxOpTest, GetMinArgInt) {
246 ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
247 ConstantAxis(), OutputType());
248 model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
249 ASSERT_EQ(model.Invoke(), kTfLiteOk);
250
251 ValidateOutput(model, {0});
252 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
253 }
254
TEST_P(ArgMinMaxOpTest,GetMinArgBool)255 TEST_P(ArgMinMaxOpTest, GetMinArgBool) {
256 ArgMinOpModel model({1, 1, 1, 4}, TensorType_BOOL, 3, AxisType(),
257 ConstantAxis(), OutputType());
258 model.PopulateTensor<bool>(model.input(), {true, false, true, true});
259 ASSERT_EQ(model.Invoke(), kTfLiteOk);
260
261 ValidateOutput(model, {1});
262 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
263 }
264
TEST_P(ArgMinMaxOpTest,GetMinArgMulDimensions)265 TEST_P(ArgMinMaxOpTest, GetMinArgMulDimensions) {
266 ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
267 ConstantAxis(), OutputType());
268 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
269 ASSERT_EQ(model.Invoke(), kTfLiteOk);
270
271 ValidateOutput(model, {0, 0});
272 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
273 }
274
TEST_P(ArgMinMaxOpTest,GetMinArgNegativeAxis)275 TEST_P(ArgMinMaxOpTest, GetMinArgNegativeAxis) {
276 ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
277 ConstantAxis(), OutputType());
278 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
279 ASSERT_EQ(model.Invoke(), kTfLiteOk);
280
281 ValidateOutput(model, {0, 0, 0, 1});
282 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
283 }
284
TEST_P(ArgMinMaxOpTest,GetMinArgOutput64)285 TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
286 ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
287 ConstantAxis(), OutputType());
288 model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
289 ASSERT_EQ(model.Invoke(), kTfLiteOk);
290
291 ValidateOutput(model, {1, 0});
292 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
293 }
294
TEST_P(ArgMinMaxOpTest,GetMinArgFloatLastAxis)295 TEST_P(ArgMinMaxOpTest, GetMinArgFloatLastAxis) {
296 std::vector<float> input{1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1};
297 for (int i = 1; i < 10; ++i) {
298 ArgMinOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
299 OutputType());
300 model.PopulateTensor<float>(
301 model.input(), std::vector<float>(input.begin(), input.begin() + i));
302 ASSERT_EQ(model.Invoke(), kTfLiteOk);
303
304 ValidateOutput(model, {i - 1});
305 }
306 }
307
TEST_P(ArgMinMaxOpTest,GetMaxArgInt8LastAxis)308 TEST_P(ArgMinMaxOpTest, GetMaxArgInt8LastAxis) {
309 // Vector size for int8 is 16 elements, so 35 covers two SIMD widths
310 // Plus extras for testing
311 constexpr int INPUT_SIZE = 35;
312 std::vector<int8_t> input;
313 input.reserve(INPUT_SIZE);
314 for (int i = 0; i < INPUT_SIZE; i++) {
315 input.push_back(INPUT_SIZE - i);
316 }
317 for (int i = 1; i < INPUT_SIZE; ++i) {
318 ArgMinOpModel model({i}, TensorType_INT8, 0, AxisType(), ConstantAxis(),
319 OutputType());
320 model.PopulateTensor<int8_t>(
321 model.input(), std::vector<int8_t>(input.begin(), input.begin() + i));
322 ASSERT_EQ(model.Invoke(), kTfLiteOk);
323
324 ValidateOutput(model, {i - 1});
325 }
326 }
327
TEST_P(ArgMinMaxOpTest,GetMaxArgUInt8LastAxis)328 TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8LastAxis) {
329 // Vector size for int8 is 16 elements, so 35 covers two SIMD widths
330 // Plus extras for testing
331 constexpr int INPUT_SIZE = 35;
332 std::vector<uint8_t> input;
333 input.reserve(INPUT_SIZE);
334 for (unsigned int i = 0; i < INPUT_SIZE; i++) {
335 input.push_back(INPUT_SIZE - i);
336 }
337 for (int i = 1; i < INPUT_SIZE; ++i) {
338 ArgMinOpModel model({i}, TensorType_UINT8, 0, AxisType(), ConstantAxis(),
339 OutputType());
340 model.PopulateTensor<uint8_t>(
341 model.input(), std::vector<uint8_t>(input.begin(), input.begin() + i));
342 ASSERT_EQ(model.Invoke(), kTfLiteOk);
343
344 ValidateOutput(model, {i - 1});
345 }
346 }
347
348 } // namespace
349 } // namespace tflite
350