• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CanonicalBuffer.h"
18 
19 #include <android-base/logging.h>
20 #include <nnapi/IPreparedModel.h>
21 #include <nnapi/Result.h>
22 #include <nnapi/Types.h>
23 
24 #include <algorithm>
25 #include <memory>
26 #include <utility>
27 
28 namespace android::nn::sample {
29 namespace {
30 
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)31 void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
32     CHECK(srcPool.getBuffer() != nullptr);
33     CHECK(dstPool.getBuffer() != nullptr);
34     CHECK(srcPool.getSize() == dstPool.getSize());
35     std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
36     dstPool.flush();
37 }
38 
copyFromInternal(const SharedMemory & src,const Dimensions & dimensions,const std::shared_ptr<ManagedBuffer> & bufferWrapper)39 GeneralResult<void> copyFromInternal(const SharedMemory& src, const Dimensions& dimensions,
40                                      const std::shared_ptr<ManagedBuffer>& bufferWrapper) {
41     CHECK(bufferWrapper != nullptr);
42     const auto srcPool = RunTimePoolInfo::createFromMemory(src);
43     if (!srcPool.has_value()) {
44         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
45                << "SampleBuffer::copyFrom -- unable to map src memory.";
46     }
47     const ErrorStatus validationStatus =
48             bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize());
49     if (validationStatus != ErrorStatus::NONE) {
50         return NN_ERROR(validationStatus);
51     }
52     const auto dstPool = bufferWrapper->createRunTimePoolInfo();
53     copyRunTimePoolInfos(srcPool.value(), dstPool);
54 
55     return {};
56 }
57 
58 }  // namespace
59 
Buffer(std::shared_ptr<ManagedBuffer> buffer,std::unique_ptr<BufferTracker::Token> token)60 Buffer::Buffer(std::shared_ptr<ManagedBuffer> buffer, std::unique_ptr<BufferTracker::Token> token)
61     : kBuffer(std::move(buffer)), kToken(std::move(token)) {
62     CHECK(kBuffer != nullptr);
63     CHECK(kToken != nullptr);
64 }
65 
getToken() const66 Request::MemoryDomainToken Buffer::getToken() const {
67     return Request::MemoryDomainToken{kToken->get()};
68 }
69 
copyTo(const SharedMemory & dst) const70 GeneralResult<void> Buffer::copyTo(const SharedMemory& dst) const {
71     const auto dstPool = RunTimePoolInfo::createFromMemory(dst);
72     if (!dstPool.has_value()) {
73         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
74                << "SampleBuffer::copyTo -- unable to map dst memory.";
75     }
76 
77     const ErrorStatus validationStatus = kBuffer->validateCopyTo(dstPool->getSize());
78     if (validationStatus != ErrorStatus::NONE) {
79         return NN_ERROR(validationStatus);
80     }
81 
82     const auto srcPool = kBuffer->createRunTimePoolInfo();
83     copyRunTimePoolInfos(srcPool, dstPool.value());
84 
85     return {};
86 }
87 
copyFrom(const SharedMemory & src,const Dimensions & dimensions) const88 GeneralResult<void> Buffer::copyFrom(const SharedMemory& src, const Dimensions& dimensions) const {
89     if (const auto result = copyFromInternal(src, dimensions, kBuffer); !result.ok()) {
90         kBuffer->setInitialized(false);
91         NN_TRY(result);
92     }
93 
94     kBuffer->updateDimensions(dimensions);
95     kBuffer->setInitialized(true);
96 
97     return {};
98 }
99 
100 }  // namespace android::nn::sample
101