• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
19 
20 #include <hwbinder/IPCThreadState.h>
21 
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "BufferTracker.h"
28 #include "CpuExecutor.h"
29 #include "HalInterfaces.h"
30 #include "NeuralNetworks.h"
31 
32 namespace android {
33 namespace nn {
34 namespace sample_driver {
35 
36 using hardware::MQDescriptorSync;
37 
38 // Manages the data buffer for an operand.
39 class SampleBuffer : public hal::IBuffer {
40    public:
SampleBuffer(std::shared_ptr<ManagedBuffer> buffer,std::unique_ptr<BufferTracker::Token> token)41     SampleBuffer(std::shared_ptr<ManagedBuffer> buffer, std::unique_ptr<BufferTracker::Token> token)
42         : kBuffer(std::move(buffer)), kToken(std::move(token)) {
43         CHECK(kBuffer != nullptr);
44         CHECK(kToken != nullptr);
45     }
46     hal::Return<hal::ErrorStatus> copyTo(const hal::hidl_memory& dst) override;
47     hal::Return<hal::ErrorStatus> copyFrom(const hal::hidl_memory& src,
48                                            const hal::hidl_vec<uint32_t>& dimensions) override;
49 
50    private:
51     const std::shared_ptr<ManagedBuffer> kBuffer;
52     const std::unique_ptr<BufferTracker::Token> kToken;
53 };
54 
55 // Base class used to create sample drivers for the NN HAL.  This class
56 // provides some implementation of the more common functions.
57 //
58 // Since these drivers simulate hardware, they must run the computations
59 // on the CPU.  An actual driver would not do that.
60 class SampleDriver : public hal::IDevice {
61    public:
62     SampleDriver(const char* name,
63                  const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
mName(name)64         : mName(name),
65           mOperationResolver(operationResolver),
66           mBufferTracker(BufferTracker::create()) {
67         android::nn::initVLogMask();
68     }
69     hal::Return<void> getCapabilities(getCapabilities_cb cb) override;
70     hal::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override;
71     hal::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
72     hal::Return<void> getVersionString(getVersionString_cb cb) override;
73     hal::Return<void> getType(getType_cb cb) override;
74     hal::Return<void> getSupportedExtensions(getSupportedExtensions_cb) override;
75     hal::Return<void> getSupportedOperations(const hal::V1_0::Model& model,
76                                              getSupportedOperations_cb cb) override;
77     hal::Return<void> getSupportedOperations_1_1(const hal::V1_1::Model& model,
78                                                  getSupportedOperations_1_1_cb cb) override;
79     hal::Return<void> getSupportedOperations_1_2(const hal::V1_2::Model& model,
80                                                  getSupportedOperations_1_2_cb cb) override;
81     hal::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override;
82     hal::Return<hal::V1_0::ErrorStatus> prepareModel(
83             const hal::V1_0::Model& model,
84             const sp<hal::V1_0::IPreparedModelCallback>& callback) override;
85     hal::Return<hal::V1_0::ErrorStatus> prepareModel_1_1(
86             const hal::V1_1::Model& model, hal::ExecutionPreference preference,
87             const sp<hal::V1_0::IPreparedModelCallback>& callback) override;
88     hal::Return<hal::V1_0::ErrorStatus> prepareModel_1_2(
89             const hal::V1_2::Model& model, hal::ExecutionPreference preference,
90             const hal::hidl_vec<hal::hidl_handle>& modelCache,
91             const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token,
92             const sp<hal::V1_2::IPreparedModelCallback>& callback) override;
93     hal::Return<hal::V1_3::ErrorStatus> prepareModel_1_3(
94             const hal::V1_3::Model& model, hal::ExecutionPreference preference,
95             hal::Priority priority, const hal::OptionalTimePoint& deadline,
96             const hal::hidl_vec<hal::hidl_handle>& modelCache,
97             const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token,
98             const sp<hal::V1_3::IPreparedModelCallback>& callback) override;
99     hal::Return<hal::V1_0::ErrorStatus> prepareModelFromCache(
100             const hal::hidl_vec<hal::hidl_handle>& modelCache,
101             const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token,
102             const sp<hal::V1_2::IPreparedModelCallback>& callback) override;
103     hal::Return<hal::V1_3::ErrorStatus> prepareModelFromCache_1_3(
104             const hal::OptionalTimePoint& deadline,
105             const hal::hidl_vec<hal::hidl_handle>& modelCache,
106             const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token,
107             const sp<hal::V1_3::IPreparedModelCallback>& callback) override;
108     hal::Return<hal::DeviceStatus> getStatus() override;
109     hal::Return<void> allocate(const hal::V1_3::BufferDesc& desc,
110                                const hal::hidl_vec<sp<hal::V1_3::IPreparedModel>>& preparedModels,
111                                const hal::hidl_vec<hal::V1_3::BufferRole>& inputRoles,
112                                const hal::hidl_vec<hal::V1_3::BufferRole>& outputRoles,
113                                allocate_cb cb) override;
114 
115     // Starts and runs the driver service.  Typically called from main().
116     // This will return only once the service shuts down.
117     int run();
118 
getExecutor()119     CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
getBufferTracker()120     const std::shared_ptr<BufferTracker>& getBufferTracker() const { return mBufferTracker; }
121 
122    protected:
123     std::string mName;
124     const IOperationResolver* mOperationResolver;
125     const std::shared_ptr<BufferTracker> mBufferTracker;
126 };
127 
128 class SamplePreparedModel : public hal::IPreparedModel {
129    public:
SamplePreparedModel(const hal::Model & model,const SampleDriver * driver,hal::ExecutionPreference preference,uid_t userId,hal::Priority priority)130     SamplePreparedModel(const hal::Model& model, const SampleDriver* driver,
131                         hal::ExecutionPreference preference, uid_t userId, hal::Priority priority)
132         : mModel(model),
133           mDriver(driver),
134           kPreference(preference),
135           kUserId(userId),
136           kPriority(priority) {
137         (void)kUserId;
138         (void)kPriority;
139     }
140     bool initialize();
141     hal::Return<hal::V1_0::ErrorStatus> execute(
142             const hal::V1_0::Request& request,
143             const sp<hal::V1_0::IExecutionCallback>& callback) override;
144     hal::Return<hal::V1_0::ErrorStatus> execute_1_2(
145             const hal::V1_0::Request& request, hal::MeasureTiming measure,
146             const sp<hal::V1_2::IExecutionCallback>& callback) override;
147     hal::Return<hal::V1_3::ErrorStatus> execute_1_3(
148             const hal::V1_3::Request& request, hal::MeasureTiming measure,
149             const hal::OptionalTimePoint& deadline,
150             const hal::OptionalTimeoutDuration& loopTimeoutDuration,
151             const sp<hal::V1_3::IExecutionCallback>& callback) override;
152     hal::Return<void> executeSynchronously(const hal::V1_0::Request& request,
153                                            hal::MeasureTiming measure,
154                                            executeSynchronously_cb cb) override;
155     hal::Return<void> executeSynchronously_1_3(
156             const hal::V1_3::Request& request, hal::MeasureTiming measure,
157             const hal::OptionalTimePoint& deadline,
158             const hal::OptionalTimeoutDuration& loopTimeoutDuration,
159             executeSynchronously_1_3_cb cb) override;
160     hal::Return<void> configureExecutionBurst(
161             const sp<hal::V1_2::IBurstCallback>& callback,
162             const MQDescriptorSync<hal::V1_2::FmqRequestDatum>& requestChannel,
163             const MQDescriptorSync<hal::V1_2::FmqResultDatum>& resultChannel,
164             configureExecutionBurst_cb cb) override;
165     hal::Return<void> executeFenced(const hal::Request& request,
166                                     const hal::hidl_vec<hal::hidl_handle>& wait_for,
167                                     hal::MeasureTiming measure,
168                                     const hal::OptionalTimePoint& deadline,
169                                     const hal::OptionalTimeoutDuration& loopTimeoutDuration,
170                                     const hal::OptionalTimeoutDuration& duration,
171                                     executeFenced_cb callback) override;
getModel()172     const hal::Model* getModel() const { return &mModel; }
173 
174    private:
175     hal::Model mModel;
176     const SampleDriver* mDriver;
177     std::vector<RunTimePoolInfo> mPoolInfos;
178     const hal::ExecutionPreference kPreference;
179     const uid_t kUserId;
180     const hal::Priority kPriority;
181 };
182 
183 class SampleFencedExecutionCallback : public hal::IFencedExecutionCallback {
184    public:
SampleFencedExecutionCallback(hal::Timing timingSinceLaunch,hal::Timing timingAfterFence,hal::ErrorStatus error)185     SampleFencedExecutionCallback(hal::Timing timingSinceLaunch, hal::Timing timingAfterFence,
186                                   hal::ErrorStatus error)
187         : kTimingSinceLaunch(timingSinceLaunch),
188           kTimingAfterFence(timingAfterFence),
189           kErrorStatus(error) {}
getExecutionInfo(getExecutionInfo_cb callback)190     hal::Return<void> getExecutionInfo(getExecutionInfo_cb callback) override {
191         callback(kErrorStatus, kTimingSinceLaunch, kTimingAfterFence);
192         return hal::Void();
193     }
194 
195    private:
196     const hal::Timing kTimingSinceLaunch;
197     const hal::Timing kTimingAfterFence;
198     const hal::ErrorStatus kErrorStatus;
199 };
200 
201 }  // namespace sample_driver
202 }  // namespace nn
203 }  // namespace android
204 
205 #endif  // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
206