1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.content.res.AssetManager; 20 21 import org.json.JSONArray; 22 import org.json.JSONException; 23 import org.json.JSONObject; 24 25 import java.io.IOException; 26 import java.io.InputStream; 27 import java.io.InputStreamReader; 28 import java.io.Reader; 29 30 /** Helper class to register test model definitions from assets data */ 31 public class TestModelsListLoader { 32 33 /** 34 * Parse list of models in form of json data. 35 * 36 * Example input: 37 * { "models" : [ 38 * {"name" : "modelName", 39 * "testName" : "testName", 40 * "baselineSec" : 0.03, 41 * "evaluator": "TopK", 42 * "inputSize" : [1,2,3,4], 43 * "dataSize" : 4, 44 * "inputOutputs" : [ {"input": "input1", "output": "output2"} ] 45 * } 46 * ]} 47 */ parseJSONModelsList(String jsonStringInput)48 static public void parseJSONModelsList(String jsonStringInput) throws JSONException { 49 JSONObject jsonRootObject = new JSONObject(jsonStringInput); 50 JSONArray jsonModelsArray = jsonRootObject.getJSONArray("models"); 51 52 for (int i = 0; i < jsonModelsArray.length(); i++) { 53 JSONObject jsonTestModelEntry = jsonModelsArray.getJSONObject(i); 54 55 String name = jsonTestModelEntry.getString("name"); 56 String testName = name; 57 if (jsonTestModelEntry.has("testName")) { 58 testName = jsonTestModelEntry.getString("testName"); 59 } 60 String modelFile = name; 61 if (jsonTestModelEntry.has("modelFile")) { 62 modelFile = jsonTestModelEntry.getString("modelFile"); 63 } 64 double baseline = jsonTestModelEntry.getDouble("baselineSec"); 65 int minSdkVersion = 0; 66 if (jsonTestModelEntry.has("minSdkVersion")) { 67 minSdkVersion = jsonTestModelEntry.getInt("minSdkVersion"); 68 } 69 EvaluatorConfig evaluator = null; 70 if (jsonTestModelEntry.has("evaluator")) { 71 JSONObject evaluatorJson = jsonTestModelEntry.getJSONObject("evaluator"); 72 evaluator = new EvaluatorConfig(evaluatorJson.getString("className"), 73 evaluatorJson.has("outputMeanStdDev") 74 ? evaluatorJson.getString("outputMeanStdDev") 75 : null, 76 evaluatorJson.has("expectedTop1") 77 ? evaluatorJson.getDouble("expectedTop1") 78 : null); 79 } 80 81 int dataSize = jsonTestModelEntry.getInt("dataSize"); 82 JSONArray jsonInputSize = jsonTestModelEntry.getJSONArray("inputSize"); 83 int[] inputSize = new int[jsonInputSize.length()]; 84 int inputSizeBytes = dataSize; 85 for (int k = 0; k < jsonInputSize.length(); ++k) { 86 inputSize[k] = jsonInputSize.getInt(k); 87 inputSizeBytes *= inputSize[k]; 88 } 89 90 InferenceInOutSequence.FromAssets[] inputOutputs = null; 91 if (jsonTestModelEntry.has("inputOutputs")) { 92 JSONArray jsonInputOutputs = jsonTestModelEntry.getJSONArray("inputOutputs"); 93 inputOutputs = 94 new InferenceInOutSequence.FromAssets[jsonInputOutputs.length()]; 95 96 for (int j = 0; j < jsonInputOutputs.length(); j++) { 97 JSONObject jsonInputOutput = jsonInputOutputs.getJSONObject(j); 98 String input = jsonInputOutput.getString("input"); 99 String[] outputs = null; 100 String output = jsonInputOutput.optString("output", null); 101 if (output != null) { 102 outputs = new String[]{output}; 103 } else { 104 JSONArray outputArray = jsonInputOutput.getJSONArray("outputs"); 105 if (outputArray != null) { 106 outputs = new String[outputArray.length()]; 107 for (int k = 0; k < outputArray.length(); ++k) { 108 outputs[k] = outputArray.getString(k); 109 } 110 } 111 } 112 113 inputOutputs[j] = new InferenceInOutSequence.FromAssets(input, outputs, 114 dataSize, 115 inputSizeBytes); 116 } 117 } 118 InferenceInOutSequence.FromDataset[] datasets = null; 119 if (jsonTestModelEntry.has("dataset")) { 120 JSONObject jsonDataset = jsonTestModelEntry.getJSONObject("dataset"); 121 String inputPath = jsonDataset.getString("inputPath"); 122 String groundTruth = jsonDataset.getString("groundTruth"); 123 String labels = jsonDataset.getString("labels"); 124 String preprocessor = jsonDataset.getString("preprocessor"); 125 if (inputSize.length != 4 || inputSize[0] != 1 || inputSize[1] != inputSize[2] || 126 inputSize[3] != 3) { 127 throw new IllegalArgumentException("Datasets only support square images," + 128 "input size [1, D, D, 3], given " + inputSize[0] + 129 ", " + inputSize[1] + ", " + inputSize[2] + ", " + inputSize[3]); 130 } 131 float quantScale = 0.f; 132 float quantZeroPoint = 0.f; 133 if (dataSize == 1) { 134 if (!jsonTestModelEntry.has("inputScale") || 135 !jsonTestModelEntry.has("inputZeroPoint")) { 136 throw new IllegalArgumentException("Quantized test model must include " + 137 "inputScale and inputZeroPoint for reading a dataset"); 138 } 139 quantScale = (float) jsonTestModelEntry.getDouble("inputScale"); 140 quantZeroPoint = (float) jsonTestModelEntry.getDouble("inputZeroPoint"); 141 } 142 datasets = new InferenceInOutSequence.FromDataset[]{ 143 new InferenceInOutSequence.FromDataset(inputPath, labels, groundTruth, 144 preprocessor, dataSize, quantScale, quantZeroPoint, inputSize[1]) 145 }; 146 } 147 148 TestModels.registerModel( 149 new TestModels.TestModelEntry(name, (float) baseline, inputSize, 150 inputOutputs, datasets, testName, modelFile, evaluator, minSdkVersion)); 151 } 152 } 153 readAssetsFileAsString(InputStream inputStream)154 static String readAssetsFileAsString(InputStream inputStream) throws IOException { 155 Reader reader = new InputStreamReader(inputStream); 156 StringBuilder sb = new StringBuilder(); 157 char buffer[] = new char[16384]; 158 int len; 159 while ((len = reader.read(buffer)) > 0) { 160 sb.append(buffer, 0, len); 161 } 162 reader.close(); 163 return sb.toString(); 164 } 165 166 /** Parse all ".json" files in root assets directory */ 167 private static final String MODELS_LIST_ROOT = "models_list"; 168 parseFromAssets(AssetManager assetManager)169 static public void parseFromAssets(AssetManager assetManager) throws IOException { 170 for (String file : assetManager.list(MODELS_LIST_ROOT)) { 171 if (!file.endsWith(".json")) { 172 continue; 173 } 174 try { 175 parseJSONModelsList(readAssetsFileAsString( 176 assetManager.open(MODELS_LIST_ROOT + "/" + file))); 177 } catch (JSONException e) { 178 throw new IOException("JSON error in " + file, e); 179 } catch (Exception e) { 180 // Wrap exception to add a filename to it 181 throw new IOException("Error while parsing " + file, e); 182 } 183 184 } 185 } 186 } 187