• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SampleDriver"
18 
19 #include "SampleDriver.h"
20 
21 #include "CpuExecutor.h"
22 #include "ExecutionBurstServer.h"
23 #include "HalInterfaces.h"
24 #include "Tracing.h"
25 #include "ValidateHal.h"
26 
27 #include <android-base/logging.h>
28 #include <hidl/LegacySupport.h>
29 #include <chrono>
30 #include <optional>
31 #include <thread>
32 
33 namespace android {
34 namespace nn {
35 namespace sample_driver {
36 
37 namespace {
38 
39 using time_point = std::chrono::steady_clock::time_point;
40 
now()41 auto now() {
42     return std::chrono::steady_clock::now();
43 };
44 
microsecondsDuration(decltype(now ()) end,decltype(now ()) start)45 auto microsecondsDuration(decltype(now()) end, decltype(now()) start) {
46     return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
47 };
48 
49 }  // namespace
50 
51 static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
52 
getCapabilities(getCapabilities_cb cb)53 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
54     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
55                  "SampleDriver::getCapabilities");
56     return getCapabilities_1_2([&](ErrorStatus error, const V1_2::Capabilities& capabilities) {
57         // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
58         cb(error, convertToV1_0(capabilities));
59     });
60 }
61 
getCapabilities_1_1(getCapabilities_1_1_cb cb)62 Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
63     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
64                  "SampleDriver::getCapabilities_1_1");
65     return getCapabilities_1_2([&](ErrorStatus error, const V1_2::Capabilities& capabilities) {
66         // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
67         cb(error, convertToV1_1(capabilities));
68     });
69 }
70 
getVersionString(getVersionString_cb cb)71 Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
72     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
73                  "SampleDriver::getVersionString");
74     cb(ErrorStatus::NONE, "JUST_AN_EXAMPLE");
75     return Void();
76 }
77 
getType(getType_cb cb)78 Return<void> SampleDriver::getType(getType_cb cb) {
79     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
80     cb(ErrorStatus::NONE, V1_2::DeviceType::CPU);
81     return Void();
82 }
83 
getSupportedExtensions(getSupportedExtensions_cb cb)84 Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
85     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
86                  "SampleDriver::getSupportedExtensions");
87     cb(ErrorStatus::NONE, {/* No extensions. */});
88     return Void();
89 }
90 
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)91 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
92                                                   getSupportedOperations_cb cb) {
93     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
94                  "SampleDriver::getSupportedOperations");
95     if (!validateModel(model)) {
96         VLOG(DRIVER) << "getSupportedOperations";
97         std::vector<bool> supported;
98         cb(ErrorStatus::INVALID_ARGUMENT, supported);
99         return Void();
100     }
101     return getSupportedOperations_1_2(convertToV1_2(model), cb);
102 }
103 
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)104 Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
105                                                       getSupportedOperations_1_1_cb cb) {
106     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
107                  "SampleDriver::getSupportedOperations_1_1");
108     if (!validateModel(model)) {
109         VLOG(DRIVER) << "getSupportedOperations_1_1";
110         std::vector<bool> supported;
111         cb(ErrorStatus::INVALID_ARGUMENT, supported);
112         return Void();
113     }
114     return getSupportedOperations_1_2(convertToV1_2(model), cb);
115 }
116 
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)117 Return<void> SampleDriver::getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) {
118     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
119                  "SampleDriver::getNumberOfCacheFilesNeeded");
120     // Set both numbers to be 0 for cache not supported.
121     cb(ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
122     return Void();
123 }
124 
notify(const sp<V1_0::IPreparedModelCallback> & callback,const ErrorStatus & status,const sp<SamplePreparedModel> & preparedModel)125 static void notify(const sp<V1_0::IPreparedModelCallback>& callback, const ErrorStatus& status,
126                    const sp<SamplePreparedModel>& preparedModel) {
127     callback->notify(status, preparedModel);
128 }
129 
notify(const sp<V1_2::IPreparedModelCallback> & callback,const ErrorStatus & status,const sp<SamplePreparedModel> & preparedModel)130 static void notify(const sp<V1_2::IPreparedModelCallback>& callback, const ErrorStatus& status,
131                    const sp<SamplePreparedModel>& preparedModel) {
132     callback->notify_1_2(status, preparedModel);
133 }
134 
135 template <typename T_Model, typename T_IPreparedModelCallback>
prepareModelBase(const T_Model & model,const SampleDriver * driver,ExecutionPreference preference,const sp<T_IPreparedModelCallback> & callback)136 Return<ErrorStatus> prepareModelBase(const T_Model& model, const SampleDriver* driver,
137                                      ExecutionPreference preference,
138                                      const sp<T_IPreparedModelCallback>& callback) {
139     if (callback.get() == nullptr) {
140         LOG(ERROR) << "invalid callback passed to prepareModelBase";
141         return ErrorStatus::INVALID_ARGUMENT;
142     }
143     if (VLOG_IS_ON(DRIVER)) {
144         VLOG(DRIVER) << "prepareModelBase";
145         logModelToInfo(model);
146     }
147     if (!validateModel(model) || !validateExecutionPreference(preference)) {
148         notify(callback, ErrorStatus::INVALID_ARGUMENT, nullptr);
149         return ErrorStatus::INVALID_ARGUMENT;
150     }
151 
152     // TODO: make asynchronous later
153     sp<SamplePreparedModel> preparedModel = new SamplePreparedModel(convertToV1_2(model), driver);
154     if (!preparedModel->initialize()) {
155         notify(callback, ErrorStatus::INVALID_ARGUMENT, nullptr);
156         return ErrorStatus::INVALID_ARGUMENT;
157     }
158     notify(callback, ErrorStatus::NONE, preparedModel);
159     return ErrorStatus::NONE;
160 }
161 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)162 Return<ErrorStatus> SampleDriver::prepareModel(const V1_0::Model& model,
163                                                const sp<V1_0::IPreparedModelCallback>& callback) {
164     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
165     return prepareModelBase(model, this, ExecutionPreference::FAST_SINGLE_ANSWER, callback);
166 }
167 
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)168 Return<ErrorStatus> SampleDriver::prepareModel_1_1(
169         const V1_1::Model& model, ExecutionPreference preference,
170         const sp<V1_0::IPreparedModelCallback>& callback) {
171     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
172     return prepareModelBase(model, this, preference, callback);
173 }
174 
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const HidlToken &,const sp<V1_2::IPreparedModelCallback> & callback)175 Return<ErrorStatus> SampleDriver::prepareModel_1_2(
176         const V1_2::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
177         const hidl_vec<hidl_handle>&, const HidlToken&,
178         const sp<V1_2::IPreparedModelCallback>& callback) {
179     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
180     return prepareModelBase(model, this, preference, callback);
181 }
182 
prepareModelFromCache(const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const HidlToken &,const sp<V1_2::IPreparedModelCallback> & callback)183 Return<ErrorStatus> SampleDriver::prepareModelFromCache(
184         const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const HidlToken&,
185         const sp<V1_2::IPreparedModelCallback>& callback) {
186     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
187                  "SampleDriver::prepareModelFromCache");
188     callback->notify_1_2(ErrorStatus::GENERAL_FAILURE, nullptr);
189     return ErrorStatus::GENERAL_FAILURE;
190 }
191 
getStatus()192 Return<DeviceStatus> SampleDriver::getStatus() {
193     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED,
194                  "SampleDriver::getStatus");
195     VLOG(DRIVER) << "getStatus()";
196     return DeviceStatus::AVAILABLE;
197 }
198 
run()199 int SampleDriver::run() {
200     android::hardware::configureRpcThreadpool(4, true);
201     if (registerAsService(mName) != android::OK) {
202         LOG(ERROR) << "Could not register service";
203         return 1;
204     }
205     android::hardware::joinRpcThreadpool();
206     LOG(ERROR) << "Service exited!";
207     return 1;
208 }
209 
initialize()210 bool SamplePreparedModel::initialize() {
211     return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
212 }
213 
notify(const sp<V1_0::IExecutionCallback> & callback,const ErrorStatus & status,const hidl_vec<OutputShape> &,Timing)214 static Return<void> notify(const sp<V1_0::IExecutionCallback>& callback, const ErrorStatus& status,
215                            const hidl_vec<OutputShape>&, Timing) {
216     return callback->notify(status);
217 }
218 
notify(const sp<V1_2::IExecutionCallback> & callback,const ErrorStatus & status,const hidl_vec<OutputShape> & outputShapes,Timing timing)219 static Return<void> notify(const sp<V1_2::IExecutionCallback>& callback, const ErrorStatus& status,
220                            const hidl_vec<OutputShape>& outputShapes, Timing timing) {
221     return callback->notify_1_2(status, outputShapes, timing);
222 }
223 
224 // TODO(xusongw): Let callback notify actual output shape once dynamic output shape
225 //                is supported in CpuExecutor.
226 template <typename T_IExecutionCallback>
asyncExecute(const Request & request,MeasureTiming measure,time_point driverStart,const Model & model,const SampleDriver & driver,const std::vector<RunTimePoolInfo> & poolInfos,const sp<T_IExecutionCallback> & callback)227 void asyncExecute(const Request& request, MeasureTiming measure, time_point driverStart,
228                   const Model& model, const SampleDriver& driver,
229                   const std::vector<RunTimePoolInfo>& poolInfos,
230                   const sp<T_IExecutionCallback>& callback) {
231     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
232                  "SampleDriver::asyncExecute");
233     std::vector<RunTimePoolInfo> requestPoolInfos;
234     if (!setRunTimePoolInfosFromHidlMemories(&requestPoolInfos, request.pools)) {
235         notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
236         return;
237     }
238 
239     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
240                         "SampleDriver::asyncExecute");
241     CpuExecutor executor = driver.getExecutor();
242     time_point driverEnd, deviceStart, deviceEnd;
243     if (measure == MeasureTiming::YES) deviceStart = now();
244     int n = executor.run(model, request, poolInfos, requestPoolInfos);
245     if (measure == MeasureTiming::YES) deviceEnd = now();
246     VLOG(DRIVER) << "executor.run returned " << n;
247     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
248     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
249     Return<void> returned;
250     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
251         driverEnd = now();
252         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
253                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
254         VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
255         returned = notify(callback, executionStatus, outputShapes, timing);
256     } else {
257         returned = notify(callback, executionStatus, outputShapes, kNoTiming);
258     }
259     if (!returned.isOk()) {
260         LOG(ERROR) << " hidl callback failed to return properly: " << returned.description();
261     }
262 }
263 
264 template <typename T_IExecutionCallback>
executeBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const std::vector<RunTimePoolInfo> & poolInfos,const sp<T_IExecutionCallback> & callback)265 Return<ErrorStatus> executeBase(const Request& request, MeasureTiming measure, const Model& model,
266                                 const SampleDriver& driver,
267                                 const std::vector<RunTimePoolInfo>& poolInfos,
268                                 const sp<T_IExecutionCallback>& callback) {
269     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
270     VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
271 
272     time_point driverStart;
273     if (measure == MeasureTiming::YES) driverStart = now();
274 
275     if (callback.get() == nullptr) {
276         LOG(ERROR) << "invalid callback passed to executeBase";
277         return ErrorStatus::INVALID_ARGUMENT;
278     }
279     if (!validateRequest(request, model)) {
280         notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
281         return ErrorStatus::INVALID_ARGUMENT;
282     }
283 
284     // This thread is intentionally detached because the sample driver service
285     // is expected to live forever.
286     std::thread([&model, &driver, &poolInfos, request, measure, driverStart, callback] {
287         asyncExecute(request, measure, driverStart, model, driver, poolInfos, callback);
288     })
289             .detach();
290 
291     return ErrorStatus::NONE;
292 }
293 
execute(const Request & request,const sp<V1_0::IExecutionCallback> & callback)294 Return<ErrorStatus> SamplePreparedModel::execute(const Request& request,
295                                                  const sp<V1_0::IExecutionCallback>& callback) {
296     return executeBase(request, MeasureTiming::NO, mModel, *mDriver, mPoolInfos, callback);
297 }
298 
execute_1_2(const Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)299 Return<ErrorStatus> SamplePreparedModel::execute_1_2(const Request& request, MeasureTiming measure,
300                                                      const sp<V1_2::IExecutionCallback>& callback) {
301     return executeBase(request, measure, mModel, *mDriver, mPoolInfos, callback);
302 }
303 
executeSynchronously(const Request & request,MeasureTiming measure,executeSynchronously_cb cb)304 Return<void> SamplePreparedModel::executeSynchronously(const Request& request,
305                                                        MeasureTiming measure,
306                                                        executeSynchronously_cb cb) {
307     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
308                  "SampleDriver::executeSynchronously");
309     VLOG(DRIVER) << "executeSynchronously(" << SHOW_IF_DEBUG(toString(request)) << ")";
310 
311     time_point driverStart, driverEnd, deviceStart, deviceEnd;
312     if (measure == MeasureTiming::YES) driverStart = now();
313 
314     if (!validateRequest(request, mModel)) {
315         cb(ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
316         return Void();
317     }
318 
319     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
320                         "SampleDriver::executeSynchronously");
321     std::vector<RunTimePoolInfo> requestPoolInfos;
322     if (!setRunTimePoolInfosFromHidlMemories(&requestPoolInfos, request.pools)) {
323         cb(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
324         return Void();
325     }
326 
327     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
328                         "SampleDriver::executeSynchronously");
329     CpuExecutor executor = mDriver->getExecutor();
330     if (measure == MeasureTiming::YES) deviceStart = now();
331     int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
332     if (measure == MeasureTiming::YES) deviceEnd = now();
333     VLOG(DRIVER) << "executor.run returned " << n;
334     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
335     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
336     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
337         driverEnd = now();
338         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
339                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
340         VLOG(DRIVER) << "executeSynchronously timing = " << toString(timing);
341         cb(executionStatus, outputShapes, timing);
342     } else {
343         cb(executionStatus, outputShapes, kNoTiming);
344     }
345     return Void();
346 }
347 
348 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
349 // the mapping until either (1) the memory is freed in the runtime, or (2) the
350 // burst object is destroyed. This allows for subsequent executions operating on
351 // pools that have been used before to reuse the mapping instead of mapping and
352 // unmapping the memory on each execution.
353 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
354    public:
BurstExecutorWithCache(const Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)355     BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
356                            const std::vector<RunTimePoolInfo>& poolInfos)
357         : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
358 
isCacheEntryPresent(int32_t slot) const359     bool isCacheEntryPresent(int32_t slot) const override {
360         const auto it = mMemoryCache.find(slot);
361         return (it != mMemoryCache.end()) && it->second.has_value();
362     }
363 
addCacheEntry(const hidl_memory & memory,int32_t slot)364     void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
365         mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
366     }
367 
removeCacheEntry(int32_t slot)368     void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
369 
execute(const Request & request,const std::vector<int32_t> & slots,MeasureTiming measure)370     std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
371             const Request& request, const std::vector<int32_t>& slots,
372             MeasureTiming measure) override {
373         NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
374                      "BurstExecutorWithCache::execute");
375 
376         time_point driverStart, driverEnd, deviceStart, deviceEnd;
377         if (measure == MeasureTiming::YES) driverStart = now();
378 
379         // ensure all relevant pools are valid
380         if (!std::all_of(slots.begin(), slots.end(),
381                          [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
382             return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
383         }
384 
385         // finish the request object (for validation)
386         hidl_vec<hidl_memory> pools(slots.size());
387         std::transform(slots.begin(), slots.end(), pools.begin(),
388                        [this](int32_t slot) { return mMemoryCache[slot]->getHidlMemory(); });
389         Request fullRequest = request;
390         fullRequest.pools = std::move(pools);
391 
392         // validate request object against the model
393         if (!validateRequest(fullRequest, mModel)) {
394             return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
395         }
396 
397         // select relevant entries from cache
398         std::vector<RunTimePoolInfo> requestPoolInfos;
399         requestPoolInfos.reserve(slots.size());
400         std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
401                        [this](int32_t slot) { return *mMemoryCache[slot]; });
402 
403         // execution
404         CpuExecutor executor = mDriver->getExecutor();
405         if (measure == MeasureTiming::YES) deviceStart = now();
406         int n = executor.run(mModel, request, mModelPoolInfos, requestPoolInfos);
407         if (measure == MeasureTiming::YES) deviceEnd = now();
408         VLOG(DRIVER) << "executor.run returned " << n;
409         ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
410         hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
411         if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
412             driverEnd = now();
413             Timing timing = {
414                     .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
415                     .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
416             VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
417             return std::make_tuple(executionStatus, outputShapes, timing);
418         } else {
419             return std::make_tuple(executionStatus, outputShapes, kNoTiming);
420         }
421     }
422 
423    private:
424     const Model mModel;
425     const SampleDriver* const mDriver;
426     const std::vector<RunTimePoolInfo> mModelPoolInfos;
427     std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache;  // cached requestPoolInfos
428 };
429 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)430 Return<void> SamplePreparedModel::configureExecutionBurst(
431         const sp<V1_2::IBurstCallback>& callback,
432         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
433         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
434         configureExecutionBurst_cb cb) {
435     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
436                  "SampleDriver::configureExecutionBurst");
437 
438     // Alternatively, the burst could be configured via:
439     // const sp<V1_2::IBurstContext> burst =
440     //         ExecutionBurstServer::create(callback, requestChannel,
441     //                                      resultChannel, this);
442     //
443     // However, this alternative representation does not include a memory map
444     // caching optimization, and adds overhead.
445     const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
446             std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
447     const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
448             callback, requestChannel, resultChannel, executorWithCache);
449 
450     if (burst == nullptr) {
451         cb(ErrorStatus::GENERAL_FAILURE, {});
452     } else {
453         cb(ErrorStatus::NONE, burst);
454     }
455 
456     return Void();
457 }
458 
459 } // namespace sample_driver
460 } // namespace nn
461 } // namespace android
462