• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SampleDriver"
18 
19 #include "SampleDriver.h"
20 
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <android/sync.h>
24 #include <hidl/LegacySupport.h>
25 
26 #include <algorithm>
27 #include <chrono>
28 #include <map>
29 #include <memory>
30 #include <optional>
31 #include <set>
32 #include <thread>
33 #include <tuple>
34 #include <utility>
35 #include <vector>
36 
37 #include "BufferTracker.h"
38 #include "CpuExecutor.h"
39 #include "ExecutionBurstServer.h"
40 #include "HalInterfaces.h"
41 #include "SampleDriverUtils.h"
42 #include "Tracing.h"
43 #include "ValidateHal.h"
44 
45 namespace android {
46 namespace nn {
47 namespace sample_driver {
48 
49 namespace {
50 
51 using namespace hal;
52 
53 using time_point = std::chrono::steady_clock::time_point;
54 
now()55 auto now() {
56     return std::chrono::steady_clock::now();
57 };
58 
microsecondsDuration(decltype(now ()) end,decltype(now ()) start)59 auto microsecondsDuration(decltype(now()) end, decltype(now()) start) {
60     return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
61 };
62 
63 }  // namespace
64 
65 static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
66 
getCapabilities(getCapabilities_cb cb)67 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
68     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
69                  "SampleDriver::getCapabilities");
70     return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
71         // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
72         cb(convertToV1_0(error), convertToV1_0(capabilities));
73     });
74 }
75 
getCapabilities_1_1(getCapabilities_1_1_cb cb)76 Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
77     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
78                  "SampleDriver::getCapabilities_1_1");
79     return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
80         // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
81         cb(convertToV1_0(error), convertToV1_1(capabilities));
82     });
83 }
84 
getCapabilities_1_2(getCapabilities_1_2_cb cb)85 Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
86     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
87                  "SampleDriver::getCapabilities_1_2");
88     return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
89         // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
90         cb(convertToV1_0(error), convertToV1_2(capabilities));
91     });
92 }
93 
getVersionString(getVersionString_cb cb)94 Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
95     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
96                  "SampleDriver::getVersionString");
97     cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
98     return Void();
99 }
100 
getType(getType_cb cb)101 Return<void> SampleDriver::getType(getType_cb cb) {
102     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
103     cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
104     return Void();
105 }
106 
getSupportedExtensions(getSupportedExtensions_cb cb)107 Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
108     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
109                  "SampleDriver::getSupportedExtensions");
110     cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
111     return Void();
112 }
113 
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)114 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
115                                                   getSupportedOperations_cb cb) {
116     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
117                  "SampleDriver::getSupportedOperations");
118     if (!validateModel(model)) {
119         VLOG(DRIVER) << "getSupportedOperations";
120         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
121         return Void();
122     }
123     return getSupportedOperations_1_3(convertToV1_3(model),
124                                       [&](ErrorStatus status, const hidl_vec<bool>& supported) {
125                                           cb(convertToV1_0(status), supported);
126                                       });
127 }
128 
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)129 Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
130                                                       getSupportedOperations_1_1_cb cb) {
131     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
132                  "SampleDriver::getSupportedOperations_1_1");
133     if (!validateModel(model)) {
134         VLOG(DRIVER) << "getSupportedOperations_1_1";
135         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
136         return Void();
137     }
138     return getSupportedOperations_1_3(convertToV1_3(model),
139                                       [&](ErrorStatus status, const hidl_vec<bool>& supported) {
140                                           cb(convertToV1_0(status), supported);
141                                       });
142 }
143 
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)144 Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
145                                                       getSupportedOperations_1_2_cb cb) {
146     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
147                  "SampleDriver::getSupportedOperations_1_2");
148     if (!validateModel(model)) {
149         VLOG(DRIVER) << "getSupportedOperations_1_2";
150         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
151         return Void();
152     }
153     return getSupportedOperations_1_3(convertToV1_3(model),
154                                       [&](ErrorStatus status, const hidl_vec<bool>& supported) {
155                                           cb(convertToV1_0(status), supported);
156                                       });
157 }
158 
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)159 Return<void> SampleDriver::getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) {
160     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
161                  "SampleDriver::getNumberOfCacheFilesNeeded");
162     // Set both numbers to be 0 for cache not supported.
163     cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
164     return Void();
165 }
166 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)167 Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
168         const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
169     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
170     const ErrorStatus status = prepareModelBase(
171             model, this, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, callback);
172     return convertToV1_0(status);
173 }
174 
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)175 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
176         const V1_1::Model& model, ExecutionPreference preference,
177         const sp<V1_0::IPreparedModelCallback>& callback) {
178     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
179     const ErrorStatus status =
180             prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
181     return convertToV1_0(status);
182 }
183 
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)184 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
185         const V1_2::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
186         const hidl_vec<hidl_handle>&, const CacheToken&,
187         const sp<V1_2::IPreparedModelCallback>& callback) {
188     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
189     const ErrorStatus status =
190             prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
191     return convertToV1_0(status);
192 }
193 
prepareModel_1_3(const V1_3::Model & model,ExecutionPreference preference,Priority priority,const OptionalTimePoint & deadline,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)194 Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
195         const V1_3::Model& model, ExecutionPreference preference, Priority priority,
196         const OptionalTimePoint& deadline, const hidl_vec<hidl_handle>&,
197         const hidl_vec<hidl_handle>&, const CacheToken&,
198         const sp<V1_3::IPreparedModelCallback>& callback) {
199     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
200     return prepareModelBase(model, this, preference, priority, deadline, callback);
201 }
202 
prepareModelFromCache(const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)203 Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
204         const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const CacheToken&,
205         const sp<V1_2::IPreparedModelCallback>& callback) {
206     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
207                  "SampleDriver::prepareModelFromCache");
208     notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
209     return V1_0::ErrorStatus::GENERAL_FAILURE;
210 }
211 
prepareModelFromCache_1_3(const OptionalTimePoint &,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)212 Return<ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
213         const OptionalTimePoint& /*deadline*/, const hidl_vec<hidl_handle>&,
214         const hidl_vec<hidl_handle>&, const CacheToken&,
215         const sp<V1_3::IPreparedModelCallback>& callback) {
216     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
217                  "SampleDriver::prepareModelFromCache_1_3");
218     notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
219     return ErrorStatus::GENERAL_FAILURE;
220 }
221 
getStatus()222 Return<DeviceStatus> SampleDriver::getStatus() {
223     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
224     VLOG(DRIVER) << "getStatus()";
225     return DeviceStatus::AVAILABLE;
226 }
227 
228 // Safely downcast an IPreparedModel object to SamplePreparedModel.
229 // This function will return nullptr if the IPreparedModel object is not originated from the sample
230 // driver process.
castToSamplePreparedModel(const sp<IPreparedModel> & preparedModel)231 static const SamplePreparedModel* castToSamplePreparedModel(
232         const sp<IPreparedModel>& preparedModel) {
233     if (preparedModel->isRemote()) {
234         return nullptr;
235     } else {
236         // This static_cast is safe because SamplePreparedModel is the only class that implements
237         // the IPreparedModel interface in the sample driver process.
238         return static_cast<const SamplePreparedModel*>(preparedModel.get());
239     }
240 }
241 
allocate(const V1_3::BufferDesc & desc,const hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hidl_vec<V1_3::BufferRole> & inputRoles,const hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)242 Return<void> SampleDriver::allocate(const V1_3::BufferDesc& desc,
243                                     const hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
244                                     const hidl_vec<V1_3::BufferRole>& inputRoles,
245                                     const hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
246     constexpr uint32_t kInvalidBufferToken = 0;
247 
248     VLOG(DRIVER) << "SampleDriver::allocate";
249     std::set<PreparedModelRole> roles;
250     V1_3::Operand operand;
251     auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
252         const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
253         if (samplePreparedModel == nullptr) {
254             LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
255             return nullptr;
256         }
257         return samplePreparedModel->getModel();
258     };
259     if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
260                             &operand)) {
261         LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
262         cb(ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
263         return Void();
264     }
265 
266     if (isExtensionOperandType(operand.type)) {
267         LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
268         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
269         return Void();
270     }
271 
272     // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
273     uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
274     VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
275                  << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
276     if (size == 0) {
277         LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
278         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
279         return Void();
280     }
281 
282     auto bufferWrapper = ManagedBuffer::create(size, std::move(roles), std::move(operand));
283     if (bufferWrapper == nullptr) {
284         LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
285         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
286         return Void();
287     }
288 
289     auto token = mBufferTracker->add(bufferWrapper);
290     if (token == nullptr) {
291         LOG(ERROR) << "SampleDriver::allocate -- BufferTracker returned invalid token.";
292         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
293         return Void();
294     }
295 
296     const uint32_t tokenValue = token->get();
297     sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
298     VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
299     cb(ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
300     return Void();
301 }
302 
run()303 int SampleDriver::run() {
304     android::hardware::configureRpcThreadpool(4, true);
305     if (registerAsService(mName) != android::OK) {
306         LOG(ERROR) << "Could not register service";
307         return 1;
308     }
309     android::hardware::joinRpcThreadpool();
310     LOG(ERROR) << "Service exited!";
311     return 1;
312 }
313 
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)314 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
315     CHECK(srcPool.getBuffer() != nullptr);
316     CHECK(dstPool.getBuffer() != nullptr);
317     CHECK(srcPool.getSize() == dstPool.getSize());
318     std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
319     dstPool.flush();
320 }
321 
copyTo(const hidl_memory & dst)322 Return<ErrorStatus> SampleBuffer::copyTo(const hidl_memory& dst) {
323     const auto dstPool = RunTimePoolInfo::createFromHidlMemory(dst);
324     if (!dstPool.has_value()) {
325         LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
326         return ErrorStatus::GENERAL_FAILURE;
327     }
328     const ErrorStatus validationStatus = kBuffer->validateCopyTo(dstPool->getSize());
329     if (validationStatus != ErrorStatus::NONE) {
330         return validationStatus;
331     }
332     const auto srcPool = kBuffer->createRunTimePoolInfo();
333     copyRunTimePoolInfos(srcPool, dstPool.value());
334     return ErrorStatus::NONE;
335 }
336 
copyFromInternal(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions,const std::shared_ptr<ManagedBuffer> & bufferWrapper)337 static ErrorStatus copyFromInternal(const hidl_memory& src, const hidl_vec<uint32_t>& dimensions,
338                                     const std::shared_ptr<ManagedBuffer>& bufferWrapper) {
339     CHECK(bufferWrapper != nullptr);
340     const auto srcPool = RunTimePoolInfo::createFromHidlMemory(src);
341     if (!srcPool.has_value()) {
342         LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
343         return ErrorStatus::GENERAL_FAILURE;
344     }
345     const ErrorStatus validationStatus =
346             bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize());
347     if (validationStatus != ErrorStatus::NONE) {
348         return validationStatus;
349     }
350     const auto dstPool = bufferWrapper->createRunTimePoolInfo();
351     copyRunTimePoolInfos(srcPool.value(), dstPool);
352     return ErrorStatus::NONE;
353 }
354 
copyFrom(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions)355 Return<ErrorStatus> SampleBuffer::copyFrom(const hidl_memory& src,
356                                            const hidl_vec<uint32_t>& dimensions) {
357     const auto status = copyFromInternal(src, dimensions, kBuffer);
358     if (status == ErrorStatus::NONE) {
359         kBuffer->updateDimensions(dimensions);
360         kBuffer->setInitialized(true);
361     } else {
362         kBuffer->setInitialized(false);
363     }
364     return status;
365 }
366 
initialize()367 bool SamplePreparedModel::initialize() {
368     return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
369 }
370 
371 static std::tuple<ErrorStatus, std::vector<RunTimePoolInfo>,
372                   std::vector<std::shared_ptr<ManagedBuffer>>>
createRunTimePoolInfos(const Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)373 createRunTimePoolInfos(const Request& request, const SampleDriver& driver,
374                        const SamplePreparedModel* preparedModel) {
375     std::vector<RunTimePoolInfo> requestPoolInfos;
376     std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
377     requestPoolInfos.reserve(request.pools.size());
378     bufferWrappers.reserve(request.pools.size());
379     for (uint32_t i = 0; i < request.pools.size(); i++) {
380         auto& pool = request.pools[i];
381         switch (pool.getDiscriminator()) {
382             case Request::MemoryPool::hidl_discriminator::hidlMemory: {
383                 auto buffer = RunTimePoolInfo::createFromHidlMemory(pool.hidlMemory());
384                 if (!buffer.has_value()) {
385                     LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
386                     return {ErrorStatus::GENERAL_FAILURE, {}, {}};
387                 }
388                 requestPoolInfos.push_back(std::move(*buffer));
389                 bufferWrappers.push_back(nullptr);
390             } break;
391             case Request::MemoryPool::hidl_discriminator::token: {
392                 auto bufferWrapper = driver.getBufferTracker()->get(pool.token());
393                 if (bufferWrapper == nullptr) {
394                     return {ErrorStatus::INVALID_ARGUMENT, {}, {}};
395                 }
396                 const auto validationStatus =
397                         bufferWrapper->validateRequest(i, request, preparedModel);
398                 if (validationStatus != ErrorStatus::NONE) {
399                     return {validationStatus, {}, {}};
400                 }
401                 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
402                 bufferWrappers.push_back(std::move(bufferWrapper));
403             } break;
404         }
405     }
406     return {ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
407 }
408 
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const hidl_vec<OutputShape> & outputShapes)409 static ErrorStatus updateDeviceMemories(
410         ErrorStatus status, const Request& request,
411         const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
412         const hidl_vec<OutputShape>& outputShapes) {
413     if (status == ErrorStatus::NONE) {
414         for (uint32_t i = 0; i < request.outputs.size(); i++) {
415             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
416             const auto& pool = request.pools[poolIndex];
417             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
418                 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
419                     return ErrorStatus::GENERAL_FAILURE;
420                 }
421             }
422         }
423         for (uint32_t i = 0; i < request.outputs.size(); i++) {
424             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
425             const auto& pool = request.pools[poolIndex];
426             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
427                 bufferWrappers[poolIndex]->setInitialized(true);
428             }
429         }
430     } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
431         // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
432         // dimensions of the device memory are incorrectly specified. The driver should return
433         // GENERAL_FAILURE instead in this case.
434         for (uint32_t i = 0; i < request.outputs.size(); i++) {
435             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
436             const auto& pool = request.pools[poolIndex];
437             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
438                 if (!outputShapes[i].isSufficient) {
439                     LOG(ERROR) << "Invalid dimensions for output " << i
440                                << ": actual shape = " << toString(outputShapes[i].dimensions);
441                     return ErrorStatus::GENERAL_FAILURE;
442                 }
443             }
444         }
445     }
446     return ErrorStatus::NONE;
447 }
448 
449 template <typename T_IExecutionCallback>
asyncExecute(const Request & request,MeasureTiming measure,time_point driverStart,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)450 void asyncExecute(const Request& request, MeasureTiming measure, time_point driverStart,
451                   const Model& model, const SampleDriver& driver,
452                   const SamplePreparedModel* preparedModel,
453                   const std::vector<RunTimePoolInfo>& poolInfos,
454                   const std::optional<Deadline>& deadline,
455                   const OptionalTimeoutDuration& loopTimeoutDuration,
456                   const sp<T_IExecutionCallback>& callback) {
457     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
458                  "SampleDriver::asyncExecute");
459 
460     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
461             createRunTimePoolInfos(request, driver, preparedModel);
462     if (poolStatus != ErrorStatus::NONE) {
463         notify(callback, poolStatus, {}, kNoTiming);
464         return;
465     }
466 
467     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
468                         "SampleDriver::asyncExecute");
469     CpuExecutor executor = driver.getExecutor();
470     if (loopTimeoutDuration.getDiscriminator() !=
471         OptionalTimeoutDuration::hidl_discriminator::none) {
472         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
473     }
474     if (deadline.has_value()) {
475         executor.setDeadline(*deadline);
476     }
477     time_point driverEnd, deviceStart, deviceEnd;
478     if (measure == MeasureTiming::YES) deviceStart = now();
479     int n = executor.run(model, request, poolInfos, requestPoolInfos);
480     if (measure == MeasureTiming::YES) deviceEnd = now();
481     VLOG(DRIVER) << "executor.run returned " << n;
482     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
483     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
484 
485     // Update device memory metadata.
486     const ErrorStatus updateStatus =
487             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
488     if (updateStatus != ErrorStatus::NONE) {
489         notify(callback, updateStatus, {}, kNoTiming);
490         return;
491     }
492 
493     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
494         driverEnd = now();
495         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
496                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
497         VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
498         notify(callback, executionStatus, outputShapes, timing);
499     } else {
500         notify(callback, executionStatus, outputShapes, kNoTiming);
501     }
502 }
503 
504 template <typename T_IExecutionCallback>
executeBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)505 ErrorStatus executeBase(const Request& request, MeasureTiming measure, const Model& model,
506                         const SampleDriver& driver, const SamplePreparedModel* preparedModel,
507                         const std::vector<RunTimePoolInfo>& poolInfos,
508                         const OptionalTimePoint& halDeadline,
509                         const OptionalTimeoutDuration& loopTimeoutDuration,
510                         const sp<T_IExecutionCallback>& callback) {
511     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
512     VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
513 
514     time_point driverStart;
515     if (measure == MeasureTiming::YES) driverStart = now();
516 
517     if (callback.get() == nullptr) {
518         LOG(ERROR) << "invalid callback passed to executeBase";
519         return ErrorStatus::INVALID_ARGUMENT;
520     }
521     if (!validateRequest(request, model)) {
522         notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
523         return ErrorStatus::INVALID_ARGUMENT;
524     }
525     const auto deadline = makeDeadline(halDeadline);
526     if (hasDeadlinePassed(deadline)) {
527         notify(callback, ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
528         return ErrorStatus::NONE;
529     }
530 
531     // This thread is intentionally detached because the sample driver service
532     // is expected to live forever.
533     std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
534                  deadline, loopTimeoutDuration, callback] {
535         asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
536                      deadline, loopTimeoutDuration, callback);
537     }).detach();
538 
539     return ErrorStatus::NONE;
540 }
541 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)542 Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
543         const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
544     const ErrorStatus status = executeBase(convertToV1_3(request), MeasureTiming::NO, mModel,
545                                            *mDriver, this, mPoolInfos, {}, {}, callback);
546     return convertToV1_0(status);
547 }
548 
execute_1_2(const V1_0::Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)549 Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
550         const V1_0::Request& request, MeasureTiming measure,
551         const sp<V1_2::IExecutionCallback>& callback) {
552     const ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver, this,
553                                            mPoolInfos, {}, {}, callback);
554     return convertToV1_0(status);
555 }
556 
execute_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)557 Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
558         const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
559         const OptionalTimeoutDuration& loopTimeoutDuration,
560         const sp<V1_3::IExecutionCallback>& callback) {
561     return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
562                        loopTimeoutDuration, callback);
563 }
564 
executeSynchronouslyBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration)565 static std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> executeSynchronouslyBase(
566         const Request& request, MeasureTiming measure, const Model& model,
567         const SampleDriver& driver, const SamplePreparedModel* preparedModel,
568         const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& halDeadline,
569         const OptionalTimeoutDuration& loopTimeoutDuration) {
570     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
571                  "SampleDriver::executeSynchronouslyBase");
572     VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
573 
574     time_point driverStart, driverEnd, deviceStart, deviceEnd;
575     if (measure == MeasureTiming::YES) driverStart = now();
576 
577     if (!validateRequest(request, model)) {
578         return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
579     }
580     const auto deadline = makeDeadline(halDeadline);
581     if (hasDeadlinePassed(deadline)) {
582         return {ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
583     }
584 
585     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
586                         "SampleDriver::executeSynchronouslyBase");
587     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
588             createRunTimePoolInfos(request, driver, preparedModel);
589     if (poolStatus != ErrorStatus::NONE) {
590         return {poolStatus, {}, kNoTiming};
591     }
592 
593     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
594                         "SampleDriver::executeSynchronouslyBase");
595     CpuExecutor executor = driver.getExecutor();
596     if (loopTimeoutDuration.getDiscriminator() !=
597         OptionalTimeoutDuration::hidl_discriminator::none) {
598         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
599     }
600     if (deadline.has_value()) {
601         executor.setDeadline(*deadline);
602     }
603     if (measure == MeasureTiming::YES) deviceStart = now();
604     int n = executor.run(model, request, poolInfos, requestPoolInfos);
605     if (measure == MeasureTiming::YES) deviceEnd = now();
606     VLOG(DRIVER) << "executor.run returned " << n;
607     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
608     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
609 
610     // Update device memory metadata.
611     const ErrorStatus updateStatus =
612             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
613     if (updateStatus != ErrorStatus::NONE) {
614         return {updateStatus, {}, kNoTiming};
615     }
616 
617     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
618         driverEnd = now();
619         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
620                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
621         VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
622         return {executionStatus, std::move(outputShapes), timing};
623     }
624     return {executionStatus, std::move(outputShapes), kNoTiming};
625 }
626 
executeSynchronously(const V1_0::Request & request,MeasureTiming measure,executeSynchronously_cb cb)627 Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
628                                                        MeasureTiming measure,
629                                                        executeSynchronously_cb cb) {
630     auto [status, outputShapes, timing] = executeSynchronouslyBase(
631             convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
632     cb(convertToV1_0(status), std::move(outputShapes), timing);
633     return Void();
634 }
635 
executeSynchronously_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)636 Return<void> SamplePreparedModel::executeSynchronously_1_3(
637         const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
638         const OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
639     auto [status, outputShapes, timing] = executeSynchronouslyBase(
640             request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
641     cb(status, std::move(outputShapes), timing);
642     return Void();
643 }
644 
645 // The sample driver will finish the execution and then return.
executeFenced(const hal::Request & request,const hidl_vec<hidl_handle> & waitFor,MeasureTiming measure,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const OptionalTimeoutDuration & duration,executeFenced_cb cb)646 Return<void> SamplePreparedModel::executeFenced(
647         const hal::Request& request, const hidl_vec<hidl_handle>& waitFor, MeasureTiming measure,
648         const OptionalTimePoint& halDeadline, const OptionalTimeoutDuration& loopTimeoutDuration,
649         const OptionalTimeoutDuration& duration, executeFenced_cb cb) {
650     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
651                  "SamplePreparedModel::executeFenced");
652     VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
653 
654     time_point driverStart, driverEnd, deviceStart, deviceEnd;
655     if (measure == MeasureTiming::YES) driverStart = now();
656 
657     if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
658         cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
659         return Void();
660     }
661     const auto deadline = makeDeadline(halDeadline);
662     if (hasDeadlinePassed(deadline)) {
663         cb(ErrorStatus::MISSED_DEADLINE_PERSISTENT, hidl_handle(nullptr), nullptr);
664         return Void();
665     }
666 
667     // Wait for the dependent events to signal
668     for (const auto& fenceHandle : waitFor) {
669         if (!fenceHandle.getNativeHandle()) {
670             cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
671             return Void();
672         }
673         int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
674         if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
675             LOG(ERROR) << "syncWait failed";
676             cb(ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
677             return Void();
678         }
679     }
680 
681     // Update deadline if the timeout duration is closer than the deadline.
682     auto closestDeadline = deadline;
683     if (duration.getDiscriminator() != OptionalTimeoutDuration::hidl_discriminator::none) {
684         const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
685         if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
686             closestDeadline = timeoutDurationDeadline;
687         }
688     }
689 
690     time_point driverStartAfterFence;
691     if (measure == MeasureTiming::YES) driverStartAfterFence = now();
692 
693     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
694                         "SamplePreparedModel::executeFenced");
695     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
696             createRunTimePoolInfos(request, *mDriver, this);
697     if (poolStatus != ErrorStatus::NONE) {
698         cb(poolStatus, hidl_handle(nullptr), nullptr);
699         return Void();
700     }
701 
702     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
703                         "SamplePreparedModel::executeFenced");
704     CpuExecutor executor = mDriver->getExecutor();
705     if (loopTimeoutDuration.getDiscriminator() !=
706         OptionalTimeoutDuration::hidl_discriminator::none) {
707         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
708     }
709     if (closestDeadline.has_value()) {
710         executor.setDeadline(*closestDeadline);
711     }
712     if (measure == MeasureTiming::YES) deviceStart = now();
713     int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
714     if (measure == MeasureTiming::YES) deviceEnd = now();
715     VLOG(DRIVER) << "executor.run returned " << n;
716     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
717     if (executionStatus != ErrorStatus::NONE) {
718         cb(executionStatus, hidl_handle(nullptr), nullptr);
719         return Void();
720     }
721 
722     // Set output memories to the initialized state.
723     if (executionStatus == ErrorStatus::NONE) {
724         for (const auto& output : request.outputs) {
725             const uint32_t poolIndex = output.location.poolIndex;
726             const auto& pool = request.pools[poolIndex];
727             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
728                 bufferWrappers[poolIndex]->setInitialized(true);
729             }
730         }
731     }
732 
733     Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
734     Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
735     if (measure == MeasureTiming::YES) {
736         driverEnd = now();
737         timingSinceLaunch = {
738                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
739                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
740         timingAfterFence = {
741                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
742                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
743         VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
744         VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
745     }
746     sp<SampleFencedExecutionCallback> fencedExecutionCallback =
747             new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
748     cb(executionStatus, hidl_handle(nullptr), fencedExecutionCallback);
749     return Void();
750 }
751 
752 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
753 // the mapping until either (1) the memory is freed in the runtime, or (2) the
754 // burst object is destroyed. This allows for subsequent executions operating on
755 // pools that have been used before to reuse the mapping instead of mapping and
756 // unmapping the memory on each execution.
757 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
758    public:
BurstExecutorWithCache(const Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)759     BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
760                            const std::vector<RunTimePoolInfo>& poolInfos)
761         : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
762 
isCacheEntryPresent(int32_t slot) const763     bool isCacheEntryPresent(int32_t slot) const override {
764         const auto it = mMemoryCache.find(slot);
765         return (it != mMemoryCache.end()) && it->second.has_value();
766     }
767 
addCacheEntry(const hidl_memory & memory,int32_t slot)768     void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
769         mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
770     }
771 
removeCacheEntry(int32_t slot)772     void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
773 
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,MeasureTiming measure)774     std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
775             const V1_0::Request& request, const std::vector<int32_t>& slots,
776             MeasureTiming measure) override {
777         NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
778                      "BurstExecutorWithCache::execute");
779 
780         time_point driverStart, driverEnd, deviceStart, deviceEnd;
781         if (measure == MeasureTiming::YES) driverStart = now();
782 
783         // ensure all relevant pools are valid
784         if (!std::all_of(slots.begin(), slots.end(),
785                          [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
786             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
787         }
788 
789         // finish the request object (for validation)
790         hidl_vec<Request::MemoryPool> pools(slots.size());
791         std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
792             Request::MemoryPool pool;
793             pool.hidlMemory(mMemoryCache[slot]->getHidlMemory());
794             return pool;
795         });
796         Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
797         fullRequest.pools = std::move(pools);
798 
799         // validate request object against the model
800         if (!validateRequest(fullRequest, mModel)) {
801             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
802         }
803 
804         // select relevant entries from cache
805         std::vector<RunTimePoolInfo> requestPoolInfos;
806         requestPoolInfos.reserve(slots.size());
807         std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
808                        [this](int32_t slot) { return *mMemoryCache[slot]; });
809 
810         // execution
811         // Configuring the loop timeout duration is not supported. This is OK
812         // because burst does not support HAL 1.3 and hence does not support
813         // WHILE loops.
814         CpuExecutor executor = mDriver->getExecutor();
815         if (measure == MeasureTiming::YES) deviceStart = now();
816         int n = executor.run(mModel, fullRequest, mModelPoolInfos, requestPoolInfos);
817         if (measure == MeasureTiming::YES) deviceEnd = now();
818         VLOG(DRIVER) << "executor.run returned " << n;
819         V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToErrorStatus(n));
820         hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
821         if (measure == MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
822             driverEnd = now();
823             Timing timing = {
824                     .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
825                     .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
826             VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
827             return std::make_tuple(executionStatus, outputShapes, timing);
828         } else {
829             return std::make_tuple(executionStatus, outputShapes, kNoTiming);
830         }
831     }
832 
833    private:
834     const Model mModel;
835     const SampleDriver* const mDriver;
836     const std::vector<RunTimePoolInfo> mModelPoolInfos;
837     std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache;  // cached requestPoolInfos
838 };
839 
840 // This is the amount of time the ExecutionBurstServer should spend polling the
841 // FMQ to see if it has data available before it should fall back to waiting on
842 // the futex.
getPollingTimeWindow()843 static std::chrono::microseconds getPollingTimeWindow() {
844     constexpr int32_t defaultPollingTimeWindow = 50;
845 #ifdef NN_DEBUGGABLE
846     constexpr int32_t minPollingTimeWindow = 0;
847     const int32_t selectedPollingTimeWindow =
848             base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
849                                  defaultPollingTimeWindow, minPollingTimeWindow);
850     return std::chrono::microseconds{selectedPollingTimeWindow};
851 #else
852     return std::chrono::microseconds{defaultPollingTimeWindow};
853 #endif  // NN_DEBUGGABLE
854 }
855 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)856 Return<void> SamplePreparedModel::configureExecutionBurst(
857         const sp<V1_2::IBurstCallback>& callback,
858         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
859         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
860         configureExecutionBurst_cb cb) {
861     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
862                  "SampleDriver::configureExecutionBurst");
863 
864     const bool preferPowerOverLatency = (kPreference == ExecutionPreference::LOW_POWER);
865     const auto pollingTimeWindow =
866             (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
867 
868     // Alternatively, the burst could be configured via:
869     // const sp<V1_2::IBurstContext> burst =
870     //         ExecutionBurstServer::create(callback, requestChannel,
871     //                                      resultChannel, this,
872     //                                      pollingTimeWindow);
873     //
874     // However, this alternative representation does not include a memory map
875     // caching optimization, and adds overhead.
876     const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
877             std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
878     const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
879             callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
880 
881     if (burst == nullptr) {
882         cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
883     } else {
884         cb(V1_0::ErrorStatus::NONE, burst);
885     }
886 
887     return Void();
888 }
889 
890 }  // namespace sample_driver
891 }  // namespace nn
892 }  // namespace android
893