1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_RUNTIME_MEMORY_H 18 #define ANDROID_ML_NN_RUNTIME_MEMORY_H 19 20 #include "NeuralNetworks.h" 21 #include "Utils.h" 22 23 #include <cutils/native_handle.h> 24 #include <sys/mman.h> 25 #include <mutex> 26 #include <unordered_map> 27 #include "vndk/hardware_buffer.h" 28 29 namespace android { 30 namespace nn { 31 32 class ExecutionBurstController; 33 class ModelBuilder; 34 35 // Represents a memory region. 36 class Memory { 37 public: Memory()38 Memory() {} 39 virtual ~Memory(); 40 41 // Disallow copy semantics to ensure the runtime object can only be freed 42 // once. Copy semantics could be enabled if some sort of reference counting 43 // or deep-copy system for runtime objects is added later. 44 Memory(const Memory&) = delete; 45 Memory& operator=(const Memory&) = delete; 46 47 // Creates a shared memory object of the size specified in bytes. 48 int create(uint32_t size); 49 getHidlMemory()50 hardware::hidl_memory getHidlMemory() const { return mHidlMemory; } 51 52 // Returns a pointer to the underlying memory of this memory object. 53 // The function will fail if the memory is not CPU accessible and nullptr 54 // will be returned. getPointer(uint8_t ** buffer)55 virtual int getPointer(uint8_t** buffer) const { 56 *buffer = static_cast<uint8_t*>(static_cast<void*>(mMemory->getPointer())); 57 if (*buffer == nullptr) { 58 return ANEURALNETWORKS_BAD_DATA; 59 } 60 return ANEURALNETWORKS_NO_ERROR; 61 } 62 63 virtual bool validateSize(uint32_t offset, uint32_t length) const; 64 65 // Unique key representing this memory object. 66 intptr_t getKey() const; 67 68 // Marks a burst object as currently using this memory. When this 69 // memory object is destroyed, it will automatically free this memory from 70 // the bursts' memory cache. 71 void usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const; 72 73 protected: 74 // The hidl_memory handle for this shared memory. We will pass this value when 75 // communicating with the drivers. 76 hardware::hidl_memory mHidlMemory; 77 sp<IMemory> mMemory; 78 79 mutable std::mutex mMutex; 80 // mUsedBy is essentially a set of burst objects which use this Memory 81 // object. However, std::weak_ptr does not have comparison operations nor a 82 // std::hash implementation. This is because it is either a valid pointer 83 // (non-null) if the shared object is still alive, or it is null if the 84 // object has been freed. To circumvent this, mUsedBy is a map with the raw 85 // pointer as the key and the weak_ptr as the value. 86 mutable std::unordered_map<const ExecutionBurstController*, 87 std::weak_ptr<ExecutionBurstController>> 88 mUsedBy; 89 }; 90 91 class MemoryFd : public Memory { 92 public: MemoryFd()93 MemoryFd() {} 94 ~MemoryFd() override; 95 96 // Disallow copy semantics to ensure the runtime object can only be freed 97 // once. Copy semantics could be enabled if some sort of reference counting 98 // or deep-copy system for runtime objects is added later. 99 MemoryFd(const MemoryFd&) = delete; 100 MemoryFd& operator=(const MemoryFd&) = delete; 101 102 // Create the native_handle based on input size, prot, and fd. 103 // Existing native_handle will be deleted, and mHidlMemory will wrap 104 // the newly created native_handle. 105 int set(size_t size, int prot, int fd, size_t offset); 106 107 int getPointer(uint8_t** buffer) const override; 108 109 private: 110 native_handle_t* mHandle = nullptr; 111 mutable uint8_t* mMapping = nullptr; 112 }; 113 114 // TODO(miaowang): move function definitions to Memory.cpp 115 class MemoryAHWB : public Memory { 116 public: MemoryAHWB()117 MemoryAHWB() {} ~MemoryAHWB()118 ~MemoryAHWB() override{}; 119 120 // Disallow copy semantics to ensure the runtime object can only be freed 121 // once. Copy semantics could be enabled if some sort of reference counting 122 // or deep-copy system for runtime objects is added later. 123 MemoryAHWB(const MemoryAHWB&) = delete; 124 MemoryAHWB& operator=(const MemoryAHWB&) = delete; 125 126 // Keep track of the provided AHardwareBuffer handle. set(const AHardwareBuffer * ahwb)127 int set(const AHardwareBuffer* ahwb) { 128 AHardwareBuffer_describe(ahwb, &mBufferDesc); 129 const native_handle_t* handle = AHardwareBuffer_getNativeHandle(ahwb); 130 mHardwareBuffer = ahwb; 131 if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) { 132 mHidlMemory = hidl_memory("hardware_buffer_blob", handle, mBufferDesc.width); 133 } else { 134 // memory size is not used. 135 mHidlMemory = hidl_memory("hardware_buffer", handle, 0); 136 } 137 return ANEURALNETWORKS_NO_ERROR; 138 }; 139 getPointer(uint8_t ** buffer)140 int getPointer(uint8_t** buffer) const override { 141 *buffer = nullptr; 142 return ANEURALNETWORKS_BAD_DATA; 143 }; 144 145 // validateSize should only be called for blob mode AHardwareBuffer. 146 // Calling it on non-blob mode AHardwareBuffer will result in an error. 147 // TODO(miaowang): consider separate blob and non-blob into different classes. validateSize(uint32_t offset,uint32_t length)148 bool validateSize(uint32_t offset, uint32_t length) const override { 149 if (mHardwareBuffer == nullptr) { 150 LOG(ERROR) << "MemoryAHWB has not been initialized."; 151 return false; 152 } 153 // validateSize should only be called on BLOB mode buffer. 154 if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) { 155 if (offset + length > mBufferDesc.width) { 156 LOG(ERROR) << "Request size larger than the memory size."; 157 return false; 158 } else { 159 return true; 160 } 161 } else { 162 LOG(ERROR) << "Invalid AHARDWAREBUFFER_FORMAT, must be AHARDWAREBUFFER_FORMAT_BLOB."; 163 return false; 164 } 165 } 166 167 private: 168 const AHardwareBuffer* mHardwareBuffer = nullptr; 169 AHardwareBuffer_Desc mBufferDesc; 170 }; 171 172 // A utility class to accumulate mulitple Memory objects and assign each 173 // a distinct index number, starting with 0. 174 // 175 // The user of this class is responsible for avoiding concurrent calls 176 // to this class from multiple threads. 177 class MemoryTracker { 178 private: 179 // The vector of Memory pointers we are building. 180 std::vector<const Memory*> mMemories; 181 // A faster way to see if we already have a memory than doing find(). 182 std::unordered_map<const Memory*, uint32_t> mKnown; 183 184 public: 185 // Adds the memory, if it does not already exists. Returns its index. 186 // The memories should survive the tracker. 187 uint32_t add(const Memory* memory); 188 // Returns the number of memories contained. size()189 uint32_t size() const { return static_cast<uint32_t>(mKnown.size()); } 190 // Returns the ith memory. 191 const Memory* operator[](size_t i) const { return mMemories[i]; } 192 // Iteration begin()193 decltype(mMemories.begin()) begin() { return mMemories.begin(); } end()194 decltype(mMemories.end()) end() { return mMemories.end(); } 195 }; 196 197 } // namespace nn 198 } // namespace android 199 200 #endif // ANDROID_ML_NN_RUNTIME_MEMORY_H 201