1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <initializer_list>
17 #include <vector>
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/kernels/test_util.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23
24 namespace tflite {
25 namespace {
26
27 using ::testing::ElementsAre;
28 using ::testing::ElementsAreArray;
29
30 enum class TestType {
31 kConst = 0,
32 kDynamic = 1,
33 };
34
35 class Conv3dTransposeOpModel : public SingleOpModel {
36 public:
Conv3dTransposeOpModel(std::initializer_list<int> output_shape_data,const TensorData & filter,const TensorData & input,const TensorData & bias,const TensorData & output,TestType test_type,Padding padding=Padding_VALID,int32_t stride_depth=1,int32_t stride_width=1,int32_t stride_height=1,ActivationFunctionType activation=ActivationFunctionType_NONE,int32_t dilation_depth=1,int32_t dilation_width=1,int32_t dilation_height=1)37 Conv3dTransposeOpModel(
38 std::initializer_list<int> output_shape_data, const TensorData& filter,
39 const TensorData& input, const TensorData& bias, const TensorData& output,
40 TestType test_type, Padding padding = Padding_VALID,
41 int32_t stride_depth = 1, int32_t stride_width = 1,
42 int32_t stride_height = 1,
43 ActivationFunctionType activation = ActivationFunctionType_NONE,
44 int32_t dilation_depth = 1, int32_t dilation_width = 1,
45 int32_t dilation_height = 1) {
46 if (test_type == TestType::kDynamic) {
47 output_shape_ = AddInput({TensorType_INT32, {5}});
48 } else {
49 output_shape_ = AddConstInput(TensorType_INT32, output_shape_data, {5});
50 }
51 filter_ = AddInput(filter);
52 input_ = AddInput(input);
53 bias_ = AddInput(bias);
54 output_ = AddOutput(output);
55 SetBuiltinOp(
56 BuiltinOperator_CONV_3D_TRANSPOSE, BuiltinOptions_Conv3DOptions,
57 CreateConv3DOptions(builder_, padding, stride_depth, stride_width,
58 stride_height, activation, dilation_depth,
59 dilation_width, dilation_height)
60 .Union());
61 BuildInterpreter({GetShape(output_shape_), GetShape(filter_),
62 GetShape(input_), GetShape(bias_)});
63
64 if (test_type == TestType::kDynamic) {
65 PopulateTensor(output_shape_, output_shape_data);
66 }
67 }
68
Conv3dTransposeOpModel(std::initializer_list<int> output_shape_data,const TensorData & filter,const TensorData & input,const TensorData & output,TestType test_type,Padding padding=Padding_VALID,int32_t stride_depth=1,int32_t stride_width=1,int32_t stride_height=1,ActivationFunctionType activation=ActivationFunctionType_NONE,int32_t dilation_depth=1,int32_t dilation_width=1,int32_t dilation_height=1)69 Conv3dTransposeOpModel(
70 std::initializer_list<int> output_shape_data, const TensorData& filter,
71 const TensorData& input, const TensorData& output, TestType test_type,
72 Padding padding = Padding_VALID, int32_t stride_depth = 1,
73 int32_t stride_width = 1, int32_t stride_height = 1,
74 ActivationFunctionType activation = ActivationFunctionType_NONE,
75 int32_t dilation_depth = 1, int32_t dilation_width = 1,
76 int32_t dilation_height = 1) {
77 if (test_type == TestType::kDynamic) {
78 output_shape_ = AddInput({TensorType_INT32, {5}});
79 } else {
80 output_shape_ = AddConstInput(TensorType_INT32, output_shape_data, {5});
81 }
82 filter_ = AddInput(filter);
83 input_ = AddInput(input);
84 output_ = AddOutput(output);
85 SetBuiltinOp(
86 BuiltinOperator_CONV_3D_TRANSPOSE, BuiltinOptions_Conv3DOptions,
87 CreateConv3DOptions(builder_, padding, stride_depth, stride_width,
88 stride_height, activation, dilation_depth,
89 dilation_width, dilation_height)
90 .Union());
91 BuildInterpreter(
92 {GetShape(output_shape_), GetShape(filter_), GetShape(input_)});
93 if (test_type == TestType::kDynamic) {
94 PopulateTensor(output_shape_, output_shape_data);
95 }
96 }
97
SetFilter(std::vector<float> f)98 void SetFilter(std::vector<float> f) { PopulateTensor(filter_, f); }
99
SetBias(std::initializer_list<float> f)100 void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
101
SetInput(std::vector<float> data)102 void SetInput(std::vector<float> data) { PopulateTensor(input_, data); }
103
GetOutput()104 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()105 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
106
107 private:
108 int output_shape_;
109 int input_;
110 int filter_;
111 int bias_;
112 int output_;
113 };
114
115 template <typename T>
CreateRangeVector(int N)116 std::vector<T> CreateRangeVector(int N) {
117 std::vector<T> result;
118 for (int i = 0; i < N; ++i) result.push_back(i);
119 return result;
120 }
121
122 class Conv3dTransposeOpTest : public ::testing::TestWithParam<TestType> {};
123
TEST_P(Conv3dTransposeOpTest,InvalidInputDimsTest)124 TEST_P(Conv3dTransposeOpTest, InvalidInputDimsTest) {
125 EXPECT_DEATH_IF_SUPPORTED(
126 Conv3dTransposeOpModel m(
127 {1, 2, 3, 4, 5}, {TensorType_FLOAT32, {2, 2, 4, 1}},
128 {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}},
129 Conv3dTransposeOpTest::GetParam()),
130 "input->dims->size != 5");
131 }
132
TEST_P(Conv3dTransposeOpTest,InvalidFilterDimsTest)133 TEST_P(Conv3dTransposeOpTest, InvalidFilterDimsTest) {
134 EXPECT_DEATH_IF_SUPPORTED(
135 Conv3dTransposeOpModel m(
136 {1, 2, 3, 4, 5}, {TensorType_FLOAT32, {2, 2, 4, 1}},
137 {TensorType_FLOAT32, {1, 3, 2, 2, 1}}, {TensorType_FLOAT32, {}},
138 Conv3dTransposeOpTest::GetParam()),
139 "filter->dims->size != 5");
140 }
141
TEST_P(Conv3dTransposeOpTest,MismatchChannelSizeTest)142 TEST_P(Conv3dTransposeOpTest, MismatchChannelSizeTest) {
143 EXPECT_DEATH_IF_SUPPORTED(
144 Conv3dTransposeOpModel m(
145 {1, 2, 3, 4, 5}, {TensorType_FLOAT32, {1, 2, 2, 4, 1}},
146 {TensorType_FLOAT32, {1, 3, 2, 2, 2}}, {TensorType_FLOAT32, {}},
147 Conv3dTransposeOpTest::GetParam()),
148 "SizeOfDimension.input, 4. != SizeOfDimension.filter, 4.");
149 }
150
TEST_P(Conv3dTransposeOpTest,MismatchBiasSizeTest)151 TEST_P(Conv3dTransposeOpTest, MismatchBiasSizeTest) {
152 EXPECT_DEATH_IF_SUPPORTED(
153 Conv3dTransposeOpModel m(
154 {1, 2, 3, 4, 5}, {TensorType_FLOAT32, {1, 3, 2, 2, 2}},
155 {TensorType_FLOAT32, {1, 2, 2, 4, 2}}, {TensorType_FLOAT32, {3}},
156 {TensorType_FLOAT32, {}}, Conv3dTransposeOpTest::GetParam()),
157 "NumElements.bias. != SizeOfDimension.filter, 3.");
158 }
159
TEST_P(Conv3dTransposeOpTest,SimpleFloat32Test)160 TEST_P(Conv3dTransposeOpTest, SimpleFloat32Test) {
161 Conv3dTransposeOpModel m(
162 {1, 3, 3, 5, 2}, {TensorType_FLOAT32, {2, 2, 2, 2, 2}},
163 {TensorType_FLOAT32, {1, 2, 2, 4, 2}}, {TensorType_FLOAT32, {}},
164 Conv3dTransposeOpTest::GetParam());
165
166 m.SetInput(CreateRangeVector<float>(32));
167 m.SetFilter({-1, -1, -1, -1, -1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1,
168 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, 1, -1});
169 ASSERT_EQ(m.Invoke(), kTfLiteOk);
170
171 EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 3, 3, 5, 2));
172 EXPECT_THAT(
173 m.GetOutput(),
174 ElementsAreArray(
175 {-1, -1, -4, -4, -8, -8, -12, -12, 1, 1, -16, -16, -18,
176 -16, -18, -20, -18, -24, 14, -12, 1, 17, 18, 4, 22, 4,
177 26, 4, 29, -29, -34, -32, -36, -30, -36, -30, -36, -30, 14,
178 2, -50, 2, -8, -26, -8, -26, -8, -26, 74, -44, -16, 50,
179 28, 4, 28, 4, 28, 4, 60, -62, -1, 33, 32, 38, 36,
180 42, 40, 46, 45, 1, -34, 50, 10, 54, 10, 58, 10, 62,
181 60, 0, -49, 1, -54, 0, -58, 0, -62, 0, -1, -1}));
182 }
183
TEST_P(Conv3dTransposeOpTest,PaddingValidTest)184 TEST_P(Conv3dTransposeOpTest, PaddingValidTest) {
185 Conv3dTransposeOpModel m(
186 {1, 4, 5, 6, 2}, {TensorType_FLOAT32, {2, 2, 2, 2, 2}},
187 {TensorType_FLOAT32, {1, 3, 4, 5, 2}}, {TensorType_FLOAT32, {}},
188 Conv3dTransposeOpTest::GetParam());
189
190 m.SetInput(CreateRangeVector<float>(120));
191 m.SetFilter({-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1,
192 1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, 1, 1});
193 ASSERT_EQ(m.Invoke(), kTfLiteOk);
194
195 EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 4, 5, 6, 2));
196 EXPECT_THAT(
197 m.GetOutput(),
198 ElementsAreArray(
199 {-1, -1, -6, -6, -14, -14, -22, -22, -30, -30, -17,
200 -17, -22, -20, -50, -46, -58, -58, -66, -70, -74, -82,
201 -20, -54, -62, -40, -90, -106, -98, -118, -106, -130, -114,
202 -142, -20, -94, -102, -60, -130, -166, -138, -178, -146, -190,
203 -154, -202, -20, -134, -61, 1, -4, -60, -4, -64, -4,
204 -68, -4, -72, 77, -77, -80, -80, -160, -164, -164, -172,
205 -168, -180, -172, -188, -96, -96, -162, -98, -188, -282, -196,
206 -290, -204, -298, -212, -306, -18, -196, -202, -118, -228, -322,
207 -236, -330, -244, -338, -252, -346, -18, -216, -242, -138, -268,
208 -362, -276, -370, -284, -378, -292, -386, -18, -236, -202, 2,
209 -68, -78, -72, -78, -76, -78, -80, -78, 158, -80, -80,
210 -160, -240, -324, -244, -332, -248, -340, -252, -348, -176, -176,
211 -322, -178, -348, -442, -356, -450, -364, -458, -372, -466, -18,
212 -276, -362, -198, -388, -482, -396, -490, -404, -498, -412, -506,
213 -18, -296, -402, -218, -428, -522, -436, -530, -444, -538, -452,
214 -546, -18, -316, -362, 2, -148, -78, -152, -78, -156, -78,
215 -160, -78, 238, -80, 161, 1, 166, 2, 170, 2, 174,
216 2, 178, 2, 1, 1, 20, 2, 22, 164, 22, 168,
217 22, 172, 22, 176, 2, 178, 20, 2, 22, 184, 22,
218 188, 22, 192, 22, 196, 2, 198, 20, 2, 22, 204,
219 22, 208, 22, 212, 22, 216, 2, 218, -221, 1, -224,
220 222, -228, 226, -232, 230, -236, 234, 1, 237}));
221 }
222
TEST_P(Conv3dTransposeOpTest,PaddingSameTest)223 TEST_P(Conv3dTransposeOpTest, PaddingSameTest) {
224 Conv3dTransposeOpModel m(
225 {1, 3, 4, 5, 2}, {TensorType_FLOAT32, {2, 2, 2, 2, 2}},
226 {TensorType_FLOAT32, {1, 3, 4, 5, 2}}, {TensorType_FLOAT32, {}},
227 Conv3dTransposeOpTest::GetParam(), Padding_SAME);
228
229 m.SetInput(CreateRangeVector<float>(120));
230 m.SetFilter({1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1,
231 -1, 1, -1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1});
232 ASSERT_EQ(m.Invoke(), kTfLiteOk);
233
234 EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 3, 4, 5, 2));
235 EXPECT_THAT(
236 m.GetOutput(),
237 ElementsAreArray(
238 {-1, -1, -2, 0, -2, 0, -2, 0, -2, 0, -2, 0, -4, 2,
239 -4, 2, -4, 2, -4, 2, -2, 0, -4, 2, -4, 2, -4, 2,
240 -4, 2, -2, 0, -4, 2, -4, 2, -4, 2, -4, 2, 0, 0,
241 -2, 2, -6, 2, -10, 2, -14, 2, 0, 2, -18, 10, -18, 14,
242 -18, 18, -18, 22, 20, 22, -18, 30, -18, 34, -18, 38, -18, 42,
243 40, 42, -18, 50, -18, 54, -18, 58, -18, 62, 0, 0, -82, 2,
244 -86, 2, -90, 2, -94, 2, 80, 82, -18, 90, -18, 94, -18, 98,
245 -18, 102, 100, 102, -18, 110, -18, 114, -18, 118, -18, 122, 120, 122,
246 -18, 130, -18, 134, -18, 138, -18, 142}));
247 }
248
TEST_P(Conv3dTransposeOpTest,PaddingValidComplexTest)249 TEST_P(Conv3dTransposeOpTest, PaddingValidComplexTest) {
250 Conv3dTransposeOpModel m(
251 {2, 4, 3, 2, 2}, {TensorType_FLOAT32, {2, 2, 2, 2, 2}},
252 {TensorType_FLOAT32, {2, 3, 2, 1, 2}}, {TensorType_FLOAT32, {}},
253 Conv3dTransposeOpTest::GetParam(), Padding_VALID);
254
255 m.SetInput(CreateRangeVector<float>(24));
256 m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
257 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
258 ASSERT_EQ(m.Invoke(), kTfLiteOk);
259
260 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4, 3, 2, 2));
261 EXPECT_THAT(
262 m.GetOutput(),
263 ElementsAreArray(
264 {-1, 1, 1, -1, -2, 4, 2, 0, -1, -5, 1, 5, -2, 10, 2, -2,
265 -4, 8, 4, 8, -2, -18, 2, 18, -2, 26, 2, -2, -4, 8, 4, 24,
266 -2, -34, 2, 34, -1, 17, 1, -1, -2, 4, 2, 16, -1, -21, 1, 21,
267 -1, 25, 1, -1, -2, 4, 2, 24, -1, -29, 1, 29, -2, 58, 2, -2,
268 -4, 8, 4, 56, -2, -66, 2, 66, -2, 74, 2, -2, -4, 8, 4, 72,
269 -2, -82, 2, 82, -1, 41, 1, -1, -2, 4, 2, 40, -1, -45, 1, 45}));
270 }
271
TEST_P(Conv3dTransposeOpTest,StrideTest)272 TEST_P(Conv3dTransposeOpTest, StrideTest) {
273 Conv3dTransposeOpModel m(
274 {2, 4, 3, 2, 2}, {TensorType_FLOAT32, {2, 2, 2, 2, 2}},
275 {TensorType_FLOAT32, {2, 2, 2, 1, 2}}, {TensorType_FLOAT32, {}},
276 Conv3dTransposeOpTest::GetParam(), Padding_VALID,
277 /*stride_depth=*/2,
278 /*stride_width=*/1, /*stride_height=*/1);
279
280 m.SetInput(CreateRangeVector<float>(16));
281 m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
282 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
283 ASSERT_EQ(m.Invoke(), kTfLiteOk);
284
285 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4, 3, 2, 2));
286 EXPECT_THAT(
287 m.GetOutput(),
288 ElementsAreArray(
289 {-1, 1, 1, -1, -2, 4, 2, 0, -1, -5, 1, 5, -1, 1, 1, -1,
290 -2, 4, 2, 0, -1, -5, 1, 5, -1, 9, 1, -1, -2, 4, 2, 8,
291 -1, -13, 1, 13, -1, 9, 1, -1, -2, 4, 2, 8, -1, -13, 1, 13,
292 -1, 17, 1, -1, -2, 4, 2, 16, -1, -21, 1, 21, -1, 17, 1, -1,
293 -2, 4, 2, 16, -1, -21, 1, 21, -1, 25, 1, -1, -2, 4, 2, 24,
294 -1, -29, 1, 29, -1, 25, 1, -1, -2, 4, 2, 24, -1, -29, 1, 29}));
295 }
296
TEST_P(Conv3dTransposeOpTest,StrideAndPaddingSameTest)297 TEST_P(Conv3dTransposeOpTest, StrideAndPaddingSameTest) {
298 Conv3dTransposeOpModel m(
299 {2, 4, 2, 1, 2}, {TensorType_FLOAT32, {2, 2, 2, 2, 2}},
300 {TensorType_FLOAT32, {2, 2, 2, 1, 2}}, {TensorType_FLOAT32, {}},
301 Conv3dTransposeOpTest::GetParam(), Padding_SAME,
302 /*stride_depth=*/2,
303 /*stride_width=*/1, /*stride_height=*/1);
304
305 m.SetInput(CreateRangeVector<float>(16));
306 m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
307 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
308 ASSERT_EQ(m.Invoke(), kTfLiteOk);
309
310 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4, 2, 1, 2));
311 EXPECT_THAT(m.GetOutput(),
312 ElementsAreArray({-1, 1, -2, 4, -1, 1, -2, 4, -1, 9, -2,
313 4, -1, 9, -2, 4, -1, 17, -2, 4, -1, 17,
314 -2, 4, -1, 25, -2, 4, -1, 25, -2, 4}));
315 }
316
TEST_P(Conv3dTransposeOpTest,DilationTest)317 TEST_P(Conv3dTransposeOpTest, DilationTest) {
318 Conv3dTransposeOpModel m(
319 {1, 3, 3, 2, 2}, {TensorType_FLOAT32, {1, 2, 2, 2, 1}},
320 {TensorType_FLOAT32, {1, 3, 1, 1, 1}}, {TensorType_FLOAT32, {}},
321 Conv3dTransposeOpTest::GetParam(), Padding_VALID,
322 /*stride_depth=*/1,
323 /*stride_width=*/1, /*stride_height=*/1,
324 /*activation=*/ActivationFunctionType_NONE,
325 /*dilation_depth=*/1, /*dilation_width=*/1,
326 /*dilation_height=*/2);
327
328 m.SetInput(CreateRangeVector<float>(3));
329 m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1});
330 ASSERT_EQ(m.Invoke(), kTfLiteOk);
331
332 EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 3, 3, 2, 2));
333 EXPECT_THAT(m.GetOutput(),
334 ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
335 1, -1, 1, 1, 0, 0, 0, 0, -1, 1, 1, -1,
336 2, -2, 2, 2, 0, 0, 0, 0, -2, 2, 2, -2}));
337 }
338
TEST_P(Conv3dTransposeOpTest,BiasTest)339 TEST_P(Conv3dTransposeOpTest, BiasTest) {
340 Conv3dTransposeOpModel m({2, 4, 3, 2, 2},
341 {TensorType_FLOAT32, {2, 2, 2, 2, 2}},
342 {TensorType_FLOAT32, {2, 3, 2, 1, 2}},
343 {TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {}},
344 Conv3dTransposeOpTest::GetParam(), Padding_VALID);
345
346 m.SetInput(CreateRangeVector<float>(24));
347 m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
348 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
349 m.SetBias({1, 2});
350 ASSERT_EQ(m.Invoke(), kTfLiteOk);
351
352 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4, 3, 2, 2));
353 EXPECT_THAT(
354 m.GetOutput(),
355 ElementsAreArray(
356 {0, 3, 2, 1, -1, 6, 3, 2, 0, -3, 2, 7, -1, 12, 3, 0,
357 -3, 10, 5, 10, -1, -16, 3, 20, -1, 28, 3, 0, -3, 10, 5, 26,
358 -1, -32, 3, 36, 0, 19, 2, 1, -1, 6, 3, 18, 0, -19, 2, 23,
359 0, 27, 2, 1, -1, 6, 3, 26, 0, -27, 2, 31, -1, 60, 3, 0,
360 -3, 10, 5, 58, -1, -64, 3, 68, -1, 76, 3, 0, -3, 10, 5, 74,
361 -1, -80, 3, 84, 0, 43, 2, 1, -1, 6, 3, 42, 0, -43, 2, 47}));
362 }
363
364 INSTANTIATE_TEST_SUITE_P(Conv3dTransposeOpTest, Conv3dTransposeOpTest,
365 ::testing::Values(TestType::kConst,
366 TestType::kDynamic));
367
368 } // namespace
369 } // namespace tflite
370