1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Python TF-Lite interpreter.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import ctypes 22import enum 23import os 24import platform 25import sys 26 27import numpy as np 28 29# pylint: disable=g-import-not-at-top 30if not os.path.splitext(__file__)[0].endswith( 31 os.path.join('tflite_runtime', 'interpreter')): 32 # This file is part of tensorflow package. 33 from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 34 from tensorflow.python.util.tf_export import tf_export as _tf_export 35 try: 36 from tensorflow.lite.python import metrics_portable as metrics 37 except ImportError: 38 from tensorflow.lite.python import metrics_nonportable as metrics 39else: 40 # This file is part of tflite_runtime package. 41 from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 42 from tflite_runtime import metrics_portable as metrics 43 44 def _tf_export(*x, **kwargs): 45 del x, kwargs 46 return lambda x: x 47 48 49# pylint: enable=g-import-not-at-top 50 51 52class Delegate(object): 53 """Python wrapper class to manage TfLiteDelegate objects. 54 55 The shared library is expected to have two functions: 56 TfLiteDelegate* tflite_plugin_create_delegate( 57 char**, char**, size_t, void (*report_error)(const char *)) 58 void tflite_plugin_destroy_delegate(TfLiteDelegate*) 59 60 The first one creates a delegate object. It may return NULL to indicate an 61 error (with a suitable error message reported by calling report_error()). 62 The second one destroys delegate object and must be called for every 63 created delegate object. Passing NULL as argument value is allowed, i.e. 64 65 tflite_plugin_destroy_delegate(tflite_plugin_create_delegate(...)) 66 67 always works. 68 """ 69 70 def __init__(self, library, options=None): 71 """Loads delegate from the shared library. 72 73 Args: 74 library: Shared library name. 75 options: Dictionary of options that are required to load the delegate. All 76 keys and values in the dictionary should be serializable. Consult the 77 documentation of the specific delegate for required and legal options. 78 (default None) 79 80 Raises: 81 RuntimeError: This is raised if the Python implementation is not CPython. 82 """ 83 84 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 85 # by using explicit closes(). See implementation of Interpreter __del__. 86 if platform.python_implementation() != 'CPython': 87 raise RuntimeError('Delegates are currently only supported into CPython' 88 'due to missing immediate reference counting.') 89 90 self._library = ctypes.pydll.LoadLibrary(library) 91 self._library.tflite_plugin_create_delegate.argtypes = [ 92 ctypes.POINTER(ctypes.c_char_p), 93 ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, 94 ctypes.CFUNCTYPE(None, ctypes.c_char_p) 95 ] 96 self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p 97 98 # Convert the options from a dictionary to lists of char pointers. 99 options = options or {} 100 options_keys = (ctypes.c_char_p * len(options))() 101 options_values = (ctypes.c_char_p * len(options))() 102 for idx, (key, value) in enumerate(options.items()): 103 options_keys[idx] = str(key).encode('utf-8') 104 options_values[idx] = str(value).encode('utf-8') 105 106 class ErrorMessageCapture(object): 107 108 def __init__(self): 109 self.message = '' 110 111 def report(self, x): 112 self.message += x if isinstance(x, str) else x.decode('utf-8') 113 114 capture = ErrorMessageCapture() 115 error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report) 116 # Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer. 117 self._delegate_ptr = self._library.tflite_plugin_create_delegate( 118 options_keys, options_values, len(options), error_capturer_cb) 119 if self._delegate_ptr is None: 120 raise ValueError(capture.message) 121 122 def __del__(self): 123 # __del__ can not be called multiple times, so if the delegate is destroyed. 124 # don't try to destroy it twice. 125 if self._library is not None: 126 self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p] 127 self._library.tflite_plugin_destroy_delegate(self._delegate_ptr) 128 self._library = None 129 130 def _get_native_delegate_pointer(self): 131 """Returns the native TfLiteDelegate pointer. 132 133 It is not safe to copy this pointer because it needs to be freed. 134 135 Returns: 136 TfLiteDelegate * 137 """ 138 return self._delegate_ptr 139 140 141@_tf_export('lite.experimental.load_delegate') 142def load_delegate(library, options=None): 143 """Returns loaded Delegate object. 144 145 Args: 146 library: Name of shared library containing the 147 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates). 148 options: Dictionary of options that are required to load the delegate. All 149 keys and values in the dictionary should be convertible to str. Consult 150 the documentation of the specific delegate for required and legal options. 151 (default None) 152 153 Returns: 154 Delegate object. 155 156 Raises: 157 ValueError: Delegate failed to load. 158 RuntimeError: If delegate loading is used on unsupported platform. 159 """ 160 try: 161 delegate = Delegate(library, options) 162 except ValueError as e: 163 raise ValueError('Failed to load delegate from {}\n{}'.format( 164 library, str(e))) 165 return delegate 166 167 168class SignatureRunner(object): 169 """SignatureRunner class for running TFLite models using SignatureDef. 170 171 This class should be instantiated through TFLite Interpreter only using 172 get_signature_runner method on Interpreter. 173 Example, 174 signature = interpreter.get_signature_runner("my_signature") 175 result = signature(input_1=my_input_1, input_2=my_input_2) 176 print(result["my_output"]) 177 print(result["my_second_output"]) 178 All names used are this specific SignatureDef names. 179 180 Notes: 181 No other function on this object or on the interpreter provided should be 182 called while this object call has not finished. 183 """ 184 185 def __init__(self, interpreter=None, signature_key=None): 186 """Constructor. 187 188 Args: 189 interpreter: Interpreter object that is already initialized with the 190 requested model. 191 signature_key: SignatureDef key to be used. 192 """ 193 if not interpreter: 194 raise ValueError('None interpreter provided.') 195 if not signature_key: 196 raise ValueError('None signature_key provided.') 197 self._interpreter = interpreter 198 self._interpreter_wrapper = interpreter._interpreter 199 self._signature_key = signature_key 200 signature_defs = interpreter._get_full_signature_list() 201 if signature_key not in signature_defs: 202 raise ValueError('Invalid signature_key provided.') 203 self._signature_def = signature_defs[signature_key] 204 self._outputs = self._signature_def['outputs'].items() 205 self._inputs = self._signature_def['inputs'] 206 207 self._subgraph_index = ( 208 self._interpreter_wrapper.GetSubgraphIndexFromSignature( 209 self._signature_key)) 210 211 def __call__(self, **kwargs): 212 """Runs the SignatureDef given the provided inputs in arguments. 213 214 Args: 215 **kwargs: key,value for inputs to the model. Key is the SignatureDef input 216 name. Value is numpy array with the value. 217 218 Returns: 219 dictionary of the results from the model invoke. 220 Key in the dictionary is SignatureDef output name. 221 Value is the result Tensor. 222 """ 223 224 if len(kwargs) != len(self._inputs): 225 raise ValueError( 226 'Invalid number of inputs provided for running a SignatureDef, ' 227 'expected %s vs provided %s' % (len(self._inputs), len(kwargs))) 228 229 # Resize input tensors 230 for input_name, value in kwargs.items(): 231 if input_name not in self._inputs: 232 raise ValueError('Invalid Input name (%s) for SignatureDef' % 233 input_name) 234 self._interpreter_wrapper.ResizeInputTensor( 235 self._inputs[input_name], np.array(value.shape, dtype=np.int32), 236 False, self._subgraph_index) 237 # Allocate tensors. 238 self._interpreter_wrapper.AllocateTensors(self._subgraph_index) 239 # Set the input values. 240 for input_name, value in kwargs.items(): 241 self._interpreter_wrapper.SetTensor(self._inputs[input_name], value, 242 self._subgraph_index) 243 244 self._interpreter_wrapper.Invoke(self._subgraph_index) 245 result = {} 246 for output_name, output_index in self._outputs: 247 result[output_name] = self._interpreter_wrapper.GetTensor( 248 output_index, self._subgraph_index) 249 return result 250 251 252@_tf_export('lite.experimental.OpResolverType') 253@enum.unique 254class OpResolverType(enum.Enum): 255 """Different types of op resolvers for Tensorflow Lite. 256 257 * `AUTO`: Indicates the op resolver that is chosen by default in TfLite 258 Python, which is the "BUILTIN" as described below. 259 * `BUILTIN`: Indicates the op resolver for built-in ops with optimized kernel 260 implementation. 261 * `BUILTIN_REF`: Indicates the op resolver for built-in ops with reference 262 kernel implementation. It's generally used for testing and debugging. 263 * `BUILTIN_WITHOUT_DEFAULT_DELEGATES`: Indicates the op resolver for 264 built-in ops with optimized kernel implementation, but it will disable 265 the application of default TfLite delegates (like the XNNPACK delegate) to 266 the model graph. Generally this should not be used unless there are issues 267 with the default configuration. 268 """ 269 # Corresponds to an op resolver chosen by default in TfLite Python. 270 AUTO = 0 271 272 # Corresponds to tflite::ops::builtin::BuiltinOpResolver in C++. 273 BUILTIN = 1 274 275 # Corresponds to tflite::ops::builtin::BuiltinRefOpResolver in C++. 276 BUILTIN_REF = 2 277 278 # Corresponds to 279 # tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates in C++. 280 BUILTIN_WITHOUT_DEFAULT_DELEGATES = 3 281 282 283def _get_op_resolver_id(op_resolver_type=OpResolverType.AUTO): 284 """Get a integer identifier for the op resolver.""" 285 286 # Note: the integer identifier value needs to be same w/ op resolver ids 287 # defined in interpreter_wrapper/interpreter_wrapper.cc. 288 return { 289 # Note AUTO and BUILTIN currently share the same identifier. 290 OpResolverType.AUTO: 1, 291 OpResolverType.BUILTIN: 1, 292 OpResolverType.BUILTIN_REF: 2, 293 OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES: 3 294 }.get(op_resolver_type, None) 295 296 297@_tf_export('lite.Interpreter') 298class Interpreter(object): 299 """Interpreter interface for running TensorFlow Lite models. 300 301 Models obtained from `TfLiteConverter` can be run in Python with 302 `Interpreter`. 303 304 As an example, lets generate a simple Keras model and convert it to TFLite 305 (`TfLiteConverter` also supports other input formats with `from_saved_model` 306 and `from_concrete_function`) 307 308 >>> x = np.array([[1.], [2.]]) 309 >>> y = np.array([[2.], [4.]]) 310 >>> model = tf.keras.models.Sequential([ 311 ... tf.keras.layers.Dropout(0.2), 312 ... tf.keras.layers.Dense(units=1, input_shape=[1]) 313 ... ]) 314 >>> model.compile(optimizer='sgd', loss='mean_squared_error') 315 >>> model.fit(x, y, epochs=1) 316 >>> converter = tf.lite.TFLiteConverter.from_keras_model(model) 317 >>> tflite_model = converter.convert() 318 319 `tflite_model` can be saved to a file and loaded later, or directly into the 320 `Interpreter`. Since TensorFlow Lite pre-plans tensor allocations to optimize 321 inference, the user needs to call `allocate_tensors()` before any inference. 322 323 >>> interpreter = tf.lite.Interpreter(model_content=tflite_model) 324 >>> interpreter.allocate_tensors() # Needed before execution! 325 326 Sample execution: 327 328 >>> output = interpreter.get_output_details()[0] # Model has single output. 329 >>> input = interpreter.get_input_details()[0] # Model has single input. 330 >>> input_data = tf.constant(1., shape=[1, 1]) 331 >>> interpreter.set_tensor(input['index'], input_data) 332 >>> interpreter.invoke() 333 >>> interpreter.get_tensor(output['index']).shape 334 (1, 1) 335 336 Use `get_signature_runner()` for a more user-friendly inference API. 337 """ 338 339 def __init__(self, 340 model_path=None, 341 model_content=None, 342 experimental_delegates=None, 343 num_threads=None, 344 experimental_op_resolver_type=OpResolverType.AUTO, 345 experimental_preserve_all_tensors=False): 346 """Constructor. 347 348 Args: 349 model_path: Path to TF-Lite Flatbuffer file. 350 model_content: Content of model. 351 experimental_delegates: Experimental. Subject to change. List of 352 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) 353 objects returned by lite.load_delegate(). 354 num_threads: Sets the number of threads used by the interpreter and 355 available to CPU kernels. If not set, the interpreter will use an 356 implementation-dependent default number of threads. Currently, only a 357 subset of kernels, such as conv, support multi-threading. num_threads 358 should be >= -1. Setting num_threads to 0 has the effect to disable 359 multithreading, which is equivalent to setting num_threads to 1. If set 360 to the value -1, the number of threads used will be 361 implementation-defined and platform-dependent. 362 experimental_op_resolver_type: The op resolver used by the interpreter. It 363 must be an instance of OpResolverType. By default, we use the built-in 364 op resolver which corresponds to tflite::ops::builtin::BuiltinOpResolver 365 in C++. 366 experimental_preserve_all_tensors: If true, then intermediate tensors used 367 during computation are preserved for inspection, and if the passed op 368 resolver type is AUTO or BUILTIN, the type will be changed to 369 BUILTIN_WITHOUT_DEFAULT_DELEGATES so that no Tensorflow Lite default 370 delegates are applied. If false, getting intermediate tensors could 371 result in undefined values or None, especially when the graph is 372 successfully modified by the Tensorflow Lite default delegate. 373 374 Raises: 375 ValueError: If the interpreter was unable to create. 376 """ 377 if not hasattr(self, '_custom_op_registerers'): 378 self._custom_op_registerers = [] 379 380 actual_resolver_type = experimental_op_resolver_type 381 if experimental_preserve_all_tensors and ( 382 experimental_op_resolver_type == OpResolverType.AUTO or 383 experimental_op_resolver_type == OpResolverType.BUILTIN): 384 actual_resolver_type = OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES 385 op_resolver_id = _get_op_resolver_id(actual_resolver_type) 386 if op_resolver_id is None: 387 raise ValueError('Unrecognized passed in op resolver type: {}'.format( 388 experimental_op_resolver_type)) 389 390 if model_path and not model_content: 391 custom_op_registerers_by_name = [ 392 x for x in self._custom_op_registerers if isinstance(x, str) 393 ] 394 custom_op_registerers_by_func = [ 395 x for x in self._custom_op_registerers if not isinstance(x, str) 396 ] 397 self._interpreter = ( 398 _interpreter_wrapper.CreateWrapperFromFile( 399 model_path, op_resolver_id, custom_op_registerers_by_name, 400 custom_op_registerers_by_func, experimental_preserve_all_tensors)) 401 if not self._interpreter: 402 raise ValueError('Failed to open {}'.format(model_path)) 403 elif model_content and not model_path: 404 custom_op_registerers_by_name = [ 405 x for x in self._custom_op_registerers if isinstance(x, str) 406 ] 407 custom_op_registerers_by_func = [ 408 x for x in self._custom_op_registerers if not isinstance(x, str) 409 ] 410 # Take a reference, so the pointer remains valid. 411 # Since python strings are immutable then PyString_XX functions 412 # will always return the same pointer. 413 self._model_content = model_content 414 self._interpreter = ( 415 _interpreter_wrapper.CreateWrapperFromBuffer( 416 model_content, op_resolver_id, custom_op_registerers_by_name, 417 custom_op_registerers_by_func, experimental_preserve_all_tensors)) 418 elif not model_content and not model_path: 419 raise ValueError('`model_path` or `model_content` must be specified.') 420 else: 421 raise ValueError('Can\'t both provide `model_path` and `model_content`') 422 423 if num_threads is not None: 424 if not isinstance(num_threads, int): 425 raise ValueError('type of num_threads should be int') 426 if num_threads < 1: 427 raise ValueError('num_threads should >= 1') 428 self._interpreter.SetNumThreads(num_threads) 429 430 # Each delegate is a wrapper that owns the delegates that have been loaded 431 # as plugins. The interpreter wrapper will be using them, but we need to 432 # hold them in a list so that the lifetime is preserved at least as long as 433 # the interpreter wrapper. 434 self._delegates = [] 435 if experimental_delegates: 436 self._delegates = experimental_delegates 437 for delegate in self._delegates: 438 self._interpreter.ModifyGraphWithDelegate( 439 delegate._get_native_delegate_pointer()) # pylint: disable=protected-access 440 self._signature_defs = self.get_signature_list() 441 442 self._metrics = metrics.TFLiteMetrics() 443 self._metrics.increase_counter_interpreter_creation() 444 445 def __del__(self): 446 # Must make sure the interpreter is destroyed before things that 447 # are used by it like the delegates. NOTE this only works on CPython 448 # probably. 449 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 450 # by using explicit closes(). See implementation of Interpreter __del__. 451 self._interpreter = None 452 self._delegates = None 453 454 def allocate_tensors(self): 455 self._ensure_safe() 456 return self._interpreter.AllocateTensors() 457 458 def _safe_to_run(self): 459 """Returns true if there exist no numpy array buffers. 460 461 This means it is safe to run tflite calls that may destroy internally 462 allocated memory. This works, because in the wrapper.cc we have made 463 the numpy base be the self._interpreter. 464 """ 465 # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. 466 # If this environment is the only _interpreter, then the ref count should be 467 # 2 (1 in self and 1 in temporary of sys.getrefcount). 468 return sys.getrefcount(self._interpreter) == 2 469 470 def _ensure_safe(self): 471 """Makes sure no numpy arrays pointing to internal buffers are active. 472 473 This should be called from any function that will call a function on 474 _interpreter that may reallocate memory e.g. invoke(), ... 475 476 Raises: 477 RuntimeError: If there exist numpy objects pointing to internal memory 478 then we throw. 479 """ 480 if not self._safe_to_run(): 481 raise RuntimeError("""There is at least 1 reference to internal data 482 in the interpreter in the form of a numpy array or slice. Be sure to 483 only hold the function returned from tensor() if you are using raw 484 data access.""") 485 486 # Experimental and subject to change 487 def _get_op_details(self, op_index): 488 """Gets a dictionary with arrays of ids for tensors involved with an op. 489 490 Args: 491 op_index: Operation/node index of node to query. 492 493 Returns: 494 a dictionary containing the index, op name, and arrays with lists of the 495 indices for the inputs and outputs of the op/node. 496 """ 497 op_index = int(op_index) 498 op_name = self._interpreter.NodeName(op_index) 499 op_inputs = self._interpreter.NodeInputs(op_index) 500 op_outputs = self._interpreter.NodeOutputs(op_index) 501 502 details = { 503 'index': op_index, 504 'op_name': op_name, 505 'inputs': op_inputs, 506 'outputs': op_outputs, 507 } 508 509 return details 510 511 def _get_tensor_details(self, tensor_index): 512 """Gets tensor details. 513 514 Args: 515 tensor_index: Tensor index of tensor to query. 516 517 Returns: 518 A dictionary containing the following fields of the tensor: 519 'name': The tensor name. 520 'index': The tensor index in the interpreter. 521 'shape': The shape of the tensor. 522 'quantization': Deprecated, use 'quantization_parameters'. This field 523 only works for per-tensor quantization, whereas 524 'quantization_parameters' works in all cases. 525 'quantization_parameters': The parameters used to quantize the tensor: 526 'scales': List of scales (one if per-tensor quantization) 527 'zero_points': List of zero_points (one if per-tensor quantization) 528 'quantized_dimension': Specifies the dimension of per-axis 529 quantization, in the case of multiple scales/zero_points. 530 531 Raises: 532 ValueError: If tensor_index is invalid. 533 """ 534 tensor_index = int(tensor_index) 535 tensor_name = self._interpreter.TensorName(tensor_index) 536 tensor_size = self._interpreter.TensorSize(tensor_index) 537 tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index) 538 tensor_type = self._interpreter.TensorType(tensor_index) 539 tensor_quantization = self._interpreter.TensorQuantization(tensor_index) 540 tensor_quantization_params = self._interpreter.TensorQuantizationParameters( 541 tensor_index) 542 tensor_sparsity_params = self._interpreter.TensorSparsityParameters( 543 tensor_index) 544 545 if not tensor_type: 546 raise ValueError('Could not get tensor details') 547 548 details = { 549 'name': tensor_name, 550 'index': tensor_index, 551 'shape': tensor_size, 552 'shape_signature': tensor_size_signature, 553 'dtype': tensor_type, 554 'quantization': tensor_quantization, 555 'quantization_parameters': { 556 'scales': tensor_quantization_params[0], 557 'zero_points': tensor_quantization_params[1], 558 'quantized_dimension': tensor_quantization_params[2], 559 }, 560 'sparsity_parameters': tensor_sparsity_params 561 } 562 563 return details 564 565 # Experimental and subject to change 566 def _get_ops_details(self): 567 """Gets op details for every node. 568 569 Returns: 570 A list of dictionaries containing arrays with lists of tensor ids for 571 tensors involved in the op. 572 """ 573 return [ 574 self._get_op_details(idx) for idx in range(self._interpreter.NumNodes()) 575 ] 576 577 def get_tensor_details(self): 578 """Gets tensor details for every tensor with valid tensor details. 579 580 Tensors where required information about the tensor is not found are not 581 added to the list. This includes temporary tensors without a name. 582 583 Returns: 584 A list of dictionaries containing tensor information. 585 """ 586 tensor_details = [] 587 for idx in range(self._interpreter.NumTensors()): 588 try: 589 tensor_details.append(self._get_tensor_details(idx)) 590 except ValueError: 591 pass 592 return tensor_details 593 594 def get_input_details(self): 595 """Gets model input tensor details. 596 597 Returns: 598 A list in which each item is a dictionary with details about 599 an input tensor. Each dictionary contains the following fields 600 that describe the tensor: 601 602 + `name`: The tensor name. 603 + `index`: The tensor index in the interpreter. 604 + `shape`: The shape of the tensor. 605 + `shape_signature`: Same as `shape` for models with known/fixed shapes. 606 If any dimension sizes are unkown, they are indicated with `-1`. 607 + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`). 608 + `quantization`: Deprecated, use `quantization_parameters`. This field 609 only works for per-tensor quantization, whereas 610 `quantization_parameters` works in all cases. 611 + `quantization_parameters`: A dictionary of parameters used to quantize 612 the tensor: 613 ~ `scales`: List of scales (one if per-tensor quantization). 614 ~ `zero_points`: List of zero_points (one if per-tensor quantization). 615 ~ `quantized_dimension`: Specifies the dimension of per-axis 616 quantization, in the case of multiple scales/zero_points. 617 + `sparsity_parameters`: A dictionary of parameters used to encode a 618 sparse tensor. This is empty if the tensor is dense. 619 """ 620 return [ 621 self._get_tensor_details(i) for i in self._interpreter.InputIndices() 622 ] 623 624 def set_tensor(self, tensor_index, value): 625 """Sets the value of the input tensor. 626 627 Note this copies data in `value`. 628 629 If you want to avoid copying, you can use the `tensor()` function to get a 630 numpy buffer pointing to the input buffer in the tflite interpreter. 631 632 Args: 633 tensor_index: Tensor index of tensor to set. This value can be gotten from 634 the 'index' field in get_input_details. 635 value: Value of tensor to set. 636 637 Raises: 638 ValueError: If the interpreter could not set the tensor. 639 """ 640 self._interpreter.SetTensor(tensor_index, value) 641 642 def resize_tensor_input(self, input_index, tensor_size, strict=False): 643 """Resizes an input tensor. 644 645 Args: 646 input_index: Tensor index of input to set. This value can be gotten from 647 the 'index' field in get_input_details. 648 tensor_size: The tensor_shape to resize the input to. 649 strict: Only unknown dimensions can be resized when `strict` is True. 650 Unknown dimensions are indicated as `-1` in the `shape_signature` 651 attribute of a given tensor. (default False) 652 653 Raises: 654 ValueError: If the interpreter could not resize the input tensor. 655 656 Usage: 657 ``` 658 interpreter = Interpreter(model_content=tflite_model) 659 interpreter.resize_tensor_input(0, [num_test_images, 224, 224, 3]) 660 interpreter.allocate_tensors() 661 interpreter.set_tensor(0, test_images) 662 interpreter.invoke() 663 ``` 664 """ 665 self._ensure_safe() 666 # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size 667 # parameter. 668 tensor_size = np.array(tensor_size, dtype=np.int32) 669 self._interpreter.ResizeInputTensor(input_index, tensor_size, strict) 670 671 def get_output_details(self): 672 """Gets model output tensor details. 673 674 Returns: 675 A list in which each item is a dictionary with details about 676 an output tensor. The dictionary contains the same fields as 677 described for `get_input_details()`. 678 """ 679 return [ 680 self._get_tensor_details(i) for i in self._interpreter.OutputIndices() 681 ] 682 683 def get_signature_list(self): 684 """Gets list of SignatureDefs in the model. 685 686 Example, 687 ``` 688 signatures = interpreter.get_signature_list() 689 print(signatures) 690 691 # { 692 # 'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']} 693 # } 694 695 Then using the names in the signature list you can get a callable from 696 get_signature_runner(). 697 ``` 698 699 Returns: 700 A list of SignatureDef details in a dictionary structure. 701 It is keyed on the SignatureDef method name, and the value holds 702 dictionary of inputs and outputs. 703 """ 704 full_signature_defs = self._interpreter.GetSignatureDefs() 705 for _, signature_def in full_signature_defs.items(): 706 signature_def['inputs'] = list(signature_def['inputs'].keys()) 707 signature_def['outputs'] = list(signature_def['outputs'].keys()) 708 return full_signature_defs 709 710 def _get_full_signature_list(self): 711 """Gets list of SignatureDefs in the model. 712 713 Example, 714 ``` 715 signatures = interpreter._get_full_signature_list() 716 print(signatures) 717 718 # { 719 # 'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}} 720 # } 721 722 Then using the names in the signature list you can get a callable from 723 get_signature_runner(). 724 ``` 725 726 Returns: 727 A list of SignatureDef details in a dictionary structure. 728 It is keyed on the SignatureDef method name, and the value holds 729 dictionary of inputs and outputs. 730 """ 731 return self._interpreter.GetSignatureDefs() 732 733 def get_signature_runner(self, signature_key=None): 734 """Gets callable for inference of specific SignatureDef. 735 736 Example usage, 737 ``` 738 interpreter = tf.lite.Interpreter(model_content=tflite_model) 739 interpreter.allocate_tensors() 740 fn = interpreter.get_signature_runner('div_with_remainder') 741 output = fn(x=np.array([3]), y=np.array([2])) 742 print(output) 743 # { 744 # 'quotient': array([1.], dtype=float32) 745 # 'remainder': array([1.], dtype=float32) 746 # } 747 ``` 748 749 None can be passed for signature_key if the model has a single Signature 750 only. 751 752 All names used are this specific SignatureDef names. 753 754 755 Args: 756 signature_key: Signature key for the SignatureDef, it can be None if and 757 only if the model has a single SignatureDef. Default value is None. 758 759 Returns: 760 This returns a callable that can run inference for SignatureDef defined 761 by argument 'signature_key'. 762 The callable will take key arguments corresponding to the arguments of the 763 SignatureDef, that should have numpy values. 764 The callable will returns dictionary that maps from output names to numpy 765 values of the computed results. 766 767 Raises: 768 ValueError: If passed signature_key is invalid. 769 """ 770 if signature_key is None: 771 if len(self._signature_defs) != 1: 772 raise ValueError( 773 'SignatureDef signature_key is None and model has {0} Signatures. ' 774 'None is only allowed when the model has 1 SignatureDef'.format( 775 len(self._signature_defs))) 776 else: 777 signature_key = next(iter(self._signature_defs)) 778 return SignatureRunner(interpreter=self, signature_key=signature_key) 779 780 def get_tensor(self, tensor_index): 781 """Gets the value of the output tensor (get a copy). 782 783 If you wish to avoid the copy, use `tensor()`. This function cannot be used 784 to read intermediate results. 785 786 Args: 787 tensor_index: Tensor index of tensor to get. This value can be gotten from 788 the 'index' field in get_output_details. 789 790 Returns: 791 a numpy array. 792 """ 793 return self._interpreter.GetTensor(tensor_index) 794 795 def tensor(self, tensor_index): 796 """Returns function that gives a numpy view of the current tensor buffer. 797 798 This allows reading and writing to this tensors w/o copies. This more 799 closely mirrors the C++ Interpreter class interface's tensor() member, hence 800 the name. Be careful to not hold these output references through calls 801 to `allocate_tensors()` and `invoke()`. This function cannot be used to read 802 intermediate results. 803 804 Usage: 805 806 ``` 807 interpreter.allocate_tensors() 808 input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) 809 output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) 810 for i in range(10): 811 input().fill(3.) 812 interpreter.invoke() 813 print("inference %s" % output()) 814 ``` 815 816 Notice how this function avoids making a numpy array directly. This is 817 because it is important to not hold actual numpy views to the data longer 818 than necessary. If you do, then the interpreter can no longer be invoked, 819 because it is possible the interpreter would resize and invalidate the 820 referenced tensors. The NumPy API doesn't allow any mutability of the 821 the underlying buffers. 822 823 WRONG: 824 825 ``` 826 input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() 827 output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() 828 interpreter.allocate_tensors() # This will throw RuntimeError 829 for i in range(10): 830 input.fill(3.) 831 interpreter.invoke() # this will throw RuntimeError since input,output 832 ``` 833 834 Args: 835 tensor_index: Tensor index of tensor to get. This value can be gotten from 836 the 'index' field in get_output_details. 837 838 Returns: 839 A function that can return a new numpy array pointing to the internal 840 TFLite tensor state at any point. It is safe to hold the function forever, 841 but it is not safe to hold the numpy array forever. 842 """ 843 return lambda: self._interpreter.tensor(self._interpreter, tensor_index) 844 845 def invoke(self): 846 """Invoke the interpreter. 847 848 Be sure to set the input sizes, allocate tensors and fill values before 849 calling this. Also, note that this function releases the GIL so heavy 850 computation can be done in the background while the Python interpreter 851 continues. No other function on this object should be called while the 852 invoke() call has not finished. 853 854 Raises: 855 ValueError: When the underlying interpreter fails raise ValueError. 856 """ 857 self._ensure_safe() 858 self._interpreter.Invoke() 859 860 def reset_all_variables(self): 861 return self._interpreter.ResetVariableTensors() 862 863 # Experimental and subject to change. 864 def _native_handle(self): 865 """Returns a pointer to the underlying tflite::Interpreter instance. 866 867 This allows extending tflite.Interpreter's functionality in a custom C++ 868 function. Consider how that may work in a custom pybind wrapper: 869 870 m.def("SomeNewFeature", ([](py::object handle) { 871 auto* interpreter = 872 reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>()); 873 ... 874 })) 875 876 and corresponding Python call: 877 878 SomeNewFeature(interpreter.native_handle()) 879 880 Note: This approach is fragile. Users must guarantee the C++ extension build 881 is consistent with the tflite.Interpreter's underlying C++ build. 882 """ 883 return self._interpreter.interpreter() 884 885 886class InterpreterWithCustomOps(Interpreter): 887 """Interpreter interface for TensorFlow Lite Models that accepts custom ops. 888 889 The interface provided by this class is experimental and therefore not exposed 890 as part of the public API. 891 892 Wraps the tf.lite.Interpreter class and adds the ability to load custom ops 893 by providing the names of functions that take a pointer to a BuiltinOpResolver 894 and add a custom op. 895 """ 896 897 def __init__(self, custom_op_registerers=None, **kwargs): 898 """Constructor. 899 900 Args: 901 custom_op_registerers: List of str (symbol names) or functions that take a 902 pointer to a MutableOpResolver and register a custom op. When passing 903 functions, use a pybind function that takes a uintptr_t that can be 904 recast as a pointer to a MutableOpResolver. 905 **kwargs: Additional arguments passed to Interpreter. 906 907 Raises: 908 ValueError: If the interpreter was unable to create. 909 """ 910 self._custom_op_registerers = custom_op_registerers or [] 911 super(InterpreterWithCustomOps, self).__init__(**kwargs) 912