• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 
10 #include <array>
11 #include <string>
12 #include <iostream>
13 
14 BOOST_AUTO_TEST_SUITE(TensorflowParser)
15 
16 struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
17 {
Convolution2dFixtureConvolution2dFixture18     explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType)
19     : Convolution2dFixture(dataLayout, paddingType, 1)
20     {}
21 
22     // Dilation: 0 - dilations attribute is not included;
23     // Dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
Convolution2dFixtureConvolution2dFixture24     explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType,
25                                   int stride, int dilation = 0)
26     {
27         std::string strideString ("        i: 1 \n"
28                                   "        i: 1 \n");
29         if (dataLayout == "NHWC")
30         {
31             strideString.append("        i: " + std::to_string(stride) + " \n"
32                                 "        i: 1 \n");
33         }
34         else // dataLayout == "NCHW"
35         {
36             strideString.append("        i: 1 \n"
37                                 "        i: " + std::to_string(stride) + " \n");
38         }
39 
40         std::string dilationString = std::to_string(dilation);
41         m_Prototext = "node { \n"
42             "    name: \"graphInput\" \n"
43             "    op: \"Placeholder\" \n"
44             "    attr { \n"
45             "      key: \"dtype\" \n"
46             "      value { \n"
47             "        type: DT_FLOAT \n"
48             "      } \n"
49             "    } \n"
50             "    attr { \n"
51             "      key: \"shape\" \n"
52             "      value { \n"
53             "        shape { \n"
54             "        } \n"
55             "      } \n"
56             "    } \n"
57             "  } \n"
58             "  node { \n"
59             "  name: \"Const_1\" \n"
60             "  op: \"Const\" \n"
61             "  attr { \n"
62             "    key: \"dtype\" \n"
63             "    value { \n"
64             "      type: DT_FLOAT \n"
65             "    } \n"
66             "  } \n"
67             "  attr { \n"
68             "    key: \"value\" \n"
69             "    value { \n"
70             "      tensor { \n"
71             "        dtype: DT_FLOAT \n"
72             "        tensor_shape { \n"
73             "          dim { \n"
74             "            size: 1 \n"
75             "          } \n"
76             "          dim { \n"
77             "            size: 3 \n"
78             "          } \n"
79             "          dim { \n"
80             "            size: 1 \n"
81             "          } \n"
82             "          dim { \n"
83             "            size: 1 \n"
84             "          } \n"
85             "        } \n"
86             "        tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
87             "      } \n"
88             "    } \n"
89             "  } \n"
90             "} \n"
91             "node { \n"
92             "  name: \"potato\" \n"
93             "  op: \"Conv2D\" \n"
94             "  input: \"graphInput\" \n"
95             "  input: \"Const_1\" \n"
96             "  attr { \n"
97             "    key: \"T\" \n"
98             "    value { \n"
99             "      type: DT_FLOAT \n"
100             "    } \n"
101             "  } \n"
102             "  attr { \n"
103             "    key: \"data_format\" \n"
104             "    value { \n"
105             "      s: \"";
106         m_Prototext.append(dataLayout);
107         m_Prototext.append("\"\n"
108                            "    } \n"
109                            "  } \n"
110                            "  attr { \n"
111                            "    key: \"padding\" \n"
112                            "    value { \n"
113                            "      s: \"");
114         m_Prototext.append(paddingType);
115         m_Prototext.append("\"\n"
116                            "    } \n"
117                            "  } \n"
118                            "  attr { \n"
119                            "    key: \"strides\" \n"
120                            "    value { \n"
121                            "      list { \n");
122         m_Prototext.append(strideString);
123 
124         m_Prototext.append("      } \n"
125                            "    } \n"
126                            "  } \n");
127 
128         if (dilation > 0)
129         {
130             m_Prototext.append("  attr { \n"
131                                "    key: \"dilations\" \n"
132                                "    value { \n"
133                                "      list { \n"
134                                "        i: 1 \n"
135                                "        i: ");
136             m_Prototext.append(dilationString);
137             m_Prototext.append(" \n"
138                                "        i: ");
139             m_Prototext.append(dilationString);
140             m_Prototext.append(" \n"
141                                "        i: 1 \n"
142                                "      } \n"
143                                "    } \n"
144                                "  } \n");
145         }
146         m_Prototext.append("  attr { \n"
147                            "    key: \"use_cudnn_on_gpu\" \n"
148                            "    value { \n"
149                            "      b: false \n"
150                            "    } \n"
151                            "  } \n"
152                            "} \n");
153 
154         // Manual height computation based on stride parameter.
155         ARMNN_ASSERT_MSG(stride == 1 || stride == 2, "Add support for strides other than 1 or 2.");
156         std::array<unsigned int, 4> dims;
157         if (dataLayout == "NHWC")
158         {
159             dims = { 1u, (stride == 2 ? 3u : 2u), 3u, 1u };
160         }
161         else // dataLayout == "NCHW"
162         {
163             dims = { 1u, 1u, (stride == 2 ? 3u : 2u), 3u };
164         }
165 
166         SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
167     }
168 };
169 
170 
171 struct Convolution2dNhwcSameFixture : Convolution2dFixture
172 {
Convolution2dNhwcSameFixtureConvolution2dNhwcSameFixture173     Convolution2dNhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 1){}
174 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame,Convolution2dNhwcSameFixture)175 BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
176 {
177     RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
178 }
179 
180 struct Convolution2dNchwSameFixture : Convolution2dFixture
181 {
Convolution2dNchwSameFixtureConvolution2dNchwSameFixture182     Convolution2dNchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 1){}
183 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwSame,Convolution2dNchwSameFixture)184 BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwSame, Convolution2dNchwSameFixture)
185 {
186     RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
187 }
188 
189 
190 struct Convolution2dNhwcValidFixture : Convolution2dFixture
191 {
Convolution2dNhwcValidFixtureConvolution2dNhwcValidFixture192     Convolution2dNhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 1){}
193 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcValid,Convolution2dNhwcValidFixture)194 BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcValid, Convolution2dNhwcValidFixture)
195 {
196     RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
197 }
198 
199 struct Convolution2dNchwValidFixture : Convolution2dFixture
200 {
Convolution2dNchwValidFixtureConvolution2dNchwValidFixture201     Convolution2dNchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 1){}
202 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwValid,Convolution2dNchwValidFixture)203 BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwValid, Convolution2dNchwValidFixture)
204 {
205     RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
206 }
207 
208 
209 struct Convolution2dStride2NhwcSameFixture : Convolution2dFixture
210 {
Convolution2dStride2NhwcSameFixtureConvolution2dStride2NhwcSameFixture211     Convolution2dStride2NhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 2){}
212 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcSame,Convolution2dStride2NhwcSameFixture)213 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcSame, Convolution2dStride2NhwcSameFixture)
214 {
215     RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
216 }
217 
218 struct Convolution2dStride2NchwSameFixture : Convolution2dFixture
219 {
Convolution2dStride2NchwSameFixtureConvolution2dStride2NchwSameFixture220     Convolution2dStride2NchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 2){}
221 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwSame,Convolution2dStride2NchwSameFixture)222 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwSame, Convolution2dStride2NchwSameFixture)
223 {
224     RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
225 }
226 
227 
228 struct Convolution2dStride2NhwcValidFixture : Convolution2dFixture
229 {
Convolution2dStride2NhwcValidFixtureConvolution2dStride2NhwcValidFixture230     Convolution2dStride2NhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 2){}
231 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcValid,Convolution2dStride2NhwcValidFixture)232 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcValid, Convolution2dStride2NhwcValidFixture)
233 {
234     RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
235 }
236 
237 struct Convolution2dStride2NchwValidFixture : Convolution2dFixture
238 {
Convolution2dStride2NchwValidFixtureConvolution2dStride2NchwValidFixture239     Convolution2dStride2NchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 2){}
240 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwValid,Convolution2dStride2NchwValidFixture)241 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwValid, Convolution2dStride2NchwValidFixture)
242 {
243     RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
244 }
245 
246 
247 struct Convolution2dDilation1NhwcFixture : Convolution2dFixture
248 {
Convolution2dDilation1NhwcFixtureConvolution2dDilation1NhwcFixture249     Convolution2dDilation1NhwcFixture() : Convolution2dFixture("NHWC", "SAME", 1, 1){}
250 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nhwc,Convolution2dDilation1NhwcFixture)251 BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nhwc, Convolution2dDilation1NhwcFixture)
252 {
253     RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
254 }
255 
256 struct Convolution2dDilation1NchwFixture : Convolution2dFixture
257 {
Convolution2dDilation1NchwFixtureConvolution2dDilation1NchwFixture258     Convolution2dDilation1NchwFixture() : Convolution2dFixture("NCHW", "SAME", 1, 1){}
259 };
BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw,Convolution2dDilation1NchwFixture)260 BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw, Convolution2dDilation1NchwFixture)
261 {
262     RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
263 }
264 
265 
BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)266 BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)
267 {
268     const char* prototext = ""
269         "node {\n"
270         "  name: \"graphInput\"\n"
271         "  op: \"Placeholder\"\n"
272         "  attr {\n"
273         "    key: \"dtype\"\n"
274         "    value {\n"
275         "      type: DT_FLOAT\n"
276         "    }\n"
277         "  }\n"
278         "  attr {\n"
279         "    key: \"shape\"\n"
280         "    value {\n"
281         "      shape {\n"
282         "      }\n"
283         "    }\n"
284         "  }\n"
285         "}\n"
286         "node {\n"
287         "  name: \"Const_1\"\n"
288         "  op: \"Const\"\n"
289         "  attr {\n"
290         "    key: \"dtype\"\n"
291         "    value {\n"
292         "      type: DT_FLOAT\n"
293         "    }\n"
294         "  }\n"
295         "  attr {\n"
296         "    key: \"value\"\n"
297         "    value {\n"
298         "      tensor {\n"
299         "        dtype: DT_FLOAT\n"
300         "        tensor_shape {\n"
301         "          dim {\n"
302         "            size: 1\n"
303         "          }\n"
304         "          dim {\n"
305         "            size: 3\n"
306         "          }\n"
307         "          dim {\n"
308         "            size: 1\n"
309         "          }\n"
310         "          dim {\n"
311         "            size: 1\n"
312         "          }\n"
313         "        }\n"
314         "        tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
315         "      }\n"
316         "    }\n"
317         "  }\n"
318         "}\n"
319         "node {\n"
320         "  name: \"potato\"\n"
321         "  op: \"Conv2D\"\n"
322         "  input: \"graphInput\"\n"
323         "  input: \"Const_1\"\n"
324         "  attr {\n"
325         "    key: \"T\"\n"
326         "    value {\n"
327         "      type: DT_FLOAT\n"
328         "    }\n"
329         "  }\n"
330         "  attr {\n"
331         "    key: \"data_format\"\n"
332         "    value {\n"
333         "      s: \"NHWC\"\n"
334         "    }\n"
335         "  }\n"
336         "  attr {\n"
337         "    key: \"padding\"\n"
338         "    value {\n"
339         "      s: \"SAME\"\n"
340         "    }\n"
341         "  }\n"
342         "  attr {\n"
343         "    key: \"strides\"\n"
344         "    value {\n"
345         "      list {\n"
346         "        i: 1\n"
347         "        i: 1\n"
348         "        i: 1\n"
349         "        i: 1\n"
350         "      }\n"
351         "    }\n"
352         "  }\n"
353         "  attr {\n"
354         "    key: \"dilations\"\n"
355         "    value {\n"
356         "      list {\n"
357         "        i: 1\n"
358         "        i: 2\n"
359         "        i: 2\n"
360         "        i: 1\n"
361         "      }\n"
362         "    }\n"
363         "  }\n"
364         "  attr {\n"
365         "    key: \"use_cudnn_on_gpu\"\n"
366         "    value {\n"
367         "      b: false\n"
368         "    }\n"
369         "  }\n"
370         "}\n";
371 
372     std::map<std::string, armnn::TensorShape> inputShapes;
373     armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
374     inputShapes["graphInput"] = tensorShape;
375     armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
376     BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }), armnn::ParseException);
377 }
378 
379 
380 BOOST_AUTO_TEST_SUITE_END()
381