1 /* 2 * Copyright (c) 2017-2023 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TEST_SHAPE_DATASETS_H 25 #define ARM_COMPUTE_TEST_SHAPE_DATASETS_H 26 27 #include "arm_compute/core/TensorShape.h" 28 #include "tests/framework/datasets/Datasets.h" 29 30 #include <type_traits> 31 32 namespace arm_compute 33 { 34 namespace test 35 { 36 namespace datasets 37 { 38 /** Parent type for all for shape datasets. */ 39 using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>; 40 41 /** Data set containing tiny 1D tensor shapes. */ 42 class Tiny1DShapes final : public ShapeDataset 43 { 44 public: Tiny1DShapes()45 Tiny1DShapes() 46 : ShapeDataset("Shape", 47 { 48 TensorShape{ 2U }, 49 TensorShape{ 3U }, 50 }) 51 { 52 } 53 }; 54 55 /** Data set containing small 1D tensor shapes. */ 56 class Small1DShapes final : public ShapeDataset 57 { 58 public: Small1DShapes()59 Small1DShapes() 60 : ShapeDataset("Shape", 61 { 62 TensorShape{ 128U }, 63 TensorShape{ 256U }, 64 TensorShape{ 512U }, 65 TensorShape{ 1024U } 66 }) 67 { 68 } 69 }; 70 71 /** Data set containing tiny 2D tensor shapes. */ 72 class Tiny2DShapes final : public ShapeDataset 73 { 74 public: Tiny2DShapes()75 Tiny2DShapes() 76 : ShapeDataset("Shape", 77 { 78 TensorShape{ 7U, 7U }, 79 TensorShape{ 11U, 13U }, 80 }) 81 { 82 } 83 }; 84 /** Data set containing small 2D tensor shapes. */ 85 class Small2DShapes final : public ShapeDataset 86 { 87 public: Small2DShapes()88 Small2DShapes() 89 : ShapeDataset("Shape", 90 { 91 TensorShape{ 1U, 7U }, 92 TensorShape{ 5U, 13U }, 93 TensorShape{ 32U, 64U } 94 }) 95 { 96 } 97 }; 98 99 /** Data set containing tiny 3D tensor shapes. */ 100 class Tiny3DShapes final : public ShapeDataset 101 { 102 public: Tiny3DShapes()103 Tiny3DShapes() 104 : ShapeDataset("Shape", 105 { 106 TensorShape{ 7U, 7U, 5U }, 107 TensorShape{ 23U, 13U, 9U }, 108 }) 109 { 110 } 111 }; 112 113 /** Data set containing small 3D tensor shapes. */ 114 class Small3DShapes final : public ShapeDataset 115 { 116 public: Small3DShapes()117 Small3DShapes() 118 : ShapeDataset("Shape", 119 { 120 TensorShape{ 1U, 7U, 7U }, 121 TensorShape{ 2U, 5U, 4U }, 122 123 TensorShape{ 7U, 7U, 5U }, 124 TensorShape{ 16U, 16U, 5U }, 125 TensorShape{ 27U, 13U, 37U }, 126 }) 127 { 128 } 129 }; 130 131 /** Data set containing tiny 4D tensor shapes. */ 132 class Tiny4DShapes final : public ShapeDataset 133 { 134 public: Tiny4DShapes()135 Tiny4DShapes() 136 : ShapeDataset("Shape", 137 { 138 TensorShape{ 2U, 7U, 5U, 3U }, 139 TensorShape{ 17U, 13U, 7U, 2U }, 140 }) 141 { 142 } 143 }; 144 /** Data set containing small 4D tensor shapes. */ 145 class Small4DShapes final : public ShapeDataset 146 { 147 public: Small4DShapes()148 Small4DShapes() 149 : ShapeDataset("Shape", 150 { 151 TensorShape{ 2U, 7U, 1U, 3U }, 152 TensorShape{ 7U, 7U, 5U, 3U }, 153 TensorShape{ 27U, 13U, 37U, 2U }, 154 TensorShape{ 128U, 64U, 21U, 3U } 155 }) 156 { 157 } 158 }; 159 160 /** Data set containing tiny tensor shapes. */ 161 class TinyShapes final : public ShapeDataset 162 { 163 public: TinyShapes()164 TinyShapes() 165 : ShapeDataset("Shape", 166 { 167 // Batch size 1 168 TensorShape{ 1U, 9U }, 169 TensorShape{ 27U, 13U, 2U }, 170 }) 171 { 172 } 173 }; 174 /** Data set containing small tensor shapes with none of the dimensions equal to 1 (unit). */ 175 class SmallNoneUnitShapes final : public ShapeDataset 176 { 177 public: SmallNoneUnitShapes()178 SmallNoneUnitShapes() 179 : ShapeDataset("Shape", 180 { 181 // Batch size 1 182 TensorShape{ 13U, 11U }, 183 TensorShape{ 16U, 16U }, 184 TensorShape{ 24U, 26U, 5U }, 185 TensorShape{ 7U, 7U, 17U, 2U }, 186 // Batch size 4 187 TensorShape{ 27U, 13U, 2U, 4U }, 188 // Arbitrary batch size 189 TensorShape{ 8U, 7U, 5U, 5U } 190 }) 191 { 192 } 193 }; 194 /** Data set containing small tensor shapes. */ 195 class SmallShapes final : public ShapeDataset 196 { 197 public: SmallShapes()198 SmallShapes() 199 : ShapeDataset("Shape", 200 { 201 // Batch size 1 202 TensorShape{ 3U, 11U }, 203 TensorShape{ 1U, 16U }, 204 TensorShape{ 27U, 13U, 7U }, 205 TensorShape{ 7U, 7U, 17U, 2U }, 206 // Batch size 4 and 2 SIMD iterations 207 TensorShape{ 33U, 13U, 2U, 4U }, 208 // Arbitrary batch size 209 TensorShape{ 11U, 11U, 3U, 5U } 210 }) 211 { 212 } 213 }; 214 215 /** Data set containing small tensor shapes. */ 216 class SmallShapesNoBatches final : public ShapeDataset 217 { 218 public: SmallShapesNoBatches()219 SmallShapesNoBatches() 220 : ShapeDataset("Shape", 221 { 222 // Batch size 1 223 TensorShape{ 3U, 11U }, 224 TensorShape{ 1U, 16U }, 225 TensorShape{ 27U, 13U, 7U }, 226 TensorShape{ 7U, 7U, 17U }, 227 TensorShape{ 33U, 13U, 2U }, 228 TensorShape{ 11U, 11U, 3U } 229 }) 230 { 231 } 232 }; 233 234 /** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */ 235 class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 236 { 237 public: TinyShapesBroadcast()238 TinyShapesBroadcast() 239 : ZipDataset<ShapeDataset, ShapeDataset>( 240 ShapeDataset("Shape0", 241 { 242 TensorShape{ 9U, 9U }, 243 TensorShape{ 10U, 2U, 14U, 2U }, 244 }), 245 ShapeDataset("Shape1", 246 { 247 TensorShape{ 9U, 1U, 9U }, 248 TensorShape{ 10U }, 249 })) 250 { 251 } 252 }; 253 /** Data set containing pairs of tiny tensor shapes that are broadcast compatible and can do in_place calculation. */ 254 class TinyShapesBroadcastInplace final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 255 { 256 public: TinyShapesBroadcastInplace()257 TinyShapesBroadcastInplace() 258 : ZipDataset<ShapeDataset, ShapeDataset>( 259 ShapeDataset("Shape0", 260 { 261 TensorShape{ 9U }, 262 TensorShape{ 10U, 2U, 14U, 2U }, 263 }), 264 ShapeDataset("Shape1", 265 { 266 TensorShape{ 9U, 1U, 9U }, 267 TensorShape{ 10U }, 268 })) 269 { 270 } 271 }; 272 /** Data set containing pairs of small tensor shapes that are broadcast compatible. */ 273 class SmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 274 { 275 public: SmallShapesBroadcast()276 SmallShapesBroadcast() 277 : ZipDataset<ShapeDataset, ShapeDataset>( 278 ShapeDataset("Shape0", 279 { 280 TensorShape{ 9U, 9U }, 281 TensorShape{ 27U, 13U, 2U }, 282 TensorShape{ 128U, 1U, 5U, 3U }, 283 TensorShape{ 9U, 9U, 3U, 4U }, 284 TensorShape{ 27U, 13U, 2U, 4U }, 285 TensorShape{ 1U, 1U, 1U, 5U }, 286 TensorShape{ 1U, 16U, 10U, 2U, 128U }, 287 TensorShape{ 1U, 16U, 10U, 2U, 128U } 288 }), 289 ShapeDataset("Shape1", 290 { 291 TensorShape{ 9U, 1U, 2U }, 292 TensorShape{ 1U, 13U, 2U }, 293 TensorShape{ 128U, 64U, 1U, 3U }, 294 TensorShape{ 9U, 1U, 3U }, 295 TensorShape{ 1U }, 296 TensorShape{ 9U, 9U, 3U, 5U }, 297 TensorShape{ 1U, 1U, 1U, 1U, 128U }, 298 TensorShape{ 128U } 299 })) 300 { 301 } 302 }; 303 304 class TemporaryLimitedSmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 305 { 306 public: TemporaryLimitedSmallShapesBroadcast()307 TemporaryLimitedSmallShapesBroadcast() 308 : ZipDataset<ShapeDataset, ShapeDataset>( 309 ShapeDataset("Shape0", 310 { 311 TensorShape{ 1U, 3U, 4U, 2U }, // LHS broadcast X 312 TensorShape{ 6U, 4U, 2U, 3U }, // RHS broadcast X 313 TensorShape{ 7U, 1U, 1U, 4U }, // LHS broadcast Y, Z 314 TensorShape{ 8U, 5U, 6U, 3U }, // RHS broadcast Y, Z 315 TensorShape{ 1U, 1U, 1U, 2U }, // LHS broadcast X, Y, Z 316 TensorShape{ 2U, 6U, 4U, 3U }, // RHS broadcast X, Y, Z 317 }), 318 ShapeDataset("Shape1", 319 { 320 TensorShape{ 5U, 3U, 4U, 2U }, 321 TensorShape{ 1U, 4U, 2U, 3U }, 322 TensorShape{ 7U, 2U, 3U, 4U }, 323 TensorShape{ 8U, 1U, 1U, 3U }, 324 TensorShape{ 4U, 7U, 3U, 2U }, 325 TensorShape{ 1U, 1U, 1U, 3U }, 326 })) 327 { 328 } 329 }; 330 331 class TemporaryLimitedLargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 332 { 333 public: TemporaryLimitedLargeShapesBroadcast()334 TemporaryLimitedLargeShapesBroadcast() 335 : ZipDataset<ShapeDataset, ShapeDataset>( 336 ShapeDataset("Shape0", 337 { 338 TensorShape{ 127U, 25U, 5U }, 339 TensorShape{ 485, 40U, 10U } 340 }), 341 ShapeDataset("Shape1", 342 { 343 TensorShape{ 1U, 1U, 1U }, // Broadcast in X, Y, Z 344 TensorShape{ 485U, 1U, 1U }, // Broadcast in Y, Z 345 })) 346 { 347 } 348 }; 349 350 /** Data set containing medium tensor shapes. */ 351 class MediumShapes final : public ShapeDataset 352 { 353 public: MediumShapes()354 MediumShapes() 355 : ShapeDataset("Shape", 356 { 357 // Batch size 1 358 TensorShape{ 37U, 37U }, 359 TensorShape{ 27U, 33U, 2U }, 360 // Arbitrary batch size 361 TensorShape{ 37U, 37U, 3U, 5U } 362 }) 363 { 364 } 365 }; 366 367 /** Data set containing medium 2D tensor shapes. */ 368 class Medium2DShapes final : public ShapeDataset 369 { 370 public: Medium2DShapes()371 Medium2DShapes() 372 : ShapeDataset("Shape", 373 { 374 TensorShape{ 42U, 37U }, 375 TensorShape{ 57U, 60U }, 376 TensorShape{ 128U, 64U }, 377 TensorShape{ 83U, 72U }, 378 TensorShape{ 40U, 40U } 379 }) 380 { 381 } 382 }; 383 384 /** Data set containing medium 3D tensor shapes. */ 385 class Medium3DShapes final : public ShapeDataset 386 { 387 public: Medium3DShapes()388 Medium3DShapes() 389 : ShapeDataset("Shape", 390 { 391 TensorShape{ 42U, 37U, 8U }, 392 TensorShape{ 57U, 60U, 13U }, 393 TensorShape{ 83U, 72U, 14U } 394 }) 395 { 396 } 397 }; 398 399 /** Data set containing medium 4D tensor shapes. */ 400 class Medium4DShapes final : public ShapeDataset 401 { 402 public: Medium4DShapes()403 Medium4DShapes() 404 : ShapeDataset("Shape", 405 { 406 TensorShape{ 42U, 37U, 8U, 15U }, 407 TensorShape{ 57U, 60U, 13U, 8U }, 408 TensorShape{ 83U, 72U, 14U, 5U } 409 }) 410 { 411 } 412 }; 413 414 /** Data set containing large tensor shapes. */ 415 class LargeShapes final : public ShapeDataset 416 { 417 public: LargeShapes()418 LargeShapes() 419 : ShapeDataset("Shape", 420 { 421 TensorShape{ 582U, 131U, 1U, 4U }, 422 }) 423 { 424 } 425 }; 426 427 /** Data set containing large tensor shapes. */ 428 class LargeShapesNoBatches final : public ShapeDataset 429 { 430 public: LargeShapesNoBatches()431 LargeShapesNoBatches() 432 : ShapeDataset("Shape", 433 { 434 TensorShape{ 582U, 131U, 2U }, 435 }) 436 { 437 } 438 }; 439 440 /** Data set containing pairs of large tensor shapes that are broadcast compatible. */ 441 class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 442 { 443 public: LargeShapesBroadcast()444 LargeShapesBroadcast() 445 : ZipDataset<ShapeDataset, ShapeDataset>( 446 ShapeDataset("Shape0", 447 { 448 TensorShape{ 1921U, 541U }, 449 TensorShape{ 1U, 485U, 2U, 3U }, 450 TensorShape{ 4159U, 1U }, 451 TensorShape{ 799U } 452 }), 453 ShapeDataset("Shape1", 454 { 455 TensorShape{ 1921U, 1U, 2U }, 456 TensorShape{ 641U, 1U, 2U, 3U }, 457 TensorShape{ 1U, 127U, 25U }, 458 TensorShape{ 799U, 595U, 1U, 4U } 459 })) 460 { 461 } 462 }; 463 464 /** Data set containing large 1D tensor shapes. */ 465 class Large1DShapes final : public ShapeDataset 466 { 467 public: Large1DShapes()468 Large1DShapes() 469 : ShapeDataset("Shape", 470 { 471 TensorShape{ 1245U } 472 }) 473 { 474 } 475 }; 476 477 /** Data set containing large 2D tensor shapes. */ 478 class Large2DShapes final : public ShapeDataset 479 { 480 public: Large2DShapes()481 Large2DShapes() 482 : ShapeDataset("Shape", 483 { 484 TensorShape{ 1245U, 652U } 485 }) 486 { 487 } 488 }; 489 490 /** Data set containing large 3D tensor shapes. */ 491 class Large3DShapes final : public ShapeDataset 492 { 493 public: Large3DShapes()494 Large3DShapes() 495 : ShapeDataset("Shape", 496 { 497 TensorShape{ 320U, 240U, 3U } 498 }) 499 { 500 } 501 }; 502 503 /** Data set containing large 4D tensor shapes. */ 504 class Large4DShapes final : public ShapeDataset 505 { 506 public: Large4DShapes()507 Large4DShapes() 508 : ShapeDataset("Shape", 509 { 510 TensorShape{ 320U, 123U, 3U, 3U } 511 }) 512 { 513 } 514 }; 515 516 /** Data set containing small 3x3 tensor shapes. */ 517 class Small3x3Shapes final : public ShapeDataset 518 { 519 public: Small3x3Shapes()520 Small3x3Shapes() 521 : ShapeDataset("Shape", 522 { 523 TensorShape{ 3U, 3U, 7U, 4U }, 524 TensorShape{ 3U, 3U, 4U, 13U }, 525 TensorShape{ 3U, 3U, 3U, 5U }, 526 }) 527 { 528 } 529 }; 530 531 /** Data set containing small 3x1 tensor shapes. */ 532 class Small3x1Shapes final : public ShapeDataset 533 { 534 public: Small3x1Shapes()535 Small3x1Shapes() 536 : ShapeDataset("Shape", 537 { 538 TensorShape{ 3U, 1U, 7U, 4U }, 539 TensorShape{ 3U, 1U, 4U, 13U }, 540 TensorShape{ 3U, 1U, 3U, 5U }, 541 }) 542 { 543 } 544 }; 545 546 /** Data set containing small 1x3 tensor shapes. */ 547 class Small1x3Shapes final : public ShapeDataset 548 { 549 public: Small1x3Shapes()550 Small1x3Shapes() 551 : ShapeDataset("Shape", 552 { 553 TensorShape{ 1U, 3U, 7U, 4U }, 554 TensorShape{ 1U, 3U, 4U, 13U }, 555 TensorShape{ 1U, 3U, 3U, 5U }, 556 }) 557 { 558 } 559 }; 560 561 /** Data set containing large 3x3 tensor shapes. */ 562 class Large3x3Shapes final : public ShapeDataset 563 { 564 public: Large3x3Shapes()565 Large3x3Shapes() 566 : ShapeDataset("Shape", 567 { 568 TensorShape{ 3U, 3U, 32U, 64U }, 569 TensorShape{ 3U, 3U, 51U, 13U }, 570 TensorShape{ 3U, 3U, 53U, 47U }, 571 }) 572 { 573 } 574 }; 575 576 /** Data set containing large 3x1 tensor shapes. */ 577 class Large3x1Shapes final : public ShapeDataset 578 { 579 public: Large3x1Shapes()580 Large3x1Shapes() 581 : ShapeDataset("Shape", 582 { 583 TensorShape{ 3U, 1U, 32U, 64U }, 584 TensorShape{ 3U, 1U, 51U, 13U }, 585 TensorShape{ 3U, 1U, 53U, 47U }, 586 }) 587 { 588 } 589 }; 590 591 /** Data set containing large 1x3 tensor shapes. */ 592 class Large1x3Shapes final : public ShapeDataset 593 { 594 public: Large1x3Shapes()595 Large1x3Shapes() 596 : ShapeDataset("Shape", 597 { 598 TensorShape{ 1U, 3U, 32U, 64U }, 599 TensorShape{ 1U, 3U, 51U, 13U }, 600 TensorShape{ 1U, 3U, 53U, 47U }, 601 }) 602 { 603 } 604 }; 605 606 /** Data set containing small 5x5 tensor shapes. */ 607 class Small5x5Shapes final : public ShapeDataset 608 { 609 public: Small5x5Shapes()610 Small5x5Shapes() 611 : ShapeDataset("Shape", 612 { 613 TensorShape{ 5U, 5U, 7U, 4U }, 614 TensorShape{ 5U, 5U, 4U, 13U }, 615 TensorShape{ 5U, 5U, 3U, 5U }, 616 }) 617 { 618 } 619 }; 620 621 /** Data set containing small 5D tensor shapes. */ 622 class Small5dShapes final : public ShapeDataset 623 { 624 public: Small5dShapes()625 Small5dShapes() 626 : ShapeDataset("Shape", 627 { 628 TensorShape{ 5U, 5U, 7U, 4U, 3U }, 629 TensorShape{ 5U, 5U, 4U, 13U, 2U }, 630 TensorShape{ 5U, 5U, 3U, 5U, 2U }, 631 }) 632 { 633 } 634 }; 635 636 /** Data set containing large 5x5 tensor shapes. */ 637 class Large5x5Shapes final : public ShapeDataset 638 { 639 public: Large5x5Shapes()640 Large5x5Shapes() 641 : ShapeDataset("Shape", 642 { 643 TensorShape{ 5U, 5U, 32U, 64U } 644 }) 645 { 646 } 647 }; 648 649 /** Data set containing large 5D tensor shapes. */ 650 class Large5dShapes final : public ShapeDataset 651 { 652 public: Large5dShapes()653 Large5dShapes() 654 : ShapeDataset("Shape", 655 { 656 TensorShape{ 30U, 40U, 30U, 32U, 3U } 657 }) 658 { 659 } 660 }; 661 662 /** Data set containing small 5x1 tensor shapes. */ 663 class Small5x1Shapes final : public ShapeDataset 664 { 665 public: Small5x1Shapes()666 Small5x1Shapes() 667 : ShapeDataset("Shape", 668 { 669 TensorShape{ 5U, 1U, 7U, 4U } 670 }) 671 { 672 } 673 }; 674 675 /** Data set containing large 5x1 tensor shapes. */ 676 class Large5x1Shapes final : public ShapeDataset 677 { 678 public: Large5x1Shapes()679 Large5x1Shapes() 680 : ShapeDataset("Shape", 681 { 682 TensorShape{ 5U, 1U, 32U, 64U } 683 }) 684 { 685 } 686 }; 687 688 /** Data set containing small 1x5 tensor shapes. */ 689 class Small1x5Shapes final : public ShapeDataset 690 { 691 public: Small1x5Shapes()692 Small1x5Shapes() 693 : ShapeDataset("Shape", 694 { 695 TensorShape{ 1U, 5U, 7U, 4U } 696 }) 697 { 698 } 699 }; 700 701 /** Data set containing large 1x5 tensor shapes. */ 702 class Large1x5Shapes final : public ShapeDataset 703 { 704 public: Large1x5Shapes()705 Large1x5Shapes() 706 : ShapeDataset("Shape", 707 { 708 TensorShape{ 1U, 5U, 32U, 64U } 709 }) 710 { 711 } 712 }; 713 714 /** Data set containing small 1x7 tensor shapes. */ 715 class Small1x7Shapes final : public ShapeDataset 716 { 717 public: Small1x7Shapes()718 Small1x7Shapes() 719 : ShapeDataset("Shape", 720 { 721 TensorShape{ 1U, 7U, 7U, 4U } 722 }) 723 { 724 } 725 }; 726 727 /** Data set containing large 1x7 tensor shapes. */ 728 class Large1x7Shapes final : public ShapeDataset 729 { 730 public: Large1x7Shapes()731 Large1x7Shapes() 732 : ShapeDataset("Shape", 733 { 734 TensorShape{ 1U, 7U, 32U, 64U } 735 }) 736 { 737 } 738 }; 739 740 /** Data set containing small 7x7 tensor shapes. */ 741 class Small7x7Shapes final : public ShapeDataset 742 { 743 public: Small7x7Shapes()744 Small7x7Shapes() 745 : ShapeDataset("Shape", 746 { 747 TensorShape{ 7U, 7U, 7U, 4U } 748 }) 749 { 750 } 751 }; 752 753 /** Data set containing large 7x7 tensor shapes. */ 754 class Large7x7Shapes final : public ShapeDataset 755 { 756 public: Large7x7Shapes()757 Large7x7Shapes() 758 : ShapeDataset("Shape", 759 { 760 TensorShape{ 7U, 7U, 32U, 64U } 761 }) 762 { 763 } 764 }; 765 766 /** Data set containing small 7x1 tensor shapes. */ 767 class Small7x1Shapes final : public ShapeDataset 768 { 769 public: Small7x1Shapes()770 Small7x1Shapes() 771 : ShapeDataset("Shape", 772 { 773 TensorShape{ 7U, 1U, 7U, 4U } 774 }) 775 { 776 } 777 }; 778 779 /** Data set containing large 7x1 tensor shapes. */ 780 class Large7x1Shapes final : public ShapeDataset 781 { 782 public: Large7x1Shapes()783 Large7x1Shapes() 784 : ShapeDataset("Shape", 785 { 786 TensorShape{ 7U, 1U, 32U, 64U } 787 }) 788 { 789 } 790 }; 791 792 /** Data set containing small tensor shapes for deconvolution. */ 793 class SmallDeconvolutionShapes final : public ShapeDataset 794 { 795 public: SmallDeconvolutionShapes()796 SmallDeconvolutionShapes() 797 : ShapeDataset("InputShape", 798 { 799 // Multiple Vector Loops for FP32 800 TensorShape{ 5U, 4U, 3U, 2U }, 801 TensorShape{ 5U, 5U, 3U }, 802 TensorShape{ 11U, 13U, 4U, 3U } 803 }) 804 { 805 } 806 }; 807 808 class SmallDeconvolutionShapesWithLargerChannels final : public ShapeDataset 809 { 810 public: SmallDeconvolutionShapesWithLargerChannels()811 SmallDeconvolutionShapesWithLargerChannels() 812 : ShapeDataset("InputShape", 813 { 814 // Multiple Vector Loops for all data types 815 TensorShape{ 5U, 5U, 35U } 816 }) 817 { 818 } 819 }; 820 821 /** Data set containing tiny tensor shapes for direct convolution. */ 822 class TinyDirectConvolutionShapes final : public ShapeDataset 823 { 824 public: TinyDirectConvolutionShapes()825 TinyDirectConvolutionShapes() 826 : ShapeDataset("InputShape", 827 { 828 // Batch size 1 829 TensorShape{ 11U, 13U, 3U }, 830 TensorShape{ 7U, 27U, 3U } 831 }) 832 { 833 } 834 }; 835 /** Data set containing small tensor shapes for direct convolution. */ 836 class SmallDirectConvolutionShapes final : public ShapeDataset 837 { 838 public: SmallDirectConvolutionShapes()839 SmallDirectConvolutionShapes() 840 : ShapeDataset("InputShape", 841 { 842 // Batch size 1 843 TensorShape{ 32U, 37U, 3U }, 844 // Batch size 4 845 TensorShape{ 6U, 9U, 5U, 4U }, 846 }) 847 { 848 } 849 }; 850 851 class SmallDirectConv3DShapes final : public ShapeDataset 852 { 853 public: SmallDirectConv3DShapes()854 SmallDirectConv3DShapes() 855 : ShapeDataset("InputShape", 856 { 857 // Batch size 2 858 TensorShape{ 1U, 3U, 4U, 5U, 2U }, 859 // Batch size 3 860 TensorShape{ 7U, 27U, 3U, 6U, 3U }, 861 // Batch size 1 862 TensorShape{ 32U, 37U, 13U, 1U, 1U }, 863 }) 864 { 865 } 866 }; 867 868 /** Data set containing small tensor shapes for direct convolution. */ 869 class SmallDirectConvolutionTensorShiftShapes final : public ShapeDataset 870 { 871 public: SmallDirectConvolutionTensorShiftShapes()872 SmallDirectConvolutionTensorShiftShapes() 873 : ShapeDataset("InputShape", 874 { 875 // Batch size 1 876 TensorShape{ 32U, 37U, 3U }, 877 // Batch size 4 878 TensorShape{ 32U, 37U, 3U, 4U }, 879 // Arbitrary batch size 880 TensorShape{ 32U, 37U, 3U, 8U } 881 }) 882 { 883 } 884 }; 885 886 /** Data set containing small grouped im2col tensor shapes. */ 887 class GroupedIm2ColSmallShapes final : public ShapeDataset 888 { 889 public: GroupedIm2ColSmallShapes()890 GroupedIm2ColSmallShapes() 891 : ShapeDataset("Shape", 892 { 893 TensorShape{ 11U, 11U, 48U }, 894 TensorShape{ 27U, 13U, 24U }, 895 TensorShape{ 128U, 64U, 12U, 3U }, 896 TensorShape{ 11U, 11U, 48U, 4U }, 897 TensorShape{ 27U, 13U, 24U, 4U }, 898 TensorShape{ 11U, 11U, 48U, 5U } 899 }) 900 { 901 } 902 }; 903 904 /** Data set containing large grouped im2col tensor shapes. */ 905 class GroupedIm2ColLargeShapes final : public ShapeDataset 906 { 907 public: GroupedIm2ColLargeShapes()908 GroupedIm2ColLargeShapes() 909 : ShapeDataset("Shape", 910 { 911 TensorShape{ 153U, 231U, 12U }, 912 TensorShape{ 123U, 191U, 12U, 2U }, 913 }) 914 { 915 } 916 }; 917 918 /** Data set containing small grouped weights tensor shapes. */ 919 class GroupedWeightsSmallShapes final : public ShapeDataset 920 { 921 public: GroupedWeightsSmallShapes()922 GroupedWeightsSmallShapes() 923 : ShapeDataset("Shape", 924 { 925 TensorShape{ 3U, 3U, 48U, 120U }, 926 TensorShape{ 1U, 3U, 24U, 240U }, 927 TensorShape{ 3U, 1U, 12U, 480U }, 928 TensorShape{ 5U, 5U, 48U, 120U } 929 }) 930 { 931 } 932 }; 933 934 /** Data set containing large grouped weights tensor shapes. */ 935 class GroupedWeightsLargeShapes final : public ShapeDataset 936 { 937 public: GroupedWeightsLargeShapes()938 GroupedWeightsLargeShapes() 939 : ShapeDataset("Shape", 940 { 941 TensorShape{ 9U, 9U, 96U, 240U }, 942 TensorShape{ 13U, 13U, 96U, 240U } 943 }) 944 { 945 } 946 }; 947 948 /** Data set containing 2D tensor shapes for DepthConcatenateLayer. */ 949 class DepthConcatenateLayerShapes final : public ShapeDataset 950 { 951 public: DepthConcatenateLayerShapes()952 DepthConcatenateLayerShapes() 953 : ShapeDataset("Shape", 954 { 955 TensorShape{ 322U, 243U }, 956 TensorShape{ 463U, 879U }, 957 TensorShape{ 416U, 651U } 958 }) 959 { 960 } 961 }; 962 963 /** Data set containing tensor shapes for ConcatenateLayer. */ 964 class ConcatenateLayerShapes final : public ShapeDataset 965 { 966 public: ConcatenateLayerShapes()967 ConcatenateLayerShapes() 968 : ShapeDataset("Shape", 969 { 970 TensorShape{ 232U, 65U, 3U }, 971 TensorShape{ 432U, 65U, 3U }, 972 TensorShape{ 124U, 65U, 3U }, 973 TensorShape{ 124U, 65U, 3U, 4U } 974 }) 975 { 976 } 977 }; 978 979 /** Data set containing global pooling tensor shapes. */ 980 class GlobalPoolingShapes final : public ShapeDataset 981 { 982 public: GlobalPoolingShapes()983 GlobalPoolingShapes() 984 : ShapeDataset("Shape", 985 { 986 // Batch size 1 987 TensorShape{ 9U, 9U }, 988 TensorShape{ 13U, 13U, 2U }, 989 TensorShape{ 27U, 27U, 1U, 3U }, 990 // Batch size 4 991 TensorShape{ 31U, 31U, 3U, 4U }, 992 TensorShape{ 34U, 34U, 2U, 4U } 993 }) 994 { 995 } 996 }; 997 /** Data set containing tiny softmax layer shapes. */ 998 class SoftmaxLayerTinyShapes final : public ShapeDataset 999 { 1000 public: SoftmaxLayerTinyShapes()1001 SoftmaxLayerTinyShapes() 1002 : ShapeDataset("Shape", 1003 { 1004 TensorShape{ 9U, 9U }, 1005 TensorShape{ 128U, 10U }, 1006 }) 1007 { 1008 } 1009 }; 1010 1011 /** Data set containing small softmax layer shapes. */ 1012 class SoftmaxLayerSmallShapes final : public ShapeDataset 1013 { 1014 public: SoftmaxLayerSmallShapes()1015 SoftmaxLayerSmallShapes() 1016 : ShapeDataset("Shape", 1017 { 1018 TensorShape{ 1U, 9U }, 1019 TensorShape{ 256U, 10U }, 1020 TensorShape{ 353U, 8U }, 1021 TensorShape{ 781U, 5U }, 1022 }) 1023 { 1024 } 1025 }; 1026 1027 /** Data set containing large softmax layer shapes. */ 1028 class SoftmaxLayerLargeShapes final : public ShapeDataset 1029 { 1030 public: SoftmaxLayerLargeShapes()1031 SoftmaxLayerLargeShapes() 1032 : ShapeDataset("Shape", 1033 { 1034 TensorShape{ 1000U, 10U } 1035 1036 }) 1037 { 1038 } 1039 }; 1040 1041 /** Data set containing large and small softmax layer 4D shapes. */ 1042 class SoftmaxLayer4DShapes final : public ShapeDataset 1043 { 1044 public: SoftmaxLayer4DShapes()1045 SoftmaxLayer4DShapes() 1046 : ShapeDataset("Shape", 1047 { 1048 TensorShape{ 9U, 9U, 9U, 9U }, 1049 TensorShape{ 31U, 10U, 1U, 9U }, 1050 }) 1051 { 1052 } 1053 }; 1054 1055 /** Data set containing 2D tensor shapes relative to an image size. */ 1056 class SmallImageShapes final : public ShapeDataset 1057 { 1058 public: SmallImageShapes()1059 SmallImageShapes() 1060 : ShapeDataset("Shape", 1061 { 1062 TensorShape{ 640U, 480U }, 1063 TensorShape{ 800U, 600U }, 1064 }) 1065 { 1066 } 1067 }; 1068 1069 /** Data set containing 2D tensor shapes relative to an image size. */ 1070 class LargeImageShapes final : public ShapeDataset 1071 { 1072 public: LargeImageShapes()1073 LargeImageShapes() 1074 : ShapeDataset("Shape", 1075 { 1076 TensorShape{ 1920U, 1080U }, 1077 TensorShape{ 2560U, 1536U }, 1078 TensorShape{ 3584U, 2048U } 1079 }) 1080 { 1081 } 1082 }; 1083 1084 /** Data set containing small YOLO tensor shapes. */ 1085 class SmallYOLOShapes final : public ShapeDataset 1086 { 1087 public: SmallYOLOShapes()1088 SmallYOLOShapes() 1089 : ShapeDataset("Shape", 1090 { 1091 // Batch size 1 1092 TensorShape{ 11U, 11U, 270U }, 1093 TensorShape{ 27U, 13U, 90U }, 1094 TensorShape{ 13U, 12U, 45U, 2U }, 1095 }) 1096 { 1097 } 1098 }; 1099 1100 /** Data set containing large YOLO tensor shapes. */ 1101 class LargeYOLOShapes final : public ShapeDataset 1102 { 1103 public: LargeYOLOShapes()1104 LargeYOLOShapes() 1105 : ShapeDataset("Shape", 1106 { 1107 TensorShape{ 24U, 23U, 270U }, 1108 TensorShape{ 51U, 63U, 90U, 2U }, 1109 TensorShape{ 76U, 91U, 45U, 3U } 1110 }) 1111 { 1112 } 1113 }; 1114 1115 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel */ 1116 class SmallGEMMReshape2DShapes final : public ShapeDataset 1117 { 1118 public: SmallGEMMReshape2DShapes()1119 SmallGEMMReshape2DShapes() 1120 : ShapeDataset("Shape", 1121 { 1122 TensorShape{ 63U, 72U }, 1123 }) 1124 { 1125 } 1126 }; 1127 1128 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */ 1129 class SmallGEMMReshape3DShapes final : public ShapeDataset 1130 { 1131 public: SmallGEMMReshape3DShapes()1132 SmallGEMMReshape3DShapes() 1133 : ShapeDataset("Shape", 1134 { 1135 TensorShape{ 63U, 9U, 8U }, 1136 }) 1137 { 1138 } 1139 }; 1140 1141 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel */ 1142 class LargeGEMMReshape2DShapes final : public ShapeDataset 1143 { 1144 public: LargeGEMMReshape2DShapes()1145 LargeGEMMReshape2DShapes() 1146 : ShapeDataset("Shape", 1147 { 1148 TensorShape{ 16U, 27U }, 1149 TensorShape{ 345U, 171U } 1150 }) 1151 { 1152 } 1153 }; 1154 1155 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */ 1156 class LargeGEMMReshape3DShapes final : public ShapeDataset 1157 { 1158 public: LargeGEMMReshape3DShapes()1159 LargeGEMMReshape3DShapes() 1160 : ShapeDataset("Shape", 1161 { 1162 TensorShape{ 16U, 3U, 9U }, 1163 TensorShape{ 345U, 34U, 18U } 1164 }) 1165 { 1166 } 1167 }; 1168 1169 /** Data set containing small 2D tensor shapes. */ 1170 class Small2DNonMaxSuppressionShapes final : public ShapeDataset 1171 { 1172 public: Small2DNonMaxSuppressionShapes()1173 Small2DNonMaxSuppressionShapes() 1174 : ShapeDataset("Shape", 1175 { 1176 TensorShape{ 4U, 7U }, 1177 TensorShape{ 4U, 13U }, 1178 TensorShape{ 4U, 64U } 1179 }) 1180 { 1181 } 1182 }; 1183 1184 /** Data set containing large 2D tensor shapes. */ 1185 class Large2DNonMaxSuppressionShapes final : public ShapeDataset 1186 { 1187 public: Large2DNonMaxSuppressionShapes()1188 Large2DNonMaxSuppressionShapes() 1189 : ShapeDataset("Shape", 1190 { 1191 TensorShape{ 4U, 113U } 1192 }) 1193 { 1194 } 1195 }; 1196 1197 } // namespace datasets 1198 } // namespace test 1199 } // namespace arm_compute 1200 #endif /* ARM_COMPUTE_TEST_SHAPE_DATASETS_H */ 1201