• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.io.File;
19 import java.nio.ByteBuffer;
20 import java.nio.MappedByteBuffer;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23 import java.util.HashMap;
24 import java.util.List;
25 import java.util.Map;
26 // Use Android annotation instead
27 import android.support.annotation.NonNull;
28 // import org.checkerframework.checker.nullness.qual.NonNull;
29 
30 /**
31  * Driver class to drive model inference with TensorFlow Lite.
32  *
33  * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
34  * are executed for model inference.
35  *
36  * <p>For example, if a model takes only one input and returns only one output:
37  *
38  * <pre>{@code
39  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
40  *   interpreter.run(input, output);
41  * }
42  * }</pre>
43  *
44  * <p>If a model takes multiple inputs or outputs:
45  *
46  * <pre>{@code
47  * Object[] inputs = {input0, input1, ...};
48  * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
49  * FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4);  // Float tensor, shape 3x2x4.
50  * ith_output.order(ByteOrder.nativeOrder());
51  * map_of_indices_to_outputs.put(i, ith_output);
52  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
53  *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
54  * }
55  * }</pre>
56  *
57  * <p>If a model takes or produces string tensors:
58  *
59  * <pre>{@code
60  * String[] input = {"foo", "bar"};  // Input tensor shape is [2].
61  * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
62  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
63  *   interpreter.runForMultipleInputsOutputs(input, output);
64  * }
65  * }</pre>
66  *
67  * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
68  * model with Toco, as are the default shapes of the inputs.
69  *
70  * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
71  * be implicitly resized according to that array's shape. When inputs are provided as {@link
72  * java.nio.Buffer} types, no implicit resizing is done; the caller must ensure that the {@link
73  * java.nio.Buffer} byte size either matches that of the corresponding tensor, or that they first
74  * resize the tensor via {@link #resizeInput(int, int[])}. Tensor shape and type information can be
75  * obtained via the {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link
76  * #getOutputTensor(int)}.
77  *
78  * <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code
79  * Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
80  *
81  * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
82  * but is not guaranteed.
83  *
84  * <p>Note: This class is not thread safe.
85  */
86 public final class Interpreter implements AutoCloseable {
87 
88   /** An options class for controlling runtime interpreter behavior. */
89   public static class Options {
Options()90     public Options() {}
91 
92     /**
93      * Sets the number of threads to be used for ops that support multi-threading. Defaults to a
94      * platform-dependent value.
95      */
setNumThreads(int numThreads)96     public Options setNumThreads(int numThreads) {
97       this.numThreads = numThreads;
98       return this;
99     }
100 
101     /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */
setUseNNAPI(boolean useNNAPI)102     public Options setUseNNAPI(boolean useNNAPI) {
103       this.useNNAPI = useNNAPI;
104       return this;
105     }
106 
107     /**
108      * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
109      * (disallow).
110      *
111      * @deprecated Prefer using {@link
112      *     org.tensorflow.lite.nnapi.NnApiDelegate.Options#setAllowFp16(boolean enable)}.
113      */
114     @Deprecated
setAllowFp16PrecisionForFp32(boolean allow)115     public Options setAllowFp16PrecisionForFp32(boolean allow) {
116       this.allowFp16PrecisionForFp32 = allow;
117       return this;
118     }
119 
120     /**
121      * Adds a {@link Delegate} to be applied during interpreter creation.
122      *
123      * <p>WARNING: This is an experimental interface that is subject to change.
124      */
addDelegate(Delegate delegate)125     public Options addDelegate(Delegate delegate) {
126       delegates.add(delegate);
127       return this;
128     }
129 
130     /**
131      * Advanced: Set if buffer handle output is allowed.
132      *
133      * <p>When a {@link Delegate} supports hardware acceleration, the interpreter will make the data
134      * of output tensors available in the CPU-allocated tensor buffers by default. If the client can
135      * consume the buffer handle directly (e.g. reading output from OpenGL texture), it can set this
136      * flag to false, avoiding the copy of data to the CPU buffer. The delegate documentation should
137      * indicate whether this is supported and how it can be used.
138      *
139      * <p>WARNING: This is an experimental interface that is subject to change.
140      */
setAllowBufferHandleOutput(boolean allow)141     public Options setAllowBufferHandleOutput(boolean allow) {
142       this.allowBufferHandleOutput = allow;
143       return this;
144     }
145 
146     /**
147      * Advanced: Set if the interpreter is able to be cancelled.
148      *
149      * @see {@link Interpreter#setCancelled(boolean)}.
150      */
setCancellable(boolean allow)151     public Options setCancellable(boolean allow) {
152       this.allowCancellation = allow;
153       return this;
154     }
155 
156     /**
157      * Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK).
158      *
159      * <p>Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
160      * via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
161      * operations. Eventually, we plan to enable this by default, as it can provide significant
162      * peformance benefits for many classes of floating point models. See
163      * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
164      * for more details.
165      *
166      * <p>Things to keep in mind when enabling this flag:
167      *
168      * <ul>
169      *   <li>Startup time and resize time may increase.
170      *   <li>Baseline memory consumption may increase.
171      *   <li>May be ignored if another delegate (eg NNAPI) have been applied.
172      *   <li>Quantized models will not see any benefit.
173      * </ul>
174      *
175      * <p>WARNING: This is an experimental interface that is subject to change.
176      */
setUseXNNPACK(boolean useXNNPACK)177     public Options setUseXNNPACK(boolean useXNNPACK) {
178       this.useXNNPACK = useXNNPACK;
179       return this;
180     }
181 
182     int numThreads = -1;
183     Boolean useNNAPI;
184     Boolean allowFp16PrecisionForFp32;
185     Boolean allowBufferHandleOutput;
186     Boolean allowCancellation;
187 
188     // TODO(b/171856982): update the comment when applying XNNPACK delegate by default is
189     // enabled for C++ TfLite library on Android platform.
190     // Note: the initial "null" value indicates default behavior which may mean XNNPACK
191     // delegate will be applied by default.
192     Boolean useXNNPACK;
193     final List<Delegate> delegates = new ArrayList<>();
194   }
195 
196   /**
197    * Initializes a {@code Interpreter}
198    *
199    * @param modelFile: a File of a pre-trained TF Lite model.
200    * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
201    *     model.
202    */
Interpreter(@onNull File modelFile)203   public Interpreter(@NonNull File modelFile) {
204     this(modelFile, /*options = */ null);
205   }
206 
207   /**
208    * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
209    *
210    * @param modelFile: a file of a pre-trained TF Lite model
211    * @param numThreads: number of threads to use for inference
212    * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will
213    *     be removed in a future release.
214    */
215   @Deprecated
Interpreter(@onNull File modelFile, int numThreads)216   public Interpreter(@NonNull File modelFile, int numThreads) {
217     this(modelFile, new Options().setNumThreads(numThreads));
218   }
219 
220   /**
221    * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
222    *
223    * @param modelFile: a file of a pre-trained TF Lite model
224    * @param options: a set of options for customizing interpreter behavior
225    * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
226    *     model.
227    */
Interpreter(@onNull File modelFile, Options options)228   public Interpreter(@NonNull File modelFile, Options options) {
229     wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
230     signatureNameList = getSignatureDefNames();
231   }
232 
233   /**
234    * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file.
235    *
236    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
237    * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
238    * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
239    *
240    * @throws IllegalArgumentException if {@code byteBuffer} is not a {@link MappedByteBuffer} nor a
241    *     direct {@link Bytebuffer} of nativeOrder.
242    */
Interpreter(@onNull ByteBuffer byteBuffer)243   public Interpreter(@NonNull ByteBuffer byteBuffer) {
244     this(byteBuffer, /* options= */ null);
245   }
246 
247   /**
248    * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and specifies the
249    * number of threads used for inference.
250    *
251    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
252    * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
253    * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
254    *
255    * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
256    *     will be removed in a future release.
257    */
258   @Deprecated
Interpreter(@onNull ByteBuffer byteBuffer, int numThreads)259   public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
260     this(byteBuffer, new Options().setNumThreads(numThreads));
261   }
262 
263   /**
264    * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
265    *
266    * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
267    * Interpreter}.
268    *
269    * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
270    *     will be removed in a future release.
271    */
272   @Deprecated
Interpreter(@onNull MappedByteBuffer mappedByteBuffer)273   public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) {
274     this(mappedByteBuffer, /* options= */ null);
275   }
276 
277   /**
278    * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
279    * {@link Interpreter.Options}.
280    *
281    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
282    * {@code ByteBuffer} can be either a {@link MappedByteBuffer} that memory-maps a model file, or a
283    * direct {@link ByteBuffer} of nativeOrder() that contains the bytes content of a model.
284    *
285    * @throws IllegalArgumentException if {@code byteBuffer} is not a {@link MappedByteBuffer} nor a
286    *     direct {@link Bytebuffer} of nativeOrder.
287    */
Interpreter(@onNull ByteBuffer byteBuffer, Options options)288   public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
289     wrapper = new NativeInterpreterWrapper(byteBuffer, options);
290     signatureNameList = getSignatureDefNames();
291   }
292 
293   /**
294    * Runs model inference if the model takes only one input, and provides only one output.
295    *
296    * <p>Warning: The API is more efficient if a {@link java.nio.Buffer} (preferably direct, but not
297    * required) is used as the input/output data type. Please consider using {@link java.nio.Buffer}
298    * to feed and fetch primitive data for better performance. The following concrete {@link
299    * java.nio.Buffer} types are supported:
300    *
301    * <ul>
302    *   <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
303    *   <li>{@link java.nio.FloatBuffer} - compatible with float Tensors.
304    *   <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors.
305    *   <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors.
306    * </ul>
307    *
308    * Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as
309    * scalar inputs.
310    *
311    * @param input an array or multidimensional array, or a {@link java.nio.Buffer} of primitive
312    *     types including int, float, long, and byte. {@link java.nio.Buffer} is the preferred way to
313    *     pass large input data for primitive types, whereas string types require using the
314    *     (multi-dimensional) array input path. When a {@link java.nio.Buffer} is used, its content
315    *     should remain unchanged until model inference is done, and the caller must ensure that the
316    *     {@link java.nio.Buffer} is at the appropriate read position. A {@code null} value is
317    *     allowed only if the caller is using a {@link Delegate} that allows buffer handle interop,
318    *     and such a buffer has been bound to the input {@link Tensor}.
319    * @param output a multidimensional array of output data, or a {@link java.nio.Buffer} of
320    *     primitive types including int, float, long, and byte. When a {@link java.nio.Buffer} is
321    *     used, the caller must ensure that it is set the appropriate write position. A null value is
322    *     allowed only if the caller is using a {@link Delegate} that allows buffer handle interop,
323    *     and such a buffer has been bound to the output {@link Tensor}. See {@link
324    *     Interpreter.Options#setAllowBufferHandleOutput(boolean)}.
325    * @throws IllegalArgumentException if {@code input} or {@code output} is null or empty, or if
326    *     error occurs when running the inference.
327    * @throws IllegalArgumentException (EXPERIMENTAL, subject to change) if the inference is
328    *     interrupted by {@code setCancelled(true)}.
329    */
run(Object input, Object output)330   public void run(Object input, Object output) {
331     Object[] inputs = {input};
332     Map<Integer, Object> outputs = new HashMap<>();
333     outputs.put(0, output);
334     runForMultipleInputsOutputs(inputs, outputs);
335   }
336 
337   /**
338    * Runs model inference if the model takes multiple inputs, or returns multiple outputs.
339    *
340    * <p>Warning: The API is more efficient if {@link java.nio.Buffer}s (preferably direct, but not
341    * required) are used as the input/output data types. Please consider using {@link
342    * java.nio.Buffer} to feed and fetch primitive data for better performance. The following
343    * concrete {@link java.nio.Buffer} types are supported:
344    *
345    * <ul>
346    *   <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
347    *   <li>{@link java.nio.FloatBuffer} - compatible with float Tensors.
348    *   <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors.
349    *   <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors.
350    * </ul>
351    *
352    * Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as
353    * scalar inputs.
354    *
355    * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is
356    * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and
357    * such a buffer has been bound to the corresponding input or output {@link Tensor}(s).
358    *
359    * @param inputs an array of input data. The inputs should be in the same order as inputs of the
360    *     model. Each input can be an array or multidimensional array, or a {@link java.nio.Buffer}
361    *     of primitive types including int, float, long, and byte. {@link java.nio.Buffer} is the
362    *     preferred way to pass large input data, whereas string types require using the
363    *     (multi-dimensional) array input path. When {@link java.nio.Buffer} is used, its content
364    *     should remain unchanged until model inference is done, and the caller must ensure that the
365    *     {@link java.nio.Buffer} is at the appropriate read position.
366    * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
367    *     java.nio.Buffer}s of primitive types including int, float, long, and byte. It only needs to
368    *     keep entries for the outputs to be used. When a {@link java.nio.Buffer} is used, the caller
369    *     must ensure that it is set the appropriate write position.
370    * @throws IllegalArgumentException if {@code inputs} or {@code outputs} is null or empty, or if
371    *     error occurs when running the inference.
372    */
runForMultipleInputsOutputs( @onNull Object[] inputs, @NonNull Map<Integer, Object> outputs)373   public void runForMultipleInputsOutputs(
374       @NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
375     checkNotClosed();
376     wrapper.run(inputs, outputs);
377   }
378 
379   /**
380    * Runs model inference based on SignatureDef provided through @code methodName.
381    *
382    * <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output
383    * data types.
384    *
385    * @param inputs A Map of inputs from input name in the signatureDef to an input object.
386    * @param outputs a map mapping from output name in SignatureDef to output data.
387    * @param methodName The exported method name identifying the SignatureDef.
388    * @throws IllegalArgumentException if {@code inputs} or {@code outputs} or {@code methodName} is
389    *     null or empty, or if error occurs when running the inference.
390    *
391    * <p>WARNING: This is an experimental API and subject to change.
392    */
runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs, String methodName)393   public void runSignature(
394       @NonNull Map<String, Object> inputs,
395       @NonNull Map<String, Object> outputs,
396       String methodName) {
397     checkNotClosed();
398     if (methodName == null && signatureNameList.length == 1) {
399       methodName = signatureNameList[0];
400     }
401     if (methodName == null) {
402       throw new IllegalArgumentException(
403           "Input error: SignatureDef methodName should not be null. null is only allowed if the"
404               + " model has a single Signature. Available Signatures: "
405               + Arrays.toString(signatureNameList));
406     }
407     wrapper.runSignature(inputs, outputs, methodName);
408   }
409 
410   /* Same as {@link Interpreter#runSignature(Object, Object, String)} but doesn't require
411    * passing a methodName, assuming the model has one SignatureDef. If the model has more than
412    * one SignatureDef it will throw an exception.
413    *
414    * * <p>WARNING: This is an experimental API and subject to change.
415    * */
runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs)416   public void runSignature(
417       @NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) {
418     checkNotClosed();
419     runSignature(inputs, outputs, null);
420   }
421 
422   /**
423    * Expicitly updates allocations for all tensors, if necessary.
424    *
425    * <p>This will propagate shapes and memory allocations for all dependent tensors using the input
426    * tensor shape(s) as given.
427    *
428    * <p>Note: This call is *purely optional*. Tensor allocation will occur automatically during
429    * execution if any input tensors have been resized. This call is most useful in determining the
430    * shapes for any output tensors before executing the graph, e.g.,
431    *
432    * <pre>{@code
433    * interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
434    * interpreter.allocateTensors();
435    * FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0),numElements());
436    * // Populate inputs...
437    * FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
438    * interpreter.run(input, output)
439    * // Process outputs...
440    * }</pre>
441    *
442    * @throws IllegalStateException if the graph's tensors could not be successfully allocated.
443    */
allocateTensors()444   public void allocateTensors() {
445     checkNotClosed();
446     wrapper.allocateTensors();
447   }
448 
449   /**
450    * Resizes idx-th input of the native model to the given dims.
451    *
452    * @throws IllegalArgumentException if {@code idx} is negtive or is not smaller than the number of
453    *     model inputs; or if error occurs when resizing the idx-th input.
454    */
resizeInput(int idx, @NonNull int[] dims)455   public void resizeInput(int idx, @NonNull int[] dims) {
456     checkNotClosed();
457     wrapper.resizeInput(idx, dims, false);
458   }
459 
460   /**
461    * Resizes idx-th input of the native model to the given dims.
462    *
463    * <p>When `strict` is True, only unknown dimensions can be resized. Unknown dimensions are
464    * indicated as `-1` in the array returned by `Tensor.shapeSignature()`.
465    *
466    * @throws IllegalArgumentException if {@code idx} is negtive or is not smaller than the number of
467    *     model inputs; or if error occurs when resizing the idx-th input. Additionally, the error
468    *     occurs when attempting to resize a tensor with fixed dimensions when `struct` is True.
469    */
resizeInput(int idx, @NonNull int[] dims, boolean strict)470   public void resizeInput(int idx, @NonNull int[] dims, boolean strict) {
471     checkNotClosed();
472     wrapper.resizeInput(idx, dims, strict);
473   }
474 
475   /** Gets the number of input tensors. */
getInputTensorCount()476   public int getInputTensorCount() {
477     checkNotClosed();
478     return wrapper.getInputTensorCount();
479   }
480 
481   /**
482    * Gets index of an input given the op name of the input.
483    *
484    * @throws IllegalArgumentException if {@code opName} does not match any input in the model used
485    *     to initialize the {@link Interpreter}.
486    */
getInputIndex(String opName)487   public int getInputIndex(String opName) {
488     checkNotClosed();
489     return wrapper.getInputIndex(opName);
490   }
491 
492   /**
493    * Gets the Tensor associated with the provdied input index.
494    *
495    * @throws IllegalArgumentException if {@code inputIndex} is negtive or is not smaller than the
496    *     number of model inputs.
497    */
getInputTensor(int inputIndex)498   public Tensor getInputTensor(int inputIndex) {
499     checkNotClosed();
500     return wrapper.getInputTensor(inputIndex);
501   }
502 
503   /**
504    * Gets the Tensor associated with the provdied input name and signature method name.
505    *
506    * @param inputName Input name in the signature.
507    * @param methodName The exported method name identifying the SignatureDef, can be null if the
508    *     model has one signature.
509    * @throws IllegalArgumentException if {@code inputName} or {@code methodName} is null or empty,
510    *     or invalid name provided.
511    *
512    * <p>WARNING: This is an experimental API and subject to change.
513    */
getInputTensorFromSignature(String inputName, String methodName)514   public Tensor getInputTensorFromSignature(String inputName, String methodName) {
515     checkNotClosed();
516     if (methodName == null && signatureNameList.length == 1) {
517       methodName = signatureNameList[0];
518     }
519     if (methodName == null) {
520       throw new IllegalArgumentException(
521           "Input error: SignatureDef methodName should not be null. null is only allowed if the"
522               + " model has a single Signature. Available Signatures: "
523               + Arrays.toString(signatureNameList));
524     }
525     return wrapper.getInputTensor(inputName, methodName);
526   }
527 
528   /**
529    * Gets the list of SignatureDef exported method names available in the model.
530    *
531    * <p>WARNING: This is an experimental API and subject to change.
532    */
getSignatureDefNames()533   public String[] getSignatureDefNames() {
534     checkNotClosed();
535     return wrapper.getSignatureDefNames();
536   }
537 
538   /**
539    * Gets the list of SignatureDefs inputs for method {@code methodName}
540    *
541    * <p>WARNING: This is an experimental API and subject to change.
542    */
getSignatureInputs(String methodName)543   public String[] getSignatureInputs(String methodName) {
544     checkNotClosed();
545     return wrapper.getSignatureInputs(methodName);
546   }
547 
548   /**
549    * Gets the list of SignatureDefs outputs for method {@code methodName}
550    *
551    * <p>WARNING: This is an experimental API and subject to change.
552    */
getSignatureOutputs(String methodName)553   public String[] getSignatureOutputs(String methodName) {
554     checkNotClosed();
555     return wrapper.getSignatureOutputs(methodName);
556   }
557 
558   /** Gets the number of output Tensors. */
getOutputTensorCount()559   public int getOutputTensorCount() {
560     checkNotClosed();
561     return wrapper.getOutputTensorCount();
562   }
563 
564   /**
565    * Gets index of an output given the op name of the output.
566    *
567    * @throws IllegalArgumentException if {@code opName} does not match any output in the model used
568    *     to initialize the {@link Interpreter}.
569    */
getOutputIndex(String opName)570   public int getOutputIndex(String opName) {
571     checkNotClosed();
572     return wrapper.getOutputIndex(opName);
573   }
574 
575   /**
576    * Gets the Tensor associated with the provdied output index.
577    *
578    * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
579    * is executed. If you need updated details *before* running inference (e.g., after resizing an
580    * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
581    * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
582    * that are dependent on input *values*, the output shape may not be fully determined until
583    * running inference.
584    *
585    * @throws IllegalArgumentException if {@code outputIndex} is negtive or is not smaller than the
586    *     number of model outputs.
587    */
getOutputTensor(int outputIndex)588   public Tensor getOutputTensor(int outputIndex) {
589     checkNotClosed();
590     return wrapper.getOutputTensor(outputIndex);
591   }
592 
593   /**
594    * Gets the Tensor associated with the provdied output name in specifc signature method.
595    *
596    * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
597    * is executed. If you need updated details *before* running inference (e.g., after resizing an
598    * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
599    * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
600    * that are dependent on input *values*, the output shape may not be fully determined until
601    * running inference.
602    *
603    * @param outputName Output name in the signature.
604    * @param methodName The exported method name identifying the SignatureDef, can be null if the
605    *     model has one signature.
606    * @throws IllegalArgumentException if {@code outputName} or {@code methodName} is null or empty,
607    *     or invalid name provided.
608    *
609    * <p>WARNING: This is an experimental API and subject to change.
610    */
getOutputTensorFromSignature(String outputName, String methodName)611   public Tensor getOutputTensorFromSignature(String outputName, String methodName) {
612     checkNotClosed();
613     if (methodName == null && signatureNameList.length == 1) {
614       methodName = signatureNameList[0];
615     }
616     if (methodName == null) {
617       throw new IllegalArgumentException(
618           "Input error: SignatureDef methodName should not be null. null is only allowed if the"
619               + " model has a single Signature. Available Signatures: "
620               + Arrays.toString(signatureNameList));
621     }
622     return wrapper.getOutputTensor(outputName, methodName);
623   }
624 
625   /**
626    * Returns native inference timing.
627    *
628    * @throws IllegalArgumentException if the model is not initialized by the {@link Interpreter}.
629    */
getLastNativeInferenceDurationNanoseconds()630   public Long getLastNativeInferenceDurationNanoseconds() {
631     checkNotClosed();
632     return wrapper.getLastNativeInferenceDurationNanoseconds();
633   }
634 
635   /**
636    * Sets the number of threads to be used for ops that support multi-threading.
637    *
638    * @deprecated Prefer using {@link Interpreter.Options#setNumThreads(int)} directly for
639    *     controlling thread multi-threading. This method will be removed in a future release.
640    */
641   @Deprecated
setNumThreads(int numThreads)642   public void setNumThreads(int numThreads) {
643     checkNotClosed();
644     wrapper.setNumThreads(numThreads);
645   }
646 
647   /**
648    * Advanced: Modifies the graph with the provided {@link Delegate}.
649    *
650    * @throws IllegalArgumentException if error occurs when modifying graph with {@code delegate}.
651    * @deprecated Prefer using {@link Interpreter.Options#addDelegate} to provide delegates at
652    *     creation time. This method will be removed in a future release.
653    */
654   @Deprecated
modifyGraphWithDelegate(Delegate delegate)655   public void modifyGraphWithDelegate(Delegate delegate) {
656     checkNotClosed();
657     wrapper.modifyGraphWithDelegate(delegate);
658   }
659 
660   /**
661    * Advanced: Resets all variable tensors to the default value.
662    *
663    * <p>If a variable tensor doesn't have an associated buffer, it will be reset to zero.
664    *
665    * <p>WARNING: This is an experimental API and subject to change.
666    */
resetVariableTensors()667   public void resetVariableTensors() {
668     checkNotClosed();
669     wrapper.resetVariableTensors();
670   }
671 
672   /**
673    * Advanced: Interrupts inference in the middle of a call to {@link Interpreter#run}.
674    *
675    * <p>A cancellation flag will be set to true when this function gets called. The interpreter will
676    * check the flag between Op invocations, and if it's {@code true}, the interpreter will stop
677    * execution. The interpreter will remain a cancelled state until explicitly "uncancelled" by
678    * {@code setCancelled(false)}.
679    *
680    * <p>WARNING: This is an experimental API and subject to change.
681    *
682    * @param cancelled {@code true} to cancel inference in a best-effort way; {@code false} to
683    *     resume.
684    * @throws IllegalStateException if the interpreter is not initialized with the cancellable
685    *     option, which is by default off.
686    * @see {@link Interpreter.Options#setCancellable(boolean)}.
687    */
setCancelled(boolean cancelled)688   public void setCancelled(boolean cancelled) {
689     wrapper.setCancelled(cancelled);
690   }
691 
getExecutionPlanLength()692   int getExecutionPlanLength() {
693     checkNotClosed();
694     return wrapper.getExecutionPlanLength();
695   }
696 
697   /** Release resources associated with the {@code Interpreter}. */
698   @Override
close()699   public void close() {
700     if (wrapper != null) {
701       wrapper.close();
702       wrapper = null;
703     }
704   }
705 
706   // for Object.finalize, see https://bugs.openjdk.java.net/browse/JDK-8165641
707   @SuppressWarnings("deprecation")
708   @Override
finalize()709   protected void finalize() throws Throwable {
710     try {
711       close();
712     } finally {
713       super.finalize();
714     }
715   }
716 
checkNotClosed()717   private void checkNotClosed() {
718     if (wrapper == null) {
719       throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
720     }
721   }
722 
723   NativeInterpreterWrapper wrapper;
724   String[] signatureNameList;
725 }
726