1 //
2 // Copyright © 2018-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/BackendId.hpp>
8 #include <armnn/Exceptions.hpp>
9 #include <armnn/Tensor.hpp>
10 #include <armnn/Types.hpp>
11
12 #include <stdint.h>
13 #include <cmath>
14 #include <ostream>
15 #include <set>
16 #include <type_traits>
17
18 namespace armnn
19 {
20
GetStatusAsCString(Status status)21 constexpr char const* GetStatusAsCString(Status status)
22 {
23 switch (status)
24 {
25 case armnn::Status::Success: return "Status::Success";
26 case armnn::Status::Failure: return "Status::Failure";
27 default: return "Unknown";
28 }
29 }
30
GetActivationFunctionAsCString(ActivationFunction activation)31 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
32 {
33 switch (activation)
34 {
35 case ActivationFunction::Sigmoid: return "Sigmoid";
36 case ActivationFunction::TanH: return "TanH";
37 case ActivationFunction::Linear: return "Linear";
38 case ActivationFunction::ReLu: return "ReLu";
39 case ActivationFunction::BoundedReLu: return "BoundedReLu";
40 case ActivationFunction::SoftReLu: return "SoftReLu";
41 case ActivationFunction::LeakyReLu: return "LeakyReLu";
42 case ActivationFunction::Abs: return "Abs";
43 case ActivationFunction::Sqrt: return "Sqrt";
44 case ActivationFunction::Square: return "Square";
45 case ActivationFunction::Elu: return "Elu";
46 case ActivationFunction::HardSwish: return "HardSwish";
47 default: return "Unknown";
48 }
49 }
50
GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)51 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
52 {
53 switch (function)
54 {
55 case ArgMinMaxFunction::Max: return "Max";
56 case ArgMinMaxFunction::Min: return "Min";
57 default: return "Unknown";
58 }
59 }
60
GetComparisonOperationAsCString(ComparisonOperation operation)61 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
62 {
63 switch (operation)
64 {
65 case ComparisonOperation::Equal: return "Equal";
66 case ComparisonOperation::Greater: return "Greater";
67 case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
68 case ComparisonOperation::Less: return "Less";
69 case ComparisonOperation::LessOrEqual: return "LessOrEqual";
70 case ComparisonOperation::NotEqual: return "NotEqual";
71 default: return "Unknown";
72 }
73 }
74
GetBinaryOperationAsCString(BinaryOperation operation)75 constexpr char const* GetBinaryOperationAsCString(BinaryOperation operation)
76 {
77 switch (operation)
78 {
79 case BinaryOperation::Add: return "Add";
80 case BinaryOperation::Div: return "Div";
81 case BinaryOperation::Maximum: return "Maximum";
82 case BinaryOperation::Minimum: return "Minimum";
83 case BinaryOperation::Mul: return "Mul";
84 case BinaryOperation::Sub: return "Sub";
85 default: return "Unknown";
86 }
87 }
88
GetUnaryOperationAsCString(UnaryOperation operation)89 constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation)
90 {
91 switch (operation)
92 {
93 case UnaryOperation::Abs: return "Abs";
94 case UnaryOperation::Exp: return "Exp";
95 case UnaryOperation::Sqrt: return "Sqrt";
96 case UnaryOperation::Rsqrt: return "Rsqrt";
97 case UnaryOperation::Neg: return "Neg";
98 case UnaryOperation::Log: return "Log";
99 case UnaryOperation::LogicalNot: return "LogicalNot";
100 case UnaryOperation::Sin: return "Sin";
101 default: return "Unknown";
102 }
103 }
104
GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)105 constexpr char const* GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)
106 {
107 switch (operation)
108 {
109 case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd";
110 case LogicalBinaryOperation::LogicalOr: return "LogicalOr";
111 default: return "Unknown";
112 }
113 }
114
GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)115 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
116 {
117 switch (pooling)
118 {
119 case PoolingAlgorithm::Average: return "Average";
120 case PoolingAlgorithm::Max: return "Max";
121 case PoolingAlgorithm::L2: return "L2";
122 default: return "Unknown";
123 }
124 }
125
GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)126 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
127 {
128 switch (rounding)
129 {
130 case OutputShapeRounding::Ceiling: return "Ceiling";
131 case OutputShapeRounding::Floor: return "Floor";
132 default: return "Unknown";
133 }
134 }
135
GetPaddingMethodAsCString(PaddingMethod method)136 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
137 {
138 switch (method)
139 {
140 case PaddingMethod::Exclude: return "Exclude";
141 case PaddingMethod::IgnoreValue: return "IgnoreValue";
142 default: return "Unknown";
143 }
144 }
145
GetPaddingModeAsCString(PaddingMode mode)146 constexpr char const* GetPaddingModeAsCString(PaddingMode mode)
147 {
148 switch (mode)
149 {
150 case PaddingMode::Constant: return "Exclude";
151 case PaddingMode::Symmetric: return "Symmetric";
152 case PaddingMode::Reflect: return "Reflect";
153 default: return "Unknown";
154 }
155 }
156
GetReduceOperationAsCString(ReduceOperation reduce_operation)157 constexpr char const* GetReduceOperationAsCString(ReduceOperation reduce_operation)
158 {
159 switch (reduce_operation)
160 {
161 case ReduceOperation::Sum: return "Sum";
162 case ReduceOperation::Max: return "Max";
163 case ReduceOperation::Mean: return "Mean";
164 case ReduceOperation::Min: return "Min";
165 case ReduceOperation::Prod: return "Prod";
166 default: return "Unknown";
167 }
168 }
GetDataTypeSize(DataType dataType)169 constexpr unsigned int GetDataTypeSize(DataType dataType)
170 {
171 switch (dataType)
172 {
173 case DataType::BFloat16:
174 case DataType::Float16: return 2U;
175 case DataType::Float32:
176 case DataType::Signed32: return 4U;
177 case DataType::Signed64: return 8U;
178 case DataType::QAsymmU8: return 1U;
179 case DataType::QAsymmS8: return 1U;
180 case DataType::QSymmS8: return 1U;
181 case DataType::QSymmS16: return 2U;
182 case DataType::Boolean: return 1U;
183 default: return 0U;
184 }
185 }
186
187 template <unsigned N>
StrEqual(const char * strA,const char (& strB)[N])188 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
189 {
190 bool isEqual = true;
191 for (unsigned i = 0; isEqual && (i < N); ++i)
192 {
193 isEqual = (strA[i] == strB[i]);
194 }
195 return isEqual;
196 }
197
198 /// Deprecated function that will be removed together with
199 /// the Compute enum
ParseComputeDevice(const char * str)200 constexpr armnn::Compute ParseComputeDevice(const char* str)
201 {
202 if (armnn::StrEqual(str, "CpuAcc"))
203 {
204 return armnn::Compute::CpuAcc;
205 }
206 else if (armnn::StrEqual(str, "CpuRef"))
207 {
208 return armnn::Compute::CpuRef;
209 }
210 else if (armnn::StrEqual(str, "GpuAcc"))
211 {
212 return armnn::Compute::GpuAcc;
213 }
214 else
215 {
216 return armnn::Compute::Undefined;
217 }
218 }
219
GetDataTypeName(DataType dataType)220 constexpr const char* GetDataTypeName(DataType dataType)
221 {
222 switch (dataType)
223 {
224 case DataType::Float16: return "Float16";
225 case DataType::Float32: return "Float32";
226 case DataType::Signed64: return "Signed64";
227 case DataType::QAsymmU8: return "QAsymmU8";
228 case DataType::QAsymmS8: return "QAsymmS8";
229 case DataType::QSymmS8: return "QSymmS8";
230 case DataType::QSymmS16: return "QSymm16";
231 case DataType::Signed32: return "Signed32";
232 case DataType::Boolean: return "Boolean";
233 case DataType::BFloat16: return "BFloat16";
234
235 default:
236 return "Unknown";
237 }
238 }
239
GetDataLayoutName(DataLayout dataLayout)240 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
241 {
242 switch (dataLayout)
243 {
244 case DataLayout::NCHW: return "NCHW";
245 case DataLayout::NHWC: return "NHWC";
246 case DataLayout::NDHWC: return "NDHWC";
247 case DataLayout::NCDHW: return "NCDHW";
248 default: return "Unknown";
249 }
250 }
251
GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)252 constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
253 {
254 switch (channel)
255 {
256 case NormalizationAlgorithmChannel::Across: return "Across";
257 case NormalizationAlgorithmChannel::Within: return "Within";
258 default: return "Unknown";
259 }
260 }
261
GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)262 constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
263 {
264 switch (method)
265 {
266 case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
267 case NormalizationAlgorithmMethod::LocalContrast: return "LocalContrast";
268 default: return "Unknown";
269 }
270 }
271
GetResizeMethodAsCString(ResizeMethod method)272 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
273 {
274 switch (method)
275 {
276 case ResizeMethod::Bilinear: return "Bilinear";
277 case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
278 default: return "Unknown";
279 }
280 }
281
GetMemBlockStrategyTypeName(MemBlockStrategyType memBlockStrategyType)282 constexpr const char* GetMemBlockStrategyTypeName(MemBlockStrategyType memBlockStrategyType)
283 {
284 switch (memBlockStrategyType)
285 {
286 case MemBlockStrategyType::SingleAxisPacking: return "SingleAxisPacking";
287 case MemBlockStrategyType::MultiAxisPacking: return "MultiAxisPacking";
288 default: return "Unknown";
289 }
290 }
291
292 template<typename T>
293 struct IsHalfType
294 : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
295 {};
296
297 template<typename T>
IsQuantizedType()298 constexpr bool IsQuantizedType()
299 {
300 return std::is_integral<T>::value;
301 }
302
IsQuantized8BitType(DataType dataType)303 constexpr bool IsQuantized8BitType(DataType dataType)
304 {
305 return dataType == DataType::QAsymmU8 ||
306 dataType == DataType::QAsymmS8 ||
307 dataType == DataType::QSymmS8;
308 }
309
IsQuantizedType(DataType dataType)310 constexpr bool IsQuantizedType(DataType dataType)
311 {
312 return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
313 }
314
operator <<(std::ostream & os,Status stat)315 inline std::ostream& operator<<(std::ostream& os, Status stat)
316 {
317 os << GetStatusAsCString(stat);
318 return os;
319 }
320
321
operator <<(std::ostream & os,const armnn::TensorShape & shape)322 inline std::ostream& operator<<(std::ostream& os, const armnn::TensorShape& shape)
323 {
324 os << "[";
325 if (shape.GetDimensionality() != Dimensionality::NotSpecified)
326 {
327 for (uint32_t i = 0; i < shape.GetNumDimensions(); ++i)
328 {
329 if (i != 0)
330 {
331 os << ",";
332 }
333 if (shape.GetDimensionSpecificity(i))
334 {
335 os << shape[i];
336 }
337 else
338 {
339 os << "?";
340 }
341 }
342 }
343 else
344 {
345 os << "Dimensionality Not Specified";
346 }
347 os << "]";
348 return os;
349 }
350
351 /// Quantize a floating point data type into an 8-bit data type.
352 /// @param value - The value to quantize.
353 /// @param scale - The scale (must be non-zero).
354 /// @param offset - The offset.
355 /// @return - The quantized value calculated as round(value/scale)+offset.
356 ///
357 template<typename QuantizedType>
358 QuantizedType Quantize(float value, float scale, int32_t offset);
359
360 /// Dequantize an 8-bit data type into a floating point data type.
361 /// @param value - The value to dequantize.
362 /// @param scale - The scale (must be non-zero).
363 /// @param offset - The offset.
364 /// @return - The dequantized value calculated as (value-offset)*scale.
365 ///
366 template <typename QuantizedType>
367 float Dequantize(QuantizedType value, float scale, int32_t offset);
368
VerifyTensorInfoDataType(const armnn::TensorInfo & info,armnn::DataType dataType)369 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
370 {
371 if (info.GetDataType() != dataType)
372 {
373 std::stringstream ss;
374 ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
375 << " for tensor:" << info.GetShape()
376 << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
377 throw armnn::Exception(ss.str());
378 }
379 }
380
381 } //namespace armnn
382