• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnTfParser/ITfParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 
9 #include <armnn/utility/IgnoreUnused.hpp>
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
16 {
SplitFixtureSplitFixture17     SplitFixture(bool withDimZero=false) {
18         m_Prototext = R"(
19         node {
20           name: "graphInput"
21           op: "Placeholder"
22           attr {
23             key: "dtype"
24             value {
25               type: DT_FLOAT
26             }
27           }
28           attr {
29             key: "shape"
30             value {
31               shape {
32               }
33             }
34           }
35         }
36         node {
37           name: "graphInput2"
38           op: "Placeholder"
39           attr {
40             key: "dtype"
41             value {
42               type: DT_FLOAT
43             }
44           }
45           attr {
46             key: "shape"
47             value {
48               shape {
49               }
50             }
51           }
52         }
53         node {
54         name: "multiplication"
55         op : "Mul"
56         input: "graphInput"
57         input: "graphInput2"
58         attr {
59         key: "T"
60         value {
61             type: DT_FLOAT
62         }
63         }
64         }
65         node {
66           name: "SplitInput"
67           op: "Const"
68           attr {
69             key: "dtype"
70             value {
71               type: DT_INT32
72             }
73           }
74           attr {
75             key: "value"
76             value {
77               tensor {
78                 dtype: DT_INT32
79                 tensor_shape {
80                 }
81                 int_val: )";
82 
83         if(withDimZero)
84         {
85             m_Prototext += std::to_string(3);
86         }
87         else
88         {
89             m_Prototext += std::to_string(1);
90         }
91 
92         m_Prototext += R"(
93         }
94         }
95         }
96         }
97         node {
98           name: "Split"
99           op: "Split" )";
100         if(withDimZero)
101         {
102             m_Prototext += "input: \"SplitInput\"\n";
103             m_Prototext += "input: \"multiplication\"\n";
104         }
105         else
106         {
107             m_Prototext += "input: \"graphInput\"\n";
108             m_Prototext += "input: \"SplitInput\"\n";
109         }
110         m_Prototext += R"(
111           attr {
112             key: "num_split"
113             value {
114               i: 2
115             }
116           }
117         }
118         node {
119             name: "Relu_1"
120             op: "Relu"
121             input: "Split:0"
122             attr {
123             key: "T"
124             value {
125             type: DT_FLOAT
126              }
127             }
128             }
129          node {
130             name: "Relu_2"
131             op: "Relu"
132             input:"Split:1"
133             attr {
134             key: "T"
135             value {
136             type: DT_FLOAT
137              }
138             }
139             } )";
140 
141         Setup( { { "graphInput", { 1,  2,  2 , 2} } , { "graphInput2", { 1,  2,  2 , 2} }},
142                { "Relu_1", "Relu_2" });
143     }
144 };
145 
146 struct InputFirstSplitFixture : SplitFixture
147 {
InputFirstSplitFixtureInputFirstSplitFixture148     InputFirstSplitFixture() : SplitFixture(true) {}
149 };
150 
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo,SplitFixture)151 BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
152 {
153     BOOST_TEST(
154         (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
155 
156     BOOST_TEST(
157         (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
158 
159     RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
160                { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
161                  { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
162 }
163 
BOOST_FIXTURE_TEST_CASE(ParseSplit,InputFirstSplitFixture)164 BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
165 {
166 
167     BOOST_TEST(
168             (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
169 
170     BOOST_TEST(
171             (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
172 
173     RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
174                  { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
175                { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
176                  { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
177 }
178 
179 struct SplitLastDimFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
180 {
SplitLastDimFixtureSplitLastDimFixture181     SplitLastDimFixture(bool withDimZero=false) {
182         armnn::IgnoreUnused(withDimZero);
183         m_Prototext = R"(
184         node {
185           name: "Placeholder"
186           op: "Placeholder"
187           attr {
188             key: "dtype"
189             value {
190               type: DT_FLOAT
191             }
192           }
193           attr {
194             key: "shape"
195             value {
196               shape {
197                 dim {
198                   size: 1
199                 }
200                 dim {
201                   size: 2
202                 }
203                 dim {
204                   size: 2
205                 }
206                 dim {
207                   size: 3
208                 }
209               }
210             }
211           }
212         }
213         node {
214           name: "Const"
215           op: "Const"
216           attr {
217             key: "dtype"
218             value {
219               type: DT_INT32
220             }
221           }
222           attr {
223             key: "value"
224             value {
225               tensor {
226                 dtype: DT_INT32
227                 tensor_shape {
228                 }
229                 int_val: 3
230               }
231             }
232           }
233         }
234         node {
235           name: "split/split_dim"
236           op: "Const"
237           attr {
238             key: "dtype"
239             value {
240               type: DT_INT32
241             }
242           }
243           attr {
244             key: "value"
245             value {
246               tensor {
247                 dtype: DT_INT32
248                 tensor_shape {
249                 }
250                 int_val: 3
251               }
252             }
253           }
254         }
255         node {
256           name: "split"
257           op: "Split"
258           input: "split/split_dim"
259           input: "Placeholder"
260           attr {
261             key: "T"
262             value {
263               type: DT_FLOAT
264             }
265           }
266           attr {
267             key: "num_split"
268             value {
269               i: 3
270             }
271           }
272         }
273         node {
274           name: "sub0/y"
275           op: "Const"
276           attr {
277             key: "dtype"
278             value {
279               type: DT_FLOAT
280             }
281           }
282           attr {
283             key: "value"
284             value {
285               tensor {
286                 dtype: DT_FLOAT
287                 tensor_shape {
288                 }
289                 float_val: 3.0
290               }
291             }
292           }
293         }
294         node {
295           name: "sub0"
296           op: "Sub"
297           input: "split"
298           input: "sub0/y"
299           attr {
300             key: "T"
301             value {
302               type: DT_FLOAT
303             }
304           }
305         }
306         node {
307           name: "sub1/y"
308           op: "Const"
309           attr {
310             key: "dtype"
311             value {
312               type: DT_FLOAT
313             }
314           }
315           attr {
316             key: "value"
317             value {
318               tensor {
319                 dtype: DT_FLOAT
320                 tensor_shape {
321                 }
322                 float_val: 2.0
323               }
324             }
325           }
326         }
327         node {
328           name: "sub1"
329           op: "Sub"
330           input: "split:1"
331           input: "sub1/y"
332           attr {
333             key: "T"
334             value {
335               type: DT_FLOAT
336             }
337           }
338         }
339         node {
340           name: "sub2/y"
341           op: "Const"
342           attr {
343             key: "dtype"
344             value {
345               type: DT_FLOAT
346             }
347           }
348           attr {
349             key: "value"
350             value {
351               tensor {
352                 dtype: DT_FLOAT
353                 tensor_shape {
354                 }
355                 float_val: 1.0
356               }
357             }
358           }
359         }
360         node {
361           name: "sub2"
362           op: "Sub"
363           input: "split:2"
364           input: "sub2/y"
365           attr {
366             key: "T"
367             value {
368               type: DT_FLOAT
369             }
370           }
371         }
372         versions {
373           producer: 27
374         } )";
375 
376         Setup( { { "Placeholder", { 1,  2,  2 , 3} } },
377                { "sub0", "sub1", "sub2" });
378     }
379 };
380 
BOOST_FIXTURE_TEST_CASE(SplitLastDimTest,SplitLastDimFixture)381 BOOST_FIXTURE_TEST_CASE(SplitLastDimTest, SplitLastDimFixture)
382 {
383     BOOST_TEST(
384             (m_Parser->GetNetworkOutputBindingInfo("sub0").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
385 
386     BOOST_TEST(
387             (m_Parser->GetNetworkOutputBindingInfo("sub1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
388 
389     BOOST_TEST(
390             (m_Parser->GetNetworkOutputBindingInfo("sub2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
391 
392     RunTest<4>({ { "Placeholder", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f } } },
393                { { "sub0", { -2.0f, 1.0f, 4.0f, 7.0f } },
394                  { "sub1", { 0.0f, 3.0f, 6.0f, 9.0f } },
395                  { "sub2", { 2.0f, 5.0f, 8.0f, 11.0f } } });
396 }
397 
398 BOOST_AUTO_TEST_SUITE_END()
399