1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H 19 20 #include <android-base/macros.h> 21 22 #include <map> 23 #include <memory> 24 #include <mutex> 25 #include <set> 26 #include <stack> 27 #include <utility> 28 #include <vector> 29 30 #include "CpuExecutor.h" 31 #include "HalInterfaces.h" 32 #include "Utils.h" 33 #include "ValidateHal.h" 34 35 namespace android::nn { 36 37 // This class manages a CPU buffer allocated on heap and provides validation methods. 38 class HalManagedBuffer { 39 public: 40 static std::shared_ptr<HalManagedBuffer> create(uint32_t size, 41 std::set<HalPreparedModelRole> roles, 42 const Operand& operand); 43 44 // Prefer HalManagedBuffer::create. 45 HalManagedBuffer(std::unique_ptr<uint8_t[]> buffer, uint32_t size, 46 std::set<HalPreparedModelRole> roles, const Operand& operand); 47 createRunTimePoolInfo()48 RunTimePoolInfo createRunTimePoolInfo() const { 49 return RunTimePoolInfo::createFromExistingBuffer(kBuffer.get(), kSize); 50 } 51 52 // "poolIndex" is the index of this buffer in the request.pools. 53 ErrorStatus validateRequest(uint32_t poolIndex, const Request& request, 54 const V1_3::IPreparedModel* preparedModel) const; 55 56 // "size" is the byte size of the Memory provided to the copyFrom or copyTo method. 57 ErrorStatus validateCopyFrom(const std::vector<uint32_t>& dimensions, uint32_t size) const; 58 ErrorStatus validateCopyTo(uint32_t size) const; 59 60 bool updateDimensions(const std::vector<uint32_t>& dimensions); 61 void setInitialized(bool initialized); 62 63 private: 64 mutable std::mutex mMutex; 65 const std::unique_ptr<uint8_t[]> kBuffer; 66 const uint32_t kSize; 67 const std::set<HalPreparedModelRole> kRoles; 68 const OperandType kOperandType; 69 const std::vector<uint32_t> kInitialDimensions; 70 std::vector<uint32_t> mUpdatedDimensions; 71 bool mInitialized = false; 72 }; 73 74 // Keep track of all HalManagedBuffers and assign each with a unique token. 75 class HalBufferTracker : public std::enable_shared_from_this<HalBufferTracker> { 76 DISALLOW_COPY_AND_ASSIGN(HalBufferTracker); 77 78 public: 79 // A RAII class to help manage the lifetime of the token. 80 // It is only supposed to be constructed in HalBufferTracker::add. 81 class Token { 82 DISALLOW_COPY_AND_ASSIGN(Token); 83 84 public: Token(uint32_t token,std::shared_ptr<HalBufferTracker> tracker)85 Token(uint32_t token, std::shared_ptr<HalBufferTracker> tracker) 86 : kToken(token), kHalBufferTracker(std::move(tracker)) {} ~Token()87 ~Token() { kHalBufferTracker->free(kToken); } get()88 uint32_t get() const { return kToken; } 89 90 private: 91 const uint32_t kToken; 92 const std::shared_ptr<HalBufferTracker> kHalBufferTracker; 93 }; 94 95 // The factory of HalBufferTracker. This ensures that the HalBufferTracker is always managed by 96 // a shared_ptr. create()97 static std::shared_ptr<HalBufferTracker> create() { 98 return std::make_shared<HalBufferTracker>(); 99 } 100 101 // Prefer HalBufferTracker::create. HalBufferTracker()102 HalBufferTracker() : mTokenToBuffers(1) {} 103 104 std::unique_ptr<Token> add(std::shared_ptr<HalManagedBuffer> buffer); 105 std::shared_ptr<HalManagedBuffer> get(uint32_t token) const; 106 107 private: 108 void free(uint32_t token); 109 110 mutable std::mutex mMutex; 111 std::stack<uint32_t, std::vector<uint32_t>> mFreeTokens; 112 113 // Since the tokens are allocated in a non-sparse way, we use a vector to represent the mapping. 114 // The index of the vector is the token. When the token gets freed, the corresponding entry is 115 // set to nullptr. mTokenToBuffers[0] is always set to nullptr because 0 is an invalid token. 116 std::vector<std::shared_ptr<HalManagedBuffer>> mTokenToBuffers; 117 }; 118 119 } // namespace android::nn 120 121 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H 122