• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.content.Context;
22 import android.os.Trace;
23 import android.util.Log;
24 import android.util.Pair;
25 
26 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
27 import java.io.IOException;
28 import java.util.Collections;
29 import java.util.List;
30 import java.util.concurrent.CountDownLatch;
31 import java.util.concurrent.atomic.AtomicBoolean;
32 
33 /** Processor is a helper thread for running the work without blocking the UI thread. */
34 public class Processor implements Runnable {
35 
36 
37     public interface Callback {
onBenchmarkFinish(boolean ok)38         void onBenchmarkFinish(boolean ok);
39 
onStatusUpdate(int testNumber, int numTests, String modelName)40         void onStatusUpdate(int testNumber, int numTests, String modelName);
41     }
42 
43     protected static final String TAG = "NN_BENCHMARK";
44     private Context mContext;
45 
46     private final AtomicBoolean mRun = new AtomicBoolean(true);
47 
48     volatile boolean mHasBeenStarted = false;
49     // You cannot restart a thread, so the completion flag is final
50     private final CountDownLatch mCompleted = new CountDownLatch(1);
51     private NNTestBase mTest;
52     private int mTestList[];
53     private BenchmarkResult mTestResults[];
54 
55     private Processor.Callback mCallback;
56 
57     private TfLiteBackend mBackend;
58     private boolean mMmapModel;
59     private boolean mCompleteInputSet;
60     private boolean mToggleLong;
61     private boolean mTogglePause;
62     private String mAcceleratorName;
63     private boolean mIgnoreUnsupportedModels;
64     private boolean mRunModelCompilationOnly;
65     // Max number of benchmark iterations to do in run method.
66     // Less or equal to 0 means unlimited
67     private int mMaxRunIterations;
68 
69     private boolean mBenchmarkCompilationCaching;
70     private float mCompilationBenchmarkWarmupTimeSeconds;
71     private float mCompilationBenchmarkRunTimeSeconds;
72     private int mCompilationBenchmarkMaxIterations;
73 
74     // Used to avoid accessing the Instrumentation Arguments when the crash tests are spawning
75     // a separate process.
76     private String mModelFilterRegex;
77 
78     private boolean mUseNnApiSupportLibrary;
79     private boolean mExtractNnApiSupportLibrary;
80 
Processor(Context context, Processor.Callback callback, int[] testList)81     public Processor(Context context, Processor.Callback callback, int[] testList) {
82         mContext = context;
83         mCallback = callback;
84         mTestList = testList;
85         if (mTestList != null) {
86             mTestResults = new BenchmarkResult[mTestList.length];
87         }
88         mAcceleratorName = null;
89         mIgnoreUnsupportedModels = false;
90         mRunModelCompilationOnly = false;
91         mMaxRunIterations = 0;
92         mBenchmarkCompilationCaching = false;
93         mBackend = TfLiteBackend.CPU;
94         mModelFilterRegex = null;
95         mUseNnApiSupportLibrary = false;
96         mExtractNnApiSupportLibrary = false;
97     }
98 
setUseNNApi(boolean useNNApi)99     public void setUseNNApi(boolean useNNApi) {
100         setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
101     }
102 
setTfLiteBackend(TfLiteBackend backend)103     public void setTfLiteBackend(TfLiteBackend backend) {
104         mBackend = backend;
105     }
106 
setCompleteInputSet(boolean completeInputSet)107     public void setCompleteInputSet(boolean completeInputSet) {
108         mCompleteInputSet = completeInputSet;
109     }
110 
setToggleLong(boolean toggleLong)111     public void setToggleLong(boolean toggleLong) {
112         mToggleLong = toggleLong;
113     }
114 
setTogglePause(boolean togglePause)115     public void setTogglePause(boolean togglePause) {
116         mTogglePause = togglePause;
117     }
118 
setNnApiAcceleratorName(String acceleratorName)119     public void setNnApiAcceleratorName(String acceleratorName) {
120         mAcceleratorName = acceleratorName;
121     }
122 
setIgnoreUnsupportedModels(boolean value)123     public void setIgnoreUnsupportedModels(boolean value) {
124         mIgnoreUnsupportedModels = value;
125     }
126 
setRunModelCompilationOnly(boolean value)127     public void setRunModelCompilationOnly(boolean value) {
128         mRunModelCompilationOnly = value;
129     }
130 
setMmapModel(boolean value)131     public void setMmapModel(boolean value) {
132         mMmapModel = value;
133     }
134 
setMaxRunIterations(int value)135     public void setMaxRunIterations(int value) {
136         mMaxRunIterations = value;
137     }
138 
setModelFilterRegex(String value)139     public void setModelFilterRegex(String value) {
140         this.mModelFilterRegex = value;
141     }
142 
setUseNnApiSupportLibrary(boolean value)143     public void setUseNnApiSupportLibrary(boolean value) { mUseNnApiSupportLibrary = value; }
setExtractNnApiSupportLibrary(boolean value)144     public void setExtractNnApiSupportLibrary(boolean value) { mExtractNnApiSupportLibrary = value; }
145 
enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)146     public void enableCompilationCachingBenchmarks(
147             float warmupTimeSeconds, float runTimeSeconds, int maxIterations) {
148         mBenchmarkCompilationCaching = true;
149         mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds;
150         mCompilationBenchmarkRunTimeSeconds = runTimeSeconds;
151         mCompilationBenchmarkMaxIterations = maxIterations;
152     }
153 
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)154     public BenchmarkResult getInstrumentationResult(
155             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
156             throws IOException, BenchmarkException {
157         return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false);
158     }
159 
160     // Method to retrieve benchmark results for instrumentation tests.
161     // Returns null if the processor is configured to run compilation only
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)162     public BenchmarkResult getInstrumentationResult(
163             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds,
164             boolean sampleResults)
165             throws IOException, BenchmarkException {
166         mTest = changeTest(mTest, t);
167         mTest.setSampleResult(sampleResults);
168         try {
169             BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(
170                     warmupTimeSeconds,
171                     runTimeSeconds);
172             return result;
173         } finally {
174             mTest.destroy();
175             mTest = null;
176         }
177     }
178 
isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)179     public static boolean isTestModelSupportedByAccelerator(Context context,
180             TestModels.TestModelEntry testModelEntry, String acceleratorName)
181             throws NnApiDelegationFailure {
182         try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI,
183                 /*enableIntermediateTensorsDump=*/false,
184                 /*mmapModel=*/ false,
185                 NNTestBase.shouldUseNnApiSupportLibrary(),
186                 NNTestBase.shouldExtractNnApiSupportLibrary()
187             )) {
188             tb.setNNApiDeviceName(acceleratorName);
189             return tb.setupModel(context);
190         } catch (IOException e) {
191             Log.w(TAG,
192                     String.format("Error trying to check support for model %s on accelerator %s",
193                             testModelEntry.mModelName, acceleratorName), e);
194             return false;
195         } catch (NnApiDelegationFailure nnApiDelegationFailure) {
196             if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) {
197                 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not
198                 // supporting all operation in the model
199                 return false;
200             }
201 
202             throw nnApiDelegationFailure;
203         }
204     }
205 
changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)206     private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
207             throws IOException, UnsupportedModelException, NnApiDelegationFailure {
208         if (oldTestBase != null) {
209             // Make sure we don't leak memory.
210             oldTestBase.destroy();
211         }
212         NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false,
213             mMmapModel, mUseNnApiSupportLibrary, mExtractNnApiSupportLibrary);
214         if (mBackend == TfLiteBackend.NNAPI) {
215             tb.setNNApiDeviceName(mAcceleratorName);
216         }
217         if (!tb.setupModel(mContext)) {
218             throw new UnsupportedModelException("Cannot initialise model");
219         }
220         return tb;
221     }
222 
223     // Run one loop of kernels for at most the specified minimum time.
224     // The function returns the average time in ms for the test run
runBenchmarkLoop(float maxTime, boolean completeInputSet)225     private BenchmarkResult runBenchmarkLoop(float maxTime, boolean completeInputSet)
226             throws IOException {
227         try {
228             // Run the kernel
229             Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
230             if (maxTime > 0.f) {
231                 if (completeInputSet) {
232                     results = mTest.runBenchmarkCompleteInputSet(1, maxTime);
233                 } else {
234                     results = mTest.runBenchmark(maxTime);
235                 }
236             } else {
237                 results = mTest.runInferenceOnce();
238             }
239             return BenchmarkResult.fromInferenceResults(
240                     mTest.getTestInfo(),
241                     mBackend.toString(),
242                     results.first,
243                     results.second,
244                     mTest.getEvaluator());
245         } catch (BenchmarkException e) {
246             return new BenchmarkResult(e.getMessage());
247         }
248     }
249 
250     // Run one loop of compilations for at least the specified minimum time.
251     // The function will set the compilation results into the provided benchmark result object.
runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)252     private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime,
253             int maxIterations, BenchmarkResult benchmarkResult) throws IOException {
254         try {
255             CompilationBenchmarkResult result =
256                     mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations);
257             benchmarkResult.setCompilationBenchmarkResult(result);
258         } catch (BenchmarkException e) {
259             benchmarkResult.setBenchmarkError(e.getMessage());
260         }
261     }
262 
getTestResults()263     public BenchmarkResult[] getTestResults() {
264         return mTestResults;
265     }
266 
267     // Get a benchmark result for a specific test
getBenchmark(float warmupTimeSeconds, float runTimeSeconds)268     private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds)
269             throws IOException {
270         try {
271             mTest.checkSdkVersion();
272         } catch (UnsupportedSdkException e) {
273             BenchmarkResult r = new BenchmarkResult(e.getMessage());
274             Log.w(TAG, "Unsupported SDK for test: " + r.toString());
275             return r;
276         }
277 
278         // We run a short bit of work before starting the actual test
279         // this is to let any power management do its job and respond.
280         // For NNAPI systrace usage documentation, see
281         // frameworks/ml/nn/common/include/Tracing.h.
282         try {
283             final String traceName = "[NN_LA_PWU]runBenchmarkLoop";
284             Trace.beginSection(traceName);
285             runBenchmarkLoop(warmupTimeSeconds, false);
286         } finally {
287             Trace.endSection();
288         }
289 
290         // Run the actual benchmark
291         BenchmarkResult r;
292         try {
293             final String traceName = "[NN_LA_PBM]runBenchmarkLoop";
294             Trace.beginSection(traceName);
295             r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet);
296         } finally {
297             Trace.endSection();
298         }
299 
300         // Compilation benchmark
301         if (mBenchmarkCompilationCaching) {
302             runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds,
303                     mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r);
304         }
305 
306         return r;
307     }
308 
309     @Override
run()310     public void run() {
311         mHasBeenStarted = true;
312         Log.d(TAG, "Processor starting");
313         boolean success = true;
314         int benchmarkIterationsCount = 0;
315         try {
316             while (mRun.get()) {
317                 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) {
318                     break;
319                 }
320                 benchmarkIterationsCount++;
321                 try {
322                     benchmarkAllModels();
323                 } catch (IOException | BenchmarkException e) {
324                     Log.e(TAG, "Exception during benchmark run", e);
325                     success = false;
326                     break;
327                 } catch (Throwable e) {
328                     Log.e(TAG, "Error during execution", e);
329                     throw e;
330                 }
331             }
332             Log.d(TAG, "Processor completed work");
333             mCallback.onBenchmarkFinish(success);
334         } finally {
335             if (mTest != null) {
336                 // Make sure we don't leak memory.
337                 mTest.destroy();
338                 mTest = null;
339             }
340             mCompleted.countDown();
341         }
342     }
343 
benchmarkAllModels()344     private void benchmarkAllModels() throws IOException, BenchmarkException {
345         final List<TestModelEntry> modelsList = TestModels.modelsList(mModelFilterRegex);
346         // Loop over the tests we want to benchmark
347         for (int ct = 0; ct < mTestList.length; ct++) {
348             if (!mRun.get()) {
349                 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
350                 break;
351             }
352             // For reproducibility we wait a short time for any sporadic work
353             // created by the user touching the screen to launch the test to pass.
354             // Also allows for things to settle after the test changes.
355             try {
356                 Thread.sleep(250);
357             } catch (InterruptedException ignored) {
358                 Thread.currentThread().interrupt();
359                 break;
360             }
361 
362             TestModels.TestModelEntry testModel =
363                     modelsList.get(mTestList[ct]);
364 
365             int testNumber = ct + 1;
366             mCallback.onStatusUpdate(testNumber, mTestList.length,
367                     testModel.toString());
368 
369             // Select the next test
370             try {
371                 mTest = changeTest(mTest, testModel);
372             } catch (UnsupportedModelException e) {
373                 if (mIgnoreUnsupportedModels) {
374                     Log.d(TAG, String.format(
375                             "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct,
376                             testModel.mTestName, mAcceleratorName));
377                 } else {
378                     Log.e(TAG,
379                             String.format("Cannot initialise test %d: '%s'  on accelerator %s.", ct,
380                                     testModel.mTestName, mAcceleratorName), e);
381                     throw e;
382                 }
383             }
384 
385             // If the user selected the "long pause" option, wait
386             if (mTogglePause) {
387                 for (int i = 0; (i < 100) && mRun.get(); i++) {
388                     try {
389                         Thread.sleep(100);
390                     } catch (InterruptedException ignored) {
391                         Thread.currentThread().interrupt();
392                         break;
393                     }
394                 }
395             }
396 
397             if (mRunModelCompilationOnly) {
398                 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName,
399                         mBackend.toString(),
400                         Collections.emptyList(),
401                         Collections.emptyList(), null);
402             } else {
403                 // Run the test
404                 float warmupTime = 0.3f;
405                 float runTime = 1.f;
406                 if (mToggleLong) {
407                     warmupTime = 2.f;
408                     runTime = 10.f;
409                 }
410                 mTestResults[ct] = getBenchmark(warmupTime, runTime);
411             }
412         }
413     }
414 
exit()415     public void exit() {
416         exitWithTimeout(-1l);
417     }
418 
exitWithTimeout(long timeoutMs)419     public void exitWithTimeout(long timeoutMs) {
420         mRun.set(false);
421 
422         if (mHasBeenStarted) {
423             Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs));
424             try {
425                 if (timeoutMs > 0) {
426                     boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
427                     if (!hasCompleted) {
428                         Log.w(TAG, "Exiting before execution actually completed");
429                     }
430                 } else {
431                     mCompleted.await();
432                 }
433             } catch (InterruptedException e) {
434                 Thread.currentThread().interrupt();
435                 Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
436             }
437         }
438 
439         Log.d(TAG, "Done, cleaning up");
440 
441         if (mTest != null) {
442             mTest.destroy();
443             mTest = null;
444         }
445     }
446 }
447