1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "VersionedInterfaces.h"
18
19 #include "Callbacks.h"
20 #include "ExecutionBurstController.h"
21 #include "Tracing.h"
22 #include "Utils.h"
23
24 #include <android-base/logging.h>
25 #include <android-base/scopeguard.h>
26 #include <android-base/thread_annotations.h>
27 #include <functional>
28 #include <type_traits>
29
30 namespace android {
31 namespace nn {
32
33 // anonymous namespace
34 namespace {
35
36 using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
37
38 const Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
39
sendFailureMessage(const sp<IPreparedModelCallback> & cb)40 void sendFailureMessage(const sp<IPreparedModelCallback>& cb) {
41 cb->notify(ErrorStatus::GENERAL_FAILURE, nullptr);
42 }
43
sendFailureMessage(const sp<PreparedModelCallback> & cb)44 void sendFailureMessage(const sp<PreparedModelCallback>& cb) {
45 sendFailureMessage(static_cast<sp<IPreparedModelCallback>>(cb));
46 }
47
sendFailureMessage(const sp<IExecutionCallback> & cb)48 void sendFailureMessage(const sp<IExecutionCallback>& cb) {
49 cb->notify(ErrorStatus::GENERAL_FAILURE);
50 }
51
sendFailureMessage(const sp<ExecutionCallback> & cb)52 void sendFailureMessage(const sp<ExecutionCallback>& cb) {
53 sendFailureMessage(static_cast<sp<IExecutionCallback>>(cb));
54 }
55
56 // This class is thread safe
57 template <typename ICallback>
58 class DeathHandler : public hardware::hidl_death_recipient {
59 public:
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)60 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
61 LOG(ERROR) << "DeathHandler::serviceDied -- service unexpectedly died!";
62 std::lock_guard<std::mutex> hold(mMutex);
63 std::for_each(mCallbacks.begin(), mCallbacks.end(),
64 [](const auto& cb) { sendFailureMessage(cb); });
65 }
66
protectCallback(const sp<ICallback> & callback)67 [[nodiscard]] base::ScopeGuard<std::function<void()>> protectCallback(
68 const sp<ICallback>& callback) {
69 registerCallback(callback);
70 return ::android::base::make_scope_guard(
71 [this, callback] { unregisterCallback(callback); });
72 }
73
registerCallback(const sp<ICallback> & callback)74 private : void registerCallback(const sp<ICallback>& callback) {
75 std::lock_guard<std::mutex> hold(mMutex);
76 mCallbacks.push_back(callback);
77 }
78
unregisterCallback(const sp<ICallback> & callback)79 void unregisterCallback(const sp<ICallback>& callback) {
80 std::lock_guard<std::mutex> hold(mMutex);
81 mCallbacks.erase(std::remove(mCallbacks.begin(), mCallbacks.end(), callback),
82 mCallbacks.end());
83 }
84
85 std::mutex mMutex;
86 std::vector<sp<ICallback>> mCallbacks GUARDED_BY(mMutex);
87 };
88
89 } // anonymous namespace
90
91 class IDeviceDeathHandler : public DeathHandler<IPreparedModelCallback> {};
92 class IPreparedModelDeathHandler : public DeathHandler<IExecutionCallback> {};
93
makeVersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel)94 static std::shared_ptr<VersionedIPreparedModel> makeVersionedIPreparedModel(
95 sp<V1_0::IPreparedModel> preparedModel) {
96 // verify input
97 if (!preparedModel) {
98 LOG(ERROR) << "makeVersionedIPreparedModel -- passed invalid preparedModel object.";
99 return nullptr;
100 }
101
102 // create death handler object
103 sp<IPreparedModelDeathHandler> deathHandler = new (std::nothrow) IPreparedModelDeathHandler();
104 if (!deathHandler) {
105 LOG(ERROR) << "makeVersionedIPreparedModel -- Failed to create IPreparedModelDeathHandler.";
106 return nullptr;
107 }
108
109 // linkToDeath registers a callback that will be invoked on service death to
110 // proactively handle service crashes. If the linkToDeath call fails,
111 // asynchronous calls are susceptible to hangs if the service crashes before
112 // providing the response.
113 const Return<bool> ret = preparedModel->linkToDeath(deathHandler, 0);
114 if (!ret.isOk() || ret != true) {
115 LOG(ERROR) << "makeVersionedIPreparedModel -- Failed to register a death recipient for the "
116 "IPreparedModel object.";
117 return nullptr;
118 }
119
120 // return a valid VersionedIPreparedModel object
121 return std::make_shared<VersionedIPreparedModel>(std::move(preparedModel),
122 std::move(deathHandler));
123 }
124
VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,sp<IPreparedModelDeathHandler> deathHandler)125 VersionedIPreparedModel::VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,
126 sp<IPreparedModelDeathHandler> deathHandler)
127 : mPreparedModelV1_0(std::move(preparedModel)),
128 mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
129 mDeathHandler(std::move(deathHandler)) {}
130
~VersionedIPreparedModel()131 VersionedIPreparedModel::~VersionedIPreparedModel() {
132 // It is safe to ignore any errors resulting from this unlinkToDeath call
133 // because the VersionedIPreparedModel object is already being destroyed and
134 // its underlying IPreparedModel object is no longer being used by the NN
135 // runtime.
136 mPreparedModelV1_0->unlinkToDeath(mDeathHandler).isOk();
137 }
138
execute(const Request & request,MeasureTiming measure,const sp<ExecutionCallback> & callback)139 ErrorStatus VersionedIPreparedModel::execute(const Request& request, MeasureTiming measure,
140 const sp<ExecutionCallback>& callback) {
141 const auto scoped = mDeathHandler->protectCallback(callback);
142
143 if (mPreparedModelV1_2 != nullptr) {
144 Return<ErrorStatus> ret = mPreparedModelV1_2->execute_1_2(request, measure, callback);
145 if (!ret.isOk()) {
146 sendFailureMessage(callback);
147 LOG(ERROR) << "execute_1_2 failure: " << ret.description();
148 return ErrorStatus::GENERAL_FAILURE;
149 }
150 if (ret != ErrorStatus::NONE) {
151 sendFailureMessage(callback);
152 LOG(ERROR) << "execute_1_2 returned " << toString(static_cast<ErrorStatus>(ret));
153 return static_cast<ErrorStatus>(ret);
154 }
155 callback->wait();
156 return static_cast<ErrorStatus>(ret);
157 } else if (mPreparedModelV1_0 != nullptr) {
158 Return<ErrorStatus> ret = mPreparedModelV1_0->execute(request, callback);
159 if (!ret.isOk()) {
160 sendFailureMessage(callback);
161 LOG(ERROR) << "execute failure: " << ret.description();
162 return ErrorStatus::GENERAL_FAILURE;
163 }
164 if (ret != ErrorStatus::NONE) {
165 sendFailureMessage(callback);
166 LOG(ERROR) << "execute returned " << toString(static_cast<ErrorStatus>(ret));
167 return static_cast<ErrorStatus>(ret);
168 }
169 callback->wait();
170 return static_cast<ErrorStatus>(ret);
171 } else {
172 sendFailureMessage(callback);
173 LOG(ERROR) << "execute called with no preparedModel";
174 return ErrorStatus::GENERAL_FAILURE;
175 }
176 }
177
178 std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>
executeSynchronously(const Request & request,MeasureTiming measure)179 VersionedIPreparedModel::executeSynchronously(const Request& request, MeasureTiming measure) {
180 const std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> kFailure = {
181 ErrorStatus::GENERAL_FAILURE, {}, kBadTiming};
182
183 if (mPreparedModelV1_2 != nullptr) {
184 std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> result;
185 Return<void> ret = mPreparedModelV1_2->executeSynchronously(
186 request, measure,
187 [&result](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
188 const Timing& timing) {
189 result = std::make_tuple(error, outputShapes, timing);
190 });
191 if (!ret.isOk()) {
192 LOG(ERROR) << "executeSynchronously failure: " << ret.description();
193 return kFailure;
194 }
195 return result;
196 } else {
197 // Simulate synchronous execution.
198 sp<ExecutionCallback> callback = new ExecutionCallback();
199 ErrorStatus ret = execute(request, measure, callback);
200 if (ret != ErrorStatus::NONE) {
201 return {ret, {}, kBadTiming};
202 }
203 callback->wait();
204 // callback->getOutputShapes() will always return an empty hidl vector.
205 // callback->getTiming() will always return values indicating no measurement.
206 return {callback->getStatus(), callback->getOutputShapes(), callback->getTiming()};
207 }
208 }
209
configureExecutionBurst(bool blocking) const210 std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
211 bool blocking) const {
212 if (mPreparedModelV1_2 != nullptr) {
213 return ExecutionBurstController::create(mPreparedModelV1_2, blocking);
214 } else {
215 return nullptr;
216 }
217 }
218
operator ==(nullptr_t) const219 bool VersionedIPreparedModel::operator==(nullptr_t) const {
220 return mPreparedModelV1_0 == nullptr;
221 }
222
operator !=(nullptr_t) const223 bool VersionedIPreparedModel::operator!=(nullptr_t) const {
224 return mPreparedModelV1_0 != nullptr;
225 }
226
create(std::string serviceName,sp<V1_0::IDevice> device)227 std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
228 sp<V1_0::IDevice> device) {
229 auto core = Core::create(std::move(device));
230 if (!core.has_value()) {
231 LOG(ERROR) << "VersionedIDevice::create -- Failed to create Core.";
232 return nullptr;
233 }
234
235 // return a valid VersionedIDevice object
236 return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
237 }
238
VersionedIDevice(std::string serviceName,Core core)239 VersionedIDevice::VersionedIDevice(std::string serviceName, Core core)
240 : mServiceName(std::move(serviceName)), mCore(std::move(core)) {}
241
create(sp<V1_0::IDevice> device)242 std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
243 // verify input
244 if (!device) {
245 LOG(ERROR) << "VersionedIDevice::Core::create -- passed invalid device object.";
246 return {};
247 }
248
249 // create death handler object
250 sp<IDeviceDeathHandler> deathHandler = new (std::nothrow) IDeviceDeathHandler();
251 if (!deathHandler) {
252 LOG(ERROR) << "VersionedIDevice::Core::create -- Failed to create IDeviceDeathHandler.";
253 return {};
254 }
255
256 // linkToDeath registers a callback that will be invoked on service death to
257 // proactively handle service crashes. If the linkToDeath call fails,
258 // asynchronous calls are susceptible to hangs if the service crashes before
259 // providing the response.
260 const Return<bool> ret = device->linkToDeath(deathHandler, 0);
261 if (!ret.isOk() || ret != true) {
262 LOG(ERROR)
263 << "VersionedIDevice::Core::create -- Failed to register a death recipient for the "
264 "IDevice object.";
265 return {};
266 }
267
268 // return a valid Core object
269 return Core(std::move(device), std::move(deathHandler));
270 }
271
272 // HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces.
Core(sp<V1_0::IDevice> device,sp<IDeviceDeathHandler> deathHandler)273 VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
274 : mDeviceV1_0(std::move(device)),
275 mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
276 mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
277 mDeathHandler(std::move(deathHandler)) {}
278
~Core()279 VersionedIDevice::Core::~Core() {
280 if (mDeathHandler != nullptr) {
281 CHECK(mDeviceV1_0 != nullptr);
282 // It is safe to ignore any errors resulting from this unlinkToDeath call
283 // because the VersionedIDevice::Core object is already being destroyed and
284 // its underlying IDevice object is no longer being used by the NN runtime.
285 mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
286 }
287 }
288
Core(Core && other)289 VersionedIDevice::Core::Core(Core&& other) noexcept
290 : mDeviceV1_0(std::move(other.mDeviceV1_0)),
291 mDeviceV1_1(std::move(other.mDeviceV1_1)),
292 mDeviceV1_2(std::move(other.mDeviceV1_2)),
293 mDeathHandler(std::move(other.mDeathHandler)) {
294 other.mDeathHandler = nullptr;
295 }
296
operator =(Core && other)297 VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept {
298 if (this != &other) {
299 mDeviceV1_0 = std::move(other.mDeviceV1_0);
300 mDeviceV1_1 = std::move(other.mDeviceV1_1);
301 mDeviceV1_2 = std::move(other.mDeviceV1_2);
302 mDeathHandler = std::move(other.mDeathHandler);
303 other.mDeathHandler = nullptr;
304 }
305 return *this;
306 }
307
308 template <typename T_IDevice>
getDeviceAndDeathHandler() const309 std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler()
310 const {
311 return {getDevice<T_IDevice>(), mDeathHandler};
312 }
313
314 template <typename T_IDevice, typename T_Callback>
callProtected(const char * context,const std::function<Return<ErrorStatus> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const sp<T_Callback> & callback,const sp<IDeviceDeathHandler> & deathHandler)315 Return<ErrorStatus> callProtected(
316 const char* context, const std::function<Return<ErrorStatus>(const sp<T_IDevice>&)>& fn,
317 const sp<T_IDevice>& device, const sp<T_Callback>& callback,
318 const sp<IDeviceDeathHandler>& deathHandler) {
319 const auto scoped = deathHandler->protectCallback(callback);
320 Return<ErrorStatus> ret = fn(device);
321 // Suppose there was a transport error. We have the following cases:
322 // 1. Either not due to a dead device, or due to a device that was
323 // already dead at the time of the call to protectCallback(). In
324 // this case, the callback was never signalled.
325 // 2. Due to a device that died after the call to protectCallback() but
326 // before fn() completed. In this case, the callback was (or will
327 // be) signalled by the deathHandler.
328 // Furthermore, what if there was no transport error, but the ErrorStatus is
329 // other than NONE? We'll conservatively signal the callback anyway, just in
330 // case the driver was sloppy and failed to do so.
331 if (!ret.isOk() || ret != ErrorStatus::NONE) {
332 // What if the deathHandler has signalled or will signal the callback?
333 // This is fine -- we're permitted to signal multiple times; and we're
334 // sending the same signal that the deathHandler does.
335 //
336 // What if the driver signalled the callback? Then this signal is
337 // ignored.
338
339 if (ret.isOk()) {
340 LOG(ERROR) << context << " returned " << toString(static_cast<ErrorStatus>(ret));
341 } else {
342 LOG(ERROR) << context << " failure: " << ret.description();
343 }
344 sendFailureMessage(callback);
345 }
346 callback->wait();
347 return ret;
348 }
349 template <typename T_Return, typename T_IDevice>
callProtected(const char *,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const std::nullptr_t &,const sp<IDeviceDeathHandler> &)350 Return<T_Return> callProtected(const char*,
351 const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
352 const sp<T_IDevice>& device, const std::nullptr_t&,
353 const sp<IDeviceDeathHandler>&) {
354 return fn(device);
355 }
356
357 template <typename T_Return, typename T_IDevice, typename T_Callback>
recoverable(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const T_Callback & callback) const358 Return<T_Return> VersionedIDevice::recoverable(
359 const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
360 const T_Callback& callback) const EXCLUDES(mMutex) {
361 CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>));
362
363 sp<T_IDevice> device;
364 sp<IDeviceDeathHandler> deathHandler;
365 std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>();
366
367 Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler);
368
369 if (ret.isDeadObject()) {
370 {
371 std::unique_lock lock(mMutex);
372 // It's possible that another device has already done the recovery.
373 // It's harmless but wasteful for us to do so in this case.
374 auto pingReturn = mCore.getDevice<T_IDevice>()->ping();
375 if (pingReturn.isDeadObject()) {
376 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
377 << mServiceName;
378 sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(mServiceName);
379 if (recoveredDevice == nullptr) {
380 VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
381 << mServiceName;
382 return ret;
383 }
384
385 auto core = Core::create(std::move(recoveredDevice));
386 if (!core.has_value()) {
387 LOG(ERROR) << "VersionedIDevice::recoverable -- Failed to create Core.";
388 return ret;
389 }
390
391 mCore = std::move(core.value());
392 } else {
393 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context
394 << ") -- Someone else recovered " << mServiceName;
395 // Might still have a transport error, which we need to check
396 // before pingReturn goes out of scope.
397 (void)pingReturn.isOk();
398 }
399 std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>();
400 }
401 ret = callProtected(context, fn, device, callback, deathHandler);
402 // It's possible that the device died again, but we're only going to
403 // attempt recovery once per call to recoverable().
404 }
405 return ret;
406 }
407
getCapabilities()408 std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() {
409 const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
410 std::pair<ErrorStatus, Capabilities> result;
411
412 if (getDevice<V1_2::IDevice>() != nullptr) {
413 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2");
414 Return<void> ret = recoverable<void, V1_2::IDevice>(
415 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
416 return device->getCapabilities_1_2(
417 [&result](ErrorStatus error, const Capabilities& capabilities) {
418 result = std::make_pair(error, capabilities);
419 });
420 });
421 if (!ret.isOk()) {
422 LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description();
423 return {ErrorStatus::GENERAL_FAILURE, {}};
424 }
425 } else if (getDevice<V1_1::IDevice>() != nullptr) {
426 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1");
427 Return<void> ret = recoverable<void, V1_1::IDevice>(
428 __FUNCTION__, [&result](const sp<V1_1::IDevice>& device) {
429 return device->getCapabilities_1_1(
430 [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) {
431 // Time taken to convert capabilities is trivial
432 result = std::make_pair(error, convertToV1_2(capabilities));
433 });
434 });
435 if (!ret.isOk()) {
436 LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description();
437 return kFailure;
438 }
439 } else if (getDevice<V1_0::IDevice>() != nullptr) {
440 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities");
441 Return<void> ret = recoverable<void, V1_0::IDevice>(
442 __FUNCTION__, [&result](const sp<V1_0::IDevice>& device) {
443 return device->getCapabilities(
444 [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) {
445 // Time taken to convert capabilities is trivial
446 result = std::make_pair(error, convertToV1_2(capabilities));
447 });
448 });
449 if (!ret.isOk()) {
450 LOG(ERROR) << "getCapabilities failure: " << ret.description();
451 return kFailure;
452 }
453 } else {
454 LOG(ERROR) << "Device not available!";
455 return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
456 }
457
458 return result;
459 }
460
getSupportedExtensions()461 std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() {
462 const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
463 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedExtensions");
464 if (getDevice<V1_2::IDevice>() != nullptr) {
465 std::pair<ErrorStatus, hidl_vec<Extension>> result;
466 Return<void> ret = recoverable<void, V1_2::IDevice>(
467 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
468 return device->getSupportedExtensions(
469 [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) {
470 result = std::make_pair(error, extensions);
471 });
472 });
473 if (!ret.isOk()) {
474 LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
475 return kFailure;
476 }
477 return result;
478 } else if (getDevice<V1_0::IDevice>() != nullptr) {
479 return {ErrorStatus::NONE, {/* No extensions. */}};
480 } else {
481 LOG(ERROR) << "Device not available!";
482 return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
483 }
484 }
485
getSupportedOperations(const Model & model,IModelSlicer * slicer)486 std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
487 const Model& model, IModelSlicer* slicer) {
488 const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
489 std::pair<ErrorStatus, hidl_vec<bool>> result;
490
491 auto noneSupported = [&model] {
492 hidl_vec<bool> supported(model.operations.size());
493 std::fill(supported.begin(), supported.end(), false);
494 return std::make_pair(ErrorStatus::NONE, std::move(supported));
495 };
496
497 auto remappedResult = [&model](const std::pair<ErrorStatus, hidl_vec<bool>>& result,
498 const std::function<uint32_t(uint32_t)>&
499 submodelOperationIndexToModelOperationIndex) {
500 const ErrorStatus status = result.first;
501 const hidl_vec<bool>& supported = result.second;
502 hidl_vec<bool> remappedSupported(model.operations.size());
503 std::fill(remappedSupported.begin(), remappedSupported.end(), false);
504 for (size_t i = 0; i < supported.size(); ++i) {
505 if (supported[i]) {
506 remappedSupported[submodelOperationIndexToModelOperationIndex(i)] = true;
507 }
508 }
509 return std::make_pair(status, std::move(remappedSupported));
510 };
511
512 if (getDevice<V1_2::IDevice>() != nullptr) {
513 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2");
514 Return<void> ret = recoverable<void, V1_2::IDevice>(
515 __FUNCTION__, [&model, &result](const sp<V1_2::IDevice>& device) {
516 return device->getSupportedOperations_1_2(
517 model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
518 result = std::make_pair(error, supported);
519 });
520 });
521 if (!ret.isOk()) {
522 LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description();
523 return kFailure;
524 }
525 return result;
526 }
527
528 if (getDevice<V1_1::IDevice>() != nullptr) {
529 const bool compliant = compliantWithV1_1(model);
530 if (compliant || slicer) {
531 V1_1::Model model11;
532 std::function<uint32_t(uint32_t)> submodelOperationIndexToModelOperationIndex;
533 if (compliant) {
534 model11 = convertToV1_1(model);
535 } else {
536 const auto slice11 = slicer->getSliceV1_1();
537 if (!slice11.has_value()) {
538 return noneSupported();
539 }
540 std::tie(model11, submodelOperationIndexToModelOperationIndex) = *slice11;
541 }
542 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION,
543 "getSupportedOperations_1_1");
544 Return<void> ret = recoverable<void, V1_1::IDevice>(
545 __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) {
546 return device->getSupportedOperations_1_1(
547 model11,
548 [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
549 result = std::make_pair(error, supported);
550 });
551 });
552 if (!ret.isOk()) {
553 LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description();
554 return kFailure;
555 }
556 if (!compliant) {
557 return remappedResult(result, submodelOperationIndexToModelOperationIndex);
558 }
559 }
560 return result;
561 }
562
563 if (getDevice<V1_0::IDevice>() != nullptr) {
564 const bool compliant = compliantWithV1_0(model);
565 if (compliant || slicer) {
566 V1_0::Model model10;
567 std::function<uint32_t(uint32_t)> submodelOperationIndexToModelOperationIndex;
568 if (compliant) {
569 model10 = convertToV1_0(model);
570 } else {
571 const auto slice10 = slicer->getSliceV1_0();
572 if (!slice10.has_value()) {
573 return noneSupported();
574 }
575 std::tie(model10, submodelOperationIndexToModelOperationIndex) = *slice10;
576 }
577 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations");
578 Return<void> ret = recoverable<void, V1_0::IDevice>(
579 __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) {
580 return device->getSupportedOperations(
581 model10,
582 [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
583 result = std::make_pair(error, supported);
584 });
585 });
586 if (!ret.isOk()) {
587 LOG(ERROR) << "getSupportedOperations failure: " << ret.description();
588 return kFailure;
589 }
590 if (!compliant) {
591 return remappedResult(result, submodelOperationIndexToModelOperationIndex);
592 }
593 }
594 return result;
595 }
596
597 return kFailure;
598 }
599
prepareModel(const Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const HidlToken & token)600 std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModel(
601 const Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>& modelCache,
602 const hidl_vec<hidl_handle>& dataCache, const HidlToken& token) {
603 const std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> kFailure = {
604 ErrorStatus::GENERAL_FAILURE, nullptr};
605
606 const sp<PreparedModelCallback> callback = new (std::nothrow) PreparedModelCallback();
607 if (callback == nullptr) {
608 LOG(ERROR) << "prepareModel failed to create callback object";
609 return kFailure;
610 }
611
612 // If 1.2 device, try preparing model
613 if (getDevice<V1_2::IDevice>() != nullptr) {
614 const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
615 __FUNCTION__,
616 [&model, &preference, &modelCache, &dataCache, &token,
617 &callback](const sp<V1_2::IDevice>& device) {
618 return device->prepareModel_1_2(model, preference, modelCache, dataCache, token,
619 callback);
620 },
621 callback);
622 if (!ret.isOk()) {
623 LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
624 return kFailure;
625 }
626 if (ret != ErrorStatus::NONE) {
627 LOG(ERROR) << "prepareModel_1_2 returned " << toString(static_cast<ErrorStatus>(ret));
628 return kFailure;
629 }
630 callback->wait();
631 return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
632 }
633
634 // If 1.1 device, try preparing model (requires conversion)
635 if (getDevice<V1_1::IDevice>() != nullptr) {
636 bool compliant = false;
637 V1_1::Model model11;
638 {
639 // Attribute time spent in model inspection and conversion to
640 // Runtime, as the time may be substantial (0.03ms for mobilenet,
641 // but could be larger for other models).
642 NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
643 "VersionedIDevice::prepareModel_1_1");
644 compliant = compliantWithV1_1(model);
645 if (compliant) {
646 model11 = convertToV1_1(model); // copy is elided
647 }
648 }
649 if (compliant) {
650 const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_1::IDevice>(
651 __FUNCTION__,
652 [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) {
653 return device->prepareModel_1_1(model11, preference, callback);
654 },
655 callback);
656 if (!ret.isOk()) {
657 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
658 return kFailure;
659 }
660 if (ret != ErrorStatus::NONE) {
661 LOG(ERROR) << "prepareModel_1_1 returned "
662 << toString(static_cast<ErrorStatus>(ret));
663 return kFailure;
664 }
665 callback->wait();
666 return {callback->getStatus(),
667 makeVersionedIPreparedModel(callback->getPreparedModel())};
668 }
669
670 LOG(ERROR) << "Could not handle prepareModel_1_1!";
671 return kFailure;
672 }
673
674 // If 1.0 device, try preparing model (requires conversion)
675 if (getDevice<V1_0::IDevice>() != nullptr) {
676 bool compliant = false;
677 V1_0::Model model10;
678 {
679 // Attribute time spent in model inspection and conversion to
680 // Runtime, as the time may be substantial (0.03ms for mobilenet,
681 // but could be larger for other models).
682 NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
683 "VersionedIDevice::prepareModel");
684 compliant = compliantWithV1_0(model);
685 if (compliant) {
686 model10 = convertToV1_0(model); // copy is elided
687 }
688 }
689 if (compliant) {
690 const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_0::IDevice>(
691 __FUNCTION__,
692 [&model10, &callback](const sp<V1_0::IDevice>& device) {
693 return device->prepareModel(model10, callback);
694 },
695 callback);
696 if (!ret.isOk()) {
697 LOG(ERROR) << "prepareModel failure: " << ret.description();
698 return kFailure;
699 }
700 if (ret != ErrorStatus::NONE) {
701 LOG(ERROR) << "prepareModel returned " << toString(static_cast<ErrorStatus>(ret));
702 return kFailure;
703 }
704 callback->wait();
705 return {callback->getStatus(),
706 makeVersionedIPreparedModel(callback->getPreparedModel())};
707 }
708
709 LOG(ERROR) << "Could not handle prepareModel!";
710 return kFailure;
711 }
712
713 // Return error because there is no valid device
714 LOG(ERROR) << "prepareModel called with no device";
715 return kFailure;
716 }
717
718 std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>>
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const HidlToken & token)719 VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
720 const hidl_vec<hidl_handle>& dataCache,
721 const HidlToken& token) {
722 const std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> kFailure = {
723 ErrorStatus::GENERAL_FAILURE, nullptr};
724
725 const sp<PreparedModelCallback> callback = new (std::nothrow) PreparedModelCallback();
726 if (callback == nullptr) {
727 LOG(ERROR) << "prepareModelFromCache failed to create callback object";
728 return kFailure;
729 }
730
731 if (getDevice<V1_2::IDevice>() != nullptr) {
732 const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
733 __FUNCTION__,
734 [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) {
735 return device->prepareModelFromCache(modelCache, dataCache, token, callback);
736 },
737 callback);
738 if (!ret.isOk()) {
739 LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
740 return kFailure;
741 }
742 if (ret != ErrorStatus::NONE) {
743 LOG(ERROR) << "prepareModelFromCache returned "
744 << toString(static_cast<ErrorStatus>(ret));
745 return kFailure;
746 }
747 callback->wait();
748 return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
749 }
750
751 if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
752 LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device";
753 return kFailure;
754 }
755
756 LOG(ERROR) << "prepareModelFromCache called with no device";
757 return kFailure;
758 }
759
getStatus()760 DeviceStatus VersionedIDevice::getStatus() {
761 if (getDevice<V1_0::IDevice>() == nullptr) {
762 LOG(ERROR) << "Device not available!";
763 return DeviceStatus::UNKNOWN;
764 }
765
766 Return<DeviceStatus> ret = recoverable<DeviceStatus, V1_0::IDevice>(
767 __FUNCTION__, [](const sp<V1_0::IDevice>& device) { return device->getStatus(); });
768
769 if (!ret.isOk()) {
770 LOG(ERROR) << "getStatus failure: " << ret.description();
771 return DeviceStatus::UNKNOWN;
772 }
773 return static_cast<DeviceStatus>(ret);
774 }
775
getFeatureLevel()776 int64_t VersionedIDevice::getFeatureLevel() {
777 constexpr int64_t kFailure = -1;
778
779 if (getDevice<V1_2::IDevice>() != nullptr) {
780 return __ANDROID_API_Q__;
781 } else if (getDevice<V1_1::IDevice>() != nullptr) {
782 return __ANDROID_API_P__;
783 } else if (getDevice<V1_0::IDevice>() != nullptr) {
784 return __ANDROID_API_O_MR1__;
785 } else {
786 LOG(ERROR) << "Device not available!";
787 return kFailure;
788 }
789 }
790
getType() const791 int32_t VersionedIDevice::getType() const {
792 constexpr int32_t kFailure = -1;
793 std::pair<ErrorStatus, DeviceType> result;
794
795 if (getDevice<V1_2::IDevice>() != nullptr) {
796 Return<void> ret = recoverable<void, V1_2::IDevice>(
797 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
798 return device->getType([&result](ErrorStatus error, DeviceType deviceType) {
799 result = std::make_pair(error, deviceType);
800 });
801 });
802 if (!ret.isOk()) {
803 LOG(ERROR) << "getType failure: " << ret.description();
804 return kFailure;
805 }
806 return static_cast<int32_t>(result.second);
807 } else {
808 LOG(INFO) << "Unknown NNAPI device type.";
809 return ANEURALNETWORKS_DEVICE_UNKNOWN;
810 }
811 }
812
getVersionString()813 std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() {
814 const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
815 std::pair<ErrorStatus, hidl_string> result;
816
817 if (getDevice<V1_2::IDevice>() != nullptr) {
818 Return<void> ret = recoverable<void, V1_2::IDevice>(
819 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
820 return device->getVersionString(
821 [&result](ErrorStatus error, const hidl_string& version) {
822 result = std::make_pair(error, version);
823 });
824 });
825 if (!ret.isOk()) {
826 LOG(ERROR) << "getVersion failure: " << ret.description();
827 return kFailure;
828 }
829 return result;
830 } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
831 return {ErrorStatus::NONE, "UNKNOWN"};
832 } else {
833 LOG(ERROR) << "Could not handle getVersionString";
834 return kFailure;
835 }
836 }
837
getNumberOfCacheFilesNeeded()838 std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() {
839 constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE,
840 0, 0};
841 std::tuple<ErrorStatus, uint32_t, uint32_t> result;
842
843 if (getDevice<V1_2::IDevice>() != nullptr) {
844 Return<void> ret = recoverable<void, V1_2::IDevice>(
845 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
846 return device->getNumberOfCacheFilesNeeded([&result](ErrorStatus error,
847 uint32_t numModelCache,
848 uint32_t numDataCache) {
849 result = {error, numModelCache, numDataCache};
850 });
851 });
852 if (!ret.isOk()) {
853 LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description();
854 return kFailure;
855 }
856 return result;
857 } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
858 return {ErrorStatus::NONE, 0, 0};
859 } else {
860 LOG(ERROR) << "Could not handle getNumberOfCacheFilesNeeded";
861 return kFailure;
862 }
863 }
864
operator ==(nullptr_t) const865 bool VersionedIDevice::operator==(nullptr_t) const {
866 return getDevice<V1_0::IDevice>() == nullptr;
867 }
868
operator !=(nullptr_t) const869 bool VersionedIDevice::operator!=(nullptr_t) const {
870 return getDevice<V1_0::IDevice>() != nullptr;
871 }
872
873 } // namespace nn
874 } // namespace android
875