1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.app; 18 19 import android.annotation.SuppressLint; 20 import android.app.Activity; 21 import android.content.Intent; 22 import android.os.Bundle; 23 import android.util.Log; 24 import android.view.WindowManager; 25 import android.widget.TextView; 26 import com.android.nn.benchmark.core.BenchmarkException; 27 import com.android.nn.benchmark.core.BenchmarkResult; 28 import com.android.nn.benchmark.core.Processor; 29 import com.android.nn.benchmark.core.TestModels.TestModelEntry; 30 import com.android.nn.benchmark.core.TfLiteBackend; 31 import java.io.IOException; 32 import java.time.Duration; 33 import java.util.concurrent.ExecutorService; 34 import java.util.concurrent.Executors; 35 36 public class NNBenchmark extends Activity implements Processor.Callback { 37 public static final String TAG = "NN_BENCHMARK"; 38 39 public static final String EXTRA_ENABLE_LONG = "enable long"; 40 public static final String EXTRA_ENABLE_PAUSE = "enable pause"; 41 public static final String EXTRA_DISABLE_NNAPI = "disable NNAPI"; 42 public static final String EXTRA_TESTS = "tests"; 43 44 public static final String EXTRA_RESULTS_TESTS = "tests"; 45 public static final String EXTRA_RESULTS_RESULTS = "results"; 46 public static final long PROCESSOR_TERMINATION_TIMEOUT_MS = Duration.ofSeconds(20).toMillis(); 47 public static final String EXTRA_MAX_ITERATIONS = "max_iterations"; 48 49 private int mTestList[]; 50 51 private boolean mUseNnApiSupportLibrary = false; 52 private boolean mExtractNnApiSupportLibrary = false; 53 private String mNnApiSupportLibraryVendor = ""; 54 55 private Processor mProcessor; 56 private final ExecutorService executorService = Executors.newSingleThreadExecutor(); 57 58 private TextView mTextView; 59 60 // Initialize the parameters for Instrumentation tests. prepareInstrumentationTest()61 protected void prepareInstrumentationTest() { 62 mTestList = new int[1]; 63 mProcessor = new Processor(this, this, mTestList); 64 } 65 setUseNNApi(boolean useNNApi)66 public void setUseNNApi(boolean useNNApi) { 67 mProcessor.setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU); 68 } 69 setNnApiAcceleratorName(String acceleratorName)70 public void setNnApiAcceleratorName(String acceleratorName) { 71 mProcessor.setNnApiAcceleratorName(acceleratorName); 72 } 73 setCompleteInputSet(boolean completeInputSet)74 public void setCompleteInputSet(boolean completeInputSet) { 75 mProcessor.setCompleteInputSet(completeInputSet); 76 } 77 enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)78 public void enableCompilationCachingBenchmarks( 79 float warmupTimeSeconds, float runTimeSeconds, int maxIterations) { 80 mProcessor.enableCompilationCachingBenchmarks( 81 warmupTimeSeconds, runTimeSeconds, maxIterations); 82 } 83 setUseNnApiSupportLibrary(boolean value)84 public void setUseNnApiSupportLibrary(boolean value) { 85 mUseNnApiSupportLibrary = value; 86 mProcessor.setUseNnApiSupportLibrary(mUseNnApiSupportLibrary); 87 } 88 setNnApiSupportLibraryVendor(String value)89 public void setNnApiSupportLibraryVendor(String value) { 90 mNnApiSupportLibraryVendor = value; 91 mProcessor.setNnApiSupportLibraryVendor(mNnApiSupportLibraryVendor); 92 } 93 setExtractNnApiSupportLibrary(boolean value)94 public void setExtractNnApiSupportLibrary(boolean value) { 95 mExtractNnApiSupportLibrary = value; 96 mProcessor.setExtractNnApiSupportLibrary(value); 97 } 98 99 @SuppressLint("SetTextI18n") 100 @Override onCreate(Bundle savedInstanceState)101 protected void onCreate(Bundle savedInstanceState) { 102 super.onCreate(savedInstanceState); 103 mTextView = new TextView(this); 104 mTextView.setTextSize(20); 105 mTextView.setText("Running NN benchmark..."); 106 setContentView(mTextView); 107 getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); 108 } 109 110 @Override onPause()111 protected void onPause() { 112 super.onPause(); 113 if (mProcessor != null) { 114 mProcessor.exitWithTimeout(PROCESSOR_TERMINATION_TIMEOUT_MS); 115 mProcessor = null; 116 } 117 } 118 onBenchmarkFinish(boolean ok)119 public void onBenchmarkFinish(boolean ok) { 120 if (ok) { 121 Intent intent = new Intent(); 122 intent.putExtra(EXTRA_RESULTS_TESTS, mTestList); 123 intent.putExtra(EXTRA_RESULTS_RESULTS, mProcessor.getTestResults()); 124 setResult(RESULT_OK, intent); 125 } else { 126 setResult(RESULT_CANCELED); 127 } 128 finish(); 129 } 130 131 @SuppressLint("DefaultLocale") onStatusUpdate(int testNumber, int numTests, String modelName)132 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 133 runOnUiThread( 134 () -> { 135 mTextView.setText( 136 String.format( 137 "Running test %d of %d: %s", testNumber, numTests, modelName)); 138 }); 139 } 140 141 @Override onResume()142 protected void onResume() { 143 super.onResume(); 144 Intent i = getIntent(); 145 mTestList = i.getIntArrayExtra(EXTRA_TESTS); 146 if (mTestList != null && mTestList.length > 0) { 147 Log.v(TAG, String.format("Starting benchmark with %d test", mTestList.length)); 148 mProcessor = new Processor(this, this, mTestList); 149 mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false)); 150 mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false)); 151 mProcessor.setTfLiteBackend(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false) ? TfLiteBackend.NNAPI : TfLiteBackend.CPU); 152 mProcessor.setMaxRunIterations(i.getIntExtra(EXTRA_MAX_ITERATIONS, 0)); 153 mProcessor.setUseNnApiSupportLibrary(mUseNnApiSupportLibrary); 154 mProcessor.setNnApiSupportLibraryVendor(mNnApiSupportLibraryVendor); 155 mProcessor.setExtractNnApiSupportLibrary(mExtractNnApiSupportLibrary); 156 executorService.submit(mProcessor); 157 } else { 158 Log.v(TAG, "No test to run, doing nothing"); 159 } 160 } 161 162 @Override onDestroy()163 protected void onDestroy() { 164 super.onDestroy(); 165 } 166 runSynchronously(TestModelEntry testModel, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)167 public BenchmarkResult runSynchronously(TestModelEntry testModel, 168 float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults) throws IOException, BenchmarkException { 169 return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds, sampleResults); 170 } 171 } 172