1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <aidl/android/hardware/neuralnetworks/BnPreparedModel.h> 20 #include <android-base/logging.h> 21 22 #include <memory> 23 #include <utility> 24 #include <vector> 25 26 #include "ShimDevice.h" 27 #include "SupportLibrary.h" 28 #include "SupportLibraryWrapper.h" 29 30 namespace aidl::android::hardware::neuralnetworks { 31 32 class ShimPreparedModel : public BnPreparedModel { 33 public: ShimPreparedModel(std::shared_ptr<const NnApiSupportLibrary> nnapi,std::shared_ptr<ShimBufferTracker> bufferTracker,::android::nn::sl_wrapper::Compilation compilation,std::vector<::android::nn::sl_wrapper::Model> mainAndReferencedModels,std::vector<std::unique_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,std::vector<uint8_t> copiedOperandValues)34 ShimPreparedModel(std::shared_ptr<const NnApiSupportLibrary> nnapi, 35 std::shared_ptr<ShimBufferTracker> bufferTracker, 36 ::android::nn::sl_wrapper::Compilation compilation, 37 std::vector<::android::nn::sl_wrapper::Model> mainAndReferencedModels, 38 std::vector<std::unique_ptr<::android::nn::sl_wrapper::Memory>> memoryPools, 39 std::vector<uint8_t> copiedOperandValues) 40 : mNnapi(nnapi), 41 mBufferTracker(bufferTracker), 42 mCompilation(std::move(compilation)), 43 mMainAndReferencedModels(std::move(mainAndReferencedModels)), 44 mMemoryPools(std::move(memoryPools)), 45 mCopiedOperandValues(std::move(copiedOperandValues)) { 46 CHECK(mMainAndReferencedModels.size() > 0); 47 }; 48 49 ::ndk::ScopedAStatus executeSynchronously(const Request& request, bool measureTiming, 50 int64_t deadlineNs, int64_t loopTimeoutDurationNs, 51 ExecutionResult* executionResults) override; 52 ::ndk::ScopedAStatus executeFenced(const Request& request, 53 const std::vector<::ndk::ScopedFileDescriptor>& waitFor, 54 bool measureTiming, int64_t deadlineNs, 55 int64_t loopTimeoutDurationNs, int64_t durationNs, 56 FencedExecutionResult* fencedExecutionResult) override; 57 ::ndk::ScopedAStatus executeSynchronouslyWithConfig(const Request& request, 58 const ExecutionConfig& config, 59 int64_t deadlineNs, 60 ExecutionResult* executionResult) override; 61 ::ndk::ScopedAStatus executeFencedWithConfig( 62 const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor, 63 const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs, 64 FencedExecutionResult* executionResult) override; 65 66 ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<IBurst>* burst) override; 67 ndk::ScopedAStatus createReusableExecution(const Request& request, 68 const ExecutionConfig& config, 69 std::shared_ptr<IExecution>* execution) override; 70 getCompilation()71 const ::android::nn::sl_wrapper::Compilation& getCompilation() const { return mCompilation; } getMainModel()72 const ::android::nn::sl_wrapper::Model& getMainModel() const { 73 return mMainAndReferencedModels[0]; 74 } 75 76 private: 77 ErrorStatus parseInputs( 78 const Request& request, bool measure, int64_t deadlineNs, int64_t loopTimeoutDurationNs, 79 ::android::nn::sl_wrapper::Execution* execution, 80 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>>* requestMemoryPools, 81 const std::vector<TokenValuePair>& executionHints, 82 const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix); 83 84 ::ndk::ScopedAStatus executeSynchronouslyCommon( 85 const Request& request, bool measureTiming, int64_t deadlineNs, 86 int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& executionHints, 87 const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix, 88 ExecutionResult* executionResult); 89 ::ndk::ScopedAStatus executeFencedCommon( 90 const Request& request, const std::vector<::ndk::ScopedFileDescriptor>& waitFor, 91 bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, 92 int64_t durationNs, const std::vector<TokenValuePair>& executionHints, 93 const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix, 94 FencedExecutionResult* fencedExecutionResult); 95 96 std::shared_ptr<const NnApiSupportLibrary> mNnapi; 97 std::shared_ptr<ShimBufferTracker> mBufferTracker; 98 99 ::android::nn::sl_wrapper::Compilation mCompilation; 100 std::vector<::android::nn::sl_wrapper::Model> mMainAndReferencedModels; 101 std::vector<std::unique_ptr<::android::nn::sl_wrapper::Memory>> mMemoryPools; 102 std::vector<uint8_t> mCopiedOperandValues; 103 }; 104 105 } // namespace aidl::android::hardware::neuralnetworks 106