• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.io.File;
19 import java.nio.ByteBuffer;
20 import java.util.ArrayList;
21 import java.util.Collections;
22 import java.util.List;
23 import java.util.Map;
24 import org.checkerframework.checker.nullness.qual.NonNull;
25 import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;
26 import org.tensorflow.lite.nnapi.NnApiDelegate;
27 
28 /**
29  * Interface to TensorFlow Lite model interpreter, excluding experimental methods.
30  *
31  * <p>An {@code InterpreterApi} instance encapsulates a pre-trained TensorFlow Lite model, in which
32  * operations are executed for model inference.
33  *
34  * <p>For example, if a model takes only one input and returns only one output:
35  *
36  * <pre>{@code
37  * try (InterpreterApi interpreter =
38  *     new InterpreterApi.create(file_of_a_tensorflowlite_model)) {
39  *   interpreter.run(input, output);
40  * }
41  * }</pre>
42  *
43  * <p>If a model takes multiple inputs or outputs:
44  *
45  * <pre>{@code
46  * Object[] inputs = {input0, input1, ...};
47  * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
48  * FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4);  // Float tensor, shape 3x2x4.
49  * ith_output.order(ByteOrder.nativeOrder());
50  * map_of_indices_to_outputs.put(i, ith_output);
51  * try (InterpreterApi interpreter =
52  *     new InterpreterApi.create(file_of_a_tensorflowlite_model)) {
53  *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
54  * }
55  * }</pre>
56  *
57  * <p>If a model takes or produces string tensors:
58  *
59  * <pre>{@code
60  * String[] input = {"foo", "bar"};  // Input tensor shape is [2].
61  * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
62  * try (InterpreterApi interpreter =
63  *     new InterpreterApi.create(file_of_a_tensorflowlite_model)) {
64  *   interpreter.runForMultipleInputsOutputs(input, output);
65  * }
66  * }</pre>
67  *
68  * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
69  * model with Toco, as are the default shapes of the inputs.
70  *
71  * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
72  * be implicitly resized according to that array's shape. When inputs are provided as {@link
73  * java.nio.Buffer} types, no implicit resizing is done; the caller must ensure that the {@link
74  * java.nio.Buffer} byte size either matches that of the corresponding tensor, or that they first
75  * resize the tensor via {@link #resizeInput(int, int[])}. Tensor shape and type information can be
76  * obtained via the {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link
77  * #getOutputTensor(int)}.
78  *
79  * <p><b>WARNING:</b>{@code InterpreterApi} instances are <b>not</b> thread-safe.
80  *
81  * <p><b>WARNING:</b>An {@code InterpreterApi} instance owns resources that <b>must</b> be
82  * explicitly freed by invoking {@link #close()}
83  *
84  * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
85  * but is not guaranteed.
86  */
87 public interface InterpreterApi extends AutoCloseable {
88 
89   /** An options class for controlling runtime interpreter behavior. */
90   class Options {
91 
Options()92     public Options() {
93       this.delegates = new ArrayList<>();
94       this.delegateFactories = new ArrayList<>();
95     }
96 
Options(Options other)97     public Options(Options other) {
98       this.numThreads = other.numThreads;
99       this.useNNAPI = other.useNNAPI;
100       this.allowCancellation = other.allowCancellation;
101       this.delegates = new ArrayList<>(other.delegates);
102       this.delegateFactories = new ArrayList<>(other.delegateFactories);
103       this.runtime = other.runtime;
104     }
105 
106     /**
107      * Sets the number of threads to be used for ops that support multi-threading.
108      *
109      * <p>{@code numThreads} should be {@code >= -1}. Setting {@code numThreads} to 0 has the effect
110      * of disabling multithreading, which is equivalent to setting {@code numThreads} to 1. If
111      * unspecified, or set to the value -1, the number of threads used will be
112      * implementation-defined and platform-dependent.
113      */
setNumThreads(int numThreads)114     public Options setNumThreads(int numThreads) {
115       this.numThreads = numThreads;
116       return this;
117     }
118 
119     /**
120      * Returns the number of threads to be used for ops that support multi-threading.
121      *
122      * <p>{@code numThreads} should be {@code >= -1}. Values of 0 (or 1) disable multithreading.
123      * Default value is -1: the number of threads used will be implementation-defined and
124      * platform-dependent.
125      */
getNumThreads()126     public int getNumThreads() {
127       return numThreads;
128     }
129 
130     /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */
setUseNNAPI(boolean useNNAPI)131     public Options setUseNNAPI(boolean useNNAPI) {
132       this.useNNAPI = useNNAPI;
133       return this;
134     }
135 
136     /**
137      * Returns whether to use NN API (if available) for op execution. Default value is false
138      * (disabled).
139      */
getUseNNAPI()140     public boolean getUseNNAPI() {
141       return useNNAPI != null && useNNAPI;
142     }
143 
144     /**
145      * Advanced: Set if the interpreter is able to be cancelled.
146      *
147      * <p>Interpreters may have an experimental API <a
148      * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter#setCancelled(boolean)">setCancelled(boolean)</a>.
149      * If this interpreter is cancellable and such a method is invoked, a cancellation flag will be
150      * set to true. The interpreter will check the flag between Op invocations, and if it's {@code
151      * true}, the interpreter will stop execution. The interpreter will remain a cancelled state
152      * until explicitly "uncancelled" by {@code setCancelled(false)}.
153      */
setCancellable(boolean allow)154     public Options setCancellable(boolean allow) {
155       this.allowCancellation = allow;
156       return this;
157     }
158 
159     /**
160      * Advanced: Returns whether the interpreter is able to be cancelled.
161      *
162      * <p>Interpreters may have an experimental API <a
163      * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter#setCancelled(boolean)">setCancelled(boolean)</a>.
164      * If this interpreter is cancellable and such a method is invoked, a cancellation flag will be
165      * set to true. The interpreter will check the flag between Op invocations, and if it's {@code
166      * true}, the interpreter will stop execution. The interpreter will remain a cancelled state
167      * until explicitly "uncancelled" by {@code setCancelled(false)}.
168      */
isCancellable()169     public boolean isCancellable() {
170       return allowCancellation != null && allowCancellation;
171     }
172 
173     /**
174      * Adds a {@link Delegate} to be applied during interpreter creation.
175      *
176      * <p>Delegates added here are applied before any delegates created from a {@link
177      * DelegateFactory} that was added with {@link #addDelegateFactory}.
178      *
179      * <p>Note that TF Lite in Google Play Services (see {@link #setRuntime}) does not support
180      * external (developer-provided) delegates, and adding a {@link Delegate} other than {@link
181      * NnApiDelegate} here is not allowed when using TF Lite in Google Play Services.
182      */
addDelegate(Delegate delegate)183     public Options addDelegate(Delegate delegate) {
184       delegates.add(delegate);
185       return this;
186     }
187 
188     /**
189      * Returns the list of delegates intended to be applied during interpreter creation that have
190      * been registered via {@code addDelegate}.
191      */
getDelegates()192     public List<Delegate> getDelegates() {
193       return Collections.unmodifiableList(delegates);
194     }
195 
196     /**
197      * Adds a {@link DelegateFactory} which will be invoked to apply its created {@link Delegate}
198      * during interpreter creation.
199      *
200      * <p>Delegates from a delegated factory that was added here are applied after any delegates
201      * added with {@link #addDelegate}.
202      */
addDelegateFactory(DelegateFactory delegateFactory)203     public Options addDelegateFactory(DelegateFactory delegateFactory) {
204       delegateFactories.add(delegateFactory);
205       return this;
206     }
207 
208     /**
209      * Returns the list of delegate factories that have been registered via {@code
210      * addDelegateFactory}).
211      */
getDelegateFactories()212     public List<DelegateFactory> getDelegateFactories() {
213       return Collections.unmodifiableList(delegateFactories);
214     }
215 
216     /**
217      * Enum to represent where to get the TensorFlow Lite runtime implementation from.
218      *
219      * <p>The difference between this class and the RuntimeFlavor class: This class specifies a
220      * <em>preference</em> which runtime to use, whereas {@link RuntimeFlavor} specifies which exact
221      * runtime <em>is</em> being used.
222      */
223     public enum TfLiteRuntime {
224       /**
225        * Use a TF Lite runtime implementation that is linked into the application. If there is no
226        * suitable TF Lite runtime implementation linked into the application, then attempting to
227        * create an InterpreterApi instance with this TfLiteRuntime setting will throw an
228        * IllegalStateException exception (even if the OS or system services could provide a TF Lite
229        * runtime implementation).
230        *
231        * <p>This is the default setting. This setting is also appropriate for apps that must run on
232        * systems that don't provide a TF Lite runtime implementation.
233        */
234       FROM_APPLICATION_ONLY,
235 
236       /**
237        * Use a TF Lite runtime implementation provided by the OS or system services. This will be
238        * obtained from a system library / shared object / service, such as Google Play Services. It
239        * may be newer than the version linked into the application (if any). If there is no suitable
240        * TF Lite runtime implementation provided by the system, then attempting to create an
241        * InterpreterApi instance with this TfLiteRuntime setting will throw an IllegalStateException
242        * exception (even if there is a TF Lite runtime implementation linked into the application).
243        *
244        * <p>This setting is appropriate for code that will use a system-provided TF Lite runtime,
245        * which can reduce app binary size and can be updated more frequently.
246        */
247       FROM_SYSTEM_ONLY,
248 
249       /**
250        * Use a system-provided TF Lite runtime implementation, if any, otherwise use the TF Lite
251        * runtime implementation linked into the application, if any. If no suitable TF Lite runtime
252        * can be found in any location, then attempting to create an InterpreterApi instance with
253        * this TFLiteRuntime setting will throw an IllegalStateException. If there is both a suitable
254        * TF Lite runtime linked into the application and also a suitable TF Lite runtime provided by
255        * the system, the one provided by the system will be used.
256        *
257        * <p>This setting is suitable for use in code that doesn't care where the TF Lite runtime is
258        * coming from (e.g. middleware layers).
259        */
260       PREFER_SYSTEM_OVER_APPLICATION,
261     }
262 
263     /** Specify where to get the TF Lite runtime implementation from. */
setRuntime(TfLiteRuntime runtime)264     public Options setRuntime(TfLiteRuntime runtime) {
265       this.runtime = runtime;
266       return this;
267     }
268 
269     /** Return where to get the TF Lite runtime implementation from. */
getRuntime()270     public TfLiteRuntime getRuntime() {
271       return runtime;
272     }
273 
274     TfLiteRuntime runtime = TfLiteRuntime.FROM_APPLICATION_ONLY;
275     int numThreads = -1;
276     Boolean useNNAPI;
277     Boolean allowCancellation;
278 
279     // See InterpreterApi.Options#addDelegate.
280     final List<Delegate> delegates;
281     // See InterpreterApi.Options#addDelegateFactory.
282     private final List<DelegateFactory> delegateFactories;
283   }
284 
285   /**
286    * Constructs an {@link InterpreterApi} instance, using the specified model and options. The model
287    * will be loaded from a file.
288    *
289    * @param modelFile A file containing a pre-trained TF Lite model.
290    * @param options A set of options for customizing interpreter behavior.
291    * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
292    *     model.
293    */
294   @SuppressWarnings("StaticOrDefaultInterfaceMethod")
create(@onNull File modelFile, InterpreterApi.Options options)295   static InterpreterApi create(@NonNull File modelFile, InterpreterApi.Options options) {
296     TfLiteRuntime runtime = (options == null ? null : options.getRuntime());
297     InterpreterFactoryApi factory = TensorFlowLite.getFactory(runtime);
298     return factory.create(modelFile, options);
299   }
300 
301   /**
302    * Constructs an {@link InterpreterApi} instance, using the specified model and options. The model
303    * will be read from a {@code ByteBuffer}.
304    *
305    * @param byteBuffer A pre-trained TF Lite model, in binary serialized form. The ByteBuffer should
306    *     not be modified after the construction of an {@link InterpreterApi} instance. The {@code
307    *     ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
308    *     direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
309    * @param options A set of options for customizing interpreter behavior.
310    * @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a
311    *     direct {@code ByteBuffer} of nativeOrder.
312    */
313   @SuppressWarnings("StaticOrDefaultInterfaceMethod")
create(@onNull ByteBuffer byteBuffer, InterpreterApi.Options options)314   static InterpreterApi create(@NonNull ByteBuffer byteBuffer, InterpreterApi.Options options) {
315     TfLiteRuntime runtime = (options == null ? null : options.getRuntime());
316     InterpreterFactoryApi factory = TensorFlowLite.getFactory(runtime);
317     return factory.create(byteBuffer, options);
318   }
319 
320   /**
321    * Runs model inference if the model takes only one input, and provides only one output.
322    *
323    * <p>Warning: The API is more efficient if a {@code Buffer} (preferably direct, but not required)
324    * is used as the input/output data type. Please consider using {@code Buffer} to feed and fetch
325    * primitive data for better performance. The following concrete {@code Buffer} types are
326    * supported:
327    *
328    * <ul>
329    *   <li>{@code ByteBuffer} - compatible with any underlying primitive Tensor type.
330    *   <li>{@code FloatBuffer} - compatible with float Tensors.
331    *   <li>{@code IntBuffer} - compatible with int32 Tensors.
332    *   <li>{@code LongBuffer} - compatible with int64 Tensors.
333    * </ul>
334    *
335    * Note that boolean types are only supported as arrays, not {@code Buffer}s, or as scalar inputs.
336    *
337    * @param input an array or multidimensional array, or a {@code Buffer} of primitive types
338    *     including int, float, long, and byte. {@code Buffer} is the preferred way to pass large
339    *     input data for primitive types, whereas string types require using the (multi-dimensional)
340    *     array input path. When a {@code Buffer} is used, its content should remain unchanged until
341    *     model inference is done, and the caller must ensure that the {@code Buffer} is at the
342    *     appropriate read position. A {@code null} value is allowed only if the caller is using a
343    *     {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to the
344    *     input {@link Tensor}.
345    * @param output a multidimensional array of output data, or a {@code Buffer} of primitive types
346    *     including int, float, long, and byte. When a {@code Buffer} is used, the caller must ensure
347    *     that it is set the appropriate write position. A null value is allowed, and is useful for
348    *     certain cases, e.g., if the caller is using a {@link Delegate} that allows buffer handle
349    *     interop, and such a buffer has been bound to the output {@link Tensor} (see also <a
350    *     href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter.Options#setAllowBufferHandleOutput(boolean)">Interpreter.Options#setAllowBufferHandleOutput(boolean)</a>),
351    *     or if the graph has dynamically shaped outputs and the caller must query the output {@link
352    *     Tensor} shape after inference has been invoked, fetching the data directly from the output
353    *     tensor (via {@link Tensor#asReadOnlyBuffer()}).
354    * @throws IllegalArgumentException if {@code input} is null or empty, or if an error occurs when
355    *     running inference.
356    * @throws IllegalArgumentException (EXPERIMENTAL, subject to change) if the inference is
357    *     interrupted by {@code setCancelled(true)}.
358    */
run(Object input, Object output)359   void run(Object input, Object output);
360 
361   /**
362    * Runs model inference if the model takes multiple inputs, or returns multiple outputs.
363    *
364    * <p>Warning: The API is more efficient if {@code Buffer}s (preferably direct, but not required)
365    * are used as the input/output data types. Please consider using {@code Buffer} to feed and fetch
366    * primitive data for better performance. The following concrete {@code Buffer} types are
367    * supported:
368    *
369    * <ul>
370    *   <li>{@code ByteBuffer} - compatible with any underlying primitive Tensor type.
371    *   <li>{@code FloatBuffer} - compatible with float Tensors.
372    *   <li>{@code IntBuffer} - compatible with int32 Tensors.
373    *   <li>{@code LongBuffer} - compatible with int64 Tensors.
374    * </ul>
375    *
376    * Note that boolean types are only supported as arrays, not {@code Buffer}s, or as scalar inputs.
377    *
378    * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is
379    * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and
380    * such a buffer has been bound to the corresponding input or output {@link Tensor}(s).
381    *
382    * @param inputs an array of input data. The inputs should be in the same order as inputs of the
383    *     model. Each input can be an array or multidimensional array, or a {@code Buffer} of
384    *     primitive types including int, float, long, and byte. {@code Buffer} is the preferred way
385    *     to pass large input data, whereas string types require using the (multi-dimensional) array
386    *     input path. When {@code Buffer} is used, its content should remain unchanged until model
387    *     inference is done, and the caller must ensure that the {@code Buffer} is at the appropriate
388    *     read position.
389    * @param outputs a map mapping output indices to multidimensional arrays of output data or {@code
390    *     Buffer}s of primitive types including int, float, long, and byte. It only needs to keep
391    *     entries for the outputs to be used. When a {@code Buffer} is used, the caller must ensure
392    *     that it is set the appropriate write position. The map may be empty for cases where either
393    *     buffer handles are used for output tensor data, or cases where the outputs are dynamically
394    *     shaped and the caller must query the output {@link Tensor} shape after inference has been
395    *     invoked, fetching the data directly from the output tensor (via {@link
396    *     Tensor#asReadOnlyBuffer()}).
397    * @throws IllegalArgumentException if {@code inputs} is null or empty, if {@code outputs} is
398    *     null, or if an error occurs when running inference.
399    */
runForMultipleInputsOutputs( Object @onNull [] inputs, @NonNull Map<Integer, Object> outputs)400   void runForMultipleInputsOutputs(
401       Object @NonNull [] inputs, @NonNull Map<Integer, Object> outputs);
402 
403   /**
404    * Explicitly updates allocations for all tensors, if necessary.
405    *
406    * <p>This will propagate shapes and memory allocations for dependent tensors using the input
407    * tensor shape(s) as given.
408    *
409    * <p>Note: This call is *purely optional*. Tensor allocation will occur automatically during
410    * execution if any input tensors have been resized. This call is most useful in determining the
411    * shapes for any output tensors before executing the graph, e.g.,
412    *
413    * <pre>{@code
414    * interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
415    * interpreter.allocateTensors();
416    * FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0).numElements());
417    * // Populate inputs...
418    * FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
419    * interpreter.run(input, output)
420    * // Process outputs...
421    * }</pre>
422    *
423    * <p>Note: Some graphs have dynamically shaped outputs, in which case the output shape may not
424    * fully propagate until inference is executed.
425    *
426    * @throws IllegalStateException if the graph's tensors could not be successfully allocated.
427    */
allocateTensors()428   void allocateTensors();
429 
430   /**
431    * Resizes idx-th input of the native model to the given dims.
432    *
433    * @throws IllegalArgumentException if {@code idx} is negative or is not smaller than the number
434    *     of model inputs; or if error occurs when resizing the idx-th input.
435    */
resizeInput(int idx, @NonNull int[] dims)436   void resizeInput(int idx, @NonNull int[] dims);
437 
438   /**
439    * Resizes idx-th input of the native model to the given dims.
440    *
441    * <p>When `strict` is True, only unknown dimensions can be resized. Unknown dimensions are
442    * indicated as `-1` in the array returned by `Tensor.shapeSignature()`.
443    *
444    * @throws IllegalArgumentException if {@code idx} is negative or is not smaller than the number
445    *     of model inputs; or if error occurs when resizing the idx-th input. Additionally, the error
446    *     occurs when attempting to resize a tensor with fixed dimensions when `strict` is True.
447    */
resizeInput(int idx, @NonNull int[] dims, boolean strict)448   void resizeInput(int idx, @NonNull int[] dims, boolean strict);
449 
450   /** Gets the number of input tensors. */
getInputTensorCount()451   int getInputTensorCount();
452 
453   /**
454    * Gets index of an input given the op name of the input.
455    *
456    * @throws IllegalArgumentException if {@code opName} does not match any input in the model used
457    *     to initialize the interpreter.
458    */
getInputIndex(String opName)459   int getInputIndex(String opName);
460 
461   /**
462    * Gets the Tensor associated with the provided input index.
463    *
464    * @throws IllegalArgumentException if {@code inputIndex} is negative or is not smaller than the
465    *     number of model inputs.
466    */
getInputTensor(int inputIndex)467   Tensor getInputTensor(int inputIndex);
468 
469   /** Gets the number of output Tensors. */
getOutputTensorCount()470   int getOutputTensorCount();
471 
472   /**
473    * Gets index of an output given the op name of the output.
474    *
475    * @throws IllegalArgumentException if {@code opName} does not match any output in the model used
476    *     to initialize the interpreter.
477    */
getOutputIndex(String opName)478   int getOutputIndex(String opName);
479 
480   /**
481    * Gets the Tensor associated with the provided output index.
482    *
483    * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
484    * is executed. If you need updated details *before* running inference (e.g., after resizing an
485    * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
486    * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
487    * that are dependent on input *values*, the output shape may not be fully determined until
488    * running inference.
489    *
490    * @throws IllegalArgumentException if {@code outputIndex} is negative or is not smaller than the
491    *     number of model outputs.
492    */
getOutputTensor(int outputIndex)493   Tensor getOutputTensor(int outputIndex);
494 
495   /**
496    * Returns native inference timing.
497    *
498    * @throws IllegalArgumentException if the model is not initialized by the interpreter.
499    */
getLastNativeInferenceDurationNanoseconds()500   Long getLastNativeInferenceDurationNanoseconds();
501 
502   /** Release resources associated with the {@code InterpreterApi} instance. */
503   @Override
close()504   void close();
505 }
506