• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
ExpandDimsFixtureExpandDimsFixture14     ExpandDimsFixture(const std::string& expandDim)
15     {
16         m_Prototext =
17                 "node { \n"
18                 "    name: \"graphInput\" \n"
19                 "    op: \"Placeholder\" \n"
20                 "    attr { \n"
21                 "      key: \"dtype\" \n"
22                 "      value { \n"
23                 "        type: DT_FLOAT \n"
24                 "      } \n"
25                 "    } \n"
26                 "    attr { \n"
27                 "      key: \"shape\" \n"
28                 "      value { \n"
29                 "        shape { \n"
30                 "        } \n"
31                 "      } \n"
32                 "    } \n"
33                 "  } \n"
34                 "node { \n"
35                 "  name: \"ExpandDims\" \n"
36                 "  op: \"ExpandDims\" \n"
37                 "  input: \"graphInput\" \n"
38                 "  attr { \n"
39                 "    key: \"T\" \n"
40                 "    value { \n"
41                 "      type: DT_FLOAT \n"
42                 "    } \n"
43                 "  } \n"
44                 "  attr { \n"
45                 "    key: \"Tdim\" \n"
46                 "    value { \n";
47             m_Prototext += "i:" + expandDim;
48             m_Prototext +=
49                 "    } \n"
50                 "  } \n"
51                 "} \n";
52 
53         SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims");
54     }
55 };
56 
57 struct ExpandZeroDim : ExpandDimsFixture
58 {
ExpandZeroDimExpandZeroDim59     ExpandZeroDim() : ExpandDimsFixture("0") {}
60 };
61 
62 struct ExpandTwoDim : ExpandDimsFixture
63 {
ExpandTwoDimExpandTwoDim64     ExpandTwoDim() : ExpandDimsFixture("2") {}
65 };
66 
67 struct ExpandThreeDim : ExpandDimsFixture
68 {
ExpandThreeDimExpandThreeDim69     ExpandThreeDim() : ExpandDimsFixture("3") {}
70 };
71 
72 struct ExpandMinusOneDim : ExpandDimsFixture
73 {
ExpandMinusOneDimExpandMinusOneDim74     ExpandMinusOneDim() : ExpandDimsFixture("-1") {}
75 };
76 
77 struct ExpandMinusThreeDim : ExpandDimsFixture
78 {
ExpandMinusThreeDimExpandMinusThreeDim79     ExpandMinusThreeDim() : ExpandDimsFixture("-3") {}
80 };
81 
BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim,ExpandZeroDim)82 BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
83 {
84     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
85                 armnn::TensorShape({1, 2, 3, 5})));
86 }
87 
BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim,ExpandTwoDim)88 BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim)
89 {
90     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
91                 armnn::TensorShape({2, 3, 1, 5})));
92 }
93 
BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim,ExpandThreeDim)94 BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim)
95 {
96     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
97                 armnn::TensorShape({2, 3, 5, 1})));
98 }
99 
BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim,ExpandMinusOneDim)100 BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim)
101 {
102     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
103                 armnn::TensorShape({2, 3, 5, 1})));
104 }
105 
BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim,ExpandMinusThreeDim)106 BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
107 {
108     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
109                 armnn::TensorShape({2, 1, 3, 5})));
110 }
111 
112 struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
113 {
ExpandDimsAsInputFixtureExpandDimsAsInputFixture114     ExpandDimsAsInputFixture(const std::string& expandDim,
115                              const bool wrongDataType = false,
116                              const std::string& numElements = "1")
117     {
118         std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32";
119         std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim);
120 
121         m_Prototext = R"(
122         node {
123             name: "a"
124             op: "Placeholder"
125             attr {
126                 key: "dtype"
127                 value {
128                     type: DT_FLOAT
129                 }
130             }
131             attr {
132                 key: "shape"
133                 value {
134                     shape {
135                         dim {
136                             size: 1
137                         }
138                         dim {
139                             size: 4
140                         }
141                     }
142                 }
143             }
144         }
145         node {
146             name: "b"
147             op: "Const"
148             attr {
149                 key: "dtype"
150                 value {
151                     type:  )" + dataType + R"(
152                 }
153             }
154             attr {
155                 key: "value"
156                 value {
157                     tensor {
158                         dtype: )" + dataType + R"(
159                         tensor_shape {
160                             dim {
161                                 size: )" + numElements + R"(
162                             }
163                         }
164                         )" + val + R"(
165                     }
166                 }
167             }
168         }
169         node {
170             name: "ExpandDims"
171             op: "ExpandDims"
172             input: "a"
173             input: "b"
174             attr {
175                 key: "T"
176                 value {
177                     type: DT_FLOAT
178                 }
179             }
180             attr {
181                 key: "Tdim"
182                 value {
183                     type: DT_INT32
184                 }
185             }
186         }
187         versions {
188             producer: 134
189         })";
190     }
191 };
192 
193 struct ExpandDimAsInput : ExpandDimsAsInputFixture
194 {
ExpandDimAsInputExpandDimAsInput195     ExpandDimAsInput() : ExpandDimsAsInputFixture("0")
196     {
197         Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" });
198     }
199 };
200 
201 
BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput,ExpandDimAsInput)202 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput)
203 {
204     // Axis parameter that describes which axis/dim should be expanded is passed as a second input
205     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
206                 armnn::TensorShape({1, 1, 4})));
207 }
208 
209 struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture
210 {
ExpandDimAsInputWrongDataTypeExpandDimAsInputWrongDataType211     ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {}
212 };
213 
BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType,ExpandDimAsInputWrongDataType)214 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType)
215 {
216     // Axis parameter that describes which axis/dim should be expanded is passed as a second input
217     // Axis parameter is of wrong data type (float instead of int32)
218     BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
219 }
220 
221 struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture
222 {
ExpandDimAsInputWrongShapeExpandDimAsInputWrongShape223     ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {}
224 };
225 
BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape,ExpandDimAsInputWrongShape)226 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape)
227 {
228     // Axis parameter that describes which axis/dim should be expanded is passed as a second input
229     // Axis parameter is of wrong shape
230     BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
231 }
232 
233 struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
234 {
ExpandDimsAsNotConstInputFixtureExpandDimsAsNotConstInputFixture235     ExpandDimsAsNotConstInputFixture()
236     {
237         m_Prototext = R"(
238             node {
239                 name: "a"
240                 op: "Placeholder"
241                 attr {
242                     key: "dtype"
243                     value {
244                         type: DT_FLOAT
245                     }
246                 }
247                 attr {
248                     key: "shape"
249                     value {
250                         shape {
251                             dim {
252                                 size: 1
253                             }
254                             dim {
255                             size: 4
256                             }
257                         }
258                     }
259                 }
260             }
261             node {
262                 name: "b"
263                 op: "Placeholder"
264                 attr {
265                     key: "dtype"
266                         value {
267                             type: DT_INT32
268                         }
269                 }
270                 attr {
271                     key: "shape"
272                     value {
273                         shape {
274                             dim {
275                                 size: 1
276                             }
277                         }
278                     }
279                 }
280             }
281             node {
282                 name: "ExpandDims"
283                 op: "ExpandDims"
284                 input: "a"
285                 input: "b"
286                 attr {
287                     key: "T"
288                         value {
289                             type: DT_FLOAT
290                         }
291                     }
292                     attr {
293                         key: "Tdim"
294                         value {
295                             type: DT_INT32
296                         }
297                     }
298                 }
299             versions {
300                 producer: 134
301             })";
302     }
303 };
304 
BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput,ExpandDimsAsNotConstInputFixture)305 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture)
306 {
307     // Axis parameter that describes which axis/dim should be expanded is passed as a second input.
308     // But is not a constant tensor --> not supported
309     BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }),
310                         armnn::ParseException);
311 }
312 
313 BOOST_AUTO_TEST_SUITE_END()
314