1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converts a frozen graph into a TFLite FlatBuffer.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import enum # pylint: disable=g-bad-import-order 22 23import os as _os 24import platform as _platform 25import subprocess as _subprocess 26import tempfile as _tempfile 27 28from tensorflow.lite.python import lite_constants 29from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 30from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 31from tensorflow.lite.toco import types_pb2 as _types_pb2 32from tensorflow.python.framework import dtypes 33from tensorflow.python.platform import resource_loader as _resource_loader 34from tensorflow.python.util import deprecation 35from tensorflow.python.util.lazy_loader import LazyLoader 36from tensorflow.python.util.tf_export import tf_export as _tf_export 37 38# Lazy load since some of the performance benchmark skylark rules 39# break dependencies. 40_toco_python = LazyLoader( 41 "tensorflow_wrap_toco", globals(), 42 "tensorflow.lite.toco.python." 43 "tensorflow_wrap_toco") 44del LazyLoader 45 46# Find the toco_from_protos binary using the resource loader if using from 47# bazel, otherwise we are in a pip where console_scripts already has 48# the toco_from_protos tool. 49if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: 50 _toco_from_proto_bin = "" 51else: 52 _toco_from_proto_bin = _resource_loader.get_path_to_datafile( 53 "../toco/python/toco_from_protos") 54 55if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin): 56 _toco_from_proto_bin = "toco_from_protos" 57 58 59# Map of tf.dtypes to TFLite types_flag_pb2. 60_MAP_TF_TO_TFLITE_TYPES = { 61 dtypes.float32: _types_pb2.FLOAT, 62 dtypes.int32: _types_pb2.INT32, 63 dtypes.int64: _types_pb2.INT64, 64 dtypes.string: _types_pb2.STRING, 65 dtypes.uint8: _types_pb2.QUANTIZED_UINT8, 66 dtypes.complex64: _types_pb2.COMPLEX64 67} 68 69 70def _try_convert_to_unicode(output): 71 if output is None: 72 return u"" 73 74 if isinstance(output, bytes): 75 try: 76 return output.decode() 77 except UnicodeDecodeError: 78 pass 79 return output 80 81 82def convert_dtype_to_tflite_type(tf_dtype): 83 """Converts tf.dtype to TFLite proto type. 84 85 Args: 86 tf_dtype: tf.dtype 87 88 Raises: 89 ValueError: Unsupported tf.dtype. 90 91 Returns: 92 types_flag_pb2. 93 """ 94 result = _MAP_TF_TO_TFLITE_TYPES.get(tf_dtype) 95 if result is None: 96 raise ValueError("Unsupported tf.dtype {0}".format(tf_dtype)) 97 return result 98 99 100@_tf_export("lite.OpsSet") 101class OpsSet(enum.Enum): 102 """Enum class defining the sets of ops available to generate TFLite models. 103 104 WARNING: Experimental interface, subject to change. 105 """ 106 # Convert model using TensorFlow Lite builtin ops. 107 TFLITE_BUILTINS = "TFLITE_BUILTINS" 108 109 # Convert model using TensorFlow ops. Not all TensorFlow ops are available. 110 # WARNING: Experimental interface, subject to change. 111 SELECT_TF_OPS = "SELECT_TF_OPS" 112 113 def __str__(self): 114 return self.value 115 116 @staticmethod 117 def get_options(): 118 """Returns a list of OpsSet options as a list of strings.""" 119 return [str(option) for option in list(OpsSet)] 120 121 122class ConverterError(Exception): 123 """Raised when an error occurs during model conversion.""" 124 pass 125 126 127# Don't expose these for now. 128# @_tf_export("lite.toco_convert_protos") 129def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): 130 """Convert `input_data_str` according to model and toco parameters. 131 132 Unless you know what you are doing consider using 133 the more friendly `tf.lite.toco_convert`. 134 135 Args: 136 model_flags_str: Serialized proto describing model properties, see 137 `toco/model_flags.proto`. 138 toco_flags_str: Serialized proto describing conversion properties, see 139 `toco/toco_flags.proto`. 140 input_data_str: Input data in serialized form (e.g. a graphdef is common) 141 Returns: 142 Converted model in serialized form (e.g. a TFLITE model is common). 143 Raises: 144 ConverterError: When conversion fails in TFLiteConverter, usually due to 145 ops not being supported. 146 RuntimeError: When conversion fails, an exception is raised with the error 147 message embedded. 148 """ 149 # TODO(aselle): When toco does not use fatal errors for failure, we can 150 # switch this on. 151 if not _toco_from_proto_bin: 152 try: 153 model_str = _toco_python.TocoConvert(model_flags_str, toco_flags_str, 154 input_data_str) 155 return model_str 156 except Exception as e: 157 raise ConverterError("TOCO failed: %s" % e) 158 159 # Windows and TemporaryFile are not that useful together, 160 # since you cannot have two readers/writers. So we have to 161 # make the temporaries and close and delete them explicitly. 162 toco_filename, model_filename, input_filename, output_filename = ( 163 None, None, None, None) 164 try: 165 # Build all input files 166 with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \ 167 _tempfile.NamedTemporaryFile(delete=False) as fp_model, \ 168 _tempfile.NamedTemporaryFile(delete=False) as fp_input: 169 toco_filename = fp_toco.name 170 input_filename = fp_input.name 171 model_filename = fp_model.name 172 fp_model.write(model_flags_str) 173 fp_toco.write(toco_flags_str) 174 fp_input.write(input_data_str) 175 fp_model.flush() 176 fp_toco.flush() 177 fp_input.flush() 178 179 # Reserve an output file 180 with _tempfile.NamedTemporaryFile(delete=False) as fp: 181 output_filename = fp.name 182 183 # Run 184 cmd = [ 185 _toco_from_proto_bin, model_filename, toco_filename, input_filename, 186 output_filename 187 ] 188 cmdline = " ".join(cmd) 189 is_windows = _platform.system() == "Windows" 190 proc = _subprocess.Popen( 191 cmdline, 192 shell=True, 193 stdout=_subprocess.PIPE, 194 stderr=_subprocess.STDOUT, 195 close_fds=not is_windows) 196 stdout, stderr = proc.communicate() 197 exitcode = proc.returncode 198 if exitcode == 0: 199 with open(output_filename, "rb") as fp: 200 return fp.read() 201 else: 202 stdout = _try_convert_to_unicode(stdout) 203 stderr = _try_convert_to_unicode(stderr) 204 raise ConverterError( 205 "TOCO failed. See console for info.\n%s\n%s\n" % (stdout, stderr)) 206 finally: 207 # Must manually cleanup files. 208 for filename in [ 209 toco_filename, input_filename, model_filename, output_filename]: 210 try: 211 _os.unlink(filename) 212 except (OSError, TypeError): 213 pass 214 215 216def tensor_name(x): 217 """Returns name of the input tensor.""" 218 parts = x.name.split(":") 219 if len(parts) > 2: 220 raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( 221 len(parts) - 1)) 222 223 # To be consistent with the tensor naming scheme in tensorflow, we need 224 # drop the ':0' suffix for the first tensor. 225 if len(parts) > 1 and parts[1] != "0": 226 return x.name 227 return parts[0] 228 229 230# Don't expose these for now. 231# @_tf_export("lite.build_toco_convert_protos") 232def build_toco_convert_protos(input_tensors, 233 output_tensors, 234 inference_type=lite_constants.FLOAT, 235 inference_input_type=None, 236 input_format=lite_constants.TENSORFLOW_GRAPHDEF, 237 input_shapes=None, 238 output_format=lite_constants.TFLITE, 239 quantized_input_stats=None, 240 default_ranges_stats=None, 241 drop_control_dependency=True, 242 reorder_across_fake_quant=False, 243 allow_custom_ops=False, 244 change_concat_input_ranges=False, 245 post_training_quantize=False, 246 dump_graphviz_dir=None, 247 dump_graphviz_video=False, 248 target_ops=None, 249 allow_nonexistent_arrays=False): 250 """Builds protocol buffers describing a conversion of a model using TOCO. 251 252 Typically this is to convert from TensorFlow GraphDef to TFLite, in which 253 case the default `input_format` and `output_format` are sufficient. 254 255 Args: 256 input_tensors: List of input tensors. Type and shape are computed using 257 `foo.shape` and `foo.dtype`. 258 output_tensors: List of output tensors (only .name is used from this). 259 inference_type: Target data type of real-number arrays in the output file. 260 Must be `{tf.float32, tf.uint8}`. (default tf.float32) 261 inference_input_type: Target data type of real-number input arrays. Allows 262 for a different type for input arrays in the case of quantization. 263 Must be `{tf.float32, tf.uint8}`. (default `inference_type`) 264 input_format: Type of data to read Currently must be 265 `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) 266 input_shapes: Input array shape. It needs to be a list of the same length 267 as `input_tensors`, or None. (default None) 268 output_format: Output file format. Currently must be `{TFLITE, 269 GRAPHVIZ_DOT}`. (default TFLITE) 270 quantized_input_stats: List of tuples of floats representing the mean and 271 standard deviation. Each tuple maps to the corresponding input tensor. 272 Only need if `inference_input_type` is `QUANTIZED_UINT8`. 273 real_input_value = (quantized_input_value - mean_value) / std_dev_value. 274 (default None) 275 default_ranges_stats: Tuple of integers representing (min, max) range values 276 for all arrays without a specified range. Intended for experimenting with 277 quantization via "dummy quantization". (default None) 278 drop_control_dependency: Boolean indicating whether to drop control 279 dependencies silently. This is due to TFLite not supporting control 280 dependencies. (default True) 281 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 282 nodes in unexpected locations. Used when the location of the FakeQuant 283 nodes is preventing graph transformations necessary to convert the graph. 284 Results in a graph that differs from the quantized training graph, 285 potentially causing differing arithmetic behavior. (default False) 286 allow_custom_ops: Boolean indicating whether to allow custom operations. 287 When false any unknown operation is an error. When true, custom ops are 288 created for any op that is unknown. The developer will need to provide 289 these to the TensorFlow Lite runtime with a custom resolver. 290 (default False) 291 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 292 inputs and outputs of the concat operator for quantized models. Changes 293 the ranges of concat operator overlap when true. (default False) 294 post_training_quantize: Boolean indicating whether to quantize the weights 295 of the converted float model. Model size will be reduced and there will be 296 latency improvements (at the cost of accuracy). 297 (default False) 298 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 299 stages of processing GraphViz .dot files. Preferred over 300 --output_format=GRAPHVIZ_DOT in order to keep the requirements of the 301 output file. (default None) 302 dump_graphviz_video: Boolean indicating whether to dump the graph after 303 every graph transformation. (default False) 304 target_ops: Experimental flag, subject to change. Set of OpsSet 305 options indicating which converter to use. 306 (default set([OpsSet.TFLITE_BUILTINS])) 307 allow_nonexistent_arrays: Allow specifying array names that don't exist 308 or are unused in the final graph. (default False) 309 310 Returns: 311 model_flags, toco_flags: two protocol buffers describing the conversion 312 process. 313 314 Raises: 315 ValueError: 316 If the input tensor type is unknown 317 Missing mean_values or std_dev_values 318 RuntimeError: If TOCO fails to convert (in which case the runtime error's 319 error text will contain the TOCO error log) 320 """ 321 toco = _toco_flags_pb2.TocoFlags() 322 toco.input_format = input_format 323 toco.output_format = output_format 324 toco.inference_type = convert_dtype_to_tflite_type(inference_type) 325 if inference_input_type: 326 toco.inference_input_type = convert_dtype_to_tflite_type( 327 inference_input_type) 328 else: 329 toco.inference_input_type = toco.inference_type 330 toco.drop_control_dependency = drop_control_dependency 331 toco.reorder_across_fake_quant = reorder_across_fake_quant 332 toco.allow_custom_ops = allow_custom_ops 333 toco.post_training_quantize = post_training_quantize 334 if default_ranges_stats: 335 toco.default_ranges_min = default_ranges_stats[0] 336 toco.default_ranges_max = default_ranges_stats[1] 337 if dump_graphviz_dir: 338 toco.dump_graphviz_dir = dump_graphviz_dir 339 toco.dump_graphviz_include_video = dump_graphviz_video 340 if target_ops: 341 if set(target_ops) == set([OpsSet.TFLITE_BUILTINS, OpsSet.SELECT_TF_OPS]): 342 toco.enable_select_tf_ops = True 343 elif set(target_ops) == set([OpsSet.SELECT_TF_OPS]): 344 toco.enable_select_tf_ops = True 345 toco.force_select_tf_ops = True 346 347 model = _model_flags_pb2.ModelFlags() 348 model.change_concat_input_ranges = change_concat_input_ranges 349 for idx, input_tensor in enumerate(input_tensors): 350 input_array = model.input_arrays.add() 351 input_array.name = tensor_name(input_tensor) 352 input_array.data_type = convert_dtype_to_tflite_type(input_tensor.dtype) 353 354 if toco.inference_input_type == _types_pb2.QUANTIZED_UINT8: 355 if not quantized_input_stats: 356 raise ValueError("std_dev and mean must be defined when " 357 "inference_input_type is QUANTIZED_UINT8.") 358 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] 359 if input_shapes is None: 360 shape = input_tensor.shape 361 else: 362 shape = input_shapes[idx] 363 input_array.shape.dims.extend(map(int, shape)) 364 365 for output_tensor in output_tensors: 366 model.output_arrays.append(tensor_name(output_tensor)) 367 368 model.allow_nonexistent_arrays = allow_nonexistent_arrays 369 370 return model, toco 371 372 373def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays, 374 *args, **kwargs): 375 """"Convert a model using TOCO. 376 377 This function is used to convert GraphDefs that cannot be loaded into 378 TensorFlow to TFLite. Conversion can be customized by providing arguments 379 that are forwarded to `build_toco_convert_protos` (see documentation for 380 details). 381 382 Args: 383 input_data: Input data (i.e. often `sess.graph_def`), 384 input_arrays_with_shape: Tuple of strings representing input tensor names 385 and list of integers representing input shapes 386 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 387 into TensorFlow and when `input_tensors` is None. (default None) 388 output_arrays: List of output tensors to freeze graph with. Use only when 389 graph cannot be loaded into TensorFlow and when `output_tensors` is None. 390 (default None) 391 *args: See `build_toco_convert_protos`, 392 **kwargs: See `build_toco_convert_protos`. 393 394 Returns: 395 The converted data. For example if TFLite was the destination, then 396 this will be a tflite flatbuffer in a bytes array. 397 398 Raises: 399 Defined in `build_toco_convert_protos`. 400 """ 401 model_flags, toco_flags = build_toco_convert_protos( 402 input_tensors=[], output_tensors=[], *args, **kwargs) 403 404 for idx, (name, shape) in enumerate(input_arrays_with_shape): 405 input_array = model_flags.input_arrays.add() 406 if toco_flags.inference_input_type == _types_pb2.QUANTIZED_UINT8: 407 if (("quantized_input_stats" not in kwargs) or 408 (not kwargs["quantized_input_stats"])): 409 raise ValueError("std_dev and mean must be defined when " 410 "inference_input_type is QUANTIZED_UINT8.") 411 input_array.mean_value, input_array.std_value = kwargs[ 412 "quantized_input_stats"][idx] 413 input_array.name = name 414 input_array.shape.dims.extend(map(int, shape)) 415 416 for name in output_arrays: 417 model_flags.output_arrays.append(name) 418 419 data = toco_convert_protos(model_flags.SerializeToString(), 420 toco_flags.SerializeToString(), 421 input_data.SerializeToString()) 422 return data 423 424 425def toco_convert_impl(input_data, input_tensors, output_tensors, *args, 426 **kwargs): 427 """"Convert a model using TOCO. 428 429 Typically this function is used to convert from TensorFlow GraphDef to TFLite. 430 Conversion can be customized by providing arguments that are forwarded to 431 `build_toco_convert_protos` (see documentation for details). 432 433 Args: 434 input_data: Input data (i.e. often `sess.graph_def`), 435 input_tensors: List of input tensors. Type and shape are computed using 436 `foo.shape` and `foo.dtype`. 437 output_tensors: List of output tensors (only .name is used from this). 438 *args: See `build_toco_convert_protos`, 439 **kwargs: See `build_toco_convert_protos`. 440 441 Returns: 442 The converted data. For example if TFLite was the destination, then 443 this will be a tflite flatbuffer in a bytes array. 444 445 Raises: 446 Defined in `build_toco_convert_protos`. 447 """ 448 model_flags, toco_flags = build_toco_convert_protos( 449 input_tensors, output_tensors, *args, **kwargs) 450 data = toco_convert_protos(model_flags.SerializeToString(), 451 toco_flags.SerializeToString(), 452 input_data.SerializeToString()) 453 return data 454 455 456@_tf_export(v1=["lite.toco_convert"]) 457@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.") 458def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): 459 """Convert a model using TOCO. 460 461 Typically this function is used to convert from TensorFlow GraphDef to TFLite. 462 Conversion can be customized by providing arguments that are forwarded to 463 `build_toco_convert_protos` (see documentation for details). This function has 464 been deprecated. Please use `lite.TFLiteConverter` instead. 465 466 Args: 467 input_data: Input data (i.e. often `sess.graph_def`), 468 input_tensors: List of input tensors. Type and shape are computed using 469 `foo.shape` and `foo.dtype`. 470 output_tensors: List of output tensors (only .name is used from this). 471 *args: See `build_toco_convert_protos`, 472 **kwargs: See `build_toco_convert_protos`. 473 474 Returns: 475 The converted data. For example if TFLite was the destination, then 476 this will be a tflite flatbuffer in a bytes array. 477 478 Raises: 479 Defined in `build_toco_convert_protos`. 480 """ 481 return toco_convert_impl(input_data, input_tensors, output_tensors, *args, 482 **kwargs) 483