• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_UTILS_H
25 #define ARM_COMPUTE_UTILS_H
26 
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/PixelValue.h"
29 #include "arm_compute/core/Rounding.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Version.h"
32 
33 #include <algorithm>
34 #include <cstdint>
35 #include <cstdlib>
36 #include <iomanip>
37 #include <numeric>
38 #include <sstream>
39 #include <string>
40 #include <type_traits>
41 #include <unordered_map>
42 #include <utility>
43 #include <vector>
44 
45 namespace arm_compute
46 {
47 class ITensor;
48 class ITensorInfo;
49 
50 /** Calculate the rounded up quotient of val / m.
51  *
52  * @param[in] val Value to divide and round up.
53  * @param[in] m   Value to divide by.
54  *
55  * @return the result.
56  */
57 template <typename S, typename T>
58 constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
59 {
60     return (val + m - 1) / m;
61 }
62 
63 /** Computes the smallest number larger or equal to value that is a multiple of divisor.
64  *
65  * @param[in] value   Lower bound value
66  * @param[in] divisor Value to compute multiple of.
67  *
68  * @return the result.
69  */
70 template <typename S, typename T>
71 inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
72 {
73     ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
74     return DIV_CEIL(value, divisor) * divisor;
75 }
76 
77 /** Computes the largest number smaller or equal to value that is a multiple of divisor.
78  *
79  * @param[in] value   Upper bound value
80  * @param[in] divisor Value to compute multiple of.
81  *
82  * @return the result.
83  */
84 template <typename S, typename T>
85 inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) * divisor)
86 {
87     ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
88     return (value / divisor) * divisor;
89 }
90 
91 /** Load an entire file in memory
92  *
93  * @param[in] filename Name of the file to read.
94  * @param[in] binary   Is it a binary file ?
95  *
96  * @return The content of the file.
97  */
98 std::string read_file(const std::string &filename, bool binary);
99 
100 /** The size in bytes of the data type
101  *
102  * @param[in] data_type Input data type
103  *
104  * @return The size in bytes of the data type
105  */
data_size_from_type(DataType data_type)106 inline size_t data_size_from_type(DataType data_type)
107 {
108     switch(data_type)
109     {
110         case DataType::U8:
111         case DataType::S8:
112         case DataType::QSYMM8:
113         case DataType::QASYMM8:
114         case DataType::QASYMM8_SIGNED:
115         case DataType::QSYMM8_PER_CHANNEL:
116             return 1;
117         case DataType::U16:
118         case DataType::S16:
119         case DataType::QSYMM16:
120         case DataType::QASYMM16:
121         case DataType::BFLOAT16:
122         case DataType::F16:
123             return 2;
124         case DataType::F32:
125         case DataType::U32:
126         case DataType::S32:
127             return 4;
128         case DataType::F64:
129         case DataType::U64:
130         case DataType::S64:
131             return 8;
132         case DataType::SIZET:
133             return sizeof(size_t);
134         default:
135             ARM_COMPUTE_ERROR("Invalid data type");
136             return 0;
137     }
138 }
139 
140 /** The size in bytes of the pixel format
141  *
142  * @param[in] format Input format
143  *
144  * @return The size in bytes of the pixel format
145  */
pixel_size_from_format(Format format)146 inline size_t pixel_size_from_format(Format format)
147 {
148     switch(format)
149     {
150         case Format::U8:
151             return 1;
152         case Format::U16:
153         case Format::S16:
154         case Format::BFLOAT16:
155         case Format::F16:
156         case Format::UV88:
157         case Format::YUYV422:
158         case Format::UYVY422:
159             return 2;
160         case Format::RGB888:
161             return 3;
162         case Format::RGBA8888:
163             return 4;
164         case Format::U32:
165         case Format::S32:
166         case Format::F32:
167             return 4;
168         //Doesn't make sense for planar formats:
169         case Format::NV12:
170         case Format::NV21:
171         case Format::IYUV:
172         case Format::YUV444:
173         default:
174             ARM_COMPUTE_ERROR("Undefined pixel size for given format");
175             return 0;
176     }
177 }
178 
179 /** The size in bytes of the data type
180  *
181  * @param[in] dt Input data type
182  *
183  * @return The size in bytes of the data type
184  */
element_size_from_data_type(DataType dt)185 inline size_t element_size_from_data_type(DataType dt)
186 {
187     switch(dt)
188     {
189         case DataType::S8:
190         case DataType::U8:
191         case DataType::QSYMM8:
192         case DataType::QASYMM8:
193         case DataType::QASYMM8_SIGNED:
194         case DataType::QSYMM8_PER_CHANNEL:
195             return 1;
196         case DataType::U16:
197         case DataType::S16:
198         case DataType::QSYMM16:
199         case DataType::QASYMM16:
200         case DataType::BFLOAT16:
201         case DataType::F16:
202             return 2;
203         case DataType::U32:
204         case DataType::S32:
205         case DataType::F32:
206             return 4;
207         default:
208             ARM_COMPUTE_ERROR("Undefined element size for given data type");
209             return 0;
210     }
211 }
212 
213 /** Return the data type used by a given single-planar pixel format
214  *
215  * @param[in] format Input format
216  *
217  * @return The size in bytes of the pixel format
218  */
data_type_from_format(Format format)219 inline DataType data_type_from_format(Format format)
220 {
221     switch(format)
222     {
223         case Format::U8:
224         case Format::UV88:
225         case Format::RGB888:
226         case Format::RGBA8888:
227         case Format::YUYV422:
228         case Format::UYVY422:
229             return DataType::U8;
230         case Format::U16:
231             return DataType::U16;
232         case Format::S16:
233             return DataType::S16;
234         case Format::U32:
235             return DataType::U32;
236         case Format::S32:
237             return DataType::S32;
238         case Format::BFLOAT16:
239             return DataType::BFLOAT16;
240         case Format::F16:
241             return DataType::F16;
242         case Format::F32:
243             return DataType::F32;
244         //Doesn't make sense for planar formats:
245         case Format::NV12:
246         case Format::NV21:
247         case Format::IYUV:
248         case Format::YUV444:
249         default:
250             ARM_COMPUTE_ERROR("Not supported data_type for given format");
251             return DataType::UNKNOWN;
252     }
253 }
254 
255 /** Return the plane index of a given channel given an input format.
256  *
257  * @param[in] format  Input format
258  * @param[in] channel Input channel
259  *
260  * @return The plane index of the specific channel of the specific format
261  */
plane_idx_from_channel(Format format,Channel channel)262 inline int plane_idx_from_channel(Format format, Channel channel)
263 {
264     switch(format)
265     {
266         // Single planar formats have a single plane
267         case Format::U8:
268         case Format::U16:
269         case Format::S16:
270         case Format::U32:
271         case Format::S32:
272         case Format::BFLOAT16:
273         case Format::F16:
274         case Format::F32:
275         case Format::UV88:
276         case Format::RGB888:
277         case Format::RGBA8888:
278         case Format::YUYV422:
279         case Format::UYVY422:
280             return 0;
281         // Multi planar formats
282         case Format::NV12:
283         case Format::NV21:
284         {
285             // Channel U and V share the same plane of format UV88
286             switch(channel)
287             {
288                 case Channel::Y:
289                     return 0;
290                 case Channel::U:
291                 case Channel::V:
292                     return 1;
293                 default:
294                     ARM_COMPUTE_ERROR("Not supported channel");
295                     return 0;
296             }
297         }
298         case Format::IYUV:
299         case Format::YUV444:
300         {
301             switch(channel)
302             {
303                 case Channel::Y:
304                     return 0;
305                 case Channel::U:
306                     return 1;
307                 case Channel::V:
308                     return 2;
309                 default:
310                     ARM_COMPUTE_ERROR("Not supported channel");
311                     return 0;
312             }
313         }
314         default:
315             ARM_COMPUTE_ERROR("Not supported format");
316             return 0;
317     }
318 }
319 
320 /** Return the channel index of a given channel given an input format.
321  *
322  * @param[in] format  Input format
323  * @param[in] channel Input channel
324  *
325  * @return The channel index of the specific channel of the specific format
326  */
channel_idx_from_format(Format format,Channel channel)327 inline int channel_idx_from_format(Format format, Channel channel)
328 {
329     switch(format)
330     {
331         case Format::RGB888:
332         {
333             switch(channel)
334             {
335                 case Channel::R:
336                     return 0;
337                 case Channel::G:
338                     return 1;
339                 case Channel::B:
340                     return 2;
341                 default:
342                     ARM_COMPUTE_ERROR("Not supported channel");
343                     return 0;
344             }
345         }
346         case Format::RGBA8888:
347         {
348             switch(channel)
349             {
350                 case Channel::R:
351                     return 0;
352                 case Channel::G:
353                     return 1;
354                 case Channel::B:
355                     return 2;
356                 case Channel::A:
357                     return 3;
358                 default:
359                     ARM_COMPUTE_ERROR("Not supported channel");
360                     return 0;
361             }
362         }
363         case Format::YUYV422:
364         {
365             switch(channel)
366             {
367                 case Channel::Y:
368                     return 0;
369                 case Channel::U:
370                     return 1;
371                 case Channel::V:
372                     return 3;
373                 default:
374                     ARM_COMPUTE_ERROR("Not supported channel");
375                     return 0;
376             }
377         }
378         case Format::UYVY422:
379         {
380             switch(channel)
381             {
382                 case Channel::Y:
383                     return 1;
384                 case Channel::U:
385                     return 0;
386                 case Channel::V:
387                     return 2;
388                 default:
389                     ARM_COMPUTE_ERROR("Not supported channel");
390                     return 0;
391             }
392         }
393         case Format::NV12:
394         {
395             switch(channel)
396             {
397                 case Channel::Y:
398                     return 0;
399                 case Channel::U:
400                     return 0;
401                 case Channel::V:
402                     return 1;
403                 default:
404                     ARM_COMPUTE_ERROR("Not supported channel");
405                     return 0;
406             }
407         }
408         case Format::NV21:
409         {
410             switch(channel)
411             {
412                 case Channel::Y:
413                     return 0;
414                 case Channel::U:
415                     return 1;
416                 case Channel::V:
417                     return 0;
418                 default:
419                     ARM_COMPUTE_ERROR("Not supported channel");
420                     return 0;
421             }
422         }
423         case Format::YUV444:
424         case Format::IYUV:
425         {
426             switch(channel)
427             {
428                 case Channel::Y:
429                     return 0;
430                 case Channel::U:
431                     return 0;
432                 case Channel::V:
433                     return 0;
434                 default:
435                     ARM_COMPUTE_ERROR("Not supported channel");
436                     return 0;
437             }
438         }
439         default:
440             ARM_COMPUTE_ERROR("Not supported format");
441             return 0;
442     }
443 }
444 
445 /** Return the number of planes for a given format
446  *
447  * @param[in] format Input format
448  *
449  * @return The number of planes for a given image format.
450  */
num_planes_from_format(Format format)451 inline size_t num_planes_from_format(Format format)
452 {
453     switch(format)
454     {
455         case Format::U8:
456         case Format::S16:
457         case Format::U16:
458         case Format::S32:
459         case Format::U32:
460         case Format::BFLOAT16:
461         case Format::F16:
462         case Format::F32:
463         case Format::RGB888:
464         case Format::RGBA8888:
465         case Format::YUYV422:
466         case Format::UYVY422:
467             return 1;
468         case Format::NV12:
469         case Format::NV21:
470             return 2;
471         case Format::IYUV:
472         case Format::YUV444:
473             return 3;
474         default:
475             ARM_COMPUTE_ERROR("Not supported format");
476             return 0;
477     }
478 }
479 
480 /** Return the number of channels for a given single-planar pixel format
481  *
482  * @param[in] format Input format
483  *
484  * @return The number of channels for a given image format.
485  */
num_channels_from_format(Format format)486 inline size_t num_channels_from_format(Format format)
487 {
488     switch(format)
489     {
490         case Format::U8:
491         case Format::U16:
492         case Format::S16:
493         case Format::U32:
494         case Format::S32:
495         case Format::BFLOAT16:
496         case Format::F16:
497         case Format::F32:
498             return 1;
499         // Because the U and V channels are subsampled
500         // these formats appear like having only 2 channels:
501         case Format::YUYV422:
502         case Format::UYVY422:
503             return 2;
504         case Format::UV88:
505             return 2;
506         case Format::RGB888:
507             return 3;
508         case Format::RGBA8888:
509             return 4;
510         //Doesn't make sense for planar formats:
511         case Format::NV12:
512         case Format::NV21:
513         case Format::IYUV:
514         case Format::YUV444:
515         default:
516             return 0;
517     }
518 }
519 
520 /** Return the promoted data type of a given data type.
521  *
522  * @note If promoted data type is not supported an error will be thrown
523  *
524  * @param[in] dt Data type to get the promoted type of.
525  *
526  * @return Promoted data type
527  */
get_promoted_data_type(DataType dt)528 inline DataType get_promoted_data_type(DataType dt)
529 {
530     switch(dt)
531     {
532         case DataType::U8:
533             return DataType::U16;
534         case DataType::S8:
535             return DataType::S16;
536         case DataType::U16:
537             return DataType::U32;
538         case DataType::S16:
539             return DataType::S32;
540         case DataType::QSYMM8:
541         case DataType::QASYMM8:
542         case DataType::QASYMM8_SIGNED:
543         case DataType::QSYMM8_PER_CHANNEL:
544         case DataType::QSYMM16:
545         case DataType::QASYMM16:
546         case DataType::BFLOAT16:
547         case DataType::F16:
548         case DataType::U32:
549         case DataType::S32:
550         case DataType::F32:
551             ARM_COMPUTE_ERROR("Unsupported data type promotions!");
552         default:
553             ARM_COMPUTE_ERROR("Undefined data type!");
554     }
555     return DataType::UNKNOWN;
556 }
557 
558 /** Compute the mininum and maximum values a data type can take
559  *
560  * @param[in] dt Data type to get the min/max bounds of
561  *
562  * @return A tuple (min,max) with the minimum and maximum values respectively wrapped in PixelValue.
563  */
get_min_max(DataType dt)564 inline std::tuple<PixelValue, PixelValue> get_min_max(DataType dt)
565 {
566     PixelValue min{};
567     PixelValue max{};
568     switch(dt)
569     {
570         case DataType::U8:
571         case DataType::QASYMM8:
572         {
573             min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::lowest()));
574             max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::max()));
575             break;
576         }
577         case DataType::S8:
578         case DataType::QSYMM8:
579         case DataType::QASYMM8_SIGNED:
580         case DataType::QSYMM8_PER_CHANNEL:
581         {
582             min = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::lowest()));
583             max = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::max()));
584             break;
585         }
586         case DataType::U16:
587         case DataType::QASYMM16:
588         {
589             min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::lowest()));
590             max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::max()));
591             break;
592         }
593         case DataType::S16:
594         case DataType::QSYMM16:
595         {
596             min = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::lowest()));
597             max = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::max()));
598             break;
599         }
600         case DataType::U32:
601         {
602             min = PixelValue(std::numeric_limits<uint32_t>::lowest());
603             max = PixelValue(std::numeric_limits<uint32_t>::max());
604             break;
605         }
606         case DataType::S32:
607         {
608             min = PixelValue(std::numeric_limits<int32_t>::lowest());
609             max = PixelValue(std::numeric_limits<int32_t>::max());
610             break;
611         }
612         case DataType::BFLOAT16:
613         {
614             min = PixelValue(bfloat16::lowest());
615             max = PixelValue(bfloat16::max());
616             break;
617         }
618         case DataType::F16:
619         {
620             min = PixelValue(std::numeric_limits<half>::lowest());
621             max = PixelValue(std::numeric_limits<half>::max());
622             break;
623         }
624         case DataType::F32:
625         {
626             min = PixelValue(std::numeric_limits<float>::lowest());
627             max = PixelValue(std::numeric_limits<float>::max());
628             break;
629         }
630         default:
631             ARM_COMPUTE_ERROR("Undefined data type!");
632     }
633     return std::make_tuple(min, max);
634 }
635 
636 /** Return true if the given format has horizontal subsampling.
637  *
638  * @param[in] format Format to determine subsampling.
639  *
640  * @return True if the format can be subsampled horizontaly.
641  */
has_format_horizontal_subsampling(Format format)642 inline bool has_format_horizontal_subsampling(Format format)
643 {
644     return (format == Format::YUYV422 || format == Format::UYVY422 || format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
645 }
646 
647 /** Return true if the given format has vertical subsampling.
648  *
649  * @param[in] format Format to determine subsampling.
650  *
651  * @return True if the format can be subsampled verticaly.
652  */
has_format_vertical_subsampling(Format format)653 inline bool has_format_vertical_subsampling(Format format)
654 {
655     return (format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
656 }
657 
658 /** Separate a 2D convolution into two 1D convolutions
659  *
660  * @param[in]  conv     2D convolution
661  * @param[out] conv_col 1D vertical convolution
662  * @param[out] conv_row 1D horizontal convolution
663  * @param[in]  size     Size of the 2D convolution
664  *
665  * @return true if the separation was successful
666  */
separate_matrix(const int16_t * conv,int16_t * conv_col,int16_t * conv_row,uint8_t size)667 inline bool separate_matrix(const int16_t *conv, int16_t *conv_col, int16_t *conv_row, uint8_t size)
668 {
669     int32_t min_col     = -1;
670     int16_t min_col_val = -1;
671 
672     for(int32_t i = 0; i < size; ++i)
673     {
674         if(conv[i] != 0 && (min_col < 0 || abs(min_col_val) > abs(conv[i])))
675         {
676             min_col     = i;
677             min_col_val = conv[i];
678         }
679     }
680 
681     if(min_col < 0)
682     {
683         return false;
684     }
685 
686     for(uint32_t j = 0; j < size; ++j)
687     {
688         conv_col[j] = conv[min_col + j * size];
689     }
690 
691     for(uint32_t i = 0; i < size; i++)
692     {
693         if(static_cast<int>(i) == min_col)
694         {
695             conv_row[i] = 1;
696         }
697         else
698         {
699             int16_t coeff = conv[i] / conv[min_col];
700 
701             for(uint32_t j = 1; j < size; ++j)
702             {
703                 if(conv[i + j * size] != (conv_col[j] * coeff))
704                 {
705                     return false;
706                 }
707             }
708 
709             conv_row[i] = coeff;
710         }
711     }
712 
713     return true;
714 }
715 
716 /** Calculate the scale of the given square matrix
717  *
718  * The scale is the absolute value of the sum of all the coefficients in the matrix.
719  *
720  * @note If the coefficients add up to 0 then the scale is set to 1.
721  *
722  * @param[in] matrix      Matrix coefficients
723  * @param[in] matrix_size Number of elements per side of the square matrix. (Number of coefficients = matrix_size * matrix_size).
724  *
725  * @return The absolute value of the sum of the coefficients if they don't add up to 0, otherwise 1.
726  */
calculate_matrix_scale(const int16_t * matrix,unsigned int matrix_size)727 inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matrix_size)
728 {
729     const size_t size = matrix_size * matrix_size;
730 
731     return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0)));
732 }
733 
734 /** Adjust tensor shape size if width or height are odd for a given multi-planar format. No modification is done for other formats.
735  *
736  * @note Adding here a few links discussing the issue of odd size and sharing the same solution:
737  *       <a href="https://android.googlesource.com/platform/frameworks/base/+/refs/heads/master/graphics/java/android/graphics/YuvImage.java">Android Source</a>
738  *       <a href="https://groups.google.com/a/webmproject.org/forum/#!topic/webm-discuss/LaCKpqiDTXM">WebM</a>
739  *       <a href="https://bugs.chromium.org/p/libyuv/issues/detail?id=198&amp;can=1&amp;q=odd%20width">libYUV</a>
740  *       <a href="https://sourceforge.net/p/raw-yuvplayer/bugs/1/">YUVPlayer</a> *
741  *
742  * @param[in, out] shape  Tensor shape of 2D size
743  * @param[in]      format Format of the tensor
744  *
745  * @return The adjusted tensor shape.
746  */
adjust_odd_shape(const TensorShape & shape,Format format)747 inline TensorShape adjust_odd_shape(const TensorShape &shape, Format format)
748 {
749     TensorShape output{ shape };
750 
751     // Force width to be even for formats which require subsampling of the U and V channels
752     if(has_format_horizontal_subsampling(format))
753     {
754         output.set(0, (output.x() + 1) & ~1U);
755     }
756 
757     // Force height to be even for formats which require subsampling of the U and V channels
758     if(has_format_vertical_subsampling(format))
759     {
760         output.set(1, (output.y() + 1) & ~1U);
761     }
762 
763     return output;
764 }
765 
766 /** Calculate subsampled shape for a given format and channel
767  *
768  * @param[in] shape   Shape of the tensor to calculate the extracted channel.
769  * @param[in] format  Format of the tensor.
770  * @param[in] channel Channel to create tensor shape to be extracted.
771  *
772  * @return The subsampled tensor shape.
773  */
774 inline TensorShape calculate_subsampled_shape(const TensorShape &shape, Format format, Channel channel = Channel::UNKNOWN)
775 {
776     TensorShape output{ shape };
777 
778     // Subsample shape only for U or V channel
779     if(Channel::U == channel || Channel::V == channel || Channel::UNKNOWN == channel)
780     {
781         // Subsample width for the tensor shape when channel is U or V
782         if(has_format_horizontal_subsampling(format))
783         {
784             output.set(0, output.x() / 2U);
785         }
786 
787         // Subsample height for the tensor shape when channel is U or V
788         if(has_format_vertical_subsampling(format))
789         {
790             output.set(1, output.y() / 2U);
791         }
792     }
793 
794     return output;
795 }
796 
797 /** Calculate accurary required by the horizontal and vertical convolution computations
798  *
799  * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter
800  * @param[in] conv_row Pointer to the horizontal vector of the convolution filter
801  * @param[in] size     Number of elements per vector of the separated matrix
802  *
803  * @return The return type is a pair. The first element of the pair is the biggest data type needed for the first stage. The second
804  * element of the pair is the biggest data type needed for the second stage.
805  */
data_type_for_convolution(const int16_t * conv_col,const int16_t * conv_row,size_t size)806 inline std::pair<DataType, DataType> data_type_for_convolution(const int16_t *conv_col, const int16_t *conv_row, size_t size)
807 {
808     DataType first_stage  = DataType::UNKNOWN;
809     DataType second_stage = DataType::UNKNOWN;
810 
811     auto gez = [](const int16_t &v)
812     {
813         return v >= 0;
814     };
815 
816     auto accu_neg = [](const int &first, const int &second)
817     {
818         return first + (second < 0 ? second : 0);
819     };
820 
821     auto accu_pos = [](const int &first, const int &second)
822     {
823         return first + (second > 0 ? second : 0);
824     };
825 
826     const bool only_positive_coefficients = std::all_of(conv_row, conv_row + size, gez) && std::all_of(conv_col, conv_col + size, gez);
827 
828     if(only_positive_coefficients)
829     {
830         const int max_row_value = std::accumulate(conv_row, conv_row + size, 0) * UINT8_MAX;
831         const int max_value     = std::accumulate(conv_col, conv_col + size, 0) * max_row_value;
832 
833         first_stage = (max_row_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
834 
835         second_stage = (max_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
836     }
837     else
838     {
839         const int min_row_value  = std::accumulate(conv_row, conv_row + size, 0, accu_neg) * UINT8_MAX;
840         const int max_row_value  = std::accumulate(conv_row, conv_row + size, 0, accu_pos) * UINT8_MAX;
841         const int neg_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_neg);
842         const int pos_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_pos);
843         const int min_value      = neg_coeffs_sum * max_row_value + pos_coeffs_sum * min_row_value;
844         const int max_value      = neg_coeffs_sum * min_row_value + pos_coeffs_sum * max_row_value;
845 
846         first_stage = ((INT16_MIN <= min_row_value) && (max_row_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
847 
848         second_stage = ((INT16_MIN <= min_value) && (max_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
849     }
850 
851     return std::make_pair(first_stage, second_stage);
852 }
853 
854 /** Calculate the accuracy required by the squared convolution calculation.
855  *
856  *
857  * @param[in] conv Pointer to the squared convolution matrix
858  * @param[in] size The total size of the convolution matrix
859  *
860  * @return The return is the biggest data type needed to do the convolution
861  */
data_type_for_convolution_matrix(const int16_t * conv,size_t size)862 inline DataType data_type_for_convolution_matrix(const int16_t *conv, size_t size)
863 {
864     auto gez = [](const int16_t v)
865     {
866         return v >= 0;
867     };
868 
869     const bool only_positive_coefficients = std::all_of(conv, conv + size, gez);
870 
871     if(only_positive_coefficients)
872     {
873         const int max_conv_value = std::accumulate(conv, conv + size, 0) * UINT8_MAX;
874         if(max_conv_value <= UINT16_MAX)
875         {
876             return DataType::U16;
877         }
878         else
879         {
880             return DataType::S32;
881         }
882     }
883     else
884     {
885         const int min_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
886         {
887             return b < 0 ? a + b : a;
888         })
889         * UINT8_MAX;
890 
891         const int max_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
892         {
893             return b > 0 ? a + b : a;
894         })
895         * UINT8_MAX;
896 
897         if((INT16_MIN <= min_value) && (INT16_MAX >= max_value))
898         {
899             return DataType::S16;
900         }
901         else
902         {
903             return DataType::S32;
904         }
905     }
906 }
907 
908 /** Permutes the given dimensions according the permutation vector
909  *
910  * @param[in,out] dimensions Dimensions to be permuted.
911  * @param[in]     perm       Vector describing the permutation.
912  *
913  */
914 template <typename T>
permute_strides(Dimensions<T> & dimensions,const PermutationVector & perm)915 inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
916 {
917     const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
918     for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
919     {
920         T dimension_val = old_dim[i];
921         dimensions.set(perm[i], dimension_val);
922     }
923 }
924 
925 /** Calculate padding requirements in case of SAME padding
926  *
927  * @param[in] input_shape   Input shape
928  * @param[in] weights_shape Weights shape
929  * @param[in] conv_info     Convolution information (containing strides)
930  * @param[in] data_layout   (Optional) Data layout of the input and weights tensor
931  * @param[in] dilation      (Optional) Dilation factor used in the convolution.
932  * @param[in] rounding_type (Optional) Dimension rounding type when down-scaling.
933  *
934  * @return PadStrideInfo for SAME padding
935  */
936 PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout = DataLayout::NCHW, const Size2D &dilation = Size2D(1u, 1u),
937                                  const DimensionRoundingType &rounding_type = DimensionRoundingType::FLOOR);
938 
939 /** Returns expected width and height of the deconvolution's output tensor.
940  *
941  * @param[in] in_width        Width of input tensor (Number of columns)
942  * @param[in] in_height       Height of input tensor (Number of rows)
943  * @param[in] kernel_width    Kernel width.
944  * @param[in] kernel_height   Kernel height.
945  * @param[in] pad_stride_info Pad and stride information.
946  *
947  * @return A pair with the new width in the first position and the new height in the second.
948  */
949 std::pair<unsigned int, unsigned int> deconvolution_output_dimensions(unsigned int in_width, unsigned int in_height,
950                                                                       unsigned int kernel_width, unsigned int kernel_height,
951                                                                       const PadStrideInfo &pad_stride_info);
952 
953 /** Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
954  *
955  * @param[in] width           Width of input tensor (Number of columns)
956  * @param[in] height          Height of input tensor (Number of rows)
957  * @param[in] kernel_width    Kernel width.
958  * @param[in] kernel_height   Kernel height.
959  * @param[in] pad_stride_info Pad and stride information.
960  * @param[in] dilation        (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
961  *
962  * @return A pair with the new width in the first position and the new height in the second.
963  */
964 std::pair<unsigned int, unsigned int> scaled_dimensions(int width, int height,
965                                                         int kernel_width, int kernel_height,
966                                                         const PadStrideInfo &pad_stride_info,
967                                                         const Size2D        &dilation = Size2D(1U, 1U));
968 
969 /** Check if the given reduction operation should be handled in a serial way.
970  *
971  * @param[in] op   Reduction operation to perform
972  * @param[in] dt   Data type
973  * @param[in] axis Axis along which to reduce
974  *
975  * @return True if the given reduction operation should be handled in a serial way.
976  */
977 bool needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis);
978 
979 /** Returns output quantization information for softmax layer
980  *
981  * @param[in] input_type The data type of the input tensor
982  * @param[in] is_log     True for log softmax
983  *
984  * @return Quantization information for the output tensor
985  */
986 QuantizationInfo get_softmax_output_quantization_info(DataType input_type, bool is_log);
987 
988 /** Returns a pair of minimum and maximum values for a quantized activation
989  *
990  * @param[in] act_info  The information for activation
991  * @param[in] data_type The used data type
992  * @param[in] oq_info   The output quantization information
993  *
994  * @return The pair with minimum and maximum values
995  */
996 std::pair<int32_t, int32_t> get_quantized_activation_min_max(ActivationLayerInfo act_info, DataType data_type, UniformQuantizationInfo oq_info);
997 
998 /** Convert a tensor format into a string.
999  *
1000  * @param[in] format @ref Format to be translated to string.
1001  *
1002  * @return The string describing the format.
1003  */
1004 const std::string &string_from_format(Format format);
1005 
1006 /** Convert a channel identity into a string.
1007  *
1008  * @param[in] channel @ref Channel to be translated to string.
1009  *
1010  * @return The string describing the channel.
1011  */
1012 const std::string &string_from_channel(Channel channel);
1013 /** Convert a data layout identity into a string.
1014  *
1015  * @param[in] dl @ref DataLayout to be translated to string.
1016  *
1017  * @return The string describing the data layout.
1018  */
1019 const std::string &string_from_data_layout(DataLayout dl);
1020 /** Convert a data type identity into a string.
1021  *
1022  * @param[in] dt @ref DataType to be translated to string.
1023  *
1024  * @return The string describing the data type.
1025  */
1026 const std::string &string_from_data_type(DataType dt);
1027 /** Convert a matrix pattern into a string.
1028  *
1029  * @param[in] pattern @ref MatrixPattern to be translated to string.
1030  *
1031  * @return The string describing the matrix pattern.
1032  */
1033 const std::string &string_from_matrix_pattern(MatrixPattern pattern);
1034 /** Translates a given activation function to a string.
1035  *
1036  * @param[in] act @ref ActivationLayerInfo::ActivationFunction to be translated to string.
1037  *
1038  * @return The string describing the activation function.
1039  */
1040 const std::string &string_from_activation_func(ActivationLayerInfo::ActivationFunction act);
1041 /** Translates a given non linear function to a string.
1042  *
1043  * @param[in] function @ref NonLinearFilterFunction to be translated to string.
1044  *
1045  * @return The string describing the non linear function.
1046  */
1047 const std::string &string_from_non_linear_filter_function(NonLinearFilterFunction function);
1048 /** Translates a given interpolation policy to a string.
1049  *
1050  * @param[in] policy @ref InterpolationPolicy to be translated to string.
1051  *
1052  * @return The string describing the interpolation policy.
1053  */
1054 const std::string &string_from_interpolation_policy(InterpolationPolicy policy);
1055 /** Translates a given border mode policy to a string.
1056  *
1057  * @param[in] border_mode @ref BorderMode to be translated to string.
1058  *
1059  * @return The string describing the border mode.
1060  */
1061 const std::string &string_from_border_mode(BorderMode border_mode);
1062 /** Translates a given normalization type to a string.
1063  *
1064  * @param[in] type @ref NormType to be translated to string.
1065  *
1066  * @return The string describing the normalization type.
1067  */
1068 const std::string &string_from_norm_type(NormType type);
1069 /** Translates a given pooling type to a string.
1070  *
1071  * @param[in] type @ref PoolingType to be translated to string.
1072  *
1073  * @return The string describing the pooling type.
1074  */
1075 const std::string &string_from_pooling_type(PoolingType type);
1076 /** Translates a given GEMMLowp output stage to a string.
1077  *
1078  * @param[in] output_stage @ref GEMMLowpOutputStageInfo to be translated to string.
1079  *
1080  * @return The string describing the GEMMLowp output stage
1081  */
1082 const std::string &string_from_gemmlowp_output_stage(GEMMLowpOutputStageType output_stage);
1083 /** Convert a PixelValue to a string, represented through the specific data type
1084  *
1085  * @param[in] value     The PixelValue to convert
1086  * @param[in] data_type The type to be used to convert the @p value
1087  *
1088  * @return String representation of the PixelValue through the given data type.
1089  */
1090 std::string string_from_pixel_value(const PixelValue &value, const DataType data_type);
1091 /** Convert a string to DataType
1092  *
1093  * @param[in] name The name of the data type
1094  *
1095  * @return DataType
1096  */
1097 DataType data_type_from_name(const std::string &name);
1098 /** Stores padding information before configuring a kernel
1099  *
1100  * @param[in] infos list of tensor infos to store the padding info for
1101  *
1102  * @return An unordered map where each tensor info pointer is paired with its original padding info
1103  */
1104 std::unordered_map<const ITensorInfo *, PaddingSize> get_padding_info(std::initializer_list<const ITensorInfo *> infos);
1105 /** Stores padding information before configuring a kernel
1106  *
1107  * @param[in] tensors list of tensors to store the padding info for
1108  *
1109  * @return An unordered map where each tensor info pointer is paired with its original padding info
1110  */
1111 std::unordered_map<const ITensorInfo *, PaddingSize> get_padding_info(std::initializer_list<const ITensor *> tensors);
1112 /** Check if the previously stored padding info has changed after configuring a kernel
1113  *
1114  * @param[in] padding_map an unordered map where each tensor info pointer is paired with its original padding info
1115  *
1116  * @return true if any of the tensor infos has changed its paddings
1117  */
1118 bool has_padding_changed(const std::unordered_map<const ITensorInfo *, PaddingSize> &padding_map);
1119 
1120 /** Input Stream operator for @ref DataType
1121  *
1122  * @param[in]  stream    Stream to parse
1123  * @param[out] data_type Output data type
1124  *
1125  * @return Updated stream
1126  */
1127 inline ::std::istream &operator>>(::std::istream &stream, DataType &data_type)
1128 {
1129     std::string value;
1130     stream >> value;
1131     data_type = data_type_from_name(value);
1132     return stream;
1133 }
1134 /** Lower a given string.
1135  *
1136  * @param[in] val Given string to lower.
1137  *
1138  * @return The lowered string
1139  */
1140 std::string lower_string(const std::string &val);
1141 
1142 /** Check if a given data type is of floating point type
1143  *
1144  * @param[in] dt Input data type.
1145  *
1146  * @return True if data type is of floating point type, else false.
1147  */
is_data_type_float(DataType dt)1148 inline bool is_data_type_float(DataType dt)
1149 {
1150     switch(dt)
1151     {
1152         case DataType::F16:
1153         case DataType::F32:
1154             return true;
1155         default:
1156             return false;
1157     }
1158 }
1159 
1160 /** Check if a given data type is of quantized type
1161  *
1162  * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
1163  *
1164  * @param[in] dt Input data type.
1165  *
1166  * @return True if data type is of quantized type, else false.
1167  */
is_data_type_quantized(DataType dt)1168 inline bool is_data_type_quantized(DataType dt)
1169 {
1170     switch(dt)
1171     {
1172         case DataType::QSYMM8:
1173         case DataType::QASYMM8:
1174         case DataType::QASYMM8_SIGNED:
1175         case DataType::QSYMM8_PER_CHANNEL:
1176         case DataType::QSYMM16:
1177         case DataType::QASYMM16:
1178             return true;
1179         default:
1180             return false;
1181     }
1182 }
1183 
1184 /** Check if a given data type is of asymmetric quantized type
1185  *
1186  * @param[in] dt Input data type.
1187  *
1188  * @return True if data type is of asymmetric quantized type, else false.
1189  */
is_data_type_quantized_asymmetric(DataType dt)1190 inline bool is_data_type_quantized_asymmetric(DataType dt)
1191 {
1192     switch(dt)
1193     {
1194         case DataType::QASYMM8:
1195         case DataType::QASYMM8_SIGNED:
1196         case DataType::QASYMM16:
1197             return true;
1198         default:
1199             return false;
1200     }
1201 }
1202 
1203 /** Check if a given data type is of asymmetric quantized signed type
1204  *
1205  * @param[in] dt Input data type.
1206  *
1207  * @return True if data type is of asymmetric quantized signed type, else false.
1208  */
is_data_type_quantized_asymmetric_signed(DataType dt)1209 inline bool is_data_type_quantized_asymmetric_signed(DataType dt)
1210 {
1211     switch(dt)
1212     {
1213         case DataType::QASYMM8_SIGNED:
1214             return true;
1215         default:
1216             return false;
1217     }
1218 }
1219 
1220 /** Check if a given data type is of symmetric quantized type
1221  *
1222  * @param[in] dt Input data type.
1223  *
1224  * @return True if data type is of symmetric quantized type, else false.
1225  */
is_data_type_quantized_symmetric(DataType dt)1226 inline bool is_data_type_quantized_symmetric(DataType dt)
1227 {
1228     switch(dt)
1229     {
1230         case DataType::QSYMM8:
1231         case DataType::QSYMM8_PER_CHANNEL:
1232         case DataType::QSYMM16:
1233             return true;
1234         default:
1235             return false;
1236     }
1237 }
1238 
1239 /** Check if a given data type is of per channel type
1240  *
1241  * @param[in] dt Input data type.
1242  *
1243  * @return True if data type is of per channel type, else false.
1244  */
is_data_type_quantized_per_channel(DataType dt)1245 inline bool is_data_type_quantized_per_channel(DataType dt)
1246 {
1247     switch(dt)
1248     {
1249         case DataType::QSYMM8_PER_CHANNEL:
1250             return true;
1251         default:
1252             return false;
1253     }
1254 }
1255 
1256 /** Create a string with the float in full precision.
1257  *
1258  * @param val Floating point value
1259  *
1260  * @return String with the floating point value.
1261  */
float_to_string_with_full_precision(float val)1262 inline std::string float_to_string_with_full_precision(float val)
1263 {
1264     std::stringstream ss;
1265     ss.precision(std::numeric_limits<float>::max_digits10);
1266     ss << val;
1267 
1268     if(val != static_cast<int>(val))
1269     {
1270         ss << "f";
1271     }
1272 
1273     return ss.str();
1274 }
1275 
1276 /** Returns the number of elements required to go from start to end with the wanted step
1277  *
1278  * @param[in] start start value
1279  * @param[in] end   end value
1280  * @param[in] step  step value between each number in the wanted sequence
1281  *
1282  * @return number of elements to go from start value to end value using the wanted step
1283  */
num_of_elements_in_range(const float start,const float end,const float step)1284 inline size_t num_of_elements_in_range(const float start, const float end, const float step)
1285 {
1286     ARM_COMPUTE_ERROR_ON_MSG(step == 0, "Range Step cannot be 0");
1287     return size_t(std::ceil((end - start) / step));
1288 }
1289 
1290 /** Returns true if the value can be represented by the given data type
1291  *
1292  * @param[in] val   value to be checked
1293  * @param[in] dt    data type that is checked
1294  * @param[in] qinfo (Optional) quantization info if the data type is QASYMM8
1295  *
1296  * @return true if the data type can hold the value.
1297  */
1298 template <typename T>
1299 bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = QuantizationInfo())
1300 {
1301     switch(dt)
1302     {
1303         case DataType::U8:
1304         {
1305             const auto val_u8 = static_cast<uint8_t>(val);
1306             return ((val_u8 == val) && val_u8 >= std::numeric_limits<uint8_t>::lowest() && val_u8 <= std::numeric_limits<uint8_t>::max());
1307         }
1308         case DataType::QASYMM8:
1309         {
1310             double min = static_cast<double>(dequantize_qasymm8(0, qinfo));
1311             double max = static_cast<double>(dequantize_qasymm8(std::numeric_limits<uint8_t>::max(), qinfo));
1312             return ((double)val >= min && (double)val <= max);
1313         }
1314         case DataType::S8:
1315         {
1316             const auto val_s8 = static_cast<int8_t>(val);
1317             return ((val_s8 == val) && val_s8 >= std::numeric_limits<int8_t>::lowest() && val_s8 <= std::numeric_limits<int8_t>::max());
1318         }
1319         case DataType::U16:
1320         {
1321             const auto val_u16 = static_cast<uint16_t>(val);
1322             return ((val_u16 == val) && val_u16 >= std::numeric_limits<uint16_t>::lowest() && val_u16 <= std::numeric_limits<uint16_t>::max());
1323         }
1324         case DataType::S16:
1325         {
1326             const auto val_s16 = static_cast<int16_t>(val);
1327             return ((val_s16 == val) && val_s16 >= std::numeric_limits<int16_t>::lowest() && val_s16 <= std::numeric_limits<int16_t>::max());
1328         }
1329         case DataType::U32:
1330         {
1331             const auto val_u32 = static_cast<uint32_t>(val);
1332             return ((val_u32 == val) && val_u32 >= std::numeric_limits<uint32_t>::lowest() && val_u32 <= std::numeric_limits<uint32_t>::max());
1333         }
1334         case DataType::S32:
1335         {
1336             const auto val_s32 = static_cast<int32_t>(val);
1337             return ((val_s32 == val) && val_s32 >= std::numeric_limits<int32_t>::lowest() && val_s32 <= std::numeric_limits<int32_t>::max());
1338         }
1339         case DataType::BFLOAT16:
1340             return (val >= bfloat16::lowest() && val <= bfloat16::max());
1341         case DataType::F16:
1342             return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
1343         case DataType::F32:
1344             return (val >= std::numeric_limits<float>::lowest() && val <= std::numeric_limits<float>::max());
1345         default:
1346             ARM_COMPUTE_ERROR("Data type not supported");
1347             return false;
1348     }
1349 }
1350 
1351 /** Returns the adjusted vector size in case it is less than the input's first dimension, getting rounded down to its closest valid vector size
1352  *
1353  * @param[in] vec_size vector size to be adjusted
1354  * @param[in] dim0     size of the first dimension
1355  *
1356  * @return the number of element processed along the X axis per thread
1357  */
adjust_vec_size(unsigned int vec_size,size_t dim0)1358 inline unsigned int adjust_vec_size(unsigned int vec_size, size_t dim0)
1359 {
1360     ARM_COMPUTE_ERROR_ON(vec_size > 16);
1361 
1362     if((vec_size >= dim0) && (dim0 == 3))
1363     {
1364         return dim0;
1365     }
1366 
1367     while(vec_size > dim0)
1368     {
1369         vec_size >>= 1;
1370     }
1371 
1372     return vec_size;
1373 }
1374 
1375 #ifdef ARM_COMPUTE_ASSERTS_ENABLED
1376 /** Print consecutive elements to an output stream.
1377  *
1378  * @param[out] s             Output stream to print the elements to.
1379  * @param[in]  ptr           Pointer to print the elements from.
1380  * @param[in]  n             Number of elements to print.
1381  * @param[in]  stream_width  (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1382  * @param[in]  element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1383  */
1384 template <typename T>
1385 void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int n, int stream_width = 0, const std::string &element_delim = " ")
1386 {
1387     using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1388     std::ios stream_status(nullptr);
1389     stream_status.copyfmt(s);
1390 
1391     for(unsigned int i = 0; i < n; ++i)
1392     {
1393         // Set stream width as it is not a "sticky" stream manipulator
1394         if(stream_width != 0)
1395         {
1396             s.width(stream_width);
1397         }
1398 
1399         if(std::is_same<typename std::decay<T>::type, half>::value)
1400         {
1401             // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1402             s << std::right << static_cast<T>(ptr[i]) << element_delim;
1403         }
1404         else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
1405         {
1406             // We use T instead of print_type here is because the std::is_floating_point<bfloat16> returns false and then the print_type becomes int.
1407             s << std::right << float(ptr[i]) << element_delim;
1408         }
1409         else
1410         {
1411             s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
1412         }
1413     }
1414 
1415     // Restore output stream flags
1416     s.copyfmt(stream_status);
1417 }
1418 
1419 /** Identify the maximum width of n consecutive elements.
1420  *
1421  * @param[in] s   The output stream which will be used to print the elements. Used to extract the stream format.
1422  * @param[in] ptr Pointer to the elements.
1423  * @param[in] n   Number of elements.
1424  *
1425  * @return The maximum width of the elements.
1426  */
1427 template <typename T>
max_consecutive_elements_display_width_impl(std::ostream & s,const T * ptr,unsigned int n)1428 int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, unsigned int n)
1429 {
1430     using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1431 
1432     int max_width = -1;
1433     for(unsigned int i = 0; i < n; ++i)
1434     {
1435         std::stringstream ss;
1436         ss.copyfmt(s);
1437 
1438         if(std::is_same<typename std::decay<T>::type, half>::value)
1439         {
1440             // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1441             ss << static_cast<T>(ptr[i]);
1442         }
1443         else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
1444         {
1445             // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
1446             ss << float(ptr[i]);
1447         }
1448         else
1449         {
1450             ss << static_cast<print_type>(ptr[i]);
1451         }
1452 
1453         max_width = std::max<int>(max_width, ss.str().size());
1454     }
1455     return max_width;
1456 }
1457 
1458 /** Print consecutive elements to an output stream.
1459  *
1460  * @param[out] s             Output stream to print the elements to.
1461  * @param[in]  dt            Data type of the elements
1462  * @param[in]  ptr           Pointer to print the elements from.
1463  * @param[in]  n             Number of elements to print.
1464  * @param[in]  stream_width  (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1465  * @param[in]  element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1466  */
1467 void print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim = " ");
1468 
1469 /** Identify the maximum width of n consecutive elements.
1470  *
1471  * @param[in] s   Output stream to print the elements to.
1472  * @param[in] dt  Data type of the elements
1473  * @param[in] ptr Pointer to print the elements from.
1474  * @param[in] n   Number of elements to print.
1475  *
1476  * @return The maximum width of the elements.
1477  */
1478 int max_consecutive_elements_display_width(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n);
1479 #endif /* ARM_COMPUTE_ASSERTS_ENABLED */
1480 }
1481 #endif /*ARM_COMPUTE_UTILS_H */
1482