• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.demo;
17 
18 import android.content.res.AssetManager;
19 import android.graphics.Bitmap;
20 import android.graphics.RectF;
21 import android.os.Trace;
22 import java.io.BufferedReader;
23 import java.io.FileInputStream;
24 import java.io.IOException;
25 import java.io.InputStream;
26 import java.io.InputStreamReader;
27 import java.util.ArrayList;
28 import java.util.Comparator;
29 import java.util.List;
30 import java.util.PriorityQueue;
31 import java.util.StringTokenizer;
32 import org.tensorflow.Graph;
33 import org.tensorflow.Operation;
34 import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
35 import org.tensorflow.demo.env.Logger;
36 
37 /**
38  * A detector for general purpose object detection as described in Scalable Object Detection using
39  * Deep Neural Networks (https://arxiv.org/abs/1312.2249).
40  */
41 public class TensorFlowMultiBoxDetector implements Classifier {
42   private static final Logger LOGGER = new Logger();
43 
44   // Only return this many results.
45   private static final int MAX_RESULTS = Integer.MAX_VALUE;
46 
47   // Config values.
48   private String inputName;
49   private int inputSize;
50   private int imageMean;
51   private float imageStd;
52 
53   // Pre-allocated buffers.
54   private int[] intValues;
55   private float[] floatValues;
56   private float[] outputLocations;
57   private float[] outputScores;
58   private String[] outputNames;
59   private int numLocations;
60 
61   private boolean logStats = false;
62 
63   private TensorFlowInferenceInterface inferenceInterface;
64 
65   private float[] boxPriors;
66 
67   /**
68    * Initializes a native TensorFlow session for classifying images.
69    *
70    * @param assetManager The asset manager to be used to load assets.
71    * @param modelFilename The filepath of the model GraphDef protocol buffer.
72    * @param locationFilename The filepath of label file for classes.
73    * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
74    * @param imageMean The assumed mean of the image values.
75    * @param imageStd The assumed std of the image values.
76    * @param inputName The label of the image input node.
77    * @param outputName The label of the output node.
78    */
create( final AssetManager assetManager, final String modelFilename, final String locationFilename, final int imageMean, final float imageStd, final String inputName, final String outputLocationsName, final String outputScoresName)79   public static Classifier create(
80       final AssetManager assetManager,
81       final String modelFilename,
82       final String locationFilename,
83       final int imageMean,
84       final float imageStd,
85       final String inputName,
86       final String outputLocationsName,
87       final String outputScoresName) {
88     final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
89 
90     d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
91 
92     final Graph g = d.inferenceInterface.graph();
93 
94     d.inputName = inputName;
95     // The inputName node has a shape of [N, H, W, C], where
96     // N is the batch size
97     // H = W are the height and width
98     // C is the number of channels (3 for our purposes - RGB)
99     final Operation inputOp = g.operation(inputName);
100     if (inputOp == null) {
101       throw new RuntimeException("Failed to find input Node '" + inputName + "'");
102     }
103     d.inputSize = (int) inputOp.output(0).shape().size(1);
104     d.imageMean = imageMean;
105     d.imageStd = imageStd;
106     // The outputScoresName node has a shape of [N, NumLocations], where N
107     // is the batch size.
108     final Operation outputOp = g.operation(outputScoresName);
109     if (outputOp == null) {
110       throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'");
111     }
112     d.numLocations = (int) outputOp.output(0).shape().size(1);
113 
114     d.boxPriors = new float[d.numLocations * 8];
115 
116     try {
117       d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
118     } catch (final IOException e) {
119       throw new RuntimeException("Error initializing box priors from " + locationFilename);
120     }
121 
122     // Pre-allocate buffers.
123     d.outputNames = new String[] {outputLocationsName, outputScoresName};
124     d.intValues = new int[d.inputSize * d.inputSize];
125     d.floatValues = new float[d.inputSize * d.inputSize * 3];
126     d.outputScores = new float[d.numLocations];
127     d.outputLocations = new float[d.numLocations * 4];
128 
129     return d;
130   }
131 
TensorFlowMultiBoxDetector()132   private TensorFlowMultiBoxDetector() {}
133 
loadCoderOptions( final AssetManager assetManager, final String locationFilename, final float[] boxPriors)134   private void loadCoderOptions(
135       final AssetManager assetManager, final String locationFilename, final float[] boxPriors)
136       throws IOException {
137     // Try to be intelligent about opening from assets or sdcard depending on prefix.
138     final String assetPrefix = "file:///android_asset/";
139     InputStream is;
140     if (locationFilename.startsWith(assetPrefix)) {
141       is = assetManager.open(locationFilename.split(assetPrefix)[1]);
142     } else {
143       is = new FileInputStream(locationFilename);
144     }
145 
146     // Read values. Number of values per line doesn't matter, as long as they are separated
147     // by commas and/or whitespace, and there are exactly numLocations * 8 values total.
148     // Values are in the order mean, std for each consecutive corner of each box, for a total of 8
149     // per location.
150     final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
151     int priorIndex = 0;
152     String line;
153     while ((line = reader.readLine()) != null) {
154       final StringTokenizer st = new StringTokenizer(line, ", ");
155       while (st.hasMoreTokens()) {
156         final String token = st.nextToken();
157         try {
158           final float number = Float.parseFloat(token);
159           boxPriors[priorIndex++] = number;
160         } catch (final NumberFormatException e) {
161           // Silently ignore.
162         }
163       }
164     }
165     if (priorIndex != boxPriors.length) {
166       throw new RuntimeException(
167           "BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length);
168     }
169   }
170 
decodeLocationsEncoding(final float[] locationEncoding)171   private float[] decodeLocationsEncoding(final float[] locationEncoding) {
172     final float[] locations = new float[locationEncoding.length];
173     boolean nonZero = false;
174     for (int i = 0; i < numLocations; ++i) {
175       for (int j = 0; j < 4; ++j) {
176         final float currEncoding = locationEncoding[4 * i + j];
177         nonZero = nonZero || currEncoding != 0.0f;
178 
179         final float mean = boxPriors[i * 8 + j * 2];
180         final float stdDev = boxPriors[i * 8 + j * 2 + 1];
181         float currentLocation = currEncoding * stdDev + mean;
182         currentLocation = Math.max(currentLocation, 0.0f);
183         currentLocation = Math.min(currentLocation, 1.0f);
184         locations[4 * i + j] = currentLocation;
185       }
186     }
187 
188     if (!nonZero) {
189       LOGGER.w("No non-zero encodings; check log for inference errors.");
190     }
191     return locations;
192   }
193 
decodeScoresEncoding(final float[] scoresEncoding)194   private float[] decodeScoresEncoding(final float[] scoresEncoding) {
195     final float[] scores = new float[scoresEncoding.length];
196     for (int i = 0; i < scoresEncoding.length; ++i) {
197       scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i])));
198     }
199     return scores;
200   }
201 
202   @Override
recognizeImage(final Bitmap bitmap)203   public List<Recognition> recognizeImage(final Bitmap bitmap) {
204     // Log this method so that it can be analyzed with systrace.
205     Trace.beginSection("recognizeImage");
206 
207     Trace.beginSection("preprocessBitmap");
208     // Preprocess the image data from 0-255 int to normalized float based
209     // on the provided parameters.
210     bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
211 
212     for (int i = 0; i < intValues.length; ++i) {
213       floatValues[i * 3 + 0] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd;
214       floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd;
215       floatValues[i * 3 + 2] = ((intValues[i] & 0xFF) - imageMean) / imageStd;
216     }
217     Trace.endSection(); // preprocessBitmap
218 
219     // Copy the input data into TensorFlow.
220     Trace.beginSection("feed");
221     inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
222     Trace.endSection();
223 
224     // Run the inference call.
225     Trace.beginSection("run");
226     inferenceInterface.run(outputNames, logStats);
227     Trace.endSection();
228 
229     // Copy the output Tensor back into the output array.
230     Trace.beginSection("fetch");
231     final float[] outputScoresEncoding = new float[numLocations];
232     final float[] outputLocationsEncoding = new float[numLocations * 4];
233     inferenceInterface.fetch(outputNames[0], outputLocationsEncoding);
234     inferenceInterface.fetch(outputNames[1], outputScoresEncoding);
235     Trace.endSection();
236 
237     outputLocations = decodeLocationsEncoding(outputLocationsEncoding);
238     outputScores = decodeScoresEncoding(outputScoresEncoding);
239 
240     // Find the best detections.
241     final PriorityQueue<Recognition> pq =
242         new PriorityQueue<Recognition>(
243             1,
244             new Comparator<Recognition>() {
245               @Override
246               public int compare(final Recognition lhs, final Recognition rhs) {
247                 // Intentionally reversed to put high confidence at the head of the queue.
248                 return Float.compare(rhs.getConfidence(), lhs.getConfidence());
249               }
250             });
251 
252     // Scale them back to the input size.
253     for (int i = 0; i < outputScores.length; ++i) {
254       final RectF detection =
255           new RectF(
256               outputLocations[4 * i] * inputSize,
257               outputLocations[4 * i + 1] * inputSize,
258               outputLocations[4 * i + 2] * inputSize,
259               outputLocations[4 * i + 3] * inputSize);
260       pq.add(new Recognition("" + i, null, outputScores[i], detection));
261     }
262 
263     final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
264     for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
265       recognitions.add(pq.poll());
266     }
267     Trace.endSection(); // "recognizeImage"
268     return recognitions;
269   }
270 
271   @Override
enableStatLogging(final boolean logStats)272   public void enableStatLogging(final boolean logStats) {
273     this.logStats = logStats;
274   }
275 
276   @Override
getStatString()277   public String getStatString() {
278     return inferenceInterface.getStatString();
279   }
280 
281   @Override
close()282   public void close() {
283     inferenceInterface.close();
284   }
285 }
286