1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.util; 18 19 import android.app.Activity; 20 import android.os.Bundle; 21 import android.util.Log; 22 23 import com.android.nn.benchmark.core.NNTestBase; 24 import com.android.nn.benchmark.core.TestModels; 25 import com.android.nn.benchmark.core.TestModels.TestModelEntry; 26 import com.android.nn.benchmark.core.TfLiteBackend; 27 28 import java.io.File; 29 30 31 /** 32 * Helper activity for dumping state of interference intermediate tensors. 33 * 34 * Example usage: 35 * adb shell am start -n com.android.nn.benchmark.app/com.android.nn.benchmark.\ 36 * util.DumpIntermediateTensors --es modelName mobilenet_v1_1.0_224_quant_topk_aosp,tts_float\ 37 * inputAssetIndex 0 38 * 39 * Assets will be then dumped into /data/data/com.android.nn.benchmark.app/files/intermediate 40 * To fetch: 41 * adb pull /data/data/com.android.nn.benchmark.app/files/intermediate 42 */ 43 public class DumpIntermediateTensors extends Activity { 44 protected static final String TAG = "VDEBUG"; 45 public static final String EXTRA_MODEL_NAME = "modelName"; 46 public static final String EXTRA_INPUT_ASSET_INDEX = "inputAssetIndex"; 47 public static final String EXTRA_INPUT_ASSET_SIZE = "inputAssetSize"; 48 public static final String EXTRA_TFLITE_BACKEND = "tfLiteBackend"; 49 public static final String DUMP_DIR = "intermediate"; 50 public static final String CPU_DIR = "cpu"; 51 public static final String NNAPI_DIR = "nnapi"; 52 // TODO(veralin): Update to use other models in vendor as well. 53 // Due to recent change in NNScoringTest, the model names are moved to here. 54 private static final String[] MODEL_NAMES = new String[]{ 55 "tts_float", 56 "asr_float", 57 "mobilenet_v1_1.0_224_quant_topk_aosp", 58 "mobilenet_v1_1.0_224_topk_aosp", 59 "mobilenet_v1_0.75_192_quant_topk_aosp", 60 "mobilenet_v1_0.75_192_topk_aosp", 61 "mobilenet_v1_0.5_160_quant_topk_aosp", 62 "mobilenet_v1_0.5_160_topk_aosp", 63 "mobilenet_v1_0.25_128_quant_topk_aosp", 64 "mobilenet_v1_0.25_128_topk_aosp", 65 "mobilenet_v2_0.35_128_topk_aosp", 66 "mobilenet_v2_0.5_160_topk_aosp", 67 "mobilenet_v2_0.75_192_topk_aosp", 68 "mobilenet_v2_1.0_224_topk_aosp", 69 "mobilenet_v2_1.0_224_quant_topk_aosp", 70 }; 71 72 @Override onCreate(Bundle savedInstanceState)73 protected void onCreate(Bundle savedInstanceState) { 74 super.onCreate(savedInstanceState); 75 Bundle extras = getIntent().getExtras(); 76 77 String userModelName = extras.getString(EXTRA_MODEL_NAME); 78 int inputAssetIndex = extras.getInt(EXTRA_INPUT_ASSET_INDEX, 0); 79 int inputAssetSize = extras.getInt(EXTRA_INPUT_ASSET_SIZE, 1); 80 81 // Default to run all models in NNScoringTest 82 String[] modelNames = userModelName == null ? MODEL_NAMES : userModelName.split(","); 83 84 try { 85 File dumpDir = new File(getFilesDir(), DUMP_DIR); 86 safeMkdir(dumpDir); 87 88 for (String modelName : modelNames) { 89 File modelDir = new File(getFilesDir() + "/" + DUMP_DIR, modelName); 90 safeMkdir(modelDir); 91 // Run in CPU and NNAPI mode 92 for (final boolean useNNAPI : new boolean[]{false, true}) { 93 String useNNAPIDir = useNNAPI ? NNAPI_DIR : CPU_DIR; 94 TfLiteBackend backend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU; 95 TestModelEntry modelEntry = TestModels.getModelByName(modelName); 96 try (NNTestBase testBase = modelEntry.createNNTestBase( 97 backend, /*enableIntermediateTensorsDump*/true, /*mmapModel*/false)) { 98 testBase.setupModel(this); 99 File outputDir = new File(getFilesDir() + "/" + DUMP_DIR + 100 "/" + modelName, useNNAPIDir); 101 safeMkdir(outputDir); 102 testBase.dumpAllLayers(outputDir, inputAssetIndex, inputAssetSize); 103 } 104 } 105 } 106 107 } catch (Exception e) { 108 Log.e(TAG, "Failed to dump tensors", e); 109 throw new IllegalStateException("Failed to dump tensors", e); 110 } 111 finish(); 112 } 113 deleteRecursive(File fileOrDirectory)114 private void deleteRecursive(File fileOrDirectory) { 115 if (fileOrDirectory.isDirectory()) { 116 for (File child : fileOrDirectory.listFiles()) { 117 deleteRecursive(child); 118 } 119 } 120 fileOrDirectory.delete(); 121 } 122 safeMkdir(File fileOrDirectory)123 private void safeMkdir(File fileOrDirectory) { 124 deleteRecursive(fileOrDirectory); 125 fileOrDirectory.mkdir(); 126 } 127 } 128