• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <tuple>
19 #include <vector>
20 
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace {
28 
29 using ::testing::ElementsAreArray;
30 
31 class ArgBaseOpModel : public SingleOpModel {
32  public:
ArgBaseOpModel(TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)33   ArgBaseOpModel(TensorType input_type, int axis_value, TensorType axis_type,
34                  bool constant_axis, TensorType output_type)
35       : axis_value_(axis_value),
36         axis_type_(axis_type),
37         constant_axis_(constant_axis) {
38     input_ = AddInput(input_type);
39     if (constant_axis) {
40       if (axis_type == TensorType_INT64) {
41         axis_ =
42             AddConstInput(axis_type, {static_cast<int64_t>(axis_value)}, {1});
43       } else {
44         axis_ = AddConstInput(axis_type, {axis_value}, {1});
45       }
46     } else {
47       axis_ = AddInput(axis_type);
48     }
49     output_ = AddOutput(output_type);
50   }
51 
input() const52   int input() const { return input_; }
axis() const53   int axis() const { return axis_; }
54 
GetInt32Output() const55   std::vector<int32_t> GetInt32Output() const {
56     return ExtractVector<int32_t>(output_);
57   }
GetInt64Output() const58   std::vector<int64_t> GetInt64Output() const {
59     return ExtractVector<int64_t>(output_);
60   }
GetOutputShape()61   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
62 
63  protected:
PopulateAxisIfNeeded()64   void PopulateAxisIfNeeded() {
65     if (constant_axis_) return;
66     if (axis_type_ == TensorType_INT32) {
67       PopulateTensor<int32_t>(axis(), {axis_value_});
68     } else {
69       PopulateTensor<int64_t>(axis(), {axis_value_});
70     }
71   }
72 
73   const int axis_value_;
74   const TensorType axis_type_;
75   const bool constant_axis_;
76 
77   int input_;
78   int axis_;
79   int output_;
80 };
81 
82 class ArgMaxOpModel : public ArgBaseOpModel {
83  public:
ArgMaxOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)84   ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
85                 int axis_value, TensorType axis_type, bool constant_axis,
86                 TensorType output_type)
87       : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
88                        output_type) {
89     ArgBaseOpModel::SetBuiltinOp(
90         BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
91         CreateArgMaxOptions(ArgBaseOpModel::builder_, output_type).Union());
92     ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
93     PopulateAxisIfNeeded();
94   }
95 };
96 
97 class ArgMinOpModel : public ArgBaseOpModel {
98  public:
ArgMinOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)99   ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
100                 int axis_value, TensorType axis_type, bool constant_axis,
101                 TensorType output_type)
102       : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
103                        output_type) {
104     ArgBaseOpModel::SetBuiltinOp(
105         BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
106         CreateArgMinOptions(ArgBaseOpModel::builder_, output_type).Union());
107     ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
108     PopulateAxisIfNeeded();
109   }
110 };
111 
112 // Declare ArgMinMaxOpTest as a parameterized test, where the parameter is a
113 // tuple with:
114 // - boolean indicating whether to use a constant axis or not.
115 // - axis type (TensorType_INT32 or TensorType_INT64)
116 // - output type (TensorType_INT32 or TensorType_INT64)
117 class ArgMinMaxOpTest : public ::testing::TestWithParam<
118                             std::tuple<bool, TensorType, TensorType>> {
119  public:
ConstantAxis() const120   bool ConstantAxis() const { return std::get<0>(GetParam()); }
121 
AxisType() const122   TensorType AxisType() const { return std::get<1>(GetParam()); }
123 
OutputType() const124   TensorType OutputType() const { return std::get<2>(GetParam()); }
125 
ValidateOutput(const ArgBaseOpModel & model,const std::vector<int> & expected_output)126   void ValidateOutput(const ArgBaseOpModel& model,
127                       const std::vector<int>& expected_output) {
128     if (OutputType() == TensorType_INT32) {
129       EXPECT_THAT(model.GetInt32Output(), ElementsAreArray(expected_output));
130     } else {
131       EXPECT_THAT(model.GetInt64Output(), ElementsAreArray(expected_output));
132     }
133   }
134 };
135 
136 INSTANTIATE_TEST_SUITE_P(
137     ArgMinMaxOpTest, ArgMinMaxOpTest,
138     ::testing::Combine(::testing::Bool(),
139                        ::testing::Values(TensorType_INT32, TensorType_INT64),
140                        ::testing::Values(TensorType_INT32, TensorType_INT64)));
141 
TEST_P(ArgMinMaxOpTest,GetMaxArgFloat)142 TEST_P(ArgMinMaxOpTest, GetMaxArgFloat) {
143   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
144                       ConstantAxis(), OutputType());
145   model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
146   ASSERT_EQ(model.Invoke(), kTfLiteOk);
147 
148   ValidateOutput(model, {1});
149   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
150 }
151 
TEST_P(ArgMinMaxOpTest,GetMaxArgUInt8)152 TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8) {
153   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_UINT8, 3, AxisType(),
154                       ConstantAxis(), OutputType());
155   model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
156   ASSERT_EQ(model.Invoke(), kTfLiteOk);
157 
158   ValidateOutput(model, {1});
159   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
160 }
161 
TEST_P(ArgMinMaxOpTest,GetMaxArgInt8)162 TEST_P(ArgMinMaxOpTest, GetMaxArgInt8) {
163   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT8, 3, AxisType(),
164                       ConstantAxis(), OutputType());
165   model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
166   ASSERT_EQ(model.Invoke(), kTfLiteOk);
167 
168   ValidateOutput(model, {2});
169   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
170 }
171 
TEST_P(ArgMinMaxOpTest,GetMaxArgInt)172 TEST_P(ArgMinMaxOpTest, GetMaxArgInt) {
173   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
174                       ConstantAxis(), OutputType());
175   model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
176   ASSERT_EQ(model.Invoke(), kTfLiteOk);
177 
178   ValidateOutput(model, {1});
179   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
180 }
181 
TEST_P(ArgMinMaxOpTest,GetMaxArgBool)182 TEST_P(ArgMinMaxOpTest, GetMaxArgBool) {
183   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_BOOL, 3, AxisType(),
184                       ConstantAxis(), OutputType());
185   model.PopulateTensor<bool>(model.input(), {true, false, false, false});
186   ASSERT_EQ(model.Invoke(), kTfLiteOk);
187 
188   ValidateOutput(model, {0});
189   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
190 }
191 
TEST_P(ArgMinMaxOpTest,GetMaxArgMulDimensions)192 TEST_P(ArgMinMaxOpTest, GetMaxArgMulDimensions) {
193   ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
194                       ConstantAxis(), OutputType());
195   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
196   ASSERT_EQ(model.Invoke(), kTfLiteOk);
197 
198   ValidateOutput(model, {3, 1});
199   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
200 }
201 
TEST_P(ArgMinMaxOpTest,GetMaxArgNegativeAxis)202 TEST_P(ArgMinMaxOpTest, GetMaxArgNegativeAxis) {
203   ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
204                       ConstantAxis(), OutputType());
205   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
206   ASSERT_EQ(model.Invoke(), kTfLiteOk);
207 
208   ValidateOutput(model, {0, 1, 0, 0});
209   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
210 }
211 
TEST_P(ArgMinMaxOpTest,GetMaxArgOutput64)212 TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
213   ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
214                       ConstantAxis(), OutputType());
215   model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
216   ASSERT_EQ(model.Invoke(), kTfLiteOk);
217 
218   ValidateOutput(model, {0, 1});
219   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
220 }
221 
TEST_P(ArgMinMaxOpTest,GetMaxArgFloatLastAxis)222 TEST_P(ArgMinMaxOpTest, GetMaxArgFloatLastAxis) {
223   std::vector<float> input{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0};
224   for (int i = 1; i < 10; ++i) {
225     ArgMaxOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
226                         OutputType());
227     model.PopulateTensor<float>(
228         model.input(), std::vector<float>(input.begin(), input.begin() + i));
229     ASSERT_EQ(model.Invoke(), kTfLiteOk);
230 
231     ValidateOutput(model, {i - 1});
232   }
233 }
234 
TEST_P(ArgMinMaxOpTest,GetMinArgFloat)235 TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
236   ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
237                       ConstantAxis(), OutputType());
238   model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
239   ASSERT_EQ(model.Invoke(), kTfLiteOk);
240 
241   ValidateOutput(model, {0});
242   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
243 }
244 
TEST_P(ArgMinMaxOpTest,GetMinArgInt)245 TEST_P(ArgMinMaxOpTest, GetMinArgInt) {
246   ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
247                       ConstantAxis(), OutputType());
248   model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
249   ASSERT_EQ(model.Invoke(), kTfLiteOk);
250 
251   ValidateOutput(model, {0});
252   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
253 }
254 
TEST_P(ArgMinMaxOpTest,GetMinArgBool)255 TEST_P(ArgMinMaxOpTest, GetMinArgBool) {
256   ArgMinOpModel model({1, 1, 1, 4}, TensorType_BOOL, 3, AxisType(),
257                       ConstantAxis(), OutputType());
258   model.PopulateTensor<bool>(model.input(), {true, false, true, true});
259   ASSERT_EQ(model.Invoke(), kTfLiteOk);
260 
261   ValidateOutput(model, {1});
262   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
263 }
264 
TEST_P(ArgMinMaxOpTest,GetMinArgMulDimensions)265 TEST_P(ArgMinMaxOpTest, GetMinArgMulDimensions) {
266   ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
267                       ConstantAxis(), OutputType());
268   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
269   ASSERT_EQ(model.Invoke(), kTfLiteOk);
270 
271   ValidateOutput(model, {0, 0});
272   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
273 }
274 
TEST_P(ArgMinMaxOpTest,GetMinArgNegativeAxis)275 TEST_P(ArgMinMaxOpTest, GetMinArgNegativeAxis) {
276   ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
277                       ConstantAxis(), OutputType());
278   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
279   ASSERT_EQ(model.Invoke(), kTfLiteOk);
280 
281   ValidateOutput(model, {0, 0, 0, 1});
282   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
283 }
284 
TEST_P(ArgMinMaxOpTest,GetMinArgOutput64)285 TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
286   ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
287                       ConstantAxis(), OutputType());
288   model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
289   ASSERT_EQ(model.Invoke(), kTfLiteOk);
290 
291   ValidateOutput(model, {1, 0});
292   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
293 }
294 
TEST_P(ArgMinMaxOpTest,GetMinArgFloatLastAxis)295 TEST_P(ArgMinMaxOpTest, GetMinArgFloatLastAxis) {
296   std::vector<float> input{1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1};
297   for (int i = 1; i < 10; ++i) {
298     ArgMinOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
299                         OutputType());
300     model.PopulateTensor<float>(
301         model.input(), std::vector<float>(input.begin(), input.begin() + i));
302     ASSERT_EQ(model.Invoke(), kTfLiteOk);
303 
304     ValidateOutput(model, {i - 1});
305   }
306 }
307 
TEST_P(ArgMinMaxOpTest,GetMaxArgInt8LastAxis)308 TEST_P(ArgMinMaxOpTest, GetMaxArgInt8LastAxis) {
309   // Vector size for int8 is 16 elements, so 35 covers two SIMD widths
310   // Plus extras for testing
311   constexpr int INPUT_SIZE = 35;
312   std::vector<int8_t> input;
313   input.reserve(INPUT_SIZE);
314   for (int i = 0; i < INPUT_SIZE; i++) {
315     input.push_back(INPUT_SIZE - i);
316   }
317   for (int i = 1; i < INPUT_SIZE; ++i) {
318     ArgMinOpModel model({i}, TensorType_INT8, 0, AxisType(), ConstantAxis(),
319                         OutputType());
320     model.PopulateTensor<int8_t>(
321         model.input(), std::vector<int8_t>(input.begin(), input.begin() + i));
322     ASSERT_EQ(model.Invoke(), kTfLiteOk);
323 
324     ValidateOutput(model, {i - 1});
325   }
326 }
327 
TEST_P(ArgMinMaxOpTest,GetMaxArgUInt8LastAxis)328 TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8LastAxis) {
329   // Vector size for int8 is 16 elements, so 35 covers two SIMD widths
330   // Plus extras for testing
331   constexpr int INPUT_SIZE = 35;
332   std::vector<uint8_t> input;
333   input.reserve(INPUT_SIZE);
334   for (unsigned int i = 0; i < INPUT_SIZE; i++) {
335     input.push_back(INPUT_SIZE - i);
336   }
337   for (int i = 1; i < INPUT_SIZE; ++i) {
338     ArgMinOpModel model({i}, TensorType_UINT8, 0, AxisType(), ConstantAxis(),
339                         OutputType());
340     model.PopulateTensor<uint8_t>(
341         model.input(), std::vector<uint8_t>(input.begin(), input.begin() + i));
342     ASSERT_EQ(model.Invoke(), kTfLiteOk);
343 
344     ValidateOutput(model, {i - 1});
345   }
346 }
347 
348 }  // namespace
349 }  // namespace tflite
350