• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Burst.h"
18 
19 #include "Conversions.h"
20 #include "Utils.h"
21 
22 #include <android-base/logging.h>
23 #include <android/binder_auto_utils.h>
24 #include <nnapi/IBurst.h>
25 #include <nnapi/IExecution.h>
26 #include <nnapi/Result.h>
27 #include <nnapi/TypeUtils.h>
28 #include <nnapi/Types.h>
29 
30 #include <memory>
31 #include <mutex>
32 #include <optional>
33 #include <utility>
34 
35 namespace aidl::android::hardware::neuralnetworks::utils {
36 namespace {
37 
38 class BurstExecution final : public nn::IExecution,
39                              public std::enable_shared_from_this<BurstExecution> {
40     struct PrivateConstructorTag {};
41 
42   public:
43     static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
44             std::shared_ptr<const Burst> burst, Request request,
45             std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
46             const std::vector<nn::TokenValuePair>& hints,
47             const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
48             hal::utils::RequestRelocation relocation,
49             std::vector<Burst::OptionalCacheHold> cacheHolds);
50 
51     BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> burst, Request request,
52                    std::vector<int64_t> memoryIdentifierTokens, bool measure,
53                    int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
54                    const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
55                    hal::utils::RequestRelocation relocation,
56                    std::vector<Burst::OptionalCacheHold> cacheHolds);
57 
58     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
59             const nn::OptionalTimePoint& deadline) const override;
60 
61     nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> computeFenced(
62             const std::vector<nn::SyncFence>& waitFor, const nn::OptionalTimePoint& deadline,
63             const nn::OptionalDuration& timeoutDurationAfterFence) const override;
64 
65   private:
66     const std::shared_ptr<const Burst> kBurst;
67     const Request kRequest;
68     const std::vector<int64_t> kMemoryIdentifierTokens;
69     const bool kMeasure;
70     const int64_t kLoopTimeoutDuration;
71     const std::vector<nn::TokenValuePair> kHints;
72     const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix;
73     const hal::utils::RequestRelocation kRelocation;
74     const std::vector<Burst::OptionalCacheHold> kCacheHolds;
75 };
76 
convertExecutionResults(const std::vector<OutputShape> & outputShapes,const Timing & timing)77 nn::GeneralResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> convertExecutionResults(
78         const std::vector<OutputShape>& outputShapes, const Timing& timing) {
79     return std::make_pair(NN_TRY(nn::convert(outputShapes)), NN_TRY(nn::convert(timing)));
80 }
81 
82 }  // namespace
83 
MemoryCache(std::shared_ptr<aidl_hal::IBurst> burst)84 Burst::MemoryCache::MemoryCache(std::shared_ptr<aidl_hal::IBurst> burst)
85     : kBurst(std::move(burst)) {}
86 
getOrCacheMemory(const nn::SharedMemory & memory)87 std::pair<int64_t, Burst::MemoryCache::SharedCleanup> Burst::MemoryCache::getOrCacheMemory(
88         const nn::SharedMemory& memory) {
89     std::lock_guard lock(mMutex);
90 
91     // Get the cache payload or create it (with default values) if it does not exist.
92     auto& cachedPayload = mCache[memory];
93     {
94         const auto& [identifier, maybeCleaner] = cachedPayload;
95         // If cache payload already exists, reuse it.
96         if (auto cleaner = maybeCleaner.lock()) {
97             return std::make_pair(identifier, std::move(cleaner));
98         }
99     }
100 
101     // If the code reaches this point, the cached payload either did not exist or expired prior to
102     // this call.
103 
104     // Allocate a new identifier.
105     CHECK_LT(mUnusedIdentifier, std::numeric_limits<int64_t>::max());
106     const int64_t identifier = mUnusedIdentifier++;
107 
108     // Create reference-counted self-cleaning cache object.
109     auto self = weak_from_this();
110     Task cleanup = [memory, identifier, maybeMemoryCache = std::move(self)] {
111         if (const auto memoryCache = maybeMemoryCache.lock()) {
112             memoryCache->tryFreeMemory(memory, identifier);
113         }
114     };
115     auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
116 
117     // Store the result in the cache and return it.
118     auto result = std::make_pair(identifier, std::move(cleaner));
119     cachedPayload = result;
120     return result;
121 }
122 
123 std::optional<std::pair<int64_t, Burst::MemoryCache::SharedCleanup>>
getMemoryIfAvailable(const nn::SharedMemory & memory)124 Burst::MemoryCache::getMemoryIfAvailable(const nn::SharedMemory& memory) {
125     std::lock_guard lock(mMutex);
126 
127     // Get the existing cached entry if it exists.
128     const auto iter = mCache.find(memory);
129     if (iter != mCache.end()) {
130         const auto& [identifier, maybeCleaner] = iter->second;
131         if (auto cleaner = maybeCleaner.lock()) {
132             return std::make_pair(identifier, std::move(cleaner));
133         }
134     }
135 
136     // If the code reaches this point, the cached payload did not exist or was actively being
137     // deleted.
138     return std::nullopt;
139 }
140 
tryFreeMemory(const nn::SharedMemory & memory,int64_t identifier)141 void Burst::MemoryCache::tryFreeMemory(const nn::SharedMemory& memory, int64_t identifier) {
142     {
143         std::lock_guard guard(mMutex);
144         // Remove the cached memory and payload if it is present but expired. Note that it may not
145         // be present or may not be expired because another thread may have removed or cached the
146         // same memory object before the current thread locked mMutex in tryFreeMemory.
147         const auto iter = mCache.find(memory);
148         if (iter != mCache.end()) {
149             if (std::get<WeakCleanup>(iter->second).expired()) {
150                 mCache.erase(iter);
151             }
152         }
153     }
154     kBurst->releaseMemoryResource(identifier);
155 }
156 
create(std::shared_ptr<aidl_hal::IBurst> burst,nn::Version featureLevel)157 nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
158         std::shared_ptr<aidl_hal::IBurst> burst, nn::Version featureLevel) {
159     if (burst == nullptr) {
160         return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
161                << "aidl_hal::utils::Burst::create must have non-null burst";
162     }
163 
164     return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst), featureLevel);
165 }
166 
Burst(PrivateConstructorTag,std::shared_ptr<aidl_hal::IBurst> burst,nn::Version featureLevel)167 Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst,
168              nn::Version featureLevel)
169     : kBurst(std::move(burst)),
170       kMemoryCache(std::make_shared<MemoryCache>(kBurst)),
171       kFeatureLevel(featureLevel) {
172     CHECK(kBurst != nullptr);
173 }
174 
cacheMemory(const nn::SharedMemory & memory) const175 Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) const {
176     auto [identifier, hold] = kMemoryCache->getOrCacheMemory(memory);
177     return hold;
178 }
179 
execute(const nn::Request & request,nn::MeasureTiming measure,const nn::OptionalTimePoint & deadline,const nn::OptionalDuration & loopTimeoutDuration,const std::vector<nn::TokenValuePair> & hints,const std::vector<nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const180 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
181         const nn::Request& request, nn::MeasureTiming measure,
182         const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
183         const std::vector<nn::TokenValuePair>& hints,
184         const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
185     // Ensure that request is ready for IPC.
186     std::optional<nn::Request> maybeRequestInShared;
187     hal::utils::RequestRelocation relocation;
188     const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
189             &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
190             &maybeRequestInShared, &relocation));
191 
192     const auto aidlRequest = NN_TRY(convert(requestInShared));
193     const auto aidlMeasure = NN_TRY(convert(measure));
194     const auto aidlDeadline = NN_TRY(convert(deadline));
195     const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
196 
197     std::vector<int64_t> memoryIdentifierTokens;
198     std::vector<OptionalCacheHold> holds;
199     memoryIdentifierTokens.reserve(requestInShared.pools.size());
200     holds.reserve(requestInShared.pools.size());
201     for (const auto& memoryPool : requestInShared.pools) {
202         if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
203             if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
204                 auto& [identifier, hold] = *cached;
205                 memoryIdentifierTokens.push_back(identifier);
206                 holds.push_back(std::move(hold));
207                 continue;
208             }
209         }
210         memoryIdentifierTokens.push_back(-1);
211     }
212     CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
213     return executeInternal(aidlRequest, memoryIdentifierTokens, aidlMeasure, aidlDeadline,
214                            aidlLoopTimeoutDuration, hints, extensionNameToPrefix, relocation);
215 }
216 
executeInternal(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measure,int64_t deadline,int64_t loopTimeoutDuration,const std::vector<nn::TokenValuePair> & hints,const std::vector<nn::ExtensionNameAndPrefix> & extensionNameToPrefix,const hal::utils::RequestRelocation & relocation) const217 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
218         const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measure,
219         int64_t deadline, int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
220         const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
221         const hal::utils::RequestRelocation& relocation) const {
222     // Ensure that at most one execution is in flight at any given time.
223     const bool alreadyInFlight = mExecutionInFlight.test_and_set();
224     if (alreadyInFlight) {
225         return NN_ERROR() << "IBurst already has an execution in flight";
226     }
227     const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
228 
229     if (relocation.input) {
230         relocation.input->flush();
231     }
232 
233     ExecutionResult executionResult;
234     if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
235         auto aidlHints = NN_TRY(convert(hints));
236         auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
237         const auto ret = kBurst->executeSynchronouslyWithConfig(
238                 request, memoryIdentifierTokens,
239                 {measure, loopTimeoutDuration, std::move(aidlHints),
240                  std::move(aidlExtensionPrefix)},
241                 deadline, &executionResult);
242         HANDLE_ASTATUS(ret) << "execute failed";
243     } else {
244         const auto ret =
245                 kBurst->executeSynchronously(request, memoryIdentifierTokens, measure, deadline,
246                                              loopTimeoutDuration, &executionResult);
247         HANDLE_ASTATUS(ret) << "execute failed";
248     }
249     if (!executionResult.outputSufficientSize) {
250         auto canonicalOutputShapes =
251                 nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
252         return NN_ERROR(nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, std::move(canonicalOutputShapes))
253                << "execution failed with " << nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
254     }
255     auto [outputShapes, timing] =
256             NN_TRY(convertExecutionResults(executionResult.outputShapes, executionResult.timing));
257 
258     if (relocation.output) {
259         relocation.output->flush();
260     }
261     return std::make_pair(std::move(outputShapes), timing);
262 }
263 
createReusableExecution(const nn::Request & request,nn::MeasureTiming measure,const nn::OptionalDuration & loopTimeoutDuration,const std::vector<nn::TokenValuePair> & hints,const std::vector<nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const264 nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
265         const nn::Request& request, nn::MeasureTiming measure,
266         const nn::OptionalDuration& loopTimeoutDuration,
267         const std::vector<nn::TokenValuePair>& hints,
268         const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
269     // Ensure that request is ready for IPC.
270     std::optional<nn::Request> maybeRequestInShared;
271     hal::utils::RequestRelocation relocation;
272     const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
273             &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
274             &maybeRequestInShared, &relocation));
275 
276     auto aidlRequest = NN_TRY(convert(requestInShared));
277     const auto aidlMeasure = NN_TRY(convert(measure));
278     const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
279 
280     std::vector<int64_t> memoryIdentifierTokens;
281     std::vector<OptionalCacheHold> holds;
282     memoryIdentifierTokens.reserve(requestInShared.pools.size());
283     holds.reserve(requestInShared.pools.size());
284     for (const auto& memoryPool : requestInShared.pools) {
285         if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
286             if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
287                 auto& [identifier, hold] = *cached;
288                 memoryIdentifierTokens.push_back(identifier);
289                 holds.push_back(std::move(hold));
290                 continue;
291             }
292         }
293         memoryIdentifierTokens.push_back(-1);
294     }
295     CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
296 
297     return BurstExecution::create(shared_from_this(), std::move(aidlRequest),
298                                   std::move(memoryIdentifierTokens), aidlMeasure,
299                                   aidlLoopTimeoutDuration, hints, extensionNameToPrefix,
300                                   std::move(relocation), std::move(holds));
301 }
302 
create(std::shared_ptr<const Burst> burst,Request request,std::vector<int64_t> memoryIdentifierTokens,bool measure,int64_t loopTimeoutDuration,const std::vector<nn::TokenValuePair> & hints,const std::vector<nn::ExtensionNameAndPrefix> & extensionNameToPrefix,hal::utils::RequestRelocation relocation,std::vector<Burst::OptionalCacheHold> cacheHolds)303 nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
304         std::shared_ptr<const Burst> burst, Request request,
305         std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
306         const std::vector<nn::TokenValuePair>& hints,
307         const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
308         hal::utils::RequestRelocation relocation,
309         std::vector<Burst::OptionalCacheHold> cacheHolds) {
310     if (burst == nullptr) {
311         return NN_ERROR() << "aidl::utils::BurstExecution::create must have non-null burst";
312     }
313 
314     return std::make_shared<const BurstExecution>(
315             PrivateConstructorTag{}, std::move(burst), std::move(request),
316             std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, hints,
317             extensionNameToPrefix, std::move(relocation), std::move(cacheHolds));
318 }
319 
BurstExecution(PrivateConstructorTag,std::shared_ptr<const Burst> burst,Request request,std::vector<int64_t> memoryIdentifierTokens,bool measure,int64_t loopTimeoutDuration,const std::vector<nn::TokenValuePair> & hints,const std::vector<nn::ExtensionNameAndPrefix> & extensionNameToPrefix,hal::utils::RequestRelocation relocation,std::vector<Burst::OptionalCacheHold> cacheHolds)320 BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<const Burst> burst,
321                                Request request, std::vector<int64_t> memoryIdentifierTokens,
322                                bool measure, int64_t loopTimeoutDuration,
323                                const std::vector<nn::TokenValuePair>& hints,
324                                const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
325                                hal::utils::RequestRelocation relocation,
326                                std::vector<Burst::OptionalCacheHold> cacheHolds)
327     : kBurst(std::move(burst)),
328       kRequest(std::move(request)),
329       kMemoryIdentifierTokens(std::move(memoryIdentifierTokens)),
330       kMeasure(measure),
331       kLoopTimeoutDuration(loopTimeoutDuration),
332       kHints(hints),
333       kExtensionNameToPrefix(extensionNameToPrefix),
334       kRelocation(std::move(relocation)),
335       kCacheHolds(std::move(cacheHolds)) {}
336 
compute(const nn::OptionalTimePoint & deadline) const337 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstExecution::compute(
338         const nn::OptionalTimePoint& deadline) const {
339     const auto aidlDeadline = NN_TRY(convert(deadline));
340     return kBurst->executeInternal(kRequest, kMemoryIdentifierTokens, kMeasure, aidlDeadline,
341                                    kLoopTimeoutDuration, kHints, kExtensionNameToPrefix,
342                                    kRelocation);
343 }
344 
345 nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
computeFenced(const std::vector<nn::SyncFence> &,const nn::OptionalTimePoint &,const nn::OptionalDuration &) const346 BurstExecution::computeFenced(const std::vector<nn::SyncFence>& /*waitFor*/,
347                               const nn::OptionalTimePoint& /*deadline*/,
348                               const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
349     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
350            << "IExecution::computeFenced is not supported on burst object";
351 }
352 
353 }  // namespace aidl::android::hardware::neuralnetworks::utils
354