• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core.test;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.Intent;
22 import android.util.Log;
23 
24 import com.android.nn.benchmark.core.BenchmarkException;
25 import com.android.nn.benchmark.core.BenchmarkResult;
26 import com.android.nn.benchmark.core.Processor;
27 import com.android.nn.benchmark.core.TestModels;
28 import com.android.nn.benchmark.core.TfLiteBackend;
29 import com.android.nn.crashtest.app.AcceleratorSpecificTestSupport;
30 import com.android.nn.crashtest.core.CrashTest;
31 import com.android.nn.crashtest.core.CrashTestCoordinator;
32 
33 import java.io.IOException;
34 import java.util.Arrays;
35 import java.util.List;
36 import java.util.Optional;
37 import java.util.concurrent.Callable;
38 import java.util.concurrent.CountDownLatch;
39 import java.util.concurrent.ExecutionException;
40 import java.util.concurrent.ExecutorService;
41 import java.util.concurrent.Executors;
42 import java.util.concurrent.Future;
43 import java.util.stream.Stream;
44 
45 public class PerformanceDegradationTest implements CrashTest {
46     public static final String TAG = "NN_PERF_DEG";
47 
48     private static final Processor.Callback mNoOpCallback = new Processor.Callback() {
49         @Override
50         public void onBenchmarkFinish(boolean ok) {
51         }
52 
53         @Override
54         public void onStatusUpdate(int testNumber, int numTests, String modelName) {
55         }
56     };
57 
58     public static final String WARMUP_SECONDS = "warmup_seconds";
59     public static final String RUN_TIME_SECONDS = "run_time_seconds";
60     public static final String ACCELERATOR_NAME = "accelerator_name";
61     public static final float DEFAULT_WARMUP_SECONDS = 3.0f;
62     public static final float DEFAULT_RUN_TIME_SECONDS = 10.0f;
63     public static final String THREAD_COUNT = "thread_count";
64     public static final int DEFAULT_THREAD_COUNT = 5;
65     public static final String MAX_PERFORMANCE_DEGRADATION = "max_performance_degradation";
66     public static final int DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE = 100;
67     public static final String TEST_NAME = "test_name";
68     private static final long INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS = 500;
69 
intentInitializer( float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount, int maxPerformanceDegradationPercent, String testName)70     static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
71             float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount,
72             int maxPerformanceDegradationPercent, String testName) {
73         return intent -> {
74             intent.putExtra(WARMUP_SECONDS, warmupTimeSeconds);
75             intent.putExtra(RUN_TIME_SECONDS, runTimeSeconds);
76             intent.putExtra(ACCELERATOR_NAME, acceleratorName);
77             intent.putExtra(THREAD_COUNT, threadCount);
78             intent.putExtra(MAX_PERFORMANCE_DEGRADATION, maxPerformanceDegradationPercent);
79             intent.putExtra(TEST_NAME, testName);
80         };
81     }
82 
intentInitializer( Intent copyFrom)83     static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
84             Intent copyFrom) {
85         return intentInitializer(
86                 copyFrom.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS),
87                 copyFrom.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS),
88                 copyFrom.getStringExtra(ACCELERATOR_NAME),
89                 copyFrom.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT),
90                 copyFrom.getIntExtra(MAX_PERFORMANCE_DEGRADATION,
91                         DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE),
92                 copyFrom.getStringExtra(TEST_NAME));
93     }
94 
95     private Context mContext;
96     private float mWarmupTimeSeconds;
97     private float mRunTimeSeconds;
98     private String mAcceleratorName;
99     private int mThreadCount;
100     private int mMaxPerformanceDegradationPercent;
101     private String mTestName;
102 
103     @Override
init(Context context, Intent configParams, Optional<ProgressListener> progressListener)104     public void init(Context context, Intent configParams,
105             Optional<ProgressListener> progressListener) {
106         mContext = context;
107 
108         mWarmupTimeSeconds = configParams.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS);
109         mRunTimeSeconds = configParams.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS);
110         mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME);
111         mThreadCount = configParams.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT);
112         mMaxPerformanceDegradationPercent = configParams.getIntExtra(MAX_PERFORMANCE_DEGRADATION,
113                 DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE);
114         mTestName = configParams.getStringExtra(TEST_NAME);
115     }
116 
117     @SuppressLint("DefaultLocale")
118     @Override
call()119     public Optional<String> call() throws Exception {
120         List<TestModels.TestModelEntry> modelsForAccelerator =
121                 AcceleratorSpecificTestSupport.findAllTestModelsRunningOnAccelerator(mContext,
122                         mAcceleratorName);
123 
124         if (modelsForAccelerator.isEmpty()) {
125             return failure("Cannot find any model to use for testing");
126         }
127 
128         Log.i(TAG, String.format("Checking performance degradation using %d models",
129                 modelsForAccelerator.size()));
130 
131         TestModels.TestModelEntry modelForInference = modelsForAccelerator.get(0);
132         // The performance degradation is strongly dependent on the model used to compile
133         // so we check all the available ones.
134         for (TestModels.TestModelEntry modelForCompilation : modelsForAccelerator) {
135             Optional<String> currTestResult = testDegradationForModels(modelForInference,
136                     modelForCompilation);
137             if (isFailure(currTestResult)) {
138                 return currTestResult;
139             }
140         }
141 
142         return success();
143     }
144 
145     @SuppressLint("DefaultLocale")
testDegradationForModels( TestModels.TestModelEntry inferenceModelEntry, TestModels.TestModelEntry compilationModelEntry)146     public Optional<String> testDegradationForModels(
147             TestModels.TestModelEntry inferenceModelEntry,
148             TestModels.TestModelEntry compilationModelEntry) throws Exception {
149         Log.i(TAG, String.format(
150                 "Testing degradation in inference of model %s when running %d threads compliing "
151                         + "model %s",
152                 inferenceModelEntry.mModelName, mThreadCount, compilationModelEntry.mModelName));
153 
154         Log.d(TAG, String.format("%s: Calculating baseline", mTestName));
155         // first let's measure a baseline performance
156         final BenchmarkResult baseline = modelPerformanceCollector(inferenceModelEntry,
157                 /*start=*/ null).call();
158         if (baseline.hasBenchmarkError()) {
159             return failure(String.format("%s: Baseline has benchmark error '%s'",
160                     mTestName, baseline.getBenchmarkError()));
161         }
162         Log.d(TAG, String.format("%s: Baseline mean time is %f seconds", mTestName,
163                 baseline.getMeanTimeSec()));
164 
165         Log.d(TAG, String.format("%s: Sleeping for %d millis", mTestName,
166                 INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS));
167         Thread.sleep(INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS);
168 
169         Log.d(TAG, String.format("%s: Calculating performance with %d threads", mTestName,
170                 mThreadCount));
171         final int totalThreadCount = mThreadCount + 1;
172         final CountDownLatch start = new CountDownLatch(totalThreadCount);
173         ModelCompiler[] compilers = Stream.generate(
174                 () -> new ModelCompiler(start, mContext, mAcceleratorName,
175                         compilationModelEntry)).limit(
176                 mThreadCount).toArray(
177                 ModelCompiler[]::new);
178 
179         Callable<BenchmarkResult> performanceWithOtherCompilingThreadCollector =
180                 modelPerformanceCollector(inferenceModelEntry, start);
181 
182         ExecutorService testExecutor = Executors.newFixedThreadPool(totalThreadCount);
183         Future<?>[] compilerFutures = Arrays.stream(compilers).map(testExecutor::submit).toArray(
184                 Future[]::new);
185         BenchmarkResult benchmarkWithOtherCompilingThread = testExecutor.submit(
186                 performanceWithOtherCompilingThreadCollector).get();
187 
188         Arrays.stream(compilers).forEach(ModelCompiler::stop);
189         Arrays.stream(compilerFutures).forEach(future -> {
190             try {
191                 future.get();
192             } catch (InterruptedException | ExecutionException e) {
193                 Log.e(TAG, "Error waiting for compiler process completion", e);
194             }
195         });
196 
197         if (benchmarkWithOtherCompilingThread.hasBenchmarkError()) {
198             return failure(
199                     String.format(
200                             "%s: Test with parallel compiling thrads has benchmark error '%s'",
201                             mTestName, benchmarkWithOtherCompilingThread.getBenchmarkError()));
202         }
203 
204         Log.d(TAG, String.format("%s: Multithreaded mean time is %f seconds",
205                 mTestName, benchmarkWithOtherCompilingThread.getMeanTimeSec()));
206 
207         int performanceDegradation = (int) (((benchmarkWithOtherCompilingThread.getMeanTimeSec()
208                 / baseline.getMeanTimeSec()) - 1.0) * 100);
209 
210         Log.i(TAG, String.format(
211                 "%s: Performance degradation for accelerator %s, with %d threads is %d%%. "
212                         + "Threshold "
213                         + "is %d%%",
214                 mTestName, mAcceleratorName, mThreadCount, performanceDegradation,
215                 mMaxPerformanceDegradationPercent));
216 
217         if (performanceDegradation > mMaxPerformanceDegradationPercent) {
218             return failure(String.format("Performance degradation is %d%%. Max acceptable is %d%%",
219                     performanceDegradation, mMaxPerformanceDegradationPercent));
220         }
221 
222         return success();
223     }
224 
225 
modelPerformanceCollector( final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start)226     private Callable<BenchmarkResult> modelPerformanceCollector(
227             final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start) {
228         return () -> {
229             Processor benchmarkProcessor = new Processor(mContext, mNoOpCallback, new int[0]);
230             benchmarkProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
231             benchmarkProcessor.setNnApiAcceleratorName(mAcceleratorName);
232             if (start != null) {
233                 start.countDown();
234                 start.await();
235             }
236             final BenchmarkResult result =
237                     benchmarkProcessor.getInstrumentationResult(
238                             inferenceModelEntry, mWarmupTimeSeconds, mRunTimeSeconds);
239 
240             return result;
241         };
242     }
243 
244     private static class ModelCompiler implements Callable<Void> {
245         private static final long SLEEP_BETWEEN_COMPILATION_INTERVAL_MS = 20;
246         private final CountDownLatch mStart;
247         private final Processor mProcessor;
248         private final TestModels.TestModelEntry mTestModelEntry;
249         private volatile boolean mRun;
250 
251         ModelCompiler(final CountDownLatch start, final Context context,
252                 final String acceleratorName, TestModels.TestModelEntry testModelEntry) {
253             mStart = start;
254             mTestModelEntry = testModelEntry;
255             mProcessor = new Processor(context, mNoOpCallback, new int[0]);
256             mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
257             mProcessor.setNnApiAcceleratorName(acceleratorName);
258             mProcessor.setRunModelCompilationOnly(true);
259             mRun = true;
260         }
261 
262         @Override
263         public Void call() throws IOException, BenchmarkException {
264             if (mStart != null) {
265                 try {
266                     mStart.countDown();
267                     mStart.await();
268                 } catch (InterruptedException e) {
269                     Thread.interrupted();
270                     Log.i(TAG, "Interrupted, stopping processing");
271                     return null;
272                 }
273             }
274             while (mRun) {
275                 mProcessor.getInstrumentationResult(mTestModelEntry, 0, 0);
276                 try {
277                     Thread.sleep(SLEEP_BETWEEN_COMPILATION_INTERVAL_MS);
278                 } catch (InterruptedException e) {
279                     Thread.interrupted();
280                     return null;
281                 }
282             }
283             return null;
284         }
285 
286         public void stop() {
287             mRun = false;
288         }
289     }
290 }
291