1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19
20 namespace tflite {
21 namespace {
22
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_1DNoChanges)23 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_1DNoChanges) {
24 RuntimeShape input_shape({9});
25 RuntimeShape output_shape({9});
26
27 TransposeParams params;
28 params.perm_count = 1;
29 params.perm[0] = 0;
30
31 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
32 ¶ms);
33
34 EXPECT_EQ(input_shape, RuntimeShape({9}));
35 EXPECT_EQ(output_shape, RuntimeShape({9}));
36
37 EXPECT_EQ(params.perm_count, 1);
38 EXPECT_EQ(params.perm[0], 0);
39 }
40
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_2DNoChanges)41 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DNoChanges) {
42 RuntimeShape input_shape({9, 3});
43 RuntimeShape output_shape({3, 9});
44
45 TransposeParams params;
46 params.perm_count = 2;
47 params.perm[0] = 1;
48 params.perm[1] = 0;
49
50 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
51 ¶ms);
52
53 EXPECT_EQ(input_shape, RuntimeShape({9, 3}));
54 EXPECT_EQ(output_shape, RuntimeShape({3, 9}));
55
56 EXPECT_EQ(params.perm_count, 2);
57 EXPECT_EQ(params.perm[0], 1);
58 EXPECT_EQ(params.perm[1], 0);
59 }
60
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_2DShrinking)61 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DShrinking) {
62 RuntimeShape input_shape({9, 1});
63 RuntimeShape output_shape({1, 9});
64
65 TransposeParams params;
66 params.perm_count = 2;
67 params.perm[0] = 1;
68 params.perm[1] = 0;
69
70 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
71 ¶ms);
72
73 EXPECT_EQ(input_shape, RuntimeShape({9}));
74 EXPECT_EQ(output_shape, RuntimeShape({9}));
75
76 EXPECT_EQ(params.perm_count, 1);
77 EXPECT_EQ(params.perm[0], 0);
78 }
79
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DNoChanges)80 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DNoChanges) {
81 RuntimeShape input_shape({4, 3, 8});
82 RuntimeShape output_shape({8, 4, 3});
83
84 TransposeParams params;
85 params.perm_count = 3;
86 params.perm[0] = 2;
87 params.perm[1] = 0;
88 params.perm[2] = 1;
89
90 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
91 ¶ms);
92
93 EXPECT_EQ(input_shape, RuntimeShape({4, 3, 8}));
94 EXPECT_EQ(output_shape, RuntimeShape({8, 4, 3}));
95
96 EXPECT_EQ(params.perm_count, 3);
97 EXPECT_EQ(params.perm[0], 2);
98 EXPECT_EQ(params.perm[1], 0);
99 EXPECT_EQ(params.perm[2], 1);
100 }
101
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DShrinkingOnce)102 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingOnce) {
103 RuntimeShape input_shape({4, 1, 8});
104 RuntimeShape output_shape({8, 4, 1});
105
106 TransposeParams params;
107 params.perm_count = 3;
108 params.perm[0] = 2;
109 params.perm[1] = 0;
110 params.perm[2] = 1;
111
112 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
113 ¶ms);
114
115 EXPECT_EQ(input_shape, RuntimeShape({4, 8}));
116 EXPECT_EQ(output_shape, RuntimeShape({8, 4}));
117 EXPECT_EQ(output_shape.Dims(1), 4);
118
119 EXPECT_EQ(params.perm_count, 2);
120 EXPECT_EQ(params.perm[0], 1);
121 EXPECT_EQ(params.perm[1], 0);
122 }
123
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DShrinkingTwice)124 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingTwice) {
125 RuntimeShape input_shape({4, 1, 1});
126 RuntimeShape output_shape({1, 4, 1});
127
128 TransposeParams params;
129 params.perm_count = 3;
130 params.perm[0] = 2;
131 params.perm[1] = 0;
132 params.perm[2] = 1;
133
134 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
135 ¶ms);
136
137 EXPECT_EQ(input_shape, RuntimeShape({4}));
138 EXPECT_EQ(output_shape, RuntimeShape({4}));
139
140 EXPECT_EQ(params.perm_count, 1);
141 EXPECT_EQ(params.perm[0], 0);
142 }
143
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DAllOnes)144 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DAllOnes) {
145 RuntimeShape input_shape({1, 1, 1});
146 RuntimeShape output_shape({1, 1, 1});
147
148 TransposeParams params;
149 params.perm_count = 3;
150 params.perm[0] = 2;
151 params.perm[1] = 0;
152 params.perm[2] = 1;
153
154 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
155 ¶ms);
156
157 EXPECT_EQ(input_shape, RuntimeShape({1}));
158 EXPECT_EQ(output_shape, RuntimeShape({1}));
159
160 EXPECT_EQ(params.perm_count, 1);
161 EXPECT_EQ(params.perm[0], 0);
162 }
163
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DNoChanges)164 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DNoChanges) {
165 RuntimeShape input_shape({9, 3, 2, 4});
166 RuntimeShape output_shape({3, 9, 4, 2});
167
168 TransposeParams params;
169 params.perm_count = 4;
170 params.perm[0] = 1;
171 params.perm[1] = 0;
172 params.perm[2] = 3;
173 params.perm[3] = 2;
174
175 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
176 ¶ms);
177
178 EXPECT_EQ(input_shape, RuntimeShape({9, 3, 2, 4}));
179 EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4, 2}));
180
181 EXPECT_EQ(params.perm_count, 4);
182 EXPECT_EQ(params.perm[0], 1);
183 EXPECT_EQ(params.perm[1], 0);
184 EXPECT_EQ(params.perm[2], 3);
185 EXPECT_EQ(params.perm[3], 2);
186 }
187
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DShrinkingOnce)188 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingOnce) {
189 RuntimeShape input_shape({9, 3, 1, 4});
190 RuntimeShape output_shape({3, 9, 4, 1});
191
192 TransposeParams params;
193 params.perm_count = 4;
194 params.perm[0] = 1;
195 params.perm[1] = 0;
196 params.perm[2] = 3;
197 params.perm[3] = 2;
198
199 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
200 ¶ms);
201
202 EXPECT_EQ(input_shape, RuntimeShape({9, 3, 4}));
203 EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4}));
204
205 EXPECT_EQ(params.perm_count, 3);
206 EXPECT_EQ(params.perm[0], 1);
207 EXPECT_EQ(params.perm[1], 0);
208 EXPECT_EQ(params.perm[2], 2);
209 }
210
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DShrinkingTwice)211 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingTwice) {
212 RuntimeShape input_shape({1, 3, 1, 4});
213 RuntimeShape output_shape({3, 1, 4, 1});
214
215 TransposeParams params;
216 params.perm_count = 4;
217 params.perm[0] = 1;
218 params.perm[1] = 2;
219 params.perm[2] = 3;
220 params.perm[3] = 0;
221
222 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
223 ¶ms);
224
225 EXPECT_EQ(input_shape, RuntimeShape({3, 4}));
226 EXPECT_EQ(output_shape, RuntimeShape({3, 4}));
227
228 EXPECT_EQ(params.perm_count, 2);
229 EXPECT_EQ(params.perm[0], 0);
230 EXPECT_EQ(params.perm[1], 1);
231 }
232
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DShrinkingThirdTimes)233 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingThirdTimes) {
234 RuntimeShape input_shape({1, 1, 7, 1});
235 RuntimeShape output_shape({1, 7, 1, 1});
236
237 TransposeParams params;
238 params.perm_count = 4;
239 params.perm[0] = 0;
240 params.perm[1] = 2;
241 params.perm[2] = 1;
242 params.perm[3] = 3;
243
244 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
245 ¶ms);
246
247 EXPECT_EQ(input_shape, RuntimeShape({7}));
248 EXPECT_EQ(output_shape, RuntimeShape({7}));
249
250 EXPECT_EQ(params.perm_count, 1);
251 EXPECT_EQ(params.perm[0], 0);
252 }
253
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DAllOnes)254 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DAllOnes) {
255 RuntimeShape input_shape({1, 1, 1, 1});
256 RuntimeShape output_shape({1, 1, 1, 1});
257
258 TransposeParams params;
259 params.perm_count = 4;
260 params.perm[0] = 0;
261 params.perm[1] = 2;
262 params.perm[2] = 1;
263 params.perm[3] = 3;
264
265 transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
266 ¶ms);
267
268 EXPECT_EQ(input_shape, RuntimeShape({1}));
269 EXPECT_EQ(output_shape, RuntimeShape({1}));
270
271 EXPECT_EQ(params.perm_count, 1);
272 EXPECT_EQ(params.perm[0], 0);
273 }
274
TEST(TransposeUtilsTest,Flatten3D)275 TEST(TransposeUtilsTest, Flatten3D) {
276 RuntimeShape input_shape({3, 5, 7});
277 RuntimeShape output_shape({3, 7, 5});
278
279 TransposeParams params;
280 params.perm_count = 3;
281 params.perm[0] = 0;
282 params.perm[1] = 2;
283 params.perm[2] = 1;
284
285 RuntimeShape non_flatten_input_shape;
286 RuntimeShape non_flatten_output_shape;
287 TransposeParams non_flatten_params;
288 size_t non_flatten_size = transpose_utils::Flatten(
289 input_shape, output_shape, params, &non_flatten_input_shape,
290 &non_flatten_output_shape, &non_flatten_params);
291
292 EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7}));
293 EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5}));
294 EXPECT_EQ(non_flatten_size, 5 * 7);
295
296 EXPECT_EQ(non_flatten_params.perm_count, 2);
297 EXPECT_EQ(non_flatten_params.perm[0], 1);
298 EXPECT_EQ(non_flatten_params.perm[1], 0);
299 }
300
TEST(TransposeUtilsTest,Flatten4DFlattenOnce)301 TEST(TransposeUtilsTest, Flatten4DFlattenOnce) {
302 RuntimeShape input_shape({3, 5, 7, 9});
303 RuntimeShape output_shape({3, 7, 5, 9});
304
305 TransposeParams params;
306 params.perm_count = 4;
307 params.perm[0] = 0;
308 params.perm[1] = 2;
309 params.perm[2] = 1;
310 params.perm[3] = 3;
311
312 RuntimeShape non_flatten_input_shape;
313 RuntimeShape non_flatten_output_shape;
314 TransposeParams non_flatten_params;
315 size_t non_flatten_size = transpose_utils::Flatten(
316 input_shape, output_shape, params, &non_flatten_input_shape,
317 &non_flatten_output_shape, &non_flatten_params);
318
319 EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7, 9}));
320 EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5, 9}));
321 EXPECT_EQ(non_flatten_size, 5 * 7 * 9);
322
323 EXPECT_EQ(non_flatten_params.perm_count, 3);
324 EXPECT_EQ(non_flatten_params.perm[0], 1);
325 EXPECT_EQ(non_flatten_params.perm[1], 0);
326 EXPECT_EQ(non_flatten_params.perm[2], 2);
327 }
328
TEST(TransposeUtilsTest,Flatten4DFlattenTwice)329 TEST(TransposeUtilsTest, Flatten4DFlattenTwice) {
330 RuntimeShape input_shape({3, 5, 7, 9});
331 RuntimeShape output_shape({3, 5, 9, 7});
332
333 TransposeParams params;
334 params.perm_count = 4;
335 params.perm[0] = 0;
336 params.perm[1] = 1;
337 params.perm[2] = 3;
338 params.perm[3] = 2;
339
340 RuntimeShape non_flatten_input_shape;
341 RuntimeShape non_flatten_output_shape;
342 TransposeParams non_flatten_params;
343 size_t non_flatten_size = transpose_utils::Flatten(
344 input_shape, output_shape, params, &non_flatten_input_shape,
345 &non_flatten_output_shape, &non_flatten_params);
346
347 EXPECT_EQ(non_flatten_input_shape, RuntimeShape({7, 9}));
348 EXPECT_EQ(non_flatten_output_shape, RuntimeShape({9, 7}));
349 EXPECT_EQ(non_flatten_size, 7 * 9);
350
351 EXPECT_EQ(non_flatten_params.perm_count, 2);
352 EXPECT_EQ(non_flatten_params.perm[0], 1);
353 EXPECT_EQ(non_flatten_params.perm[1], 0);
354 }
355
TEST(TransposeUtilsTest,IsTranspose2DApplicable2D)356 TEST(TransposeUtilsTest, IsTranspose2DApplicable2D) {
357 RuntimeShape input_shape({4, 5});
358
359 TransposeParams params;
360 params.perm_count = 2;
361 params.perm[0] = 1;
362 params.perm[1] = 0;
363
364 int dim0, dim1;
365 bool applicable = transpose_utils::IsTranspose2DApplicable(
366 params, input_shape, &dim0, &dim1);
367
368 EXPECT_TRUE(applicable);
369 EXPECT_EQ(dim0, 4);
370 EXPECT_EQ(dim1, 5);
371 }
372
TEST(TransposeUtilsTest,IsTranspose2DApplicable3DOne)373 TEST(TransposeUtilsTest, IsTranspose2DApplicable3DOne) {
374 RuntimeShape input_shape({4, 5, 6});
375
376 TransposeParams params;
377 params.perm_count = 3;
378 params.perm[0] = 1;
379 params.perm[1] = 2;
380 params.perm[2] = 0;
381
382 int dim0, dim1;
383 bool applicable = transpose_utils::IsTranspose2DApplicable(
384 params, input_shape, &dim0, &dim1);
385
386 EXPECT_TRUE(applicable);
387 EXPECT_EQ(dim0, 4);
388 EXPECT_EQ(dim1, 30);
389 }
390
TEST(TransposeUtilsTest,IsTranspose2DApplicable3DTwo)391 TEST(TransposeUtilsTest, IsTranspose2DApplicable3DTwo) {
392 RuntimeShape input_shape({4, 5, 6});
393
394 TransposeParams params;
395 params.perm_count = 3;
396 params.perm[0] = 2;
397 params.perm[1] = 0;
398 params.perm[2] = 1;
399
400 int dim0, dim1;
401 bool applicable = transpose_utils::IsTranspose2DApplicable(
402 params, input_shape, &dim0, &dim1);
403
404 EXPECT_TRUE(applicable);
405 EXPECT_EQ(dim0, 20);
406 EXPECT_EQ(dim1, 6);
407 }
408
TEST(TransposeUtilsTest,IsTranspose2DApplicable3DNotApplicable)409 TEST(TransposeUtilsTest, IsTranspose2DApplicable3DNotApplicable) {
410 RuntimeShape input_shape({4, 5, 6});
411
412 TransposeParams params;
413 params.perm_count = 3;
414 params.perm[0] = 2;
415 params.perm[1] = 1;
416 params.perm[2] = 0;
417
418 int dim0, dim1;
419 bool applicable = transpose_utils::IsTranspose2DApplicable(
420 params, input_shape, &dim0, &dim1);
421
422 EXPECT_FALSE(applicable);
423 }
424
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DOne)425 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DOne) {
426 RuntimeShape input_shape({4, 5, 6, 7});
427
428 TransposeParams params;
429 params.perm_count = 4;
430 params.perm[0] = 1;
431 params.perm[1] = 2;
432 params.perm[2] = 3;
433 params.perm[3] = 0;
434
435 int dim0, dim1;
436 bool applicable = transpose_utils::IsTranspose2DApplicable(
437 params, input_shape, &dim0, &dim1);
438
439 EXPECT_TRUE(applicable);
440 EXPECT_EQ(dim0, 4);
441 EXPECT_EQ(dim1, 210);
442 }
443
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DTwo)444 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DTwo) {
445 RuntimeShape input_shape({4, 5, 6, 7});
446
447 TransposeParams params;
448 params.perm_count = 4;
449 params.perm[0] = 2;
450 params.perm[1] = 3;
451 params.perm[2] = 0;
452 params.perm[3] = 1;
453
454 int dim0, dim1;
455 bool applicable = transpose_utils::IsTranspose2DApplicable(
456 params, input_shape, &dim0, &dim1);
457
458 EXPECT_TRUE(applicable);
459 EXPECT_EQ(dim0, 20);
460 EXPECT_EQ(dim1, 42);
461 }
462
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DThird)463 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DThird) {
464 RuntimeShape input_shape({4, 5, 6, 7});
465
466 TransposeParams params;
467 params.perm_count = 4;
468 params.perm[0] = 3;
469 params.perm[1] = 0;
470 params.perm[2] = 1;
471 params.perm[3] = 2;
472
473 int dim0, dim1;
474 bool applicable = transpose_utils::IsTranspose2DApplicable(
475 params, input_shape, &dim0, &dim1);
476
477 EXPECT_TRUE(applicable);
478 EXPECT_EQ(dim0, 120);
479 EXPECT_EQ(dim1, 7);
480 }
481
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DNotApplicable)482 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DNotApplicable) {
483 RuntimeShape input_shape({4, 5, 6, 7});
484
485 TransposeParams params;
486 params.perm_count = 4;
487 params.perm[0] = 3;
488 params.perm[1] = 2;
489 params.perm[2] = 1;
490 params.perm[3] = 0;
491
492 int dim0, dim1;
493 bool applicable = transpose_utils::IsTranspose2DApplicable(
494 params, input_shape, &dim0, &dim1);
495
496 EXPECT_FALSE(applicable);
497 }
498
499 } // namespace
500 } // namespace tflite
501
main(int argc,char ** argv)502 int main(int argc, char** argv) {
503 // On Linux, add: tflite::LogToStderr();
504 ::testing::InitGoogleTest(&argc, argv);
505 return RUN_ALL_TESTS();
506 }
507