• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2019 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #pragma once
26 
27 #include "arm_gemm.hpp"
28 
29 #include <cstddef>
30 #include <utility>
31 
32 namespace winograd
33 {
34 
35 class ITransform
36 {
37   public:
38     virtual ~ITransform() = default;
39 
40     /**
41      * Get the working space required to perform the transformation.
42      *
43      * Note, the working space is only required when performing the
44      * transformation - hence it can be reused whenever the transformation is
45      * not running.
46      *
47      * @param nthreads The greatest number of threads that will be used to execute the transform.
48      * @return Size of working space required in bytes.
49      */
50     virtual size_t get_working_space_size(unsigned int nthreads=1) const = 0;
51 
52     /**
53      * Set the working space to be used by the transformation.
54      *
55      * Note, the working space is only required when performing the
56      * transformation - hence it can be reused whenever the transformation is
57      * not running.
58      *
59      * @param Pointer to the working space.
60      */
61     virtual void set_working_space(void *buffer) = 0;
62 
63     /**
64      * Get the window of work a given operator can perform.
65      */
66     virtual unsigned int get_window() const = 0;
67 
68     /**
69      * Perform work upon a window of the transform.
70      */
71     virtual void run(unsigned int start, unsigned int stop, unsigned int threadid=0) = 0;
72 };
73 
74 class IInputTransform : public ITransform
75 {
76   public:
77     virtual ~IInputTransform() = default;
78 
79     /**
80      * Set the pointer to the (NHWC-ordered) tensor to be transformed.
81      */
82     virtual void set_input_tensor(const void *input) = 0;
83 
84     /**
85      * Set the pointer to the (NHWC-ordered) tensor to be transformed.
86      * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
87      */
88     virtual void set_input_tensor(const void *input, int col_stride) = 0;
89 
90     /**
91      * Set the pointer to the (NHWC-ordered) tensor to be transformed.
92      * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
93      * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
94      */
95     virtual void set_input_tensor(const void *input, int row_stride, int col_stride) = 0;
96 
97     /**
98      * Set the pointer to the (NHWC-ordered) tensor to be transformed.
99      * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes).
100      * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
101      * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
102      */
103     virtual void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) = 0;
104 
105     /**
106      * Set pointers to the matrices written by the transform.
107      * @param matrices Pointer to the start of the first matrix representing the transformed input.
108      * @param inter_matrix_stride Stride (in elements) between matrices.
109      * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
110      */
111     virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
112 };
113 
114 class IOutputTransform : public ITransform
115 {
116   public:
117     virtual ~IOutputTransform() = default;
118 
119     /**
120      * Set pointers to the matrices written by the transform.
121      * @param matrices Pointer to the start of the first matrix representing the input to the transform.
122      * @param inter_matrix_stride Stride (in elements) between matrices.
123      * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
124      */
125     virtual void set_input_matrices(const void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
126 
127     /**
128      * Set pointer to the bias tensor (can be ignored or called with nullptr for no bias.
129      */
130     virtual void set_bias(const void *bias=nullptr) = 0;
131 
132     /**
133      * Set pointer to the output tensor produced by the transform.
134      */
135     virtual void set_output_tensor(void *output) = 0;
136 
137     /**
138      * Set pointer to the output tensor produced by the transform.
139      * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
140      */
141     virtual void set_output_tensor(void *output, int col_stride) = 0;
142 
143     /**
144      * Set pointer to the output tensor produced by the transform.
145      * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
146      * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
147      */
148     virtual void set_output_tensor(void *output, int row_stride, int col_stride) = 0;
149 
150     /**
151      * Set pointer to the output tensor produced by the transform.
152      * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes).
153      * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
154      * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
155      */
156     virtual void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) = 0;
157 };
158 
159 class IWeightTransform : public ITransform
160 {
161   public:
162     virtual ~IWeightTransform() = default;
163 
164     /** Set pointer to the weight tensor read by the transform. */
165     virtual void set_weight_tensor(const void *weights) = 0;
166 
167     /**
168      * Set pointers to the matrices written by the transform.
169      * @param matrices Pointer to the start of the first matrix representing the transformed input.
170      * @param inter_matrix_stride Stride (in elements) between matrices.
171      * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
172      */
173     virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
174 };
175 
176 enum class WinogradRoots
177 {
178   Integers,
179 };
180 
181 template <int InnerTileRows, int InnerTileCols, typename TIn, typename TOut, WinogradRoots Roots>
182 class InputTransform : public IInputTransform
183 {
184   public:
185     /** Create an InputTransform operator fixed on a given problem and set of
186      * pointers.
187      */
188     InputTransform(
189         int kernel_rows,     /**< Number of rows in the kernel */
190         int kernel_cols,     /**< Number of columns in the kernel */
191         int n_batches,       /**< Number of batches in input tensor. */
192         int n_rows,          /**< Number of rows in input tensor. */
193         int n_cols,          /**< Number of columns in input tensor. */
194         int n_channels,      /**< Number of channels in input tensor. */
195         int padding_top,     /**< Padding to apply to the top of the image. */
196         int padding_left,    /**< Padding to apply to the left of the image. */
197         int padding_bottom,  /**< Padding to apply to the bottom of the image. */
198         int padding_right    /**< Padding to apply to the right of the image. */
199     );
200 
201     InputTransform(InputTransform&) = delete;
202     InputTransform operator=(InputTransform&) = delete;
203 
204     /** Set pointers to the input tensor read by the transform. */
205     void set_input_tensor(const void *input) override;
206     void set_input_tensor(const void *input, int col_stride) override;
207     void set_input_tensor(const void *input, int row_stride, int col_stride) override;
208     void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override;
209 
210     /** Set pointers to the matrices written by the transform. */
211     void set_output_matrices(void *matrices, int iter_matrix_stride, int matrix_row_stride) override;
212 
213     /** Get the working space required to perform the transformation. */
214     size_t get_working_space_size(unsigned int nthreads=1) const override;
215     void set_working_space(void *buffer) override;
216 
217     /** Get the window of work a given operator can perform. */
218     unsigned int get_window() const override;
219     static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
220 
221     /** Perform work upon a window of the input. */
222     void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
223 
224   protected:
225     const int _n_batches, _n_rows, _n_cols, _n_channels;
226 
227   private:
228     void transform_unpadded_tile(
229       unsigned int threadid,
230       int n_channels,
231       TOut *outptr,
232       const TIn *inptr
233     );
234 
235     void transform_padded_tile(
236       unsigned int threadid,
237       int n_channels,
238       TOut *outptr,
239       const TIn *inptr,
240       int padding_top,
241       int padding_left,
242       int padding_bottom,
243       int padding_right
244     );
245 
246     /* Tile implementation */
247     static void transform_tile(
248       int n_channels,         /** @param[in] Number of channels in the tensor. */
249       const TIn* inptr_base,  /** @param[in] Pointer to the base of the input tile. */
250       int input_row_stride,   /** @param[in] Stride between rows of the input tensor. */
251       int input_col_stride,   /** @param[in] Stride between columns of the input tensor. */
252       TOut* mptr_base,        /** @param[out] Base pointer to transformed input matrices. */
253       int matrix_stride       /** @param[in] Stride between matrices in the input space. */
254     );
255 
256     /** Get the working space for a thread. */
257     void * get_working_space(unsigned int threadid) const;
258 
259     const TIn* _inptr;
260     TOut* _outptr;
261 
262     const int _overlap_rows, _overlap_cols;
263     const int _padding_top, _padding_left, _padding_bottom, _padding_right;
264     const int _tiles_M, _tiles_N;
265     int _matrix_stride, _matrix_row_stride, _matrix_batch_stride;
266     int _in_col_stride, _in_row_stride, _in_batch_stride;
267 
268     const int _working_space_col_stride, _working_space_row_stride;
269     TIn *_working_space;
270 };
271 
272 template <int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots>
273 class InputTransform<InnerTileRows, 1, TIn, TOut, Roots> :
274   public InputTransform<1, InnerTileRows, TIn, TOut, Roots>
275 {
276   using Base = InputTransform<1, InnerTileRows, TIn, TOut, Roots>;
277 
278   public:
279     InputTransform(
280       int kernel_rows,     /**< Number of rows in the kernel. */
281       int kernel_cols,     /**< Number of columns in the kernel. */
282       int n_batches,       /**< Number of batches in input tensor. */
283       int n_rows,          /**< Number of rows in input tensor. */
284       int n_cols,          /**< Number of columns in input tensor. */
285       int n_channels,      /**< Number of channels in input tensor. */
286       int padding_top,     /**< Padding to apply to the top of the image. */
287       int padding_left,    /**< Padding to apply to the left of the image. */
288       int padding_bottom,  /**< Padding to apply to the bottom of the image. */
289       int padding_right    /**< Padding to apply to the right of the image. */
290     );
291 
292     /** Set pointers to the input tensor read by the transform. */
293     void set_input_tensor(const void *input) override;
294     void set_input_tensor(const void *input, int col_stride) override;
295     void set_input_tensor(const void *input, int row_stride, int col_stride) override;
296     void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override;
297 };
298 
299 template <
300   int KernelRows, int KernelCols,
301   int InnerTileRows, int InnerTileCols,
302   typename TIn, typename TOut,
303   WinogradRoots Roots
304 >
305 class OutputTransform : public IOutputTransform
306 {
307   public:
308     OutputTransform(
309       int n_batches,  /**< Number of batches in output tensor. */
310       int n_rows,     /**< Number of rows in output tensor. */
311       int n_cols,     /**< Number of columns in output tensor. */
312       int n_channels, /**< Number of channels in output tensor. */
313       const arm_gemm::Activation &activation
314     );
315 
316     OutputTransform(OutputTransform&) = delete;
317     OutputTransform operator=(OutputTransform&) = delete;
318 
319     /** Set pointers to the matrices read by the transform. */
320     void set_input_matrices(const void *matrices, int iter_matrix_stride, int matrix_row_stride) override;
321 
322     /** Set pointer to the bias tensor (can be ignored or called with nullptr for no bias */
323     void set_bias(const void *bias=nullptr) override;
324 
325     /** Set pointers to the output tensor written by the transform. */
326     void set_output_tensor(void *output) override;
327     void set_output_tensor(void *output, int col_stride) override;
328     void set_output_tensor(void *output, int row_stride, int col_stride) override;
329     void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override;
330 
331     /** Get the working space required to perform the transformation. */
332     size_t get_working_space_size(unsigned int nthreads=1) const override;
333     void set_working_space(void *buffer) override;
334 
335     /** Get the window of work a given operator can perform. */
336     unsigned int get_window() const override;
337     static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
338 
339     /** Perform work upon a window of the input. */
340     void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
341 
342   protected:
343     static constexpr int inner_tile_rows = InnerTileRows;
344     static constexpr int inner_tile_cols = InnerTileCols;
345     static constexpr int output_tile_rows = InnerTileRows - KernelRows + 1;
346     static constexpr int output_tile_cols = InnerTileCols - KernelCols + 1;
347 
348     const int _n_batches, _n_rows, _n_cols, _n_channels;
349     const TOut _output_min, _output_max;
350 
351   private:
352     void transform_uncropped_tile(
353       unsigned int threadid,
354       int n_channels,
355       TOut *outptr,
356       const TIn *inptr,
357       const TOut *biases
358     );
359 
360     void transform_cropped_tile(
361       unsigned int threadid,
362       int n_channels,
363       TOut *outptr,
364       const TIn *inptr,
365       const TOut *biases,
366       int pad_bottom,
367       int pad_right
368     );
369 
370     /** Implementation of the tile transformation method. */
371     static void transform_tile(
372       int n_channels,
373       const TIn* matrix_base,
374       int matrix_stride,
375       const TOut* biases,
376       TOut* output,
377       int output_row_stride,
378       int output_col_stride,
379       TOut output_min,
380       TOut output_max
381     );
382 
383     /** Get the working space for a thread. */
384     void * get_working_space(unsigned int threadid) const;
385 
386     const TIn* _matrix_base;
387     const TOut* _biases;
388     int _matrix_stride, _matrix_row_stride, _matrix_batch_stride;
389     TOut* _outptr;
390     const int _tiles_M, _tiles_N;
391     int _out_col_stride, _out_row_stride, _out_batch_stride;
392 
393     const int _working_space_col_stride, _working_space_row_stride;
394     TOut *_working_space;
395 };
396 
397 template <
398   int KernelRows,
399   int InnerTileRows,
400   typename TIn, typename TOut,
401   WinogradRoots Roots
402 >
403 class OutputTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> :
404   public OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>
405 {
406   using Base = OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>;
407 
408   public:
409     OutputTransform(
410       int n_batches,  /**< Number of batches in output tensor. */
411       int n_rows,     /**< Number of rows in output tensor. */
412       int n_cols,     /**< Number of columns in output tensor. */
413       int n_channels, /**< Number of channels in output tensor. */
414       const arm_gemm::Activation &activation
415     );
416 
417     /** Set pointers to the output tensor written by the transform. */
418     void set_output_tensor(void *output) override;
419     void set_output_tensor(void *output, int col_stride) override;
420     void set_output_tensor(void *output, int row_stride, int col_stride) override;
421     void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override;
422 };
423 
424 template <
425   int KernelRows, int KernelCols,
426   int InnerTileRows, int InnerTileCols,
427   typename TIn, typename TOut,
428   WinogradRoots Roots
429 >
430 class WeightTransform : public IWeightTransform
431 {
432   public:
433     WeightTransform(
434       int n_output_channels,  /**< Number of output channels in the kernel. */
435       int n_input_channels    /**< Number of input channels in the kernel. */
436     );
437 
438     WeightTransform(WeightTransform&) = delete;
439     WeightTransform operator=(WeightTransform&) = delete;
440 
441     /** Set pointer to the weight tensor read by the transform. */
442     void set_weight_tensor(const void *weights) override;
443 
444     /** Set pointer to the matrices written by the transform. */
445     void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) override;
446 
447     /** Get the working space required to perform the transformation. */
448     size_t get_working_space_size(unsigned int nthreads=1) const override;
449     void set_working_space(void *buffer) override;
450 
451     /** Get the window of work a given operator can perform. */
452     unsigned int get_window() const override;
453     static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
454 
455     /** Perform work upon a window of the input. */
456     void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
457 
458   protected:
459     static const int kernel_rows = KernelRows;
460     static const int kernel_cols = KernelCols;
461     static const int inner_tile_rows = InnerTileRows;
462     static const int inner_tile_cols = InnerTileCols;
463 
464   private:
465     /** Apply the transform to a tensor. */
466     static void execute(
467       int n_output_channels,
468       int n_input_channels,
469       const TIn* input,
470       TOut* output,
471       int matrix_stride,
472       int matrix_row_stride
473     );
474 
475     const int _n_output_channels, _n_input_channels;
476     TOut *_matrices;
477     int _matrix_stride, _matrix_row_stride;
478     const TIn *_weights;
479 };
480 
481 template <int KernelRows, int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots>
482 class WeightTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> :
483   public WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>
484 {
485   public:
486     using WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::WeightTransform;
487 };
488 
489 template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, WinogradRoots Roots>
490 class WinogradGEMM
491 {
492   public:
493     // Information about the specific Winograd instance
494     static constexpr int output_tile_rows = OutputTileRows;
495     static constexpr int output_tile_cols = OutputTileCols;
496     static constexpr int kernel_rows = KernelRows;
497     static constexpr int kernel_cols = KernelCols;
498     static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1;
499     static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1;
500     static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
501 
502     /** Transform weights from the spatial to the Winograd domain. */
503     template <typename TIn, typename TOut>
504     using WeightsTransform = WeightTransform<
505       KernelRows, KernelCols, inner_tile_rows, inner_tile_cols,
506       TIn, TOut, Roots
507     >;
508 
509     /** Transform input feature maps from the spatial to the Winograd domain.
510      */
511     template <typename TIn, typename TOut>
512     using InputTransform = InputTransform<
513       inner_tile_rows, inner_tile_cols, TIn, TOut, Roots
514     >;
515 
516     /** Transform output feature maps from the Winograd to the spatial domain.
517      */
518     template <typename TIn, typename TOut>
519     using OutputTransform = OutputTransform<
520       KernelRows, KernelCols, inner_tile_rows, inner_tile_cols,
521       TIn, TOut, Roots
522     >;
523 
524     /** Perform a convolution.
525      */
526     template <typename TOut, typename TIn, typename TInGEMM=TIn, typename TOutGEMM=TOut>
527     class Convolution
528     {
529       public:
530         // Information about the typed Winograd instance
531         typedef TOut OutputType;
532         typedef TOutGEMM GemmOutputType;
533         typedef TInGEMM GemmInputType;
534         typedef TIn InputType;
535 
536         /** Get the output shape of a convolution. */
537         static std::pair<unsigned int, unsigned int> get_output_shape(
538             const std::pair<unsigned int, unsigned int> input_shape,
539             bool padding_same);
540 
541         /** Get the memory required to store the kernel transformed into the
542          * Winograd domain.
543          */
544         static size_t get_kernel_storage_size(unsigned int n_input_channels,
545                                               unsigned int n_output_channels);
546 
547         /** Get the memory required to store the input tensor transformed into
548          * the Winograd domain.
549          */
550         static size_t get_input_storage_size(
551             unsigned int n_batches,  // Number of batches
552             unsigned int n_rows,     // Number of input rows
553             unsigned int n_cols,     // Number of input columns
554             unsigned int n_channels, // Number of input channels
555             bool padding_same);
556 
557         /** Get the memory required to store the output tensor in the Winograd
558          * domain.
559          */
560         static size_t get_output_storage_size(
561             unsigned int n_batches, // Number of batches
562             unsigned int n_rows,    // Number of output rows
563             unsigned int n_cols,    // Number of output columns
564             unsigned int n_channels // Number of output channels
565             );
566 
567         /** Get the memory required to apply a Winograd operator to some input.
568          */
569         static size_t get_working_space_size(
570             unsigned int n_batches,
571             unsigned int n_rows,            // Number of input rows
572             unsigned int n_cols,            // Number of input columns
573             unsigned int n_input_channels,  // Number of input channels
574             unsigned int n_output_channels, // Number of output channels
575             bool padding_same);
576 
577         /* Get the memory required by a single "input" matrix.
578          */
579         static size_t get_input_matrix_size(
580             unsigned int n_batches,  // Number of batches
581             unsigned int n_rows,     // Number of input rows
582             unsigned int n_cols,     // Number of input columns
583             unsigned int n_channels, // Number of input channels
584             bool padding_same);
585 
586         static int get_input_matrix_stride(
587             unsigned int n_batches,  // Number of batches
588             unsigned int n_rows,     // Number of input rows
589             unsigned int n_cols,     // Number of input columns
590             unsigned int n_channels, // Number of input channels
591             bool padding_same);
592 
593         /* Get the memory required by a single "output" matrix.
594          */
595         static size_t get_output_matrix_size(
596             unsigned int n_batches, // Number of batches
597             unsigned int n_rows,    // Number of output rows
598             unsigned int n_cols,    // Number of output columns
599             unsigned int n_channels // Number of output channels
600             );
601 
602         static int get_output_matrix_stride(
603             unsigned int n_batches, // Number of batches
604             unsigned int n_rows,    // Number of output rows
605             unsigned int n_cols,    // Number of output columns
606             unsigned int n_channels // Number of output channels
607             );
608 
609         /* Get the memory required by a single "kernel" matrix.
610          */
611         static size_t get_kernel_matrix_size(unsigned int n_input_channels,
612                                              unsigned int n_output_channels);
613         static int get_kernel_matrix_stride(unsigned int n_input_channels,
614                                             unsigned int n_output_channels);
615 
616         static constexpr int M_BLOCK = 4;   /** Size of block used by GEMM. */
617         static constexpr int N_BLOCK = 16;  /** Size of block used by GEMM. */
618     };
619 };
620 
621 }  // namespace winograd
622