1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 9 TEST_SUITE("OnnxParser_Pooling") 10 { 11 struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12 { PoolingMainFixturePoolingMainFixture13 PoolingMainFixture(const std::string& dataType, const std::string& op) 14 { 15 m_Prototext = R"( 16 ir_version: 3 17 producer_name: "CNTK" 18 producer_version: "2.5.1" 19 domain: "ai.cntk" 20 model_version: 1 21 graph { 22 name: "CNTKGraph" 23 input { 24 name: "Input" 25 type { 26 tensor_type { 27 elem_type: )" + dataType + R"( 28 shape { 29 dim { 30 dim_value: 1 31 } 32 dim { 33 dim_value: 1 34 } 35 dim { 36 dim_value: 2 37 } 38 dim { 39 dim_value: 2 40 } 41 } 42 } 43 } 44 } 45 node { 46 input: "Input" 47 output: "Output" 48 name: "Pooling" 49 op_type: )" + op + R"( 50 attribute { 51 name: "kernel_shape" 52 ints: 2 53 ints: 2 54 type: INTS 55 } 56 attribute { 57 name: "strides" 58 ints: 1 59 ints: 1 60 type: INTS 61 } 62 attribute { 63 name: "pads" 64 ints: 0 65 ints: 0 66 ints: 0 67 ints: 0 68 type: INTS 69 } 70 } 71 output { 72 name: "Output" 73 type { 74 tensor_type { 75 elem_type: 1 76 shape { 77 dim { 78 dim_value: 1 79 } 80 dim { 81 dim_value: 1 82 } 83 dim { 84 dim_value: 1 85 } 86 dim { 87 dim_value: 1 88 } 89 } 90 } 91 } 92 } 93 } 94 opset_import { 95 version: 7 96 })"; 97 } 98 }; 99 100 struct MaxPoolValidFixture : PoolingMainFixture 101 { MaxPoolValidFixtureMaxPoolValidFixture102 MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") { 103 Setup(); 104 } 105 }; 106 107 struct MaxPoolInvalidFixture : PoolingMainFixture 108 { MaxPoolInvalidFixtureMaxPoolInvalidFixture109 MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { } 110 }; 111 112 TEST_CASE_FIXTURE(MaxPoolValidFixture, "ValidMaxPoolTest") 113 { 114 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}}); 115 } 116 117 struct AvgPoolValidFixture : PoolingMainFixture 118 { AvgPoolValidFixtureAvgPoolValidFixture119 AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") { 120 Setup(); 121 } 122 }; 123 124 struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 125 { PoolingWithPadFixturePoolingWithPadFixture126 PoolingWithPadFixture() 127 { 128 m_Prototext = R"( 129 ir_version: 3 130 producer_name: "CNTK" 131 producer_version: "2.5.1" 132 domain: "ai.cntk" 133 model_version: 1 134 graph { 135 name: "CNTKGraph" 136 input { 137 name: "Input" 138 type { 139 tensor_type { 140 elem_type: 1 141 shape { 142 dim { 143 dim_value: 1 144 } 145 dim { 146 dim_value: 1 147 } 148 dim { 149 dim_value: 2 150 } 151 dim { 152 dim_value: 2 153 } 154 } 155 } 156 } 157 } 158 node { 159 input: "Input" 160 output: "Output" 161 name: "Pooling" 162 op_type: "AveragePool" 163 attribute { 164 name: "kernel_shape" 165 ints: 4 166 ints: 4 167 type: INTS 168 } 169 attribute { 170 name: "strides" 171 ints: 1 172 ints: 1 173 type: INTS 174 } 175 attribute { 176 name: "pads" 177 ints: 1 178 ints: 1 179 ints: 1 180 ints: 1 181 type: INTS 182 } 183 attribute { 184 name: "count_include_pad" 185 i: 1 186 type: INT 187 } 188 } 189 output { 190 name: "Output" 191 type { 192 tensor_type { 193 elem_type: 1 194 shape { 195 dim { 196 dim_value: 1 197 } 198 dim { 199 dim_value: 1 200 } 201 dim { 202 dim_value: 1 203 } 204 dim { 205 dim_value: 1 206 } 207 } 208 } 209 } 210 } 211 } 212 opset_import { 213 version: 7 214 })"; 215 Setup(); 216 } 217 }; 218 219 TEST_CASE_FIXTURE(AvgPoolValidFixture, "AveragePoolValid") 220 { 221 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}}); 222 } 223 224 TEST_CASE_FIXTURE(PoolingWithPadFixture, "ValidAvgWithPadTest") 225 { 226 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}}); 227 } 228 229 struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 230 { GlobalAvgFixtureGlobalAvgFixture231 GlobalAvgFixture() 232 { 233 m_Prototext = R"( 234 ir_version: 3 235 producer_name: "CNTK" 236 producer_version: "2.5.1" 237 domain: "ai.cntk" 238 model_version: 1 239 graph { 240 name: "CNTKGraph" 241 input { 242 name: "Input" 243 type { 244 tensor_type { 245 elem_type: 1 246 shape { 247 dim { 248 dim_value: 1 249 } 250 dim { 251 dim_value: 2 252 } 253 dim { 254 dim_value: 2 255 } 256 dim { 257 dim_value: 2 258 } 259 } 260 } 261 } 262 } 263 node { 264 input: "Input" 265 output: "Output" 266 name: "Pooling" 267 op_type: "GlobalAveragePool" 268 } 269 output { 270 name: "Output" 271 type { 272 tensor_type { 273 elem_type: 1 274 shape { 275 dim { 276 dim_value: 1 277 } 278 dim { 279 dim_value: 2 280 } 281 dim { 282 dim_value: 1 283 } 284 dim { 285 dim_value: 1 286 } 287 } 288 } 289 } 290 } 291 } 292 opset_import { 293 version: 7 294 })"; 295 Setup(); 296 } 297 }; 298 299 TEST_CASE_FIXTURE(GlobalAvgFixture, "GlobalAvgTest") 300 { 301 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}}); 302 } 303 304 TEST_CASE_FIXTURE(MaxPoolInvalidFixture, "IncorrectDataTypeMaxPool") 305 { 306 CHECK_THROWS_AS(Setup(), armnn::ParseException); 307 } 308 309 } 310