1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "CanonicalPreparedModel.h"
18
19 #include <DefaultExecution.h>
20 #include <Tracing.h>
21 #include <nnapi/IPreparedModel.h>
22 #include <nnapi/Result.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/Types.h>
25 #include <nnapi/Validation.h>
26
27 #include <memory>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31
32 #include "CanonicalBurst.h"
33 #include "CanonicalDevice.h"
34
35 namespace android::nn::sample {
36 namespace {
37
38 GeneralResult<std::pair<std::vector<RunTimePoolInfo>, std::vector<std::shared_ptr<ManagedBuffer>>>>
createRunTimePoolInfos(const Request & request,const BufferTracker & bufferTracker,const PreparedModel & preparedModel)39 createRunTimePoolInfos(const Request& request, const BufferTracker& bufferTracker,
40 const PreparedModel& preparedModel) {
41 std::vector<RunTimePoolInfo> requestPoolInfos;
42 std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
43 requestPoolInfos.reserve(request.pools.size());
44 bufferWrappers.reserve(request.pools.size());
45 for (uint32_t i = 0; i < request.pools.size(); ++i) {
46 auto& pool = request.pools[i];
47 if (const auto* maybeMemory = std::get_if<SharedMemory>(&pool)) {
48 auto buffer = RunTimePoolInfo::createFromMemory(*maybeMemory);
49 if (!buffer.has_value()) {
50 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
51 << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
52 }
53 requestPoolInfos.push_back(std::move(*buffer));
54 bufferWrappers.push_back(nullptr);
55 } else if (const auto* maybeToken = std::get_if<Request::MemoryDomainToken>(&pool)) {
56 auto bufferWrapper = bufferTracker.get(*maybeToken);
57 if (bufferWrapper == nullptr) {
58 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
59 }
60 const auto validationStatus =
61 bufferWrapper->validateRequest(i, request, &preparedModel);
62 if (validationStatus != ErrorStatus::NONE) {
63 return NN_ERROR(validationStatus);
64 }
65 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
66 bufferWrappers.push_back(std::move(bufferWrapper));
67 }
68 }
69 return std::make_pair(std::move(requestPoolInfos), std::move(bufferWrappers));
70 }
71
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const std::vector<OutputShape> & outputShapes)72 ErrorStatus updateDeviceMemories(ErrorStatus status, const Request& request,
73 const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
74 const std::vector<OutputShape>& outputShapes) {
75 if (status == ErrorStatus::NONE) {
76 for (uint32_t i = 0; i < request.outputs.size(); i++) {
77 if (request.outputs[i].lifetime != Request::Argument::LifeTime::POOL) continue;
78 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
79 const auto& pool = request.pools[poolIndex];
80 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
81 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
82 return ErrorStatus::GENERAL_FAILURE;
83 }
84 }
85 }
86 for (uint32_t i = 0; i < request.outputs.size(); i++) {
87 if (request.outputs[i].lifetime != Request::Argument::LifeTime::POOL) continue;
88 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
89 const auto& pool = request.pools[poolIndex];
90 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
91 bufferWrappers[poolIndex]->setInitialized(true);
92 }
93 }
94 } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
95 // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
96 // dimensions of the device memory are incorrectly specified. The driver should return
97 // GENERAL_FAILURE instead in this case.
98 for (uint32_t i = 0; i < request.outputs.size(); i++) {
99 if (request.outputs[i].lifetime != Request::Argument::LifeTime::POOL) continue;
100 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
101 const auto& pool = request.pools[poolIndex];
102 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
103 if (!outputShapes[i].isSufficient) {
104 LOG(ERROR) << "Invalid dimensions for output " << i
105 << ": actual shape = " << toString(outputShapes[i].dimensions);
106 return ErrorStatus::GENERAL_FAILURE;
107 }
108 }
109 }
110 }
111 return ErrorStatus::NONE;
112 }
113
114 } // namespace
115
PreparedModel(Model model,ExecutionPreference preference,Priority priority,const IOperationResolver * operationResolver,std::shared_ptr<BufferTracker> bufferTracker,std::vector<RunTimePoolInfo> poolInfos)116 PreparedModel::PreparedModel(Model model, ExecutionPreference preference, Priority priority,
117 const IOperationResolver* operationResolver,
118 std::shared_ptr<BufferTracker> bufferTracker,
119 std::vector<RunTimePoolInfo> poolInfos)
120 : kModel(std::move(model)),
121 kExecutionPreference(preference),
122 kExecutionPriority(priority),
123 kOperationResolver(*operationResolver),
124 kBufferTracker(std::move(bufferTracker)),
125 kPoolInfos(std::move(poolInfos)) {
126 CHECK(operationResolver != nullptr);
127 CHECK(kBufferTracker != nullptr);
128 }
129
execute(const Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration,const std::vector<TokenValuePair> &,const std::vector<ExtensionNameAndPrefix> &) const130 ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> PreparedModel::execute(
131 const Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
132 const OptionalDuration& loopTimeoutDuration, const std::vector<TokenValuePair>& /*hints*/,
133 const std::vector<ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
134 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::PreparedModel::execute");
135 VLOG(DRIVER) << "sample::PreparedModel::execute(" << SHOW_IF_DEBUG(request) << ")";
136
137 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
138 if (measure == MeasureTiming::YES) driverStart = Clock::now();
139
140 if (const auto result = validateRequestForModel(request, kModel); !result.ok()) {
141 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
142 }
143 if (hasDeadlinePassed(deadline)) {
144 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
145 }
146
147 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
148 "sample::Device::execute");
149 const auto [requestPoolInfos, bufferWrappers] =
150 NN_TRY(createRunTimePoolInfos(request, *kBufferTracker, *this));
151
152 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::Device::execute");
153 auto executor = CpuExecutor(&kOperationResolver);
154 if (loopTimeoutDuration.has_value()) {
155 executor.setLoopTimeout(loopTimeoutDuration->count());
156 }
157 if (deadline.has_value()) {
158 executor.setDeadline(*deadline);
159 }
160
161 // Perform execution.
162 if (measure == MeasureTiming::YES) deviceStart = Clock::now();
163 int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
164 if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
165 VLOG(DRIVER) << "executor.run returned " << n;
166 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
167 const auto& outputShapes = executor.getOutputShapes();
168
169 // Update device memory metadata.
170 const ErrorStatus updateStatus =
171 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
172 if (updateStatus != ErrorStatus::NONE) {
173 return NN_ERROR(updateStatus);
174 }
175 if (executionStatus != ErrorStatus::NONE) {
176 return NN_ERROR(executionStatus, outputShapes);
177 }
178
179 Timing timing = {};
180 if (measure == MeasureTiming::YES) {
181 driverEnd = Clock::now();
182 timing = {.timeOnDevice = deviceEnd - deviceStart, .timeInDriver = driverEnd - driverStart};
183 VLOG(DRIVER) << "sample::PreparedModel::execute timing = " << timing;
184 }
185
186 return std::make_pair(outputShapes, timing);
187 }
188
executeFenced(const Request & request,const std::vector<SyncFence> & waitFor,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration,const OptionalDuration & timeoutDurationAfterFence,const std::vector<TokenValuePair> &,const std::vector<ExtensionNameAndPrefix> &) const189 GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> PreparedModel::executeFenced(
190 const Request& request, const std::vector<SyncFence>& waitFor, MeasureTiming measure,
191 const OptionalTimePoint& deadline, const OptionalDuration& loopTimeoutDuration,
192 const OptionalDuration& timeoutDurationAfterFence,
193 const std::vector<TokenValuePair>& /*hints*/,
194 const std::vector<ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
195 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
196 "sample::PreparedModel::executeFenced");
197 VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(request) << ")";
198
199 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
200 if (measure == MeasureTiming::YES) driverStart = Clock::now();
201
202 if (const auto result =
203 validateRequestForModel(request, kModel, /*allowUnspecifiedOutput=*/false);
204 !result.ok()) {
205 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
206 }
207 if (std::any_of(waitFor.begin(), waitFor.end(),
208 [](const SyncFence& syncFence) { return !syncFence.getSharedHandle(); })) {
209 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
210 << "sample::PreparedModel::executeFenced passed an empty SyncFence";
211 }
212 if (hasDeadlinePassed(deadline)) {
213 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
214 }
215
216 // Wait for the dependent events to signal
217 for (const auto& syncFence : waitFor) {
218 if (syncFence.syncWait({}) != SyncFence::FenceState::SIGNALED) {
219 return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "syncWait failed";
220 }
221 }
222
223 // Update deadline if the timeout duration is closer than the deadline.
224 auto closestDeadline = deadline;
225 if (timeoutDurationAfterFence.has_value()) {
226 const auto timeoutDurationDeadline = makeDeadline(*timeoutDurationAfterFence);
227 if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
228 closestDeadline = timeoutDurationDeadline;
229 }
230 }
231
232 TimePoint driverStartAfterFence;
233 if (measure == MeasureTiming::YES) driverStartAfterFence = Clock::now();
234
235 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
236 "sample::PreparedModel::executeFenced");
237 const auto [requestPoolInfos, bufferWrappers] =
238 NN_TRY(createRunTimePoolInfos(request, *kBufferTracker, *this));
239
240 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
241 "sample::PreparedModel::executeFenced");
242 auto executor = CpuExecutor(&kOperationResolver);
243 if (loopTimeoutDuration.has_value()) {
244 executor.setLoopTimeout(loopTimeoutDuration->count());
245 }
246 if (closestDeadline.has_value()) {
247 executor.setDeadline(*closestDeadline);
248 }
249 if (measure == MeasureTiming::YES) deviceStart = Clock::now();
250 int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
251 if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
252 VLOG(DRIVER) << "executor.run returned " << n;
253 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
254 if (executionStatus != ErrorStatus::NONE) {
255 return NN_ERROR(executionStatus);
256 }
257
258 // Set output memories to the initialized state.
259 for (const auto& output : request.outputs) {
260 if (output.lifetime != Request::Argument::LifeTime::POOL) continue;
261 const uint32_t poolIndex = output.location.poolIndex;
262 const auto& pool = request.pools[poolIndex];
263 if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
264 bufferWrappers[poolIndex]->setInitialized(true);
265 }
266 }
267
268 Timing timingSinceLaunch = {};
269 Timing timingAfterFence = {};
270 if (measure == MeasureTiming::YES) {
271 driverEnd = Clock::now();
272 timingSinceLaunch = {.timeOnDevice = deviceEnd - deviceStart,
273 .timeInDriver = driverEnd - driverStart};
274 timingAfterFence = {.timeOnDevice = deviceEnd - deviceStart,
275 .timeInDriver = driverEnd - driverStartAfterFence};
276 VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << timingSinceLaunch;
277 VLOG(DRIVER) << "executeFenced timingAfterFence = " << timingAfterFence;
278 }
279
280 ExecuteFencedInfoCallback fencedExecutionCallback = [timingSinceLaunch, timingAfterFence]() {
281 return std::make_pair(timingSinceLaunch, timingAfterFence);
282 };
283 return std::make_pair(SyncFence::createAsSignaled(), std::move(fencedExecutionCallback));
284 }
285
createReusableExecution(const Request & request,MeasureTiming measure,const OptionalDuration & loopTimeoutDuration,const std::vector<TokenValuePair> &,const std::vector<ExtensionNameAndPrefix> &) const286 GeneralResult<SharedExecution> PreparedModel::createReusableExecution(
287 const Request& request, MeasureTiming measure, const OptionalDuration& loopTimeoutDuration,
288 const std::vector<TokenValuePair>& /*hints*/,
289 const std::vector<ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
290 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
291 "sample::PreparedModel::createReusableExecution");
292 return std::make_shared<DefaultExecution>(shared_from_this(), request, measure,
293 loopTimeoutDuration);
294 }
295
configureExecutionBurst() const296 GeneralResult<SharedBurst> PreparedModel::configureExecutionBurst() const {
297 return std::make_shared<const Burst>(shared_from_this());
298 }
299
getUnderlyingResource() const300 std::any PreparedModel::getUnderlyingResource() const {
301 return &kModel;
302 }
303
304 } // namespace android::nn::sample
305