• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
19 
20 #include <android-base/logging.h>
21 #include <android-base/macros.h>
22 
23 #include <ostream>
24 #include <utility>
25 #include <vector>
26 
27 #include "nnapi/OperandTypes.h"
28 #include "nnapi/OperationTypes.h"
29 #include "nnapi/Result.h"
30 #include "nnapi/Types.h"
31 
32 namespace android::nn {
33 
34 enum class HalVersion : int32_t {
35     UNKNOWN,
36     V1_0,
37     V1_1,
38     V1_2,
39     V1_3,
40     AIDL_UNSTABLE,
41     LATEST = V1_3,
42 };
43 
44 bool isExtension(OperandType type);
45 bool isExtension(OperationType type);
46 
47 bool isNonExtensionScalar(OperandType operandType);
48 
49 size_t getNonExtensionSize(OperandType operandType);
50 
getExtensionPrefix(uint32_t type)51 inline uint16_t getExtensionPrefix(uint32_t type) {
52     return static_cast<uint16_t>(type >> kExtensionTypeBits);
53 }
54 
getTypeWithinExtension(uint32_t type)55 inline uint16_t getTypeWithinExtension(uint32_t type) {
56     return static_cast<uint16_t>(type & kTypeWithinExtensionMask);
57 }
58 
59 std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions);
60 std::optional<size_t> getNonExtensionSize(const Operand& operand);
61 
62 size_t getOffsetFromInts(int lower, int higher);
63 std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset);
64 
65 Result<std::vector<uint32_t>> countNumberOfConsumers(size_t numberOfOperands,
66                                                      const std::vector<nn::Operation>& operations);
67 
68 // Combine two tensor dimensions, both may have unspecified dimensions or rank.
69 Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs);
70 
71 // Returns the operandValues's size and a size for each pool in the provided model.
72 std::pair<size_t, std::vector<size_t>> getMemorySizes(const Model& model);
73 
74 // Round up "size" to the nearest multiple of "multiple". "multiple" must be a power of 2.
75 size_t roundUp(size_t size, size_t multiple);
76 
77 // Returns the alignment for data of the specified length.  It aligns object of length:
78 // 2, 3 on a 2 byte boundary,
79 // 4+ on a 4 byte boundary.
80 // We may want to have different alignments for tensors.
81 // TODO: This is arbitrary, more a proof of concept.  We need to determine what this should be.
82 //
83 // Note that Types.cpp ensures `new` has sufficient alignment for all alignments returned by this
84 // function. If this function is changed to return different alignments (e.g., 8 byte boundary
85 // alignment), the code check in Types.cpp similarly needs to be updated.
86 size_t getAlignmentForLength(size_t length);
87 
88 // Set of output utility functions.
89 std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus);
90 std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference);
91 std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType);
92 std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming);
93 std::ostream& operator<<(std::ostream& os, const OperandType& operandType);
94 std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime);
95 std::ostream& operator<<(std::ostream& os, const OperationType& operationType);
96 std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime);
97 std::ostream& operator<<(std::ostream& os, const Priority& priority);
98 std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus);
99 std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation);
100 std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape);
101 std::ostream& operator<<(std::ostream& os, const Timing& timing);
102 std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo);
103 std::ostream& operator<<(std::ostream& os,
104                          const Capabilities::OperandPerformance& operandPerformance);
105 std::ostream& operator<<(std::ostream& os,
106                          const Capabilities::OperandPerformanceTable& operandPerformances);
107 std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities);
108 std::ostream& operator<<(std::ostream& os,
109                          const Extension::OperandTypeInformation& operandTypeInformation);
110 std::ostream& operator<<(std::ostream& os, const Extension& extension);
111 std::ostream& operator<<(std::ostream& os, const DataLocation& location);
112 std::ostream& operator<<(std::ostream& os,
113                          const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams);
114 std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams);
115 std::ostream& operator<<(std::ostream& os, const Operand& operand);
116 std::ostream& operator<<(std::ostream& os, const Operation& operation);
117 std::ostream& operator<<(std::ostream& os, const SharedHandle& handle);
118 std::ostream& operator<<(std::ostream& os, const Memory& memory);
119 std::ostream& operator<<(std::ostream& os, const SharedMemory& memory);
120 std::ostream& operator<<(std::ostream& os, const MemoryPreference& memoryPreference);
121 std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph);
122 std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues);
123 std::ostream& operator<<(std::ostream& os,
124                          const Model::ExtensionNameAndPrefix& extensionNameAndPrefix);
125 std::ostream& operator<<(std::ostream& os, const Model& model);
126 std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc);
127 std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole);
128 std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument);
129 std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool);
130 std::ostream& operator<<(std::ostream& os, const Request& request);
131 std::ostream& operator<<(std::ostream& os, const SyncFence::FenceState& fenceState);
132 std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint);
133 std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint);
134 std::ostream& operator<<(std::ostream& os, const Duration& timeoutDuration);
135 std::ostream& operator<<(std::ostream& os, const OptionalDuration& optionalTimeoutDuration);
136 std::ostream& operator<<(std::ostream& os, const Version& version);
137 std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion);
138 
139 bool operator==(const Timing& a, const Timing& b);
140 bool operator!=(const Timing& a, const Timing& b);
141 bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
142 bool operator!=(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
143 bool operator==(const Capabilities::OperandPerformance& a,
144                 const Capabilities::OperandPerformance& b);
145 bool operator!=(const Capabilities::OperandPerformance& a,
146                 const Capabilities::OperandPerformance& b);
147 bool operator==(const Capabilities& a, const Capabilities& b);
148 bool operator!=(const Capabilities& a, const Capabilities& b);
149 bool operator==(const Extension::OperandTypeInformation& a,
150                 const Extension::OperandTypeInformation& b);
151 bool operator!=(const Extension::OperandTypeInformation& a,
152                 const Extension::OperandTypeInformation& b);
153 bool operator==(const Extension& a, const Extension& b);
154 bool operator!=(const Extension& a, const Extension& b);
155 bool operator==(const MemoryPreference& a, const MemoryPreference& b);
156 bool operator!=(const MemoryPreference& a, const MemoryPreference& b);
157 bool operator==(const Operand::SymmPerChannelQuantParams& a,
158                 const Operand::SymmPerChannelQuantParams& b);
159 bool operator!=(const Operand::SymmPerChannelQuantParams& a,
160                 const Operand::SymmPerChannelQuantParams& b);
161 bool operator==(const Operand& a, const Operand& b);
162 bool operator!=(const Operand& a, const Operand& b);
163 bool operator==(const Operation& a, const Operation& b);
164 bool operator!=(const Operation& a, const Operation& b);
165 
166 // The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in
167 // system/libbase/include/android-base/logging.h
168 //
169 // The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL)
170 // and return false instead of aborting.
171 
172 // Logs an error and returns false. Append context using << after. For example:
173 //
174 //   NN_RET_CHECK_FAIL() << "Something went wrong";
175 //
176 // The containing function must return a bool.
177 #define NN_RET_CHECK_FAIL()                   \
178     return ::android::nn::FalseyErrorStream() \
179            << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): "
180 
181 // Logs an error and returns false if condition is false. Extra logging can be appended using <<
182 // after. For example:
183 //
184 //   NN_RET_CHECK(false) << "Something went wrong";
185 //
186 // The containing function must return a bool.
187 #define NN_RET_CHECK(condition) \
188     while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " "
189 
190 // Helper for NN_CHECK_xx(x, y) macros.
191 #define NN_RET_CHECK_OP(LHS, RHS, OP)                                                       \
192     for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS);                      \
193          UNLIKELY(!(_values.lhs.v OP _values.rhs.v));                                       \
194          /* empty */)                                                                       \
195     NN_RET_CHECK_FAIL()                                                                     \
196             << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = "                   \
197             << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
198             << ", " << #RHS << " = "                                                        \
199             << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
200             << ") "
201 
202 // Logs an error and returns false if a condition between x and y does not hold. Extra logging can
203 // be appended using << after. For example:
204 //
205 //   NN_RET_CHECK_EQ(a, b) << "Something went wrong";
206 //
207 // The values must implement the appropriate comparison operator as well as
208 // `operator<<(std::ostream&, ...)`.
209 // The containing function must return a bool.
210 #define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==)
211 #define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=)
212 #define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=)
213 #define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <)
214 #define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=)
215 #define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >)
216 
217 // Ensure that every user of FalseyErrorStream is linked to the
218 // correct instance, using the correct LOG_TAG
219 namespace {
220 
221 // A wrapper around LOG(ERROR) that can be implicitly converted to bool (always evaluates to false).
222 // Used to implement stream logging in NN_RET_CHECK.
223 class FalseyErrorStream {
224     DISALLOW_COPY_AND_ASSIGN(FalseyErrorStream);
225 
226    public:
FalseyErrorStream()227     FalseyErrorStream() {}
228 
229     template <typename T>
230     FalseyErrorStream& operator<<(const T& value) {
231         mBuffer << value;
232         return *this;
233     }
234 
~FalseyErrorStream()235     ~FalseyErrorStream() { LOG(ERROR) << mBuffer.str(); }
236 
237     operator bool() const { return false; }
238 
239     operator Result<Version>() const { return error() << mBuffer.str(); }
240 
241    private:
242     std::ostringstream mBuffer;
243 };
244 
245 }  // namespace
246 
247 }  // namespace android::nn
248 
249 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
250