• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android-base/logging.h>
20 #include <android/hidl/memory/1.0/IMemory.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25 
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29 
30 #include "Callbacks.h"
31 #include "GeneratedTestHarness.h"
32 #include "TestHarness.h"
33 #include "Utils.h"
34 #include "VtsHalNeuralnetworks.h"
35 
36 namespace android {
37 namespace hardware {
38 namespace neuralnetworks {
39 namespace V1_2 {
40 namespace vts {
41 namespace functional {
42 
43 using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
44 using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
45 using ::android::nn::allocateSharedMemory;
46 using ::test_helper::MixedTypedExample;
47 
48 namespace float32_model {
49 
50 // In frameworks/ml/nn/runtime/test/generated/, creates a hidl model of float32 mobilenet.
51 #include "examples/mobilenet_224_gender_basic_fixed.example.cpp"
52 #include "vts_models/mobilenet_224_gender_basic_fixed.model.cpp"
53 
54 // Prevent the compiler from complaining about an otherwise unused function.
55 [[maybe_unused]] auto dummy_createTestModel = createTestModel_dynamic_output_shape;
56 [[maybe_unused]] auto dummy_get_examples = get_examples_dynamic_output_shape;
57 
58 // MixedTypedExample is defined in frameworks/ml/nn/tools/test_generator/include/TestHarness.h.
59 // This function assumes the operation is always ADD.
getLargeModelExamples(uint32_t len)60 std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
61     float outputValue = 1.0f + static_cast<float>(len);
62     return {{.operands = {
63                      // Input
64                      {.operandDimensions = {{0, {1}}}, .float32Operands = {{0, {1.0f}}}},
65                      // Output
66                      {.operandDimensions = {{0, {1}}}, .float32Operands = {{0, {outputValue}}}}}}};
67 }
68 
69 }  // namespace float32_model
70 
71 namespace quant8_model {
72 
73 // In frameworks/ml/nn/runtime/test/generated/, creates a hidl model of quant8 mobilenet.
74 #include "examples/mobilenet_quantized.example.cpp"
75 #include "vts_models/mobilenet_quantized.model.cpp"
76 
77 // Prevent the compiler from complaining about an otherwise unused function.
78 [[maybe_unused]] auto dummy_createTestModel = createTestModel_dynamic_output_shape;
79 [[maybe_unused]] auto dummy_get_examples = get_examples_dynamic_output_shape;
80 
81 // MixedTypedExample is defined in frameworks/ml/nn/tools/test_generator/include/TestHarness.h.
82 // This function assumes the operation is always ADD.
getLargeModelExamples(uint32_t len)83 std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
84     uint8_t outputValue = 1 + static_cast<uint8_t>(len);
85     return {{.operands = {// Input
86                           {.operandDimensions = {{0, {1}}}, .quant8AsymmOperands = {{0, {1}}}},
87                           // Output
88                           {.operandDimensions = {{0, {1}}},
89                            .quant8AsymmOperands = {{0, {outputValue}}}}}}};
90 }
91 
92 }  // namespace quant8_model
93 
94 namespace {
95 
96 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
97 
98 // Creates cache handles based on provided file groups.
99 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)100 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
101                         const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
102     handles->resize(fileGroups.size());
103     for (uint32_t i = 0; i < fileGroups.size(); i++) {
104         std::vector<int> fds;
105         for (const auto& file : fileGroups[i]) {
106             int fd;
107             if (mode[i] == AccessMode::READ_ONLY) {
108                 fd = open(file.c_str(), O_RDONLY);
109             } else if (mode[i] == AccessMode::WRITE_ONLY) {
110                 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
111             } else if (mode[i] == AccessMode::READ_WRITE) {
112                 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
113             } else {
114                 FAIL();
115             }
116             ASSERT_GE(fd, 0);
117             fds.push_back(fd);
118         }
119         native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
120         ASSERT_NE(cacheNativeHandle, nullptr);
121         std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
122         (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
123     }
124 }
125 
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)126 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
127                         hidl_vec<hidl_handle>* handles) {
128     createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
129 }
130 
131 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
132 // For simplicity, activation scalar is shared. The second operand is not shared
133 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
134 // data locations in cache.
135 //
136 //                --------- activation --------
137 //                ↓      ↓      ↓             ↓
138 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
139 //                ↑      ↑      ↑             ↑
140 //               [1]    [1]    [1]           [1]
141 //
142 // This function assumes the operation is either ADD or MUL.
143 template <typename CppType, OperandType operandType>
createLargeTestModelImpl(OperationType op,uint32_t len)144 Model createLargeTestModelImpl(OperationType op, uint32_t len) {
145     EXPECT_TRUE(op == OperationType::ADD || op == OperationType::MUL);
146 
147     // Model operations and operands.
148     std::vector<Operation> operations(len);
149     std::vector<Operand> operands(len * 2 + 2);
150 
151     // The constant buffer pool. This contains the activation scalar, followed by the
152     // per-operation constant operands.
153     std::vector<uint8_t> operandValues(sizeof(int32_t) + len * sizeof(CppType));
154 
155     // The activation scalar, value = 0.
156     operands[0] = {
157             .type = OperandType::INT32,
158             .dimensions = {},
159             .numberOfConsumers = len,
160             .scale = 0.0f,
161             .zeroPoint = 0,
162             .lifetime = OperandLifeTime::CONSTANT_COPY,
163             .location = {.poolIndex = 0, .offset = 0, .length = sizeof(int32_t)},
164     };
165     memset(operandValues.data(), 0, sizeof(int32_t));
166 
167     // The buffer value of the constant second operand. The logical value is always 1.0f.
168     CppType bufferValue;
169     // The scale of the first and second operand.
170     float scale1, scale2;
171     if (operandType == OperandType::TENSOR_FLOAT32) {
172         bufferValue = 1.0f;
173         scale1 = 0.0f;
174         scale2 = 0.0f;
175     } else if (op == OperationType::ADD) {
176         bufferValue = 1;
177         scale1 = 1.0f;
178         scale2 = 1.0f;
179     } else {
180         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
181         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
182         bufferValue = 2;
183         scale1 = 1.0f;
184         scale2 = 0.5f;
185     }
186 
187     for (uint32_t i = 0; i < len; i++) {
188         const uint32_t firstInputIndex = i * 2 + 1;
189         const uint32_t secondInputIndex = firstInputIndex + 1;
190         const uint32_t outputIndex = secondInputIndex + 1;
191 
192         // The first operation input.
193         operands[firstInputIndex] = {
194                 .type = operandType,
195                 .dimensions = {1},
196                 .numberOfConsumers = 1,
197                 .scale = scale1,
198                 .zeroPoint = 0,
199                 .lifetime = (i == 0 ? OperandLifeTime::MODEL_INPUT
200                                     : OperandLifeTime::TEMPORARY_VARIABLE),
201                 .location = {},
202         };
203 
204         // The second operation input, value = 1.
205         operands[secondInputIndex] = {
206                 .type = operandType,
207                 .dimensions = {1},
208                 .numberOfConsumers = 1,
209                 .scale = scale2,
210                 .zeroPoint = 0,
211                 .lifetime = OperandLifeTime::CONSTANT_COPY,
212                 .location = {.poolIndex = 0,
213                              .offset = static_cast<uint32_t>(i * sizeof(CppType) + sizeof(int32_t)),
214                              .length = sizeof(CppType)},
215         };
216         memcpy(operandValues.data() + sizeof(int32_t) + i * sizeof(CppType), &bufferValue,
217                sizeof(CppType));
218 
219         // The operation. All operations share the same activation scalar.
220         // The output operand is created as an input in the next iteration of the loop, in the case
221         // of all but the last member of the chain; and after the loop as a model output, in the
222         // case of the last member of the chain.
223         operations[i] = {
224                 .type = op,
225                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
226                 .outputs = {outputIndex},
227         };
228     }
229 
230     // The model output.
231     operands.back() = {
232             .type = operandType,
233             .dimensions = {1},
234             .numberOfConsumers = 0,
235             .scale = scale1,
236             .zeroPoint = 0,
237             .lifetime = OperandLifeTime::MODEL_OUTPUT,
238             .location = {},
239     };
240 
241     const std::vector<uint32_t> inputIndexes = {1};
242     const std::vector<uint32_t> outputIndexes = {len * 2 + 1};
243     const std::vector<hidl_memory> pools = {};
244 
245     return {
246             .operands = operands,
247             .operations = operations,
248             .inputIndexes = inputIndexes,
249             .outputIndexes = outputIndexes,
250             .operandValues = operandValues,
251             .pools = pools,
252     };
253 }
254 
255 }  // namespace
256 
257 // Tag for the compilation caching tests.
258 class CompilationCachingTestBase : public NeuralnetworksHidlTest {
259   protected:
CompilationCachingTestBase(OperandType type)260     CompilationCachingTestBase(OperandType type) : kOperandType(type) {}
261 
SetUp()262     void SetUp() override {
263         NeuralnetworksHidlTest::SetUp();
264         ASSERT_NE(device.get(), nullptr);
265 
266         // Create cache directory. The cache directory and a temporary cache file is always created
267         // to test the behavior of prepareModelFromCache, even when caching is not supported.
268         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
269         char* cacheDir = mkdtemp(cacheDirTemp);
270         ASSERT_NE(cacheDir, nullptr);
271         mCacheDir = cacheDir;
272         mCacheDir.push_back('/');
273 
274         Return<void> ret = device->getNumberOfCacheFilesNeeded(
275                 [this](ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
276                     EXPECT_EQ(ErrorStatus::NONE, status);
277                     mNumModelCache = numModelCache;
278                     mNumDataCache = numDataCache;
279                 });
280         EXPECT_TRUE(ret.isOk());
281         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
282 
283         // Create empty cache files.
284         mTmpCache = mCacheDir + "tmp";
285         for (uint32_t i = 0; i < mNumModelCache; i++) {
286             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
287         }
288         for (uint32_t i = 0; i < mNumDataCache; i++) {
289             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
290         }
291         // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
292         hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
293         createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
294         createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
295         createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
296 
297         if (!mIsCachingSupported) {
298             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
299                          "support compilation caching.";
300             std::cout << "[          ]   Early termination of test because vendor service does not "
301                          "support compilation caching."
302                       << std::endl;
303         }
304     }
305 
TearDown()306     void TearDown() override {
307         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
308         if (!::testing::Test::HasFailure()) {
309             // Recursively remove the cache directory specified by mCacheDir.
310             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
311                 return remove(entry);
312             };
313             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
314         }
315         NeuralnetworksHidlTest::TearDown();
316     }
317 
318     // Model and examples creators. According to kOperandType, the following methods will return
319     // either float32 model/examples or the quant8 variant.
createTestModel()320     Model createTestModel() {
321         if (kOperandType == OperandType::TENSOR_FLOAT32) {
322             return float32_model::createTestModel();
323         } else {
324             return quant8_model::createTestModel();
325         }
326     }
327 
get_examples()328     std::vector<MixedTypedExample> get_examples() {
329         if (kOperandType == OperandType::TENSOR_FLOAT32) {
330             return float32_model::get_examples();
331         } else {
332             return quant8_model::get_examples();
333         }
334     }
335 
createLargeTestModel(OperationType op,uint32_t len)336     Model createLargeTestModel(OperationType op, uint32_t len) {
337         if (kOperandType == OperandType::TENSOR_FLOAT32) {
338             return createLargeTestModelImpl<float, OperandType::TENSOR_FLOAT32>(op, len);
339         } else {
340             return createLargeTestModelImpl<uint8_t, OperandType::TENSOR_QUANT8_ASYMM>(op, len);
341         }
342     }
343 
getLargeModelExamples(uint32_t len)344     std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
345         if (kOperandType == OperandType::TENSOR_FLOAT32) {
346             return float32_model::getLargeModelExamples(len);
347         } else {
348             return quant8_model::getLargeModelExamples(len);
349         }
350     }
351 
352     // See if the service can handle the model.
isModelFullySupported(const V1_2::Model & model)353     bool isModelFullySupported(const V1_2::Model& model) {
354         bool fullySupportsModel = false;
355         Return<void> supportedCall = device->getSupportedOperations_1_2(
356                 model,
357                 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
358                     ASSERT_EQ(ErrorStatus::NONE, status);
359                     ASSERT_EQ(supported.size(), model.operations.size());
360                     fullySupportsModel = std::all_of(supported.begin(), supported.end(),
361                                                      [](bool valid) { return valid; });
362                 });
363         EXPECT_TRUE(supportedCall.isOk());
364         return fullySupportsModel;
365     }
366 
saveModelToCache(const V1_2::Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)367     void saveModelToCache(const V1_2::Model& model, const hidl_vec<hidl_handle>& modelCache,
368                           const hidl_vec<hidl_handle>& dataCache,
369                           sp<IPreparedModel>* preparedModel = nullptr) {
370         if (preparedModel != nullptr) *preparedModel = nullptr;
371 
372         // Launch prepare model.
373         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
374         ASSERT_NE(nullptr, preparedModelCallback.get());
375         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
376         Return<ErrorStatus> prepareLaunchStatus =
377                 device->prepareModel_1_2(model, ExecutionPreference::FAST_SINGLE_ANSWER, modelCache,
378                                          dataCache, cacheToken, preparedModelCallback);
379         ASSERT_TRUE(prepareLaunchStatus.isOk());
380         ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
381 
382         // Retrieve prepared model.
383         preparedModelCallback->wait();
384         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
385         if (preparedModel != nullptr) {
386             *preparedModel =
387                     V1_2::IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
388                             .withDefault(nullptr);
389         }
390     }
391 
checkEarlyTermination(ErrorStatus status)392     bool checkEarlyTermination(ErrorStatus status) {
393         if (status == ErrorStatus::GENERAL_FAILURE) {
394             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
395                          "save the prepared model that it does not support.";
396             std::cout << "[          ]   Early termination of test because vendor service cannot "
397                          "save the prepared model that it does not support."
398                       << std::endl;
399             return true;
400         }
401         return false;
402     }
403 
checkEarlyTermination(const V1_2::Model & model)404     bool checkEarlyTermination(const V1_2::Model& model) {
405         if (!isModelFullySupported(model)) {
406             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
407                          "prepare model that it does not support.";
408             std::cout << "[          ]   Early termination of test because vendor service cannot "
409                          "prepare model that it does not support."
410                       << std::endl;
411             return true;
412         }
413         return false;
414     }
415 
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)416     void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
417                                const hidl_vec<hidl_handle>& dataCache,
418                                sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
419         // Launch prepare model from cache.
420         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
421         ASSERT_NE(nullptr, preparedModelCallback.get());
422         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
423         Return<ErrorStatus> prepareLaunchStatus = device->prepareModelFromCache(
424                 modelCache, dataCache, cacheToken, preparedModelCallback);
425         ASSERT_TRUE(prepareLaunchStatus.isOk());
426         if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
427             *preparedModel = nullptr;
428             *status = static_cast<ErrorStatus>(prepareLaunchStatus);
429             return;
430         }
431 
432         // Retrieve prepared model.
433         preparedModelCallback->wait();
434         *status = preparedModelCallback->getStatus();
435         *preparedModel = V1_2::IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
436                                  .withDefault(nullptr);
437     }
438 
439     // Absolute path to the temporary cache directory.
440     std::string mCacheDir;
441 
442     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
443     // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
444     // and the inner vector is for fds held by each handle.
445     std::vector<std::vector<std::string>> mModelCache;
446     std::vector<std::vector<std::string>> mDataCache;
447 
448     // A separate temporary file path in the tmp cache directory.
449     std::string mTmpCache;
450 
451     uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
452     uint32_t mNumModelCache;
453     uint32_t mNumDataCache;
454     uint32_t mIsCachingSupported;
455 
456     // The primary data type of the testModel.
457     const OperandType kOperandType;
458 };
459 
460 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
461 // pass running with float32 models and the second pass running with quant8 models.
462 class CompilationCachingTest : public CompilationCachingTestBase,
463                                public ::testing::WithParamInterface<OperandType> {
464   protected:
CompilationCachingTest()465     CompilationCachingTest() : CompilationCachingTestBase(GetParam()) {}
466 };
467 
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)468 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
469     // Create test HIDL model and compile.
470     const Model testModel = createTestModel();
471     if (checkEarlyTermination(testModel)) return;
472     sp<IPreparedModel> preparedModel = nullptr;
473 
474     // Save the compilation to cache.
475     {
476         hidl_vec<hidl_handle> modelCache, dataCache;
477         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
478         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
479         saveModelToCache(testModel, modelCache, dataCache);
480     }
481 
482     // Retrieve preparedModel from cache.
483     {
484         preparedModel = nullptr;
485         ErrorStatus status;
486         hidl_vec<hidl_handle> modelCache, dataCache;
487         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
488         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
489         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
490         if (!mIsCachingSupported) {
491             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
492             ASSERT_EQ(preparedModel, nullptr);
493             return;
494         } else if (checkEarlyTermination(status)) {
495             ASSERT_EQ(preparedModel, nullptr);
496             return;
497         } else {
498             ASSERT_EQ(status, ErrorStatus::NONE);
499             ASSERT_NE(preparedModel, nullptr);
500         }
501     }
502 
503     // Execute and verify results.
504     generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; }, get_examples(),
505                                            testModel.relaxComputationFloat32toFloat16,
506                                            /*testDynamicOutputShape=*/false);
507 }
508 
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)509 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
510     // Create test HIDL model and compile.
511     const Model testModel = createTestModel();
512     if (checkEarlyTermination(testModel)) return;
513     sp<IPreparedModel> preparedModel = nullptr;
514 
515     // Save the compilation to cache.
516     {
517         hidl_vec<hidl_handle> modelCache, dataCache;
518         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
519         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
520         uint8_t dummyBytes[] = {0, 0};
521         // Write a dummy integer to the cache.
522         // The driver should be able to handle non-empty cache and non-zero fd offset.
523         for (uint32_t i = 0; i < modelCache.size(); i++) {
524             ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
525                             sizeof(dummyBytes)),
526                       sizeof(dummyBytes));
527         }
528         for (uint32_t i = 0; i < dataCache.size(); i++) {
529             ASSERT_EQ(
530                     write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
531                     sizeof(dummyBytes));
532         }
533         saveModelToCache(testModel, modelCache, dataCache);
534     }
535 
536     // Retrieve preparedModel from cache.
537     {
538         preparedModel = nullptr;
539         ErrorStatus status;
540         hidl_vec<hidl_handle> modelCache, dataCache;
541         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
542         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
543         uint8_t dummyByte = 0;
544         // Advance the offset of each handle by one byte.
545         // The driver should be able to handle non-zero fd offset.
546         for (uint32_t i = 0; i < modelCache.size(); i++) {
547             ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
548         }
549         for (uint32_t i = 0; i < dataCache.size(); i++) {
550             ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
551         }
552         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
553         if (!mIsCachingSupported) {
554             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
555             ASSERT_EQ(preparedModel, nullptr);
556             return;
557         } else if (checkEarlyTermination(status)) {
558             ASSERT_EQ(preparedModel, nullptr);
559             return;
560         } else {
561             ASSERT_EQ(status, ErrorStatus::NONE);
562             ASSERT_NE(preparedModel, nullptr);
563         }
564     }
565 
566     // Execute and verify results.
567     generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; }, get_examples(),
568                                            testModel.relaxComputationFloat32toFloat16,
569                                            /*testDynamicOutputShape=*/false);
570 }
571 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)572 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
573     // Create test HIDL model and compile.
574     const Model testModel = createTestModel();
575     if (checkEarlyTermination(testModel)) return;
576 
577     // Test with number of model cache files greater than mNumModelCache.
578     {
579         hidl_vec<hidl_handle> modelCache, dataCache;
580         // Pass an additional cache file for model cache.
581         mModelCache.push_back({mTmpCache});
582         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
583         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
584         mModelCache.pop_back();
585         sp<IPreparedModel> preparedModel = nullptr;
586         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
587         ASSERT_NE(preparedModel, nullptr);
588         // Execute and verify results.
589         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
590                                                get_examples(),
591                                                testModel.relaxComputationFloat32toFloat16,
592                                                /*testDynamicOutputShape=*/false);
593         // Check if prepareModelFromCache fails.
594         preparedModel = nullptr;
595         ErrorStatus status;
596         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
597         if (status != ErrorStatus::INVALID_ARGUMENT) {
598             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
599         }
600         ASSERT_EQ(preparedModel, nullptr);
601     }
602 
603     // Test with number of model cache files smaller than mNumModelCache.
604     if (mModelCache.size() > 0) {
605         hidl_vec<hidl_handle> modelCache, dataCache;
606         // Pop out the last cache file.
607         auto tmp = mModelCache.back();
608         mModelCache.pop_back();
609         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
610         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
611         mModelCache.push_back(tmp);
612         sp<IPreparedModel> preparedModel = nullptr;
613         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
614         ASSERT_NE(preparedModel, nullptr);
615         // Execute and verify results.
616         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
617                                                get_examples(),
618                                                testModel.relaxComputationFloat32toFloat16,
619                                                /*testDynamicOutputShape=*/false);
620         // Check if prepareModelFromCache fails.
621         preparedModel = nullptr;
622         ErrorStatus status;
623         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
624         if (status != ErrorStatus::INVALID_ARGUMENT) {
625             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
626         }
627         ASSERT_EQ(preparedModel, nullptr);
628     }
629 
630     // Test with number of data cache files greater than mNumDataCache.
631     {
632         hidl_vec<hidl_handle> modelCache, dataCache;
633         // Pass an additional cache file for data cache.
634         mDataCache.push_back({mTmpCache});
635         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
636         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
637         mDataCache.pop_back();
638         sp<IPreparedModel> preparedModel = nullptr;
639         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
640         ASSERT_NE(preparedModel, nullptr);
641         // Execute and verify results.
642         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
643                                                get_examples(),
644                                                testModel.relaxComputationFloat32toFloat16,
645                                                /*testDynamicOutputShape=*/false);
646         // Check if prepareModelFromCache fails.
647         preparedModel = nullptr;
648         ErrorStatus status;
649         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
650         if (status != ErrorStatus::INVALID_ARGUMENT) {
651             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
652         }
653         ASSERT_EQ(preparedModel, nullptr);
654     }
655 
656     // Test with number of data cache files smaller than mNumDataCache.
657     if (mDataCache.size() > 0) {
658         hidl_vec<hidl_handle> modelCache, dataCache;
659         // Pop out the last cache file.
660         auto tmp = mDataCache.back();
661         mDataCache.pop_back();
662         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
663         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
664         mDataCache.push_back(tmp);
665         sp<IPreparedModel> preparedModel = nullptr;
666         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
667         ASSERT_NE(preparedModel, nullptr);
668         // Execute and verify results.
669         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
670                                                get_examples(),
671                                                testModel.relaxComputationFloat32toFloat16,
672                                                /*testDynamicOutputShape=*/false);
673         // Check if prepareModelFromCache fails.
674         preparedModel = nullptr;
675         ErrorStatus status;
676         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
677         if (status != ErrorStatus::INVALID_ARGUMENT) {
678             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
679         }
680         ASSERT_EQ(preparedModel, nullptr);
681     }
682 }
683 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)684 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
685     // Create test HIDL model and compile.
686     const Model testModel = createTestModel();
687     if (checkEarlyTermination(testModel)) return;
688 
689     // Save the compilation to cache.
690     {
691         hidl_vec<hidl_handle> modelCache, dataCache;
692         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
693         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
694         saveModelToCache(testModel, modelCache, dataCache);
695     }
696 
697     // Test with number of model cache files greater than mNumModelCache.
698     {
699         sp<IPreparedModel> preparedModel = nullptr;
700         ErrorStatus status;
701         hidl_vec<hidl_handle> modelCache, dataCache;
702         mModelCache.push_back({mTmpCache});
703         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
704         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
705         mModelCache.pop_back();
706         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
707         if (status != ErrorStatus::GENERAL_FAILURE) {
708             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
709         }
710         ASSERT_EQ(preparedModel, nullptr);
711     }
712 
713     // Test with number of model cache files smaller than mNumModelCache.
714     if (mModelCache.size() > 0) {
715         sp<IPreparedModel> preparedModel = nullptr;
716         ErrorStatus status;
717         hidl_vec<hidl_handle> modelCache, dataCache;
718         auto tmp = mModelCache.back();
719         mModelCache.pop_back();
720         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
721         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
722         mModelCache.push_back(tmp);
723         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
724         if (status != ErrorStatus::GENERAL_FAILURE) {
725             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
726         }
727         ASSERT_EQ(preparedModel, nullptr);
728     }
729 
730     // Test with number of data cache files greater than mNumDataCache.
731     {
732         sp<IPreparedModel> preparedModel = nullptr;
733         ErrorStatus status;
734         hidl_vec<hidl_handle> modelCache, dataCache;
735         mDataCache.push_back({mTmpCache});
736         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
737         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
738         mDataCache.pop_back();
739         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
740         if (status != ErrorStatus::GENERAL_FAILURE) {
741             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
742         }
743         ASSERT_EQ(preparedModel, nullptr);
744     }
745 
746     // Test with number of data cache files smaller than mNumDataCache.
747     if (mDataCache.size() > 0) {
748         sp<IPreparedModel> preparedModel = nullptr;
749         ErrorStatus status;
750         hidl_vec<hidl_handle> modelCache, dataCache;
751         auto tmp = mDataCache.back();
752         mDataCache.pop_back();
753         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
754         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
755         mDataCache.push_back(tmp);
756         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
757         if (status != ErrorStatus::GENERAL_FAILURE) {
758             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
759         }
760         ASSERT_EQ(preparedModel, nullptr);
761     }
762 }
763 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)764 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
765     // Create test HIDL model and compile.
766     const Model testModel = createTestModel();
767     if (checkEarlyTermination(testModel)) return;
768 
769     // Go through each handle in model cache, test with NumFd greater than 1.
770     for (uint32_t i = 0; i < mNumModelCache; i++) {
771         hidl_vec<hidl_handle> modelCache, dataCache;
772         // Pass an invalid number of fds for handle i.
773         mModelCache[i].push_back(mTmpCache);
774         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
775         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
776         mModelCache[i].pop_back();
777         sp<IPreparedModel> preparedModel = nullptr;
778         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
779         ASSERT_NE(preparedModel, nullptr);
780         // Execute and verify results.
781         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
782                                                get_examples(),
783                                                testModel.relaxComputationFloat32toFloat16,
784                                                /*testDynamicOutputShape=*/false);
785         // Check if prepareModelFromCache fails.
786         preparedModel = nullptr;
787         ErrorStatus status;
788         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
789         if (status != ErrorStatus::INVALID_ARGUMENT) {
790             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
791         }
792         ASSERT_EQ(preparedModel, nullptr);
793     }
794 
795     // Go through each handle in model cache, test with NumFd equal to 0.
796     for (uint32_t i = 0; i < mNumModelCache; i++) {
797         hidl_vec<hidl_handle> modelCache, dataCache;
798         // Pass an invalid number of fds for handle i.
799         auto tmp = mModelCache[i].back();
800         mModelCache[i].pop_back();
801         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
802         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
803         mModelCache[i].push_back(tmp);
804         sp<IPreparedModel> preparedModel = nullptr;
805         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
806         ASSERT_NE(preparedModel, nullptr);
807         // Execute and verify results.
808         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
809                                                get_examples(),
810                                                testModel.relaxComputationFloat32toFloat16,
811                                                /*testDynamicOutputShape=*/false);
812         // Check if prepareModelFromCache fails.
813         preparedModel = nullptr;
814         ErrorStatus status;
815         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
816         if (status != ErrorStatus::INVALID_ARGUMENT) {
817             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
818         }
819         ASSERT_EQ(preparedModel, nullptr);
820     }
821 
822     // Go through each handle in data cache, test with NumFd greater than 1.
823     for (uint32_t i = 0; i < mNumDataCache; i++) {
824         hidl_vec<hidl_handle> modelCache, dataCache;
825         // Pass an invalid number of fds for handle i.
826         mDataCache[i].push_back(mTmpCache);
827         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
828         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
829         mDataCache[i].pop_back();
830         sp<IPreparedModel> preparedModel = nullptr;
831         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
832         ASSERT_NE(preparedModel, nullptr);
833         // Execute and verify results.
834         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
835                                                get_examples(),
836                                                testModel.relaxComputationFloat32toFloat16,
837                                                /*testDynamicOutputShape=*/false);
838         // Check if prepareModelFromCache fails.
839         preparedModel = nullptr;
840         ErrorStatus status;
841         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
842         if (status != ErrorStatus::INVALID_ARGUMENT) {
843             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
844         }
845         ASSERT_EQ(preparedModel, nullptr);
846     }
847 
848     // Go through each handle in data cache, test with NumFd equal to 0.
849     for (uint32_t i = 0; i < mNumDataCache; i++) {
850         hidl_vec<hidl_handle> modelCache, dataCache;
851         // Pass an invalid number of fds for handle i.
852         auto tmp = mDataCache[i].back();
853         mDataCache[i].pop_back();
854         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
855         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
856         mDataCache[i].push_back(tmp);
857         sp<IPreparedModel> preparedModel = nullptr;
858         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
859         ASSERT_NE(preparedModel, nullptr);
860         // Execute and verify results.
861         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
862                                                get_examples(),
863                                                testModel.relaxComputationFloat32toFloat16,
864                                                /*testDynamicOutputShape=*/false);
865         // Check if prepareModelFromCache fails.
866         preparedModel = nullptr;
867         ErrorStatus status;
868         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
869         if (status != ErrorStatus::INVALID_ARGUMENT) {
870             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
871         }
872         ASSERT_EQ(preparedModel, nullptr);
873     }
874 }
875 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)876 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
877     // Create test HIDL model and compile.
878     const Model testModel = createTestModel();
879     if (checkEarlyTermination(testModel)) return;
880 
881     // Save the compilation to cache.
882     {
883         hidl_vec<hidl_handle> modelCache, dataCache;
884         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
885         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
886         saveModelToCache(testModel, modelCache, dataCache);
887     }
888 
889     // Go through each handle in model cache, test with NumFd greater than 1.
890     for (uint32_t i = 0; i < mNumModelCache; i++) {
891         sp<IPreparedModel> preparedModel = nullptr;
892         ErrorStatus status;
893         hidl_vec<hidl_handle> modelCache, dataCache;
894         mModelCache[i].push_back(mTmpCache);
895         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
896         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
897         mModelCache[i].pop_back();
898         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
899         if (status != ErrorStatus::GENERAL_FAILURE) {
900             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
901         }
902         ASSERT_EQ(preparedModel, nullptr);
903     }
904 
905     // Go through each handle in model cache, test with NumFd equal to 0.
906     for (uint32_t i = 0; i < mNumModelCache; i++) {
907         sp<IPreparedModel> preparedModel = nullptr;
908         ErrorStatus status;
909         hidl_vec<hidl_handle> modelCache, dataCache;
910         auto tmp = mModelCache[i].back();
911         mModelCache[i].pop_back();
912         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
913         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
914         mModelCache[i].push_back(tmp);
915         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
916         if (status != ErrorStatus::GENERAL_FAILURE) {
917             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
918         }
919         ASSERT_EQ(preparedModel, nullptr);
920     }
921 
922     // Go through each handle in data cache, test with NumFd greater than 1.
923     for (uint32_t i = 0; i < mNumDataCache; i++) {
924         sp<IPreparedModel> preparedModel = nullptr;
925         ErrorStatus status;
926         hidl_vec<hidl_handle> modelCache, dataCache;
927         mDataCache[i].push_back(mTmpCache);
928         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
929         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
930         mDataCache[i].pop_back();
931         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
932         if (status != ErrorStatus::GENERAL_FAILURE) {
933             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
934         }
935         ASSERT_EQ(preparedModel, nullptr);
936     }
937 
938     // Go through each handle in data cache, test with NumFd equal to 0.
939     for (uint32_t i = 0; i < mNumDataCache; i++) {
940         sp<IPreparedModel> preparedModel = nullptr;
941         ErrorStatus status;
942         hidl_vec<hidl_handle> modelCache, dataCache;
943         auto tmp = mDataCache[i].back();
944         mDataCache[i].pop_back();
945         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
946         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
947         mDataCache[i].push_back(tmp);
948         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
949         if (status != ErrorStatus::GENERAL_FAILURE) {
950             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
951         }
952         ASSERT_EQ(preparedModel, nullptr);
953     }
954 }
955 
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)956 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
957     // Create test HIDL model and compile.
958     const Model testModel = createTestModel();
959     if (checkEarlyTermination(testModel)) return;
960     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
961     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
962 
963     // Go through each handle in model cache, test with invalid access mode.
964     for (uint32_t i = 0; i < mNumModelCache; i++) {
965         hidl_vec<hidl_handle> modelCache, dataCache;
966         modelCacheMode[i] = AccessMode::READ_ONLY;
967         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
968         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
969         modelCacheMode[i] = AccessMode::READ_WRITE;
970         sp<IPreparedModel> preparedModel = nullptr;
971         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
972         ASSERT_NE(preparedModel, nullptr);
973         // Execute and verify results.
974         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
975                                                get_examples(),
976                                                testModel.relaxComputationFloat32toFloat16,
977                                                /*testDynamicOutputShape=*/false);
978         // Check if prepareModelFromCache fails.
979         preparedModel = nullptr;
980         ErrorStatus status;
981         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
982         if (status != ErrorStatus::INVALID_ARGUMENT) {
983             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
984         }
985         ASSERT_EQ(preparedModel, nullptr);
986     }
987 
988     // Go through each handle in data cache, test with invalid access mode.
989     for (uint32_t i = 0; i < mNumDataCache; i++) {
990         hidl_vec<hidl_handle> modelCache, dataCache;
991         dataCacheMode[i] = AccessMode::READ_ONLY;
992         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
993         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
994         dataCacheMode[i] = AccessMode::READ_WRITE;
995         sp<IPreparedModel> preparedModel = nullptr;
996         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
997         ASSERT_NE(preparedModel, nullptr);
998         // Execute and verify results.
999         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
1000                                                get_examples(),
1001                                                testModel.relaxComputationFloat32toFloat16,
1002                                                /*testDynamicOutputShape=*/false);
1003         // Check if prepareModelFromCache fails.
1004         preparedModel = nullptr;
1005         ErrorStatus status;
1006         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1007         if (status != ErrorStatus::INVALID_ARGUMENT) {
1008             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1009         }
1010         ASSERT_EQ(preparedModel, nullptr);
1011     }
1012 }
1013 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)1014 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
1015     // Create test HIDL model and compile.
1016     const Model testModel = createTestModel();
1017     if (checkEarlyTermination(testModel)) return;
1018     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
1019     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
1020 
1021     // Save the compilation to cache.
1022     {
1023         hidl_vec<hidl_handle> modelCache, dataCache;
1024         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1025         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1026         saveModelToCache(testModel, modelCache, dataCache);
1027     }
1028 
1029     // Go through each handle in model cache, test with invalid access mode.
1030     for (uint32_t i = 0; i < mNumModelCache; i++) {
1031         sp<IPreparedModel> preparedModel = nullptr;
1032         ErrorStatus status;
1033         hidl_vec<hidl_handle> modelCache, dataCache;
1034         modelCacheMode[i] = AccessMode::WRITE_ONLY;
1035         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
1036         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
1037         modelCacheMode[i] = AccessMode::READ_WRITE;
1038         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1039         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1040         ASSERT_EQ(preparedModel, nullptr);
1041     }
1042 
1043     // Go through each handle in data cache, test with invalid access mode.
1044     for (uint32_t i = 0; i < mNumDataCache; i++) {
1045         sp<IPreparedModel> preparedModel = nullptr;
1046         ErrorStatus status;
1047         hidl_vec<hidl_handle> modelCache, dataCache;
1048         dataCacheMode[i] = AccessMode::WRITE_ONLY;
1049         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
1050         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
1051         dataCacheMode[i] = AccessMode::READ_WRITE;
1052         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1053         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1054         ASSERT_EQ(preparedModel, nullptr);
1055     }
1056 }
1057 
1058 // Copy file contents between file groups.
1059 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
1060 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)1061 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
1062                            const std::vector<std::vector<std::string>>& to) {
1063     constexpr size_t kBufferSize = 1000000;
1064     uint8_t buffer[kBufferSize];
1065 
1066     ASSERT_EQ(from.size(), to.size());
1067     for (uint32_t i = 0; i < from.size(); i++) {
1068         ASSERT_EQ(from[i].size(), 1u);
1069         ASSERT_EQ(to[i].size(), 1u);
1070         int fromFd = open(from[i][0].c_str(), O_RDONLY);
1071         int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1072         ASSERT_GE(fromFd, 0);
1073         ASSERT_GE(toFd, 0);
1074 
1075         ssize_t readBytes;
1076         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1077             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1078         }
1079         ASSERT_GE(readBytes, 0);
1080 
1081         close(fromFd);
1082         close(toFd);
1083     }
1084 }
1085 
1086 // Number of operations in the large test model.
1087 constexpr uint32_t kLargeModelSize = 100;
1088 constexpr uint32_t kNumIterationsTOCTOU = 100;
1089 
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1090 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1091     if (!mIsCachingSupported) return;
1092 
1093     // Create test models and check if fully supported by the service.
1094     const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1095     if (checkEarlyTermination(testModelMul)) return;
1096     const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1097     if (checkEarlyTermination(testModelAdd)) return;
1098 
1099     // Save the testModelMul compilation to cache.
1100     auto modelCacheMul = mModelCache;
1101     for (auto& cache : modelCacheMul) {
1102         cache[0].append("_mul");
1103     }
1104     {
1105         hidl_vec<hidl_handle> modelCache, dataCache;
1106         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1107         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1108         saveModelToCache(testModelMul, modelCache, dataCache);
1109     }
1110 
1111     // Use a different token for testModelAdd.
1112     mToken[0]++;
1113 
1114     // This test is probabilistic, so we run it multiple times.
1115     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1116         // Save the testModelAdd compilation to cache.
1117         {
1118             hidl_vec<hidl_handle> modelCache, dataCache;
1119             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1120             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1121 
1122             // Spawn a thread to copy the cache content concurrently while saving to cache.
1123             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1124             saveModelToCache(testModelAdd, modelCache, dataCache);
1125             thread.join();
1126         }
1127 
1128         // Retrieve preparedModel from cache.
1129         {
1130             sp<IPreparedModel> preparedModel = nullptr;
1131             ErrorStatus status;
1132             hidl_vec<hidl_handle> modelCache, dataCache;
1133             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1134             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1135             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1136 
1137             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1138             // the prepared model must be executed with the correct result and not crash.
1139             if (status != ErrorStatus::NONE) {
1140                 ASSERT_EQ(preparedModel, nullptr);
1141             } else {
1142                 ASSERT_NE(preparedModel, nullptr);
1143                 generated_tests::EvaluatePreparedModel(
1144                         preparedModel, [](int) { return false; },
1145                         getLargeModelExamples(kLargeModelSize),
1146                         testModelAdd.relaxComputationFloat32toFloat16,
1147                         /*testDynamicOutputShape=*/false);
1148             }
1149         }
1150     }
1151 }
1152 
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1153 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1154     if (!mIsCachingSupported) return;
1155 
1156     // Create test models and check if fully supported by the service.
1157     const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1158     if (checkEarlyTermination(testModelMul)) return;
1159     const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1160     if (checkEarlyTermination(testModelAdd)) return;
1161 
1162     // Save the testModelMul compilation to cache.
1163     auto modelCacheMul = mModelCache;
1164     for (auto& cache : modelCacheMul) {
1165         cache[0].append("_mul");
1166     }
1167     {
1168         hidl_vec<hidl_handle> modelCache, dataCache;
1169         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1170         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1171         saveModelToCache(testModelMul, modelCache, dataCache);
1172     }
1173 
1174     // Use a different token for testModelAdd.
1175     mToken[0]++;
1176 
1177     // This test is probabilistic, so we run it multiple times.
1178     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1179         // Save the testModelAdd compilation to cache.
1180         {
1181             hidl_vec<hidl_handle> modelCache, dataCache;
1182             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1183             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1184             saveModelToCache(testModelAdd, modelCache, dataCache);
1185         }
1186 
1187         // Retrieve preparedModel from cache.
1188         {
1189             sp<IPreparedModel> preparedModel = nullptr;
1190             ErrorStatus status;
1191             hidl_vec<hidl_handle> modelCache, dataCache;
1192             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1193             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1194 
1195             // Spawn a thread to copy the cache content concurrently while preparing from cache.
1196             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1197             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1198             thread.join();
1199 
1200             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1201             // the prepared model must be executed with the correct result and not crash.
1202             if (status != ErrorStatus::NONE) {
1203                 ASSERT_EQ(preparedModel, nullptr);
1204             } else {
1205                 ASSERT_NE(preparedModel, nullptr);
1206                 generated_tests::EvaluatePreparedModel(
1207                         preparedModel, [](int) { return false; },
1208                         getLargeModelExamples(kLargeModelSize),
1209                         testModelAdd.relaxComputationFloat32toFloat16,
1210                         /*testDynamicOutputShape=*/false);
1211             }
1212         }
1213     }
1214 }
1215 
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1216 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1217     if (!mIsCachingSupported) return;
1218 
1219     // Create test models and check if fully supported by the service.
1220     const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1221     if (checkEarlyTermination(testModelMul)) return;
1222     const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1223     if (checkEarlyTermination(testModelAdd)) return;
1224 
1225     // Save the testModelMul compilation to cache.
1226     auto modelCacheMul = mModelCache;
1227     for (auto& cache : modelCacheMul) {
1228         cache[0].append("_mul");
1229     }
1230     {
1231         hidl_vec<hidl_handle> modelCache, dataCache;
1232         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1233         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1234         saveModelToCache(testModelMul, modelCache, dataCache);
1235     }
1236 
1237     // Use a different token for testModelAdd.
1238     mToken[0]++;
1239 
1240     // Save the testModelAdd compilation to cache.
1241     {
1242         hidl_vec<hidl_handle> modelCache, dataCache;
1243         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1244         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1245         saveModelToCache(testModelAdd, modelCache, dataCache);
1246     }
1247 
1248     // Replace the model cache of testModelAdd with testModelMul.
1249     copyCacheFiles(modelCacheMul, mModelCache);
1250 
1251     // Retrieve the preparedModel from cache, expect failure.
1252     {
1253         sp<IPreparedModel> preparedModel = nullptr;
1254         ErrorStatus status;
1255         hidl_vec<hidl_handle> modelCache, dataCache;
1256         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1257         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1258         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1259         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1260         ASSERT_EQ(preparedModel, nullptr);
1261     }
1262 }
1263 
1264 static const auto kOperandTypeChoices =
1265         ::testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1266 
1267 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest, kOperandTypeChoices);
1268 
1269 class CompilationCachingSecurityTest
1270     : public CompilationCachingTestBase,
1271       public ::testing::WithParamInterface<std::tuple<OperandType, uint32_t>> {
1272   protected:
CompilationCachingSecurityTest()1273     CompilationCachingSecurityTest() : CompilationCachingTestBase(std::get<0>(GetParam())) {}
1274 
SetUp()1275     void SetUp() {
1276         CompilationCachingTestBase::SetUp();
1277         generator.seed(kSeed);
1278     }
1279 
1280     // Get a random integer within a closed range [lower, upper].
1281     template <typename T>
getRandomInt(T lower,T upper)1282     T getRandomInt(T lower, T upper) {
1283         std::uniform_int_distribution<T> dis(lower, upper);
1284         return dis(generator);
1285     }
1286 
1287     // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1288     void flipOneBitOfCache(const std::string& filename, bool* skip) {
1289         FILE* pFile = fopen(filename.c_str(), "r+");
1290         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1291         long int fileSize = ftell(pFile);
1292         if (fileSize == 0) {
1293             fclose(pFile);
1294             *skip = true;
1295             return;
1296         }
1297         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1298         int readByte = fgetc(pFile);
1299         ASSERT_NE(readByte, EOF);
1300         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1301         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1302         fclose(pFile);
1303         *skip = false;
1304     }
1305 
1306     // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1307     void appendBytesToCache(const std::string& filename, bool* skip) {
1308         FILE* pFile = fopen(filename.c_str(), "a");
1309         uint32_t appendLength = getRandomInt(1, 256);
1310         for (uint32_t i = 0; i < appendLength; i++) {
1311             ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1312         }
1313         fclose(pFile);
1314         *skip = false;
1315     }
1316 
1317     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1318 
1319     // Test if the driver behaves as expected when given corrupted cache or token.
1320     // The modifier will be invoked after save to cache but before prepare from cache.
1321     // The modifier accepts one pointer argument "skip" as the returning value, indicating
1322     // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1323     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1324         const Model testModel = createTestModel();
1325         if (checkEarlyTermination(testModel)) return;
1326 
1327         // Save the compilation to cache.
1328         {
1329             hidl_vec<hidl_handle> modelCache, dataCache;
1330             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1331             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1332             saveModelToCache(testModel, modelCache, dataCache);
1333         }
1334 
1335         bool skip = false;
1336         modifier(&skip);
1337         if (skip) return;
1338 
1339         // Retrieve preparedModel from cache.
1340         {
1341             sp<IPreparedModel> preparedModel = nullptr;
1342             ErrorStatus status;
1343             hidl_vec<hidl_handle> modelCache, dataCache;
1344             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1345             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1346             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1347 
1348             switch (expected) {
1349                 case ExpectedResult::GENERAL_FAILURE:
1350                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1351                     ASSERT_EQ(preparedModel, nullptr);
1352                     break;
1353                 case ExpectedResult::NOT_CRASH:
1354                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1355                     break;
1356                 default:
1357                     FAIL();
1358             }
1359         }
1360     }
1361 
1362     const uint32_t kSeed = std::get<1>(GetParam());
1363     std::mt19937 generator;
1364 };
1365 
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1366 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1367     if (!mIsCachingSupported) return;
1368     for (uint32_t i = 0; i < mNumModelCache; i++) {
1369         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1370                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1371     }
1372 }
1373 
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1374 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1375     if (!mIsCachingSupported) return;
1376     for (uint32_t i = 0; i < mNumModelCache; i++) {
1377         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1378                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1379     }
1380 }
1381 
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1382 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1383     if (!mIsCachingSupported) return;
1384     for (uint32_t i = 0; i < mNumDataCache; i++) {
1385         testCorruptedCache(ExpectedResult::NOT_CRASH,
1386                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1387     }
1388 }
1389 
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1390 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1391     if (!mIsCachingSupported) return;
1392     for (uint32_t i = 0; i < mNumDataCache; i++) {
1393         testCorruptedCache(ExpectedResult::NOT_CRASH,
1394                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1395     }
1396 }
1397 
TEST_P(CompilationCachingSecurityTest,WrongToken)1398 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1399     if (!mIsCachingSupported) return;
1400     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1401         // Randomly flip one single bit in mToken.
1402         uint32_t ind =
1403                 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1404         mToken[ind] ^= (1U << getRandomInt(0, 7));
1405         *skip = false;
1406     });
1407 }
1408 
1409 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1410                         ::testing::Combine(kOperandTypeChoices, ::testing::Range(0U, 10U)));
1411 
1412 }  // namespace functional
1413 }  // namespace vts
1414 }  // namespace V1_2
1415 }  // namespace neuralnetworks
1416 }  // namespace hardware
1417 }  // namespace android
1418