• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python TF-Lite interpreter."""
16import ctypes
17import enum
18import os
19import platform
20import sys
21
22import numpy as np
23
24# pylint: disable=g-import-not-at-top
25if not os.path.splitext(__file__)[0].endswith(
26    os.path.join('tflite_runtime', 'interpreter')):
27  # This file is part of tensorflow package.
28  from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper
29  from tensorflow.lite.python.metrics import metrics
30  from tensorflow.python.util.tf_export import tf_export as _tf_export
31else:
32  # This file is part of tflite_runtime package.
33  from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper
34  from tflite_runtime import metrics_portable as metrics
35
36  def _tf_export(*x, **kwargs):
37    del x, kwargs
38    return lambda x: x
39
40
41# pylint: enable=g-import-not-at-top
42
43
44class Delegate:
45  """Python wrapper class to manage TfLiteDelegate objects.
46
47  The shared library is expected to have two functions:
48    TfLiteDelegate* tflite_plugin_create_delegate(
49        char**, char**, size_t, void (*report_error)(const char *))
50    void tflite_plugin_destroy_delegate(TfLiteDelegate*)
51
52  The first one creates a delegate object. It may return NULL to indicate an
53  error (with a suitable error message reported by calling report_error()).
54  The second one destroys delegate object and must be called for every
55  created delegate object. Passing NULL as argument value is allowed, i.e.
56
57    tflite_plugin_destroy_delegate(tflite_plugin_create_delegate(...))
58
59  always works.
60  """
61
62  def __init__(self, library, options=None):
63    """Loads delegate from the shared library.
64
65    Args:
66      library: Shared library name.
67      options: Dictionary of options that are required to load the delegate. All
68        keys and values in the dictionary should be serializable. Consult the
69        documentation of the specific delegate for required and legal options.
70        (default None)
71
72    Raises:
73      RuntimeError: This is raised if the Python implementation is not CPython.
74    """
75
76    # TODO(b/136468453): Remove need for __del__ ordering needs of CPython
77    # by using explicit closes(). See implementation of Interpreter __del__.
78    if platform.python_implementation() != 'CPython':
79      raise RuntimeError('Delegates are currently only supported into CPython'
80                         'due to missing immediate reference counting.')
81
82    self._library = ctypes.pydll.LoadLibrary(library)
83    self._library.tflite_plugin_create_delegate.argtypes = [
84        ctypes.POINTER(ctypes.c_char_p),
85        ctypes.POINTER(ctypes.c_char_p), ctypes.c_int,
86        ctypes.CFUNCTYPE(None, ctypes.c_char_p)
87    ]
88    self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p
89
90    # Convert the options from a dictionary to lists of char pointers.
91    options = options or {}
92    options_keys = (ctypes.c_char_p * len(options))()
93    options_values = (ctypes.c_char_p * len(options))()
94    for idx, (key, value) in enumerate(options.items()):
95      options_keys[idx] = str(key).encode('utf-8')
96      options_values[idx] = str(value).encode('utf-8')
97
98    class ErrorMessageCapture:
99
100      def __init__(self):
101        self.message = ''
102
103      def report(self, x):
104        self.message += x if isinstance(x, str) else x.decode('utf-8')
105
106    capture = ErrorMessageCapture()
107    error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report)
108    # Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer.
109    self._delegate_ptr = self._library.tflite_plugin_create_delegate(
110        options_keys, options_values, len(options), error_capturer_cb)
111    if self._delegate_ptr is None:
112      raise ValueError(capture.message)
113
114  def __del__(self):
115    # __del__ can not be called multiple times, so if the delegate is destroyed.
116    # don't try to destroy it twice.
117    if self._library is not None:
118      self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p]
119      self._library.tflite_plugin_destroy_delegate(self._delegate_ptr)
120      self._library = None
121
122  def _get_native_delegate_pointer(self):
123    """Returns the native TfLiteDelegate pointer.
124
125    It is not safe to copy this pointer because it needs to be freed.
126
127    Returns:
128      TfLiteDelegate *
129    """
130    return self._delegate_ptr
131
132
133@_tf_export('lite.experimental.load_delegate')
134def load_delegate(library, options=None):
135  """Returns loaded Delegate object.
136
137  Example usage:
138
139  ```
140  import tensorflow as tf
141
142  try:
143    delegate = tf.lite.experimental.load_delegate('delegate.so')
144  except ValueError:
145    // Fallback to CPU
146
147  if delegate:
148    interpreter = tf.lite.Interpreter(
149        model_path='model.tflite',
150        experimental_delegates=[delegate])
151  else:
152    interpreter = tf.lite.Interpreter(model_path='model.tflite')
153  ```
154
155  This is typically used to leverage EdgeTPU for running TensorFlow Lite models.
156  For more information see: https://coral.ai/docs/edgetpu/tflite-python/
157
158  Args:
159    library: Name of shared library containing the
160      [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates).
161    options: Dictionary of options that are required to load the delegate. All
162      keys and values in the dictionary should be convertible to str. Consult
163      the documentation of the specific delegate for required and legal options.
164      (default None)
165
166  Returns:
167    Delegate object.
168
169  Raises:
170    ValueError: Delegate failed to load.
171    RuntimeError: If delegate loading is used on unsupported platform.
172  """
173  try:
174    delegate = Delegate(library, options)
175  except ValueError as e:
176    raise ValueError('Failed to load delegate from {}\n{}'.format(
177        library, str(e)))
178  return delegate
179
180
181class SignatureRunner:
182  """SignatureRunner class for running TFLite models using SignatureDef.
183
184  This class should be instantiated through TFLite Interpreter only using
185  get_signature_runner method on Interpreter.
186  Example,
187  signature = interpreter.get_signature_runner("my_signature")
188  result = signature(input_1=my_input_1, input_2=my_input_2)
189  print(result["my_output"])
190  print(result["my_second_output"])
191  All names used are this specific SignatureDef names.
192
193  Notes:
194    No other function on this object or on the interpreter provided should be
195    called while this object call has not finished.
196  """
197
198  def __init__(self, interpreter=None, signature_key=None):
199    """Constructor.
200
201    Args:
202      interpreter: Interpreter object that is already initialized with the
203        requested model.
204      signature_key: SignatureDef key to be used.
205    """
206    if not interpreter:
207      raise ValueError('None interpreter provided.')
208    if not signature_key:
209      raise ValueError('None signature_key provided.')
210    self._interpreter = interpreter
211    self._interpreter_wrapper = interpreter._interpreter
212    self._signature_key = signature_key
213    signature_defs = interpreter._get_full_signature_list()
214    if signature_key not in signature_defs:
215      raise ValueError('Invalid signature_key provided.')
216    self._signature_def = signature_defs[signature_key]
217    self._outputs = self._signature_def['outputs'].items()
218    self._inputs = self._signature_def['inputs']
219
220    self._subgraph_index = (
221        self._interpreter_wrapper.GetSubgraphIndexFromSignature(
222            self._signature_key))
223
224  def __call__(self, **kwargs):
225    """Runs the SignatureDef given the provided inputs in arguments.
226
227    Args:
228      **kwargs: key,value for inputs to the model. Key is the SignatureDef input
229        name. Value is numpy array with the value.
230
231    Returns:
232      dictionary of the results from the model invoke.
233      Key in the dictionary is SignatureDef output name.
234      Value is the result Tensor.
235    """
236
237    if len(kwargs) != len(self._inputs):
238      raise ValueError(
239          'Invalid number of inputs provided for running a SignatureDef, '
240          'expected %s vs provided %s' % (len(self._inputs), len(kwargs)))
241
242    # Resize input tensors
243    for input_name, value in kwargs.items():
244      if input_name not in self._inputs:
245        raise ValueError('Invalid Input name (%s) for SignatureDef' %
246                         input_name)
247      self._interpreter_wrapper.ResizeInputTensor(
248          self._inputs[input_name], np.array(value.shape, dtype=np.int32),
249          False, self._subgraph_index)
250    # Allocate tensors.
251    self._interpreter_wrapper.AllocateTensors(self._subgraph_index)
252    # Set the input values.
253    for input_name, value in kwargs.items():
254      self._interpreter_wrapper.SetTensor(self._inputs[input_name], value,
255                                          self._subgraph_index)
256
257    self._interpreter_wrapper.Invoke(self._subgraph_index)
258    result = {}
259    for output_name, output_index in self._outputs:
260      result[output_name] = self._interpreter_wrapper.GetTensor(
261          output_index, self._subgraph_index)
262    return result
263
264  def get_input_details(self):
265    """Gets input tensor details.
266
267    Returns:
268      A dictionary from input name to tensor details where each item is a
269      dictionary with details about an input tensor. Each dictionary contains
270      the following fields that describe the tensor:
271
272      + `name`: The tensor name.
273      + `index`: The tensor index in the interpreter.
274      + `shape`: The shape of the tensor.
275      + `shape_signature`: Same as `shape` for models with known/fixed shapes.
276        If any dimension sizes are unknown, they are indicated with `-1`.
277      + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
278      + `quantization`: Deprecated, use `quantization_parameters`. This field
279        only works for per-tensor quantization, whereas
280        `quantization_parameters` works in all cases.
281      + `quantization_parameters`: A dictionary of parameters used to quantize
282        the tensor:
283        ~ `scales`: List of scales (one if per-tensor quantization).
284        ~ `zero_points`: List of zero_points (one if per-tensor quantization).
285        ~ `quantized_dimension`: Specifies the dimension of per-axis
286        quantization, in the case of multiple scales/zero_points.
287      + `sparsity_parameters`: A dictionary of parameters used to encode a
288        sparse tensor. This is empty if the tensor is dense.
289    """
290    result = {}
291    for input_name, tensor_index in self._inputs.items():
292      result[input_name] = self._interpreter._get_tensor_details(tensor_index)  # pylint: disable=protected-access
293    return result
294
295  def get_output_details(self):
296    """Gets output tensor details.
297
298    Returns:
299      A dictionary from input name to tensor details where each item is a
300      dictionary with details about an output tensor. The dictionary contains
301      the same fields as described for `get_input_details()`.
302    """
303    result = {}
304    for output_name, tensor_index in self._outputs:
305      result[output_name] = self._interpreter._get_tensor_details(tensor_index)  # pylint: disable=protected-access
306    return result
307
308
309@_tf_export('lite.experimental.OpResolverType')
310@enum.unique
311class OpResolverType(enum.Enum):
312  """Different types of op resolvers for Tensorflow Lite.
313
314  * `AUTO`: Indicates the op resolver that is chosen by default in TfLite
315     Python, which is the "BUILTIN" as described below.
316  * `BUILTIN`: Indicates the op resolver for built-in ops with optimized kernel
317    implementation.
318  * `BUILTIN_REF`: Indicates the op resolver for built-in ops with reference
319    kernel implementation. It's generally used for testing and debugging.
320  * `BUILTIN_WITHOUT_DEFAULT_DELEGATES`: Indicates the op resolver for
321    built-in ops with optimized kernel implementation, but it will disable
322    the application of default TfLite delegates (like the XNNPACK delegate) to
323    the model graph. Generally this should not be used unless there are issues
324    with the default configuration.
325  """
326  # Corresponds to an op resolver chosen by default in TfLite Python.
327  AUTO = 0
328
329  # Corresponds to tflite::ops::builtin::BuiltinOpResolver in C++.
330  BUILTIN = 1
331
332  # Corresponds to tflite::ops::builtin::BuiltinRefOpResolver in C++.
333  BUILTIN_REF = 2
334
335  # Corresponds to
336  # tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates in C++.
337  BUILTIN_WITHOUT_DEFAULT_DELEGATES = 3
338
339
340def _get_op_resolver_id(op_resolver_type=OpResolverType.AUTO):
341  """Get a integer identifier for the op resolver."""
342
343  # Note: the integer identifier value needs to be same w/ op resolver ids
344  # defined in interpreter_wrapper/interpreter_wrapper.cc.
345  return {
346      # Note AUTO and BUILTIN currently share the same identifier.
347      OpResolverType.AUTO: 1,
348      OpResolverType.BUILTIN: 1,
349      OpResolverType.BUILTIN_REF: 2,
350      OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES: 3
351  }.get(op_resolver_type, None)
352
353
354@_tf_export('lite.Interpreter')
355class Interpreter:
356  """Interpreter interface for running TensorFlow Lite models.
357
358  Models obtained from `TfLiteConverter` can be run in Python with
359  `Interpreter`.
360
361  As an example, lets generate a simple Keras model and convert it to TFLite
362  (`TfLiteConverter` also supports other input formats with `from_saved_model`
363  and `from_concrete_function`)
364
365  >>> x = np.array([[1.], [2.]])
366  >>> y = np.array([[2.], [4.]])
367  >>> model = tf.keras.models.Sequential([
368  ...           tf.keras.layers.Dropout(0.2),
369  ...           tf.keras.layers.Dense(units=1, input_shape=[1])
370  ...         ])
371  >>> model.compile(optimizer='sgd', loss='mean_squared_error')
372  >>> model.fit(x, y, epochs=1)
373  >>> converter = tf.lite.TFLiteConverter.from_keras_model(model)
374  >>> tflite_model = converter.convert()
375
376  `tflite_model` can be saved to a file and loaded later, or directly into the
377  `Interpreter`. Since TensorFlow Lite pre-plans tensor allocations to optimize
378  inference, the user needs to call `allocate_tensors()` before any inference.
379
380  >>> interpreter = tf.lite.Interpreter(model_content=tflite_model)
381  >>> interpreter.allocate_tensors()  # Needed before execution!
382
383  Sample execution:
384
385  >>> output = interpreter.get_output_details()[0]  # Model has single output.
386  >>> input = interpreter.get_input_details()[0]  # Model has single input.
387  >>> input_data = tf.constant(1., shape=[1, 1])
388  >>> interpreter.set_tensor(input['index'], input_data)
389  >>> interpreter.invoke()
390  >>> interpreter.get_tensor(output['index']).shape
391  (1, 1)
392
393  Use `get_signature_runner()` for a more user-friendly inference API.
394  """
395
396  def __init__(self,
397               model_path=None,
398               model_content=None,
399               experimental_delegates=None,
400               num_threads=None,
401               experimental_op_resolver_type=OpResolverType.AUTO,
402               experimental_preserve_all_tensors=False):
403    """Constructor.
404
405    Args:
406      model_path: Path to TF-Lite Flatbuffer file.
407      model_content: Content of model.
408      experimental_delegates: Experimental. Subject to change. List of
409        [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
410          objects returned by lite.load_delegate().
411      num_threads: Sets the number of threads used by the interpreter and
412        available to CPU kernels. If not set, the interpreter will use an
413        implementation-dependent default number of threads. Currently, only a
414        subset of kernels, such as conv, support multi-threading. num_threads
415        should be >= -1. Setting num_threads to 0 has the effect to disable
416        multithreading, which is equivalent to setting num_threads to 1. If set
417        to the value -1, the number of threads used will be
418        implementation-defined and platform-dependent.
419      experimental_op_resolver_type: The op resolver used by the interpreter. It
420        must be an instance of OpResolverType. By default, we use the built-in
421        op resolver which corresponds to tflite::ops::builtin::BuiltinOpResolver
422        in C++.
423      experimental_preserve_all_tensors: If true, then intermediate tensors used
424        during computation are preserved for inspection, and if the passed op
425        resolver type is AUTO or BUILTIN, the type will be changed to
426        BUILTIN_WITHOUT_DEFAULT_DELEGATES so that no Tensorflow Lite default
427        delegates are applied. If false, getting intermediate tensors could
428        result in undefined values or None, especially when the graph is
429        successfully modified by the Tensorflow Lite default delegate.
430
431    Raises:
432      ValueError: If the interpreter was unable to create.
433    """
434    if not hasattr(self, '_custom_op_registerers'):
435      self._custom_op_registerers = []
436
437    actual_resolver_type = experimental_op_resolver_type
438    if experimental_preserve_all_tensors and (
439        experimental_op_resolver_type == OpResolverType.AUTO or
440        experimental_op_resolver_type == OpResolverType.BUILTIN):
441      actual_resolver_type = OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
442    op_resolver_id = _get_op_resolver_id(actual_resolver_type)
443    if op_resolver_id is None:
444      raise ValueError('Unrecognized passed in op resolver type: {}'.format(
445          experimental_op_resolver_type))
446
447    if model_path and not model_content:
448      custom_op_registerers_by_name = [
449          x for x in self._custom_op_registerers if isinstance(x, str)
450      ]
451      custom_op_registerers_by_func = [
452          x for x in self._custom_op_registerers if not isinstance(x, str)
453      ]
454      self._interpreter = (
455          _interpreter_wrapper.CreateWrapperFromFile(
456              model_path, op_resolver_id, custom_op_registerers_by_name,
457              custom_op_registerers_by_func, experimental_preserve_all_tensors))
458      if not self._interpreter:
459        raise ValueError('Failed to open {}'.format(model_path))
460    elif model_content and not model_path:
461      custom_op_registerers_by_name = [
462          x for x in self._custom_op_registerers if isinstance(x, str)
463      ]
464      custom_op_registerers_by_func = [
465          x for x in self._custom_op_registerers if not isinstance(x, str)
466      ]
467      # Take a reference, so the pointer remains valid.
468      # Since python strings are immutable then PyString_XX functions
469      # will always return the same pointer.
470      self._model_content = model_content
471      self._interpreter = (
472          _interpreter_wrapper.CreateWrapperFromBuffer(
473              model_content, op_resolver_id, custom_op_registerers_by_name,
474              custom_op_registerers_by_func, experimental_preserve_all_tensors))
475    elif not model_content and not model_path:
476      raise ValueError('`model_path` or `model_content` must be specified.')
477    else:
478      raise ValueError('Can\'t both provide `model_path` and `model_content`')
479
480    if num_threads is not None:
481      if not isinstance(num_threads, int):
482        raise ValueError('type of num_threads should be int')
483      if num_threads < 1:
484        raise ValueError('num_threads should >= 1')
485      self._interpreter.SetNumThreads(num_threads)
486
487    # Each delegate is a wrapper that owns the delegates that have been loaded
488    # as plugins. The interpreter wrapper will be using them, but we need to
489    # hold them in a list so that the lifetime is preserved at least as long as
490    # the interpreter wrapper.
491    self._delegates = []
492    if experimental_delegates:
493      self._delegates = experimental_delegates
494      for delegate in self._delegates:
495        self._interpreter.ModifyGraphWithDelegate(
496            delegate._get_native_delegate_pointer())  # pylint: disable=protected-access
497    self._signature_defs = self.get_signature_list()
498
499    self._metrics = metrics.TFLiteMetrics()
500    self._metrics.increase_counter_interpreter_creation()
501
502  def __del__(self):
503    # Must make sure the interpreter is destroyed before things that
504    # are used by it like the delegates. NOTE this only works on CPython
505    # probably.
506    # TODO(b/136468453): Remove need for __del__ ordering needs of CPython
507    # by using explicit closes(). See implementation of Interpreter __del__.
508    self._interpreter = None
509    self._delegates = None
510
511  def allocate_tensors(self):
512    self._ensure_safe()
513    return self._interpreter.AllocateTensors()
514
515  def _safe_to_run(self):
516    """Returns true if there exist no numpy array buffers.
517
518    This means it is safe to run tflite calls that may destroy internally
519    allocated memory. This works, because in the wrapper.cc we have made
520    the numpy base be the self._interpreter.
521    """
522    # NOTE, our tensor() call in cpp will use _interpreter as a base pointer.
523    # If this environment is the only _interpreter, then the ref count should be
524    # 2 (1 in self and 1 in temporary of sys.getrefcount).
525    return sys.getrefcount(self._interpreter) == 2
526
527  def _ensure_safe(self):
528    """Makes sure no numpy arrays pointing to internal buffers are active.
529
530    This should be called from any function that will call a function on
531    _interpreter that may reallocate memory e.g. invoke(), ...
532
533    Raises:
534      RuntimeError: If there exist numpy objects pointing to internal memory
535        then we throw.
536    """
537    if not self._safe_to_run():
538      raise RuntimeError("""There is at least 1 reference to internal data
539      in the interpreter in the form of a numpy array or slice. Be sure to
540      only hold the function returned from tensor() if you are using raw
541      data access.""")
542
543  # Experimental and subject to change
544  def _get_op_details(self, op_index):
545    """Gets a dictionary with arrays of ids for tensors involved with an op.
546
547    Args:
548      op_index: Operation/node index of node to query.
549
550    Returns:
551      a dictionary containing the index, op name, and arrays with lists of the
552      indices for the inputs and outputs of the op/node.
553    """
554    op_index = int(op_index)
555    op_name = self._interpreter.NodeName(op_index)
556    op_inputs = self._interpreter.NodeInputs(op_index)
557    op_outputs = self._interpreter.NodeOutputs(op_index)
558
559    details = {
560        'index': op_index,
561        'op_name': op_name,
562        'inputs': op_inputs,
563        'outputs': op_outputs,
564    }
565
566    return details
567
568  def _get_tensor_details(self, tensor_index):
569    """Gets tensor details.
570
571    Args:
572      tensor_index: Tensor index of tensor to query.
573
574    Returns:
575      A dictionary containing the following fields of the tensor:
576        'name': The tensor name.
577        'index': The tensor index in the interpreter.
578        'shape': The shape of the tensor.
579        'quantization': Deprecated, use 'quantization_parameters'. This field
580            only works for per-tensor quantization, whereas
581            'quantization_parameters' works in all cases.
582        'quantization_parameters': The parameters used to quantize the tensor:
583          'scales': List of scales (one if per-tensor quantization)
584          'zero_points': List of zero_points (one if per-tensor quantization)
585          'quantized_dimension': Specifies the dimension of per-axis
586              quantization, in the case of multiple scales/zero_points.
587
588    Raises:
589      ValueError: If tensor_index is invalid.
590    """
591    tensor_index = int(tensor_index)
592    tensor_name = self._interpreter.TensorName(tensor_index)
593    tensor_size = self._interpreter.TensorSize(tensor_index)
594    tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index)
595    tensor_type = self._interpreter.TensorType(tensor_index)
596    tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
597    tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
598        tensor_index)
599    tensor_sparsity_params = self._interpreter.TensorSparsityParameters(
600        tensor_index)
601
602    if not tensor_type:
603      raise ValueError('Could not get tensor details')
604
605    details = {
606        'name': tensor_name,
607        'index': tensor_index,
608        'shape': tensor_size,
609        'shape_signature': tensor_size_signature,
610        'dtype': tensor_type,
611        'quantization': tensor_quantization,
612        'quantization_parameters': {
613            'scales': tensor_quantization_params[0],
614            'zero_points': tensor_quantization_params[1],
615            'quantized_dimension': tensor_quantization_params[2],
616        },
617        'sparsity_parameters': tensor_sparsity_params
618    }
619
620    return details
621
622  # Experimental and subject to change
623  def _get_ops_details(self):
624    """Gets op details for every node.
625
626    Returns:
627      A list of dictionaries containing arrays with lists of tensor ids for
628      tensors involved in the op.
629    """
630    return [
631        self._get_op_details(idx) for idx in range(self._interpreter.NumNodes())
632    ]
633
634  def get_tensor_details(self):
635    """Gets tensor details for every tensor with valid tensor details.
636
637    Tensors where required information about the tensor is not found are not
638    added to the list. This includes temporary tensors without a name.
639
640    Returns:
641      A list of dictionaries containing tensor information.
642    """
643    tensor_details = []
644    for idx in range(self._interpreter.NumTensors()):
645      try:
646        tensor_details.append(self._get_tensor_details(idx))
647      except ValueError:
648        pass
649    return tensor_details
650
651  def get_input_details(self):
652    """Gets model input tensor details.
653
654    Returns:
655      A list in which each item is a dictionary with details about
656      an input tensor. Each dictionary contains the following fields
657      that describe the tensor:
658
659      + `name`: The tensor name.
660      + `index`: The tensor index in the interpreter.
661      + `shape`: The shape of the tensor.
662      + `shape_signature`: Same as `shape` for models with known/fixed shapes.
663        If any dimension sizes are unknown, they are indicated with `-1`.
664      + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
665      + `quantization`: Deprecated, use `quantization_parameters`. This field
666        only works for per-tensor quantization, whereas
667        `quantization_parameters` works in all cases.
668      + `quantization_parameters`: A dictionary of parameters used to quantize
669        the tensor:
670        ~ `scales`: List of scales (one if per-tensor quantization).
671        ~ `zero_points`: List of zero_points (one if per-tensor quantization).
672        ~ `quantized_dimension`: Specifies the dimension of per-axis
673        quantization, in the case of multiple scales/zero_points.
674      + `sparsity_parameters`: A dictionary of parameters used to encode a
675        sparse tensor. This is empty if the tensor is dense.
676    """
677    return [
678        self._get_tensor_details(i) for i in self._interpreter.InputIndices()
679    ]
680
681  def set_tensor(self, tensor_index, value):
682    """Sets the value of the input tensor.
683
684    Note this copies data in `value`.
685
686    If you want to avoid copying, you can use the `tensor()` function to get a
687    numpy buffer pointing to the input buffer in the tflite interpreter.
688
689    Args:
690      tensor_index: Tensor index of tensor to set. This value can be gotten from
691        the 'index' field in get_input_details.
692      value: Value of tensor to set.
693
694    Raises:
695      ValueError: If the interpreter could not set the tensor.
696    """
697    self._interpreter.SetTensor(tensor_index, value)
698
699  def resize_tensor_input(self, input_index, tensor_size, strict=False):
700    """Resizes an input tensor.
701
702    Args:
703      input_index: Tensor index of input to set. This value can be gotten from
704        the 'index' field in get_input_details.
705      tensor_size: The tensor_shape to resize the input to.
706      strict: Only unknown dimensions can be resized when `strict` is True.
707        Unknown dimensions are indicated as `-1` in the `shape_signature`
708        attribute of a given tensor. (default False)
709
710    Raises:
711      ValueError: If the interpreter could not resize the input tensor.
712
713    Usage:
714    ```
715    interpreter = Interpreter(model_content=tflite_model)
716    interpreter.resize_tensor_input(0, [num_test_images, 224, 224, 3])
717    interpreter.allocate_tensors()
718    interpreter.set_tensor(0, test_images)
719    interpreter.invoke()
720    ```
721    """
722    self._ensure_safe()
723    # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
724    # parameter.
725    tensor_size = np.array(tensor_size, dtype=np.int32)
726    self._interpreter.ResizeInputTensor(input_index, tensor_size, strict)
727
728  def get_output_details(self):
729    """Gets model output tensor details.
730
731    Returns:
732      A list in which each item is a dictionary with details about
733      an output tensor. The dictionary contains the same fields as
734      described for `get_input_details()`.
735    """
736    return [
737        self._get_tensor_details(i) for i in self._interpreter.OutputIndices()
738    ]
739
740  def get_signature_list(self):
741    """Gets list of SignatureDefs in the model.
742
743    Example,
744    ```
745    signatures = interpreter.get_signature_list()
746    print(signatures)
747
748    # {
749    #   'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']}
750    # }
751
752    Then using the names in the signature list you can get a callable from
753    get_signature_runner().
754    ```
755
756    Returns:
757      A list of SignatureDef details in a dictionary structure.
758      It is keyed on the SignatureDef method name, and the value holds
759      dictionary of inputs and outputs.
760    """
761    full_signature_defs = self._interpreter.GetSignatureDefs()
762    for _, signature_def in full_signature_defs.items():
763      signature_def['inputs'] = list(signature_def['inputs'].keys())
764      signature_def['outputs'] = list(signature_def['outputs'].keys())
765    return full_signature_defs
766
767  def _get_full_signature_list(self):
768    """Gets list of SignatureDefs in the model.
769
770    Example,
771    ```
772    signatures = interpreter._get_full_signature_list()
773    print(signatures)
774
775    # {
776    #   'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}}
777    # }
778
779    Then using the names in the signature list you can get a callable from
780    get_signature_runner().
781    ```
782
783    Returns:
784      A list of SignatureDef details in a dictionary structure.
785      It is keyed on the SignatureDef method name, and the value holds
786      dictionary of inputs and outputs.
787    """
788    return self._interpreter.GetSignatureDefs()
789
790  def get_signature_runner(self, signature_key=None):
791    """Gets callable for inference of specific SignatureDef.
792
793    Example usage,
794    ```
795    interpreter = tf.lite.Interpreter(model_content=tflite_model)
796    interpreter.allocate_tensors()
797    fn = interpreter.get_signature_runner('div_with_remainder')
798    output = fn(x=np.array([3]), y=np.array([2]))
799    print(output)
800    # {
801    #   'quotient': array([1.], dtype=float32)
802    #   'remainder': array([1.], dtype=float32)
803    # }
804    ```
805
806    None can be passed for signature_key if the model has a single Signature
807    only.
808
809    All names used are this specific SignatureDef names.
810
811
812    Args:
813      signature_key: Signature key for the SignatureDef, it can be None if and
814        only if the model has a single SignatureDef. Default value is None.
815
816    Returns:
817      This returns a callable that can run inference for SignatureDef defined
818      by argument 'signature_key'.
819      The callable will take key arguments corresponding to the arguments of the
820      SignatureDef, that should have numpy values.
821      The callable will returns dictionary that maps from output names to numpy
822      values of the computed results.
823
824    Raises:
825      ValueError: If passed signature_key is invalid.
826    """
827    if signature_key is None:
828      if len(self._signature_defs) != 1:
829        raise ValueError(
830            'SignatureDef signature_key is None and model has {0} Signatures. '
831            'None is only allowed when the model has 1 SignatureDef'.format(
832                len(self._signature_defs)))
833      else:
834        signature_key = next(iter(self._signature_defs))
835    return SignatureRunner(interpreter=self, signature_key=signature_key)
836
837  def get_tensor(self, tensor_index, subgraph_index=0):
838    """Gets the value of the output tensor (get a copy).
839
840    If you wish to avoid the copy, use `tensor()`. This function cannot be used
841    to read intermediate results.
842
843    Args:
844      tensor_index: Tensor index of tensor to get. This value can be gotten from
845        the 'index' field in get_output_details.
846      subgraph_index: Index of the subgraph to fetch the tensor. Default value
847        is 0, which means to fetch from the primary subgraph.
848
849    Returns:
850      a numpy array.
851    """
852    return self._interpreter.GetTensor(tensor_index, subgraph_index)
853
854  def tensor(self, tensor_index):
855    """Returns function that gives a numpy view of the current tensor buffer.
856
857    This allows reading and writing to this tensors w/o copies. This more
858    closely mirrors the C++ Interpreter class interface's tensor() member, hence
859    the name. Be careful to not hold these output references through calls
860    to `allocate_tensors()` and `invoke()`. This function cannot be used to read
861    intermediate results.
862
863    Usage:
864
865    ```
866    interpreter.allocate_tensors()
867    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
868    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
869    for i in range(10):
870      input().fill(3.)
871      interpreter.invoke()
872      print("inference %s" % output())
873    ```
874
875    Notice how this function avoids making a numpy array directly. This is
876    because it is important to not hold actual numpy views to the data longer
877    than necessary. If you do, then the interpreter can no longer be invoked,
878    because it is possible the interpreter would resize and invalidate the
879    referenced tensors. The NumPy API doesn't allow any mutability of the
880    the underlying buffers.
881
882    WRONG:
883
884    ```
885    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
886    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
887    interpreter.allocate_tensors()  # This will throw RuntimeError
888    for i in range(10):
889      input.fill(3.)
890      interpreter.invoke()  # this will throw RuntimeError since input,output
891    ```
892
893    Args:
894      tensor_index: Tensor index of tensor to get. This value can be gotten from
895        the 'index' field in get_output_details.
896
897    Returns:
898      A function that can return a new numpy array pointing to the internal
899      TFLite tensor state at any point. It is safe to hold the function forever,
900      but it is not safe to hold the numpy array forever.
901    """
902    return lambda: self._interpreter.tensor(self._interpreter, tensor_index)
903
904  def invoke(self):
905    """Invoke the interpreter.
906
907    Be sure to set the input sizes, allocate tensors and fill values before
908    calling this. Also, note that this function releases the GIL so heavy
909    computation can be done in the background while the Python interpreter
910    continues. No other function on this object should be called while the
911    invoke() call has not finished.
912
913    Raises:
914      ValueError: When the underlying interpreter fails raise ValueError.
915    """
916    self._ensure_safe()
917    self._interpreter.Invoke()
918
919  def reset_all_variables(self):
920    return self._interpreter.ResetVariableTensors()
921
922  # Experimental and subject to change.
923  def _native_handle(self):
924    """Returns a pointer to the underlying tflite::Interpreter instance.
925
926    This allows extending tflite.Interpreter's functionality in a custom C++
927    function. Consider how that may work in a custom pybind wrapper:
928
929      m.def("SomeNewFeature", ([](py::object handle) {
930        auto* interpreter =
931          reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>());
932        ...
933      }))
934
935    and corresponding Python call:
936
937      SomeNewFeature(interpreter.native_handle())
938
939    Note: This approach is fragile. Users must guarantee the C++ extension build
940    is consistent with the tflite.Interpreter's underlying C++ build.
941    """
942    return self._interpreter.interpreter()
943
944
945class InterpreterWithCustomOps(Interpreter):
946  """Interpreter interface for TensorFlow Lite Models that accepts custom ops.
947
948  The interface provided by this class is experimental and therefore not exposed
949  as part of the public API.
950
951  Wraps the tf.lite.Interpreter class and adds the ability to load custom ops
952  by providing the names of functions that take a pointer to a BuiltinOpResolver
953  and add a custom op.
954  """
955
956  def __init__(self, custom_op_registerers=None, **kwargs):
957    """Constructor.
958
959    Args:
960      custom_op_registerers: List of str (symbol names) or functions that take a
961        pointer to a MutableOpResolver and register a custom op. When passing
962        functions, use a pybind function that takes a uintptr_t that can be
963        recast as a pointer to a MutableOpResolver.
964      **kwargs: Additional arguments passed to Interpreter.
965
966    Raises:
967      ValueError: If the interpreter was unable to create.
968    """
969    self._custom_op_registerers = custom_op_registerers or []
970    super(InterpreterWithCustomOps, self).__init__(**kwargs)
971