• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.res.AssetManager;
22 import android.os.Build;
23 import android.system.Os;
24 import android.system.ErrnoException;
25 import android.util.Log;
26 import android.util.Pair;
27 import android.widget.TextView;
28 import androidx.test.InstrumentationRegistry;
29 import com.android.nn.benchmark.core.sl.QualcommSupportLibraryDriverHandler;
30 import com.android.nn.benchmark.core.sl.SupportLibraryDriverHandler;
31 import java.io.BufferedReader;
32 import java.io.File;
33 import java.io.FileNotFoundException;
34 import java.io.FileOutputStream;
35 import java.io.IOException;
36 import java.io.InputStream;
37 import java.io.InputStreamReader;
38 import java.io.OutputStream;
39 import java.util.ArrayList;
40 import java.util.Collections;
41 import java.util.List;
42 import java.util.Optional;
43 import java.util.Random;
44 import java.util.stream.Collectors;
45 import dalvik.system.BaseDexClassLoader;
46 import android.content.res.AssetFileDescriptor;
47 import android.os.ParcelFileDescriptor;
48 import android.os.ParcelFileDescriptor.AutoCloseInputStream;
49 import java.util.jar.JarFile;
50 import java.util.jar.JarEntry;
51 
52 public class NNTestBase implements AutoCloseable {
53     protected static final String TAG = "NN_TESTBASE";
54 
55     // Used to load the 'native-lib' library on application startup.
56     static {
57         System.loadLibrary("nnbenchmark_jni");
58     }
59 
60     // Does the device has any NNAPI accelerator?
61     // We only consider a real device, not 'nnapi-reference'.
hasAccelerator()62     public static native boolean hasAccelerator();
63 
64     /**
65      * Fills resultList with the name of the available NNAPI accelerators
66      *
67      * @return False if any error occurred, true otherwise
68      */
getAcceleratorNames(List<String> resultList)69     public static native boolean getAcceleratorNames(List<String> resultList);
hasNnApiDevice(String nnApiDeviceName)70     public static native boolean hasNnApiDevice(String nnApiDeviceName);
71 
initModel( String modelFileName, int tfliteBackend, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir, long nnApiLibHandle)72     private synchronized native long initModel(
73             String modelFileName,
74             int tfliteBackend,
75             boolean enableIntermediateTensorsDump,
76             String nnApiDeviceName,
77             boolean mmapModel,
78             String nnApiCacheDir,
79             long nnApiLibHandle) throws NnApiDelegationFailure;
80 
destroyModel(long modelHandle)81     private synchronized native void destroyModel(long modelHandle);
82 
resizeInputTensors(long modelHandle, int[] inputShape)83     private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape);
84 
runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)85     private synchronized native boolean runBenchmark(long modelHandle,
86             List<InferenceInOutSequence> inOutList,
87             List<InferenceResult> resultList,
88             int inferencesSeqMaxCount,
89             float timeoutSec,
90             int flags);
91 
runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec, boolean useNnapiSl)92     private synchronized native CompilationBenchmarkResult runCompilationBenchmark(
93         long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec,
94         boolean useNnapiSl);
95 
dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)96     private synchronized native void dumpAllLayers(
97             long modelHandle,
98             String dumpPath,
99             List<InferenceInOutSequence> inOutList);
100 
availableAcceleratorNames()101     public static List<String> availableAcceleratorNames() {
102         List<String> availableAccelerators = new ArrayList<>();
103         if (NNTestBase.getAcceleratorNames(availableAccelerators)) {
104             return availableAccelerators.stream().filter(
105                     acceleratorName -> !acceleratorName.equalsIgnoreCase(
106                             "nnapi-reference")).collect(Collectors.toList());
107         } else {
108             Log.e(TAG, "Unable to retrieve accelerator names!!");
109             return Collections.EMPTY_LIST;
110         }
111     }
112 
113     /** Discard inference output in inference results. */
114     public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
115     /**
116      * Do not expect golden outputs with inference inputs.
117      *
118      * Useful in cases where there's no straightforward golden output values
119      * for the benchmark. This will also skip calculating basic (golden
120      * output based) error metrics.
121      */
122     public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
123 
124 
125     /** Collect only 1 benchmark result every 10 **/
126     public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2;
127 
128     protected Context mContext;
129     protected TextView mText;
130     private final String mModelName;
131     private final String mModelFile;
132     private long mModelHandle;
133     private final int[] mInputShape;
134     private final InferenceInOutSequence.FromAssets[] mInputOutputAssets;
135     private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets;
136     private final EvaluatorConfig mEvaluatorConfig;
137     private EvaluatorInterface mEvaluator;
138     private boolean mHasGoldenOutputs;
139     private TfLiteBackend mTfLiteBackend;
140     private boolean mEnableIntermediateTensorsDump = false;
141     private final int mMinSdkVersion;
142     private Optional<String> mNNApiDeviceName = Optional.empty();
143     private boolean mMmapModel = false;
144     // Path where the current model has been stored for execution
145     private String mTemporaryModelFilePath;
146     private boolean mSampleResults;
147 
148     // If set to true the test will look for the NNAPI SL binaries in the app resources,
149     // copy them into the app cache dir and configure the TfLite test to load NNAPI
150     // from the library.
151     private boolean mUseNnApiSupportLibrary = false;
152     private boolean mExtractNnApiSupportLibrary = false;
153 
154     static final String USE_NNAPI_SL_PROPERTY = "useNnApiSupportLibrary";
155     static final String EXTRACT_NNAPI_SL_PROPERTY = "extractNnApiSupportLibrary";
156 
getBooleanTestParameter(String key, boolean defaultValue)157     private static boolean getBooleanTestParameter(String key, boolean defaultValue) {
158       // All instrumentation arguments are passed as String so I have to convert the value here.
159       return Boolean.parseBoolean(
160           InstrumentationRegistry.getArguments().getString(key, "" + defaultValue));
161     }
162 
shouldUseNnApiSupportLibrary()163     public static boolean shouldUseNnApiSupportLibrary() {
164       return getBooleanTestParameter(USE_NNAPI_SL_PROPERTY, false);
165     }
166 
shouldExtractNnApiSupportLibrary()167     public static boolean shouldExtractNnApiSupportLibrary() {
168         return getBooleanTestParameter(EXTRACT_NNAPI_SL_PROPERTY, false);
169     }
170 
NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)171     public NNTestBase(String modelName, String modelFile, int[] inputShape,
172             InferenceInOutSequence.FromAssets[] inputOutputAssets,
173             InferenceInOutSequence.FromDataset[] inputOutputDatasets,
174             EvaluatorConfig evaluator, int minSdkVersion) {
175         if (inputOutputAssets == null && inputOutputDatasets == null) {
176             throw new IllegalArgumentException(
177                     "Neither inputOutputAssets or inputOutputDatasets given - no inputs");
178         }
179         if (inputOutputAssets != null && inputOutputDatasets != null) {
180             throw new IllegalArgumentException(
181                     "Both inputOutputAssets or inputOutputDatasets given. Only one" +
182                             "supported at once.");
183         }
184         mModelName = modelName;
185         mModelFile = modelFile;
186         mInputShape = inputShape;
187         mInputOutputAssets = inputOutputAssets;
188         mInputOutputDatasets = inputOutputDatasets;
189         mModelHandle = 0;
190         mEvaluatorConfig = evaluator;
191         mMinSdkVersion = minSdkVersion;
192         mSampleResults = false;
193     }
194 
setTfLiteBackend(TfLiteBackend tfLiteBackend)195     public void setTfLiteBackend(TfLiteBackend tfLiteBackend) {
196         mTfLiteBackend = tfLiteBackend;
197     }
198 
enableIntermediateTensorsDump()199     public void enableIntermediateTensorsDump() {
200         enableIntermediateTensorsDump(true);
201     }
202 
enableIntermediateTensorsDump(boolean value)203     public void enableIntermediateTensorsDump(boolean value) {
204         mEnableIntermediateTensorsDump = value;
205     }
206 
useNNApi()207     public void useNNApi() {
208       setTfLiteBackend(TfLiteBackend.NNAPI);
209     }
210 
setUseNnApiSupportLibrary(boolean value)211     public  void setUseNnApiSupportLibrary(boolean value) {mUseNnApiSupportLibrary = value;}
setExtractNnApiSupportLibrary(boolean value)212     public  void setExtractNnApiSupportLibrary(boolean value) {mExtractNnApiSupportLibrary = value;}
213 
setNNApiDeviceName(String value)214     public void setNNApiDeviceName(String value) {
215         if (mTfLiteBackend != TfLiteBackend.NNAPI) {
216             Log.e(TAG, "Setting device name has no effect when not using NNAPI");
217         }
218         mNNApiDeviceName = Optional.ofNullable(value);
219     }
220 
setMmapModel(boolean value)221     public void setMmapModel(boolean value) {
222         mMmapModel = value;
223     }
224 
setupModel(Context ipcxt)225     public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure {
226         mContext = ipcxt;
227         long nnApiLibHandle = 0;
228         if (mUseNnApiSupportLibrary) {
229           // TODO: support different drivers providers maybe with a flag
230           QualcommSupportLibraryDriverHandler qcSlhandler = new QualcommSupportLibraryDriverHandler();
231           nnApiLibHandle = qcSlhandler.getOrLoadNnApiSlHandle(mContext, mExtractNnApiSupportLibrary);
232           if (nnApiLibHandle == 0) {
233             Log.e(TAG, String
234                 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.",
235                     SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME));
236             throw new NnApiDelegationFailure(String
237                 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.",
238                     SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME));
239           }
240         }
241         if (mTemporaryModelFilePath != null) {
242             deleteOrWarn(mTemporaryModelFilePath);
243         }
244         mTemporaryModelFilePath = copyAssetToFile();
245         String nnApiCacheDir = mContext.getCodeCacheDir().toString();
246         mModelHandle = initModel(
247                 mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump,
248                 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir, nnApiLibHandle);
249         if (mModelHandle == 0) {
250             Log.e(TAG, "Failed to init the model");
251             return false;
252         }
253         if (!resizeInputTensors(mModelHandle, mInputShape)) {
254             return false;
255         }
256 
257         if (mEvaluatorConfig != null) {
258             mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets());
259         }
260         return true;
261     }
262 
getTestInfo()263     public String getTestInfo() {
264         return mModelName;
265     }
266 
getEvaluator()267     public EvaluatorInterface getEvaluator() {
268         return mEvaluator;
269     }
270 
checkSdkVersion()271     public void checkSdkVersion() throws UnsupportedSdkException {
272         if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) {
273             throw new UnsupportedSdkException("SDK version not supported. Mininum required: " +
274                     mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT);
275         }
276     }
277 
deleteOrWarn(String path)278     private void deleteOrWarn(String path) {
279         if (!new File(path).delete()) {
280             Log.w(TAG, String.format(
281                     "Unable to delete file '%s'. This might cause device to run out of space.",
282                     path));
283         }
284     }
285 
286 
getInputOutputAssets()287     private List<InferenceInOutSequence> getInputOutputAssets() throws IOException {
288         // TODO: Caching, don't read inputs for every inference
289         List<InferenceInOutSequence> inOutList =
290                 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets);
291 
292         Boolean lastGolden = null;
293         for (InferenceInOutSequence sequence : inOutList) {
294             mHasGoldenOutputs = sequence.hasGoldenOutput();
295             if (lastGolden == null) {
296                 lastGolden = mHasGoldenOutputs;
297             } else {
298                 if (lastGolden != mHasGoldenOutputs) {
299                     throw new IllegalArgumentException(
300                             "Some inputs for " + mModelName + " have outputs while some don't.");
301                 }
302             }
303         }
304         return inOutList;
305     }
306 
getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)307     public static List<InferenceInOutSequence> getInputOutputAssets(Context context,
308             InferenceInOutSequence.FromAssets[] inputOutputAssets,
309             InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException {
310         // TODO: Caching, don't read inputs for every inference
311         List<InferenceInOutSequence> inOutList = new ArrayList<>();
312         if (inputOutputAssets != null) {
313             for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) {
314                 inOutList.add(ioAsset.readAssets(context.getAssets()));
315             }
316         }
317         if (inputOutputDatasets != null) {
318             for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) {
319                 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir()));
320             }
321         }
322 
323         return inOutList;
324     }
325 
getDefaultFlags()326     public int getDefaultFlags() {
327         int flags = 0;
328         if (!mHasGoldenOutputs) {
329             flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT;
330         }
331         if (mEvaluator == null) {
332             flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
333         }
334         // For very long tests we will collect only a sample of the results
335         if (mSampleResults) {
336             flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS;
337         }
338         return flags;
339     }
340 
dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)341     public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)
342             throws IOException {
343         if (!dumpDir.exists() || !dumpDir.isDirectory()) {
344             throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory");
345         }
346         if (!mEnableIntermediateTensorsDump) {
347             throw new IllegalStateException("mEnableIntermediateTensorsDump is " +
348                     "set to false, impossible to proceed");
349         }
350 
351         List<InferenceInOutSequence> ios = getInputOutputAssets();
352         dumpAllLayers(mModelHandle, dumpDir.toString(),
353                 ios.subList(inputAssetIndex, inputAssetSize));
354     }
355 
runInferenceOnce()356     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce()
357             throws IOException, BenchmarkException {
358         List<InferenceInOutSequence> ios = getInputOutputAssets();
359         int flags = getDefaultFlags();
360         Pair<List<InferenceInOutSequence>, List<InferenceResult>> output =
361                 runBenchmark(ios, 1, Float.MAX_VALUE, flags);
362         return output;
363     }
364 
runBenchmark(float timeoutSec)365     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec)
366             throws IOException, BenchmarkException {
367         // Run as many as possible before timeout.
368         int flags = getDefaultFlags();
369         return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
370     }
371 
372     /** Run through whole input set (once or multiple times). */
runBenchmarkCompleteInputSet( int minInferences, float timeoutSec)373     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
374             int minInferences,
375             float timeoutSec)
376             throws IOException, BenchmarkException {
377         int flags = getDefaultFlags();
378         List<InferenceInOutSequence> ios = getInputOutputAssets();
379         int setInferences = 0;
380         for (InferenceInOutSequence iosSeq : ios) {
381             setInferences += iosSeq.size();
382         }
383         int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil.
384         int totalSequenceInferencesCount = ios.size() * setRepeat;
385         int expectedResults = setInferences * setRepeat;
386 
387         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
388                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
389                         flags);
390         if (result.second.size() != expectedResults) {
391             // We reached a timeout or failed to evaluate whole set for other reason, abort.
392             @SuppressLint("DefaultLocale")
393             final String errorMsg = String.format(
394                     "Failed to evaluate complete input set, in %f seconds expected: %d, received:"
395                             + " %d",
396                     timeoutSec, expectedResults, result.second.size());
397             Log.w(TAG, errorMsg);
398             throw new IllegalStateException(errorMsg);
399         }
400         return result;
401     }
402 
runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)403     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(
404             List<InferenceInOutSequence> inOutList,
405             int inferencesSeqMaxCount,
406             float timeoutSec,
407             int flags)
408             throws IOException, BenchmarkException {
409         if (mModelHandle == 0) {
410             throw new UnsupportedModelException("Unsupported model");
411         }
412         List<InferenceResult> resultList = new ArrayList<>();
413         if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
414                 timeoutSec, flags)) {
415             throw new BenchmarkException("Failed to run benchmark");
416         }
417         return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
418                 inOutList, resultList);
419     }
420 
runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)421     public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec,
422             float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException {
423         if (mModelHandle == 0) {
424             throw new UnsupportedModelException("Unsupported model");
425         }
426         CompilationBenchmarkResult result = runCompilationBenchmark(
427             mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec,
428             shouldUseNnApiSupportLibrary());
429         if (result == null) {
430             throw new BenchmarkException("Failed to run compilation benchmark");
431         }
432         return result;
433     }
434 
destroy()435     public void destroy() {
436         if (mModelHandle != 0) {
437             destroyModel(mModelHandle);
438             mModelHandle = 0;
439         }
440         if (mTemporaryModelFilePath != null) {
441             deleteOrWarn(mTemporaryModelFilePath);
442             mTemporaryModelFilePath = null;
443         }
444     }
445 
446     private final Random mRandom = new Random(System.currentTimeMillis());
447 
448     // We need to copy it to cache dir, so that TFlite can load it directly.
copyAssetToFile()449     private String copyAssetToFile() throws IOException {
450         @SuppressLint("DefaultLocale")
451         String outFileName =
452                 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(),
453                         mModelFile,
454                         Thread.currentThread().getId(), mRandom.nextInt(10000));
455 
456         copyAssetToFile(mContext, mModelFile + ".tflite", outFileName);
457         return outFileName;
458     }
459 
copyModelToFile(Context context, String modelFileName, File targetFile)460     public static boolean copyModelToFile(Context context, String modelFileName, File targetFile)
461             throws IOException {
462         if (!targetFile.exists() && !targetFile.createNewFile()) {
463             Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath()));
464             return false;
465         }
466         NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath());
467         return true;
468     }
469 
copyAssetToFile(Context context, String modelAssetName, String targetPath)470     public static void copyAssetToFile(Context context, String modelAssetName, String targetPath)
471             throws IOException {
472         AssetManager assetManager = context.getAssets();
473         try {
474             File outFile = new File(targetPath);
475 
476             try (InputStream in = assetManager.open(modelAssetName);
477                  FileOutputStream out = new FileOutputStream(outFile)) {
478                 copyFull(in, out);
479             }
480         } catch (IOException e) {
481             Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
482             throw e;
483         }
484     }
485 
copyFull(InputStream in, OutputStream out)486     public static void copyFull(InputStream in, OutputStream out) throws IOException {
487         byte[] byteBuffer = new byte[1024];
488         int readBytes = -1;
489         while ((readBytes = in.read(byteBuffer)) != -1) {
490             out.write(byteBuffer, 0, readBytes);
491         }
492     }
493 
494     @Override
close()495     public void close() {
496         destroy();
497     }
498 
setSampleResult(boolean sampleResults)499     public void setSampleResult(boolean sampleResults) {
500         this.mSampleResults = sampleResults;
501     }
502 }
503