• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core.test;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.Intent;
22 import android.text.TextUtils;
23 import android.util.Log;
24 
25 
26 import com.android.nn.crashtest.core.CrashTest;
27 import com.android.nn.crashtest.core.CrashTestCoordinator;
28 
29 import java.io.File;
30 import java.time.Duration;
31 import java.time.LocalDateTime;
32 import java.util.Optional;
33 
34 public class RandomGraphTest implements CrashTest {
35     private static final String TAG = "NN_RAND_MODEL";
36 
37     private static final boolean ENABLE_NNAPI_LOGS = false;
38 
getGeneratorOutFilePath(String fileExtension)39     private String getGeneratorOutFilePath(String fileExtension) {
40         return mContext.getExternalFilesDir(null).getAbsolutePath() + "/"
41                 + mTestName.hashCode() + "." + fileExtension;
42     }
43 
getNnapiLogFilePath()44     private String getNnapiLogFilePath() {
45         if (ENABLE_NNAPI_LOGS) {
46             String logFile = getGeneratorOutFilePath("model.py");
47             Log.d(TAG, String.format("Writing NNAPI Fuzzer logs to %s", logFile));
48             return logFile;
49         } else {
50             return "";
51         }
52     }
53 
getFailedModelDumpPath()54     private String getFailedModelDumpPath() {
55         return getGeneratorOutFilePath("log");
56     }
57 
58     static {
59         System.loadLibrary("random_graph_test_jni");
60     }
61 
62     private enum RandomModelExecutionResult {
63         // This is the java translation of the RandomModelExecutionResult c++ enum in
64         // random_graph_test_jni.cpp
65         kSuccess(0, ""),
66         kFailedCompilation(1, "Compilation failed"),
67         kFailedExecution(2, "Execution failed"),
68         kFailedOtherNnApiCall(3,
69                 "Failure trying to interact with the driver"),
70         kInvalidModelGenerated(4, "Unable to generate a valid model"),
71         kUnsupportedModelGenerated(5, "Unable to generate a model supported by the driver");
72 
73 
74         private final int mValue;
75         private final String mDescription;
76 
RandomModelExecutionResult(int value, String description)77         RandomModelExecutionResult(int value, String description) {
78             mValue = value;
79             mDescription = description;
80         }
81 
fromNativeResult(int nativeResult)82         public static RandomModelExecutionResult fromNativeResult(int nativeResult) {
83             for (RandomModelExecutionResult currValue : RandomModelExecutionResult.values()) {
84                 if (currValue.mValue == nativeResult) {
85                     return currValue;
86                 }
87             }
88             throw new IllegalArgumentException(
89                     String.format("Invalid native result value %d", nativeResult));
90         }
91     }
92 
93     public static final String MAX_TEST_DURATION = "max_test_duration";
94     public static final String GRAPH_SIZE = "graph_size";
95     public static final String DIMENSIONS_RANGE = "dimensions_range";
96     public static final String MODELS_COUNT = "models_count";
97     public static final String PAUSE_BETWEEN_MODELS_MS = "pause_between_models_ms";
98     public static final String COMPILATION_ONLY = "compilation_only";
99     public static final String DEVICE_NAME = "device_name";
100     public static final String TEST_NAME = "test_name";
101 
102     public static final int DEFAULT_GRAPH_SIZE = 100;
103     public static final int DEFAULT_DIMENSIONS_RANGE = 100;
104     public static final int DEFAULT_MODELS_COUNT = 100;
105     public static final long DEFAULT_PAUSE_BETWEEN_MODELS_MILLIS = 300;
106     public static final boolean DEFAULT_COMPILATION_ONLY = false;
107     public static final long DEFAULT_MAX_TEST_DURATION_MILLIS = Duration.ofMinutes(2).toMillis();
108     private static final long MAX_TIME_TO_LOOK_FOR_SUITABLE_MODEL_SECONDS = 30;
109 
intentInitializer(int graphSize, int dimensionsRange, int modelsCount, long pauseBetweenModelsMillis, boolean compilationOnly, String deviceName, long maxTestDurationMillis, String testName)110     static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(int graphSize,
111             int dimensionsRange, int modelsCount, long pauseBetweenModelsMillis,
112             boolean compilationOnly, String deviceName, long maxTestDurationMillis,
113             String testName) {
114         return intent -> {
115             intent.putExtra(GRAPH_SIZE, graphSize);
116             intent.putExtra(DIMENSIONS_RANGE, dimensionsRange);
117             intent.putExtra(MODELS_COUNT, modelsCount);
118             intent.putExtra(PAUSE_BETWEEN_MODELS_MS, pauseBetweenModelsMillis);
119             intent.putExtra(COMPILATION_ONLY, compilationOnly);
120             intent.putExtra(DEVICE_NAME, deviceName);
121             intent.putExtra(MAX_TEST_DURATION, maxTestDurationMillis);
122             intent.putExtra(TEST_NAME, testName);
123         };
124     }
125 
intentInitializer( Intent copyFrom)126     static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
127             Intent copyFrom) {
128         return intentInitializer(
129                 copyFrom.getIntExtra(RandomGraphTest.GRAPH_SIZE,
130                         RandomGraphTest.DEFAULT_GRAPH_SIZE),
131                 copyFrom.getIntExtra(
132                         RandomGraphTest.DIMENSIONS_RANGE, RandomGraphTest.DEFAULT_DIMENSIONS_RANGE),
133                 copyFrom.getIntExtra(RandomGraphTest.MODELS_COUNT,
134                         RandomGraphTest.DEFAULT_MODELS_COUNT),
135                 copyFrom.getLongExtra(RandomGraphTest.PAUSE_BETWEEN_MODELS_MS,
136                         RandomGraphTest.DEFAULT_PAUSE_BETWEEN_MODELS_MILLIS),
137                 copyFrom.getBooleanExtra(
138                         RandomGraphTest.COMPILATION_ONLY, RandomGraphTest.DEFAULT_COMPILATION_ONLY),
139                 copyFrom.getStringExtra(RandomGraphTest.DEVICE_NAME),
140                 copyFrom.getLongExtra(MAX_TEST_DURATION,
141                         DEFAULT_MAX_TEST_DURATION_MILLIS),
142                 copyFrom.getStringExtra(RandomGraphTest.TEST_NAME));
143     }
144 
145     private Context mContext;
146     private String mDeviceName;
147     private boolean mCompilationOnly;
148     private int mGraphSize;
149     private int mDimensionsRange;
150     private int mModelsCount;
151     private long mPauseBetweenModelsMillis;
152     private Duration mMaxTestDuration;
153     private String mTestName;
154 
createRandomGraphGenerator(String nnApiDeviceName, int numOperations, int dimensionRange, String testName, String nnapiLogPath, String failedModelDumpPath)155     public static native long createRandomGraphGenerator(String nnApiDeviceName, int numOperations,
156             int dimensionRange,
157             String testName, String nnapiLogPath, String failedModelDumpPath);
158 
destroyRandomGraphGenerator(long generatorHandle)159     public static native long destroyRandomGraphGenerator(long generatorHandle);
160 
runRandomModel(long generatorHandle, boolean compilationOnly, long maxModelSearchTimeSeconds)161     private static native int runRandomModel(long generatorHandle,
162             boolean compilationOnly, long maxModelSearchTimeSeconds);
163 
164     @Override
init(Context context, Intent configParams, Optional<ProgressListener> progressListener)165     public void init(Context context, Intent configParams,
166             Optional<ProgressListener> progressListener) {
167         mContext = context;
168         mDeviceName = configParams.getStringExtra(DEVICE_NAME);
169         mCompilationOnly = configParams.getBooleanExtra(COMPILATION_ONLY, DEFAULT_COMPILATION_ONLY);
170         mGraphSize = configParams.getIntExtra(GRAPH_SIZE, DEFAULT_GRAPH_SIZE);
171         mDimensionsRange = configParams.getIntExtra(DIMENSIONS_RANGE, DEFAULT_DIMENSIONS_RANGE);
172         mModelsCount = configParams.getIntExtra(MODELS_COUNT, DEFAULT_MODELS_COUNT);
173         mPauseBetweenModelsMillis =
174                 configParams.getLongExtra(PAUSE_BETWEEN_MODELS_MS,
175                         DEFAULT_PAUSE_BETWEEN_MODELS_MILLIS);
176         mMaxTestDuration =
177                 Duration.ofMillis(configParams.getLongExtra(MAX_TEST_DURATION,
178                         DEFAULT_MAX_TEST_DURATION_MILLIS));
179         mTestName = configParams.getStringExtra(TEST_NAME) != null
180                 ? configParams.getStringExtra(TEST_NAME)
181                 : "no-name";
182     }
183 
184     @SuppressLint("DefaultLocale")
185     @Override
call()186     public Optional<String> call() throws Exception {
187         LocalDateTime testStart = LocalDateTime.now();
188         Log.i(TAG,
189                 String.format(String.format(
190                         "Starting test '%s', testing %d models of size %d and dimension range %d "
191                                 + "for a max duration of %s on device %s.",
192                         mTestName, mModelsCount, mGraphSize, mDimensionsRange, mMaxTestDuration,
193                         mDeviceName != null ? mDeviceName : "no-device")));
194 
195         final long generatorHandle = RandomGraphTest.createRandomGraphGenerator(mDeviceName,
196                 mGraphSize, mDimensionsRange, mTestName, getNnapiLogFilePath(),
197                 getFailedModelDumpPath());
198         if (generatorHandle == 0) {
199             Log.e(TAG, "Unable to initialize random graph generator, failing test");
200             return failure("Unable to initialize random graph generator");
201         }
202         try {
203             for (int i = 0; i < mModelsCount; i++) {
204                 if (Duration.between(testStart, LocalDateTime.now()).plus(
205                         Duration.ofSeconds(MAX_TIME_TO_LOOK_FOR_SUITABLE_MODEL_SECONDS)).compareTo(
206                         mMaxTestDuration)
207                         >= 0) {
208                     Log.d(TAG, "Max test duration reached, ending test");
209                     break;
210                 }
211 
212                 int nativeExecutionResult = runRandomModel(generatorHandle,
213                         mCompilationOnly, MAX_TIME_TO_LOOK_FOR_SUITABLE_MODEL_SECONDS);
214 
215                 RandomModelExecutionResult executionResult =
216                         RandomModelExecutionResult.fromNativeResult(nativeExecutionResult);
217 
218                 if (executionResult != RandomModelExecutionResult.kSuccess) {
219                     Log.w(TAG, String.format(
220                             "Received failure result '%s' at iteration %d, failing",
221                             executionResult.mDescription, i));
222                     if (executionResult == RandomModelExecutionResult.kFailedExecution ||
223                             executionResult == RandomModelExecutionResult.kFailedCompilation) {
224                         Log.i(TAG, String.format("Model has been dumped at path '%s'",
225                                 getFailedModelDumpPath()));
226                     } else if (
227                             executionResult == RandomModelExecutionResult.kUnsupportedModelGenerated
228                                     || executionResult
229                                     == RandomModelExecutionResult.kInvalidModelGenerated) {
230                         Log.w(TAG, String.format(
231                                 "Unable to find a valid model for test '%s', returning success "
232                                         + "anyway",
233                                 mTestName));
234 
235                         return success();
236                     }
237 
238                     return failure(executionResult.mDescription);
239                 } else if (!TextUtils.isEmpty(getNnapiLogFilePath())) {
240                     (new File(getNnapiLogFilePath())).delete();
241                 }
242 
243                 Thread.sleep(mPauseBetweenModelsMillis);
244             }
245 
246             return success();
247         } finally {
248             RandomGraphTest.destroyRandomGraphGenerator(generatorHandle);
249         }
250     }
251 }
252