• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H
18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H
19 
20 #include <CpuExecutor.h>
21 #include <LegacyUtils.h>
22 #include <android-base/macros.h>
23 #include <android-base/scopeguard.h>
24 #include <nnapi/IBuffer.h>
25 #include <nnapi/IBurst.h>
26 #include <nnapi/SharedMemory.h>
27 #include <nnapi/Validation.h>
28 #include <sys/mman.h>
29 
30 #include <algorithm>
31 #include <map>
32 #include <memory>
33 #include <mutex>
34 #include <set>
35 #include <tuple>
36 #include <unordered_map>
37 #include <utility>
38 #include <vector>
39 
40 #include "NeuralNetworks.h"
41 
42 namespace android {
43 namespace nn {
44 
45 class CompilationBuilder;
46 class Device;
47 class ModelBuilder;
48 class RuntimePreparedModel;
49 
50 // A utility template class to accumulate multiple objects and assign each
51 // a distinct index number, starting with 0.
52 //
53 // The user of this class is responsible for avoiding concurrent calls
54 // to this class from multiple threads.
55 template <typename ObjectType>
56 class ObjectTracker {
57    public:
58     // Adds the object, if it does not already exists.  Returns its index.
59     // The objects should survive the tracker.
add(const ObjectType * object)60     uint32_t add(const ObjectType* object) {
61         VLOG(MEMORY) << __func__ << "(" << SHOW_IF_DEBUG(object) << ")";
62         // See if we already have this object. If so, return its index.
63         auto i = mKnown.find(object);
64         if (i != mKnown.end()) {
65             return i->second;
66         }
67         VLOG(MEMORY) << "It's new";
68         // It's a new one.  Save it an assign an index to it.
69         size_t next = mKnown.size();
70         uint32_t idx = static_cast<uint32_t>(next);
71         mKnown[object] = idx;
72         mObjects.push_back(object);
73         return idx;
74     }
75 
76     // Returns the number of objects contained.
size()77     uint32_t size() const { return mObjects.size(); }
78     // Returns the ith object.
79     const ObjectType* operator[](size_t i) const {
80         CHECK(i < size());
81         return mObjects[i];
82     }
83     // Iteration
begin()84     auto begin() { return mObjects.begin(); }
end()85     auto end() { return mObjects.end(); }
begin()86     auto begin() const { return mObjects.begin(); }
end()87     auto end() const { return mObjects.end(); }
getObjects()88     const std::vector<const ObjectType*>& getObjects() const { return mObjects; }
89 
90    private:
91     // The vector of object pointers we are building.
92     std::vector<const ObjectType*> mObjects;
93     // A faster way to see if we already have an object than doing find().
94     std::unordered_map<const ObjectType*, uint32_t> mKnown;
95 };
96 
97 using CompilationRole = std::tuple<const CompilationBuilder*, IOType, uint32_t>;
98 
99 struct MemoryDescriptor {
100     std::vector<uint32_t> dimensions;
101     ObjectTracker<RuntimePreparedModel> preparedModels;
102     std::vector<BufferRole> inputRoles, outputRoles;
103 };
104 
105 class MemoryValidatorBase {
106     DISALLOW_COPY_AND_ASSIGN(MemoryValidatorBase);
107 
108    public:
109     MemoryValidatorBase() = default;
110     virtual ~MemoryValidatorBase() = default;
111 
112     // Validate the memory usage and size information when passed in
113     // ANeuralNetworks{Model,Compilation}_set*FromMemory.
114     //
115     // This method only validates the arguments against the memory. It does not validate the
116     // correctness of the arguments themselves. E.g. it does not validate if the index is out of
117     // range.
118     //
119     // Usages:
120     //   - ANeuralNetworksModel_setOperandValueFromMemory:
121     //         validate(nullptr, IOType::INPUT, operandIndex, nullptr, offset, length)
122     //
123     //   - ANeuralNetworksExecution_setInputFromMemory:
124     //         validate(compilation, IOType::INPUT, inputIndex, type, offset, length)
125     //
126     //   - ANeuralNetworksExecution_setOutputFromMemory:
127     //         validate(compilation, IOType::OUTPUT, outputIndex, type, offset, length)
128     //
129     virtual bool validate(const CompilationBuilder* compilation, IOType ioType, uint32_t index,
130                           const ANeuralNetworksOperandType* type, uint32_t offset,
131                           uint32_t length) const = 0;
132 
133     // Validate the memory dimensional information at the beginning of a computation.
validateInputDimensions(const std::vector<uint32_t> &)134     virtual bool validateInputDimensions(const std::vector<uint32_t>&) const { return true; }
135 
136     // The validation metadata for this memory.
137     struct Metadata {
138         // The byte size of the memory when it is transformed to a closely packed layout.
139         // Set to 0 if unknown (e.g. non-BLOB mode AHWB or device memory with dynamic shape).
140         uint32_t logicalSize;
141 
142         // The dimensions of the memory. Set to empty if undefined.
143         std::vector<uint32_t> dimensions;
144 
145         // The data type, scale, zero point, and extra parameters of the target operand.
146         // Other fields will be ignored, including dimensions, lifetime, location, etc.
147         // Set to std::nullopt if undefined.
148         std::optional<Operand> operand;
149     };
150     virtual Metadata getMetadata() const = 0;
151 
152     // Try update the memory metadata with the provided metadata. Return false if incompatible.
153     virtual bool updateMetadata(const Metadata& metadata) = 0;
154 
155     // Whether the memory is created with unknown dimensions or rank.
createdWithUnknownShape()156     virtual bool createdWithUnknownShape() const { return false; }
157 
setInitialized(bool)158     virtual void setInitialized(bool) {}
isInitialized()159     virtual bool isInitialized() const { return true; }
160 };
161 
162 int copyIBufferToMemory(const SharedBuffer& src, const SharedMemory& dst);
163 
164 int copyMemoryToIBuffer(const SharedMemory& src, const SharedBuffer& dst,
165                         const std::vector<uint32_t>& dimensions);
166 
167 // Represents a memory region.
168 class RuntimeMemory {
169     // Disallow copy and assign to prevent slicing
170     DISALLOW_COPY_AND_ASSIGN(RuntimeMemory);
171 
172    public:
173     virtual ~RuntimeMemory() = default;
174 
175     Request::MemoryPool getMemoryPool() const;
getMemory()176     const SharedMemory& getMemory() const { return kMemory; }
getIBuffer()177     const SharedBuffer& getIBuffer() const { return kBuffer; }
getSize()178     virtual uint32_t getSize() const { return nn::getSize(getMemory()); }
179     virtual std::optional<RunTimePoolInfo> getRunTimePoolInfo() const;
180 
getValidator()181     MemoryValidatorBase& getValidator() const {
182         CHECK(mValidator != nullptr);
183         return *mValidator;
184     }
185 
setValidator(std::unique_ptr<MemoryValidatorBase> validator)186     void setValidator(std::unique_ptr<MemoryValidatorBase> validator) {
187         mValidator = std::move(validator);
188     }
189 
190     // This function binds `cacheHold` to the memory object, holding it for as long as the Memory
191     // object is alive. This keeps the cache present while the Memory object is alive. If
192     // `cacheHold` is null, this function is a no-op.
193     void hold(const IBurst::OptionalCacheHold& cacheHold) const;
194 
195     static int copy(const RuntimeMemory& src, const RuntimeMemory& dst);
196 
197    protected:
198     explicit RuntimeMemory(SharedMemory memory);
199     RuntimeMemory(SharedMemory memory, std::unique_ptr<MemoryValidatorBase> validator);
200     explicit RuntimeMemory(SharedBuffer buffer);
201 
202     // The canonical representation for this memory.  We will use one of the
203     // following values when communicating with the drivers.
204     const SharedMemory kMemory = std::make_shared<const Memory>();
205     const SharedBuffer kBuffer;
206 
207     std::unique_ptr<MemoryValidatorBase> mValidator;
208 
209    private:
210     mutable std::mutex mMutex;
211 
212     // This set contains `CacheHold` objects, holding it for as long as the Memory object is alive.
213     // This keeps the cache present while the Memory object is alive.
214     mutable std::set<IBurst::OptionalCacheHold> mHold;
215 
216     mutable std::optional<RunTimePoolInfo> mCachedRunTimePoolInfo;
217     mutable bool mHasCachedRunTimePoolInfo = false;
218 };
219 
220 class MemoryBuilder {
221     DISALLOW_COPY_AND_ASSIGN(MemoryBuilder);
222 
223    public:
224     MemoryBuilder() = default;
225 
226     int addRole(const CompilationBuilder& compilation, IOType ioType, uint32_t index, float freq);
227     int setDimensions(const std::vector<uint32_t>& dimensions);
228 
229     int finish();
230 
231     std::pair<int, std::unique_ptr<RuntimeMemory>> allocate() const;
232 
233    private:
234     bool badState(const char* name) const;
235 
236     // The memory descriptor that the MemoryBuilder is building.
237     MemoryDescriptor mDesc;
238 
239     // The roles that have been specified via addRole.
240     // This is to check whether a new role has been seen before or not.
241     std::set<CompilationRole> mRoles;
242 
243     // Keep track of the data type, scale, zero point, and extra parameters of the target operand.
244     // Other fields will be ignored, including dimensions, lifetime, location, etc.
245     // It is std::nullopt if no usage has been specified yet.
246     std::optional<Operand> mOperand;
247 
248     // Once the descriptor has been finished, we should not allow further modifications.
249     bool mFinished = false;
250 
251     // The following fields are only valid when finished.
252 
253     // The chosen device to allocate the memory. Set to nullptr if there are multiple devices.
254     const Device* mAllocator = nullptr;
255 
256     // Whether BLOB mode AHWB is supported on all of the relevant devices of the roles.
257     bool mSupportsAhwb = false;
258 
259     // If set to true, allocate() will fallback to Ashmem or AHardwareBuffer if the memory
260     // allocation fails on the chosen device, or if there is no device chosen.
261     bool mShouldFallback = true;
262 };
263 
264 class MemoryAshmem : public RuntimeMemory {
265    public:
266     // Creates a memory object containing a new android shared memory ("ashmem")
267     // object of the size specified in bytes. Because this ashmem region can be
268     // shared with and accessed by one or more driver processes, MemoryAshmem
269     // has shared ownership over the ashmem region.
270     //
271     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
272     // On error, returns the appropriate NNAPI error code and nullptr.
273     static std::pair<int, std::unique_ptr<MemoryAshmem>> create(uint32_t size);
274 
275     // Get a pointer to the ashmem region of memory. The returned pointer is
276     // valid for the lifetime of the MemoryAshmem object. This call always
277     // returns non-null because it was validated during MemoryAshmem::create.
278     uint8_t* getPointer() const;
279 
getRunTimePoolInfo()280     std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override {
281         return RunTimePoolInfo::createFromExistingBuffer(getPointer(), nn::getSize(kMemory));
282     }
283 
284     // prefer using MemoryAshmem::create
285     MemoryAshmem(SharedMemory memory, Mapping mapped);
286 
287    private:
288     const Mapping kMapping;
289 };
290 
291 class MemoryFd : public RuntimeMemory {
292    public:
293     // Create a memory object based on input size, prot, and fd. This function
294     // duplicates the provided fd, and owns the duplicate.
295     //
296     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
297     // On error, returns the appropriate NNAPI error code and nullptr.
298     static std::pair<int, std::unique_ptr<MemoryFd>> create(size_t size, int prot, int fd,
299                                                             size_t offset);
300 
301     // prefer using MemoryFd::create
302     explicit MemoryFd(SharedMemory memory);
303 };
304 
305 class MemoryAHWB : public RuntimeMemory {
306    public:
307     // Create a memory object to keep track of (but not take ownership of) the
308     // provided AHardwareBuffer handle.
309     //
310     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
311     // On error, returns the appropriate NNAPI error code and nullptr.
312     static std::pair<int, std::unique_ptr<MemoryAHWB>> create(const AHardwareBuffer& ahwb);
313 
314     // prefer using MemoryAHWB::create
MemoryAHWB(SharedMemory memory,std::unique_ptr<MemoryValidatorBase> validator)315     MemoryAHWB(SharedMemory memory, std::unique_ptr<MemoryValidatorBase> validator)
316         : RuntimeMemory(std::move(memory), std::move(validator)) {}
317 };
318 
319 class MemoryRuntimeAHWB : public RuntimeMemory {
320    public:
321     // Create a memory object containing a new BLOB-mode AHardwareBuffer memory
322     // object of the size specified in bytes. The created memory is managed and
323     // owned by the NNAPI runtime.
324     //
325     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
326     // On error, returns the appropriate NNAPI error code and nullptr.
327     static std::pair<int, std::unique_ptr<MemoryRuntimeAHWB>> create(uint32_t size);
328 
329     // Get a pointer to the content of the memory. The returned pointer is
330     // valid for the lifetime of the MemoryRuntimeAHWB object. This call always
331     // returns non-null because it was validated during MemoryRuntimeAHWB::create.
332     uint8_t* getPointer() const;
333 
getRunTimePoolInfo()334     std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override {
335         return RunTimePoolInfo::createFromExistingBuffer(getPointer(), nn::getSize(kMemory));
336     }
337 
338     // prefer using MemoryRuntimeAHWB::create
339     MemoryRuntimeAHWB(SharedMemory memory, Mapping mapping);
340 
341    private:
342     const Mapping kMapping;
343 };
344 
345 class MemoryFromDevice : public RuntimeMemory {
346    public:
347     // Create a memory object to keep track of a driver-allocated device memory.
348     // The memory is recognized by the driver via a token.
349     //
350     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
351     // On error, returns the appropriate NNAPI error code and nullptr.
352     static std::pair<int, std::unique_ptr<MemoryFromDevice>> create(SharedBuffer buffer);
353 
354     // prefer using MemoryFromDevice::create
355     explicit MemoryFromDevice(SharedBuffer buffer);
356 };
357 
358 using MemoryTracker = ObjectTracker<RuntimeMemory>;
359 
360 }  // namespace nn
361 }  // namespace android
362 
363 #endif  // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H
364