• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ShimDevice"
18 
19 #include "ShimDevice.h"
20 
21 #include <NeuralNetworks.h>
22 #include <aidl/android/hardware/neuralnetworks/DataLocation.h>
23 #include <aidl/android/hardware/neuralnetworks/ErrorStatus.h>
24 #include <aidl/android/hardware/neuralnetworks/Extension.h>
25 #include <aidl/android/hardware/neuralnetworks/ExtensionOperandTypeInformation.h>
26 #include <aidl/android/hardware/neuralnetworks/Memory.h>
27 #include <aidl/android/hardware/neuralnetworks/NumberOfCacheFiles.h>
28 #include <aidl/android/hardware/neuralnetworks/OperandLifeTime.h>
29 #include <aidl/android/hardware/neuralnetworks/OperandPerformance.h>
30 #include <android-base/logging.h>
31 #include <android-base/scopeguard.h>
32 #include <android/binder_auto_utils.h>
33 #include <android/binder_manager.h>
34 #include <android/binder_process.h>
35 #include <nnapi/TypeUtils.h>
36 #include <nnapi/hal/aidl/Conversions.h>
37 
38 #include <algorithm>
39 #include <limits>
40 #include <memory>
41 #include <optional>
42 #include <string>
43 #include <unordered_map>
44 #include <utility>
45 #include <vector>
46 
47 #include "ShimConverter.h"
48 #include "ShimPreparedModel.h"
49 #include "ShimUtils.h"
50 #include "SupportLibrary.h"
51 
52 using namespace ::android::nn::sl_wrapper;
53 
54 namespace aidl::android::hardware::neuralnetworks {
55 
56 namespace {
57 
convertToNDKPriority(Priority priority)58 constexpr std::optional<::android::nn::wrapper::ExecutePriority> convertToNDKPriority(
59         Priority priority) {
60     switch (priority) {
61         case Priority::LOW:
62             return ::android::nn::wrapper::ExecutePriority::LOW;
63         case Priority::MEDIUM:
64             return ::android::nn::wrapper::ExecutePriority::MEDIUM;
65         case Priority::HIGH:
66             return ::android::nn::wrapper::ExecutePriority::HIGH;
67     }
68     LOG(ERROR) << "unrecognized priority: " << static_cast<int32_t>(priority);
69     return std::nullopt;
70 }
71 
convertToNDKPreference(ExecutionPreference preference)72 constexpr std::optional<::android::nn::wrapper::ExecutePreference> convertToNDKPreference(
73         ExecutionPreference preference) {
74     switch (preference) {
75         case ExecutionPreference::LOW_POWER:
76             return ::android::nn::wrapper::ExecutePreference::PREFER_LOW_POWER;
77         case ExecutionPreference::FAST_SINGLE_ANSWER:
78             return ::android::nn::wrapper::ExecutePreference::PREFER_FAST_SINGLE_ANSWER;
79         case ExecutionPreference::SUSTAINED_SPEED:
80             return ::android::nn::wrapper::ExecutePreference::PREFER_SUSTAINED_SPEED;
81     }
82     LOG(ERROR) << "unrecognized preference: " << static_cast<int32_t>(preference);
83     return std::nullopt;
84 }
85 
86 // Safely downcast an IPreparedModel object to ShimPreparedModel.
87 // This function will return nullptr if the IPreparedModel object is not originated from the
88 // shim process.
castToShimPreparedModel(IPreparedModel * preparedModel)89 const ShimPreparedModel* castToShimPreparedModel(IPreparedModel* preparedModel) {
90     if (preparedModel->isRemote()) {
91         return nullptr;
92     }
93     // This static_cast is safe because ShimPreparedModel is the only class that implements
94     // the IPreparedModel interface in the sample driver process.
95     return static_cast<const ShimPreparedModel*>(preparedModel);
96 }
97 
convertPerformanceInfo(const SL_ANeuralNetworksPerformanceInfo & info)98 static PerformanceInfo convertPerformanceInfo(const SL_ANeuralNetworksPerformanceInfo& info) {
99     return {.execTime = info.execTime, .powerUsage = info.powerUsage};
100 }
101 
getCapabilities(const NnApiSupportLibrary * nnapi,ANeuralNetworksDevice * device)102 Capabilities getCapabilities(const NnApiSupportLibrary* nnapi, ANeuralNetworksDevice* device) {
103     Capabilities capabilities;
104     SL_ANeuralNetworksPerformanceInfo performanceInfo;
105 
106     nnapi->getFL5()->SL_ANeuralNetworksDevice_getPerformanceInfo(
107             device, SL_ANEURALNETWORKS_CAPABILITIES_PERFORMANCE_RELAXED_SCALAR, &performanceInfo);
108     capabilities.relaxedFloat32toFloat16PerformanceScalar = convertPerformanceInfo(performanceInfo);
109 
110     nnapi->getFL5()->SL_ANeuralNetworksDevice_getPerformanceInfo(
111             device, SL_ANEURALNETWORKS_CAPABILITIES_PERFORMANCE_RELAXED_TENSOR, &performanceInfo);
112     capabilities.relaxedFloat32toFloat16PerformanceTensor = convertPerformanceInfo(performanceInfo);
113 
114     nnapi->getFL5()->SL_ANeuralNetworksDevice_getPerformanceInfo(
115             device, SL_ANEURALNETWORKS_CAPABILITIES_PERFORMANCE_IF, &performanceInfo);
116     capabilities.ifPerformance = convertPerformanceInfo(performanceInfo);
117 
118     nnapi->getFL5()->SL_ANeuralNetworksDevice_getPerformanceInfo(
119             device, SL_ANEURALNETWORKS_CAPABILITIES_PERFORMANCE_WHILE, &performanceInfo);
120     capabilities.whilePerformance = convertPerformanceInfo(performanceInfo);
121 
122     constexpr auto fn = [](SL_ANeuralNetworksOperandPerformanceInfo info, void* context) {
123         auto* out = static_cast<std::vector<OperandPerformance>*>(context);
124         out->push_back(OperandPerformance{
125                 .type = static_cast<OperandType>(info.operandType),
126                 .info = convertPerformanceInfo(info.performanceInfo),
127         });
128     };
129 
130     nnapi->getFL5()->SL_ANeuralNetworksDevice_forEachOperandPerformanceInfo(
131             device, static_cast<void*>(&capabilities.operandPerformance), fn);
132 
133     return capabilities;
134 }
135 
getNumberOfCacheFilesNeeded(const NnApiSupportLibrary * nnapi,ANeuralNetworksDevice * device)136 NumberOfCacheFiles getNumberOfCacheFilesNeeded(const NnApiSupportLibrary* nnapi,
137                                                ANeuralNetworksDevice* device) {
138     uint32_t numModelCacheFiles;
139     uint32_t numDataCacheFiles;
140     nnapi->getFL5()->SL_ANeuralNetworksDevice_getNumberOfCacheFilesNeeded(
141             device, &numModelCacheFiles, &numDataCacheFiles);
142     return {
143             .numModelCache = static_cast<int32_t>(numModelCacheFiles),
144             .numDataCache = static_cast<int32_t>(numDataCacheFiles),
145     };
146 }
147 
getVendorExtensions(const NnApiSupportLibrary * nnapi,ANeuralNetworksDevice * device)148 std::vector<Extension> getVendorExtensions(const NnApiSupportLibrary* nnapi,
149                                            ANeuralNetworksDevice* device) {
150     uint32_t vendorExtensionCount;
151     nnapi->getFL5()->SL_ANeuralNetworksDevice_getVendorExtensionCount(device,
152                                                                       &vendorExtensionCount);
153 
154     std::vector<Extension> extensions(vendorExtensionCount);
155 
156     for (uint32_t vendorExtensionIndex = 0; vendorExtensionIndex < vendorExtensionCount;
157          ++vendorExtensionIndex) {
158         auto& extension = extensions[vendorExtensionIndex];
159 
160         const char* extensionName;
161         nnapi->getFL5()->SL_ANeuralNetworksDevice_getVendorExtensionName(
162                 device, vendorExtensionIndex, &extensionName);
163         extension.name = extensionName;
164 
165         constexpr auto fn = [](SL_ANeuralNetworksExtensionOperandTypeInformation info,
166                                void* context) {
167             auto* out = static_cast<std::vector<ExtensionOperandTypeInformation>*>(context);
168             out->push_back(ExtensionOperandTypeInformation{
169                     .type = info.type,
170                     .isTensor = info.isTensor,
171                     .byteSize = static_cast<int32_t>(info.byteSize),
172             });
173         };
174         nnapi->getFL5()->SL_ANeuralNetworksDevice_forEachVendorExtensionOperandTypeInformation(
175                 device, vendorExtensionIndex, static_cast<void*>(&extension.operandTypes), fn);
176     }
177 
178     return extensions;
179 }
180 
181 }  // namespace
182 
ShimDevice(std::shared_ptr<const NnApiSupportLibrary> nnapi,ANeuralNetworksDevice * device,std::string serviceName)183 ShimDevice::ShimDevice(std::shared_ptr<const NnApiSupportLibrary> nnapi,
184                        ANeuralNetworksDevice* device, std::string serviceName)
185     : mNnapi(std::move(nnapi)),
186       mBufferTracker(ShimBufferTracker::create()),
187       mServiceName(std::move(serviceName)),
188       mDevice(device),
189       mCapabilities(neuralnetworks::getCapabilities(mNnapi.get(), mDevice)),
190       mNumberOfCacheFiles(neuralnetworks::getNumberOfCacheFilesNeeded(mNnapi.get(), mDevice)),
191       mExtensions(neuralnetworks::getVendorExtensions(mNnapi.get(), mDevice)) {}
192 
193 // Manages the data buffer for an operand.
194 class ShimBuffer : public BnBuffer {
195    public:
ShimBuffer(const NnApiSupportLibrary * nnApi,const::android::nn::Dimensions initialDimensions,const::android::nn::OperandType type,std::shared_ptr<::android::nn::sl_wrapper::Memory> memory,std::unique_ptr<ShimBufferTracker::Token> token)196     ShimBuffer(const NnApiSupportLibrary* nnApi, const ::android::nn::Dimensions initialDimensions,
197                const ::android::nn::OperandType type,
198                std::shared_ptr<::android::nn::sl_wrapper::Memory> memory,
199                std::unique_ptr<ShimBufferTracker::Token> token)
200         : kInitialDimensions(initialDimensions),
201           kType(type),
202           mNnApi(nnApi),
203           mMemory(std::move(memory)),
204           kToken(std::move(token)) {}
205 
tensorHasUnspecifiedDimensions(::android::nn::OperandType type,const::android::nn::Dimensions & dimensions)206     bool tensorHasUnspecifiedDimensions(::android::nn::OperandType type,
207                                         const ::android::nn::Dimensions& dimensions) {
208         if (!::android::nn::isExtension(type)) {
209             if (isNonExtensionScalar(type)) {
210                 return false;
211             }
212         }
213         return dimensions.size() == 0 || std::any_of(dimensions.begin(), dimensions.end(),
214                                                      [](int32_t dim) { return dim == 0; });
215     }
216 
validateDimensions(const::android::nn::Dimensions & dimensions)217     bool validateDimensions(const ::android::nn::Dimensions& dimensions) {
218         if (isNonExtensionScalar(kType)) {
219             if (!dimensions.empty()) {
220                 LOG(ERROR) << "ShimBuffer::validateDimensions -- invalid dimensions for scalar "
221                               "operand";
222                 return false;
223             }
224             return true;
225         }
226 
227         if (dimensions.empty()) {
228             if (tensorHasUnspecifiedDimensions(kType, kInitialDimensions)) {
229                 LOG(ERROR) << "ShimBuffer::validateDimensions -- the initial dimensions are not "
230                               "fully specified and no dimension update is provided: ";
231 
232                 return false;
233             }
234         } else {
235             if (tensorHasUnspecifiedDimensions(kType, dimensions)) {
236                 LOG(ERROR) << "ShimBuffer::validateDimensions -- the updated dimensions are not "
237                               "fully specified: ";
238 
239                 return false;
240             }
241         }
242 
243         const auto combined = ::android::nn::combineDimensions(kInitialDimensions, dimensions);
244         if (!combined.has_value()) {
245             LOG(ERROR) << "ShimBuffer::validateDimensions -- incompatible dimensions";
246             return false;
247         }
248         return true;
249     }
250 
copyFrom(const aidl::android::hardware::neuralnetworks::Memory & src,const std::vector<int32_t> & dimensions)251     ndk::ScopedAStatus copyFrom(const aidl::android::hardware::neuralnetworks::Memory& src,
252                                 const std::vector<int32_t>& dimensions) override {
253         auto memory = convertFromHAL(mNnApi, src);
254 
255         if (!memory) {
256             LOG(ERROR) << "Failed to convert HAL Memory to SL memory";
257             return toAStatus(ErrorStatus::INVALID_ARGUMENT);
258         }
259         const auto unsignedDimensions = ::android::nn::toUnsigned(dimensions);
260         if (!unsignedDimensions.has_value()) {
261             return toAStatus(ErrorStatus::INVALID_ARGUMENT, unsignedDimensions.error().message);
262         }
263 
264         if (!validateDimensions(unsignedDimensions.value())) {
265             LOG(ERROR) << "Invalid dimensions";
266             return toAStatus(ErrorStatus::INVALID_ARGUMENT);
267         }
268         Result result = memory->copyTo(*mMemory.get());
269 
270         // Special case expected error status for uninitialized source memory
271         if (result == Result::BAD_DATA) {
272             // NNAPI Runtime reports both uninitialized memory
273             // and incompatible dimensions as BAD_DATA, but
274             // VTS expects to see INVALID_ARGUMENT for bad dimensions,
275             // and GENERAL_FAILURE for uninitialized memory.
276             if (memory->getSize() != mMemory->getSize()) {
277                 return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Incompatible sizes");
278             }
279 
280             return toAStatus(ErrorStatus::GENERAL_FAILURE);
281         }
282         SLW2SAS_RETURN_IF_ERROR(result);
283         return ndk::ScopedAStatus::ok();
284     }
285 
copyTo(const Memory & dst)286     ndk::ScopedAStatus copyTo(const Memory& dst) override {
287         auto memory = convertFromHAL(mNnApi, dst);
288 
289         if (!memory) {
290             LOG(ERROR) << "Failed to convert HAL Memory to SL memory";
291             return toAStatus(ErrorStatus::INVALID_ARGUMENT);
292         }
293 
294         Result result = mMemory->copyTo(*memory);
295         // Special case expected error status for uninitialized source memory
296         if (result == Result::BAD_DATA) {
297             // NNAPI Runtime reports both uninitialized memory
298             // and incompatible dimensions as BAD_DATA, but
299             // VTS expects to see INVALID_ARGUMENT for bad dimensions,
300             // and GENERAL_FAILURE for uninitialized memory.
301             if (memory->getSize() != mMemory->getSize()) {
302                 return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Incompatible sizes");
303             }
304             return toAStatus(ErrorStatus::GENERAL_FAILURE);
305         }
306         SLW2SAS_RETURN_IF_ERROR(result);
307         return ndk::ScopedAStatus::ok();
308     }
309 
310    private:
311     const ::android::nn::Dimensions kInitialDimensions;
312     const ::android::nn::OperandType kType;
313 
314     const NnApiSupportLibrary* mNnApi;
315     std::shared_ptr<::android::nn::sl_wrapper::Memory> mMemory;
316     const std::unique_ptr<ShimBufferTracker::Token> kToken;
317 };
318 
allocate(const BufferDesc & desc,const std::vector<IPreparedModelParcel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,DeviceBuffer * buffer)319 ::ndk::ScopedAStatus ShimDevice::allocate(const BufferDesc& desc,
320                                           const std::vector<IPreparedModelParcel>& preparedModels,
321                                           const std::vector<BufferRole>& inputRoles,
322                                           const std::vector<BufferRole>& outputRoles,
323                                           DeviceBuffer* buffer) {
324     if (!isValidDimension(desc.dimensions)) {
325         LOG(ERROR) << "ShimDriver::allocate -- passed invalid dimension values.";
326         return toAStatus(ErrorStatus::INVALID_ARGUMENT,
327                          "ShimDriver::allocate -- passed invalid dimension values");
328     }
329     ANeuralNetworksMemoryDesc* slDesc = nullptr;
330     mNnapi->getFL5()->ANeuralNetworksMemoryDesc_create(&slDesc);
331     const auto slDescGuard = ::android::base::make_scope_guard(
332             [this, slDesc] { mNnapi->getFL5()->ANeuralNetworksMemoryDesc_free(slDesc); });
333 
334     auto unsignedDimensions = ::android::nn::toUnsigned(desc.dimensions).value();
335     if (mNnapi->getFL5()->ANeuralNetworksMemoryDesc_setDimensions(slDesc, desc.dimensions.size(),
336                                                                   unsignedDimensions.data()) !=
337         ANEURALNETWORKS_NO_ERROR) {
338         LOG(ERROR) << "ShimDriver::allocate -- ANeuralNetworksMemoryDesc_setDimensions fail.";
339         return toAStatus(ErrorStatus::INVALID_ARGUMENT,
340                          "ShimDriver::allocate -- ANeuralNetworksMemoryDesc_setDimensions fail");
341     }
342 
343     constexpr auto getCompilation = [](IPreparedModel* preparedModel) -> const ShimPreparedModel* {
344         const auto* samplePreparedModel = castToShimPreparedModel(preparedModel);
345         if (samplePreparedModel == nullptr) {
346             LOG(ERROR) << "ShimDriver::allocate -- unknown remote IPreparedModel.";
347             return nullptr;
348         }
349         return samplePreparedModel;
350     };
351 
352     std::optional<::android::nn::OperandType> type;
353     std::vector<uint32_t> dimensions = ::android::nn::toUnsigned(desc.dimensions).value();
354 
355     for (const auto& role : inputRoles) {
356         if (role.modelIndex < 0 || role.modelIndex >= preparedModels.size()) {
357             LOG(ERROR) << "Invalid modelIndex value " << role.modelIndex;
358             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
359                              "ShimDriver::allocate -- Input role modeIndex with invalid value");
360         }
361         auto preparedModel = preparedModels[role.modelIndex];
362         if (preparedModel.preparedModel == nullptr) {
363             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
364                              "ShimDriver::allocate -- nullptr model");
365         }
366 
367         auto pmodel = getCompilation(preparedModel.preparedModel.get());
368         if (pmodel == nullptr) {
369             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
370                              "ShimDriver::allocate -- nullptr model");
371         }
372 
373         auto result = mNnapi->getFL5()->ANeuralNetworksMemoryDesc_addInputRole(
374                 slDesc, pmodel->getCompilation().getHandle(), role.ioIndex, role.probability);
375 
376         if (result != ANEURALNETWORKS_NO_ERROR) {
377             LOG(ERROR) << "SampleDriver::allocate -- ANeuralNetworksMemoryDesc_addInputRole fail.";
378             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
379                              "ShimDriver::allocate -- ANeuralNetworksMemoryDesc_addInputRole fail");
380         }
381 
382         const auto& model = pmodel->getMainModel();
383         const auto& op = model.getOperands()[model.getInputs()[role.ioIndex]];
384         auto operandType = static_cast<::android::nn::OperandType>(op.operandType.type);
385         if (!type) {
386             type = operandType;
387         }
388         if (dimensions.empty()) {
389             dimensions = op.dimensions;
390         }
391     }
392 
393     for (const auto& role : outputRoles) {
394         if (role.modelIndex < 0 || role.modelIndex >= preparedModels.size()) {
395             LOG(ERROR) << "Invalid modelIndex value " << role.modelIndex;
396             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
397                              "ShimDriver::allocate -- Ou0tput role modeIndex with invalid value");
398         }
399         auto preparedModel = preparedModels[role.modelIndex];
400         if (preparedModel.preparedModel == nullptr) {
401             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
402                              "ShimDriver::allocate -- nullptr model");
403         }
404 
405         auto pmodel = getCompilation(preparedModel.preparedModel.get());
406         if (pmodel == nullptr) {
407             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
408                              "ShimDriver::allocate -- nullptr model");
409         }
410 
411         auto result = mNnapi->getFL5()->ANeuralNetworksMemoryDesc_addOutputRole(
412                 slDesc, pmodel->getCompilation().getHandle(), role.ioIndex, role.probability);
413 
414         if (result != ANEURALNETWORKS_NO_ERROR) {
415             LOG(ERROR) << "SampleDriver::allocate -- ANeuralNetworksMemoryDesc_addInputRole fail.";
416             return toAStatus(ErrorStatus::INVALID_ARGUMENT,
417                              "ShimDriver::allocate -- ANeuralNetworksMemoryDesc_addInputRole fail");
418         }
419         const auto& model = pmodel->getMainModel();
420         const auto& op = model.getOperands()[model.getOutputs()[role.ioIndex]];
421         auto operandType = static_cast<::android::nn::OperandType>(op.operandType.type);
422         if (!type) {
423             type = operandType;
424         }
425         if (dimensions.empty()) {
426             dimensions = op.dimensions;
427         }
428     }
429 
430     auto typeSize = ::android::nn::getNonExtensionSize(*type, dimensions);
431     if (!typeSize.has_value()) {
432         return toAStatus(ErrorStatus::INVALID_ARGUMENT,
433                          "ShimDriver::allocate -- failed to get underlying type size, "
434                          "possibly an extension type");
435     }
436 
437     mNnapi->getFL5()->ANeuralNetworksMemoryDesc_finish(slDesc);
438     auto memory =
439             std::make_shared<::android::nn::sl_wrapper::Memory>(mNnapi.get(), slDesc, *typeSize);
440 
441     if (!memory->isValid()) {
442         LOG(ERROR) << "ShimDriver::allocate -- ANeuralNetworksMemory_createFromDesc failed.";
443         return toAStatus(ErrorStatus::GENERAL_FAILURE,
444                          "ShimDriver::allocate -- ANeuralNetworksMemory_createFromDesc failed");
445     }
446 
447     auto token = mBufferTracker->add(memory);
448     if (token == nullptr) {
449         LOG(ERROR) << "ShimDriver::allocate -- ShimBufferTracker returned invalid token.";
450         return toAStatus(ErrorStatus::GENERAL_FAILURE,
451                          "ShimDriver::allocate -- ShimBufferTracker returned invalid token.");
452     }
453     const uint32_t tokenValue = token->get();
454     auto shimbuffer = ndk::SharedRefBase::make<ShimBuffer>(mNnapi.get(), dimensions, *type,
455                                                            std::move(memory), std::move(token));
456     buffer->buffer = std::move(shimbuffer);
457     buffer->token = tokenValue;
458 
459     return ndk::ScopedAStatus::ok();
460 }
461 
getCapabilities(Capabilities * capabilities)462 ndk::ScopedAStatus ShimDevice::getCapabilities(Capabilities* capabilities) {
463     *capabilities = mCapabilities;
464     return ndk::ScopedAStatus::ok();
465 }
466 
getNumberOfCacheFilesNeeded(NumberOfCacheFiles * numberOfCacheFiles)467 ndk::ScopedAStatus ShimDevice::getNumberOfCacheFilesNeeded(NumberOfCacheFiles* numberOfCacheFiles) {
468     *numberOfCacheFiles = mNumberOfCacheFiles;
469     return ndk::ScopedAStatus::ok();
470 }
471 
getSupportedExtensions(std::vector<Extension> * extensions)472 ndk::ScopedAStatus ShimDevice::getSupportedExtensions(std::vector<Extension>* extensions) {
473     *extensions = mExtensions;
474     return ndk::ScopedAStatus::ok();
475 }
476 
getSupportedOperations(const Model & model,std::vector<bool> * supportedOperations)477 ndk::ScopedAStatus ShimDevice::getSupportedOperations(const Model& model,
478                                                       std::vector<bool>* supportedOperations) {
479     const auto numOperations = model.main.operations.size();
480     supportedOperations->resize(numOperations);
481 
482     ErrorStatus convertErrorStatus = ErrorStatus::NONE;
483     std::vector<uint8_t> copiedOperandValues;
484     auto modelAndMemory =
485             convertFromHAL(mNnapi.get(), model, &copiedOperandValues, &convertErrorStatus);
486     if (!modelAndMemory || modelAndMemory->models.empty()) {
487         LOG(ERROR) << "Failed to convert HAL model to SL model";
488         return toAStatus(convertErrorStatus);
489     }
490 
491     auto annModel = modelAndMemory->models[0].getHandle();
492     auto supportedOps = std::make_unique<bool[]>(numOperations);
493 
494     auto result = mNnapi->getFL5()->ANeuralNetworksModel_getSupportedOperationsForDevices(
495             annModel, &mDevice, /*numDevices=*/1, supportedOps.get());
496     SLW2SAS_RETURN_IF_ERROR(result);
497 
498     std::copy(supportedOps.get(), supportedOps.get() + numOperations, supportedOperations->begin());
499     return ndk::ScopedAStatus::ok();
500 }
501 
getType(DeviceType * type)502 ndk::ScopedAStatus ShimDevice::getType(DeviceType* type) {
503     int32_t deviceType;
504     auto result = mNnapi->getFL5()->ANeuralNetworksDevice_getType(mDevice, &deviceType);
505     SLW2SAS_RETURN_IF_ERROR(result);
506     *type = static_cast<DeviceType>(deviceType);
507     return ndk::ScopedAStatus::ok();
508 }
509 
getVersionString(std::string * versionString)510 ndk::ScopedAStatus ShimDevice::getVersionString(std::string* versionString) {
511     const char* buffer;
512     auto result = mNnapi->getFL5()->ANeuralNetworksDevice_getVersion(mDevice, &buffer);
513     SLW2SAS_RETURN_IF_ERROR(result);
514 
515     *versionString = std::string(buffer);
516     return ndk::ScopedAStatus::ok();
517 }
518 
getIntFds(const std::vector<::ndk::ScopedFileDescriptor> & scopedFds)519 static std::vector<int> getIntFds(const std::vector<::ndk::ScopedFileDescriptor>& scopedFds) {
520     std::vector<int> fds;
521     fds.reserve(scopedFds.size());
522     for (const auto& scopedFd : scopedFds) {
523         fds.push_back(scopedFd.get());
524     }
525     return fds;
526 }
527 
prepareModelCommon(const Model & model,ExecutionPreference preference,Priority priority,int64_t deadlineNs,const std::vector<::ndk::ScopedFileDescriptor> & modelCache,const std::vector<::ndk::ScopedFileDescriptor> & dataCache,const std::vector<uint8_t> & token,const std::vector<TokenValuePair> & compilationHints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix,const std::shared_ptr<IPreparedModelCallback> & callback)528 ndk::ScopedAStatus ShimDevice::prepareModelCommon(
529         const Model& model, ExecutionPreference preference, Priority priority, int64_t deadlineNs,
530         const std::vector<::ndk::ScopedFileDescriptor>& modelCache,
531         const std::vector<::ndk::ScopedFileDescriptor>& dataCache,
532         const std::vector<uint8_t>& token, const std::vector<TokenValuePair>& compilationHints,
533         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
534         const std::shared_ptr<IPreparedModelCallback>& callback) {
535     // TODO(183398748): Run model preparation in detached thread.
536     if (callback == nullptr) {
537         return toAStatus(ErrorStatus::INVALID_ARGUMENT);
538     }
539 
540     auto ndkPreference = convertToNDKPreference(preference);
541     if (!ndkPreference) {
542         callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
543         return toAStatus(ErrorStatus::INVALID_ARGUMENT);
544     }
545     auto ndkPriority = convertToNDKPriority(priority);
546     if (!ndkPriority) {
547         callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
548         return toAStatus(ErrorStatus::INVALID_ARGUMENT);
549     }
550 
551     ErrorStatus convertErrorStatus = ErrorStatus::NONE;
552     std::vector<uint8_t> copiedOperandValues;
553     auto modelAndMemory =
554             convertFromHAL(mNnapi.get(), model, &copiedOperandValues, &convertErrorStatus);
555 
556     if (!modelAndMemory || modelAndMemory->models.empty()) {
557         callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
558         return toAStatus(convertErrorStatus);
559     }
560 
561     // b/185976051, past this point we pretend that compilation is asynchronous, and in
562     /// case of error we return OK status, but communicate the error through the callback.
563     auto compilation = ::android::nn::sl_wrapper::Compilation::createForDevice(
564             mNnapi.get(), &modelAndMemory->models[0], mDevice);
565 
566     SLW2SAS_OK_RETURN_AND_ERROR_CALLBACK_IF_ERROR(compilation.first, callback);
567     SLW2SAS_OK_RETURN_AND_ERROR_CALLBACK_IF_ERROR(compilation.second.setPreference(*ndkPreference),
568                                                   callback);
569     SLW2SAS_OK_RETURN_AND_ERROR_CALLBACK_IF_ERROR(compilation.second.setPriority(*ndkPriority),
570                                                   callback);
571     if (deadlineNs > -1) {
572         std::chrono::time_point<::android::base::boot_clock> deadlinePoint(
573                 std::chrono::nanoseconds{deadlineNs});
574         const auto currentTime = ::android::base::boot_clock::now();
575         const auto timeoutDuration = std::chrono::nanoseconds(deadlinePoint - currentTime);
576         if (timeoutDuration <= std::chrono::nanoseconds::zero()) {
577             callback->notify(ErrorStatus::MISSED_DEADLINE_TRANSIENT, nullptr);
578             return ndk::ScopedAStatus::ok();
579         }
580         SLW2SAS_OK_RETURN_AND_ERROR_CALLBACK_IF_ERROR(
581                 compilation.second.setTimeout(std::max<uint64_t>(1, timeoutDuration.count())),
582                 callback);
583     }
584     if (!modelCache.empty() || !dataCache.empty()) {
585         SLW2SAS_OK_RETURN_AND_ERROR_CALLBACK_IF_ERROR(
586                 compilation.second.setCachingFromFds(getIntFds(modelCache), getIntFds(dataCache),
587                                                      token),
588                 callback);
589     }
590     if (!compilationHints.empty() || !extensionNameToPrefix.empty()) {
591         std::unordered_map<uint16_t, std::string> prefixToName;
592         for (const auto [name, prefix] : extensionNameToPrefix) {
593             prefixToName.emplace(prefix, name);
594         }
595 
596         for (const auto& [token, value] : compilationHints) {
597             const auto uToken = static_cast<uint32_t>(token);
598             const auto prefix = ::android::nn::getExtensionPrefix(uToken);
599             const auto attributeCodeWithinExtension = ::android::nn::getTypeWithinExtension(uToken);
600 
601             const auto it = prefixToName.find(prefix);
602             if (it == prefixToName.end()) {
603                 callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
604                 return toAStatus(ErrorStatus::INVALID_ARGUMENT);
605             }
606             const std::string& extensionName = it->second;
607 
608             SLW2SAS_OK_RETURN_AND_ERROR_CALLBACK_IF_ERROR(
609                     compilation.second.addExtensionAttribute(extensionName,
610                                                              attributeCodeWithinExtension, value),
611                     callback);
612         }
613     }
614 
615     SLW2SAS_OK_RETURN_AND_ERROR_CALLBACK_IF_ERROR(compilation.second.finish(), callback);
616 
617     const std::shared_ptr<ShimPreparedModel> preparedModel =
618             ndk::SharedRefBase::make<ShimPreparedModel>(
619                     mNnapi, mBufferTracker, std::move(compilation.second),
620                     std::move(modelAndMemory->models), std::move(modelAndMemory->memory),
621                     std::move(copiedOperandValues));
622 
623     callback->notify(ErrorStatus::NONE, preparedModel);
624     return ndk::ScopedAStatus::ok();
625 }
626 
prepareModel(const Model & model,ExecutionPreference preference,Priority priority,int64_t deadlineNs,const std::vector<::ndk::ScopedFileDescriptor> & modelCache,const std::vector<::ndk::ScopedFileDescriptor> & dataCache,const std::vector<uint8_t> & token,const std::shared_ptr<IPreparedModelCallback> & callback)627 ndk::ScopedAStatus ShimDevice::prepareModel(
628         const Model& model, ExecutionPreference preference, Priority priority, int64_t deadlineNs,
629         const std::vector<::ndk::ScopedFileDescriptor>& modelCache,
630         const std::vector<::ndk::ScopedFileDescriptor>& dataCache,
631         const std::vector<uint8_t>& token,
632         const std::shared_ptr<IPreparedModelCallback>& callback) {
633     return prepareModelCommon(model, preference, priority, deadlineNs, modelCache, dataCache, token,
634                               /*compilationHints=*/{}, /*extensionNameToPrefix=*/{}, callback);
635 }
636 
prepareModelWithConfig(const Model & model,const PrepareModelConfig & config,const std::shared_ptr<IPreparedModelCallback> & callback)637 ndk::ScopedAStatus ShimDevice::prepareModelWithConfig(
638         const Model& model, const PrepareModelConfig& config,
639         const std::shared_ptr<IPreparedModelCallback>& callback) {
640     return prepareModelCommon(model, config.preference, config.priority, config.deadlineNs,
641                               config.modelCache, config.dataCache, utils::toVec(config.cacheToken),
642                               config.compilationHints, config.extensionNameToPrefix, callback);
643 }
644 
prepareModelFromCache(int64_t,const std::vector<::ndk::ScopedFileDescriptor> &,const std::vector<::ndk::ScopedFileDescriptor> &,const std::vector<uint8_t> &,const std::shared_ptr<IPreparedModelCallback> & callback)645 ndk::ScopedAStatus ShimDevice::prepareModelFromCache(
646         int64_t /*deadlineNs*/, const std::vector<::ndk::ScopedFileDescriptor>& /*modelCache*/,
647         const std::vector<::ndk::ScopedFileDescriptor>& /*dataCache*/,
648         const std::vector<uint8_t>& /*token*/,
649         const std::shared_ptr<IPreparedModelCallback>& callback) {
650     if (callback == nullptr) {
651         return toAStatus(ErrorStatus::INVALID_ARGUMENT);
652     }
653     // The NNAPI runtime will attempt to call this before falling back to
654     // ShimDevice::prepareModel(). This is not a LOG(ERROR) to avoid producing
655     // misleading logcat messages on every compilation request because there is
656     // technically nothing wrong.
657     LOG(DEBUG) << "ShimDevice::prepareModelFromCache() is not supported. Use "
658                   "ShimDevice::prepareModel() instead.";
659     const auto ret = callback->notify(ErrorStatus::GENERAL_FAILURE, nullptr);
660     return toAStatus(ErrorStatus::GENERAL_FAILURE);
661 }
662 
663 }  // namespace aidl::android::hardware::neuralnetworks
664