• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.app;
18 
19 import android.content.Context;
20 import android.util.Log;
21 
22 import androidx.test.InstrumentationRegistry;
23 
24 import com.android.nn.benchmark.core.BenchmarkException;
25 import com.android.nn.benchmark.core.BenchmarkResult;
26 import com.android.nn.benchmark.core.NNTestBase;
27 import com.android.nn.benchmark.core.NnApiDelegationFailure;
28 import com.android.nn.benchmark.core.Processor;
29 import com.android.nn.benchmark.core.TestModels;
30 import com.android.nn.benchmark.core.TfLiteBackend;
31 
32 import java.io.IOException;
33 import java.util.ArrayList;
34 import java.util.Arrays;
35 import java.util.List;
36 import java.util.Optional;
37 import java.util.concurrent.Callable;
38 import java.util.concurrent.atomic.AtomicBoolean;
39 import java.util.stream.Collectors;
40 
41 public interface AcceleratorSpecificTestSupport {
42     String TAG = "AcceleratorTest";
43 
findTestModelRunningOnAccelerator( Context context, String acceleratorName)44     static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator(
45             Context context, String acceleratorName) throws NnApiDelegationFailure {
46         for (TestModels.TestModelEntry model : TestModels.modelsList()) {
47             if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
48                 return Optional.of(model);
49             }
50         }
51         return Optional.empty();
52     }
53 
findAllTestModelsRunningOnAccelerator( Context context, String acceleratorName)54     static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator(
55             Context context, String acceleratorName) throws NnApiDelegationFailure {
56         List<TestModels.TestModelEntry> result = new ArrayList<>();
57         for (TestModels.TestModelEntry model : TestModels.modelsList()) {
58             if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
59                 result.add(model);
60             }
61         }
62         return result;
63     }
64 
ramdomInRange(long min, long max)65     default long ramdomInRange(long min, long max) {
66         return min + (long) (Math.random() * (max - min));
67     }
68 
getTestParameter(String key, String defaultValue)69     static String getTestParameter(String key, String defaultValue) {
70         return InstrumentationRegistry.getArguments().getString(key, defaultValue);
71     }
72 
getBooleanTestParameter(String key, boolean defaultValue)73     static boolean getBooleanTestParameter(String key, boolean defaultValue) {
74         // All instrumentation arguments are passed as String so I have to convert the value here.
75         return Boolean.parseBoolean(
76                 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue));
77     }
78 
79     static final String ACCELERATOR_FILTER_PROPERTY = "nnCrashtestDeviceFilter";
80     static final String INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY =
81             "nnCrashtestIncludeNnapiReference";
82 
getTargetAcceleratorNames()83     static List<String> getTargetAcceleratorNames() {
84         List<String> accelerators = new ArrayList<>();
85         String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ".+");
86         accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter(
87                 name -> name.matches(acceleratorFilter)).collect(
88                 Collectors.toList()));
89         if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) {
90             accelerators.add(null); // running tests with no specified target accelerator too
91         }
92         return accelerators;
93     }
94 
95 
perAcceleratorTestConfig(List<Object[]> testConfig)96     static List<Object[]> perAcceleratorTestConfig(List<Object[]> testConfig) {
97         return testConfig.stream()
98                 .flatMap(currConfigurationParams -> getTargetAcceleratorNames().stream().map(
99                         accelerator -> {
100                             Object[] result =
101                                     Arrays.copyOf(currConfigurationParams,
102                                             currConfigurationParams.length + 1);
103                             result[currConfigurationParams.length] = accelerator;
104                             return result;
105                         }))
106                 .collect(Collectors.toList());
107     }
108 
109     class DriverLivenessChecker implements Callable<Boolean> {
110         final Processor mProcessor;
111         private final AtomicBoolean mRun = new AtomicBoolean(true);
112         private final TestModels.TestModelEntry mTestModelEntry;
113 
DriverLivenessChecker(Context context, String acceleratorName, TestModels.TestModelEntry testModelEntry)114         public DriverLivenessChecker(Context context, String acceleratorName,
115                 TestModels.TestModelEntry testModelEntry) {
116             mProcessor = new Processor(context,
117                     new Processor.Callback() {
118                         @Override
119                         public void onBenchmarkFinish(boolean ok) {
120                         }
121 
122                         @Override
123                         public void onStatusUpdate(int testNumber, int numTests, String modelName) {
124                         }
125                     }, new int[0]);
126             mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
127             mProcessor.setCompleteInputSet(false);
128             mProcessor.setNnApiAcceleratorName(acceleratorName);
129             mTestModelEntry = testModelEntry;
130         }
131 
stop()132         public void stop() {
133             mRun.set(false);
134         }
135 
136         @Override
call()137         public Boolean call() throws Exception {
138             while (mRun.get()) {
139                 try {
140                     BenchmarkResult modelExecutionResult = mProcessor.getInstrumentationResult(
141                             mTestModelEntry, 0, 3);
142                     if (modelExecutionResult.hasBenchmarkError()) {
143                         Log.e(TAG, String.format("Benchmark failed with message %s",
144                                 modelExecutionResult.getBenchmarkError()));
145                         return false;
146                     }
147                 } catch (IOException | BenchmarkException e) {
148                     Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName));
149                     return false;
150                 }
151             }
152 
153             return true;
154         }
155     }
156 }
157