1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 #include <array>
11 #include <string>
12 #include <iostream>
13
14 BOOST_AUTO_TEST_SUITE(TensorflowParser)
15
16 struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
17 {
Convolution2dFixtureConvolution2dFixture18 explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType)
19 : Convolution2dFixture(dataLayout, paddingType, 1)
20 {}
21
22 // Dilation: 0 - dilations attribute is not included;
23 // Dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
Convolution2dFixtureConvolution2dFixture24 explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType,
25 int stride, int dilation = 0)
26 {
27 std::string strideString (" i: 1 \n"
28 " i: 1 \n");
29 if (dataLayout == "NHWC")
30 {
31 strideString.append(" i: " + std::to_string(stride) + " \n"
32 " i: 1 \n");
33 }
34 else // dataLayout == "NCHW"
35 {
36 strideString.append(" i: 1 \n"
37 " i: " + std::to_string(stride) + " \n");
38 }
39
40 std::string dilationString = std::to_string(dilation);
41 m_Prototext = "node { \n"
42 " name: \"graphInput\" \n"
43 " op: \"Placeholder\" \n"
44 " attr { \n"
45 " key: \"dtype\" \n"
46 " value { \n"
47 " type: DT_FLOAT \n"
48 " } \n"
49 " } \n"
50 " attr { \n"
51 " key: \"shape\" \n"
52 " value { \n"
53 " shape { \n"
54 " } \n"
55 " } \n"
56 " } \n"
57 " } \n"
58 " node { \n"
59 " name: \"Const_1\" \n"
60 " op: \"Const\" \n"
61 " attr { \n"
62 " key: \"dtype\" \n"
63 " value { \n"
64 " type: DT_FLOAT \n"
65 " } \n"
66 " } \n"
67 " attr { \n"
68 " key: \"value\" \n"
69 " value { \n"
70 " tensor { \n"
71 " dtype: DT_FLOAT \n"
72 " tensor_shape { \n"
73 " dim { \n"
74 " size: 1 \n"
75 " } \n"
76 " dim { \n"
77 " size: 3 \n"
78 " } \n"
79 " dim { \n"
80 " size: 1 \n"
81 " } \n"
82 " dim { \n"
83 " size: 1 \n"
84 " } \n"
85 " } \n"
86 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
87 " } \n"
88 " } \n"
89 " } \n"
90 "} \n"
91 "node { \n"
92 " name: \"potato\" \n"
93 " op: \"Conv2D\" \n"
94 " input: \"graphInput\" \n"
95 " input: \"Const_1\" \n"
96 " attr { \n"
97 " key: \"T\" \n"
98 " value { \n"
99 " type: DT_FLOAT \n"
100 " } \n"
101 " } \n"
102 " attr { \n"
103 " key: \"data_format\" \n"
104 " value { \n"
105 " s: \"";
106 m_Prototext.append(dataLayout);
107 m_Prototext.append("\"\n"
108 " } \n"
109 " } \n"
110 " attr { \n"
111 " key: \"padding\" \n"
112 " value { \n"
113 " s: \"");
114 m_Prototext.append(paddingType);
115 m_Prototext.append("\"\n"
116 " } \n"
117 " } \n"
118 " attr { \n"
119 " key: \"strides\" \n"
120 " value { \n"
121 " list { \n");
122 m_Prototext.append(strideString);
123
124 m_Prototext.append(" } \n"
125 " } \n"
126 " } \n");
127
128 if (dilation > 0)
129 {
130 m_Prototext.append(" attr { \n"
131 " key: \"dilations\" \n"
132 " value { \n"
133 " list { \n"
134 " i: 1 \n"
135 " i: ");
136 m_Prototext.append(dilationString);
137 m_Prototext.append(" \n"
138 " i: ");
139 m_Prototext.append(dilationString);
140 m_Prototext.append(" \n"
141 " i: 1 \n"
142 " } \n"
143 " } \n"
144 " } \n");
145 }
146 m_Prototext.append(" attr { \n"
147 " key: \"use_cudnn_on_gpu\" \n"
148 " value { \n"
149 " b: false \n"
150 " } \n"
151 " } \n"
152 "} \n");
153
154 // Manual height computation based on stride parameter.
155 ARMNN_ASSERT_MSG(stride == 1 || stride == 2, "Add support for strides other than 1 or 2.");
156 std::array<unsigned int, 4> dims;
157 if (dataLayout == "NHWC")
158 {
159 dims = { 1u, (stride == 2 ? 3u : 2u), 3u, 1u };
160 }
161 else // dataLayout == "NCHW"
162 {
163 dims = { 1u, 1u, (stride == 2 ? 3u : 2u), 3u };
164 }
165
166 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
167 }
168 };
169
170
171 struct Convolution2dNhwcSameFixture : Convolution2dFixture
172 {
Convolution2dNhwcSameFixtureConvolution2dNhwcSameFixture173 Convolution2dNhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 1){}
174 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame,Convolution2dNhwcSameFixture)175 BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
176 {
177 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
178 }
179
180 struct Convolution2dNchwSameFixture : Convolution2dFixture
181 {
Convolution2dNchwSameFixtureConvolution2dNchwSameFixture182 Convolution2dNchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 1){}
183 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwSame,Convolution2dNchwSameFixture)184 BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwSame, Convolution2dNchwSameFixture)
185 {
186 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
187 }
188
189
190 struct Convolution2dNhwcValidFixture : Convolution2dFixture
191 {
Convolution2dNhwcValidFixtureConvolution2dNhwcValidFixture192 Convolution2dNhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 1){}
193 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcValid,Convolution2dNhwcValidFixture)194 BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcValid, Convolution2dNhwcValidFixture)
195 {
196 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
197 }
198
199 struct Convolution2dNchwValidFixture : Convolution2dFixture
200 {
Convolution2dNchwValidFixtureConvolution2dNchwValidFixture201 Convolution2dNchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 1){}
202 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwValid,Convolution2dNchwValidFixture)203 BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwValid, Convolution2dNchwValidFixture)
204 {
205 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
206 }
207
208
209 struct Convolution2dStride2NhwcSameFixture : Convolution2dFixture
210 {
Convolution2dStride2NhwcSameFixtureConvolution2dStride2NhwcSameFixture211 Convolution2dStride2NhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 2){}
212 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcSame,Convolution2dStride2NhwcSameFixture)213 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcSame, Convolution2dStride2NhwcSameFixture)
214 {
215 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
216 }
217
218 struct Convolution2dStride2NchwSameFixture : Convolution2dFixture
219 {
Convolution2dStride2NchwSameFixtureConvolution2dStride2NchwSameFixture220 Convolution2dStride2NchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 2){}
221 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwSame,Convolution2dStride2NchwSameFixture)222 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwSame, Convolution2dStride2NchwSameFixture)
223 {
224 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
225 }
226
227
228 struct Convolution2dStride2NhwcValidFixture : Convolution2dFixture
229 {
Convolution2dStride2NhwcValidFixtureConvolution2dStride2NhwcValidFixture230 Convolution2dStride2NhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 2){}
231 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcValid,Convolution2dStride2NhwcValidFixture)232 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcValid, Convolution2dStride2NhwcValidFixture)
233 {
234 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
235 }
236
237 struct Convolution2dStride2NchwValidFixture : Convolution2dFixture
238 {
Convolution2dStride2NchwValidFixtureConvolution2dStride2NchwValidFixture239 Convolution2dStride2NchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 2){}
240 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwValid,Convolution2dStride2NchwValidFixture)241 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwValid, Convolution2dStride2NchwValidFixture)
242 {
243 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
244 }
245
246
247 struct Convolution2dDilation1NhwcFixture : Convolution2dFixture
248 {
Convolution2dDilation1NhwcFixtureConvolution2dDilation1NhwcFixture249 Convolution2dDilation1NhwcFixture() : Convolution2dFixture("NHWC", "SAME", 1, 1){}
250 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nhwc,Convolution2dDilation1NhwcFixture)251 BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nhwc, Convolution2dDilation1NhwcFixture)
252 {
253 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
254 }
255
256 struct Convolution2dDilation1NchwFixture : Convolution2dFixture
257 {
Convolution2dDilation1NchwFixtureConvolution2dDilation1NchwFixture258 Convolution2dDilation1NchwFixture() : Convolution2dFixture("NCHW", "SAME", 1, 1){}
259 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw,Convolution2dDilation1NchwFixture)260 BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw, Convolution2dDilation1NchwFixture)
261 {
262 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
263 }
264
265
BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)266 BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)
267 {
268 const char* prototext = ""
269 "node {\n"
270 " name: \"graphInput\"\n"
271 " op: \"Placeholder\"\n"
272 " attr {\n"
273 " key: \"dtype\"\n"
274 " value {\n"
275 " type: DT_FLOAT\n"
276 " }\n"
277 " }\n"
278 " attr {\n"
279 " key: \"shape\"\n"
280 " value {\n"
281 " shape {\n"
282 " }\n"
283 " }\n"
284 " }\n"
285 "}\n"
286 "node {\n"
287 " name: \"Const_1\"\n"
288 " op: \"Const\"\n"
289 " attr {\n"
290 " key: \"dtype\"\n"
291 " value {\n"
292 " type: DT_FLOAT\n"
293 " }\n"
294 " }\n"
295 " attr {\n"
296 " key: \"value\"\n"
297 " value {\n"
298 " tensor {\n"
299 " dtype: DT_FLOAT\n"
300 " tensor_shape {\n"
301 " dim {\n"
302 " size: 1\n"
303 " }\n"
304 " dim {\n"
305 " size: 3\n"
306 " }\n"
307 " dim {\n"
308 " size: 1\n"
309 " }\n"
310 " dim {\n"
311 " size: 1\n"
312 " }\n"
313 " }\n"
314 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
315 " }\n"
316 " }\n"
317 " }\n"
318 "}\n"
319 "node {\n"
320 " name: \"potato\"\n"
321 " op: \"Conv2D\"\n"
322 " input: \"graphInput\"\n"
323 " input: \"Const_1\"\n"
324 " attr {\n"
325 " key: \"T\"\n"
326 " value {\n"
327 " type: DT_FLOAT\n"
328 " }\n"
329 " }\n"
330 " attr {\n"
331 " key: \"data_format\"\n"
332 " value {\n"
333 " s: \"NHWC\"\n"
334 " }\n"
335 " }\n"
336 " attr {\n"
337 " key: \"padding\"\n"
338 " value {\n"
339 " s: \"SAME\"\n"
340 " }\n"
341 " }\n"
342 " attr {\n"
343 " key: \"strides\"\n"
344 " value {\n"
345 " list {\n"
346 " i: 1\n"
347 " i: 1\n"
348 " i: 1\n"
349 " i: 1\n"
350 " }\n"
351 " }\n"
352 " }\n"
353 " attr {\n"
354 " key: \"dilations\"\n"
355 " value {\n"
356 " list {\n"
357 " i: 1\n"
358 " i: 2\n"
359 " i: 2\n"
360 " i: 1\n"
361 " }\n"
362 " }\n"
363 " }\n"
364 " attr {\n"
365 " key: \"use_cudnn_on_gpu\"\n"
366 " value {\n"
367 " b: false\n"
368 " }\n"
369 " }\n"
370 "}\n";
371
372 std::map<std::string, armnn::TensorShape> inputShapes;
373 armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
374 inputShapes["graphInput"] = tensorShape;
375 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
376 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }), armnn::ParseException);
377 }
378
379
380 BOOST_AUTO_TEST_SUITE_END()
381