• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct ConcatFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
ConcatFixtureConcatFixture14     explicit ConcatFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15                            unsigned int concatDim)
16     {
17         m_Prototext = R"(
18         node {
19           name: "graphInput0"
20           op: "Placeholder"
21           attr {
22             key: "dtype"
23             value {
24               type: DT_FLOAT
25             }
26           }
27           attr {
28             key: "shape"
29             value {
30               shape {
31               }
32             }
33           }
34         }
35         node {
36           name: "graphInput1"
37           op: "Placeholder"
38           attr {
39             key: "dtype"
40             value {
41               type: DT_FLOAT
42             }
43           }
44           attr {
45             key: "shape"
46             value {
47               shape {
48               }
49             }
50           }
51         }
52         node {
53           name: "concat/axis"
54           op: "Const"
55           attr {
56             key: "dtype"
57             value {
58               type: DT_INT32
59             }
60           }
61           attr {
62             key: "value"
63             value {
64               tensor {
65                 dtype: DT_INT32
66                 tensor_shape {
67                 }
68                 int_val: )";
69 
70         m_Prototext += std::to_string(concatDim);
71 
72         m_Prototext += R"(
73               }
74             }
75           }
76         }
77         node {
78           name: "concat"
79           op: "ConcatV2"
80           input: "graphInput0"
81           input: "graphInput1"
82           input: "concat/axis"
83           attr {
84             key: "N"
85             value {
86               i: 2
87             }
88           }
89           attr {
90             key: "T"
91             value {
92               type: DT_FLOAT
93             }
94           }
95           attr {
96             key: "Tidx"
97             value {
98               type: DT_FLOAT
99             }
100           }
101         }
102         )";
103 
104         Setup({{"graphInput0", inputShape0 },
105                {"graphInput1", inputShape1 }}, {"concat"});
106     }
107 };
108 
109 struct ConcatFixtureNCHW : ConcatFixture
110 {
ConcatFixtureNCHWConcatFixtureNCHW111     ConcatFixtureNCHW() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 1 ) {}
112 };
113 
114 struct ConcatFixtureNHWC : ConcatFixture
115 {
ConcatFixtureNHWCConcatFixtureNHWC116     ConcatFixtureNHWC() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 3 ) {}
117 };
118 
BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW,ConcatFixtureNCHW)119 BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW)
120 {
121     RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
122                 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
123                {{"concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }}});
124 }
125 
BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC,ConcatFixtureNHWC)126 BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC, ConcatFixtureNHWC)
127 {
128     RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
129                 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
130                {{"concat", { 0.0, 1.0, 4.0, 5.0, 2.0, 3.0, 6.0, 7.0 }}});
131 }
132 
133 struct ConcatFixtureDim1 : ConcatFixture
134 {
ConcatFixtureDim1ConcatFixtureDim1135     ConcatFixtureDim1() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 1) {}
136 };
137 
138 struct ConcatFixtureDim3 : ConcatFixture
139 {
ConcatFixtureDim3ConcatFixtureDim3140     ConcatFixtureDim3() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 3) {}
141 };
142 
BOOST_FIXTURE_TEST_CASE(ParseConcatDim1,ConcatFixtureDim1)143 BOOST_FIXTURE_TEST_CASE(ParseConcatDim1, ConcatFixtureDim1)
144 {
145     RunTest<4>({ { "graphInput0", {  0.0,  1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0, 10.0, 11.0,
146                                      12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0 } },
147                  { "graphInput1", {  50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
148                                      62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } },
149                { { "concat",      {  0.0,  1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0, 10.0, 11.0,
150                                      12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
151                                      50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
152                                      62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } });
153 }
154 
BOOST_FIXTURE_TEST_CASE(ParseConcatDim3,ConcatFixtureDim3)155 BOOST_FIXTURE_TEST_CASE(ParseConcatDim3, ConcatFixtureDim3)
156 {
157     RunTest<4>({ { "graphInput0", {  0.0, 1.0, 2.0, 3.0,
158                                      4.0, 5.0, 6.0, 7.0,
159                                      8.0, 9.0, 10.0, 11.0,
160                                      12.0, 13.0, 14.0, 15.0,
161                                      16.0, 17.0, 18.0, 19.0,
162                                      20.0, 21.0, 22.0, 23.0 } },
163                  { "graphInput1", {  50.0, 51.0, 52.0, 53.0,
164                                      54.0, 55.0, 56.0, 57.0,
165                                      58.0, 59.0, 60.0, 61.0,
166                                      62.0, 63.0, 64.0, 65.0,
167                                      66.0, 67.0, 68.0, 69.0,
168                                      70.0, 71.0, 72.0, 73.0 } } },
169                { { "concat",      {  0.0,  1.0,  2.0,  3.0,
170                                      50.0, 51.0, 52.0, 53.0,
171                                      4.0,  5.0,  6.0,  7.0,
172                                      54.0, 55.0, 56.0, 57.0,
173                                      8.0,  9.0,  10.0, 11.0,
174                                      58.0, 59.0, 60.0, 61.0,
175                                      12.0, 13.0, 14.0, 15.0,
176                                      62.0, 63.0, 64.0, 65.0,
177                                      16.0, 17.0, 18.0, 19.0,
178                                      66.0, 67.0, 68.0, 69.0,
179                                      20.0, 21.0, 22.0, 23.0,
180                                      70.0, 71.0, 72.0, 73.0 } } });
181 }
182 
183 BOOST_AUTO_TEST_SUITE_END()