1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Python TF-Lite interpreter.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import ctypes 22import platform 23import sys 24import os 25 26import numpy as np 27 28# pylint: disable=g-import-not-at-top 29if not os.path.splitext(__file__)[0].endswith( 30 os.path.join('tflite_runtime', 'interpreter')): 31 # This file is part of tensorflow package. 32 from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 33 from tensorflow.python.util.tf_export import tf_export as _tf_export 34else: 35 # This file is part of tflite_runtime package. 36 from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 37 38 def _tf_export(*x, **kwargs): 39 del x, kwargs 40 return lambda x: x 41 42 43class Delegate(object): 44 """Python wrapper class to manage TfLiteDelegate objects. 45 46 The shared library is expected to have two functions: 47 TfLiteDelegate* tflite_plugin_create_delegate( 48 char**, char**, size_t, void (*report_error)(const char *)) 49 void tflite_plugin_destroy_delegate(TfLiteDelegate*) 50 51 The first one creates a delegate object. It may return NULL to indicate an 52 error (with a suitable error message reported by calling report_error()). 53 The second one destroys delegate object and must be called for every 54 created delegate object. Passing NULL as argument value is allowed, i.e. 55 56 tflite_plugin_destroy_delegate(tflite_plugin_create_delegate(...)) 57 58 always works. 59 """ 60 61 def __init__(self, library, options=None): 62 """Loads delegate from the shared library. 63 64 Args: 65 library: Shared library name. 66 options: Dictionary of options that are required to load the delegate. All 67 keys and values in the dictionary should be serializable. Consult the 68 documentation of the specific delegate for required and legal options. 69 (default None) 70 71 Raises: 72 RuntimeError: This is raised if the Python implementation is not CPython. 73 """ 74 75 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 76 # by using explicit closes(). See implementation of Interpreter __del__. 77 if platform.python_implementation() != 'CPython': 78 raise RuntimeError('Delegates are currently only supported into CPython' 79 'due to missing immediate reference counting.') 80 81 self._library = ctypes.pydll.LoadLibrary(library) 82 self._library.tflite_plugin_create_delegate.argtypes = [ 83 ctypes.POINTER(ctypes.c_char_p), 84 ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, 85 ctypes.CFUNCTYPE(None, ctypes.c_char_p) 86 ] 87 self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p 88 89 # Convert the options from a dictionary to lists of char pointers. 90 options = options or {} 91 options_keys = (ctypes.c_char_p * len(options))() 92 options_values = (ctypes.c_char_p * len(options))() 93 for idx, (key, value) in enumerate(options.items()): 94 options_keys[idx] = str(key).encode('utf-8') 95 options_values[idx] = str(value).encode('utf-8') 96 97 class ErrorMessageCapture(object): 98 99 def __init__(self): 100 self.message = '' 101 102 def report(self, x): 103 self.message += x if isinstance(x, str) else x.decode('utf-8') 104 105 capture = ErrorMessageCapture() 106 error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report) 107 # Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer. 108 self._delegate_ptr = self._library.tflite_plugin_create_delegate( 109 options_keys, options_values, len(options), error_capturer_cb) 110 if self._delegate_ptr is None: 111 raise ValueError(capture.message) 112 113 def __del__(self): 114 # __del__ can not be called multiple times, so if the delegate is destroyed. 115 # don't try to destroy it twice. 116 if self._library is not None: 117 self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p] 118 self._library.tflite_plugin_destroy_delegate(self._delegate_ptr) 119 self._library = None 120 121 def _get_native_delegate_pointer(self): 122 """Returns the native TfLiteDelegate pointer. 123 124 It is not safe to copy this pointer because it needs to be freed. 125 126 Returns: 127 TfLiteDelegate * 128 """ 129 return self._delegate_ptr 130 131 132@_tf_export('lite.experimental.load_delegate') 133def load_delegate(library, options=None): 134 """Returns loaded Delegate object. 135 136 Args: 137 library: Name of shared library containing the 138 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates). 139 options: Dictionary of options that are required to load the delegate. All 140 keys and values in the dictionary should be convertible to str. Consult 141 the documentation of the specific delegate for required and legal options. 142 (default None) 143 144 Returns: 145 Delegate object. 146 147 Raises: 148 ValueError: Delegate failed to load. 149 RuntimeError: If delegate loading is used on unsupported platform. 150 """ 151 try: 152 delegate = Delegate(library, options) 153 except ValueError as e: 154 raise ValueError('Failed to load delegate from {}\n{}'.format( 155 library, str(e))) 156 return delegate 157 158 159class SignatureRunner(object): 160 """SignatureRunner class for running TFLite models using SignatureDef. 161 162 This class should be instantiated through TFLite Interpreter only using 163 get_signature_runner method on Interpreter. 164 Example, 165 signature = interpreter.get_signature_runner("my_signature") 166 result = signature(input_1=my_input_1, input_2=my_input_2) 167 print(result["my_output"]) 168 print(result["my_second_output"]) 169 All names used are this specific SignatureDef names. 170 171 Notes: 172 No other function on this object or on the interpreter provided should be 173 called while this object call has not finished. 174 """ 175 176 def __init__(self, interpreter=None, signature_def_name=None): 177 """Constructor. 178 179 Args: 180 interpreter: Interpreter object that is already initialized with the 181 requested model. 182 signature_def_name: SignatureDef names to be used. 183 """ 184 if not interpreter: 185 raise ValueError('None interpreter provided.') 186 if not signature_def_name: 187 raise ValueError('None signature_def_name provided.') 188 self._interpreter = interpreter 189 self._signature_def_name = signature_def_name 190 signature_defs = interpreter._get_full_signature_list() 191 if signature_def_name not in signature_defs: 192 raise ValueError('Invalid signature_def_name provided.') 193 self._signature_def = signature_defs[signature_def_name] 194 self._outputs = self._signature_def['outputs'].items() 195 self._inputs = self._signature_def['inputs'] 196 197 def __call__(self, **kwargs): 198 """Runs the SignatureDef given the provided inputs in arguments. 199 200 Args: 201 **kwargs: key,value for inputs to the model. Key is the SignatureDef input 202 name. Value is numpy array with the value. 203 204 Returns: 205 dictionary of the results from the model invoke. 206 Key in the dictionary is SignatureDef output name. 207 Value is the result Tensor. 208 """ 209 210 if len(kwargs) != len(self._inputs): 211 raise ValueError( 212 'Invalid number of inputs provided for running a SignatureDef, ' 213 'expected %s vs provided %s' % (len(kwargs), len(self._inputs))) 214 # Resize input tensors 215 for input_name, value in kwargs.items(): 216 if input_name not in self._inputs: 217 raise ValueError('Invalid Input name (%s) for SignatureDef' % 218 input_name) 219 self._interpreter.resize_tensor_input(self._inputs[input_name], 220 value.shape) 221 # Allocate tensors. 222 self._interpreter.allocate_tensors() 223 # Set the input values. 224 for input_name, value in kwargs.items(): 225 self._interpreter._set_input_tensor( 226 input_name, value=value, method_name=self._signature_def_name) 227 self._interpreter.invoke() 228 result = {} 229 for output_name, output_index in self._outputs: 230 result[output_name] = self._interpreter.get_tensor(output_index) 231 return result 232 233 234@_tf_export('lite.Interpreter') 235class Interpreter(object): 236 """Interpreter interface for TensorFlow Lite Models. 237 238 This makes the TensorFlow Lite interpreter accessible in Python. 239 It is possible to use this interpreter in a multithreaded Python environment, 240 but you must be sure to call functions of a particular instance from only 241 one thread at a time. So if you want to have 4 threads running different 242 inferences simultaneously, create an interpreter for each one as thread-local 243 data. Similarly, if you are calling invoke() in one thread on a single 244 interpreter but you want to use tensor() on another thread once it is done, 245 you must use a synchronization primitive between the threads to ensure invoke 246 has returned before calling tensor(). 247 """ 248 249 def __init__(self, 250 model_path=None, 251 model_content=None, 252 experimental_delegates=None, 253 num_threads=None): 254 """Constructor. 255 256 Args: 257 model_path: Path to TF-Lite Flatbuffer file. 258 model_content: Content of model. 259 experimental_delegates: Experimental. Subject to change. List of 260 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) 261 objects returned by lite.load_delegate(). 262 num_threads: Sets the number of threads used by the interpreter and 263 available to CPU kernels. If not set, the interpreter will use an 264 implementation-dependent default number of threads. Currently, only a 265 subset of kernels, such as conv, support multi-threading. 266 267 Raises: 268 ValueError: If the interpreter was unable to create. 269 """ 270 if not hasattr(self, '_custom_op_registerers'): 271 self._custom_op_registerers = [] 272 if model_path and not model_content: 273 custom_op_registerers_by_name = [ 274 x for x in self._custom_op_registerers if isinstance(x, str) 275 ] 276 custom_op_registerers_by_func = [ 277 x for x in self._custom_op_registerers if not isinstance(x, str) 278 ] 279 self._interpreter = ( 280 _interpreter_wrapper.CreateWrapperFromFile( 281 model_path, custom_op_registerers_by_name, 282 custom_op_registerers_by_func)) 283 if not self._interpreter: 284 raise ValueError('Failed to open {}'.format(model_path)) 285 elif model_content and not model_path: 286 custom_op_registerers_by_name = [ 287 x for x in self._custom_op_registerers if isinstance(x, str) 288 ] 289 custom_op_registerers_by_func = [ 290 x for x in self._custom_op_registerers if not isinstance(x, str) 291 ] 292 # Take a reference, so the pointer remains valid. 293 # Since python strings are immutable then PyString_XX functions 294 # will always return the same pointer. 295 self._model_content = model_content 296 self._interpreter = ( 297 _interpreter_wrapper.CreateWrapperFromBuffer( 298 model_content, custom_op_registerers_by_name, 299 custom_op_registerers_by_func)) 300 elif not model_content and not model_path: 301 raise ValueError('`model_path` or `model_content` must be specified.') 302 else: 303 raise ValueError('Can\'t both provide `model_path` and `model_content`') 304 305 if num_threads is not None: 306 if not isinstance(num_threads, int): 307 raise ValueError('type of num_threads should be int') 308 if num_threads < 1: 309 raise ValueError('num_threads should >= 1') 310 self._interpreter.SetNumThreads(num_threads) 311 312 # Each delegate is a wrapper that owns the delegates that have been loaded 313 # as plugins. The interpreter wrapper will be using them, but we need to 314 # hold them in a list so that the lifetime is preserved at least as long as 315 # the interpreter wrapper. 316 self._delegates = [] 317 if experimental_delegates: 318 self._delegates = experimental_delegates 319 for delegate in self._delegates: 320 self._interpreter.ModifyGraphWithDelegate( 321 delegate._get_native_delegate_pointer()) # pylint: disable=protected-access 322 self._signature_defs = self.get_signature_list() 323 324 def __del__(self): 325 # Must make sure the interpreter is destroyed before things that 326 # are used by it like the delegates. NOTE this only works on CPython 327 # probably. 328 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 329 # by using explicit closes(). See implementation of Interpreter __del__. 330 self._interpreter = None 331 self._delegates = None 332 333 def allocate_tensors(self): 334 self._ensure_safe() 335 return self._interpreter.AllocateTensors() 336 337 def _safe_to_run(self): 338 """Returns true if there exist no numpy array buffers. 339 340 This means it is safe to run tflite calls that may destroy internally 341 allocated memory. This works, because in the wrapper.cc we have made 342 the numpy base be the self._interpreter. 343 """ 344 # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. 345 # If this environment is the only _interpreter, then the ref count should be 346 # 2 (1 in self and 1 in temporary of sys.getrefcount). 347 return sys.getrefcount(self._interpreter) == 2 348 349 def _ensure_safe(self): 350 """Makes sure no numpy arrays pointing to internal buffers are active. 351 352 This should be called from any function that will call a function on 353 _interpreter that may reallocate memory e.g. invoke(), ... 354 355 Raises: 356 RuntimeError: If there exist numpy objects pointing to internal memory 357 then we throw. 358 """ 359 if not self._safe_to_run(): 360 raise RuntimeError("""There is at least 1 reference to internal data 361 in the interpreter in the form of a numpy array or slice. Be sure to 362 only hold the function returned from tensor() if you are using raw 363 data access.""") 364 365 # Experimental and subject to change 366 def _get_op_details(self, op_index): 367 """Gets a dictionary with arrays of ids for tensors involved with an op. 368 369 Args: 370 op_index: Operation/node index of node to query. 371 372 Returns: 373 a dictionary containing the index, op name, and arrays with lists of the 374 indices for the inputs and outputs of the op/node. 375 """ 376 op_index = int(op_index) 377 op_name = self._interpreter.NodeName(op_index) 378 op_inputs = self._interpreter.NodeInputs(op_index) 379 op_outputs = self._interpreter.NodeOutputs(op_index) 380 381 details = { 382 'index': op_index, 383 'op_name': op_name, 384 'inputs': op_inputs, 385 'outputs': op_outputs, 386 } 387 388 return details 389 390 def _get_tensor_details(self, tensor_index): 391 """Gets tensor details. 392 393 Args: 394 tensor_index: Tensor index of tensor to query. 395 396 Returns: 397 A dictionary containing the following fields of the tensor: 398 'name': The tensor name. 399 'index': The tensor index in the interpreter. 400 'shape': The shape of the tensor. 401 'quantization': Deprecated, use 'quantization_parameters'. This field 402 only works for per-tensor quantization, whereas 403 'quantization_parameters' works in all cases. 404 'quantization_parameters': The parameters used to quantize the tensor: 405 'scales': List of scales (one if per-tensor quantization) 406 'zero_points': List of zero_points (one if per-tensor quantization) 407 'quantized_dimension': Specifies the dimension of per-axis 408 quantization, in the case of multiple scales/zero_points. 409 410 Raises: 411 ValueError: If tensor_index is invalid. 412 """ 413 tensor_index = int(tensor_index) 414 tensor_name = self._interpreter.TensorName(tensor_index) 415 tensor_size = self._interpreter.TensorSize(tensor_index) 416 tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index) 417 tensor_type = self._interpreter.TensorType(tensor_index) 418 tensor_quantization = self._interpreter.TensorQuantization(tensor_index) 419 tensor_quantization_params = self._interpreter.TensorQuantizationParameters( 420 tensor_index) 421 tensor_sparsity_params = self._interpreter.TensorSparsityParameters( 422 tensor_index) 423 424 if not tensor_type: 425 raise ValueError('Could not get tensor details') 426 427 details = { 428 'name': tensor_name, 429 'index': tensor_index, 430 'shape': tensor_size, 431 'shape_signature': tensor_size_signature, 432 'dtype': tensor_type, 433 'quantization': tensor_quantization, 434 'quantization_parameters': { 435 'scales': tensor_quantization_params[0], 436 'zero_points': tensor_quantization_params[1], 437 'quantized_dimension': tensor_quantization_params[2], 438 }, 439 'sparsity_parameters': tensor_sparsity_params 440 } 441 442 return details 443 444 # Experimental and subject to change 445 def _get_ops_details(self): 446 """Gets op details for every node. 447 448 Returns: 449 A list of dictionaries containing arrays with lists of tensor ids for 450 tensors involved in the op. 451 """ 452 return [ 453 self._get_op_details(idx) for idx in range(self._interpreter.NumNodes()) 454 ] 455 456 def get_tensor_details(self): 457 """Gets tensor details for every tensor with valid tensor details. 458 459 Tensors where required information about the tensor is not found are not 460 added to the list. This includes temporary tensors without a name. 461 462 Returns: 463 A list of dictionaries containing tensor information. 464 """ 465 tensor_details = [] 466 for idx in range(self._interpreter.NumTensors()): 467 try: 468 tensor_details.append(self._get_tensor_details(idx)) 469 except ValueError: 470 pass 471 return tensor_details 472 473 def get_input_details(self): 474 """Gets model input details. 475 476 Returns: 477 A list of input details. 478 """ 479 return [ 480 self._get_tensor_details(i) for i in self._interpreter.InputIndices() 481 ] 482 483 def set_tensor(self, tensor_index, value): 484 """Sets the value of the input tensor. 485 486 Note this copies data in `value`. 487 488 If you want to avoid copying, you can use the `tensor()` function to get a 489 numpy buffer pointing to the input buffer in the tflite interpreter. 490 491 Args: 492 tensor_index: Tensor index of tensor to set. This value can be gotten from 493 the 'index' field in get_input_details. 494 value: Value of tensor to set. 495 496 Raises: 497 ValueError: If the interpreter could not set the tensor. 498 """ 499 self._interpreter.SetTensor(tensor_index, value) 500 501 def resize_tensor_input(self, input_index, tensor_size, strict=False): 502 """Resizes an input tensor. 503 504 Args: 505 input_index: Tensor index of input to set. This value can be gotten from 506 the 'index' field in get_input_details. 507 tensor_size: The tensor_shape to resize the input to. 508 strict: Only unknown dimensions can be resized when `strict` is True. 509 Unknown dimensions are indicated as `-1` in the `shape_signature` 510 attribute of a given tensor. (default False) 511 512 Raises: 513 ValueError: If the interpreter could not resize the input tensor. 514 515 Usage: 516 ``` 517 interpreter = Interpreter(model_content=tflite_model) 518 interpreter.resize_tensor_input(0, [num_test_images, 224, 224, 3]) 519 interpreter.allocate_tensors() 520 interpreter.set_tensor(0, test_images) 521 interpreter.invoke() 522 ``` 523 """ 524 self._ensure_safe() 525 # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size 526 # parameter. 527 tensor_size = np.array(tensor_size, dtype=np.int32) 528 self._interpreter.ResizeInputTensor(input_index, tensor_size, strict) 529 530 def get_output_details(self): 531 """Gets model output details. 532 533 Returns: 534 A list of output details. 535 """ 536 return [ 537 self._get_tensor_details(i) for i in self._interpreter.OutputIndices() 538 ] 539 540 def get_signature_list(self): 541 """Gets list of SignatureDefs in the model. 542 543 Example, 544 ``` 545 signatures = interpreter.get_signature_list() 546 print(signatures) 547 548 # { 549 # 'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']} 550 # } 551 552 Then using the names in the signature list you can get a callable from 553 get_signature_runner(). 554 ``` 555 556 Returns: 557 A list of SignatureDef details in a dictionary structure. 558 It is keyed on the SignatureDef method name, and the value holds 559 dictionary of inputs and outputs. 560 """ 561 full_signature_defs = self._interpreter.GetSignatureDefs() 562 for _, signature_def in full_signature_defs.items(): 563 signature_def['inputs'] = list(signature_def['inputs'].keys()) 564 signature_def['outputs'] = list(signature_def['outputs'].keys()) 565 return full_signature_defs 566 567 def _get_full_signature_list(self): 568 """Gets list of SignatureDefs in the model. 569 570 Example, 571 ``` 572 signatures = interpreter._get_full_signature_list() 573 print(signatures) 574 575 # { 576 # 'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}} 577 # } 578 579 Then using the names in the signature list you can get a callable from 580 get_signature_runner(). 581 ``` 582 583 Returns: 584 A list of SignatureDef details in a dictionary structure. 585 It is keyed on the SignatureDef method name, and the value holds 586 dictionary of inputs and outputs. 587 """ 588 return self._interpreter.GetSignatureDefs() 589 590 def _set_input_tensor(self, input_name, value, method_name=None): 591 """Sets the value of the input tensor. 592 593 Input tensor is identified by `input_name` in the SignatureDef identified 594 by `method_name`. 595 If the model has a single SignatureDef then you can pass None as 596 `method_name`. 597 598 Note this copies data in `value`. 599 600 Example, 601 ``` 602 input_data = np.array([1.2, 1.4], np.float32) 603 signatures = interpreter.get_signature_list() 604 print(signatures) 605 # { 606 # 'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}} 607 # } 608 interpreter._set_input_tensor(input_name='x', value=input_data, 609 method_name='add_fn') 610 ``` 611 612 Args: 613 input_name: Name of the output tensor in the SignatureDef. 614 value: Value of tensor to set as a numpy array. 615 method_name: The exported method name for the SignatureDef, it can be None 616 if and only if the model has a single SignatureDef. Default value is 617 None. 618 619 Raises: 620 ValueError: If the interpreter could not set the tensor. Or 621 if `method_name` is None and model doesn't have a single 622 Signature. 623 """ 624 if method_name is None: 625 if len(self._signature_defs) != 1: 626 raise ValueError( 627 'SignatureDef method_name is None and model has {0} Signatures. ' 628 'None is only allowed when the model has 1 SignatureDef'.format( 629 len(self._signature_defs))) 630 else: 631 method_name = next(iter(self._signature_defs)) 632 self._interpreter.SetInputTensorFromSignatureDefName( 633 input_name, method_name, value) 634 635 def get_signature_runner(self, method_name=None): 636 """Gets callable for inference of specific SignatureDef. 637 638 Example usage, 639 ``` 640 interpreter = tf.lite.Interpreter(model_content=tflite_model) 641 interpreter.allocate_tensors() 642 fn = interpreter.get_signature_runner('div_with_remainder') 643 output = fn(x=np.array([3]), y=np.array([2])) 644 print(output) 645 # { 646 # 'quotient': array([1.], dtype=float32) 647 # 'remainder': array([1.], dtype=float32) 648 # } 649 ``` 650 651 None can be passed for method_name if the model has a single Signature only. 652 653 All names used are this specific SignatureDef names. 654 655 656 Args: 657 method_name: The exported method name for the SignatureDef, it can be None 658 if and only if the model has a single SignatureDef. Default value is 659 None. 660 661 Returns: 662 This returns a callable that can run inference for SignatureDef defined 663 by argument 'method_name'. 664 The callable will take key arguments corresponding to the arguments of the 665 SignatureDef, that should have numpy values. 666 The callable will returns dictionary that maps from output names to numpy 667 values of the computed results. 668 669 Raises: 670 ValueError: If passed method_name is invalid. 671 """ 672 if method_name is None: 673 if len(self._signature_defs) != 1: 674 raise ValueError( 675 'SignatureDef method_name is None and model has {0} Signatures. ' 676 'None is only allowed when the model has 1 SignatureDef'.format( 677 len(self._signature_defs))) 678 else: 679 method_name = next(iter(self._signature_defs)) 680 return SignatureRunner(interpreter=self, signature_def_name=method_name) 681 682 def get_tensor(self, tensor_index): 683 """Gets the value of the output tensor (get a copy). 684 685 If you wish to avoid the copy, use `tensor()`. This function cannot be used 686 to read intermediate results. 687 688 Args: 689 tensor_index: Tensor index of tensor to get. This value can be gotten from 690 the 'index' field in get_output_details. 691 692 Returns: 693 a numpy array. 694 """ 695 return self._interpreter.GetTensor(tensor_index) 696 697 def tensor(self, tensor_index): 698 """Returns function that gives a numpy view of the current tensor buffer. 699 700 This allows reading and writing to this tensors w/o copies. This more 701 closely mirrors the C++ Interpreter class interface's tensor() member, hence 702 the name. Be careful to not hold these output references through calls 703 to `allocate_tensors()` and `invoke()`. This function cannot be used to read 704 intermediate results. 705 706 Usage: 707 708 ``` 709 interpreter.allocate_tensors() 710 input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) 711 output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) 712 for i in range(10): 713 input().fill(3.) 714 interpreter.invoke() 715 print("inference %s" % output()) 716 ``` 717 718 Notice how this function avoids making a numpy array directly. This is 719 because it is important to not hold actual numpy views to the data longer 720 than necessary. If you do, then the interpreter can no longer be invoked, 721 because it is possible the interpreter would resize and invalidate the 722 referenced tensors. The NumPy API doesn't allow any mutability of the 723 the underlying buffers. 724 725 WRONG: 726 727 ``` 728 input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() 729 output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() 730 interpreter.allocate_tensors() # This will throw RuntimeError 731 for i in range(10): 732 input.fill(3.) 733 interpreter.invoke() # this will throw RuntimeError since input,output 734 ``` 735 736 Args: 737 tensor_index: Tensor index of tensor to get. This value can be gotten from 738 the 'index' field in get_output_details. 739 740 Returns: 741 A function that can return a new numpy array pointing to the internal 742 TFLite tensor state at any point. It is safe to hold the function forever, 743 but it is not safe to hold the numpy array forever. 744 """ 745 return lambda: self._interpreter.tensor(self._interpreter, tensor_index) 746 747 def invoke(self): 748 """Invoke the interpreter. 749 750 Be sure to set the input sizes, allocate tensors and fill values before 751 calling this. Also, note that this function releases the GIL so heavy 752 computation can be done in the background while the Python interpreter 753 continues. No other function on this object should be called while the 754 invoke() call has not finished. 755 756 Raises: 757 ValueError: When the underlying interpreter fails raise ValueError. 758 """ 759 self._ensure_safe() 760 self._interpreter.Invoke() 761 762 def reset_all_variables(self): 763 return self._interpreter.ResetVariableTensors() 764 765 # Experimental and subject to change. 766 def _native_handle(self): 767 """Returns a pointer to the underlying tflite::Interpreter instance. 768 769 This allows extending tflite.Interpreter's functionality in a custom C++ 770 function. Consider how that may work in a custom pybind wrapper: 771 772 m.def("SomeNewFeature", ([](py::object handle) { 773 auto* interpreter = 774 reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>()); 775 ... 776 })) 777 778 and corresponding Python call: 779 780 SomeNewFeature(interpreter.native_handle()) 781 782 Note: This approach is fragile. Users must guarantee the C++ extension build 783 is consistent with the tflite.Interpreter's underlying C++ build. 784 """ 785 return self._interpreter.interpreter() 786 787 788class InterpreterWithCustomOps(Interpreter): 789 """Interpreter interface for TensorFlow Lite Models that accepts custom ops. 790 791 The interface provided by this class is experimental and therefore not exposed 792 as part of the public API. 793 794 Wraps the tf.lite.Interpreter class and adds the ability to load custom ops 795 by providing the names of functions that take a pointer to a BuiltinOpResolver 796 and add a custom op. 797 """ 798 799 def __init__(self, 800 model_path=None, 801 model_content=None, 802 experimental_delegates=None, 803 custom_op_registerers=None): 804 """Constructor. 805 806 Args: 807 model_path: Path to TF-Lite Flatbuffer file. 808 model_content: Content of model. 809 experimental_delegates: Experimental. Subject to change. List of 810 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) 811 objects returned by lite.load_delegate(). 812 custom_op_registerers: List of str (symbol names) or functions that take a 813 pointer to a MutableOpResolver and register a custom op. When passing 814 functions, use a pybind function that takes a uintptr_t that can be 815 recast as a pointer to a MutableOpResolver. 816 817 Raises: 818 ValueError: If the interpreter was unable to create. 819 """ 820 self._custom_op_registerers = custom_op_registerers or [] 821 super(InterpreterWithCustomOps, self).__init__( 822 model_path=model_path, 823 model_content=model_content, 824 experimental_delegates=experimental_delegates) 825