• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package com.example.android.tflitecamerademo;
17 
18 import android.app.Activity;
19 import android.content.res.AssetFileDescriptor;
20 import android.graphics.Bitmap;
21 import android.os.SystemClock;
22 import android.text.SpannableString;
23 import android.text.SpannableStringBuilder;
24 import android.text.style.ForegroundColorSpan;
25 import android.text.style.RelativeSizeSpan;
26 import android.util.Log;
27 import java.io.BufferedReader;
28 import java.io.FileInputStream;
29 import java.io.IOException;
30 import java.io.InputStreamReader;
31 import java.nio.ByteBuffer;
32 import java.nio.ByteOrder;
33 import java.nio.MappedByteBuffer;
34 import java.nio.channels.FileChannel;
35 import java.util.AbstractMap;
36 import java.util.ArrayList;
37 import java.util.Comparator;
38 import java.util.List;
39 import java.util.Map;
40 import java.util.PriorityQueue;
41 import org.tensorflow.lite.Interpreter;
42 import org.tensorflow.lite.gpu.GpuDelegate;
43 import org.tensorflow.lite.nnapi.NnApiDelegate;
44 
45 /**
46  * Classifies images with Tensorflow Lite.
47  */
48 public abstract class ImageClassifier {
49   // Display preferences
50   private static final float GOOD_PROB_THRESHOLD = 0.3f;
51   private static final int SMALL_COLOR = 0xffddaa88;
52 
53   /** Tag for the {@link Log}. */
54   private static final String TAG = "TfLiteCameraDemo";
55 
56   /** Number of results to show in the UI. */
57   private static final int RESULTS_TO_SHOW = 3;
58 
59   /** Dimensions of inputs. */
60   private static final int DIM_BATCH_SIZE = 1;
61 
62   private static final int DIM_PIXEL_SIZE = 3;
63 
64   /** Preallocated buffers for storing image data in. */
65   private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
66 
67   /** Options for configuring the Interpreter. */
68   private final Interpreter.Options tfliteOptions = new Interpreter.Options();
69 
70   /** The loaded TensorFlow Lite model. */
71   private MappedByteBuffer tfliteModel;
72 
73   /** An instance of the driver class to run model inference with Tensorflow Lite. */
74   protected Interpreter tflite;
75 
76   /** Labels corresponding to the output of the vision model. */
77   private List<String> labelList;
78 
79   /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
80   protected ByteBuffer imgData = null;
81 
82   /** multi-stage low pass filter * */
83   private float[][] filterLabelProbArray = null;
84 
85   private static final int FILTER_STAGES = 3;
86   private static final float FILTER_FACTOR = 0.4f;
87 
88   private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
89       new PriorityQueue<>(
90           RESULTS_TO_SHOW,
91           new Comparator<Map.Entry<String, Float>>() {
92             @Override
93             public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
94               return (o1.getValue()).compareTo(o2.getValue());
95             }
96           });
97 
98   /** holds a gpu delegate */
99   GpuDelegate gpuDelegate = null;
100   /** holds an nnapi delegate */
101   NnApiDelegate nnapiDelegate = null;
102 
103   /** Initializes an {@code ImageClassifier}. */
ImageClassifier(Activity activity)104   ImageClassifier(Activity activity) throws IOException {
105     tfliteModel = loadModelFile(activity);
106     tflite = new Interpreter(tfliteModel, tfliteOptions);
107     labelList = loadLabelList(activity);
108     imgData =
109         ByteBuffer.allocateDirect(
110             DIM_BATCH_SIZE
111                 * getImageSizeX()
112                 * getImageSizeY()
113                 * DIM_PIXEL_SIZE
114                 * getNumBytesPerChannel());
115     imgData.order(ByteOrder.nativeOrder());
116     filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()];
117     Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
118   }
119 
120   /** Classifies a frame from the preview stream. */
classifyFrame(Bitmap bitmap, SpannableStringBuilder builder)121   void classifyFrame(Bitmap bitmap, SpannableStringBuilder builder) {
122     if (tflite == null) {
123       Log.e(TAG, "Image classifier has not been initialized; Skipped.");
124       builder.append(new SpannableString("Uninitialized Classifier."));
125     }
126     convertBitmapToByteBuffer(bitmap);
127     // Here's where the magic happens!!!
128     long startTime = SystemClock.uptimeMillis();
129     runInference();
130     long endTime = SystemClock.uptimeMillis();
131     Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
132 
133     // Smooth the results across frames.
134     applyFilter();
135 
136     // Print the results.
137     printTopKLabels(builder);
138     long duration = endTime - startTime;
139     SpannableString span = new SpannableString(duration + " ms");
140     span.setSpan(new ForegroundColorSpan(android.graphics.Color.LTGRAY), 0, span.length(), 0);
141     builder.append(span);
142   }
143 
applyFilter()144   void applyFilter() {
145     int numLabels = getNumLabels();
146 
147     // Low pass filter `labelProbArray` into the first stage of the filter.
148     for (int j = 0; j < numLabels; ++j) {
149       filterLabelProbArray[0][j] +=
150           FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]);
151     }
152     // Low pass filter each stage into the next.
153     for (int i = 1; i < FILTER_STAGES; ++i) {
154       for (int j = 0; j < numLabels; ++j) {
155         filterLabelProbArray[i][j] +=
156             FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]);
157       }
158     }
159 
160     // Copy the last stage filter output back to `labelProbArray`.
161     for (int j = 0; j < numLabels; ++j) {
162       setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]);
163     }
164   }
165 
recreateInterpreter()166   private void recreateInterpreter() {
167     if (tflite != null) {
168       tflite.close();
169       tflite = new Interpreter(tfliteModel, tfliteOptions);
170     }
171   }
172 
useGpu()173   public void useGpu() {
174     if (gpuDelegate == null) {
175       GpuDelegate.Options options = new GpuDelegate.Options();
176       options.setQuantizedModelsAllowed(true);
177 
178       gpuDelegate = new GpuDelegate(options);
179       tfliteOptions.addDelegate(gpuDelegate);
180       recreateInterpreter();
181     }
182   }
183 
useCPU()184   public void useCPU() {
185     recreateInterpreter();
186   }
187 
useNNAPI()188   public void useNNAPI() {
189     nnapiDelegate = new NnApiDelegate();
190     tfliteOptions.addDelegate(nnapiDelegate);
191     recreateInterpreter();
192   }
193 
setNumThreads(int numThreads)194   public void setNumThreads(int numThreads) {
195     tfliteOptions.setNumThreads(numThreads);
196     recreateInterpreter();
197   }
198 
199   /** Closes tflite to release resources. */
close()200   public void close() {
201     tflite.close();
202     tflite = null;
203     if (gpuDelegate != null) {
204       gpuDelegate.close();
205       gpuDelegate = null;
206     }
207     if (nnapiDelegate != null) {
208       nnapiDelegate.close();
209       nnapiDelegate = null;
210     }
211     tfliteModel = null;
212   }
213 
214   /** Reads label list from Assets. */
loadLabelList(Activity activity)215   private List<String> loadLabelList(Activity activity) throws IOException {
216     List<String> labelList = new ArrayList<String>();
217     BufferedReader reader =
218         new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath())));
219     String line;
220     while ((line = reader.readLine()) != null) {
221       labelList.add(line);
222     }
223     reader.close();
224     return labelList;
225   }
226 
227   /** Memory-map the model file in Assets. */
loadModelFile(Activity activity)228   private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
229     AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
230     FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
231     FileChannel fileChannel = inputStream.getChannel();
232     long startOffset = fileDescriptor.getStartOffset();
233     long declaredLength = fileDescriptor.getDeclaredLength();
234     return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
235   }
236 
237   /** Writes Image data into a {@code ByteBuffer}. */
convertBitmapToByteBuffer(Bitmap bitmap)238   private void convertBitmapToByteBuffer(Bitmap bitmap) {
239     if (imgData == null) {
240       return;
241     }
242     imgData.rewind();
243     bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
244     // Convert the image to floating point.
245     int pixel = 0;
246     long startTime = SystemClock.uptimeMillis();
247     for (int i = 0; i < getImageSizeX(); ++i) {
248       for (int j = 0; j < getImageSizeY(); ++j) {
249         final int val = intValues[pixel++];
250         addPixelValue(val);
251       }
252     }
253     long endTime = SystemClock.uptimeMillis();
254     Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
255   }
256 
257   /** Prints top-K labels, to be shown in UI as the results. */
printTopKLabels(SpannableStringBuilder builder)258   private void printTopKLabels(SpannableStringBuilder builder) {
259     for (int i = 0; i < getNumLabels(); ++i) {
260       sortedLabels.add(
261           new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i)));
262       if (sortedLabels.size() > RESULTS_TO_SHOW) {
263         sortedLabels.poll();
264       }
265     }
266 
267     final int size = sortedLabels.size();
268     for (int i = 0; i < size; i++) {
269       Map.Entry<String, Float> label = sortedLabels.poll();
270       SpannableString span =
271           new SpannableString(String.format("%s: %4.2f\n", label.getKey(), label.getValue()));
272       int color;
273       // Make it white when probability larger than threshold.
274       if (label.getValue() > GOOD_PROB_THRESHOLD) {
275         color = android.graphics.Color.WHITE;
276       } else {
277         color = SMALL_COLOR;
278       }
279       // Make first item bigger.
280       if (i == size - 1) {
281         float sizeScale = (i == size - 1) ? 1.25f : 0.8f;
282         span.setSpan(new RelativeSizeSpan(sizeScale), 0, span.length(), 0);
283       }
284       span.setSpan(new ForegroundColorSpan(color), 0, span.length(), 0);
285       builder.insert(0, span);
286     }
287   }
288 
289   /**
290    * Get the name of the model file stored in Assets.
291    *
292    * @return
293    */
getModelPath()294   protected abstract String getModelPath();
295 
296   /**
297    * Get the name of the label file stored in Assets.
298    *
299    * @return
300    */
getLabelPath()301   protected abstract String getLabelPath();
302 
303   /**
304    * Get the image size along the x axis.
305    *
306    * @return
307    */
getImageSizeX()308   protected abstract int getImageSizeX();
309 
310   /**
311    * Get the image size along the y axis.
312    *
313    * @return
314    */
getImageSizeY()315   protected abstract int getImageSizeY();
316 
317   /**
318    * Get the number of bytes that is used to store a single color channel value.
319    *
320    * @return
321    */
getNumBytesPerChannel()322   protected abstract int getNumBytesPerChannel();
323 
324   /**
325    * Add pixelValue to byteBuffer.
326    *
327    * @param pixelValue
328    */
addPixelValue(int pixelValue)329   protected abstract void addPixelValue(int pixelValue);
330 
331   /**
332    * Read the probability value for the specified label This is either the original value as it was
333    * read from the net's output or the updated value after the filter was applied.
334    *
335    * @param labelIndex
336    * @return
337    */
getProbability(int labelIndex)338   protected abstract float getProbability(int labelIndex);
339 
340   /**
341    * Set the probability value for the specified label.
342    *
343    * @param labelIndex
344    * @param value
345    */
setProbability(int labelIndex, Number value)346   protected abstract void setProbability(int labelIndex, Number value);
347 
348   /**
349    * Get the normalized probability value for the specified label. This is the final value as it
350    * will be shown to the user.
351    *
352    * @return
353    */
getNormalizedProbability(int labelIndex)354   protected abstract float getNormalizedProbability(int labelIndex);
355 
356   /**
357    * Run inference using the prepared input in {@link #imgData}. Afterwards, the result will be
358    * provided by getProbability().
359    *
360    * <p>This additional method is necessary, because we don't have a common base for different
361    * primitive data types.
362    */
runInference()363   protected abstract void runInference();
364 
365   /**
366    * Get the total number of labels.
367    *
368    * @return
369    */
getNumLabels()370   protected int getNumLabels() {
371     return labelList.size();
372   }
373 }
374