• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_SHAPE_DATASETS_H
25 #define ARM_COMPUTE_TEST_SHAPE_DATASETS_H
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "tests/framework/datasets/Datasets.h"
29 
30 #include <type_traits>
31 
32 namespace arm_compute
33 {
34 namespace test
35 {
36 namespace datasets
37 {
38 /** Parent type for all for shape datasets. */
39 using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>;
40 
41 /** Data set containing tiny 1D tensor shapes. */
42 class Tiny1DShapes final : public ShapeDataset
43 {
44 public:
Tiny1DShapes()45     Tiny1DShapes()
46         : ShapeDataset("Shape",
47     {
48         TensorShape{ 2U },
49                      TensorShape{ 3U },
50     })
51     {
52     }
53 };
54 
55 /** Data set containing small 1D tensor shapes. */
56 class Small1DShapes final : public ShapeDataset
57 {
58 public:
Small1DShapes()59     Small1DShapes()
60         : ShapeDataset("Shape",
61     {
62         TensorShape{ 128U },
63                      TensorShape{ 256U },
64                      TensorShape{ 512U },
65                      TensorShape{ 1024U }
66     })
67     {
68     }
69 };
70 
71 /** Data set containing tiny 2D tensor shapes. */
72 class Tiny2DShapes final : public ShapeDataset
73 {
74 public:
Tiny2DShapes()75     Tiny2DShapes()
76         : ShapeDataset("Shape",
77     {
78         TensorShape{ 7U, 7U },
79                      TensorShape{ 11U, 13U },
80     })
81     {
82     }
83 };
84 /** Data set containing small 2D tensor shapes. */
85 class Small2DShapes final : public ShapeDataset
86 {
87 public:
Small2DShapes()88     Small2DShapes()
89         : ShapeDataset("Shape",
90     {
91         TensorShape{ 7U, 7U },
92                      TensorShape{ 27U, 13U },
93                      TensorShape{ 128U, 64U }
94     })
95     {
96     }
97 };
98 
99 /** Data set containing tiny 3D tensor shapes. */
100 class Tiny3DShapes final : public ShapeDataset
101 {
102 public:
Tiny3DShapes()103     Tiny3DShapes()
104         : ShapeDataset("Shape",
105     {
106         TensorShape{ 7U, 7U, 5U },
107                      TensorShape{ 23U, 13U, 9U },
108     })
109     {
110     }
111 };
112 
113 /** Data set containing small 3D tensor shapes. */
114 class Small3DShapes final : public ShapeDataset
115 {
116 public:
Small3DShapes()117     Small3DShapes()
118         : ShapeDataset("Shape",
119     {
120         TensorShape{ 1U, 7U, 7U },
121                      TensorShape{ 2U, 5U, 4U },
122 
123                      TensorShape{ 7U, 7U, 5U },
124                      TensorShape{ 16U, 16U, 5U },
125                      TensorShape{ 27U, 13U, 37U },
126     })
127     {
128     }
129 };
130 
131 /** Data set containing tiny 4D tensor shapes. */
132 class Tiny4DShapes final : public ShapeDataset
133 {
134 public:
Tiny4DShapes()135     Tiny4DShapes()
136         : ShapeDataset("Shape",
137     {
138         TensorShape{ 7U, 7U, 5U, 3U },
139                      TensorShape{ 17U, 13U, 7U, 2U },
140     })
141     {
142     }
143 };
144 /** Data set containing small 4D tensor shapes. */
145 class Small4DShapes final : public ShapeDataset
146 {
147 public:
Small4DShapes()148     Small4DShapes()
149         : ShapeDataset("Shape",
150     {
151         TensorShape{ 2U, 7U, 1U, 3U },
152                      TensorShape{ 7U, 7U, 5U, 3U },
153                      TensorShape{ 27U, 13U, 37U, 2U },
154                      TensorShape{ 128U, 64U, 21U, 3U }
155     })
156     {
157     }
158 };
159 
160 /** Data set containing tiny tensor shapes. */
161 class TinyShapes final : public ShapeDataset
162 {
163 public:
TinyShapes()164     TinyShapes()
165         : ShapeDataset("Shape",
166     {
167         // Batch size 1
168         TensorShape{ 9U, 9U },
169                      TensorShape{ 27U, 13U, 2U },
170     })
171     {
172     }
173 };
174 /** Data set containing small tensor shapes. */
175 class SmallShapes final : public ShapeDataset
176 {
177 public:
SmallShapes()178     SmallShapes()
179         : ShapeDataset("Shape",
180     {
181         // Batch size 1
182         TensorShape{ 11U, 11U },
183                      TensorShape{ 16U, 16U },
184                      TensorShape{ 27U, 13U, 7U },
185                      TensorShape{ 31U, 27U, 17U, 2U },
186                      // Batch size 4
187                      TensorShape{ 27U, 13U, 2U, 4U },
188                      // Arbitrary batch size
189                      TensorShape{ 11U, 11U, 3U, 5U }
190     })
191     {
192     }
193 };
194 
195 /** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */
196 class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
197 {
198 public:
TinyShapesBroadcast()199     TinyShapesBroadcast()
200         : ZipDataset<ShapeDataset, ShapeDataset>(
201               ShapeDataset("Shape0",
202     {
203         TensorShape{ 9U, 9U },
204                      TensorShape{ 10U, 2U, 14U, 2U },
205     }),
206     ShapeDataset("Shape1",
207     {
208         TensorShape{ 9U, 1U, 9U },
209         TensorShape{ 10U },
210     }))
211     {
212     }
213 };
214 /** Data set containing pairs of small tensor shapes that are broadcast compatible. */
215 class SmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
216 {
217 public:
SmallShapesBroadcast()218     SmallShapesBroadcast()
219         : ZipDataset<ShapeDataset, ShapeDataset>(
220               ShapeDataset("Shape0",
221     {
222         TensorShape{ 9U, 9U },
223                      TensorShape{ 27U, 13U, 2U },
224                      TensorShape{ 128U, 1U, 5U, 3U },
225                      TensorShape{ 9U, 9U, 3U, 4U },
226                      TensorShape{ 27U, 13U, 2U, 4U },
227                      TensorShape{ 1U, 1U, 1U, 5U },
228                      TensorShape{ 1U, 16U, 10U, 2U, 128U },
229                      TensorShape{ 1U, 16U, 10U, 2U, 128U }
230     }),
231     ShapeDataset("Shape1",
232     {
233         TensorShape{ 9U, 1U, 2U },
234         TensorShape{ 1U, 13U, 2U },
235         TensorShape{ 128U, 64U, 1U, 3U },
236         TensorShape{ 9U, 1U, 3U },
237         TensorShape{ 1U },
238         TensorShape{ 9U, 9U, 3U, 5U },
239         TensorShape{ 1U, 1U, 1U, 1U, 128U },
240         TensorShape{ 128U }
241     }))
242     {
243     }
244 };
245 
246 /** Data set containing medium tensor shapes. */
247 class MediumShapes final : public ShapeDataset
248 {
249 public:
MediumShapes()250     MediumShapes()
251         : ShapeDataset("Shape",
252     {
253         // Batch size 1
254         TensorShape{ 37U, 37U },
255                      TensorShape{ 27U, 33U, 2U },
256                      // Arbitrary batch size
257                      TensorShape{ 37U, 37U, 3U, 5U }
258     })
259     {
260     }
261 };
262 
263 /** Data set containing medium 2D tensor shapes. */
264 class Medium2DShapes final : public ShapeDataset
265 {
266 public:
Medium2DShapes()267     Medium2DShapes()
268         : ShapeDataset("Shape",
269     {
270         TensorShape{ 42U, 37U },
271                      TensorShape{ 57U, 60U },
272                      TensorShape{ 128U, 64U },
273                      TensorShape{ 83U, 72U },
274                      TensorShape{ 40U, 40U }
275     })
276     {
277     }
278 };
279 
280 /** Data set containing medium 3D tensor shapes. */
281 class Medium3DShapes final : public ShapeDataset
282 {
283 public:
Medium3DShapes()284     Medium3DShapes()
285         : ShapeDataset("Shape",
286     {
287         TensorShape{ 42U, 37U, 8U },
288                      TensorShape{ 57U, 60U, 13U },
289                      TensorShape{ 83U, 72U, 14U }
290     })
291     {
292     }
293 };
294 
295 /** Data set containing medium 4D tensor shapes. */
296 class Medium4DShapes final : public ShapeDataset
297 {
298 public:
Medium4DShapes()299     Medium4DShapes()
300         : ShapeDataset("Shape",
301     {
302         TensorShape{ 42U, 37U, 8U, 15U },
303                      TensorShape{ 57U, 60U, 13U, 8U },
304                      TensorShape{ 83U, 72U, 14U, 5U }
305     })
306     {
307     }
308 };
309 
310 /** Data set containing large tensor shapes. */
311 class LargeShapes final : public ShapeDataset
312 {
313 public:
LargeShapes()314     LargeShapes()
315         : ShapeDataset("Shape",
316     {
317         TensorShape{ 582U, 131U, 1U, 4U },
318     })
319     {
320     }
321 };
322 
323 /** Data set containing pairs of large tensor shapes that are broadcast compatible. */
324 class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
325 {
326 public:
LargeShapesBroadcast()327     LargeShapesBroadcast()
328         : ZipDataset<ShapeDataset, ShapeDataset>(
329               ShapeDataset("Shape0",
330     {
331         TensorShape{ 1921U, 541U },
332                      TensorShape{ 1U, 485U, 2U, 3U },
333                      TensorShape{ 4159U, 1U },
334                      TensorShape{ 799U }
335     }),
336     ShapeDataset("Shape1",
337     {
338         TensorShape{ 1921U, 1U, 2U },
339         TensorShape{ 641U, 1U, 2U, 3U },
340         TensorShape{ 1U, 127U, 25U },
341         TensorShape{ 799U, 595U, 1U, 4U }
342     }))
343     {
344     }
345 };
346 
347 /** Data set containing large 1D tensor shapes. */
348 class Large1DShapes final : public ShapeDataset
349 {
350 public:
Large1DShapes()351     Large1DShapes()
352         : ShapeDataset("Shape",
353     {
354         TensorShape{ 1245U }
355     })
356     {
357     }
358 };
359 
360 /** Data set containing large 2D tensor shapes. */
361 class Large2DShapes final : public ShapeDataset
362 {
363 public:
Large2DShapes()364     Large2DShapes()
365         : ShapeDataset("Shape",
366     {
367         TensorShape{ 1245U, 652U }
368     })
369     {
370     }
371 };
372 
373 /** Data set containing large 3D tensor shapes. */
374 class Large3DShapes final : public ShapeDataset
375 {
376 public:
Large3DShapes()377     Large3DShapes()
378         : ShapeDataset("Shape",
379     {
380         TensorShape{ 320U, 240U, 3U }
381     })
382     {
383     }
384 };
385 
386 /** Data set containing large 4D tensor shapes. */
387 class Large4DShapes final : public ShapeDataset
388 {
389 public:
Large4DShapes()390     Large4DShapes()
391         : ShapeDataset("Shape",
392     {
393         TensorShape{ 320U, 123U, 3U, 3U }
394     })
395     {
396     }
397 };
398 
399 /** Data set containing small 3x3 tensor shapes. */
400 class Small3x3Shapes final : public ShapeDataset
401 {
402 public:
Small3x3Shapes()403     Small3x3Shapes()
404         : ShapeDataset("Shape",
405     {
406         TensorShape{ 3U, 3U, 7U, 4U },
407                      TensorShape{ 3U, 3U, 4U, 13U },
408                      TensorShape{ 3U, 3U, 3U, 5U },
409     })
410     {
411     }
412 };
413 
414 /** Data set containing small 3x1 tensor shapes. */
415 class Small3x1Shapes final : public ShapeDataset
416 {
417 public:
Small3x1Shapes()418     Small3x1Shapes()
419         : ShapeDataset("Shape",
420     {
421         TensorShape{ 3U, 1U, 7U, 4U },
422                      TensorShape{ 3U, 1U, 4U, 13U },
423                      TensorShape{ 3U, 1U, 3U, 5U },
424     })
425     {
426     }
427 };
428 
429 /** Data set containing small 1x3 tensor shapes. */
430 class Small1x3Shapes final : public ShapeDataset
431 {
432 public:
Small1x3Shapes()433     Small1x3Shapes()
434         : ShapeDataset("Shape",
435     {
436         TensorShape{ 1U, 3U, 7U, 4U },
437                      TensorShape{ 1U, 3U, 4U, 13U },
438                      TensorShape{ 1U, 3U, 3U, 5U },
439     })
440     {
441     }
442 };
443 
444 /** Data set containing large 3x3 tensor shapes. */
445 class Large3x3Shapes final : public ShapeDataset
446 {
447 public:
Large3x3Shapes()448     Large3x3Shapes()
449         : ShapeDataset("Shape",
450     {
451         TensorShape{ 3U, 3U, 32U, 64U },
452                      TensorShape{ 3U, 3U, 51U, 13U },
453                      TensorShape{ 3U, 3U, 53U, 47U },
454     })
455     {
456     }
457 };
458 
459 /** Data set containing large 3x1 tensor shapes. */
460 class Large3x1Shapes final : public ShapeDataset
461 {
462 public:
Large3x1Shapes()463     Large3x1Shapes()
464         : ShapeDataset("Shape",
465     {
466         TensorShape{ 3U, 1U, 32U, 64U },
467                      TensorShape{ 3U, 1U, 51U, 13U },
468                      TensorShape{ 3U, 1U, 53U, 47U },
469     })
470     {
471     }
472 };
473 
474 /** Data set containing large 1x3 tensor shapes. */
475 class Large1x3Shapes final : public ShapeDataset
476 {
477 public:
Large1x3Shapes()478     Large1x3Shapes()
479         : ShapeDataset("Shape",
480     {
481         TensorShape{ 1U, 3U, 32U, 64U },
482                      TensorShape{ 1U, 3U, 51U, 13U },
483                      TensorShape{ 1U, 3U, 53U, 47U },
484     })
485     {
486     }
487 };
488 
489 /** Data set containing small 5x5 tensor shapes. */
490 class Small5x5Shapes final : public ShapeDataset
491 {
492 public:
Small5x5Shapes()493     Small5x5Shapes()
494         : ShapeDataset("Shape",
495     {
496         TensorShape{ 5U, 5U, 7U, 4U },
497                      TensorShape{ 5U, 5U, 4U, 13U },
498                      TensorShape{ 5U, 5U, 3U, 5U },
499     })
500     {
501     }
502 };
503 
504 /** Data set containing large 5x5 tensor shapes. */
505 class Large5x5Shapes final : public ShapeDataset
506 {
507 public:
Large5x5Shapes()508     Large5x5Shapes()
509         : ShapeDataset("Shape",
510     {
511         TensorShape{ 5U, 5U, 32U, 64U }
512     })
513     {
514     }
515 };
516 
517 /** Data set containing small 5x1 tensor shapes. */
518 class Small5x1Shapes final : public ShapeDataset
519 {
520 public:
Small5x1Shapes()521     Small5x1Shapes()
522         : ShapeDataset("Shape",
523     {
524         TensorShape{ 5U, 1U, 7U, 4U }
525     })
526     {
527     }
528 };
529 
530 /** Data set containing large 5x1 tensor shapes. */
531 class Large5x1Shapes final : public ShapeDataset
532 {
533 public:
Large5x1Shapes()534     Large5x1Shapes()
535         : ShapeDataset("Shape",
536     {
537         TensorShape{ 5U, 1U, 32U, 64U }
538     })
539     {
540     }
541 };
542 
543 /** Data set containing small 1x5 tensor shapes. */
544 class Small1x5Shapes final : public ShapeDataset
545 {
546 public:
Small1x5Shapes()547     Small1x5Shapes()
548         : ShapeDataset("Shape",
549     {
550         TensorShape{ 1U, 5U, 7U, 4U }
551     })
552     {
553     }
554 };
555 
556 /** Data set containing large 1x5 tensor shapes. */
557 class Large1x5Shapes final : public ShapeDataset
558 {
559 public:
Large1x5Shapes()560     Large1x5Shapes()
561         : ShapeDataset("Shape",
562     {
563         TensorShape{ 1U, 5U, 32U, 64U }
564     })
565     {
566     }
567 };
568 
569 /** Data set containing small 1x7 tensor shapes. */
570 class Small1x7Shapes final : public ShapeDataset
571 {
572 public:
Small1x7Shapes()573     Small1x7Shapes()
574         : ShapeDataset("Shape",
575     {
576         TensorShape{ 1U, 7U, 7U, 4U }
577     })
578     {
579     }
580 };
581 
582 /** Data set containing large 1x7 tensor shapes. */
583 class Large1x7Shapes final : public ShapeDataset
584 {
585 public:
Large1x7Shapes()586     Large1x7Shapes()
587         : ShapeDataset("Shape",
588     {
589         TensorShape{ 1U, 7U, 32U, 64U }
590     })
591     {
592     }
593 };
594 
595 /** Data set containing small 7x7 tensor shapes. */
596 class Small7x7Shapes final : public ShapeDataset
597 {
598 public:
Small7x7Shapes()599     Small7x7Shapes()
600         : ShapeDataset("Shape",
601     {
602         TensorShape{ 7U, 7U, 7U, 4U }
603     })
604     {
605     }
606 };
607 
608 /** Data set containing large 7x7 tensor shapes. */
609 class Large7x7Shapes final : public ShapeDataset
610 {
611 public:
Large7x7Shapes()612     Large7x7Shapes()
613         : ShapeDataset("Shape",
614     {
615         TensorShape{ 7U, 7U, 32U, 64U }
616     })
617     {
618     }
619 };
620 
621 /** Data set containing small 7x1 tensor shapes. */
622 class Small7x1Shapes final : public ShapeDataset
623 {
624 public:
Small7x1Shapes()625     Small7x1Shapes()
626         : ShapeDataset("Shape",
627     {
628         TensorShape{ 7U, 1U, 7U, 4U }
629     })
630     {
631     }
632 };
633 
634 /** Data set containing large 7x1 tensor shapes. */
635 class Large7x1Shapes final : public ShapeDataset
636 {
637 public:
Large7x1Shapes()638     Large7x1Shapes()
639         : ShapeDataset("Shape",
640     {
641         TensorShape{ 7U, 1U, 32U, 64U }
642     })
643     {
644     }
645 };
646 
647 /** Data set containing small tensor shapes for deconvolution. */
648 class SmallDeconvolutionShapes final : public ShapeDataset
649 {
650 public:
SmallDeconvolutionShapes()651     SmallDeconvolutionShapes()
652         : ShapeDataset("InputShape",
653     {
654         TensorShape{ 5U, 4U, 3U, 2U },
655                      TensorShape{ 5U, 5U, 3U },
656                      TensorShape{ 11U, 13U, 4U, 3U }
657     })
658     {
659     }
660 };
661 
662 /** Data set containing tiny tensor shapes for direct convolution. */
663 class TinyDirectConvolutionShapes final : public ShapeDataset
664 {
665 public:
TinyDirectConvolutionShapes()666     TinyDirectConvolutionShapes()
667         : ShapeDataset("InputShape",
668     {
669         // Batch size 1
670         TensorShape{ 11U, 13U, 3U },
671                      TensorShape{ 7U, 27U, 3U }
672     })
673     {
674     }
675 };
676 /** Data set containing small tensor shapes for direct convolution. */
677 class SmallDirectConvolutionShapes final : public ShapeDataset
678 {
679 public:
SmallDirectConvolutionShapes()680     SmallDirectConvolutionShapes()
681         : ShapeDataset("InputShape",
682     {
683         // Batch size 1
684         TensorShape{ 32U, 37U, 3U },
685                      // Batch size 4
686                      TensorShape{ 32U, 37U, 3U, 4U },
687     })
688     {
689     }
690 };
691 
692 /** Data set containing small tensor shapes for direct convolution. */
693 class SmallDirectConvolutionTensorShiftShapes final : public ShapeDataset
694 {
695 public:
SmallDirectConvolutionTensorShiftShapes()696     SmallDirectConvolutionTensorShiftShapes()
697         : ShapeDataset("InputShape",
698     {
699         // Batch size 1
700         TensorShape{ 32U, 37U, 3U },
701                      // Batch size 4
702                      TensorShape{ 32U, 37U, 3U, 4U },
703                      // Arbitrary batch size
704                      TensorShape{ 32U, 37U, 3U, 8U }
705     })
706     {
707     }
708 };
709 
710 /** Data set containing small grouped im2col tensor shapes. */
711 class GroupedIm2ColSmallShapes final : public ShapeDataset
712 {
713 public:
GroupedIm2ColSmallShapes()714     GroupedIm2ColSmallShapes()
715         : ShapeDataset("Shape",
716     {
717         TensorShape{ 11U, 11U, 48U },
718                      TensorShape{ 27U, 13U, 24U },
719                      TensorShape{ 128U, 64U, 12U, 3U },
720                      TensorShape{ 11U, 11U, 48U, 4U },
721                      TensorShape{ 27U, 13U, 24U, 4U },
722                      TensorShape{ 11U, 11U, 48U, 5U }
723     })
724     {
725     }
726 };
727 
728 /** Data set containing large grouped im2col tensor shapes. */
729 class GroupedIm2ColLargeShapes final : public ShapeDataset
730 {
731 public:
GroupedIm2ColLargeShapes()732     GroupedIm2ColLargeShapes()
733         : ShapeDataset("Shape",
734     {
735         TensorShape{ 153U, 231U, 12U },
736                      TensorShape{ 123U, 191U, 12U, 2U },
737     })
738     {
739     }
740 };
741 
742 /** Data set containing small grouped weights tensor shapes. */
743 class GroupedWeightsSmallShapes final : public ShapeDataset
744 {
745 public:
GroupedWeightsSmallShapes()746     GroupedWeightsSmallShapes()
747         : ShapeDataset("Shape",
748     {
749         TensorShape{ 3U, 3U, 48U, 120U },
750                      TensorShape{ 1U, 3U, 24U, 240U },
751                      TensorShape{ 3U, 1U, 12U, 480U },
752                      TensorShape{ 5U, 5U, 48U, 120U }
753     })
754     {
755     }
756 };
757 
758 /** Data set containing large grouped weights tensor shapes. */
759 class GroupedWeightsLargeShapes final : public ShapeDataset
760 {
761 public:
GroupedWeightsLargeShapes()762     GroupedWeightsLargeShapes()
763         : ShapeDataset("Shape",
764     {
765         TensorShape{ 9U, 9U, 96U, 240U },
766                      TensorShape{ 13U, 13U, 96U, 240U }
767     })
768     {
769     }
770 };
771 
772 /** Data set containing 2D tensor shapes for DepthConcatenateLayer. */
773 class DepthConcatenateLayerShapes final : public ShapeDataset
774 {
775 public:
DepthConcatenateLayerShapes()776     DepthConcatenateLayerShapes()
777         : ShapeDataset("Shape",
778     {
779         TensorShape{ 322U, 243U },
780                      TensorShape{ 463U, 879U },
781                      TensorShape{ 416U, 651U }
782     })
783     {
784     }
785 };
786 
787 /** Data set containing tensor shapes for ConcatenateLayer. */
788 class ConcatenateLayerShapes final : public ShapeDataset
789 {
790 public:
ConcatenateLayerShapes()791     ConcatenateLayerShapes()
792         : ShapeDataset("Shape",
793     {
794         TensorShape{ 232U, 65U, 3U },
795                      TensorShape{ 432U, 65U, 3U },
796                      TensorShape{ 124U, 65U, 3U },
797                      TensorShape{ 124U, 65U, 3U, 4U }
798     })
799     {
800     }
801 };
802 
803 /** Data set containing global pooling tensor shapes. */
804 class GlobalPoolingShapes final : public ShapeDataset
805 {
806 public:
GlobalPoolingShapes()807     GlobalPoolingShapes()
808         : ShapeDataset("Shape",
809     {
810         // Batch size 1
811         TensorShape{ 9U, 9U },
812                      TensorShape{ 13U, 13U, 2U },
813                      TensorShape{ 27U, 27U, 1U, 3U },
814                      // Batch size 4
815                      TensorShape{ 31U, 31U, 3U, 4U },
816                      TensorShape{ 34U, 34U, 2U, 4U }
817     })
818     {
819     }
820 };
821 /** Data set containing tiny softmax layer shapes. */
822 class SoftmaxLayerTinyShapes final : public ShapeDataset
823 {
824 public:
SoftmaxLayerTinyShapes()825     SoftmaxLayerTinyShapes()
826         : ShapeDataset("Shape",
827     {
828         TensorShape{ 9U, 9U },
829                      TensorShape{ 128U, 10U },
830     })
831     {
832     }
833 };
834 
835 /** Data set containing small softmax layer shapes. */
836 class SoftmaxLayerSmallShapes final : public ShapeDataset
837 {
838 public:
SoftmaxLayerSmallShapes()839     SoftmaxLayerSmallShapes()
840         : ShapeDataset("Shape",
841     {
842         TensorShape{ 9U, 9U },
843                      TensorShape{ 256U, 10U },
844                      TensorShape{ 353U, 8U },
845                      TensorShape{ 781U, 5U },
846     })
847     {
848     }
849 };
850 
851 /** Data set containing large softmax layer shapes. */
852 class SoftmaxLayerLargeShapes final : public ShapeDataset
853 {
854 public:
SoftmaxLayerLargeShapes()855     SoftmaxLayerLargeShapes()
856         : ShapeDataset("Shape",
857     {
858         TensorShape{ 1000U, 10U }
859 
860     })
861     {
862     }
863 };
864 
865 /** Data set containing large and small softmax layer 4D shapes. */
866 class SoftmaxLayer4DShapes final : public ShapeDataset
867 {
868 public:
SoftmaxLayer4DShapes()869     SoftmaxLayer4DShapes()
870         : ShapeDataset("Shape",
871     {
872         TensorShape{ 9U, 9U, 9U, 9U },
873                      TensorShape{ 31U, 10U, 1U, 9U },
874     })
875     {
876     }
877 };
878 
879 /** Data set containing 2D tensor shapes relative to an image size. */
880 class SmallImageShapes final : public ShapeDataset
881 {
882 public:
SmallImageShapes()883     SmallImageShapes()
884         : ShapeDataset("Shape",
885     {
886         TensorShape{ 640U, 480U },
887                      TensorShape{ 800U, 600U },
888     })
889     {
890     }
891 };
892 
893 /** Data set containing 2D tensor shapes relative to an image size. */
894 class LargeImageShapes final : public ShapeDataset
895 {
896 public:
LargeImageShapes()897     LargeImageShapes()
898         : ShapeDataset("Shape",
899     {
900         TensorShape{ 1920U, 1080U },
901                      TensorShape{ 2560U, 1536U },
902                      TensorShape{ 3584U, 2048U }
903     })
904     {
905     }
906 };
907 
908 /** Data set containing small YOLO tensor shapes. */
909 class SmallYOLOShapes final : public ShapeDataset
910 {
911 public:
SmallYOLOShapes()912     SmallYOLOShapes()
913         : ShapeDataset("Shape",
914     {
915         // Batch size 1
916         TensorShape{ 11U, 11U, 270U },
917                      TensorShape{ 27U, 13U, 90U },
918                      TensorShape{ 13U, 12U, 45U, 2U },
919     })
920     {
921     }
922 };
923 
924 /** Data set containing large YOLO tensor shapes. */
925 class LargeYOLOShapes final : public ShapeDataset
926 {
927 public:
LargeYOLOShapes()928     LargeYOLOShapes()
929         : ShapeDataset("Shape",
930     {
931         TensorShape{ 24U, 23U, 270U },
932                      TensorShape{ 51U, 63U, 90U, 2U },
933                      TensorShape{ 76U, 91U, 45U, 3U }
934     })
935     {
936     }
937 };
938 
939 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel */
940 class SmallGEMMReshape2DShapes final : public ShapeDataset
941 {
942 public:
SmallGEMMReshape2DShapes()943     SmallGEMMReshape2DShapes()
944         : ShapeDataset("Shape",
945     {
946         TensorShape{ 63U, 72U },
947     })
948     {
949     }
950 };
951 
952 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */
953 class SmallGEMMReshape3DShapes final : public ShapeDataset
954 {
955 public:
SmallGEMMReshape3DShapes()956     SmallGEMMReshape3DShapes()
957         : ShapeDataset("Shape",
958     {
959         TensorShape{ 63U, 9U, 8U },
960     })
961     {
962     }
963 };
964 
965 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel */
966 class LargeGEMMReshape2DShapes final : public ShapeDataset
967 {
968 public:
LargeGEMMReshape2DShapes()969     LargeGEMMReshape2DShapes()
970         : ShapeDataset("Shape",
971     {
972         TensorShape{ 16U, 27U },
973                      TensorShape{ 345U, 171U }
974     })
975     {
976     }
977 };
978 
979 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */
980 class LargeGEMMReshape3DShapes final : public ShapeDataset
981 {
982 public:
LargeGEMMReshape3DShapes()983     LargeGEMMReshape3DShapes()
984         : ShapeDataset("Shape",
985     {
986         TensorShape{ 16U, 3U, 9U },
987                      TensorShape{ 345U, 34U, 18U }
988     })
989     {
990     }
991 };
992 
993 /** Data set containing small 2D tensor shapes. */
994 class Small2DNonMaxSuppressionShapes final : public ShapeDataset
995 {
996 public:
Small2DNonMaxSuppressionShapes()997     Small2DNonMaxSuppressionShapes()
998         : ShapeDataset("Shape",
999     {
1000         TensorShape{ 4U, 7U },
1001                      TensorShape{ 4U, 13U },
1002                      TensorShape{ 4U, 64U }
1003     })
1004     {
1005     }
1006 };
1007 
1008 /** Data set containing large 2D tensor shapes. */
1009 class Large2DNonMaxSuppressionShapes final : public ShapeDataset
1010 {
1011 public:
Large2DNonMaxSuppressionShapes()1012     Large2DNonMaxSuppressionShapes()
1013         : ShapeDataset("Shape",
1014     {
1015         TensorShape{ 4U, 113U }
1016     })
1017     {
1018     }
1019 };
1020 
1021 } // namespace datasets
1022 } // namespace test
1023 } // namespace arm_compute
1024 #endif /* ARM_COMPUTE_TEST_SHAPE_DATASETS_H */
1025