1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H 19 20 #include <LegacyUtils.h> 21 #include <android-base/macros.h> 22 #include <nnapi/IBurst.h> 23 #include <nnapi/IDevice.h> 24 #include <nnapi/Types.h> 25 26 #include <map> 27 #include <memory> 28 #include <string> 29 #include <tuple> 30 #include <unordered_set> 31 #include <utility> 32 #include <vector> 33 34 #include "ExecutionCallback.h" 35 #include "Memory.h" 36 37 namespace android { 38 namespace nn { 39 40 // Forward declaration 41 class Device; 42 class MetaModel; 43 class ModelArgumentInfo; 44 45 // A unified interface for a reusable execution with cached resources. 46 // This object provides no thread-safety guarantee. The caller must guarantee there is at most one 47 // call to RuntimeExecution::compute or RuntimeExecution::computeFenced on the same RuntimeExecution 48 // object in flight at a time. 49 class RuntimeExecution { 50 DISALLOW_COPY_AND_ASSIGN(RuntimeExecution); 51 52 public: 53 RuntimeExecution() = default; 54 virtual ~RuntimeExecution() = default; 55 56 virtual std::tuple<int, std::vector<OutputShape>, Timing> compute( 57 const SharedBurst& burstController, const OptionalTimePoint& deadline) const = 0; 58 59 // The returned timing information is only valid if the callback is nullptr. 60 // Returns error_code, sync_fence, callback and timing. 61 virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> computeFenced( 62 const std::vector<int>& waitFor, const OptionalTimePoint& deadline, 63 const OptionalDuration& timeoutDurationAfterFence) const = 0; 64 }; 65 66 // A unified interface for actual driver prepared model as well as the CPU. 67 class RuntimePreparedModel { 68 DISALLOW_COPY_AND_ASSIGN(RuntimePreparedModel); 69 70 public: 71 RuntimePreparedModel() = default; 72 virtual ~RuntimePreparedModel() = default; 73 74 virtual const Device* getDevice() const = 0; 75 virtual SharedPreparedModel getInterface() const = 0; 76 77 // Perform computation with given input/output argument info and memory pools. 78 virtual std::tuple<int, std::vector<OutputShape>, Timing> execute( 79 const std::vector<ModelArgumentInfo>& inputs, 80 const std::vector<ModelArgumentInfo>& outputs, 81 const std::vector<const RuntimeMemory*>& memories, const SharedBurst& burstController, 82 MeasureTiming measure, const OptionalTimePoint& deadline, 83 const OptionalDuration& loopTimeoutDuration, 84 const std::vector<TokenValuePair>& metaData) const = 0; 85 86 // Perform fenced computation with given input/output argument info and memory pools. 87 // The returned timing information is only valid if the callback is nullptr. 88 // Returns error_code, sync_fence, callback and timing. 89 virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> executeFenced( 90 const std::vector<ModelArgumentInfo>& inputs, 91 const std::vector<ModelArgumentInfo>& outputs, 92 const std::vector<const RuntimeMemory*>& memories, const std::vector<int>& waitFor, 93 MeasureTiming measure, const OptionalTimePoint& deadline, 94 const OptionalDuration& loopTimeoutDuration, 95 const OptionalDuration& timeoutDurationAfterFence, 96 const std::vector<TokenValuePair>& metaData) const = 0; 97 98 // Create a reusable execution with given input/output argument info and memory pools. 99 virtual std::pair<int, std::shared_ptr<RuntimeExecution>> createReusableExecution( 100 const std::vector<ModelArgumentInfo>& inputs, 101 const std::vector<ModelArgumentInfo>& outputs, 102 const std::vector<const RuntimeMemory*>& memories, MeasureTiming measure, 103 const OptionalDuration& loopTimeoutDuration, 104 const std::vector<TokenValuePair>& metaData) const = 0; 105 106 virtual GeneralResult<SharedBurst> configureExecutionBurst() const = 0; 107 108 virtual MemoryPreference getMemoryPreference() const = 0; 109 }; 110 111 using ModelFactory = std::function<Model()>; 112 113 struct CacheHandles { 114 std::vector<SharedHandle> modelCache; 115 std::vector<SharedHandle> dataCache; 116 }; 117 118 using CacheDir = std::string; 119 120 struct CacheInfo { 121 std::variant<CacheDir, CacheHandles> variant; 122 }; 123 124 // A unified interface for actual driver devices as well as the CPU 125 class Device { 126 DISALLOW_COPY_AND_ASSIGN(Device); 127 128 public: 129 Device() = default; 130 virtual ~Device() = default; 131 132 // Introspection methods returning device information 133 virtual const std::string& getName() const = 0; 134 virtual const std::string& getVersionString() const = 0; 135 virtual Version getFeatureLevel() const = 0; 136 virtual int32_t getType() const = 0; 137 virtual const std::vector<Extension>& getSupportedExtensions() const = 0; 138 139 // See the MetaModel class in MetaModel.h for more details. 140 virtual std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const = 0; 141 142 virtual const Capabilities& getCapabilities() const = 0; 143 virtual Capabilities::PerformanceInfo getPerformance(OperandType type) const = 0; 144 virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const = 0; 145 virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const = 0; 146 virtual Capabilities::PerformanceInfo getIfPerformance() const = 0; 147 virtual Capabilities::PerformanceInfo getWhilePerformance() const = 0; 148 virtual std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const = 0; 149 virtual bool isCachingSupported() const = 0; 150 virtual int wait() const = 0; 151 152 virtual std::pair<int, std::shared_ptr<RuntimePreparedModel>> prepareModel( 153 const ModelFactory& makeModel, ExecutionPreference preference, Priority priority, 154 const OptionalTimePoint& deadline, const CacheInfo& cacheInfo, 155 const std::optional<CacheToken>& maybeToken, 156 const std::vector<TokenValuePair>& metaData, 157 const std::vector<ExtensionNameAndPrefix>& extensionNameAndPrefix) const = 0; 158 159 // The caller is responsible for making sure the MemoryDescriptor only contains 160 // PreparedModels from the same Device. 161 virtual std::pair<int, std::unique_ptr<RuntimeMemory>> allocate(const MemoryDescriptor& desc, 162 OperandType type) const = 0; 163 }; 164 165 // Manages the NN HAL devices. Only one instance of this class will exist. 166 // Use get() to retrieve it. 167 class DeviceManager { 168 public: getDrivers()169 const std::vector<std::shared_ptr<Device>>& getDrivers() const { 170 if (mSetCpuOnly || mDebugNNCpuOnly) { 171 return mDevicesCpuOnly; 172 } 173 return mDevices; 174 } 175 176 // Gets the runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h). getRuntimeVersion()177 Version getRuntimeVersion() const { return mRuntimeVersion; } 178 179 // Gets the runtime feature level corresponding to getServerFeatureLevelFlag (in ServerFlag.h). 180 int64_t getRuntimeFeatureLevel() const; 181 182 // Convert the internal Version level representation to the NDK representation. 183 static int64_t versionToFeatureLevel(Version::Level versionLevel); 184 185 // Returns whether platform telemetry is enabled. isPlatformTelemetryEnabled()186 bool isPlatformTelemetryEnabled() const { return mIsPlatformTelemetryEnabled; } 187 188 // For testing only: setUseCpuOnly(bool useCpuOnly)189 void setUseCpuOnly(bool useCpuOnly) { mSetCpuOnly = useCpuOnly; } getUseCpuOnly()190 bool getUseCpuOnly() const { return mSetCpuOnly; } 191 syncExecCpu()192 bool syncExecCpu() const { return mSyncExecCpu; } syncExecRuntime()193 bool syncExecRuntime() const { return mSyncExecRuntime; } 194 195 // How to handle graph partitioning? 196 // 0 - Don't do graph partitioning. 197 // 1 - Do graph partitioning; but fall back to non-partitioned 198 // execution if there is a partitioning failure. 199 // 2 - Do graph partitioning, and rely on it; there is no fallback. 200 enum { kPartitioningNo = 0, kPartitioningWithFallback = 1, kPartitioningWithoutFallback = 2 }; getPartitioning()201 uint32_t getPartitioning() const { return mPartitioning; } partitioningAllowsFallback(uint32_t partitioning)202 static bool partitioningAllowsFallback(uint32_t partitioning) { 203 return partitioning == kPartitioningWithFallback; 204 } 205 strictSlicing()206 bool strictSlicing() const { return mStrictSlicing; } 207 208 // Returns the singleton manager. 209 static DeviceManager* get(); 210 211 // Returns the singleton Cpu device. 212 static std::shared_ptr<Device> getCpuDevice(); 213 214 // The forTest_* functions below are solely intended for use by unit tests. 215 216 // Returns all devices (ignores the cpu-only flags). forTest_getDevices()217 std::vector<std::shared_ptr<Device>> forTest_getDevices() const { return mDevices; } 218 219 // Sets the device list (does not affect cpu-only queries). forTest_setDevices(std::vector<std::shared_ptr<Device>> devices)220 void forTest_setDevices(std::vector<std::shared_ptr<Device>> devices) { 221 mDevices = std::move(devices); 222 } 223 224 // Register a test device. forTest_registerDevice(const SharedDevice & device)225 void forTest_registerDevice(const SharedDevice& device) { registerDevice(device); } 226 227 // Re-initialize the list of available devices. forTest_reInitializeDeviceList()228 void forTest_reInitializeDeviceList() { 229 mDevices.clear(); 230 mDevicesCpuOnly.clear(); 231 findAvailableDevices(); 232 } 233 234 // Make a test device 235 static std::shared_ptr<Device> forTest_makeDriverDevice(const SharedDevice& device); 236 forTest_isCpuDevice(const ANeuralNetworksDevice * device)237 bool forTest_isCpuDevice(const ANeuralNetworksDevice* device) const { 238 return reinterpret_cast<const Device*>(device) == getCpuDevice().get(); 239 } 240 241 private: 242 // Builds the list of available drivers and queries their capabilities. 243 DeviceManager(); 244 245 // Adds a device for the manager to use. 246 void registerDevice(const SharedDevice& device); 247 248 void findAvailableDevices(); 249 250 // Runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h). 251 Version mRuntimeVersion; 252 253 // Holds whether platform telemetry is enabled, as indicated by getServerTelemetryEnableFlag (in 254 // ServerFlag.h). 255 bool mIsPlatformTelemetryEnabled; 256 257 // List of all the devices we discovered (including CpuDevice). 258 std::vector<std::shared_ptr<Device>> mDevices; 259 260 // We set this one to have CpuDevice only. To be used when m*CpuOnly is true. 261 std::vector<std::shared_ptr<Device>> mDevicesCpuOnly; 262 263 // If either of these is true, we'll ignore the drivers that are 264 // on the device and run everything on the CPU. 265 bool mSetCpuOnly = false; // set by setUseCpuOnly() 266 bool mDebugNNCpuOnly = false; // derived from system property debug.nn.cpuonly 267 268 // synchronous execution 269 bool mSyncExecCpu = true; 270 bool mSyncExecRuntime = false; 271 272 static const uint32_t kPartitioningDefault = kPartitioningWithFallback; 273 uint32_t mPartitioning = kPartitioningDefault; 274 275 bool mStrictSlicing = false; 276 }; 277 278 std::vector<SharedDevice> getDevices(); 279 280 } // namespace nn 281 } // namespace android 282 283 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H 284