• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 
20 namespace tflite {
21 namespace {
22 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_1DNoChanges)23 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_1DNoChanges) {
24   RuntimeShape input_shape({9});
25   RuntimeShape output_shape({9});
26 
27   TransposeParams params;
28   params.perm_count = 1;
29   params.perm[0] = 0;
30 
31   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
32                                            &params);
33 
34   EXPECT_EQ(input_shape, RuntimeShape({9}));
35   EXPECT_EQ(output_shape, RuntimeShape({9}));
36 
37   EXPECT_EQ(params.perm_count, 1);
38   EXPECT_EQ(params.perm[0], 0);
39 }
40 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_2DNoChanges)41 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DNoChanges) {
42   RuntimeShape input_shape({9, 3});
43   RuntimeShape output_shape({3, 9});
44 
45   TransposeParams params;
46   params.perm_count = 2;
47   params.perm[0] = 1;
48   params.perm[1] = 0;
49 
50   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
51                                            &params);
52 
53   EXPECT_EQ(input_shape, RuntimeShape({9, 3}));
54   EXPECT_EQ(output_shape, RuntimeShape({3, 9}));
55 
56   EXPECT_EQ(params.perm_count, 2);
57   EXPECT_EQ(params.perm[0], 1);
58   EXPECT_EQ(params.perm[1], 0);
59 }
60 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_2DShrinking)61 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DShrinking) {
62   RuntimeShape input_shape({9, 1});
63   RuntimeShape output_shape({1, 9});
64 
65   TransposeParams params;
66   params.perm_count = 2;
67   params.perm[0] = 1;
68   params.perm[1] = 0;
69 
70   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
71                                            &params);
72 
73   EXPECT_EQ(input_shape, RuntimeShape({9}));
74   EXPECT_EQ(output_shape, RuntimeShape({9}));
75 
76   EXPECT_EQ(params.perm_count, 1);
77   EXPECT_EQ(params.perm[0], 0);
78 }
79 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DNoChanges)80 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DNoChanges) {
81   RuntimeShape input_shape({4, 3, 8});
82   RuntimeShape output_shape({8, 4, 3});
83 
84   TransposeParams params;
85   params.perm_count = 3;
86   params.perm[0] = 2;
87   params.perm[1] = 0;
88   params.perm[2] = 1;
89 
90   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
91                                            &params);
92 
93   EXPECT_EQ(input_shape, RuntimeShape({4, 3, 8}));
94   EXPECT_EQ(output_shape, RuntimeShape({8, 4, 3}));
95 
96   EXPECT_EQ(params.perm_count, 3);
97   EXPECT_EQ(params.perm[0], 2);
98   EXPECT_EQ(params.perm[1], 0);
99   EXPECT_EQ(params.perm[2], 1);
100 }
101 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DShrinkingOnce)102 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingOnce) {
103   RuntimeShape input_shape({4, 1, 8});
104   RuntimeShape output_shape({8, 4, 1});
105 
106   TransposeParams params;
107   params.perm_count = 3;
108   params.perm[0] = 2;
109   params.perm[1] = 0;
110   params.perm[2] = 1;
111 
112   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
113                                            &params);
114 
115   EXPECT_EQ(input_shape, RuntimeShape({4, 8}));
116   EXPECT_EQ(output_shape, RuntimeShape({8, 4}));
117   EXPECT_EQ(output_shape.Dims(1), 4);
118 
119   EXPECT_EQ(params.perm_count, 2);
120   EXPECT_EQ(params.perm[0], 1);
121   EXPECT_EQ(params.perm[1], 0);
122 }
123 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DShrinkingTwice)124 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingTwice) {
125   RuntimeShape input_shape({4, 1, 1});
126   RuntimeShape output_shape({1, 4, 1});
127 
128   TransposeParams params;
129   params.perm_count = 3;
130   params.perm[0] = 2;
131   params.perm[1] = 0;
132   params.perm[2] = 1;
133 
134   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
135                                            &params);
136 
137   EXPECT_EQ(input_shape, RuntimeShape({4}));
138   EXPECT_EQ(output_shape, RuntimeShape({4}));
139 
140   EXPECT_EQ(params.perm_count, 1);
141   EXPECT_EQ(params.perm[0], 0);
142 }
143 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_3DAllOnes)144 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DAllOnes) {
145   RuntimeShape input_shape({1, 1, 1});
146   RuntimeShape output_shape({1, 1, 1});
147 
148   TransposeParams params;
149   params.perm_count = 3;
150   params.perm[0] = 2;
151   params.perm[1] = 0;
152   params.perm[2] = 1;
153 
154   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
155                                            &params);
156 
157   EXPECT_EQ(input_shape, RuntimeShape({1}));
158   EXPECT_EQ(output_shape, RuntimeShape({1}));
159 
160   EXPECT_EQ(params.perm_count, 1);
161   EXPECT_EQ(params.perm[0], 0);
162 }
163 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DNoChanges)164 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DNoChanges) {
165   RuntimeShape input_shape({9, 3, 2, 4});
166   RuntimeShape output_shape({3, 9, 4, 2});
167 
168   TransposeParams params;
169   params.perm_count = 4;
170   params.perm[0] = 1;
171   params.perm[1] = 0;
172   params.perm[2] = 3;
173   params.perm[3] = 2;
174 
175   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
176                                            &params);
177 
178   EXPECT_EQ(input_shape, RuntimeShape({9, 3, 2, 4}));
179   EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4, 2}));
180 
181   EXPECT_EQ(params.perm_count, 4);
182   EXPECT_EQ(params.perm[0], 1);
183   EXPECT_EQ(params.perm[1], 0);
184   EXPECT_EQ(params.perm[2], 3);
185   EXPECT_EQ(params.perm[3], 2);
186 }
187 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DShrinkingOnce)188 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingOnce) {
189   RuntimeShape input_shape({9, 3, 1, 4});
190   RuntimeShape output_shape({3, 9, 4, 1});
191 
192   TransposeParams params;
193   params.perm_count = 4;
194   params.perm[0] = 1;
195   params.perm[1] = 0;
196   params.perm[2] = 3;
197   params.perm[3] = 2;
198 
199   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
200                                            &params);
201 
202   EXPECT_EQ(input_shape, RuntimeShape({9, 3, 4}));
203   EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4}));
204 
205   EXPECT_EQ(params.perm_count, 3);
206   EXPECT_EQ(params.perm[0], 1);
207   EXPECT_EQ(params.perm[1], 0);
208   EXPECT_EQ(params.perm[2], 2);
209 }
210 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DShrinkingTwice)211 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingTwice) {
212   RuntimeShape input_shape({1, 3, 1, 4});
213   RuntimeShape output_shape({3, 1, 4, 1});
214 
215   TransposeParams params;
216   params.perm_count = 4;
217   params.perm[0] = 1;
218   params.perm[1] = 2;
219   params.perm[2] = 3;
220   params.perm[3] = 0;
221 
222   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
223                                            &params);
224 
225   EXPECT_EQ(input_shape, RuntimeShape({3, 4}));
226   EXPECT_EQ(output_shape, RuntimeShape({3, 4}));
227 
228   EXPECT_EQ(params.perm_count, 2);
229   EXPECT_EQ(params.perm[0], 0);
230   EXPECT_EQ(params.perm[1], 1);
231 }
232 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DShrinkingThirdTimes)233 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingThirdTimes) {
234   RuntimeShape input_shape({1, 1, 7, 1});
235   RuntimeShape output_shape({1, 7, 1, 1});
236 
237   TransposeParams params;
238   params.perm_count = 4;
239   params.perm[0] = 0;
240   params.perm[1] = 2;
241   params.perm[2] = 1;
242   params.perm[3] = 3;
243 
244   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
245                                            &params);
246 
247   EXPECT_EQ(input_shape, RuntimeShape({7}));
248   EXPECT_EQ(output_shape, RuntimeShape({7}));
249 
250   EXPECT_EQ(params.perm_count, 1);
251   EXPECT_EQ(params.perm[0], 0);
252 }
253 
TEST(TransposeUtilsTest,RemoveOneSizeDimensions_4DAllOnes)254 TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DAllOnes) {
255   RuntimeShape input_shape({1, 1, 1, 1});
256   RuntimeShape output_shape({1, 1, 1, 1});
257 
258   TransposeParams params;
259   params.perm_count = 4;
260   params.perm[0] = 0;
261   params.perm[1] = 2;
262   params.perm[2] = 1;
263   params.perm[3] = 3;
264 
265   transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
266                                            &params);
267 
268   EXPECT_EQ(input_shape, RuntimeShape({1}));
269   EXPECT_EQ(output_shape, RuntimeShape({1}));
270 
271   EXPECT_EQ(params.perm_count, 1);
272   EXPECT_EQ(params.perm[0], 0);
273 }
274 
TEST(TransposeUtilsTest,Flatten3D)275 TEST(TransposeUtilsTest, Flatten3D) {
276   RuntimeShape input_shape({3, 5, 7});
277   RuntimeShape output_shape({3, 7, 5});
278 
279   TransposeParams params;
280   params.perm_count = 3;
281   params.perm[0] = 0;
282   params.perm[1] = 2;
283   params.perm[2] = 1;
284 
285   RuntimeShape non_flatten_input_shape;
286   RuntimeShape non_flatten_output_shape;
287   TransposeParams non_flatten_params;
288   size_t non_flatten_size = transpose_utils::Flatten(
289       input_shape, output_shape, params, &non_flatten_input_shape,
290       &non_flatten_output_shape, &non_flatten_params);
291 
292   EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7}));
293   EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5}));
294   EXPECT_EQ(non_flatten_size, 5 * 7);
295 
296   EXPECT_EQ(non_flatten_params.perm_count, 2);
297   EXPECT_EQ(non_flatten_params.perm[0], 1);
298   EXPECT_EQ(non_flatten_params.perm[1], 0);
299 }
300 
TEST(TransposeUtilsTest,Flatten4DFlattenOnce)301 TEST(TransposeUtilsTest, Flatten4DFlattenOnce) {
302   RuntimeShape input_shape({3, 5, 7, 9});
303   RuntimeShape output_shape({3, 7, 5, 9});
304 
305   TransposeParams params;
306   params.perm_count = 4;
307   params.perm[0] = 0;
308   params.perm[1] = 2;
309   params.perm[2] = 1;
310   params.perm[3] = 3;
311 
312   RuntimeShape non_flatten_input_shape;
313   RuntimeShape non_flatten_output_shape;
314   TransposeParams non_flatten_params;
315   size_t non_flatten_size = transpose_utils::Flatten(
316       input_shape, output_shape, params, &non_flatten_input_shape,
317       &non_flatten_output_shape, &non_flatten_params);
318 
319   EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7, 9}));
320   EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5, 9}));
321   EXPECT_EQ(non_flatten_size, 5 * 7 * 9);
322 
323   EXPECT_EQ(non_flatten_params.perm_count, 3);
324   EXPECT_EQ(non_flatten_params.perm[0], 1);
325   EXPECT_EQ(non_flatten_params.perm[1], 0);
326   EXPECT_EQ(non_flatten_params.perm[2], 2);
327 }
328 
TEST(TransposeUtilsTest,Flatten4DFlattenTwice)329 TEST(TransposeUtilsTest, Flatten4DFlattenTwice) {
330   RuntimeShape input_shape({3, 5, 7, 9});
331   RuntimeShape output_shape({3, 5, 9, 7});
332 
333   TransposeParams params;
334   params.perm_count = 4;
335   params.perm[0] = 0;
336   params.perm[1] = 1;
337   params.perm[2] = 3;
338   params.perm[3] = 2;
339 
340   RuntimeShape non_flatten_input_shape;
341   RuntimeShape non_flatten_output_shape;
342   TransposeParams non_flatten_params;
343   size_t non_flatten_size = transpose_utils::Flatten(
344       input_shape, output_shape, params, &non_flatten_input_shape,
345       &non_flatten_output_shape, &non_flatten_params);
346 
347   EXPECT_EQ(non_flatten_input_shape, RuntimeShape({7, 9}));
348   EXPECT_EQ(non_flatten_output_shape, RuntimeShape({9, 7}));
349   EXPECT_EQ(non_flatten_size, 7 * 9);
350 
351   EXPECT_EQ(non_flatten_params.perm_count, 2);
352   EXPECT_EQ(non_flatten_params.perm[0], 1);
353   EXPECT_EQ(non_flatten_params.perm[1], 0);
354 }
355 
TEST(TransposeUtilsTest,IsTranspose2DApplicable2D)356 TEST(TransposeUtilsTest, IsTranspose2DApplicable2D) {
357   RuntimeShape input_shape({4, 5});
358 
359   TransposeParams params;
360   params.perm_count = 2;
361   params.perm[0] = 1;
362   params.perm[1] = 0;
363 
364   int dim0, dim1;
365   bool applicable = transpose_utils::IsTranspose2DApplicable(
366       params, input_shape, &dim0, &dim1);
367 
368   EXPECT_TRUE(applicable);
369   EXPECT_EQ(dim0, 4);
370   EXPECT_EQ(dim1, 5);
371 }
372 
TEST(TransposeUtilsTest,IsTranspose2DApplicable3DOne)373 TEST(TransposeUtilsTest, IsTranspose2DApplicable3DOne) {
374   RuntimeShape input_shape({4, 5, 6});
375 
376   TransposeParams params;
377   params.perm_count = 3;
378   params.perm[0] = 1;
379   params.perm[1] = 2;
380   params.perm[2] = 0;
381 
382   int dim0, dim1;
383   bool applicable = transpose_utils::IsTranspose2DApplicable(
384       params, input_shape, &dim0, &dim1);
385 
386   EXPECT_TRUE(applicable);
387   EXPECT_EQ(dim0, 4);
388   EXPECT_EQ(dim1, 30);
389 }
390 
TEST(TransposeUtilsTest,IsTranspose2DApplicable3DTwo)391 TEST(TransposeUtilsTest, IsTranspose2DApplicable3DTwo) {
392   RuntimeShape input_shape({4, 5, 6});
393 
394   TransposeParams params;
395   params.perm_count = 3;
396   params.perm[0] = 2;
397   params.perm[1] = 0;
398   params.perm[2] = 1;
399 
400   int dim0, dim1;
401   bool applicable = transpose_utils::IsTranspose2DApplicable(
402       params, input_shape, &dim0, &dim1);
403 
404   EXPECT_TRUE(applicable);
405   EXPECT_EQ(dim0, 20);
406   EXPECT_EQ(dim1, 6);
407 }
408 
TEST(TransposeUtilsTest,IsTranspose2DApplicable3DNotApplicable)409 TEST(TransposeUtilsTest, IsTranspose2DApplicable3DNotApplicable) {
410   RuntimeShape input_shape({4, 5, 6});
411 
412   TransposeParams params;
413   params.perm_count = 3;
414   params.perm[0] = 2;
415   params.perm[1] = 1;
416   params.perm[2] = 0;
417 
418   int dim0, dim1;
419   bool applicable = transpose_utils::IsTranspose2DApplicable(
420       params, input_shape, &dim0, &dim1);
421 
422   EXPECT_FALSE(applicable);
423 }
424 
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DOne)425 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DOne) {
426   RuntimeShape input_shape({4, 5, 6, 7});
427 
428   TransposeParams params;
429   params.perm_count = 4;
430   params.perm[0] = 1;
431   params.perm[1] = 2;
432   params.perm[2] = 3;
433   params.perm[3] = 0;
434 
435   int dim0, dim1;
436   bool applicable = transpose_utils::IsTranspose2DApplicable(
437       params, input_shape, &dim0, &dim1);
438 
439   EXPECT_TRUE(applicable);
440   EXPECT_EQ(dim0, 4);
441   EXPECT_EQ(dim1, 210);
442 }
443 
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DTwo)444 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DTwo) {
445   RuntimeShape input_shape({4, 5, 6, 7});
446 
447   TransposeParams params;
448   params.perm_count = 4;
449   params.perm[0] = 2;
450   params.perm[1] = 3;
451   params.perm[2] = 0;
452   params.perm[3] = 1;
453 
454   int dim0, dim1;
455   bool applicable = transpose_utils::IsTranspose2DApplicable(
456       params, input_shape, &dim0, &dim1);
457 
458   EXPECT_TRUE(applicable);
459   EXPECT_EQ(dim0, 20);
460   EXPECT_EQ(dim1, 42);
461 }
462 
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DThird)463 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DThird) {
464   RuntimeShape input_shape({4, 5, 6, 7});
465 
466   TransposeParams params;
467   params.perm_count = 4;
468   params.perm[0] = 3;
469   params.perm[1] = 0;
470   params.perm[2] = 1;
471   params.perm[3] = 2;
472 
473   int dim0, dim1;
474   bool applicable = transpose_utils::IsTranspose2DApplicable(
475       params, input_shape, &dim0, &dim1);
476 
477   EXPECT_TRUE(applicable);
478   EXPECT_EQ(dim0, 120);
479   EXPECT_EQ(dim1, 7);
480 }
481 
TEST(TransposeUtilsTest,IsTranspose2DApplicable4DNotApplicable)482 TEST(TransposeUtilsTest, IsTranspose2DApplicable4DNotApplicable) {
483   RuntimeShape input_shape({4, 5, 6, 7});
484 
485   TransposeParams params;
486   params.perm_count = 4;
487   params.perm[0] = 3;
488   params.perm[1] = 2;
489   params.perm[2] = 1;
490   params.perm[3] = 0;
491 
492   int dim0, dim1;
493   bool applicable = transpose_utils::IsTranspose2DApplicable(
494       params, input_shape, &dim0, &dim1);
495 
496   EXPECT_FALSE(applicable);
497 }
498 
499 }  // namespace
500 }  // namespace tflite
501 
main(int argc,char ** argv)502 int main(int argc, char** argv) {
503   // On Linux, add: tflite::LogToStderr();
504   ::testing::InitGoogleTest(&argc, argv);
505   return RUN_ALL_TESTS();
506 }
507