1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.res.AssetManager; 22 import android.os.Build; 23 import android.system.Os; 24 import android.system.ErrnoException; 25 import android.util.Log; 26 import android.util.Pair; 27 import android.widget.TextView; 28 import androidx.test.InstrumentationRegistry; 29 import com.android.nn.benchmark.core.sl.QualcommSupportLibraryDriverHandler; 30 import com.android.nn.benchmark.core.sl.SupportLibraryDriverHandler; 31 import java.io.BufferedReader; 32 import java.io.File; 33 import java.io.FileNotFoundException; 34 import java.io.FileOutputStream; 35 import java.io.IOException; 36 import java.io.InputStream; 37 import java.io.InputStreamReader; 38 import java.io.OutputStream; 39 import java.util.ArrayList; 40 import java.util.Collections; 41 import java.util.List; 42 import java.util.Optional; 43 import java.util.Random; 44 import java.util.stream.Collectors; 45 import dalvik.system.BaseDexClassLoader; 46 import android.content.res.AssetFileDescriptor; 47 import android.os.ParcelFileDescriptor; 48 import android.os.ParcelFileDescriptor.AutoCloseInputStream; 49 import java.util.jar.JarFile; 50 import java.util.jar.JarEntry; 51 52 public class NNTestBase implements AutoCloseable { 53 protected static final String TAG = "NN_TESTBASE"; 54 55 // Used to load the 'native-lib' library on application startup. 56 static { 57 System.loadLibrary("nnbenchmark_jni"); 58 } 59 60 // Does the device has any NNAPI accelerator? 61 // We only consider a real device, not 'nnapi-reference'. hasAccelerator()62 public static native boolean hasAccelerator(); 63 64 /** 65 * Fills resultList with the name of the available NNAPI accelerators 66 * 67 * @return False if any error occurred, true otherwise 68 */ getAcceleratorNames(List<String> resultList)69 public static native boolean getAcceleratorNames(List<String> resultList); hasNnApiDevice(String nnApiDeviceName)70 public static native boolean hasNnApiDevice(String nnApiDeviceName); 71 initModel( String modelFileName, int tfliteBackend, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir, long nnApiLibHandle)72 private synchronized native long initModel( 73 String modelFileName, 74 int tfliteBackend, 75 boolean enableIntermediateTensorsDump, 76 String nnApiDeviceName, 77 boolean mmapModel, 78 String nnApiCacheDir, 79 long nnApiLibHandle) throws NnApiDelegationFailure; 80 destroyModel(long modelHandle)81 private synchronized native void destroyModel(long modelHandle); 82 resizeInputTensors(long modelHandle, int[] inputShape)83 private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape); 84 runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)85 private synchronized native boolean runBenchmark(long modelHandle, 86 List<InferenceInOutSequence> inOutList, 87 List<InferenceResult> resultList, 88 int inferencesSeqMaxCount, 89 float timeoutSec, 90 int flags); 91 runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec, boolean useNnapiSl)92 private synchronized native CompilationBenchmarkResult runCompilationBenchmark( 93 long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec, 94 boolean useNnapiSl); 95 dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)96 private synchronized native void dumpAllLayers( 97 long modelHandle, 98 String dumpPath, 99 List<InferenceInOutSequence> inOutList); 100 availableAcceleratorNames()101 public static List<String> availableAcceleratorNames() { 102 List<String> availableAccelerators = new ArrayList<>(); 103 if (NNTestBase.getAcceleratorNames(availableAccelerators)) { 104 return availableAccelerators.stream().filter( 105 acceleratorName -> !acceleratorName.equalsIgnoreCase( 106 "nnapi-reference")).collect(Collectors.toList()); 107 } else { 108 Log.e(TAG, "Unable to retrieve accelerator names!!"); 109 return Collections.EMPTY_LIST; 110 } 111 } 112 113 /** Discard inference output in inference results. */ 114 public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0; 115 /** 116 * Do not expect golden outputs with inference inputs. 117 * 118 * Useful in cases where there's no straightforward golden output values 119 * for the benchmark. This will also skip calculating basic (golden 120 * output based) error metrics. 121 */ 122 public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1; 123 124 125 /** Collect only 1 benchmark result every 10 **/ 126 public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2; 127 128 protected Context mContext; 129 protected TextView mText; 130 private final String mModelName; 131 private final String mModelFile; 132 private long mModelHandle; 133 private final int[] mInputShape; 134 private final InferenceInOutSequence.FromAssets[] mInputOutputAssets; 135 private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets; 136 private final EvaluatorConfig mEvaluatorConfig; 137 private EvaluatorInterface mEvaluator; 138 private boolean mHasGoldenOutputs; 139 private TfLiteBackend mTfLiteBackend; 140 private boolean mEnableIntermediateTensorsDump = false; 141 private final int mMinSdkVersion; 142 private Optional<String> mNNApiDeviceName = Optional.empty(); 143 private boolean mMmapModel = false; 144 // Path where the current model has been stored for execution 145 private String mTemporaryModelFilePath; 146 private boolean mSampleResults; 147 148 // If set to true the test will look for the NNAPI SL binaries in the app resources, 149 // copy them into the app cache dir and configure the TfLite test to load NNAPI 150 // from the library. 151 private boolean mUseNnApiSupportLibrary = false; 152 private boolean mExtractNnApiSupportLibrary = false; 153 154 static final String USE_NNAPI_SL_PROPERTY = "useNnApiSupportLibrary"; 155 static final String EXTRACT_NNAPI_SL_PROPERTY = "extractNnApiSupportLibrary"; 156 getBooleanTestParameter(String key, boolean defaultValue)157 private static boolean getBooleanTestParameter(String key, boolean defaultValue) { 158 // All instrumentation arguments are passed as String so I have to convert the value here. 159 return Boolean.parseBoolean( 160 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue)); 161 } 162 shouldUseNnApiSupportLibrary()163 public static boolean shouldUseNnApiSupportLibrary() { 164 return getBooleanTestParameter(USE_NNAPI_SL_PROPERTY, false); 165 } 166 shouldExtractNnApiSupportLibrary()167 public static boolean shouldExtractNnApiSupportLibrary() { 168 return getBooleanTestParameter(EXTRACT_NNAPI_SL_PROPERTY, false); 169 } 170 NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)171 public NNTestBase(String modelName, String modelFile, int[] inputShape, 172 InferenceInOutSequence.FromAssets[] inputOutputAssets, 173 InferenceInOutSequence.FromDataset[] inputOutputDatasets, 174 EvaluatorConfig evaluator, int minSdkVersion) { 175 if (inputOutputAssets == null && inputOutputDatasets == null) { 176 throw new IllegalArgumentException( 177 "Neither inputOutputAssets or inputOutputDatasets given - no inputs"); 178 } 179 if (inputOutputAssets != null && inputOutputDatasets != null) { 180 throw new IllegalArgumentException( 181 "Both inputOutputAssets or inputOutputDatasets given. Only one" + 182 "supported at once."); 183 } 184 mModelName = modelName; 185 mModelFile = modelFile; 186 mInputShape = inputShape; 187 mInputOutputAssets = inputOutputAssets; 188 mInputOutputDatasets = inputOutputDatasets; 189 mModelHandle = 0; 190 mEvaluatorConfig = evaluator; 191 mMinSdkVersion = minSdkVersion; 192 mSampleResults = false; 193 } 194 setTfLiteBackend(TfLiteBackend tfLiteBackend)195 public void setTfLiteBackend(TfLiteBackend tfLiteBackend) { 196 mTfLiteBackend = tfLiteBackend; 197 } 198 enableIntermediateTensorsDump()199 public void enableIntermediateTensorsDump() { 200 enableIntermediateTensorsDump(true); 201 } 202 enableIntermediateTensorsDump(boolean value)203 public void enableIntermediateTensorsDump(boolean value) { 204 mEnableIntermediateTensorsDump = value; 205 } 206 useNNApi()207 public void useNNApi() { 208 setTfLiteBackend(TfLiteBackend.NNAPI); 209 } 210 setUseNnApiSupportLibrary(boolean value)211 public void setUseNnApiSupportLibrary(boolean value) {mUseNnApiSupportLibrary = value;} setExtractNnApiSupportLibrary(boolean value)212 public void setExtractNnApiSupportLibrary(boolean value) {mExtractNnApiSupportLibrary = value;} 213 setNNApiDeviceName(String value)214 public void setNNApiDeviceName(String value) { 215 if (mTfLiteBackend != TfLiteBackend.NNAPI) { 216 Log.e(TAG, "Setting device name has no effect when not using NNAPI"); 217 } 218 mNNApiDeviceName = Optional.ofNullable(value); 219 } 220 setMmapModel(boolean value)221 public void setMmapModel(boolean value) { 222 mMmapModel = value; 223 } 224 setupModel(Context ipcxt)225 public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure { 226 mContext = ipcxt; 227 long nnApiLibHandle = 0; 228 if (mUseNnApiSupportLibrary) { 229 // TODO: support different drivers providers maybe with a flag 230 QualcommSupportLibraryDriverHandler qcSlhandler = new QualcommSupportLibraryDriverHandler(); 231 nnApiLibHandle = qcSlhandler.getOrLoadNnApiSlHandle(mContext, mExtractNnApiSupportLibrary); 232 if (nnApiLibHandle == 0) { 233 Log.e(TAG, String 234 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.", 235 SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME)); 236 throw new NnApiDelegationFailure(String 237 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.", 238 SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME)); 239 } 240 } 241 if (mTemporaryModelFilePath != null) { 242 deleteOrWarn(mTemporaryModelFilePath); 243 } 244 mTemporaryModelFilePath = copyAssetToFile(); 245 String nnApiCacheDir = mContext.getCodeCacheDir().toString(); 246 mModelHandle = initModel( 247 mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump, 248 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir, nnApiLibHandle); 249 if (mModelHandle == 0) { 250 Log.e(TAG, "Failed to init the model"); 251 return false; 252 } 253 if (!resizeInputTensors(mModelHandle, mInputShape)) { 254 return false; 255 } 256 257 if (mEvaluatorConfig != null) { 258 mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets()); 259 } 260 return true; 261 } 262 getTestInfo()263 public String getTestInfo() { 264 return mModelName; 265 } 266 getEvaluator()267 public EvaluatorInterface getEvaluator() { 268 return mEvaluator; 269 } 270 checkSdkVersion()271 public void checkSdkVersion() throws UnsupportedSdkException { 272 if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) { 273 throw new UnsupportedSdkException("SDK version not supported. Mininum required: " + 274 mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT); 275 } 276 } 277 deleteOrWarn(String path)278 private void deleteOrWarn(String path) { 279 if (!new File(path).delete()) { 280 Log.w(TAG, String.format( 281 "Unable to delete file '%s'. This might cause device to run out of space.", 282 path)); 283 } 284 } 285 286 getInputOutputAssets()287 private List<InferenceInOutSequence> getInputOutputAssets() throws IOException { 288 // TODO: Caching, don't read inputs for every inference 289 List<InferenceInOutSequence> inOutList = 290 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets); 291 292 Boolean lastGolden = null; 293 for (InferenceInOutSequence sequence : inOutList) { 294 mHasGoldenOutputs = sequence.hasGoldenOutput(); 295 if (lastGolden == null) { 296 lastGolden = mHasGoldenOutputs; 297 } else { 298 if (lastGolden != mHasGoldenOutputs) { 299 throw new IllegalArgumentException( 300 "Some inputs for " + mModelName + " have outputs while some don't."); 301 } 302 } 303 } 304 return inOutList; 305 } 306 getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)307 public static List<InferenceInOutSequence> getInputOutputAssets(Context context, 308 InferenceInOutSequence.FromAssets[] inputOutputAssets, 309 InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException { 310 // TODO: Caching, don't read inputs for every inference 311 List<InferenceInOutSequence> inOutList = new ArrayList<>(); 312 if (inputOutputAssets != null) { 313 for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) { 314 inOutList.add(ioAsset.readAssets(context.getAssets())); 315 } 316 } 317 if (inputOutputDatasets != null) { 318 for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) { 319 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir())); 320 } 321 } 322 323 return inOutList; 324 } 325 getDefaultFlags()326 public int getDefaultFlags() { 327 int flags = 0; 328 if (!mHasGoldenOutputs) { 329 flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT; 330 } 331 if (mEvaluator == null) { 332 flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT; 333 } 334 // For very long tests we will collect only a sample of the results 335 if (mSampleResults) { 336 flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS; 337 } 338 return flags; 339 } 340 dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)341 public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize) 342 throws IOException { 343 if (!dumpDir.exists() || !dumpDir.isDirectory()) { 344 throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory"); 345 } 346 if (!mEnableIntermediateTensorsDump) { 347 throw new IllegalStateException("mEnableIntermediateTensorsDump is " + 348 "set to false, impossible to proceed"); 349 } 350 351 List<InferenceInOutSequence> ios = getInputOutputAssets(); 352 dumpAllLayers(mModelHandle, dumpDir.toString(), 353 ios.subList(inputAssetIndex, inputAssetSize)); 354 } 355 runInferenceOnce()356 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce() 357 throws IOException, BenchmarkException { 358 List<InferenceInOutSequence> ios = getInputOutputAssets(); 359 int flags = getDefaultFlags(); 360 Pair<List<InferenceInOutSequence>, List<InferenceResult>> output = 361 runBenchmark(ios, 1, Float.MAX_VALUE, flags); 362 return output; 363 } 364 runBenchmark(float timeoutSec)365 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec) 366 throws IOException, BenchmarkException { 367 // Run as many as possible before timeout. 368 int flags = getDefaultFlags(); 369 return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags); 370 } 371 372 /** Run through whole input set (once or multiple times). */ runBenchmarkCompleteInputSet( int minInferences, float timeoutSec)373 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet( 374 int minInferences, 375 float timeoutSec) 376 throws IOException, BenchmarkException { 377 int flags = getDefaultFlags(); 378 List<InferenceInOutSequence> ios = getInputOutputAssets(); 379 int setInferences = 0; 380 for (InferenceInOutSequence iosSeq : ios) { 381 setInferences += iosSeq.size(); 382 } 383 int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil. 384 int totalSequenceInferencesCount = ios.size() * setRepeat; 385 int expectedResults = setInferences * setRepeat; 386 387 Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = 388 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, 389 flags); 390 if (result.second.size() != expectedResults) { 391 // We reached a timeout or failed to evaluate whole set for other reason, abort. 392 @SuppressLint("DefaultLocale") 393 final String errorMsg = String.format( 394 "Failed to evaluate complete input set, in %f seconds expected: %d, received:" 395 + " %d", 396 timeoutSec, expectedResults, result.second.size()); 397 Log.w(TAG, errorMsg); 398 throw new IllegalStateException(errorMsg); 399 } 400 return result; 401 } 402 runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)403 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark( 404 List<InferenceInOutSequence> inOutList, 405 int inferencesSeqMaxCount, 406 float timeoutSec, 407 int flags) 408 throws IOException, BenchmarkException { 409 if (mModelHandle == 0) { 410 throw new UnsupportedModelException("Unsupported model"); 411 } 412 List<InferenceResult> resultList = new ArrayList<>(); 413 if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, 414 timeoutSec, flags)) { 415 throw new BenchmarkException("Failed to run benchmark"); 416 } 417 return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( 418 inOutList, resultList); 419 } 420 runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)421 public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec, 422 float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException { 423 if (mModelHandle == 0) { 424 throw new UnsupportedModelException("Unsupported model"); 425 } 426 CompilationBenchmarkResult result = runCompilationBenchmark( 427 mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec, 428 shouldUseNnApiSupportLibrary()); 429 if (result == null) { 430 throw new BenchmarkException("Failed to run compilation benchmark"); 431 } 432 return result; 433 } 434 destroy()435 public void destroy() { 436 if (mModelHandle != 0) { 437 destroyModel(mModelHandle); 438 mModelHandle = 0; 439 } 440 if (mTemporaryModelFilePath != null) { 441 deleteOrWarn(mTemporaryModelFilePath); 442 mTemporaryModelFilePath = null; 443 } 444 } 445 446 private final Random mRandom = new Random(System.currentTimeMillis()); 447 448 // We need to copy it to cache dir, so that TFlite can load it directly. copyAssetToFile()449 private String copyAssetToFile() throws IOException { 450 @SuppressLint("DefaultLocale") 451 String outFileName = 452 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(), 453 mModelFile, 454 Thread.currentThread().getId(), mRandom.nextInt(10000)); 455 456 copyAssetToFile(mContext, mModelFile + ".tflite", outFileName); 457 return outFileName; 458 } 459 copyModelToFile(Context context, String modelFileName, File targetFile)460 public static boolean copyModelToFile(Context context, String modelFileName, File targetFile) 461 throws IOException { 462 if (!targetFile.exists() && !targetFile.createNewFile()) { 463 Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath())); 464 return false; 465 } 466 NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath()); 467 return true; 468 } 469 copyAssetToFile(Context context, String modelAssetName, String targetPath)470 public static void copyAssetToFile(Context context, String modelAssetName, String targetPath) 471 throws IOException { 472 AssetManager assetManager = context.getAssets(); 473 try { 474 File outFile = new File(targetPath); 475 476 try (InputStream in = assetManager.open(modelAssetName); 477 FileOutputStream out = new FileOutputStream(outFile)) { 478 copyFull(in, out); 479 } 480 } catch (IOException e) { 481 Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); 482 throw e; 483 } 484 } 485 copyFull(InputStream in, OutputStream out)486 public static void copyFull(InputStream in, OutputStream out) throws IOException { 487 byte[] byteBuffer = new byte[1024]; 488 int readBytes = -1; 489 while ((readBytes = in.read(byteBuffer)) != -1) { 490 out.write(byteBuffer, 0, readBytes); 491 } 492 } 493 494 @Override close()495 public void close() { 496 destroy(); 497 } 498 setSampleResult(boolean sampleResults)499 public void setSampleResult(boolean sampleResults) { 500 this.mSampleResults = sampleResults; 501 } 502 } 503