• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.io.File;
19 import java.nio.ByteBuffer;
20 import java.util.Arrays;
21 import java.util.Map;
22 import org.checkerframework.checker.nullness.qual.NonNull;
23 
24 /**
25  * Driver class to drive model inference with TensorFlow Lite.
26  *
27  * <p>Note: If you don't need access to any of the "experimental" API features below, prefer to use
28  * InterpreterApi and InterpreterFactory rather than using Interpreter directly.
29  *
30  * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
31  * are executed for model inference.
32  *
33  * <p>For example, if a model takes only one input and returns only one output:
34  *
35  * <pre>{@code
36  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
37  *   interpreter.run(input, output);
38  * }
39  * }</pre>
40  *
41  * <p>If a model takes multiple inputs or outputs:
42  *
43  * <pre>{@code
44  * Object[] inputs = {input0, input1, ...};
45  * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
46  * FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4);  // Float tensor, shape 3x2x4.
47  * ith_output.order(ByteOrder.nativeOrder());
48  * map_of_indices_to_outputs.put(i, ith_output);
49  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
50  *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
51  * }
52  * }</pre>
53  *
54  * <p>If a model takes or produces string tensors:
55  *
56  * <pre>{@code
57  * String[] input = {"foo", "bar"};  // Input tensor shape is [2].
58  * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
59  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
60  *   interpreter.runForMultipleInputsOutputs(input, output);
61  * }
62  * }</pre>
63  *
64  * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
65  * model with Toco, as are the default shapes of the inputs.
66  *
67  * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
68  * be implicitly resized according to that array's shape. When inputs are provided as {@code Buffer}
69  * types, no implicit resizing is done; the caller must ensure that the {@code Buffer} byte size
70  * either matches that of the corresponding tensor, or that they first resize the tensor via {@link
71  * #resizeInput(int, int[])}. Tensor shape and type information can be obtained via the {@link
72  * Tensor} class, available via {@link #getInputTensor(int)} and {@link #getOutputTensor(int)}.
73  *
74  * <p><b>WARNING:</b>{@code Interpreter} instances are <b>not</b> thread-safe. A {@code Interpreter}
75  * owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
76  *
77  * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
78  * but is not guaranteed.
79  */
80 public final class Interpreter extends InterpreterImpl implements InterpreterApi {
81 
82   /** An options class for controlling runtime interpreter behavior. */
83   public static class Options extends InterpreterImpl.Options {
Options()84     public Options() {}
85 
Options(InterpreterApi.Options options)86     public Options(InterpreterApi.Options options) {
87       super(options);
88     }
89 
Options(InterpreterImpl.Options options)90     Options(InterpreterImpl.Options options) {
91       super(options);
92     }
93 
94     @Override
setNumThreads(int numThreads)95     public Options setNumThreads(int numThreads) {
96       super.setNumThreads(numThreads);
97       return this;
98     }
99 
100     @Override
setUseNNAPI(boolean useNNAPI)101     public Options setUseNNAPI(boolean useNNAPI) {
102       super.setUseNNAPI(useNNAPI);
103       return this;
104     }
105 
106     /**
107      * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
108      * (disallow).
109      *
110      * @deprecated Prefer using <a
111      *     href="https://github.com/tensorflow/tensorflow/blob/5dc7f6981fdaf74c8c5be41f393df705841fb7c5/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java#L127">NnApiDelegate.Options#setAllowFp16(boolean
112      *     enable)</a>.
113      */
114     @Deprecated
setAllowFp16PrecisionForFp32(boolean allow)115     public Options setAllowFp16PrecisionForFp32(boolean allow) {
116       this.allowFp16PrecisionForFp32 = allow;
117       return this;
118     }
119 
120     @Override
addDelegate(Delegate delegate)121     public Options addDelegate(Delegate delegate) {
122       super.addDelegate(delegate);
123       return this;
124     }
125 
126     @Override
addDelegateFactory(DelegateFactory delegateFactory)127     public Options addDelegateFactory(DelegateFactory delegateFactory) {
128       super.addDelegateFactory(delegateFactory);
129       return this;
130     }
131 
132     /**
133      * Advanced: Set if buffer handle output is allowed.
134      *
135      * <p>When a {@link Delegate} supports hardware acceleration, the interpreter will make the data
136      * of output tensors available in the CPU-allocated tensor buffers by default. If the client can
137      * consume the buffer handle directly (e.g. reading output from OpenGL texture), it can set this
138      * flag to false, avoiding the copy of data to the CPU buffer. The delegate documentation should
139      * indicate whether this is supported and how it can be used.
140      *
141      * <p>WARNING: This is an experimental interface that is subject to change.
142      */
setAllowBufferHandleOutput(boolean allow)143     public Options setAllowBufferHandleOutput(boolean allow) {
144       this.allowBufferHandleOutput = allow;
145       return this;
146     }
147 
148     @Override
setCancellable(boolean allow)149     public Options setCancellable(boolean allow) {
150       super.setCancellable(allow);
151       return this;
152     }
153 
154     /**
155      * Experimental: Disable an optimized set of CPU kernels (provided by XNNPACK).
156      *
157      * <p>Disabling this flag will disable use of a highly optimized set of CPU kernels provided via
158      * the XNNPACK delegate. Currently, this is restricted to a subset of floating point operations.
159      * See
160      * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
161      * for more details.
162      *
163      * <p>WARNING: This is an experimental interface that is subject to change.
164      */
setUseXNNPACK(boolean useXNNPACK)165     public Options setUseXNNPACK(boolean useXNNPACK) {
166       this.useXNNPACK = useXNNPACK;
167       return this;
168     }
169 
170     @Override
setRuntime(InterpreterApi.Options.TfLiteRuntime runtime)171     public Options setRuntime(InterpreterApi.Options.TfLiteRuntime runtime) {
172       super.setRuntime(runtime);
173       return this;
174     }
175   }
176 
177   /**
178    * Initializes an {@code Interpreter}.
179    *
180    * @param modelFile a File of a pre-trained TF Lite model.
181    * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
182    *     model.
183    */
Interpreter(@onNull File modelFile)184   public Interpreter(@NonNull File modelFile) {
185     this(modelFile, /*options = */ null);
186   }
187 
188   /**
189    * Initializes an {@code Interpreter} and specifies options for customizing interpreter behavior.
190    *
191    * @param modelFile a file of a pre-trained TF Lite model
192    * @param options a set of options for customizing interpreter behavior
193    * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
194    *     model.
195    */
Interpreter(@onNull File modelFile, Options options)196   public Interpreter(@NonNull File modelFile, Options options) {
197     this(new NativeInterpreterWrapperExperimental(modelFile.getAbsolutePath(), options));
198   }
199 
200   /**
201    * Initializes an {@code Interpreter} with a {@code ByteBuffer} of a model file.
202    *
203    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
204    * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
205    * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
206    *
207    * @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a
208    *     direct {@code ByteBuffer} of nativeOrder.
209    */
Interpreter(@onNull ByteBuffer byteBuffer)210   public Interpreter(@NonNull ByteBuffer byteBuffer) {
211     this(byteBuffer, /* options= */ null);
212   }
213 
214   /**
215    * Initializes an {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of
216    * custom {@link Interpreter.Options}.
217    *
218    * <p>The {@code ByteBuffer} should not be modified after the construction of an {@code
219    * Interpreter}. The {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps
220    * a model file, or a direct {@code ByteBuffer} of nativeOrder() that contains the bytes content
221    * of a model.
222    *
223    * @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a
224    *     direct {@code ByteBuffer} of nativeOrder.
225    */
Interpreter(@onNull ByteBuffer byteBuffer, Options options)226   public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
227     this(new NativeInterpreterWrapperExperimental(byteBuffer, options));
228   }
229 
Interpreter(NativeInterpreterWrapperExperimental wrapper)230   private Interpreter(NativeInterpreterWrapperExperimental wrapper) {
231     super(wrapper);
232     wrapperExperimental = wrapper;
233     signatureKeyList = getSignatureKeys();
234   }
235 
236   /**
237    * Runs model inference based on SignatureDef provided through {@code signatureKey}.
238    *
239    * <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output
240    * data types.
241    *
242    * <p>WARNING: This is an experimental API and subject to change.
243    *
244    * @param inputs A map from input name in the SignatureDef to an input object.
245    * @param outputs A map from output name in SignatureDef to output data. This may be empty if the
246    *     caller wishes to query the {@link Tensor} data directly after inference (e.g., if the
247    *     output shape is dynamic, or output buffer handles are used).
248    * @param signatureKey Signature key identifying the SignatureDef.
249    * @throws IllegalArgumentException if {@code inputs} is null or empty, if {@code outputs} or
250    *     {@code signatureKey} is null, or if an error occurs when running inference.
251    */
runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs, String signatureKey)252   public void runSignature(
253       @NonNull Map<String, Object> inputs,
254       @NonNull Map<String, Object> outputs,
255       String signatureKey) {
256     checkNotClosed();
257     if (signatureKey == null && signatureKeyList.length == 1) {
258       signatureKey = signatureKeyList[0];
259     }
260     if (signatureKey == null) {
261       throw new IllegalArgumentException(
262           "Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
263               + " model has a single Signature. Available Signatures: "
264               + Arrays.toString(signatureKeyList));
265     }
266     wrapper.runSignature(inputs, outputs, signatureKey);
267   }
268 
269   /**
270    * Same as {@link #runSignature(Map, Map, String)} but doesn't require passing a signatureKey,
271    * assuming the model has one SignatureDef. If the model has more than one SignatureDef it will
272    * throw an exception.
273    *
274    * <p>WARNING: This is an experimental API and subject to change.
275    */
runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs)276   public void runSignature(
277       @NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) {
278     checkNotClosed();
279     runSignature(inputs, outputs, null);
280   }
281 
282   /**
283    * Gets the Tensor associated with the provided input name and signature method name.
284    *
285    * <p>WARNING: This is an experimental API and subject to change.
286    *
287    * @param inputName Input name in the signature.
288    * @param signatureKey Signature key identifying the SignatureDef, can be null if the model has
289    *     one signature.
290    * @throws IllegalArgumentException if {@code inputName} or {@code signatureKey} is null or empty,
291    *     or invalid name provided.
292    */
getInputTensorFromSignature(String inputName, String signatureKey)293   public Tensor getInputTensorFromSignature(String inputName, String signatureKey) {
294     checkNotClosed();
295     if (signatureKey == null && signatureKeyList.length == 1) {
296       signatureKey = signatureKeyList[0];
297     }
298     if (signatureKey == null) {
299       throw new IllegalArgumentException(
300           "Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
301               + " model has a single Signature. Available Signatures: "
302               + Arrays.toString(signatureKeyList));
303     }
304     return wrapper.getInputTensor(inputName, signatureKey);
305   }
306 
307   /**
308    * Gets the list of SignatureDef exported method names available in the model.
309    *
310    * <p>WARNING: This is an experimental API and subject to change.
311    */
getSignatureKeys()312   public String[] getSignatureKeys() {
313     checkNotClosed();
314     return wrapper.getSignatureKeys();
315   }
316 
317   /**
318    * Gets the list of SignatureDefs inputs for method {@code signatureKey}.
319    *
320    * <p>WARNING: This is an experimental API and subject to change.
321    */
getSignatureInputs(String signatureKey)322   public String[] getSignatureInputs(String signatureKey) {
323     checkNotClosed();
324     return wrapper.getSignatureInputs(signatureKey);
325   }
326 
327   /**
328    * Gets the list of SignatureDefs outputs for method {@code signatureKey}.
329    *
330    * <p>WARNING: This is an experimental API and subject to change.
331    */
getSignatureOutputs(String signatureKey)332   public String[] getSignatureOutputs(String signatureKey) {
333     checkNotClosed();
334     return wrapper.getSignatureOutputs(signatureKey);
335   }
336 
337   /**
338    * Gets the Tensor associated with the provided output name in specific signature method.
339    *
340    * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
341    * is executed. If you need updated details *before* running inference (e.g., after resizing an
342    * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
343    * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
344    * that are dependent on input *values*, the output shape may not be fully determined until
345    * running inference.
346    *
347    * <p>WARNING: This is an experimental API and subject to change.
348    *
349    * @param outputName Output name in the signature.
350    * @param signatureKey Signature key identifying the SignatureDef, can be null if the model has
351    *     one signature.
352    * @throws IllegalArgumentException if {@code outputName} or {@code signatureKey} is null or
353    *     empty, or invalid name provided.
354    */
getOutputTensorFromSignature(String outputName, String signatureKey)355   public Tensor getOutputTensorFromSignature(String outputName, String signatureKey) {
356     checkNotClosed();
357     if (signatureKey == null && signatureKeyList.length == 1) {
358       signatureKey = signatureKeyList[0];
359     }
360     if (signatureKey == null) {
361       throw new IllegalArgumentException(
362           "Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
363               + " model has a single Signature. Available Signatures: "
364               + Arrays.toString(signatureKeyList));
365     }
366     return wrapper.getOutputTensor(outputName, signatureKey);
367   }
368 
369   /**
370    * Advanced: Resets all variable tensors to the default value.
371    *
372    * <p>If a variable tensor doesn't have an associated buffer, it will be reset to zero.
373    *
374    * <p>WARNING: This is an experimental API and subject to change.
375    */
resetVariableTensors()376   public void resetVariableTensors() {
377     checkNotClosed();
378     wrapperExperimental.resetVariableTensors();
379   }
380 
381   /**
382    * Advanced: Interrupts inference in the middle of a call to {@link Interpreter#run}.
383    *
384    * <p>A cancellation flag will be set to true when this function gets called. The interpreter will
385    * check the flag between Op invocations, and if it's {@code true}, the interpreter will stop
386    * execution. The interpreter will remain a cancelled state until explicitly "uncancelled" by
387    * {@code setCancelled(false)}.
388    *
389    * <p>WARNING: This is an experimental API and subject to change.
390    *
391    * @param cancelled {@code true} to cancel inference in a best-effort way; {@code false} to
392    *     resume.
393    * @throws IllegalStateException if the interpreter is not initialized with the cancellable
394    *     option, which is by default off.
395    * @see Interpreter.Options#setCancellable(boolean).
396    */
setCancelled(boolean cancelled)397   public void setCancelled(boolean cancelled) {
398     wrapper.setCancelled(cancelled);
399   }
400 
401   private final NativeInterpreterWrapperExperimental wrapperExperimental;
402   private final String[] signatureKeyList;
403 }
404