1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
9
10 #include <cmath>
11 #include <ostream>
12 #include <set>
13
14 namespace armnn
15 {
16
GetStatusAsCString(Status status)17 constexpr char const* GetStatusAsCString(Status status)
18 {
19 switch (status)
20 {
21 case armnn::Status::Success: return "Status::Success";
22 case armnn::Status::Failure: return "Status::Failure";
23 default: return "Unknown";
24 }
25 }
26
GetActivationFunctionAsCString(ActivationFunction activation)27 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
28 {
29 switch (activation)
30 {
31 case ActivationFunction::Sigmoid: return "Sigmoid";
32 case ActivationFunction::TanH: return "TanH";
33 case ActivationFunction::Linear: return "Linear";
34 case ActivationFunction::ReLu: return "ReLu";
35 case ActivationFunction::BoundedReLu: return "BoundedReLu";
36 case ActivationFunction::SoftReLu: return "SoftReLu";
37 case ActivationFunction::LeakyReLu: return "LeakyReLu";
38 case ActivationFunction::Abs: return "Abs";
39 case ActivationFunction::Sqrt: return "Sqrt";
40 case ActivationFunction::Square: return "Square";
41 case ActivationFunction::Elu: return "Elu";
42 case ActivationFunction::HardSwish: return "HardSwish";
43 default: return "Unknown";
44 }
45 }
46
GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)47 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
48 {
49 switch (function)
50 {
51 case ArgMinMaxFunction::Max: return "Max";
52 case ArgMinMaxFunction::Min: return "Min";
53 default: return "Unknown";
54 }
55 }
56
GetComparisonOperationAsCString(ComparisonOperation operation)57 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
58 {
59 switch (operation)
60 {
61 case ComparisonOperation::Equal: return "Equal";
62 case ComparisonOperation::Greater: return "Greater";
63 case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
64 case ComparisonOperation::Less: return "Less";
65 case ComparisonOperation::LessOrEqual: return "LessOrEqual";
66 case ComparisonOperation::NotEqual: return "NotEqual";
67 default: return "Unknown";
68 }
69 }
70
GetUnaryOperationAsCString(UnaryOperation operation)71 constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation)
72 {
73 switch (operation)
74 {
75 case UnaryOperation::Abs: return "Abs";
76 case UnaryOperation::Exp: return "Exp";
77 case UnaryOperation::Sqrt: return "Sqrt";
78 case UnaryOperation::Rsqrt: return "Rsqrt";
79 case UnaryOperation::Neg: return "Neg";
80 case UnaryOperation::LogicalNot: return "LogicalNot";
81 default: return "Unknown";
82 }
83 }
84
GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)85 constexpr char const* GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)
86 {
87 switch (operation)
88 {
89 case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd";
90 case LogicalBinaryOperation::LogicalOr: return "LogicalOr";
91 default: return "Unknown";
92 }
93 }
94
GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)95 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
96 {
97 switch (pooling)
98 {
99 case PoolingAlgorithm::Average: return "Average";
100 case PoolingAlgorithm::Max: return "Max";
101 case PoolingAlgorithm::L2: return "L2";
102 default: return "Unknown";
103 }
104 }
105
GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)106 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
107 {
108 switch (rounding)
109 {
110 case OutputShapeRounding::Ceiling: return "Ceiling";
111 case OutputShapeRounding::Floor: return "Floor";
112 default: return "Unknown";
113 }
114 }
115
GetPaddingMethodAsCString(PaddingMethod method)116 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
117 {
118 switch (method)
119 {
120 case PaddingMethod::Exclude: return "Exclude";
121 case PaddingMethod::IgnoreValue: return "IgnoreValue";
122 default: return "Unknown";
123 }
124 }
125
GetDataTypeSize(DataType dataType)126 constexpr unsigned int GetDataTypeSize(DataType dataType)
127 {
128 switch (dataType)
129 {
130 case DataType::BFloat16:
131 case DataType::Float16: return 2U;
132 case DataType::Float32:
133 case DataType::Signed32: return 4U;
134 case DataType::Signed64: return 8U;
135 case DataType::QAsymmU8: return 1U;
136 case DataType::QAsymmS8: return 1U;
137 case DataType::QSymmS8: return 1U;
138 ARMNN_NO_DEPRECATE_WARN_BEGIN
139 case DataType::QuantizedSymm8PerAxis: return 1U;
140 ARMNN_NO_DEPRECATE_WARN_END
141 case DataType::QSymmS16: return 2U;
142 case DataType::Boolean: return 1U;
143 default: return 0U;
144 }
145 }
146
147 template <unsigned N>
StrEqual(const char * strA,const char (& strB)[N])148 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
149 {
150 bool isEqual = true;
151 for (unsigned i = 0; isEqual && (i < N); ++i)
152 {
153 isEqual = (strA[i] == strB[i]);
154 }
155 return isEqual;
156 }
157
158 /// Deprecated function that will be removed together with
159 /// the Compute enum
ParseComputeDevice(const char * str)160 constexpr armnn::Compute ParseComputeDevice(const char* str)
161 {
162 if (armnn::StrEqual(str, "CpuAcc"))
163 {
164 return armnn::Compute::CpuAcc;
165 }
166 else if (armnn::StrEqual(str, "CpuRef"))
167 {
168 return armnn::Compute::CpuRef;
169 }
170 else if (armnn::StrEqual(str, "GpuAcc"))
171 {
172 return armnn::Compute::GpuAcc;
173 }
174 else
175 {
176 return armnn::Compute::Undefined;
177 }
178 }
179
GetDataTypeName(DataType dataType)180 constexpr const char* GetDataTypeName(DataType dataType)
181 {
182 switch (dataType)
183 {
184 case DataType::Float16: return "Float16";
185 case DataType::Float32: return "Float32";
186 case DataType::Signed64: return "Signed64";
187 case DataType::QAsymmU8: return "QAsymmU8";
188 case DataType::QAsymmS8: return "QAsymmS8";
189 case DataType::QSymmS8: return "QSymmS8";
190 ARMNN_NO_DEPRECATE_WARN_BEGIN
191 case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
192 ARMNN_NO_DEPRECATE_WARN_END
193 case DataType::QSymmS16: return "QSymm16";
194 case DataType::Signed32: return "Signed32";
195 case DataType::Boolean: return "Boolean";
196 case DataType::BFloat16: return "BFloat16";
197
198 default:
199 return "Unknown";
200 }
201 }
202
GetDataLayoutName(DataLayout dataLayout)203 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
204 {
205 switch (dataLayout)
206 {
207 case DataLayout::NCHW: return "NCHW";
208 case DataLayout::NHWC: return "NHWC";
209 default: return "Unknown";
210 }
211 }
212
GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)213 constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
214 {
215 switch (channel)
216 {
217 case NormalizationAlgorithmChannel::Across: return "Across";
218 case NormalizationAlgorithmChannel::Within: return "Within";
219 default: return "Unknown";
220 }
221 }
222
GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)223 constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
224 {
225 switch (method)
226 {
227 case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
228 case NormalizationAlgorithmMethod::LocalContrast: return "LocalContrast";
229 default: return "Unknown";
230 }
231 }
232
GetResizeMethodAsCString(ResizeMethod method)233 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
234 {
235 switch (method)
236 {
237 case ResizeMethod::Bilinear: return "Bilinear";
238 case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
239 default: return "Unknown";
240 }
241 }
242
243 template<typename T>
244 struct IsHalfType
245 : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
246 {};
247
248 template<typename T>
IsQuantizedType()249 constexpr bool IsQuantizedType()
250 {
251 return std::is_integral<T>::value;
252 }
253
IsQuantized8BitType(DataType dataType)254 constexpr bool IsQuantized8BitType(DataType dataType)
255 {
256 ARMNN_NO_DEPRECATE_WARN_BEGIN
257 return dataType == DataType::QAsymmU8 ||
258 dataType == DataType::QAsymmS8 ||
259 dataType == DataType::QSymmS8 ||
260 dataType == DataType::QuantizedSymm8PerAxis;
261 ARMNN_NO_DEPRECATE_WARN_END
262 }
263
IsQuantizedType(DataType dataType)264 constexpr bool IsQuantizedType(DataType dataType)
265 {
266 return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
267 }
268
operator <<(std::ostream & os,Status stat)269 inline std::ostream& operator<<(std::ostream& os, Status stat)
270 {
271 os << GetStatusAsCString(stat);
272 return os;
273 }
274
275
operator <<(std::ostream & os,const armnn::TensorShape & shape)276 inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
277 {
278 os << "[";
279 for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
280 {
281 if (i!=0)
282 {
283 os << ",";
284 }
285 os << shape[i];
286 }
287 os << "]";
288 return os;
289 }
290
291 /// Quantize a floating point data type into an 8-bit data type.
292 /// @param value - The value to quantize.
293 /// @param scale - The scale (must be non-zero).
294 /// @param offset - The offset.
295 /// @return - The quantized value calculated as round(value/scale)+offset.
296 ///
297 template<typename QuantizedType>
298 QuantizedType Quantize(float value, float scale, int32_t offset);
299
300 /// Dequantize an 8-bit data type into a floating point data type.
301 /// @param value - The value to dequantize.
302 /// @param scale - The scale (must be non-zero).
303 /// @param offset - The offset.
304 /// @return - The dequantized value calculated as (value-offset)*scale.
305 ///
306 template <typename QuantizedType>
307 float Dequantize(QuantizedType value, float scale, int32_t offset);
308
VerifyTensorInfoDataType(const armnn::TensorInfo & info,armnn::DataType dataType)309 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
310 {
311 if (info.GetDataType() != dataType)
312 {
313 std::stringstream ss;
314 ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
315 << " for tensor:" << info.GetShape()
316 << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
317 throw armnn::Exception(ss.str());
318 }
319 }
320
321 } //namespace armnn
322