• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# TensorFlow Lite inference
2
3The term *inference* refers to the process of executing a TensorFlow Lite model
4on-device in order to make predictions based on input data. To perform an
5inference with a TensorFlow Lite model, you must run it through an
6*interpreter*. The TensorFlow Lite interpreter is designed to be lean and fast.
7The interpreter uses a static graph ordering and a custom (less-dynamic) memory
8allocator to ensure minimal load, initialization, and execution latency.
9
10This page describes how to access to the TensorFlow Lite interpreter and perform
11an inference using C++, Java, and Python, plus links to other resources for each
12[supported platform](#supported-platforms).
13
14[TOC]
15
16## Important concepts
17
18TensorFlow Lite inference typically follows the following steps:
19
201.  **Loading a model**
21
22    You must load the `.tflite` model into memory, which contains the model's
23    execution graph.
24
251.  **Transforming data**
26
27    Raw input data for the model generally does not match the input data format
28    expected by the model. For example, you might need to resize an image or
29    change the image format to be compatible with the model.
30
311.  **Running inference**
32
33    This step involves using the TensorFlow Lite API to execute the model. It
34    involves a few steps such as building the interpreter, and allocating
35    tensors, as described in the following sections.
36
371.  **Interpreting output**
38
39    When you receive results from the model inference, you must interpret the
40    tensors in a meaningful way that's useful in your application.
41
42    For example, a model might return only a list of probabilities. It's up to
43    you to map the probabilities to relevant categories and present it to your
44    end-user.
45
46## Supported platforms
47
48TensorFlow inference APIs are provided for most common mobile/embedded platforms
49such as [Android](#android-platform), [iOS](#ios-platform) and
50[Linux](#linux-platform), in multiple programming languages.
51
52In most cases, the API design reflects a preference for performance over ease of
53use. TensorFlow Lite is designed for fast inference on small devices, so it
54should be no surprise that the APIs try to avoid unnecessary copies at the
55expense of convenience. Similarly, consistency with TensorFlow APIs was not an
56explicit goal and some variance between languages is to be expected.
57
58Across all libraries, the TensorFlow Lite API enables you to load models, feed
59inputs, and retrieve inference outputs.
60
61### Android Platform
62
63On Android, TensorFlow Lite inference can be performed using either Java or C++
64APIs. The Java APIs provide convenience and can be used directly within your
65Android Activity classes. The C++ APIs offer more flexibility and speed, but may
66require writing JNI wrappers to move data between Java and C++ layers.
67
68See below for details about using [C++](#load-and-run-a-model-in-c) and
69[Java](#load-and-run-a-model-in-java), or follow the
70[Android quickstart](../android) for a tutorial and example code.
71
72#### TensorFlow Lite Android wrapper code generator
73
74Note: TensorFlow Lite wrapper code generator is in experimental (beta) phase and
75it currently only supports Android.
76
77For TensorFlow Lite model enhanced with [metadata](../inference_with_metadata/overview),
78developers can use the TensorFlow Lite Android wrapper code generator to create
79platform specific wrapper code. The wrapper code removes the need to interact
80directly with `ByteBuffer` on Android. Instead, developers can interact with the
81TensorFlow Lite model with typed objects such as `Bitmap` and `Rect`. For more
82information, please refer to the
83[TensorFlow Lite Android wrapper code generator](../inference_with_metadata/codegen.md).
84
85### iOS Platform
86
87On iOS, TensorFlow Lite is available with native iOS libraries written in
88[Swift](https://www.tensorflow.org/code/tensorflow/lite/swift)
89and
90[Objective-C](https://www.tensorflow.org/code/tensorflow/lite/objc).
91You can also use
92[C API](https://www.tensorflow.org/code/tensorflow/lite/c/c_api.h)
93directly in Objective-C codes.
94
95See below for details about using [Swift](#load-and-run-a-model-in-swift),
96[Objective-C](#load-and-run-a-model-in-objective-c) and the
97[C API](#using-c-api-in-objective-c-code), or follow the
98[iOS quickstart](ios.md) for a tutorial and example code.
99
100### Linux Platform
101
102On Linux platforms (including [Raspberry Pi](build_arm)), you can run
103inferences using TensorFlow Lite APIs available in
104[C++](#load-and-run-a-model-in-c) and [Python](#load-and-run-a-model-in-python),
105as shown in the following sections.
106
107## Running a model
108
109Running a TensorFlow Lite model involves a few simple steps:
110
1111.  Load the model into memory.
1122.  Build an `Interpreter` based on an existing model.
1133.  Set input tensor values. (Optionally resize input tensors if the predefined
114    sizes are not desired.)
1154.  Invoke inference.
1165.  Read output tensor values.
117
118Following sections describe how these steps can be done in each language.
119
120## Load and run a model in Java
121
122*Platform: Android*
123
124The Java API for running an inference with TensorFlow Lite is primarily designed
125for use with Android, so it's available as an Android library dependency:
126`org.tensorflow:tensorflow-lite`.
127
128In Java, you'll use the `Interpreter` class to load a model and drive model
129inference. In many cases, this may be the only API you need.
130
131You can initialize an `Interpreter` using a `.tflite` file:
132
133```java
134public Interpreter(@NotNull File modelFile);
135```
136
137Or with a `MappedByteBuffer`:
138
139```java
140public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer);
141```
142
143In both cases, you must provide a valid TensorFlow Lite model or the API throws
144`IllegalArgumentException`. If you use `MappedByteBuffer` to initialize an
145`Interpreter`, it must remain unchanged for the whole lifetime of the
146`Interpreter`.
147
148The preferred way to run inference on a model is to use signatures -
149Available for models converted starting Tensorflow 2.5
150
151```Java
152try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
153  Map<String, Object> inputs = new HashMap<>();
154  inputs.put("input_1", input1);
155  inputs.put("input_2", input2);
156  Map<String, Object> outputs = new HashMap<>();
157  outputs.put("output_1", output1);
158  interpreter.runSignature(inputs, outputs, "mySignature");
159}
160```
161
162The `runSignature` method takes three arguments:
163
164-   **Inputs** : map for inputs from input name in the signature to an input
165    object.
166
167-   **Outputs** : map for output mapping from output name in signature to output
168    data.
169
170-   **Signature Name** [optional]: Signature name (Can be left empty if the
171    model has single signature).
172
173Another way to run an inference when the model doesn't
174have a defined signatures.
175Simply call `Interpreter.run()`. For example:
176
177```java
178try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
179  interpreter.run(input, output);
180}
181```
182
183The `run()` method takes only one input and returns only one output. So if your
184model has multiple inputs or multiple outputs, instead use:
185
186```java
187interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
188```
189
190In this case, each entry in `inputs` corresponds to an input tensor and
191`map_of_indices_to_outputs` maps indices of output tensors to the corresponding
192output data.
193
194In both cases, the tensor indices should correspond to the values you gave to
195the [TensorFlow Lite Converter](../models/convert/) when you created the model. Be
196aware that the order of tensors in `input` must match the order given to the
197TensorFlow Lite Converter.
198
199The `Interpreter` class also provides convenient functions for you to get the
200index of any model input or output using an operation name:
201
202```java
203public int getInputIndex(String opName);
204public int getOutputIndex(String opName);
205```
206
207If `opName` is not a valid operation in the model, it throws an
208`IllegalArgumentException`.
209
210Also beware that `Interpreter` owns resources. To avoid memory leak, the
211resources must be released after use by:
212
213```java
214interpreter.close();
215```
216
217For an example project with Java, see the
218[Android image classification sample](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android).
219
220### Supported data types (in Java)
221
222To use TensorFlow Lite, the data types of the input and output tensors must be
223one of the following primitive types:
224
225*   `float`
226*   `int`
227*   `long`
228*   `byte`
229
230`String` types are also supported, but they are encoded differently than the
231primitive types. In particular, the shape of a string Tensor dictates the number
232and arrangement of strings in the Tensor, with each element itself being a
233variable length string. In this sense, the (byte) size of the Tensor cannot be
234computed from the shape and type alone, and consequently strings cannot be
235provided as a single, flat `ByteBuffer` argument.
236
237If other data types, including boxed types like `Integer` and `Float`, are used,
238an `IllegalArgumentException` will be thrown.
239
240#### Inputs
241
242Each input should be an array or multi-dimensional array of the supported
243primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is
244an array or multi-dimensional array, the associated input tensor will be
245implicitly resized to the array's dimensions at inference time. If the input is
246a ByteBuffer, the caller should first manually resize the associated input
247tensor (via `Interpreter.resizeInput()`) before running inference.
248
249When using `ByteBuffer`, prefer using direct byte buffers, as this allows the
250`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte
251buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a
252model inference, it must remain unchanged until the model inference is finished.
253
254#### Outputs
255
256Each output should be an array or multi-dimensional array of the supported
257primitive types, or a ByteBuffer of the appropriate size. Note that some models
258have dynamic outputs, where the shape of output tensors can vary depending on
259the input. There's no straightforward way of handling this with the existing
260Java inference API, but planned extensions will make this possible.
261
262## Load and run a model in Swift
263
264*Platform: iOS*
265
266The
267[Swift API](https://www.tensorflow.org/code/tensorflow/lite/swift)
268is available in `TensorFlowLiteSwift` Pod from Cocoapods.
269
270First, you need to import `TensorFlowLite` module.
271
272```swift
273import TensorFlowLite
274```
275
276```swift
277// Getting model path
278guard
279  let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
280else {
281  // Error handling...
282}
283
284do {
285  // Initialize an interpreter with the model.
286  let interpreter = try Interpreter(modelPath: modelPath)
287
288  // Allocate memory for the model's input `Tensor`s.
289  try interpreter.allocateTensors()
290
291  let inputData: Data  // Should be initialized
292
293  // input data preparation...
294
295  // Copy the input data to the input `Tensor`.
296  try self.interpreter.copy(inputData, toInputAt: 0)
297
298  // Run inference by invoking the `Interpreter`.
299  try self.interpreter.invoke()
300
301  // Get the output `Tensor`
302  let outputTensor = try self.interpreter.output(at: 0)
303
304  // Copy output to `Data` to process the inference results.
305  let outputSize = outputTensor.shape.dimensions.reduce(1, {x, y in x * y})
306  let outputData =
307        UnsafeMutableBufferPointer<Float32>.allocate(capacity: outputSize)
308  outputTensor.data.copyBytes(to: outputData)
309
310  if (error != nil) { /* Error handling... */ }
311} catch error {
312  // Error handling...
313}
314```
315
316## Load and run a model in Objective-C
317
318*Platform: iOS*
319
320The
321[Objective-C API](https://www.tensorflow.org/code/tensorflow/lite/objc)
322is available in `TensorFlowLiteObjC` Pod from Cocoapods.
323
324First, you need to import `TensorFlowLite` module.
325
326```objc
327@import TensorFlowLite;
328```
329
330```objc
331NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"model"
332                                                      ofType:@"tflite"];
333NSError *error;
334
335// Initialize an interpreter with the model.
336TFLInterpreter *interpreter = [[TFLInterpreter alloc] initWithModelPath:modelPath
337                                                                  error:&error];
338if (error != nil) { /* Error handling... */ }
339
340// Allocate memory for the model's input `TFLTensor`s.
341[interpreter allocateTensorsWithError:&error];
342if (error != nil) { /* Error handling... */ }
343
344NSMutableData *inputData;  // Should be initialized
345// input data preparation...
346
347// Get the input `TFLTensor`
348TFLTensor *inputTensor = [interpreter inputTensorAtIndex:0 error:&error];
349if (error != nil) { /* Error handling... */ }
350
351// Copy the input data to the input `TFLTensor`.
352[inputTensor copyData:inputData error:&error];
353if (error != nil) { /* Error handling... */ }
354
355// Run inference by invoking the `TFLInterpreter`.
356[interpreter invokeWithError:&error];
357if (error != nil) { /* Error handling... */ }
358
359// Get the output `TFLTensor`
360TFLTensor *outputTensor = [interpreter outputTensorAtIndex:0 error:&error];
361if (error != nil) { /* Error handling... */ }
362
363// Copy output to `NSData` to process the inference results.
364NSData *outputData = [outputTensor dataWithError:&error];
365if (error != nil) { /* Error handling... */ }
366```
367
368### Using C API in Objective-C code
369
370Currently Objective-C API does not support delegates. In order to use delegates
371with Objective-C code, you need to directly call underlying
372[C API](https://www.tensorflow.org/code/tensorflow/lite/c/c_api.h).
373
374```c
375#include "tensorflow/lite/c/c_api.h"
376```
377
378```c
379TfLiteModel* model = TfLiteModelCreateFromFile([modelPath UTF8String]);
380TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
381
382// Create the interpreter.
383TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
384
385// Allocate tensors and populate the input tensor data.
386TfLiteInterpreterAllocateTensors(interpreter);
387TfLiteTensor* input_tensor =
388    TfLiteInterpreterGetInputTensor(interpreter, 0);
389TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
390                           input.size() * sizeof(float));
391
392// Execute inference.
393TfLiteInterpreterInvoke(interpreter);
394
395// Extract the output tensor data.
396const TfLiteTensor* output_tensor =
397    TfLiteInterpreterGetOutputTensor(interpreter, 0);
398TfLiteTensorCopyToBuffer(output_tensor, output.data(),
399                         output.size() * sizeof(float));
400
401// Dispose of the model and interpreter objects.
402TfLiteInterpreterDelete(interpreter);
403TfLiteInterpreterOptionsDelete(options);
404TfLiteModelDelete(model);
405```
406
407## Load and run a model in C++
408
409*Platforms: Android, iOS, and Linux*
410
411Note: C++ API on iOS is only available when using bazel.
412
413In C++, the model is stored in
414[`FlatBufferModel`](https://www.tensorflow.org/lite/api_docs/cc/class/tflite/flat-buffer-model.html)
415class. It encapsulates a TensorFlow Lite model and you can build it in a couple
416of different ways, depending on where the model is stored:
417
418```c++
419class FlatBufferModel {
420  // Build a model based on a file. Return a nullptr in case of failure.
421  static std::unique_ptr<FlatBufferModel> BuildFromFile(
422      const char* filename,
423      ErrorReporter* error_reporter);
424
425  // Build a model based on a pre-loaded flatbuffer. The caller retains
426  // ownership of the buffer and should keep it alive until the returned object
427  // is destroyed. Return a nullptr in case of failure.
428  static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
429      const char* buffer,
430      size_t buffer_size,
431      ErrorReporter* error_reporter);
432};
433```
434
435Note: If TensorFlow Lite detects the presence of the
436[Android NNAPI](https://developer.android.com/ndk/guides/neuralnetworks), it
437will automatically try to use shared memory to store the `FlatBufferModel`.
438
439Now that you have the model as a `FlatBufferModel` object, you can execute it
440with an
441[`Interpreter`](https://www.tensorflow.org/lite/api_docs/cc/class/tflite/interpreter.html).
442A single `FlatBufferModel` can be used simultaneously by more than one
443`Interpreter`.
444
445Caution: The `FlatBufferModel` object must remain valid until all instances of
446`Interpreter` using it have been destroyed.
447
448The important parts of the `Interpreter` API are shown in the code snippet
449below. It should be noted that:
450
451*   Tensors are represented by integers, in order to avoid string comparisons
452    (and any fixed dependency on string libraries).
453*   An interpreter must not be accessed from concurrent threads.
454*   Memory allocation for input and output tensors must be triggered by calling
455    `AllocateTensors()` right after resizing tensors.
456
457The simplest usage of TensorFlow Lite with C++ looks like this:
458
459```c++
460// Load the model
461std::unique_ptr<tflite::FlatBufferModel> model =
462    tflite::FlatBufferModel::BuildFromFile(filename);
463
464// Build the interpreter
465tflite::ops::builtin::BuiltinOpResolver resolver;
466std::unique_ptr<tflite::Interpreter> interpreter;
467tflite::InterpreterBuilder(*model, resolver)(&interpreter);
468
469// Resize input tensors, if desired.
470interpreter->AllocateTensors();
471
472float* input = interpreter->typed_input_tensor<float>(0);
473// Fill `input`.
474
475interpreter->Invoke();
476
477float* output = interpreter->typed_output_tensor<float>(0);
478```
479
480For more example code, see
481[`minimal.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/minimal/minimal.cc)
482and
483[`label_image.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/label_image/label_image.cc).
484
485## Load and run a model in Python
486
487*Platform: Linux*
488
489The Python API for running an inference is provided in the `tf.lite` module.
490From which, you mostly need only
491[`tf.lite.Interpreter`](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
492to load a model and run an inference.
493
494The following example shows how to use the Python interpreter to load a
495`.tflite` file and run inference with random input data:
496
497This example is recommended if you're converting from SavedModel with a defined
498SignatureDef.
499Available starting from TensorFlow 2.5
500
501```python
502class TestModel(tf.Module):
503  def __init__(self):
504    super(TestModel, self).__init__()
505
506  @tf.function(input_signature=[tf.TensorSpec(shape=[1, 10], dtype=tf.float32)])
507  def add(self, x):
508    '''
509    Simple method that accepts single input 'x' and returns 'x' + 4.
510    '''
511    # Name the output 'result' for convenience.
512    return {'result' : x + 4}
513
514
515SAVED_MODEL_PATH = 'content/saved_models/test_variable'
516TFLITE_FILE_PATH = 'content/test_variable.tflite'
517
518# Save the model
519module = TestModel()
520# You can omit the signatures argument and a default signature name will be
521# created with name 'serving_default'.
522tf.saved_model.save(
523    module, SAVED_MODEL_PATH,
524    signatures={'my_signature':module.add.get_concrete_function()})
525
526# Convert the model using TFLiteConverter
527converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
528tflite_model = converter.convert()
529with open(TFLITE_FILE_PATH, 'wb') as f:
530  f.write(tflite_model)
531
532# Load the TFLite model in TFLite Interpreter
533interpreter = tf.lite.Interpreter(TFLITE_FILE_PATH)
534# There is only 1 signature defined in the model,
535# so it will return it by default.
536# If there are multiple signatures then we can pass the name.
537my_signature = interpreter.get_signature_runner()
538
539# my_signature is callable with input as arguments.
540output = my_signature(x=tf.constant([1.0], shape=(1,10), dtype=tf.float32))
541# 'output' is dictionary with all outputs from the inference.
542# In this case we have single output 'result'.
543print(output['result'])
544```
545
546Another example if the model doesn't have SignatureDefs defined.
547
548```python
549import numpy as np
550import tensorflow as tf
551
552# Load the TFLite model and allocate tensors.
553interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
554interpreter.allocate_tensors()
555
556# Get input and output tensors.
557input_details = interpreter.get_input_details()
558output_details = interpreter.get_output_details()
559
560# Test the model on random input data.
561input_shape = input_details[0]['shape']
562input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
563interpreter.set_tensor(input_details[0]['index'], input_data)
564
565interpreter.invoke()
566
567# The function `get_tensor()` returns a copy of the tensor data.
568# Use `tensor()` in order to get a pointer to the tensor.
569output_data = interpreter.get_tensor(output_details[0]['index'])
570print(output_data)
571```
572
573As an alternative to loading the model as a pre-converted `.tflite` file, you
574can combine your code with the
575[TensorFlow Lite Converter Python API](https://www.tensorflow.org/lite/api_docs/python/tf/lite/TFLiteConverter)
576(`tf.lite.TFLiteConverter`), allowing you to convert your TensorFlow model into
577the TensorFlow Lite format and then run inference:
578
579```python
580import numpy as np
581import tensorflow as tf
582
583img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
584const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
585val = img + const
586out = tf.identity(val, name="out")
587
588# Convert to TF Lite format
589with tf.Session() as sess:
590  converter = tf.lite.TFLiteConverter.from_session(sess, [img], [out])
591  tflite_model = converter.convert()
592
593# Load the TFLite model and allocate tensors.
594interpreter = tf.lite.Interpreter(model_content=tflite_model)
595interpreter.allocate_tensors()
596
597# Continue to get tensors and so forth, as shown above...
598```
599
600For more Python sample code, see
601[`label_image.py`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/python/label_image.py).
602
603Tip: Run `help(tf.lite.Interpreter)` in the Python terminal to get detailed
604documentation about the interpreter.
605
606## Supported operations
607
608TensorFlow Lite supports a subset of TensorFlow operations with some
609limitations. For full list of operations and limitations see
610[TF Lite Ops page](https://www.tensorflow.org/mlir/tfl_ops).
611