1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H 18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H 19 20 #include <CpuExecutor.h> 21 #include <LegacyUtils.h> 22 #include <android-base/macros.h> 23 #include <android-base/scopeguard.h> 24 #include <nnapi/IBuffer.h> 25 #include <nnapi/IBurst.h> 26 #include <nnapi/SharedMemory.h> 27 #include <nnapi/Validation.h> 28 #include <sys/mman.h> 29 30 #include <algorithm> 31 #include <map> 32 #include <memory> 33 #include <mutex> 34 #include <set> 35 #include <tuple> 36 #include <unordered_map> 37 #include <utility> 38 #include <vector> 39 40 #include "NeuralNetworks.h" 41 42 namespace android { 43 namespace nn { 44 45 class CompilationBuilder; 46 class Device; 47 class ModelBuilder; 48 class RuntimePreparedModel; 49 50 // A utility template class to accumulate multiple objects and assign each 51 // a distinct index number, starting with 0. 52 // 53 // The user of this class is responsible for avoiding concurrent calls 54 // to this class from multiple threads. 55 template <typename ObjectType> 56 class ObjectTracker { 57 public: 58 // Adds the object, if it does not already exists. Returns its index. 59 // The objects should survive the tracker. add(const ObjectType * object)60 uint32_t add(const ObjectType* object) { 61 VLOG(MEMORY) << __func__ << "(" << SHOW_IF_DEBUG(object) << ")"; 62 // See if we already have this object. If so, return its index. 63 auto i = mKnown.find(object); 64 if (i != mKnown.end()) { 65 return i->second; 66 } 67 VLOG(MEMORY) << "It's new"; 68 // It's a new one. Save it an assign an index to it. 69 size_t next = mKnown.size(); 70 uint32_t idx = static_cast<uint32_t>(next); 71 mKnown[object] = idx; 72 mObjects.push_back(object); 73 return idx; 74 } 75 76 // Returns the number of objects contained. size()77 uint32_t size() const { return mObjects.size(); } 78 // Returns the ith object. 79 const ObjectType* operator[](size_t i) const { 80 CHECK(i < size()); 81 return mObjects[i]; 82 } 83 // Iteration begin()84 auto begin() { return mObjects.begin(); } end()85 auto end() { return mObjects.end(); } begin()86 auto begin() const { return mObjects.begin(); } end()87 auto end() const { return mObjects.end(); } getObjects()88 const std::vector<const ObjectType*>& getObjects() const { return mObjects; } 89 90 private: 91 // The vector of object pointers we are building. 92 std::vector<const ObjectType*> mObjects; 93 // A faster way to see if we already have an object than doing find(). 94 std::unordered_map<const ObjectType*, uint32_t> mKnown; 95 }; 96 97 using CompilationRole = std::tuple<const CompilationBuilder*, IOType, uint32_t>; 98 99 struct MemoryDescriptor { 100 std::vector<uint32_t> dimensions; 101 ObjectTracker<RuntimePreparedModel> preparedModels; 102 std::vector<BufferRole> inputRoles, outputRoles; 103 }; 104 105 class MemoryValidatorBase { 106 DISALLOW_COPY_AND_ASSIGN(MemoryValidatorBase); 107 108 public: 109 MemoryValidatorBase() = default; 110 virtual ~MemoryValidatorBase() = default; 111 112 // Validate the memory usage and size information when passed in 113 // ANeuralNetworks{Model,Compilation}_set*FromMemory. 114 // 115 // This method only validates the arguments against the memory. It does not validate the 116 // correctness of the arguments themselves. E.g. it does not validate if the index is out of 117 // range. 118 // 119 // Usages: 120 // - ANeuralNetworksModel_setOperandValueFromMemory: 121 // validate(nullptr, IOType::INPUT, operandIndex, nullptr, offset, length) 122 // 123 // - ANeuralNetworksExecution_setInputFromMemory: 124 // validate(compilation, IOType::INPUT, inputIndex, type, offset, length) 125 // 126 // - ANeuralNetworksExecution_setOutputFromMemory: 127 // validate(compilation, IOType::OUTPUT, outputIndex, type, offset, length) 128 // 129 virtual bool validate(const CompilationBuilder* compilation, IOType ioType, uint32_t index, 130 const ANeuralNetworksOperandType* type, uint32_t offset, 131 uint32_t length) const = 0; 132 133 // Validate the memory dimensional information at the beginning of a computation. validateInputDimensions(const std::vector<uint32_t> &)134 virtual bool validateInputDimensions(const std::vector<uint32_t>&) const { return true; } 135 136 // The validation metadata for this memory. 137 struct Metadata { 138 // The byte size of the memory when it is transformed to a closely packed layout. 139 // Set to 0 if unknown (e.g. non-BLOB mode AHWB or device memory with dynamic shape). 140 uint32_t logicalSize; 141 142 // The dimensions of the memory. Set to empty if undefined. 143 std::vector<uint32_t> dimensions; 144 145 // The data type, scale, zero point, and extra parameters of the target operand. 146 // Other fields will be ignored, including dimensions, lifetime, location, etc. 147 // Set to std::nullopt if undefined. 148 std::optional<Operand> operand; 149 }; 150 virtual Metadata getMetadata() const = 0; 151 152 // Try update the memory metadata with the provided metadata. Return false if incompatible. 153 virtual bool updateMetadata(const Metadata& metadata) = 0; 154 155 // Whether the memory is created with unknown dimensions or rank. createdWithUnknownShape()156 virtual bool createdWithUnknownShape() const { return false; } 157 setInitialized(bool)158 virtual void setInitialized(bool) {} isInitialized()159 virtual bool isInitialized() const { return true; } 160 }; 161 162 int copyIBufferToMemory(const SharedBuffer& src, const SharedMemory& dst); 163 164 int copyMemoryToIBuffer(const SharedMemory& src, const SharedBuffer& dst, 165 const std::vector<uint32_t>& dimensions); 166 167 // Represents a memory region. 168 class RuntimeMemory { 169 // Disallow copy and assign to prevent slicing 170 DISALLOW_COPY_AND_ASSIGN(RuntimeMemory); 171 172 public: 173 virtual ~RuntimeMemory() = default; 174 175 Request::MemoryPool getMemoryPool() const; getMemory()176 const SharedMemory& getMemory() const { return kMemory; } getIBuffer()177 const SharedBuffer& getIBuffer() const { return kBuffer; } getSize()178 virtual uint32_t getSize() const { return nn::getSize(getMemory()); } 179 virtual std::optional<RunTimePoolInfo> getRunTimePoolInfo() const; 180 getValidator()181 MemoryValidatorBase& getValidator() const { 182 CHECK(mValidator != nullptr); 183 return *mValidator; 184 } 185 setValidator(std::unique_ptr<MemoryValidatorBase> validator)186 void setValidator(std::unique_ptr<MemoryValidatorBase> validator) { 187 mValidator = std::move(validator); 188 } 189 190 // This function binds `cacheHold` to the memory object, holding it for as long as the Memory 191 // object is alive. This keeps the cache present while the Memory object is alive. If 192 // `cacheHold` is null, this function is a no-op. 193 void hold(const IBurst::OptionalCacheHold& cacheHold) const; 194 195 static int copy(const RuntimeMemory& src, const RuntimeMemory& dst); 196 197 protected: 198 explicit RuntimeMemory(SharedMemory memory); 199 RuntimeMemory(SharedMemory memory, std::unique_ptr<MemoryValidatorBase> validator); 200 explicit RuntimeMemory(SharedBuffer buffer); 201 202 // The canonical representation for this memory. We will use one of the 203 // following values when communicating with the drivers. 204 const SharedMemory kMemory = std::make_shared<const Memory>(); 205 const SharedBuffer kBuffer; 206 207 std::unique_ptr<MemoryValidatorBase> mValidator; 208 209 private: 210 mutable std::mutex mMutex; 211 212 // This set contains `CacheHold` objects, holding it for as long as the Memory object is alive. 213 // This keeps the cache present while the Memory object is alive. 214 mutable std::set<IBurst::OptionalCacheHold> mHold; 215 216 mutable std::optional<RunTimePoolInfo> mCachedRunTimePoolInfo; 217 mutable bool mHasCachedRunTimePoolInfo = false; 218 }; 219 220 class MemoryBuilder { 221 DISALLOW_COPY_AND_ASSIGN(MemoryBuilder); 222 223 public: 224 MemoryBuilder() = default; 225 226 int addRole(const CompilationBuilder& compilation, IOType ioType, uint32_t index, float freq); 227 int setDimensions(const std::vector<uint32_t>& dimensions); 228 229 int finish(); 230 231 std::pair<int, std::unique_ptr<RuntimeMemory>> allocate() const; 232 233 private: 234 bool badState(const char* name) const; 235 236 // The memory descriptor that the MemoryBuilder is building. 237 MemoryDescriptor mDesc; 238 239 // The roles that have been specified via addRole. 240 // This is to check whether a new role has been seen before or not. 241 std::set<CompilationRole> mRoles; 242 243 // Keep track of the data type, scale, zero point, and extra parameters of the target operand. 244 // Other fields will be ignored, including dimensions, lifetime, location, etc. 245 // It is std::nullopt if no usage has been specified yet. 246 std::optional<Operand> mOperand; 247 248 // Once the descriptor has been finished, we should not allow further modifications. 249 bool mFinished = false; 250 251 // The following fields are only valid when finished. 252 253 // The chosen device to allocate the memory. Set to nullptr if there are multiple devices. 254 const Device* mAllocator = nullptr; 255 256 // Whether BLOB mode AHWB is supported on all of the relevant devices of the roles. 257 bool mSupportsAhwb = false; 258 259 // If set to true, allocate() will fallback to Ashmem or AHardwareBuffer if the memory 260 // allocation fails on the chosen device, or if there is no device chosen. 261 bool mShouldFallback = true; 262 }; 263 264 class MemoryAshmem : public RuntimeMemory { 265 public: 266 // Creates a memory object containing a new android shared memory ("ashmem") 267 // object of the size specified in bytes. Because this ashmem region can be 268 // shared with and accessed by one or more driver processes, MemoryAshmem 269 // has shared ownership over the ashmem region. 270 // 271 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 272 // On error, returns the appropriate NNAPI error code and nullptr. 273 static std::pair<int, std::unique_ptr<MemoryAshmem>> create(uint32_t size); 274 275 // Get a pointer to the ashmem region of memory. The returned pointer is 276 // valid for the lifetime of the MemoryAshmem object. This call always 277 // returns non-null because it was validated during MemoryAshmem::create. 278 uint8_t* getPointer() const; 279 getRunTimePoolInfo()280 std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override { 281 return RunTimePoolInfo::createFromExistingBuffer(getPointer(), nn::getSize(kMemory)); 282 } 283 284 // prefer using MemoryAshmem::create 285 MemoryAshmem(SharedMemory memory, Mapping mapped); 286 287 private: 288 const Mapping kMapping; 289 }; 290 291 class MemoryFd : public RuntimeMemory { 292 public: 293 // Create a memory object based on input size, prot, and fd. This function 294 // duplicates the provided fd, and owns the duplicate. 295 // 296 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 297 // On error, returns the appropriate NNAPI error code and nullptr. 298 static std::pair<int, std::unique_ptr<MemoryFd>> create(size_t size, int prot, int fd, 299 size_t offset); 300 301 // prefer using MemoryFd::create 302 explicit MemoryFd(SharedMemory memory); 303 }; 304 305 class MemoryAHWB : public RuntimeMemory { 306 public: 307 // Create a memory object to keep track of (but not take ownership of) the 308 // provided AHardwareBuffer handle. 309 // 310 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 311 // On error, returns the appropriate NNAPI error code and nullptr. 312 static std::pair<int, std::unique_ptr<MemoryAHWB>> create(const AHardwareBuffer& ahwb); 313 314 // prefer using MemoryAHWB::create MemoryAHWB(SharedMemory memory,std::unique_ptr<MemoryValidatorBase> validator)315 MemoryAHWB(SharedMemory memory, std::unique_ptr<MemoryValidatorBase> validator) 316 : RuntimeMemory(std::move(memory), std::move(validator)) {} 317 }; 318 319 class MemoryRuntimeAHWB : public RuntimeMemory { 320 public: 321 // Create a memory object containing a new BLOB-mode AHardwareBuffer memory 322 // object of the size specified in bytes. The created memory is managed and 323 // owned by the NNAPI runtime. 324 // 325 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 326 // On error, returns the appropriate NNAPI error code and nullptr. 327 static std::pair<int, std::unique_ptr<MemoryRuntimeAHWB>> create(uint32_t size); 328 329 // Get a pointer to the content of the memory. The returned pointer is 330 // valid for the lifetime of the MemoryRuntimeAHWB object. This call always 331 // returns non-null because it was validated during MemoryRuntimeAHWB::create. 332 uint8_t* getPointer() const; 333 getRunTimePoolInfo()334 std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override { 335 return RunTimePoolInfo::createFromExistingBuffer(getPointer(), nn::getSize(kMemory)); 336 } 337 338 // prefer using MemoryRuntimeAHWB::create 339 MemoryRuntimeAHWB(SharedMemory memory, Mapping mapping); 340 341 private: 342 const Mapping kMapping; 343 }; 344 345 class MemoryFromDevice : public RuntimeMemory { 346 public: 347 // Create a memory object to keep track of a driver-allocated device memory. 348 // The memory is recognized by the driver via a token. 349 // 350 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 351 // On error, returns the appropriate NNAPI error code and nullptr. 352 static std::pair<int, std::unique_ptr<MemoryFromDevice>> create(SharedBuffer buffer); 353 354 // prefer using MemoryFromDevice::create 355 explicit MemoryFromDevice(SharedBuffer buffer); 356 }; 357 358 using MemoryTracker = ObjectTracker<RuntimeMemory>; 359 360 } // namespace nn 361 } // namespace android 362 363 #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H 364