• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "GeneratedTestHarness.h"
18 #include "Callbacks.h"
19 #include "ExecutionBurstController.h"
20 #include "TestHarness.h"
21 #include "Utils.h"
22 
23 #include <android-base/logging.h>
24 #include <android/hardware/neuralnetworks/1.0/IDevice.h>
25 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
26 #include <android/hardware/neuralnetworks/1.0/IPreparedModel.h>
27 #include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
28 #include <android/hardware/neuralnetworks/1.0/types.h>
29 #include <android/hardware/neuralnetworks/1.1/IDevice.h>
30 #include <android/hardware/neuralnetworks/1.2/IDevice.h>
31 #include <android/hardware/neuralnetworks/1.2/IExecutionCallback.h>
32 #include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
33 #include <android/hardware/neuralnetworks/1.2/IPreparedModelCallback.h>
34 #include <android/hidl/allocator/1.0/IAllocator.h>
35 #include <android/hidl/memory/1.0/IMemory.h>
36 #include <hidlmemory/mapping.h>
37 #include <iostream>
38 
39 namespace android {
40 namespace hardware {
41 namespace neuralnetworks {
42 
43 namespace generated_tests {
44 using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
45 using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
46 using ::test_helper::bool8;
47 using ::test_helper::compare;
48 using ::test_helper::expectMultinomialDistributionWithinTolerance;
49 using ::test_helper::filter;
50 using ::test_helper::for_all;
51 using ::test_helper::for_each;
52 using ::test_helper::MixedTyped;
53 using ::test_helper::MixedTypedExample;
54 using ::test_helper::resize_accordingly;
55 using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
56 
57 template <typename T>
copy_back_(std::map<int,std::vector<T>> * dst,const std::vector<RequestArgument> & ra,char * src)58 void copy_back_(std::map<int, std::vector<T>>* dst, const std::vector<RequestArgument>& ra,
59                 char* src) {
60     for_each<T>(*dst, [&ra, src](int index, std::vector<T>& m) {
61         ASSERT_EQ(m.size(), ra[index].location.length / sizeof(T));
62         char* begin = src + ra[index].location.offset;
63         memcpy(m.data(), begin, ra[index].location.length);
64     });
65 }
66 
copy_back(MixedTyped * dst,const std::vector<RequestArgument> & ra,char * src)67 void copy_back(MixedTyped* dst, const std::vector<RequestArgument>& ra, char* src) {
68     copy_back_(&dst->float32Operands, ra, src);
69     copy_back_(&dst->int32Operands, ra, src);
70     copy_back_(&dst->quant8AsymmOperands, ra, src);
71     copy_back_(&dst->quant16SymmOperands, ra, src);
72     copy_back_(&dst->float16Operands, ra, src);
73     copy_back_(&dst->bool8Operands, ra, src);
74     copy_back_(&dst->quant8ChannelOperands, ra, src);
75     copy_back_(&dst->quant16AsymmOperands, ra, src);
76     copy_back_(&dst->quant8SymmOperands, ra, src);
77     static_assert(9 == MixedTyped::kNumTypes,
78                   "Number of types in MixedTyped changed, but copy_back function wasn't updated");
79 }
80 
isZeroSized(const MixedTyped & example,uint32_t index)81 static bool isZeroSized(const MixedTyped& example, uint32_t index) {
82     for (auto i : example.operandDimensions.at(index)) {
83         if (i == 0) return true;
84     }
85     return false;
86 }
87 
88 // Top level driver for models and examples generated by test_generator.py
89 // Test driver for those generated from ml/nn/runtime/test/spec
ExecutePreparedModel(sp<V1_0::IPreparedModel> & preparedModel,const Request & request,MeasureTiming,sp<ExecutionCallback> & callback)90 static Return<ErrorStatus> ExecutePreparedModel(sp<V1_0::IPreparedModel>& preparedModel,
91                                                 const Request& request, MeasureTiming,
92                                                 sp<ExecutionCallback>& callback) {
93     return preparedModel->execute(request, callback);
94 }
ExecutePreparedModel(sp<V1_2::IPreparedModel> & preparedModel,const Request & request,MeasureTiming measure,sp<ExecutionCallback> & callback)95 static Return<ErrorStatus> ExecutePreparedModel(sp<V1_2::IPreparedModel>& preparedModel,
96                                                 const Request& request, MeasureTiming measure,
97                                                 sp<ExecutionCallback>& callback) {
98     return preparedModel->execute_1_2(request, measure, callback);
99 }
ExecutePreparedModel(sp<V1_0::IPreparedModel> &,const Request &,MeasureTiming,hidl_vec<OutputShape> *,Timing *)100 static Return<ErrorStatus> ExecutePreparedModel(sp<V1_0::IPreparedModel>&, const Request&,
101                                                 MeasureTiming, hidl_vec<OutputShape>*, Timing*) {
102     ADD_FAILURE() << "asking for synchronous execution at V1_0";
103     return ErrorStatus::GENERAL_FAILURE;
104 }
ExecutePreparedModel(sp<V1_2::IPreparedModel> & preparedModel,const Request & request,MeasureTiming measure,hidl_vec<OutputShape> * outputShapes,Timing * timing)105 static Return<ErrorStatus> ExecutePreparedModel(sp<V1_2::IPreparedModel>& preparedModel,
106                                                 const Request& request, MeasureTiming measure,
107                                                 hidl_vec<OutputShape>* outputShapes,
108                                                 Timing* timing) {
109     ErrorStatus result;
110     Return<void> ret = preparedModel->executeSynchronously(
111             request, measure,
112             [&result, outputShapes, timing](ErrorStatus error, const hidl_vec<OutputShape>& shapes,
113                                             const Timing& time) {
114                 result = error;
115                 *outputShapes = shapes;
116                 *timing = time;
117             });
118     if (!ret.isOk()) {
119         return ErrorStatus::GENERAL_FAILURE;
120     }
121     return result;
122 }
CreateBurst(const sp<V1_0::IPreparedModel> &)123 static std::unique_ptr<::android::nn::ExecutionBurstController> CreateBurst(
124         const sp<V1_0::IPreparedModel>&) {
125     ADD_FAILURE() << "asking for burst execution at V1_0";
126     return nullptr;
127 }
CreateBurst(const sp<V1_2::IPreparedModel> & preparedModel)128 static std::shared_ptr<::android::nn::ExecutionBurstController> CreateBurst(
129         const sp<V1_2::IPreparedModel>& preparedModel) {
130     return ::android::nn::ExecutionBurstController::create(preparedModel, /*blocking=*/true);
131 }
132 enum class Executor { ASYNC, SYNC, BURST };
133 enum class OutputType { FULLY_SPECIFIED, UNSPECIFIED, INSUFFICIENT };
134 const float kDefaultAtol = 1e-5f;
135 const float kDefaultRtol = 1e-5f;
136 template <typename T_IPreparedModel>
EvaluatePreparedModel(sp<T_IPreparedModel> & preparedModel,std::function<bool (int)> is_ignored,const std::vector<MixedTypedExample> & examples,bool hasRelaxedFloat32Model,float fpAtol,float fpRtol,Executor executor,MeasureTiming measure,OutputType outputType)137 void EvaluatePreparedModel(sp<T_IPreparedModel>& preparedModel, std::function<bool(int)> is_ignored,
138                            const std::vector<MixedTypedExample>& examples,
139                            bool hasRelaxedFloat32Model, float fpAtol, float fpRtol,
140                            Executor executor, MeasureTiming measure, OutputType outputType) {
141     const uint32_t INPUT = 0;
142     const uint32_t OUTPUT = 1;
143 
144     int example_no = 1;
145     for (auto& example : examples) {
146         SCOPED_TRACE(example_no++);
147         const MixedTyped& inputs = example.operands.first;
148         const MixedTyped& golden = example.operands.second;
149 
150         const bool hasFloat16Inputs = !inputs.float16Operands.empty();
151         if (hasRelaxedFloat32Model || hasFloat16Inputs) {
152             // TODO: Adjust the error limit based on testing.
153             // If in relaxed mode, set the absolute tolerance to be 5ULP of FP16.
154             fpAtol = 5.0f * 0.0009765625f;
155             // Set the relative tolerance to be 5ULP of the corresponding FP precision.
156             fpRtol = 5.0f * 0.0009765625f;
157         }
158 
159         std::vector<RequestArgument> inputs_info, outputs_info;
160         uint32_t inputSize = 0, outputSize = 0;
161         // This function only partially specifies the metadata (vector of RequestArguments).
162         // The contents are copied over below.
163         for_all(inputs, [&inputs_info, &inputSize](int index, auto, auto s) {
164             if (inputs_info.size() <= static_cast<size_t>(index)) inputs_info.resize(index + 1);
165             RequestArgument arg = {
166                 .location = {.poolIndex = INPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
167                 .dimensions = {},
168             };
169             RequestArgument arg_empty = {
170                 .hasNoValue = true,
171             };
172             inputs_info[index] = s ? arg : arg_empty;
173             inputSize += s;
174         });
175         // Compute offset for inputs 1 and so on
176         {
177             size_t offset = 0;
178             for (auto& i : inputs_info) {
179                 if (!i.hasNoValue) i.location.offset = offset;
180                 offset += i.location.length;
181             }
182         }
183 
184         MixedTyped test;  // holding test results
185 
186         // Go through all outputs, initialize RequestArgument descriptors
187         resize_accordingly(golden, test);
188         bool sizeLargerThanOne = true;
189         for_all(golden, [&golden, &outputs_info, &outputSize, &outputType, &sizeLargerThanOne](
190                                 int index, auto, auto s) {
191             if (outputs_info.size() <= static_cast<size_t>(index)) outputs_info.resize(index + 1);
192             if (index == 0) {
193                 // On OutputType::INSUFFICIENT, set the output operand with index 0 with
194                 // buffer size one byte less than needed.
195                 if (outputType == OutputType::INSUFFICIENT) {
196                     if (s > 1 && !isZeroSized(golden, index)) {
197                         s -= 1;
198                     } else {
199                         sizeLargerThanOne = false;
200                     }
201                 }
202             }
203             RequestArgument arg = {
204                 .location = {.poolIndex = OUTPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
205                 .dimensions = {},
206             };
207             outputs_info[index] = arg;
208             outputSize += s;
209         });
210         // If output0 does not have size larger than one byte,
211         // we can not provide an insufficient buffer
212         if (!sizeLargerThanOne && outputType == OutputType::INSUFFICIENT) return;
213         // Compute offset for outputs 1 and so on
214         {
215             size_t offset = 0;
216             for (auto& i : outputs_info) {
217                 i.location.offset = offset;
218                 offset += i.location.length;
219             }
220         }
221         std::vector<hidl_memory> pools = {nn::allocateSharedMemory(inputSize),
222                                           nn::allocateSharedMemory(outputSize)};
223         ASSERT_NE(0ull, pools[INPUT].size());
224         ASSERT_NE(0ull, pools[OUTPUT].size());
225 
226         // load data
227         sp<IMemory> inputMemory = mapMemory(pools[INPUT]);
228         sp<IMemory> outputMemory = mapMemory(pools[OUTPUT]);
229         ASSERT_NE(nullptr, inputMemory.get());
230         ASSERT_NE(nullptr, outputMemory.get());
231         char* inputPtr = reinterpret_cast<char*>(static_cast<void*>(inputMemory->getPointer()));
232         char* outputPtr = reinterpret_cast<char*>(static_cast<void*>(outputMemory->getPointer()));
233         ASSERT_NE(nullptr, inputPtr);
234         ASSERT_NE(nullptr, outputPtr);
235         inputMemory->update();
236         outputMemory->update();
237 
238         // Go through all inputs, copy the values
239         for_all(inputs, [&inputs_info, inputPtr](int index, auto p, auto s) {
240             char* begin = (char*)p;
241             char* end = begin + s;
242             // TODO: handle more than one input
243             std::copy(begin, end, inputPtr + inputs_info[index].location.offset);
244         });
245 
246         inputMemory->commit();
247         outputMemory->commit();
248 
249         const Request request = {.inputs = inputs_info, .outputs = outputs_info, .pools = pools};
250 
251         ErrorStatus executionStatus;
252         hidl_vec<OutputShape> outputShapes;
253         Timing timing;
254         switch (executor) {
255             case Executor::ASYNC: {
256                 SCOPED_TRACE("asynchronous");
257 
258                 // launch execution
259                 sp<ExecutionCallback> executionCallback = new ExecutionCallback();
260                 ASSERT_NE(nullptr, executionCallback.get());
261                 Return<ErrorStatus> executionLaunchStatus =
262                         ExecutePreparedModel(preparedModel, request, measure, executionCallback);
263                 ASSERT_TRUE(executionLaunchStatus.isOk());
264                 EXPECT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(executionLaunchStatus));
265 
266                 // retrieve execution status
267                 executionCallback->wait();
268                 executionStatus = executionCallback->getStatus();
269                 outputShapes = executionCallback->getOutputShapes();
270                 timing = executionCallback->getTiming();
271 
272                 break;
273             }
274             case Executor::SYNC: {
275                 SCOPED_TRACE("synchronous");
276 
277                 // execute
278                 Return<ErrorStatus> executionReturnStatus = ExecutePreparedModel(
279                         preparedModel, request, measure, &outputShapes, &timing);
280                 ASSERT_TRUE(executionReturnStatus.isOk());
281                 executionStatus = static_cast<ErrorStatus>(executionReturnStatus);
282 
283                 break;
284             }
285             case Executor::BURST: {
286                 SCOPED_TRACE("burst");
287 
288                 // create burst
289                 const std::shared_ptr<::android::nn::ExecutionBurstController> controller =
290                         CreateBurst(preparedModel);
291                 ASSERT_NE(nullptr, controller.get());
292 
293                 // create memory keys
294                 std::vector<intptr_t> keys(request.pools.size());
295                 for (size_t i = 0; i < keys.size(); ++i) {
296                     keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
297                 }
298 
299                 // execute burst
300                 std::tie(executionStatus, outputShapes, timing) =
301                         controller->compute(request, measure, keys);
302 
303                 break;
304             }
305         }
306 
307         if (outputType != OutputType::FULLY_SPECIFIED &&
308             executionStatus == ErrorStatus::GENERAL_FAILURE) {
309             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
310                          "execute model that it does not support.";
311             std::cout << "[          ]   Early termination of test because vendor service cannot "
312                          "execute model that it does not support."
313                       << std::endl;
314             GTEST_SKIP();
315         }
316         if (measure == MeasureTiming::NO) {
317             EXPECT_EQ(UINT64_MAX, timing.timeOnDevice);
318             EXPECT_EQ(UINT64_MAX, timing.timeInDriver);
319         } else {
320             if (timing.timeOnDevice != UINT64_MAX && timing.timeInDriver != UINT64_MAX) {
321                 EXPECT_LE(timing.timeOnDevice, timing.timeInDriver);
322             }
323         }
324 
325         switch (outputType) {
326             case OutputType::FULLY_SPECIFIED:
327                 // If the model output operands are fully specified, outputShapes must be either
328                 // either empty, or have the same number of elements as the number of outputs.
329                 ASSERT_EQ(ErrorStatus::NONE, executionStatus);
330                 ASSERT_TRUE(outputShapes.size() == 0 ||
331                             outputShapes.size() == test.operandDimensions.size());
332                 break;
333             case OutputType::UNSPECIFIED:
334                 // If the model output operands are not fully specified, outputShapes must have
335                 // the same number of elements as the number of outputs.
336                 ASSERT_EQ(ErrorStatus::NONE, executionStatus);
337                 ASSERT_EQ(outputShapes.size(), test.operandDimensions.size());
338                 break;
339             case OutputType::INSUFFICIENT:
340                 ASSERT_EQ(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, executionStatus);
341                 ASSERT_EQ(outputShapes.size(), test.operandDimensions.size());
342                 ASSERT_FALSE(outputShapes[0].isSufficient);
343                 return;
344         }
345         // Go through all outputs, overwrite output dimensions with returned output shapes
346         if (outputShapes.size() > 0) {
347             for_each<uint32_t>(test.operandDimensions,
348                                [&outputShapes](int idx, std::vector<uint32_t>& dim) {
349                                    dim = outputShapes[idx].dimensions;
350                                });
351         }
352 
353         // validate results
354         outputMemory->read();
355         copy_back(&test, outputs_info, outputPtr);
356         outputMemory->commit();
357         // Filter out don't cares
358         MixedTyped filtered_golden = filter(golden, is_ignored);
359         MixedTyped filtered_test = filter(test, is_ignored);
360 
361         // We want "close-enough" results for float
362         compare(filtered_golden, filtered_test, fpAtol, fpRtol);
363 
364         if (example.expectedMultinomialDistributionTolerance > 0) {
365             expectMultinomialDistributionWithinTolerance(test, example);
366         }
367     }
368 }
369 template <typename T_IPreparedModel>
EvaluatePreparedModel(sp<T_IPreparedModel> & preparedModel,std::function<bool (int)> is_ignored,const std::vector<MixedTypedExample> & examples,bool hasRelaxedFloat32Model,Executor executor,MeasureTiming measure,OutputType outputType)370 void EvaluatePreparedModel(sp<T_IPreparedModel>& preparedModel, std::function<bool(int)> is_ignored,
371                            const std::vector<MixedTypedExample>& examples,
372                            bool hasRelaxedFloat32Model, Executor executor, MeasureTiming measure,
373                            OutputType outputType) {
374     EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model, kDefaultAtol,
375                           kDefaultRtol, executor, measure, outputType);
376 }
377 
EvaluatePreparedModel(sp<V1_2::IPreparedModel> & preparedModel,std::function<bool (int)> is_ignored,const std::vector<MixedTypedExample> & examples,bool hasRelaxedFloat32Model,bool testDynamicOutputShape)378 void EvaluatePreparedModel(sp<V1_2::IPreparedModel>& preparedModel,
379                            std::function<bool(int)> is_ignored,
380                            const std::vector<MixedTypedExample>& examples,
381                            bool hasRelaxedFloat32Model, bool testDynamicOutputShape) {
382     if (testDynamicOutputShape) {
383         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
384                               Executor::ASYNC, MeasureTiming::NO, OutputType::UNSPECIFIED);
385         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
386                               Executor::SYNC, MeasureTiming::NO, OutputType::UNSPECIFIED);
387         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
388                               Executor::BURST, MeasureTiming::NO, OutputType::UNSPECIFIED);
389         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
390                               Executor::ASYNC, MeasureTiming::YES, OutputType::UNSPECIFIED);
391         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
392                               Executor::SYNC, MeasureTiming::YES, OutputType::UNSPECIFIED);
393         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
394                               Executor::BURST, MeasureTiming::YES, OutputType::UNSPECIFIED);
395         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
396                               Executor::ASYNC, MeasureTiming::NO, OutputType::INSUFFICIENT);
397         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
398                               Executor::SYNC, MeasureTiming::NO, OutputType::INSUFFICIENT);
399         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
400                               Executor::BURST, MeasureTiming::NO, OutputType::INSUFFICIENT);
401         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
402                               Executor::ASYNC, MeasureTiming::YES, OutputType::INSUFFICIENT);
403         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
404                               Executor::SYNC, MeasureTiming::YES, OutputType::INSUFFICIENT);
405         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
406                               Executor::BURST, MeasureTiming::YES, OutputType::INSUFFICIENT);
407     } else {
408         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
409                               Executor::ASYNC, MeasureTiming::NO, OutputType::FULLY_SPECIFIED);
410         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
411                               Executor::SYNC, MeasureTiming::NO, OutputType::FULLY_SPECIFIED);
412         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
413                               Executor::BURST, MeasureTiming::NO, OutputType::FULLY_SPECIFIED);
414         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
415                               Executor::ASYNC, MeasureTiming::YES, OutputType::FULLY_SPECIFIED);
416         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
417                               Executor::SYNC, MeasureTiming::YES, OutputType::FULLY_SPECIFIED);
418         EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model,
419                               Executor::BURST, MeasureTiming::YES, OutputType::FULLY_SPECIFIED);
420     }
421 }
422 
getPreparedModel(sp<PreparedModelCallback> callback,sp<V1_0::IPreparedModel> * preparedModel)423 static void getPreparedModel(sp<PreparedModelCallback> callback,
424                              sp<V1_0::IPreparedModel>* preparedModel) {
425     *preparedModel = callback->getPreparedModel();
426 }
getPreparedModel(sp<PreparedModelCallback> callback,sp<V1_2::IPreparedModel> * preparedModel)427 static void getPreparedModel(sp<PreparedModelCallback> callback,
428                              sp<V1_2::IPreparedModel>* preparedModel) {
429     sp<V1_0::IPreparedModel> preparedModelV1_0 = callback->getPreparedModel();
430     *preparedModel = V1_2::IPreparedModel::castFrom(preparedModelV1_0).withDefault(nullptr);
431 }
432 
Execute(const sp<V1_0::IDevice> & device,std::function<V1_0::Model (void)> create_model,std::function<bool (int)> is_ignored,const std::vector<MixedTypedExample> & examples)433 void Execute(const sp<V1_0::IDevice>& device, std::function<V1_0::Model(void)> create_model,
434              std::function<bool(int)> is_ignored, const std::vector<MixedTypedExample>& examples) {
435     V1_0::Model model = create_model();
436 
437     // see if service can handle model
438     bool fullySupportsModel = false;
439     Return<void> supportedCall = device->getSupportedOperations(
440         model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
441             ASSERT_EQ(ErrorStatus::NONE, status);
442             ASSERT_NE(0ul, supported.size());
443             fullySupportsModel =
444                 std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
445         });
446     ASSERT_TRUE(supportedCall.isOk());
447 
448     // launch prepare model
449     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
450     ASSERT_NE(nullptr, preparedModelCallback.get());
451     Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback);
452     ASSERT_TRUE(prepareLaunchStatus.isOk());
453     ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
454 
455     // retrieve prepared model
456     preparedModelCallback->wait();
457     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
458     sp<V1_0::IPreparedModel> preparedModel;
459     getPreparedModel(preparedModelCallback, &preparedModel);
460 
461     // early termination if vendor service cannot fully prepare model
462     if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
463         ASSERT_EQ(nullptr, preparedModel.get());
464         LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
465                      "prepare model that it does not support.";
466         std::cout << "[          ]   Early termination of test because vendor service cannot "
467                      "prepare model that it does not support."
468                   << std::endl;
469         GTEST_SKIP();
470     }
471     EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
472     ASSERT_NE(nullptr, preparedModel.get());
473 
474     float fpAtol = 1e-5f, fpRtol = 5.0f * 1.1920928955078125e-7f;
475     EvaluatePreparedModel(preparedModel, is_ignored, examples,
476                           /*hasRelaxedFloat32Model=*/false, fpAtol, fpRtol, Executor::ASYNC,
477                           MeasureTiming::NO, OutputType::FULLY_SPECIFIED);
478 }
479 
Execute(const sp<V1_1::IDevice> & device,std::function<V1_1::Model (void)> create_model,std::function<bool (int)> is_ignored,const std::vector<MixedTypedExample> & examples)480 void Execute(const sp<V1_1::IDevice>& device, std::function<V1_1::Model(void)> create_model,
481              std::function<bool(int)> is_ignored, const std::vector<MixedTypedExample>& examples) {
482     V1_1::Model model = create_model();
483 
484     // see if service can handle model
485     bool fullySupportsModel = false;
486     Return<void> supportedCall = device->getSupportedOperations_1_1(
487         model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
488             ASSERT_EQ(ErrorStatus::NONE, status);
489             ASSERT_NE(0ul, supported.size());
490             fullySupportsModel =
491                 std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
492         });
493     ASSERT_TRUE(supportedCall.isOk());
494 
495     // launch prepare model
496     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
497     ASSERT_NE(nullptr, preparedModelCallback.get());
498     Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_1(
499         model, ExecutionPreference::FAST_SINGLE_ANSWER, preparedModelCallback);
500     ASSERT_TRUE(prepareLaunchStatus.isOk());
501     ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
502 
503     // retrieve prepared model
504     preparedModelCallback->wait();
505     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
506     sp<V1_0::IPreparedModel> preparedModel;
507     getPreparedModel(preparedModelCallback, &preparedModel);
508 
509     // early termination if vendor service cannot fully prepare model
510     if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
511         ASSERT_EQ(nullptr, preparedModel.get());
512         LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
513                      "prepare model that it does not support.";
514         std::cout << "[          ]   Early termination of test because vendor service cannot "
515                      "prepare model that it does not support."
516                   << std::endl;
517         GTEST_SKIP();
518     }
519     EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
520     ASSERT_NE(nullptr, preparedModel.get());
521 
522     EvaluatePreparedModel(preparedModel, is_ignored, examples,
523                           model.relaxComputationFloat32toFloat16, 1e-5f, 1e-5f, Executor::ASYNC,
524                           MeasureTiming::NO, OutputType::FULLY_SPECIFIED);
525 }
526 
PrepareModel(const sp<V1_2::IDevice> & device,const V1_2::Model & model,sp<V1_2::IPreparedModel> * preparedModel)527 void PrepareModel(const sp<V1_2::IDevice>& device, const V1_2::Model& model,
528                   sp<V1_2::IPreparedModel>* preparedModel) {
529     // see if service can handle model
530     bool fullySupportsModel = false;
531     Return<void> supportedCall = device->getSupportedOperations_1_2(
532         model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
533             ASSERT_EQ(ErrorStatus::NONE, status);
534             ASSERT_NE(0ul, supported.size());
535             fullySupportsModel =
536                 std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
537         });
538     ASSERT_TRUE(supportedCall.isOk());
539 
540     // launch prepare model
541     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
542     ASSERT_NE(nullptr, preparedModelCallback.get());
543     Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_2(
544             model, ExecutionPreference::FAST_SINGLE_ANSWER, hidl_vec<hidl_handle>(),
545             hidl_vec<hidl_handle>(), HidlToken(), preparedModelCallback);
546     ASSERT_TRUE(prepareLaunchStatus.isOk());
547     ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
548 
549     // retrieve prepared model
550     preparedModelCallback->wait();
551     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
552     getPreparedModel(preparedModelCallback, preparedModel);
553 
554     // early termination if vendor service cannot fully prepare model
555     if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
556         ASSERT_EQ(nullptr, preparedModel->get());
557         LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
558                      "prepare model that it does not support.";
559         std::cout << "[          ]   Early termination of test because vendor service cannot "
560                      "prepare model that it does not support."
561                   << std::endl;
562         return;
563     }
564     EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
565     ASSERT_NE(nullptr, preparedModel->get());
566 }
567 
568 // TODO: Reduce code duplication.
Execute(const sp<V1_2::IDevice> & device,std::function<V1_2::Model (void)> create_model,std::function<bool (int)> is_ignored,const std::vector<MixedTypedExample> & examples,bool testDynamicOutputShape)569 void Execute(const sp<V1_2::IDevice>& device, std::function<V1_2::Model(void)> create_model,
570              std::function<bool(int)> is_ignored, const std::vector<MixedTypedExample>& examples,
571              bool testDynamicOutputShape) {
572     V1_2::Model model = create_model();
573     sp<V1_2::IPreparedModel> preparedModel = nullptr;
574     PrepareModel(device, model, &preparedModel);
575     if (preparedModel == nullptr) {
576         GTEST_SKIP();
577     }
578     EvaluatePreparedModel(preparedModel, is_ignored, examples,
579                           model.relaxComputationFloat32toFloat16, testDynamicOutputShape);
580 }
581 
582 }  // namespace generated_tests
583 
584 }  // namespace neuralnetworks
585 }  // namespace hardware
586 }  // namespace android
587