• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package com.example.android.tflitecamerademo;
17 
18 import android.app.Activity;
19 import android.content.res.AssetFileDescriptor;
20 import android.graphics.Bitmap;
21 import android.os.SystemClock;
22 import android.text.SpannableString;
23 import android.text.SpannableStringBuilder;
24 import android.text.style.ForegroundColorSpan;
25 import android.text.style.RelativeSizeSpan;
26 import android.util.Log;
27 import java.io.BufferedReader;
28 import java.io.FileInputStream;
29 import java.io.IOException;
30 import java.io.InputStreamReader;
31 import java.nio.ByteBuffer;
32 import java.nio.ByteOrder;
33 import java.nio.MappedByteBuffer;
34 import java.nio.channels.FileChannel;
35 import java.util.AbstractMap;
36 import java.util.ArrayList;
37 import java.util.Comparator;
38 import java.util.List;
39 import java.util.Map;
40 import java.util.PriorityQueue;
41 import org.tensorflow.lite.Delegate;
42 import org.tensorflow.lite.Interpreter;
43 
44 /**
45  * Classifies images with Tensorflow Lite.
46  */
47 public abstract class ImageClassifier {
48   // Display preferences
49   private static final float GOOD_PROB_THRESHOLD = 0.3f;
50   private static final int SMALL_COLOR = 0xffddaa88;
51 
52   /** Tag for the {@link Log}. */
53   private static final String TAG = "TfLiteCameraDemo";
54 
55   /** Number of results to show in the UI. */
56   private static final int RESULTS_TO_SHOW = 3;
57 
58   /** Dimensions of inputs. */
59   private static final int DIM_BATCH_SIZE = 1;
60 
61   private static final int DIM_PIXEL_SIZE = 3;
62 
63   /** Preallocated buffers for storing image data in. */
64   private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
65 
66   /** Options for configuring the Interpreter. */
67   private final Interpreter.Options tfliteOptions = new Interpreter.Options();
68 
69   /** The loaded TensorFlow Lite model. */
70   private MappedByteBuffer tfliteModel;
71 
72   /** An instance of the driver class to run model inference with Tensorflow Lite. */
73   protected Interpreter tflite;
74 
75   /** Labels corresponding to the output of the vision model. */
76   private List<String> labelList;
77 
78   /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
79   protected ByteBuffer imgData = null;
80 
81   /** multi-stage low pass filter * */
82   private float[][] filterLabelProbArray = null;
83 
84   private static final int FILTER_STAGES = 3;
85   private static final float FILTER_FACTOR = 0.4f;
86 
87   private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
88       new PriorityQueue<>(
89           RESULTS_TO_SHOW,
90           new Comparator<Map.Entry<String, Float>>() {
91             @Override
92             public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
93               return (o1.getValue()).compareTo(o2.getValue());
94             }
95           });
96 
97   /** holds a gpu delegate */
98   Delegate gpuDelegate = null;
99 
100   /** Initializes an {@code ImageClassifier}. */
ImageClassifier(Activity activity)101   ImageClassifier(Activity activity) throws IOException {
102     tfliteModel = loadModelFile(activity);
103     tflite = new Interpreter(tfliteModel, tfliteOptions);
104     labelList = loadLabelList(activity);
105     imgData =
106         ByteBuffer.allocateDirect(
107             DIM_BATCH_SIZE
108                 * getImageSizeX()
109                 * getImageSizeY()
110                 * DIM_PIXEL_SIZE
111                 * getNumBytesPerChannel());
112     imgData.order(ByteOrder.nativeOrder());
113     filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()];
114     Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
115   }
116 
117   /** Classifies a frame from the preview stream. */
classifyFrame(Bitmap bitmap, SpannableStringBuilder builder)118   void classifyFrame(Bitmap bitmap, SpannableStringBuilder builder) {
119     if (tflite == null) {
120       Log.e(TAG, "Image classifier has not been initialized; Skipped.");
121       builder.append(new SpannableString("Uninitialized Classifier."));
122     }
123     convertBitmapToByteBuffer(bitmap);
124     // Here's where the magic happens!!!
125     long startTime = SystemClock.uptimeMillis();
126     runInference();
127     long endTime = SystemClock.uptimeMillis();
128     Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
129 
130     // Smooth the results across frames.
131     applyFilter();
132 
133     // Print the results.
134     printTopKLabels(builder);
135     long duration = endTime - startTime;
136     SpannableString span = new SpannableString(duration + " ms");
137     span.setSpan(new ForegroundColorSpan(android.graphics.Color.LTGRAY), 0, span.length(), 0);
138     builder.append(span);
139   }
140 
applyFilter()141   void applyFilter() {
142     int numLabels = getNumLabels();
143 
144     // Low pass filter `labelProbArray` into the first stage of the filter.
145     for (int j = 0; j < numLabels; ++j) {
146       filterLabelProbArray[0][j] +=
147           FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]);
148     }
149     // Low pass filter each stage into the next.
150     for (int i = 1; i < FILTER_STAGES; ++i) {
151       for (int j = 0; j < numLabels; ++j) {
152         filterLabelProbArray[i][j] +=
153             FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]);
154       }
155     }
156 
157     // Copy the last stage filter output back to `labelProbArray`.
158     for (int j = 0; j < numLabels; ++j) {
159       setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]);
160     }
161   }
162 
recreateInterpreter()163   private void recreateInterpreter() {
164     if (tflite != null) {
165       tflite.close();
166       // TODO(b/120679982)
167       // gpuDelegate.close();
168       tflite = new Interpreter(tfliteModel, tfliteOptions);
169     }
170   }
171 
useGpu()172   public void useGpu() {
173     if (gpuDelegate == null && GpuDelegateHelper.isGpuDelegateAvailable()) {
174       gpuDelegate = GpuDelegateHelper.createGpuDelegate();
175       tfliteOptions.addDelegate(gpuDelegate);
176       recreateInterpreter();
177     }
178   }
179 
useCPU()180   public void useCPU() {
181     tfliteOptions.setUseNNAPI(false);
182     recreateInterpreter();
183   }
184 
useNNAPI()185   public void useNNAPI() {
186     tfliteOptions.setUseNNAPI(true);
187     recreateInterpreter();
188   }
189 
setNumThreads(int numThreads)190   public void setNumThreads(int numThreads) {
191     tfliteOptions.setNumThreads(numThreads);
192     recreateInterpreter();
193   }
194 
195   /** Closes tflite to release resources. */
close()196   public void close() {
197     tflite.close();
198     tflite = null;
199     tfliteModel = null;
200   }
201 
202   /** Reads label list from Assets. */
loadLabelList(Activity activity)203   private List<String> loadLabelList(Activity activity) throws IOException {
204     List<String> labelList = new ArrayList<String>();
205     BufferedReader reader =
206         new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath())));
207     String line;
208     while ((line = reader.readLine()) != null) {
209       labelList.add(line);
210     }
211     reader.close();
212     return labelList;
213   }
214 
215   /** Memory-map the model file in Assets. */
loadModelFile(Activity activity)216   private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
217     AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
218     FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
219     FileChannel fileChannel = inputStream.getChannel();
220     long startOffset = fileDescriptor.getStartOffset();
221     long declaredLength = fileDescriptor.getDeclaredLength();
222     return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
223   }
224 
225   /** Writes Image data into a {@code ByteBuffer}. */
convertBitmapToByteBuffer(Bitmap bitmap)226   private void convertBitmapToByteBuffer(Bitmap bitmap) {
227     if (imgData == null) {
228       return;
229     }
230     imgData.rewind();
231     bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
232     // Convert the image to floating point.
233     int pixel = 0;
234     long startTime = SystemClock.uptimeMillis();
235     for (int i = 0; i < getImageSizeX(); ++i) {
236       for (int j = 0; j < getImageSizeY(); ++j) {
237         final int val = intValues[pixel++];
238         addPixelValue(val);
239       }
240     }
241     long endTime = SystemClock.uptimeMillis();
242     Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
243   }
244 
245   /** Prints top-K labels, to be shown in UI as the results. */
printTopKLabels(SpannableStringBuilder builder)246   private void printTopKLabels(SpannableStringBuilder builder) {
247     for (int i = 0; i < getNumLabels(); ++i) {
248       sortedLabels.add(
249           new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i)));
250       if (sortedLabels.size() > RESULTS_TO_SHOW) {
251         sortedLabels.poll();
252       }
253     }
254 
255     final int size = sortedLabels.size();
256     for (int i = 0; i < size; i++) {
257       Map.Entry<String, Float> label = sortedLabels.poll();
258       SpannableString span =
259           new SpannableString(String.format("%s: %4.2f\n", label.getKey(), label.getValue()));
260       int color;
261       // Make it white when probability larger than threshold.
262       if (label.getValue() > GOOD_PROB_THRESHOLD) {
263         color = android.graphics.Color.WHITE;
264       } else {
265         color = SMALL_COLOR;
266       }
267       // Make first item bigger.
268       if (i == size - 1) {
269         float sizeScale = (i == size - 1) ? 1.25f : 0.8f;
270         span.setSpan(new RelativeSizeSpan(sizeScale), 0, span.length(), 0);
271       }
272       span.setSpan(new ForegroundColorSpan(color), 0, span.length(), 0);
273       builder.insert(0, span);
274     }
275   }
276 
277   /**
278    * Get the name of the model file stored in Assets.
279    *
280    * @return
281    */
getModelPath()282   protected abstract String getModelPath();
283 
284   /**
285    * Get the name of the label file stored in Assets.
286    *
287    * @return
288    */
getLabelPath()289   protected abstract String getLabelPath();
290 
291   /**
292    * Get the image size along the x axis.
293    *
294    * @return
295    */
getImageSizeX()296   protected abstract int getImageSizeX();
297 
298   /**
299    * Get the image size along the y axis.
300    *
301    * @return
302    */
getImageSizeY()303   protected abstract int getImageSizeY();
304 
305   /**
306    * Get the number of bytes that is used to store a single color channel value.
307    *
308    * @return
309    */
getNumBytesPerChannel()310   protected abstract int getNumBytesPerChannel();
311 
312   /**
313    * Add pixelValue to byteBuffer.
314    *
315    * @param pixelValue
316    */
addPixelValue(int pixelValue)317   protected abstract void addPixelValue(int pixelValue);
318 
319   /**
320    * Read the probability value for the specified label This is either the original value as it was
321    * read from the net's output or the updated value after the filter was applied.
322    *
323    * @param labelIndex
324    * @return
325    */
getProbability(int labelIndex)326   protected abstract float getProbability(int labelIndex);
327 
328   /**
329    * Set the probability value for the specified label.
330    *
331    * @param labelIndex
332    * @param value
333    */
setProbability(int labelIndex, Number value)334   protected abstract void setProbability(int labelIndex, Number value);
335 
336   /**
337    * Get the normalized probability value for the specified label. This is the final value as it
338    * will be shown to the user.
339    *
340    * @return
341    */
getNormalizedProbability(int labelIndex)342   protected abstract float getNormalizedProbability(int labelIndex);
343 
344   /**
345    * Run inference using the prepared input in {@link #imgData}. Afterwards, the result will be
346    * provided by getProbability().
347    *
348    * <p>This additional method is necessary, because we don't have a common base for different
349    * primitive data types.
350    */
runInference()351   protected abstract void runInference();
352 
353   /**
354    * Get the total number of labels.
355    *
356    * @return
357    */
getNumLabels()358   protected int getNumLabels() {
359     return labelList.size();
360   }
361 }
362