1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 // Graph which tests that nodes are re-ordered in the queue when they are encountered a second time.
13 // In this case R0 will be encountered first via R1 and then via R2. At that time
14 // we need to make sure that R0 (and the I on which it is dependent) is moved to the front again
15 // so that it is before both R1 and R2.
16 // I
17 // |
18 // R0
19 // / \'
20 // R1 R2
21 // \ |
22 // \ R3
23 // \|
24 // O
25 struct RediscoveredDependenciesFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
26 {
RediscoveredDependenciesFixtureRediscoveredDependenciesFixture27 RediscoveredDependenciesFixture()
28 {
29 // Input = tf.placeholder(tf.float32, 1, "input")
30 // Relu0 = tf.nn.relu(input, "relu0")
31 // Relu1 = tf.nn.relu(relu0, "relu1")
32 // Relu2 = tf.nn.relu(relu0, "relu2")
33 // Relu3 = tf.nn.relu(relu2, "relu3")
34 // Output = tf.add(relu1, relu3, "output")
35 m_Prototext = R"(
36 node {
37 name: "input"
38 op: "Placeholder"
39 attr {
40 key: "dtype"
41 value {
42 type: DT_FLOAT
43 }
44 }
45 attr {
46 key: "shape"
47 value {
48 shape {
49 dim {
50 size: 1
51 }
52 }
53 }
54 }
55 }
56 node {
57 name: "relu0"
58 op: "Relu"
59 input: "input"
60 attr {
61 key: "T"
62 value {
63 type: DT_FLOAT
64 }
65 }
66 }
67 node {
68 name: "relu1"
69 op: "Relu"
70 input: "relu0"
71 attr {
72 key: "T"
73 value {
74 type: DT_FLOAT
75 }
76 }
77 }
78 node {
79 name: "relu2"
80 op: "Relu"
81 input: "relu0"
82 attr {
83 key: "T"
84 value {
85 type: DT_FLOAT
86 }
87 }
88 }
89 node {
90 name: "relu3"
91 op: "Relu"
92 input: "relu2"
93 attr {
94 key: "T"
95 value {
96 type: DT_FLOAT
97 }
98 }
99 }
100 node {
101 name: "output"
102 op: "Add"
103 input: "relu1"
104 input: "relu3"
105 attr {
106 key: "T"
107 value {
108 type: DT_FLOAT
109 }
110 }
111 }
112 )";
113 SetupSingleInputSingleOutput({ 1 }, "input", "output");
114 }
115 };
116
BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies,RediscoveredDependenciesFixture)117 BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
118 {
119 RunTest<1>({1}, {2});
120 }
121
122 // Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser
123 // getting stuck in an infinite loop.
BOOST_AUTO_TEST_CASE(SimpleCycle)124 BOOST_AUTO_TEST_CASE(SimpleCycle)
125 {
126 const char* prototext = R"(
127 node {
128 name: "r1"
129 op: "Relu"
130 input: "r2"
131 attr {
132 key: "T"
133 value {
134 type: DT_FLOAT
135 }
136 }
137 }
138 node {
139 name: "r2"
140 op: "Relu"
141 input: "r1"
142 attr {
143 key: "T"
144 value {
145 type: DT_FLOAT
146 }
147 }
148 }
149 )";
150 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
151 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException);
152 }
153
154 // Similar to the above SimpleCycle test, but has a single node which connects to itself.
BOOST_AUTO_TEST_CASE(SingleNodeCycle)155 BOOST_AUTO_TEST_CASE(SingleNodeCycle)
156 {
157 const char* prototext = R"(
158 node {
159 name: "r1"
160 op: "Relu"
161 input: "r1"
162 attr {
163 key: "T"
164 value {
165 type: DT_FLOAT
166 }
167 }
168 }
169 )";
170 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
171 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
172 }
173
174 // Similar to the above SimpleCycle test, but with a more complicated graph.
175 // I
176 // |
177 // A2---<---<-
178 // / \' |
179 // R1 R2 |
180 // \ | |
181 // \ R3 |
182 // \| |
183 // A1-->--->|
184 //
BOOST_AUTO_TEST_CASE(ComplexCycle)185 BOOST_AUTO_TEST_CASE(ComplexCycle)
186 {
187 // Input = tf.placeholder(tf.float32, 1, "input")
188 // Add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined
189 // Relu1 = tf.nn.relu(relu0, "relu1")
190 // Relu2 = tf.nn.relu(relu0, "relu2")
191 // Relu3 = tf.nn.relu(relu2, "relu3")
192 // Add1 = tf.add(relu1, relu3, "add1")
193 const char* prototext = R"(
194 node {
195 name: "input"
196 op: "Placeholder"
197 attr {
198 key: "dtype"
199 value {
200 type: DT_FLOAT
201 }
202 }
203 attr {
204 key: "shape"
205 value {
206 shape {
207 dim {
208 size: 1
209 }
210 }
211 }
212 }
213 }
214 node {
215 name: "add2"
216 op: "Add"
217 input: "input"
218 input: "add1"
219 attr {
220 key: "T"
221 value {
222 type: DT_FLOAT
223 }
224 }
225 }
226 node {
227 name: "relu1"
228 op: "Relu"
229 input: "add2"
230 attr {
231 key: "T"
232 value {
233 type: DT_FLOAT
234 }
235 }
236 }
237 node {
238 name: "relu2"
239 op: "Relu"
240 input: "add2"
241 attr {
242 key: "T"
243 value {
244 type: DT_FLOAT
245 }
246 }
247 }
248 node {
249 name: "relu3"
250 op: "Relu"
251 input: "relu2"
252 attr {
253 key: "T"
254 value {
255 type: DT_FLOAT
256 }
257 }
258 }
259 node {
260 name: "add1"
261 op: "Add"
262 input: "relu1"
263 input: "relu3"
264 attr {
265 key: "T"
266 value {
267 type: DT_FLOAT
268 }
269 }
270 }
271 )";
272 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
273 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException);
274 }
275
276 // Tests that a graph with an input that is not present throws a ParseException.
BOOST_AUTO_TEST_CASE(InvalidInput)277 BOOST_AUTO_TEST_CASE(InvalidInput)
278 {
279 const char* prototext = R"(
280 node {
281 name: "r1"
282 op: "Relu"
283 input: "a-node-that-does-not-exist"
284 attr {
285 key: "T"
286 value {
287 type: DT_FLOAT
288 }
289 }
290 }
291 )";
292 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
293 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
294 }
295
296 BOOST_AUTO_TEST_SUITE_END()
297