• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 
9 TEST_SUITE("OnnxParser_Pooling")
10 {
11 struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
PoolingMainFixturePoolingMainFixture13     PoolingMainFixture(const std::string& dataType, const std::string& op)
14     {
15         m_Prototext = R"(
16                    ir_version: 3
17                    producer_name:  "CNTK"
18                    producer_version:  "2.5.1"
19                    domain:  "ai.cntk"
20                    model_version: 1
21                    graph {
22                      name:  "CNTKGraph"
23                      input {
24                         name: "Input"
25                         type {
26                           tensor_type {
27                             elem_type: )" + dataType + R"(
28                             shape {
29                               dim {
30                                 dim_value: 1
31                               }
32                               dim {
33                                 dim_value: 1
34                               }
35                               dim {
36                                 dim_value: 2
37                               }
38                               dim {
39                                 dim_value: 2
40                               }
41                             }
42                           }
43                         }
44                       }
45                      node {
46                          input: "Input"
47                          output: "Output"
48                          name: "Pooling"
49                          op_type: )" + op + R"(
50                          attribute {
51                            name: "kernel_shape"
52                            ints: 2
53                            ints: 2
54                            type: INTS
55                          }
56                          attribute {
57                            name: "strides"
58                            ints: 1
59                            ints: 1
60                            type: INTS
61                          }
62                          attribute {
63                            name: "pads"
64                            ints: 0
65                            ints: 0
66                            ints: 0
67                            ints: 0
68                            type: INTS
69                          }
70                       }
71                       output {
72                           name: "Output"
73                           type {
74                              tensor_type {
75                                elem_type: 1
76                                shape {
77                                    dim {
78                                        dim_value: 1
79                                    }
80                                    dim {
81                                        dim_value: 1
82                                    }
83                                    dim {
84                                        dim_value: 1
85                                    }
86                                    dim {
87                                        dim_value: 1
88                                    }
89                                }
90                             }
91                         }
92                         }
93                     }
94                    opset_import {
95                       version: 7
96                     })";
97     }
98 };
99 
100 struct MaxPoolValidFixture : PoolingMainFixture
101 {
MaxPoolValidFixtureMaxPoolValidFixture102     MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") {
103         Setup();
104     }
105 };
106 
107 struct MaxPoolInvalidFixture : PoolingMainFixture
108 {
MaxPoolInvalidFixtureMaxPoolInvalidFixture109     MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { }
110 };
111 
112 TEST_CASE_FIXTURE(MaxPoolValidFixture, "ValidMaxPoolTest")
113 {
114     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
115 }
116 
117 struct AvgPoolValidFixture : PoolingMainFixture
118 {
AvgPoolValidFixtureAvgPoolValidFixture119     AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") {
120         Setup();
121     }
122 };
123 
124 struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
125 {
PoolingWithPadFixturePoolingWithPadFixture126     PoolingWithPadFixture()
127     {
128         m_Prototext = R"(
129                    ir_version: 3
130                    producer_name:  "CNTK"
131                    producer_version:  "2.5.1"
132                    domain:  "ai.cntk"
133                    model_version: 1
134                    graph {
135                      name:  "CNTKGraph"
136                      input {
137                         name: "Input"
138                         type {
139                           tensor_type {
140                             elem_type: 1
141                             shape {
142                               dim {
143                                 dim_value: 1
144                               }
145                               dim {
146                                 dim_value: 1
147                               }
148                               dim {
149                                 dim_value: 2
150                               }
151                               dim {
152                                 dim_value: 2
153                               }
154                             }
155                           }
156                         }
157                       }
158                      node {
159                          input: "Input"
160                          output: "Output"
161                          name: "Pooling"
162                          op_type: "AveragePool"
163                          attribute {
164                            name: "kernel_shape"
165                            ints: 4
166                            ints: 4
167                            type: INTS
168                          }
169                          attribute {
170                            name: "strides"
171                            ints: 1
172                            ints: 1
173                            type: INTS
174                          }
175                          attribute {
176                            name: "pads"
177                            ints: 1
178                            ints: 1
179                            ints: 1
180                            ints: 1
181                            type: INTS
182                          }
183                          attribute {
184                            name: "count_include_pad"
185                            i: 1
186                            type: INT
187                          }
188                       }
189                       output {
190                           name: "Output"
191                           type {
192                              tensor_type {
193                                elem_type: 1
194                                shape {
195                                    dim {
196                                        dim_value: 1
197                                    }
198                                    dim {
199                                        dim_value: 1
200                                    }
201                                    dim {
202                                        dim_value: 1
203                                    }
204                                    dim {
205                                        dim_value: 1
206                                    }
207                                }
208                             }
209                         }
210                         }
211                     }
212                    opset_import {
213                       version: 7
214                     })";
215         Setup();
216     }
217 };
218 
219 TEST_CASE_FIXTURE(AvgPoolValidFixture, "AveragePoolValid")
220 {
221     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
222 }
223 
224 TEST_CASE_FIXTURE(PoolingWithPadFixture, "ValidAvgWithPadTest")
225 {
226     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
227 }
228 
229 struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
230 {
GlobalAvgFixtureGlobalAvgFixture231     GlobalAvgFixture()
232     {
233         m_Prototext = R"(
234                    ir_version: 3
235                    producer_name:  "CNTK"
236                    producer_version:  "2.5.1"
237                    domain:  "ai.cntk"
238                    model_version: 1
239                    graph {
240                      name:  "CNTKGraph"
241                      input {
242                         name: "Input"
243                         type {
244                           tensor_type {
245                             elem_type: 1
246                             shape {
247                               dim {
248                                 dim_value: 1
249                               }
250                               dim {
251                                 dim_value: 2
252                               }
253                               dim {
254                                 dim_value: 2
255                               }
256                               dim {
257                                 dim_value: 2
258                               }
259                             }
260                           }
261                         }
262                       }
263                      node {
264                          input: "Input"
265                          output: "Output"
266                          name: "Pooling"
267                          op_type: "GlobalAveragePool"
268                       }
269                       output {
270                           name: "Output"
271                           type {
272                              tensor_type {
273                                elem_type: 1
274                                shape {
275                                    dim {
276                                        dim_value: 1
277                                    }
278                                    dim {
279                                        dim_value: 2
280                                    }
281                                    dim {
282                                        dim_value: 1
283                                    }
284                                    dim {
285                                        dim_value: 1
286                                    }
287                                }
288                             }
289                         }
290                         }
291                     }
292                    opset_import {
293                       version: 7
294                     })";
295         Setup();
296     }
297 };
298 
299 TEST_CASE_FIXTURE(GlobalAvgFixture, "GlobalAvgTest")
300 {
301     RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
302 }
303 
304 TEST_CASE_FIXTURE(MaxPoolInvalidFixture, "IncorrectDataTypeMaxPool")
305 {
306    CHECK_THROWS_AS(Setup(), armnn::ParseException);
307 }
308 
309 }
310