• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TYPES_H
25 #define ARM_COMPUTE_TYPES_H
26 
27 #include "arm_compute/core/Coordinates.h"
28 #include "arm_compute/core/QuantizationInfo.h"
29 #include "arm_compute/core/Size2D.h"
30 #include "arm_compute/core/Size3D.h"
31 #include "arm_compute/core/Strides.h"
32 #include "arm_compute/core/TensorShape.h"
33 #include "arm_compute/core/experimental/IPostOp.h"
34 #include "arm_compute/core/utils/misc/Macros.h"
35 #include "support/Bfloat16.h"
36 #include "support/Half.h"
37 
38 #include <cmath>
39 #include <cstddef>
40 #include <cstdint>
41 #include <map>
42 #include <string>
43 #include <utility>
44 
45 namespace arm_compute
46 {
47 /** 16-bit floating point type */
48 using half = half_float::half;
49 
50 /** Permutation vector */
51 using PermutationVector = Strides;
52 /** Bidirectional strides */
53 using BiStrides = Coordinates;
54 
55 /** Image colour formats */
56 enum class Format
57 {
58     UNKNOWN,  /**< Unknown image format */
59     U8,       /**< 1 channel, 1 U8 per channel */
60     S16,      /**< 1 channel, 1 S16 per channel */
61     U16,      /**< 1 channel, 1 U16 per channel */
62     S32,      /**< 1 channel, 1 S32 per channel */
63     U32,      /**< 1 channel, 1 U32 per channel */
64     BFLOAT16, /**< 16-bit brain floating-point number */
65     F16,      /**< 1 channel, 1 F16 per channel */
66     F32,      /**< 1 channel, 1 F32 per channel */
67     UV88,     /**< 2 channel, 1 U8 per channel */
68     RGB888,   /**< 3 channels, 1 U8 per channel */
69     RGBA8888, /**< 4 channels, 1 U8 per channel */
70     YUV444,   /**< A 3 plane of 8 bit 4:4:4 sampled Y, U, V planes */
71     YUYV422,  /**< A single plane of 32-bit macro pixel of Y0, U0, Y1, V0 bytes */
72     NV12,     /**< A 2 plane YUV format of Luma (Y) and interleaved UV data at 4:2:0 sampling */
73     NV21,     /**< A 2 plane YUV format of Luma (Y) and interleaved VU data at 4:2:0 sampling */
74     IYUV,     /**< A 3 plane of 8-bit 4:2:0 sampled Y, U, V planes */
75     UYVY422   /**< A single plane of 32-bit macro pixel of U0, Y0, V0, Y1 byte */
76 };
77 
78 /** Available data types */
79 enum class DataType
80 {
81     UNKNOWN,            /**< Unknown data type */
82     U8,                 /**< unsigned 8-bit number */
83     S8,                 /**< signed 8-bit number */
84     QSYMM8,             /**< quantized, symmetric fixed-point 8-bit number */
85     QASYMM8,            /**< quantized, asymmetric fixed-point 8-bit number unsigned */
86     QASYMM8_SIGNED,     /**< quantized, asymmetric fixed-point 8-bit number signed */
87     QSYMM8_PER_CHANNEL, /**< quantized, symmetric per channel fixed-point 8-bit number */
88     U16,                /**< unsigned 16-bit number */
89     S16,                /**< signed 16-bit number */
90     QSYMM16,            /**< quantized, symmetric fixed-point 16-bit number */
91     QASYMM16,           /**< quantized, asymmetric fixed-point 16-bit number */
92     U32,                /**< unsigned 32-bit number */
93     S32,                /**< signed 32-bit number */
94     U64,                /**< unsigned 64-bit number */
95     S64,                /**< signed 64-bit number */
96     BFLOAT16,           /**< 16-bit brain floating-point number */
97     F16,                /**< 16-bit floating-point number */
98     F32,                /**< 32-bit floating-point number */
99     F64,                /**< 64-bit floating-point number */
100     SIZET               /**< size_t */
101 };
102 
103 /** Available Sampling Policies */
104 enum class SamplingPolicy
105 {
106     CENTER,  /**< Samples are taken at pixel center */
107     TOP_LEFT /**< Samples are taken at pixel top left corner */
108 };
109 
110 /** [DataLayout enum definition] **/
111 
112 /** Supported tensor data layouts */
113 enum class DataLayout
114 {
115     UNKNOWN, /**< Unknown data layout */
116     NCHW,    /**< Num samples, channels, height, width */
117     NHWC,    /**< Num samples, height, width, channels */
118     NCDHW,   /**< Num samples, channels, depth, height, width */
119     NDHWC    /**< Num samples, depth, height, width, channels */
120 };
121 /** [DataLayout enum definition] **/
122 
123 /** Supported tensor data layout dimensions */
124 enum class DataLayoutDimension
125 {
126     CHANNEL, /**< channel */
127     HEIGHT,  /**< height */
128     WIDTH,   /**< width */
129     DEPTH,   /**< depth */
130     BATCHES  /**< batches */
131 };
132 
133 /** Available ConvolutionMethod*/
134 enum class ConvolutionMethod
135 {
136     GEMM,        /**< Convolution using GEMM */
137     GEMM_CONV2D, /**< Direct 2D GEMM convolution */
138     DIRECT,      /**< Direct convolution */
139     INDIRECT,    /**< Indirect convolution */
140     WINOGRAD,    /**< Convolution using Winograd */
141     FFT          /**< Convolution using FFT */
142 };
143 
144 /** Available DepthwiseConvolutionFunction*/
145 enum class DepthwiseConvolutionFunction
146 {
147     OPTIMIZED, /**< Optimized Depthwise Convolution */
148     GENERIC,   /**< Generic Depthwise Convolution */
149 };
150 
151 /** Available DeconvolutionMethod*/
152 enum class DeconvolutionMethod
153 {
154     GEMM,            /**< Deconvolution using GEMM */
155     DIRECT,          /**< Direct deconvolution */
156     UPSCALE_CONV2D   /**< Deconvolution with Upscaling */
157 };
158 
159 /** Available FuseBatchNormalizationType*/
160 enum class FuseBatchNormalizationType
161 {
162     CONVOLUTION,         /**< For Convolution weights */
163     DEPTHWISECONVOLUTION /**< For Depthwise Convolution weights*/
164 };
165 
166 /** Padding mode to use for PadLayer */
167 enum class PaddingMode
168 {
169     CONSTANT,
170     REFLECT,
171     SYMMETRIC
172 };
173 
174 /** Supported comparison operations */
175 enum class ComparisonOperation
176 {
177     Equal,        /**< Equal comparison ( \f$ x == y \f$ ) */
178     NotEqual,     /**< NotEqual comparison ( \f$ x != y \f$ ) */
179     Greater,      /**< Greater comparison ( \f$ x > y \f$ ) */
180     GreaterEqual, /**< Greater equal comparison ( \f$ x >= y \f$ ) */
181     Less,         /**< Less comparison ( \f$ x < y \f$ ) */
182     LessEqual     /**< Less equal comparison ( \f$ x <= y \f$ ) */
183 };
184 
185 /** Container for valid region of a window */
186 struct ValidRegion
187 {
188     /** Default constructor */
ValidRegionValidRegion189     ValidRegion()
190         : anchor{}, shape{}
191     {
192     }
193 
194     /** Allow instances of this class to be copy constructed */
195     ValidRegion(const ValidRegion &) = default;
196     /** Allow instances of this class to be move constructed */
197     ValidRegion(ValidRegion &&) = default;
198     /** Allow instances of this class to be copied */
199     ValidRegion &operator=(const ValidRegion &) = default;
200     /** Allow instances of this class to be moved */
201     ValidRegion &operator=(ValidRegion &&) = default;
202     /** Default destructor */
203     ~ValidRegion() = default;
204 
205     /** Constructor for a valid region with default number of dimensions
206      *
207      * @param[in] an_anchor Anchor for the start of the valid region.
208      * @param[in] a_shape   Shape of the valid region.
209      *
210      */
ValidRegionValidRegion211     ValidRegion(const Coordinates &an_anchor, const TensorShape &a_shape)
212         : anchor{ an_anchor }, shape{ a_shape }
213     {
214         anchor.set_num_dimensions(std::max(anchor.num_dimensions(), shape.num_dimensions()));
215     }
216 
217     /** Constructor for a valid region with specified number of dimensions
218      *
219      * @param[in] an_anchor      Anchor for the start of the valid region.
220      * @param[in] a_shape        Shape of the valid region.
221      * @param[in] num_dimensions Number of dimensions (must be >= number of dimensions of anchor and shape).
222      *
223      */
ValidRegionValidRegion224     ValidRegion(const Coordinates &an_anchor, const TensorShape &a_shape, size_t num_dimensions)
225         : anchor{ an_anchor }, shape{ a_shape }
226     {
227         ARM_COMPUTE_ERROR_ON(num_dimensions < std::max(anchor.num_dimensions(), shape.num_dimensions()));
228         anchor.set_num_dimensions(num_dimensions);
229     }
230 
231     /** Return the start of the valid region for the given dimension @p d */
startValidRegion232     int start(unsigned int d) const
233     {
234         return anchor[d];
235     }
236 
237     /** Return the end of the valid region for the given dimension @p d */
endValidRegion238     int end(unsigned int d) const
239     {
240         return anchor[d] + shape[d];
241     }
242 
243     /** Accessor to set the value of anchor and shape for one of the dimensions.
244      *
245      * @param[in] dimension Dimension for which the value is set.
246      * @param[in] start     Value to be set in anchor for the dimension.
247      * @param[in] size      Value to be set in shape for the dimension.
248      *
249      * @return *this.
250      */
setValidRegion251     ValidRegion &set(size_t dimension, int start, size_t size)
252     {
253         anchor.set(dimension, start);
254         shape.set(dimension, size);
255         return *this;
256     }
257 
258     /** Check whether two valid regions are equal.
259      *
260      * @param[in] lhs LHS valid region
261      * @param[in] rhs RHS valid region
262      *
263      * @return True if the valid regions are the same.
264      */
265     inline friend bool operator==(const ValidRegion &lhs, const ValidRegion &rhs);
266 
267     Coordinates anchor; /**< Anchor for the start of the valid region. */
268     TensorShape shape;  /**< Shape of the valid region. */
269 };
270 inline bool operator==(const ValidRegion &lhs, const ValidRegion &rhs)
271 {
272     return (lhs.anchor == rhs.anchor) && (lhs.shape == rhs.shape);
273 }
274 
275 /** Methods available to handle borders */
276 enum class BorderMode
277 {
278     UNDEFINED, /**< Borders are left undefined */
279     CONSTANT,  /**< Pixels outside the image are assumed to have a constant value */
280     REPLICATE  /**< Pixels outside the image are assumed to have the same value as the closest image pixel */
281 };
282 
283 /** Container for 2D border size */
284 struct BorderSize
285 {
286     /** Empty border, i.e. no border */
BorderSizeBorderSize287     constexpr BorderSize() noexcept
288         : top{ 0 },
289     right{ 0 },
290     bottom{ 0 },
291     left{ 0 }
292     {
293     }
294 
295     /** Border with equal size around the 2D plane */
BorderSizeBorderSize296     explicit constexpr BorderSize(unsigned int size) noexcept
297         : top{ size },
298     right{ size },
299     bottom{ size },
300     left{ size }
301     {
302     }
303 
304     /** Border with same size for top/bottom and left/right */
BorderSizeBorderSize305     constexpr BorderSize(unsigned int top_bottom, unsigned int left_right)
306         : top{ top_bottom }, right{ left_right }, bottom{ top_bottom }, left{ left_right }
307     {
308     }
309 
310     /** Border with different sizes */
BorderSizeBorderSize311     constexpr BorderSize(unsigned int top, unsigned int right, unsigned int bottom, unsigned int left)
312         : top{ top }, right{ right }, bottom{ bottom }, left{ left }
313     {
314     }
315 
316     /** Check if the entire border is zero */
emptyBorderSize317     constexpr bool empty() const
318     {
319         return top == 0 && right == 0 && bottom == 0 && left == 0;
320     }
321 
322     /** Check if the border is the same size on all sides */
uniformBorderSize323     constexpr bool uniform() const
324     {
325         return top == right && top == bottom && top == left;
326     }
327 
328     /** Scale this border size.
329      *
330      * @param[in] scale Scale to multiply border size by.
331      *
332      * @return *this.
333      */
334     BorderSize &operator*=(float scale)
335     {
336         top *= scale;
337         right *= scale;
338         bottom *= scale;
339         left *= scale;
340 
341         return *this;
342     }
343 
344     /** Scale a copy of this border size.
345      *
346      * @param[in] scale Scale to multiply border size by.
347      *
348      * @return a scaled copy of this.
349      */
350     BorderSize operator*(float scale)
351     {
352         BorderSize size = *this;
353         size *= scale;
354 
355         return size;
356     }
357 
358     /** Check equality with another BorderSize struct
359      *
360      * @param[in] rhs other struct to check against
361      *
362      * @return true if they are equal
363      */
364     bool operator==(const BorderSize &rhs) const
365     {
366         return (top == rhs.top) && (right == rhs.right) && (bottom == rhs.bottom) && (left == rhs.left);
367     }
368 
369     /** Check non-equality with another BorderSize struct
370      *
371      * @param[in] rhs other struct to check against
372      *
373      * @return true if they are different
374      */
375     bool operator!=(const BorderSize &rhs) const
376     {
377         return !(*this == rhs);
378     }
379 
380     /** Limit this border size.
381      *
382      * @param[in] limit Border size to limit this border size to.
383      */
limitBorderSize384     void limit(const BorderSize &limit)
385     {
386         top    = std::min(top, limit.top);
387         right  = std::min(right, limit.right);
388         bottom = std::min(bottom, limit.bottom);
389         left   = std::min(left, limit.left);
390     }
391 
392     unsigned int top;    /**< top of the border */
393     unsigned int right;  /**< right of the border */
394     unsigned int bottom; /**< bottom of the border */
395     unsigned int left;   /**< left of the border */
396 };
397 
398 /** Container for 2D padding size */
399 using PaddingSize = BorderSize;
400 
401 /** Policy to handle integer overflow
402  *  @note: This is ignored by floating point operations where the overflow behavior adheres to the IEEE-754 standard
403  *         which states that in case of overflow ±infinity is returned for the round-to-nearest modes (and follows the
404  *         rounding rules for the directed rounding modes) by default.
405  */
406 enum class ConvertPolicy
407 {
408     WRAP,    /**< Wrap around */
409     SATURATE /**< Saturate */
410 };
411 
412 /** Interpolation method */
413 enum class InterpolationPolicy
414 {
415     NEAREST_NEIGHBOR, /**< Output values are defined to match the source pixel whose center is nearest to the sample position */
416     BILINEAR,         /**< Output values are defined by bilinear interpolation between the pixels */
417     AREA,             /**< Output values are determined by averaging the source pixels whose areas fall under the area of the destination pixel, projected onto the source image */
418 };
419 
420 /** Bilinear Interpolation method used by LKTracker */
421 enum class BilinearInterpolation
422 {
423     BILINEAR_OLD_NEW, /**< Old-new method */
424     BILINEAR_SCHARR   /**< Scharr method */
425 };
426 
427 /** Rectangle type */
428 struct Rectangle
429 {
430     uint16_t x;      /**< Top-left x coordinate */
431     uint16_t y;      /**< Top-left y coordinate */
432     uint16_t width;  /**< Width of the rectangle */
433     uint16_t height; /**< Height of the rectangle */
434 };
435 
436 /** Coordinate type */
437 struct Coordinates2D
438 {
439     int32_t x; /**< X coordinates */
440     int32_t y; /**< Y coordinates */
441 };
442 
443 /** Coordinate type */
444 struct Coordinates3D
445 {
446     uint32_t x; /**< X coordinates */
447     uint32_t y; /**< Y coordinates */
448     uint32_t z; /**< Z coordinates */
449 };
450 
451 /** Padding information as a pair of unsigned int start/end */
452 using PaddingInfo = std::pair<uint32_t, uint32_t>;
453 
454 /** List of padding information */
455 using PaddingList = std::vector<PaddingInfo>;
456 
457 /** Information to produce a tiled version of a Tensor */
458 using Multiples = std::vector<uint32_t>;
459 
460 /** Available channels */
461 enum class Channel
462 {
463     UNKNOWN, /** Unknown channel format */
464     C0,      /**< First channel (used by formats with unknown channel types). */
465     C1,      /**< Second channel (used by formats with unknown channel types). */
466     C2,      /**< Third channel (used by formats with unknown channel types). */
467     C3,      /**< Fourth channel (used by formats with unknown channel types). */
468     R,       /**< Red channel. */
469     G,       /**< Green channel. */
470     B,       /**< Blue channel. */
471     A,       /**< Alpha channel. */
472     Y,       /**< Luma channel. */
473     U,       /**< Cb/U channel. */
474     V        /**< Cr/V/Value channel. */
475 };
476 
477 /** Available reduction operations */
478 enum class ReductionOperation
479 {
480     ARG_IDX_MAX, /**< Index of the max value */
481     ARG_IDX_MIN, /**< Index of the min value */
482     MEAN_SUM,    /**< Mean of sum */
483     PROD,        /**< Product */
484     SUM_SQUARE,  /**< Sum of squares */
485     SUM,         /**< Sum */
486     MIN,         /**< Min */
487     MAX,         /**< Max */
488 };
489 
490 /** Available element-wise operations */
491 enum class ArithmeticOperation
492 {
493     ADD,          /**< (x + y) */
494     SUB,          /**< (x  - y) */
495     DIV,          /**< (x / y) */
496     MIN,          /**< Min(x, y) */
497     MAX,          /**< Max(x, y) */
498     SQUARED_DIFF, /**< (x - y)^2 */
499     POWER,        /**< x ^ y */
500     PRELU,        /**< y*x if x < 0, x otherwise */
501 };
502 
503 /** Available element wise unary operations */
504 enum class ElementWiseUnary
505 {
506     RSQRT,       /**< Reverse square root */
507     EXP,         /**< Exponential */
508     NEG,         /**< Negate */
509     LOG,         /**< Natural Logarithm */
510     ABS,         /**< Absolute value */
511     SIN,         /**< Sine */
512     ROUND,       /**< Round */
513     LOGICAL_NOT, /**< Logical Not */
514 };
515 
516 /** Available bitwise operations */
517 enum class BitwiseOperation
518 {
519     AND, /**< Bitwise AND operation */
520     NOT, /**< Bitwise NOT operation */
521     OR,  /**< Bitwise OR operation  */
522     XOR, /**< Bitwise XOR operation  */
523 };
524 
525 /** The normalization type used for the normalization layer */
526 enum class NormType
527 {
528     IN_MAP_1D, /**< Normalization applied within the same map in 1D region */
529     IN_MAP_2D, /**< Normalization applied within the same map in 2D region */
530     CROSS_MAP  /**< Normalization applied cross maps */
531 };
532 
533 /** Detection window used for the object detection. The detection window keeps the following information:
534  *
535  *  -# Geometry of the rectangular window (x/y of top-left corner and width/height)
536  *  -# Index of the class used for evaluating which class the detection window belongs to
537  *  -# Confidence value (score) obtained with the classifier
538  */
539 struct DetectionWindow
540 {
541     uint16_t x{ 0 };         /**< Top-left x coordinate */
542     uint16_t y{ 0 };         /**< Top-left y coordinate */
543     uint16_t width{ 0 };     /**< Width of the detection window */
544     uint16_t height{ 0 };    /**< Height of the detection window */
545     uint16_t idx_class{ 0 }; /**< Index of the class */
546     float    score{ 0.f };   /**< Confidence value for the detection window */
547 };
548 
549 /** Dimension rounding type when down-scaling on CNNs
550  * @note Used in pooling and convolution layer
551  */
552 enum class DimensionRoundingType
553 {
554     FLOOR, /**< Floor rounding */
555     CEIL   /**< Ceil rounding */
556 };
557 
558 /** Available pooling types */
559 enum class PoolingType
560 {
561     MAX, /**< Max Pooling */
562     AVG, /**< Average Pooling */
563     L2   /**< L2 Pooling */
564 };
565 
566 /** Available non maxima suppression types */
567 enum class NMSType
568 {
569     LINEAR,   /**< Linear NMS */
570     GAUSSIAN, /**< Gaussian NMS */
571     ORIGINAL  /**< Original NMS */
572 };
573 
574 /** BoxWithNonMaximaSuppressionLimit Information class */
575 class BoxNMSLimitInfo final
576 {
577 public:
578     /** Constructor
579      *
580      * @param[in] score_thresh             (Optional) Score threshold.
581      * @param[in] nms                      (Optional) NMS value
582      * @param[in] detections               (Optional) Number of detections
583      * @param[in] soft_nms_enabled         (Optional) Enable SoftNMS
584      * @param[in] soft_nms_method          (Optional) Soft NMS method
585      * @param[in] soft_nms_sigma           (Optional) Soft NMS sigma value
586      * @param[in] soft_nms_min_score_thres (Optional) Soft NMS minimum score threshold
587      * @param[in] suppress_size            (Optional) Filter out boxes based on their size. Defaults to false
588      * @param[in] min_size                 (Optional) Smaller boxes than min_size will be filtered out. Defaults to 1
589      * @param[in] im_width                 (Optional) Boxes whose centers (on the x axis) is beyond im_width will be filtered. Defaults to 1
590      * @param[in] im_height                (Optional) Boxes whose centers (on the y axis) is beyond im_height will be filtered. Defaults to 1
591      */
592     BoxNMSLimitInfo(float score_thresh = 0.05f, float nms = 0.3f,
593                     int detections = 100, bool soft_nms_enabled = false,
594                     NMSType soft_nms_method = NMSType::LINEAR,
595                     float soft_nms_sigma = 0.5f, float soft_nms_min_score_thres = 0.001f, bool suppress_size = false, float min_size = 1.0f, float im_width = 1.0f, float im_height = 1.0f)
_score_thresh(score_thresh)596         : _score_thresh(score_thresh), _nms(nms), _detections_per_im(detections), _soft_nms_enabled(soft_nms_enabled), _soft_nms_method(soft_nms_method), _soft_nms_sigma(soft_nms_sigma),
597           _soft_nms_min_score_thres(soft_nms_min_score_thres), _suppress_size(suppress_size), _min_size(min_size), _im_width(im_width), _im_height(im_height)
598     {
599     }
600     /** Get the score threshold */
score_thresh()601     float score_thresh() const
602     {
603         return _score_thresh;
604     }
605     /** Get the NMS */
nms()606     float nms() const
607     {
608         return _nms;
609     }
610     /** Get the number of detections */
detections_per_im()611     int detections_per_im() const
612     {
613         return _detections_per_im;
614     }
615     /** Check if soft NMS is enabled */
soft_nms_enabled()616     bool soft_nms_enabled() const
617     {
618         return _soft_nms_enabled;
619     }
620     /** Get soft NMS method */
soft_nms_method()621     NMSType soft_nms_method() const
622     {
623         return _soft_nms_method;
624     }
625     /** Get soft NMS sigma */
soft_nms_sigma()626     float soft_nms_sigma() const
627     {
628         return _soft_nms_sigma;
629     }
630     /** Get soft nms min score threshold */
soft_nms_min_score_thres()631     float soft_nms_min_score_thres() const
632     {
633         return _soft_nms_min_score_thres;
634     }
635     /** Get if NMS will suppress boxes based on their size/position */
suppress_size()636     bool suppress_size() const
637     {
638         return _suppress_size;
639     }
640     /** Get size suppression threshold */
min_size()641     float min_size() const
642     {
643         return _min_size;
644     }
645     /** Get image width (NMS may suppress boxes whose center sits beyond the image width) */
im_width()646     float im_width() const
647     {
648         return _im_width;
649     }
650     /** Get image height (NMS may suppress boxes whose center sits beyond the image height) */
im_height()651     float im_height() const
652     {
653         return _im_height;
654     }
655 
656 private:
657     float   _score_thresh;
658     float   _nms;
659     int     _detections_per_im;
660     bool    _soft_nms_enabled;
661     NMSType _soft_nms_method;
662     float   _soft_nms_sigma;
663     float   _soft_nms_min_score_thres;
664     bool    _suppress_size;
665     float   _min_size;
666     float   _im_width;
667     float   _im_height;
668 };
669 
670 /** Padding and stride information class */
671 class PadStrideInfo
672 {
673 public:
674     /** Constructor
675      *
676      * @param[in] stride_x (Optional) Stride, in elements, across x. Defaults to 1.
677      * @param[in] stride_y (Optional) Stride, in elements, across y. Defaults to 1.
678      * @param[in] pad_x    (Optional) Padding, in elements, across x. Defaults to 0.
679      * @param[in] pad_y    (Optional) Padding, in elements, across y. Defaults to 0.
680      * @param[in] round    (Optional) Dimensions rounding. Defaults to @ref FLOOR.
681      */
682     PadStrideInfo(unsigned int stride_x = 1, unsigned int stride_y = 1,
683                   unsigned int pad_x = 0, unsigned int pad_y = 0,
684                   DimensionRoundingType round = DimensionRoundingType::FLOOR)
_stride(std::make_pair (stride_x,stride_y))685         : _stride(std::make_pair(stride_x, stride_y)),
686           _pad_left(pad_x),
687           _pad_top(pad_y),
688           _pad_right(pad_x),
689           _pad_bottom(pad_y),
690           _round_type(round)
691     {
692     }
693     /** Constructor
694      *
695      * @param[in] stride_x   Stride, in elements, across x.
696      * @param[in] stride_y   Stride, in elements, across y.
697      * @param[in] pad_left   Padding across x on the left, in elements.
698      * @param[in] pad_right  Padding across x on the right, in elements.
699      * @param[in] pad_top    Padding across y on the top, in elements.
700      * @param[in] pad_bottom Padding across y on the bottom, in elements.
701      * @param[in] round      Dimensions rounding.
702      */
PadStrideInfo(unsigned int stride_x,unsigned int stride_y,unsigned int pad_left,unsigned int pad_right,unsigned int pad_top,unsigned int pad_bottom,DimensionRoundingType round)703     PadStrideInfo(unsigned int stride_x, unsigned int stride_y,
704                   unsigned int pad_left, unsigned int pad_right,
705                   unsigned int pad_top, unsigned int pad_bottom,
706                   DimensionRoundingType round)
707         : _stride(std::make_pair(stride_x, stride_y)),
708           _pad_left(pad_left),
709           _pad_top(pad_top),
710           _pad_right(pad_right),
711           _pad_bottom(pad_bottom),
712           _round_type(round)
713     {
714     }
715     /** Get the stride.
716      *
717      * @return a pair: stride x, stride y.
718      */
stride()719     std::pair<unsigned int, unsigned int> stride() const
720     {
721         return _stride;
722     }
723     /** Check whether the padding is symmetric.
724      *
725      * @return True if the padding is symmetric.
726      */
padding_is_symmetric()727     bool padding_is_symmetric() const
728     {
729         return (_pad_left == _pad_right) && (_pad_top == _pad_bottom);
730     }
731     /** Get the padding.
732      *
733      * @note This should only be used when the padding is symmetric.
734      *
735      * @return a pair: padding left/right, padding top/bottom
736      */
pad()737     std::pair<unsigned int, unsigned int> pad() const
738     {
739         //this accessor should be used only when padding is symmetric
740         ARM_COMPUTE_ERROR_ON(!padding_is_symmetric());
741         return std::make_pair(_pad_left, _pad_top);
742     }
743 
744     /** Get the left padding */
pad_left()745     unsigned int pad_left() const
746     {
747         return _pad_left;
748     }
749     /** Get the right padding */
pad_right()750     unsigned int pad_right() const
751     {
752         return _pad_right;
753     }
754     /** Get the top padding */
pad_top()755     unsigned int pad_top() const
756     {
757         return _pad_top;
758     }
759     /** Get the bottom padding */
pad_bottom()760     unsigned int pad_bottom() const
761     {
762         return _pad_bottom;
763     }
764 
765     /** Get the rounding type */
round()766     DimensionRoundingType round() const
767     {
768         return _round_type;
769     }
770 
771     /** Check whether this has any padding */
has_padding()772     bool has_padding() const
773     {
774         return (_pad_left != 0 || _pad_top != 0 || _pad_right != 0 || _pad_bottom != 0);
775     }
776 
777 private:
778     std::pair<unsigned int, unsigned int> _stride;
779     unsigned int _pad_left;
780     unsigned int _pad_top;
781     unsigned int _pad_right;
782     unsigned int _pad_bottom;
783 
784     DimensionRoundingType _round_type;
785 };
786 
787 /** Padding information for 2D operations like Conv2d */
788 struct Padding2D
789 {
790     Padding2D() = default;
Padding2DPadding2D791     Padding2D(size_t left, size_t right, size_t top, size_t bottom)
792         : left(left), right(right), top(top), bottom(bottom)
793     {
794     }
795     size_t left   = { 0 }; /**<  Padding across the width dimension on the left, in elements. */
796     size_t right  = { 0 }; /**<  Padding across the width dimension on the right, in elements. */
797     size_t top    = { 0 }; /**<  Padding across the height dimension on the top, in elements. */
798     size_t bottom = { 0 }; /**<  Padding across the height dimension on the bottom, in elements. */
799 };
800 
801 /** Padding information for 3D operations like Conv3d */
802 struct Padding3D
803 {
Padding3DPadding3D804     Padding3D() noexcept
805     {
806     }
807 
Padding3DPadding3D808     Padding3D(size_t pad_x, size_t pad_y, size_t pad_z)
809         : left(pad_x), right(pad_x), top(pad_y), bottom(pad_y), front(pad_z), back(pad_z)
810     {
811     }
812 
Padding3DPadding3D813     Padding3D(size_t left, size_t right, size_t top, size_t bottom, size_t front, size_t back)
814         : left(left), right(right), top(top), bottom(bottom), front(front), back(back)
815     {
816     }
817 
818     size_t left   = { 0 }; /**<  Padding across the width dimenstion on the left, in elements. */
819     size_t right  = { 0 }; /**<  Padding across the width dimenstion on the right, in elements. */
820     size_t top    = { 0 }; /**<  Padding across the height dimenstion  on the top, in elements. */
821     size_t bottom = { 0 }; /**<  Padding across the height dimenstion on the bottom, in elements. */
822     size_t front  = { 0 }; /**<  Padding across the depth dimenstion on the front, in elements. */
823     size_t back   = { 0 }; /**<  Padding across the depth dimenstion on the back, in elements. */
824 };
825 
826 /** PriorBox layer info */
827 class PriorBoxLayerInfo final
828 {
829 public:
830     /** Default Constructor */
PriorBoxLayerInfo()831     PriorBoxLayerInfo()
832         : _min_sizes(),
833           _variances(),
834           _offset(),
835           _flip(true),
836           _clip(false),
837           _max_sizes(),
838           _aspect_ratios(),
839           _img_size(),
840           _steps()
841     {
842     }
843     /** Constructor
844      *
845      * @param[in] min_sizes     Min sizes vector.
846      * @param[in] variances     Variances vector.
847      * @param[in] offset        Offset value.
848      * @param[in] flip          (Optional) Flip the aspect ratios.
849      * @param[in] clip          (Optional) Clip coordinates so that they're within [0,1].
850      * @param[in] max_sizes     (Optional) Max sizes vector.
851      * @param[in] aspect_ratios (Optional) Aspect ratios of the boxes.
852      * @param[in] img_size      (Optional) Image size.
853      * @param[in] steps         (Optional) Step values.
854      */
855     PriorBoxLayerInfo(const std::vector<float> &min_sizes, const std::vector<float> &variances, float offset, bool flip = true, bool clip = false,
856                       const std::vector<float> &max_sizes = {}, const std::vector<float> &aspect_ratios = {},
857     const Coordinates2D &img_size = Coordinates2D{ 0, 0 }, const std::array<float, 2> &steps = { { 0.f, 0.f } })
_min_sizes(min_sizes)858         : _min_sizes(min_sizes),
859           _variances(variances),
860           _offset(offset),
861           _flip(flip),
862           _clip(clip),
863           _max_sizes(max_sizes),
864           _aspect_ratios(),
865           _img_size(img_size),
866           _steps(steps)
867     {
868         _aspect_ratios.push_back(1.);
869         for(unsigned int i = 0; i < aspect_ratios.size(); ++i)
870         {
871             float ar            = aspect_ratios[i];
872             bool  already_exist = false;
873             for(auto ar_new : _aspect_ratios)
874             {
875                 if(fabs(ar - ar_new) < 1e-6)
876                 {
877                     already_exist = true;
878                     break;
879                 }
880             }
881             if(!already_exist)
882             {
883                 _aspect_ratios.push_back(ar);
884                 if(flip)
885                 {
886                     _aspect_ratios.push_back(1.f / ar);
887                 }
888             }
889         }
890     }
891     /** Get min sizes. */
min_sizes()892     std::vector<float> min_sizes() const
893     {
894         return _min_sizes;
895     }
896     /** Get min variances. */
variances()897     std::vector<float> variances() const
898     {
899         return _variances;
900     }
901     /** Get the step coordinates */
steps()902     std::array<float, 2> steps() const
903     {
904         return _steps;
905     }
906     /** Get the image size coordinates */
img_size()907     Coordinates2D img_size() const
908     {
909         return _img_size;
910     }
911     /** Get the offset */
offset()912     float offset() const
913     {
914         return _offset;
915     }
916     /** Get the flip value */
flip()917     bool flip() const
918     {
919         return _flip;
920     }
921     /** Get the clip value */
clip()922     bool clip() const
923     {
924         return _clip;
925     }
926     /** Get max sizes. */
max_sizes()927     std::vector<float> max_sizes() const
928     {
929         return _max_sizes;
930     }
931     /** Get aspect ratios. */
aspect_ratios()932     std::vector<float> aspect_ratios() const
933     {
934         return _aspect_ratios;
935     }
936 
937 private:
938     std::vector<float> _min_sizes;
939     std::vector<float> _variances;
940     float              _offset;
941     bool               _flip;
942     bool               _clip;
943     std::vector<float> _max_sizes;
944     std::vector<float> _aspect_ratios;
945     Coordinates2D      _img_size;
946     std::array<float, 2> _steps;
947 };
948 
949 // Bounding Box [xmin, ymin, xmax, ymax]
950 using BBox = std::array<float, 4>;
951 // LabelBBox used for map label and bounding box
952 using LabelBBox = std::map<int, std::vector<BBox>>;
953 
954 /** Available Detection Output code types */
955 enum class DetectionOutputLayerCodeType
956 {
957     CORNER,      /**< Use box corners */
958     CENTER_SIZE, /**< Use box centers and size */
959     CORNER_SIZE, /**< Use box centers and size */
960     TF_CENTER    /**< Use box centers and size but flip x and y co-ordinates */
961 };
962 
963 /** Detection Output layer info */
964 class DetectionOutputLayerInfo final
965 {
966 public:
967     /** Default Constructor */
DetectionOutputLayerInfo()968     DetectionOutputLayerInfo()
969         : _num_classes(),
970           _share_location(),
971           _code_type(DetectionOutputLayerCodeType::CORNER),
972           _keep_top_k(),
973           _nms_threshold(),
974           _top_k(),
975           _background_label_id(),
976           _confidence_threshold(),
977           _variance_encoded_in_target(false),
978           _eta(),
979           _num_loc_classes()
980     {
981         _num_loc_classes = _share_location ? 1 : _num_classes;
982     }
983     /** Constructor
984      *
985      * @param[in] num_classes                Number of classes to be predicted.
986      * @param[in] share_location             If true, bounding box are shared among different classes.
987      * @param[in] code_type                  Type of coding method for bbox.
988      * @param[in] keep_top_k                 Number of total bounding boxes to be kept per image after NMS step.
989      * @param[in] nms_threshold              Threshold to be used in NMS.
990      * @param[in] top_k                      (Optional) Number of boxes per image with top confidence scores that are fed into the NMS algorithm. Default set to -1.
991      * @param[in] background_label_id        (Optional) Background label ID. If there is no background class, set it as -1.
992      * @param[in] confidence_threshold       (Optional) Only consider detections whose confidences are larger than a threshold. Default set to -FLT_MAX.
993      * @param[in] variance_encoded_in_target (Optional) If true, variance is encoded in target. Otherwise we need to adjust the predicted offset accordingly.Default set to false.
994      * @param[in] eta                        (Optional) Eta.
995      */
996     DetectionOutputLayerInfo(int num_classes, bool share_location, DetectionOutputLayerCodeType code_type, int keep_top_k, float nms_threshold, int top_k = -1, int background_label_id = -1,
997                              float confidence_threshold = std::numeric_limits<float>::lowest(), bool variance_encoded_in_target = false, float eta = 1)
_num_classes(num_classes)998         : _num_classes(num_classes),
999           _share_location(share_location),
1000           _code_type(code_type),
1001           _keep_top_k(keep_top_k),
1002           _nms_threshold(nms_threshold),
1003           _top_k(top_k),
1004           _background_label_id(background_label_id),
1005           _confidence_threshold(confidence_threshold),
1006           _variance_encoded_in_target(variance_encoded_in_target),
1007           _eta(eta),
1008           _num_loc_classes()
1009     {
1010         _num_loc_classes = _share_location ? 1 : _num_classes;
1011     }
1012     /** Get num classes. */
num_classes()1013     int num_classes() const
1014     {
1015         return _num_classes;
1016     }
1017     /** Get share location. */
share_location()1018     bool share_location() const
1019     {
1020         return _share_location;
1021     }
1022     /** Get detection output code type. */
code_type()1023     DetectionOutputLayerCodeType code_type() const
1024     {
1025         return _code_type;
1026     }
1027     /** Get if variance encoded in target. */
variance_encoded_in_target()1028     bool variance_encoded_in_target() const
1029     {
1030         return _variance_encoded_in_target;
1031     }
1032     /** Get the number of total bounding boxes to be kept per image. */
keep_top_k()1033     int keep_top_k() const
1034     {
1035         return _keep_top_k;
1036     }
1037     /** Get nms threshold. */
nms_threshold()1038     float nms_threshold() const
1039     {
1040         return _nms_threshold;
1041     }
1042     /** Get eta. */
eta()1043     float eta() const
1044     {
1045         return _eta;
1046     }
1047     /** Get background label ID. */
background_label_id()1048     int background_label_id() const
1049     {
1050         return _background_label_id;
1051     }
1052     /** Get confidence threshold. */
confidence_threshold()1053     float confidence_threshold() const
1054     {
1055         return _confidence_threshold;
1056     }
1057     /** Get top K. */
top_k()1058     int top_k() const
1059     {
1060         return _top_k;
1061     }
1062     /** Get number of location classes. */
num_loc_classes()1063     int num_loc_classes() const
1064     {
1065         return _num_loc_classes;
1066     }
1067 
1068 private:
1069     int                          _num_classes;
1070     bool                         _share_location;
1071     DetectionOutputLayerCodeType _code_type;
1072     int                          _keep_top_k;
1073     float                        _nms_threshold;
1074     int                          _top_k;
1075     int                          _background_label_id;
1076     float                        _confidence_threshold;
1077     bool                         _variance_encoded_in_target;
1078     float                        _eta;
1079     int                          _num_loc_classes;
1080 };
1081 
1082 /** Detection Output layer info */
1083 class DetectionPostProcessLayerInfo final
1084 {
1085 public:
1086     /** Default Constructor */
DetectionPostProcessLayerInfo()1087     DetectionPostProcessLayerInfo()
1088         : _max_detections(),
1089           _max_classes_per_detection(),
1090           _nms_score_threshold(),
1091           _iou_threshold(),
1092           _num_classes(),
1093           _scales_values(),
1094           _use_regular_nms(),
1095           _detection_per_class(),
1096           _dequantize_scores()
1097     {
1098     }
1099     /** Constructor
1100      *
1101      * @param[in] max_detections            Number of total detection.
1102      * @param[in] max_classes_per_detection Number of total classes to be kept after NMS step. Used in the Fast Non-Max-Suppression
1103      * @param[in] nms_score_threshold       Threshold to be used in NMS
1104      * @param[in] iou_threshold             Threshold to be used during the intersection over union.
1105      * @param[in] num_classes               Number of classes.
1106      * @param[in] scales_values             Scales values used for decode center size boxes.
1107      * @param[in] use_regular_nms           (Optional) Boolean to determinate if use regular or fast nms. Defaults to false.
1108      * @param[in] detection_per_class       (Optional) Number of detection per class. Used in the Regular Non-Max-Suppression. Defaults to 100.
1109      * @param[in] dequantize_scores         (Optional) If the scores need to be dequantized. Defaults to true.
1110      */
1111     DetectionPostProcessLayerInfo(unsigned int max_detections, unsigned int max_classes_per_detection, float nms_score_threshold, float iou_threshold, unsigned int num_classes,
1112                                   std::array<float, 4> scales_values, bool use_regular_nms = false, unsigned int detection_per_class = 100, bool dequantize_scores = true)
_max_detections(max_detections)1113         : _max_detections(max_detections),
1114           _max_classes_per_detection(max_classes_per_detection),
1115           _nms_score_threshold(nms_score_threshold),
1116           _iou_threshold(iou_threshold),
1117           _num_classes(num_classes),
1118           _scales_values(scales_values),
1119           _use_regular_nms(use_regular_nms),
1120           _detection_per_class(detection_per_class),
1121           _dequantize_scores(dequantize_scores)
1122     {
1123     }
1124     /** Get max detections. */
max_detections()1125     unsigned int max_detections() const
1126     {
1127         return _max_detections;
1128     }
1129     /** Get max_classes per detection. Used in the Fast Non-Max-Suppression.*/
max_classes_per_detection()1130     unsigned int max_classes_per_detection() const
1131     {
1132         return _max_classes_per_detection;
1133     }
1134     /** Get detection per class. Used in the Regular Non-Max-Suppression */
detection_per_class()1135     unsigned int detection_per_class() const
1136     {
1137         return _detection_per_class;
1138     }
1139     /** Get nms threshold. */
nms_score_threshold()1140     float nms_score_threshold() const
1141     {
1142         return _nms_score_threshold;
1143     }
1144     /** Get intersection over union threshold. */
iou_threshold()1145     float iou_threshold() const
1146     {
1147         return _iou_threshold;
1148     }
1149     /** Get num classes. */
num_classes()1150     unsigned int num_classes() const
1151     {
1152         return _num_classes;
1153     }
1154     /** Get if use regular nms. */
use_regular_nms()1155     bool use_regular_nms() const
1156     {
1157         return _use_regular_nms;
1158     }
1159     /** Get y scale value. */
scale_value_y()1160     float scale_value_y() const
1161     {
1162         // Saved as [y,x,h,w]
1163         return _scales_values[0];
1164     }
1165     /** Get x scale value. */
scale_value_x()1166     float scale_value_x() const
1167     {
1168         // Saved as [y,x,h,w]
1169         return _scales_values[1];
1170     }
1171     /** Get h scale value. */
scale_value_h()1172     float scale_value_h() const
1173     {
1174         // Saved as [y,x,h,w]
1175         return _scales_values[2];
1176     }
1177     /** Get w scale value. */
scale_value_w()1178     float scale_value_w() const
1179     {
1180         // Saved as [y,x,h,w]
1181         return _scales_values[3];
1182     }
1183     /** Get dequantize_scores value. */
dequantize_scores()1184     bool dequantize_scores() const
1185     {
1186         return _dequantize_scores;
1187     }
1188 
1189 private:
1190     unsigned int _max_detections;
1191     unsigned int _max_classes_per_detection;
1192     float        _nms_score_threshold;
1193     float        _iou_threshold;
1194     unsigned int _num_classes;
1195     std::array<float, 4> _scales_values;
1196     bool         _use_regular_nms;
1197     unsigned int _detection_per_class;
1198     bool         _dequantize_scores;
1199 };
1200 
1201 /** Pooling Layer Information struct*/
1202 struct PoolingLayerInfo
1203 {
1204     /** Default Constructor */
PoolingLayerInfoPoolingLayerInfo1205     PoolingLayerInfo()
1206         : pool_type(PoolingType::MAX),
1207           pool_size(Size2D()),
1208           data_layout(DataLayout::UNKNOWN),
1209           pad_stride_info(PadStrideInfo()),
1210           exclude_padding(false),
1211           is_global_pooling(false),
1212           fp_mixed_precision(false)
1213     {
1214     }
1215     /** Constructor
1216      *
1217      * @param[in] pool_type          Pooling type @ref PoolingType.
1218      * @param[in] pool_size          Pooling size, in elements, across  x and y.
1219      * @param[in] data_layout        Data layout used by the layer @ref DataLayout
1220      * @param[in] pad_stride_info    (Optional) Padding and stride information @ref PadStrideInfo
1221      * @param[in] exclude_padding    (Optional) Strategy when accounting padding in calculations.
1222      *                               True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
1223      *                               Defaults to false;
1224      * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
1225      */
1226     explicit PoolingLayerInfo(PoolingType   pool_type,
1227                               unsigned int  pool_size,
1228                               DataLayout    data_layout,
1229                               PadStrideInfo pad_stride_info    = PadStrideInfo(),
1230                               bool          exclude_padding    = false,
1231                               bool          fp_mixed_precision = false)
pool_typePoolingLayerInfo1232         : pool_type(pool_type),
1233           pool_size(Size2D(pool_size, pool_size)),
1234           data_layout(data_layout),
1235           pad_stride_info(pad_stride_info),
1236           exclude_padding(exclude_padding),
1237           is_global_pooling(false),
1238           fp_mixed_precision(fp_mixed_precision)
1239     {
1240     }
1241 
1242     /** Constructor
1243      *
1244      * @param[in] pool_type          Pooling type @ref PoolingType.
1245      * @param[in] pool_size          Pooling size, in elements, across  x and y.
1246      * @param[in] data_layout        Data layout used by the layer @ref DataLayout
1247      * @param[in] pad_stride_info    (Optional) Padding and stride information @ref PadStrideInfo
1248      * @param[in] exclude_padding    (Optional) Strategy when accounting padding in calculations.
1249      *                               True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
1250      *                               Defaults to false;
1251      * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
1252      */
1253     explicit PoolingLayerInfo(PoolingType   pool_type,
1254                               Size2D        pool_size,
1255                               DataLayout    data_layout,
1256                               PadStrideInfo pad_stride_info    = PadStrideInfo(),
1257                               bool          exclude_padding    = false,
1258                               bool          fp_mixed_precision = false)
pool_typePoolingLayerInfo1259         : pool_type(pool_type),
1260           pool_size(pool_size),
1261           data_layout(data_layout),
1262           pad_stride_info(pad_stride_info),
1263           exclude_padding(exclude_padding),
1264           is_global_pooling(false),
1265           fp_mixed_precision(fp_mixed_precision)
1266     {
1267     }
1268 
1269     /** Constructor
1270      *
1271      * @note This constructor is used for global pooling
1272      *
1273      * @param[in] pool_type   Pooling type @ref PoolingType.
1274      * @param[in] data_layout Data layout used by the layer @ref DataLayout
1275      */
PoolingLayerInfoPoolingLayerInfo1276     explicit PoolingLayerInfo(PoolingType pool_type, DataLayout data_layout)
1277         : pool_type(pool_type),
1278           pool_size(Size2D()),
1279           data_layout(data_layout),
1280           pad_stride_info(PadStrideInfo(1, 1, 0, 0)),
1281           exclude_padding(false),
1282           is_global_pooling(true),
1283           fp_mixed_precision(false)
1284     {
1285     }
1286 
1287     PoolingType   pool_type;
1288     Size2D        pool_size;
1289     DataLayout    data_layout;
1290     PadStrideInfo pad_stride_info;
1291     bool          exclude_padding;
1292     bool          is_global_pooling;
1293     bool          fp_mixed_precision;
1294 };
1295 
1296 /** Pooling Layer Information struct*/
1297 struct Pooling3dLayerInfo
1298 {
1299     /** Default Constructor */
Pooling3dLayerInfoPooling3dLayerInfo1300     Pooling3dLayerInfo() noexcept
1301         : pool_type(PoolingType::MAX),
1302           pool_size(Size3D()),
1303           stride(Size3D()),
1304           padding(Padding3D()),
1305           exclude_padding(false),
1306           is_global_pooling(false),
1307           fp_mixed_precision(false),
1308           round_type(DimensionRoundingType::FLOOR)
1309     {
1310     }
1311     /** Constructor
1312      *
1313      * @param[in] pool_type          Pooling type @ref PoolingType.
1314      * @param[in] pool_size          Pooling size, in elements, across x, y and z.
1315      * @param[in] stride             (Optional) stride information @ref Size3D
1316      * @param[in] padding            (Optional) padding information @ref Padding3D
1317      * @param[in] exclude_padding    (Optional) Strategy when accounting padding in calculations.
1318      *                               True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
1319      *                               Defaults to false;
1320      * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
1321      * @param[in] round_type         (Optional) Dimensions rounding. Defaults to @ref FLOOR
1322      */
1323     explicit Pooling3dLayerInfo(PoolingType           pool_type,
1324                                 unsigned int          pool_size,
1325                                 Size3D                stride             = Size3D(1U, 1U, 1U),
1326                                 Padding3D             padding            = Padding3D(),
1327                                 bool                  exclude_padding    = false,
1328                                 bool                  fp_mixed_precision = false,
1329                                 DimensionRoundingType round_type         = DimensionRoundingType::FLOOR)
pool_typePooling3dLayerInfo1330         : pool_type(pool_type),
1331           pool_size(Size3D(pool_size, pool_size, pool_size)),
1332           stride(stride),
1333           padding(padding),
1334           exclude_padding(exclude_padding),
1335           is_global_pooling(false),
1336           fp_mixed_precision(fp_mixed_precision),
1337           round_type(round_type)
1338     {
1339     }
1340 
1341     /** Constructor
1342      *
1343      * @param[in] pool_type          Pooling type @ref PoolingType.
1344      * @param[in] pool_size          Pooling size, in elements, across  x, y and z.
1345      * @param[in] stride             (Optional) stride information @ref Size3D
1346      * @param[in] padding            (Optional) padding information @ref Padding3D
1347      * @param[in] exclude_padding    (Optional) Strategy when accounting padding in calculations.
1348      *                               True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
1349      *                               Defaults to false;
1350      * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
1351      * @param[in] round_type         (Optional) Dimensions rounding. Defaults to @ref FLOOR
1352      */
1353     explicit Pooling3dLayerInfo(PoolingType           pool_type,
1354                                 Size3D                pool_size,
1355                                 Size3D                stride             = Size3D(1U, 1U, 1U),
1356                                 Padding3D             padding            = Padding3D(),
1357                                 bool                  exclude_padding    = false,
1358                                 bool                  fp_mixed_precision = false,
1359                                 DimensionRoundingType round_type         = DimensionRoundingType::FLOOR)
pool_typePooling3dLayerInfo1360         : pool_type(pool_type),
1361           pool_size(pool_size),
1362           stride(stride),
1363           padding(padding),
1364           exclude_padding(exclude_padding),
1365           is_global_pooling(false),
1366           fp_mixed_precision(fp_mixed_precision),
1367           round_type(round_type)
1368     {
1369     }
1370 
1371     /** Constructor
1372      *
1373      * @note This constructor is used for global pooling
1374      *
1375      * @param[in] pool_type Pooling type @ref PoolingType.
1376      */
Pooling3dLayerInfoPooling3dLayerInfo1377     explicit Pooling3dLayerInfo(PoolingType pool_type)
1378         : pool_type(pool_type),
1379           pool_size(Size3D()),
1380           stride(Size3D(1U, 1U, 1U)),
1381           padding(Padding3D(0, 0, 0)),
1382           exclude_padding(false),
1383           is_global_pooling(true),
1384           fp_mixed_precision(false),
1385           round_type(DimensionRoundingType::FLOOR)
1386     {
1387     }
1388 
1389     PoolingType           pool_type;
1390     Size3D                pool_size;
1391     Size3D                stride;
1392     Padding3D             padding;
1393     bool                  exclude_padding;
1394     bool                  is_global_pooling;
1395     bool                  fp_mixed_precision;
1396     DimensionRoundingType round_type;
1397 };
1398 
1399 /** ROI Pooling Layer Information class */
1400 class ROIPoolingLayerInfo final
1401 {
1402 public:
1403     /** Constructor
1404      *
1405      * @param[in] pooled_width   Pooled width of the layer.
1406      * @param[in] pooled_height  Pooled height of the layer.
1407      * @param[in] spatial_scale  Spatial scale to be applied to the ROI coordinates and dimensions.
1408      * @param[in] sampling_ratio Number of samples to include in each pooling region (if set to zero, a ceil(roi_dims/pooling_dims))
1409      */
1410     ROIPoolingLayerInfo(unsigned int pooled_width, unsigned int pooled_height, float spatial_scale, unsigned int sampling_ratio = 0)
_pooled_width(pooled_width)1411         : _pooled_width(pooled_width), _pooled_height(pooled_height), _spatial_scale(spatial_scale), _sampling_ratio(sampling_ratio)
1412     {
1413     }
1414     /** Get the pooled width of the layer */
pooled_width()1415     unsigned int pooled_width() const
1416     {
1417         return _pooled_width;
1418     }
1419     /** Get the pooled height of the layer */
pooled_height()1420     unsigned int pooled_height() const
1421     {
1422         return _pooled_height;
1423     }
1424     /** Get the spatial scale */
spatial_scale()1425     float spatial_scale() const
1426     {
1427         return _spatial_scale;
1428     }
1429     /** Get sampling ratio */
sampling_ratio()1430     unsigned int sampling_ratio() const
1431     {
1432         return _sampling_ratio;
1433     }
1434 
1435 private:
1436     unsigned int _pooled_width;
1437     unsigned int _pooled_height;
1438     float        _spatial_scale;
1439     unsigned int _sampling_ratio;
1440 };
1441 
1442 /** Generate Proposals Information class */
1443 class GenerateProposalsInfo
1444 {
1445 public:
1446     /** Constructor
1447      *
1448      * @param[in] im_width       Width of the original image
1449      * @param[in] im_height      Height of the original image
1450      * @param[in] im_scale       Scale applied to the original image
1451      * @param[in] spatial_scale  (Optional)Scale applied to the feature map. Defaults to 1.0
1452      * @param[in] pre_nms_topN   (Optional)Number of the best scores to be selected from the transformations. Defaults to 6000.
1453      * @param[in] post_nms_topN  (Optional)Number of the best scores to be selected from the NMS operation. Defaults to 300.
1454      * @param[in] nms_thres      (Optional)NMS overlap threshold. Defaults to 0.7.
1455      * @param[in] min_size       (Optional)Size used to validate the anchors produced. Defaults to 16.
1456      * @param[in] values_per_roi (Optional)Values used to represent a ROI(Region of interest). Defaults to 4.
1457      */
1458     GenerateProposalsInfo(float im_width, float im_height, float im_scale, float spatial_scale = 1.0, int pre_nms_topN = 6000, int post_nms_topN = 300, float nms_thres = 0.7, float min_size = 16.0,
1459                           size_t values_per_roi = 4)
_im_height(im_height)1460         : _im_height(im_height), _im_width(im_width), _im_scale(im_scale), _spatial_scale(spatial_scale), _pre_nms_topN(pre_nms_topN), _post_nms_topN(post_nms_topN), _nms_thres(nms_thres),
1461           _min_size(min_size), _values_per_roi(values_per_roi)
1462     {
1463     }
1464 
1465     /* Get the original height */
im_height()1466     float im_height() const
1467     {
1468         return _im_height;
1469     }
1470     /* Get the original width */
im_width()1471     float im_width() const
1472     {
1473         return _im_width;
1474     }
1475     /* Get the image scale */
im_scale()1476     float im_scale() const
1477     {
1478         return _im_scale;
1479     }
1480     /* Get the value of how many best scores to select (before NMS) */
pre_nms_topN()1481     int pre_nms_topN() const
1482     {
1483         return _pre_nms_topN;
1484     }
1485     /* Get the value of how many best scores to select (after NMS) */
post_nms_topN()1486     int post_nms_topN() const
1487     {
1488         return _post_nms_topN;
1489     }
1490     /* Get the NMS overlap threshold */
nms_thres()1491     float nms_thres() const
1492     {
1493         return _nms_thres;
1494     }
1495     /* Get the minimal size */
min_size()1496     float min_size() const
1497     {
1498         return _min_size;
1499     }
1500     /* Get the spatial scale to be applied to the feature maps */
spatial_scale()1501     float spatial_scale() const
1502     {
1503         return _spatial_scale;
1504     }
1505     /* Get the values used to represent a ROI(Region of interest)*/
values_per_roi()1506     size_t values_per_roi() const
1507     {
1508         return _values_per_roi;
1509     }
1510 
1511 private:
1512     float  _im_height;
1513     float  _im_width;
1514     float  _im_scale;
1515     float  _spatial_scale;
1516     int    _pre_nms_topN;
1517     int    _post_nms_topN;
1518     float  _nms_thres;
1519     float  _min_size;
1520     size_t _values_per_roi;
1521 };
1522 
1523 /** ComputeAnchors information class */
1524 class ComputeAnchorsInfo
1525 {
1526 public:
1527     /** Constructor
1528      *
1529      * @param[in] feat_width     Feature map width
1530      * @param[in] feat_height    Feature map height
1531      * @param[in] spatial_scale  Feature map scale
1532      * @param[in] values_per_roi (Optional)Values used to represent a ROI(Region Of Interest). Defaults to 4
1533      */
1534     ComputeAnchorsInfo(float feat_width, float feat_height, float spatial_scale, size_t values_per_roi = 4)
_feat_height(feat_height)1535         : _feat_height(feat_height),
1536           _feat_width(feat_width),
1537           _spatial_scale(spatial_scale),
1538           _values_per_roi(values_per_roi)
1539     {
1540     }
1541 
1542     /* Get the height of the feature map */
feat_height()1543     float feat_height() const
1544     {
1545         return _feat_height;
1546     }
1547 
1548     /* Get the width of the feature map */
feat_width()1549     float feat_width() const
1550     {
1551         return _feat_width;
1552     }
1553 
1554     /* Get the scale of the feature map */
spatial_scale()1555     float spatial_scale() const
1556     {
1557         return _spatial_scale;
1558     }
1559 
1560     /* Get the values used to represent a ROI(Region Of Interest)*/
values_per_roi()1561     size_t values_per_roi() const
1562     {
1563         return _values_per_roi;
1564     }
1565 
1566 private:
1567     float  _feat_height;
1568     float  _feat_width;
1569     float  _spatial_scale;
1570     size_t _values_per_roi;
1571 };
1572 
1573 /** Bounding Box Transform information class */
1574 class BoundingBoxTransformInfo final
1575 {
1576 public:
1577     /** Constructor
1578      *
1579      * @param[in] img_width                Width of the original image
1580      * @param[in] img_height               Height, of the original image
1581      * @param[in] scale                    Scale of the original image
1582      * @param[in] apply_scale              (Optional)Re-apply scaling after transforming the boxes. Defaults to false
1583      * @param[in] weights                  (Optional)Weights [wx, wy, ww, wh] for the deltas. Defaults to all ones
1584      * @param[in] correct_transform_coords (Optional)Correct bounding box transform coordinates. Defaults to false
1585      * @param[in] bbox_xform_clip          (Optional)Minimum bounding box width and height after bounding box transformation in log-space. Defaults to log(1000/16)
1586      */
1587     BoundingBoxTransformInfo(float img_width, float img_height, float scale, bool apply_scale = false, const std::array<float, 4> weights = { { 1.f, 1.f, 1.f, 1.f } }, bool correct_transform_coords =
1588     false,
1589     float bbox_xform_clip =
1590         4.135166556742356f)
_img_width(img_width)1591         : _img_width(img_width), _img_height(img_height), _scale(scale), _apply_scale(apply_scale), _correct_transform_coords(correct_transform_coords), _weights(weights), _bbox_xform_clip(bbox_xform_clip)
1592     {
1593     }
1594 
weights()1595     std::array<float, 4> weights() const
1596     {
1597         return _weights;
1598     }
1599 
bbox_xform_clip()1600     float bbox_xform_clip() const
1601     {
1602         return _bbox_xform_clip;
1603     }
1604 
img_height()1605     float img_height() const
1606     {
1607         return _img_height;
1608     }
1609 
img_width()1610     float img_width() const
1611     {
1612         return _img_width;
1613     }
1614 
scale()1615     float scale() const
1616     {
1617         return _scale;
1618     }
1619 
apply_scale()1620     bool apply_scale() const
1621     {
1622         return _apply_scale;
1623     }
1624 
correct_transform_coords()1625     bool correct_transform_coords() const
1626     {
1627         return _correct_transform_coords;
1628     }
1629 
1630 private:
1631     float _img_width;
1632     float _img_height;
1633     float _scale;
1634     bool  _apply_scale;
1635     bool  _correct_transform_coords;
1636     std::array<float, 4> _weights;
1637     float _bbox_xform_clip;
1638 };
1639 
1640 /** Activation Layer Information class */
1641 class ActivationLayerInfo
1642 {
1643 public:
1644     /** Available activation functions */
1645     enum class ActivationFunction
1646     {
1647         LOGISTIC,        /**< Logistic ( \f$ f(x) = \frac{1}{1 + e^{-x}} \f$ ) */
1648         TANH,            /**< Hyperbolic tangent ( \f$ f(x) = a \cdot tanh(b \cdot x) \f$ ) */
1649         RELU,            /**< Rectifier ( \f$ f(x) = max(0,x) \f$ ) */
1650         BOUNDED_RELU,    /**< Upper Bounded Rectifier ( \f$ f(x) = min(a, max(0,x)) \f$ ) */
1651         LU_BOUNDED_RELU, /**< Lower and Upper Bounded Rectifier ( \f$ f(x) = min(a, max(b,x)) \f$ ) */
1652         LEAKY_RELU,      /**< Leaky Rectifier ( \f$ f(x) = \begin{cases}  \alpha x & \quad \text{if } x \text{ < 0}\\  x & \quad \text{if } x \geq \text{ 0 } \end{cases} \f$ ) */
1653         SOFT_RELU,       /**< Soft Rectifier ( \f$ f(x)= log(1+e^x) \f$ ) */
1654         ELU,             /**< Exponential Linear Unit ( \f$ f(x) = \begin{cases}  \alpha (exp(x) - 1) & \quad \text{if } x \text{ < 0}\\  x & \quad \text{if } x \geq \text{ 0 } \end{cases} \f$ ) */
1655         ABS,             /**< Absolute ( \f$ f(x)= |x| \f$ ) */
1656         SQUARE,          /**< Square ( \f$ f(x)= x^2 \f$ )*/
1657         SQRT,            /**< Square root ( \f$ f(x) = \sqrt{x} \f$ )*/
1658         LINEAR,          /**< Linear ( \f$ f(x)= ax + b \f$ ) */
1659         IDENTITY,        /**< Identity ( \f$ f(x)= x \f$ ) */
1660         HARD_SWISH,      /**< Hard-swish ( \f$ f(x) = (x \text{ReLU6}(x+3))/6 = x \min(\max(0,x+3),6)/6 \f$ ) */
1661         SWISH,           /**< Swish ( \f$ f(x) = \frac{x}{1 + e^{-ax}} = x \text{logistic}(ax) \f$ ) */
1662         GELU             /**< GELU ( \f$ f(x) = x * 1/2 * 1 + erf(x / \sqrt{2}) \f$ ) */
1663     };
1664 
1665     /** Lookup table  */
1666     using LookupTable256 = std::array<qasymm8_t, 256>;
1667 
1668     ActivationLayerInfo() = default;
1669     /** Default Constructor
1670      *
1671      * @param[in] f The activation function to use.
1672      * @param[in] a (Optional) The alpha parameter used by some activation functions
1673      *              (@ref ActivationFunction::BOUNDED_RELU, @ref ActivationFunction::LU_BOUNDED_RELU, @ref ActivationFunction::LINEAR, @ref ActivationFunction::TANH).
1674      * @param[in] b (Optional) The beta parameter used by some activation functions (@ref ActivationFunction::LINEAR, @ref ActivationFunction::LU_BOUNDED_RELU, @ref ActivationFunction::TANH).
1675      */
1676     ActivationLayerInfo(ActivationFunction f, float a = 0.0f, float b = 0.0f)
_act(f)1677         : _act(f), _a(a), _b(b), _enabled(true)
1678     {
1679     }
1680     /** Get the type of activation function */
activation()1681     ActivationFunction activation() const
1682     {
1683         return _act;
1684     }
1685     /** Get the alpha value */
a()1686     float a() const
1687     {
1688         return _a;
1689     }
1690     /** Get the beta value */
b()1691     float b() const
1692     {
1693         return _b;
1694     }
1695     /** Check if initialised */
enabled()1696     bool enabled() const
1697     {
1698         return _enabled;
1699     }
1700 
1701 #ifdef __aarch64__
lut()1702     const LookupTable256 &lut() const
1703     {
1704         return _lut;
1705     }
1706 
init_lut(DataType data_type,const UniformQuantizationInfo & qi_in,const UniformQuantizationInfo & qi_out)1707     void init_lut(DataType data_type, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out)
1708     {
1709         if(_act == ActivationFunction::HARD_SWISH)
1710         {
1711             if(data_type == DataType::QASYMM8)
1712             {
1713                 qasymm8_hard_swish_populate_table(_lut, qi_in, qi_out);
1714             }
1715             else
1716             {
1717                 qasymm8_signed_hard_swish_populate_table(_lut, qi_in, qi_out);
1718             }
1719         }
1720         else if(_act == ActivationFunction::LEAKY_RELU)
1721         {
1722             qasymm8_leaky_relu_populate_table(_lut, qi_in, qi_out, _a);
1723         }
1724         else if(_act == ActivationFunction::LOGISTIC)
1725         {
1726             if(data_type == DataType::QASYMM8)
1727             {
1728                 qasymm8_logistic_populate_table(_lut, qi_in, qi_out);
1729             }
1730             else
1731             {
1732                 qasymm8_signed_logistic_populate_table(_lut, qi_in, qi_out);
1733             }
1734         }
1735     }
1736 #endif // __aarch64__
1737 
is_lut_supported(ActivationFunction act_func,DataType data_type)1738     static inline bool is_lut_supported(ActivationFunction act_func, DataType data_type)
1739     {
1740 #ifdef __aarch64__
1741         switch(act_func)
1742         {
1743             case ActivationFunction::HARD_SWISH:
1744                 return data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED;
1745             case ActivationFunction::LEAKY_RELU:
1746                 return data_type == DataType::QASYMM8;
1747             case ActivationFunction::LOGISTIC:
1748                 return data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED;
1749             default:
1750                 return false;
1751         }
1752 #else  // __aarch64__
1753         ARM_COMPUTE_UNUSED(act_func);
1754         ARM_COMPUTE_UNUSED(data_type);
1755         return false;
1756 #endif // __aarch64__
1757     }
1758 
1759 private:
1760     ActivationFunction _act     = { ActivationLayerInfo::ActivationFunction::IDENTITY };
1761     float              _a       = {};
1762     float              _b       = {};
1763     bool               _enabled = { false };
1764 
1765 #ifdef __aarch64__
1766     LookupTable256 _lut = {};
1767 
qasymm8_hard_swish_populate_table(LookupTable256 & lut,const UniformQuantizationInfo & qi_in,const UniformQuantizationInfo & qi_out)1768     static inline void qasymm8_hard_swish_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out)
1769     {
1770         for(size_t i = 0; i < lut.size(); ++i)
1771         {
1772             lut[i] = qasymm8_hard_swish(i, qi_in, qi_out);
1773         }
1774     }
1775 
qasymm8_signed_hard_swish_populate_table(LookupTable256 & lut,const UniformQuantizationInfo & qi_in,const UniformQuantizationInfo & qi_out)1776     static inline void qasymm8_signed_hard_swish_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out)
1777     {
1778         for(size_t i = 0; i < lut.size(); ++i)
1779         {
1780             lut[i] = qasymm8_signed_hard_swish(i, qi_in, qi_out);
1781         }
1782     }
1783 
qasymm8_leaky_relu_populate_table(LookupTable256 & lut,const UniformQuantizationInfo & qi_in,const UniformQuantizationInfo & qi_out,float alpha)1784     static inline void qasymm8_leaky_relu_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out, float alpha)
1785     {
1786         for(size_t i = 0; i < lut.size(); ++i)
1787         {
1788             lut[i] = qasymm8_leaky_relu(i, qi_in, qi_out, alpha);
1789         }
1790     }
1791 
qasymm8_logistic_populate_table(LookupTable256 & lut,const UniformQuantizationInfo & qi_in,const UniformQuantizationInfo & qi_out)1792     static inline void qasymm8_logistic_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out)
1793     {
1794         for(size_t i = 0; i < lut.size(); ++i)
1795         {
1796             lut[i] = qasymm8_logistic(i, qi_in, qi_out);
1797         }
1798     }
1799 
qasymm8_signed_logistic_populate_table(LookupTable256 & lut,const UniformQuantizationInfo & qi_in,const UniformQuantizationInfo & qi_out)1800     static inline void qasymm8_signed_logistic_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out)
1801     {
1802         for(size_t i = 0; i < lut.size(); ++i)
1803         {
1804             lut[i] = qasymm8_signed_logistic(static_cast<int8_t>(i), qi_in, qi_out);
1805         }
1806     }
1807 #endif // __aarch64__
1808 };
1809 
1810 /** Fully connected layer info */
1811 struct FullyConnectedLayerInfo
1812 {
1813     /* Fused-activation parameters */
1814     ActivationLayerInfo activation_info{}; /**<  Fused activation to apply after the matrix multiplication. */
1815     /* Information about weights */
1816     DataLayout weights_trained_layout{ DataLayout::NCHW }; /**<  Layout that the weights have been trained with. */
1817     bool       transpose_weights{ true };                  /**<  Transpose weights if true. */
1818     bool       are_weights_reshaped{ false };              /**<  Reshape the weights tensor if false. */
1819     bool       retain_internal_weights{ false };           /**<  Retain internal reshaped weights. */
1820     bool       enable_fast_math{ false };                  /**<  Enable fast math computation. */
1821     /* Other parameters */
1822     bool fp_mixed_precision{ false }; /**<  Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. */
1823 
1824     /** Sets the weights trained data layout
1825      *
1826      * @param[in] layout Data layout that the weights were trained with
1827      *
1828      * @return Updated object
1829      */
set_weights_trained_layoutFullyConnectedLayerInfo1830     FullyConnectedLayerInfo &set_weights_trained_layout(DataLayout layout)
1831     {
1832         weights_trained_layout = layout;
1833         return *this;
1834     }
1835     /** Sets the transpose weights flag
1836      *
1837      * @param[in] should_transpose_weights Boolean flag indicating if weights should be transposed
1838      *
1839      * @return Updated object
1840      */
set_transpose_weightsFullyConnectedLayerInfo1841     FullyConnectedLayerInfo &set_transpose_weights(bool should_transpose_weights)
1842     {
1843         transpose_weights = should_transpose_weights;
1844         return *this;
1845     }
1846 };
1847 
1848 /** Normalization Layer Information class */
1849 class NormalizationLayerInfo
1850 {
1851 public:
1852     /** Default Constructor
1853      *
1854      * @param[in] type      The normalization type. Can be @ref NormType::IN_MAP_1D, @ref NormType::IN_MAP_2D or @ref NormType::CROSS_MAP
1855      * @param[in] norm_size The normalization size is the number of elements to normalize across. Defaults to 5.
1856      * @param[in] alpha     (Optional) Alpha parameter used by normalization equation. Defaults to 0.0001.
1857      * @param[in] beta      (Optional) Beta parameter used by normalization equation. Defaults to 0.5.
1858      * @param[in] kappa     (Optional) Kappa parameter used by [Krichevksy 2012] Across Channel Local Brightness Normalization equation.
1859      * @param[in] is_scaled (Optional) Boolean that specifies if alpha will be scaled by the normalization size or not.
1860      *                      Should be false to follow [Krichevksy 2012].
1861      */
1862     NormalizationLayerInfo(NormType type, uint32_t norm_size = 5, float alpha = 0.0001f, float beta = 0.5f, float kappa = 1.f, bool is_scaled = true)
_type(type)1863         : _type(type), _norm_size(norm_size), _alpha(alpha), _beta(beta), _kappa(kappa), _is_scaled(is_scaled)
1864     {
1865     }
1866     /** Get the normalization type */
type()1867     NormType type() const
1868     {
1869         return _type;
1870     }
1871     /** Get the normalization size */
norm_size()1872     uint32_t norm_size() const
1873     {
1874         return _norm_size;
1875     }
1876     /** Get the alpha value */
alpha()1877     float alpha() const
1878     {
1879         return _alpha;
1880     }
1881     /** Get the beta value */
beta()1882     float beta() const
1883     {
1884         return _beta;
1885     }
1886     /** Get the kappa value */
kappa()1887     float kappa() const
1888     {
1889         return _kappa;
1890     }
1891     /** Get the is_scaled value */
is_scaled()1892     bool is_scaled() const
1893     {
1894         return _is_scaled;
1895     }
1896     /** Check if normalization is cross map */
is_cross_map()1897     bool is_cross_map() const
1898     {
1899         return _type == NormType::CROSS_MAP;
1900     }
1901     /** Check if normalization is not cross map */
is_in_map()1902     bool is_in_map() const
1903     {
1904         return !is_cross_map();
1905     }
1906     /** Return the scaling factor of the normalization function.
1907      *
1908      * If is_scaled is set to false then [Krichevksy 2012] normalization scaling is performed,
1909      * where alpha is returned plainly, else alpha is scaled by the total number of elements used for the normalization.
1910      *
1911      * @return The normalization scaling factor.
1912      */
scale_coeff()1913     float scale_coeff() const
1914     {
1915         const uint32_t size = (_type == NormType::IN_MAP_2D) ? _norm_size * _norm_size : _norm_size;
1916         return (_is_scaled) ? (_alpha / size) : _alpha;
1917     }
1918 
1919 private:
1920     NormType _type;
1921     uint32_t _norm_size;
1922     float    _alpha;
1923     float    _beta;
1924     float    _kappa;
1925     bool     _is_scaled;
1926 };
1927 
1928 class StridedSliceLayerInfo
1929 {
1930 public:
1931     /** Default Constructor
1932      *
1933      * @param[in] begin_mask       (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and the fullest possible range in that dimension is used instead.
1934      * @param[in] end_mask         (Optional) If the ith bit of end_mask is set, ends[i] is ignored and the fullest possible range in that dimension is used instead.
1935      * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1.
1936      */
1937     StridedSliceLayerInfo(int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0)
_begin_mask(begin_mask)1938         : _begin_mask(begin_mask), _end_mask(end_mask), _shrink_axis_mask(shrink_axis_mask)
1939     {
1940     }
1941 
1942     /* Get the begin mask value */
begin_mask()1943     int32_t begin_mask() const
1944     {
1945         return _begin_mask;
1946     }
1947 
1948     /* Get the end mask value */
end_mask()1949     int32_t end_mask() const
1950     {
1951         return _end_mask;
1952     }
1953 
1954     /* Get the shrink axis mask value */
shrink_axis_mask()1955     int32_t shrink_axis_mask() const
1956     {
1957         return _shrink_axis_mask;
1958     }
1959 
1960 private:
1961     int32_t _begin_mask;
1962     int32_t _end_mask;
1963     int32_t _shrink_axis_mask;
1964 };
1965 
1966 /** Memory layouts for the weights tensor.
1967   *
1968   * * UNSPECIFIED is used to select kernels that do not run in
1969   *    variable weights mode.
1970   *
1971   * * ANY is used to query the kernel database to retrieve any of the
1972   *   kernels that runs in variable weights mode. Once a kernel is
1973   *   found, the specific format expected by the kernel can be
1974   *   retrieved by the user for reordering the weights tensor
1975   *   accordingly.
1976   *
1977   * The other values OHWIo{interleave_by}i{block_by} describe the
1978   * memory layout of a 4D tensor with layout OHWI that has been
1979   * transformed into a 4D tensor with dimensions O'HWI' where:
1980   *
1981   * O' = first multiple of {interleave_by} s.t. O<=O'
1982   * I' = first multiple of {block_by} s.t. I<=I'
1983   *
1984   * The total size of the dst tensor is O' x H x W x I'
1985   *
1986   * The access function of the tensor with layout
1987   * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
1988   * access function, where the 6 parameters are computed as follows:
1989   *
1990   * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
1991   *
1992   * x4 = h                        RANGE [0, H-1]                   SIZE: H
1993   * x3 = w                        RANGE [0, W-1]                   SIZE: W
1994   * x2 = floor(i/{block_by})      RANGE [0, I'/{block_by} -1]      SIZE: I'/{block_by}
1995   * x1 = o%{interleave_by}        RANGE [0, {interleave_by} -1]    SIZE: {interleave_by}
1996   * x0 = i%{block_by}             RANGE [0, {block_by} -1]         SIZE: {block_by}
1997   *                                                          TOTAL SIZE: O' * H * W * I'
1998   *
1999   *        4D                       6D
2000   * -----------------   -----------------------------------
2001   * value(o, h, w, i) =   x5 * H * W * I' * {interleave_by}
2002   *                     + x4 * W * I' * {interleave_by}
2003   *                     + x3 * I' * {interleave_by}
2004   *                     + x2 * {interleave_by} * {block_by}
2005   *                     + x1 * {block_by}
2006   *                     + x0
2007   *
2008   * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
2009   * for the OHWIo{interleave_by}i{block_by} format is in reality seen
2010   * as a 2D tensor, where the number of rows is O'/{interleave_by}
2011   * and the number of columns is {interleave_by} * H * W * I'.
2012   *
2013   * The postfix *_bf16 is for the memory layout needed for the
2014   * fast-mode kernels, in which the weights are passed in bfloat16
2015   * format.
2016   */
2017 enum class WeightFormat
2018 {
2019     UNSPECIFIED    = 0x1,
2020     ANY            = 0x2,
2021     OHWI           = 0x100100,
2022     OHWIo2         = 0x100200,
2023     OHWIo4         = 0x100400,
2024     OHWIo8         = 0x100800,
2025     OHWIo16        = 0x101000,
2026     OHWIo32        = 0x102000,
2027     OHWIo64        = 0x104000,
2028     OHWIo128       = 0x108000,
2029     OHWIo4i2       = 0x200400,
2030     OHWIo4i2_bf16  = 0x200410,
2031     OHWIo8i2       = 0x200800,
2032     OHWIo8i2_bf16  = 0x200810,
2033     OHWIo16i2      = 0x201000,
2034     OHWIo16i2_bf16 = 0x201010,
2035     OHWIo32i2      = 0x202000,
2036     OHWIo32i2_bf16 = 0x202010,
2037     OHWIo64i2      = 0x204000,
2038     OHWIo64i2_bf16 = 0x204010,
2039     OHWIo4i4       = 0x400400,
2040     OHWIo4i4_bf16  = 0x400410,
2041     OHWIo8i4       = 0x400800,
2042     OHWIo8i4_bf16  = 0x400810,
2043     OHWIo16i4      = 0x401000,
2044     OHWIo16i4_bf16 = 0x401010,
2045     OHWIo32i4      = 0x402000,
2046     OHWIo32i4_bf16 = 0x402010,
2047     OHWIo64i4      = 0x404000,
2048     OHWIo64i4_bf16 = 0x404010,
2049     OHWIo2i8       = 0x800200,
2050     OHWIo4i8       = 0x800400,
2051     OHWIo8i8       = 0x800800,
2052     OHWIo16i8      = 0x801000,
2053     OHWIo32i8      = 0x802000,
2054     OHWIo64i8      = 0x804000
2055 };
2056 // OHWIo<interleave_by>i<block_by>
interleave_by(const WeightFormat wf)2057 inline int interleave_by(const WeightFormat wf)
2058 {
2059     return (static_cast<int>(wf) >> 8) & 0xFFF;
2060 }
block_by(const WeightFormat wf)2061 inline int block_by(const WeightFormat wf)
2062 {
2063     return (static_cast<int>(wf) >> 20) & 0xF;
2064 }
is_fixed_format(const WeightFormat & wf)2065 inline bool is_fixed_format(const WeightFormat &wf)
2066 {
2067     return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
2068 }
is_fixed_format_fast_math(const WeightFormat & wf)2069 inline bool is_fixed_format_fast_math(const WeightFormat &wf)
2070 {
2071     return (static_cast<int>(wf) >> 4) & 0x1;
2072 }
2073 
2074 /** Convolution Layer Weights Information class. This class stores the necessary information to compute convolution layer when the weights are already reshaped */
2075 class WeightsInfo
2076 {
2077 public:
2078     /** Default constructor */
WeightsInfo()2079     WeightsInfo()
2080         : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false), _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
2081     {
2082     }
2083     /** Constructor
2084      *
2085      * @param[in] are_reshaped            True if the weights have been reshaped
2086      * @param[in] kernel_width            Kernel width.
2087      * @param[in] kernel_height           Kernel height.
2088      * @param[in] num_kernels             Number of convolution kernels.
2089      * @param[in] retain_internal_weights (Optional) True if internal reshaped weights must be retained. Used for reconfiguration purposes. Default is false.
2090      * @param[in] weight_format           (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
2091      */
2092     WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false,
2093                 arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED)
_are_reshaped(are_reshaped)2094         : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights), _weight_format(weight_format)
2095     {
2096     }
2097     /** Flag which specifies if the weights tensor has been reshaped.
2098      *
2099      * @return True if the weights tensors has been reshaped
2100      */
are_reshaped()2101     bool are_reshaped() const
2102     {
2103         return _are_reshaped;
2104     };
2105     /** Return the number of convolution kernels
2106      *
2107      * @return The number of convolution kernels
2108      */
num_kernels()2109     unsigned int num_kernels() const
2110     {
2111         return _num_kernels;
2112     };
2113     /** Return the width and height of the kernel
2114      *
2115      * @return The width and height of the kernel
2116      */
kernel_size()2117     std::pair<unsigned int, unsigned int> kernel_size() const
2118     {
2119         return std::make_pair(_kernel_width, _kernel_height);
2120     }
retain_internal_weights()2121     bool retain_internal_weights() const
2122     {
2123         return _retain_internal_weights;
2124     }
weight_format()2125     arm_compute::WeightFormat weight_format() const
2126     {
2127         return _weight_format;
2128     }
set_weight_format(arm_compute::WeightFormat weight_format)2129     void set_weight_format(arm_compute::WeightFormat weight_format)
2130     {
2131         _weight_format = weight_format;
2132     }
2133 
kernel_width()2134     unsigned int kernel_width() const
2135     {
2136         return _kernel_width;
2137     }
kernel_height()2138     unsigned int kernel_height() const
2139     {
2140         return _kernel_height;
2141     }
2142 
2143 private:
2144     bool                      _are_reshaped;
2145     unsigned int              _kernel_width;
2146     unsigned int              _kernel_height;
2147     unsigned int              _num_kernels;
2148     bool                      _retain_internal_weights;
2149     arm_compute::WeightFormat _weight_format;
2150 };
2151 
2152 /** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape.
2153  *
2154  * The matrix A can only be reshaped through @ref opencl::kernels::ClGemmReshapeLhsMatrixKernel or  @ref cpu::kernels::CpuGemmInterleave4x4Kernel
2155  * Note: Optionally just for @ref opencl::kernels::ClGemmReshapeLhsMatrixKernel is it possible to set mult_interleave4x4_height, the multiplication factor for the height of the 4x4 interleaved block
2156  *
2157  * The matrix B can only be reshaped through @ref opencl::kernels::ClGemmReshapeRhsMatrixKernel or  @ref cpu::kernels::CpuGemmTranspose1xWKernel
2158  * Note: Optionally just for @ref opencl::kernels::ClGemmReshapeRhsMatrixKernel is it possible to set mult_transpose1xW_width, the multiplication factor for the width of the 1xW transposed block
2159  *
2160  */
2161 class GEMMReshapeInfo final
2162 {
2163 public:
2164     /** Default constructor */
GEMMReshapeInfo()2165     GEMMReshapeInfo()
2166         : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _broadcast_bias(false)
2167     {
2168     }
2169     /** Constructor
2170      *
2171      * @param[in] m                         Number of matrix A rows
2172      * @param[in] n                         Number of matrix B columns
2173      * @param[in] k                         Number of matrix A columns or matrix B rows
2174      * @param[in] mult_transpose1xW_width   (Optional) Multiplication factor for the width of the 1xW transposed block
2175      * @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block
2176      * @param[in] depth_output_gemm3d       (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.
2177      *                                      If 0 the output will not be reinterpreted as 3D. Default 0
2178      * @param[in] reinterpret_input_as_3d   (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
2179      *                                      to perform 1x1 convolutions with the NHWC data layout)
2180      * @param[in] broadcast_bias            (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
2181      */
2182     GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool broadcast_bias = false)
_m(m)2183         : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d),
2184           _reinterpret_input_as_3d(reinterpret_input_as_3d), _broadcast_bias(broadcast_bias)
2185     {
2186     }
2187     /** Number of matrix A rows
2188      *
2189      * @return the number of matrix A rows
2190      */
m()2191     int m() const
2192     {
2193         return _m;
2194     }
2195     /** Number of matrix B columns
2196      *
2197      * @return the number of matrix B columns
2198      */
n()2199     int n() const
2200     {
2201         return _n;
2202     }
2203     /** Number of matrix A columns or matrix B rows
2204      *
2205      * @return the number of matrix A columns or matrix B rows
2206      */
k()2207     int k() const
2208     {
2209         return _k;
2210     }
2211     /** Multiplication factor for the width of the 1xW transposed block
2212      *
2213      * @return the multiplication factor for the width of the 1xW transposed block
2214      */
mult_transpose1xW_width()2215     int mult_transpose1xW_width() const
2216     {
2217         return _mult_transpose1xW_width;
2218     }
2219     /** Multiplication factor for the height of the 4x4 interleaved block
2220      *
2221      * @return the multiplication factor for the height of the 4x4 interleaved block
2222      */
mult_interleave4x4_height()2223     int mult_interleave4x4_height() const
2224     {
2225         return _mult_interleave4x4_height;
2226     }
2227     /** Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
2228      *
2229      * @note GEMM3D kernel is used when the output has to be reinterpret as 3D tensor. In that case:
2230      *       m = depth_output_gemm3d * output_height
2231      *
2232      * @return the depth of the output tensor to be used with the GEMM3D kernel
2233      */
depth_output_gemm3d()2234     int depth_output_gemm3d() const
2235     {
2236         return _depth_output_gemm3d;
2237     }
2238     /** Flag which specifies if the input tensor has to be reinterpreted as 3D
2239      *
2240      * @return True if the input tensor has to be reinterpreted as 3D tensor
2241      */
reinterpret_input_as_3d()2242     bool reinterpret_input_as_3d() const
2243     {
2244         return _reinterpret_input_as_3d;
2245     };
2246     /** Flag which specifies whether to broadcast the shape of the bias tensor.
2247      *
2248      * @return True if the shape of the bias tensor is to be broadcasted.
2249      */
broadcast_bias()2250     bool broadcast_bias() const
2251     {
2252         return _broadcast_bias;
2253     };
2254 
2255 private:
2256     int  _m;
2257     int  _n;
2258     int  _k;
2259     int  _mult_transpose1xW_width;
2260     int  _mult_interleave4x4_height;
2261     int  _depth_output_gemm3d;
2262     bool _reinterpret_input_as_3d;
2263     bool _broadcast_bias;
2264 };
2265 
2266 struct ConvolutionInfo
2267 {
2268     ConvolutionInfo() = default;
ConvolutionInfoConvolutionInfo2269     ConvolutionInfo(const PadStrideInfo &pad_stride_info, unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation)
2270         : pad_stride_info(pad_stride_info), depth_multiplier(depth_multiplier), act_info(act_info), dilation(dilation)
2271     {
2272     }
2273     PadStrideInfo       pad_stride_info{};        /**< Convolution info (Pads, strides,...) */
2274     unsigned int        depth_multiplier{ 1 };    /**< Multiplier to apply to input's depth to retrieve the output depth. Defaults to 1 */
2275     ActivationLayerInfo act_info{};               /**< Fused activation to apply after convolution. */
2276     Size2D              dilation{ Size2D(1, 1) }; /**< Dilation, in elements, across x and y. Defaults to (1, 1). */
2277 };
2278 
2279 /** GEMMLowp output stage type */
2280 enum class GEMMLowpOutputStageType
2281 {
2282     NONE,                     /**< No quantization */
2283     QUANTIZE_DOWN,            /**< Quantize using an integer multiplication */
2284     QUANTIZE_DOWN_FIXEDPOINT, /**< Quantize using a fixed point multiplication */
2285     QUANTIZE_DOWN_FLOAT       /**< Quantize using a floating point multiplication */
2286 };
2287 
2288 /** GEMMLowp output stage info */
2289 struct GEMMLowpOutputStageInfo
2290 {
2291     GEMMLowpOutputStageType type{ GEMMLowpOutputStageType::NONE };                        /**< GEMMLowp output stage type */
2292     int32_t                 gemmlowp_offset{ 0 };                                         /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */
2293     int32_t                 gemmlowp_multiplier{ 0 };                                     /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
2294     int32_t                 gemmlowp_shift{ 0 };                                          /**< GEMMLowp output stage shift used for quantizing to uint8 */
2295     int32_t                 gemmlowp_min_bound{ std::numeric_limits<int32_t>::lowest() }; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */
2296     int32_t                 gemmlowp_max_bound{ std::numeric_limits<int32_t>::max() };    /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */
2297     std::vector<int32_t>    gemmlowp_multipliers{};                                       /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
2298     std::vector<int32_t>    gemmlowp_shifts{};                                            /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
2299     float                   gemmlowp_real_multiplier{ 0 };                                /**< GEMMLowp output stage real multiplier used for quantizing to QASYMM8 */
2300     bool                    is_quantized_per_channel{ false };                            /**< GEMMLowp quantized per-channel flag */
2301     DataType                output_data_type{ DataType::UNKNOWN };                        /**< Output tensor data type to use if the output is not initialized */
2302 };
2303 
2304 /** GEMM LHS (Left Hand Side) matrix information */
2305 struct GEMMLHSMatrixInfo
2306 {
2307     GEMMLHSMatrixInfo() = default;
GEMMLHSMatrixInfoGEMMLHSMatrixInfo2308     GEMMLHSMatrixInfo(unsigned int m, unsigned int k, unsigned int v, bool trans, bool inter)
2309         : m0(m), k0(k), v0(v), transpose(trans), interleave(inter)
2310     {
2311     }
2312     unsigned int m0{ 1 };            /**< Number of rows processed by the matrix multiplication */
2313     unsigned int k0{ 1 };            /**< Number of partial accumulations performed by the matrix multiplication */
2314     unsigned int v0{ 1 };            /**< Number of vertical blocks of size (m0xk0) stored on the same output row */
2315     bool         transpose{ true };  /**< True if the (m0xk0) block has to be transposed before been stored */
2316     bool         interleave{ true }; /**< True if the v0 (m0xk0) blocks have to be interleaved in the output row */
2317 };
2318 
2319 /** GEMM RHS (Right Hand Side) matrix information */
2320 struct GEMMRHSMatrixInfo
2321 {
2322     GEMMRHSMatrixInfo() = default;
GEMMRHSMatrixInfoGEMMRHSMatrixInfo2323     GEMMRHSMatrixInfo(unsigned int n, unsigned int k, unsigned int h, bool trans, bool inter, bool export_to_cl_img)
2324         : n0(n), k0(k), h0(h), transpose(trans), interleave(inter), export_to_cl_image(export_to_cl_img)
2325     {
2326     }
2327     unsigned int n0{ 1 };                     /**< Number of columns processed by the matrix multiplication */
2328     unsigned int k0{ 1 };                     /**< Number of partial accumulations performed by the matrix multiplication */
2329     unsigned int h0{ 1 };                     /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
2330     bool         transpose{ true };           /**< True if the (k0xn0) block has to be transposed before been stored */
2331     bool         interleave{ true };          /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */
2332     bool         export_to_cl_image{ false }; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */
2333 };
2334 
2335 class ITensorInfo;
2336 /** GEMM information class. This class stores the necessary information to compute GEMM functions
2337  *
2338  * This object also contains the information about how matrix A and matrix B have been reshaped
2339  *
2340  */
2341 class GEMMInfo
2342 {
2343 public:
2344     /** Default constructor */
GEMMInfo()2345     GEMMInfo() noexcept
2346         : _is_a_reshaped(false),
2347           _is_b_reshaped(false),
2348           _reshape_b_only_on_first_run(true),
2349           _depth_output_gemm3d(0),
2350           _reinterpret_input_as_3d(false),
2351           _retain_internal_weights(false),
2352           _gemmlowp_output_stage(),
2353           _fast_math(false),
2354           _fp_mixed_precision(false),
2355           _broadcast_bias(false),
2356           _pretranspose_A(false),
2357           _pretranspose_B(false),
2358           _activation_info(),
2359           _post_ops(),
2360           _fixed_format(false),
2361           _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
2362     {
2363     }
2364     /** Constructor
2365      *
2366      * @param[in] is_a_reshaped               True if the matrix A has been reshaped
2367      * @param[in] is_b_reshaped               True if the matrix B has been reshaped
2368      * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
2369      * @param[in] depth_output_gemm3d         (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
2370      *                                        If 0 the output will not be reinterpreted as 3D. Default 0
2371      * @param[in] reinterpret_input_as_3d     (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
2372      *                                        to perform 1x1 convolutions with the NHWC data layout)
2373      * @param[in] retain_internal_weights     (Optional) Retain the weights tensor from previous run
2374      * @param[in] gemmlowp_output_stage       (Optional) GEMMLowp Output stage info
2375      * @param[in] fp_mixed_precision          (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
2376      * @param[in] fast_math                   (Optional) Use a data type of shorter width to improve performance
2377      * @param[in] broadcast_bias              (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
2378      * @param[in] activation_info             (Optional) Activation to apply after the matrix multiplication
2379      * @param[in] post_ops                    (Optional) A sequence of post operations that are performed after the main operation.
2380      * @param[in] fixed_format                (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
2381      * @param[in] weight_format               (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
2382      */
2383     GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
2384              GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
2385              const ActivationLayerInfo &activation_info = ActivationLayerInfo(), const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *>(),
2386              bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
_is_a_reshaped(is_a_reshaped)2387         : _is_a_reshaped(is_a_reshaped),
2388           _is_b_reshaped(is_b_reshaped),
2389           _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
2390           _depth_output_gemm3d(depth_output_gemm3d),
2391           _reinterpret_input_as_3d(reinterpret_input_as_3d),
2392           _retain_internal_weights(retain_internal_weights),
2393           _gemmlowp_output_stage(gemmlowp_output_stage),
2394           _fast_math(fast_math),
2395           _fp_mixed_precision(fp_mixed_precision),
2396           _broadcast_bias(broadcast_bias),
2397           _pretranspose_A(false),
2398           _pretranspose_B(false),
2399           _activation_info(activation_info),
2400           _post_ops(post_ops),
2401           _fixed_format(fixed_format),
2402           _weight_format(weight_format)
2403     {
2404     }
2405     /** Flag which specifies if the matrix A has been reshaped
2406      *
2407      * @return True if the matrix A has been reshaped
2408      */
is_a_reshaped()2409     bool is_a_reshaped() const
2410     {
2411         return _is_a_reshaped;
2412     };
2413     /** Flag which specifies if the matrix B has been reshaped
2414      *
2415      * @return True if the matrix B has been reshaped
2416      */
is_b_reshaped()2417     bool is_b_reshaped() const
2418     {
2419         return _is_b_reshaped;
2420     };
2421     /** Flag which specifies if the reshape of matrix B should executed only for the first
2422      *
2423      * @note This flag could be set to TRUE when GEMM is used to accelerate convolution layer
2424      *
2425      * @return True if the reshaped of matrix B happens only for the first run
2426      */
reshape_b_only_on_first_run()2427     bool reshape_b_only_on_first_run() const
2428     {
2429         return _reshape_b_only_on_first_run;
2430     };
2431     /** Depth of the output when GEMM output is reinterpreted as 3D tensor
2432      *
2433      * @return the depth of the output tensor
2434      */
depth_output_gemm3d()2435     int depth_output_gemm3d() const
2436     {
2437         return _depth_output_gemm3d;
2438     };
2439     /** Flag which specifies if the input tensor has to be reinterpreted as 3D
2440      *
2441      * @return True if the input tensor has to be reinterpreted as 3D tensor
2442      */
reinterpret_input_as_3d()2443     bool reinterpret_input_as_3d() const
2444     {
2445         return _reinterpret_input_as_3d;
2446     };
2447     /** Flag which specifies if the weights tensor has to be retained from previous run
2448      *
2449      * @return True if the weights tensor has to be retained
2450      */
retain_internal_weights()2451     bool retain_internal_weights() const
2452     {
2453         return _retain_internal_weights;
2454     };
2455     /** GEMMLowp output stage
2456      *
2457      * @return the GEMMLowp output stage info
2458      */
gemmlowp_output_stage()2459     GEMMLowpOutputStageInfo gemmlowp_output_stage() const
2460     {
2461         return _gemmlowp_output_stage;
2462     };
2463     /** Sets GEMMLowp output stage
2464      *
2465      * @param[in] output_stage Output stage to set
2466      */
set_gemmlowp_output_stage(GEMMLowpOutputStageInfo & output_stage)2467     void set_gemmlowp_output_stage(GEMMLowpOutputStageInfo &output_stage)
2468     {
2469         _gemmlowp_output_stage = output_stage;
2470     };
2471     /** Flag which specifies if a wider accumulator should be used.
2472      *
2473      * @return True if a wider accumulator has to be used
2474      */
fp_mixed_precision()2475     bool fp_mixed_precision() const
2476     {
2477         return _fp_mixed_precision;
2478     };
2479     /** Flag which specifies if a shorter accumulator to be used.
2480      *
2481      * @return True if a shorter accumulator has to be used
2482      */
fast_math()2483     bool fast_math() const
2484     {
2485         return _fast_math;
2486     };
2487     /** Set fast math flag
2488      *
2489      * @param[in] fast_math Flag to set
2490      */
set_fast_math(bool fast_math)2491     void set_fast_math(bool fast_math)
2492     {
2493         _fast_math = fast_math;
2494     }
2495     /** Flag which specifies whether to broadcast the shape of the bias tensor.
2496      *
2497      * @return True if the shape of the bias tensor is to be broadcasted.
2498      */
broadcast_bias()2499     bool broadcast_bias() const
2500     {
2501         return _broadcast_bias;
2502     };
2503     /** Flag which specifies whether A should be pre-transposed if supported.
2504      *
2505      * @return True if A should be pre-transposed else false.
2506      */
pretranspose_A()2507     bool pretranspose_A() const
2508     {
2509         return _pretranspose_A;
2510     };
2511     /** Set pre-transpose A flag
2512      *
2513      * @param[in] flag Flag to set
2514      */
set_pretranspose_A(bool flag)2515     void set_pretranspose_A(bool flag)
2516     {
2517         _pretranspose_A = flag;
2518     }
2519     /** Flag which specifies whether b should be pre-transposed if supported.
2520      *
2521      * @return True if b should be pre-transposed else false.
2522      */
pretranspose_B()2523     bool pretranspose_B() const
2524     {
2525         return _pretranspose_B;
2526     };
2527     /** Set pre-transpose b flag
2528      *
2529      * @param[in] flag Flag to set
2530      */
set_pretranspose_B(bool flag)2531     void set_pretranspose_B(bool flag)
2532     {
2533         _pretranspose_B = flag;
2534     }
2535     /** Activation layer to apply after the matrix multiplication
2536      *
2537      * @return ActivationLayerInfo object
2538      */
activation_info()2539     ActivationLayerInfo activation_info() const
2540     {
2541         return _activation_info;
2542     }
2543     /** Set activation layer info
2544      *
2545      * @param[in] activation_info ActivationLayerInfo object to set
2546      */
set_activation_info(const ActivationLayerInfo & activation_info)2547     void set_activation_info(const ActivationLayerInfo &activation_info)
2548     {
2549         _activation_info = activation_info;
2550     }
2551     /** Post operations to apply after the matrix multiplication
2552      *
2553      * @return experimental::PostOpList object
2554      */
post_ops()2555     const experimental::PostOpList<ITensorInfo *> &post_ops() const
2556     {
2557         return _post_ops;
2558     }
2559     /** Set post ops
2560      *
2561      * @param[in] post_ops experimental::PostOpList object to set
2562      */
set_post_ops(const experimental::PostOpList<ITensorInfo * > & post_ops)2563     void set_post_ops(const experimental::PostOpList<ITensorInfo *> &post_ops)
2564     {
2565         _post_ops = post_ops;
2566     }
2567     /** Flag which specifies if the GEMM operation is running fixed-format kernels.
2568      *
2569      * @return True if the GEMM operation is running fixed-format kernel else false.
2570      */
fixed_format()2571     bool fixed_format() const
2572     {
2573         return _fixed_format;
2574     }
2575 
2576     /** Set fixed-format flag
2577      *
2578      * @param[in] fixed_format sets whether or not to use fixed-format kernels
2579      */
set_fixed_format(bool fixed_format)2580     void set_fixed_format(bool fixed_format)
2581     {
2582         _fixed_format = fixed_format;
2583     }
2584 
weight_format()2585     arm_compute::WeightFormat weight_format() const
2586     {
2587         return _weight_format;
2588     }
2589 
2590     /** Set weight format to be used
2591      *
2592      * @param[in] weight_format arm_compute::WeightFormat enumeration
2593      */
set_weight_format(arm_compute::WeightFormat weight_format)2594     void set_weight_format(arm_compute::WeightFormat weight_format)
2595     {
2596         _weight_format = weight_format;
2597     }
2598 
2599 private:
2600     bool                                    _is_a_reshaped;
2601     bool                                    _is_b_reshaped;
2602     bool                                    _reshape_b_only_on_first_run;
2603     int                                     _depth_output_gemm3d;
2604     bool                                    _reinterpret_input_as_3d;
2605     bool                                    _retain_internal_weights;
2606     GEMMLowpOutputStageInfo                 _gemmlowp_output_stage;
2607     bool                                    _fast_math;
2608     bool                                    _fp_mixed_precision;
2609     bool                                    _broadcast_bias;
2610     bool                                    _pretranspose_A;
2611     bool                                    _pretranspose_B;
2612     ActivationLayerInfo                     _activation_info;
2613     experimental::PostOpList<ITensorInfo *> _post_ops;
2614     bool                                    _fixed_format;
2615     arm_compute::WeightFormat               _weight_format;
2616 };
2617 
2618 /** Winograd information */
2619 struct WinogradInfo
2620 {
2621     /** Default constructor
2622      *
2623      * @param[in] output_tile_sz Width and height of the output tile
2624      * @param[in] kernel_sz      Width and height of the kernel
2625      * @param[in] input_dims     Width and height of the input tensor before the convolution is applied
2626      * @param[in] conv_info      Convolution info (Pads, strides)
2627      * @param[in] data_layout    Data layout to use for the output tensor once the convolution has been applied
2628      */
WinogradInfoWinogradInfo2629     WinogradInfo(Size2D output_tile_sz, Size2D kernel_sz, Size2D input_dims, PadStrideInfo conv_info, DataLayout data_layout)
2630         : output_tile_size(output_tile_sz), kernel_size(kernel_sz), input_dimensions(input_dims), convolution_info(conv_info), output_data_layout(data_layout)
2631     {
2632     }
2633 
2634     Size2D        output_tile_size{};                     /**< Width and height of the output tile */
2635     Size2D        kernel_size{};                          /**< Width and height of the kernel*/
2636     Size2D        input_dimensions{};                     /**< Width and height of the input tensor before the convolution is applied */
2637     PadStrideInfo convolution_info{};                     /**< Convolution info (Pads, strides,...) */
2638     DataLayout    output_data_layout{ DataLayout::NCHW }; /**< Data layout to use for the output tensor once the convolution has been applied (NCHW or NHWC) */
2639 };
2640 
2641 /** IO formatting information class*/
2642 struct IOFormatInfo
2643 {
2644     /** Precision type used when printing floating point numbers */
2645     enum class PrecisionType
2646     {
2647         Default, /**< Default precision to the one that the current stream has */
2648         Custom,  /**< Custom precision specified by the user using the precision parameter */
2649         Full     /**< The maximum precision of the floating point representation */
2650     };
2651 
2652     /** Specifies the area to be printed, used by Tensor objects */
2653     enum class PrintRegion
2654     {
2655         ValidRegion, /**< Prints the valid region of the Tensor object */
2656         NoPadding,   /**< Prints the Tensor object without the padding */
2657         Full         /**< Print the tensor object including padding */
2658     };
2659 
2660     /** Construct a set of IO formatting information.
2661      *
2662      * @param[in] print_region   Area to be printed. Used by Tensor objects. Default: ValidRegion.
2663      * @param[in] precision_type Precision type for floating point numbers. Default: stream default.
2664      * @param[in] precision      Precision value for float point numbers. Default: 10.
2665      * @param[in] align_columns  Whether to align columns when printed. Default: true.
2666      * @param[in] element_delim  Delimeter between elements. Default: " ".
2667      * @param[in] row_delim      Delimenter between rows. Default: "\n".
2668      */
2669     IOFormatInfo(PrintRegion   print_region   = PrintRegion::ValidRegion,
2670                  PrecisionType precision_type = PrecisionType::Default,
2671                  unsigned int  precision      = 10,
2672                  bool          align_columns  = true,
2673                  std::string   element_delim  = " ",
2674                  std::string   row_delim      = "\n")
print_regionIOFormatInfo2675         : print_region(print_region),
2676           precision_type(precision_type),
2677           precision(precision),
2678           element_delim(element_delim),
2679           row_delim(row_delim),
2680           align_columns(align_columns)
2681     {
2682     }
2683 
2684     /** Area to be printed by Tensor objects */
2685     PrintRegion print_region;
2686     /** Floating point precision type */
2687     PrecisionType precision_type;
2688     /** Floating point precision */
2689     unsigned int precision;
2690     /** Element delimeter */
2691     std::string element_delim;
2692     /** Row delimeter */
2693     std::string row_delim;
2694     /** Align columns */
2695     bool align_columns;
2696 };
2697 } // namespace arm_compute
2698 #endif /* ARM_COMPUTE_TYPES_H */
2699