• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.content.res.AssetManager;
20 
21 import org.json.JSONArray;
22 import org.json.JSONException;
23 import org.json.JSONObject;
24 
25 import java.io.IOException;
26 import java.io.InputStream;
27 import java.io.InputStreamReader;
28 import java.io.Reader;
29 
30 /** Helper class to register test model definitions from assets data */
31 public class TestModelsListLoader {
32 
33     /**
34      * Parse list of models in form of json data.
35      *
36      * Example input:
37      * { "models" : [
38      * {"name" : "modelName",
39      * "testName" : "testName",
40      * "baselineSec" : 0.03,
41      * "evaluator": "TopK",
42      * "inputSize" : [1,2,3,4],
43      * "dataSize" : 4,
44      * "inputOutputs" : [ {"input": "input1", "output": "output2"} ]
45      * }
46      * ]}
47      */
parseJSONModelsList(String jsonStringInput)48     static public void parseJSONModelsList(String jsonStringInput) throws JSONException {
49         JSONObject jsonRootObject = new JSONObject(jsonStringInput);
50         JSONArray jsonModelsArray = jsonRootObject.getJSONArray("models");
51 
52         for (int i = 0; i < jsonModelsArray.length(); i++) {
53             JSONObject jsonTestModelEntry = jsonModelsArray.getJSONObject(i);
54 
55             String name = jsonTestModelEntry.getString("name");
56             String testName = name;
57             if (jsonTestModelEntry.has("testName")) {
58                 testName = jsonTestModelEntry.getString("testName");
59             }
60             String modelFile = name;
61             if (jsonTestModelEntry.has("modelFile")) {
62                 modelFile = jsonTestModelEntry.getString("modelFile");
63             }
64             double baseline = jsonTestModelEntry.getDouble("baselineSec");
65             int minSdkVersion = 0;
66             if (jsonTestModelEntry.has("minSdkVersion")) {
67                 minSdkVersion = jsonTestModelEntry.getInt("minSdkVersion");
68             }
69             EvaluatorConfig evaluator = null;
70             if (jsonTestModelEntry.has("evaluator")) {
71                 JSONObject evaluatorJson = jsonTestModelEntry.getJSONObject("evaluator");
72                 evaluator = new EvaluatorConfig(evaluatorJson.getString("className"),
73                         evaluatorJson.has("outputMeanStdDev")
74                                 ? evaluatorJson.getString("outputMeanStdDev")
75                                 : null,
76                         evaluatorJson.has("expectedTop1")
77                                 ? evaluatorJson.getDouble("expectedTop1")
78                                 : null);
79             }
80 
81             int dataSize = jsonTestModelEntry.getInt("dataSize");
82             JSONArray jsonInputSize = jsonTestModelEntry.getJSONArray("inputSize");
83             int[] inputSize = new int[jsonInputSize.length()];
84             int inputSizeBytes = dataSize;
85             for (int k = 0; k < jsonInputSize.length(); ++k) {
86                 inputSize[k] = jsonInputSize.getInt(k);
87                 inputSizeBytes *= inputSize[k];
88             }
89 
90             InferenceInOutSequence.FromAssets[] inputOutputs = null;
91             if (jsonTestModelEntry.has("inputOutputs")) {
92                 JSONArray jsonInputOutputs = jsonTestModelEntry.getJSONArray("inputOutputs");
93                 inputOutputs =
94                         new InferenceInOutSequence.FromAssets[jsonInputOutputs.length()];
95 
96                 for (int j = 0; j < jsonInputOutputs.length(); j++) {
97                     JSONObject jsonInputOutput = jsonInputOutputs.getJSONObject(j);
98                     String input = jsonInputOutput.getString("input");
99                     String[] outputs = null;
100                     String output = jsonInputOutput.optString("output", null);
101                     if (output != null) {
102                         outputs = new String[]{output};
103                     } else {
104                         JSONArray outputArray = jsonInputOutput.getJSONArray("outputs");
105                         if (outputArray != null) {
106                             outputs = new String[outputArray.length()];
107                             for (int k = 0; k < outputArray.length(); ++k) {
108                                 outputs[k] = outputArray.getString(k);
109                             }
110                         }
111                     }
112 
113                     inputOutputs[j] = new InferenceInOutSequence.FromAssets(input, outputs,
114                             dataSize,
115                             inputSizeBytes);
116                 }
117             }
118             InferenceInOutSequence.FromDataset[] datasets = null;
119             if (jsonTestModelEntry.has("dataset")) {
120                 JSONObject jsonDataset = jsonTestModelEntry.getJSONObject("dataset");
121                 String inputPath = jsonDataset.getString("inputPath");
122                 String groundTruth = jsonDataset.getString("groundTruth");
123                 String labels = jsonDataset.getString("labels");
124                 String preprocessor = jsonDataset.getString("preprocessor");
125                 if (inputSize.length != 4 || inputSize[0] != 1 || inputSize[1] != inputSize[2] ||
126                         inputSize[3] != 3) {
127                     throw new IllegalArgumentException("Datasets only support square images," +
128                             "input size [1, D, D, 3], given " + inputSize[0] +
129                             ", " + inputSize[1] + ", " + inputSize[2] + ", " + inputSize[3]);
130                 }
131                 float quantScale = 0.f;
132                 float quantZeroPoint = 0.f;
133                 if (dataSize == 1) {
134                     if (!jsonTestModelEntry.has("inputScale") ||
135                             !jsonTestModelEntry.has("inputZeroPoint")) {
136                         throw new IllegalArgumentException("Quantized test model must include " +
137                                 "inputScale and inputZeroPoint for reading a dataset");
138                     }
139                     quantScale = (float) jsonTestModelEntry.getDouble("inputScale");
140                     quantZeroPoint = (float) jsonTestModelEntry.getDouble("inputZeroPoint");
141                 }
142                 datasets = new InferenceInOutSequence.FromDataset[]{
143                         new InferenceInOutSequence.FromDataset(inputPath, labels, groundTruth,
144                                 preprocessor, dataSize, quantScale, quantZeroPoint, inputSize[1])
145                 };
146             }
147 
148             TestModels.registerModel(
149                     new TestModels.TestModelEntry(name, (float) baseline, inputSize,
150                             inputOutputs, datasets, testName, modelFile, evaluator, minSdkVersion));
151         }
152     }
153 
readAssetsFileAsString(InputStream inputStream)154     static String readAssetsFileAsString(InputStream inputStream) throws IOException {
155         Reader reader = new InputStreamReader(inputStream);
156         StringBuilder sb = new StringBuilder();
157         char buffer[] = new char[16384];
158         int len;
159         while ((len = reader.read(buffer)) > 0) {
160             sb.append(buffer, 0, len);
161         }
162         reader.close();
163         return sb.toString();
164     }
165 
166     /** Parse all ".json" files in root assets directory */
167     private static final String MODELS_LIST_ROOT = "models_list";
168 
parseFromAssets(AssetManager assetManager)169     static public void parseFromAssets(AssetManager assetManager) throws IOException {
170         for (String file : assetManager.list(MODELS_LIST_ROOT)) {
171             if (!file.endsWith(".json")) {
172                 continue;
173             }
174             try {
175                 parseJSONModelsList(readAssetsFileAsString(
176                         assetManager.open(MODELS_LIST_ROOT + "/" + file)));
177             } catch (JSONException e) {
178                 throw new IOException("JSON error in " + file, e);
179             } catch (Exception e) {
180                 // Wrap exception to add a filename to it
181                 throw new IOException("Error while parsing " + file, e);
182             }
183 
184         }
185     }
186 }
187