1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <gtest/gtest.h>
7
8 #include "argmax-pooling-operator-tester.h"
9
10 #include <xnnpack/params.h>
11
12
FindMaxSinglePassPoolingSize(const argmaxpool_parameters * ukernel)13 static uint32_t FindMaxSinglePassPoolingSize(const argmaxpool_parameters* ukernel) {
14 uint32_t mr = 0;
15 while (ukernel->qr == 0) {
16 mr = std::max<uint32_t>(mr, ukernel->mr);
17 ukernel++;
18 }
19 return mr;
20 }
21
FindMultiPassMicroKernel(const argmaxpool_parameters * ukernel)22 static argmaxpool_parameters FindMultiPassMicroKernel(const argmaxpool_parameters* ukernel) {
23 while (ukernel->qr == 0) {
24 ukernel++;
25 }
26 return *ukernel;
27 }
28
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_1xM_pool)29 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_1xM_pool) {
30 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
31 for (size_t channels = 1; channels <= 100; channels += 15) {
32 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
33 ArgmaxPoolingOperatorTester()
34 .batch_size(1)
35 .input_height(2)
36 .input_width(pool_size + 2)
37 .pooling_height(1)
38 .pooling_width(pool_size)
39 .channels(channels)
40 .TestF32();
41 }
42 }
43 }
44
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_1xM_pool_with_padding)45 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_1xM_pool_with_padding) {
46 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
47 for (size_t channels = 1; channels <= 100; channels += 15) {
48 for (size_t pool_size = 3; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
49 for (size_t padding_left = 0; padding_left <= 1; padding_left++) {
50 for (size_t padding_right = 0; padding_right <= 1; padding_right++) {
51 ArgmaxPoolingOperatorTester()
52 .batch_size(1)
53 .input_height(2)
54 .input_width(pool_size + 2)
55 .padding_left(padding_left)
56 .padding_right(padding_right)
57 .pooling_height(1)
58 .pooling_width(pool_size)
59 .channels(channels)
60 .TestF32();
61 }
62 }
63 }
64 }
65 }
66
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_1xM_pool_with_tf_same_padding)67 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_1xM_pool_with_tf_same_padding) {
68 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
69 for (size_t channels = 1; channels <= 100; channels += 15) {
70 for (size_t pool_size = 3; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
71 for (size_t input_width = pool_size + 1; input_width <= pool_size; input_width++) {
72 ArgmaxPoolingOperatorTester()
73 .batch_size(1)
74 .input_height(2)
75 .input_width(input_width)
76 .padding_tf_same(true)
77 .pooling_height(1)
78 .pooling_width(pool_size)
79 .channels(channels)
80 .TestF32();
81 }
82 }
83 }
84 }
85
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_Mx1_pool)86 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_Mx1_pool) {
87 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
88 for (size_t channels = 1; channels <= 100; channels += 15) {
89 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
90 ArgmaxPoolingOperatorTester()
91 .batch_size(1)
92 .input_height(pool_size + 1)
93 .input_width(3)
94 .pooling_height(pool_size)
95 .pooling_width(1)
96 .channels(channels)
97 .TestF32();
98 }
99 }
100 }
101
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_Mx1_pool_with_padding)102 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_Mx1_pool_with_padding) {
103 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
104 for (size_t channels = 1; channels <= 100; channels += 15) {
105 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
106 for (size_t padding_top = 0; padding_top <= 1; padding_top++) {
107 for (size_t padding_bottom = 0; padding_bottom <= 1; padding_bottom++) {
108 ArgmaxPoolingOperatorTester()
109 .batch_size(1)
110 .input_height(pool_size + 1)
111 .input_width(3)
112 .padding_top(padding_top)
113 .padding_bottom(padding_bottom)
114 .pooling_height(pool_size)
115 .pooling_width(1)
116 .channels(channels)
117 .TestF32();
118 }
119 }
120 }
121 }
122 }
123
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_Mx1_pool_with_tf_same_padding)124 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_Mx1_pool_with_tf_same_padding) {
125 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
126 for (size_t channels = 1; channels <= 100; channels += 15) {
127 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
128 for (size_t input_height = pool_size + 1; input_height <= pool_size * 2; input_height++) {
129 ArgmaxPoolingOperatorTester()
130 .batch_size(1)
131 .input_height(input_height)
132 .input_width(3)
133 .padding_tf_same(true)
134 .pooling_height(pool_size)
135 .pooling_width(1)
136 .channels(channels)
137 .TestF32();
138 }
139 }
140 }
141 }
142
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_pool_with_input_stride)143 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_pool_with_input_stride) {
144 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
145 for (size_t channels = 1; channels <= 100; channels += 15) {
146 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
147 ArgmaxPoolingOperatorTester()
148 .batch_size(1)
149 .input_height(pool_size + 1)
150 .input_width(3)
151 .pooling_height(pool_size)
152 .pooling_width(1)
153 .channels(channels)
154 .input_pixel_stride(5 * channels)
155 .TestF32();
156 ArgmaxPoolingOperatorTester()
157 .batch_size(1)
158 .input_height(2)
159 .input_width(pool_size + 2)
160 .pooling_height(1)
161 .pooling_width(pool_size)
162 .channels(channels)
163 .input_pixel_stride(5 * channels)
164 .TestF32();
165 }
166 }
167 }
168
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_small_pool_with_output_stride)169 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_small_pool_with_output_stride) {
170 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
171 for (size_t channels = 1; channels <= 100; channels += 15) {
172 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
173 ArgmaxPoolingOperatorTester()
174 .batch_size(1)
175 .input_height(pool_size + 1)
176 .input_width(3)
177 .pooling_height(pool_size)
178 .pooling_width(1)
179 .channels(channels)
180 .output_pixel_stride(5 * channels)
181 .TestF32();
182 ArgmaxPoolingOperatorTester()
183 .batch_size(1)
184 .input_height(2)
185 .input_width(pool_size + 2)
186 .pooling_height(1)
187 .pooling_width(pool_size)
188 .channels(channels)
189 .output_pixel_stride(5 * channels)
190 .TestF32();
191 }
192 }
193 }
194
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_1xM_pool)195 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_1xM_pool) {
196 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
197 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
198 for (size_t channels = 1; channels <= 100; channels += 15) {
199 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
200 ArgmaxPoolingOperatorTester()
201 .batch_size(1)
202 .input_height(2)
203 .input_width(pool_size + 2)
204 .pooling_height(1)
205 .pooling_width(pool_size)
206 .channels(channels)
207 .TestF32();
208 }
209 }
210 }
211
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_1xM_pool_with_padding)212 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_1xM_pool_with_padding) {
213 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
214 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
215 for (size_t channels = 1; channels <= 100; channels += 15) {
216 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
217 for (size_t padding_left = 0; padding_left <= 1; padding_left++) {
218 for (size_t padding_right = 0; padding_right <= 1; padding_right++) {
219 ArgmaxPoolingOperatorTester()
220 .batch_size(1)
221 .input_height(2)
222 .input_width(pool_size + 2)
223 .padding_left(padding_left)
224 .padding_right(padding_right)
225 .pooling_height(1)
226 .pooling_width(pool_size)
227 .channels(channels)
228 .TestF32();
229 }
230 }
231 }
232 }
233 }
234
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_1xM_pool_with_tf_same_padding)235 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_1xM_pool_with_tf_same_padding) {
236 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
237 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
238 for (size_t channels = 1; channels <= 100; channels += 15) {
239 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
240 for (size_t input_width = pool_size + 1; input_width <= pool_size * 2; input_width++) {
241 ArgmaxPoolingOperatorTester()
242 .batch_size(1)
243 .input_height(2)
244 .input_width(input_width)
245 .padding_tf_same(true)
246 .pooling_height(1)
247 .pooling_width(pool_size)
248 .channels(channels)
249 .TestF32();
250 }
251 }
252 }
253 }
254
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_Mx1_pool)255 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_Mx1_pool) {
256 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
257 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
258 for (size_t channels = 1; channels <= 100; channels += 15) {
259 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
260 ArgmaxPoolingOperatorTester()
261 .batch_size(1)
262 .input_height(pool_size + 1)
263 .input_width(3)
264 .pooling_height(pool_size)
265 .pooling_width(1)
266 .channels(channels)
267 .TestF32();
268 }
269 }
270 }
271
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_Mx1_pool_with_padding)272 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_Mx1_pool_with_padding) {
273 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
274 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
275 for (size_t channels = 1; channels <= 100; channels += 15) {
276 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
277 for (size_t padding_top = 0; padding_top <= 1; padding_top++) {
278 for (size_t padding_bottom = 0; padding_bottom <= 1; padding_bottom++) {
279 ArgmaxPoolingOperatorTester()
280 .batch_size(1)
281 .input_height(pool_size + 1)
282 .input_width(3)
283 .padding_top(padding_top)
284 .padding_bottom(padding_bottom)
285 .pooling_height(pool_size)
286 .pooling_width(1)
287 .channels(channels)
288 .TestF32();
289 }
290 }
291 }
292 }
293 }
294
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_Mx1_pool_with_tf_same_padding)295 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_Mx1_pool_with_tf_same_padding) {
296 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
297 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
298 for (size_t channels = 1; channels <= 100; channels += 15) {
299 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
300 for (size_t input_height = pool_size + 2; input_height <= pool_size * 2; input_height++) {
301 ArgmaxPoolingOperatorTester()
302 .batch_size(1)
303 .input_height(input_height)
304 .input_width(3)
305 .padding_tf_same(true)
306 .pooling_height(pool_size)
307 .pooling_width(1)
308 .channels(channels)
309 .TestF32();
310 }
311 }
312 }
313 }
314
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_pool_with_input_stride)315 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_pool_with_input_stride) {
316 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
317 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
318 for (size_t channels = 1; channels <= 100; channels += 15) {
319 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
320 ArgmaxPoolingOperatorTester()
321 .batch_size(1)
322 .input_height(pool_size + 1)
323 .input_width(3)
324 .pooling_height(pool_size)
325 .pooling_width(1)
326 .channels(channels)
327 .input_pixel_stride(5 * channels)
328 .TestF32();
329 ArgmaxPoolingOperatorTester()
330 .batch_size(1)
331 .input_height(2)
332 .input_width(pool_size + 2)
333 .pooling_height(1)
334 .pooling_width(pool_size)
335 .channels(channels)
336 .input_pixel_stride(5 * channels)
337 .TestF32();
338 }
339 }
340 }
341
TEST(ARGMAX_POOLING_NHWC_F32,unit_batch_large_pool_with_output_stride)342 TEST(ARGMAX_POOLING_NHWC_F32, unit_batch_large_pool_with_output_stride) {
343 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
344 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
345 for (size_t channels = 1; channels <= 100; channels += 15) {
346 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
347 ArgmaxPoolingOperatorTester()
348 .batch_size(1)
349 .input_height(pool_size + 1)
350 .input_width(3)
351 .pooling_height(pool_size)
352 .pooling_width(1)
353 .channels(channels)
354 .output_pixel_stride(5 * channels)
355 .TestF32();
356 ArgmaxPoolingOperatorTester()
357 .batch_size(1)
358 .input_height(2)
359 .input_width(pool_size + 2)
360 .pooling_height(1)
361 .pooling_width(pool_size)
362 .channels(channels)
363 .output_pixel_stride(5 * channels)
364 .TestF32();
365 }
366 }
367 }
368
TEST(ARGMAX_POOLING_NHWC_F32,small_batch_small_pool)369 TEST(ARGMAX_POOLING_NHWC_F32, small_batch_small_pool) {
370 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
371 for (size_t channels = 1; channels <= 100; channels += 15) {
372 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
373 ArgmaxPoolingOperatorTester()
374 .batch_size(3)
375 .input_height(pool_size + 1)
376 .input_width(3)
377 .pooling_height(pool_size)
378 .pooling_width(1)
379 .channels(channels)
380 .TestF32();
381 ArgmaxPoolingOperatorTester()
382 .batch_size(3)
383 .input_height(2)
384 .input_width(pool_size + 2)
385 .pooling_height(1)
386 .pooling_width(pool_size)
387 .channels(channels)
388 .TestF32();
389 }
390 }
391 }
392
TEST(ARGMAX_POOLING_NHWC_F32,small_batch_small_pool_with_input_stride)393 TEST(ARGMAX_POOLING_NHWC_F32, small_batch_small_pool_with_input_stride) {
394 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
395 for (size_t channels = 1; channels <= 100; channels += 15) {
396 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
397 ArgmaxPoolingOperatorTester()
398 .batch_size(3)
399 .input_height(pool_size + 1)
400 .input_width(3)
401 .pooling_height(pool_size)
402 .pooling_width(1)
403 .channels(channels)
404 .input_pixel_stride(5 * channels)
405 .TestF32();
406 ArgmaxPoolingOperatorTester()
407 .batch_size(3)
408 .input_height(2)
409 .input_width(pool_size + 2)
410 .pooling_height(1)
411 .pooling_width(pool_size)
412 .channels(channels)
413 .input_pixel_stride(5 * channels)
414 .TestF32();
415 }
416 }
417 }
418
TEST(ARGMAX_POOLING_NHWC_F32,small_batch_small_pool_with_output_stride)419 TEST(ARGMAX_POOLING_NHWC_F32, small_batch_small_pool_with_output_stride) {
420 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
421 for (size_t channels = 1; channels <= 100; channels += 15) {
422 for (size_t pool_size = 2; pool_size <= FindMaxSinglePassPoolingSize(xnn_params.f32.argmaxpool); pool_size++) {
423 ArgmaxPoolingOperatorTester()
424 .batch_size(3)
425 .input_height(pool_size + 1)
426 .input_width(3)
427 .pooling_height(pool_size)
428 .pooling_width(1)
429 .channels(channels)
430 .output_pixel_stride(5 * channels)
431 .TestF32();
432 ArgmaxPoolingOperatorTester()
433 .batch_size(3)
434 .input_height(2)
435 .input_width(pool_size + 2)
436 .pooling_height(1)
437 .pooling_width(pool_size)
438 .channels(channels)
439 .output_pixel_stride(5 * channels)
440 .TestF32();
441 }
442 }
443 }
444
TEST(ARGMAX_POOLING_NHWC_F32,small_batch_large_pool)445 TEST(ARGMAX_POOLING_NHWC_F32, small_batch_large_pool) {
446 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
447 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
448 for (size_t channels = 1; channels <= 100; channels += 15) {
449 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
450 ArgmaxPoolingOperatorTester()
451 .batch_size(3)
452 .input_height(pool_size + 1)
453 .input_width(3)
454 .pooling_height(pool_size)
455 .pooling_width(1)
456 .channels(channels)
457 .TestF32();
458 ArgmaxPoolingOperatorTester()
459 .batch_size(3)
460 .input_height(2)
461 .input_width(pool_size + 2)
462 .pooling_height(1)
463 .pooling_width(pool_size)
464 .channels(channels)
465 .TestF32();
466 }
467 }
468 }
469
TEST(ARGMAX_POOLING_NHWC_F32,small_batch_large_pool_with_input_stride)470 TEST(ARGMAX_POOLING_NHWC_F32, small_batch_large_pool_with_input_stride) {
471 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
472 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
473 for (size_t channels = 1; channels <= 100; channels += 15) {
474 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
475 ArgmaxPoolingOperatorTester()
476 .batch_size(3)
477 .input_height(pool_size + 1)
478 .input_width(3)
479 .pooling_height(pool_size)
480 .pooling_width(1)
481 .channels(channels)
482 .input_pixel_stride(5 * channels)
483 .TestF32();
484 ArgmaxPoolingOperatorTester()
485 .batch_size(3)
486 .input_height(2)
487 .input_width(pool_size + 2)
488 .pooling_height(1)
489 .pooling_width(pool_size)
490 .channels(channels)
491 .input_pixel_stride(5 * channels)
492 .TestF32();
493 }
494 }
495 }
496
TEST(ARGMAX_POOLING_NHWC_F32,small_batch_large_pool_with_output_stride)497 TEST(ARGMAX_POOLING_NHWC_F32, small_batch_large_pool_with_output_stride) {
498 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
499 const auto multipass = FindMultiPassMicroKernel(xnn_params.f32.argmaxpool);
500 for (size_t channels = 1; channels <= 100; channels += 15) {
501 for (size_t pool_size = multipass.mr + 1; pool_size <= multipass.mr + multipass.qr; pool_size++) {
502 ArgmaxPoolingOperatorTester()
503 .batch_size(3)
504 .input_height(pool_size + 1)
505 .input_width(3)
506 .pooling_height(pool_size)
507 .pooling_width(1)
508 .channels(channels)
509 .output_pixel_stride(5 * channels)
510 .TestF32();
511 ArgmaxPoolingOperatorTester()
512 .batch_size(3)
513 .input_height(2)
514 .input_width(pool_size + 2)
515 .pooling_height(1)
516 .pooling_width(pool_size)
517 .channels(channels)
518 .output_pixel_stride(5 * channels)
519 .TestF32();
520 }
521 }
522 }
523
TEST(ARGMAX_POOLING_NHWC_F32,setup_increasing_batch)524 TEST(ARGMAX_POOLING_NHWC_F32, setup_increasing_batch) {
525 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
526 ArgmaxPoolingOperatorTester()
527 .batch_size(3)
528 .next_batch_size(5)
529 .input_height(8)
530 .input_width(8)
531 .pooling_height(5)
532 .pooling_width(3)
533 .channels(24)
534 .TestSetupF32();
535 }
536
TEST(ARGMAX_POOLING_NHWC_F32,setup_decreasing_batch)537 TEST(ARGMAX_POOLING_NHWC_F32, setup_decreasing_batch) {
538 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
539 ArgmaxPoolingOperatorTester()
540 .batch_size(5)
541 .next_batch_size(3)
542 .input_height(8)
543 .input_width(8)
544 .pooling_height(5)
545 .pooling_width(3)
546 .channels(24)
547 .TestSetupF32();
548 }
549
TEST(ARGMAX_POOLING_NHWC_F32,setup_changing_height)550 TEST(ARGMAX_POOLING_NHWC_F32, setup_changing_height) {
551 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
552 ArgmaxPoolingOperatorTester()
553 .batch_size(3)
554 .input_height(8)
555 .input_width(8)
556 .next_input_height(9)
557 .pooling_height(5)
558 .pooling_width(3)
559 .channels(24)
560 .TestSetupF32();
561 ArgmaxPoolingOperatorTester()
562 .batch_size(3)
563 .input_height(8)
564 .input_width(8)
565 .next_input_height(7)
566 .pooling_height(5)
567 .pooling_width(3)
568 .channels(24)
569 .TestSetupF32();
570 }
571
TEST(ARGMAX_POOLING_NHWC_F32,setup_changing_width)572 TEST(ARGMAX_POOLING_NHWC_F32, setup_changing_width) {
573 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
574 ArgmaxPoolingOperatorTester()
575 .batch_size(3)
576 .input_height(8)
577 .input_width(8)
578 .next_input_width(9)
579 .pooling_height(5)
580 .pooling_width(3)
581 .channels(24)
582 .TestSetupF32();
583 ArgmaxPoolingOperatorTester()
584 .batch_size(3)
585 .input_height(8)
586 .input_width(8)
587 .next_input_width(7)
588 .pooling_height(5)
589 .pooling_width(3)
590 .channels(24)
591 .TestSetupF32();
592 }
593
TEST(ARGMAX_POOLING_NHWC_F32,setup_swap_height_and_width)594 TEST(ARGMAX_POOLING_NHWC_F32, setup_swap_height_and_width) {
595 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
596 ArgmaxPoolingOperatorTester()
597 .batch_size(3)
598 .input_height(9)
599 .input_width(8)
600 .next_input_height(8)
601 .next_input_width(9)
602 .pooling_height(5)
603 .pooling_width(3)
604 .channels(24)
605 .TestSetupF32();
606 }
607