1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H 18 #define ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H 19 20 #include "CpuExecutor.h" 21 #include "HalInterfaces.h" 22 #include "NeuralNetworks.h" 23 24 #include <string> 25 26 namespace android { 27 namespace nn { 28 namespace sample_driver { 29 30 using ::android::hardware::MQDescriptorSync; 31 using HidlToken = hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>; 32 33 // Base class used to create sample drivers for the NN HAL. This class 34 // provides some implementation of the more common functions. 35 // 36 // Since these drivers simulate hardware, they must run the computations 37 // on the CPU. An actual driver would not do that. 38 class SampleDriver : public IDevice { 39 public: 40 SampleDriver(const char* name, 41 const IOperationResolver* operationResolver = BuiltinOperationResolver::get()) mName(name)42 : mName(name), mOperationResolver(operationResolver) { 43 android::nn::initVLogMask(); 44 } ~SampleDriver()45 ~SampleDriver() override {} 46 Return<void> getCapabilities(getCapabilities_cb cb) override; 47 Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override; 48 Return<void> getVersionString(getVersionString_cb cb) override; 49 Return<void> getType(getType_cb cb) override; 50 Return<void> getSupportedExtensions(getSupportedExtensions_cb) override; 51 Return<void> getSupportedOperations(const V1_0::Model& model, 52 getSupportedOperations_cb cb) override; 53 Return<void> getSupportedOperations_1_1(const V1_1::Model& model, 54 getSupportedOperations_1_1_cb cb) override; 55 Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override; 56 Return<ErrorStatus> prepareModel(const V1_0::Model& model, 57 const sp<V1_0::IPreparedModelCallback>& callback) override; 58 Return<ErrorStatus> prepareModel_1_1(const V1_1::Model& model, ExecutionPreference preference, 59 const sp<V1_0::IPreparedModelCallback>& callback) override; 60 Return<ErrorStatus> prepareModel_1_2(const V1_2::Model& model, ExecutionPreference preference, 61 const hidl_vec<hidl_handle>& modelCache, 62 const hidl_vec<hidl_handle>& dataCache, 63 const HidlToken& token, 64 const sp<V1_2::IPreparedModelCallback>& callback) override; 65 Return<ErrorStatus> prepareModelFromCache( 66 const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache, 67 const HidlToken& token, const sp<V1_2::IPreparedModelCallback>& callback) override; 68 Return<DeviceStatus> getStatus() override; 69 70 // Starts and runs the driver service. Typically called from main(). 71 // This will return only once the service shuts down. 72 int run(); 73 getExecutor()74 CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); } 75 76 protected: 77 std::string mName; 78 const IOperationResolver* mOperationResolver; 79 }; 80 81 class SamplePreparedModel : public IPreparedModel { 82 public: SamplePreparedModel(const Model & model,const SampleDriver * driver)83 SamplePreparedModel(const Model& model, const SampleDriver* driver) 84 : mModel(model), mDriver(driver) {} ~SamplePreparedModel()85 ~SamplePreparedModel() override {} 86 bool initialize(); 87 Return<ErrorStatus> execute(const Request& request, 88 const sp<V1_0::IExecutionCallback>& callback) override; 89 Return<ErrorStatus> execute_1_2(const Request& request, MeasureTiming measure, 90 const sp<V1_2::IExecutionCallback>& callback) override; 91 Return<void> executeSynchronously(const Request& request, MeasureTiming measure, 92 executeSynchronously_cb cb) override; 93 Return<void> configureExecutionBurst( 94 const sp<V1_2::IBurstCallback>& callback, 95 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel, 96 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel, 97 configureExecutionBurst_cb cb) override; 98 99 private: 100 Model mModel; 101 const SampleDriver* mDriver; 102 std::vector<RunTimePoolInfo> mPoolInfos; 103 }; 104 105 } // namespace sample_driver 106 } // namespace nn 107 } // namespace android 108 109 #endif // ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H 110