1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python TF-Lite interpreter.""" 16import ctypes 17import enum 18import os 19import platform 20import sys 21 22import numpy as np 23 24# pylint: disable=g-import-not-at-top 25if not os.path.splitext(__file__)[0].endswith( 26 os.path.join('tflite_runtime', 'interpreter')): 27 # This file is part of tensorflow package. 28 from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 29 from tensorflow.lite.python.metrics import metrics 30 from tensorflow.python.util.tf_export import tf_export as _tf_export 31else: 32 # This file is part of tflite_runtime package. 33 from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 34 from tflite_runtime import metrics_portable as metrics 35 36 def _tf_export(*x, **kwargs): 37 del x, kwargs 38 return lambda x: x 39 40 41# pylint: enable=g-import-not-at-top 42 43 44class Delegate: 45 """Python wrapper class to manage TfLiteDelegate objects. 46 47 The shared library is expected to have two functions: 48 TfLiteDelegate* tflite_plugin_create_delegate( 49 char**, char**, size_t, void (*report_error)(const char *)) 50 void tflite_plugin_destroy_delegate(TfLiteDelegate*) 51 52 The first one creates a delegate object. It may return NULL to indicate an 53 error (with a suitable error message reported by calling report_error()). 54 The second one destroys delegate object and must be called for every 55 created delegate object. Passing NULL as argument value is allowed, i.e. 56 57 tflite_plugin_destroy_delegate(tflite_plugin_create_delegate(...)) 58 59 always works. 60 """ 61 62 def __init__(self, library, options=None): 63 """Loads delegate from the shared library. 64 65 Args: 66 library: Shared library name. 67 options: Dictionary of options that are required to load the delegate. All 68 keys and values in the dictionary should be serializable. Consult the 69 documentation of the specific delegate for required and legal options. 70 (default None) 71 72 Raises: 73 RuntimeError: This is raised if the Python implementation is not CPython. 74 """ 75 76 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 77 # by using explicit closes(). See implementation of Interpreter __del__. 78 if platform.python_implementation() != 'CPython': 79 raise RuntimeError('Delegates are currently only supported into CPython' 80 'due to missing immediate reference counting.') 81 82 self._library = ctypes.pydll.LoadLibrary(library) 83 self._library.tflite_plugin_create_delegate.argtypes = [ 84 ctypes.POINTER(ctypes.c_char_p), 85 ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, 86 ctypes.CFUNCTYPE(None, ctypes.c_char_p) 87 ] 88 self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p 89 90 # Convert the options from a dictionary to lists of char pointers. 91 options = options or {} 92 options_keys = (ctypes.c_char_p * len(options))() 93 options_values = (ctypes.c_char_p * len(options))() 94 for idx, (key, value) in enumerate(options.items()): 95 options_keys[idx] = str(key).encode('utf-8') 96 options_values[idx] = str(value).encode('utf-8') 97 98 class ErrorMessageCapture: 99 100 def __init__(self): 101 self.message = '' 102 103 def report(self, x): 104 self.message += x if isinstance(x, str) else x.decode('utf-8') 105 106 capture = ErrorMessageCapture() 107 error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report) 108 # Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer. 109 self._delegate_ptr = self._library.tflite_plugin_create_delegate( 110 options_keys, options_values, len(options), error_capturer_cb) 111 if self._delegate_ptr is None: 112 raise ValueError(capture.message) 113 114 def __del__(self): 115 # __del__ can not be called multiple times, so if the delegate is destroyed. 116 # don't try to destroy it twice. 117 if self._library is not None: 118 self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p] 119 self._library.tflite_plugin_destroy_delegate(self._delegate_ptr) 120 self._library = None 121 122 def _get_native_delegate_pointer(self): 123 """Returns the native TfLiteDelegate pointer. 124 125 It is not safe to copy this pointer because it needs to be freed. 126 127 Returns: 128 TfLiteDelegate * 129 """ 130 return self._delegate_ptr 131 132 133@_tf_export('lite.experimental.load_delegate') 134def load_delegate(library, options=None): 135 """Returns loaded Delegate object. 136 137 Example usage: 138 139 ``` 140 import tensorflow as tf 141 142 try: 143 delegate = tf.lite.experimental.load_delegate('delegate.so') 144 except ValueError: 145 // Fallback to CPU 146 147 if delegate: 148 interpreter = tf.lite.Interpreter( 149 model_path='model.tflite', 150 experimental_delegates=[delegate]) 151 else: 152 interpreter = tf.lite.Interpreter(model_path='model.tflite') 153 ``` 154 155 This is typically used to leverage EdgeTPU for running TensorFlow Lite models. 156 For more information see: https://coral.ai/docs/edgetpu/tflite-python/ 157 158 Args: 159 library: Name of shared library containing the 160 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates). 161 options: Dictionary of options that are required to load the delegate. All 162 keys and values in the dictionary should be convertible to str. Consult 163 the documentation of the specific delegate for required and legal options. 164 (default None) 165 166 Returns: 167 Delegate object. 168 169 Raises: 170 ValueError: Delegate failed to load. 171 RuntimeError: If delegate loading is used on unsupported platform. 172 """ 173 try: 174 delegate = Delegate(library, options) 175 except ValueError as e: 176 raise ValueError('Failed to load delegate from {}\n{}'.format( 177 library, str(e))) 178 return delegate 179 180 181class SignatureRunner: 182 """SignatureRunner class for running TFLite models using SignatureDef. 183 184 This class should be instantiated through TFLite Interpreter only using 185 get_signature_runner method on Interpreter. 186 Example, 187 signature = interpreter.get_signature_runner("my_signature") 188 result = signature(input_1=my_input_1, input_2=my_input_2) 189 print(result["my_output"]) 190 print(result["my_second_output"]) 191 All names used are this specific SignatureDef names. 192 193 Notes: 194 No other function on this object or on the interpreter provided should be 195 called while this object call has not finished. 196 """ 197 198 def __init__(self, interpreter=None, signature_key=None): 199 """Constructor. 200 201 Args: 202 interpreter: Interpreter object that is already initialized with the 203 requested model. 204 signature_key: SignatureDef key to be used. 205 """ 206 if not interpreter: 207 raise ValueError('None interpreter provided.') 208 if not signature_key: 209 raise ValueError('None signature_key provided.') 210 self._interpreter = interpreter 211 self._interpreter_wrapper = interpreter._interpreter 212 self._signature_key = signature_key 213 signature_defs = interpreter._get_full_signature_list() 214 if signature_key not in signature_defs: 215 raise ValueError('Invalid signature_key provided.') 216 self._signature_def = signature_defs[signature_key] 217 self._outputs = self._signature_def['outputs'].items() 218 self._inputs = self._signature_def['inputs'] 219 220 self._subgraph_index = ( 221 self._interpreter_wrapper.GetSubgraphIndexFromSignature( 222 self._signature_key)) 223 224 def __call__(self, **kwargs): 225 """Runs the SignatureDef given the provided inputs in arguments. 226 227 Args: 228 **kwargs: key,value for inputs to the model. Key is the SignatureDef input 229 name. Value is numpy array with the value. 230 231 Returns: 232 dictionary of the results from the model invoke. 233 Key in the dictionary is SignatureDef output name. 234 Value is the result Tensor. 235 """ 236 237 if len(kwargs) != len(self._inputs): 238 raise ValueError( 239 'Invalid number of inputs provided for running a SignatureDef, ' 240 'expected %s vs provided %s' % (len(self._inputs), len(kwargs))) 241 242 # Resize input tensors 243 for input_name, value in kwargs.items(): 244 if input_name not in self._inputs: 245 raise ValueError('Invalid Input name (%s) for SignatureDef' % 246 input_name) 247 self._interpreter_wrapper.ResizeInputTensor( 248 self._inputs[input_name], np.array(value.shape, dtype=np.int32), 249 False, self._subgraph_index) 250 # Allocate tensors. 251 self._interpreter_wrapper.AllocateTensors(self._subgraph_index) 252 # Set the input values. 253 for input_name, value in kwargs.items(): 254 self._interpreter_wrapper.SetTensor(self._inputs[input_name], value, 255 self._subgraph_index) 256 257 self._interpreter_wrapper.Invoke(self._subgraph_index) 258 result = {} 259 for output_name, output_index in self._outputs: 260 result[output_name] = self._interpreter_wrapper.GetTensor( 261 output_index, self._subgraph_index) 262 return result 263 264 def get_input_details(self): 265 """Gets input tensor details. 266 267 Returns: 268 A dictionary from input name to tensor details where each item is a 269 dictionary with details about an input tensor. Each dictionary contains 270 the following fields that describe the tensor: 271 272 + `name`: The tensor name. 273 + `index`: The tensor index in the interpreter. 274 + `shape`: The shape of the tensor. 275 + `shape_signature`: Same as `shape` for models with known/fixed shapes. 276 If any dimension sizes are unknown, they are indicated with `-1`. 277 + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`). 278 + `quantization`: Deprecated, use `quantization_parameters`. This field 279 only works for per-tensor quantization, whereas 280 `quantization_parameters` works in all cases. 281 + `quantization_parameters`: A dictionary of parameters used to quantize 282 the tensor: 283 ~ `scales`: List of scales (one if per-tensor quantization). 284 ~ `zero_points`: List of zero_points (one if per-tensor quantization). 285 ~ `quantized_dimension`: Specifies the dimension of per-axis 286 quantization, in the case of multiple scales/zero_points. 287 + `sparsity_parameters`: A dictionary of parameters used to encode a 288 sparse tensor. This is empty if the tensor is dense. 289 """ 290 result = {} 291 for input_name, tensor_index in self._inputs.items(): 292 result[input_name] = self._interpreter._get_tensor_details(tensor_index) # pylint: disable=protected-access 293 return result 294 295 def get_output_details(self): 296 """Gets output tensor details. 297 298 Returns: 299 A dictionary from input name to tensor details where each item is a 300 dictionary with details about an output tensor. The dictionary contains 301 the same fields as described for `get_input_details()`. 302 """ 303 result = {} 304 for output_name, tensor_index in self._outputs: 305 result[output_name] = self._interpreter._get_tensor_details(tensor_index) # pylint: disable=protected-access 306 return result 307 308 309@_tf_export('lite.experimental.OpResolverType') 310@enum.unique 311class OpResolverType(enum.Enum): 312 """Different types of op resolvers for Tensorflow Lite. 313 314 * `AUTO`: Indicates the op resolver that is chosen by default in TfLite 315 Python, which is the "BUILTIN" as described below. 316 * `BUILTIN`: Indicates the op resolver for built-in ops with optimized kernel 317 implementation. 318 * `BUILTIN_REF`: Indicates the op resolver for built-in ops with reference 319 kernel implementation. It's generally used for testing and debugging. 320 * `BUILTIN_WITHOUT_DEFAULT_DELEGATES`: Indicates the op resolver for 321 built-in ops with optimized kernel implementation, but it will disable 322 the application of default TfLite delegates (like the XNNPACK delegate) to 323 the model graph. Generally this should not be used unless there are issues 324 with the default configuration. 325 """ 326 # Corresponds to an op resolver chosen by default in TfLite Python. 327 AUTO = 0 328 329 # Corresponds to tflite::ops::builtin::BuiltinOpResolver in C++. 330 BUILTIN = 1 331 332 # Corresponds to tflite::ops::builtin::BuiltinRefOpResolver in C++. 333 BUILTIN_REF = 2 334 335 # Corresponds to 336 # tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates in C++. 337 BUILTIN_WITHOUT_DEFAULT_DELEGATES = 3 338 339 340def _get_op_resolver_id(op_resolver_type=OpResolverType.AUTO): 341 """Get a integer identifier for the op resolver.""" 342 343 # Note: the integer identifier value needs to be same w/ op resolver ids 344 # defined in interpreter_wrapper/interpreter_wrapper.cc. 345 return { 346 # Note AUTO and BUILTIN currently share the same identifier. 347 OpResolverType.AUTO: 1, 348 OpResolverType.BUILTIN: 1, 349 OpResolverType.BUILTIN_REF: 2, 350 OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES: 3 351 }.get(op_resolver_type, None) 352 353 354@_tf_export('lite.Interpreter') 355class Interpreter: 356 """Interpreter interface for running TensorFlow Lite models. 357 358 Models obtained from `TfLiteConverter` can be run in Python with 359 `Interpreter`. 360 361 As an example, lets generate a simple Keras model and convert it to TFLite 362 (`TfLiteConverter` also supports other input formats with `from_saved_model` 363 and `from_concrete_function`) 364 365 >>> x = np.array([[1.], [2.]]) 366 >>> y = np.array([[2.], [4.]]) 367 >>> model = tf.keras.models.Sequential([ 368 ... tf.keras.layers.Dropout(0.2), 369 ... tf.keras.layers.Dense(units=1, input_shape=[1]) 370 ... ]) 371 >>> model.compile(optimizer='sgd', loss='mean_squared_error') 372 >>> model.fit(x, y, epochs=1) 373 >>> converter = tf.lite.TFLiteConverter.from_keras_model(model) 374 >>> tflite_model = converter.convert() 375 376 `tflite_model` can be saved to a file and loaded later, or directly into the 377 `Interpreter`. Since TensorFlow Lite pre-plans tensor allocations to optimize 378 inference, the user needs to call `allocate_tensors()` before any inference. 379 380 >>> interpreter = tf.lite.Interpreter(model_content=tflite_model) 381 >>> interpreter.allocate_tensors() # Needed before execution! 382 383 Sample execution: 384 385 >>> output = interpreter.get_output_details()[0] # Model has single output. 386 >>> input = interpreter.get_input_details()[0] # Model has single input. 387 >>> input_data = tf.constant(1., shape=[1, 1]) 388 >>> interpreter.set_tensor(input['index'], input_data) 389 >>> interpreter.invoke() 390 >>> interpreter.get_tensor(output['index']).shape 391 (1, 1) 392 393 Use `get_signature_runner()` for a more user-friendly inference API. 394 """ 395 396 def __init__(self, 397 model_path=None, 398 model_content=None, 399 experimental_delegates=None, 400 num_threads=None, 401 experimental_op_resolver_type=OpResolverType.AUTO, 402 experimental_preserve_all_tensors=False): 403 """Constructor. 404 405 Args: 406 model_path: Path to TF-Lite Flatbuffer file. 407 model_content: Content of model. 408 experimental_delegates: Experimental. Subject to change. List of 409 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) 410 objects returned by lite.load_delegate(). 411 num_threads: Sets the number of threads used by the interpreter and 412 available to CPU kernels. If not set, the interpreter will use an 413 implementation-dependent default number of threads. Currently, only a 414 subset of kernels, such as conv, support multi-threading. num_threads 415 should be >= -1. Setting num_threads to 0 has the effect to disable 416 multithreading, which is equivalent to setting num_threads to 1. If set 417 to the value -1, the number of threads used will be 418 implementation-defined and platform-dependent. 419 experimental_op_resolver_type: The op resolver used by the interpreter. It 420 must be an instance of OpResolverType. By default, we use the built-in 421 op resolver which corresponds to tflite::ops::builtin::BuiltinOpResolver 422 in C++. 423 experimental_preserve_all_tensors: If true, then intermediate tensors used 424 during computation are preserved for inspection, and if the passed op 425 resolver type is AUTO or BUILTIN, the type will be changed to 426 BUILTIN_WITHOUT_DEFAULT_DELEGATES so that no Tensorflow Lite default 427 delegates are applied. If false, getting intermediate tensors could 428 result in undefined values or None, especially when the graph is 429 successfully modified by the Tensorflow Lite default delegate. 430 431 Raises: 432 ValueError: If the interpreter was unable to create. 433 """ 434 if not hasattr(self, '_custom_op_registerers'): 435 self._custom_op_registerers = [] 436 437 actual_resolver_type = experimental_op_resolver_type 438 if experimental_preserve_all_tensors and ( 439 experimental_op_resolver_type == OpResolverType.AUTO or 440 experimental_op_resolver_type == OpResolverType.BUILTIN): 441 actual_resolver_type = OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES 442 op_resolver_id = _get_op_resolver_id(actual_resolver_type) 443 if op_resolver_id is None: 444 raise ValueError('Unrecognized passed in op resolver type: {}'.format( 445 experimental_op_resolver_type)) 446 447 if model_path and not model_content: 448 custom_op_registerers_by_name = [ 449 x for x in self._custom_op_registerers if isinstance(x, str) 450 ] 451 custom_op_registerers_by_func = [ 452 x for x in self._custom_op_registerers if not isinstance(x, str) 453 ] 454 self._interpreter = ( 455 _interpreter_wrapper.CreateWrapperFromFile( 456 model_path, op_resolver_id, custom_op_registerers_by_name, 457 custom_op_registerers_by_func, experimental_preserve_all_tensors)) 458 if not self._interpreter: 459 raise ValueError('Failed to open {}'.format(model_path)) 460 elif model_content and not model_path: 461 custom_op_registerers_by_name = [ 462 x for x in self._custom_op_registerers if isinstance(x, str) 463 ] 464 custom_op_registerers_by_func = [ 465 x for x in self._custom_op_registerers if not isinstance(x, str) 466 ] 467 # Take a reference, so the pointer remains valid. 468 # Since python strings are immutable then PyString_XX functions 469 # will always return the same pointer. 470 self._model_content = model_content 471 self._interpreter = ( 472 _interpreter_wrapper.CreateWrapperFromBuffer( 473 model_content, op_resolver_id, custom_op_registerers_by_name, 474 custom_op_registerers_by_func, experimental_preserve_all_tensors)) 475 elif not model_content and not model_path: 476 raise ValueError('`model_path` or `model_content` must be specified.') 477 else: 478 raise ValueError('Can\'t both provide `model_path` and `model_content`') 479 480 if num_threads is not None: 481 if not isinstance(num_threads, int): 482 raise ValueError('type of num_threads should be int') 483 if num_threads < 1: 484 raise ValueError('num_threads should >= 1') 485 self._interpreter.SetNumThreads(num_threads) 486 487 # Each delegate is a wrapper that owns the delegates that have been loaded 488 # as plugins. The interpreter wrapper will be using them, but we need to 489 # hold them in a list so that the lifetime is preserved at least as long as 490 # the interpreter wrapper. 491 self._delegates = [] 492 if experimental_delegates: 493 self._delegates = experimental_delegates 494 for delegate in self._delegates: 495 self._interpreter.ModifyGraphWithDelegate( 496 delegate._get_native_delegate_pointer()) # pylint: disable=protected-access 497 self._signature_defs = self.get_signature_list() 498 499 self._metrics = metrics.TFLiteMetrics() 500 self._metrics.increase_counter_interpreter_creation() 501 502 def __del__(self): 503 # Must make sure the interpreter is destroyed before things that 504 # are used by it like the delegates. NOTE this only works on CPython 505 # probably. 506 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 507 # by using explicit closes(). See implementation of Interpreter __del__. 508 self._interpreter = None 509 self._delegates = None 510 511 def allocate_tensors(self): 512 self._ensure_safe() 513 return self._interpreter.AllocateTensors() 514 515 def _safe_to_run(self): 516 """Returns true if there exist no numpy array buffers. 517 518 This means it is safe to run tflite calls that may destroy internally 519 allocated memory. This works, because in the wrapper.cc we have made 520 the numpy base be the self._interpreter. 521 """ 522 # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. 523 # If this environment is the only _interpreter, then the ref count should be 524 # 2 (1 in self and 1 in temporary of sys.getrefcount). 525 return sys.getrefcount(self._interpreter) == 2 526 527 def _ensure_safe(self): 528 """Makes sure no numpy arrays pointing to internal buffers are active. 529 530 This should be called from any function that will call a function on 531 _interpreter that may reallocate memory e.g. invoke(), ... 532 533 Raises: 534 RuntimeError: If there exist numpy objects pointing to internal memory 535 then we throw. 536 """ 537 if not self._safe_to_run(): 538 raise RuntimeError("""There is at least 1 reference to internal data 539 in the interpreter in the form of a numpy array or slice. Be sure to 540 only hold the function returned from tensor() if you are using raw 541 data access.""") 542 543 # Experimental and subject to change 544 def _get_op_details(self, op_index): 545 """Gets a dictionary with arrays of ids for tensors involved with an op. 546 547 Args: 548 op_index: Operation/node index of node to query. 549 550 Returns: 551 a dictionary containing the index, op name, and arrays with lists of the 552 indices for the inputs and outputs of the op/node. 553 """ 554 op_index = int(op_index) 555 op_name = self._interpreter.NodeName(op_index) 556 op_inputs = self._interpreter.NodeInputs(op_index) 557 op_outputs = self._interpreter.NodeOutputs(op_index) 558 559 details = { 560 'index': op_index, 561 'op_name': op_name, 562 'inputs': op_inputs, 563 'outputs': op_outputs, 564 } 565 566 return details 567 568 def _get_tensor_details(self, tensor_index): 569 """Gets tensor details. 570 571 Args: 572 tensor_index: Tensor index of tensor to query. 573 574 Returns: 575 A dictionary containing the following fields of the tensor: 576 'name': The tensor name. 577 'index': The tensor index in the interpreter. 578 'shape': The shape of the tensor. 579 'quantization': Deprecated, use 'quantization_parameters'. This field 580 only works for per-tensor quantization, whereas 581 'quantization_parameters' works in all cases. 582 'quantization_parameters': The parameters used to quantize the tensor: 583 'scales': List of scales (one if per-tensor quantization) 584 'zero_points': List of zero_points (one if per-tensor quantization) 585 'quantized_dimension': Specifies the dimension of per-axis 586 quantization, in the case of multiple scales/zero_points. 587 588 Raises: 589 ValueError: If tensor_index is invalid. 590 """ 591 tensor_index = int(tensor_index) 592 tensor_name = self._interpreter.TensorName(tensor_index) 593 tensor_size = self._interpreter.TensorSize(tensor_index) 594 tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index) 595 tensor_type = self._interpreter.TensorType(tensor_index) 596 tensor_quantization = self._interpreter.TensorQuantization(tensor_index) 597 tensor_quantization_params = self._interpreter.TensorQuantizationParameters( 598 tensor_index) 599 tensor_sparsity_params = self._interpreter.TensorSparsityParameters( 600 tensor_index) 601 602 if not tensor_type: 603 raise ValueError('Could not get tensor details') 604 605 details = { 606 'name': tensor_name, 607 'index': tensor_index, 608 'shape': tensor_size, 609 'shape_signature': tensor_size_signature, 610 'dtype': tensor_type, 611 'quantization': tensor_quantization, 612 'quantization_parameters': { 613 'scales': tensor_quantization_params[0], 614 'zero_points': tensor_quantization_params[1], 615 'quantized_dimension': tensor_quantization_params[2], 616 }, 617 'sparsity_parameters': tensor_sparsity_params 618 } 619 620 return details 621 622 # Experimental and subject to change 623 def _get_ops_details(self): 624 """Gets op details for every node. 625 626 Returns: 627 A list of dictionaries containing arrays with lists of tensor ids for 628 tensors involved in the op. 629 """ 630 return [ 631 self._get_op_details(idx) for idx in range(self._interpreter.NumNodes()) 632 ] 633 634 def get_tensor_details(self): 635 """Gets tensor details for every tensor with valid tensor details. 636 637 Tensors where required information about the tensor is not found are not 638 added to the list. This includes temporary tensors without a name. 639 640 Returns: 641 A list of dictionaries containing tensor information. 642 """ 643 tensor_details = [] 644 for idx in range(self._interpreter.NumTensors()): 645 try: 646 tensor_details.append(self._get_tensor_details(idx)) 647 except ValueError: 648 pass 649 return tensor_details 650 651 def get_input_details(self): 652 """Gets model input tensor details. 653 654 Returns: 655 A list in which each item is a dictionary with details about 656 an input tensor. Each dictionary contains the following fields 657 that describe the tensor: 658 659 + `name`: The tensor name. 660 + `index`: The tensor index in the interpreter. 661 + `shape`: The shape of the tensor. 662 + `shape_signature`: Same as `shape` for models with known/fixed shapes. 663 If any dimension sizes are unknown, they are indicated with `-1`. 664 + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`). 665 + `quantization`: Deprecated, use `quantization_parameters`. This field 666 only works for per-tensor quantization, whereas 667 `quantization_parameters` works in all cases. 668 + `quantization_parameters`: A dictionary of parameters used to quantize 669 the tensor: 670 ~ `scales`: List of scales (one if per-tensor quantization). 671 ~ `zero_points`: List of zero_points (one if per-tensor quantization). 672 ~ `quantized_dimension`: Specifies the dimension of per-axis 673 quantization, in the case of multiple scales/zero_points. 674 + `sparsity_parameters`: A dictionary of parameters used to encode a 675 sparse tensor. This is empty if the tensor is dense. 676 """ 677 return [ 678 self._get_tensor_details(i) for i in self._interpreter.InputIndices() 679 ] 680 681 def set_tensor(self, tensor_index, value): 682 """Sets the value of the input tensor. 683 684 Note this copies data in `value`. 685 686 If you want to avoid copying, you can use the `tensor()` function to get a 687 numpy buffer pointing to the input buffer in the tflite interpreter. 688 689 Args: 690 tensor_index: Tensor index of tensor to set. This value can be gotten from 691 the 'index' field in get_input_details. 692 value: Value of tensor to set. 693 694 Raises: 695 ValueError: If the interpreter could not set the tensor. 696 """ 697 self._interpreter.SetTensor(tensor_index, value) 698 699 def resize_tensor_input(self, input_index, tensor_size, strict=False): 700 """Resizes an input tensor. 701 702 Args: 703 input_index: Tensor index of input to set. This value can be gotten from 704 the 'index' field in get_input_details. 705 tensor_size: The tensor_shape to resize the input to. 706 strict: Only unknown dimensions can be resized when `strict` is True. 707 Unknown dimensions are indicated as `-1` in the `shape_signature` 708 attribute of a given tensor. (default False) 709 710 Raises: 711 ValueError: If the interpreter could not resize the input tensor. 712 713 Usage: 714 ``` 715 interpreter = Interpreter(model_content=tflite_model) 716 interpreter.resize_tensor_input(0, [num_test_images, 224, 224, 3]) 717 interpreter.allocate_tensors() 718 interpreter.set_tensor(0, test_images) 719 interpreter.invoke() 720 ``` 721 """ 722 self._ensure_safe() 723 # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size 724 # parameter. 725 tensor_size = np.array(tensor_size, dtype=np.int32) 726 self._interpreter.ResizeInputTensor(input_index, tensor_size, strict) 727 728 def get_output_details(self): 729 """Gets model output tensor details. 730 731 Returns: 732 A list in which each item is a dictionary with details about 733 an output tensor. The dictionary contains the same fields as 734 described for `get_input_details()`. 735 """ 736 return [ 737 self._get_tensor_details(i) for i in self._interpreter.OutputIndices() 738 ] 739 740 def get_signature_list(self): 741 """Gets list of SignatureDefs in the model. 742 743 Example, 744 ``` 745 signatures = interpreter.get_signature_list() 746 print(signatures) 747 748 # { 749 # 'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']} 750 # } 751 752 Then using the names in the signature list you can get a callable from 753 get_signature_runner(). 754 ``` 755 756 Returns: 757 A list of SignatureDef details in a dictionary structure. 758 It is keyed on the SignatureDef method name, and the value holds 759 dictionary of inputs and outputs. 760 """ 761 full_signature_defs = self._interpreter.GetSignatureDefs() 762 for _, signature_def in full_signature_defs.items(): 763 signature_def['inputs'] = list(signature_def['inputs'].keys()) 764 signature_def['outputs'] = list(signature_def['outputs'].keys()) 765 return full_signature_defs 766 767 def _get_full_signature_list(self): 768 """Gets list of SignatureDefs in the model. 769 770 Example, 771 ``` 772 signatures = interpreter._get_full_signature_list() 773 print(signatures) 774 775 # { 776 # 'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}} 777 # } 778 779 Then using the names in the signature list you can get a callable from 780 get_signature_runner(). 781 ``` 782 783 Returns: 784 A list of SignatureDef details in a dictionary structure. 785 It is keyed on the SignatureDef method name, and the value holds 786 dictionary of inputs and outputs. 787 """ 788 return self._interpreter.GetSignatureDefs() 789 790 def get_signature_runner(self, signature_key=None): 791 """Gets callable for inference of specific SignatureDef. 792 793 Example usage, 794 ``` 795 interpreter = tf.lite.Interpreter(model_content=tflite_model) 796 interpreter.allocate_tensors() 797 fn = interpreter.get_signature_runner('div_with_remainder') 798 output = fn(x=np.array([3]), y=np.array([2])) 799 print(output) 800 # { 801 # 'quotient': array([1.], dtype=float32) 802 # 'remainder': array([1.], dtype=float32) 803 # } 804 ``` 805 806 None can be passed for signature_key if the model has a single Signature 807 only. 808 809 All names used are this specific SignatureDef names. 810 811 812 Args: 813 signature_key: Signature key for the SignatureDef, it can be None if and 814 only if the model has a single SignatureDef. Default value is None. 815 816 Returns: 817 This returns a callable that can run inference for SignatureDef defined 818 by argument 'signature_key'. 819 The callable will take key arguments corresponding to the arguments of the 820 SignatureDef, that should have numpy values. 821 The callable will returns dictionary that maps from output names to numpy 822 values of the computed results. 823 824 Raises: 825 ValueError: If passed signature_key is invalid. 826 """ 827 if signature_key is None: 828 if len(self._signature_defs) != 1: 829 raise ValueError( 830 'SignatureDef signature_key is None and model has {0} Signatures. ' 831 'None is only allowed when the model has 1 SignatureDef'.format( 832 len(self._signature_defs))) 833 else: 834 signature_key = next(iter(self._signature_defs)) 835 return SignatureRunner(interpreter=self, signature_key=signature_key) 836 837 def get_tensor(self, tensor_index, subgraph_index=0): 838 """Gets the value of the output tensor (get a copy). 839 840 If you wish to avoid the copy, use `tensor()`. This function cannot be used 841 to read intermediate results. 842 843 Args: 844 tensor_index: Tensor index of tensor to get. This value can be gotten from 845 the 'index' field in get_output_details. 846 subgraph_index: Index of the subgraph to fetch the tensor. Default value 847 is 0, which means to fetch from the primary subgraph. 848 849 Returns: 850 a numpy array. 851 """ 852 return self._interpreter.GetTensor(tensor_index, subgraph_index) 853 854 def tensor(self, tensor_index): 855 """Returns function that gives a numpy view of the current tensor buffer. 856 857 This allows reading and writing to this tensors w/o copies. This more 858 closely mirrors the C++ Interpreter class interface's tensor() member, hence 859 the name. Be careful to not hold these output references through calls 860 to `allocate_tensors()` and `invoke()`. This function cannot be used to read 861 intermediate results. 862 863 Usage: 864 865 ``` 866 interpreter.allocate_tensors() 867 input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) 868 output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) 869 for i in range(10): 870 input().fill(3.) 871 interpreter.invoke() 872 print("inference %s" % output()) 873 ``` 874 875 Notice how this function avoids making a numpy array directly. This is 876 because it is important to not hold actual numpy views to the data longer 877 than necessary. If you do, then the interpreter can no longer be invoked, 878 because it is possible the interpreter would resize and invalidate the 879 referenced tensors. The NumPy API doesn't allow any mutability of the 880 the underlying buffers. 881 882 WRONG: 883 884 ``` 885 input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() 886 output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() 887 interpreter.allocate_tensors() # This will throw RuntimeError 888 for i in range(10): 889 input.fill(3.) 890 interpreter.invoke() # this will throw RuntimeError since input,output 891 ``` 892 893 Args: 894 tensor_index: Tensor index of tensor to get. This value can be gotten from 895 the 'index' field in get_output_details. 896 897 Returns: 898 A function that can return a new numpy array pointing to the internal 899 TFLite tensor state at any point. It is safe to hold the function forever, 900 but it is not safe to hold the numpy array forever. 901 """ 902 return lambda: self._interpreter.tensor(self._interpreter, tensor_index) 903 904 def invoke(self): 905 """Invoke the interpreter. 906 907 Be sure to set the input sizes, allocate tensors and fill values before 908 calling this. Also, note that this function releases the GIL so heavy 909 computation can be done in the background while the Python interpreter 910 continues. No other function on this object should be called while the 911 invoke() call has not finished. 912 913 Raises: 914 ValueError: When the underlying interpreter fails raise ValueError. 915 """ 916 self._ensure_safe() 917 self._interpreter.Invoke() 918 919 def reset_all_variables(self): 920 return self._interpreter.ResetVariableTensors() 921 922 # Experimental and subject to change. 923 def _native_handle(self): 924 """Returns a pointer to the underlying tflite::Interpreter instance. 925 926 This allows extending tflite.Interpreter's functionality in a custom C++ 927 function. Consider how that may work in a custom pybind wrapper: 928 929 m.def("SomeNewFeature", ([](py::object handle) { 930 auto* interpreter = 931 reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>()); 932 ... 933 })) 934 935 and corresponding Python call: 936 937 SomeNewFeature(interpreter.native_handle()) 938 939 Note: This approach is fragile. Users must guarantee the C++ extension build 940 is consistent with the tflite.Interpreter's underlying C++ build. 941 """ 942 return self._interpreter.interpreter() 943 944 945class InterpreterWithCustomOps(Interpreter): 946 """Interpreter interface for TensorFlow Lite Models that accepts custom ops. 947 948 The interface provided by this class is experimental and therefore not exposed 949 as part of the public API. 950 951 Wraps the tf.lite.Interpreter class and adds the ability to load custom ops 952 by providing the names of functions that take a pointer to a BuiltinOpResolver 953 and add a custom op. 954 """ 955 956 def __init__(self, custom_op_registerers=None, **kwargs): 957 """Constructor. 958 959 Args: 960 custom_op_registerers: List of str (symbol names) or functions that take a 961 pointer to a MutableOpResolver and register a custom op. When passing 962 functions, use a pybind function that takes a uintptr_t that can be 963 recast as a pointer to a MutableOpResolver. 964 **kwargs: Additional arguments passed to Interpreter. 965 966 Raises: 967 ValueError: If the interpreter was unable to create. 968 """ 969 self._custom_op_registerers = custom_op_registerers or [] 970 super(InterpreterWithCustomOps, self).__init__(**kwargs) 971