• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <xnnpack.h>
7 
8 #include "subgraph-tester.h"
9 #include <gtest/gtest.h>
10 
TEST(SUBGRAPH_NCHW,single_conv)11 TEST(SUBGRAPH_NCHW, single_conv) {
12   auto tester = SubgraphTester(4);
13   tester
14     .add_tensor({1, 256, 256, 3}, kDynamic, 0)
15     .add_tensor({32, 3, 3, 3}, kStaticDense, 1)
16     .add_tensor({32}, kStaticDense, 2)
17     .add_tensor({1, 128, 128, 32}, kDynamic, 3)
18     .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 32, 0, 1, 2, 3)
19     .optimize()
20     .rewrite();
21 
22   ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
23   ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nhwc);
24 }
25 
TEST(SUBGRAPH_NCHW,single_conv_and_global_average_pooling)26 TEST(SUBGRAPH_NCHW, single_conv_and_global_average_pooling) {
27   auto tester = SubgraphTester(5);
28   tester
29     .add_tensor({1, 256, 256, 3}, kDynamic, 0)
30     .add_tensor({32, 3, 3, 3}, kStaticDense, 1)
31     .add_tensor({32}, kStaticDense, 2)
32     .add_tensor({1, 128, 128, 32}, kDynamic, 3)
33     .add_tensor({32}, kDynamic, 4)
34     .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 32, 0, 1, 2, 3)
35     .add_global_average_pooling(3, 4)
36     .optimize()
37     .rewrite();
38 
39   ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
40   ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nhwc);
41   ASSERT_EQ(tester.get_layout(4), xnn_layout_type_nhwc);
42 }
43 
TEST(SUBGRAPH_NCHW,pixelwise_conv_sandwich)44 TEST(SUBGRAPH_NCHW, pixelwise_conv_sandwich) {
45   auto tester = SubgraphTester(8);
46   tester
47     .add_tensor({1, 256, 256, 3}, kDynamic, 0)
48     .add_tensor({8, 3, 3, 3}, kStaticDense, 1)
49     .add_tensor({8}, kStaticDense, 2)
50     .add_tensor({1, 128, 128, 8}, kDynamic, 3)
51     .add_tensor({4, 1, 1, 8}, kStaticSparse, 4)
52     .add_tensor({4}, kStaticDense, 5)
53     .add_tensor({1, 128, 128, 4}, kDynamic, 6)
54     .add_tensor({1, 4}, kDynamic, 7)
55     .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 8, 0, 1, 2, 3)
56     .add_conv(0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 8, 4, 3, 4, 5, 6)
57     .add_global_average_pooling(6, 7)
58     .optimize()
59     .rewrite();
60 
61   ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
62   ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nchw);
63   ASSERT_EQ(tester.get_layout(6), xnn_layout_type_nchw);
64   ASSERT_EQ(tester.get_layout(7), xnn_layout_type_nhwc);
65 }
66 
TEST(SUBGRAPH_NCHW,bottleneck)67 TEST(SUBGRAPH_NCHW, bottleneck) {
68   auto tester = SubgraphTester(15);
69   tester
70     .add_tensor({1, 256, 256, 3}, kDynamic, 0)
71     .add_tensor({8, 3, 3, 3}, kStaticDense, 1)
72     .add_tensor({8}, kStaticDense, 2)
73     .add_tensor({1, 128, 128, 8}, kDynamic, 3)
74     .add_tensor({4, 1, 1, 8}, kStaticSparse, 4)
75     .add_tensor({4}, kStaticDense, 5)
76     .add_tensor({1, 128, 128, 4}, kDynamic, 6)
77     .add_tensor({1, 3, 3, 4}, kStaticDense, 7)
78     .add_tensor({4}, kStaticDense, 8)
79     .add_tensor({1, 128, 128, 4}, kDynamic, 9)
80     .add_tensor({8, 1, 1, 4}, kStaticSparse, 10)
81     .add_tensor({8}, kStaticDense, 11)
82     .add_tensor({1, 128, 128, 8}, kDynamic, 12)
83     .add_tensor({1, 128, 128, 8}, kDynamic, 13)
84     .add_tensor({1, 128, 128, 8}, kDynamic, 13)
85     .add_tensor({1, 8}, kDynamic, 14)
86     .add_conv(1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 1, 3, 8, 0, 1, 2, 3)
87     .add_conv(0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 8, 4, 3, 4, 5, 6)
88     .add_depthwise_conv(1, 1, 1, 1, 3, 3, 1, 1, 1, 1, 1, 4, 6, 7, 8, 9)
89     .add_conv(0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 8, 4, 9, 10, 11, 12)
90     .add_addition(3, 12, 13)
91     .add_global_average_pooling(13, 14)
92     .optimize()
93     .rewrite();
94 
95   ASSERT_EQ(tester.get_layout(0), xnn_layout_type_nhwc);
96   ASSERT_EQ(tester.get_layout(3), xnn_layout_type_nchw);
97   ASSERT_EQ(tester.get_layout(6), xnn_layout_type_nchw);
98   ASSERT_EQ(tester.get_layout(9), xnn_layout_type_nchw);
99   ASSERT_EQ(tester.get_layout(12), xnn_layout_type_nchw);
100   ASSERT_EQ(tester.get_layout(13), xnn_layout_type_nchw);
101   ASSERT_EQ(tester.get_layout(14), xnn_layout_type_nhwc);
102 }
103