• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H
19 
20 #include <LegacyUtils.h>
21 #include <android-base/macros.h>
22 #include <nnapi/IBurst.h>
23 #include <nnapi/IDevice.h>
24 #include <nnapi/Types.h>
25 
26 #include <map>
27 #include <memory>
28 #include <string>
29 #include <tuple>
30 #include <unordered_set>
31 #include <utility>
32 #include <vector>
33 
34 #include "ExecutionCallback.h"
35 #include "Memory.h"
36 
37 namespace android {
38 namespace nn {
39 
40 // Forward declaration
41 class Device;
42 class MetaModel;
43 class ModelArgumentInfo;
44 
45 // A unified interface for a reusable execution with cached resources.
46 // This object provides no thread-safety guarantee. The caller must guarantee there is at most one
47 // call to RuntimeExecution::compute or RuntimeExecution::computeFenced on the same RuntimeExecution
48 // object in flight at a time.
49 class RuntimeExecution {
50     DISALLOW_COPY_AND_ASSIGN(RuntimeExecution);
51 
52    public:
53     RuntimeExecution() = default;
54     virtual ~RuntimeExecution() = default;
55 
56     virtual std::tuple<int, std::vector<OutputShape>, Timing> compute(
57             const SharedBurst& burstController, const OptionalTimePoint& deadline) const = 0;
58 
59     // The returned timing information is only valid if the callback is nullptr.
60     // Returns error_code, sync_fence, callback and timing.
61     virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> computeFenced(
62             const std::vector<int>& waitFor, const OptionalTimePoint& deadline,
63             const OptionalDuration& timeoutDurationAfterFence) const = 0;
64 };
65 
66 // A unified interface for actual driver prepared model as well as the CPU.
67 class RuntimePreparedModel {
68     DISALLOW_COPY_AND_ASSIGN(RuntimePreparedModel);
69 
70    public:
71     RuntimePreparedModel() = default;
72     virtual ~RuntimePreparedModel() = default;
73 
74     virtual const Device* getDevice() const = 0;
75     virtual SharedPreparedModel getInterface() const = 0;
76 
77     // Perform computation with given input/output argument info and memory pools.
78     virtual std::tuple<int, std::vector<OutputShape>, Timing> execute(
79             const std::vector<ModelArgumentInfo>& inputs,
80             const std::vector<ModelArgumentInfo>& outputs,
81             const std::vector<const RuntimeMemory*>& memories, const SharedBurst& burstController,
82             MeasureTiming measure, const OptionalTimePoint& deadline,
83             const OptionalDuration& loopTimeoutDuration,
84             const std::vector<TokenValuePair>& metaData) const = 0;
85 
86     // Perform fenced computation with given input/output argument info and memory pools.
87     // The returned timing information is only valid if the callback is nullptr.
88     // Returns error_code, sync_fence, callback and timing.
89     virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> executeFenced(
90             const std::vector<ModelArgumentInfo>& inputs,
91             const std::vector<ModelArgumentInfo>& outputs,
92             const std::vector<const RuntimeMemory*>& memories, const std::vector<int>& waitFor,
93             MeasureTiming measure, const OptionalTimePoint& deadline,
94             const OptionalDuration& loopTimeoutDuration,
95             const OptionalDuration& timeoutDurationAfterFence,
96             const std::vector<TokenValuePair>& metaData) const = 0;
97 
98     // Create a reusable execution with given input/output argument info and memory pools.
99     virtual std::pair<int, std::shared_ptr<RuntimeExecution>> createReusableExecution(
100             const std::vector<ModelArgumentInfo>& inputs,
101             const std::vector<ModelArgumentInfo>& outputs,
102             const std::vector<const RuntimeMemory*>& memories, MeasureTiming measure,
103             const OptionalDuration& loopTimeoutDuration,
104             const std::vector<TokenValuePair>& metaData) const = 0;
105 
106     virtual GeneralResult<SharedBurst> configureExecutionBurst() const = 0;
107 
108     virtual MemoryPreference getMemoryPreference() const = 0;
109 };
110 
111 using ModelFactory = std::function<Model()>;
112 
113 struct CacheHandles {
114     std::vector<SharedHandle> modelCache;
115     std::vector<SharedHandle> dataCache;
116 };
117 
118 using CacheDir = std::string;
119 
120 struct CacheInfo {
121     std::variant<CacheDir, CacheHandles> variant;
122 };
123 
124 // A unified interface for actual driver devices as well as the CPU
125 class Device {
126     DISALLOW_COPY_AND_ASSIGN(Device);
127 
128    public:
129     Device() = default;
130     virtual ~Device() = default;
131 
132     // Introspection methods returning device information
133     virtual const std::string& getName() const = 0;
134     virtual const std::string& getVersionString() const = 0;
135     virtual Version getFeatureLevel() const = 0;
136     virtual int32_t getType() const = 0;
137     virtual const std::vector<Extension>& getSupportedExtensions() const = 0;
138 
139     // See the MetaModel class in MetaModel.h for more details.
140     virtual std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const = 0;
141 
142     virtual const Capabilities& getCapabilities() const = 0;
143     virtual Capabilities::PerformanceInfo getPerformance(OperandType type) const = 0;
144     virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const = 0;
145     virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const = 0;
146     virtual Capabilities::PerformanceInfo getIfPerformance() const = 0;
147     virtual Capabilities::PerformanceInfo getWhilePerformance() const = 0;
148     virtual std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const = 0;
149     virtual bool isCachingSupported() const = 0;
150     virtual int wait() const = 0;
151 
152     virtual std::pair<int, std::shared_ptr<RuntimePreparedModel>> prepareModel(
153             const ModelFactory& makeModel, ExecutionPreference preference, Priority priority,
154             const OptionalTimePoint& deadline, const CacheInfo& cacheInfo,
155             const std::optional<CacheToken>& maybeToken,
156             const std::vector<TokenValuePair>& metaData,
157             const std::vector<ExtensionNameAndPrefix>& extensionNameAndPrefix) const = 0;
158 
159     // The caller is responsible for making sure the MemoryDescriptor only contains
160     // PreparedModels from the same Device.
161     virtual std::pair<int, std::unique_ptr<RuntimeMemory>> allocate(const MemoryDescriptor& desc,
162                                                                     OperandType type) const = 0;
163 };
164 
165 // Manages the NN HAL devices.  Only one instance of this class will exist.
166 // Use get() to retrieve it.
167 class DeviceManager {
168    public:
getDrivers()169     const std::vector<std::shared_ptr<Device>>& getDrivers() const {
170         if (mSetCpuOnly || mDebugNNCpuOnly) {
171             return mDevicesCpuOnly;
172         }
173         return mDevices;
174     }
175 
176     // Gets the runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h).
getRuntimeVersion()177     Version getRuntimeVersion() const { return mRuntimeVersion; }
178 
179     // Gets the runtime feature level corresponding to getServerFeatureLevelFlag (in ServerFlag.h).
180     int64_t getRuntimeFeatureLevel() const;
181 
182     // Convert the internal Version level representation to the NDK representation.
183     static int64_t versionToFeatureLevel(Version::Level versionLevel);
184 
185     // Returns whether platform telemetry is enabled.
isPlatformTelemetryEnabled()186     bool isPlatformTelemetryEnabled() const { return mIsPlatformTelemetryEnabled; }
187 
188     // For testing only:
setUseCpuOnly(bool useCpuOnly)189     void setUseCpuOnly(bool useCpuOnly) { mSetCpuOnly = useCpuOnly; }
getUseCpuOnly()190     bool getUseCpuOnly() const { return mSetCpuOnly; }
191 
syncExecCpu()192     bool syncExecCpu() const { return mSyncExecCpu; }
syncExecRuntime()193     bool syncExecRuntime() const { return mSyncExecRuntime; }
194 
195     // How to handle graph partitioning?
196     // 0 - Don't do graph partitioning.
197     // 1 - Do graph partitioning; but fall back to non-partitioned
198     //     execution if there is a partitioning failure.
199     // 2 - Do graph partitioning, and rely on it; there is no fallback.
200     enum { kPartitioningNo = 0, kPartitioningWithFallback = 1, kPartitioningWithoutFallback = 2 };
getPartitioning()201     uint32_t getPartitioning() const { return mPartitioning; }
partitioningAllowsFallback(uint32_t partitioning)202     static bool partitioningAllowsFallback(uint32_t partitioning) {
203         return partitioning == kPartitioningWithFallback;
204     }
205 
strictSlicing()206     bool strictSlicing() const { return mStrictSlicing; }
207 
208     // Returns the singleton manager.
209     static DeviceManager* get();
210 
211     // Returns the singleton Cpu device.
212     static std::shared_ptr<Device> getCpuDevice();
213 
214     // The forTest_* functions below are solely intended for use by unit tests.
215 
216     // Returns all devices (ignores the cpu-only flags).
forTest_getDevices()217     std::vector<std::shared_ptr<Device>> forTest_getDevices() const { return mDevices; }
218 
219     // Sets the device list (does not affect cpu-only queries).
forTest_setDevices(std::vector<std::shared_ptr<Device>> devices)220     void forTest_setDevices(std::vector<std::shared_ptr<Device>> devices) {
221         mDevices = std::move(devices);
222     }
223 
224     // Register a test device.
forTest_registerDevice(const SharedDevice & device)225     void forTest_registerDevice(const SharedDevice& device) { registerDevice(device); }
226 
227     // Re-initialize the list of available devices.
forTest_reInitializeDeviceList()228     void forTest_reInitializeDeviceList() {
229         mDevices.clear();
230         mDevicesCpuOnly.clear();
231         findAvailableDevices();
232     }
233 
234     // Make a test device
235     static std::shared_ptr<Device> forTest_makeDriverDevice(const SharedDevice& device);
236 
forTest_isCpuDevice(const ANeuralNetworksDevice * device)237     bool forTest_isCpuDevice(const ANeuralNetworksDevice* device) const {
238         return reinterpret_cast<const Device*>(device) == getCpuDevice().get();
239     }
240 
241    private:
242     // Builds the list of available drivers and queries their capabilities.
243     DeviceManager();
244 
245     // Adds a device for the manager to use.
246     void registerDevice(const SharedDevice& device);
247 
248     void findAvailableDevices();
249 
250     // Runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h).
251     Version mRuntimeVersion;
252 
253     // Holds whether platform telemetry is enabled, as indicated by getServerTelemetryEnableFlag (in
254     // ServerFlag.h).
255     bool mIsPlatformTelemetryEnabled;
256 
257     // List of all the devices we discovered (including CpuDevice).
258     std::vector<std::shared_ptr<Device>> mDevices;
259 
260     // We set this one to have CpuDevice only. To be used when m*CpuOnly is true.
261     std::vector<std::shared_ptr<Device>> mDevicesCpuOnly;
262 
263     // If either of these is true, we'll ignore the drivers that are
264     // on the device and run everything on the CPU.
265     bool mSetCpuOnly = false;      // set by setUseCpuOnly()
266     bool mDebugNNCpuOnly = false;  // derived from system property debug.nn.cpuonly
267 
268     // synchronous execution
269     bool mSyncExecCpu = true;
270     bool mSyncExecRuntime = false;
271 
272     static const uint32_t kPartitioningDefault = kPartitioningWithFallback;
273     uint32_t mPartitioning = kPartitioningDefault;
274 
275     bool mStrictSlicing = false;
276 };
277 
278 std::vector<SharedDevice> getDevices();
279 
280 }  // namespace nn
281 }  // namespace android
282 
283 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H
284