• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ShimPreparedModel.h"
18 
19 #include <aidl/android/hardware/neuralnetworks/BnBurst.h>
20 #include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
21 #include <aidl/android/hardware/neuralnetworks/ErrorStatus.h>
22 #include <aidl/android/hardware/neuralnetworks/OutputShape.h>
23 #include <aidl/android/hardware/neuralnetworks/RequestMemoryPool.h>
24 #include <android-base/chrono_utils.h>
25 #include <android-base/logging.h>
26 #include <android-base/scopeguard.h>
27 #include <android/binder_auto_utils.h>
28 #include <nnapi/TypeUtils.h>
29 #include <nnapi/hal/aidl/Conversions.h>
30 
31 #include <algorithm>
32 #include <chrono>
33 #include <limits>
34 #include <memory>
35 #include <thread>
36 #include <utility>
37 #include <vector>
38 
39 #include "ShimConverter.h"
40 #include "ShimUtils.h"
41 
42 namespace aidl::android::hardware::neuralnetworks {
43 
parseInputs(const Request & request,bool measure,int64_t deadlineNs,int64_t loopTimeoutDurationNs,::android::nn::sl_wrapper::Execution * execution,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> * requestMemoryPools)44 ErrorStatus ShimPreparedModel::parseInputs(
45         const Request& request, bool measure, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
46         ::android::nn::sl_wrapper::Execution* execution,
47         std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>>* requestMemoryPools) {
48     for (const auto& requestPool : request.pools) {
49         switch (requestPool.getTag()) {
50             case RequestMemoryPool::pool: {
51                 const auto& memoryPool = requestPool.get<RequestMemoryPool::pool>();
52                 std::shared_ptr<::android::nn::sl_wrapper::Memory> mem =
53                         convertFromHAL(mNnapi.get(), memoryPool);
54                 if (!mem) {
55                     LOG(ERROR) << "Failed to convert request HAL memory pools into SL memory";
56                     return ErrorStatus::INVALID_ARGUMENT;
57                 }
58 
59                 requestMemoryPools->push_back(mem);
60                 break;
61             }
62             case RequestMemoryPool::token: {
63                 int token = requestPool.get<RequestMemoryPool::token>();
64 
65                 auto memory = mBufferTracker->get(static_cast<uint32_t>(token));
66                 if (memory == nullptr) {
67                     return ErrorStatus::INVALID_ARGUMENT;
68                 }
69 
70                 requestMemoryPools->push_back(memory);
71                 break;
72             }
73         }
74     }
75 
76     const auto& model = mMainAndReferencedModels[0];
77     // set inputs
78     for (int i = 0; i < request.inputs.size(); ++i) {
79         const auto& input = request.inputs[i];
80         ::android::nn::wrapper::OperandType operandType = model.getOperands()[model.getInputs()[i]];
81         if (!input.hasNoValue) {
82             if (input.dimensions.size() > 0) {
83                 operandType.updateDimensions(::android::nn::toUnsigned(input.dimensions).value());
84             }
85             auto result = execution->setInputFromMemory(
86                     i, requestMemoryPools->at(input.location.poolIndex).get(),
87                     input.location.offset, input.location.length, &operandType.operandType);
88             if (result != Result::NO_ERROR) {
89                 return convertResultToErrorStatus(result);
90             }
91         } else {
92             auto result = execution->setInput(i, nullptr, 0);
93             if (result != Result::NO_ERROR) {
94                 return convertResultToErrorStatus(result);
95             }
96         }
97     }
98 
99     // set outputs
100     for (int i = 0; i < request.outputs.size(); ++i) {
101         const auto& output = request.outputs[i];
102         ::android::nn::wrapper::OperandType operandType =
103                 model.getOperands()[model.getOutputs()[i]];
104 
105         if (!output.hasNoValue) {
106             if (output.dimensions.size() > 0) {
107                 operandType.updateDimensions(::android::nn::toUnsigned(output.dimensions).value());
108             }
109             auto result = execution->setOutputFromMemory(
110                     i, requestMemoryPools->at(output.location.poolIndex).get(),
111                     output.location.offset, output.location.length, &operandType.operandType);
112             if (result != Result::NO_ERROR) {
113                 return convertResultToErrorStatus(result);
114             }
115         } else {
116             auto result = execution->setOutput(i, nullptr, 0);
117             if (result != Result::NO_ERROR) {
118                 return convertResultToErrorStatus(result);
119             }
120         }
121     }
122 
123     if (measure) {
124         execution->setMeasureTiming(true);
125     }
126 
127     if (deadlineNs > -1) {
128         std::chrono::time_point<::android::base::boot_clock> deadlinePoint(
129                 std::chrono::nanoseconds{deadlineNs});
130         const auto currentTime = ::android::base::boot_clock::now();
131         const auto timeoutDuration = std::chrono::nanoseconds(deadlinePoint - currentTime);
132         if (timeoutDuration <= std::chrono::nanoseconds::zero()) {
133             return ErrorStatus::MISSED_DEADLINE_TRANSIENT;
134         } else {
135             auto result = execution->setTimeout(std::max<uint64_t>(1, timeoutDuration.count()));
136             if (result != Result::NO_ERROR) {
137                 return convertResultToErrorStatus(result);
138             }
139         }
140     }
141 
142     if (loopTimeoutDurationNs > 0) {
143         execution->setLoopTimeout(loopTimeoutDurationNs);
144     }
145     return ErrorStatus::NONE;
146 }
147 
148 class ShimFencedExecutionCallback : public BnFencedExecutionCallback {
149    public:
ShimFencedExecutionCallback(::android::nn::sl_wrapper::Execution execution,Event e,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,bool measureTiming)150     ShimFencedExecutionCallback(
151             ::android::nn::sl_wrapper::Execution execution, Event e,
152             std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,
153             bool measureTiming)
154         : mMemoryPools(std::move(memoryPools)),
155           mExecution(std::move(execution)),
156           mEvent(std::move(e)),
157           mMeasureTiming(measureTiming) {}
158 
getExecutionInfo(Timing * timingLaunched,Timing * timingFenced,ErrorStatus * errorStatus)159     ndk::ScopedAStatus getExecutionInfo(Timing* timingLaunched, Timing* timingFenced,
160                                         ErrorStatus* errorStatus) override {
161         auto status = mEvent.wait();
162         *errorStatus = convertResultToErrorStatus(status);
163 
164         if (mMeasureTiming) {
165             uint64_t duration;
166             constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
167             // Special value used for "no measurements"
168             constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
169             auto result = mExecution.getDuration(Duration::ON_HARDWARE, &duration);
170             SLW2SAS_RETURN_IF_ERROR(result);
171             timingLaunched->timeOnDeviceNs =
172                     (duration == uint64cap)
173                             ? -1
174                             : (duration > int64cap) ? int64cap : static_cast<int64_t>(duration);
175 
176             result = mExecution.getDuration(Duration::IN_DRIVER, &duration);
177             SLW2SAS_RETURN_IF_ERROR(result);
178             timingLaunched->timeInDriverNs =
179                     (duration == uint64cap)
180                             ? -1
181                             : (duration > int64cap) ? int64cap : static_cast<int64_t>(duration);
182 
183             result = mExecution.getDuration(Duration::FENCED_ON_HARDWARE, &duration);
184             SLW2SAS_RETURN_IF_ERROR(result);
185             timingFenced->timeOnDeviceNs =
186                     (duration == uint64cap)
187                             ? -1
188                             : (duration > int64cap) ? int64cap : static_cast<int64_t>(duration);
189 
190             result = mExecution.getDuration(Duration::FENCED_IN_DRIVER, &duration);
191             SLW2SAS_RETURN_IF_ERROR(result);
192             timingFenced->timeInDriverNs =
193                     (duration == uint64cap)
194                             ? -1
195                             : (duration > int64cap) ? int64cap : static_cast<int64_t>(duration);
196         } else {
197             timingFenced->timeOnDeviceNs = -1;
198             timingFenced->timeInDriverNs = -1;
199             timingLaunched->timeOnDeviceNs = -1;
200             timingLaunched->timeInDriverNs = -1;
201         }
202 
203         return ndk::ScopedAStatus::ok();
204     }
205 
206    private:
207     std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> mMemoryPools;
208     ::android::nn::sl_wrapper::Execution mExecution;
209     ::android::nn::wrapper::Event mEvent;
210     bool mMeasureTiming;
211 };
212 
executeFenced(const::aidl::android::hardware::neuralnetworks::Request & request,const std::vector<::ndk::ScopedFileDescriptor> & waitFor,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,int64_t durationNs,FencedExecutionResult * fencedExecutionResult)213 ::ndk::ScopedAStatus ShimPreparedModel::executeFenced(
214         const ::aidl::android::hardware::neuralnetworks::Request& request,
215         const std::vector<::ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
216         int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
217         FencedExecutionResult* fencedExecutionResult) {
218     if (deadlineNs < -1) {
219         LOG(ERROR) << "Invalid deadline value, must be >= -1";
220         return ndk::ScopedAStatus::fromServiceSpecificError(
221                 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
222     }
223     auto execution = ::android::nn::sl_wrapper::Execution(mNnapi.get(), &mCompilation);
224     std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
225     auto errorStatus = parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs,
226                                    &execution, &requestMemoryPools);
227     if (errorStatus != ErrorStatus::NONE) {
228         return toAStatus(errorStatus);
229     }
230 
231     std::vector<const ANeuralNetworksEvent*> deps(waitFor.size());
232     auto createResult = Result::NO_ERROR;
233     std::transform(waitFor.begin(), waitFor.end(), deps.begin(),
234                    [&](const ::ndk::ScopedFileDescriptor& e) {
235                        ANeuralNetworksEvent* r = nullptr;
236                        if (createResult == Result::NO_ERROR) {
237                            createResult = static_cast<Result>(
238                                    mNnapi->ANeuralNetworksEvent_createFromSyncFenceFd(e.get(), &r));
239                        }
240                        return r;
241                    });
242 
243     const auto guard = ::android::base::make_scope_guard([this, deps] {
244         for (auto& dep : deps) {
245             if (dep != nullptr) {
246                 mNnapi->ANeuralNetworksEvent_free(const_cast<ANeuralNetworksEvent*>(dep));
247             }
248         }
249     });
250 
251     SLW2SAS_RETURN_IF_ERROR(createResult);
252 
253     Event e(mNnapi.get());
254     auto result = execution.startComputeWithDependencies(deps, durationNs, &e);
255     SLW2SAS_RETURN_IF_ERROR(result);
256 
257     int syncFence = -1;
258     fencedExecutionResult->syncFence = ndk::ScopedFileDescriptor(
259             (e.getSyncFenceFd(&syncFence) == Result::NO_ERROR) ? syncFence : -1);
260     fencedExecutionResult->callback = ndk::SharedRefBase::make<ShimFencedExecutionCallback>(
261             std::move(execution), std::move(e), std::move(requestMemoryPools), measureTiming);
262 
263     return ndk::ScopedAStatus::ok();
264 }
265 
executeSynchronously(const Request & request,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,::aidl::android::hardware::neuralnetworks::ExecutionResult * executionResult)266 ::ndk::ScopedAStatus ShimPreparedModel::executeSynchronously(
267         const Request& request, bool measureTiming, int64_t deadlineNs,
268         int64_t loopTimeoutDurationNs,
269         ::aidl::android::hardware::neuralnetworks::ExecutionResult* executionResult) {
270     if (deadlineNs < -1) {
271         LOG(ERROR) << "Invalid deadline value, must be >= -1";
272         return ndk::ScopedAStatus::fromServiceSpecificError(
273                 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
274     }
275 
276     auto execution =
277             std::make_unique<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
278     std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
279     auto errorStatus = parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs,
280                                    execution.get(), &requestMemoryPools);
281     if (errorStatus != ErrorStatus::NONE) {
282         return toAStatus(errorStatus);
283     }
284 
285     auto result = execution->compute();
286     errorStatus = convertResultToErrorStatus(result);
287 
288     int numOutputs = request.outputs.size();
289     std::vector<OutputShape> outputShapes;
290     outputShapes.reserve(numOutputs);
291     bool sufficientSize = true;
292     for (int i = 0; i < numOutputs; ++i) {
293         OutputShape outputShape;
294         std::vector<uint32_t> outputDims;
295         auto result = execution->getOutputOperandDimensions(i, &outputDims);
296         if (result == Result::NO_ERROR) {
297             outputShape.isSufficient = true;
298             outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
299         } else if (result == Result::OUTPUT_INSUFFICIENT_SIZE) {
300             sufficientSize = false;
301             outputShape.isSufficient = false;
302             outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
303         } else {
304             if (errorStatus == ErrorStatus::NONE) {
305                 errorStatus = ErrorStatus::GENERAL_FAILURE;
306             }
307         }
308         outputShapes.push_back(std::move(outputShape));
309     }
310 
311     int64_t timeOnDeviceNs = -1;
312     int64_t timeInDriverNs = -1;
313     if (measureTiming && errorStatus == ErrorStatus::NONE) {
314         uint64_t duration;
315         constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
316         // Special value used for "no measurements"
317         constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
318         auto result = execution->getDuration(Duration::ON_HARDWARE, &duration);
319         SLW2SAS_RETURN_IF_ERROR(result);
320         timeOnDeviceNs =
321                 (duration == uint64cap)
322                         ? -1
323                         : (duration > int64cap) ? int64cap : static_cast<int64_t>(duration);
324 
325         result = execution->getDuration(Duration::IN_DRIVER, &duration);
326         SLW2SAS_RETURN_IF_ERROR(result);
327         timeInDriverNs =
328                 (duration == uint64cap)
329                         ? -1
330                         : (duration > int64cap) ? int64cap : static_cast<int64_t>(duration);
331     }
332 
333     *executionResult =
334             ExecutionResult{sufficientSize,
335                             std::move(outputShapes),
336                             {.timeOnDeviceNs = timeOnDeviceNs, .timeInDriverNs = timeInDriverNs}};
337     if (errorStatus == ErrorStatus::NONE || errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
338         return ndk::ScopedAStatus::ok();
339     }
340     return toAStatus(errorStatus);
341 }
342 
343 // TODO(183397380): make it use ANNBurst object
344 class ShimBurst : public BnBurst {
345    public:
346     // Precondition: preparedModel != nullptr
347     explicit ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel);
348 
349     ndk::ScopedAStatus executeSynchronously(const Request& request,
350                                             const std::vector<int64_t>& memoryIdentifierTokens,
351                                             bool measureTiming, int64_t deadlineNs,
352                                             int64_t loopTimeoutDurationNs,
353                                             ExecutionResult* executionResult) override;
354     ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
355 
356    protected:
357     std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
358     const std::shared_ptr<ShimPreparedModel> kPreparedModel;
359 };
360 
configureExecutionBurst(std::shared_ptr<IBurst> * burst)361 ndk::ScopedAStatus ShimPreparedModel::configureExecutionBurst(std::shared_ptr<IBurst>* burst) {
362     std::shared_ptr<ShimPreparedModel> self = this->template ref<ShimPreparedModel>();
363     *burst = ndk::SharedRefBase::make<ShimBurst>(std::move(self));
364     return ndk::ScopedAStatus::ok();
365 }
366 
ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel)367 ShimBurst::ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel)
368     : kPreparedModel(std::move(preparedModel)) {
369     CHECK(kPreparedModel != nullptr);
370 }
371 
executeSynchronously(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,ExecutionResult * executionResult)372 ndk::ScopedAStatus ShimBurst::executeSynchronously(
373         const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
374         bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
375         ExecutionResult* executionResult) {
376     if (request.pools.size() != memoryIdentifierTokens.size()) {
377         return toAStatus(ErrorStatus::INVALID_ARGUMENT,
378                          "request.pools.size() != memoryIdentifierTokens.size()");
379     }
380     if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
381                      [](int64_t token) { return token >= -1; })) {
382         return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierTokens");
383     }
384 
385     // Ensure at most one execution is in flight at a time.
386     const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
387     if (executionAlreadyInFlight) {
388         return toAStatus(ErrorStatus::GENERAL_FAILURE,
389                          "Burst object supports at most one execution at a time");
390     }
391     const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
392 
393     return kPreparedModel->executeSynchronously(request, measureTiming, deadlineNs,
394                                                 loopTimeoutDurationNs, executionResult);
395 }
396 
releaseMemoryResource(int64_t memoryIdentifierToken)397 ndk::ScopedAStatus ShimBurst::releaseMemoryResource(int64_t memoryIdentifierToken) {
398     if (memoryIdentifierToken < -1) {
399         return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierToken");
400     }
401     return ndk::ScopedAStatus::ok();
402 }
403 
404 }  // namespace aidl::android::hardware::neuralnetworks
405