• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <ValidateHal.h>
20 #include <gtest/gtest.h>
21 
22 #include <algorithm>
23 #include <atomic>
24 #include <cassert>
25 #include <memory>
26 #include <string>
27 #include <thread>
28 #include <tuple>
29 #include <vector>
30 
31 #include "CompilationBuilder.h"
32 #include "ExecutionBurstServer.h"
33 #include "ExecutionCallback.h"
34 #include "HalUtils.h"
35 #include "Manager.h"
36 #include "ModelBuilder.h"
37 #include "NeuralNetworks.h"
38 #include "PreparedModelCallback.h"
39 #include "TestNeuralNetworksWrapper.h"
40 
41 namespace android {
42 
43 namespace V1_0 = ::android::hardware::neuralnetworks::V1_0;
44 namespace V1_1 = ::android::hardware::neuralnetworks::V1_1;
45 namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
46 namespace V1_3 = ::android::hardware::neuralnetworks::V1_3;
47 using CompilationBuilder = nn::CompilationBuilder;
48 using Device = nn::Device;
49 using SharedDevice = nn::SharedDevice;
50 using DeviceManager = nn::DeviceManager;
51 using HidlModel = V1_3::Model;
52 using PreparedModelCallback = nn::PreparedModelCallback;
53 using SampleDriver = nn::sample_driver::SampleDriver;
54 using WrapperCompilation = nn::test_wrapper::Compilation;
55 using WrapperEvent = nn::test_wrapper::Event;
56 using WrapperExecution = nn::test_wrapper::Execution;
57 using WrapperModel = nn::test_wrapper::Model;
58 using WrapperOperandType = nn::test_wrapper::OperandType;
59 using WrapperResult = nn::test_wrapper::Result;
60 using WrapperType = nn::test_wrapper::Type;
61 using nn::convertToV1_0;
62 using nn::convertToV1_3;
63 using nn::ErrorStatus;
64 
65 template <typename T>
66 using MQDescriptorSync = hardware::MQDescriptorSync<T>;
67 
68 namespace {
69 
70 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
71 
72 // Wraps the latest version of IPreparedModel to allow dummying up the execution status,
73 // and control when the execution finishes.
74 class TestPreparedModelLatest : public V1_3::IPreparedModel {
75    public:
76     // If errorStatus is NONE, then execute behaves normally (and sends back
77     // the actual execution status).  Otherwise, don't bother to execute, and
78     // just send back errorStatus (as the execution status, not the launch
79     // status).
TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)80     TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
81         : mPreparedModelV1_0(preparedModel),
82           mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
83           mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
84           mErrorStatus(errorStatus) {}
85 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)86     hardware::Return<V1_0::ErrorStatus> execute(
87             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
88         CHECK(mPreparedModelV1_0 != nullptr) << "V1_0 prepared model is nullptr.";
89         std::thread([this, request, callback] {
90             dummyExecution();
91             if (mErrorStatus == V1_3::ErrorStatus::NONE) {
92                 // Note that we lose the actual launch status.
93                 (void)mPreparedModelV1_0->execute(request, callback);
94             } else {
95                 callback->notify(convertToV1_0(mErrorStatus));
96             }
97         }).detach();
98         return V1_0::ErrorStatus::NONE;
99     }
100 
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)101     hardware::Return<V1_0::ErrorStatus> execute_1_2(
102             const V1_0::Request& request, V1_2::MeasureTiming measure,
103             const sp<V1_2::IExecutionCallback>& callback) override {
104         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
105         std::thread([this, request, measure, callback] {
106             dummyExecution();
107             if (mErrorStatus == V1_3::ErrorStatus::NONE) {
108                 // Note that we lose the actual launch status.
109                 (void)mPreparedModelV1_2->execute_1_2(request, measure, callback);
110             } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
111                 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
112                 callback->notify_1_2(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
113             } else {
114                 callback->notify_1_2(convertToV1_0(mErrorStatus), {}, kBadTiming);
115             }
116         }).detach();
117         return V1_0::ErrorStatus::NONE;
118     }
119 
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)120     hardware::Return<V1_3::ErrorStatus> execute_1_3(
121             const V1_3::Request& request, V1_2::MeasureTiming measure,
122             const V1_3::OptionalTimePoint& deadline,
123             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
124             const sp<V1_3::IExecutionCallback>& callback) override {
125         CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
126         std::thread([this, request, measure, deadline, loopTimeoutDuration, callback] {
127             dummyExecution();
128             if (mErrorStatus == V1_3::ErrorStatus::NONE) {
129                 // Note that we lose the actual launch status.
130                 (void)mPreparedModelV1_3->execute_1_3(request, measure, deadline,
131                                                       loopTimeoutDuration, callback);
132             } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
133                 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
134                 callback->notify_1_3(mErrorStatus, {shape}, kBadTiming);
135             } else {
136                 callback->notify_1_3(mErrorStatus, {}, kBadTiming);
137             }
138         }).detach();
139         return V1_3::ErrorStatus::NONE;
140     }
141 
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)142     hardware::Return<void> executeSynchronously(const V1_0::Request& request,
143                                                 V1_2::MeasureTiming measure,
144                                                 executeSynchronously_cb cb) override {
145         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
146         dummyExecution();
147         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
148             return mPreparedModelV1_2->executeSynchronously(request, measure, cb);
149         } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
150             V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
151             cb(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
152             return hardware::Void();
153         } else {
154             cb(convertToV1_0(mErrorStatus), {}, kBadTiming);
155             return hardware::Void();
156         }
157     }
158 
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)159     hardware::Return<void> executeSynchronously_1_3(
160             const V1_3::Request& request, V1_2::MeasureTiming measure,
161             const V1_3::OptionalTimePoint& deadline,
162             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
163             executeSynchronously_1_3_cb cb) override {
164         CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
165         dummyExecution();
166         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
167             return mPreparedModelV1_3->executeSynchronously_1_3(request, measure, deadline,
168                                                                 loopTimeoutDuration, cb);
169         } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
170             V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
171             cb(mErrorStatus, {shape}, kBadTiming);
172             return hardware::Void();
173         } else {
174             cb(mErrorStatus, {}, kBadTiming);
175             return hardware::Void();
176         }
177     }
178 
179     // ExecutionBurstServer::create has an overload that will use
180     // IPreparedModel::executeSynchronously(), so we can rely on that, rather
181     // than having to implement ExecutionBurstServer::IExecutorWithCache.
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)182     hardware::Return<void> configureExecutionBurst(
183             const sp<V1_2::IBurstCallback>& callback,
184             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
185             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
186             configureExecutionBurst_cb cb) override {
187         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
188         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
189             const sp<V1_2::IBurstContext> burst =
190                     nn::ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
191 
192             cb(burst == nullptr ? V1_0::ErrorStatus::GENERAL_FAILURE : V1_0::ErrorStatus::NONE,
193                burst);
194             return hardware::Void();
195         } else {
196             cb(convertToV1_0(mErrorStatus), nullptr);
197             return hardware::Void();
198         }
199     }
200 
201     // Note, due to the limitation of SampleDriver implementation, the call is
202     // synchronous.  The test code that exercises this implementation of
203     // SampleDriver is written with that in mind.  Therefore, this
204     // implementation is synchronous also.  If the SampleDriver is updated to
205     // return real sync fence, this must be updated.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)206     hardware::Return<void> executeFenced(const V1_3::Request& request,
207                                          const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
208                                          V1_2::MeasureTiming measure,
209                                          const V1_3::OptionalTimePoint& deadline,
210                                          const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
211                                          const V1_3::OptionalTimeoutDuration& duration,
212                                          executeFenced_cb cb) override {
213         CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
214         CHECK(mErrorStatus != V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE)
215                 << "executeFenced does not support dynamic output shape";
216         dummyExecution();
217         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
218             return mPreparedModelV1_3->executeFenced(request, waitFor, measure, deadline,
219                                                      loopTimeoutDuration, duration, cb);
220         } else {
221             // Due to the limitations of the SampleDriver, all failures look
222             // like launch failures.  If the SampleDriver is updated to return
223             // real sync fences, this must be updated.
224             cb(mErrorStatus, hardware::hidl_handle(nullptr), nullptr);
225         }
226         return hardware::Void();
227     }
228 
229     // We can place the TestPreparedModelLatest system in a "pause" mode where
230     // no execution will complete until the system is taken out of that mode.
231     // Initially, the system is not in that mode.
pauseExecutions(bool v)232     static void pauseExecutions(bool v) { mPauseExecutions.store(v); }
233 
234     // This function is only guaranteed to work in the following pattern:
235     // Consider thread A as primary thread
236     // - thread A: pauseExecutions(true);
237     // - thread A: launch execution (as thread B)
238     // - thread A: waitForExecutionToBegin(), block until call to dummyExecution by
239     //                                        thread B makes mExecutionsInFlight nonzero
240     // - thread B: dummyExecution(), which makes mExecutionsInFlight nonzero and blocks
241     //                               until thread A calls pauseExecutions(false)
242     // - thread A: waitForExecutionToBegin() returns
243     // - thread A: pauseExecutions(false), allowing dummyExecution() on thread B to continue
244     // - thread B: dummyExecution() zeroes mExecutionsInFlight and returns
245     // - thread B: thread exits
waitForExecutionToBegin()246     static void waitForExecutionToBegin() {
247         CHECK(mPauseExecutions.load());
248         while (mExecutionsInFlight.load() == 0) {
249         }
250     }
251 
252    private:
253     const sp<V1_0::IPreparedModel> mPreparedModelV1_0;
254     const sp<V1_2::IPreparedModel> mPreparedModelV1_2;
255     const sp<V1_3::IPreparedModel> mPreparedModelV1_3;
256     V1_3::ErrorStatus mErrorStatus;
257 
258     static std::atomic<bool> mPauseExecutions;
259     static std::atomic<unsigned int> mExecutionsInFlight;
260 
dummyExecution()261     static void dummyExecution() {
262         CHECK_EQ(mExecutionsInFlight.fetch_add(1), 0u) << "We do not support concurrent executions";
263         while (mPauseExecutions.load()) {
264         }
265         mExecutionsInFlight.fetch_sub(1);
266     }
267 };
268 std::atomic<bool> TestPreparedModelLatest::mPauseExecutions = false;
269 std::atomic<unsigned int> TestPreparedModelLatest::mExecutionsInFlight = 0;
270 
271 using TestPreparedModel13 = TestPreparedModelLatest;
272 
273 // Like TestPreparedModelLatest, but implementing 1.2
274 class TestPreparedModel12 : public V1_2::IPreparedModel {
275    public:
TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)276     TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
277         : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
278 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)279     hardware::Return<V1_0::ErrorStatus> execute(
280             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
281         return mLatestPreparedModel->execute(request, callback);
282     }
283 
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)284     hardware::Return<V1_0::ErrorStatus> execute_1_2(
285             const V1_0::Request& request, V1_2::MeasureTiming measure,
286             const sp<V1_2::IExecutionCallback>& callback) override {
287         return mLatestPreparedModel->execute_1_2(request, measure, callback);
288     }
289 
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)290     hardware::Return<void> executeSynchronously(const V1_0::Request& request,
291                                                 V1_2::MeasureTiming measure,
292                                                 executeSynchronously_cb cb) override {
293         return mLatestPreparedModel->executeSynchronously(request, measure, cb);
294     }
295 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)296     hardware::Return<void> configureExecutionBurst(
297             const sp<V1_2::IBurstCallback>& callback,
298             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
299             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
300             configureExecutionBurst_cb cb) override {
301         return mLatestPreparedModel->configureExecutionBurst(callback, requestChannel,
302                                                              resultChannel, cb);
303     }
304 
305    private:
306     const sp<V1_3::IPreparedModel> mLatestPreparedModel;
307 };
308 
309 // Like TestPreparedModelLatest, but implementing 1.0
310 class TestPreparedModel10 : public V1_0::IPreparedModel {
311    public:
TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)312     TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
313         : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
314 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)315     hardware::Return<V1_0::ErrorStatus> execute(
316             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
317         return mLatestPreparedModel->execute(request, callback);
318     }
319 
320    private:
321     const sp<V1_3::IPreparedModel> mLatestPreparedModel;
322 };
323 
324 // Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
325 class TestDriver13 : public SampleDriver {
326    public:
327     // Allow dummying up the error status for execution of all models
328     // prepared from this driver.  If errorStatus is NONE, then
329     // execute behaves normally (and sends back the actual execution
330     // status). Otherwise, don't bother to execute, and just send
331     // back errorStatus (as the execution status, not the launch
332     // status).
TestDriver13(const std::string & name,V1_3::ErrorStatus errorStatus)333     TestDriver13(const std::string& name, V1_3::ErrorStatus errorStatus)
334         : SampleDriver(name.c_str()), mErrorStatus(errorStatus) {}
335 
getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb)336     hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb) override {
337         android::nn::initVLogMask();
338         const V1_0::PerformanceInfo kPerf = {.execTime = 0.75f, .powerUsage = 0.75f};
339         V1_3::Capabilities capabilities = {
340                 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
341                 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
342                 .operandPerformance =
343                         nn::nonExtensionOperandPerformance<nn::HalVersion::V1_3>(kPerf),
344                 .ifPerformance = kPerf,
345                 .whilePerformance = kPerf};
346         _hidl_cb(V1_3::ErrorStatus::NONE, capabilities);
347         return hardware::Void();
348     }
349 
getSupportedOperations_1_3(const HidlModel & model,getSupportedOperations_1_3_cb cb)350     hardware::Return<void> getSupportedOperations_1_3(const HidlModel& model,
351                                                       getSupportedOperations_1_3_cb cb) override {
352         if (nn::validateModel(model)) {
353             std::vector<bool> supported(model.main.operations.size(), true);
354             cb(V1_3::ErrorStatus::NONE, supported);
355         } else {
356             cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
357         }
358         return hardware::Void();
359     }
360 
prepareModel_1_3(const HidlModel & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_3::IPreparedModelCallback> & actualCallback)361     hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
362             const HidlModel& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
363             const V1_3::OptionalTimePoint& deadline,
364             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
365             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
366             const nn::HalCacheToken& token,
367             const sp<V1_3::IPreparedModelCallback>& actualCallback) override {
368         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
369         hardware::Return<V1_3::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_3(
370                 model, preference, priority, deadline, modelCache, dataCache, token, localCallback);
371         if (!prepareModelReturn.isOkUnchecked()) {
372             return prepareModelReturn;
373         }
374         if (prepareModelReturn != V1_3::ErrorStatus::NONE) {
375             actualCallback->notify_1_3(
376                     convertToV1_3(localCallback->getStatus()),
377                     V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
378             return prepareModelReturn;
379         }
380         localCallback->wait();
381         if (localCallback->getStatus() != ErrorStatus::NONE) {
382             actualCallback->notify_1_3(
383                     convertToV1_3(localCallback->getStatus()),
384                     V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
385         } else {
386             actualCallback->notify_1_3(
387                     V1_3::ErrorStatus::NONE,
388                     new TestPreparedModel13(localCallback->getPreparedModel(), mErrorStatus));
389         }
390         return prepareModelReturn;
391     }
392 
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)393     hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
394             const V1_2::Model& model, V1_1::ExecutionPreference preference,
395             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
396             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
397             const nn::HalCacheToken& token,
398             const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
399         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
400         hardware::Return<V1_0::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_2(
401                 model, preference, modelCache, dataCache, token, localCallback);
402         if (!prepareModelReturn.isOkUnchecked()) {
403             return prepareModelReturn;
404         }
405         if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
406             actualCallback->notify_1_2(
407                     convertToV1_0(localCallback->getStatus()),
408                     V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
409             return prepareModelReturn;
410         }
411         localCallback->wait();
412         if (localCallback->getStatus() != ErrorStatus::NONE) {
413             actualCallback->notify_1_2(
414                     convertToV1_0(localCallback->getStatus()),
415                     V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
416         } else {
417             actualCallback->notify_1_2(
418                     V1_0::ErrorStatus::NONE,
419                     new TestPreparedModel12(localCallback->getPreparedModel(), mErrorStatus));
420         }
421         return prepareModelReturn;
422     }
423 
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)424     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
425             const V1_1::Model& model, V1_1::ExecutionPreference preference,
426             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
427         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
428         hardware::Return<V1_0::ErrorStatus> prepareModelReturn =
429                 SampleDriver::prepareModel_1_1(model, preference, localCallback);
430         if (!prepareModelReturn.isOkUnchecked()) {
431             return prepareModelReturn;
432         }
433         if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
434             actualCallback->notify(convertToV1_0(localCallback->getStatus()),
435                                    localCallback->getPreparedModel());
436             return prepareModelReturn;
437         }
438         localCallback->wait();
439         if (localCallback->getStatus() != ErrorStatus::NONE) {
440             actualCallback->notify(convertToV1_0(localCallback->getStatus()),
441                                    localCallback->getPreparedModel());
442         } else {
443             actualCallback->notify(
444                     V1_0::ErrorStatus::NONE,
445                     new TestPreparedModel10(localCallback->getPreparedModel(), mErrorStatus));
446         }
447         return prepareModelReturn;
448     }
449 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)450     hardware::Return<V1_0::ErrorStatus> prepareModel(
451             const V1_0::Model& model,
452             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
453         return prepareModel_1_1(nn::convertToV1_1(model),
454                                 V1_1::ExecutionPreference::FAST_SINGLE_ANSWER, actualCallback);
455     }
456 
457    private:
458     V1_3::ErrorStatus mErrorStatus;
459 };
460 
461 // Like TestDriver, but implementing 1.2
462 class TestDriver12 : public V1_2::IDevice {
463    public:
TestDriver12(const std::string & name,V1_3::ErrorStatus errorStatus)464     TestDriver12(const std::string& name, V1_3::ErrorStatus errorStatus)
465         : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb)466     hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
467         return mLatestDriver->getCapabilities_1_2(_hidl_cb);
468     }
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)469     hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
470         return mLatestDriver->getCapabilities_1_1(_hidl_cb);
471     }
getCapabilities(getCapabilities_cb _hidl_cb)472     hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
473         return mLatestDriver->getCapabilities(_hidl_cb);
474     }
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb _hidl_cb)475     hardware::Return<void> getSupportedOperations_1_2(
476             const V1_2::Model& model, getSupportedOperations_1_2_cb _hidl_cb) override {
477         return mLatestDriver->getSupportedOperations_1_2(model, _hidl_cb);
478     }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)479     hardware::Return<void> getSupportedOperations_1_1(
480             const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
481         return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
482     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)483     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
484                                                   getSupportedOperations_cb _hidl_cb) override {
485         return mLatestDriver->getSupportedOperations(model, _hidl_cb);
486     }
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)487     hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
488             const V1_2::Model& model, V1_1::ExecutionPreference preference,
489             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
490             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
491             const nn::HalCacheToken& token,
492             const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
493         return mLatestDriver->prepareModel_1_2(model, preference, modelCache, dataCache, token,
494                                                actualCallback);
495     }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)496     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
497             const V1_1::Model& model, V1_1::ExecutionPreference preference,
498             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
499         return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
500     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)501     hardware::Return<V1_0::ErrorStatus> prepareModel(
502             const V1_0::Model& model,
503             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
504         return mLatestDriver->prepareModel(model, actualCallback);
505     }
getStatus()506     hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getVersionString(getVersionString_cb _hidl_cb)507     hardware::Return<void> getVersionString(getVersionString_cb _hidl_cb) override {
508         return mLatestDriver->getVersionString(_hidl_cb);
509     }
getType(getType_cb _hidl_cb)510     hardware::Return<void> getType(getType_cb _hidl_cb) override {
511         return mLatestDriver->getType(_hidl_cb);
512     }
getSupportedExtensions(getSupportedExtensions_cb _hidl_cb)513     hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb _hidl_cb) {
514         return mLatestDriver->getSupportedExtensions(_hidl_cb);
515     }
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb)516     hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb) {
517         return mLatestDriver->getNumberOfCacheFilesNeeded(_hidl_cb);
518     }
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & callback)519     hardware::Return<V1_0::ErrorStatus> prepareModelFromCache(
520             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
521             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
522             const nn::HalCacheToken& token, const sp<V1_2::IPreparedModelCallback>& callback) {
523         return mLatestDriver->prepareModelFromCache(modelCache, dataCache, token, callback);
524     }
525 
526    private:
527     const sp<V1_3::IDevice> mLatestDriver;
528 };
529 
530 // Like TestDriver, but implementing 1.1
531 class TestDriver11 : public V1_1::IDevice {
532    public:
TestDriver11(const std::string & name,V1_3::ErrorStatus errorStatus)533     TestDriver11(const std::string& name, V1_3::ErrorStatus errorStatus)
534         : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)535     hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
536         return mLatestDriver->getCapabilities_1_1(_hidl_cb);
537     }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)538     hardware::Return<void> getSupportedOperations_1_1(
539             const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
540         return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
541     }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)542     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
543             const V1_1::Model& model, V1_1::ExecutionPreference preference,
544             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
545         return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
546     }
getStatus()547     hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getCapabilities(getCapabilities_cb _hidl_cb)548     hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
549         return mLatestDriver->getCapabilities(_hidl_cb);
550     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)551     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
552                                                   getSupportedOperations_cb _hidl_cb) override {
553         return mLatestDriver->getSupportedOperations(model, _hidl_cb);
554     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)555     hardware::Return<V1_0::ErrorStatus> prepareModel(
556             const V1_0::Model& model,
557             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
558         return mLatestDriver->prepareModel(model, actualCallback);
559     }
560 
561    private:
562     const sp<V1_3::IDevice> mLatestDriver;
563 };
564 
565 // Like TestDriver, but implementing 1.0
566 class TestDriver10 : public V1_0::IDevice {
567    public:
TestDriver10(const std::string & name,V1_3::ErrorStatus errorStatus)568     TestDriver10(const std::string& name, V1_3::ErrorStatus errorStatus)
569         : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities(getCapabilities_cb _hidl_cb)570     hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
571         return mLatestDriver->getCapabilities(_hidl_cb);
572     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)573     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
574                                                   getSupportedOperations_cb _hidl_cb) override {
575         return mLatestDriver->getSupportedOperations(model, _hidl_cb);
576     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)577     hardware::Return<V1_0::ErrorStatus> prepareModel(
578             const V1_0::Model& model,
579             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
580         return mLatestDriver->prepareModel(model, actualCallback);
581     }
getStatus()582     hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
583 
584    private:
585     const sp<V1_3::IDevice> mLatestDriver;
586 };
587 
588 // This class adds some simple utilities on top of WrapperCompilation in order
589 // to provide access to certain features from CompilationBuilder that are not
590 // exposed by the base class.
591 template <typename DriverClass>
592 class TestCompilation : public WrapperCompilation {
593    public:
594     // Allow dummying up the error status for all executions from this
595     // compilation.  If errorStatus is NONE, then execute behaves
596     // normally (and sends back the actual execution status).
597     // Otherwise, don't bother to execute, and just send back
598     // errorStatus (as the execution status, not the launch status).
TestCompilation(const WrapperModel * model,const std::string & deviceName,V1_3::ErrorStatus errorStatus)599     TestCompilation(const WrapperModel* model, const std::string& deviceName,
600                     V1_3::ErrorStatus errorStatus) {
601         std::vector<std::shared_ptr<Device>> devices;
602         auto device = DeviceManager::forTest_makeDriverDevice(
603                 nn::makeSharedDevice(deviceName, new DriverClass(deviceName, errorStatus)));
604         devices.push_back(device);
605 
606         nn::ModelBuilder* m = reinterpret_cast<nn::ModelBuilder*>(model->getHandle());
607         CompilationBuilder* c = nullptr;
608         int result = m->createCompilation(&c, devices);
609         EXPECT_EQ(result, 0);
610         // We need to ensure that we use our TestDriver and do not
611         // fall back to CPU.  (If we allow CPU fallback, then when our
612         // TestDriver reports an execution failure, we'll re-execute
613         // on CPU, and will not see the failure.)
614         c->forTest_setPartitioning(DeviceManager::kPartitioningWithoutFallback);
615         mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
616     }
617 };
618 
619 // This class has roughly the same functionality as TestCompilation class.
620 // The major difference is that Introspection API is used to select the device.
621 class TestIntrospectionCompilation : public WrapperCompilation {
622    public:
TestIntrospectionCompilation(const WrapperModel * model,const std::string & deviceName)623     TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
624         std::vector<ANeuralNetworksDevice*> mDevices;
625         uint32_t numDevices = 0;
626         EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
627         EXPECT_GE(numDevices, (uint32_t)1);
628 
629         for (uint32_t i = 0; i < numDevices; i++) {
630             ANeuralNetworksDevice* device = nullptr;
631             EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
632             const char* buffer = nullptr;
633             int result = ANeuralNetworksDevice_getName(device, &buffer);
634             if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
635                 mDevices.push_back(device);
636             }
637         }
638         // In CPU only mode, DeviceManager::getDrivers() will not be able to
639         // provide the actual device list. We will not be able to find the test
640         // driver with specified deviceName.
641         if (!DeviceManager::get()->getUseCpuOnly()) {
642             EXPECT_EQ(mDevices.size(), (uint32_t)1);
643 
644             int result = ANeuralNetworksCompilation_createForDevices(
645                     model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
646             EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
647         }
648     }
649 };
650 
651 template <class DriverClass>
652 class ExecutionTestTemplate
653     : public ::testing::TestWithParam<std::tuple<V1_3::ErrorStatus, WrapperResult, bool>> {
654    public:
ExecutionTestTemplate()655     ExecutionTestTemplate()
656         : kName(toString(std::get<0>(GetParam()))),
657           kForceErrorStatus(std::get<0>(GetParam())),
658           kExpectResult(std::get<1>(GetParam())),
659           kUseIntrospectionAPI(std::get<2>(GetParam())),
660           mModel(makeModel()) {
661         if (kUseIntrospectionAPI) {
662             DeviceManager::get()->forTest_registerDevice(
663                     nn::makeSharedDevice(kName, new DriverClass(kName.c_str(), kForceErrorStatus)));
664             mCompilation = TestIntrospectionCompilation(&mModel, kName);
665         } else {
666             mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
667         }
668     }
669 
670    protected:
671     // Unit test method
672     // Set "reusable" to true to test reusable execution; Otherwise, test non-reusable execution.
673     void TestWait(bool reusable);
674 
TearDown()675     virtual void TearDown() {
676         // Reinitialize the device list since Introspection API path altered it.
677         if (kUseIntrospectionAPI) {
678             DeviceManager::get()->forTest_reInitializeDeviceList();
679         }
680     }
681 
getDimensionsWhileRunning(WrapperExecution & execution)682     void getDimensionsWhileRunning(WrapperExecution& execution) {
683         TestPreparedModelLatest::waitForExecutionToBegin();
684         // Cannot query dimensions while execution is running
685         std::vector<uint32_t> dimensions;
686         EXPECT_EQ(execution.getOutputOperandDimensions(0, &dimensions), WrapperResult::BAD_STATE);
687     }
688 
689     const std::string kName;
690 
691     // Allow dummying up the error status for execution.  If
692     // kForceErrorStatus is NONE, then execution behaves normally (and
693     // sends back the actual execution status).  Otherwise, don't
694     // bother to execute, and just send back kForceErrorStatus (as the
695     // execution status, not the launch status).
696     const V1_3::ErrorStatus kForceErrorStatus;
697 
698     // What result do we expect from the execution?  (The WrapperResult
699     // equivalent of kForceErrorStatus.)
700     const WrapperResult kExpectResult;
701 
702     // Whether mCompilation is created via Introspection API or not.
703     const bool kUseIntrospectionAPI;
704 
705     WrapperModel mModel;
706     WrapperCompilation mCompilation;
707 
setInputOutput(WrapperExecution * execution)708     void setInputOutput(WrapperExecution* execution) {
709         mInputBuffer = kInputBuffer;
710         mOutputBuffer = kOutputBufferInitial;
711         ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)),
712                   WrapperResult::NO_ERROR);
713         ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)),
714                   WrapperResult::NO_ERROR);
715     }
716 
717     const float kInputBuffer = 3.14;
718     const float kOutputBufferInitial = 0;
719     float mInputBuffer;
720     float mOutputBuffer;
721     const float kOutputBufferExpected = 3;
722     const std::vector<uint32_t> kOutputDimensionsExpected = {1};
723 
724    private:
makeModel()725     static WrapperModel makeModel() {
726         static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, {1});
727 
728         WrapperModel model;
729         uint32_t input = model.addOperand(&tensorType);
730         uint32_t output = model.addOperand(&tensorType);
731         model.addOperation(ANEURALNETWORKS_FLOOR, {input}, {output});
732         model.identifyInputsAndOutputs({input}, {output});
733         assert(model.finish() == WrapperResult::NO_ERROR);
734 
735         return model;
736     }
737 };
738 
computeHelper(bool reusable,const std::function<void ()> & compute)739 void computeHelper(bool reusable, const std::function<void()>& compute) {
740     {
741         SCOPED_TRACE(reusable ? "first time reusable" : "non-reusable");
742         compute();
743     }
744     if (reusable) {
745         SCOPED_TRACE("second time reusable");
746         compute();
747     }
748 }
749 
750 template <class DriverClass>
TestWait(bool reusable)751 void ExecutionTestTemplate<DriverClass>::TestWait(bool reusable) {
752     SCOPED_TRACE(kName);
753     // Skip Introspection API tests when CPU only flag is forced on.
754     if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
755         GTEST_SKIP();
756     }
757 
758     ASSERT_EQ(mCompilation.finish(), WrapperResult::NO_ERROR);
759 
760     {
761         SCOPED_TRACE("startCompute");
762         WrapperExecution execution(&mCompilation);
763         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
764         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
765         const auto compute = [this, &execution] {
766             TestPreparedModelLatest::pauseExecutions(true);
767             WrapperEvent event;
768             ASSERT_EQ(execution.startCompute(&event), WrapperResult::NO_ERROR);
769             getDimensionsWhileRunning(execution);
770             TestPreparedModelLatest::pauseExecutions(false);
771             ASSERT_EQ(event.wait(), kExpectResult);
772             if (kExpectResult == WrapperResult::NO_ERROR) {
773                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
774             }
775             std::vector<uint32_t> dimensions;
776             if (kExpectResult == WrapperResult::NO_ERROR ||
777                 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
778                 // Only one output operand, hardcoded as index 0.
779                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
780                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
781             } else {
782                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
783                           WrapperResult::BAD_STATE);
784             }
785         };
786         computeHelper(reusable, compute);
787     }
788     {
789         SCOPED_TRACE("compute");
790         WrapperExecution execution(&mCompilation);
791         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
792         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
793         const auto compute = [this, &execution] {
794             TestPreparedModelLatest::pauseExecutions(true);
795             std::thread run([this, &execution] { EXPECT_EQ(execution.compute(), kExpectResult); });
796             getDimensionsWhileRunning(execution);
797             TestPreparedModelLatest::pauseExecutions(false);
798             run.join();
799             if (kExpectResult == WrapperResult::NO_ERROR) {
800                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
801             }
802             std::vector<uint32_t> dimensions;
803             if (kExpectResult == WrapperResult::NO_ERROR ||
804                 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
805                 // Only one output operand, hardcoded as index 0.
806                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
807                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
808             } else {
809                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
810                           WrapperResult::BAD_STATE);
811             }
812         };
813         computeHelper(reusable, compute);
814     }
815     {
816         SCOPED_TRACE("burstCompute");
817 
818         // TODO: If a burst API is added to nn::test_wrapper (e.g.,
819         // Execution::burstCompute()), then use that, rather than
820         // Execution::compute(WrapperExecution::ComputeMode::BURST).
821 
822         WrapperExecution execution(&mCompilation);
823         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
824         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
825         const auto compute = [this, &execution] {
826             TestPreparedModelLatest::pauseExecutions(true);
827             std::thread run([this, &execution] {
828                 EXPECT_EQ(execution.compute(WrapperExecution::ComputeMode::BURST), kExpectResult);
829             });
830             getDimensionsWhileRunning(execution);
831             TestPreparedModelLatest::pauseExecutions(false);
832             run.join();
833             if (kExpectResult == WrapperResult::NO_ERROR) {
834                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
835             }
836             std::vector<uint32_t> dimensions;
837             if (kExpectResult == WrapperResult::NO_ERROR ||
838                 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
839                 // Only one output operand, hardcoded as index 0.
840                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
841                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
842             } else {
843                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
844                           WrapperResult::BAD_STATE);
845             }
846         };
847         computeHelper(reusable, compute);
848     }
849     if (kExpectResult != WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
850         // computeWithDependencies doesn't support OUTPUT_INSUFFICIENT_SIZE
851         SCOPED_TRACE("computeWithDependencies");
852         WrapperExecution execution(&mCompilation);
853         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
854         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
855 
856         const auto compute = [this, &execution] {
857             TestPreparedModelLatest::pauseExecutions(true);
858 
859             WrapperEvent event;
860             // Note, due to the limitation of SampleDriver implementation, the call is synchronous.
861             // If the SampleDriver is updated to return real sync fence, this must be updated.
862             std::thread run([this, &execution, &event] {
863                 EXPECT_EQ(execution.startComputeWithDependencies({}, 0, &event), kExpectResult);
864             });
865             getDimensionsWhileRunning(execution);
866             TestPreparedModelLatest::pauseExecutions(false);
867             run.join();
868             if (kExpectResult == WrapperResult::NO_ERROR) {
869                 ASSERT_EQ(event.wait(), kExpectResult);
870                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
871             } else {
872                 ASSERT_EQ(event.wait(), WrapperResult::UNEXPECTED_NULL);
873             }
874             std::vector<uint32_t> dimensions;
875             if (kExpectResult == WrapperResult::NO_ERROR) {
876                 // Only one output operand, hardcoded as index 0.
877                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
878                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
879             } else {
880                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
881                           WrapperResult::BAD_STATE);
882             }
883         };
884         computeHelper(reusable, compute);
885     }
886 }
887 
888 auto kTestValues = ::testing::Values(
889         std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
890                         /* kUseIntrospectionAPI */ false),
891         std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
892                         /* kUseIntrospectionAPI */ false),
893         std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
894                         /* kUseIntrospectionAPI */ false),
895         std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
896                         WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
897                         /* kUseIntrospectionAPI */ false),
898         std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
899                         /* kUseIntrospectionAPI */ false));
900 
901 class ExecutionTest13 : public ExecutionTestTemplate<TestDriver13> {};
TEST_P(ExecutionTest13,Wait)902 TEST_P(ExecutionTest13, Wait) {
903     TestWait(/*reusable=*/false);
904 }
TEST_P(ExecutionTest13,WaitReusable)905 TEST_P(ExecutionTest13, WaitReusable) {
906     TestWait(/*reusable=*/true);
907 }
908 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest13, kTestValues);
909 
910 class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12,Wait)911 TEST_P(ExecutionTest12, Wait) {
912     TestWait(/*reusable=*/false);
913 }
TEST_P(ExecutionTest12,WaitReusable)914 TEST_P(ExecutionTest12, WaitReusable) {
915     TestWait(/*reusable=*/true);
916 }
917 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest12, kTestValues);
918 
919 class ExecutionTest11 : public ExecutionTestTemplate<TestDriver11> {};
TEST_P(ExecutionTest11,Wait)920 TEST_P(ExecutionTest11, Wait) {
921     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
922     TestWait(/*reusable=*/false);
923 }
TEST_P(ExecutionTest11,WaitReusable)924 TEST_P(ExecutionTest11, WaitReusable) {
925     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
926     TestWait(/*reusable=*/true);
927 }
928 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest11, kTestValues);
929 
930 class ExecutionTest10 : public ExecutionTestTemplate<TestDriver10> {};
TEST_P(ExecutionTest10,Wait)931 TEST_P(ExecutionTest10, Wait) {
932     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
933     TestWait(/*reusable=*/false);
934 }
TEST_P(ExecutionTest10,WaitReusable)935 TEST_P(ExecutionTest10, WaitReusable) {
936     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
937     TestWait(/*reusable=*/true);
938 }
939 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest10, kTestValues);
940 
941 auto kIntrospectionTestValues = ::testing::Values(
942         std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
943                         /* kUseIntrospectionAPI */ true),
944         std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
945                         /* kUseIntrospectionAPI */ true),
946         std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
947                         /* kUseIntrospectionAPI */ true),
948         std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
949                         WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
950                         /* kUseIntrospectionAPI */ true),
951         std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
952                         /* kUseIntrospectionAPI */ true));
953 
954 INSTANTIATE_TEST_SUITE_P(IntrospectionFlavor, ExecutionTest13, kIntrospectionTestValues);
955 
956 }  // namespace
957 }  // namespace android
958