• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CanonicalPreparedModel.h"
18 
19 #include <DefaultExecution.h>
20 #include <Tracing.h>
21 #include <nnapi/IPreparedModel.h>
22 #include <nnapi/Result.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/Types.h>
25 #include <nnapi/Validation.h>
26 
27 #include <memory>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31 
32 #include "CanonicalBurst.h"
33 #include "CanonicalDevice.h"
34 
35 namespace android::nn::sample {
36 namespace {
37 
38 GeneralResult<std::pair<std::vector<RunTimePoolInfo>, std::vector<std::shared_ptr<ManagedBuffer>>>>
createRunTimePoolInfos(const Request & request,const BufferTracker & bufferTracker,const PreparedModel & preparedModel)39 createRunTimePoolInfos(const Request& request, const BufferTracker& bufferTracker,
40                        const PreparedModel& preparedModel) {
41     std::vector<RunTimePoolInfo> requestPoolInfos;
42     std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
43     requestPoolInfos.reserve(request.pools.size());
44     bufferWrappers.reserve(request.pools.size());
45     for (uint32_t i = 0; i < request.pools.size(); ++i) {
46         auto& pool = request.pools[i];
47         if (const auto* maybeMemory = std::get_if<SharedMemory>(&pool)) {
48             auto buffer = RunTimePoolInfo::createFromMemory(*maybeMemory);
49             if (!buffer.has_value()) {
50                 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
51                        << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
52             }
53             requestPoolInfos.push_back(std::move(*buffer));
54             bufferWrappers.push_back(nullptr);
55         } else if (const auto* maybeToken = std::get_if<Request::MemoryDomainToken>(&pool)) {
56             auto bufferWrapper = bufferTracker.get(*maybeToken);
57             if (bufferWrapper == nullptr) {
58                 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
59             }
60             const auto validationStatus =
61                     bufferWrapper->validateRequest(i, request, &preparedModel);
62             if (validationStatus != ErrorStatus::NONE) {
63                 return NN_ERROR(validationStatus);
64             }
65             requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
66             bufferWrappers.push_back(std::move(bufferWrapper));
67         }
68     }
69     return std::make_pair(std::move(requestPoolInfos), std::move(bufferWrappers));
70 }
71 
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const std::vector<OutputShape> & outputShapes)72 ErrorStatus updateDeviceMemories(ErrorStatus status, const Request& request,
73                                  const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
74                                  const std::vector<OutputShape>& outputShapes) {
75     if (status == ErrorStatus::NONE) {
76         for (uint32_t i = 0; i < request.outputs.size(); i++) {
77             if (request.outputs[i].lifetime != Request::Argument::LifeTime::POOL) continue;
78             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
79             const auto& pool = request.pools[poolIndex];
80             if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
81                 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
82                     return ErrorStatus::GENERAL_FAILURE;
83                 }
84             }
85         }
86         for (uint32_t i = 0; i < request.outputs.size(); i++) {
87             if (request.outputs[i].lifetime != Request::Argument::LifeTime::POOL) continue;
88             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
89             const auto& pool = request.pools[poolIndex];
90             if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
91                 bufferWrappers[poolIndex]->setInitialized(true);
92             }
93         }
94     } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
95         // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
96         // dimensions of the device memory are incorrectly specified. The driver should return
97         // GENERAL_FAILURE instead in this case.
98         for (uint32_t i = 0; i < request.outputs.size(); i++) {
99             if (request.outputs[i].lifetime != Request::Argument::LifeTime::POOL) continue;
100             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
101             const auto& pool = request.pools[poolIndex];
102             if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
103                 if (!outputShapes[i].isSufficient) {
104                     LOG(ERROR) << "Invalid dimensions for output " << i
105                                << ": actual shape = " << toString(outputShapes[i].dimensions);
106                     return ErrorStatus::GENERAL_FAILURE;
107                 }
108             }
109         }
110     }
111     return ErrorStatus::NONE;
112 }
113 
114 }  // namespace
115 
PreparedModel(Model model,ExecutionPreference preference,Priority priority,const IOperationResolver * operationResolver,std::shared_ptr<BufferTracker> bufferTracker,std::vector<RunTimePoolInfo> poolInfos)116 PreparedModel::PreparedModel(Model model, ExecutionPreference preference, Priority priority,
117                              const IOperationResolver* operationResolver,
118                              std::shared_ptr<BufferTracker> bufferTracker,
119                              std::vector<RunTimePoolInfo> poolInfos)
120     : kModel(std::move(model)),
121       kExecutionPreference(preference),
122       kExecutionPriority(priority),
123       kOperationResolver(*operationResolver),
124       kBufferTracker(std::move(bufferTracker)),
125       kPoolInfos(std::move(poolInfos)) {
126     CHECK(operationResolver != nullptr);
127     CHECK(kBufferTracker != nullptr);
128 }
129 
execute(const Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration,const std::vector<TokenValuePair> &,const std::vector<ExtensionNameAndPrefix> &) const130 ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> PreparedModel::execute(
131         const Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
132         const OptionalDuration& loopTimeoutDuration, const std::vector<TokenValuePair>& /*hints*/,
133         const std::vector<ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
134     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::PreparedModel::execute");
135     VLOG(DRIVER) << "sample::PreparedModel::execute(" << SHOW_IF_DEBUG(request) << ")";
136 
137     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
138     if (measure == MeasureTiming::YES) driverStart = Clock::now();
139 
140     if (const auto result = validateRequestForModel(request, kModel); !result.ok()) {
141         return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
142     }
143     if (hasDeadlinePassed(deadline)) {
144         return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
145     }
146 
147     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
148                         "sample::Device::execute");
149     const auto [requestPoolInfos, bufferWrappers] =
150             NN_TRY(createRunTimePoolInfos(request, *kBufferTracker, *this));
151 
152     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::Device::execute");
153     auto executor = CpuExecutor(&kOperationResolver);
154     if (loopTimeoutDuration.has_value()) {
155         executor.setLoopTimeout(loopTimeoutDuration->count());
156     }
157     if (deadline.has_value()) {
158         executor.setDeadline(*deadline);
159     }
160 
161     // Perform execution.
162     if (measure == MeasureTiming::YES) deviceStart = Clock::now();
163     int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
164     if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
165     VLOG(DRIVER) << "executor.run returned " << n;
166     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
167     const auto& outputShapes = executor.getOutputShapes();
168 
169     // Update device memory metadata.
170     const ErrorStatus updateStatus =
171             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
172     if (updateStatus != ErrorStatus::NONE) {
173         return NN_ERROR(updateStatus);
174     }
175     if (executionStatus != ErrorStatus::NONE) {
176         return NN_ERROR(executionStatus, outputShapes);
177     }
178 
179     Timing timing = {};
180     if (measure == MeasureTiming::YES) {
181         driverEnd = Clock::now();
182         timing = {.timeOnDevice = deviceEnd - deviceStart, .timeInDriver = driverEnd - driverStart};
183         VLOG(DRIVER) << "sample::PreparedModel::execute timing = " << timing;
184     }
185 
186     return std::make_pair(outputShapes, timing);
187 }
188 
executeFenced(const Request & request,const std::vector<SyncFence> & waitFor,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration,const OptionalDuration & timeoutDurationAfterFence,const std::vector<TokenValuePair> &,const std::vector<ExtensionNameAndPrefix> &) const189 GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> PreparedModel::executeFenced(
190         const Request& request, const std::vector<SyncFence>& waitFor, MeasureTiming measure,
191         const OptionalTimePoint& deadline, const OptionalDuration& loopTimeoutDuration,
192         const OptionalDuration& timeoutDurationAfterFence,
193         const std::vector<TokenValuePair>& /*hints*/,
194         const std::vector<ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
195     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
196                  "sample::PreparedModel::executeFenced");
197     VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(request) << ")";
198 
199     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
200     if (measure == MeasureTiming::YES) driverStart = Clock::now();
201 
202     if (const auto result =
203                 validateRequestForModel(request, kModel, /*allowUnspecifiedOutput=*/false);
204         !result.ok()) {
205         return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
206     }
207     if (std::any_of(waitFor.begin(), waitFor.end(),
208                     [](const SyncFence& syncFence) { return !syncFence.getSharedHandle(); })) {
209         return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
210                << "sample::PreparedModel::executeFenced passed an empty SyncFence";
211     }
212     if (hasDeadlinePassed(deadline)) {
213         return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
214     }
215 
216     // Wait for the dependent events to signal
217     for (const auto& syncFence : waitFor) {
218         if (syncFence.syncWait({}) != SyncFence::FenceState::SIGNALED) {
219             return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "syncWait failed";
220         }
221     }
222 
223     // Update deadline if the timeout duration is closer than the deadline.
224     auto closestDeadline = deadline;
225     if (timeoutDurationAfterFence.has_value()) {
226         const auto timeoutDurationDeadline = makeDeadline(*timeoutDurationAfterFence);
227         if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
228             closestDeadline = timeoutDurationDeadline;
229         }
230     }
231 
232     TimePoint driverStartAfterFence;
233     if (measure == MeasureTiming::YES) driverStartAfterFence = Clock::now();
234 
235     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
236                         "sample::PreparedModel::executeFenced");
237     const auto [requestPoolInfos, bufferWrappers] =
238             NN_TRY(createRunTimePoolInfos(request, *kBufferTracker, *this));
239 
240     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
241                         "sample::PreparedModel::executeFenced");
242     auto executor = CpuExecutor(&kOperationResolver);
243     if (loopTimeoutDuration.has_value()) {
244         executor.setLoopTimeout(loopTimeoutDuration->count());
245     }
246     if (closestDeadline.has_value()) {
247         executor.setDeadline(*closestDeadline);
248     }
249     if (measure == MeasureTiming::YES) deviceStart = Clock::now();
250     int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
251     if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
252     VLOG(DRIVER) << "executor.run returned " << n;
253     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
254     if (executionStatus != ErrorStatus::NONE) {
255         return NN_ERROR(executionStatus);
256     }
257 
258     // Set output memories to the initialized state.
259     for (const auto& output : request.outputs) {
260         if (output.lifetime != Request::Argument::LifeTime::POOL) continue;
261         const uint32_t poolIndex = output.location.poolIndex;
262         const auto& pool = request.pools[poolIndex];
263         if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
264             bufferWrappers[poolIndex]->setInitialized(true);
265         }
266     }
267 
268     Timing timingSinceLaunch = {};
269     Timing timingAfterFence = {};
270     if (measure == MeasureTiming::YES) {
271         driverEnd = Clock::now();
272         timingSinceLaunch = {.timeOnDevice = deviceEnd - deviceStart,
273                              .timeInDriver = driverEnd - driverStart};
274         timingAfterFence = {.timeOnDevice = deviceEnd - deviceStart,
275                             .timeInDriver = driverEnd - driverStartAfterFence};
276         VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << timingSinceLaunch;
277         VLOG(DRIVER) << "executeFenced timingAfterFence = " << timingAfterFence;
278     }
279 
280     ExecuteFencedInfoCallback fencedExecutionCallback = [timingSinceLaunch, timingAfterFence]() {
281         return std::make_pair(timingSinceLaunch, timingAfterFence);
282     };
283     return std::make_pair(SyncFence::createAsSignaled(), std::move(fencedExecutionCallback));
284 }
285 
createReusableExecution(const Request & request,MeasureTiming measure,const OptionalDuration & loopTimeoutDuration,const std::vector<TokenValuePair> &,const std::vector<ExtensionNameAndPrefix> &) const286 GeneralResult<SharedExecution> PreparedModel::createReusableExecution(
287         const Request& request, MeasureTiming measure, const OptionalDuration& loopTimeoutDuration,
288         const std::vector<TokenValuePair>& /*hints*/,
289         const std::vector<ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
290     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
291                  "sample::PreparedModel::createReusableExecution");
292     return std::make_shared<DefaultExecution>(shared_from_this(), request, measure,
293                                               loopTimeoutDuration);
294 }
295 
configureExecutionBurst() const296 GeneralResult<SharedBurst> PreparedModel::configureExecutionBurst() const {
297     return std::make_shared<const Burst>(shared_from_this());
298 }
299 
getUnderlyingResource() const300 std::any PreparedModel::getUnderlyingResource() const {
301     return &kModel;
302 }
303 
304 }  // namespace android::nn::sample
305