1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MANAGER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MANAGER_H 19 20 #include <LegacyUtils.h> 21 #include <android-base/macros.h> 22 #include <nnapi/IBurst.h> 23 #include <nnapi/IDevice.h> 24 #include <nnapi/Types.h> 25 26 #include <map> 27 #include <memory> 28 #include <string> 29 #include <tuple> 30 #include <unordered_set> 31 #include <utility> 32 #include <vector> 33 34 #include "ExecutionCallback.h" 35 #include "Memory.h" 36 37 namespace android { 38 namespace nn { 39 40 // Forward declaration 41 class Device; 42 class MetaModel; 43 class ModelArgumentInfo; 44 45 // A unified interface for a reusable execution with cached resources. 46 // This object provides no thread-safety guarantee. The caller must guarantee there is at most one 47 // call to RuntimeExecution::compute or RuntimeExecution::computeFenced on the same RuntimeExecution 48 // object in flight at a time. 49 class RuntimeExecution { 50 DISALLOW_COPY_AND_ASSIGN(RuntimeExecution); 51 52 public: 53 RuntimeExecution() = default; 54 virtual ~RuntimeExecution() = default; 55 56 virtual std::tuple<int, std::vector<OutputShape>, Timing> compute( 57 const SharedBurst& burstController, const OptionalTimePoint& deadline) const = 0; 58 59 // The returned timing information is only valid if the callback is nullptr. 60 // Returns error_code, sync_fence, callback and timing. 61 virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> computeFenced( 62 const std::vector<int>& waitFor, const OptionalTimePoint& deadline, 63 const OptionalDuration& timeoutDurationAfterFence) const = 0; 64 }; 65 66 // A unified interface for actual driver prepared model as well as the CPU. 67 class RuntimePreparedModel { 68 DISALLOW_COPY_AND_ASSIGN(RuntimePreparedModel); 69 70 public: 71 RuntimePreparedModel() = default; 72 virtual ~RuntimePreparedModel() = default; 73 74 virtual const Device* getDevice() const = 0; 75 virtual SharedPreparedModel getInterface() const = 0; 76 77 // Perform computation with given input/output argument info and memory pools. 78 virtual std::tuple<int, std::vector<OutputShape>, Timing> execute( 79 const std::vector<ModelArgumentInfo>& inputs, 80 const std::vector<ModelArgumentInfo>& outputs, 81 const std::vector<const RuntimeMemory*>& memories, const SharedBurst& burstController, 82 MeasureTiming measure, const OptionalTimePoint& deadline, 83 const OptionalDuration& loopTimeoutDuration) const = 0; 84 85 // Perform fenced computation with given input/output argument info and memory pools. 86 // The returned timing information is only valid if the callback is nullptr. 87 // Returns error_code, sync_fence, callback and timing. 88 virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> executeFenced( 89 const std::vector<ModelArgumentInfo>& inputs, 90 const std::vector<ModelArgumentInfo>& outputs, 91 const std::vector<const RuntimeMemory*>& memories, const std::vector<int>& waitFor, 92 MeasureTiming measure, const OptionalTimePoint& deadline, 93 const OptionalDuration& loopTimeoutDuration, 94 const OptionalDuration& timeoutDurationAfterFence) const = 0; 95 96 // Create a reusable execution with given input/output argument info and memory pools. 97 virtual std::pair<int, std::shared_ptr<RuntimeExecution>> createReusableExecution( 98 const std::vector<ModelArgumentInfo>& inputs, 99 const std::vector<ModelArgumentInfo>& outputs, 100 const std::vector<const RuntimeMemory*>& memories, MeasureTiming measure, 101 const OptionalDuration& loopTimeoutDuration) const = 0; 102 103 virtual GeneralResult<SharedBurst> configureExecutionBurst() const = 0; 104 105 virtual MemoryPreference getMemoryPreference() const = 0; 106 }; 107 108 using ModelFactory = std::function<Model()>; 109 110 struct CacheHandles { 111 std::vector<SharedHandle> modelCache; 112 std::vector<SharedHandle> dataCache; 113 }; 114 115 using CacheDir = std::string; 116 117 struct CacheInfo { 118 std::variant<CacheDir, CacheHandles> variant; 119 }; 120 121 // A unified interface for actual driver devices as well as the CPU 122 class Device { 123 DISALLOW_COPY_AND_ASSIGN(Device); 124 125 public: 126 Device() = default; 127 virtual ~Device() = default; 128 129 // Introspection methods returning device information 130 virtual const std::string& getName() const = 0; 131 virtual const std::string& getVersionString() const = 0; 132 virtual int64_t getFeatureLevel() const = 0; 133 virtual int32_t getType() const = 0; 134 virtual bool isUpdatable() const = 0; 135 virtual const std::vector<Extension>& getSupportedExtensions() const = 0; 136 137 // See the MetaModel class in MetaModel.h for more details. 138 virtual std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const = 0; 139 140 virtual const Capabilities& getCapabilities() const = 0; 141 virtual Capabilities::PerformanceInfo getPerformance(OperandType type) const = 0; 142 virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const = 0; 143 virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const = 0; 144 virtual Capabilities::PerformanceInfo getIfPerformance() const = 0; 145 virtual Capabilities::PerformanceInfo getWhilePerformance() const = 0; 146 virtual std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const = 0; 147 virtual bool isCachingSupported() const = 0; 148 virtual int wait() const = 0; 149 150 virtual std::pair<int, std::shared_ptr<RuntimePreparedModel>> prepareModel( 151 const ModelFactory& makeModel, ExecutionPreference preference, Priority priority, 152 const OptionalTimePoint& deadline, const CacheInfo& cacheInfo, 153 const std::optional<CacheToken>& maybeToken) const = 0; 154 155 // The caller is responsible for making sure the MemoryDescriptor only contains 156 // PreparedModels from the same Device. 157 virtual std::pair<int, std::unique_ptr<RuntimeMemory>> allocate(const MemoryDescriptor& desc, 158 OperandType type) const = 0; 159 }; 160 161 // Manages the NN HAL devices. Only one instance of this class will exist. 162 // Use get() to retrieve it. 163 class DeviceManager { 164 public: getDrivers()165 const std::vector<std::shared_ptr<Device>>& getDrivers() const { 166 if (mSetCpuOnly || mDebugNNCpuOnly) { 167 return mDevicesCpuOnly; 168 } 169 return mDevices; 170 } 171 172 // For testing only: setUseCpuOnly(bool useCpuOnly)173 void setUseCpuOnly(bool useCpuOnly) { mSetCpuOnly = useCpuOnly; } getUseCpuOnly()174 bool getUseCpuOnly() const { return mSetCpuOnly; } 175 syncExecCpu()176 bool syncExecCpu() const { return mSyncExecCpu; } syncExecRuntime()177 bool syncExecRuntime() const { return mSyncExecRuntime; } 178 179 // How to handle graph partitioning? 180 // 0 - Don't do graph partitioning. 181 // 1 - Do graph partitioning; but fall back to non-partitioned 182 // execution if there is a partitioning failure. 183 // 2 - Do graph partitioning, and rely on it; there is no fallback. 184 enum { kPartitioningNo = 0, kPartitioningWithFallback = 1, kPartitioningWithoutFallback = 2 }; getPartitioning()185 uint32_t getPartitioning() const { return mPartitioning; } partitioningAllowsFallback(uint32_t partitioning)186 static bool partitioningAllowsFallback(uint32_t partitioning) { 187 return partitioning == kPartitioningWithFallback; 188 } 189 strictSlicing()190 bool strictSlicing() const { return mStrictSlicing; } 191 192 // Returns the singleton manager. 193 static DeviceManager* get(); 194 195 // Returns the singleton Cpu device. 196 static std::shared_ptr<Device> getCpuDevice(); 197 198 // The forTest_* functions below are solely intended for use by unit tests. 199 200 // Returns all devices (ignores the cpu-only flags). forTest_getDevices()201 std::vector<std::shared_ptr<Device>> forTest_getDevices() const { return mDevices; } 202 203 // Sets the device list (does not affect cpu-only queries). forTest_setDevices(std::vector<std::shared_ptr<Device>> devices)204 void forTest_setDevices(std::vector<std::shared_ptr<Device>> devices) { 205 mDevices = std::move(devices); 206 } 207 208 // Register a test device. forTest_registerDevice(const SharedDevice & device)209 void forTest_registerDevice(const SharedDevice& device) { registerDevice(device); } 210 211 // Re-initialize the list of available devices. forTest_reInitializeDeviceList()212 void forTest_reInitializeDeviceList() { 213 mDevices.clear(); 214 mDevicesCpuOnly.clear(); 215 findAvailableDevices(); 216 } 217 218 // Make a test device 219 static std::shared_ptr<Device> forTest_makeDriverDevice(const SharedDevice& device); 220 forTest_isCpuDevice(const ANeuralNetworksDevice * device)221 bool forTest_isCpuDevice(const ANeuralNetworksDevice* device) const { 222 return reinterpret_cast<const Device*>(device) == getCpuDevice().get(); 223 } 224 225 private: 226 // Builds the list of available drivers and queries their capabilities. 227 DeviceManager(); 228 229 // Adds a device for the manager to use. 230 void registerDevice(const SharedDevice& device); 231 232 void findAvailableDevices(); 233 234 // List of all the devices we discovered (including CpuDevice). 235 std::vector<std::shared_ptr<Device>> mDevices; 236 237 // We set this one to have CpuDevice only. To be used when m*CpuOnly is true. 238 std::vector<std::shared_ptr<Device>> mDevicesCpuOnly; 239 240 // If either of these is true, we'll ignore the drivers that are 241 // on the device and run everything on the CPU. 242 bool mSetCpuOnly = false; // set by setUseCpuOnly() 243 bool mDebugNNCpuOnly = false; // derived from system property debug.nn.cpuonly 244 245 // synchronous execution 246 bool mSyncExecCpu = true; 247 bool mSyncExecRuntime = false; 248 249 static const uint32_t kPartitioningDefault = kPartitioningWithFallback; 250 uint32_t mPartitioning = kPartitioningDefault; 251 252 bool mStrictSlicing = false; 253 }; 254 255 std::vector<SharedDevice> getDevices(); 256 257 } // namespace nn 258 } // namespace android 259 260 #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MANAGER_H 261