1 /* 2 * Copyright (c) 2017-2019 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 25 #pragma once 26 27 #include "arm_gemm.hpp" 28 29 #include <cstddef> 30 #include <utility> 31 32 namespace winograd 33 { 34 35 class ITransform 36 { 37 public: 38 virtual ~ITransform() = default; 39 40 /** 41 * Get the working space required to perform the transformation. 42 * 43 * Note, the working space is only required when performing the 44 * transformation - hence it can be reused whenever the transformation is 45 * not running. 46 * 47 * @param nthreads The greatest number of threads that will be used to execute the transform. 48 * @return Size of working space required in bytes. 49 */ 50 virtual size_t get_working_space_size(unsigned int nthreads=1) const = 0; 51 52 /** 53 * Set the working space to be used by the transformation. 54 * 55 * Note, the working space is only required when performing the 56 * transformation - hence it can be reused whenever the transformation is 57 * not running. 58 * 59 * @param Pointer to the working space. 60 */ 61 virtual void set_working_space(void *buffer) = 0; 62 63 /** 64 * Get the window of work a given operator can perform. 65 */ 66 virtual unsigned int get_window() const = 0; 67 68 /** 69 * Perform work upon a window of the transform. 70 */ 71 virtual void run(unsigned int start, unsigned int stop, unsigned int threadid=0) = 0; 72 }; 73 74 class IInputTransform : public ITransform 75 { 76 public: 77 virtual ~IInputTransform() = default; 78 79 /** 80 * Set the pointer to the (NHWC-ordered) tensor to be transformed. 81 */ 82 virtual void set_input_tensor(const void *input) = 0; 83 84 /** 85 * Set the pointer to the (NHWC-ordered) tensor to be transformed. 86 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). 87 */ 88 virtual void set_input_tensor(const void *input, int col_stride) = 0; 89 90 /** 91 * Set the pointer to the (NHWC-ordered) tensor to be transformed. 92 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). 93 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). 94 */ 95 virtual void set_input_tensor(const void *input, int row_stride, int col_stride) = 0; 96 97 /** 98 * Set the pointer to the (NHWC-ordered) tensor to be transformed. 99 * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes). 100 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). 101 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). 102 */ 103 virtual void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) = 0; 104 105 /** 106 * Set pointers to the matrices written by the transform. 107 * @param matrices Pointer to the start of the first matrix representing the transformed input. 108 * @param inter_matrix_stride Stride (in elements) between matrices. 109 * @param matrix_row_stride Stride (in elements) between the rows within a single matrix. 110 */ 111 virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0; 112 }; 113 114 class IOutputTransform : public ITransform 115 { 116 public: 117 virtual ~IOutputTransform() = default; 118 119 /** 120 * Set pointers to the matrices written by the transform. 121 * @param matrices Pointer to the start of the first matrix representing the input to the transform. 122 * @param inter_matrix_stride Stride (in elements) between matrices. 123 * @param matrix_row_stride Stride (in elements) between the rows within a single matrix. 124 */ 125 virtual void set_input_matrices(const void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0; 126 127 /** 128 * Set pointer to the bias tensor (can be ignored or called with nullptr for no bias. 129 */ 130 virtual void set_bias(const void *bias=nullptr) = 0; 131 132 /** 133 * Set pointer to the output tensor produced by the transform. 134 */ 135 virtual void set_output_tensor(void *output) = 0; 136 137 /** 138 * Set pointer to the output tensor produced by the transform. 139 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). 140 */ 141 virtual void set_output_tensor(void *output, int col_stride) = 0; 142 143 /** 144 * Set pointer to the output tensor produced by the transform. 145 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). 146 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). 147 */ 148 virtual void set_output_tensor(void *output, int row_stride, int col_stride) = 0; 149 150 /** 151 * Set pointer to the output tensor produced by the transform. 152 * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes). 153 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). 154 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). 155 */ 156 virtual void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) = 0; 157 }; 158 159 class IWeightTransform : public ITransform 160 { 161 public: 162 virtual ~IWeightTransform() = default; 163 164 /** Set pointer to the weight tensor read by the transform. */ 165 virtual void set_weight_tensor(const void *weights) = 0; 166 167 /** 168 * Set pointers to the matrices written by the transform. 169 * @param matrices Pointer to the start of the first matrix representing the transformed input. 170 * @param inter_matrix_stride Stride (in elements) between matrices. 171 * @param matrix_row_stride Stride (in elements) between the rows within a single matrix. 172 */ 173 virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0; 174 }; 175 176 enum class WinogradRoots 177 { 178 Integers, 179 }; 180 181 template <int InnerTileRows, int InnerTileCols, typename TIn, typename TOut, WinogradRoots Roots> 182 class InputTransform : public IInputTransform 183 { 184 public: 185 /** Create an InputTransform operator fixed on a given problem and set of 186 * pointers. 187 */ 188 InputTransform( 189 int kernel_rows, /**< Number of rows in the kernel */ 190 int kernel_cols, /**< Number of columns in the kernel */ 191 int n_batches, /**< Number of batches in input tensor. */ 192 int n_rows, /**< Number of rows in input tensor. */ 193 int n_cols, /**< Number of columns in input tensor. */ 194 int n_channels, /**< Number of channels in input tensor. */ 195 int padding_top, /**< Padding to apply to the top of the image. */ 196 int padding_left, /**< Padding to apply to the left of the image. */ 197 int padding_bottom, /**< Padding to apply to the bottom of the image. */ 198 int padding_right /**< Padding to apply to the right of the image. */ 199 ); 200 201 InputTransform(InputTransform&) = delete; 202 InputTransform operator=(InputTransform&) = delete; 203 204 /** Set pointers to the input tensor read by the transform. */ 205 void set_input_tensor(const void *input) override; 206 void set_input_tensor(const void *input, int col_stride) override; 207 void set_input_tensor(const void *input, int row_stride, int col_stride) override; 208 void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override; 209 210 /** Set pointers to the matrices written by the transform. */ 211 void set_output_matrices(void *matrices, int iter_matrix_stride, int matrix_row_stride) override; 212 213 /** Get the working space required to perform the transformation. */ 214 size_t get_working_space_size(unsigned int nthreads=1) const override; 215 void set_working_space(void *buffer) override; 216 217 /** Get the window of work a given operator can perform. */ 218 unsigned int get_window() const override; 219 static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window 220 221 /** Perform work upon a window of the input. */ 222 void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override; 223 224 protected: 225 const int _n_batches, _n_rows, _n_cols, _n_channels; 226 227 private: 228 void transform_unpadded_tile( 229 unsigned int threadid, 230 int n_channels, 231 TOut *outptr, 232 const TIn *inptr 233 ); 234 235 void transform_padded_tile( 236 unsigned int threadid, 237 int n_channels, 238 TOut *outptr, 239 const TIn *inptr, 240 int padding_top, 241 int padding_left, 242 int padding_bottom, 243 int padding_right 244 ); 245 246 /* Tile implementation */ 247 static void transform_tile( 248 int n_channels, /** @param[in] Number of channels in the tensor. */ 249 const TIn* inptr_base, /** @param[in] Pointer to the base of the input tile. */ 250 int input_row_stride, /** @param[in] Stride between rows of the input tensor. */ 251 int input_col_stride, /** @param[in] Stride between columns of the input tensor. */ 252 TOut* mptr_base, /** @param[out] Base pointer to transformed input matrices. */ 253 int matrix_stride /** @param[in] Stride between matrices in the input space. */ 254 ); 255 256 /** Get the working space for a thread. */ 257 void * get_working_space(unsigned int threadid) const; 258 259 const TIn* _inptr; 260 TOut* _outptr; 261 262 const int _overlap_rows, _overlap_cols; 263 const int _padding_top, _padding_left, _padding_bottom, _padding_right; 264 const int _tiles_M, _tiles_N; 265 int _matrix_stride, _matrix_row_stride, _matrix_batch_stride; 266 int _in_col_stride, _in_row_stride, _in_batch_stride; 267 268 const int _working_space_col_stride, _working_space_row_stride; 269 TIn *_working_space; 270 }; 271 272 template <int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots> 273 class InputTransform<InnerTileRows, 1, TIn, TOut, Roots> : 274 public InputTransform<1, InnerTileRows, TIn, TOut, Roots> 275 { 276 using Base = InputTransform<1, InnerTileRows, TIn, TOut, Roots>; 277 278 public: 279 InputTransform( 280 int kernel_rows, /**< Number of rows in the kernel. */ 281 int kernel_cols, /**< Number of columns in the kernel. */ 282 int n_batches, /**< Number of batches in input tensor. */ 283 int n_rows, /**< Number of rows in input tensor. */ 284 int n_cols, /**< Number of columns in input tensor. */ 285 int n_channels, /**< Number of channels in input tensor. */ 286 int padding_top, /**< Padding to apply to the top of the image. */ 287 int padding_left, /**< Padding to apply to the left of the image. */ 288 int padding_bottom, /**< Padding to apply to the bottom of the image. */ 289 int padding_right /**< Padding to apply to the right of the image. */ 290 ); 291 292 /** Set pointers to the input tensor read by the transform. */ 293 void set_input_tensor(const void *input) override; 294 void set_input_tensor(const void *input, int col_stride) override; 295 void set_input_tensor(const void *input, int row_stride, int col_stride) override; 296 void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override; 297 }; 298 299 template < 300 int KernelRows, int KernelCols, 301 int InnerTileRows, int InnerTileCols, 302 typename TIn, typename TOut, 303 WinogradRoots Roots 304 > 305 class OutputTransform : public IOutputTransform 306 { 307 public: 308 OutputTransform( 309 int n_batches, /**< Number of batches in output tensor. */ 310 int n_rows, /**< Number of rows in output tensor. */ 311 int n_cols, /**< Number of columns in output tensor. */ 312 int n_channels, /**< Number of channels in output tensor. */ 313 const arm_gemm::Activation &activation 314 ); 315 316 OutputTransform(OutputTransform&) = delete; 317 OutputTransform operator=(OutputTransform&) = delete; 318 319 /** Set pointers to the matrices read by the transform. */ 320 void set_input_matrices(const void *matrices, int iter_matrix_stride, int matrix_row_stride) override; 321 322 /** Set pointer to the bias tensor (can be ignored or called with nullptr for no bias */ 323 void set_bias(const void *bias=nullptr) override; 324 325 /** Set pointers to the output tensor written by the transform. */ 326 void set_output_tensor(void *output) override; 327 void set_output_tensor(void *output, int col_stride) override; 328 void set_output_tensor(void *output, int row_stride, int col_stride) override; 329 void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override; 330 331 /** Get the working space required to perform the transformation. */ 332 size_t get_working_space_size(unsigned int nthreads=1) const override; 333 void set_working_space(void *buffer) override; 334 335 /** Get the window of work a given operator can perform. */ 336 unsigned int get_window() const override; 337 static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window 338 339 /** Perform work upon a window of the input. */ 340 void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override; 341 342 protected: 343 static constexpr int inner_tile_rows = InnerTileRows; 344 static constexpr int inner_tile_cols = InnerTileCols; 345 static constexpr int output_tile_rows = InnerTileRows - KernelRows + 1; 346 static constexpr int output_tile_cols = InnerTileCols - KernelCols + 1; 347 348 const int _n_batches, _n_rows, _n_cols, _n_channels; 349 const TOut _output_min, _output_max; 350 351 private: 352 void transform_uncropped_tile( 353 unsigned int threadid, 354 int n_channels, 355 TOut *outptr, 356 const TIn *inptr, 357 const TOut *biases 358 ); 359 360 void transform_cropped_tile( 361 unsigned int threadid, 362 int n_channels, 363 TOut *outptr, 364 const TIn *inptr, 365 const TOut *biases, 366 int pad_bottom, 367 int pad_right 368 ); 369 370 /** Implementation of the tile transformation method. */ 371 static void transform_tile( 372 int n_channels, 373 const TIn* matrix_base, 374 int matrix_stride, 375 const TOut* biases, 376 TOut* output, 377 int output_row_stride, 378 int output_col_stride, 379 TOut output_min, 380 TOut output_max 381 ); 382 383 /** Get the working space for a thread. */ 384 void * get_working_space(unsigned int threadid) const; 385 386 const TIn* _matrix_base; 387 const TOut* _biases; 388 int _matrix_stride, _matrix_row_stride, _matrix_batch_stride; 389 TOut* _outptr; 390 const int _tiles_M, _tiles_N; 391 int _out_col_stride, _out_row_stride, _out_batch_stride; 392 393 const int _working_space_col_stride, _working_space_row_stride; 394 TOut *_working_space; 395 }; 396 397 template < 398 int KernelRows, 399 int InnerTileRows, 400 typename TIn, typename TOut, 401 WinogradRoots Roots 402 > 403 class OutputTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> : 404 public OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots> 405 { 406 using Base = OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>; 407 408 public: 409 OutputTransform( 410 int n_batches, /**< Number of batches in output tensor. */ 411 int n_rows, /**< Number of rows in output tensor. */ 412 int n_cols, /**< Number of columns in output tensor. */ 413 int n_channels, /**< Number of channels in output tensor. */ 414 const arm_gemm::Activation &activation 415 ); 416 417 /** Set pointers to the output tensor written by the transform. */ 418 void set_output_tensor(void *output) override; 419 void set_output_tensor(void *output, int col_stride) override; 420 void set_output_tensor(void *output, int row_stride, int col_stride) override; 421 void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override; 422 }; 423 424 template < 425 int KernelRows, int KernelCols, 426 int InnerTileRows, int InnerTileCols, 427 typename TIn, typename TOut, 428 WinogradRoots Roots 429 > 430 class WeightTransform : public IWeightTransform 431 { 432 public: 433 WeightTransform( 434 int n_output_channels, /**< Number of output channels in the kernel. */ 435 int n_input_channels /**< Number of input channels in the kernel. */ 436 ); 437 438 WeightTransform(WeightTransform&) = delete; 439 WeightTransform operator=(WeightTransform&) = delete; 440 441 /** Set pointer to the weight tensor read by the transform. */ 442 void set_weight_tensor(const void *weights) override; 443 444 /** Set pointer to the matrices written by the transform. */ 445 void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) override; 446 447 /** Get the working space required to perform the transformation. */ 448 size_t get_working_space_size(unsigned int nthreads=1) const override; 449 void set_working_space(void *buffer) override; 450 451 /** Get the window of work a given operator can perform. */ 452 unsigned int get_window() const override; 453 static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window 454 455 /** Perform work upon a window of the input. */ 456 void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override; 457 458 protected: 459 static const int kernel_rows = KernelRows; 460 static const int kernel_cols = KernelCols; 461 static const int inner_tile_rows = InnerTileRows; 462 static const int inner_tile_cols = InnerTileCols; 463 464 private: 465 /** Apply the transform to a tensor. */ 466 static void execute( 467 int n_output_channels, 468 int n_input_channels, 469 const TIn* input, 470 TOut* output, 471 int matrix_stride, 472 int matrix_row_stride 473 ); 474 475 const int _n_output_channels, _n_input_channels; 476 TOut *_matrices; 477 int _matrix_stride, _matrix_row_stride; 478 const TIn *_weights; 479 }; 480 481 template <int KernelRows, int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots> 482 class WeightTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> : 483 public WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots> 484 { 485 public: 486 using WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::WeightTransform; 487 }; 488 489 template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, WinogradRoots Roots> 490 class WinogradGEMM 491 { 492 public: 493 // Information about the specific Winograd instance 494 static constexpr int output_tile_rows = OutputTileRows; 495 static constexpr int output_tile_cols = OutputTileCols; 496 static constexpr int kernel_rows = KernelRows; 497 static constexpr int kernel_cols = KernelCols; 498 static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1; 499 static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1; 500 static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols; 501 502 /** Transform weights from the spatial to the Winograd domain. */ 503 template <typename TIn, typename TOut> 504 using WeightsTransform = WeightTransform< 505 KernelRows, KernelCols, inner_tile_rows, inner_tile_cols, 506 TIn, TOut, Roots 507 >; 508 509 /** Transform input feature maps from the spatial to the Winograd domain. 510 */ 511 template <typename TIn, typename TOut> 512 using InputTransform = InputTransform< 513 inner_tile_rows, inner_tile_cols, TIn, TOut, Roots 514 >; 515 516 /** Transform output feature maps from the Winograd to the spatial domain. 517 */ 518 template <typename TIn, typename TOut> 519 using OutputTransform = OutputTransform< 520 KernelRows, KernelCols, inner_tile_rows, inner_tile_cols, 521 TIn, TOut, Roots 522 >; 523 524 /** Perform a convolution. 525 */ 526 template <typename TOut, typename TIn, typename TInGEMM=TIn, typename TOutGEMM=TOut> 527 class Convolution 528 { 529 public: 530 // Information about the typed Winograd instance 531 typedef TOut OutputType; 532 typedef TOutGEMM GemmOutputType; 533 typedef TInGEMM GemmInputType; 534 typedef TIn InputType; 535 536 /** Get the output shape of a convolution. */ 537 static std::pair<unsigned int, unsigned int> get_output_shape( 538 const std::pair<unsigned int, unsigned int> input_shape, 539 bool padding_same); 540 541 /** Get the memory required to store the kernel transformed into the 542 * Winograd domain. 543 */ 544 static size_t get_kernel_storage_size(unsigned int n_input_channels, 545 unsigned int n_output_channels); 546 547 /** Get the memory required to store the input tensor transformed into 548 * the Winograd domain. 549 */ 550 static size_t get_input_storage_size( 551 unsigned int n_batches, // Number of batches 552 unsigned int n_rows, // Number of input rows 553 unsigned int n_cols, // Number of input columns 554 unsigned int n_channels, // Number of input channels 555 bool padding_same); 556 557 /** Get the memory required to store the output tensor in the Winograd 558 * domain. 559 */ 560 static size_t get_output_storage_size( 561 unsigned int n_batches, // Number of batches 562 unsigned int n_rows, // Number of output rows 563 unsigned int n_cols, // Number of output columns 564 unsigned int n_channels // Number of output channels 565 ); 566 567 /** Get the memory required to apply a Winograd operator to some input. 568 */ 569 static size_t get_working_space_size( 570 unsigned int n_batches, 571 unsigned int n_rows, // Number of input rows 572 unsigned int n_cols, // Number of input columns 573 unsigned int n_input_channels, // Number of input channels 574 unsigned int n_output_channels, // Number of output channels 575 bool padding_same); 576 577 /* Get the memory required by a single "input" matrix. 578 */ 579 static size_t get_input_matrix_size( 580 unsigned int n_batches, // Number of batches 581 unsigned int n_rows, // Number of input rows 582 unsigned int n_cols, // Number of input columns 583 unsigned int n_channels, // Number of input channels 584 bool padding_same); 585 586 static int get_input_matrix_stride( 587 unsigned int n_batches, // Number of batches 588 unsigned int n_rows, // Number of input rows 589 unsigned int n_cols, // Number of input columns 590 unsigned int n_channels, // Number of input channels 591 bool padding_same); 592 593 /* Get the memory required by a single "output" matrix. 594 */ 595 static size_t get_output_matrix_size( 596 unsigned int n_batches, // Number of batches 597 unsigned int n_rows, // Number of output rows 598 unsigned int n_cols, // Number of output columns 599 unsigned int n_channels // Number of output channels 600 ); 601 602 static int get_output_matrix_stride( 603 unsigned int n_batches, // Number of batches 604 unsigned int n_rows, // Number of output rows 605 unsigned int n_cols, // Number of output columns 606 unsigned int n_channels // Number of output channels 607 ); 608 609 /* Get the memory required by a single "kernel" matrix. 610 */ 611 static size_t get_kernel_matrix_size(unsigned int n_input_channels, 612 unsigned int n_output_channels); 613 static int get_kernel_matrix_stride(unsigned int n_input_channels, 614 unsigned int n_output_channels); 615 616 static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */ 617 static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */ 618 }; 619 }; 620 621 } // namespace winograd 622