1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <type_traits>
19 #include <vector>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25
26 namespace tflite {
27 namespace {
28
29 using ::testing::ElementsAre;
30 using ::testing::ElementsAreArray;
31
32 template <typename T>
33 class PackOpModel : public SingleOpModel {
34 public:
PackOpModel(const TensorData & input_template,int axis,int values_count)35 PackOpModel(const TensorData& input_template, int axis, int values_count) {
36 std::vector<std::vector<int>> all_input_shapes;
37 for (int i = 0; i < values_count; ++i) {
38 all_input_shapes.push_back(input_template.shape);
39 AddInput(input_template);
40 }
41 output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
42 input_template.max});
43 SetBuiltinOp(BuiltinOperator_PACK, BuiltinOptions_PackOptions,
44 CreatePackOptions(builder_, values_count, axis).Union());
45 BuildInterpreter(all_input_shapes);
46 }
47
SetInput(int index,std::initializer_list<T> data)48 void SetInput(int index, std::initializer_list<T> data) {
49 PopulateTensor(index, data);
50 }
51
GetOutput()52 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()53 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
54
55 private:
56 int output_;
57 };
58
59 // float32 tests.
TEST(PackOpTest,FloatThreeInputs)60 TEST(PackOpTest, FloatThreeInputs) {
61 PackOpModel<float> model({TensorType_FLOAT32, {2}}, 0, 3);
62 model.SetInput(0, {1, 4});
63 model.SetInput(1, {2, 5});
64 model.SetInput(2, {3, 6});
65 model.Invoke();
66 EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
67 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
68 }
69
TEST(PackOpTest,FloatThreeInputsDifferentAxis)70 TEST(PackOpTest, FloatThreeInputsDifferentAxis) {
71 PackOpModel<float> model({TensorType_FLOAT32, {2}}, 1, 3);
72 model.SetInput(0, {1, 4});
73 model.SetInput(1, {2, 5});
74 model.SetInput(2, {3, 6});
75 model.Invoke();
76 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
77 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
78 }
79
TEST(PackOpTest,FloatThreeInputsNegativeAxis)80 TEST(PackOpTest, FloatThreeInputsNegativeAxis) {
81 PackOpModel<float> model({TensorType_FLOAT32, {2}}, -1, 3);
82 model.SetInput(0, {1, 4});
83 model.SetInput(1, {2, 5});
84 model.SetInput(2, {3, 6});
85 model.Invoke();
86 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
87 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
88 }
89
TEST(PackOpTest,FloatMultilDimensions)90 TEST(PackOpTest, FloatMultilDimensions) {
91 PackOpModel<float> model({TensorType_FLOAT32, {2, 3}}, 1, 2);
92 model.SetInput(0, {1, 2, 3, 4, 5, 6});
93 model.SetInput(1, {7, 8, 9, 10, 11, 12});
94 model.Invoke();
95 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
96 EXPECT_THAT(model.GetOutput(),
97 ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
98 }
99
TEST(PackOpTest,FloatFiveDimensions)100 TEST(PackOpTest, FloatFiveDimensions) {
101 PackOpModel<float> model({TensorType_FLOAT32, {2, 2, 2, 2}}, 1, 2);
102 model.SetInput(0, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
103 model.SetInput(
104 1, {17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
105 model.Invoke();
106 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 2, 2, 2));
107 EXPECT_THAT(model.GetOutput(),
108 ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19,
109 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14,
110 15, 16, 25, 26, 27, 28, 29, 30, 31, 32}));
111 }
112
113 // int32 tests.
TEST(PackOpTest,Int32ThreeInputs)114 TEST(PackOpTest, Int32ThreeInputs) {
115 PackOpModel<int32_t> model({TensorType_INT32, {2}}, 0, 3);
116 model.SetInput(0, {1, 4});
117 model.SetInput(1, {2, 5});
118 model.SetInput(2, {3, 6});
119 model.Invoke();
120 EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
121 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
122 }
123
TEST(PackOpTest,Int32ThreeInputsDifferentAxis)124 TEST(PackOpTest, Int32ThreeInputsDifferentAxis) {
125 PackOpModel<int32_t> model({TensorType_INT32, {2}}, 1, 3);
126 model.SetInput(0, {1, 4});
127 model.SetInput(1, {2, 5});
128 model.SetInput(2, {3, 6});
129 model.Invoke();
130 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
131 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
132 }
133
TEST(PackOpTest,Int32ThreeInputsNegativeAxis)134 TEST(PackOpTest, Int32ThreeInputsNegativeAxis) {
135 PackOpModel<int32_t> model({TensorType_INT32, {2}}, -1, 3);
136 model.SetInput(0, {1, 4});
137 model.SetInput(1, {2, 5});
138 model.SetInput(2, {3, 6});
139 model.Invoke();
140 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
141 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
142 }
143
TEST(PackOpTest,Int32MultilDimensions)144 TEST(PackOpTest, Int32MultilDimensions) {
145 PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
146 model.SetInput(0, {1, 2, 3, 4, 5, 6});
147 model.SetInput(1, {7, 8, 9, 10, 11, 12});
148 model.Invoke();
149 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
150 EXPECT_THAT(model.GetOutput(),
151 ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
152 }
153
154 // int64 tests.
TEST(PackOpTest,Int64ThreeInputs)155 TEST(PackOpTest, Int64ThreeInputs) {
156 PackOpModel<int64_t> model({TensorType_INT64, {2}}, 0, 3);
157 model.SetInput(0, {1LL << 33, 4});
158 model.SetInput(1, {2, 5});
159 model.SetInput(2, {3, -(1LL << 34)});
160 model.Invoke();
161 EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
162 EXPECT_THAT(model.GetOutput(),
163 ElementsAreArray({1LL << 33, 4LL, 2LL, 5LL, 3LL, -(1LL << 34)}));
164 }
165
TEST(PackOpTest,Int64ThreeInputsDifferentAxis)166 TEST(PackOpTest, Int64ThreeInputsDifferentAxis) {
167 PackOpModel<int64_t> model({TensorType_INT64, {2}}, 1, 3);
168 model.SetInput(0, {1LL << 33, 4});
169 model.SetInput(1, {2, 5});
170 model.SetInput(2, {3, -(1LL << 34)});
171 model.Invoke();
172 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
173 EXPECT_THAT(model.GetOutput(),
174 ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
175 }
176
TEST(PackOpTest,Int64ThreeInputsNegativeAxis)177 TEST(PackOpTest, Int64ThreeInputsNegativeAxis) {
178 PackOpModel<int64_t> model({TensorType_INT64, {2}}, -1, 3);
179 model.SetInput(0, {1LL << 33, 4});
180 model.SetInput(1, {2, 5});
181 model.SetInput(2, {3, -(1LL << 34)});
182 model.Invoke();
183 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
184 EXPECT_THAT(model.GetOutput(),
185 ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
186 }
187
TEST(PackOpTest,Int64MultilDimensions)188 TEST(PackOpTest, Int64MultilDimensions) {
189 PackOpModel<int64_t> model({TensorType_INT64, {2, 3}}, 1, 2);
190 model.SetInput(0, {1LL << 33, 2, 3, 4, 5, 6});
191 model.SetInput(1, {7, 8, -(1LL << 34), 10, 11, 12});
192 model.Invoke();
193 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
194 EXPECT_THAT(model.GetOutput(),
195 ElementsAreArray({1LL << 33, 2LL, 3LL, 7LL, 8LL, -(1LL << 34),
196 4LL, 5LL, 6LL, 10LL, 11LL, 12LL}));
197 }
198
199 template <typename InputType>
200 struct PackOpTestInt : public ::testing::Test {
201 using TypeToTest = InputType;
202 TensorType TENSOR_TYPE =
203 (std::is_same<InputType, int16_t>::value
204 ? TensorType_INT16
205 : (std::is_same<InputType, uint8_t>::value ? TensorType_UINT8
206 : TensorType_INT8));
207 };
208
209 using TestTypes = testing::Types<int8_t, uint8_t, int16_t>;
210 TYPED_TEST_CASE(PackOpTestInt, TestTypes);
211
TYPED_TEST(PackOpTestInt,ThreeInputs)212 TYPED_TEST(PackOpTestInt, ThreeInputs) {
213 PackOpModel<typename TestFixture::TypeToTest> model(
214 {TestFixture::TENSOR_TYPE, {2}}, 0, 3);
215 model.SetInput(0, {1, 4});
216 model.SetInput(1, {2, 5});
217 model.SetInput(2, {3, 6});
218 model.Invoke();
219 EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
220 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
221 }
222
TYPED_TEST(PackOpTestInt,ThreeInputsDifferentAxis)223 TYPED_TEST(PackOpTestInt, ThreeInputsDifferentAxis) {
224 PackOpModel<typename TestFixture::TypeToTest> model(
225 {TestFixture::TENSOR_TYPE, {2}}, 1, 3);
226 model.SetInput(0, {1, 4});
227 model.SetInput(1, {2, 5});
228 model.SetInput(2, {3, 6});
229 model.Invoke();
230 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
231 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
232 }
233
TYPED_TEST(PackOpTestInt,ThreeInputsNegativeAxis)234 TYPED_TEST(PackOpTestInt, ThreeInputsNegativeAxis) {
235 PackOpModel<typename TestFixture::TypeToTest> model(
236 {TestFixture::TENSOR_TYPE, {2}}, -1, 3);
237 model.SetInput(0, {1, 4});
238 model.SetInput(1, {2, 5});
239 model.SetInput(2, {3, 6});
240 model.Invoke();
241 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
242 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
243 }
244
TYPED_TEST(PackOpTestInt,MultilDimensions)245 TYPED_TEST(PackOpTestInt, MultilDimensions) {
246 PackOpModel<typename TestFixture::TypeToTest> model(
247 {TestFixture::TENSOR_TYPE, {2, 3}}, 1, 2);
248 model.SetInput(0, {1, 2, 3, 4, 5, 6});
249 model.SetInput(1, {7, 8, 9, 10, 11, 12});
250 model.Invoke();
251 EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
252 EXPECT_THAT(model.GetOutput(),
253 ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
254 }
255
256 } // namespace
257 } // namespace tflite
258