• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2018-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/BackendId.hpp>
8 #include <armnn/Exceptions.hpp>
9 #include <armnn/Tensor.hpp>
10 #include <armnn/Types.hpp>
11 
12 #include <stdint.h>
13 #include <cmath>
14 #include <ostream>
15 #include <set>
16 #include <type_traits>
17 
18 namespace armnn
19 {
20 
GetStatusAsCString(Status status)21 constexpr char const* GetStatusAsCString(Status status)
22 {
23     switch (status)
24     {
25         case armnn::Status::Success: return "Status::Success";
26         case armnn::Status::Failure: return "Status::Failure";
27         default:                     return "Unknown";
28     }
29 }
30 
GetActivationFunctionAsCString(ActivationFunction activation)31 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
32 {
33     switch (activation)
34     {
35         case ActivationFunction::Sigmoid:       return "Sigmoid";
36         case ActivationFunction::TanH:          return "TanH";
37         case ActivationFunction::Linear:        return "Linear";
38         case ActivationFunction::ReLu:          return "ReLu";
39         case ActivationFunction::BoundedReLu:   return "BoundedReLu";
40         case ActivationFunction::SoftReLu:      return "SoftReLu";
41         case ActivationFunction::LeakyReLu:     return "LeakyReLu";
42         case ActivationFunction::Abs:           return "Abs";
43         case ActivationFunction::Sqrt:          return "Sqrt";
44         case ActivationFunction::Square:        return "Square";
45         case ActivationFunction::Elu:           return "Elu";
46         case ActivationFunction::HardSwish:     return "HardSwish";
47         default:                                return "Unknown";
48     }
49 }
50 
GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)51 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
52 {
53     switch (function)
54     {
55         case ArgMinMaxFunction::Max:    return "Max";
56         case ArgMinMaxFunction::Min:    return "Min";
57         default:                        return "Unknown";
58     }
59 }
60 
GetComparisonOperationAsCString(ComparisonOperation operation)61 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
62 {
63     switch (operation)
64     {
65         case ComparisonOperation::Equal:          return "Equal";
66         case ComparisonOperation::Greater:        return "Greater";
67         case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
68         case ComparisonOperation::Less:           return "Less";
69         case ComparisonOperation::LessOrEqual:    return "LessOrEqual";
70         case ComparisonOperation::NotEqual:       return "NotEqual";
71         default:                                  return "Unknown";
72     }
73 }
74 
GetBinaryOperationAsCString(BinaryOperation operation)75 constexpr char const* GetBinaryOperationAsCString(BinaryOperation operation)
76 {
77     switch (operation)
78     {
79         case BinaryOperation::Add:      return "Add";
80         case BinaryOperation::Div:      return "Div";
81         case BinaryOperation::Maximum:  return "Maximum";
82         case BinaryOperation::Minimum:  return "Minimum";
83         case BinaryOperation::Mul:      return "Mul";
84         case BinaryOperation::Sub:      return "Sub";
85         default:                        return "Unknown";
86     }
87 }
88 
GetUnaryOperationAsCString(UnaryOperation operation)89 constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation)
90 {
91     switch (operation)
92     {
93         case UnaryOperation::Abs:        return "Abs";
94         case UnaryOperation::Exp:        return "Exp";
95         case UnaryOperation::Sqrt:       return "Sqrt";
96         case UnaryOperation::Rsqrt:      return "Rsqrt";
97         case UnaryOperation::Neg:        return "Neg";
98         case UnaryOperation::Log:        return "Log";
99         case UnaryOperation::LogicalNot: return "LogicalNot";
100         case UnaryOperation::Sin:        return "Sin";
101         default:                         return "Unknown";
102     }
103 }
104 
GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)105 constexpr char const* GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)
106 {
107     switch (operation)
108     {
109         case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd";
110         case LogicalBinaryOperation::LogicalOr:  return "LogicalOr";
111         default:                                 return "Unknown";
112     }
113 }
114 
GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)115 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
116 {
117     switch (pooling)
118     {
119         case PoolingAlgorithm::Average:  return "Average";
120         case PoolingAlgorithm::Max:      return "Max";
121         case PoolingAlgorithm::L2:       return "L2";
122         default:                         return "Unknown";
123     }
124 }
125 
GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)126 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
127 {
128     switch (rounding)
129     {
130         case OutputShapeRounding::Ceiling:  return "Ceiling";
131         case OutputShapeRounding::Floor:    return "Floor";
132         default:                            return "Unknown";
133     }
134 }
135 
GetPaddingMethodAsCString(PaddingMethod method)136 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
137 {
138     switch (method)
139     {
140         case PaddingMethod::Exclude:       return "Exclude";
141         case PaddingMethod::IgnoreValue:   return "IgnoreValue";
142         default:                           return "Unknown";
143     }
144 }
145 
GetPaddingModeAsCString(PaddingMode mode)146 constexpr char const* GetPaddingModeAsCString(PaddingMode mode)
147 {
148     switch (mode)
149     {
150         case PaddingMode::Constant:   return "Exclude";
151         case PaddingMode::Symmetric:  return "Symmetric";
152         case PaddingMode::Reflect:    return "Reflect";
153         default:                      return "Unknown";
154     }
155 }
156 
GetReduceOperationAsCString(ReduceOperation reduce_operation)157 constexpr char const* GetReduceOperationAsCString(ReduceOperation reduce_operation)
158 {
159     switch (reduce_operation)
160     {
161         case ReduceOperation::Sum:  return "Sum";
162         case ReduceOperation::Max:  return "Max";
163         case ReduceOperation::Mean: return "Mean";
164         case ReduceOperation::Min:  return "Min";
165         case ReduceOperation::Prod: return "Prod";
166         default:                    return "Unknown";
167     }
168 }
GetDataTypeSize(DataType dataType)169 constexpr unsigned int GetDataTypeSize(DataType dataType)
170 {
171     switch (dataType)
172     {
173         case DataType::BFloat16:
174         case DataType::Float16:               return 2U;
175         case DataType::Float32:
176         case DataType::Signed32:              return 4U;
177         case DataType::Signed64:              return 8U;
178         case DataType::QAsymmU8:              return 1U;
179         case DataType::QAsymmS8:              return 1U;
180         case DataType::QSymmS8:               return 1U;
181         case DataType::QSymmS16:              return 2U;
182         case DataType::Boolean:               return 1U;
183         default:                              return 0U;
184     }
185 }
186 
187 template <unsigned N>
StrEqual(const char * strA,const char (& strB)[N])188 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
189 {
190     bool isEqual = true;
191     for (unsigned i = 0; isEqual && (i < N); ++i)
192     {
193         isEqual = (strA[i] == strB[i]);
194     }
195     return isEqual;
196 }
197 
198 /// Deprecated function that will be removed together with
199 /// the Compute enum
ParseComputeDevice(const char * str)200 constexpr armnn::Compute ParseComputeDevice(const char* str)
201 {
202     if (armnn::StrEqual(str, "CpuAcc"))
203     {
204         return armnn::Compute::CpuAcc;
205     }
206     else if (armnn::StrEqual(str, "CpuRef"))
207     {
208         return armnn::Compute::CpuRef;
209     }
210     else if (armnn::StrEqual(str, "GpuAcc"))
211     {
212         return armnn::Compute::GpuAcc;
213     }
214     else
215     {
216         return armnn::Compute::Undefined;
217     }
218 }
219 
GetDataTypeName(DataType dataType)220 constexpr const char* GetDataTypeName(DataType dataType)
221 {
222     switch (dataType)
223     {
224         case DataType::Float16:               return "Float16";
225         case DataType::Float32:               return "Float32";
226         case DataType::Signed64:              return "Signed64";
227         case DataType::QAsymmU8:              return "QAsymmU8";
228         case DataType::QAsymmS8:              return "QAsymmS8";
229         case DataType::QSymmS8:               return "QSymmS8";
230         case DataType::QSymmS16:              return "QSymm16";
231         case DataType::Signed32:              return "Signed32";
232         case DataType::Boolean:               return "Boolean";
233         case DataType::BFloat16:              return "BFloat16";
234 
235         default:
236             return "Unknown";
237     }
238 }
239 
GetDataLayoutName(DataLayout dataLayout)240 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
241 {
242     switch (dataLayout)
243     {
244         case DataLayout::NCHW:  return "NCHW";
245         case DataLayout::NHWC:  return "NHWC";
246         case DataLayout::NDHWC: return "NDHWC";
247         case DataLayout::NCDHW: return "NCDHW";
248         default:                return "Unknown";
249     }
250 }
251 
GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)252 constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
253 {
254     switch (channel)
255     {
256         case NormalizationAlgorithmChannel::Across: return "Across";
257         case NormalizationAlgorithmChannel::Within: return "Within";
258         default:                                    return "Unknown";
259     }
260 }
261 
GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)262 constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
263 {
264     switch (method)
265     {
266         case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
267         case NormalizationAlgorithmMethod::LocalContrast:   return "LocalContrast";
268         default:                                            return "Unknown";
269     }
270 }
271 
GetResizeMethodAsCString(ResizeMethod method)272 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
273 {
274     switch (method)
275     {
276         case ResizeMethod::Bilinear:        return "Bilinear";
277         case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
278         default:                            return "Unknown";
279     }
280 }
281 
GetMemBlockStrategyTypeName(MemBlockStrategyType memBlockStrategyType)282 constexpr const char* GetMemBlockStrategyTypeName(MemBlockStrategyType memBlockStrategyType)
283 {
284     switch (memBlockStrategyType)
285     {
286         case MemBlockStrategyType::SingleAxisPacking: return "SingleAxisPacking";
287         case MemBlockStrategyType::MultiAxisPacking:  return "MultiAxisPacking";
288         default:                                      return "Unknown";
289     }
290 }
291 
292 template<typename T>
293 struct IsHalfType
294     : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
295 {};
296 
297 template<typename T>
IsQuantizedType()298 constexpr bool IsQuantizedType()
299 {
300     return std::is_integral<T>::value;
301 }
302 
IsQuantized8BitType(DataType dataType)303 constexpr bool IsQuantized8BitType(DataType dataType)
304 {
305     return dataType == DataType::QAsymmU8        ||
306            dataType == DataType::QAsymmS8        ||
307            dataType == DataType::QSymmS8;
308 }
309 
IsQuantizedType(DataType dataType)310 constexpr bool IsQuantizedType(DataType dataType)
311 {
312     return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
313 }
314 
operator <<(std::ostream & os,Status stat)315 inline std::ostream& operator<<(std::ostream& os, Status stat)
316 {
317     os << GetStatusAsCString(stat);
318     return os;
319 }
320 
321 
operator <<(std::ostream & os,const armnn::TensorShape & shape)322 inline std::ostream& operator<<(std::ostream& os, const armnn::TensorShape& shape)
323 {
324     os << "[";
325     if (shape.GetDimensionality() != Dimensionality::NotSpecified)
326     {
327         for (uint32_t i = 0; i < shape.GetNumDimensions(); ++i)
328         {
329             if (i != 0)
330             {
331                 os << ",";
332             }
333             if (shape.GetDimensionSpecificity(i))
334             {
335                 os << shape[i];
336             }
337             else
338             {
339                 os << "?";
340             }
341         }
342     }
343     else
344     {
345         os << "Dimensionality Not Specified";
346     }
347     os << "]";
348     return os;
349 }
350 
351 /// Quantize a floating point data type into an 8-bit data type.
352 /// @param value - The value to quantize.
353 /// @param scale - The scale (must be non-zero).
354 /// @param offset - The offset.
355 /// @return - The quantized value calculated as round(value/scale)+offset.
356 ///
357 template<typename QuantizedType>
358 QuantizedType Quantize(float value, float scale, int32_t offset);
359 
360 /// Dequantize an 8-bit data type into a floating point data type.
361 /// @param value - The value to dequantize.
362 /// @param scale - The scale (must be non-zero).
363 /// @param offset - The offset.
364 /// @return - The dequantized value calculated as (value-offset)*scale.
365 ///
366 template <typename QuantizedType>
367 float Dequantize(QuantizedType value, float scale, int32_t offset);
368 
VerifyTensorInfoDataType(const armnn::TensorInfo & info,armnn::DataType dataType)369 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
370 {
371     if (info.GetDataType() != dataType)
372     {
373         std::stringstream ss;
374         ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
375            << " for tensor:" << info.GetShape()
376            << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
377         throw armnn::Exception(ss.str());
378     }
379 }
380 
381 } //namespace armnn
382