1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <tuple>
19 #include <vector>
20
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25
26 namespace tflite {
27 namespace {
28
29 using ::testing::ElementsAreArray;
30
31 class ArgBaseOpModel : public SingleOpModel {
32 public:
ArgBaseOpModel(TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)33 ArgBaseOpModel(TensorType input_type, int axis_value, TensorType axis_type,
34 bool constant_axis, TensorType output_type)
35 : axis_value_(axis_value),
36 axis_type_(axis_type),
37 constant_axis_(constant_axis) {
38 input_ = AddInput(input_type);
39 if (constant_axis) {
40 if (axis_type == TensorType_INT64) {
41 axis_ =
42 AddConstInput(axis_type, {static_cast<int64_t>(axis_value)}, {1});
43 } else {
44 axis_ = AddConstInput(axis_type, {axis_value}, {1});
45 }
46 } else {
47 axis_ = AddInput(axis_type);
48 }
49 output_ = AddOutput(output_type);
50 }
51
input() const52 int input() const { return input_; }
axis() const53 int axis() const { return axis_; }
54
GetInt32Output() const55 std::vector<int32_t> GetInt32Output() const {
56 return ExtractVector<int32_t>(output_);
57 }
GetInt64Output() const58 std::vector<int64_t> GetInt64Output() const {
59 return ExtractVector<int64_t>(output_);
60 }
GetOutputShape()61 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
62
63 protected:
PopulateAxisIfNeeded()64 void PopulateAxisIfNeeded() {
65 if (constant_axis_) return;
66 if (axis_type_ == TensorType_INT32) {
67 PopulateTensor<int32_t>(axis(), {axis_value_});
68 } else {
69 PopulateTensor<int64_t>(axis(), {axis_value_});
70 }
71 }
72
73 const int axis_value_;
74 const TensorType axis_type_;
75 const bool constant_axis_;
76
77 int input_;
78 int axis_;
79 int output_;
80 };
81
82 class ArgMaxOpModel : public ArgBaseOpModel {
83 public:
ArgMaxOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)84 ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
85 int axis_value, TensorType axis_type, bool constant_axis,
86 TensorType output_type)
87 : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
88 output_type) {
89 ArgBaseOpModel::SetBuiltinOp(
90 BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
91 CreateArgMaxOptions(ArgBaseOpModel::builder_, output_type).Union());
92 ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
93 PopulateAxisIfNeeded();
94 }
95 };
96
97 class ArgMinOpModel : public ArgBaseOpModel {
98 public:
ArgMinOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)99 ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
100 int axis_value, TensorType axis_type, bool constant_axis,
101 TensorType output_type)
102 : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
103 output_type) {
104 ArgBaseOpModel::SetBuiltinOp(
105 BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
106 CreateArgMinOptions(ArgBaseOpModel::builder_, output_type).Union());
107 ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
108 PopulateAxisIfNeeded();
109 }
110 };
111
112 // Declare ArgMinMaxOpTest as a parameterized test, where the parameter is a
113 // tuple with:
114 // - boolean indicating whether to use a constant axis or not.
115 // - axis type (TensorType_INT32 or TensorType_INT64)
116 // - output type (TensorType_INT32 or TensorType_INT64)
117 class ArgMinMaxOpTest : public ::testing::TestWithParam<
118 std::tuple<bool, TensorType, TensorType>> {
119 public:
ConstantAxis() const120 bool ConstantAxis() const { return std::get<0>(GetParam()); }
121
AxisType() const122 TensorType AxisType() const { return std::get<1>(GetParam()); }
123
OutputType() const124 TensorType OutputType() const { return std::get<2>(GetParam()); }
125
ValidateOutput(const ArgBaseOpModel & model,const std::vector<int> & expected_output)126 void ValidateOutput(const ArgBaseOpModel& model,
127 const std::vector<int>& expected_output) {
128 if (OutputType() == TensorType_INT32) {
129 EXPECT_THAT(model.GetInt32Output(), ElementsAreArray(expected_output));
130 } else {
131 EXPECT_THAT(model.GetInt64Output(), ElementsAreArray(expected_output));
132 }
133 }
134 };
135
136 INSTANTIATE_TEST_SUITE_P(
137 ArgMinMaxOpTest, ArgMinMaxOpTest,
138 ::testing::Combine(::testing::Bool(),
139 ::testing::Values(TensorType_INT32, TensorType_INT64),
140 ::testing::Values(TensorType_INT32, TensorType_INT64)));
141
TEST_P(ArgMinMaxOpTest,GetMaxArgFloat)142 TEST_P(ArgMinMaxOpTest, GetMaxArgFloat) {
143 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
144 ConstantAxis(), OutputType());
145 model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
146 model.Invoke();
147
148 ValidateOutput(model, {1});
149 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
150 }
151
TEST_P(ArgMinMaxOpTest,GetMaxArgUInt8)152 TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8) {
153 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_UINT8, 3, AxisType(),
154 ConstantAxis(), OutputType());
155 model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
156 model.Invoke();
157
158 ValidateOutput(model, {1});
159 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
160 }
161
TEST_P(ArgMinMaxOpTest,GetMaxArgInt8)162 TEST_P(ArgMinMaxOpTest, GetMaxArgInt8) {
163 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT8, 3, AxisType(),
164 ConstantAxis(), OutputType());
165 model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
166 model.Invoke();
167
168 ValidateOutput(model, {2});
169 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
170 }
171
TEST_P(ArgMinMaxOpTest,GetMaxArgInt)172 TEST_P(ArgMinMaxOpTest, GetMaxArgInt) {
173 ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
174 ConstantAxis(), OutputType());
175 model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
176 model.Invoke();
177
178 ValidateOutput(model, {1});
179 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
180 }
181
TEST_P(ArgMinMaxOpTest,GetMaxArgMulDimensions)182 TEST_P(ArgMinMaxOpTest, GetMaxArgMulDimensions) {
183 ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
184 ConstantAxis(), OutputType());
185 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
186 model.Invoke();
187
188 ValidateOutput(model, {3, 1});
189 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
190 }
191
TEST_P(ArgMinMaxOpTest,GetMaxArgNegativeAxis)192 TEST_P(ArgMinMaxOpTest, GetMaxArgNegativeAxis) {
193 ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
194 ConstantAxis(), OutputType());
195 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
196 model.Invoke();
197
198 ValidateOutput(model, {0, 1, 0, 0});
199 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
200 }
201
TEST_P(ArgMinMaxOpTest,GetMaxArgOutput64)202 TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
203 ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
204 ConstantAxis(), OutputType());
205 model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
206 model.Invoke();
207
208 ValidateOutput(model, {0, 1});
209 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
210 }
211
TEST_P(ArgMinMaxOpTest,GetMinArgFloat)212 TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
213 ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
214 ConstantAxis(), OutputType());
215 model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
216 model.Invoke();
217
218 ValidateOutput(model, {0});
219 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
220 }
221
TEST_P(ArgMinMaxOpTest,GetMinArgInt)222 TEST_P(ArgMinMaxOpTest, GetMinArgInt) {
223 ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
224 ConstantAxis(), OutputType());
225 model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
226 model.Invoke();
227
228 ValidateOutput(model, {0});
229 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
230 }
231
TEST_P(ArgMinMaxOpTest,GetMinArgMulDimensions)232 TEST_P(ArgMinMaxOpTest, GetMinArgMulDimensions) {
233 ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
234 ConstantAxis(), OutputType());
235 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
236 model.Invoke();
237
238 ValidateOutput(model, {0, 0});
239 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
240 }
241
TEST_P(ArgMinMaxOpTest,GetMinArgNegativeAxis)242 TEST_P(ArgMinMaxOpTest, GetMinArgNegativeAxis) {
243 ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
244 ConstantAxis(), OutputType());
245 model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
246 model.Invoke();
247
248 ValidateOutput(model, {0, 0, 0, 1});
249 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
250 }
251
TEST_P(ArgMinMaxOpTest,GetMinArgOutput64)252 TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
253 ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
254 ConstantAxis(), OutputType());
255 model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
256 model.Invoke();
257
258 ValidateOutput(model, {1, 0});
259 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
260 }
261
262 } // namespace
263 } // namespace tflite
264