1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.io.File; 19 import java.nio.ByteBuffer; 20 import java.nio.MappedByteBuffer; 21 import java.util.ArrayList; 22 import java.util.HashMap; 23 import java.util.List; 24 import java.util.Map; 25 import org.checkerframework.checker.nullness.qual.NonNull; 26 27 /** 28 * Driver class to drive model inference with TensorFlow Lite. 29 * 30 * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations 31 * are executed for model inference. 32 * 33 * <p>For example, if a model takes only one input and returns only one output: 34 * 35 * <pre>{@code 36 * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { 37 * interpreter.run(input, output); 38 * } 39 * }</pre> 40 * 41 * <p>If a model takes multiple inputs or outputs: 42 * 43 * <pre>{@code 44 * Object[] inputs = {input0, input1, ...}; 45 * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>(); 46 * ByteBuffer ith_output = ByteBuffer.allocateDirect(3 * 2 * 4 * 4); // Float tensor, shape 3x2x4. 47 * ith_output.order(ByteOrder.nativeOrder()); 48 * map_of_indices_to_outputs.put(i, ith_output); 49 * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { 50 * interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); 51 * } 52 * }</pre> 53 * 54 * <p>If a model takes or produces string tensors: 55 * 56 * <pre>{@code 57 * String[] input = {"foo", "bar"}; // Input tensor shape is [2]. 58 * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. 59 * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { 60 * interpreter.runForMultipleInputsOutputs(input, output); 61 * } 62 * }</pre> 63 * 64 * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite 65 * model with Toco, as are the default shapes of the inputs. 66 * 67 * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will 68 * be implicitly resized according to that array's shape. When inputs are provided as {@link 69 * ByteBuffer} types, no implicit resizing is done; the caller must ensure that the {@link 70 * ByteBuffer} byte size either matches that of the corresponding tensor, or that they first resize 71 * the tensor via {@link #resizeInput()}. Tensor shape and type information can be obtained via the 72 * {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link 73 * #getOutputTensor(int)}. 74 * 75 * <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code 76 * Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()} 77 */ 78 public final class Interpreter implements AutoCloseable { 79 80 /** An options class for controlling runtime interpreter behavior. */ 81 public static class Options { Options()82 public Options() {} 83 84 /** 85 * Sets the number of threads to be used for ops that support multi-threading. Defaults to a 86 * platform-dependent value. 87 */ setNumThreads(int numThreads)88 public Options setNumThreads(int numThreads) { 89 this.numThreads = numThreads; 90 return this; 91 } 92 93 /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */ setUseNNAPI(boolean useNNAPI)94 public Options setUseNNAPI(boolean useNNAPI) { 95 this.useNNAPI = useNNAPI; 96 return this; 97 } 98 99 /** 100 * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false 101 * (disallow). 102 * WARNING: This is an experimental API and subject to change. 103 */ setAllowFp16PrecisionForFp32(boolean allow)104 public Options setAllowFp16PrecisionForFp32(boolean allow) { 105 this.allowFp16PrecisionForFp32 = allow; 106 return this; 107 } 108 109 /** 110 * Adds a {@link Delegate} to be applied during interpreter creation. 111 * 112 * <p>WARNING: This is an experimental interface that is subject to change. 113 */ addDelegate(Delegate delegate)114 public Options addDelegate(Delegate delegate) { 115 delegates.add(delegate); 116 return this; 117 } 118 119 /** 120 * Advanced: Set if buffer handle output is allowed. 121 * 122 * <p>When a {@link Delegate} supports hardware acceleration, the interpreter will make the data 123 * of output tensors available in the CPU-allocated tensor buffers by default. If the client can 124 * consume the buffer handle directly (e.g. reading output from OpenGL texture), it can set this 125 * flag to false, avoiding the copy of data to the CPU buffer. The delegate documentation should 126 * indicate whether this is supported and how it can be used. 127 * 128 * <p>WARNING: This is an experimental interface that is subject to change. 129 */ setAllowBufferHandleOutput(boolean allow)130 public Options setAllowBufferHandleOutput(boolean allow) { 131 this.allowBufferHandleOutput = allow; 132 return this; 133 } 134 135 int numThreads = -1; 136 Boolean useNNAPI; 137 Boolean allowFp16PrecisionForFp32; 138 Boolean allowBufferHandleOutput; 139 final List<Delegate> delegates = new ArrayList<>(); 140 } 141 142 /** 143 * Initializes a {@code Interpreter} 144 * 145 * @param modelFile: a File of a pre-trained TF Lite model. 146 */ Interpreter(@onNull File modelFile)147 public Interpreter(@NonNull File modelFile) { 148 this(modelFile, /*options = */ null); 149 } 150 151 /** 152 * Initializes a {@code Interpreter} and specifies the number of threads used for inference. 153 * 154 * @param modelFile: a file of a pre-trained TF Lite model 155 * @param numThreads: number of threads to use for inference 156 * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will 157 * be removed in a future release. 158 */ 159 @Deprecated Interpreter(@onNull File modelFile, int numThreads)160 public Interpreter(@NonNull File modelFile, int numThreads) { 161 this(modelFile, new Options().setNumThreads(numThreads)); 162 } 163 164 /** 165 * Initializes a {@code Interpreter} and specifies the number of threads used for inference. 166 * 167 * @param modelFile: a file of a pre-trained TF Lite model 168 * @param options: a set of options for customizing interpreter behavior 169 */ Interpreter(@onNull File modelFile, Options options)170 public Interpreter(@NonNull File modelFile, Options options) { 171 wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options); 172 } 173 174 /** 175 * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file. 176 * 177 * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The 178 * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a 179 * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. 180 */ Interpreter(@onNull ByteBuffer byteBuffer)181 public Interpreter(@NonNull ByteBuffer byteBuffer) { 182 this(byteBuffer, /* options= */ null); 183 } 184 185 /** 186 * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and specifies the 187 * number of threads used for inference. 188 * 189 * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The 190 * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a 191 * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. 192 * 193 * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method 194 * will be removed in a future release. 195 */ 196 @Deprecated Interpreter(@onNull ByteBuffer byteBuffer, int numThreads)197 public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { 198 this(byteBuffer, new Options().setNumThreads(numThreads)); 199 } 200 201 /** 202 * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. 203 * 204 * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code 205 * Interpreter}. 206 * 207 * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method 208 * will be removed in a future release. 209 */ 210 @Deprecated Interpreter(@onNull MappedByteBuffer mappedByteBuffer)211 public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) { 212 this(mappedByteBuffer, /* options= */ null); 213 } 214 215 /** 216 * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom 217 * {@link #Options}. 218 * 219 * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The 220 * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a 221 * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. 222 */ Interpreter(@onNull ByteBuffer byteBuffer, Options options)223 public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) { 224 wrapper = new NativeInterpreterWrapper(byteBuffer, options); 225 } 226 227 /** 228 * Runs model inference if the model takes only one input, and provides only one output. 229 * 230 * <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please 231 * consider using {@link ByteBuffer} to feed primitive input data for better performance. 232 * 233 * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types 234 * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large 235 * input data for primitive types, whereas string types require using the (multi-dimensional) 236 * array input path. When {@link ByteBuffer} is used, its content should remain unchanged 237 * until model inference is done. A {@code null} value is allowed only if the caller is using 238 * a {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to 239 * the input {@link Tensor}. 240 * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive 241 * types including int, float, long, and byte. A null value is allowed only if the caller is 242 * using a {@link Delegate} that allows buffer handle interop, and such a buffer has been 243 * bound to the output {@link Tensor}. See also {@link Options#setAllowBufferHandleOutput()}. 244 */ run(Object input, Object output)245 public void run(Object input, Object output) { 246 Object[] inputs = {input}; 247 Map<Integer, Object> outputs = new HashMap<>(); 248 outputs.put(0, output); 249 runForMultipleInputsOutputs(inputs, outputs); 250 } 251 252 /** 253 * Runs model inference if the model takes multiple inputs, or returns multiple outputs. 254 * 255 * <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please 256 * consider using {@link ByteBuffer} to feed primitive input data for better performance. 257 * 258 * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is 259 * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and 260 * such a buffer has been bound to the corresponding input or output {@link Tensor}(s). 261 * 262 * @param inputs an array of input data. The inputs should be in the same order as inputs of the 263 * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of 264 * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred 265 * way to pass large input data, whereas string types require using the (multi-dimensional) 266 * array input path. When {@link ByteBuffer} is used, its content should remain unchanged 267 * until model inference is done. 268 * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link 269 * ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep 270 * entries for the outputs to be used. 271 */ runForMultipleInputsOutputs( @onNull Object[] inputs, @NonNull Map<Integer, Object> outputs)272 public void runForMultipleInputsOutputs( 273 @NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) { 274 checkNotClosed(); 275 wrapper.run(inputs, outputs); 276 } 277 278 /** 279 * Resizes idx-th input of the native model to the given dims. 280 * 281 * <p>IllegalArgumentException will be thrown if it fails to resize. 282 */ resizeInput(int idx, @NonNull int[] dims)283 public void resizeInput(int idx, @NonNull int[] dims) { 284 checkNotClosed(); 285 wrapper.resizeInput(idx, dims); 286 } 287 288 /** Gets the number of input tensors. */ getInputTensorCount()289 public int getInputTensorCount() { 290 checkNotClosed(); 291 return wrapper.getInputTensorCount(); 292 } 293 294 /** 295 * Gets index of an input given the op name of the input. 296 * 297 * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used 298 * to initialize the {@link Interpreter}. 299 */ getInputIndex(String opName)300 public int getInputIndex(String opName) { 301 checkNotClosed(); 302 return wrapper.getInputIndex(opName); 303 } 304 305 /** 306 * Gets the Tensor associated with the provdied input index. 307 * 308 * <p>IllegalArgumentException will be thrown if the provided index is invalid. 309 */ getInputTensor(int inputIndex)310 public Tensor getInputTensor(int inputIndex) { 311 checkNotClosed(); 312 return wrapper.getInputTensor(inputIndex); 313 } 314 315 /** Gets the number of output Tensors. */ getOutputTensorCount()316 public int getOutputTensorCount() { 317 checkNotClosed(); 318 return wrapper.getOutputTensorCount(); 319 } 320 321 /** 322 * Gets index of an output given the op name of the output. 323 * 324 * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used 325 * to initialize the {@link Interpreter}. 326 */ getOutputIndex(String opName)327 public int getOutputIndex(String opName) { 328 checkNotClosed(); 329 return wrapper.getOutputIndex(opName); 330 } 331 332 /** 333 * Gets the Tensor associated with the provdied output index. 334 * 335 * <p>IllegalArgumentException will be thrown if the provided index is invalid. 336 */ getOutputTensor(int outputIndex)337 public Tensor getOutputTensor(int outputIndex) { 338 checkNotClosed(); 339 return wrapper.getOutputTensor(outputIndex); 340 } 341 342 /** 343 * Returns native inference timing. 344 * 345 * <p>IllegalArgumentException will be thrown if the model is not initialized by the {@link 346 * Interpreter}. 347 */ getLastNativeInferenceDurationNanoseconds()348 public Long getLastNativeInferenceDurationNanoseconds() { 349 checkNotClosed(); 350 return wrapper.getLastNativeInferenceDurationNanoseconds(); 351 } 352 353 /** 354 * Turns on/off Android NNAPI for hardware acceleration when it is available. 355 * 356 * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API. 357 * This method will be removed in a future release. 358 */ 359 @Deprecated setUseNNAPI(boolean useNNAPI)360 public void setUseNNAPI(boolean useNNAPI) { 361 checkNotClosed(); 362 wrapper.setUseNNAPI(useNNAPI); 363 } 364 365 /** 366 * Sets the number of threads to be used for ops that support multi-threading. 367 * 368 * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread 369 * multi-threading. This method will be removed in a future release. 370 */ 371 @Deprecated setNumThreads(int numThreads)372 public void setNumThreads(int numThreads) { 373 checkNotClosed(); 374 wrapper.setNumThreads(numThreads); 375 } 376 377 /** 378 * Advanced: Modifies the graph with the provided {@link Delegate}. 379 * 380 * <p>Note: The typical path for providing delegates is via {@link Options#addDelegate}, at 381 * creation time. This path should only be used when a delegate might require coordinated 382 * interaction between Interpeter creation and delegate application. 383 * 384 * <p>WARNING: This is an experimental API and subject to change. 385 */ modifyGraphWithDelegate(Delegate delegate)386 public void modifyGraphWithDelegate(Delegate delegate) { 387 checkNotClosed(); 388 wrapper.modifyGraphWithDelegate(delegate); 389 } 390 391 /** Release resources associated with the {@code Interpreter}. */ 392 @Override close()393 public void close() { 394 if (wrapper != null) { 395 wrapper.close(); 396 wrapper = null; 397 } 398 } 399 400 @Override finalize()401 protected void finalize() throws Throwable { 402 try { 403 close(); 404 } finally { 405 super.finalize(); 406 } 407 } 408 checkNotClosed()409 private void checkNotClosed() { 410 if (wrapper == null) { 411 throw new IllegalStateException("Internal error: The Interpreter has already been closed."); 412 } 413 } 414 415 NativeInterpreterWrapper wrapper; 416 } 417