• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.app;
18 
19 import android.annotation.SuppressLint;
20 import android.app.Activity;
21 import android.content.Intent;
22 import android.os.Bundle;
23 import android.util.Log;
24 import android.view.WindowManager;
25 import android.widget.TextView;
26 import com.android.nn.benchmark.core.BenchmarkException;
27 import com.android.nn.benchmark.core.BenchmarkResult;
28 import com.android.nn.benchmark.core.Processor;
29 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
30 import com.android.nn.benchmark.core.TfLiteBackend;
31 import java.io.IOException;
32 import java.time.Duration;
33 import java.util.concurrent.ExecutorService;
34 import java.util.concurrent.Executors;
35 
36 public class NNBenchmark extends Activity implements Processor.Callback {
37     public static final String TAG = "NN_BENCHMARK";
38 
39     public static final String EXTRA_ENABLE_LONG = "enable long";
40     public static final String EXTRA_ENABLE_PAUSE = "enable pause";
41     public static final String EXTRA_DISABLE_NNAPI = "disable NNAPI";
42     public static final String EXTRA_TESTS = "tests";
43 
44     public static final String EXTRA_RESULTS_TESTS = "tests";
45     public static final String EXTRA_RESULTS_RESULTS = "results";
46     public static final long PROCESSOR_TERMINATION_TIMEOUT_MS = Duration.ofSeconds(20).toMillis();
47     public static final String EXTRA_MAX_ITERATIONS = "max_iterations";
48 
49     private int mTestList[];
50 
51     private boolean mUseNnApiSupportLibrary = false;
52     private boolean mExtractNnApiSupportLibrary = false;
53     private String mNnApiSupportLibraryVendor = "";
54 
55     private Processor mProcessor;
56     private final ExecutorService executorService = Executors.newSingleThreadExecutor();
57 
58     private TextView mTextView;
59 
60     // Initialize the parameters for Instrumentation tests.
prepareInstrumentationTest()61     protected void prepareInstrumentationTest() {
62         mTestList = new int[1];
63         mProcessor = new Processor(this, this, mTestList);
64     }
65 
setUseNNApi(boolean useNNApi)66     public void setUseNNApi(boolean useNNApi) {
67         mProcessor.setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
68     }
69 
setNnApiAcceleratorName(String acceleratorName)70     public void setNnApiAcceleratorName(String acceleratorName) {
71         mProcessor.setNnApiAcceleratorName(acceleratorName);
72     }
73 
setCompleteInputSet(boolean completeInputSet)74     public void setCompleteInputSet(boolean completeInputSet) {
75         mProcessor.setCompleteInputSet(completeInputSet);
76     }
77 
enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)78     public void enableCompilationCachingBenchmarks(
79             float warmupTimeSeconds, float runTimeSeconds, int maxIterations) {
80         mProcessor.enableCompilationCachingBenchmarks(
81                 warmupTimeSeconds, runTimeSeconds, maxIterations);
82     }
83 
setUseNnApiSupportLibrary(boolean value)84     public void setUseNnApiSupportLibrary(boolean value) {
85         mUseNnApiSupportLibrary = value;
86         mProcessor.setUseNnApiSupportLibrary(mUseNnApiSupportLibrary);
87     }
88 
setNnApiSupportLibraryVendor(String value)89     public void setNnApiSupportLibraryVendor(String value) {
90         mNnApiSupportLibraryVendor = value;
91         mProcessor.setNnApiSupportLibraryVendor(mNnApiSupportLibraryVendor);
92     }
93 
setExtractNnApiSupportLibrary(boolean value)94     public void setExtractNnApiSupportLibrary(boolean value) {
95         mExtractNnApiSupportLibrary = value;
96         mProcessor.setExtractNnApiSupportLibrary(value);
97     }
98 
99     @SuppressLint("SetTextI18n")
100     @Override
onCreate(Bundle savedInstanceState)101     protected void onCreate(Bundle savedInstanceState) {
102         super.onCreate(savedInstanceState);
103         mTextView = new TextView(this);
104         mTextView.setTextSize(20);
105         mTextView.setText("Running NN benchmark...");
106         setContentView(mTextView);
107         getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
108     }
109 
110     @Override
onPause()111     protected void onPause() {
112         super.onPause();
113         if (mProcessor != null) {
114             mProcessor.exitWithTimeout(PROCESSOR_TERMINATION_TIMEOUT_MS);
115             mProcessor = null;
116         }
117     }
118 
onBenchmarkFinish(boolean ok)119     public void onBenchmarkFinish(boolean ok) {
120         if (ok) {
121             Intent intent = new Intent();
122             intent.putExtra(EXTRA_RESULTS_TESTS, mTestList);
123             intent.putExtra(EXTRA_RESULTS_RESULTS, mProcessor.getTestResults());
124             setResult(RESULT_OK, intent);
125         } else {
126             setResult(RESULT_CANCELED);
127         }
128         finish();
129     }
130 
131     @SuppressLint("DefaultLocale")
onStatusUpdate(int testNumber, int numTests, String modelName)132     public void onStatusUpdate(int testNumber, int numTests, String modelName) {
133         runOnUiThread(
134                 () -> {
135                     mTextView.setText(
136                             String.format(
137                                     "Running test %d of %d: %s", testNumber, numTests, modelName));
138                 });
139     }
140 
141     @Override
onResume()142     protected void onResume() {
143         super.onResume();
144         Intent i = getIntent();
145         mTestList = i.getIntArrayExtra(EXTRA_TESTS);
146         if (mTestList != null && mTestList.length > 0) {
147             Log.v(TAG, String.format("Starting benchmark with %d test", mTestList.length));
148             mProcessor = new Processor(this, this, mTestList);
149             mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
150             mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
151             mProcessor.setTfLiteBackend(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false) ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
152             mProcessor.setMaxRunIterations(i.getIntExtra(EXTRA_MAX_ITERATIONS, 0));
153             mProcessor.setUseNnApiSupportLibrary(mUseNnApiSupportLibrary);
154             mProcessor.setNnApiSupportLibraryVendor(mNnApiSupportLibraryVendor);
155             mProcessor.setExtractNnApiSupportLibrary(mExtractNnApiSupportLibrary);
156             executorService.submit(mProcessor);
157         } else {
158             Log.v(TAG, "No test to run, doing nothing");
159         }
160     }
161 
162     @Override
onDestroy()163     protected void onDestroy() {
164         super.onDestroy();
165     }
166 
runSynchronously(TestModelEntry testModel, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)167     public BenchmarkResult runSynchronously(TestModelEntry testModel,
168         float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults) throws IOException, BenchmarkException {
169         return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds, sampleResults);
170     }
171 }
172