1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12 struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
PoolingMainFixturePoolingMainFixture14 PoolingMainFixture(const std::string& dataType, const std::string& op)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 2
38 }
39 dim {
40 dim_value: 2
41 }
42 }
43 }
44 }
45 }
46 node {
47 input: "Input"
48 output: "Output"
49 name: "Pooling"
50 op_type: )" + op + R"(
51 attribute {
52 name: "kernel_shape"
53 ints: 2
54 ints: 2
55 type: INTS
56 }
57 attribute {
58 name: "strides"
59 ints: 1
60 ints: 1
61 type: INTS
62 }
63 attribute {
64 name: "pads"
65 ints: 0
66 ints: 0
67 ints: 0
68 ints: 0
69 type: INTS
70 }
71 }
72 output {
73 name: "Output"
74 type {
75 tensor_type {
76 elem_type: 1
77 shape {
78 dim {
79 dim_value: 1
80 }
81 dim {
82 dim_value: 1
83 }
84 dim {
85 dim_value: 1
86 }
87 dim {
88 dim_value: 1
89 }
90 }
91 }
92 }
93 }
94 }
95 opset_import {
96 version: 7
97 })";
98 }
99 };
100
101 struct MaxPoolValidFixture : PoolingMainFixture
102 {
MaxPoolValidFixtureMaxPoolValidFixture103 MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") {
104 Setup();
105 }
106 };
107
108 struct MaxPoolInvalidFixture : PoolingMainFixture
109 {
MaxPoolInvalidFixtureMaxPoolInvalidFixture110 MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { }
111 };
112
BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest,MaxPoolValidFixture)113 BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest, MaxPoolValidFixture)
114 {
115 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
116 }
117
118 struct AvgPoolValidFixture : PoolingMainFixture
119 {
AvgPoolValidFixtureAvgPoolValidFixture120 AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") {
121 Setup();
122 }
123 };
124
125 struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
126 {
PoolingWithPadFixturePoolingWithPadFixture127 PoolingWithPadFixture()
128 {
129 m_Prototext = R"(
130 ir_version: 3
131 producer_name: "CNTK"
132 producer_version: "2.5.1"
133 domain: "ai.cntk"
134 model_version: 1
135 graph {
136 name: "CNTKGraph"
137 input {
138 name: "Input"
139 type {
140 tensor_type {
141 elem_type: 1
142 shape {
143 dim {
144 dim_value: 1
145 }
146 dim {
147 dim_value: 1
148 }
149 dim {
150 dim_value: 2
151 }
152 dim {
153 dim_value: 2
154 }
155 }
156 }
157 }
158 }
159 node {
160 input: "Input"
161 output: "Output"
162 name: "Pooling"
163 op_type: "AveragePool"
164 attribute {
165 name: "kernel_shape"
166 ints: 4
167 ints: 4
168 type: INTS
169 }
170 attribute {
171 name: "strides"
172 ints: 1
173 ints: 1
174 type: INTS
175 }
176 attribute {
177 name: "pads"
178 ints: 1
179 ints: 1
180 ints: 1
181 ints: 1
182 type: INTS
183 }
184 attribute {
185 name: "count_include_pad"
186 i: 1
187 type: INT
188 }
189 }
190 output {
191 name: "Output"
192 type {
193 tensor_type {
194 elem_type: 1
195 shape {
196 dim {
197 dim_value: 1
198 }
199 dim {
200 dim_value: 1
201 }
202 dim {
203 dim_value: 1
204 }
205 dim {
206 dim_value: 1
207 }
208 }
209 }
210 }
211 }
212 }
213 opset_import {
214 version: 7
215 })";
216 Setup();
217 }
218 };
219
BOOST_FIXTURE_TEST_CASE(AveragePoolValid,AvgPoolValidFixture)220 BOOST_FIXTURE_TEST_CASE(AveragePoolValid, AvgPoolValidFixture)
221 {
222 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
223 }
224
BOOST_FIXTURE_TEST_CASE(ValidAvgWithPadTest,PoolingWithPadFixture)225 BOOST_FIXTURE_TEST_CASE(ValidAvgWithPadTest, PoolingWithPadFixture)
226 {
227 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
228 }
229
230 struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
231 {
GlobalAvgFixtureGlobalAvgFixture232 GlobalAvgFixture()
233 {
234 m_Prototext = R"(
235 ir_version: 3
236 producer_name: "CNTK"
237 producer_version: "2.5.1"
238 domain: "ai.cntk"
239 model_version: 1
240 graph {
241 name: "CNTKGraph"
242 input {
243 name: "Input"
244 type {
245 tensor_type {
246 elem_type: 1
247 shape {
248 dim {
249 dim_value: 1
250 }
251 dim {
252 dim_value: 2
253 }
254 dim {
255 dim_value: 2
256 }
257 dim {
258 dim_value: 2
259 }
260 }
261 }
262 }
263 }
264 node {
265 input: "Input"
266 output: "Output"
267 name: "Pooling"
268 op_type: "GlobalAveragePool"
269 }
270 output {
271 name: "Output"
272 type {
273 tensor_type {
274 elem_type: 1
275 shape {
276 dim {
277 dim_value: 1
278 }
279 dim {
280 dim_value: 2
281 }
282 dim {
283 dim_value: 1
284 }
285 dim {
286 dim_value: 1
287 }
288 }
289 }
290 }
291 }
292 }
293 opset_import {
294 version: 7
295 })";
296 Setup();
297 }
298 };
299
BOOST_FIXTURE_TEST_CASE(GlobalAvgTest,GlobalAvgFixture)300 BOOST_FIXTURE_TEST_CASE(GlobalAvgTest, GlobalAvgFixture)
301 {
302 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
303 }
304
BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeMaxPool,MaxPoolInvalidFixture)305 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeMaxPool, MaxPoolInvalidFixture)
306 {
307 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
308 }
309
310 BOOST_AUTO_TEST_SUITE_END()
311