• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H
19 
20 #include <android-base/macros.h>
21 
22 #include <map>
23 #include <memory>
24 #include <mutex>
25 #include <set>
26 #include <stack>
27 #include <utility>
28 #include <vector>
29 
30 #include "CpuExecutor.h"
31 #include "HalInterfaces.h"
32 #include "Utils.h"
33 #include "ValidateHal.h"
34 
35 namespace android::nn {
36 
37 // This class manages a CPU buffer allocated on heap and provides validation methods.
38 class HalManagedBuffer {
39    public:
40     static std::shared_ptr<HalManagedBuffer> create(uint32_t size,
41                                                     std::set<HalPreparedModelRole> roles,
42                                                     const Operand& operand);
43 
44     // Prefer HalManagedBuffer::create.
45     HalManagedBuffer(std::unique_ptr<uint8_t[]> buffer, uint32_t size,
46                      std::set<HalPreparedModelRole> roles, const Operand& operand);
47 
createRunTimePoolInfo()48     RunTimePoolInfo createRunTimePoolInfo() const {
49         return RunTimePoolInfo::createFromExistingBuffer(kBuffer.get(), kSize);
50     }
51 
52     // "poolIndex" is the index of this buffer in the request.pools.
53     ErrorStatus validateRequest(uint32_t poolIndex, const Request& request,
54                                 const V1_3::IPreparedModel* preparedModel) const;
55 
56     // "size" is the byte size of the Memory provided to the copyFrom or copyTo method.
57     ErrorStatus validateCopyFrom(const std::vector<uint32_t>& dimensions, uint32_t size) const;
58     ErrorStatus validateCopyTo(uint32_t size) const;
59 
60     bool updateDimensions(const std::vector<uint32_t>& dimensions);
61     void setInitialized(bool initialized);
62 
63    private:
64     mutable std::mutex mMutex;
65     const std::unique_ptr<uint8_t[]> kBuffer;
66     const uint32_t kSize;
67     const std::set<HalPreparedModelRole> kRoles;
68     const OperandType kOperandType;
69     const std::vector<uint32_t> kInitialDimensions;
70     std::vector<uint32_t> mUpdatedDimensions;
71     bool mInitialized = false;
72 };
73 
74 // Keep track of all HalManagedBuffers and assign each with a unique token.
75 class HalBufferTracker : public std::enable_shared_from_this<HalBufferTracker> {
76     DISALLOW_COPY_AND_ASSIGN(HalBufferTracker);
77 
78    public:
79     // A RAII class to help manage the lifetime of the token.
80     // It is only supposed to be constructed in HalBufferTracker::add.
81     class Token {
82         DISALLOW_COPY_AND_ASSIGN(Token);
83 
84        public:
Token(uint32_t token,std::shared_ptr<HalBufferTracker> tracker)85         Token(uint32_t token, std::shared_ptr<HalBufferTracker> tracker)
86             : kToken(token), kHalBufferTracker(std::move(tracker)) {}
~Token()87         ~Token() { kHalBufferTracker->free(kToken); }
get()88         uint32_t get() const { return kToken; }
89 
90        private:
91         const uint32_t kToken;
92         const std::shared_ptr<HalBufferTracker> kHalBufferTracker;
93     };
94 
95     // The factory of HalBufferTracker. This ensures that the HalBufferTracker is always managed by
96     // a shared_ptr.
create()97     static std::shared_ptr<HalBufferTracker> create() {
98         return std::make_shared<HalBufferTracker>();
99     }
100 
101     // Prefer HalBufferTracker::create.
HalBufferTracker()102     HalBufferTracker() : mTokenToBuffers(1) {}
103 
104     std::unique_ptr<Token> add(std::shared_ptr<HalManagedBuffer> buffer);
105     std::shared_ptr<HalManagedBuffer> get(uint32_t token) const;
106 
107    private:
108     void free(uint32_t token);
109 
110     mutable std::mutex mMutex;
111     std::stack<uint32_t, std::vector<uint32_t>> mFreeTokens;
112 
113     // Since the tokens are allocated in a non-sparse way, we use a vector to represent the mapping.
114     // The index of the vector is the token. When the token gets freed, the corresponding entry is
115     // set to nullptr. mTokenToBuffers[0] is always set to nullptr because 0 is an invalid token.
116     std::vector<std::shared_ptr<HalManagedBuffer>> mTokenToBuffers;
117 };
118 
119 }  // namespace android::nn
120 
121 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H
122