• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct ConcatOfConcatsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
ConcatOfConcatsFixtureConcatOfConcatsFixture14     explicit ConcatOfConcatsFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15                                     const armnn::TensorShape& inputShape2, const armnn::TensorShape& inputShape3,
16                                     unsigned int concatDim)
17     {
18         m_Prototext = R"(
19             node {
20               name: "graphInput0"
21               op: "Placeholder"
22               attr {
23                 key: "dtype"
24                 value {
25                   type: DT_FLOAT
26                 }
27               }
28               attr {
29                 key: "shape"
30                 value {
31                   shape {
32                   }
33                 }
34               }
35             }
36             node {
37               name: "graphInput1"
38               op: "Placeholder"
39               attr {
40                 key: "dtype"
41                 value {
42                   type: DT_FLOAT
43                 }
44               }
45               attr {
46                 key: "shape"
47                 value {
48                   shape {
49                   }
50                 }
51               }
52             }
53             node {
54               name: "graphInput2"
55               op: "Placeholder"
56               attr {
57                 key: "dtype"
58                 value {
59                   type: DT_FLOAT
60                 }
61               }
62               attr {
63                 key: "shape"
64                 value {
65                   shape {
66                   }
67                 }
68               }
69             }
70             node {
71               name: "graphInput3"
72               op: "Placeholder"
73               attr {
74                 key: "dtype"
75                 value {
76                   type: DT_FLOAT
77                 }
78               }
79               attr {
80                 key: "shape"
81                 value {
82                   shape {
83                   }
84                 }
85               }
86             }
87             node {
88               name: "Relu"
89               op: "Relu"
90               input: "graphInput0"
91               attr {
92                 key: "T"
93                 value {
94                   type: DT_FLOAT
95                 }
96               }
97             }
98             node {
99               name: "Relu_1"
100               op: "Relu"
101               input: "graphInput1"
102               attr {
103                 key: "T"
104                 value {
105                   type: DT_FLOAT
106                 }
107               }
108             }
109             node {
110               name: "Relu_2"
111               op: "Relu"
112               input: "graphInput2"
113               attr {
114                 key: "T"
115                 value {
116                   type: DT_FLOAT
117                 }
118               }
119             }
120             node {
121               name: "Relu_3"
122               op: "Relu"
123               input: "graphInput3"
124               attr {
125                 key: "T"
126                 value {
127                   type: DT_FLOAT
128                 }
129               }
130             }
131             node {
132               name: "concat/axis"
133               op: "Const"
134               attr {
135                 key: "dtype"
136                 value {
137                   type: DT_INT32
138                 }
139               }
140               attr {
141                 key: "value"
142                 value {
143                   tensor {
144                     dtype: DT_INT32
145                     tensor_shape {
146                     }
147                     int_val: )";
148                 m_Prototext += std::to_string(concatDim);
149                 m_Prototext += R"(
150                   }
151                 }
152               }
153             }
154             node {
155               name: "concat"
156               op: "ConcatV2"
157               input: "Relu"
158               input: "Relu_1"
159               input: "concat/axis"
160               attr {
161                 key: "N"
162                 value {
163                   i: 2
164                 }
165               }
166               attr {
167                 key: "T"
168                 value {
169                   type: DT_FLOAT
170                 }
171               }
172               attr {
173                 key: "Tidx"
174                 value {
175                   type: DT_INT32
176                 }
177               }
178             }
179             node {
180               name: "concat_1/axis"
181               op: "Const"
182               attr {
183                 key: "dtype"
184                 value {
185                   type: DT_INT32
186                 }
187               }
188               attr {
189                 key: "value"
190                 value {
191                   tensor {
192                     dtype: DT_INT32
193                     tensor_shape {
194                     }
195                     int_val: )";
196                 m_Prototext += std::to_string(concatDim);
197                 m_Prototext += R"(
198                   }
199                 }
200               }
201             }
202             node {
203               name: "concat_1"
204               op: "ConcatV2"
205               input: "Relu_2"
206               input: "Relu_3"
207               input: "concat_1/axis"
208               attr {
209                 key: "N"
210                 value {
211                   i: 2
212                 }
213               }
214               attr {
215                 key: "T"
216                 value {
217                   type: DT_FLOAT
218                 }
219               }
220               attr {
221                 key: "Tidx"
222                 value {
223                   type: DT_INT32
224                 }
225               }
226             }
227             node {
228               name: "concat_2/axis"
229               op: "Const"
230               attr {
231                 key: "dtype"
232                 value {
233                   type: DT_INT32
234                 }
235               }
236               attr {
237                 key: "value"
238                 value {
239                   tensor {
240                     dtype: DT_INT32
241                     tensor_shape {
242                     }
243                     int_val: )";
244                 m_Prototext += std::to_string(concatDim);
245                 m_Prototext += R"(
246                   }
247                 }
248               }
249             }
250             node {
251               name: "concat_2"
252               op: "ConcatV2"
253               input: "concat"
254               input: "concat_1"
255               input: "concat_2/axis"
256               attr {
257                 key: "N"
258                 value {
259                   i: 2
260                 }
261               }
262               attr {
263                 key: "T"
264                 value {
265                   type: DT_FLOAT
266                 }
267               }
268               attr {
269                 key: "Tidx"
270                 value {
271                   type: DT_INT32
272                 }
273               }
274             }
275             )";
276 
277         Setup({{ "graphInput0", inputShape0 },
278                { "graphInput1", inputShape1 },
279                { "graphInput2", inputShape2 },
280                { "graphInput3", inputShape3}}, {"concat_2"});
281     }
282 };
283 
284 struct ConcatOfConcatsFixtureNCHW : ConcatOfConcatsFixture
285 {
ConcatOfConcatsFixtureNCHWConcatOfConcatsFixtureNCHW286     ConcatOfConcatsFixtureNCHW() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
287                                                           { 1, 1, 2, 2 }, 1 ) {}
288 };
289 
290 struct ConcatOfConcatsFixtureNHWC : ConcatOfConcatsFixture
291 {
ConcatOfConcatsFixtureNHWCConcatOfConcatsFixtureNHWC292     ConcatOfConcatsFixtureNHWC() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
293                                                           { 1, 1, 2, 2 }, 3 ) {}
294 };
295 
BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW,ConcatOfConcatsFixtureNCHW)296 BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW)
297 {
298     RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
299                 {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
300                 {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
301                 {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
302                {{"concat_2", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
303                                      8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0 }}});
304 }
305 
BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC,ConcatOfConcatsFixtureNHWC)306 BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC, ConcatOfConcatsFixtureNHWC)
307 {
308     RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
309                 {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
310                 {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
311                 {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
312                {{"concat_2", { 0.0, 1.0, 4.0, 5.0, 8.0, 9.0, 12.0, 13.0,
313                                      2.0, 3.0, 6.0, 7.0, 10.0, 11.0, 14.0, 15.0 }}});
314 }
315 
316 BOOST_AUTO_TEST_SUITE_END()
317