1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.content.Context; 22 import android.os.Trace; 23 import android.util.Log; 24 import android.util.Pair; 25 26 import com.android.nn.benchmark.core.TestModels.TestModelEntry; 27 import java.io.IOException; 28 import java.util.Collections; 29 import java.util.List; 30 import java.util.concurrent.CountDownLatch; 31 import java.util.concurrent.atomic.AtomicBoolean; 32 33 /** Processor is a helper thread for running the work without blocking the UI thread. */ 34 public class Processor implements Runnable { 35 36 37 public interface Callback { onBenchmarkFinish(boolean ok)38 void onBenchmarkFinish(boolean ok); 39 onStatusUpdate(int testNumber, int numTests, String modelName)40 void onStatusUpdate(int testNumber, int numTests, String modelName); 41 } 42 43 protected static final String TAG = "NN_BENCHMARK"; 44 private Context mContext; 45 46 private final AtomicBoolean mRun = new AtomicBoolean(true); 47 48 volatile boolean mHasBeenStarted = false; 49 // You cannot restart a thread, so the completion flag is final 50 private final CountDownLatch mCompleted = new CountDownLatch(1); 51 private NNTestBase mTest; 52 private int mTestList[]; 53 private BenchmarkResult mTestResults[]; 54 55 private Processor.Callback mCallback; 56 57 private TfLiteBackend mBackend; 58 private boolean mMmapModel; 59 private boolean mCompleteInputSet; 60 private boolean mToggleLong; 61 private boolean mTogglePause; 62 private String mAcceleratorName; 63 private boolean mIgnoreUnsupportedModels; 64 private boolean mRunModelCompilationOnly; 65 // Max number of benchmark iterations to do in run method. 66 // Less or equal to 0 means unlimited 67 private int mMaxRunIterations; 68 69 private boolean mBenchmarkCompilationCaching; 70 private float mCompilationBenchmarkWarmupTimeSeconds; 71 private float mCompilationBenchmarkRunTimeSeconds; 72 private int mCompilationBenchmarkMaxIterations; 73 74 // Used to avoid accessing the Instrumentation Arguments when the crash tests are spawning 75 // a separate process. 76 private String mModelFilterRegex; 77 78 private boolean mUseNnApiSupportLibrary; 79 private boolean mExtractNnApiSupportLibrary; 80 Processor(Context context, Processor.Callback callback, int[] testList)81 public Processor(Context context, Processor.Callback callback, int[] testList) { 82 mContext = context; 83 mCallback = callback; 84 mTestList = testList; 85 if (mTestList != null) { 86 mTestResults = new BenchmarkResult[mTestList.length]; 87 } 88 mAcceleratorName = null; 89 mIgnoreUnsupportedModels = false; 90 mRunModelCompilationOnly = false; 91 mMaxRunIterations = 0; 92 mBenchmarkCompilationCaching = false; 93 mBackend = TfLiteBackend.CPU; 94 mModelFilterRegex = null; 95 mUseNnApiSupportLibrary = false; 96 mExtractNnApiSupportLibrary = false; 97 } 98 setUseNNApi(boolean useNNApi)99 public void setUseNNApi(boolean useNNApi) { 100 setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU); 101 } 102 setTfLiteBackend(TfLiteBackend backend)103 public void setTfLiteBackend(TfLiteBackend backend) { 104 mBackend = backend; 105 } 106 setCompleteInputSet(boolean completeInputSet)107 public void setCompleteInputSet(boolean completeInputSet) { 108 mCompleteInputSet = completeInputSet; 109 } 110 setToggleLong(boolean toggleLong)111 public void setToggleLong(boolean toggleLong) { 112 mToggleLong = toggleLong; 113 } 114 setTogglePause(boolean togglePause)115 public void setTogglePause(boolean togglePause) { 116 mTogglePause = togglePause; 117 } 118 setNnApiAcceleratorName(String acceleratorName)119 public void setNnApiAcceleratorName(String acceleratorName) { 120 mAcceleratorName = acceleratorName; 121 } 122 setIgnoreUnsupportedModels(boolean value)123 public void setIgnoreUnsupportedModels(boolean value) { 124 mIgnoreUnsupportedModels = value; 125 } 126 setRunModelCompilationOnly(boolean value)127 public void setRunModelCompilationOnly(boolean value) { 128 mRunModelCompilationOnly = value; 129 } 130 setMmapModel(boolean value)131 public void setMmapModel(boolean value) { 132 mMmapModel = value; 133 } 134 setMaxRunIterations(int value)135 public void setMaxRunIterations(int value) { 136 mMaxRunIterations = value; 137 } 138 setModelFilterRegex(String value)139 public void setModelFilterRegex(String value) { 140 this.mModelFilterRegex = value; 141 } 142 setUseNnApiSupportLibrary(boolean value)143 public void setUseNnApiSupportLibrary(boolean value) { mUseNnApiSupportLibrary = value; } setExtractNnApiSupportLibrary(boolean value)144 public void setExtractNnApiSupportLibrary(boolean value) { mExtractNnApiSupportLibrary = value; } 145 enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)146 public void enableCompilationCachingBenchmarks( 147 float warmupTimeSeconds, float runTimeSeconds, int maxIterations) { 148 mBenchmarkCompilationCaching = true; 149 mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds; 150 mCompilationBenchmarkRunTimeSeconds = runTimeSeconds; 151 mCompilationBenchmarkMaxIterations = maxIterations; 152 } 153 getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)154 public BenchmarkResult getInstrumentationResult( 155 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds) 156 throws IOException, BenchmarkException { 157 return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false); 158 } 159 160 // Method to retrieve benchmark results for instrumentation tests. 161 // Returns null if the processor is configured to run compilation only getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)162 public BenchmarkResult getInstrumentationResult( 163 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, 164 boolean sampleResults) 165 throws IOException, BenchmarkException { 166 mTest = changeTest(mTest, t); 167 mTest.setSampleResult(sampleResults); 168 try { 169 BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark( 170 warmupTimeSeconds, 171 runTimeSeconds); 172 return result; 173 } finally { 174 mTest.destroy(); 175 mTest = null; 176 } 177 } 178 isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)179 public static boolean isTestModelSupportedByAccelerator(Context context, 180 TestModels.TestModelEntry testModelEntry, String acceleratorName) 181 throws NnApiDelegationFailure { 182 try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI, 183 /*enableIntermediateTensorsDump=*/false, 184 /*mmapModel=*/ false, 185 NNTestBase.shouldUseNnApiSupportLibrary(), 186 NNTestBase.shouldExtractNnApiSupportLibrary() 187 )) { 188 tb.setNNApiDeviceName(acceleratorName); 189 return tb.setupModel(context); 190 } catch (IOException e) { 191 Log.w(TAG, 192 String.format("Error trying to check support for model %s on accelerator %s", 193 testModelEntry.mModelName, acceleratorName), e); 194 return false; 195 } catch (NnApiDelegationFailure nnApiDelegationFailure) { 196 if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) { 197 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not 198 // supporting all operation in the model 199 return false; 200 } 201 202 throw nnApiDelegationFailure; 203 } 204 } 205 changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)206 private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) 207 throws IOException, UnsupportedModelException, NnApiDelegationFailure { 208 if (oldTestBase != null) { 209 // Make sure we don't leak memory. 210 oldTestBase.destroy(); 211 } 212 NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false, 213 mMmapModel, mUseNnApiSupportLibrary, mExtractNnApiSupportLibrary); 214 if (mBackend == TfLiteBackend.NNAPI) { 215 tb.setNNApiDeviceName(mAcceleratorName); 216 } 217 if (!tb.setupModel(mContext)) { 218 throw new UnsupportedModelException("Cannot initialise model"); 219 } 220 return tb; 221 } 222 223 // Run one loop of kernels for at most the specified minimum time. 224 // The function returns the average time in ms for the test run runBenchmarkLoop(float maxTime, boolean completeInputSet)225 private BenchmarkResult runBenchmarkLoop(float maxTime, boolean completeInputSet) 226 throws IOException { 227 try { 228 // Run the kernel 229 Pair<List<InferenceInOutSequence>, List<InferenceResult>> results; 230 if (maxTime > 0.f) { 231 if (completeInputSet) { 232 results = mTest.runBenchmarkCompleteInputSet(1, maxTime); 233 } else { 234 results = mTest.runBenchmark(maxTime); 235 } 236 } else { 237 results = mTest.runInferenceOnce(); 238 } 239 return BenchmarkResult.fromInferenceResults( 240 mTest.getTestInfo(), 241 mBackend.toString(), 242 results.first, 243 results.second, 244 mTest.getEvaluator()); 245 } catch (BenchmarkException e) { 246 return new BenchmarkResult(e.getMessage()); 247 } 248 } 249 250 // Run one loop of compilations for at least the specified minimum time. 251 // The function will set the compilation results into the provided benchmark result object. runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)252 private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, 253 int maxIterations, BenchmarkResult benchmarkResult) throws IOException { 254 try { 255 CompilationBenchmarkResult result = 256 mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations); 257 benchmarkResult.setCompilationBenchmarkResult(result); 258 } catch (BenchmarkException e) { 259 benchmarkResult.setBenchmarkError(e.getMessage()); 260 } 261 } 262 getTestResults()263 public BenchmarkResult[] getTestResults() { 264 return mTestResults; 265 } 266 267 // Get a benchmark result for a specific test getBenchmark(float warmupTimeSeconds, float runTimeSeconds)268 private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds) 269 throws IOException { 270 try { 271 mTest.checkSdkVersion(); 272 } catch (UnsupportedSdkException e) { 273 BenchmarkResult r = new BenchmarkResult(e.getMessage()); 274 Log.w(TAG, "Unsupported SDK for test: " + r.toString()); 275 return r; 276 } 277 278 // We run a short bit of work before starting the actual test 279 // this is to let any power management do its job and respond. 280 // For NNAPI systrace usage documentation, see 281 // frameworks/ml/nn/common/include/Tracing.h. 282 try { 283 final String traceName = "[NN_LA_PWU]runBenchmarkLoop"; 284 Trace.beginSection(traceName); 285 runBenchmarkLoop(warmupTimeSeconds, false); 286 } finally { 287 Trace.endSection(); 288 } 289 290 // Run the actual benchmark 291 BenchmarkResult r; 292 try { 293 final String traceName = "[NN_LA_PBM]runBenchmarkLoop"; 294 Trace.beginSection(traceName); 295 r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet); 296 } finally { 297 Trace.endSection(); 298 } 299 300 // Compilation benchmark 301 if (mBenchmarkCompilationCaching) { 302 runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds, 303 mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r); 304 } 305 306 return r; 307 } 308 309 @Override run()310 public void run() { 311 mHasBeenStarted = true; 312 Log.d(TAG, "Processor starting"); 313 boolean success = true; 314 int benchmarkIterationsCount = 0; 315 try { 316 while (mRun.get()) { 317 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) { 318 break; 319 } 320 benchmarkIterationsCount++; 321 try { 322 benchmarkAllModels(); 323 } catch (IOException | BenchmarkException e) { 324 Log.e(TAG, "Exception during benchmark run", e); 325 success = false; 326 break; 327 } catch (Throwable e) { 328 Log.e(TAG, "Error during execution", e); 329 throw e; 330 } 331 } 332 Log.d(TAG, "Processor completed work"); 333 mCallback.onBenchmarkFinish(success); 334 } finally { 335 if (mTest != null) { 336 // Make sure we don't leak memory. 337 mTest.destroy(); 338 mTest = null; 339 } 340 mCompleted.countDown(); 341 } 342 } 343 benchmarkAllModels()344 private void benchmarkAllModels() throws IOException, BenchmarkException { 345 final List<TestModelEntry> modelsList = TestModels.modelsList(mModelFilterRegex); 346 // Loop over the tests we want to benchmark 347 for (int ct = 0; ct < mTestList.length; ct++) { 348 if (!mRun.get()) { 349 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct)); 350 break; 351 } 352 // For reproducibility we wait a short time for any sporadic work 353 // created by the user touching the screen to launch the test to pass. 354 // Also allows for things to settle after the test changes. 355 try { 356 Thread.sleep(250); 357 } catch (InterruptedException ignored) { 358 Thread.currentThread().interrupt(); 359 break; 360 } 361 362 TestModels.TestModelEntry testModel = 363 modelsList.get(mTestList[ct]); 364 365 int testNumber = ct + 1; 366 mCallback.onStatusUpdate(testNumber, mTestList.length, 367 testModel.toString()); 368 369 // Select the next test 370 try { 371 mTest = changeTest(mTest, testModel); 372 } catch (UnsupportedModelException e) { 373 if (mIgnoreUnsupportedModels) { 374 Log.d(TAG, String.format( 375 "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct, 376 testModel.mTestName, mAcceleratorName)); 377 } else { 378 Log.e(TAG, 379 String.format("Cannot initialise test %d: '%s' on accelerator %s.", ct, 380 testModel.mTestName, mAcceleratorName), e); 381 throw e; 382 } 383 } 384 385 // If the user selected the "long pause" option, wait 386 if (mTogglePause) { 387 for (int i = 0; (i < 100) && mRun.get(); i++) { 388 try { 389 Thread.sleep(100); 390 } catch (InterruptedException ignored) { 391 Thread.currentThread().interrupt(); 392 break; 393 } 394 } 395 } 396 397 if (mRunModelCompilationOnly) { 398 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName, 399 mBackend.toString(), 400 Collections.emptyList(), 401 Collections.emptyList(), null); 402 } else { 403 // Run the test 404 float warmupTime = 0.3f; 405 float runTime = 1.f; 406 if (mToggleLong) { 407 warmupTime = 2.f; 408 runTime = 10.f; 409 } 410 mTestResults[ct] = getBenchmark(warmupTime, runTime); 411 } 412 } 413 } 414 exit()415 public void exit() { 416 exitWithTimeout(-1l); 417 } 418 exitWithTimeout(long timeoutMs)419 public void exitWithTimeout(long timeoutMs) { 420 mRun.set(false); 421 422 if (mHasBeenStarted) { 423 Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs)); 424 try { 425 if (timeoutMs > 0) { 426 boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS); 427 if (!hasCompleted) { 428 Log.w(TAG, "Exiting before execution actually completed"); 429 } 430 } else { 431 mCompleted.await(); 432 } 433 } catch (InterruptedException e) { 434 Thread.currentThread().interrupt(); 435 Log.w(TAG, "Interrupted while waiting for Processor to complete", e); 436 } 437 } 438 439 Log.d(TAG, "Done, cleaning up"); 440 441 if (mTest != null) { 442 mTest.destroy(); 443 mTest = null; 444 } 445 } 446 } 447