• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
9 
10 #include <cmath>
11 #include <ostream>
12 #include <set>
13 
14 namespace armnn
15 {
16 
GetStatusAsCString(Status status)17 constexpr char const* GetStatusAsCString(Status status)
18 {
19     switch (status)
20     {
21         case armnn::Status::Success: return "Status::Success";
22         case armnn::Status::Failure: return "Status::Failure";
23         default:                     return "Unknown";
24     }
25 }
26 
GetActivationFunctionAsCString(ActivationFunction activation)27 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
28 {
29     switch (activation)
30     {
31         case ActivationFunction::Sigmoid:       return "Sigmoid";
32         case ActivationFunction::TanH:          return "TanH";
33         case ActivationFunction::Linear:        return "Linear";
34         case ActivationFunction::ReLu:          return "ReLu";
35         case ActivationFunction::BoundedReLu:   return "BoundedReLu";
36         case ActivationFunction::SoftReLu:      return "SoftReLu";
37         case ActivationFunction::LeakyReLu:     return "LeakyReLu";
38         case ActivationFunction::Abs:           return "Abs";
39         case ActivationFunction::Sqrt:          return "Sqrt";
40         case ActivationFunction::Square:        return "Square";
41         case ActivationFunction::Elu:           return "Elu";
42         case ActivationFunction::HardSwish:     return "HardSwish";
43         default:                                return "Unknown";
44     }
45 }
46 
GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)47 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
48 {
49     switch (function)
50     {
51         case ArgMinMaxFunction::Max:    return "Max";
52         case ArgMinMaxFunction::Min:    return "Min";
53         default:                        return "Unknown";
54     }
55 }
56 
GetComparisonOperationAsCString(ComparisonOperation operation)57 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
58 {
59     switch (operation)
60     {
61         case ComparisonOperation::Equal:          return "Equal";
62         case ComparisonOperation::Greater:        return "Greater";
63         case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
64         case ComparisonOperation::Less:           return "Less";
65         case ComparisonOperation::LessOrEqual:    return "LessOrEqual";
66         case ComparisonOperation::NotEqual:       return "NotEqual";
67         default:                                  return "Unknown";
68     }
69 }
70 
GetUnaryOperationAsCString(UnaryOperation operation)71 constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation)
72 {
73     switch (operation)
74     {
75         case UnaryOperation::Abs:        return "Abs";
76         case UnaryOperation::Exp:        return "Exp";
77         case UnaryOperation::Sqrt:       return "Sqrt";
78         case UnaryOperation::Rsqrt:      return "Rsqrt";
79         case UnaryOperation::Neg:        return "Neg";
80         case UnaryOperation::LogicalNot: return "LogicalNot";
81         default:                         return "Unknown";
82     }
83 }
84 
GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)85 constexpr char const* GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)
86 {
87     switch (operation)
88     {
89         case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd";
90         case LogicalBinaryOperation::LogicalOr:  return "LogicalOr";
91         default:                                 return "Unknown";
92     }
93 }
94 
GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)95 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
96 {
97     switch (pooling)
98     {
99         case PoolingAlgorithm::Average:  return "Average";
100         case PoolingAlgorithm::Max:      return "Max";
101         case PoolingAlgorithm::L2:       return "L2";
102         default:                         return "Unknown";
103     }
104 }
105 
GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)106 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
107 {
108     switch (rounding)
109     {
110         case OutputShapeRounding::Ceiling:  return "Ceiling";
111         case OutputShapeRounding::Floor:    return "Floor";
112         default:                            return "Unknown";
113     }
114 }
115 
GetPaddingMethodAsCString(PaddingMethod method)116 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
117 {
118     switch (method)
119     {
120         case PaddingMethod::Exclude:       return "Exclude";
121         case PaddingMethod::IgnoreValue:   return "IgnoreValue";
122         default:                           return "Unknown";
123     }
124 }
125 
GetDataTypeSize(DataType dataType)126 constexpr unsigned int GetDataTypeSize(DataType dataType)
127 {
128     switch (dataType)
129     {
130         case DataType::BFloat16:
131         case DataType::Float16:               return 2U;
132         case DataType::Float32:
133         case DataType::Signed32:              return 4U;
134         case DataType::Signed64:              return 8U;
135         case DataType::QAsymmU8:              return 1U;
136         case DataType::QAsymmS8:              return 1U;
137         case DataType::QSymmS8:               return 1U;
138         ARMNN_NO_DEPRECATE_WARN_BEGIN
139         case DataType::QuantizedSymm8PerAxis: return 1U;
140         ARMNN_NO_DEPRECATE_WARN_END
141         case DataType::QSymmS16:              return 2U;
142         case DataType::Boolean:               return 1U;
143         default:                              return 0U;
144     }
145 }
146 
147 template <unsigned N>
StrEqual(const char * strA,const char (& strB)[N])148 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
149 {
150     bool isEqual = true;
151     for (unsigned i = 0; isEqual && (i < N); ++i)
152     {
153         isEqual = (strA[i] == strB[i]);
154     }
155     return isEqual;
156 }
157 
158 /// Deprecated function that will be removed together with
159 /// the Compute enum
ParseComputeDevice(const char * str)160 constexpr armnn::Compute ParseComputeDevice(const char* str)
161 {
162     if (armnn::StrEqual(str, "CpuAcc"))
163     {
164         return armnn::Compute::CpuAcc;
165     }
166     else if (armnn::StrEqual(str, "CpuRef"))
167     {
168         return armnn::Compute::CpuRef;
169     }
170     else if (armnn::StrEqual(str, "GpuAcc"))
171     {
172         return armnn::Compute::GpuAcc;
173     }
174     else
175     {
176         return armnn::Compute::Undefined;
177     }
178 }
179 
GetDataTypeName(DataType dataType)180 constexpr const char* GetDataTypeName(DataType dataType)
181 {
182     switch (dataType)
183     {
184         case DataType::Float16:               return "Float16";
185         case DataType::Float32:               return "Float32";
186         case DataType::Signed64:              return "Signed64";
187         case DataType::QAsymmU8:              return "QAsymmU8";
188         case DataType::QAsymmS8:              return "QAsymmS8";
189         case DataType::QSymmS8:               return "QSymmS8";
190         ARMNN_NO_DEPRECATE_WARN_BEGIN
191         case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
192         ARMNN_NO_DEPRECATE_WARN_END
193         case DataType::QSymmS16:              return "QSymm16";
194         case DataType::Signed32:              return "Signed32";
195         case DataType::Boolean:               return "Boolean";
196         case DataType::BFloat16:              return "BFloat16";
197 
198         default:
199             return "Unknown";
200     }
201 }
202 
GetDataLayoutName(DataLayout dataLayout)203 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
204 {
205     switch (dataLayout)
206     {
207         case DataLayout::NCHW: return "NCHW";
208         case DataLayout::NHWC: return "NHWC";
209         default:               return "Unknown";
210     }
211 }
212 
GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)213 constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
214 {
215     switch (channel)
216     {
217         case NormalizationAlgorithmChannel::Across: return "Across";
218         case NormalizationAlgorithmChannel::Within: return "Within";
219         default:                                    return "Unknown";
220     }
221 }
222 
GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)223 constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
224 {
225     switch (method)
226     {
227         case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
228         case NormalizationAlgorithmMethod::LocalContrast:   return "LocalContrast";
229         default:                                            return "Unknown";
230     }
231 }
232 
GetResizeMethodAsCString(ResizeMethod method)233 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
234 {
235     switch (method)
236     {
237         case ResizeMethod::Bilinear:        return "Bilinear";
238         case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
239         default:                            return "Unknown";
240     }
241 }
242 
243 template<typename T>
244 struct IsHalfType
245     : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
246 {};
247 
248 template<typename T>
IsQuantizedType()249 constexpr bool IsQuantizedType()
250 {
251     return std::is_integral<T>::value;
252 }
253 
IsQuantized8BitType(DataType dataType)254 constexpr bool IsQuantized8BitType(DataType dataType)
255 {
256     ARMNN_NO_DEPRECATE_WARN_BEGIN
257     return dataType == DataType::QAsymmU8        ||
258            dataType == DataType::QAsymmS8        ||
259            dataType == DataType::QSymmS8         ||
260            dataType == DataType::QuantizedSymm8PerAxis;
261     ARMNN_NO_DEPRECATE_WARN_END
262 }
263 
IsQuantizedType(DataType dataType)264 constexpr bool IsQuantizedType(DataType dataType)
265 {
266     return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
267 }
268 
operator <<(std::ostream & os,Status stat)269 inline std::ostream& operator<<(std::ostream& os, Status stat)
270 {
271     os << GetStatusAsCString(stat);
272     return os;
273 }
274 
275 
operator <<(std::ostream & os,const armnn::TensorShape & shape)276 inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
277 {
278     os << "[";
279     for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
280     {
281         if (i!=0)
282         {
283             os << ",";
284         }
285         os << shape[i];
286     }
287     os << "]";
288     return os;
289 }
290 
291 /// Quantize a floating point data type into an 8-bit data type.
292 /// @param value - The value to quantize.
293 /// @param scale - The scale (must be non-zero).
294 /// @param offset - The offset.
295 /// @return - The quantized value calculated as round(value/scale)+offset.
296 ///
297 template<typename QuantizedType>
298 QuantizedType Quantize(float value, float scale, int32_t offset);
299 
300 /// Dequantize an 8-bit data type into a floating point data type.
301 /// @param value - The value to dequantize.
302 /// @param scale - The scale (must be non-zero).
303 /// @param offset - The offset.
304 /// @return - The dequantized value calculated as (value-offset)*scale.
305 ///
306 template <typename QuantizedType>
307 float Dequantize(QuantizedType value, float scale, int32_t offset);
308 
VerifyTensorInfoDataType(const armnn::TensorInfo & info,armnn::DataType dataType)309 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
310 {
311     if (info.GetDataType() != dataType)
312     {
313         std::stringstream ss;
314         ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
315            << " for tensor:" << info.GetShape()
316            << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
317         throw armnn::Exception(ss.str());
318     }
319 }
320 
321 } //namespace armnn
322