• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H
18 #define ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H
19 
20 #include "CpuExecutor.h"
21 #include "HalInterfaces.h"
22 #include "NeuralNetworks.h"
23 
24 #include <string>
25 
26 namespace android {
27 namespace nn {
28 namespace sample_driver {
29 
30 using ::android::hardware::MQDescriptorSync;
31 using HidlToken = hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>;
32 
33 // Base class used to create sample drivers for the NN HAL.  This class
34 // provides some implementation of the more common functions.
35 //
36 // Since these drivers simulate hardware, they must run the computations
37 // on the CPU.  An actual driver would not do that.
38 class SampleDriver : public IDevice {
39    public:
40     SampleDriver(const char* name,
41                  const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
mName(name)42         : mName(name), mOperationResolver(operationResolver) {
43         android::nn::initVLogMask();
44     }
~SampleDriver()45     ~SampleDriver() override {}
46     Return<void> getCapabilities(getCapabilities_cb cb) override;
47     Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override;
48     Return<void> getVersionString(getVersionString_cb cb) override;
49     Return<void> getType(getType_cb cb) override;
50     Return<void> getSupportedExtensions(getSupportedExtensions_cb) override;
51     Return<void> getSupportedOperations(const V1_0::Model& model,
52                                         getSupportedOperations_cb cb) override;
53     Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
54                                             getSupportedOperations_1_1_cb cb) override;
55     Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override;
56     Return<ErrorStatus> prepareModel(const V1_0::Model& model,
57                                      const sp<V1_0::IPreparedModelCallback>& callback) override;
58     Return<ErrorStatus> prepareModel_1_1(const V1_1::Model& model, ExecutionPreference preference,
59                                          const sp<V1_0::IPreparedModelCallback>& callback) override;
60     Return<ErrorStatus> prepareModel_1_2(const V1_2::Model& model, ExecutionPreference preference,
61                                          const hidl_vec<hidl_handle>& modelCache,
62                                          const hidl_vec<hidl_handle>& dataCache,
63                                          const HidlToken& token,
64                                          const sp<V1_2::IPreparedModelCallback>& callback) override;
65     Return<ErrorStatus> prepareModelFromCache(
66             const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache,
67             const HidlToken& token, const sp<V1_2::IPreparedModelCallback>& callback) override;
68     Return<DeviceStatus> getStatus() override;
69 
70     // Starts and runs the driver service.  Typically called from main().
71     // This will return only once the service shuts down.
72     int run();
73 
getExecutor()74     CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
75 
76    protected:
77     std::string mName;
78     const IOperationResolver* mOperationResolver;
79 };
80 
81 class SamplePreparedModel : public IPreparedModel {
82    public:
SamplePreparedModel(const Model & model,const SampleDriver * driver)83     SamplePreparedModel(const Model& model, const SampleDriver* driver)
84         : mModel(model), mDriver(driver) {}
~SamplePreparedModel()85     ~SamplePreparedModel() override {}
86     bool initialize();
87     Return<ErrorStatus> execute(const Request& request,
88                                 const sp<V1_0::IExecutionCallback>& callback) override;
89     Return<ErrorStatus> execute_1_2(const Request& request, MeasureTiming measure,
90                                     const sp<V1_2::IExecutionCallback>& callback) override;
91     Return<void> executeSynchronously(const Request& request, MeasureTiming measure,
92                                       executeSynchronously_cb cb) override;
93     Return<void> configureExecutionBurst(
94             const sp<V1_2::IBurstCallback>& callback,
95             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
96             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
97             configureExecutionBurst_cb cb) override;
98 
99    private:
100     Model mModel;
101     const SampleDriver* mDriver;
102     std::vector<RunTimePoolInfo> mPoolInfos;
103 };
104 
105 }  // namespace sample_driver
106 }  // namespace nn
107 }  // namespace android
108 
109 #endif  // ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H
110