• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.io.File;
19 import java.nio.ByteBuffer;
20 import java.nio.MappedByteBuffer;
21 import java.util.ArrayList;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import org.checkerframework.checker.nullness.qual.NonNull;
26 
27 /**
28  * Driver class to drive model inference with TensorFlow Lite.
29  *
30  * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
31  * are executed for model inference.
32  *
33  * <p>For example, if a model takes only one input and returns only one output:
34  *
35  * <pre>{@code
36  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
37  *   interpreter.run(input, output);
38  * }
39  * }</pre>
40  *
41  * <p>If a model takes multiple inputs or outputs:
42  *
43  * <pre>{@code
44  * Object[] inputs = {input0, input1, ...};
45  * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
46  * ByteBuffer ith_output = ByteBuffer.allocateDirect(3 * 2 * 4 * 4);  // Float tensor, shape 3x2x4.
47  * ith_output.order(ByteOrder.nativeOrder());
48  * map_of_indices_to_outputs.put(i, ith_output);
49  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
50  *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
51  * }
52  * }</pre>
53  *
54  * <p>If a model takes or produces string tensors:
55  *
56  * <pre>{@code
57  * String[] input = {"foo", "bar"};  // Input tensor shape is [2].
58  * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
59  * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
60  *   interpreter.runForMultipleInputsOutputs(input, output);
61  * }
62  * }</pre>
63  *
64  * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
65  * model with Toco, as are the default shapes of the inputs.
66  *
67  * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
68  * be implicitly resized according to that array's shape. When inputs are provided as {@link
69  * ByteBuffer} types, no implicit resizing is done; the caller must ensure that the {@link
70  * ByteBuffer} byte size either matches that of the corresponding tensor, or that they first resize
71  * the tensor via {@link #resizeInput()}. Tensor shape and type information can be obtained via the
72  * {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link
73  * #getOutputTensor(int)}.
74  *
75  * <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code
76  * Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
77  */
78 public final class Interpreter implements AutoCloseable {
79 
80   /** An options class for controlling runtime interpreter behavior. */
81   public static class Options {
Options()82     public Options() {}
83 
84     /**
85      * Sets the number of threads to be used for ops that support multi-threading. Defaults to a
86      * platform-dependent value.
87      */
setNumThreads(int numThreads)88     public Options setNumThreads(int numThreads) {
89       this.numThreads = numThreads;
90       return this;
91     }
92 
93     /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */
setUseNNAPI(boolean useNNAPI)94     public Options setUseNNAPI(boolean useNNAPI) {
95       this.useNNAPI = useNNAPI;
96       return this;
97     }
98 
99     /**
100      * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
101      * (disallow).
102      * WARNING: This is an experimental API and subject to change.
103      */
setAllowFp16PrecisionForFp32(boolean allow)104     public Options setAllowFp16PrecisionForFp32(boolean allow) {
105       this.allowFp16PrecisionForFp32 = allow;
106       return this;
107     }
108 
109     /**
110      * Adds a {@link Delegate} to be applied during interpreter creation.
111      *
112      * <p>WARNING: This is an experimental interface that is subject to change.
113      */
addDelegate(Delegate delegate)114     public Options addDelegate(Delegate delegate) {
115       delegates.add(delegate);
116       return this;
117     }
118 
119     /**
120      * Advanced: Set if buffer handle output is allowed.
121      *
122      * <p>When a {@link Delegate} supports hardware acceleration, the interpreter will make the data
123      * of output tensors available in the CPU-allocated tensor buffers by default. If the client can
124      * consume the buffer handle directly (e.g. reading output from OpenGL texture), it can set this
125      * flag to false, avoiding the copy of data to the CPU buffer. The delegate documentation should
126      * indicate whether this is supported and how it can be used.
127      *
128      * <p>WARNING: This is an experimental interface that is subject to change.
129      */
setAllowBufferHandleOutput(boolean allow)130     public Options setAllowBufferHandleOutput(boolean allow) {
131       this.allowBufferHandleOutput = allow;
132       return this;
133     }
134 
135     int numThreads = -1;
136     Boolean useNNAPI;
137     Boolean allowFp16PrecisionForFp32;
138     Boolean allowBufferHandleOutput;
139     final List<Delegate> delegates = new ArrayList<>();
140   }
141 
142   /**
143    * Initializes a {@code Interpreter}
144    *
145    * @param modelFile: a File of a pre-trained TF Lite model.
146    */
Interpreter(@onNull File modelFile)147   public Interpreter(@NonNull File modelFile) {
148     this(modelFile, /*options = */ null);
149   }
150 
151   /**
152    * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
153    *
154    * @param modelFile: a file of a pre-trained TF Lite model
155    * @param numThreads: number of threads to use for inference
156    * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will
157    *     be removed in a future release.
158    */
159   @Deprecated
Interpreter(@onNull File modelFile, int numThreads)160   public Interpreter(@NonNull File modelFile, int numThreads) {
161     this(modelFile, new Options().setNumThreads(numThreads));
162   }
163 
164   /**
165    * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
166    *
167    * @param modelFile: a file of a pre-trained TF Lite model
168    * @param options: a set of options for customizing interpreter behavior
169    */
Interpreter(@onNull File modelFile, Options options)170   public Interpreter(@NonNull File modelFile, Options options) {
171     wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
172   }
173 
174   /**
175    * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file.
176    *
177    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
178    * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
179    * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
180    */
Interpreter(@onNull ByteBuffer byteBuffer)181   public Interpreter(@NonNull ByteBuffer byteBuffer) {
182     this(byteBuffer, /* options= */ null);
183   }
184 
185   /**
186    * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and specifies the
187    * number of threads used for inference.
188    *
189    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
190    * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
191    * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
192    *
193    * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
194    *     will be removed in a future release.
195    */
196   @Deprecated
Interpreter(@onNull ByteBuffer byteBuffer, int numThreads)197   public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
198     this(byteBuffer, new Options().setNumThreads(numThreads));
199   }
200 
201   /**
202    * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
203    *
204    * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
205    * Interpreter}.
206    *
207    * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
208    *     will be removed in a future release.
209    */
210   @Deprecated
Interpreter(@onNull MappedByteBuffer mappedByteBuffer)211   public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) {
212     this(mappedByteBuffer, /* options= */ null);
213   }
214 
215   /**
216    * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
217    * {@link #Options}.
218    *
219    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
220    * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
221    * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
222    */
Interpreter(@onNull ByteBuffer byteBuffer, Options options)223   public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
224     wrapper = new NativeInterpreterWrapper(byteBuffer, options);
225   }
226 
227   /**
228    * Runs model inference if the model takes only one input, and provides only one output.
229    *
230    * <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please
231    * consider using {@link ByteBuffer} to feed primitive input data for better performance.
232    *
233    * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types
234    *     including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
235    *     input data for primitive types, whereas string types require using the (multi-dimensional)
236    *     array input path. When {@link ByteBuffer} is used, its content should remain unchanged
237    *     until model inference is done. A {@code null} value is allowed only if the caller is using
238    *     a {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to
239    *     the input {@link Tensor}.
240    * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
241    *     types including int, float, long, and byte. A null value is allowed only if the caller is
242    *     using a {@link Delegate} that allows buffer handle interop, and such a buffer has been
243    *     bound to the output {@link Tensor}. See also {@link Options#setAllowBufferHandleOutput()}.
244    */
run(Object input, Object output)245   public void run(Object input, Object output) {
246     Object[] inputs = {input};
247     Map<Integer, Object> outputs = new HashMap<>();
248     outputs.put(0, output);
249     runForMultipleInputsOutputs(inputs, outputs);
250   }
251 
252   /**
253    * Runs model inference if the model takes multiple inputs, or returns multiple outputs.
254    *
255    * <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please
256    * consider using {@link ByteBuffer} to feed primitive input data for better performance.
257    *
258    * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is
259    * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and
260    * such a buffer has been bound to the corresponding input or output {@link Tensor}(s).
261    *
262    * @param inputs an array of input data. The inputs should be in the same order as inputs of the
263    *     model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of
264    *     primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
265    *     way to pass large input data, whereas string types require using the (multi-dimensional)
266    *     array input path. When {@link ByteBuffer} is used, its content should remain unchanged
267    *     until model inference is done.
268    * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
269    *     ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep
270    *     entries for the outputs to be used.
271    */
runForMultipleInputsOutputs( @onNull Object[] inputs, @NonNull Map<Integer, Object> outputs)272   public void runForMultipleInputsOutputs(
273       @NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
274     checkNotClosed();
275     wrapper.run(inputs, outputs);
276   }
277 
278   /**
279    * Resizes idx-th input of the native model to the given dims.
280    *
281    * <p>IllegalArgumentException will be thrown if it fails to resize.
282    */
resizeInput(int idx, @NonNull int[] dims)283   public void resizeInput(int idx, @NonNull int[] dims) {
284     checkNotClosed();
285     wrapper.resizeInput(idx, dims);
286   }
287 
288   /** Gets the number of input tensors. */
getInputTensorCount()289   public int getInputTensorCount() {
290     checkNotClosed();
291     return wrapper.getInputTensorCount();
292   }
293 
294   /**
295    * Gets index of an input given the op name of the input.
296    *
297    * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
298    * to initialize the {@link Interpreter}.
299    */
getInputIndex(String opName)300   public int getInputIndex(String opName) {
301     checkNotClosed();
302     return wrapper.getInputIndex(opName);
303   }
304 
305   /**
306    * Gets the Tensor associated with the provdied input index.
307    *
308    * <p>IllegalArgumentException will be thrown if the provided index is invalid.
309    */
getInputTensor(int inputIndex)310   public Tensor getInputTensor(int inputIndex) {
311     checkNotClosed();
312     return wrapper.getInputTensor(inputIndex);
313   }
314 
315   /** Gets the number of output Tensors. */
getOutputTensorCount()316   public int getOutputTensorCount() {
317     checkNotClosed();
318     return wrapper.getOutputTensorCount();
319   }
320 
321   /**
322    * Gets index of an output given the op name of the output.
323    *
324    * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
325    * to initialize the {@link Interpreter}.
326    */
getOutputIndex(String opName)327   public int getOutputIndex(String opName) {
328     checkNotClosed();
329     return wrapper.getOutputIndex(opName);
330   }
331 
332   /**
333    * Gets the Tensor associated with the provdied output index.
334    *
335    * <p>IllegalArgumentException will be thrown if the provided index is invalid.
336    */
getOutputTensor(int outputIndex)337   public Tensor getOutputTensor(int outputIndex) {
338     checkNotClosed();
339     return wrapper.getOutputTensor(outputIndex);
340   }
341 
342   /**
343    * Returns native inference timing.
344    *
345    * <p>IllegalArgumentException will be thrown if the model is not initialized by the {@link
346    * Interpreter}.
347    */
getLastNativeInferenceDurationNanoseconds()348   public Long getLastNativeInferenceDurationNanoseconds() {
349     checkNotClosed();
350     return wrapper.getLastNativeInferenceDurationNanoseconds();
351   }
352 
353   /**
354    * Turns on/off Android NNAPI for hardware acceleration when it is available.
355    *
356    * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API.
357    *     This method will be removed in a future release.
358    */
359   @Deprecated
setUseNNAPI(boolean useNNAPI)360   public void setUseNNAPI(boolean useNNAPI) {
361     checkNotClosed();
362     wrapper.setUseNNAPI(useNNAPI);
363   }
364 
365   /**
366    * Sets the number of threads to be used for ops that support multi-threading.
367    *
368    * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread
369    *     multi-threading. This method will be removed in a future release.
370    */
371   @Deprecated
setNumThreads(int numThreads)372   public void setNumThreads(int numThreads) {
373     checkNotClosed();
374     wrapper.setNumThreads(numThreads);
375   }
376 
377   /**
378    * Advanced: Modifies the graph with the provided {@link Delegate}.
379    *
380    * <p>Note: The typical path for providing delegates is via {@link Options#addDelegate}, at
381    * creation time. This path should only be used when a delegate might require coordinated
382    * interaction between Interpeter creation and delegate application.
383    *
384    * <p>WARNING: This is an experimental API and subject to change.
385    */
modifyGraphWithDelegate(Delegate delegate)386   public void modifyGraphWithDelegate(Delegate delegate) {
387     checkNotClosed();
388     wrapper.modifyGraphWithDelegate(delegate);
389   }
390 
391   /** Release resources associated with the {@code Interpreter}. */
392   @Override
close()393   public void close() {
394     if (wrapper != null) {
395       wrapper.close();
396       wrapper = null;
397     }
398   }
399 
400   @Override
finalize()401   protected void finalize() throws Throwable {
402     try {
403       close();
404     } finally {
405       super.finalize();
406     }
407   }
408 
checkNotClosed()409   private void checkNotClosed() {
410     if (wrapper == null) {
411       throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
412     }
413   }
414 
415   NativeInterpreterWrapper wrapper;
416 }
417