1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package com.example.android.tflitecamerademo; 17 18 import android.app.Activity; 19 import android.content.res.AssetFileDescriptor; 20 import android.graphics.Bitmap; 21 import android.os.SystemClock; 22 import android.text.SpannableString; 23 import android.text.SpannableStringBuilder; 24 import android.text.style.ForegroundColorSpan; 25 import android.text.style.RelativeSizeSpan; 26 import android.util.Log; 27 import java.io.BufferedReader; 28 import java.io.FileInputStream; 29 import java.io.IOException; 30 import java.io.InputStreamReader; 31 import java.nio.ByteBuffer; 32 import java.nio.ByteOrder; 33 import java.nio.MappedByteBuffer; 34 import java.nio.channels.FileChannel; 35 import java.util.AbstractMap; 36 import java.util.ArrayList; 37 import java.util.Comparator; 38 import java.util.List; 39 import java.util.Map; 40 import java.util.PriorityQueue; 41 import org.tensorflow.lite.Interpreter; 42 import org.tensorflow.lite.gpu.GpuDelegate; 43 import org.tensorflow.lite.nnapi.NnApiDelegate; 44 45 /** 46 * Classifies images with Tensorflow Lite. 47 */ 48 public abstract class ImageClassifier { 49 // Display preferences 50 private static final float GOOD_PROB_THRESHOLD = 0.3f; 51 private static final int SMALL_COLOR = 0xffddaa88; 52 53 /** Tag for the {@link Log}. */ 54 private static final String TAG = "TfLiteCameraDemo"; 55 56 /** Number of results to show in the UI. */ 57 private static final int RESULTS_TO_SHOW = 3; 58 59 /** Dimensions of inputs. */ 60 private static final int DIM_BATCH_SIZE = 1; 61 62 private static final int DIM_PIXEL_SIZE = 3; 63 64 /** Preallocated buffers for storing image data in. */ 65 private int[] intValues = new int[getImageSizeX() * getImageSizeY()]; 66 67 /** Options for configuring the Interpreter. */ 68 private final Interpreter.Options tfliteOptions = new Interpreter.Options(); 69 70 /** The loaded TensorFlow Lite model. */ 71 private MappedByteBuffer tfliteModel; 72 73 /** An instance of the driver class to run model inference with Tensorflow Lite. */ 74 protected Interpreter tflite; 75 76 /** Labels corresponding to the output of the vision model. */ 77 private List<String> labelList; 78 79 /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ 80 protected ByteBuffer imgData = null; 81 82 /** multi-stage low pass filter * */ 83 private float[][] filterLabelProbArray = null; 84 85 private static final int FILTER_STAGES = 3; 86 private static final float FILTER_FACTOR = 0.4f; 87 88 private PriorityQueue<Map.Entry<String, Float>> sortedLabels = 89 new PriorityQueue<>( 90 RESULTS_TO_SHOW, 91 new Comparator<Map.Entry<String, Float>>() { 92 @Override 93 public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) { 94 return (o1.getValue()).compareTo(o2.getValue()); 95 } 96 }); 97 98 /** holds a gpu delegate */ 99 GpuDelegate gpuDelegate = null; 100 /** holds an nnapi delegate */ 101 NnApiDelegate nnapiDelegate = null; 102 103 /** Initializes an {@code ImageClassifier}. */ ImageClassifier(Activity activity)104 ImageClassifier(Activity activity) throws IOException { 105 tfliteModel = loadModelFile(activity); 106 tflite = new Interpreter(tfliteModel, tfliteOptions); 107 labelList = loadLabelList(activity); 108 imgData = 109 ByteBuffer.allocateDirect( 110 DIM_BATCH_SIZE 111 * getImageSizeX() 112 * getImageSizeY() 113 * DIM_PIXEL_SIZE 114 * getNumBytesPerChannel()); 115 imgData.order(ByteOrder.nativeOrder()); 116 filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()]; 117 Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); 118 } 119 120 /** Classifies a frame from the preview stream. */ classifyFrame(Bitmap bitmap, SpannableStringBuilder builder)121 void classifyFrame(Bitmap bitmap, SpannableStringBuilder builder) { 122 if (tflite == null) { 123 Log.e(TAG, "Image classifier has not been initialized; Skipped."); 124 builder.append(new SpannableString("Uninitialized Classifier.")); 125 } 126 convertBitmapToByteBuffer(bitmap); 127 // Here's where the magic happens!!! 128 long startTime = SystemClock.uptimeMillis(); 129 runInference(); 130 long endTime = SystemClock.uptimeMillis(); 131 Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); 132 133 // Smooth the results across frames. 134 applyFilter(); 135 136 // Print the results. 137 printTopKLabels(builder); 138 long duration = endTime - startTime; 139 SpannableString span = new SpannableString(duration + " ms"); 140 span.setSpan(new ForegroundColorSpan(android.graphics.Color.LTGRAY), 0, span.length(), 0); 141 builder.append(span); 142 } 143 applyFilter()144 void applyFilter() { 145 int numLabels = getNumLabels(); 146 147 // Low pass filter `labelProbArray` into the first stage of the filter. 148 for (int j = 0; j < numLabels; ++j) { 149 filterLabelProbArray[0][j] += 150 FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]); 151 } 152 // Low pass filter each stage into the next. 153 for (int i = 1; i < FILTER_STAGES; ++i) { 154 for (int j = 0; j < numLabels; ++j) { 155 filterLabelProbArray[i][j] += 156 FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]); 157 } 158 } 159 160 // Copy the last stage filter output back to `labelProbArray`. 161 for (int j = 0; j < numLabels; ++j) { 162 setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]); 163 } 164 } 165 recreateInterpreter()166 private void recreateInterpreter() { 167 if (tflite != null) { 168 tflite.close(); 169 tflite = new Interpreter(tfliteModel, tfliteOptions); 170 } 171 } 172 useGpu()173 public void useGpu() { 174 if (gpuDelegate == null) { 175 GpuDelegate.Options options = new GpuDelegate.Options(); 176 options.setQuantizedModelsAllowed(true); 177 178 gpuDelegate = new GpuDelegate(options); 179 tfliteOptions.addDelegate(gpuDelegate); 180 recreateInterpreter(); 181 } 182 } 183 useCPU()184 public void useCPU() { 185 recreateInterpreter(); 186 } 187 useNNAPI()188 public void useNNAPI() { 189 nnapiDelegate = new NnApiDelegate(); 190 tfliteOptions.addDelegate(nnapiDelegate); 191 recreateInterpreter(); 192 } 193 setNumThreads(int numThreads)194 public void setNumThreads(int numThreads) { 195 tfliteOptions.setNumThreads(numThreads); 196 recreateInterpreter(); 197 } 198 199 /** Closes tflite to release resources. */ close()200 public void close() { 201 tflite.close(); 202 tflite = null; 203 if (gpuDelegate != null) { 204 gpuDelegate.close(); 205 gpuDelegate = null; 206 } 207 if (nnapiDelegate != null) { 208 nnapiDelegate.close(); 209 nnapiDelegate = null; 210 } 211 tfliteModel = null; 212 } 213 214 /** Reads label list from Assets. */ loadLabelList(Activity activity)215 private List<String> loadLabelList(Activity activity) throws IOException { 216 List<String> labelList = new ArrayList<String>(); 217 BufferedReader reader = 218 new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath()))); 219 String line; 220 while ((line = reader.readLine()) != null) { 221 labelList.add(line); 222 } 223 reader.close(); 224 return labelList; 225 } 226 227 /** Memory-map the model file in Assets. */ loadModelFile(Activity activity)228 private MappedByteBuffer loadModelFile(Activity activity) throws IOException { 229 AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath()); 230 FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); 231 FileChannel fileChannel = inputStream.getChannel(); 232 long startOffset = fileDescriptor.getStartOffset(); 233 long declaredLength = fileDescriptor.getDeclaredLength(); 234 return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); 235 } 236 237 /** Writes Image data into a {@code ByteBuffer}. */ convertBitmapToByteBuffer(Bitmap bitmap)238 private void convertBitmapToByteBuffer(Bitmap bitmap) { 239 if (imgData == null) { 240 return; 241 } 242 imgData.rewind(); 243 bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 244 // Convert the image to floating point. 245 int pixel = 0; 246 long startTime = SystemClock.uptimeMillis(); 247 for (int i = 0; i < getImageSizeX(); ++i) { 248 for (int j = 0; j < getImageSizeY(); ++j) { 249 final int val = intValues[pixel++]; 250 addPixelValue(val); 251 } 252 } 253 long endTime = SystemClock.uptimeMillis(); 254 Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); 255 } 256 257 /** Prints top-K labels, to be shown in UI as the results. */ printTopKLabels(SpannableStringBuilder builder)258 private void printTopKLabels(SpannableStringBuilder builder) { 259 for (int i = 0; i < getNumLabels(); ++i) { 260 sortedLabels.add( 261 new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i))); 262 if (sortedLabels.size() > RESULTS_TO_SHOW) { 263 sortedLabels.poll(); 264 } 265 } 266 267 final int size = sortedLabels.size(); 268 for (int i = 0; i < size; i++) { 269 Map.Entry<String, Float> label = sortedLabels.poll(); 270 SpannableString span = 271 new SpannableString(String.format("%s: %4.2f\n", label.getKey(), label.getValue())); 272 int color; 273 // Make it white when probability larger than threshold. 274 if (label.getValue() > GOOD_PROB_THRESHOLD) { 275 color = android.graphics.Color.WHITE; 276 } else { 277 color = SMALL_COLOR; 278 } 279 // Make first item bigger. 280 if (i == size - 1) { 281 float sizeScale = (i == size - 1) ? 1.25f : 0.8f; 282 span.setSpan(new RelativeSizeSpan(sizeScale), 0, span.length(), 0); 283 } 284 span.setSpan(new ForegroundColorSpan(color), 0, span.length(), 0); 285 builder.insert(0, span); 286 } 287 } 288 289 /** 290 * Get the name of the model file stored in Assets. 291 * 292 * @return 293 */ getModelPath()294 protected abstract String getModelPath(); 295 296 /** 297 * Get the name of the label file stored in Assets. 298 * 299 * @return 300 */ getLabelPath()301 protected abstract String getLabelPath(); 302 303 /** 304 * Get the image size along the x axis. 305 * 306 * @return 307 */ getImageSizeX()308 protected abstract int getImageSizeX(); 309 310 /** 311 * Get the image size along the y axis. 312 * 313 * @return 314 */ getImageSizeY()315 protected abstract int getImageSizeY(); 316 317 /** 318 * Get the number of bytes that is used to store a single color channel value. 319 * 320 * @return 321 */ getNumBytesPerChannel()322 protected abstract int getNumBytesPerChannel(); 323 324 /** 325 * Add pixelValue to byteBuffer. 326 * 327 * @param pixelValue 328 */ addPixelValue(int pixelValue)329 protected abstract void addPixelValue(int pixelValue); 330 331 /** 332 * Read the probability value for the specified label This is either the original value as it was 333 * read from the net's output or the updated value after the filter was applied. 334 * 335 * @param labelIndex 336 * @return 337 */ getProbability(int labelIndex)338 protected abstract float getProbability(int labelIndex); 339 340 /** 341 * Set the probability value for the specified label. 342 * 343 * @param labelIndex 344 * @param value 345 */ setProbability(int labelIndex, Number value)346 protected abstract void setProbability(int labelIndex, Number value); 347 348 /** 349 * Get the normalized probability value for the specified label. This is the final value as it 350 * will be shown to the user. 351 * 352 * @return 353 */ getNormalizedProbability(int labelIndex)354 protected abstract float getNormalizedProbability(int labelIndex); 355 356 /** 357 * Run inference using the prepared input in {@link #imgData}. Afterwards, the result will be 358 * provided by getProbability(). 359 * 360 * <p>This additional method is necessary, because we don't have a common base for different 361 * primitive data types. 362 */ runInference()363 protected abstract void runInference(); 364 365 /** 366 * Get the total number of labels. 367 * 368 * @return 369 */ getNumLabels()370 protected int getNumLabels() { 371 return labelList.size(); 372 } 373 } 374