1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "CanonicalPreparedModel.h"
18
19 #include <DefaultExecution.h>
20 #include <Tracing.h>
21 #include <nnapi/IPreparedModel.h>
22 #include <nnapi/Result.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/Types.h>
25 #include <nnapi/Validation.h>
26
27 #include <memory>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31
32 #include "CanonicalBurst.h"
33 #include "CanonicalDevice.h"
34
35 namespace android::nn::sample {
36 namespace {
37
38 GeneralResult<std::pair<std::vector<RunTimePoolInfo>, std::vector<std::shared_ptr<ManagedBuffer>>>>
createRunTimePoolInfos(const Request & request,const BufferTracker & bufferTracker,const PreparedModel & preparedModel)39 createRunTimePoolInfos(const Request& request, const BufferTracker& bufferTracker,
40 const PreparedModel& preparedModel) {
41 std::vector<RunTimePoolInfo> requestPoolInfos;
42 std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
43 requestPoolInfos.reserve(request.pools.size());
44 bufferWrappers.reserve(request.pools.size());
45 for (uint32_t i = 0; i < request.pools.size(); ++i) {
46 auto& pool = request.pools[i];
47 if (const auto* maybeMemory = std::get_if<SharedMemory>(&pool)) {
48 auto buffer = RunTimePoolInfo::createFromMemory(*maybeMemory);
49 if (!buffer.has_value()) {
50 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
51 << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
52 }
53 requestPoolInfos.push_back(std::move(*buffer));
54 bufferWrappers.push_back(nullptr);
55 } else if (const auto* maybeToken = std::get_if<Request::MemoryDomainToken>(&pool)) {
56 auto bufferWrapper = bufferTracker.get(*maybeToken);
57 if (bufferWrapper == nullptr) {
58 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
59 }
60 const auto validationStatus =
61 bufferWrapper->validateRequest(i, request, &preparedModel);
62 if (validationStatus != ErrorStatus::NONE) {
63 return NN_ERROR(validationStatus);
64 }
65 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
66 bufferWrappers.push_back(std::move(bufferWrapper));
67 }
68 }
69 return std::make_pair(std::move(requestPoolInfos), std::move(bufferWrappers));
70 }
71
72 template <typename T>
makeExecutionResult(GeneralResult<T> result)73 ExecutionResult<T> makeExecutionResult(GeneralResult<T> result) {
74 if (!result.has_value()) {
75 const auto& [message, code] = std::move(result).error();
76 return error(code) << message;
77 }
78 return std::move(result).value();
79 }
80
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const std::vector<OutputShape> & outputShapes)81 ErrorStatus updateDeviceMemories(ErrorStatus status, const Request& request,
82 const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
83 const std::vector<OutputShape>& outputShapes) {
84 if (status == ErrorStatus::NONE) {
85 for (uint32_t i = 0; i < request.outputs.size(); i++) {
86 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
87 const auto& pool = request.pools[poolIndex];
88 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
89 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
90 return ErrorStatus::GENERAL_FAILURE;
91 }
92 }
93 }
94 for (uint32_t i = 0; i < request.outputs.size(); i++) {
95 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
96 const auto& pool = request.pools[poolIndex];
97 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
98 bufferWrappers[poolIndex]->setInitialized(true);
99 }
100 }
101 } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
102 // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
103 // dimensions of the device memory are incorrectly specified. The driver should return
104 // GENERAL_FAILURE instead in this case.
105 for (uint32_t i = 0; i < request.outputs.size(); i++) {
106 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
107 const auto& pool = request.pools[poolIndex];
108 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
109 if (!outputShapes[i].isSufficient) {
110 LOG(ERROR) << "Invalid dimensions for output " << i
111 << ": actual shape = " << toString(outputShapes[i].dimensions);
112 return ErrorStatus::GENERAL_FAILURE;
113 }
114 }
115 }
116 }
117 return ErrorStatus::NONE;
118 }
119
120 } // namespace
121
PreparedModel(Model model,ExecutionPreference preference,Priority priority,const IOperationResolver * operationResolver,std::shared_ptr<BufferTracker> bufferTracker,std::vector<RunTimePoolInfo> poolInfos)122 PreparedModel::PreparedModel(Model model, ExecutionPreference preference, Priority priority,
123 const IOperationResolver* operationResolver,
124 std::shared_ptr<BufferTracker> bufferTracker,
125 std::vector<RunTimePoolInfo> poolInfos)
126 : kModel(std::move(model)),
127 kExecutionPreference(preference),
128 kExecutionPriority(priority),
129 kOperationResolver(*operationResolver),
130 kBufferTracker(std::move(bufferTracker)),
131 kPoolInfos(std::move(poolInfos)) {
132 CHECK(operationResolver != nullptr);
133 CHECK(kBufferTracker != nullptr);
134 }
135
execute(const Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration) const136 ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> PreparedModel::execute(
137 const Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
138 const OptionalDuration& loopTimeoutDuration) const {
139 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::PreparedModel::execute");
140 VLOG(DRIVER) << "sample::PreparedModel::execute(" << SHOW_IF_DEBUG(request) << ")";
141
142 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
143 if (measure == MeasureTiming::YES) driverStart = Clock::now();
144
145 if (const auto result = validateRequestForModel(request, kModel); !result.ok()) {
146 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
147 }
148 if (hasDeadlinePassed(deadline)) {
149 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
150 }
151
152 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
153 "sample::Device::execute");
154 const auto [requestPoolInfos, bufferWrappers] =
155 NN_TRY(makeExecutionResult(createRunTimePoolInfos(request, *kBufferTracker, *this)));
156
157 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::Device::execute");
158 auto executor = CpuExecutor(&kOperationResolver);
159 if (loopTimeoutDuration.has_value()) {
160 executor.setLoopTimeout(loopTimeoutDuration->count());
161 }
162 if (deadline.has_value()) {
163 executor.setDeadline(*deadline);
164 }
165
166 // Perform execution.
167 if (measure == MeasureTiming::YES) deviceStart = Clock::now();
168 int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
169 if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
170 VLOG(DRIVER) << "executor.run returned " << n;
171 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
172 const auto& outputShapes = executor.getOutputShapes();
173
174 // Update device memory metadata.
175 const ErrorStatus updateStatus =
176 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
177 if (updateStatus != ErrorStatus::NONE) {
178 return NN_ERROR(updateStatus);
179 }
180 if (executionStatus != ErrorStatus::NONE) {
181 return NN_ERROR(executionStatus, outputShapes);
182 }
183
184 Timing timing = {};
185 if (measure == MeasureTiming::YES) {
186 driverEnd = Clock::now();
187 timing = {.timeOnDevice = deviceEnd - deviceStart, .timeInDriver = driverEnd - driverStart};
188 VLOG(DRIVER) << "sample::PreparedModel::execute timing = " << timing;
189 }
190
191 return std::make_pair(outputShapes, timing);
192 }
193
executeFenced(const Request & request,const std::vector<SyncFence> & waitFor,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration,const OptionalDuration & timeoutDurationAfterFence) const194 GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> PreparedModel::executeFenced(
195 const Request& request, const std::vector<SyncFence>& waitFor, MeasureTiming measure,
196 const OptionalTimePoint& deadline, const OptionalDuration& loopTimeoutDuration,
197 const OptionalDuration& timeoutDurationAfterFence) const {
198 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
199 "sample::PreparedModel::executeFenced");
200 VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(request) << ")";
201
202 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
203 if (measure == MeasureTiming::YES) driverStart = Clock::now();
204
205 if (const auto result = validateRequestForModel(request, kModel); !result.ok()) {
206 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
207 }
208 if (hasDeadlinePassed(deadline)) {
209 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
210 }
211
212 // Wait for the dependent events to signal
213 for (const auto& syncFence : waitFor) {
214 if (!syncFence.getSharedHandle()) {
215 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
216 }
217 if (syncFence.syncWait({}) != SyncFence::FenceState::SIGNALED) {
218 return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "syncWait failed";
219 }
220 }
221
222 // Update deadline if the timeout duration is closer than the deadline.
223 auto closestDeadline = deadline;
224 if (timeoutDurationAfterFence.has_value()) {
225 const auto timeoutDurationDeadline = makeDeadline(*timeoutDurationAfterFence);
226 if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
227 closestDeadline = timeoutDurationDeadline;
228 }
229 }
230
231 TimePoint driverStartAfterFence;
232 if (measure == MeasureTiming::YES) driverStartAfterFence = Clock::now();
233
234 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
235 "sample::PreparedModel::executeFenced");
236 const auto [requestPoolInfos, bufferWrappers] =
237 NN_TRY(createRunTimePoolInfos(request, *kBufferTracker, *this));
238
239 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
240 "sample::PreparedModel::executeFenced");
241 auto executor = CpuExecutor(&kOperationResolver);
242 if (loopTimeoutDuration.has_value()) {
243 executor.setLoopTimeout(loopTimeoutDuration->count());
244 }
245 if (closestDeadline.has_value()) {
246 executor.setDeadline(*closestDeadline);
247 }
248 if (measure == MeasureTiming::YES) deviceStart = Clock::now();
249 int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
250 if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
251 VLOG(DRIVER) << "executor.run returned " << n;
252 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
253 if (executionStatus != ErrorStatus::NONE) {
254 return NN_ERROR(executionStatus);
255 }
256
257 // Set output memories to the initialized state.
258 for (const auto& output : request.outputs) {
259 const uint32_t poolIndex = output.location.poolIndex;
260 const auto& pool = request.pools[poolIndex];
261 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
262 bufferWrappers[poolIndex]->setInitialized(true);
263 }
264 }
265
266 Timing timingSinceLaunch = {};
267 Timing timingAfterFence = {};
268 if (measure == MeasureTiming::YES) {
269 driverEnd = Clock::now();
270 timingSinceLaunch = {.timeOnDevice = deviceEnd - deviceStart,
271 .timeInDriver = driverEnd - driverStart};
272 timingAfterFence = {.timeOnDevice = deviceEnd - deviceStart,
273 .timeInDriver = driverEnd - driverStartAfterFence};
274 VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << timingSinceLaunch;
275 VLOG(DRIVER) << "executeFenced timingAfterFence = " << timingAfterFence;
276 }
277
278 ExecuteFencedInfoCallback fencedExecutionCallback = [timingSinceLaunch, timingAfterFence]() {
279 return std::make_pair(timingSinceLaunch, timingAfterFence);
280 };
281 return std::make_pair(SyncFence::createAsSignaled(), std::move(fencedExecutionCallback));
282 }
283
createReusableExecution(const Request & request,MeasureTiming measure,const OptionalDuration & loopTimeoutDuration) const284 GeneralResult<SharedExecution> PreparedModel::createReusableExecution(
285 const Request& request, MeasureTiming measure,
286 const OptionalDuration& loopTimeoutDuration) const {
287 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
288 "sample::PreparedModel::createReusableExecution");
289 return std::make_shared<DefaultExecution>(shared_from_this(), request, measure,
290 loopTimeoutDuration);
291 }
292
configureExecutionBurst() const293 GeneralResult<SharedBurst> PreparedModel::configureExecutionBurst() const {
294 return std::make_shared<const Burst>(shared_from_this());
295 }
296
getUnderlyingResource() const297 std::any PreparedModel::getUnderlyingResource() const {
298 return &kModel;
299 }
300
301 } // namespace android::nn::sample
302