• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_SHAPE_DATASETS_H
25 #define ARM_COMPUTE_TEST_SHAPE_DATASETS_H
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "tests/framework/datasets/Datasets.h"
29 
30 #include <type_traits>
31 
32 namespace arm_compute
33 {
34 namespace test
35 {
36 namespace datasets
37 {
38 /** Parent type for all for shape datasets. */
39 using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>;
40 
41 /** Data set containing tiny 1D tensor shapes. */
42 class Tiny1DShapes final : public ShapeDataset
43 {
44 public:
Tiny1DShapes()45     Tiny1DShapes()
46         : ShapeDataset("Shape",
47     {
48         TensorShape{ 2U },
49                      TensorShape{ 3U },
50     })
51     {
52     }
53 };
54 
55 /** Data set containing small 1D tensor shapes. */
56 class Small1DShapes final : public ShapeDataset
57 {
58 public:
Small1DShapes()59     Small1DShapes()
60         : ShapeDataset("Shape",
61     {
62         TensorShape{ 128U },
63                      TensorShape{ 256U },
64                      TensorShape{ 512U },
65                      TensorShape{ 1024U }
66     })
67     {
68     }
69 };
70 
71 /** Data set containing tiny 2D tensor shapes. */
72 class Tiny2DShapes final : public ShapeDataset
73 {
74 public:
Tiny2DShapes()75     Tiny2DShapes()
76         : ShapeDataset("Shape",
77     {
78         TensorShape{ 7U, 7U },
79                      TensorShape{ 11U, 13U },
80     })
81     {
82     }
83 };
84 /** Data set containing small 2D tensor shapes. */
85 class Small2DShapes final : public ShapeDataset
86 {
87 public:
Small2DShapes()88     Small2DShapes()
89         : ShapeDataset("Shape",
90     {
91         TensorShape{ 1U, 7U },
92                      TensorShape{ 5U, 13U },
93                      TensorShape{ 32U, 64U }
94     })
95     {
96     }
97 };
98 
99 /** Data set containing tiny 3D tensor shapes. */
100 class Tiny3DShapes final : public ShapeDataset
101 {
102 public:
Tiny3DShapes()103     Tiny3DShapes()
104         : ShapeDataset("Shape",
105     {
106         TensorShape{ 7U, 7U, 5U },
107                      TensorShape{ 23U, 13U, 9U },
108     })
109     {
110     }
111 };
112 
113 /** Data set containing small 3D tensor shapes. */
114 class Small3DShapes final : public ShapeDataset
115 {
116 public:
Small3DShapes()117     Small3DShapes()
118         : ShapeDataset("Shape",
119     {
120         TensorShape{ 1U, 7U, 7U },
121                      TensorShape{ 2U, 5U, 4U },
122 
123                      TensorShape{ 7U, 7U, 5U },
124                      TensorShape{ 16U, 16U, 5U },
125                      TensorShape{ 27U, 13U, 37U },
126     })
127     {
128     }
129 };
130 
131 /** Data set containing tiny 4D tensor shapes. */
132 class Tiny4DShapes final : public ShapeDataset
133 {
134 public:
Tiny4DShapes()135     Tiny4DShapes()
136         : ShapeDataset("Shape",
137     {
138         TensorShape{ 2U, 7U, 5U, 3U },
139                      TensorShape{ 17U, 13U, 7U, 2U },
140     })
141     {
142     }
143 };
144 /** Data set containing small 4D tensor shapes. */
145 class Small4DShapes final : public ShapeDataset
146 {
147 public:
Small4DShapes()148     Small4DShapes()
149         : ShapeDataset("Shape",
150     {
151         TensorShape{ 2U, 7U, 1U, 3U },
152                      TensorShape{ 7U, 7U, 5U, 3U },
153                      TensorShape{ 27U, 13U, 37U, 2U },
154                      TensorShape{ 128U, 64U, 21U, 3U }
155     })
156     {
157     }
158 };
159 
160 /** Data set containing tiny tensor shapes. */
161 class TinyShapes final : public ShapeDataset
162 {
163 public:
TinyShapes()164     TinyShapes()
165         : ShapeDataset("Shape",
166     {
167         // Batch size 1
168         TensorShape{ 1U, 9U },
169                      TensorShape{ 27U, 13U, 2U },
170     })
171     {
172     }
173 };
174 /** Data set containing small tensor shapes with none of the dimensions equal to 1 (unit). */
175 class SmallNoneUnitShapes final : public ShapeDataset
176 {
177 public:
SmallNoneUnitShapes()178     SmallNoneUnitShapes()
179         : ShapeDataset("Shape",
180     {
181         // Batch size 1
182         TensorShape{ 13U, 11U },
183                      TensorShape{ 16U, 16U },
184                      TensorShape{ 24U, 26U, 5U },
185                      TensorShape{ 7U, 7U, 17U, 2U },
186                      // Batch size 4
187                      TensorShape{ 27U, 13U, 2U, 4U },
188                      // Arbitrary batch size
189                      TensorShape{ 8U, 7U, 5U, 5U }
190     })
191     {
192     }
193 };
194 /** Data set containing small tensor shapes. */
195 class SmallShapes final : public ShapeDataset
196 {
197 public:
SmallShapes()198     SmallShapes()
199         : ShapeDataset("Shape",
200     {
201         // Batch size 1
202         TensorShape{ 3U, 11U },
203                      TensorShape{ 1U, 16U },
204                      TensorShape{ 27U, 13U, 7U },
205                      TensorShape{ 7U, 7U, 17U, 2U },
206                      // Batch size 4 and 2 SIMD iterations
207                      TensorShape{ 33U, 13U, 2U, 4U },
208                      // Arbitrary batch size
209                      TensorShape{ 11U, 11U, 3U, 5U }
210     })
211     {
212     }
213 };
214 
215 /** Data set containing small tensor shapes. */
216 class SmallShapesNoBatches final : public ShapeDataset
217 {
218 public:
SmallShapesNoBatches()219     SmallShapesNoBatches()
220         : ShapeDataset("Shape",
221     {
222         // Batch size 1
223         TensorShape{ 3U, 11U },
224                      TensorShape{ 1U, 16U },
225                      TensorShape{ 27U, 13U, 7U },
226                      TensorShape{ 7U, 7U, 17U },
227                      TensorShape{ 33U, 13U, 2U },
228                      TensorShape{ 11U, 11U, 3U }
229     })
230     {
231     }
232 };
233 
234 /** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */
235 class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
236 {
237 public:
TinyShapesBroadcast()238     TinyShapesBroadcast()
239         : ZipDataset<ShapeDataset, ShapeDataset>(
240               ShapeDataset("Shape0",
241     {
242         TensorShape{ 9U, 9U },
243                      TensorShape{ 10U, 2U, 14U, 2U },
244     }),
245     ShapeDataset("Shape1",
246     {
247         TensorShape{ 9U, 1U, 9U },
248         TensorShape{ 10U },
249     }))
250     {
251     }
252 };
253 /** Data set containing pairs of tiny tensor shapes that are broadcast compatible and can do in_place calculation. */
254 class TinyShapesBroadcastInplace final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
255 {
256 public:
TinyShapesBroadcastInplace()257     TinyShapesBroadcastInplace()
258         : ZipDataset<ShapeDataset, ShapeDataset>(
259               ShapeDataset("Shape0",
260     {
261         TensorShape{ 9U },
262                      TensorShape{ 10U, 2U, 14U, 2U },
263     }),
264     ShapeDataset("Shape1",
265     {
266         TensorShape{ 9U, 1U, 9U },
267         TensorShape{ 10U },
268     }))
269     {
270     }
271 };
272 /** Data set containing pairs of small tensor shapes that are broadcast compatible. */
273 class SmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
274 {
275 public:
SmallShapesBroadcast()276     SmallShapesBroadcast()
277         : ZipDataset<ShapeDataset, ShapeDataset>(
278               ShapeDataset("Shape0",
279     {
280         TensorShape{ 9U, 9U },
281                      TensorShape{ 27U, 13U, 2U },
282                      TensorShape{ 128U, 1U, 5U, 3U },
283                      TensorShape{ 9U, 9U, 3U, 4U },
284                      TensorShape{ 27U, 13U, 2U, 4U },
285                      TensorShape{ 1U, 1U, 1U, 5U },
286                      TensorShape{ 1U, 16U, 10U, 2U, 128U },
287                      TensorShape{ 1U, 16U, 10U, 2U, 128U }
288     }),
289     ShapeDataset("Shape1",
290     {
291         TensorShape{ 9U, 1U, 2U },
292         TensorShape{ 1U, 13U, 2U },
293         TensorShape{ 128U, 64U, 1U, 3U },
294         TensorShape{ 9U, 1U, 3U },
295         TensorShape{ 1U },
296         TensorShape{ 9U, 9U, 3U, 5U },
297         TensorShape{ 1U, 1U, 1U, 1U, 128U },
298         TensorShape{ 128U }
299     }))
300     {
301     }
302 };
303 
304 class TemporaryLimitedSmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
305 {
306 public:
TemporaryLimitedSmallShapesBroadcast()307     TemporaryLimitedSmallShapesBroadcast()
308         : ZipDataset<ShapeDataset, ShapeDataset>(
309               ShapeDataset("Shape0",
310     {
311         TensorShape{ 1U, 3U, 4U, 2U },  // LHS broadcast X
312         TensorShape{ 6U, 4U, 2U, 3U },  // RHS broadcast X
313         TensorShape{ 7U, 1U, 1U, 4U },  // LHS broadcast Y, Z
314         TensorShape{ 8U, 5U, 6U, 3U },  // RHS broadcast Y, Z
315         TensorShape{ 1U, 1U, 1U, 2U },  // LHS broadcast X, Y, Z
316         TensorShape{ 2U, 6U, 4U, 3U },  // RHS broadcast X, Y, Z
317     }),
318     ShapeDataset("Shape1",
319     {
320         TensorShape{ 5U, 3U, 4U, 2U },
321         TensorShape{ 1U, 4U, 2U, 3U },
322         TensorShape{ 7U, 2U, 3U, 4U },
323         TensorShape{ 8U, 1U, 1U, 3U },
324         TensorShape{ 4U, 7U, 3U, 2U },
325         TensorShape{ 1U, 1U, 1U, 3U },
326     }))
327     {
328     }
329 };
330 
331 class TemporaryLimitedLargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
332 {
333 public:
TemporaryLimitedLargeShapesBroadcast()334     TemporaryLimitedLargeShapesBroadcast()
335         : ZipDataset<ShapeDataset, ShapeDataset>(
336               ShapeDataset("Shape0",
337     {
338         TensorShape{ 127U, 25U, 5U },
339                      TensorShape{ 485, 40U, 10U }
340     }),
341     ShapeDataset("Shape1",
342     {
343         TensorShape{ 1U, 1U, 1U },   // Broadcast in X, Y, Z
344         TensorShape{ 485U, 1U, 1U }, // Broadcast in Y, Z
345     }))
346     {
347     }
348 };
349 
350 /** Data set containing medium tensor shapes. */
351 class MediumShapes final : public ShapeDataset
352 {
353 public:
MediumShapes()354     MediumShapes()
355         : ShapeDataset("Shape",
356     {
357         // Batch size 1
358         TensorShape{ 37U, 37U },
359                      TensorShape{ 27U, 33U, 2U },
360                      // Arbitrary batch size
361                      TensorShape{ 37U, 37U, 3U, 5U }
362     })
363     {
364     }
365 };
366 
367 /** Data set containing medium 2D tensor shapes. */
368 class Medium2DShapes final : public ShapeDataset
369 {
370 public:
Medium2DShapes()371     Medium2DShapes()
372         : ShapeDataset("Shape",
373     {
374         TensorShape{ 42U, 37U },
375                      TensorShape{ 57U, 60U },
376                      TensorShape{ 128U, 64U },
377                      TensorShape{ 83U, 72U },
378                      TensorShape{ 40U, 40U }
379     })
380     {
381     }
382 };
383 
384 /** Data set containing medium 3D tensor shapes. */
385 class Medium3DShapes final : public ShapeDataset
386 {
387 public:
Medium3DShapes()388     Medium3DShapes()
389         : ShapeDataset("Shape",
390     {
391         TensorShape{ 42U, 37U, 8U },
392                      TensorShape{ 57U, 60U, 13U },
393                      TensorShape{ 83U, 72U, 14U }
394     })
395     {
396     }
397 };
398 
399 /** Data set containing medium 4D tensor shapes. */
400 class Medium4DShapes final : public ShapeDataset
401 {
402 public:
Medium4DShapes()403     Medium4DShapes()
404         : ShapeDataset("Shape",
405     {
406         TensorShape{ 42U, 37U, 8U, 15U },
407                      TensorShape{ 57U, 60U, 13U, 8U },
408                      TensorShape{ 83U, 72U, 14U, 5U }
409     })
410     {
411     }
412 };
413 
414 /** Data set containing large tensor shapes. */
415 class LargeShapes final : public ShapeDataset
416 {
417 public:
LargeShapes()418     LargeShapes()
419         : ShapeDataset("Shape",
420     {
421         TensorShape{ 582U, 131U, 1U, 4U },
422     })
423     {
424     }
425 };
426 
427 /** Data set containing large tensor shapes. */
428 class LargeShapesNoBatches final : public ShapeDataset
429 {
430 public:
LargeShapesNoBatches()431     LargeShapesNoBatches()
432         : ShapeDataset("Shape",
433     {
434         TensorShape{ 582U, 131U, 2U },
435     })
436     {
437     }
438 };
439 
440 /** Data set containing pairs of large tensor shapes that are broadcast compatible. */
441 class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
442 {
443 public:
LargeShapesBroadcast()444     LargeShapesBroadcast()
445         : ZipDataset<ShapeDataset, ShapeDataset>(
446               ShapeDataset("Shape0",
447     {
448         TensorShape{ 1921U, 541U },
449                      TensorShape{ 1U, 485U, 2U, 3U },
450                      TensorShape{ 4159U, 1U },
451                      TensorShape{ 799U }
452     }),
453     ShapeDataset("Shape1",
454     {
455         TensorShape{ 1921U, 1U, 2U },
456         TensorShape{ 641U, 1U, 2U, 3U },
457         TensorShape{ 1U, 127U, 25U },
458         TensorShape{ 799U, 595U, 1U, 4U }
459     }))
460     {
461     }
462 };
463 
464 /** Data set containing large 1D tensor shapes. */
465 class Large1DShapes final : public ShapeDataset
466 {
467 public:
Large1DShapes()468     Large1DShapes()
469         : ShapeDataset("Shape",
470     {
471         TensorShape{ 1245U }
472     })
473     {
474     }
475 };
476 
477 /** Data set containing large 2D tensor shapes. */
478 class Large2DShapes final : public ShapeDataset
479 {
480 public:
Large2DShapes()481     Large2DShapes()
482         : ShapeDataset("Shape",
483     {
484         TensorShape{ 1245U, 652U }
485     })
486     {
487     }
488 };
489 
490 /** Data set containing large 3D tensor shapes. */
491 class Large3DShapes final : public ShapeDataset
492 {
493 public:
Large3DShapes()494     Large3DShapes()
495         : ShapeDataset("Shape",
496     {
497         TensorShape{ 320U, 240U, 3U }
498     })
499     {
500     }
501 };
502 
503 /** Data set containing large 4D tensor shapes. */
504 class Large4DShapes final : public ShapeDataset
505 {
506 public:
Large4DShapes()507     Large4DShapes()
508         : ShapeDataset("Shape",
509     {
510         TensorShape{ 320U, 123U, 3U, 3U }
511     })
512     {
513     }
514 };
515 
516 /** Data set containing small 3x3 tensor shapes. */
517 class Small3x3Shapes final : public ShapeDataset
518 {
519 public:
Small3x3Shapes()520     Small3x3Shapes()
521         : ShapeDataset("Shape",
522     {
523         TensorShape{ 3U, 3U, 7U, 4U },
524                      TensorShape{ 3U, 3U, 4U, 13U },
525                      TensorShape{ 3U, 3U, 3U, 5U },
526     })
527     {
528     }
529 };
530 
531 /** Data set containing small 3x1 tensor shapes. */
532 class Small3x1Shapes final : public ShapeDataset
533 {
534 public:
Small3x1Shapes()535     Small3x1Shapes()
536         : ShapeDataset("Shape",
537     {
538         TensorShape{ 3U, 1U, 7U, 4U },
539                      TensorShape{ 3U, 1U, 4U, 13U },
540                      TensorShape{ 3U, 1U, 3U, 5U },
541     })
542     {
543     }
544 };
545 
546 /** Data set containing small 1x3 tensor shapes. */
547 class Small1x3Shapes final : public ShapeDataset
548 {
549 public:
Small1x3Shapes()550     Small1x3Shapes()
551         : ShapeDataset("Shape",
552     {
553         TensorShape{ 1U, 3U, 7U, 4U },
554                      TensorShape{ 1U, 3U, 4U, 13U },
555                      TensorShape{ 1U, 3U, 3U, 5U },
556     })
557     {
558     }
559 };
560 
561 /** Data set containing large 3x3 tensor shapes. */
562 class Large3x3Shapes final : public ShapeDataset
563 {
564 public:
Large3x3Shapes()565     Large3x3Shapes()
566         : ShapeDataset("Shape",
567     {
568         TensorShape{ 3U, 3U, 32U, 64U },
569                      TensorShape{ 3U, 3U, 51U, 13U },
570                      TensorShape{ 3U, 3U, 53U, 47U },
571     })
572     {
573     }
574 };
575 
576 /** Data set containing large 3x1 tensor shapes. */
577 class Large3x1Shapes final : public ShapeDataset
578 {
579 public:
Large3x1Shapes()580     Large3x1Shapes()
581         : ShapeDataset("Shape",
582     {
583         TensorShape{ 3U, 1U, 32U, 64U },
584                      TensorShape{ 3U, 1U, 51U, 13U },
585                      TensorShape{ 3U, 1U, 53U, 47U },
586     })
587     {
588     }
589 };
590 
591 /** Data set containing large 1x3 tensor shapes. */
592 class Large1x3Shapes final : public ShapeDataset
593 {
594 public:
Large1x3Shapes()595     Large1x3Shapes()
596         : ShapeDataset("Shape",
597     {
598         TensorShape{ 1U, 3U, 32U, 64U },
599                      TensorShape{ 1U, 3U, 51U, 13U },
600                      TensorShape{ 1U, 3U, 53U, 47U },
601     })
602     {
603     }
604 };
605 
606 /** Data set containing small 5x5 tensor shapes. */
607 class Small5x5Shapes final : public ShapeDataset
608 {
609 public:
Small5x5Shapes()610     Small5x5Shapes()
611         : ShapeDataset("Shape",
612     {
613         TensorShape{ 5U, 5U, 7U, 4U },
614                      TensorShape{ 5U, 5U, 4U, 13U },
615                      TensorShape{ 5U, 5U, 3U, 5U },
616     })
617     {
618     }
619 };
620 
621 /** Data set containing small 5D tensor shapes. */
622 class Small5dShapes final : public ShapeDataset
623 {
624 public:
Small5dShapes()625     Small5dShapes()
626         : ShapeDataset("Shape",
627     {
628         TensorShape{ 5U, 5U, 7U, 4U, 3U },
629                      TensorShape{ 5U, 5U, 4U, 13U, 2U },
630                      TensorShape{ 5U, 5U, 3U, 5U, 2U },
631     })
632     {
633     }
634 };
635 
636 /** Data set containing large 5x5 tensor shapes. */
637 class Large5x5Shapes final : public ShapeDataset
638 {
639 public:
Large5x5Shapes()640     Large5x5Shapes()
641         : ShapeDataset("Shape",
642     {
643         TensorShape{ 5U, 5U, 32U, 64U }
644     })
645     {
646     }
647 };
648 
649 /** Data set containing large 5D tensor shapes. */
650 class Large5dShapes final : public ShapeDataset
651 {
652 public:
Large5dShapes()653     Large5dShapes()
654         : ShapeDataset("Shape",
655     {
656         TensorShape{ 30U, 40U, 30U, 32U, 3U }
657     })
658     {
659     }
660 };
661 
662 /** Data set containing small 5x1 tensor shapes. */
663 class Small5x1Shapes final : public ShapeDataset
664 {
665 public:
Small5x1Shapes()666     Small5x1Shapes()
667         : ShapeDataset("Shape",
668     {
669         TensorShape{ 5U, 1U, 7U, 4U }
670     })
671     {
672     }
673 };
674 
675 /** Data set containing large 5x1 tensor shapes. */
676 class Large5x1Shapes final : public ShapeDataset
677 {
678 public:
Large5x1Shapes()679     Large5x1Shapes()
680         : ShapeDataset("Shape",
681     {
682         TensorShape{ 5U, 1U, 32U, 64U }
683     })
684     {
685     }
686 };
687 
688 /** Data set containing small 1x5 tensor shapes. */
689 class Small1x5Shapes final : public ShapeDataset
690 {
691 public:
Small1x5Shapes()692     Small1x5Shapes()
693         : ShapeDataset("Shape",
694     {
695         TensorShape{ 1U, 5U, 7U, 4U }
696     })
697     {
698     }
699 };
700 
701 /** Data set containing large 1x5 tensor shapes. */
702 class Large1x5Shapes final : public ShapeDataset
703 {
704 public:
Large1x5Shapes()705     Large1x5Shapes()
706         : ShapeDataset("Shape",
707     {
708         TensorShape{ 1U, 5U, 32U, 64U }
709     })
710     {
711     }
712 };
713 
714 /** Data set containing small 1x7 tensor shapes. */
715 class Small1x7Shapes final : public ShapeDataset
716 {
717 public:
Small1x7Shapes()718     Small1x7Shapes()
719         : ShapeDataset("Shape",
720     {
721         TensorShape{ 1U, 7U, 7U, 4U }
722     })
723     {
724     }
725 };
726 
727 /** Data set containing large 1x7 tensor shapes. */
728 class Large1x7Shapes final : public ShapeDataset
729 {
730 public:
Large1x7Shapes()731     Large1x7Shapes()
732         : ShapeDataset("Shape",
733     {
734         TensorShape{ 1U, 7U, 32U, 64U }
735     })
736     {
737     }
738 };
739 
740 /** Data set containing small 7x7 tensor shapes. */
741 class Small7x7Shapes final : public ShapeDataset
742 {
743 public:
Small7x7Shapes()744     Small7x7Shapes()
745         : ShapeDataset("Shape",
746     {
747         TensorShape{ 7U, 7U, 7U, 4U }
748     })
749     {
750     }
751 };
752 
753 /** Data set containing large 7x7 tensor shapes. */
754 class Large7x7Shapes final : public ShapeDataset
755 {
756 public:
Large7x7Shapes()757     Large7x7Shapes()
758         : ShapeDataset("Shape",
759     {
760         TensorShape{ 7U, 7U, 32U, 64U }
761     })
762     {
763     }
764 };
765 
766 /** Data set containing small 7x1 tensor shapes. */
767 class Small7x1Shapes final : public ShapeDataset
768 {
769 public:
Small7x1Shapes()770     Small7x1Shapes()
771         : ShapeDataset("Shape",
772     {
773         TensorShape{ 7U, 1U, 7U, 4U }
774     })
775     {
776     }
777 };
778 
779 /** Data set containing large 7x1 tensor shapes. */
780 class Large7x1Shapes final : public ShapeDataset
781 {
782 public:
Large7x1Shapes()783     Large7x1Shapes()
784         : ShapeDataset("Shape",
785     {
786         TensorShape{ 7U, 1U, 32U, 64U }
787     })
788     {
789     }
790 };
791 
792 /** Data set containing small tensor shapes for deconvolution. */
793 class SmallDeconvolutionShapes final : public ShapeDataset
794 {
795 public:
SmallDeconvolutionShapes()796     SmallDeconvolutionShapes()
797         : ShapeDataset("InputShape",
798     {
799         // Multiple Vector Loops for FP32
800         TensorShape{ 5U, 4U, 3U, 2U },
801                      TensorShape{ 5U, 5U, 3U },
802                      TensorShape{ 11U, 13U, 4U, 3U }
803     })
804     {
805     }
806 };
807 
808 class SmallDeconvolutionShapesWithLargerChannels final : public ShapeDataset
809 {
810 public:
SmallDeconvolutionShapesWithLargerChannels()811     SmallDeconvolutionShapesWithLargerChannels()
812         : ShapeDataset("InputShape",
813     {
814         // Multiple Vector Loops for all data types
815         TensorShape{ 5U, 5U, 35U }
816     })
817     {
818     }
819 };
820 
821 /** Data set containing tiny tensor shapes for direct convolution. */
822 class TinyDirectConvolutionShapes final : public ShapeDataset
823 {
824 public:
TinyDirectConvolutionShapes()825     TinyDirectConvolutionShapes()
826         : ShapeDataset("InputShape",
827     {
828         // Batch size 1
829         TensorShape{ 11U, 13U, 3U },
830                      TensorShape{ 7U, 27U, 3U }
831     })
832     {
833     }
834 };
835 /** Data set containing small tensor shapes for direct convolution. */
836 class SmallDirectConvolutionShapes final : public ShapeDataset
837 {
838 public:
SmallDirectConvolutionShapes()839     SmallDirectConvolutionShapes()
840         : ShapeDataset("InputShape",
841     {
842         // Batch size 1
843         TensorShape{ 32U, 37U, 3U },
844                      // Batch size 4
845                      TensorShape{ 6U, 9U, 5U, 4U },
846     })
847     {
848     }
849 };
850 
851 class SmallDirectConv3DShapes final : public ShapeDataset
852 {
853 public:
SmallDirectConv3DShapes()854     SmallDirectConv3DShapes()
855         : ShapeDataset("InputShape",
856     {
857         // Batch size 2
858         TensorShape{ 1U, 3U, 4U, 5U, 2U },
859                      // Batch size 3
860                      TensorShape{ 7U, 27U, 3U, 6U, 3U },
861                      // Batch size 1
862                      TensorShape{ 32U, 37U, 13U, 1U, 1U },
863     })
864     {
865     }
866 };
867 
868 /** Data set containing small tensor shapes for direct convolution. */
869 class SmallDirectConvolutionTensorShiftShapes final : public ShapeDataset
870 {
871 public:
SmallDirectConvolutionTensorShiftShapes()872     SmallDirectConvolutionTensorShiftShapes()
873         : ShapeDataset("InputShape",
874     {
875         // Batch size 1
876         TensorShape{ 32U, 37U, 3U },
877                      // Batch size 4
878                      TensorShape{ 32U, 37U, 3U, 4U },
879                      // Arbitrary batch size
880                      TensorShape{ 32U, 37U, 3U, 8U }
881     })
882     {
883     }
884 };
885 
886 /** Data set containing small grouped im2col tensor shapes. */
887 class GroupedIm2ColSmallShapes final : public ShapeDataset
888 {
889 public:
GroupedIm2ColSmallShapes()890     GroupedIm2ColSmallShapes()
891         : ShapeDataset("Shape",
892     {
893         TensorShape{ 11U, 11U, 48U },
894                      TensorShape{ 27U, 13U, 24U },
895                      TensorShape{ 128U, 64U, 12U, 3U },
896                      TensorShape{ 11U, 11U, 48U, 4U },
897                      TensorShape{ 27U, 13U, 24U, 4U },
898                      TensorShape{ 11U, 11U, 48U, 5U }
899     })
900     {
901     }
902 };
903 
904 /** Data set containing large grouped im2col tensor shapes. */
905 class GroupedIm2ColLargeShapes final : public ShapeDataset
906 {
907 public:
GroupedIm2ColLargeShapes()908     GroupedIm2ColLargeShapes()
909         : ShapeDataset("Shape",
910     {
911         TensorShape{ 153U, 231U, 12U },
912                      TensorShape{ 123U, 191U, 12U, 2U },
913     })
914     {
915     }
916 };
917 
918 /** Data set containing small grouped weights tensor shapes. */
919 class GroupedWeightsSmallShapes final : public ShapeDataset
920 {
921 public:
GroupedWeightsSmallShapes()922     GroupedWeightsSmallShapes()
923         : ShapeDataset("Shape",
924     {
925         TensorShape{ 3U, 3U, 48U, 120U },
926                      TensorShape{ 1U, 3U, 24U, 240U },
927                      TensorShape{ 3U, 1U, 12U, 480U },
928                      TensorShape{ 5U, 5U, 48U, 120U }
929     })
930     {
931     }
932 };
933 
934 /** Data set containing large grouped weights tensor shapes. */
935 class GroupedWeightsLargeShapes final : public ShapeDataset
936 {
937 public:
GroupedWeightsLargeShapes()938     GroupedWeightsLargeShapes()
939         : ShapeDataset("Shape",
940     {
941         TensorShape{ 9U, 9U, 96U, 240U },
942                      TensorShape{ 13U, 13U, 96U, 240U }
943     })
944     {
945     }
946 };
947 
948 /** Data set containing 2D tensor shapes for DepthConcatenateLayer. */
949 class DepthConcatenateLayerShapes final : public ShapeDataset
950 {
951 public:
DepthConcatenateLayerShapes()952     DepthConcatenateLayerShapes()
953         : ShapeDataset("Shape",
954     {
955         TensorShape{ 322U, 243U },
956                      TensorShape{ 463U, 879U },
957                      TensorShape{ 416U, 651U }
958     })
959     {
960     }
961 };
962 
963 /** Data set containing tensor shapes for ConcatenateLayer. */
964 class ConcatenateLayerShapes final : public ShapeDataset
965 {
966 public:
ConcatenateLayerShapes()967     ConcatenateLayerShapes()
968         : ShapeDataset("Shape",
969     {
970         TensorShape{ 232U, 65U, 3U },
971                      TensorShape{ 432U, 65U, 3U },
972                      TensorShape{ 124U, 65U, 3U },
973                      TensorShape{ 124U, 65U, 3U, 4U }
974     })
975     {
976     }
977 };
978 
979 /** Data set containing global pooling tensor shapes. */
980 class GlobalPoolingShapes final : public ShapeDataset
981 {
982 public:
GlobalPoolingShapes()983     GlobalPoolingShapes()
984         : ShapeDataset("Shape",
985     {
986         // Batch size 1
987         TensorShape{ 9U, 9U },
988                      TensorShape{ 13U, 13U, 2U },
989                      TensorShape{ 27U, 27U, 1U, 3U },
990                      // Batch size 4
991                      TensorShape{ 31U, 31U, 3U, 4U },
992                      TensorShape{ 34U, 34U, 2U, 4U }
993     })
994     {
995     }
996 };
997 /** Data set containing tiny softmax layer shapes. */
998 class SoftmaxLayerTinyShapes final : public ShapeDataset
999 {
1000 public:
SoftmaxLayerTinyShapes()1001     SoftmaxLayerTinyShapes()
1002         : ShapeDataset("Shape",
1003     {
1004         TensorShape{ 9U, 9U },
1005                      TensorShape{ 128U, 10U },
1006     })
1007     {
1008     }
1009 };
1010 
1011 /** Data set containing small softmax layer shapes. */
1012 class SoftmaxLayerSmallShapes final : public ShapeDataset
1013 {
1014 public:
SoftmaxLayerSmallShapes()1015     SoftmaxLayerSmallShapes()
1016         : ShapeDataset("Shape",
1017     {
1018         TensorShape{ 1U, 9U },
1019                      TensorShape{ 256U, 10U },
1020                      TensorShape{ 353U, 8U },
1021                      TensorShape{ 781U, 5U },
1022     })
1023     {
1024     }
1025 };
1026 
1027 /** Data set containing large softmax layer shapes. */
1028 class SoftmaxLayerLargeShapes final : public ShapeDataset
1029 {
1030 public:
SoftmaxLayerLargeShapes()1031     SoftmaxLayerLargeShapes()
1032         : ShapeDataset("Shape",
1033     {
1034         TensorShape{ 1000U, 10U }
1035 
1036     })
1037     {
1038     }
1039 };
1040 
1041 /** Data set containing large and small softmax layer 4D shapes. */
1042 class SoftmaxLayer4DShapes final : public ShapeDataset
1043 {
1044 public:
SoftmaxLayer4DShapes()1045     SoftmaxLayer4DShapes()
1046         : ShapeDataset("Shape",
1047     {
1048         TensorShape{ 9U, 9U, 9U, 9U },
1049                      TensorShape{ 31U, 10U, 1U, 9U },
1050     })
1051     {
1052     }
1053 };
1054 
1055 /** Data set containing 2D tensor shapes relative to an image size. */
1056 class SmallImageShapes final : public ShapeDataset
1057 {
1058 public:
SmallImageShapes()1059     SmallImageShapes()
1060         : ShapeDataset("Shape",
1061     {
1062         TensorShape{ 640U, 480U },
1063                      TensorShape{ 800U, 600U },
1064     })
1065     {
1066     }
1067 };
1068 
1069 /** Data set containing 2D tensor shapes relative to an image size. */
1070 class LargeImageShapes final : public ShapeDataset
1071 {
1072 public:
LargeImageShapes()1073     LargeImageShapes()
1074         : ShapeDataset("Shape",
1075     {
1076         TensorShape{ 1920U, 1080U },
1077                      TensorShape{ 2560U, 1536U },
1078                      TensorShape{ 3584U, 2048U }
1079     })
1080     {
1081     }
1082 };
1083 
1084 /** Data set containing small YOLO tensor shapes. */
1085 class SmallYOLOShapes final : public ShapeDataset
1086 {
1087 public:
SmallYOLOShapes()1088     SmallYOLOShapes()
1089         : ShapeDataset("Shape",
1090     {
1091         // Batch size 1
1092         TensorShape{ 11U, 11U, 270U },
1093                      TensorShape{ 27U, 13U, 90U },
1094                      TensorShape{ 13U, 12U, 45U, 2U },
1095     })
1096     {
1097     }
1098 };
1099 
1100 /** Data set containing large YOLO tensor shapes. */
1101 class LargeYOLOShapes final : public ShapeDataset
1102 {
1103 public:
LargeYOLOShapes()1104     LargeYOLOShapes()
1105         : ShapeDataset("Shape",
1106     {
1107         TensorShape{ 24U, 23U, 270U },
1108                      TensorShape{ 51U, 63U, 90U, 2U },
1109                      TensorShape{ 76U, 91U, 45U, 3U }
1110     })
1111     {
1112     }
1113 };
1114 
1115 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel */
1116 class SmallGEMMReshape2DShapes final : public ShapeDataset
1117 {
1118 public:
SmallGEMMReshape2DShapes()1119     SmallGEMMReshape2DShapes()
1120         : ShapeDataset("Shape",
1121     {
1122         TensorShape{ 63U, 72U },
1123     })
1124     {
1125     }
1126 };
1127 
1128 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */
1129 class SmallGEMMReshape3DShapes final : public ShapeDataset
1130 {
1131 public:
SmallGEMMReshape3DShapes()1132     SmallGEMMReshape3DShapes()
1133         : ShapeDataset("Shape",
1134     {
1135         TensorShape{ 63U, 9U, 8U },
1136     })
1137     {
1138     }
1139 };
1140 
1141 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel */
1142 class LargeGEMMReshape2DShapes final : public ShapeDataset
1143 {
1144 public:
LargeGEMMReshape2DShapes()1145     LargeGEMMReshape2DShapes()
1146         : ShapeDataset("Shape",
1147     {
1148         TensorShape{ 16U, 27U },
1149                      TensorShape{ 345U, 171U }
1150     })
1151     {
1152     }
1153 };
1154 
1155 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */
1156 class LargeGEMMReshape3DShapes final : public ShapeDataset
1157 {
1158 public:
LargeGEMMReshape3DShapes()1159     LargeGEMMReshape3DShapes()
1160         : ShapeDataset("Shape",
1161     {
1162         TensorShape{ 16U, 3U, 9U },
1163                      TensorShape{ 345U, 34U, 18U }
1164     })
1165     {
1166     }
1167 };
1168 
1169 /** Data set containing small 2D tensor shapes. */
1170 class Small2DNonMaxSuppressionShapes final : public ShapeDataset
1171 {
1172 public:
Small2DNonMaxSuppressionShapes()1173     Small2DNonMaxSuppressionShapes()
1174         : ShapeDataset("Shape",
1175     {
1176         TensorShape{ 4U, 7U },
1177                      TensorShape{ 4U, 13U },
1178                      TensorShape{ 4U, 64U }
1179     })
1180     {
1181     }
1182 };
1183 
1184 /** Data set containing large 2D tensor shapes. */
1185 class Large2DNonMaxSuppressionShapes final : public ShapeDataset
1186 {
1187 public:
Large2DNonMaxSuppressionShapes()1188     Large2DNonMaxSuppressionShapes()
1189         : ShapeDataset("Shape",
1190     {
1191         TensorShape{ 4U, 113U }
1192     })
1193     {
1194     }
1195 };
1196 
1197 } // namespace datasets
1198 } // namespace test
1199 } // namespace arm_compute
1200 #endif /* ARM_COMPUTE_TEST_SHAPE_DATASETS_H */
1201