• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CanonicalPreparedModel.h"
18 
19 #include <DefaultExecution.h>
20 #include <Tracing.h>
21 #include <nnapi/IPreparedModel.h>
22 #include <nnapi/Result.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/Types.h>
25 #include <nnapi/Validation.h>
26 
27 #include <memory>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31 
32 #include "CanonicalBurst.h"
33 #include "CanonicalDevice.h"
34 
35 namespace android::nn::sample {
36 namespace {
37 
38 GeneralResult<std::pair<std::vector<RunTimePoolInfo>, std::vector<std::shared_ptr<ManagedBuffer>>>>
createRunTimePoolInfos(const Request & request,const BufferTracker & bufferTracker,const PreparedModel & preparedModel)39 createRunTimePoolInfos(const Request& request, const BufferTracker& bufferTracker,
40                        const PreparedModel& preparedModel) {
41     std::vector<RunTimePoolInfo> requestPoolInfos;
42     std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
43     requestPoolInfos.reserve(request.pools.size());
44     bufferWrappers.reserve(request.pools.size());
45     for (uint32_t i = 0; i < request.pools.size(); ++i) {
46         auto& pool = request.pools[i];
47         if (const auto* maybeMemory = std::get_if<SharedMemory>(&pool)) {
48             auto buffer = RunTimePoolInfo::createFromMemory(*maybeMemory);
49             if (!buffer.has_value()) {
50                 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
51                        << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
52             }
53             requestPoolInfos.push_back(std::move(*buffer));
54             bufferWrappers.push_back(nullptr);
55         } else if (const auto* maybeToken = std::get_if<Request::MemoryDomainToken>(&pool)) {
56             auto bufferWrapper = bufferTracker.get(*maybeToken);
57             if (bufferWrapper == nullptr) {
58                 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
59             }
60             const auto validationStatus =
61                     bufferWrapper->validateRequest(i, request, &preparedModel);
62             if (validationStatus != ErrorStatus::NONE) {
63                 return NN_ERROR(validationStatus);
64             }
65             requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
66             bufferWrappers.push_back(std::move(bufferWrapper));
67         }
68     }
69     return std::make_pair(std::move(requestPoolInfos), std::move(bufferWrappers));
70 }
71 
72 template <typename T>
makeExecutionResult(GeneralResult<T> result)73 ExecutionResult<T> makeExecutionResult(GeneralResult<T> result) {
74     if (!result.has_value()) {
75         const auto& [message, code] = std::move(result).error();
76         return error(code) << message;
77     }
78     return std::move(result).value();
79 }
80 
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const std::vector<OutputShape> & outputShapes)81 ErrorStatus updateDeviceMemories(ErrorStatus status, const Request& request,
82                                  const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
83                                  const std::vector<OutputShape>& outputShapes) {
84     if (status == ErrorStatus::NONE) {
85         for (uint32_t i = 0; i < request.outputs.size(); i++) {
86             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
87             const auto& pool = request.pools[poolIndex];
88             if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
89                 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
90                     return ErrorStatus::GENERAL_FAILURE;
91                 }
92             }
93         }
94         for (uint32_t i = 0; i < request.outputs.size(); i++) {
95             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
96             const auto& pool = request.pools[poolIndex];
97             if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
98                 bufferWrappers[poolIndex]->setInitialized(true);
99             }
100         }
101     } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
102         // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
103         // dimensions of the device memory are incorrectly specified. The driver should return
104         // GENERAL_FAILURE instead in this case.
105         for (uint32_t i = 0; i < request.outputs.size(); i++) {
106             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
107             const auto& pool = request.pools[poolIndex];
108             if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
109                 if (!outputShapes[i].isSufficient) {
110                     LOG(ERROR) << "Invalid dimensions for output " << i
111                                << ": actual shape = " << toString(outputShapes[i].dimensions);
112                     return ErrorStatus::GENERAL_FAILURE;
113                 }
114             }
115         }
116     }
117     return ErrorStatus::NONE;
118 }
119 
120 }  // namespace
121 
PreparedModel(Model model,ExecutionPreference preference,Priority priority,const IOperationResolver * operationResolver,std::shared_ptr<BufferTracker> bufferTracker,std::vector<RunTimePoolInfo> poolInfos)122 PreparedModel::PreparedModel(Model model, ExecutionPreference preference, Priority priority,
123                              const IOperationResolver* operationResolver,
124                              std::shared_ptr<BufferTracker> bufferTracker,
125                              std::vector<RunTimePoolInfo> poolInfos)
126     : kModel(std::move(model)),
127       kExecutionPreference(preference),
128       kExecutionPriority(priority),
129       kOperationResolver(*operationResolver),
130       kBufferTracker(std::move(bufferTracker)),
131       kPoolInfos(std::move(poolInfos)) {
132     CHECK(operationResolver != nullptr);
133     CHECK(kBufferTracker != nullptr);
134 }
135 
execute(const Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration) const136 ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> PreparedModel::execute(
137         const Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
138         const OptionalDuration& loopTimeoutDuration) const {
139     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::PreparedModel::execute");
140     VLOG(DRIVER) << "sample::PreparedModel::execute(" << SHOW_IF_DEBUG(request) << ")";
141 
142     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
143     if (measure == MeasureTiming::YES) driverStart = Clock::now();
144 
145     if (const auto result = validateRequestForModel(request, kModel); !result.ok()) {
146         return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
147     }
148     if (hasDeadlinePassed(deadline)) {
149         return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
150     }
151 
152     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
153                         "sample::Device::execute");
154     const auto [requestPoolInfos, bufferWrappers] =
155             NN_TRY(makeExecutionResult(createRunTimePoolInfos(request, *kBufferTracker, *this)));
156 
157     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "sample::Device::execute");
158     auto executor = CpuExecutor(&kOperationResolver);
159     if (loopTimeoutDuration.has_value()) {
160         executor.setLoopTimeout(loopTimeoutDuration->count());
161     }
162     if (deadline.has_value()) {
163         executor.setDeadline(*deadline);
164     }
165 
166     // Perform execution.
167     if (measure == MeasureTiming::YES) deviceStart = Clock::now();
168     int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
169     if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
170     VLOG(DRIVER) << "executor.run returned " << n;
171     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
172     const auto& outputShapes = executor.getOutputShapes();
173 
174     // Update device memory metadata.
175     const ErrorStatus updateStatus =
176             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
177     if (updateStatus != ErrorStatus::NONE) {
178         return NN_ERROR(updateStatus);
179     }
180     if (executionStatus != ErrorStatus::NONE) {
181         return NN_ERROR(executionStatus, outputShapes);
182     }
183 
184     Timing timing = {};
185     if (measure == MeasureTiming::YES) {
186         driverEnd = Clock::now();
187         timing = {.timeOnDevice = deviceEnd - deviceStart, .timeInDriver = driverEnd - driverStart};
188         VLOG(DRIVER) << "sample::PreparedModel::execute timing = " << timing;
189     }
190 
191     return std::make_pair(outputShapes, timing);
192 }
193 
executeFenced(const Request & request,const std::vector<SyncFence> & waitFor,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalDuration & loopTimeoutDuration,const OptionalDuration & timeoutDurationAfterFence) const194 GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> PreparedModel::executeFenced(
195         const Request& request, const std::vector<SyncFence>& waitFor, MeasureTiming measure,
196         const OptionalTimePoint& deadline, const OptionalDuration& loopTimeoutDuration,
197         const OptionalDuration& timeoutDurationAfterFence) const {
198     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
199                  "sample::PreparedModel::executeFenced");
200     VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(request) << ")";
201 
202     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
203     if (measure == MeasureTiming::YES) driverStart = Clock::now();
204 
205     if (const auto result = validateRequestForModel(request, kModel); !result.ok()) {
206         return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
207     }
208     if (hasDeadlinePassed(deadline)) {
209         return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
210     }
211 
212     // Wait for the dependent events to signal
213     for (const auto& syncFence : waitFor) {
214         if (!syncFence.getSharedHandle()) {
215             return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
216         }
217         if (syncFence.syncWait({}) != SyncFence::FenceState::SIGNALED) {
218             return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "syncWait failed";
219         }
220     }
221 
222     // Update deadline if the timeout duration is closer than the deadline.
223     auto closestDeadline = deadline;
224     if (timeoutDurationAfterFence.has_value()) {
225         const auto timeoutDurationDeadline = makeDeadline(*timeoutDurationAfterFence);
226         if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
227             closestDeadline = timeoutDurationDeadline;
228         }
229     }
230 
231     TimePoint driverStartAfterFence;
232     if (measure == MeasureTiming::YES) driverStartAfterFence = Clock::now();
233 
234     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
235                         "sample::PreparedModel::executeFenced");
236     const auto [requestPoolInfos, bufferWrappers] =
237             NN_TRY(createRunTimePoolInfos(request, *kBufferTracker, *this));
238 
239     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
240                         "sample::PreparedModel::executeFenced");
241     auto executor = CpuExecutor(&kOperationResolver);
242     if (loopTimeoutDuration.has_value()) {
243         executor.setLoopTimeout(loopTimeoutDuration->count());
244     }
245     if (closestDeadline.has_value()) {
246         executor.setDeadline(*closestDeadline);
247     }
248     if (measure == MeasureTiming::YES) deviceStart = Clock::now();
249     int n = executor.run(kModel, request, kPoolInfos, requestPoolInfos);
250     if (measure == MeasureTiming::YES) deviceEnd = Clock::now();
251     VLOG(DRIVER) << "executor.run returned " << n;
252     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
253     if (executionStatus != ErrorStatus::NONE) {
254         return NN_ERROR(executionStatus);
255     }
256 
257     // Set output memories to the initialized state.
258     for (const auto& output : request.outputs) {
259         const uint32_t poolIndex = output.location.poolIndex;
260         const auto& pool = request.pools[poolIndex];
261         if (std::holds_alternative<Request::MemoryDomainToken>(pool)) {
262             bufferWrappers[poolIndex]->setInitialized(true);
263         }
264     }
265 
266     Timing timingSinceLaunch = {};
267     Timing timingAfterFence = {};
268     if (measure == MeasureTiming::YES) {
269         driverEnd = Clock::now();
270         timingSinceLaunch = {.timeOnDevice = deviceEnd - deviceStart,
271                              .timeInDriver = driverEnd - driverStart};
272         timingAfterFence = {.timeOnDevice = deviceEnd - deviceStart,
273                             .timeInDriver = driverEnd - driverStartAfterFence};
274         VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << timingSinceLaunch;
275         VLOG(DRIVER) << "executeFenced timingAfterFence = " << timingAfterFence;
276     }
277 
278     ExecuteFencedInfoCallback fencedExecutionCallback = [timingSinceLaunch, timingAfterFence]() {
279         return std::make_pair(timingSinceLaunch, timingAfterFence);
280     };
281     return std::make_pair(SyncFence::createAsSignaled(), std::move(fencedExecutionCallback));
282 }
283 
createReusableExecution(const Request & request,MeasureTiming measure,const OptionalDuration & loopTimeoutDuration) const284 GeneralResult<SharedExecution> PreparedModel::createReusableExecution(
285         const Request& request, MeasureTiming measure,
286         const OptionalDuration& loopTimeoutDuration) const {
287     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
288                  "sample::PreparedModel::createReusableExecution");
289     return std::make_shared<DefaultExecution>(shared_from_this(), request, measure,
290                                               loopTimeoutDuration);
291 }
292 
configureExecutionBurst() const293 GeneralResult<SharedBurst> PreparedModel::configureExecutionBurst() const {
294     return std::make_shared<const Burst>(shared_from_this());
295 }
296 
getUnderlyingResource() const297 std::any PreparedModel::getUnderlyingResource() const {
298     return &kModel;
299 }
300 
301 }  // namespace android::nn::sample
302