1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19 #include <android-base/logging.h>
20 #include <android/hidl/memory/1.0/IMemory.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29
30 #include "Callbacks.h"
31 #include "GeneratedTestHarness.h"
32 #include "TestHarness.h"
33 #include "Utils.h"
34 #include "VtsHalNeuralnetworks.h"
35
36 namespace android {
37 namespace hardware {
38 namespace neuralnetworks {
39 namespace V1_2 {
40 namespace vts {
41 namespace functional {
42
43 using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
44 using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
45 using ::android::nn::allocateSharedMemory;
46 using ::test_helper::MixedTypedExample;
47
48 namespace float32_model {
49
50 // In frameworks/ml/nn/runtime/test/generated/, creates a hidl model of float32 mobilenet.
51 #include "examples/mobilenet_224_gender_basic_fixed.example.cpp"
52 #include "vts_models/mobilenet_224_gender_basic_fixed.model.cpp"
53
54 // Prevent the compiler from complaining about an otherwise unused function.
55 [[maybe_unused]] auto dummy_createTestModel = createTestModel_dynamic_output_shape;
56 [[maybe_unused]] auto dummy_get_examples = get_examples_dynamic_output_shape;
57
58 // MixedTypedExample is defined in frameworks/ml/nn/tools/test_generator/include/TestHarness.h.
59 // This function assumes the operation is always ADD.
getLargeModelExamples(uint32_t len)60 std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
61 float outputValue = 1.0f + static_cast<float>(len);
62 return {{.operands = {
63 // Input
64 {.operandDimensions = {{0, {1}}}, .float32Operands = {{0, {1.0f}}}},
65 // Output
66 {.operandDimensions = {{0, {1}}}, .float32Operands = {{0, {outputValue}}}}}}};
67 }
68
69 } // namespace float32_model
70
71 namespace quant8_model {
72
73 // In frameworks/ml/nn/runtime/test/generated/, creates a hidl model of quant8 mobilenet.
74 #include "examples/mobilenet_quantized.example.cpp"
75 #include "vts_models/mobilenet_quantized.model.cpp"
76
77 // Prevent the compiler from complaining about an otherwise unused function.
78 [[maybe_unused]] auto dummy_createTestModel = createTestModel_dynamic_output_shape;
79 [[maybe_unused]] auto dummy_get_examples = get_examples_dynamic_output_shape;
80
81 // MixedTypedExample is defined in frameworks/ml/nn/tools/test_generator/include/TestHarness.h.
82 // This function assumes the operation is always ADD.
getLargeModelExamples(uint32_t len)83 std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
84 uint8_t outputValue = 1 + static_cast<uint8_t>(len);
85 return {{.operands = {// Input
86 {.operandDimensions = {{0, {1}}}, .quant8AsymmOperands = {{0, {1}}}},
87 // Output
88 {.operandDimensions = {{0, {1}}},
89 .quant8AsymmOperands = {{0, {outputValue}}}}}}};
90 }
91
92 } // namespace quant8_model
93
94 namespace {
95
96 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
97
98 // Creates cache handles based on provided file groups.
99 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)100 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
101 const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
102 handles->resize(fileGroups.size());
103 for (uint32_t i = 0; i < fileGroups.size(); i++) {
104 std::vector<int> fds;
105 for (const auto& file : fileGroups[i]) {
106 int fd;
107 if (mode[i] == AccessMode::READ_ONLY) {
108 fd = open(file.c_str(), O_RDONLY);
109 } else if (mode[i] == AccessMode::WRITE_ONLY) {
110 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
111 } else if (mode[i] == AccessMode::READ_WRITE) {
112 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
113 } else {
114 FAIL();
115 }
116 ASSERT_GE(fd, 0);
117 fds.push_back(fd);
118 }
119 native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
120 ASSERT_NE(cacheNativeHandle, nullptr);
121 std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
122 (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
123 }
124 }
125
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)126 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
127 hidl_vec<hidl_handle>* handles) {
128 createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
129 }
130
131 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
132 // For simplicity, activation scalar is shared. The second operand is not shared
133 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
134 // data locations in cache.
135 //
136 // --------- activation --------
137 // ↓ ↓ ↓ ↓
138 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
139 // ↑ ↑ ↑ ↑
140 // [1] [1] [1] [1]
141 //
142 // This function assumes the operation is either ADD or MUL.
143 template <typename CppType, OperandType operandType>
createLargeTestModelImpl(OperationType op,uint32_t len)144 Model createLargeTestModelImpl(OperationType op, uint32_t len) {
145 EXPECT_TRUE(op == OperationType::ADD || op == OperationType::MUL);
146
147 // Model operations and operands.
148 std::vector<Operation> operations(len);
149 std::vector<Operand> operands(len * 2 + 2);
150
151 // The constant buffer pool. This contains the activation scalar, followed by the
152 // per-operation constant operands.
153 std::vector<uint8_t> operandValues(sizeof(int32_t) + len * sizeof(CppType));
154
155 // The activation scalar, value = 0.
156 operands[0] = {
157 .type = OperandType::INT32,
158 .dimensions = {},
159 .numberOfConsumers = len,
160 .scale = 0.0f,
161 .zeroPoint = 0,
162 .lifetime = OperandLifeTime::CONSTANT_COPY,
163 .location = {.poolIndex = 0, .offset = 0, .length = sizeof(int32_t)},
164 };
165 memset(operandValues.data(), 0, sizeof(int32_t));
166
167 // The buffer value of the constant second operand. The logical value is always 1.0f.
168 CppType bufferValue;
169 // The scale of the first and second operand.
170 float scale1, scale2;
171 if (operandType == OperandType::TENSOR_FLOAT32) {
172 bufferValue = 1.0f;
173 scale1 = 0.0f;
174 scale2 = 0.0f;
175 } else if (op == OperationType::ADD) {
176 bufferValue = 1;
177 scale1 = 1.0f;
178 scale2 = 1.0f;
179 } else {
180 // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
181 // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
182 bufferValue = 2;
183 scale1 = 1.0f;
184 scale2 = 0.5f;
185 }
186
187 for (uint32_t i = 0; i < len; i++) {
188 const uint32_t firstInputIndex = i * 2 + 1;
189 const uint32_t secondInputIndex = firstInputIndex + 1;
190 const uint32_t outputIndex = secondInputIndex + 1;
191
192 // The first operation input.
193 operands[firstInputIndex] = {
194 .type = operandType,
195 .dimensions = {1},
196 .numberOfConsumers = 1,
197 .scale = scale1,
198 .zeroPoint = 0,
199 .lifetime = (i == 0 ? OperandLifeTime::MODEL_INPUT
200 : OperandLifeTime::TEMPORARY_VARIABLE),
201 .location = {},
202 };
203
204 // The second operation input, value = 1.
205 operands[secondInputIndex] = {
206 .type = operandType,
207 .dimensions = {1},
208 .numberOfConsumers = 1,
209 .scale = scale2,
210 .zeroPoint = 0,
211 .lifetime = OperandLifeTime::CONSTANT_COPY,
212 .location = {.poolIndex = 0,
213 .offset = static_cast<uint32_t>(i * sizeof(CppType) + sizeof(int32_t)),
214 .length = sizeof(CppType)},
215 };
216 memcpy(operandValues.data() + sizeof(int32_t) + i * sizeof(CppType), &bufferValue,
217 sizeof(CppType));
218
219 // The operation. All operations share the same activation scalar.
220 // The output operand is created as an input in the next iteration of the loop, in the case
221 // of all but the last member of the chain; and after the loop as a model output, in the
222 // case of the last member of the chain.
223 operations[i] = {
224 .type = op,
225 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
226 .outputs = {outputIndex},
227 };
228 }
229
230 // The model output.
231 operands.back() = {
232 .type = operandType,
233 .dimensions = {1},
234 .numberOfConsumers = 0,
235 .scale = scale1,
236 .zeroPoint = 0,
237 .lifetime = OperandLifeTime::MODEL_OUTPUT,
238 .location = {},
239 };
240
241 const std::vector<uint32_t> inputIndexes = {1};
242 const std::vector<uint32_t> outputIndexes = {len * 2 + 1};
243 const std::vector<hidl_memory> pools = {};
244
245 return {
246 .operands = operands,
247 .operations = operations,
248 .inputIndexes = inputIndexes,
249 .outputIndexes = outputIndexes,
250 .operandValues = operandValues,
251 .pools = pools,
252 };
253 }
254
255 } // namespace
256
257 // Tag for the compilation caching tests.
258 class CompilationCachingTestBase : public NeuralnetworksHidlTest {
259 protected:
CompilationCachingTestBase(OperandType type)260 CompilationCachingTestBase(OperandType type) : kOperandType(type) {}
261
SetUp()262 void SetUp() override {
263 NeuralnetworksHidlTest::SetUp();
264 ASSERT_NE(device.get(), nullptr);
265
266 // Create cache directory. The cache directory and a temporary cache file is always created
267 // to test the behavior of prepareModelFromCache, even when caching is not supported.
268 char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
269 char* cacheDir = mkdtemp(cacheDirTemp);
270 ASSERT_NE(cacheDir, nullptr);
271 mCacheDir = cacheDir;
272 mCacheDir.push_back('/');
273
274 Return<void> ret = device->getNumberOfCacheFilesNeeded(
275 [this](ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
276 EXPECT_EQ(ErrorStatus::NONE, status);
277 mNumModelCache = numModelCache;
278 mNumDataCache = numDataCache;
279 });
280 EXPECT_TRUE(ret.isOk());
281 mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
282
283 // Create empty cache files.
284 mTmpCache = mCacheDir + "tmp";
285 for (uint32_t i = 0; i < mNumModelCache; i++) {
286 mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
287 }
288 for (uint32_t i = 0; i < mNumDataCache; i++) {
289 mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
290 }
291 // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
292 hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
293 createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
294 createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
295 createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
296
297 if (!mIsCachingSupported) {
298 LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
299 "support compilation caching.";
300 std::cout << "[ ] Early termination of test because vendor service does not "
301 "support compilation caching."
302 << std::endl;
303 }
304 }
305
TearDown()306 void TearDown() override {
307 // If the test passes, remove the tmp directory. Otherwise, keep it for debugging purposes.
308 if (!::testing::Test::HasFailure()) {
309 // Recursively remove the cache directory specified by mCacheDir.
310 auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
311 return remove(entry);
312 };
313 nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
314 }
315 NeuralnetworksHidlTest::TearDown();
316 }
317
318 // Model and examples creators. According to kOperandType, the following methods will return
319 // either float32 model/examples or the quant8 variant.
createTestModel()320 Model createTestModel() {
321 if (kOperandType == OperandType::TENSOR_FLOAT32) {
322 return float32_model::createTestModel();
323 } else {
324 return quant8_model::createTestModel();
325 }
326 }
327
get_examples()328 std::vector<MixedTypedExample> get_examples() {
329 if (kOperandType == OperandType::TENSOR_FLOAT32) {
330 return float32_model::get_examples();
331 } else {
332 return quant8_model::get_examples();
333 }
334 }
335
createLargeTestModel(OperationType op,uint32_t len)336 Model createLargeTestModel(OperationType op, uint32_t len) {
337 if (kOperandType == OperandType::TENSOR_FLOAT32) {
338 return createLargeTestModelImpl<float, OperandType::TENSOR_FLOAT32>(op, len);
339 } else {
340 return createLargeTestModelImpl<uint8_t, OperandType::TENSOR_QUANT8_ASYMM>(op, len);
341 }
342 }
343
getLargeModelExamples(uint32_t len)344 std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
345 if (kOperandType == OperandType::TENSOR_FLOAT32) {
346 return float32_model::getLargeModelExamples(len);
347 } else {
348 return quant8_model::getLargeModelExamples(len);
349 }
350 }
351
352 // See if the service can handle the model.
isModelFullySupported(const V1_2::Model & model)353 bool isModelFullySupported(const V1_2::Model& model) {
354 bool fullySupportsModel = false;
355 Return<void> supportedCall = device->getSupportedOperations_1_2(
356 model,
357 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
358 ASSERT_EQ(ErrorStatus::NONE, status);
359 ASSERT_EQ(supported.size(), model.operations.size());
360 fullySupportsModel = std::all_of(supported.begin(), supported.end(),
361 [](bool valid) { return valid; });
362 });
363 EXPECT_TRUE(supportedCall.isOk());
364 return fullySupportsModel;
365 }
366
saveModelToCache(const V1_2::Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)367 void saveModelToCache(const V1_2::Model& model, const hidl_vec<hidl_handle>& modelCache,
368 const hidl_vec<hidl_handle>& dataCache,
369 sp<IPreparedModel>* preparedModel = nullptr) {
370 if (preparedModel != nullptr) *preparedModel = nullptr;
371
372 // Launch prepare model.
373 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
374 ASSERT_NE(nullptr, preparedModelCallback.get());
375 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
376 Return<ErrorStatus> prepareLaunchStatus =
377 device->prepareModel_1_2(model, ExecutionPreference::FAST_SINGLE_ANSWER, modelCache,
378 dataCache, cacheToken, preparedModelCallback);
379 ASSERT_TRUE(prepareLaunchStatus.isOk());
380 ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
381
382 // Retrieve prepared model.
383 preparedModelCallback->wait();
384 ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
385 if (preparedModel != nullptr) {
386 *preparedModel =
387 V1_2::IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
388 .withDefault(nullptr);
389 }
390 }
391
checkEarlyTermination(ErrorStatus status)392 bool checkEarlyTermination(ErrorStatus status) {
393 if (status == ErrorStatus::GENERAL_FAILURE) {
394 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
395 "save the prepared model that it does not support.";
396 std::cout << "[ ] Early termination of test because vendor service cannot "
397 "save the prepared model that it does not support."
398 << std::endl;
399 return true;
400 }
401 return false;
402 }
403
checkEarlyTermination(const V1_2::Model & model)404 bool checkEarlyTermination(const V1_2::Model& model) {
405 if (!isModelFullySupported(model)) {
406 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
407 "prepare model that it does not support.";
408 std::cout << "[ ] Early termination of test because vendor service cannot "
409 "prepare model that it does not support."
410 << std::endl;
411 return true;
412 }
413 return false;
414 }
415
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)416 void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
417 const hidl_vec<hidl_handle>& dataCache,
418 sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
419 // Launch prepare model from cache.
420 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
421 ASSERT_NE(nullptr, preparedModelCallback.get());
422 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
423 Return<ErrorStatus> prepareLaunchStatus = device->prepareModelFromCache(
424 modelCache, dataCache, cacheToken, preparedModelCallback);
425 ASSERT_TRUE(prepareLaunchStatus.isOk());
426 if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
427 *preparedModel = nullptr;
428 *status = static_cast<ErrorStatus>(prepareLaunchStatus);
429 return;
430 }
431
432 // Retrieve prepared model.
433 preparedModelCallback->wait();
434 *status = preparedModelCallback->getStatus();
435 *preparedModel = V1_2::IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
436 .withDefault(nullptr);
437 }
438
439 // Absolute path to the temporary cache directory.
440 std::string mCacheDir;
441
442 // Groups of file paths for model and data cache in the tmp cache directory, initialized with
443 // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
444 // and the inner vector is for fds held by each handle.
445 std::vector<std::vector<std::string>> mModelCache;
446 std::vector<std::vector<std::string>> mDataCache;
447
448 // A separate temporary file path in the tmp cache directory.
449 std::string mTmpCache;
450
451 uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
452 uint32_t mNumModelCache;
453 uint32_t mNumDataCache;
454 uint32_t mIsCachingSupported;
455
456 // The primary data type of the testModel.
457 const OperandType kOperandType;
458 };
459
460 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
461 // pass running with float32 models and the second pass running with quant8 models.
462 class CompilationCachingTest : public CompilationCachingTestBase,
463 public ::testing::WithParamInterface<OperandType> {
464 protected:
CompilationCachingTest()465 CompilationCachingTest() : CompilationCachingTestBase(GetParam()) {}
466 };
467
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)468 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
469 // Create test HIDL model and compile.
470 const Model testModel = createTestModel();
471 if (checkEarlyTermination(testModel)) return;
472 sp<IPreparedModel> preparedModel = nullptr;
473
474 // Save the compilation to cache.
475 {
476 hidl_vec<hidl_handle> modelCache, dataCache;
477 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
478 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
479 saveModelToCache(testModel, modelCache, dataCache);
480 }
481
482 // Retrieve preparedModel from cache.
483 {
484 preparedModel = nullptr;
485 ErrorStatus status;
486 hidl_vec<hidl_handle> modelCache, dataCache;
487 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
488 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
489 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
490 if (!mIsCachingSupported) {
491 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
492 ASSERT_EQ(preparedModel, nullptr);
493 return;
494 } else if (checkEarlyTermination(status)) {
495 ASSERT_EQ(preparedModel, nullptr);
496 return;
497 } else {
498 ASSERT_EQ(status, ErrorStatus::NONE);
499 ASSERT_NE(preparedModel, nullptr);
500 }
501 }
502
503 // Execute and verify results.
504 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; }, get_examples(),
505 testModel.relaxComputationFloat32toFloat16,
506 /*testDynamicOutputShape=*/false);
507 }
508
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)509 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
510 // Create test HIDL model and compile.
511 const Model testModel = createTestModel();
512 if (checkEarlyTermination(testModel)) return;
513 sp<IPreparedModel> preparedModel = nullptr;
514
515 // Save the compilation to cache.
516 {
517 hidl_vec<hidl_handle> modelCache, dataCache;
518 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
519 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
520 uint8_t dummyBytes[] = {0, 0};
521 // Write a dummy integer to the cache.
522 // The driver should be able to handle non-empty cache and non-zero fd offset.
523 for (uint32_t i = 0; i < modelCache.size(); i++) {
524 ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
525 sizeof(dummyBytes)),
526 sizeof(dummyBytes));
527 }
528 for (uint32_t i = 0; i < dataCache.size(); i++) {
529 ASSERT_EQ(
530 write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
531 sizeof(dummyBytes));
532 }
533 saveModelToCache(testModel, modelCache, dataCache);
534 }
535
536 // Retrieve preparedModel from cache.
537 {
538 preparedModel = nullptr;
539 ErrorStatus status;
540 hidl_vec<hidl_handle> modelCache, dataCache;
541 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
542 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
543 uint8_t dummyByte = 0;
544 // Advance the offset of each handle by one byte.
545 // The driver should be able to handle non-zero fd offset.
546 for (uint32_t i = 0; i < modelCache.size(); i++) {
547 ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
548 }
549 for (uint32_t i = 0; i < dataCache.size(); i++) {
550 ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
551 }
552 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
553 if (!mIsCachingSupported) {
554 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
555 ASSERT_EQ(preparedModel, nullptr);
556 return;
557 } else if (checkEarlyTermination(status)) {
558 ASSERT_EQ(preparedModel, nullptr);
559 return;
560 } else {
561 ASSERT_EQ(status, ErrorStatus::NONE);
562 ASSERT_NE(preparedModel, nullptr);
563 }
564 }
565
566 // Execute and verify results.
567 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; }, get_examples(),
568 testModel.relaxComputationFloat32toFloat16,
569 /*testDynamicOutputShape=*/false);
570 }
571
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)572 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
573 // Create test HIDL model and compile.
574 const Model testModel = createTestModel();
575 if (checkEarlyTermination(testModel)) return;
576
577 // Test with number of model cache files greater than mNumModelCache.
578 {
579 hidl_vec<hidl_handle> modelCache, dataCache;
580 // Pass an additional cache file for model cache.
581 mModelCache.push_back({mTmpCache});
582 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
583 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
584 mModelCache.pop_back();
585 sp<IPreparedModel> preparedModel = nullptr;
586 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
587 ASSERT_NE(preparedModel, nullptr);
588 // Execute and verify results.
589 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
590 get_examples(),
591 testModel.relaxComputationFloat32toFloat16,
592 /*testDynamicOutputShape=*/false);
593 // Check if prepareModelFromCache fails.
594 preparedModel = nullptr;
595 ErrorStatus status;
596 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
597 if (status != ErrorStatus::INVALID_ARGUMENT) {
598 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
599 }
600 ASSERT_EQ(preparedModel, nullptr);
601 }
602
603 // Test with number of model cache files smaller than mNumModelCache.
604 if (mModelCache.size() > 0) {
605 hidl_vec<hidl_handle> modelCache, dataCache;
606 // Pop out the last cache file.
607 auto tmp = mModelCache.back();
608 mModelCache.pop_back();
609 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
610 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
611 mModelCache.push_back(tmp);
612 sp<IPreparedModel> preparedModel = nullptr;
613 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
614 ASSERT_NE(preparedModel, nullptr);
615 // Execute and verify results.
616 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
617 get_examples(),
618 testModel.relaxComputationFloat32toFloat16,
619 /*testDynamicOutputShape=*/false);
620 // Check if prepareModelFromCache fails.
621 preparedModel = nullptr;
622 ErrorStatus status;
623 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
624 if (status != ErrorStatus::INVALID_ARGUMENT) {
625 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
626 }
627 ASSERT_EQ(preparedModel, nullptr);
628 }
629
630 // Test with number of data cache files greater than mNumDataCache.
631 {
632 hidl_vec<hidl_handle> modelCache, dataCache;
633 // Pass an additional cache file for data cache.
634 mDataCache.push_back({mTmpCache});
635 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
636 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
637 mDataCache.pop_back();
638 sp<IPreparedModel> preparedModel = nullptr;
639 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
640 ASSERT_NE(preparedModel, nullptr);
641 // Execute and verify results.
642 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
643 get_examples(),
644 testModel.relaxComputationFloat32toFloat16,
645 /*testDynamicOutputShape=*/false);
646 // Check if prepareModelFromCache fails.
647 preparedModel = nullptr;
648 ErrorStatus status;
649 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
650 if (status != ErrorStatus::INVALID_ARGUMENT) {
651 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
652 }
653 ASSERT_EQ(preparedModel, nullptr);
654 }
655
656 // Test with number of data cache files smaller than mNumDataCache.
657 if (mDataCache.size() > 0) {
658 hidl_vec<hidl_handle> modelCache, dataCache;
659 // Pop out the last cache file.
660 auto tmp = mDataCache.back();
661 mDataCache.pop_back();
662 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
663 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
664 mDataCache.push_back(tmp);
665 sp<IPreparedModel> preparedModel = nullptr;
666 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
667 ASSERT_NE(preparedModel, nullptr);
668 // Execute and verify results.
669 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
670 get_examples(),
671 testModel.relaxComputationFloat32toFloat16,
672 /*testDynamicOutputShape=*/false);
673 // Check if prepareModelFromCache fails.
674 preparedModel = nullptr;
675 ErrorStatus status;
676 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
677 if (status != ErrorStatus::INVALID_ARGUMENT) {
678 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
679 }
680 ASSERT_EQ(preparedModel, nullptr);
681 }
682 }
683
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)684 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
685 // Create test HIDL model and compile.
686 const Model testModel = createTestModel();
687 if (checkEarlyTermination(testModel)) return;
688
689 // Save the compilation to cache.
690 {
691 hidl_vec<hidl_handle> modelCache, dataCache;
692 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
693 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
694 saveModelToCache(testModel, modelCache, dataCache);
695 }
696
697 // Test with number of model cache files greater than mNumModelCache.
698 {
699 sp<IPreparedModel> preparedModel = nullptr;
700 ErrorStatus status;
701 hidl_vec<hidl_handle> modelCache, dataCache;
702 mModelCache.push_back({mTmpCache});
703 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
704 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
705 mModelCache.pop_back();
706 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
707 if (status != ErrorStatus::GENERAL_FAILURE) {
708 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
709 }
710 ASSERT_EQ(preparedModel, nullptr);
711 }
712
713 // Test with number of model cache files smaller than mNumModelCache.
714 if (mModelCache.size() > 0) {
715 sp<IPreparedModel> preparedModel = nullptr;
716 ErrorStatus status;
717 hidl_vec<hidl_handle> modelCache, dataCache;
718 auto tmp = mModelCache.back();
719 mModelCache.pop_back();
720 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
721 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
722 mModelCache.push_back(tmp);
723 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
724 if (status != ErrorStatus::GENERAL_FAILURE) {
725 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
726 }
727 ASSERT_EQ(preparedModel, nullptr);
728 }
729
730 // Test with number of data cache files greater than mNumDataCache.
731 {
732 sp<IPreparedModel> preparedModel = nullptr;
733 ErrorStatus status;
734 hidl_vec<hidl_handle> modelCache, dataCache;
735 mDataCache.push_back({mTmpCache});
736 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
737 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
738 mDataCache.pop_back();
739 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
740 if (status != ErrorStatus::GENERAL_FAILURE) {
741 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
742 }
743 ASSERT_EQ(preparedModel, nullptr);
744 }
745
746 // Test with number of data cache files smaller than mNumDataCache.
747 if (mDataCache.size() > 0) {
748 sp<IPreparedModel> preparedModel = nullptr;
749 ErrorStatus status;
750 hidl_vec<hidl_handle> modelCache, dataCache;
751 auto tmp = mDataCache.back();
752 mDataCache.pop_back();
753 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
754 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
755 mDataCache.push_back(tmp);
756 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
757 if (status != ErrorStatus::GENERAL_FAILURE) {
758 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
759 }
760 ASSERT_EQ(preparedModel, nullptr);
761 }
762 }
763
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)764 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
765 // Create test HIDL model and compile.
766 const Model testModel = createTestModel();
767 if (checkEarlyTermination(testModel)) return;
768
769 // Go through each handle in model cache, test with NumFd greater than 1.
770 for (uint32_t i = 0; i < mNumModelCache; i++) {
771 hidl_vec<hidl_handle> modelCache, dataCache;
772 // Pass an invalid number of fds for handle i.
773 mModelCache[i].push_back(mTmpCache);
774 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
775 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
776 mModelCache[i].pop_back();
777 sp<IPreparedModel> preparedModel = nullptr;
778 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
779 ASSERT_NE(preparedModel, nullptr);
780 // Execute and verify results.
781 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
782 get_examples(),
783 testModel.relaxComputationFloat32toFloat16,
784 /*testDynamicOutputShape=*/false);
785 // Check if prepareModelFromCache fails.
786 preparedModel = nullptr;
787 ErrorStatus status;
788 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
789 if (status != ErrorStatus::INVALID_ARGUMENT) {
790 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
791 }
792 ASSERT_EQ(preparedModel, nullptr);
793 }
794
795 // Go through each handle in model cache, test with NumFd equal to 0.
796 for (uint32_t i = 0; i < mNumModelCache; i++) {
797 hidl_vec<hidl_handle> modelCache, dataCache;
798 // Pass an invalid number of fds for handle i.
799 auto tmp = mModelCache[i].back();
800 mModelCache[i].pop_back();
801 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
802 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
803 mModelCache[i].push_back(tmp);
804 sp<IPreparedModel> preparedModel = nullptr;
805 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
806 ASSERT_NE(preparedModel, nullptr);
807 // Execute and verify results.
808 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
809 get_examples(),
810 testModel.relaxComputationFloat32toFloat16,
811 /*testDynamicOutputShape=*/false);
812 // Check if prepareModelFromCache fails.
813 preparedModel = nullptr;
814 ErrorStatus status;
815 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
816 if (status != ErrorStatus::INVALID_ARGUMENT) {
817 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
818 }
819 ASSERT_EQ(preparedModel, nullptr);
820 }
821
822 // Go through each handle in data cache, test with NumFd greater than 1.
823 for (uint32_t i = 0; i < mNumDataCache; i++) {
824 hidl_vec<hidl_handle> modelCache, dataCache;
825 // Pass an invalid number of fds for handle i.
826 mDataCache[i].push_back(mTmpCache);
827 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
828 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
829 mDataCache[i].pop_back();
830 sp<IPreparedModel> preparedModel = nullptr;
831 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
832 ASSERT_NE(preparedModel, nullptr);
833 // Execute and verify results.
834 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
835 get_examples(),
836 testModel.relaxComputationFloat32toFloat16,
837 /*testDynamicOutputShape=*/false);
838 // Check if prepareModelFromCache fails.
839 preparedModel = nullptr;
840 ErrorStatus status;
841 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
842 if (status != ErrorStatus::INVALID_ARGUMENT) {
843 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
844 }
845 ASSERT_EQ(preparedModel, nullptr);
846 }
847
848 // Go through each handle in data cache, test with NumFd equal to 0.
849 for (uint32_t i = 0; i < mNumDataCache; i++) {
850 hidl_vec<hidl_handle> modelCache, dataCache;
851 // Pass an invalid number of fds for handle i.
852 auto tmp = mDataCache[i].back();
853 mDataCache[i].pop_back();
854 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
855 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
856 mDataCache[i].push_back(tmp);
857 sp<IPreparedModel> preparedModel = nullptr;
858 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
859 ASSERT_NE(preparedModel, nullptr);
860 // Execute and verify results.
861 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
862 get_examples(),
863 testModel.relaxComputationFloat32toFloat16,
864 /*testDynamicOutputShape=*/false);
865 // Check if prepareModelFromCache fails.
866 preparedModel = nullptr;
867 ErrorStatus status;
868 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
869 if (status != ErrorStatus::INVALID_ARGUMENT) {
870 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
871 }
872 ASSERT_EQ(preparedModel, nullptr);
873 }
874 }
875
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)876 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
877 // Create test HIDL model and compile.
878 const Model testModel = createTestModel();
879 if (checkEarlyTermination(testModel)) return;
880
881 // Save the compilation to cache.
882 {
883 hidl_vec<hidl_handle> modelCache, dataCache;
884 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
885 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
886 saveModelToCache(testModel, modelCache, dataCache);
887 }
888
889 // Go through each handle in model cache, test with NumFd greater than 1.
890 for (uint32_t i = 0; i < mNumModelCache; i++) {
891 sp<IPreparedModel> preparedModel = nullptr;
892 ErrorStatus status;
893 hidl_vec<hidl_handle> modelCache, dataCache;
894 mModelCache[i].push_back(mTmpCache);
895 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
896 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
897 mModelCache[i].pop_back();
898 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
899 if (status != ErrorStatus::GENERAL_FAILURE) {
900 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
901 }
902 ASSERT_EQ(preparedModel, nullptr);
903 }
904
905 // Go through each handle in model cache, test with NumFd equal to 0.
906 for (uint32_t i = 0; i < mNumModelCache; i++) {
907 sp<IPreparedModel> preparedModel = nullptr;
908 ErrorStatus status;
909 hidl_vec<hidl_handle> modelCache, dataCache;
910 auto tmp = mModelCache[i].back();
911 mModelCache[i].pop_back();
912 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
913 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
914 mModelCache[i].push_back(tmp);
915 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
916 if (status != ErrorStatus::GENERAL_FAILURE) {
917 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
918 }
919 ASSERT_EQ(preparedModel, nullptr);
920 }
921
922 // Go through each handle in data cache, test with NumFd greater than 1.
923 for (uint32_t i = 0; i < mNumDataCache; i++) {
924 sp<IPreparedModel> preparedModel = nullptr;
925 ErrorStatus status;
926 hidl_vec<hidl_handle> modelCache, dataCache;
927 mDataCache[i].push_back(mTmpCache);
928 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
929 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
930 mDataCache[i].pop_back();
931 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
932 if (status != ErrorStatus::GENERAL_FAILURE) {
933 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
934 }
935 ASSERT_EQ(preparedModel, nullptr);
936 }
937
938 // Go through each handle in data cache, test with NumFd equal to 0.
939 for (uint32_t i = 0; i < mNumDataCache; i++) {
940 sp<IPreparedModel> preparedModel = nullptr;
941 ErrorStatus status;
942 hidl_vec<hidl_handle> modelCache, dataCache;
943 auto tmp = mDataCache[i].back();
944 mDataCache[i].pop_back();
945 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
946 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
947 mDataCache[i].push_back(tmp);
948 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
949 if (status != ErrorStatus::GENERAL_FAILURE) {
950 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
951 }
952 ASSERT_EQ(preparedModel, nullptr);
953 }
954 }
955
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)956 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
957 // Create test HIDL model and compile.
958 const Model testModel = createTestModel();
959 if (checkEarlyTermination(testModel)) return;
960 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
961 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
962
963 // Go through each handle in model cache, test with invalid access mode.
964 for (uint32_t i = 0; i < mNumModelCache; i++) {
965 hidl_vec<hidl_handle> modelCache, dataCache;
966 modelCacheMode[i] = AccessMode::READ_ONLY;
967 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
968 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
969 modelCacheMode[i] = AccessMode::READ_WRITE;
970 sp<IPreparedModel> preparedModel = nullptr;
971 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
972 ASSERT_NE(preparedModel, nullptr);
973 // Execute and verify results.
974 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
975 get_examples(),
976 testModel.relaxComputationFloat32toFloat16,
977 /*testDynamicOutputShape=*/false);
978 // Check if prepareModelFromCache fails.
979 preparedModel = nullptr;
980 ErrorStatus status;
981 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
982 if (status != ErrorStatus::INVALID_ARGUMENT) {
983 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
984 }
985 ASSERT_EQ(preparedModel, nullptr);
986 }
987
988 // Go through each handle in data cache, test with invalid access mode.
989 for (uint32_t i = 0; i < mNumDataCache; i++) {
990 hidl_vec<hidl_handle> modelCache, dataCache;
991 dataCacheMode[i] = AccessMode::READ_ONLY;
992 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
993 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
994 dataCacheMode[i] = AccessMode::READ_WRITE;
995 sp<IPreparedModel> preparedModel = nullptr;
996 saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
997 ASSERT_NE(preparedModel, nullptr);
998 // Execute and verify results.
999 generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
1000 get_examples(),
1001 testModel.relaxComputationFloat32toFloat16,
1002 /*testDynamicOutputShape=*/false);
1003 // Check if prepareModelFromCache fails.
1004 preparedModel = nullptr;
1005 ErrorStatus status;
1006 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1007 if (status != ErrorStatus::INVALID_ARGUMENT) {
1008 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1009 }
1010 ASSERT_EQ(preparedModel, nullptr);
1011 }
1012 }
1013
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)1014 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
1015 // Create test HIDL model and compile.
1016 const Model testModel = createTestModel();
1017 if (checkEarlyTermination(testModel)) return;
1018 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
1019 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
1020
1021 // Save the compilation to cache.
1022 {
1023 hidl_vec<hidl_handle> modelCache, dataCache;
1024 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1025 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1026 saveModelToCache(testModel, modelCache, dataCache);
1027 }
1028
1029 // Go through each handle in model cache, test with invalid access mode.
1030 for (uint32_t i = 0; i < mNumModelCache; i++) {
1031 sp<IPreparedModel> preparedModel = nullptr;
1032 ErrorStatus status;
1033 hidl_vec<hidl_handle> modelCache, dataCache;
1034 modelCacheMode[i] = AccessMode::WRITE_ONLY;
1035 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
1036 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
1037 modelCacheMode[i] = AccessMode::READ_WRITE;
1038 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1039 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1040 ASSERT_EQ(preparedModel, nullptr);
1041 }
1042
1043 // Go through each handle in data cache, test with invalid access mode.
1044 for (uint32_t i = 0; i < mNumDataCache; i++) {
1045 sp<IPreparedModel> preparedModel = nullptr;
1046 ErrorStatus status;
1047 hidl_vec<hidl_handle> modelCache, dataCache;
1048 dataCacheMode[i] = AccessMode::WRITE_ONLY;
1049 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
1050 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
1051 dataCacheMode[i] = AccessMode::READ_WRITE;
1052 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1053 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1054 ASSERT_EQ(preparedModel, nullptr);
1055 }
1056 }
1057
1058 // Copy file contents between file groups.
1059 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
1060 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)1061 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
1062 const std::vector<std::vector<std::string>>& to) {
1063 constexpr size_t kBufferSize = 1000000;
1064 uint8_t buffer[kBufferSize];
1065
1066 ASSERT_EQ(from.size(), to.size());
1067 for (uint32_t i = 0; i < from.size(); i++) {
1068 ASSERT_EQ(from[i].size(), 1u);
1069 ASSERT_EQ(to[i].size(), 1u);
1070 int fromFd = open(from[i][0].c_str(), O_RDONLY);
1071 int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1072 ASSERT_GE(fromFd, 0);
1073 ASSERT_GE(toFd, 0);
1074
1075 ssize_t readBytes;
1076 while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1077 ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1078 }
1079 ASSERT_GE(readBytes, 0);
1080
1081 close(fromFd);
1082 close(toFd);
1083 }
1084 }
1085
1086 // Number of operations in the large test model.
1087 constexpr uint32_t kLargeModelSize = 100;
1088 constexpr uint32_t kNumIterationsTOCTOU = 100;
1089
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1090 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1091 if (!mIsCachingSupported) return;
1092
1093 // Create test models and check if fully supported by the service.
1094 const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1095 if (checkEarlyTermination(testModelMul)) return;
1096 const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1097 if (checkEarlyTermination(testModelAdd)) return;
1098
1099 // Save the testModelMul compilation to cache.
1100 auto modelCacheMul = mModelCache;
1101 for (auto& cache : modelCacheMul) {
1102 cache[0].append("_mul");
1103 }
1104 {
1105 hidl_vec<hidl_handle> modelCache, dataCache;
1106 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1107 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1108 saveModelToCache(testModelMul, modelCache, dataCache);
1109 }
1110
1111 // Use a different token for testModelAdd.
1112 mToken[0]++;
1113
1114 // This test is probabilistic, so we run it multiple times.
1115 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1116 // Save the testModelAdd compilation to cache.
1117 {
1118 hidl_vec<hidl_handle> modelCache, dataCache;
1119 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1120 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1121
1122 // Spawn a thread to copy the cache content concurrently while saving to cache.
1123 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1124 saveModelToCache(testModelAdd, modelCache, dataCache);
1125 thread.join();
1126 }
1127
1128 // Retrieve preparedModel from cache.
1129 {
1130 sp<IPreparedModel> preparedModel = nullptr;
1131 ErrorStatus status;
1132 hidl_vec<hidl_handle> modelCache, dataCache;
1133 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1134 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1135 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1136
1137 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1138 // the prepared model must be executed with the correct result and not crash.
1139 if (status != ErrorStatus::NONE) {
1140 ASSERT_EQ(preparedModel, nullptr);
1141 } else {
1142 ASSERT_NE(preparedModel, nullptr);
1143 generated_tests::EvaluatePreparedModel(
1144 preparedModel, [](int) { return false; },
1145 getLargeModelExamples(kLargeModelSize),
1146 testModelAdd.relaxComputationFloat32toFloat16,
1147 /*testDynamicOutputShape=*/false);
1148 }
1149 }
1150 }
1151 }
1152
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1153 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1154 if (!mIsCachingSupported) return;
1155
1156 // Create test models and check if fully supported by the service.
1157 const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1158 if (checkEarlyTermination(testModelMul)) return;
1159 const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1160 if (checkEarlyTermination(testModelAdd)) return;
1161
1162 // Save the testModelMul compilation to cache.
1163 auto modelCacheMul = mModelCache;
1164 for (auto& cache : modelCacheMul) {
1165 cache[0].append("_mul");
1166 }
1167 {
1168 hidl_vec<hidl_handle> modelCache, dataCache;
1169 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1170 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1171 saveModelToCache(testModelMul, modelCache, dataCache);
1172 }
1173
1174 // Use a different token for testModelAdd.
1175 mToken[0]++;
1176
1177 // This test is probabilistic, so we run it multiple times.
1178 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1179 // Save the testModelAdd compilation to cache.
1180 {
1181 hidl_vec<hidl_handle> modelCache, dataCache;
1182 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1183 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1184 saveModelToCache(testModelAdd, modelCache, dataCache);
1185 }
1186
1187 // Retrieve preparedModel from cache.
1188 {
1189 sp<IPreparedModel> preparedModel = nullptr;
1190 ErrorStatus status;
1191 hidl_vec<hidl_handle> modelCache, dataCache;
1192 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1193 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1194
1195 // Spawn a thread to copy the cache content concurrently while preparing from cache.
1196 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1197 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1198 thread.join();
1199
1200 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1201 // the prepared model must be executed with the correct result and not crash.
1202 if (status != ErrorStatus::NONE) {
1203 ASSERT_EQ(preparedModel, nullptr);
1204 } else {
1205 ASSERT_NE(preparedModel, nullptr);
1206 generated_tests::EvaluatePreparedModel(
1207 preparedModel, [](int) { return false; },
1208 getLargeModelExamples(kLargeModelSize),
1209 testModelAdd.relaxComputationFloat32toFloat16,
1210 /*testDynamicOutputShape=*/false);
1211 }
1212 }
1213 }
1214 }
1215
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1216 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1217 if (!mIsCachingSupported) return;
1218
1219 // Create test models and check if fully supported by the service.
1220 const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1221 if (checkEarlyTermination(testModelMul)) return;
1222 const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1223 if (checkEarlyTermination(testModelAdd)) return;
1224
1225 // Save the testModelMul compilation to cache.
1226 auto modelCacheMul = mModelCache;
1227 for (auto& cache : modelCacheMul) {
1228 cache[0].append("_mul");
1229 }
1230 {
1231 hidl_vec<hidl_handle> modelCache, dataCache;
1232 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1233 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1234 saveModelToCache(testModelMul, modelCache, dataCache);
1235 }
1236
1237 // Use a different token for testModelAdd.
1238 mToken[0]++;
1239
1240 // Save the testModelAdd compilation to cache.
1241 {
1242 hidl_vec<hidl_handle> modelCache, dataCache;
1243 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1244 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1245 saveModelToCache(testModelAdd, modelCache, dataCache);
1246 }
1247
1248 // Replace the model cache of testModelAdd with testModelMul.
1249 copyCacheFiles(modelCacheMul, mModelCache);
1250
1251 // Retrieve the preparedModel from cache, expect failure.
1252 {
1253 sp<IPreparedModel> preparedModel = nullptr;
1254 ErrorStatus status;
1255 hidl_vec<hidl_handle> modelCache, dataCache;
1256 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1257 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1258 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1259 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1260 ASSERT_EQ(preparedModel, nullptr);
1261 }
1262 }
1263
1264 static const auto kOperandTypeChoices =
1265 ::testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1266
1267 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest, kOperandTypeChoices);
1268
1269 class CompilationCachingSecurityTest
1270 : public CompilationCachingTestBase,
1271 public ::testing::WithParamInterface<std::tuple<OperandType, uint32_t>> {
1272 protected:
CompilationCachingSecurityTest()1273 CompilationCachingSecurityTest() : CompilationCachingTestBase(std::get<0>(GetParam())) {}
1274
SetUp()1275 void SetUp() {
1276 CompilationCachingTestBase::SetUp();
1277 generator.seed(kSeed);
1278 }
1279
1280 // Get a random integer within a closed range [lower, upper].
1281 template <typename T>
getRandomInt(T lower,T upper)1282 T getRandomInt(T lower, T upper) {
1283 std::uniform_int_distribution<T> dis(lower, upper);
1284 return dis(generator);
1285 }
1286
1287 // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1288 void flipOneBitOfCache(const std::string& filename, bool* skip) {
1289 FILE* pFile = fopen(filename.c_str(), "r+");
1290 ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1291 long int fileSize = ftell(pFile);
1292 if (fileSize == 0) {
1293 fclose(pFile);
1294 *skip = true;
1295 return;
1296 }
1297 ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1298 int readByte = fgetc(pFile);
1299 ASSERT_NE(readByte, EOF);
1300 ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1301 ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1302 fclose(pFile);
1303 *skip = false;
1304 }
1305
1306 // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1307 void appendBytesToCache(const std::string& filename, bool* skip) {
1308 FILE* pFile = fopen(filename.c_str(), "a");
1309 uint32_t appendLength = getRandomInt(1, 256);
1310 for (uint32_t i = 0; i < appendLength; i++) {
1311 ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1312 }
1313 fclose(pFile);
1314 *skip = false;
1315 }
1316
1317 enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1318
1319 // Test if the driver behaves as expected when given corrupted cache or token.
1320 // The modifier will be invoked after save to cache but before prepare from cache.
1321 // The modifier accepts one pointer argument "skip" as the returning value, indicating
1322 // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1323 void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1324 const Model testModel = createTestModel();
1325 if (checkEarlyTermination(testModel)) return;
1326
1327 // Save the compilation to cache.
1328 {
1329 hidl_vec<hidl_handle> modelCache, dataCache;
1330 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1331 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1332 saveModelToCache(testModel, modelCache, dataCache);
1333 }
1334
1335 bool skip = false;
1336 modifier(&skip);
1337 if (skip) return;
1338
1339 // Retrieve preparedModel from cache.
1340 {
1341 sp<IPreparedModel> preparedModel = nullptr;
1342 ErrorStatus status;
1343 hidl_vec<hidl_handle> modelCache, dataCache;
1344 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1345 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1346 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1347
1348 switch (expected) {
1349 case ExpectedResult::GENERAL_FAILURE:
1350 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1351 ASSERT_EQ(preparedModel, nullptr);
1352 break;
1353 case ExpectedResult::NOT_CRASH:
1354 ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1355 break;
1356 default:
1357 FAIL();
1358 }
1359 }
1360 }
1361
1362 const uint32_t kSeed = std::get<1>(GetParam());
1363 std::mt19937 generator;
1364 };
1365
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1366 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1367 if (!mIsCachingSupported) return;
1368 for (uint32_t i = 0; i < mNumModelCache; i++) {
1369 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1370 [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1371 }
1372 }
1373
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1374 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1375 if (!mIsCachingSupported) return;
1376 for (uint32_t i = 0; i < mNumModelCache; i++) {
1377 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1378 [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1379 }
1380 }
1381
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1382 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1383 if (!mIsCachingSupported) return;
1384 for (uint32_t i = 0; i < mNumDataCache; i++) {
1385 testCorruptedCache(ExpectedResult::NOT_CRASH,
1386 [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1387 }
1388 }
1389
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1390 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1391 if (!mIsCachingSupported) return;
1392 for (uint32_t i = 0; i < mNumDataCache; i++) {
1393 testCorruptedCache(ExpectedResult::NOT_CRASH,
1394 [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1395 }
1396 }
1397
TEST_P(CompilationCachingSecurityTest,WrongToken)1398 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1399 if (!mIsCachingSupported) return;
1400 testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1401 // Randomly flip one single bit in mToken.
1402 uint32_t ind =
1403 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1404 mToken[ind] ^= (1U << getRandomInt(0, 7));
1405 *skip = false;
1406 });
1407 }
1408
1409 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1410 ::testing::Combine(kOperandTypeChoices, ::testing::Range(0U, 10U)));
1411
1412 } // namespace functional
1413 } // namespace vts
1414 } // namespace V1_2
1415 } // namespace neuralnetworks
1416 } // namespace hardware
1417 } // namespace android
1418