1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_DRIVER_SAMPLE_SAMPLE_DRIVER_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_DRIVER_SAMPLE_SAMPLE_DRIVER_H 19 20 #include <CpuExecutor.h> 21 #include <HalBufferTracker.h> 22 #include <HalInterfaces.h> 23 #include <hwbinder/IPCThreadState.h> 24 25 #include <memory> 26 #include <string> 27 #include <utility> 28 #include <vector> 29 30 #include "NeuralNetworks.h" 31 32 namespace android { 33 namespace nn { 34 namespace sample_driver { 35 36 using hardware::MQDescriptorSync; 37 38 // Manages the data buffer for an operand. 39 class SampleBuffer : public V1_3::IBuffer { 40 public: SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer,std::unique_ptr<HalBufferTracker::Token> token)41 SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer, 42 std::unique_ptr<HalBufferTracker::Token> token) 43 : kBuffer(std::move(buffer)), kToken(std::move(token)) { 44 CHECK(kBuffer != nullptr); 45 CHECK(kToken != nullptr); 46 } 47 hardware::Return<V1_3::ErrorStatus> copyTo(const hardware::hidl_memory& dst) override; 48 hardware::Return<V1_3::ErrorStatus> copyFrom( 49 const hardware::hidl_memory& src, 50 const hardware::hidl_vec<uint32_t>& dimensions) override; 51 52 private: 53 const std::shared_ptr<HalManagedBuffer> kBuffer; 54 const std::unique_ptr<HalBufferTracker::Token> kToken; 55 }; 56 57 // Base class used to create sample drivers for the NN HAL. This class 58 // provides some implementation of the more common functions. 59 // 60 // Since these drivers simulate hardware, they must run the computations 61 // on the CPU. An actual driver would not do that. 62 class SampleDriver : public V1_3::IDevice { 63 public: 64 SampleDriver(const char* name, 65 const IOperationResolver* operationResolver = BuiltinOperationResolver::get()) mName(name)66 : mName(name), 67 mOperationResolver(operationResolver), 68 mHalBufferTracker(HalBufferTracker::create()) { 69 android::nn::initVLogMask(); 70 } 71 hardware::Return<void> getCapabilities(getCapabilities_cb cb) override; 72 hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override; 73 hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override; 74 hardware::Return<void> getVersionString(getVersionString_cb cb) override; 75 hardware::Return<void> getType(getType_cb cb) override; 76 hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb) override; 77 hardware::Return<void> getSupportedOperations(const V1_0::Model& model, 78 getSupportedOperations_cb cb) override; 79 hardware::Return<void> getSupportedOperations_1_1(const V1_1::Model& model, 80 getSupportedOperations_1_1_cb cb) override; 81 hardware::Return<void> getSupportedOperations_1_2(const V1_2::Model& model, 82 getSupportedOperations_1_2_cb cb) override; 83 hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override; 84 hardware::Return<V1_0::ErrorStatus> prepareModel( 85 const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) override; 86 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1( 87 const V1_1::Model& model, V1_1::ExecutionPreference preference, 88 const sp<V1_0::IPreparedModelCallback>& callback) override; 89 hardware::Return<V1_0::ErrorStatus> prepareModel_1_2( 90 const V1_2::Model& model, V1_1::ExecutionPreference preference, 91 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 92 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 93 const sp<V1_2::IPreparedModelCallback>& callback) override; 94 hardware::Return<V1_3::ErrorStatus> prepareModel_1_3( 95 const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority, 96 const V1_3::OptionalTimePoint& deadline, 97 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 98 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 99 const sp<V1_3::IPreparedModelCallback>& callback) override; 100 hardware::Return<V1_0::ErrorStatus> prepareModelFromCache( 101 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 102 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 103 const sp<V1_2::IPreparedModelCallback>& callback) override; 104 hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3( 105 const V1_3::OptionalTimePoint& deadline, 106 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 107 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 108 const sp<V1_3::IPreparedModelCallback>& callback) override; 109 hardware::Return<V1_0::DeviceStatus> getStatus() override; 110 hardware::Return<void> allocate( 111 const V1_3::BufferDesc& desc, 112 const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels, 113 const hardware::hidl_vec<V1_3::BufferRole>& inputRoles, 114 const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) override; 115 getExecutor()116 CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); } getHalBufferTracker()117 const std::shared_ptr<HalBufferTracker>& getHalBufferTracker() const { 118 return mHalBufferTracker; 119 } 120 121 protected: 122 std::string mName; 123 const IOperationResolver* mOperationResolver; 124 const std::shared_ptr<HalBufferTracker> mHalBufferTracker; 125 }; 126 127 class SamplePreparedModel : public V1_3::IPreparedModel { 128 public: SamplePreparedModel(const V1_3::Model & model,const SampleDriver * driver,V1_1::ExecutionPreference preference,uid_t userId,V1_3::Priority priority)129 SamplePreparedModel(const V1_3::Model& model, const SampleDriver* driver, 130 V1_1::ExecutionPreference preference, uid_t userId, V1_3::Priority priority) 131 : mModel(model), 132 mDriver(driver), 133 kPreference(preference), 134 kUserId(userId), 135 kPriority(priority) { 136 (void)kUserId; 137 (void)kPriority; 138 } 139 bool initialize(); 140 hardware::Return<V1_0::ErrorStatus> execute( 141 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override; 142 hardware::Return<V1_0::ErrorStatus> execute_1_2( 143 const V1_0::Request& request, V1_2::MeasureTiming measure, 144 const sp<V1_2::IExecutionCallback>& callback) override; 145 hardware::Return<V1_3::ErrorStatus> execute_1_3( 146 const V1_3::Request& request, V1_2::MeasureTiming measure, 147 const V1_3::OptionalTimePoint& deadline, 148 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, 149 const sp<V1_3::IExecutionCallback>& callback) override; 150 hardware::Return<void> executeSynchronously(const V1_0::Request& request, 151 V1_2::MeasureTiming measure, 152 executeSynchronously_cb cb) override; 153 hardware::Return<void> executeSynchronously_1_3( 154 const V1_3::Request& request, V1_2::MeasureTiming measure, 155 const V1_3::OptionalTimePoint& deadline, 156 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, 157 executeSynchronously_1_3_cb cb) override; 158 hardware::Return<void> configureExecutionBurst( 159 const sp<V1_2::IBurstCallback>& callback, 160 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel, 161 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel, 162 configureExecutionBurst_cb cb) override; 163 hardware::Return<void> executeFenced(const V1_3::Request& request, 164 const hardware::hidl_vec<hardware::hidl_handle>& wait_for, 165 V1_2::MeasureTiming measure, 166 const V1_3::OptionalTimePoint& deadline, 167 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, 168 const V1_3::OptionalTimeoutDuration& duration, 169 executeFenced_cb callback) override; getModel()170 const V1_3::Model* getModel() const { return &mModel; } 171 172 protected: 173 V1_3::Model mModel; 174 const SampleDriver* mDriver; 175 std::vector<RunTimePoolInfo> mPoolInfos; 176 const V1_1::ExecutionPreference kPreference; 177 const uid_t kUserId; 178 const V1_3::Priority kPriority; 179 }; 180 181 class SampleFencedExecutionCallback : public V1_3::IFencedExecutionCallback { 182 public: SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch,V1_2::Timing timingAfterFence,V1_3::ErrorStatus error)183 SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch, V1_2::Timing timingAfterFence, 184 V1_3::ErrorStatus error) 185 : kTimingSinceLaunch(timingSinceLaunch), 186 kTimingAfterFence(timingAfterFence), 187 kErrorStatus(error) {} getExecutionInfo(getExecutionInfo_cb callback)188 hardware::Return<void> getExecutionInfo(getExecutionInfo_cb callback) override { 189 callback(kErrorStatus, kTimingSinceLaunch, kTimingAfterFence); 190 return hardware::Void(); 191 } 192 193 private: 194 const V1_2::Timing kTimingSinceLaunch; 195 const V1_2::Timing kTimingAfterFence; 196 const V1_3::ErrorStatus kErrorStatus; 197 }; 198 199 } // namespace sample_driver 200 } // namespace nn 201 } // namespace android 202 203 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_DRIVER_SAMPLE_SAMPLE_DRIVER_H 204