1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.demo; 17 18 import android.content.res.AssetManager; 19 import android.graphics.Bitmap; 20 import android.graphics.RectF; 21 import android.os.Trace; 22 import java.io.BufferedReader; 23 import java.io.FileInputStream; 24 import java.io.IOException; 25 import java.io.InputStream; 26 import java.io.InputStreamReader; 27 import java.util.ArrayList; 28 import java.util.Comparator; 29 import java.util.List; 30 import java.util.PriorityQueue; 31 import java.util.StringTokenizer; 32 import org.tensorflow.Graph; 33 import org.tensorflow.Operation; 34 import org.tensorflow.contrib.android.TensorFlowInferenceInterface; 35 import org.tensorflow.demo.env.Logger; 36 37 /** 38 * A detector for general purpose object detection as described in Scalable Object Detection using 39 * Deep Neural Networks (https://arxiv.org/abs/1312.2249). 40 */ 41 public class TensorFlowMultiBoxDetector implements Classifier { 42 private static final Logger LOGGER = new Logger(); 43 44 // Only return this many results. 45 private static final int MAX_RESULTS = Integer.MAX_VALUE; 46 47 // Config values. 48 private String inputName; 49 private int inputSize; 50 private int imageMean; 51 private float imageStd; 52 53 // Pre-allocated buffers. 54 private int[] intValues; 55 private float[] floatValues; 56 private float[] outputLocations; 57 private float[] outputScores; 58 private String[] outputNames; 59 private int numLocations; 60 61 private boolean logStats = false; 62 63 private TensorFlowInferenceInterface inferenceInterface; 64 65 private float[] boxPriors; 66 67 /** 68 * Initializes a native TensorFlow session for classifying images. 69 * 70 * @param assetManager The asset manager to be used to load assets. 71 * @param modelFilename The filepath of the model GraphDef protocol buffer. 72 * @param locationFilename The filepath of label file for classes. 73 * @param inputSize The input size. A square image of inputSize x inputSize is assumed. 74 * @param imageMean The assumed mean of the image values. 75 * @param imageStd The assumed std of the image values. 76 * @param inputName The label of the image input node. 77 * @param outputName The label of the output node. 78 */ create( final AssetManager assetManager, final String modelFilename, final String locationFilename, final int imageMean, final float imageStd, final String inputName, final String outputLocationsName, final String outputScoresName)79 public static Classifier create( 80 final AssetManager assetManager, 81 final String modelFilename, 82 final String locationFilename, 83 final int imageMean, 84 final float imageStd, 85 final String inputName, 86 final String outputLocationsName, 87 final String outputScoresName) { 88 final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector(); 89 90 d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); 91 92 final Graph g = d.inferenceInterface.graph(); 93 94 d.inputName = inputName; 95 // The inputName node has a shape of [N, H, W, C], where 96 // N is the batch size 97 // H = W are the height and width 98 // C is the number of channels (3 for our purposes - RGB) 99 final Operation inputOp = g.operation(inputName); 100 if (inputOp == null) { 101 throw new RuntimeException("Failed to find input Node '" + inputName + "'"); 102 } 103 d.inputSize = (int) inputOp.output(0).shape().size(1); 104 d.imageMean = imageMean; 105 d.imageStd = imageStd; 106 // The outputScoresName node has a shape of [N, NumLocations], where N 107 // is the batch size. 108 final Operation outputOp = g.operation(outputScoresName); 109 if (outputOp == null) { 110 throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'"); 111 } 112 d.numLocations = (int) outputOp.output(0).shape().size(1); 113 114 d.boxPriors = new float[d.numLocations * 8]; 115 116 try { 117 d.loadCoderOptions(assetManager, locationFilename, d.boxPriors); 118 } catch (final IOException e) { 119 throw new RuntimeException("Error initializing box priors from " + locationFilename); 120 } 121 122 // Pre-allocate buffers. 123 d.outputNames = new String[] {outputLocationsName, outputScoresName}; 124 d.intValues = new int[d.inputSize * d.inputSize]; 125 d.floatValues = new float[d.inputSize * d.inputSize * 3]; 126 d.outputScores = new float[d.numLocations]; 127 d.outputLocations = new float[d.numLocations * 4]; 128 129 return d; 130 } 131 TensorFlowMultiBoxDetector()132 private TensorFlowMultiBoxDetector() {} 133 loadCoderOptions( final AssetManager assetManager, final String locationFilename, final float[] boxPriors)134 private void loadCoderOptions( 135 final AssetManager assetManager, final String locationFilename, final float[] boxPriors) 136 throws IOException { 137 // Try to be intelligent about opening from assets or sdcard depending on prefix. 138 final String assetPrefix = "file:///android_asset/"; 139 InputStream is; 140 if (locationFilename.startsWith(assetPrefix)) { 141 is = assetManager.open(locationFilename.split(assetPrefix)[1]); 142 } else { 143 is = new FileInputStream(locationFilename); 144 } 145 146 // Read values. Number of values per line doesn't matter, as long as they are separated 147 // by commas and/or whitespace, and there are exactly numLocations * 8 values total. 148 // Values are in the order mean, std for each consecutive corner of each box, for a total of 8 149 // per location. 150 final BufferedReader reader = new BufferedReader(new InputStreamReader(is)); 151 int priorIndex = 0; 152 String line; 153 while ((line = reader.readLine()) != null) { 154 final StringTokenizer st = new StringTokenizer(line, ", "); 155 while (st.hasMoreTokens()) { 156 final String token = st.nextToken(); 157 try { 158 final float number = Float.parseFloat(token); 159 boxPriors[priorIndex++] = number; 160 } catch (final NumberFormatException e) { 161 // Silently ignore. 162 } 163 } 164 } 165 if (priorIndex != boxPriors.length) { 166 throw new RuntimeException( 167 "BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length); 168 } 169 } 170 decodeLocationsEncoding(final float[] locationEncoding)171 private float[] decodeLocationsEncoding(final float[] locationEncoding) { 172 final float[] locations = new float[locationEncoding.length]; 173 boolean nonZero = false; 174 for (int i = 0; i < numLocations; ++i) { 175 for (int j = 0; j < 4; ++j) { 176 final float currEncoding = locationEncoding[4 * i + j]; 177 nonZero = nonZero || currEncoding != 0.0f; 178 179 final float mean = boxPriors[i * 8 + j * 2]; 180 final float stdDev = boxPriors[i * 8 + j * 2 + 1]; 181 float currentLocation = currEncoding * stdDev + mean; 182 currentLocation = Math.max(currentLocation, 0.0f); 183 currentLocation = Math.min(currentLocation, 1.0f); 184 locations[4 * i + j] = currentLocation; 185 } 186 } 187 188 if (!nonZero) { 189 LOGGER.w("No non-zero encodings; check log for inference errors."); 190 } 191 return locations; 192 } 193 decodeScoresEncoding(final float[] scoresEncoding)194 private float[] decodeScoresEncoding(final float[] scoresEncoding) { 195 final float[] scores = new float[scoresEncoding.length]; 196 for (int i = 0; i < scoresEncoding.length; ++i) { 197 scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i]))); 198 } 199 return scores; 200 } 201 202 @Override recognizeImage(final Bitmap bitmap)203 public List<Recognition> recognizeImage(final Bitmap bitmap) { 204 // Log this method so that it can be analyzed with systrace. 205 Trace.beginSection("recognizeImage"); 206 207 Trace.beginSection("preprocessBitmap"); 208 // Preprocess the image data from 0-255 int to normalized float based 209 // on the provided parameters. 210 bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 211 212 for (int i = 0; i < intValues.length; ++i) { 213 floatValues[i * 3 + 0] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd; 214 floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd; 215 floatValues[i * 3 + 2] = ((intValues[i] & 0xFF) - imageMean) / imageStd; 216 } 217 Trace.endSection(); // preprocessBitmap 218 219 // Copy the input data into TensorFlow. 220 Trace.beginSection("feed"); 221 inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); 222 Trace.endSection(); 223 224 // Run the inference call. 225 Trace.beginSection("run"); 226 inferenceInterface.run(outputNames, logStats); 227 Trace.endSection(); 228 229 // Copy the output Tensor back into the output array. 230 Trace.beginSection("fetch"); 231 final float[] outputScoresEncoding = new float[numLocations]; 232 final float[] outputLocationsEncoding = new float[numLocations * 4]; 233 inferenceInterface.fetch(outputNames[0], outputLocationsEncoding); 234 inferenceInterface.fetch(outputNames[1], outputScoresEncoding); 235 Trace.endSection(); 236 237 outputLocations = decodeLocationsEncoding(outputLocationsEncoding); 238 outputScores = decodeScoresEncoding(outputScoresEncoding); 239 240 // Find the best detections. 241 final PriorityQueue<Recognition> pq = 242 new PriorityQueue<Recognition>( 243 1, 244 new Comparator<Recognition>() { 245 @Override 246 public int compare(final Recognition lhs, final Recognition rhs) { 247 // Intentionally reversed to put high confidence at the head of the queue. 248 return Float.compare(rhs.getConfidence(), lhs.getConfidence()); 249 } 250 }); 251 252 // Scale them back to the input size. 253 for (int i = 0; i < outputScores.length; ++i) { 254 final RectF detection = 255 new RectF( 256 outputLocations[4 * i] * inputSize, 257 outputLocations[4 * i + 1] * inputSize, 258 outputLocations[4 * i + 2] * inputSize, 259 outputLocations[4 * i + 3] * inputSize); 260 pq.add(new Recognition("" + i, null, outputScores[i], detection)); 261 } 262 263 final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); 264 for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { 265 recognitions.add(pq.poll()); 266 } 267 Trace.endSection(); // "recognizeImage" 268 return recognitions; 269 } 270 271 @Override enableStatLogging(final boolean logStats)272 public void enableStatLogging(final boolean logStats) { 273 this.logStats = logStats; 274 } 275 276 @Override getStatString()277 public String getStatString() { 278 return inferenceInterface.getStatString(); 279 } 280 281 @Override close()282 public void close() { 283 inferenceInterface.close(); 284 } 285 } 286