1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.io.File; 19 import java.nio.ByteBuffer; 20 import java.util.ArrayList; 21 import java.util.Collections; 22 import java.util.List; 23 import java.util.Map; 24 import org.checkerframework.checker.nullness.qual.NonNull; 25 import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime; 26 import org.tensorflow.lite.nnapi.NnApiDelegate; 27 28 /** 29 * Interface to TensorFlow Lite model interpreter, excluding experimental methods. 30 * 31 * <p>An {@code InterpreterApi} instance encapsulates a pre-trained TensorFlow Lite model, in which 32 * operations are executed for model inference. 33 * 34 * <p>For example, if a model takes only one input and returns only one output: 35 * 36 * <pre>{@code 37 * try (InterpreterApi interpreter = 38 * new InterpreterApi.create(file_of_a_tensorflowlite_model)) { 39 * interpreter.run(input, output); 40 * } 41 * }</pre> 42 * 43 * <p>If a model takes multiple inputs or outputs: 44 * 45 * <pre>{@code 46 * Object[] inputs = {input0, input1, ...}; 47 * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>(); 48 * FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4. 49 * ith_output.order(ByteOrder.nativeOrder()); 50 * map_of_indices_to_outputs.put(i, ith_output); 51 * try (InterpreterApi interpreter = 52 * new InterpreterApi.create(file_of_a_tensorflowlite_model)) { 53 * interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); 54 * } 55 * }</pre> 56 * 57 * <p>If a model takes or produces string tensors: 58 * 59 * <pre>{@code 60 * String[] input = {"foo", "bar"}; // Input tensor shape is [2]. 61 * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. 62 * try (InterpreterApi interpreter = 63 * new InterpreterApi.create(file_of_a_tensorflowlite_model)) { 64 * interpreter.runForMultipleInputsOutputs(input, output); 65 * } 66 * }</pre> 67 * 68 * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite 69 * model with Toco, as are the default shapes of the inputs. 70 * 71 * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will 72 * be implicitly resized according to that array's shape. When inputs are provided as {@link 73 * java.nio.Buffer} types, no implicit resizing is done; the caller must ensure that the {@link 74 * java.nio.Buffer} byte size either matches that of the corresponding tensor, or that they first 75 * resize the tensor via {@link #resizeInput(int, int[])}. Tensor shape and type information can be 76 * obtained via the {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link 77 * #getOutputTensor(int)}. 78 * 79 * <p><b>WARNING:</b>{@code InterpreterApi} instances are <b>not</b> thread-safe. 80 * 81 * <p><b>WARNING:</b>An {@code InterpreterApi} instance owns resources that <b>must</b> be 82 * explicitly freed by invoking {@link #close()} 83 * 84 * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19, 85 * but is not guaranteed. 86 */ 87 public interface InterpreterApi extends AutoCloseable { 88 89 /** An options class for controlling runtime interpreter behavior. */ 90 class Options { 91 Options()92 public Options() { 93 this.delegates = new ArrayList<>(); 94 this.delegateFactories = new ArrayList<>(); 95 } 96 Options(Options other)97 public Options(Options other) { 98 this.numThreads = other.numThreads; 99 this.useNNAPI = other.useNNAPI; 100 this.allowCancellation = other.allowCancellation; 101 this.delegates = new ArrayList<>(other.delegates); 102 this.delegateFactories = new ArrayList<>(other.delegateFactories); 103 this.runtime = other.runtime; 104 } 105 106 /** 107 * Sets the number of threads to be used for ops that support multi-threading. 108 * 109 * <p>{@code numThreads} should be {@code >= -1}. Setting {@code numThreads} to 0 has the effect 110 * of disabling multithreading, which is equivalent to setting {@code numThreads} to 1. If 111 * unspecified, or set to the value -1, the number of threads used will be 112 * implementation-defined and platform-dependent. 113 */ setNumThreads(int numThreads)114 public Options setNumThreads(int numThreads) { 115 this.numThreads = numThreads; 116 return this; 117 } 118 119 /** 120 * Returns the number of threads to be used for ops that support multi-threading. 121 * 122 * <p>{@code numThreads} should be {@code >= -1}. Values of 0 (or 1) disable multithreading. 123 * Default value is -1: the number of threads used will be implementation-defined and 124 * platform-dependent. 125 */ getNumThreads()126 public int getNumThreads() { 127 return numThreads; 128 } 129 130 /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */ setUseNNAPI(boolean useNNAPI)131 public Options setUseNNAPI(boolean useNNAPI) { 132 this.useNNAPI = useNNAPI; 133 return this; 134 } 135 136 /** 137 * Returns whether to use NN API (if available) for op execution. Default value is false 138 * (disabled). 139 */ getUseNNAPI()140 public boolean getUseNNAPI() { 141 return useNNAPI != null && useNNAPI; 142 } 143 144 /** 145 * Advanced: Set if the interpreter is able to be cancelled. 146 * 147 * <p>Interpreters may have an experimental API <a 148 * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter#setCancelled(boolean)">setCancelled(boolean)</a>. 149 * If this interpreter is cancellable and such a method is invoked, a cancellation flag will be 150 * set to true. The interpreter will check the flag between Op invocations, and if it's {@code 151 * true}, the interpreter will stop execution. The interpreter will remain a cancelled state 152 * until explicitly "uncancelled" by {@code setCancelled(false)}. 153 */ setCancellable(boolean allow)154 public Options setCancellable(boolean allow) { 155 this.allowCancellation = allow; 156 return this; 157 } 158 159 /** 160 * Advanced: Returns whether the interpreter is able to be cancelled. 161 * 162 * <p>Interpreters may have an experimental API <a 163 * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter#setCancelled(boolean)">setCancelled(boolean)</a>. 164 * If this interpreter is cancellable and such a method is invoked, a cancellation flag will be 165 * set to true. The interpreter will check the flag between Op invocations, and if it's {@code 166 * true}, the interpreter will stop execution. The interpreter will remain a cancelled state 167 * until explicitly "uncancelled" by {@code setCancelled(false)}. 168 */ isCancellable()169 public boolean isCancellable() { 170 return allowCancellation != null && allowCancellation; 171 } 172 173 /** 174 * Adds a {@link Delegate} to be applied during interpreter creation. 175 * 176 * <p>Delegates added here are applied before any delegates created from a {@link 177 * DelegateFactory} that was added with {@link #addDelegateFactory}. 178 * 179 * <p>Note that TF Lite in Google Play Services (see {@link #setRuntime}) does not support 180 * external (developer-provided) delegates, and adding a {@link Delegate} other than {@link 181 * NnApiDelegate} here is not allowed when using TF Lite in Google Play Services. 182 */ addDelegate(Delegate delegate)183 public Options addDelegate(Delegate delegate) { 184 delegates.add(delegate); 185 return this; 186 } 187 188 /** 189 * Returns the list of delegates intended to be applied during interpreter creation that have 190 * been registered via {@code addDelegate}. 191 */ getDelegates()192 public List<Delegate> getDelegates() { 193 return Collections.unmodifiableList(delegates); 194 } 195 196 /** 197 * Adds a {@link DelegateFactory} which will be invoked to apply its created {@link Delegate} 198 * during interpreter creation. 199 * 200 * <p>Delegates from a delegated factory that was added here are applied after any delegates 201 * added with {@link #addDelegate}. 202 */ addDelegateFactory(DelegateFactory delegateFactory)203 public Options addDelegateFactory(DelegateFactory delegateFactory) { 204 delegateFactories.add(delegateFactory); 205 return this; 206 } 207 208 /** 209 * Returns the list of delegate factories that have been registered via {@code 210 * addDelegateFactory}). 211 */ getDelegateFactories()212 public List<DelegateFactory> getDelegateFactories() { 213 return Collections.unmodifiableList(delegateFactories); 214 } 215 216 /** 217 * Enum to represent where to get the TensorFlow Lite runtime implementation from. 218 * 219 * <p>The difference between this class and the RuntimeFlavor class: This class specifies a 220 * <em>preference</em> which runtime to use, whereas {@link RuntimeFlavor} specifies which exact 221 * runtime <em>is</em> being used. 222 */ 223 public enum TfLiteRuntime { 224 /** 225 * Use a TF Lite runtime implementation that is linked into the application. If there is no 226 * suitable TF Lite runtime implementation linked into the application, then attempting to 227 * create an InterpreterApi instance with this TfLiteRuntime setting will throw an 228 * IllegalStateException exception (even if the OS or system services could provide a TF Lite 229 * runtime implementation). 230 * 231 * <p>This is the default setting. This setting is also appropriate for apps that must run on 232 * systems that don't provide a TF Lite runtime implementation. 233 */ 234 FROM_APPLICATION_ONLY, 235 236 /** 237 * Use a TF Lite runtime implementation provided by the OS or system services. This will be 238 * obtained from a system library / shared object / service, such as Google Play Services. It 239 * may be newer than the version linked into the application (if any). If there is no suitable 240 * TF Lite runtime implementation provided by the system, then attempting to create an 241 * InterpreterApi instance with this TfLiteRuntime setting will throw an IllegalStateException 242 * exception (even if there is a TF Lite runtime implementation linked into the application). 243 * 244 * <p>This setting is appropriate for code that will use a system-provided TF Lite runtime, 245 * which can reduce app binary size and can be updated more frequently. 246 */ 247 FROM_SYSTEM_ONLY, 248 249 /** 250 * Use a system-provided TF Lite runtime implementation, if any, otherwise use the TF Lite 251 * runtime implementation linked into the application, if any. If no suitable TF Lite runtime 252 * can be found in any location, then attempting to create an InterpreterApi instance with 253 * this TFLiteRuntime setting will throw an IllegalStateException. If there is both a suitable 254 * TF Lite runtime linked into the application and also a suitable TF Lite runtime provided by 255 * the system, the one provided by the system will be used. 256 * 257 * <p>This setting is suitable for use in code that doesn't care where the TF Lite runtime is 258 * coming from (e.g. middleware layers). 259 */ 260 PREFER_SYSTEM_OVER_APPLICATION, 261 } 262 263 /** Specify where to get the TF Lite runtime implementation from. */ setRuntime(TfLiteRuntime runtime)264 public Options setRuntime(TfLiteRuntime runtime) { 265 this.runtime = runtime; 266 return this; 267 } 268 269 /** Return where to get the TF Lite runtime implementation from. */ getRuntime()270 public TfLiteRuntime getRuntime() { 271 return runtime; 272 } 273 274 TfLiteRuntime runtime = TfLiteRuntime.FROM_APPLICATION_ONLY; 275 int numThreads = -1; 276 Boolean useNNAPI; 277 Boolean allowCancellation; 278 279 // See InterpreterApi.Options#addDelegate. 280 final List<Delegate> delegates; 281 // See InterpreterApi.Options#addDelegateFactory. 282 private final List<DelegateFactory> delegateFactories; 283 } 284 285 /** 286 * Constructs an {@link InterpreterApi} instance, using the specified model and options. The model 287 * will be loaded from a file. 288 * 289 * @param modelFile A file containing a pre-trained TF Lite model. 290 * @param options A set of options for customizing interpreter behavior. 291 * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite 292 * model. 293 */ 294 @SuppressWarnings("StaticOrDefaultInterfaceMethod") create(@onNull File modelFile, InterpreterApi.Options options)295 static InterpreterApi create(@NonNull File modelFile, InterpreterApi.Options options) { 296 TfLiteRuntime runtime = (options == null ? null : options.getRuntime()); 297 InterpreterFactoryApi factory = TensorFlowLite.getFactory(runtime); 298 return factory.create(modelFile, options); 299 } 300 301 /** 302 * Constructs an {@link InterpreterApi} instance, using the specified model and options. The model 303 * will be read from a {@code ByteBuffer}. 304 * 305 * @param byteBuffer A pre-trained TF Lite model, in binary serialized form. The ByteBuffer should 306 * not be modified after the construction of an {@link InterpreterApi} instance. The {@code 307 * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a 308 * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. 309 * @param options A set of options for customizing interpreter behavior. 310 * @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a 311 * direct {@code ByteBuffer} of nativeOrder. 312 */ 313 @SuppressWarnings("StaticOrDefaultInterfaceMethod") create(@onNull ByteBuffer byteBuffer, InterpreterApi.Options options)314 static InterpreterApi create(@NonNull ByteBuffer byteBuffer, InterpreterApi.Options options) { 315 TfLiteRuntime runtime = (options == null ? null : options.getRuntime()); 316 InterpreterFactoryApi factory = TensorFlowLite.getFactory(runtime); 317 return factory.create(byteBuffer, options); 318 } 319 320 /** 321 * Runs model inference if the model takes only one input, and provides only one output. 322 * 323 * <p>Warning: The API is more efficient if a {@code Buffer} (preferably direct, but not required) 324 * is used as the input/output data type. Please consider using {@code Buffer} to feed and fetch 325 * primitive data for better performance. The following concrete {@code Buffer} types are 326 * supported: 327 * 328 * <ul> 329 * <li>{@code ByteBuffer} - compatible with any underlying primitive Tensor type. 330 * <li>{@code FloatBuffer} - compatible with float Tensors. 331 * <li>{@code IntBuffer} - compatible with int32 Tensors. 332 * <li>{@code LongBuffer} - compatible with int64 Tensors. 333 * </ul> 334 * 335 * Note that boolean types are only supported as arrays, not {@code Buffer}s, or as scalar inputs. 336 * 337 * @param input an array or multidimensional array, or a {@code Buffer} of primitive types 338 * including int, float, long, and byte. {@code Buffer} is the preferred way to pass large 339 * input data for primitive types, whereas string types require using the (multi-dimensional) 340 * array input path. When a {@code Buffer} is used, its content should remain unchanged until 341 * model inference is done, and the caller must ensure that the {@code Buffer} is at the 342 * appropriate read position. A {@code null} value is allowed only if the caller is using a 343 * {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to the 344 * input {@link Tensor}. 345 * @param output a multidimensional array of output data, or a {@code Buffer} of primitive types 346 * including int, float, long, and byte. When a {@code Buffer} is used, the caller must ensure 347 * that it is set the appropriate write position. A null value is allowed, and is useful for 348 * certain cases, e.g., if the caller is using a {@link Delegate} that allows buffer handle 349 * interop, and such a buffer has been bound to the output {@link Tensor} (see also <a 350 * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter.Options#setAllowBufferHandleOutput(boolean)">Interpreter.Options#setAllowBufferHandleOutput(boolean)</a>), 351 * or if the graph has dynamically shaped outputs and the caller must query the output {@link 352 * Tensor} shape after inference has been invoked, fetching the data directly from the output 353 * tensor (via {@link Tensor#asReadOnlyBuffer()}). 354 * @throws IllegalArgumentException if {@code input} is null or empty, or if an error occurs when 355 * running inference. 356 * @throws IllegalArgumentException (EXPERIMENTAL, subject to change) if the inference is 357 * interrupted by {@code setCancelled(true)}. 358 */ run(Object input, Object output)359 void run(Object input, Object output); 360 361 /** 362 * Runs model inference if the model takes multiple inputs, or returns multiple outputs. 363 * 364 * <p>Warning: The API is more efficient if {@code Buffer}s (preferably direct, but not required) 365 * are used as the input/output data types. Please consider using {@code Buffer} to feed and fetch 366 * primitive data for better performance. The following concrete {@code Buffer} types are 367 * supported: 368 * 369 * <ul> 370 * <li>{@code ByteBuffer} - compatible with any underlying primitive Tensor type. 371 * <li>{@code FloatBuffer} - compatible with float Tensors. 372 * <li>{@code IntBuffer} - compatible with int32 Tensors. 373 * <li>{@code LongBuffer} - compatible with int64 Tensors. 374 * </ul> 375 * 376 * Note that boolean types are only supported as arrays, not {@code Buffer}s, or as scalar inputs. 377 * 378 * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is 379 * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and 380 * such a buffer has been bound to the corresponding input or output {@link Tensor}(s). 381 * 382 * @param inputs an array of input data. The inputs should be in the same order as inputs of the 383 * model. Each input can be an array or multidimensional array, or a {@code Buffer} of 384 * primitive types including int, float, long, and byte. {@code Buffer} is the preferred way 385 * to pass large input data, whereas string types require using the (multi-dimensional) array 386 * input path. When {@code Buffer} is used, its content should remain unchanged until model 387 * inference is done, and the caller must ensure that the {@code Buffer} is at the appropriate 388 * read position. 389 * @param outputs a map mapping output indices to multidimensional arrays of output data or {@code 390 * Buffer}s of primitive types including int, float, long, and byte. It only needs to keep 391 * entries for the outputs to be used. When a {@code Buffer} is used, the caller must ensure 392 * that it is set the appropriate write position. The map may be empty for cases where either 393 * buffer handles are used for output tensor data, or cases where the outputs are dynamically 394 * shaped and the caller must query the output {@link Tensor} shape after inference has been 395 * invoked, fetching the data directly from the output tensor (via {@link 396 * Tensor#asReadOnlyBuffer()}). 397 * @throws IllegalArgumentException if {@code inputs} is null or empty, if {@code outputs} is 398 * null, or if an error occurs when running inference. 399 */ runForMultipleInputsOutputs( Object @onNull [] inputs, @NonNull Map<Integer, Object> outputs)400 void runForMultipleInputsOutputs( 401 Object @NonNull [] inputs, @NonNull Map<Integer, Object> outputs); 402 403 /** 404 * Explicitly updates allocations for all tensors, if necessary. 405 * 406 * <p>This will propagate shapes and memory allocations for dependent tensors using the input 407 * tensor shape(s) as given. 408 * 409 * <p>Note: This call is *purely optional*. Tensor allocation will occur automatically during 410 * execution if any input tensors have been resized. This call is most useful in determining the 411 * shapes for any output tensors before executing the graph, e.g., 412 * 413 * <pre>{@code 414 * interpreter.resizeInput(0, new int[]{1, 4, 4, 3})); 415 * interpreter.allocateTensors(); 416 * FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0).numElements()); 417 * // Populate inputs... 418 * FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements()); 419 * interpreter.run(input, output) 420 * // Process outputs... 421 * }</pre> 422 * 423 * <p>Note: Some graphs have dynamically shaped outputs, in which case the output shape may not 424 * fully propagate until inference is executed. 425 * 426 * @throws IllegalStateException if the graph's tensors could not be successfully allocated. 427 */ allocateTensors()428 void allocateTensors(); 429 430 /** 431 * Resizes idx-th input of the native model to the given dims. 432 * 433 * @throws IllegalArgumentException if {@code idx} is negative or is not smaller than the number 434 * of model inputs; or if error occurs when resizing the idx-th input. 435 */ resizeInput(int idx, @NonNull int[] dims)436 void resizeInput(int idx, @NonNull int[] dims); 437 438 /** 439 * Resizes idx-th input of the native model to the given dims. 440 * 441 * <p>When `strict` is True, only unknown dimensions can be resized. Unknown dimensions are 442 * indicated as `-1` in the array returned by `Tensor.shapeSignature()`. 443 * 444 * @throws IllegalArgumentException if {@code idx} is negative or is not smaller than the number 445 * of model inputs; or if error occurs when resizing the idx-th input. Additionally, the error 446 * occurs when attempting to resize a tensor with fixed dimensions when `strict` is True. 447 */ resizeInput(int idx, @NonNull int[] dims, boolean strict)448 void resizeInput(int idx, @NonNull int[] dims, boolean strict); 449 450 /** Gets the number of input tensors. */ getInputTensorCount()451 int getInputTensorCount(); 452 453 /** 454 * Gets index of an input given the op name of the input. 455 * 456 * @throws IllegalArgumentException if {@code opName} does not match any input in the model used 457 * to initialize the interpreter. 458 */ getInputIndex(String opName)459 int getInputIndex(String opName); 460 461 /** 462 * Gets the Tensor associated with the provided input index. 463 * 464 * @throws IllegalArgumentException if {@code inputIndex} is negative or is not smaller than the 465 * number of model inputs. 466 */ getInputTensor(int inputIndex)467 Tensor getInputTensor(int inputIndex); 468 469 /** Gets the number of output Tensors. */ getOutputTensorCount()470 int getOutputTensorCount(); 471 472 /** 473 * Gets index of an output given the op name of the output. 474 * 475 * @throws IllegalArgumentException if {@code opName} does not match any output in the model used 476 * to initialize the interpreter. 477 */ getOutputIndex(String opName)478 int getOutputIndex(String opName); 479 480 /** 481 * Gets the Tensor associated with the provided output index. 482 * 483 * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference 484 * is executed. If you need updated details *before* running inference (e.g., after resizing an 485 * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to 486 * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes 487 * that are dependent on input *values*, the output shape may not be fully determined until 488 * running inference. 489 * 490 * @throws IllegalArgumentException if {@code outputIndex} is negative or is not smaller than the 491 * number of model outputs. 492 */ getOutputTensor(int outputIndex)493 Tensor getOutputTensor(int outputIndex); 494 495 /** 496 * Returns native inference timing. 497 * 498 * @throws IllegalArgumentException if the model is not initialized by the interpreter. 499 */ getLastNativeInferenceDurationNanoseconds()500 Long getLastNativeInferenceDurationNanoseconds(); 501 502 /** Release resources associated with the {@code InterpreterApi} instance. */ 503 @Override close()504 void close(); 505 } 506