1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
19
20 #include <android-base/logging.h>
21 #include <android-base/macros.h>
22
23 #include <ostream>
24 #include <utility>
25 #include <vector>
26
27 #include "nnapi/OperandTypes.h"
28 #include "nnapi/OperationTypes.h"
29 #include "nnapi/Result.h"
30 #include "nnapi/Types.h"
31
32 namespace android::nn {
33
34 enum class HalVersion : int32_t {
35 UNKNOWN,
36 V1_0,
37 V1_1,
38 V1_2,
39 V1_3,
40 AIDL_UNSTABLE,
41 LATEST = V1_3,
42 };
43
44 bool isExtension(OperandType type);
45 bool isExtension(OperationType type);
46
47 bool isNonExtensionScalar(OperandType operandType);
48
49 size_t getNonExtensionSize(OperandType operandType);
50
getExtensionPrefix(uint32_t type)51 inline uint16_t getExtensionPrefix(uint32_t type) {
52 return static_cast<uint16_t>(type >> kExtensionTypeBits);
53 }
54
getTypeWithinExtension(uint32_t type)55 inline uint16_t getTypeWithinExtension(uint32_t type) {
56 return static_cast<uint16_t>(type & kTypeWithinExtensionMask);
57 }
58
59 std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions);
60 std::optional<size_t> getNonExtensionSize(const Operand& operand);
61
62 size_t getOffsetFromInts(int lower, int higher);
63 std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset);
64
65 Result<std::vector<uint32_t>> countNumberOfConsumers(size_t numberOfOperands,
66 const std::vector<nn::Operation>& operations);
67
68 // Combine two tensor dimensions, both may have unspecified dimensions or rank.
69 Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs);
70
71 // Returns the operandValues's size and a size for each pool in the provided model.
72 std::pair<size_t, std::vector<size_t>> getMemorySizes(const Model& model);
73
74 // Round up "size" to the nearest multiple of "multiple". "multiple" must be a power of 2.
75 size_t roundUp(size_t size, size_t multiple);
76
77 // Returns the alignment for data of the specified length. It aligns object of length:
78 // 2, 3 on a 2 byte boundary,
79 // 4+ on a 4 byte boundary.
80 // We may want to have different alignments for tensors.
81 // TODO: This is arbitrary, more a proof of concept. We need to determine what this should be.
82 //
83 // Note that Types.cpp ensures `new` has sufficient alignment for all alignments returned by this
84 // function. If this function is changed to return different alignments (e.g., 8 byte boundary
85 // alignment), the code check in Types.cpp similarly needs to be updated.
86 size_t getAlignmentForLength(size_t length);
87
88 // Set of output utility functions.
89 std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus);
90 std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference);
91 std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType);
92 std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming);
93 std::ostream& operator<<(std::ostream& os, const OperandType& operandType);
94 std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime);
95 std::ostream& operator<<(std::ostream& os, const OperationType& operationType);
96 std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime);
97 std::ostream& operator<<(std::ostream& os, const Priority& priority);
98 std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus);
99 std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation);
100 std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape);
101 std::ostream& operator<<(std::ostream& os, const Timing& timing);
102 std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo);
103 std::ostream& operator<<(std::ostream& os,
104 const Capabilities::OperandPerformance& operandPerformance);
105 std::ostream& operator<<(std::ostream& os,
106 const Capabilities::OperandPerformanceTable& operandPerformances);
107 std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities);
108 std::ostream& operator<<(std::ostream& os,
109 const Extension::OperandTypeInformation& operandTypeInformation);
110 std::ostream& operator<<(std::ostream& os, const Extension& extension);
111 std::ostream& operator<<(std::ostream& os, const DataLocation& location);
112 std::ostream& operator<<(std::ostream& os,
113 const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams);
114 std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams);
115 std::ostream& operator<<(std::ostream& os, const Operand& operand);
116 std::ostream& operator<<(std::ostream& os, const Operation& operation);
117 std::ostream& operator<<(std::ostream& os, const SharedHandle& handle);
118 std::ostream& operator<<(std::ostream& os, const Memory& memory);
119 std::ostream& operator<<(std::ostream& os, const SharedMemory& memory);
120 std::ostream& operator<<(std::ostream& os, const MemoryPreference& memoryPreference);
121 std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph);
122 std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues);
123 std::ostream& operator<<(std::ostream& os,
124 const Model::ExtensionNameAndPrefix& extensionNameAndPrefix);
125 std::ostream& operator<<(std::ostream& os, const Model& model);
126 std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc);
127 std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole);
128 std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument);
129 std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool);
130 std::ostream& operator<<(std::ostream& os, const Request& request);
131 std::ostream& operator<<(std::ostream& os, const SyncFence::FenceState& fenceState);
132 std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint);
133 std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint);
134 std::ostream& operator<<(std::ostream& os, const Duration& timeoutDuration);
135 std::ostream& operator<<(std::ostream& os, const OptionalDuration& optionalTimeoutDuration);
136 std::ostream& operator<<(std::ostream& os, const Version& version);
137 std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion);
138
139 bool operator==(const Timing& a, const Timing& b);
140 bool operator!=(const Timing& a, const Timing& b);
141 bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
142 bool operator!=(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
143 bool operator==(const Capabilities::OperandPerformance& a,
144 const Capabilities::OperandPerformance& b);
145 bool operator!=(const Capabilities::OperandPerformance& a,
146 const Capabilities::OperandPerformance& b);
147 bool operator==(const Capabilities& a, const Capabilities& b);
148 bool operator!=(const Capabilities& a, const Capabilities& b);
149 bool operator==(const Extension::OperandTypeInformation& a,
150 const Extension::OperandTypeInformation& b);
151 bool operator!=(const Extension::OperandTypeInformation& a,
152 const Extension::OperandTypeInformation& b);
153 bool operator==(const Extension& a, const Extension& b);
154 bool operator!=(const Extension& a, const Extension& b);
155 bool operator==(const MemoryPreference& a, const MemoryPreference& b);
156 bool operator!=(const MemoryPreference& a, const MemoryPreference& b);
157 bool operator==(const Operand::SymmPerChannelQuantParams& a,
158 const Operand::SymmPerChannelQuantParams& b);
159 bool operator!=(const Operand::SymmPerChannelQuantParams& a,
160 const Operand::SymmPerChannelQuantParams& b);
161 bool operator==(const Operand& a, const Operand& b);
162 bool operator!=(const Operand& a, const Operand& b);
163 bool operator==(const Operation& a, const Operation& b);
164 bool operator!=(const Operation& a, const Operation& b);
165
166 // The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in
167 // system/libbase/include/android-base/logging.h
168 //
169 // The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL)
170 // and return false instead of aborting.
171
172 // Logs an error and returns false. Append context using << after. For example:
173 //
174 // NN_RET_CHECK_FAIL() << "Something went wrong";
175 //
176 // The containing function must return a bool.
177 #define NN_RET_CHECK_FAIL() \
178 return ::android::nn::FalseyErrorStream() \
179 << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): "
180
181 // Logs an error and returns false if condition is false. Extra logging can be appended using <<
182 // after. For example:
183 //
184 // NN_RET_CHECK(false) << "Something went wrong";
185 //
186 // The containing function must return a bool.
187 #define NN_RET_CHECK(condition) \
188 while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " "
189
190 // Helper for NN_CHECK_xx(x, y) macros.
191 #define NN_RET_CHECK_OP(LHS, RHS, OP) \
192 for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \
193 UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \
194 /* empty */) \
195 NN_RET_CHECK_FAIL() \
196 << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \
197 << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
198 << ", " << #RHS << " = " \
199 << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
200 << ") "
201
202 // Logs an error and returns false if a condition between x and y does not hold. Extra logging can
203 // be appended using << after. For example:
204 //
205 // NN_RET_CHECK_EQ(a, b) << "Something went wrong";
206 //
207 // The values must implement the appropriate comparison operator as well as
208 // `operator<<(std::ostream&, ...)`.
209 // The containing function must return a bool.
210 #define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==)
211 #define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=)
212 #define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=)
213 #define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <)
214 #define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=)
215 #define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >)
216
217 // Ensure that every user of FalseyErrorStream is linked to the
218 // correct instance, using the correct LOG_TAG
219 namespace {
220
221 // A wrapper around LOG(ERROR) that can be implicitly converted to bool (always evaluates to false).
222 // Used to implement stream logging in NN_RET_CHECK.
223 class FalseyErrorStream {
224 DISALLOW_COPY_AND_ASSIGN(FalseyErrorStream);
225
226 public:
FalseyErrorStream()227 FalseyErrorStream() {}
228
229 template <typename T>
230 FalseyErrorStream& operator<<(const T& value) {
231 mBuffer << value;
232 return *this;
233 }
234
~FalseyErrorStream()235 ~FalseyErrorStream() { LOG(ERROR) << mBuffer.str(); }
236
237 operator bool() const { return false; }
238
239 operator Result<Version>() const { return error() << mBuffer.str(); }
240
241 private:
242 std::ostringstream mBuffer;
243 };
244
245 } // namespace
246
247 } // namespace android::nn
248
249 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
250