1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Types.h"
18
19 #include <android-base/logging.h>
20 #include <errno.h>
21 #include <poll.h>
22
23 #include <algorithm>
24 #include <cstddef>
25 #include <iterator>
26 #include <limits>
27 #include <memory>
28 #include <optional>
29 #include <utility>
30 #include <vector>
31
32 #include "OperandTypes.h"
33 #include "OperationTypes.h"
34 #include "Result.h"
35 #include "TypeUtils.h"
36
37 namespace android::nn {
38
39 // Ensure that std::vector<uint8_t>::data() will always have sufficient alignment to hold all NNAPI
40 // primitive types. "4" is chosen because that is the maximum alignment returned by
41 // `getAlignmentForLength`. However, this value will have to be changed if `getAlignmentForLength`
42 // returns a larger alignment.
43 static_assert(__STDCPP_DEFAULT_NEW_ALIGNMENT__ >= 4, "`New` alignment is not sufficient");
44
OperandValues()45 Model::OperandValues::OperandValues() {
46 constexpr size_t kNumberBytes = 4 * 1024;
47 mData.reserve(kNumberBytes);
48 }
49
OperandValues(const uint8_t * data,size_t length)50 Model::OperandValues::OperandValues(const uint8_t* data, size_t length)
51 : mData(data, data + length) {}
52
append(const uint8_t * data,size_t length)53 DataLocation Model::OperandValues::append(const uint8_t* data, size_t length) {
54 CHECK_GT(length, 0u);
55 CHECK_LE(length, std::numeric_limits<uint32_t>::max());
56 const size_t alignment = getAlignmentForLength(length);
57 const size_t offset = roundUp(size(), alignment);
58 CHECK_LE(offset, std::numeric_limits<uint32_t>::max());
59 mData.resize(offset + length);
60 CHECK_LE(size(), std::numeric_limits<uint32_t>::max());
61 std::memcpy(mData.data() + offset, data, length);
62 return {.offset = static_cast<uint32_t>(offset), .length = static_cast<uint32_t>(length)};
63 }
64
data() const65 const uint8_t* Model::OperandValues::data() const {
66 return mData.data();
67 }
68
size() const69 size_t Model::OperandValues::size() const {
70 return mData.size();
71 }
72
OperandPerformanceTable(std::vector<OperandPerformance> operandPerformances)73 Capabilities::OperandPerformanceTable::OperandPerformanceTable(
74 std::vector<OperandPerformance> operandPerformances)
75 : mSorted(std::move(operandPerformances)) {}
76
create(std::vector<OperandPerformance> operandPerformances)77 Result<Capabilities::OperandPerformanceTable> Capabilities::OperandPerformanceTable::create(
78 std::vector<OperandPerformance> operandPerformances) {
79 const auto notUnique = [](const auto& lhs, const auto& rhs) { return !(lhs.type < rhs.type); };
80 const bool isUnique = std::adjacent_find(operandPerformances.begin(), operandPerformances.end(),
81 notUnique) == operandPerformances.end();
82 if (!isUnique) {
83 return NN_ERROR() << "Failed to create OperandPerformanceTable: Input must be sorted by "
84 "key (in ascending order), and there must be no duplicate keys";
85 }
86
87 return Capabilities::OperandPerformanceTable(std::move(operandPerformances));
88 }
89
lookup(OperandType operandType) const90 Capabilities::PerformanceInfo Capabilities::OperandPerformanceTable::lookup(
91 OperandType operandType) const {
92 // Search for operand type in the sorted collection.
93 constexpr auto cmp = [](const auto& performance, auto type) { return performance.type < type; };
94 const auto it = std::lower_bound(mSorted.begin(), mSorted.end(), operandType, cmp);
95
96 // If the operand type is found, return its corresponding info.
97 if (it != mSorted.end() && it->type == operandType) {
98 return it->info;
99 }
100
101 // If no performance info is defined, use the default value (float's max).
102 return Capabilities::PerformanceInfo{};
103 }
104
105 const std::vector<Capabilities::OperandPerformance>&
asVector() const106 Capabilities::OperandPerformanceTable::asVector() const {
107 return mSorted;
108 }
109
createAsSignaled()110 SyncFence SyncFence::createAsSignaled() {
111 return SyncFence(nullptr);
112 }
113
create(base::unique_fd fd)114 SyncFence SyncFence::create(base::unique_fd fd) {
115 std::vector<base::unique_fd> fds;
116 fds.push_back(std::move(fd));
117 return SyncFence(std::make_shared<const Handle>(Handle{
118 .fds = std::move(fds),
119 .ints = {},
120 }));
121 }
122
create(SharedHandle syncFence)123 Result<SyncFence> SyncFence::create(SharedHandle syncFence) {
124 const bool isValid =
125 (syncFence != nullptr && syncFence->fds.size() == 1 && syncFence->ints.empty());
126 if (!isValid) {
127 return NN_ERROR() << "Invalid sync fence handle passed to SyncFence::create";
128 }
129 return SyncFence(std::move(syncFence));
130 }
131
SyncFence(SharedHandle syncFence)132 SyncFence::SyncFence(SharedHandle syncFence) : mSyncFence(std::move(syncFence)) {}
133
syncWait(OptionalTimeout optionalTimeout) const134 SyncFence::FenceState SyncFence::syncWait(OptionalTimeout optionalTimeout) const {
135 if (mSyncFence == nullptr) {
136 return FenceState::SIGNALED;
137 }
138
139 const int fd = mSyncFence->fds.front().get();
140 const int timeout = optionalTimeout.value_or(Timeout{-1}).count();
141
142 // This implementation is directly based on the ::sync_wait() implementation.
143
144 struct pollfd fds;
145 int ret;
146
147 if (fd < 0) {
148 errno = EINVAL;
149 return FenceState::UNKNOWN;
150 }
151
152 fds.fd = fd;
153 fds.events = POLLIN;
154
155 do {
156 ret = poll(&fds, 1, timeout);
157 if (ret > 0) {
158 if (fds.revents & POLLNVAL) {
159 errno = EINVAL;
160 return FenceState::UNKNOWN;
161 }
162 if (fds.revents & POLLERR) {
163 errno = EINVAL;
164 return FenceState::ERROR;
165 }
166 return FenceState::SIGNALED;
167 } else if (ret == 0) {
168 errno = ETIME;
169 return FenceState::ACTIVE;
170 }
171 } while (ret == -1 && (errno == EINTR || errno == EAGAIN));
172
173 return FenceState::UNKNOWN;
174 }
175
getSharedHandle() const176 SharedHandle SyncFence::getSharedHandle() const {
177 return mSyncFence;
178 }
179
hasFd() const180 bool SyncFence::hasFd() const {
181 return mSyncFence != nullptr;
182 }
183
getFd() const184 int SyncFence::getFd() const {
185 return mSyncFence == nullptr ? -1 : mSyncFence->fds.front().get();
186 }
187
188 } // namespace android::nn
189