1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package com.example.android.tflitecamerademo; 17 18 import android.app.Activity; 19 import android.content.res.AssetFileDescriptor; 20 import android.graphics.Bitmap; 21 import android.os.SystemClock; 22 import android.text.SpannableString; 23 import android.text.SpannableStringBuilder; 24 import android.text.style.ForegroundColorSpan; 25 import android.text.style.RelativeSizeSpan; 26 import android.util.Log; 27 import java.io.BufferedReader; 28 import java.io.FileInputStream; 29 import java.io.IOException; 30 import java.io.InputStreamReader; 31 import java.nio.ByteBuffer; 32 import java.nio.ByteOrder; 33 import java.nio.MappedByteBuffer; 34 import java.nio.channels.FileChannel; 35 import java.util.AbstractMap; 36 import java.util.ArrayList; 37 import java.util.Comparator; 38 import java.util.List; 39 import java.util.Map; 40 import java.util.PriorityQueue; 41 import org.tensorflow.lite.Delegate; 42 import org.tensorflow.lite.Interpreter; 43 44 /** 45 * Classifies images with Tensorflow Lite. 46 */ 47 public abstract class ImageClassifier { 48 // Display preferences 49 private static final float GOOD_PROB_THRESHOLD = 0.3f; 50 private static final int SMALL_COLOR = 0xffddaa88; 51 52 /** Tag for the {@link Log}. */ 53 private static final String TAG = "TfLiteCameraDemo"; 54 55 /** Number of results to show in the UI. */ 56 private static final int RESULTS_TO_SHOW = 3; 57 58 /** Dimensions of inputs. */ 59 private static final int DIM_BATCH_SIZE = 1; 60 61 private static final int DIM_PIXEL_SIZE = 3; 62 63 /** Preallocated buffers for storing image data in. */ 64 private int[] intValues = new int[getImageSizeX() * getImageSizeY()]; 65 66 /** Options for configuring the Interpreter. */ 67 private final Interpreter.Options tfliteOptions = new Interpreter.Options(); 68 69 /** The loaded TensorFlow Lite model. */ 70 private MappedByteBuffer tfliteModel; 71 72 /** An instance of the driver class to run model inference with Tensorflow Lite. */ 73 protected Interpreter tflite; 74 75 /** Labels corresponding to the output of the vision model. */ 76 private List<String> labelList; 77 78 /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ 79 protected ByteBuffer imgData = null; 80 81 /** multi-stage low pass filter * */ 82 private float[][] filterLabelProbArray = null; 83 84 private static final int FILTER_STAGES = 3; 85 private static final float FILTER_FACTOR = 0.4f; 86 87 private PriorityQueue<Map.Entry<String, Float>> sortedLabels = 88 new PriorityQueue<>( 89 RESULTS_TO_SHOW, 90 new Comparator<Map.Entry<String, Float>>() { 91 @Override 92 public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) { 93 return (o1.getValue()).compareTo(o2.getValue()); 94 } 95 }); 96 97 /** holds a gpu delegate */ 98 Delegate gpuDelegate = null; 99 100 /** Initializes an {@code ImageClassifier}. */ ImageClassifier(Activity activity)101 ImageClassifier(Activity activity) throws IOException { 102 tfliteModel = loadModelFile(activity); 103 tflite = new Interpreter(tfliteModel, tfliteOptions); 104 labelList = loadLabelList(activity); 105 imgData = 106 ByteBuffer.allocateDirect( 107 DIM_BATCH_SIZE 108 * getImageSizeX() 109 * getImageSizeY() 110 * DIM_PIXEL_SIZE 111 * getNumBytesPerChannel()); 112 imgData.order(ByteOrder.nativeOrder()); 113 filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()]; 114 Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); 115 } 116 117 /** Classifies a frame from the preview stream. */ classifyFrame(Bitmap bitmap, SpannableStringBuilder builder)118 void classifyFrame(Bitmap bitmap, SpannableStringBuilder builder) { 119 if (tflite == null) { 120 Log.e(TAG, "Image classifier has not been initialized; Skipped."); 121 builder.append(new SpannableString("Uninitialized Classifier.")); 122 } 123 convertBitmapToByteBuffer(bitmap); 124 // Here's where the magic happens!!! 125 long startTime = SystemClock.uptimeMillis(); 126 runInference(); 127 long endTime = SystemClock.uptimeMillis(); 128 Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); 129 130 // Smooth the results across frames. 131 applyFilter(); 132 133 // Print the results. 134 printTopKLabels(builder); 135 long duration = endTime - startTime; 136 SpannableString span = new SpannableString(duration + " ms"); 137 span.setSpan(new ForegroundColorSpan(android.graphics.Color.LTGRAY), 0, span.length(), 0); 138 builder.append(span); 139 } 140 applyFilter()141 void applyFilter() { 142 int numLabels = getNumLabels(); 143 144 // Low pass filter `labelProbArray` into the first stage of the filter. 145 for (int j = 0; j < numLabels; ++j) { 146 filterLabelProbArray[0][j] += 147 FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]); 148 } 149 // Low pass filter each stage into the next. 150 for (int i = 1; i < FILTER_STAGES; ++i) { 151 for (int j = 0; j < numLabels; ++j) { 152 filterLabelProbArray[i][j] += 153 FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]); 154 } 155 } 156 157 // Copy the last stage filter output back to `labelProbArray`. 158 for (int j = 0; j < numLabels; ++j) { 159 setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]); 160 } 161 } 162 recreateInterpreter()163 private void recreateInterpreter() { 164 if (tflite != null) { 165 tflite.close(); 166 // TODO(b/120679982) 167 // gpuDelegate.close(); 168 tflite = new Interpreter(tfliteModel, tfliteOptions); 169 } 170 } 171 useGpu()172 public void useGpu() { 173 if (gpuDelegate == null && GpuDelegateHelper.isGpuDelegateAvailable()) { 174 gpuDelegate = GpuDelegateHelper.createGpuDelegate(); 175 tfliteOptions.addDelegate(gpuDelegate); 176 recreateInterpreter(); 177 } 178 } 179 useCPU()180 public void useCPU() { 181 tfliteOptions.setUseNNAPI(false); 182 recreateInterpreter(); 183 } 184 useNNAPI()185 public void useNNAPI() { 186 tfliteOptions.setUseNNAPI(true); 187 recreateInterpreter(); 188 } 189 setNumThreads(int numThreads)190 public void setNumThreads(int numThreads) { 191 tfliteOptions.setNumThreads(numThreads); 192 recreateInterpreter(); 193 } 194 195 /** Closes tflite to release resources. */ close()196 public void close() { 197 tflite.close(); 198 tflite = null; 199 tfliteModel = null; 200 } 201 202 /** Reads label list from Assets. */ loadLabelList(Activity activity)203 private List<String> loadLabelList(Activity activity) throws IOException { 204 List<String> labelList = new ArrayList<String>(); 205 BufferedReader reader = 206 new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath()))); 207 String line; 208 while ((line = reader.readLine()) != null) { 209 labelList.add(line); 210 } 211 reader.close(); 212 return labelList; 213 } 214 215 /** Memory-map the model file in Assets. */ loadModelFile(Activity activity)216 private MappedByteBuffer loadModelFile(Activity activity) throws IOException { 217 AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath()); 218 FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); 219 FileChannel fileChannel = inputStream.getChannel(); 220 long startOffset = fileDescriptor.getStartOffset(); 221 long declaredLength = fileDescriptor.getDeclaredLength(); 222 return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); 223 } 224 225 /** Writes Image data into a {@code ByteBuffer}. */ convertBitmapToByteBuffer(Bitmap bitmap)226 private void convertBitmapToByteBuffer(Bitmap bitmap) { 227 if (imgData == null) { 228 return; 229 } 230 imgData.rewind(); 231 bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 232 // Convert the image to floating point. 233 int pixel = 0; 234 long startTime = SystemClock.uptimeMillis(); 235 for (int i = 0; i < getImageSizeX(); ++i) { 236 for (int j = 0; j < getImageSizeY(); ++j) { 237 final int val = intValues[pixel++]; 238 addPixelValue(val); 239 } 240 } 241 long endTime = SystemClock.uptimeMillis(); 242 Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); 243 } 244 245 /** Prints top-K labels, to be shown in UI as the results. */ printTopKLabels(SpannableStringBuilder builder)246 private void printTopKLabels(SpannableStringBuilder builder) { 247 for (int i = 0; i < getNumLabels(); ++i) { 248 sortedLabels.add( 249 new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i))); 250 if (sortedLabels.size() > RESULTS_TO_SHOW) { 251 sortedLabels.poll(); 252 } 253 } 254 255 final int size = sortedLabels.size(); 256 for (int i = 0; i < size; i++) { 257 Map.Entry<String, Float> label = sortedLabels.poll(); 258 SpannableString span = 259 new SpannableString(String.format("%s: %4.2f\n", label.getKey(), label.getValue())); 260 int color; 261 // Make it white when probability larger than threshold. 262 if (label.getValue() > GOOD_PROB_THRESHOLD) { 263 color = android.graphics.Color.WHITE; 264 } else { 265 color = SMALL_COLOR; 266 } 267 // Make first item bigger. 268 if (i == size - 1) { 269 float sizeScale = (i == size - 1) ? 1.25f : 0.8f; 270 span.setSpan(new RelativeSizeSpan(sizeScale), 0, span.length(), 0); 271 } 272 span.setSpan(new ForegroundColorSpan(color), 0, span.length(), 0); 273 builder.insert(0, span); 274 } 275 } 276 277 /** 278 * Get the name of the model file stored in Assets. 279 * 280 * @return 281 */ getModelPath()282 protected abstract String getModelPath(); 283 284 /** 285 * Get the name of the label file stored in Assets. 286 * 287 * @return 288 */ getLabelPath()289 protected abstract String getLabelPath(); 290 291 /** 292 * Get the image size along the x axis. 293 * 294 * @return 295 */ getImageSizeX()296 protected abstract int getImageSizeX(); 297 298 /** 299 * Get the image size along the y axis. 300 * 301 * @return 302 */ getImageSizeY()303 protected abstract int getImageSizeY(); 304 305 /** 306 * Get the number of bytes that is used to store a single color channel value. 307 * 308 * @return 309 */ getNumBytesPerChannel()310 protected abstract int getNumBytesPerChannel(); 311 312 /** 313 * Add pixelValue to byteBuffer. 314 * 315 * @param pixelValue 316 */ addPixelValue(int pixelValue)317 protected abstract void addPixelValue(int pixelValue); 318 319 /** 320 * Read the probability value for the specified label This is either the original value as it was 321 * read from the net's output or the updated value after the filter was applied. 322 * 323 * @param labelIndex 324 * @return 325 */ getProbability(int labelIndex)326 protected abstract float getProbability(int labelIndex); 327 328 /** 329 * Set the probability value for the specified label. 330 * 331 * @param labelIndex 332 * @param value 333 */ setProbability(int labelIndex, Number value)334 protected abstract void setProbability(int labelIndex, Number value); 335 336 /** 337 * Get the normalized probability value for the specified label. This is the final value as it 338 * will be shown to the user. 339 * 340 * @return 341 */ getNormalizedProbability(int labelIndex)342 protected abstract float getNormalizedProbability(int labelIndex); 343 344 /** 345 * Run inference using the prepared input in {@link #imgData}. Afterwards, the result will be 346 * provided by getProbability(). 347 * 348 * <p>This additional method is necessary, because we don't have a common base for different 349 * primitive data types. 350 */ runInference()351 protected abstract void runInference(); 352 353 /** 354 * Get the total number of labels. 355 * 356 * @return 357 */ getNumLabels()358 protected int getNumLabels() { 359 return labelList.size(); 360 } 361 } 362