• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.io.File;
19 import java.nio.ByteBuffer;
20 import java.util.Arrays;
21 import java.util.Map;
22 // Use Android annotation instead
23 import android.support.annotation.NonNull;
24 // import org.checkerframework.checker.nullness.qual.NonNull;
25 
26 /**
27  * Driver class to drive model inference with TensorFlow Lite.
28  *
29  * <p>Note: If you don't need access to any of the "experimental" API features below, prefer to use
30  * InterpreterApi and InterpreterFactory rather than using Interpreter directly.
31  *
32  * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
33  * are executed for model inference.
34  *
35  * <p>For example, if a model takes only one input and returns only one output:
36  *
37  * <pre>{@code
38  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
39  *   interpreter.run(input, output);
40  * }
41  * }</pre>
42  *
43  * <p>If a model takes multiple inputs or outputs:
44  *
45  * <pre>{@code
46  * Object[] inputs = {input0, input1, ...};
47  * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
48  * FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4);  // Float tensor, shape 3x2x4.
49  * ith_output.order(ByteOrder.nativeOrder());
50  * map_of_indices_to_outputs.put(i, ith_output);
51  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
52  *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
53  * }
54  * }</pre>
55  *
56  * <p>If a model takes or produces string tensors:
57  *
58  * <pre>{@code
59  * String[] input = {"foo", "bar"};  // Input tensor shape is [2].
60  * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
61  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
62  *   interpreter.runForMultipleInputsOutputs(input, output);
63  * }
64  * }</pre>
65  *
66  * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
67  * model with Toco, as are the default shapes of the inputs.
68  *
69  * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
70  * be implicitly resized according to that array's shape. When inputs are provided as {@code Buffer}
71  * types, no implicit resizing is done; the caller must ensure that the {@code Buffer} byte size
72  * either matches that of the corresponding tensor, or that they first resize the tensor via {@link
73  * #resizeInput(int, int[])}. Tensor shape and type information can be obtained via the {@link
74  * Tensor} class, available via {@link #getInputTensor(int)} and {@link #getOutputTensor(int)}.
75  *
76  * <p><b>WARNING:</b>{@code Interpreter} instances are <b>not</b> thread-safe. A {@code Interpreter}
77  * owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
78  *
79  * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
80  * but is not guaranteed.
81  */
82 public final class Interpreter extends InterpreterImpl implements InterpreterApi {
83 
84   /** An options class for controlling runtime interpreter behavior. */
85   public static class Options extends InterpreterImpl.Options {
Options()86     public Options() {
87     }
88 
Options(InterpreterApi.Options options)89     public Options(InterpreterApi.Options options) {
90       super(options);
91     }
92 
Options(InterpreterImpl.Options options)93     Options(InterpreterImpl.Options options) {
94       super(options);
95     }
96 
97     @Override
setNumThreads(int numThreads)98     public Options setNumThreads(int numThreads) {
99       super.setNumThreads(numThreads);
100       return this;
101     }
102 
103     @Override
setUseNNAPI(boolean useNNAPI)104     public Options setUseNNAPI(boolean useNNAPI) {
105       super.setUseNNAPI(useNNAPI);
106       return this;
107     }
108 
109     /**
110      * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
111      * (disallow).
112      *
113      * @deprecated Prefer using <a
114      *     href="https://github.com/tensorflow/tensorflow/blob/5dc7f6981fdaf74c8c5be41f393df705841fb7c5/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java#L127">NnApiDelegate.Options#setAllowFp16(boolean
115      *     enable)</a>.
116      */
117     @Deprecated
setAllowFp16PrecisionForFp32(boolean allow)118     public Options setAllowFp16PrecisionForFp32(boolean allow) {
119       this.allowFp16PrecisionForFp32 = allow;
120       return this;
121     }
122 
123     @Override
addDelegate(Delegate delegate)124     public Options addDelegate(Delegate delegate) {
125       super.addDelegate(delegate);
126       return this;
127     }
128 
129     /**
130      * Advanced: Set if buffer handle output is allowed.
131      *
132      * <p>When a {@link Delegate} supports hardware acceleration, the interpreter will make the data
133      * of output tensors available in the CPU-allocated tensor buffers by default. If the client can
134      * consume the buffer handle directly (e.g. reading output from OpenGL texture), it can set this
135      * flag to false, avoiding the copy of data to the CPU buffer. The delegate documentation should
136      * indicate whether this is supported and how it can be used.
137      *
138      * <p>WARNING: This is an experimental interface that is subject to change.
139      */
setAllowBufferHandleOutput(boolean allow)140     public Options setAllowBufferHandleOutput(boolean allow) {
141       this.allowBufferHandleOutput = allow;
142       return this;
143     }
144 
145     @Override
setCancellable(boolean allow)146     public Options setCancellable(boolean allow) {
147       super.setCancellable(allow);
148       return this;
149     }
150 
151     /**
152      * Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK).
153      *
154      * <p>Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
155      * via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
156      * operations. Eventually, we plan to enable this by default, as it can provide significant
157      * peformance benefits for many classes of floating point models. See
158      * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
159      * for more details.
160      *
161      * <p>Things to keep in mind when enabling this flag:
162      *
163      * <ul>
164      *   <li>Startup time and resize time may increase.
165      *   <li>Baseline memory consumption may increase.
166      *   <li>May be ignored if another delegate (eg NNAPI) have been applied.
167      *   <li>Quantized models will not see any benefit.
168      * </ul>
169      *
170      * <p>WARNING: This is an experimental interface that is subject to change.
171      */
setUseXNNPACK(boolean useXNNPACK)172     public Options setUseXNNPACK(boolean useXNNPACK) {
173       this.useXNNPACK = useXNNPACK;
174       return this;
175     }
176   }
177 
178   /**
179    * Initializes an {@code Interpreter}.
180    *
181    * @param modelFile a File of a pre-trained TF Lite model.
182    * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
183    *     model.
184    */
Interpreter(@onNull File modelFile)185   public Interpreter(@NonNull File modelFile) {
186     this(modelFile, /*options = */ null);
187   }
188 
189   /**
190    * Initializes an {@code Interpreter} and specifies options for customizing interpreter behavior.
191    *
192    * @param modelFile a file of a pre-trained TF Lite model
193    * @param options a set of options for customizing interpreter behavior
194    * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
195    *     model.
196    */
Interpreter(@onNull File modelFile, Options options)197   public Interpreter(@NonNull File modelFile, Options options) {
198     this(new NativeInterpreterWrapperExperimental(modelFile.getAbsolutePath(), options));
199   }
200 
201   /**
202    * Initializes an {@code Interpreter} with a {@code ByteBuffer} of a model file.
203    *
204    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
205    * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
206    * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
207    *
208    * @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a
209    *     direct {@code ByteBuffer} of nativeOrder.
210    */
Interpreter(@onNull ByteBuffer byteBuffer)211   public Interpreter(@NonNull ByteBuffer byteBuffer) {
212     this(byteBuffer, /* options= */ null);
213   }
214 
215   /**
216    * Initializes an {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of
217    * custom {@link Interpreter.Options}.
218    *
219    * <p>The {@code ByteBuffer} should not be modified after the construction of an {@code
220    * Interpreter}. The {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps
221    * a model file, or a direct {@code ByteBuffer} of nativeOrder() that contains the bytes content
222    * of a model.
223    *
224    * @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a
225    *     direct {@code ByteBuffer} of nativeOrder.
226    */
Interpreter(@onNull ByteBuffer byteBuffer, Options options)227   public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
228     this(new NativeInterpreterWrapperExperimental(byteBuffer, options));
229   }
230 
Interpreter(NativeInterpreterWrapperExperimental wrapper)231   private Interpreter(NativeInterpreterWrapperExperimental wrapper) {
232     super(wrapper);
233     wrapperExperimental = wrapper;
234     signatureKeyList = getSignatureKeys();
235   }
236 
237   /**
238    * Runs model inference based on SignatureDef provided through {@code signatureKey}.
239    *
240    * <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output
241    * data types.
242    *
243    * <p>WARNING: This is an experimental API and subject to change.
244    *
245    * @param inputs A map from input name in the SignatureDef to an input object.
246    * @param outputs A map from output name in SignatureDef to output data. This may be empty if the
247    *     caller wishes to query the {@link Tensor} data directly after inference (e.g., if the
248    *     output shape is dynamic, or output buffer handles are used).
249    * @param signatureKey Signature key identifying the SignatureDef.
250    * @throws IllegalArgumentException if {@code inputs} is null or empty, if {@code outputs} or
251    *     {@code signatureKey} is null, or if an error occurs when running inference.
252    */
runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs, String signatureKey)253   public void runSignature(
254       @NonNull Map<String, Object> inputs,
255       @NonNull Map<String, Object> outputs,
256       String signatureKey) {
257     checkNotClosed();
258     if (signatureKey == null && signatureKeyList.length == 1) {
259       signatureKey = signatureKeyList[0];
260     }
261     if (signatureKey == null) {
262       throw new IllegalArgumentException(
263           "Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
264               + " model has a single Signature. Available Signatures: "
265               + Arrays.toString(signatureKeyList));
266     }
267     wrapper.runSignature(inputs, outputs, signatureKey);
268   }
269 
270   /**
271    * Same as {@link #runSignature(Map, Map, String)} but doesn't require passing a signatureKey,
272    * assuming the model has one SignatureDef. If the model has more than one SignatureDef it will
273    * throw an exception.
274    *
275    * <p>WARNING: This is an experimental API and subject to change.
276    */
runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs)277   public void runSignature(
278       @NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) {
279     checkNotClosed();
280     runSignature(inputs, outputs, null);
281   }
282 
283   /**
284    * Gets the Tensor associated with the provdied input name and signature method name.
285    *
286    * <p>WARNING: This is an experimental API and subject to change.
287    *
288    * @param inputName Input name in the signature.
289    * @param signatureKey Signature key identifying the SignatureDef, can be null if the model has
290    *     one signature.
291    * @throws IllegalArgumentException if {@code inputName} or {@code signatureKey} is null or empty,
292    *     or invalid name provided.
293    */
getInputTensorFromSignature(String inputName, String signatureKey)294   public Tensor getInputTensorFromSignature(String inputName, String signatureKey) {
295     checkNotClosed();
296     if (signatureKey == null && signatureKeyList.length == 1) {
297       signatureKey = signatureKeyList[0];
298     }
299     if (signatureKey == null) {
300       throw new IllegalArgumentException(
301           "Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
302               + " model has a single Signature. Available Signatures: "
303               + Arrays.toString(signatureKeyList));
304     }
305     return wrapper.getInputTensor(inputName, signatureKey);
306   }
307 
308   /**
309    * Gets the list of SignatureDef exported method names available in the model.
310    *
311    * <p>WARNING: This is an experimental API and subject to change.
312    */
getSignatureKeys()313   public String[] getSignatureKeys() {
314     checkNotClosed();
315     return wrapper.getSignatureKeys();
316   }
317 
318   /**
319    * Gets the list of SignatureDefs inputs for method {@code signatureKey}.
320    *
321    * <p>WARNING: This is an experimental API and subject to change.
322    */
getSignatureInputs(String signatureKey)323   public String[] getSignatureInputs(String signatureKey) {
324     checkNotClosed();
325     return wrapper.getSignatureInputs(signatureKey);
326   }
327 
328   /**
329    * Gets the list of SignatureDefs outputs for method {@code signatureKey}.
330    *
331    * <p>WARNING: This is an experimental API and subject to change.
332    */
getSignatureOutputs(String signatureKey)333   public String[] getSignatureOutputs(String signatureKey) {
334     checkNotClosed();
335     return wrapper.getSignatureOutputs(signatureKey);
336   }
337 
338   /**
339    * Gets the Tensor associated with the provdied output name in specifc signature method.
340    *
341    * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
342    * is executed. If you need updated details *before* running inference (e.g., after resizing an
343    * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
344    * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
345    * that are dependent on input *values*, the output shape may not be fully determined until
346    * running inference.
347    *
348    * <p>WARNING: This is an experimental API and subject to change.
349    *
350    * @param outputName Output name in the signature.
351    * @param signatureKey Signature key identifying the SignatureDef, can be null if the model has
352    *     one signature.
353    * @throws IllegalArgumentException if {@code outputName} or {@code signatureKey} is null or
354    *     empty, or invalid name provided.
355    */
getOutputTensorFromSignature(String outputName, String signatureKey)356   public Tensor getOutputTensorFromSignature(String outputName, String signatureKey) {
357     checkNotClosed();
358     if (signatureKey == null && signatureKeyList.length == 1) {
359       signatureKey = signatureKeyList[0];
360     }
361     if (signatureKey == null) {
362       throw new IllegalArgumentException(
363           "Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
364               + " model has a single Signature. Available Signatures: "
365               + Arrays.toString(signatureKeyList));
366     }
367     return wrapper.getOutputTensor(outputName, signatureKey);
368   }
369 
370   /**
371    * Advanced: Resets all variable tensors to the default value.
372    *
373    * <p>If a variable tensor doesn't have an associated buffer, it will be reset to zero.
374    *
375    * <p>WARNING: This is an experimental API and subject to change.
376    */
resetVariableTensors()377   public void resetVariableTensors() {
378     checkNotClosed();
379     wrapperExperimental.resetVariableTensors();
380   }
381 
382   /**
383    * Advanced: Interrupts inference in the middle of a call to {@link Interpreter#run}.
384    *
385    * <p>A cancellation flag will be set to true when this function gets called. The interpreter will
386    * check the flag between Op invocations, and if it's {@code true}, the interpreter will stop
387    * execution. The interpreter will remain a cancelled state until explicitly "uncancelled" by
388    * {@code setCancelled(false)}.
389    *
390    * <p>WARNING: This is an experimental API and subject to change.
391    *
392    * @param cancelled {@code true} to cancel inference in a best-effort way; {@code false} to
393    *     resume.
394    * @throws IllegalStateException if the interpreter is not initialized with the cancellable
395    *     option, which is by default off.
396    * @see Interpreter.Options#setCancellable(boolean).
397    */
setCancelled(boolean cancelled)398   public void setCancelled(boolean cancelled) {
399     wrapper.setCancelled(cancelled);
400   }
401 
402   NativeInterpreterWrapperExperimental wrapperExperimental;
403   String[] signatureKeyList;
404 }
405