• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <tuple>
19 #include <vector>
20 
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace {
28 
29 using ::testing::ElementsAreArray;
30 
31 class ArgBaseOpModel : public SingleOpModel {
32  public:
ArgBaseOpModel(TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)33   ArgBaseOpModel(TensorType input_type, int axis_value, TensorType axis_type,
34                  bool constant_axis, TensorType output_type)
35       : axis_value_(axis_value),
36         axis_type_(axis_type),
37         constant_axis_(constant_axis) {
38     input_ = AddInput(input_type);
39     if (constant_axis) {
40       if (axis_type == TensorType_INT64) {
41         axis_ =
42             AddConstInput(axis_type, {static_cast<int64_t>(axis_value)}, {1});
43       } else {
44         axis_ = AddConstInput(axis_type, {axis_value}, {1});
45       }
46     } else {
47       axis_ = AddInput(axis_type);
48     }
49     output_ = AddOutput(output_type);
50   }
51 
input() const52   int input() const { return input_; }
axis() const53   int axis() const { return axis_; }
54 
GetInt32Output() const55   std::vector<int32_t> GetInt32Output() const {
56     return ExtractVector<int32_t>(output_);
57   }
GetInt64Output() const58   std::vector<int64_t> GetInt64Output() const {
59     return ExtractVector<int64_t>(output_);
60   }
GetOutputShape()61   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
62 
63  protected:
PopulateAxisIfNeeded()64   void PopulateAxisIfNeeded() {
65     if (constant_axis_) return;
66     if (axis_type_ == TensorType_INT32) {
67       PopulateTensor<int32_t>(axis(), {axis_value_});
68     } else {
69       PopulateTensor<int64_t>(axis(), {axis_value_});
70     }
71   }
72 
73   const int axis_value_;
74   const TensorType axis_type_;
75   const bool constant_axis_;
76 
77   int input_;
78   int axis_;
79   int output_;
80 };
81 
82 class ArgMaxOpModel : public ArgBaseOpModel {
83  public:
ArgMaxOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)84   ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
85                 int axis_value, TensorType axis_type, bool constant_axis,
86                 TensorType output_type)
87       : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
88                        output_type) {
89     ArgBaseOpModel::SetBuiltinOp(
90         BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
91         CreateArgMaxOptions(ArgBaseOpModel::builder_, output_type).Union());
92     ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
93     PopulateAxisIfNeeded();
94   }
95 };
96 
97 class ArgMinOpModel : public ArgBaseOpModel {
98  public:
ArgMinOpModel(std::initializer_list<int> input_shape,TensorType input_type,int axis_value,TensorType axis_type,bool constant_axis,TensorType output_type)99   ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
100                 int axis_value, TensorType axis_type, bool constant_axis,
101                 TensorType output_type)
102       : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
103                        output_type) {
104     ArgBaseOpModel::SetBuiltinOp(
105         BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
106         CreateArgMinOptions(ArgBaseOpModel::builder_, output_type).Union());
107     ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
108     PopulateAxisIfNeeded();
109   }
110 };
111 
112 // Declare ArgMinMaxOpTest as a parameterized test, where the parameter is a
113 // tuple with:
114 // - boolean indicating whether to use a constant axis or not.
115 // - axis type (TensorType_INT32 or TensorType_INT64)
116 // - output type (TensorType_INT32 or TensorType_INT64)
117 class ArgMinMaxOpTest : public ::testing::TestWithParam<
118                             std::tuple<bool, TensorType, TensorType>> {
119  public:
ConstantAxis() const120   bool ConstantAxis() const { return std::get<0>(GetParam()); }
121 
AxisType() const122   TensorType AxisType() const { return std::get<1>(GetParam()); }
123 
OutputType() const124   TensorType OutputType() const { return std::get<2>(GetParam()); }
125 
ValidateOutput(const ArgBaseOpModel & model,const std::vector<int> & expected_output)126   void ValidateOutput(const ArgBaseOpModel& model,
127                       const std::vector<int>& expected_output) {
128     if (OutputType() == TensorType_INT32) {
129       EXPECT_THAT(model.GetInt32Output(), ElementsAreArray(expected_output));
130     } else {
131       EXPECT_THAT(model.GetInt64Output(), ElementsAreArray(expected_output));
132     }
133   }
134 };
135 
136 INSTANTIATE_TEST_SUITE_P(
137     ArgMinMaxOpTest, ArgMinMaxOpTest,
138     ::testing::Combine(::testing::Bool(),
139                        ::testing::Values(TensorType_INT32, TensorType_INT64),
140                        ::testing::Values(TensorType_INT32, TensorType_INT64)));
141 
TEST_P(ArgMinMaxOpTest,GetMaxArgFloat)142 TEST_P(ArgMinMaxOpTest, GetMaxArgFloat) {
143   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
144                       ConstantAxis(), OutputType());
145   model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
146   model.Invoke();
147 
148   ValidateOutput(model, {1});
149   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
150 }
151 
TEST_P(ArgMinMaxOpTest,GetMaxArgUInt8)152 TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8) {
153   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_UINT8, 3, AxisType(),
154                       ConstantAxis(), OutputType());
155   model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
156   model.Invoke();
157 
158   ValidateOutput(model, {1});
159   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
160 }
161 
TEST_P(ArgMinMaxOpTest,GetMaxArgInt8)162 TEST_P(ArgMinMaxOpTest, GetMaxArgInt8) {
163   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT8, 3, AxisType(),
164                       ConstantAxis(), OutputType());
165   model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
166   model.Invoke();
167 
168   ValidateOutput(model, {2});
169   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
170 }
171 
TEST_P(ArgMinMaxOpTest,GetMaxArgInt)172 TEST_P(ArgMinMaxOpTest, GetMaxArgInt) {
173   ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
174                       ConstantAxis(), OutputType());
175   model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
176   model.Invoke();
177 
178   ValidateOutput(model, {1});
179   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
180 }
181 
TEST_P(ArgMinMaxOpTest,GetMaxArgMulDimensions)182 TEST_P(ArgMinMaxOpTest, GetMaxArgMulDimensions) {
183   ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
184                       ConstantAxis(), OutputType());
185   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
186   model.Invoke();
187 
188   ValidateOutput(model, {3, 1});
189   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
190 }
191 
TEST_P(ArgMinMaxOpTest,GetMaxArgNegativeAxis)192 TEST_P(ArgMinMaxOpTest, GetMaxArgNegativeAxis) {
193   ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
194                       ConstantAxis(), OutputType());
195   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
196   model.Invoke();
197 
198   ValidateOutput(model, {0, 1, 0, 0});
199   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
200 }
201 
TEST_P(ArgMinMaxOpTest,GetMaxArgOutput64)202 TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
203   ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
204                       ConstantAxis(), OutputType());
205   model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
206   model.Invoke();
207 
208   ValidateOutput(model, {0, 1});
209   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
210 }
211 
TEST_P(ArgMinMaxOpTest,GetMinArgFloat)212 TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
213   ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
214                       ConstantAxis(), OutputType());
215   model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
216   model.Invoke();
217 
218   ValidateOutput(model, {0});
219   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
220 }
221 
TEST_P(ArgMinMaxOpTest,GetMinArgInt)222 TEST_P(ArgMinMaxOpTest, GetMinArgInt) {
223   ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
224                       ConstantAxis(), OutputType());
225   model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
226   model.Invoke();
227 
228   ValidateOutput(model, {0});
229   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
230 }
231 
TEST_P(ArgMinMaxOpTest,GetMinArgMulDimensions)232 TEST_P(ArgMinMaxOpTest, GetMinArgMulDimensions) {
233   ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
234                       ConstantAxis(), OutputType());
235   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
236   model.Invoke();
237 
238   ValidateOutput(model, {0, 0});
239   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
240 }
241 
TEST_P(ArgMinMaxOpTest,GetMinArgNegativeAxis)242 TEST_P(ArgMinMaxOpTest, GetMinArgNegativeAxis) {
243   ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
244                       ConstantAxis(), OutputType());
245   model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
246   model.Invoke();
247 
248   ValidateOutput(model, {0, 0, 0, 1});
249   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
250 }
251 
TEST_P(ArgMinMaxOpTest,GetMinArgOutput64)252 TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
253   ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
254                       ConstantAxis(), OutputType());
255   model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
256   model.Invoke();
257 
258   ValidateOutput(model, {1, 0});
259   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
260 }
261 
262 }  // namespace
263 }  // namespace tflite
264