1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <ValidateHal.h>
20 #include <gtest/gtest.h>
21
22 #include <algorithm>
23 #include <atomic>
24 #include <cassert>
25 #include <memory>
26 #include <string>
27 #include <thread>
28 #include <tuple>
29 #include <vector>
30
31 #include "CompilationBuilder.h"
32 #include "ExecutionBurstServer.h"
33 #include "ExecutionCallback.h"
34 #include "HalUtils.h"
35 #include "Manager.h"
36 #include "ModelBuilder.h"
37 #include "NeuralNetworks.h"
38 #include "PreparedModelCallback.h"
39 #include "TestNeuralNetworksWrapper.h"
40
41 namespace android {
42
43 namespace V1_0 = ::android::hardware::neuralnetworks::V1_0;
44 namespace V1_1 = ::android::hardware::neuralnetworks::V1_1;
45 namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
46 namespace V1_3 = ::android::hardware::neuralnetworks::V1_3;
47 using CompilationBuilder = nn::CompilationBuilder;
48 using Device = nn::Device;
49 using SharedDevice = nn::SharedDevice;
50 using DeviceManager = nn::DeviceManager;
51 using HidlModel = V1_3::Model;
52 using PreparedModelCallback = nn::PreparedModelCallback;
53 using SampleDriver = nn::sample_driver::SampleDriver;
54 using WrapperCompilation = nn::test_wrapper::Compilation;
55 using WrapperEvent = nn::test_wrapper::Event;
56 using WrapperExecution = nn::test_wrapper::Execution;
57 using WrapperModel = nn::test_wrapper::Model;
58 using WrapperOperandType = nn::test_wrapper::OperandType;
59 using WrapperResult = nn::test_wrapper::Result;
60 using WrapperType = nn::test_wrapper::Type;
61 using nn::convertToV1_0;
62 using nn::convertToV1_3;
63 using nn::ErrorStatus;
64
65 template <typename T>
66 using MQDescriptorSync = hardware::MQDescriptorSync<T>;
67
68 namespace {
69
70 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
71
72 // Wraps the latest version of IPreparedModel to allow dummying up the execution status,
73 // and control when the execution finishes.
74 class TestPreparedModelLatest : public V1_3::IPreparedModel {
75 public:
76 // If errorStatus is NONE, then execute behaves normally (and sends back
77 // the actual execution status). Otherwise, don't bother to execute, and
78 // just send back errorStatus (as the execution status, not the launch
79 // status).
TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)80 TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
81 : mPreparedModelV1_0(preparedModel),
82 mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
83 mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
84 mErrorStatus(errorStatus) {}
85
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)86 hardware::Return<V1_0::ErrorStatus> execute(
87 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
88 CHECK(mPreparedModelV1_0 != nullptr) << "V1_0 prepared model is nullptr.";
89 std::thread([this, request, callback] {
90 dummyExecution();
91 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
92 // Note that we lose the actual launch status.
93 (void)mPreparedModelV1_0->execute(request, callback);
94 } else {
95 callback->notify(convertToV1_0(mErrorStatus));
96 }
97 }).detach();
98 return V1_0::ErrorStatus::NONE;
99 }
100
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)101 hardware::Return<V1_0::ErrorStatus> execute_1_2(
102 const V1_0::Request& request, V1_2::MeasureTiming measure,
103 const sp<V1_2::IExecutionCallback>& callback) override {
104 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
105 std::thread([this, request, measure, callback] {
106 dummyExecution();
107 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
108 // Note that we lose the actual launch status.
109 (void)mPreparedModelV1_2->execute_1_2(request, measure, callback);
110 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
111 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
112 callback->notify_1_2(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
113 } else {
114 callback->notify_1_2(convertToV1_0(mErrorStatus), {}, kBadTiming);
115 }
116 }).detach();
117 return V1_0::ErrorStatus::NONE;
118 }
119
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)120 hardware::Return<V1_3::ErrorStatus> execute_1_3(
121 const V1_3::Request& request, V1_2::MeasureTiming measure,
122 const V1_3::OptionalTimePoint& deadline,
123 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
124 const sp<V1_3::IExecutionCallback>& callback) override {
125 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
126 std::thread([this, request, measure, deadline, loopTimeoutDuration, callback] {
127 dummyExecution();
128 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
129 // Note that we lose the actual launch status.
130 (void)mPreparedModelV1_3->execute_1_3(request, measure, deadline,
131 loopTimeoutDuration, callback);
132 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
133 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
134 callback->notify_1_3(mErrorStatus, {shape}, kBadTiming);
135 } else {
136 callback->notify_1_3(mErrorStatus, {}, kBadTiming);
137 }
138 }).detach();
139 return V1_3::ErrorStatus::NONE;
140 }
141
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)142 hardware::Return<void> executeSynchronously(const V1_0::Request& request,
143 V1_2::MeasureTiming measure,
144 executeSynchronously_cb cb) override {
145 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
146 dummyExecution();
147 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
148 return mPreparedModelV1_2->executeSynchronously(request, measure, cb);
149 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
150 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
151 cb(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
152 return hardware::Void();
153 } else {
154 cb(convertToV1_0(mErrorStatus), {}, kBadTiming);
155 return hardware::Void();
156 }
157 }
158
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)159 hardware::Return<void> executeSynchronously_1_3(
160 const V1_3::Request& request, V1_2::MeasureTiming measure,
161 const V1_3::OptionalTimePoint& deadline,
162 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
163 executeSynchronously_1_3_cb cb) override {
164 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
165 dummyExecution();
166 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
167 return mPreparedModelV1_3->executeSynchronously_1_3(request, measure, deadline,
168 loopTimeoutDuration, cb);
169 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
170 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
171 cb(mErrorStatus, {shape}, kBadTiming);
172 return hardware::Void();
173 } else {
174 cb(mErrorStatus, {}, kBadTiming);
175 return hardware::Void();
176 }
177 }
178
179 // ExecutionBurstServer::create has an overload that will use
180 // IPreparedModel::executeSynchronously(), so we can rely on that, rather
181 // than having to implement ExecutionBurstServer::IExecutorWithCache.
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)182 hardware::Return<void> configureExecutionBurst(
183 const sp<V1_2::IBurstCallback>& callback,
184 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
185 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
186 configureExecutionBurst_cb cb) override {
187 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
188 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
189 const sp<V1_2::IBurstContext> burst =
190 nn::ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
191
192 cb(burst == nullptr ? V1_0::ErrorStatus::GENERAL_FAILURE : V1_0::ErrorStatus::NONE,
193 burst);
194 return hardware::Void();
195 } else {
196 cb(convertToV1_0(mErrorStatus), nullptr);
197 return hardware::Void();
198 }
199 }
200
201 // Note, due to the limitation of SampleDriver implementation, the call is
202 // synchronous. The test code that exercises this implementation of
203 // SampleDriver is written with that in mind. Therefore, this
204 // implementation is synchronous also. If the SampleDriver is updated to
205 // return real sync fence, this must be updated.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)206 hardware::Return<void> executeFenced(const V1_3::Request& request,
207 const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
208 V1_2::MeasureTiming measure,
209 const V1_3::OptionalTimePoint& deadline,
210 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
211 const V1_3::OptionalTimeoutDuration& duration,
212 executeFenced_cb cb) override {
213 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
214 CHECK(mErrorStatus != V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE)
215 << "executeFenced does not support dynamic output shape";
216 dummyExecution();
217 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
218 return mPreparedModelV1_3->executeFenced(request, waitFor, measure, deadline,
219 loopTimeoutDuration, duration, cb);
220 } else {
221 // Due to the limitations of the SampleDriver, all failures look
222 // like launch failures. If the SampleDriver is updated to return
223 // real sync fences, this must be updated.
224 cb(mErrorStatus, hardware::hidl_handle(nullptr), nullptr);
225 }
226 return hardware::Void();
227 }
228
229 // We can place the TestPreparedModelLatest system in a "pause" mode where
230 // no execution will complete until the system is taken out of that mode.
231 // Initially, the system is not in that mode.
pauseExecutions(bool v)232 static void pauseExecutions(bool v) { mPauseExecutions.store(v); }
233
234 // This function is only guaranteed to work in the following pattern:
235 // Consider thread A as primary thread
236 // - thread A: pauseExecutions(true);
237 // - thread A: launch execution (as thread B)
238 // - thread A: waitForExecutionToBegin(), block until call to dummyExecution by
239 // thread B makes mExecutionsInFlight nonzero
240 // - thread B: dummyExecution(), which makes mExecutionsInFlight nonzero and blocks
241 // until thread A calls pauseExecutions(false)
242 // - thread A: waitForExecutionToBegin() returns
243 // - thread A: pauseExecutions(false), allowing dummyExecution() on thread B to continue
244 // - thread B: dummyExecution() zeroes mExecutionsInFlight and returns
245 // - thread B: thread exits
waitForExecutionToBegin()246 static void waitForExecutionToBegin() {
247 CHECK(mPauseExecutions.load());
248 while (mExecutionsInFlight.load() == 0) {
249 }
250 }
251
252 private:
253 const sp<V1_0::IPreparedModel> mPreparedModelV1_0;
254 const sp<V1_2::IPreparedModel> mPreparedModelV1_2;
255 const sp<V1_3::IPreparedModel> mPreparedModelV1_3;
256 V1_3::ErrorStatus mErrorStatus;
257
258 static std::atomic<bool> mPauseExecutions;
259 static std::atomic<unsigned int> mExecutionsInFlight;
260
dummyExecution()261 static void dummyExecution() {
262 CHECK_EQ(mExecutionsInFlight.fetch_add(1), 0u) << "We do not support concurrent executions";
263 while (mPauseExecutions.load()) {
264 }
265 mExecutionsInFlight.fetch_sub(1);
266 }
267 };
268 std::atomic<bool> TestPreparedModelLatest::mPauseExecutions = false;
269 std::atomic<unsigned int> TestPreparedModelLatest::mExecutionsInFlight = 0;
270
271 using TestPreparedModel13 = TestPreparedModelLatest;
272
273 // Like TestPreparedModelLatest, but implementing 1.2
274 class TestPreparedModel12 : public V1_2::IPreparedModel {
275 public:
TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)276 TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
277 : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
278
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)279 hardware::Return<V1_0::ErrorStatus> execute(
280 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
281 return mLatestPreparedModel->execute(request, callback);
282 }
283
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)284 hardware::Return<V1_0::ErrorStatus> execute_1_2(
285 const V1_0::Request& request, V1_2::MeasureTiming measure,
286 const sp<V1_2::IExecutionCallback>& callback) override {
287 return mLatestPreparedModel->execute_1_2(request, measure, callback);
288 }
289
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)290 hardware::Return<void> executeSynchronously(const V1_0::Request& request,
291 V1_2::MeasureTiming measure,
292 executeSynchronously_cb cb) override {
293 return mLatestPreparedModel->executeSynchronously(request, measure, cb);
294 }
295
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)296 hardware::Return<void> configureExecutionBurst(
297 const sp<V1_2::IBurstCallback>& callback,
298 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
299 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
300 configureExecutionBurst_cb cb) override {
301 return mLatestPreparedModel->configureExecutionBurst(callback, requestChannel,
302 resultChannel, cb);
303 }
304
305 private:
306 const sp<V1_3::IPreparedModel> mLatestPreparedModel;
307 };
308
309 // Like TestPreparedModelLatest, but implementing 1.0
310 class TestPreparedModel10 : public V1_0::IPreparedModel {
311 public:
TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)312 TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
313 : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
314
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)315 hardware::Return<V1_0::ErrorStatus> execute(
316 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
317 return mLatestPreparedModel->execute(request, callback);
318 }
319
320 private:
321 const sp<V1_3::IPreparedModel> mLatestPreparedModel;
322 };
323
324 // Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
325 class TestDriver13 : public SampleDriver {
326 public:
327 // Allow dummying up the error status for execution of all models
328 // prepared from this driver. If errorStatus is NONE, then
329 // execute behaves normally (and sends back the actual execution
330 // status). Otherwise, don't bother to execute, and just send
331 // back errorStatus (as the execution status, not the launch
332 // status).
TestDriver13(const std::string & name,V1_3::ErrorStatus errorStatus)333 TestDriver13(const std::string& name, V1_3::ErrorStatus errorStatus)
334 : SampleDriver(name.c_str()), mErrorStatus(errorStatus) {}
335
getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb)336 hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb) override {
337 android::nn::initVLogMask();
338 const V1_0::PerformanceInfo kPerf = {.execTime = 0.75f, .powerUsage = 0.75f};
339 V1_3::Capabilities capabilities = {
340 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
341 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
342 .operandPerformance =
343 nn::nonExtensionOperandPerformance<nn::HalVersion::V1_3>(kPerf),
344 .ifPerformance = kPerf,
345 .whilePerformance = kPerf};
346 _hidl_cb(V1_3::ErrorStatus::NONE, capabilities);
347 return hardware::Void();
348 }
349
getSupportedOperations_1_3(const HidlModel & model,getSupportedOperations_1_3_cb cb)350 hardware::Return<void> getSupportedOperations_1_3(const HidlModel& model,
351 getSupportedOperations_1_3_cb cb) override {
352 if (nn::validateModel(model)) {
353 std::vector<bool> supported(model.main.operations.size(), true);
354 cb(V1_3::ErrorStatus::NONE, supported);
355 } else {
356 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
357 }
358 return hardware::Void();
359 }
360
prepareModel_1_3(const HidlModel & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_3::IPreparedModelCallback> & actualCallback)361 hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
362 const HidlModel& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
363 const V1_3::OptionalTimePoint& deadline,
364 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
365 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
366 const nn::HalCacheToken& token,
367 const sp<V1_3::IPreparedModelCallback>& actualCallback) override {
368 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
369 hardware::Return<V1_3::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_3(
370 model, preference, priority, deadline, modelCache, dataCache, token, localCallback);
371 if (!prepareModelReturn.isOkUnchecked()) {
372 return prepareModelReturn;
373 }
374 if (prepareModelReturn != V1_3::ErrorStatus::NONE) {
375 actualCallback->notify_1_3(
376 convertToV1_3(localCallback->getStatus()),
377 V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
378 return prepareModelReturn;
379 }
380 localCallback->wait();
381 if (localCallback->getStatus() != ErrorStatus::NONE) {
382 actualCallback->notify_1_3(
383 convertToV1_3(localCallback->getStatus()),
384 V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
385 } else {
386 actualCallback->notify_1_3(
387 V1_3::ErrorStatus::NONE,
388 new TestPreparedModel13(localCallback->getPreparedModel(), mErrorStatus));
389 }
390 return prepareModelReturn;
391 }
392
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)393 hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
394 const V1_2::Model& model, V1_1::ExecutionPreference preference,
395 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
396 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
397 const nn::HalCacheToken& token,
398 const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
399 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
400 hardware::Return<V1_0::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_2(
401 model, preference, modelCache, dataCache, token, localCallback);
402 if (!prepareModelReturn.isOkUnchecked()) {
403 return prepareModelReturn;
404 }
405 if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
406 actualCallback->notify_1_2(
407 convertToV1_0(localCallback->getStatus()),
408 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
409 return prepareModelReturn;
410 }
411 localCallback->wait();
412 if (localCallback->getStatus() != ErrorStatus::NONE) {
413 actualCallback->notify_1_2(
414 convertToV1_0(localCallback->getStatus()),
415 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
416 } else {
417 actualCallback->notify_1_2(
418 V1_0::ErrorStatus::NONE,
419 new TestPreparedModel12(localCallback->getPreparedModel(), mErrorStatus));
420 }
421 return prepareModelReturn;
422 }
423
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)424 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
425 const V1_1::Model& model, V1_1::ExecutionPreference preference,
426 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
427 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
428 hardware::Return<V1_0::ErrorStatus> prepareModelReturn =
429 SampleDriver::prepareModel_1_1(model, preference, localCallback);
430 if (!prepareModelReturn.isOkUnchecked()) {
431 return prepareModelReturn;
432 }
433 if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
434 actualCallback->notify(convertToV1_0(localCallback->getStatus()),
435 localCallback->getPreparedModel());
436 return prepareModelReturn;
437 }
438 localCallback->wait();
439 if (localCallback->getStatus() != ErrorStatus::NONE) {
440 actualCallback->notify(convertToV1_0(localCallback->getStatus()),
441 localCallback->getPreparedModel());
442 } else {
443 actualCallback->notify(
444 V1_0::ErrorStatus::NONE,
445 new TestPreparedModel10(localCallback->getPreparedModel(), mErrorStatus));
446 }
447 return prepareModelReturn;
448 }
449
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)450 hardware::Return<V1_0::ErrorStatus> prepareModel(
451 const V1_0::Model& model,
452 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
453 return prepareModel_1_1(nn::convertToV1_1(model),
454 V1_1::ExecutionPreference::FAST_SINGLE_ANSWER, actualCallback);
455 }
456
457 private:
458 V1_3::ErrorStatus mErrorStatus;
459 };
460
461 // Like TestDriver, but implementing 1.2
462 class TestDriver12 : public V1_2::IDevice {
463 public:
TestDriver12(const std::string & name,V1_3::ErrorStatus errorStatus)464 TestDriver12(const std::string& name, V1_3::ErrorStatus errorStatus)
465 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb)466 hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
467 return mLatestDriver->getCapabilities_1_2(_hidl_cb);
468 }
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)469 hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
470 return mLatestDriver->getCapabilities_1_1(_hidl_cb);
471 }
getCapabilities(getCapabilities_cb _hidl_cb)472 hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
473 return mLatestDriver->getCapabilities(_hidl_cb);
474 }
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb _hidl_cb)475 hardware::Return<void> getSupportedOperations_1_2(
476 const V1_2::Model& model, getSupportedOperations_1_2_cb _hidl_cb) override {
477 return mLatestDriver->getSupportedOperations_1_2(model, _hidl_cb);
478 }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)479 hardware::Return<void> getSupportedOperations_1_1(
480 const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
481 return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
482 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)483 hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
484 getSupportedOperations_cb _hidl_cb) override {
485 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
486 }
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)487 hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
488 const V1_2::Model& model, V1_1::ExecutionPreference preference,
489 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
490 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
491 const nn::HalCacheToken& token,
492 const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
493 return mLatestDriver->prepareModel_1_2(model, preference, modelCache, dataCache, token,
494 actualCallback);
495 }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)496 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
497 const V1_1::Model& model, V1_1::ExecutionPreference preference,
498 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
499 return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
500 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)501 hardware::Return<V1_0::ErrorStatus> prepareModel(
502 const V1_0::Model& model,
503 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
504 return mLatestDriver->prepareModel(model, actualCallback);
505 }
getStatus()506 hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getVersionString(getVersionString_cb _hidl_cb)507 hardware::Return<void> getVersionString(getVersionString_cb _hidl_cb) override {
508 return mLatestDriver->getVersionString(_hidl_cb);
509 }
getType(getType_cb _hidl_cb)510 hardware::Return<void> getType(getType_cb _hidl_cb) override {
511 return mLatestDriver->getType(_hidl_cb);
512 }
getSupportedExtensions(getSupportedExtensions_cb _hidl_cb)513 hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb _hidl_cb) {
514 return mLatestDriver->getSupportedExtensions(_hidl_cb);
515 }
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb)516 hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb) {
517 return mLatestDriver->getNumberOfCacheFilesNeeded(_hidl_cb);
518 }
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & callback)519 hardware::Return<V1_0::ErrorStatus> prepareModelFromCache(
520 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
521 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
522 const nn::HalCacheToken& token, const sp<V1_2::IPreparedModelCallback>& callback) {
523 return mLatestDriver->prepareModelFromCache(modelCache, dataCache, token, callback);
524 }
525
526 private:
527 const sp<V1_3::IDevice> mLatestDriver;
528 };
529
530 // Like TestDriver, but implementing 1.1
531 class TestDriver11 : public V1_1::IDevice {
532 public:
TestDriver11(const std::string & name,V1_3::ErrorStatus errorStatus)533 TestDriver11(const std::string& name, V1_3::ErrorStatus errorStatus)
534 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)535 hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
536 return mLatestDriver->getCapabilities_1_1(_hidl_cb);
537 }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)538 hardware::Return<void> getSupportedOperations_1_1(
539 const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
540 return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
541 }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)542 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
543 const V1_1::Model& model, V1_1::ExecutionPreference preference,
544 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
545 return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
546 }
getStatus()547 hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getCapabilities(getCapabilities_cb _hidl_cb)548 hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
549 return mLatestDriver->getCapabilities(_hidl_cb);
550 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)551 hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
552 getSupportedOperations_cb _hidl_cb) override {
553 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
554 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)555 hardware::Return<V1_0::ErrorStatus> prepareModel(
556 const V1_0::Model& model,
557 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
558 return mLatestDriver->prepareModel(model, actualCallback);
559 }
560
561 private:
562 const sp<V1_3::IDevice> mLatestDriver;
563 };
564
565 // Like TestDriver, but implementing 1.0
566 class TestDriver10 : public V1_0::IDevice {
567 public:
TestDriver10(const std::string & name,V1_3::ErrorStatus errorStatus)568 TestDriver10(const std::string& name, V1_3::ErrorStatus errorStatus)
569 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities(getCapabilities_cb _hidl_cb)570 hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
571 return mLatestDriver->getCapabilities(_hidl_cb);
572 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)573 hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
574 getSupportedOperations_cb _hidl_cb) override {
575 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
576 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)577 hardware::Return<V1_0::ErrorStatus> prepareModel(
578 const V1_0::Model& model,
579 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
580 return mLatestDriver->prepareModel(model, actualCallback);
581 }
getStatus()582 hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
583
584 private:
585 const sp<V1_3::IDevice> mLatestDriver;
586 };
587
588 // This class adds some simple utilities on top of WrapperCompilation in order
589 // to provide access to certain features from CompilationBuilder that are not
590 // exposed by the base class.
591 template <typename DriverClass>
592 class TestCompilation : public WrapperCompilation {
593 public:
594 // Allow dummying up the error status for all executions from this
595 // compilation. If errorStatus is NONE, then execute behaves
596 // normally (and sends back the actual execution status).
597 // Otherwise, don't bother to execute, and just send back
598 // errorStatus (as the execution status, not the launch status).
TestCompilation(const WrapperModel * model,const std::string & deviceName,V1_3::ErrorStatus errorStatus)599 TestCompilation(const WrapperModel* model, const std::string& deviceName,
600 V1_3::ErrorStatus errorStatus) {
601 std::vector<std::shared_ptr<Device>> devices;
602 auto device = DeviceManager::forTest_makeDriverDevice(
603 nn::makeSharedDevice(deviceName, new DriverClass(deviceName, errorStatus)));
604 devices.push_back(device);
605
606 nn::ModelBuilder* m = reinterpret_cast<nn::ModelBuilder*>(model->getHandle());
607 CompilationBuilder* c = nullptr;
608 int result = m->createCompilation(&c, devices);
609 EXPECT_EQ(result, 0);
610 // We need to ensure that we use our TestDriver and do not
611 // fall back to CPU. (If we allow CPU fallback, then when our
612 // TestDriver reports an execution failure, we'll re-execute
613 // on CPU, and will not see the failure.)
614 c->forTest_setPartitioning(DeviceManager::kPartitioningWithoutFallback);
615 mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
616 }
617 };
618
619 // This class has roughly the same functionality as TestCompilation class.
620 // The major difference is that Introspection API is used to select the device.
621 class TestIntrospectionCompilation : public WrapperCompilation {
622 public:
TestIntrospectionCompilation(const WrapperModel * model,const std::string & deviceName)623 TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
624 std::vector<ANeuralNetworksDevice*> mDevices;
625 uint32_t numDevices = 0;
626 EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
627 EXPECT_GE(numDevices, (uint32_t)1);
628
629 for (uint32_t i = 0; i < numDevices; i++) {
630 ANeuralNetworksDevice* device = nullptr;
631 EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
632 const char* buffer = nullptr;
633 int result = ANeuralNetworksDevice_getName(device, &buffer);
634 if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
635 mDevices.push_back(device);
636 }
637 }
638 // In CPU only mode, DeviceManager::getDrivers() will not be able to
639 // provide the actual device list. We will not be able to find the test
640 // driver with specified deviceName.
641 if (!DeviceManager::get()->getUseCpuOnly()) {
642 EXPECT_EQ(mDevices.size(), (uint32_t)1);
643
644 int result = ANeuralNetworksCompilation_createForDevices(
645 model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
646 EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
647 }
648 }
649 };
650
651 template <class DriverClass>
652 class ExecutionTestTemplate
653 : public ::testing::TestWithParam<std::tuple<V1_3::ErrorStatus, WrapperResult, bool>> {
654 public:
ExecutionTestTemplate()655 ExecutionTestTemplate()
656 : kName(toString(std::get<0>(GetParam()))),
657 kForceErrorStatus(std::get<0>(GetParam())),
658 kExpectResult(std::get<1>(GetParam())),
659 kUseIntrospectionAPI(std::get<2>(GetParam())),
660 mModel(makeModel()) {
661 if (kUseIntrospectionAPI) {
662 DeviceManager::get()->forTest_registerDevice(
663 nn::makeSharedDevice(kName, new DriverClass(kName.c_str(), kForceErrorStatus)));
664 mCompilation = TestIntrospectionCompilation(&mModel, kName);
665 } else {
666 mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
667 }
668 }
669
670 protected:
671 // Unit test method
672 // Set "reusable" to true to test reusable execution; Otherwise, test non-reusable execution.
673 void TestWait(bool reusable);
674
TearDown()675 virtual void TearDown() {
676 // Reinitialize the device list since Introspection API path altered it.
677 if (kUseIntrospectionAPI) {
678 DeviceManager::get()->forTest_reInitializeDeviceList();
679 }
680 }
681
getDimensionsWhileRunning(WrapperExecution & execution)682 void getDimensionsWhileRunning(WrapperExecution& execution) {
683 TestPreparedModelLatest::waitForExecutionToBegin();
684 // Cannot query dimensions while execution is running
685 std::vector<uint32_t> dimensions;
686 EXPECT_EQ(execution.getOutputOperandDimensions(0, &dimensions), WrapperResult::BAD_STATE);
687 }
688
689 const std::string kName;
690
691 // Allow dummying up the error status for execution. If
692 // kForceErrorStatus is NONE, then execution behaves normally (and
693 // sends back the actual execution status). Otherwise, don't
694 // bother to execute, and just send back kForceErrorStatus (as the
695 // execution status, not the launch status).
696 const V1_3::ErrorStatus kForceErrorStatus;
697
698 // What result do we expect from the execution? (The WrapperResult
699 // equivalent of kForceErrorStatus.)
700 const WrapperResult kExpectResult;
701
702 // Whether mCompilation is created via Introspection API or not.
703 const bool kUseIntrospectionAPI;
704
705 WrapperModel mModel;
706 WrapperCompilation mCompilation;
707
setInputOutput(WrapperExecution * execution)708 void setInputOutput(WrapperExecution* execution) {
709 mInputBuffer = kInputBuffer;
710 mOutputBuffer = kOutputBufferInitial;
711 ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)),
712 WrapperResult::NO_ERROR);
713 ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)),
714 WrapperResult::NO_ERROR);
715 }
716
717 const float kInputBuffer = 3.14;
718 const float kOutputBufferInitial = 0;
719 float mInputBuffer;
720 float mOutputBuffer;
721 const float kOutputBufferExpected = 3;
722 const std::vector<uint32_t> kOutputDimensionsExpected = {1};
723
724 private:
makeModel()725 static WrapperModel makeModel() {
726 static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, {1});
727
728 WrapperModel model;
729 uint32_t input = model.addOperand(&tensorType);
730 uint32_t output = model.addOperand(&tensorType);
731 model.addOperation(ANEURALNETWORKS_FLOOR, {input}, {output});
732 model.identifyInputsAndOutputs({input}, {output});
733 assert(model.finish() == WrapperResult::NO_ERROR);
734
735 return model;
736 }
737 };
738
computeHelper(bool reusable,const std::function<void ()> & compute)739 void computeHelper(bool reusable, const std::function<void()>& compute) {
740 {
741 SCOPED_TRACE(reusable ? "first time reusable" : "non-reusable");
742 compute();
743 }
744 if (reusable) {
745 SCOPED_TRACE("second time reusable");
746 compute();
747 }
748 }
749
750 template <class DriverClass>
TestWait(bool reusable)751 void ExecutionTestTemplate<DriverClass>::TestWait(bool reusable) {
752 SCOPED_TRACE(kName);
753 // Skip Introspection API tests when CPU only flag is forced on.
754 if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
755 GTEST_SKIP();
756 }
757
758 ASSERT_EQ(mCompilation.finish(), WrapperResult::NO_ERROR);
759
760 {
761 SCOPED_TRACE("startCompute");
762 WrapperExecution execution(&mCompilation);
763 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
764 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
765 const auto compute = [this, &execution] {
766 TestPreparedModelLatest::pauseExecutions(true);
767 WrapperEvent event;
768 ASSERT_EQ(execution.startCompute(&event), WrapperResult::NO_ERROR);
769 getDimensionsWhileRunning(execution);
770 TestPreparedModelLatest::pauseExecutions(false);
771 ASSERT_EQ(event.wait(), kExpectResult);
772 if (kExpectResult == WrapperResult::NO_ERROR) {
773 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
774 }
775 std::vector<uint32_t> dimensions;
776 if (kExpectResult == WrapperResult::NO_ERROR ||
777 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
778 // Only one output operand, hardcoded as index 0.
779 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
780 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
781 } else {
782 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
783 WrapperResult::BAD_STATE);
784 }
785 };
786 computeHelper(reusable, compute);
787 }
788 {
789 SCOPED_TRACE("compute");
790 WrapperExecution execution(&mCompilation);
791 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
792 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
793 const auto compute = [this, &execution] {
794 TestPreparedModelLatest::pauseExecutions(true);
795 std::thread run([this, &execution] { EXPECT_EQ(execution.compute(), kExpectResult); });
796 getDimensionsWhileRunning(execution);
797 TestPreparedModelLatest::pauseExecutions(false);
798 run.join();
799 if (kExpectResult == WrapperResult::NO_ERROR) {
800 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
801 }
802 std::vector<uint32_t> dimensions;
803 if (kExpectResult == WrapperResult::NO_ERROR ||
804 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
805 // Only one output operand, hardcoded as index 0.
806 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
807 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
808 } else {
809 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
810 WrapperResult::BAD_STATE);
811 }
812 };
813 computeHelper(reusable, compute);
814 }
815 {
816 SCOPED_TRACE("burstCompute");
817
818 // TODO: If a burst API is added to nn::test_wrapper (e.g.,
819 // Execution::burstCompute()), then use that, rather than
820 // Execution::compute(WrapperExecution::ComputeMode::BURST).
821
822 WrapperExecution execution(&mCompilation);
823 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
824 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
825 const auto compute = [this, &execution] {
826 TestPreparedModelLatest::pauseExecutions(true);
827 std::thread run([this, &execution] {
828 EXPECT_EQ(execution.compute(WrapperExecution::ComputeMode::BURST), kExpectResult);
829 });
830 getDimensionsWhileRunning(execution);
831 TestPreparedModelLatest::pauseExecutions(false);
832 run.join();
833 if (kExpectResult == WrapperResult::NO_ERROR) {
834 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
835 }
836 std::vector<uint32_t> dimensions;
837 if (kExpectResult == WrapperResult::NO_ERROR ||
838 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
839 // Only one output operand, hardcoded as index 0.
840 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
841 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
842 } else {
843 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
844 WrapperResult::BAD_STATE);
845 }
846 };
847 computeHelper(reusable, compute);
848 }
849 if (kExpectResult != WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
850 // computeWithDependencies doesn't support OUTPUT_INSUFFICIENT_SIZE
851 SCOPED_TRACE("computeWithDependencies");
852 WrapperExecution execution(&mCompilation);
853 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
854 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
855
856 const auto compute = [this, &execution] {
857 TestPreparedModelLatest::pauseExecutions(true);
858
859 WrapperEvent event;
860 // Note, due to the limitation of SampleDriver implementation, the call is synchronous.
861 // If the SampleDriver is updated to return real sync fence, this must be updated.
862 std::thread run([this, &execution, &event] {
863 EXPECT_EQ(execution.startComputeWithDependencies({}, 0, &event), kExpectResult);
864 });
865 getDimensionsWhileRunning(execution);
866 TestPreparedModelLatest::pauseExecutions(false);
867 run.join();
868 if (kExpectResult == WrapperResult::NO_ERROR) {
869 ASSERT_EQ(event.wait(), kExpectResult);
870 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
871 } else {
872 ASSERT_EQ(event.wait(), WrapperResult::UNEXPECTED_NULL);
873 }
874 std::vector<uint32_t> dimensions;
875 if (kExpectResult == WrapperResult::NO_ERROR) {
876 // Only one output operand, hardcoded as index 0.
877 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
878 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
879 } else {
880 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
881 WrapperResult::BAD_STATE);
882 }
883 };
884 computeHelper(reusable, compute);
885 }
886 }
887
888 auto kTestValues = ::testing::Values(
889 std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
890 /* kUseIntrospectionAPI */ false),
891 std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
892 /* kUseIntrospectionAPI */ false),
893 std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
894 /* kUseIntrospectionAPI */ false),
895 std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
896 WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
897 /* kUseIntrospectionAPI */ false),
898 std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
899 /* kUseIntrospectionAPI */ false));
900
901 class ExecutionTest13 : public ExecutionTestTemplate<TestDriver13> {};
TEST_P(ExecutionTest13,Wait)902 TEST_P(ExecutionTest13, Wait) {
903 TestWait(/*reusable=*/false);
904 }
TEST_P(ExecutionTest13,WaitReusable)905 TEST_P(ExecutionTest13, WaitReusable) {
906 TestWait(/*reusable=*/true);
907 }
908 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest13, kTestValues);
909
910 class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12,Wait)911 TEST_P(ExecutionTest12, Wait) {
912 TestWait(/*reusable=*/false);
913 }
TEST_P(ExecutionTest12,WaitReusable)914 TEST_P(ExecutionTest12, WaitReusable) {
915 TestWait(/*reusable=*/true);
916 }
917 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest12, kTestValues);
918
919 class ExecutionTest11 : public ExecutionTestTemplate<TestDriver11> {};
TEST_P(ExecutionTest11,Wait)920 TEST_P(ExecutionTest11, Wait) {
921 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
922 TestWait(/*reusable=*/false);
923 }
TEST_P(ExecutionTest11,WaitReusable)924 TEST_P(ExecutionTest11, WaitReusable) {
925 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
926 TestWait(/*reusable=*/true);
927 }
928 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest11, kTestValues);
929
930 class ExecutionTest10 : public ExecutionTestTemplate<TestDriver10> {};
TEST_P(ExecutionTest10,Wait)931 TEST_P(ExecutionTest10, Wait) {
932 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
933 TestWait(/*reusable=*/false);
934 }
TEST_P(ExecutionTest10,WaitReusable)935 TEST_P(ExecutionTest10, WaitReusable) {
936 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
937 TestWait(/*reusable=*/true);
938 }
939 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest10, kTestValues);
940
941 auto kIntrospectionTestValues = ::testing::Values(
942 std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
943 /* kUseIntrospectionAPI */ true),
944 std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
945 /* kUseIntrospectionAPI */ true),
946 std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
947 /* kUseIntrospectionAPI */ true),
948 std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
949 WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
950 /* kUseIntrospectionAPI */ true),
951 std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
952 /* kUseIntrospectionAPI */ true));
953
954 INSTANTIATE_TEST_SUITE_P(IntrospectionFlavor, ExecutionTest13, kIntrospectionTestValues);
955
956 } // namespace
957 } // namespace android
958