1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "SampleDriver"
18
19 #include "SampleDriver.h"
20
21 #include "CpuExecutor.h"
22 #include "ExecutionBurstServer.h"
23 #include "HalInterfaces.h"
24 #include "Tracing.h"
25 #include "ValidateHal.h"
26
27 #include <android-base/logging.h>
28 #include <hidl/LegacySupport.h>
29 #include <chrono>
30 #include <optional>
31 #include <thread>
32
33 namespace android {
34 namespace nn {
35 namespace sample_driver {
36
37 namespace {
38
39 using time_point = std::chrono::steady_clock::time_point;
40
now()41 auto now() {
42 return std::chrono::steady_clock::now();
43 };
44
microsecondsDuration(decltype(now ()) end,decltype(now ()) start)45 auto microsecondsDuration(decltype(now()) end, decltype(now()) start) {
46 return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
47 };
48
49 } // namespace
50
51 static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
52
getCapabilities(getCapabilities_cb cb)53 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
54 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
55 "SampleDriver::getCapabilities");
56 return getCapabilities_1_2([&](ErrorStatus error, const V1_2::Capabilities& capabilities) {
57 // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
58 cb(error, convertToV1_0(capabilities));
59 });
60 }
61
getCapabilities_1_1(getCapabilities_1_1_cb cb)62 Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
63 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
64 "SampleDriver::getCapabilities_1_1");
65 return getCapabilities_1_2([&](ErrorStatus error, const V1_2::Capabilities& capabilities) {
66 // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
67 cb(error, convertToV1_1(capabilities));
68 });
69 }
70
getVersionString(getVersionString_cb cb)71 Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
72 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
73 "SampleDriver::getVersionString");
74 cb(ErrorStatus::NONE, "JUST_AN_EXAMPLE");
75 return Void();
76 }
77
getType(getType_cb cb)78 Return<void> SampleDriver::getType(getType_cb cb) {
79 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
80 cb(ErrorStatus::NONE, V1_2::DeviceType::CPU);
81 return Void();
82 }
83
getSupportedExtensions(getSupportedExtensions_cb cb)84 Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
85 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
86 "SampleDriver::getSupportedExtensions");
87 cb(ErrorStatus::NONE, {/* No extensions. */});
88 return Void();
89 }
90
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)91 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
92 getSupportedOperations_cb cb) {
93 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
94 "SampleDriver::getSupportedOperations");
95 if (!validateModel(model)) {
96 VLOG(DRIVER) << "getSupportedOperations";
97 std::vector<bool> supported;
98 cb(ErrorStatus::INVALID_ARGUMENT, supported);
99 return Void();
100 }
101 return getSupportedOperations_1_2(convertToV1_2(model), cb);
102 }
103
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)104 Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
105 getSupportedOperations_1_1_cb cb) {
106 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
107 "SampleDriver::getSupportedOperations_1_1");
108 if (!validateModel(model)) {
109 VLOG(DRIVER) << "getSupportedOperations_1_1";
110 std::vector<bool> supported;
111 cb(ErrorStatus::INVALID_ARGUMENT, supported);
112 return Void();
113 }
114 return getSupportedOperations_1_2(convertToV1_2(model), cb);
115 }
116
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)117 Return<void> SampleDriver::getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) {
118 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
119 "SampleDriver::getNumberOfCacheFilesNeeded");
120 // Set both numbers to be 0 for cache not supported.
121 cb(ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
122 return Void();
123 }
124
notify(const sp<V1_0::IPreparedModelCallback> & callback,const ErrorStatus & status,const sp<SamplePreparedModel> & preparedModel)125 static void notify(const sp<V1_0::IPreparedModelCallback>& callback, const ErrorStatus& status,
126 const sp<SamplePreparedModel>& preparedModel) {
127 callback->notify(status, preparedModel);
128 }
129
notify(const sp<V1_2::IPreparedModelCallback> & callback,const ErrorStatus & status,const sp<SamplePreparedModel> & preparedModel)130 static void notify(const sp<V1_2::IPreparedModelCallback>& callback, const ErrorStatus& status,
131 const sp<SamplePreparedModel>& preparedModel) {
132 callback->notify_1_2(status, preparedModel);
133 }
134
135 template <typename T_Model, typename T_IPreparedModelCallback>
prepareModelBase(const T_Model & model,const SampleDriver * driver,ExecutionPreference preference,const sp<T_IPreparedModelCallback> & callback)136 Return<ErrorStatus> prepareModelBase(const T_Model& model, const SampleDriver* driver,
137 ExecutionPreference preference,
138 const sp<T_IPreparedModelCallback>& callback) {
139 if (callback.get() == nullptr) {
140 LOG(ERROR) << "invalid callback passed to prepareModelBase";
141 return ErrorStatus::INVALID_ARGUMENT;
142 }
143 if (VLOG_IS_ON(DRIVER)) {
144 VLOG(DRIVER) << "prepareModelBase";
145 logModelToInfo(model);
146 }
147 if (!validateModel(model) || !validateExecutionPreference(preference)) {
148 notify(callback, ErrorStatus::INVALID_ARGUMENT, nullptr);
149 return ErrorStatus::INVALID_ARGUMENT;
150 }
151
152 // TODO: make asynchronous later
153 sp<SamplePreparedModel> preparedModel = new SamplePreparedModel(convertToV1_2(model), driver);
154 if (!preparedModel->initialize()) {
155 notify(callback, ErrorStatus::INVALID_ARGUMENT, nullptr);
156 return ErrorStatus::INVALID_ARGUMENT;
157 }
158 notify(callback, ErrorStatus::NONE, preparedModel);
159 return ErrorStatus::NONE;
160 }
161
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)162 Return<ErrorStatus> SampleDriver::prepareModel(const V1_0::Model& model,
163 const sp<V1_0::IPreparedModelCallback>& callback) {
164 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
165 return prepareModelBase(model, this, ExecutionPreference::FAST_SINGLE_ANSWER, callback);
166 }
167
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)168 Return<ErrorStatus> SampleDriver::prepareModel_1_1(
169 const V1_1::Model& model, ExecutionPreference preference,
170 const sp<V1_0::IPreparedModelCallback>& callback) {
171 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
172 return prepareModelBase(model, this, preference, callback);
173 }
174
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const HidlToken &,const sp<V1_2::IPreparedModelCallback> & callback)175 Return<ErrorStatus> SampleDriver::prepareModel_1_2(
176 const V1_2::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
177 const hidl_vec<hidl_handle>&, const HidlToken&,
178 const sp<V1_2::IPreparedModelCallback>& callback) {
179 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
180 return prepareModelBase(model, this, preference, callback);
181 }
182
prepareModelFromCache(const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const HidlToken &,const sp<V1_2::IPreparedModelCallback> & callback)183 Return<ErrorStatus> SampleDriver::prepareModelFromCache(
184 const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const HidlToken&,
185 const sp<V1_2::IPreparedModelCallback>& callback) {
186 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
187 "SampleDriver::prepareModelFromCache");
188 callback->notify_1_2(ErrorStatus::GENERAL_FAILURE, nullptr);
189 return ErrorStatus::GENERAL_FAILURE;
190 }
191
getStatus()192 Return<DeviceStatus> SampleDriver::getStatus() {
193 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED,
194 "SampleDriver::getStatus");
195 VLOG(DRIVER) << "getStatus()";
196 return DeviceStatus::AVAILABLE;
197 }
198
run()199 int SampleDriver::run() {
200 android::hardware::configureRpcThreadpool(4, true);
201 if (registerAsService(mName) != android::OK) {
202 LOG(ERROR) << "Could not register service";
203 return 1;
204 }
205 android::hardware::joinRpcThreadpool();
206 LOG(ERROR) << "Service exited!";
207 return 1;
208 }
209
initialize()210 bool SamplePreparedModel::initialize() {
211 return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
212 }
213
notify(const sp<V1_0::IExecutionCallback> & callback,const ErrorStatus & status,const hidl_vec<OutputShape> &,Timing)214 static Return<void> notify(const sp<V1_0::IExecutionCallback>& callback, const ErrorStatus& status,
215 const hidl_vec<OutputShape>&, Timing) {
216 return callback->notify(status);
217 }
218
notify(const sp<V1_2::IExecutionCallback> & callback,const ErrorStatus & status,const hidl_vec<OutputShape> & outputShapes,Timing timing)219 static Return<void> notify(const sp<V1_2::IExecutionCallback>& callback, const ErrorStatus& status,
220 const hidl_vec<OutputShape>& outputShapes, Timing timing) {
221 return callback->notify_1_2(status, outputShapes, timing);
222 }
223
224 // TODO(xusongw): Let callback notify actual output shape once dynamic output shape
225 // is supported in CpuExecutor.
226 template <typename T_IExecutionCallback>
asyncExecute(const Request & request,MeasureTiming measure,time_point driverStart,const Model & model,const SampleDriver & driver,const std::vector<RunTimePoolInfo> & poolInfos,const sp<T_IExecutionCallback> & callback)227 void asyncExecute(const Request& request, MeasureTiming measure, time_point driverStart,
228 const Model& model, const SampleDriver& driver,
229 const std::vector<RunTimePoolInfo>& poolInfos,
230 const sp<T_IExecutionCallback>& callback) {
231 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
232 "SampleDriver::asyncExecute");
233 std::vector<RunTimePoolInfo> requestPoolInfos;
234 if (!setRunTimePoolInfosFromHidlMemories(&requestPoolInfos, request.pools)) {
235 notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
236 return;
237 }
238
239 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
240 "SampleDriver::asyncExecute");
241 CpuExecutor executor = driver.getExecutor();
242 time_point driverEnd, deviceStart, deviceEnd;
243 if (measure == MeasureTiming::YES) deviceStart = now();
244 int n = executor.run(model, request, poolInfos, requestPoolInfos);
245 if (measure == MeasureTiming::YES) deviceEnd = now();
246 VLOG(DRIVER) << "executor.run returned " << n;
247 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
248 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
249 Return<void> returned;
250 if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
251 driverEnd = now();
252 Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
253 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
254 VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
255 returned = notify(callback, executionStatus, outputShapes, timing);
256 } else {
257 returned = notify(callback, executionStatus, outputShapes, kNoTiming);
258 }
259 if (!returned.isOk()) {
260 LOG(ERROR) << " hidl callback failed to return properly: " << returned.description();
261 }
262 }
263
264 template <typename T_IExecutionCallback>
executeBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const std::vector<RunTimePoolInfo> & poolInfos,const sp<T_IExecutionCallback> & callback)265 Return<ErrorStatus> executeBase(const Request& request, MeasureTiming measure, const Model& model,
266 const SampleDriver& driver,
267 const std::vector<RunTimePoolInfo>& poolInfos,
268 const sp<T_IExecutionCallback>& callback) {
269 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
270 VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
271
272 time_point driverStart;
273 if (measure == MeasureTiming::YES) driverStart = now();
274
275 if (callback.get() == nullptr) {
276 LOG(ERROR) << "invalid callback passed to executeBase";
277 return ErrorStatus::INVALID_ARGUMENT;
278 }
279 if (!validateRequest(request, model)) {
280 notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
281 return ErrorStatus::INVALID_ARGUMENT;
282 }
283
284 // This thread is intentionally detached because the sample driver service
285 // is expected to live forever.
286 std::thread([&model, &driver, &poolInfos, request, measure, driverStart, callback] {
287 asyncExecute(request, measure, driverStart, model, driver, poolInfos, callback);
288 })
289 .detach();
290
291 return ErrorStatus::NONE;
292 }
293
execute(const Request & request,const sp<V1_0::IExecutionCallback> & callback)294 Return<ErrorStatus> SamplePreparedModel::execute(const Request& request,
295 const sp<V1_0::IExecutionCallback>& callback) {
296 return executeBase(request, MeasureTiming::NO, mModel, *mDriver, mPoolInfos, callback);
297 }
298
execute_1_2(const Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)299 Return<ErrorStatus> SamplePreparedModel::execute_1_2(const Request& request, MeasureTiming measure,
300 const sp<V1_2::IExecutionCallback>& callback) {
301 return executeBase(request, measure, mModel, *mDriver, mPoolInfos, callback);
302 }
303
executeSynchronously(const Request & request,MeasureTiming measure,executeSynchronously_cb cb)304 Return<void> SamplePreparedModel::executeSynchronously(const Request& request,
305 MeasureTiming measure,
306 executeSynchronously_cb cb) {
307 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
308 "SampleDriver::executeSynchronously");
309 VLOG(DRIVER) << "executeSynchronously(" << SHOW_IF_DEBUG(toString(request)) << ")";
310
311 time_point driverStart, driverEnd, deviceStart, deviceEnd;
312 if (measure == MeasureTiming::YES) driverStart = now();
313
314 if (!validateRequest(request, mModel)) {
315 cb(ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
316 return Void();
317 }
318
319 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
320 "SampleDriver::executeSynchronously");
321 std::vector<RunTimePoolInfo> requestPoolInfos;
322 if (!setRunTimePoolInfosFromHidlMemories(&requestPoolInfos, request.pools)) {
323 cb(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
324 return Void();
325 }
326
327 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
328 "SampleDriver::executeSynchronously");
329 CpuExecutor executor = mDriver->getExecutor();
330 if (measure == MeasureTiming::YES) deviceStart = now();
331 int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
332 if (measure == MeasureTiming::YES) deviceEnd = now();
333 VLOG(DRIVER) << "executor.run returned " << n;
334 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
335 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
336 if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
337 driverEnd = now();
338 Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
339 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
340 VLOG(DRIVER) << "executeSynchronously timing = " << toString(timing);
341 cb(executionStatus, outputShapes, timing);
342 } else {
343 cb(executionStatus, outputShapes, kNoTiming);
344 }
345 return Void();
346 }
347
348 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
349 // the mapping until either (1) the memory is freed in the runtime, or (2) the
350 // burst object is destroyed. This allows for subsequent executions operating on
351 // pools that have been used before to reuse the mapping instead of mapping and
352 // unmapping the memory on each execution.
353 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
354 public:
BurstExecutorWithCache(const Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)355 BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
356 const std::vector<RunTimePoolInfo>& poolInfos)
357 : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
358
isCacheEntryPresent(int32_t slot) const359 bool isCacheEntryPresent(int32_t slot) const override {
360 const auto it = mMemoryCache.find(slot);
361 return (it != mMemoryCache.end()) && it->second.has_value();
362 }
363
addCacheEntry(const hidl_memory & memory,int32_t slot)364 void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
365 mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
366 }
367
removeCacheEntry(int32_t slot)368 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
369
execute(const Request & request,const std::vector<int32_t> & slots,MeasureTiming measure)370 std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
371 const Request& request, const std::vector<int32_t>& slots,
372 MeasureTiming measure) override {
373 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
374 "BurstExecutorWithCache::execute");
375
376 time_point driverStart, driverEnd, deviceStart, deviceEnd;
377 if (measure == MeasureTiming::YES) driverStart = now();
378
379 // ensure all relevant pools are valid
380 if (!std::all_of(slots.begin(), slots.end(),
381 [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
382 return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
383 }
384
385 // finish the request object (for validation)
386 hidl_vec<hidl_memory> pools(slots.size());
387 std::transform(slots.begin(), slots.end(), pools.begin(),
388 [this](int32_t slot) { return mMemoryCache[slot]->getHidlMemory(); });
389 Request fullRequest = request;
390 fullRequest.pools = std::move(pools);
391
392 // validate request object against the model
393 if (!validateRequest(fullRequest, mModel)) {
394 return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
395 }
396
397 // select relevant entries from cache
398 std::vector<RunTimePoolInfo> requestPoolInfos;
399 requestPoolInfos.reserve(slots.size());
400 std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
401 [this](int32_t slot) { return *mMemoryCache[slot]; });
402
403 // execution
404 CpuExecutor executor = mDriver->getExecutor();
405 if (measure == MeasureTiming::YES) deviceStart = now();
406 int n = executor.run(mModel, request, mModelPoolInfos, requestPoolInfos);
407 if (measure == MeasureTiming::YES) deviceEnd = now();
408 VLOG(DRIVER) << "executor.run returned " << n;
409 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
410 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
411 if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
412 driverEnd = now();
413 Timing timing = {
414 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
415 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
416 VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
417 return std::make_tuple(executionStatus, outputShapes, timing);
418 } else {
419 return std::make_tuple(executionStatus, outputShapes, kNoTiming);
420 }
421 }
422
423 private:
424 const Model mModel;
425 const SampleDriver* const mDriver;
426 const std::vector<RunTimePoolInfo> mModelPoolInfos;
427 std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache; // cached requestPoolInfos
428 };
429
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)430 Return<void> SamplePreparedModel::configureExecutionBurst(
431 const sp<V1_2::IBurstCallback>& callback,
432 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
433 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
434 configureExecutionBurst_cb cb) {
435 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
436 "SampleDriver::configureExecutionBurst");
437
438 // Alternatively, the burst could be configured via:
439 // const sp<V1_2::IBurstContext> burst =
440 // ExecutionBurstServer::create(callback, requestChannel,
441 // resultChannel, this);
442 //
443 // However, this alternative representation does not include a memory map
444 // caching optimization, and adds overhead.
445 const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
446 std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
447 const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
448 callback, requestChannel, resultChannel, executorWithCache);
449
450 if (burst == nullptr) {
451 cb(ErrorStatus::GENERAL_FAILURE, {});
452 } else {
453 cb(ErrorStatus::NONE, burst);
454 }
455
456 return Void();
457 }
458
459 } // namespace sample_driver
460 } // namespace nn
461 } // namespace android
462