1 /* 2 * Copyright (c) 2017-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TEST_SHAPE_DATASETS_H 25 #define ARM_COMPUTE_TEST_SHAPE_DATASETS_H 26 27 #include "arm_compute/core/TensorShape.h" 28 #include "tests/framework/datasets/Datasets.h" 29 30 #include <type_traits> 31 32 namespace arm_compute 33 { 34 namespace test 35 { 36 namespace datasets 37 { 38 /** Parent type for all for shape datasets. */ 39 using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>; 40 41 /** Data set containing tiny 1D tensor shapes. */ 42 class Tiny1DShapes final : public ShapeDataset 43 { 44 public: Tiny1DShapes()45 Tiny1DShapes() 46 : ShapeDataset("Shape", 47 { 48 TensorShape{ 2U }, 49 TensorShape{ 3U }, 50 }) 51 { 52 } 53 }; 54 55 /** Data set containing small 1D tensor shapes. */ 56 class Small1DShapes final : public ShapeDataset 57 { 58 public: Small1DShapes()59 Small1DShapes() 60 : ShapeDataset("Shape", 61 { 62 TensorShape{ 128U }, 63 TensorShape{ 256U }, 64 TensorShape{ 512U }, 65 TensorShape{ 1024U } 66 }) 67 { 68 } 69 }; 70 71 /** Data set containing tiny 2D tensor shapes. */ 72 class Tiny2DShapes final : public ShapeDataset 73 { 74 public: Tiny2DShapes()75 Tiny2DShapes() 76 : ShapeDataset("Shape", 77 { 78 TensorShape{ 7U, 7U }, 79 TensorShape{ 11U, 13U }, 80 }) 81 { 82 } 83 }; 84 /** Data set containing small 2D tensor shapes. */ 85 class Small2DShapes final : public ShapeDataset 86 { 87 public: Small2DShapes()88 Small2DShapes() 89 : ShapeDataset("Shape", 90 { 91 TensorShape{ 7U, 7U }, 92 TensorShape{ 27U, 13U }, 93 TensorShape{ 128U, 64U } 94 }) 95 { 96 } 97 }; 98 99 /** Data set containing tiny 3D tensor shapes. */ 100 class Tiny3DShapes final : public ShapeDataset 101 { 102 public: Tiny3DShapes()103 Tiny3DShapes() 104 : ShapeDataset("Shape", 105 { 106 TensorShape{ 7U, 7U, 5U }, 107 TensorShape{ 23U, 13U, 9U }, 108 }) 109 { 110 } 111 }; 112 113 /** Data set containing small 3D tensor shapes. */ 114 class Small3DShapes final : public ShapeDataset 115 { 116 public: Small3DShapes()117 Small3DShapes() 118 : ShapeDataset("Shape", 119 { 120 TensorShape{ 1U, 7U, 7U }, 121 TensorShape{ 2U, 5U, 4U }, 122 123 TensorShape{ 7U, 7U, 5U }, 124 TensorShape{ 16U, 16U, 5U }, 125 TensorShape{ 27U, 13U, 37U }, 126 }) 127 { 128 } 129 }; 130 131 /** Data set containing tiny 4D tensor shapes. */ 132 class Tiny4DShapes final : public ShapeDataset 133 { 134 public: Tiny4DShapes()135 Tiny4DShapes() 136 : ShapeDataset("Shape", 137 { 138 TensorShape{ 7U, 7U, 5U, 3U }, 139 TensorShape{ 17U, 13U, 7U, 2U }, 140 }) 141 { 142 } 143 }; 144 /** Data set containing small 4D tensor shapes. */ 145 class Small4DShapes final : public ShapeDataset 146 { 147 public: Small4DShapes()148 Small4DShapes() 149 : ShapeDataset("Shape", 150 { 151 TensorShape{ 2U, 7U, 1U, 3U }, 152 TensorShape{ 7U, 7U, 5U, 3U }, 153 TensorShape{ 27U, 13U, 37U, 2U }, 154 TensorShape{ 128U, 64U, 21U, 3U } 155 }) 156 { 157 } 158 }; 159 160 /** Data set containing tiny tensor shapes. */ 161 class TinyShapes final : public ShapeDataset 162 { 163 public: TinyShapes()164 TinyShapes() 165 : ShapeDataset("Shape", 166 { 167 // Batch size 1 168 TensorShape{ 9U, 9U }, 169 TensorShape{ 27U, 13U, 2U }, 170 }) 171 { 172 } 173 }; 174 /** Data set containing small tensor shapes. */ 175 class SmallShapes final : public ShapeDataset 176 { 177 public: SmallShapes()178 SmallShapes() 179 : ShapeDataset("Shape", 180 { 181 // Batch size 1 182 TensorShape{ 11U, 11U }, 183 TensorShape{ 16U, 16U }, 184 TensorShape{ 27U, 13U, 7U }, 185 TensorShape{ 31U, 27U, 17U, 2U }, 186 // Batch size 4 187 TensorShape{ 27U, 13U, 2U, 4U }, 188 // Arbitrary batch size 189 TensorShape{ 11U, 11U, 3U, 5U } 190 }) 191 { 192 } 193 }; 194 195 /** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */ 196 class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 197 { 198 public: TinyShapesBroadcast()199 TinyShapesBroadcast() 200 : ZipDataset<ShapeDataset, ShapeDataset>( 201 ShapeDataset("Shape0", 202 { 203 TensorShape{ 9U, 9U }, 204 TensorShape{ 10U, 2U, 14U, 2U }, 205 }), 206 ShapeDataset("Shape1", 207 { 208 TensorShape{ 9U, 1U, 9U }, 209 TensorShape{ 10U }, 210 })) 211 { 212 } 213 }; 214 /** Data set containing pairs of small tensor shapes that are broadcast compatible. */ 215 class SmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 216 { 217 public: SmallShapesBroadcast()218 SmallShapesBroadcast() 219 : ZipDataset<ShapeDataset, ShapeDataset>( 220 ShapeDataset("Shape0", 221 { 222 TensorShape{ 9U, 9U }, 223 TensorShape{ 27U, 13U, 2U }, 224 TensorShape{ 128U, 1U, 5U, 3U }, 225 TensorShape{ 9U, 9U, 3U, 4U }, 226 TensorShape{ 27U, 13U, 2U, 4U }, 227 TensorShape{ 1U, 1U, 1U, 5U }, 228 TensorShape{ 1U, 16U, 10U, 2U, 128U }, 229 TensorShape{ 1U, 16U, 10U, 2U, 128U } 230 }), 231 ShapeDataset("Shape1", 232 { 233 TensorShape{ 9U, 1U, 2U }, 234 TensorShape{ 1U, 13U, 2U }, 235 TensorShape{ 128U, 64U, 1U, 3U }, 236 TensorShape{ 9U, 1U, 3U }, 237 TensorShape{ 1U }, 238 TensorShape{ 9U, 9U, 3U, 5U }, 239 TensorShape{ 1U, 1U, 1U, 1U, 128U }, 240 TensorShape{ 128U } 241 })) 242 { 243 } 244 }; 245 246 /** Data set containing medium tensor shapes. */ 247 class MediumShapes final : public ShapeDataset 248 { 249 public: MediumShapes()250 MediumShapes() 251 : ShapeDataset("Shape", 252 { 253 // Batch size 1 254 TensorShape{ 37U, 37U }, 255 TensorShape{ 27U, 33U, 2U }, 256 // Arbitrary batch size 257 TensorShape{ 37U, 37U, 3U, 5U } 258 }) 259 { 260 } 261 }; 262 263 /** Data set containing medium 2D tensor shapes. */ 264 class Medium2DShapes final : public ShapeDataset 265 { 266 public: Medium2DShapes()267 Medium2DShapes() 268 : ShapeDataset("Shape", 269 { 270 TensorShape{ 42U, 37U }, 271 TensorShape{ 57U, 60U }, 272 TensorShape{ 128U, 64U }, 273 TensorShape{ 83U, 72U }, 274 TensorShape{ 40U, 40U } 275 }) 276 { 277 } 278 }; 279 280 /** Data set containing medium 3D tensor shapes. */ 281 class Medium3DShapes final : public ShapeDataset 282 { 283 public: Medium3DShapes()284 Medium3DShapes() 285 : ShapeDataset("Shape", 286 { 287 TensorShape{ 42U, 37U, 8U }, 288 TensorShape{ 57U, 60U, 13U }, 289 TensorShape{ 83U, 72U, 14U } 290 }) 291 { 292 } 293 }; 294 295 /** Data set containing medium 4D tensor shapes. */ 296 class Medium4DShapes final : public ShapeDataset 297 { 298 public: Medium4DShapes()299 Medium4DShapes() 300 : ShapeDataset("Shape", 301 { 302 TensorShape{ 42U, 37U, 8U, 15U }, 303 TensorShape{ 57U, 60U, 13U, 8U }, 304 TensorShape{ 83U, 72U, 14U, 5U } 305 }) 306 { 307 } 308 }; 309 310 /** Data set containing large tensor shapes. */ 311 class LargeShapes final : public ShapeDataset 312 { 313 public: LargeShapes()314 LargeShapes() 315 : ShapeDataset("Shape", 316 { 317 TensorShape{ 582U, 131U, 1U, 4U }, 318 }) 319 { 320 } 321 }; 322 323 /** Data set containing pairs of large tensor shapes that are broadcast compatible. */ 324 class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 325 { 326 public: LargeShapesBroadcast()327 LargeShapesBroadcast() 328 : ZipDataset<ShapeDataset, ShapeDataset>( 329 ShapeDataset("Shape0", 330 { 331 TensorShape{ 1921U, 541U }, 332 TensorShape{ 1U, 485U, 2U, 3U }, 333 TensorShape{ 4159U, 1U }, 334 TensorShape{ 799U } 335 }), 336 ShapeDataset("Shape1", 337 { 338 TensorShape{ 1921U, 1U, 2U }, 339 TensorShape{ 641U, 1U, 2U, 3U }, 340 TensorShape{ 1U, 127U, 25U }, 341 TensorShape{ 799U, 595U, 1U, 4U } 342 })) 343 { 344 } 345 }; 346 347 /** Data set containing large 1D tensor shapes. */ 348 class Large1DShapes final : public ShapeDataset 349 { 350 public: Large1DShapes()351 Large1DShapes() 352 : ShapeDataset("Shape", 353 { 354 TensorShape{ 1245U } 355 }) 356 { 357 } 358 }; 359 360 /** Data set containing large 2D tensor shapes. */ 361 class Large2DShapes final : public ShapeDataset 362 { 363 public: Large2DShapes()364 Large2DShapes() 365 : ShapeDataset("Shape", 366 { 367 TensorShape{ 1245U, 652U } 368 }) 369 { 370 } 371 }; 372 373 /** Data set containing large 3D tensor shapes. */ 374 class Large3DShapes final : public ShapeDataset 375 { 376 public: Large3DShapes()377 Large3DShapes() 378 : ShapeDataset("Shape", 379 { 380 TensorShape{ 320U, 240U, 3U } 381 }) 382 { 383 } 384 }; 385 386 /** Data set containing large 4D tensor shapes. */ 387 class Large4DShapes final : public ShapeDataset 388 { 389 public: Large4DShapes()390 Large4DShapes() 391 : ShapeDataset("Shape", 392 { 393 TensorShape{ 320U, 123U, 3U, 3U } 394 }) 395 { 396 } 397 }; 398 399 /** Data set containing small 3x3 tensor shapes. */ 400 class Small3x3Shapes final : public ShapeDataset 401 { 402 public: Small3x3Shapes()403 Small3x3Shapes() 404 : ShapeDataset("Shape", 405 { 406 TensorShape{ 3U, 3U, 7U, 4U }, 407 TensorShape{ 3U, 3U, 4U, 13U }, 408 TensorShape{ 3U, 3U, 3U, 5U }, 409 }) 410 { 411 } 412 }; 413 414 /** Data set containing small 3x1 tensor shapes. */ 415 class Small3x1Shapes final : public ShapeDataset 416 { 417 public: Small3x1Shapes()418 Small3x1Shapes() 419 : ShapeDataset("Shape", 420 { 421 TensorShape{ 3U, 1U, 7U, 4U }, 422 TensorShape{ 3U, 1U, 4U, 13U }, 423 TensorShape{ 3U, 1U, 3U, 5U }, 424 }) 425 { 426 } 427 }; 428 429 /** Data set containing small 1x3 tensor shapes. */ 430 class Small1x3Shapes final : public ShapeDataset 431 { 432 public: Small1x3Shapes()433 Small1x3Shapes() 434 : ShapeDataset("Shape", 435 { 436 TensorShape{ 1U, 3U, 7U, 4U }, 437 TensorShape{ 1U, 3U, 4U, 13U }, 438 TensorShape{ 1U, 3U, 3U, 5U }, 439 }) 440 { 441 } 442 }; 443 444 /** Data set containing large 3x3 tensor shapes. */ 445 class Large3x3Shapes final : public ShapeDataset 446 { 447 public: Large3x3Shapes()448 Large3x3Shapes() 449 : ShapeDataset("Shape", 450 { 451 TensorShape{ 3U, 3U, 32U, 64U }, 452 TensorShape{ 3U, 3U, 51U, 13U }, 453 TensorShape{ 3U, 3U, 53U, 47U }, 454 }) 455 { 456 } 457 }; 458 459 /** Data set containing large 3x1 tensor shapes. */ 460 class Large3x1Shapes final : public ShapeDataset 461 { 462 public: Large3x1Shapes()463 Large3x1Shapes() 464 : ShapeDataset("Shape", 465 { 466 TensorShape{ 3U, 1U, 32U, 64U }, 467 TensorShape{ 3U, 1U, 51U, 13U }, 468 TensorShape{ 3U, 1U, 53U, 47U }, 469 }) 470 { 471 } 472 }; 473 474 /** Data set containing large 1x3 tensor shapes. */ 475 class Large1x3Shapes final : public ShapeDataset 476 { 477 public: Large1x3Shapes()478 Large1x3Shapes() 479 : ShapeDataset("Shape", 480 { 481 TensorShape{ 1U, 3U, 32U, 64U }, 482 TensorShape{ 1U, 3U, 51U, 13U }, 483 TensorShape{ 1U, 3U, 53U, 47U }, 484 }) 485 { 486 } 487 }; 488 489 /** Data set containing small 5x5 tensor shapes. */ 490 class Small5x5Shapes final : public ShapeDataset 491 { 492 public: Small5x5Shapes()493 Small5x5Shapes() 494 : ShapeDataset("Shape", 495 { 496 TensorShape{ 5U, 5U, 7U, 4U }, 497 TensorShape{ 5U, 5U, 4U, 13U }, 498 TensorShape{ 5U, 5U, 3U, 5U }, 499 }) 500 { 501 } 502 }; 503 504 /** Data set containing large 5x5 tensor shapes. */ 505 class Large5x5Shapes final : public ShapeDataset 506 { 507 public: Large5x5Shapes()508 Large5x5Shapes() 509 : ShapeDataset("Shape", 510 { 511 TensorShape{ 5U, 5U, 32U, 64U } 512 }) 513 { 514 } 515 }; 516 517 /** Data set containing small 5x1 tensor shapes. */ 518 class Small5x1Shapes final : public ShapeDataset 519 { 520 public: Small5x1Shapes()521 Small5x1Shapes() 522 : ShapeDataset("Shape", 523 { 524 TensorShape{ 5U, 1U, 7U, 4U } 525 }) 526 { 527 } 528 }; 529 530 /** Data set containing large 5x1 tensor shapes. */ 531 class Large5x1Shapes final : public ShapeDataset 532 { 533 public: Large5x1Shapes()534 Large5x1Shapes() 535 : ShapeDataset("Shape", 536 { 537 TensorShape{ 5U, 1U, 32U, 64U } 538 }) 539 { 540 } 541 }; 542 543 /** Data set containing small 1x5 tensor shapes. */ 544 class Small1x5Shapes final : public ShapeDataset 545 { 546 public: Small1x5Shapes()547 Small1x5Shapes() 548 : ShapeDataset("Shape", 549 { 550 TensorShape{ 1U, 5U, 7U, 4U } 551 }) 552 { 553 } 554 }; 555 556 /** Data set containing large 1x5 tensor shapes. */ 557 class Large1x5Shapes final : public ShapeDataset 558 { 559 public: Large1x5Shapes()560 Large1x5Shapes() 561 : ShapeDataset("Shape", 562 { 563 TensorShape{ 1U, 5U, 32U, 64U } 564 }) 565 { 566 } 567 }; 568 569 /** Data set containing small 1x7 tensor shapes. */ 570 class Small1x7Shapes final : public ShapeDataset 571 { 572 public: Small1x7Shapes()573 Small1x7Shapes() 574 : ShapeDataset("Shape", 575 { 576 TensorShape{ 1U, 7U, 7U, 4U } 577 }) 578 { 579 } 580 }; 581 582 /** Data set containing large 1x7 tensor shapes. */ 583 class Large1x7Shapes final : public ShapeDataset 584 { 585 public: Large1x7Shapes()586 Large1x7Shapes() 587 : ShapeDataset("Shape", 588 { 589 TensorShape{ 1U, 7U, 32U, 64U } 590 }) 591 { 592 } 593 }; 594 595 /** Data set containing small 7x7 tensor shapes. */ 596 class Small7x7Shapes final : public ShapeDataset 597 { 598 public: Small7x7Shapes()599 Small7x7Shapes() 600 : ShapeDataset("Shape", 601 { 602 TensorShape{ 7U, 7U, 7U, 4U } 603 }) 604 { 605 } 606 }; 607 608 /** Data set containing large 7x7 tensor shapes. */ 609 class Large7x7Shapes final : public ShapeDataset 610 { 611 public: Large7x7Shapes()612 Large7x7Shapes() 613 : ShapeDataset("Shape", 614 { 615 TensorShape{ 7U, 7U, 32U, 64U } 616 }) 617 { 618 } 619 }; 620 621 /** Data set containing small 7x1 tensor shapes. */ 622 class Small7x1Shapes final : public ShapeDataset 623 { 624 public: Small7x1Shapes()625 Small7x1Shapes() 626 : ShapeDataset("Shape", 627 { 628 TensorShape{ 7U, 1U, 7U, 4U } 629 }) 630 { 631 } 632 }; 633 634 /** Data set containing large 7x1 tensor shapes. */ 635 class Large7x1Shapes final : public ShapeDataset 636 { 637 public: Large7x1Shapes()638 Large7x1Shapes() 639 : ShapeDataset("Shape", 640 { 641 TensorShape{ 7U, 1U, 32U, 64U } 642 }) 643 { 644 } 645 }; 646 647 /** Data set containing small tensor shapes for deconvolution. */ 648 class SmallDeconvolutionShapes final : public ShapeDataset 649 { 650 public: SmallDeconvolutionShapes()651 SmallDeconvolutionShapes() 652 : ShapeDataset("InputShape", 653 { 654 TensorShape{ 5U, 4U, 3U, 2U }, 655 TensorShape{ 5U, 5U, 3U }, 656 TensorShape{ 11U, 13U, 4U, 3U } 657 }) 658 { 659 } 660 }; 661 662 /** Data set containing tiny tensor shapes for direct convolution. */ 663 class TinyDirectConvolutionShapes final : public ShapeDataset 664 { 665 public: TinyDirectConvolutionShapes()666 TinyDirectConvolutionShapes() 667 : ShapeDataset("InputShape", 668 { 669 // Batch size 1 670 TensorShape{ 11U, 13U, 3U }, 671 TensorShape{ 7U, 27U, 3U } 672 }) 673 { 674 } 675 }; 676 /** Data set containing small tensor shapes for direct convolution. */ 677 class SmallDirectConvolutionShapes final : public ShapeDataset 678 { 679 public: SmallDirectConvolutionShapes()680 SmallDirectConvolutionShapes() 681 : ShapeDataset("InputShape", 682 { 683 // Batch size 1 684 TensorShape{ 32U, 37U, 3U }, 685 // Batch size 4 686 TensorShape{ 32U, 37U, 3U, 4U }, 687 }) 688 { 689 } 690 }; 691 692 /** Data set containing small tensor shapes for direct convolution. */ 693 class SmallDirectConvolutionTensorShiftShapes final : public ShapeDataset 694 { 695 public: SmallDirectConvolutionTensorShiftShapes()696 SmallDirectConvolutionTensorShiftShapes() 697 : ShapeDataset("InputShape", 698 { 699 // Batch size 1 700 TensorShape{ 32U, 37U, 3U }, 701 // Batch size 4 702 TensorShape{ 32U, 37U, 3U, 4U }, 703 // Arbitrary batch size 704 TensorShape{ 32U, 37U, 3U, 8U } 705 }) 706 { 707 } 708 }; 709 710 /** Data set containing small grouped im2col tensor shapes. */ 711 class GroupedIm2ColSmallShapes final : public ShapeDataset 712 { 713 public: GroupedIm2ColSmallShapes()714 GroupedIm2ColSmallShapes() 715 : ShapeDataset("Shape", 716 { 717 TensorShape{ 11U, 11U, 48U }, 718 TensorShape{ 27U, 13U, 24U }, 719 TensorShape{ 128U, 64U, 12U, 3U }, 720 TensorShape{ 11U, 11U, 48U, 4U }, 721 TensorShape{ 27U, 13U, 24U, 4U }, 722 TensorShape{ 11U, 11U, 48U, 5U } 723 }) 724 { 725 } 726 }; 727 728 /** Data set containing large grouped im2col tensor shapes. */ 729 class GroupedIm2ColLargeShapes final : public ShapeDataset 730 { 731 public: GroupedIm2ColLargeShapes()732 GroupedIm2ColLargeShapes() 733 : ShapeDataset("Shape", 734 { 735 TensorShape{ 153U, 231U, 12U }, 736 TensorShape{ 123U, 191U, 12U, 2U }, 737 }) 738 { 739 } 740 }; 741 742 /** Data set containing small grouped weights tensor shapes. */ 743 class GroupedWeightsSmallShapes final : public ShapeDataset 744 { 745 public: GroupedWeightsSmallShapes()746 GroupedWeightsSmallShapes() 747 : ShapeDataset("Shape", 748 { 749 TensorShape{ 3U, 3U, 48U, 120U }, 750 TensorShape{ 1U, 3U, 24U, 240U }, 751 TensorShape{ 3U, 1U, 12U, 480U }, 752 TensorShape{ 5U, 5U, 48U, 120U } 753 }) 754 { 755 } 756 }; 757 758 /** Data set containing large grouped weights tensor shapes. */ 759 class GroupedWeightsLargeShapes final : public ShapeDataset 760 { 761 public: GroupedWeightsLargeShapes()762 GroupedWeightsLargeShapes() 763 : ShapeDataset("Shape", 764 { 765 TensorShape{ 9U, 9U, 96U, 240U }, 766 TensorShape{ 13U, 13U, 96U, 240U } 767 }) 768 { 769 } 770 }; 771 772 /** Data set containing 2D tensor shapes for DepthConcatenateLayer. */ 773 class DepthConcatenateLayerShapes final : public ShapeDataset 774 { 775 public: DepthConcatenateLayerShapes()776 DepthConcatenateLayerShapes() 777 : ShapeDataset("Shape", 778 { 779 TensorShape{ 322U, 243U }, 780 TensorShape{ 463U, 879U }, 781 TensorShape{ 416U, 651U } 782 }) 783 { 784 } 785 }; 786 787 /** Data set containing tensor shapes for ConcatenateLayer. */ 788 class ConcatenateLayerShapes final : public ShapeDataset 789 { 790 public: ConcatenateLayerShapes()791 ConcatenateLayerShapes() 792 : ShapeDataset("Shape", 793 { 794 TensorShape{ 232U, 65U, 3U }, 795 TensorShape{ 432U, 65U, 3U }, 796 TensorShape{ 124U, 65U, 3U }, 797 TensorShape{ 124U, 65U, 3U, 4U } 798 }) 799 { 800 } 801 }; 802 803 /** Data set containing global pooling tensor shapes. */ 804 class GlobalPoolingShapes final : public ShapeDataset 805 { 806 public: GlobalPoolingShapes()807 GlobalPoolingShapes() 808 : ShapeDataset("Shape", 809 { 810 // Batch size 1 811 TensorShape{ 9U, 9U }, 812 TensorShape{ 13U, 13U, 2U }, 813 TensorShape{ 27U, 27U, 1U, 3U }, 814 // Batch size 4 815 TensorShape{ 31U, 31U, 3U, 4U }, 816 TensorShape{ 34U, 34U, 2U, 4U } 817 }) 818 { 819 } 820 }; 821 /** Data set containing tiny softmax layer shapes. */ 822 class SoftmaxLayerTinyShapes final : public ShapeDataset 823 { 824 public: SoftmaxLayerTinyShapes()825 SoftmaxLayerTinyShapes() 826 : ShapeDataset("Shape", 827 { 828 TensorShape{ 9U, 9U }, 829 TensorShape{ 128U, 10U }, 830 }) 831 { 832 } 833 }; 834 835 /** Data set containing small softmax layer shapes. */ 836 class SoftmaxLayerSmallShapes final : public ShapeDataset 837 { 838 public: SoftmaxLayerSmallShapes()839 SoftmaxLayerSmallShapes() 840 : ShapeDataset("Shape", 841 { 842 TensorShape{ 9U, 9U }, 843 TensorShape{ 256U, 10U }, 844 TensorShape{ 353U, 8U }, 845 TensorShape{ 781U, 5U }, 846 }) 847 { 848 } 849 }; 850 851 /** Data set containing large softmax layer shapes. */ 852 class SoftmaxLayerLargeShapes final : public ShapeDataset 853 { 854 public: SoftmaxLayerLargeShapes()855 SoftmaxLayerLargeShapes() 856 : ShapeDataset("Shape", 857 { 858 TensorShape{ 1000U, 10U } 859 860 }) 861 { 862 } 863 }; 864 865 /** Data set containing large and small softmax layer 4D shapes. */ 866 class SoftmaxLayer4DShapes final : public ShapeDataset 867 { 868 public: SoftmaxLayer4DShapes()869 SoftmaxLayer4DShapes() 870 : ShapeDataset("Shape", 871 { 872 TensorShape{ 9U, 9U, 9U, 9U }, 873 TensorShape{ 31U, 10U, 1U, 9U }, 874 }) 875 { 876 } 877 }; 878 879 /** Data set containing 2D tensor shapes relative to an image size. */ 880 class SmallImageShapes final : public ShapeDataset 881 { 882 public: SmallImageShapes()883 SmallImageShapes() 884 : ShapeDataset("Shape", 885 { 886 TensorShape{ 640U, 480U }, 887 TensorShape{ 800U, 600U }, 888 }) 889 { 890 } 891 }; 892 893 /** Data set containing 2D tensor shapes relative to an image size. */ 894 class LargeImageShapes final : public ShapeDataset 895 { 896 public: LargeImageShapes()897 LargeImageShapes() 898 : ShapeDataset("Shape", 899 { 900 TensorShape{ 1920U, 1080U }, 901 TensorShape{ 2560U, 1536U }, 902 TensorShape{ 3584U, 2048U } 903 }) 904 { 905 } 906 }; 907 908 /** Data set containing small YOLO tensor shapes. */ 909 class SmallYOLOShapes final : public ShapeDataset 910 { 911 public: SmallYOLOShapes()912 SmallYOLOShapes() 913 : ShapeDataset("Shape", 914 { 915 // Batch size 1 916 TensorShape{ 11U, 11U, 270U }, 917 TensorShape{ 27U, 13U, 90U }, 918 TensorShape{ 13U, 12U, 45U, 2U }, 919 }) 920 { 921 } 922 }; 923 924 /** Data set containing large YOLO tensor shapes. */ 925 class LargeYOLOShapes final : public ShapeDataset 926 { 927 public: LargeYOLOShapes()928 LargeYOLOShapes() 929 : ShapeDataset("Shape", 930 { 931 TensorShape{ 24U, 23U, 270U }, 932 TensorShape{ 51U, 63U, 90U, 2U }, 933 TensorShape{ 76U, 91U, 45U, 3U } 934 }) 935 { 936 } 937 }; 938 939 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel */ 940 class SmallGEMMReshape2DShapes final : public ShapeDataset 941 { 942 public: SmallGEMMReshape2DShapes()943 SmallGEMMReshape2DShapes() 944 : ShapeDataset("Shape", 945 { 946 TensorShape{ 63U, 72U }, 947 }) 948 { 949 } 950 }; 951 952 /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */ 953 class SmallGEMMReshape3DShapes final : public ShapeDataset 954 { 955 public: SmallGEMMReshape3DShapes()956 SmallGEMMReshape3DShapes() 957 : ShapeDataset("Shape", 958 { 959 TensorShape{ 63U, 9U, 8U }, 960 }) 961 { 962 } 963 }; 964 965 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel */ 966 class LargeGEMMReshape2DShapes final : public ShapeDataset 967 { 968 public: LargeGEMMReshape2DShapes()969 LargeGEMMReshape2DShapes() 970 : ShapeDataset("Shape", 971 { 972 TensorShape{ 16U, 27U }, 973 TensorShape{ 345U, 171U } 974 }) 975 { 976 } 977 }; 978 979 /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */ 980 class LargeGEMMReshape3DShapes final : public ShapeDataset 981 { 982 public: LargeGEMMReshape3DShapes()983 LargeGEMMReshape3DShapes() 984 : ShapeDataset("Shape", 985 { 986 TensorShape{ 16U, 3U, 9U }, 987 TensorShape{ 345U, 34U, 18U } 988 }) 989 { 990 } 991 }; 992 993 /** Data set containing small 2D tensor shapes. */ 994 class Small2DNonMaxSuppressionShapes final : public ShapeDataset 995 { 996 public: Small2DNonMaxSuppressionShapes()997 Small2DNonMaxSuppressionShapes() 998 : ShapeDataset("Shape", 999 { 1000 TensorShape{ 4U, 7U }, 1001 TensorShape{ 4U, 13U }, 1002 TensorShape{ 4U, 64U } 1003 }) 1004 { 1005 } 1006 }; 1007 1008 /** Data set containing large 2D tensor shapes. */ 1009 class Large2DNonMaxSuppressionShapes final : public ShapeDataset 1010 { 1011 public: Large2DNonMaxSuppressionShapes()1012 Large2DNonMaxSuppressionShapes() 1013 : ShapeDataset("Shape", 1014 { 1015 TensorShape{ 4U, 113U } 1016 }) 1017 { 1018 } 1019 }; 1020 1021 } // namespace datasets 1022 } // namespace test 1023 } // namespace arm_compute 1024 #endif /* ARM_COMPUTE_TEST_SHAPE_DATASETS_H */ 1025