1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python TF-Lite interpreter.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import sys 21import numpy as np 22 23# pylint: disable=g-import-not-at-top 24try: 25 from tensorflow.python.util.lazy_loader import LazyLoader 26 from tensorflow.python.util.tf_export import tf_export as _tf_export 27 28 # Lazy load since some of the performance benchmark skylark rules 29 # break dependencies. Must use double quotes to match code internal rewrite 30 # rule. 31 # pylint: disable=g-inconsistent-quotes 32 _interpreter_wrapper = LazyLoader( 33 "_interpreter_wrapper", globals(), 34 "tensorflow.lite.python.interpreter_wrapper." 35 "tensorflow_wrap_interpreter_wrapper") 36 # pylint: enable=g-inconsistent-quotes 37 38 del LazyLoader 39except ImportError: 40 # When full Tensorflow Python PIP is not available do not use lazy load 41 # and instead uf the tflite_runtime path. 42 from tflite_runtime.lite.python import interpreter_wrapper as _interpreter_wrapper 43 44 def tf_export_dummy(*x, **kwargs): 45 del x, kwargs 46 return lambda x: x 47 _tf_export = tf_export_dummy 48 49 50@_tf_export('lite.Interpreter') 51class Interpreter(object): 52 """Interpreter inferace for TF-Lite Models.""" 53 54 def __init__(self, model_path=None, model_content=None): 55 """Constructor. 56 57 Args: 58 model_path: Path to TF-Lite Flatbuffer file. 59 model_content: Content of model. 60 61 Raises: 62 ValueError: If the interpreter was unable to create. 63 """ 64 if model_path and not model_content: 65 self._interpreter = ( 66 _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile( 67 model_path)) 68 if not self._interpreter: 69 raise ValueError('Failed to open {}'.format(model_path)) 70 elif model_content and not model_path: 71 # Take a reference, so the pointer remains valid. 72 # Since python strings are immutable then PyString_XX functions 73 # will always return the same pointer. 74 self._model_content = model_content 75 self._interpreter = ( 76 _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( 77 model_content)) 78 elif not model_path and not model_path: 79 raise ValueError('`model_path` or `model_content` must be specified.') 80 else: 81 raise ValueError('Can\'t both provide `model_path` and `model_content`') 82 83 def allocate_tensors(self): 84 self._ensure_safe() 85 return self._interpreter.AllocateTensors() 86 87 def _safe_to_run(self): 88 """Returns true if there exist no numpy array buffers. 89 90 This means it is safe to run tflite calls that may destroy internally 91 allocated memory. This works, because in the wrapper.cc we have made 92 the numpy base be the self._interpreter. 93 """ 94 # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. 95 # If this environment is the only _interpreter, then the ref count should be 96 # 2 (1 in self and 1 in temporary of sys.getrefcount). 97 return sys.getrefcount(self._interpreter) == 2 98 99 def _ensure_safe(self): 100 """Makes sure no numpy arrays pointing to internal buffers are active. 101 102 This should be called from any function that will call a function on 103 _interpreter that may reallocate memory e.g. invoke(), ... 104 105 Raises: 106 RuntimeError: If there exist numpy objects pointing to internal memory 107 then we throw. 108 """ 109 if not self._safe_to_run(): 110 raise RuntimeError("""There is at least 1 reference to internal data 111 in the interpreter in the form of a numpy array or slice. Be sure to 112 only hold the function returned from tensor() if you are using raw 113 data access.""") 114 115 def _get_tensor_details(self, tensor_index): 116 """Gets tensor details. 117 118 Args: 119 tensor_index: Tensor index of tensor to query. 120 121 Returns: 122 a dictionary containing the name, index, shape and type of the tensor. 123 124 Raises: 125 ValueError: If tensor_index is invalid. 126 """ 127 tensor_index = int(tensor_index) 128 tensor_name = self._interpreter.TensorName(tensor_index) 129 tensor_size = self._interpreter.TensorSize(tensor_index) 130 tensor_type = self._interpreter.TensorType(tensor_index) 131 tensor_quantization = self._interpreter.TensorQuantization(tensor_index) 132 133 if not tensor_name or not tensor_type: 134 raise ValueError('Could not get tensor details') 135 136 details = { 137 'name': tensor_name, 138 'index': tensor_index, 139 'shape': tensor_size, 140 'dtype': tensor_type, 141 'quantization': tensor_quantization, 142 } 143 144 return details 145 146 def get_tensor_details(self): 147 """Gets tensor details for every tensor with valid tensor details. 148 149 Tensors where required information about the tensor is not found are not 150 added to the list. This includes temporary tensors without a name. 151 152 Returns: 153 A list of dictionaries containing tensor information. 154 """ 155 tensor_details = [] 156 for idx in range(self._interpreter.NumTensors()): 157 try: 158 tensor_details.append(self._get_tensor_details(idx)) 159 except ValueError: 160 pass 161 return tensor_details 162 163 def get_input_details(self): 164 """Gets model input details. 165 166 Returns: 167 A list of input details. 168 """ 169 return [ 170 self._get_tensor_details(i) for i in self._interpreter.InputIndices() 171 ] 172 173 def set_tensor(self, tensor_index, value): 174 """Sets the value of the input tensor. Note this copies data in `value`. 175 176 If you want to avoid copying, you can use the `tensor()` function to get a 177 numpy buffer pointing to the input buffer in the tflite interpreter. 178 179 Args: 180 tensor_index: Tensor index of tensor to set. This value can be gotten from 181 the 'index' field in get_input_details. 182 value: Value of tensor to set. 183 184 Raises: 185 ValueError: If the interpreter could not set the tensor. 186 """ 187 self._interpreter.SetTensor(tensor_index, value) 188 189 def resize_tensor_input(self, input_index, tensor_size): 190 """Resizes an input tensor. 191 192 Args: 193 input_index: Tensor index of input to set. This value can be gotten from 194 the 'index' field in get_input_details. 195 tensor_size: The tensor_shape to resize the input to. 196 197 Raises: 198 ValueError: If the interpreter could not resize the input tensor. 199 """ 200 self._ensure_safe() 201 # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size 202 # parameter. 203 tensor_size = np.array(tensor_size, dtype=np.int32) 204 self._interpreter.ResizeInputTensor(input_index, tensor_size) 205 206 def get_output_details(self): 207 """Gets model output details. 208 209 Returns: 210 A list of output details. 211 """ 212 return [ 213 self._get_tensor_details(i) for i in self._interpreter.OutputIndices() 214 ] 215 216 def get_tensor(self, tensor_index): 217 """Gets the value of the input tensor (get a copy). 218 219 If you wish to avoid the copy, use `tensor()`. This function cannot be used 220 to read intermediate results. 221 222 Args: 223 tensor_index: Tensor index of tensor to get. This value can be gotten from 224 the 'index' field in get_output_details. 225 226 Returns: 227 a numpy array. 228 """ 229 return self._interpreter.GetTensor(tensor_index) 230 231 def tensor(self, tensor_index): 232 """Returns function that gives a numpy view of the current tensor buffer. 233 234 This allows reading and writing to this tensors w/o copies. This more 235 closely mirrors the C++ Interpreter class interface's tensor() member, hence 236 the name. Be careful to not hold these output references through calls 237 to `allocate_tensors()` and `invoke()`. This function cannot be used to read 238 intermediate results. 239 240 Usage: 241 242 ``` 243 interpreter.allocate_tensors() 244 input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) 245 output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) 246 for i in range(10): 247 input().fill(3.) 248 interpreter.invoke() 249 print("inference %s" % output()) 250 ``` 251 252 Notice how this function avoids making a numpy array directly. This is 253 because it is important to not hold actual numpy views to the data longer 254 than necessary. If you do, then the interpreter can no longer be invoked, 255 because it is possible the interpreter would resize and invalidate the 256 referenced tensors. The NumPy API doesn't allow any mutability of the 257 the underlying buffers. 258 259 WRONG: 260 261 ``` 262 input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() 263 output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() 264 interpreter.allocate_tensors() # This will throw RuntimeError 265 for i in range(10): 266 input.fill(3.) 267 interpreter.invoke() # this will throw RuntimeError since input,output 268 ``` 269 270 Args: 271 tensor_index: Tensor index of tensor to get. This value can be gotten from 272 the 'index' field in get_output_details. 273 274 Returns: 275 A function that can return a new numpy array pointing to the internal 276 TFLite tensor state at any point. It is safe to hold the function forever, 277 but it is not safe to hold the numpy array forever. 278 """ 279 return lambda: self._interpreter.tensor(self._interpreter, tensor_index) 280 281 def invoke(self): 282 """Invoke the interpreter. 283 284 Be sure to set the input sizes, allocate tensors and fill values before 285 calling this. 286 287 Raises: 288 ValueError: When the underlying interpreter fails raise ValueError. 289 """ 290 self._ensure_safe() 291 self._interpreter.Invoke() 292 293 def reset_all_variables(self): 294 return self._interpreter.ResetVariableTensors() 295