1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Converts a frozen graph into a TFLite FlatBuffer.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import distutils.spawn 23import enum # pylint: disable=g-bad-import-order 24import os as _os 25import platform as _platform 26import subprocess as _subprocess 27import tempfile as _tempfile 28 29import six 30from six.moves import map 31 32from tensorflow.lite.python import lite_constants 33from tensorflow.lite.python import util 34from tensorflow.lite.python import wrap_toco 35from tensorflow.lite.python.convert_phase import Component 36from tensorflow.lite.python.convert_phase import convert_phase 37from tensorflow.lite.python.convert_phase import ConverterError 38from tensorflow.lite.python.convert_phase import SubComponent 39from tensorflow.lite.python.metrics_wrapper import metrics_wrapper as _metrics_wrapper 40from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 41from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 42from tensorflow.lite.toco import types_pb2 as _types_pb2 43from tensorflow.lite.tools import flatbuffer_utils 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import tensor_shape 46from tensorflow.python.platform import resource_loader as _resource_loader 47from tensorflow.python.util import deprecation 48from tensorflow.python.util.tf_export import tf_export as _tf_export 49 50 51def _requires_input_stats(toco_flags: _toco_flags_pb2.TocoFlags()) -> bool: 52 """Checks if the `input_stats` flag is required for conversion. 53 54 Args: 55 toco_flags: A protocol buffer describing the conversion process. 56 57 Returns: 58 True, if the `inference_type` or the `inference_input_type` is a quantized 59 type and it is not post training quantization, else False. 60 """ 61 quantized_inference_types = ( 62 [_types_pb2.QUANTIZED_UINT8, _types_pb2.QUANTIZED_INT8]) 63 return ((toco_flags.inference_type in quantized_inference_types or 64 toco_flags.inference_input_type in quantized_inference_types) and 65 not toco_flags.post_training_quantize) 66 67 68def convert_tensor_tf_type_to_tflite_type( 69 tf_type: dtypes.DType, usage: str = "") -> _types_pb2.IODataType: 70 """Convert tensor type from tf type to tflite type. 71 72 Args: 73 tf_type: TensorFlow type. 74 usage: Text describing the reason for invoking this function. 75 76 Raises: 77 ValueError: If `tf_type` is unsupported. 78 79 Returns: 80 tflite_type: TFLite type. Refer to lite/toco/types.proto. 81 """ 82 mapping = { 83 dtypes.float16: _types_pb2.FLOAT16, 84 dtypes.float32: _types_pb2.FLOAT, 85 dtypes.float64: _types_pb2.FLOAT64, 86 dtypes.int8: _types_pb2.INT8, 87 dtypes.int16: _types_pb2.INT16, 88 dtypes.int32: _types_pb2.INT32, 89 dtypes.int64: _types_pb2.INT64, 90 dtypes.uint8: _types_pb2.UINT8, 91 dtypes.uint32: _types_pb2.UINT32, 92 dtypes.uint64: _types_pb2.UINT64, 93 dtypes.string: _types_pb2.STRING, 94 dtypes.bool: _types_pb2.BOOL, 95 dtypes.complex64: _types_pb2.COMPLEX64, 96 dtypes.complex128: _types_pb2.COMPLEX128, 97 } 98 tflite_type = mapping.get(tf_type) 99 if tflite_type is None: 100 raise ValueError("Unsupported TensorFlow type `{0}` provided for the {1}" 101 .format(tf_type, usage)) 102 return tflite_type 103 104 105# Only a few restricted tensor types are allowed for explicitly setting 106# inference/input/output types. 107def convert_inference_tf_type_to_tflite_type( 108 tf_type: dtypes.DType, usage: str = "") -> _types_pb2.IODataType: 109 """Convert inference type from tf type to tflite type. 110 111 Args: 112 tf_type: TensorFlow type. 113 usage: Text describing the reason for invoking this function. 114 115 Raises: 116 ValueError: If `tf_type` is unsupported. 117 118 Returns: 119 tflite_type: TFLite type. Refer to lite/toco/types.proto. 120 """ 121 mapping = { 122 dtypes.float32: _types_pb2.FLOAT, 123 dtypes.uint8: _types_pb2.QUANTIZED_UINT8, 124 dtypes.int8: _types_pb2.QUANTIZED_INT8, 125 dtypes.int16: _types_pb2.QUANTIZED_INT16, 126 } 127 tflite_type = mapping.get(tf_type) 128 if tflite_type is None: 129 raise ValueError("Unsupported TensorFlow type `{0}` provided for the {1}" 130 .format(tf_type, usage)) 131 return tflite_type 132 133 134# Find the toco_from_protos binary using the resource loader if using from 135# bazel, otherwise we are in a pip where console_scripts already has 136# the toco_from_protos tool. 137if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: 138 _toco_from_proto_bin = "" 139else: 140 _toco_from_proto_bin = _resource_loader.get_path_to_datafile( 141 "../toco/python/toco_from_protos") 142 143if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin): 144 _toco_from_proto_bin = "toco_from_protos" 145 146 147def _try_convert_to_unicode(output): 148 if output is None: 149 return u"" 150 151 if isinstance(output, bytes): 152 try: 153 return six.ensure_text(output) 154 except UnicodeDecodeError: 155 pass 156 return output 157 158 159@_tf_export("lite.OpsSet") 160class OpsSet(enum.Enum): 161 """Enum class defining the sets of ops available to generate TFLite models. 162 163 WARNING: Experimental interface, subject to change. 164 """ 165 # Convert model using TensorFlow Lite builtin ops. 166 TFLITE_BUILTINS = "TFLITE_BUILTINS" 167 168 # Convert model using TensorFlow ops. Not all TensorFlow ops are available. 169 # WARNING: Experimental interface, subject to change. 170 SELECT_TF_OPS = "SELECT_TF_OPS" 171 172 # Convert model using only TensorFlow Lite quantized int8 operations. 173 # Specifying this will throw an error for operations that do not yet have 174 # quantized implementations. 175 TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8" 176 177 # Convert model using only TensorFlow Lite operations with quantized int8 178 # weights, int16 activations and int64 bias. 179 # Specifying this will throw an error for operations that do not yet have 180 # quantized implementations. 181 # This quantization mode may be used in models for super-resolution, 182 # audio signal processing or image de-noising. It improves accuracy 183 # significantly, but only slightly increases the model size. 184 # WARNING: These ops are currently experimental and have not yet been 185 # finalized. 186 # They are only compatible with CPU execution, and have not been optimized for 187 # production. 188 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = ( 189 "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8") 190 191 def __str__(self): 192 return str(self.value) 193 194 @staticmethod 195 def get_options(): 196 """Returns a list of OpsSet options as a list of strings.""" 197 return [str(option) for option in list(OpsSet)] 198 199 200@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.QUANTIZE) 201def mlir_quantize(input_data_str, 202 disable_per_channel=False, 203 fully_quantize=False, 204 inference_type=_types_pb2.QUANTIZED_INT8, 205 input_data_type=dtypes.float32, 206 output_data_type=dtypes.float32, 207 enable_numeric_verify=False, 208 enable_whole_model_verify=False, 209 denylisted_ops=None, 210 denylisted_nodes=None): 211 """Quantize `input_data_str` with calibration results. 212 213 Args: 214 input_data_str: Input data in serialized form (e.g. a TFLITE model with 215 calibration results). 216 disable_per_channel: Bool indicating whether to do per-channel or per-tensor 217 quantization 218 fully_quantize: Bool indicating whether to fully quantize the model. Besides 219 model body, the input/output will be quantized as well. 220 inference_type: Data type for the activations. The default value is int8. 221 input_data_type: Data type for the inputs. The default value is float32. 222 output_data_type: Data type for the outputs. The default value is float32. 223 enable_numeric_verify: Experimental. Subject to change. Bool indicating 224 whether to add NumericVerify ops into the debug mode quantized model. 225 enable_whole_model_verify: Experimental. Subject to change. Bool indicating 226 whether to add verification for layer by layer, or on whole model. When 227 disabled (per-layer) float and quantized ops will be run from same input 228 (output of previous quantized layer). When enabled, float and quantized ops 229 will run with respective float and quantized output of previous ops. 230 denylisted_ops: Experimental. Subject to change. Set of ops to denylist. 231 denylisted_nodes: Experimental. Subject to change. Set of notes to 232 denylist. 233 Returns: 234 Quantized model in serialized form (e.g. a TFLITE model) with floating-point 235 inputs and outputs. 236 """ 237 return wrap_toco.wrapped_experimental_mlir_quantize( 238 input_data_str, disable_per_channel, fully_quantize, inference_type, 239 convert_tensor_tf_type_to_tflite_type(input_data_type), 240 convert_tensor_tf_type_to_tflite_type(output_data_type), 241 enable_numeric_verify, enable_whole_model_verify, denylisted_ops, 242 denylisted_nodes) 243 244 245@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.SPARSIFY) 246def mlir_sparsify(input_data_str): 247 """Sparsify `input_data_str` to encode sparse tensor with proper format. 248 249 Args: 250 input_data_str: Input data in serialized form (e.g. a TFLITE model). 251 252 Returns: 253 Sparsified model in serialized form (e.g. a TFLITE model). 254 """ 255 return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str) 256 257 258def register_custom_opdefs(custom_opdefs_list): 259 """Register the given custom opdefs to the TensorFlow global op registry. 260 261 Args: 262 custom_opdefs_list: String representing the custom ops OpDefs that are 263 included in the GraphDef. 264 265 Returns: 266 True if the registration is successfully completed. 267 """ 268 return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list) 269 270 271def toco_convert_protos(model_flags_str, 272 toco_flags_str, 273 input_data_str, 274 debug_info_str=None, 275 enable_mlir_converter=False): 276 """Convert `input_data_str` according to model and toco parameters. 277 278 Unless you know what you are doing consider using 279 the more friendly `tf.compat.v1.lite.toco_convert`. 280 281 Args: 282 model_flags_str: Serialized proto describing model properties, see 283 `toco/model_flags.proto`. 284 toco_flags_str: Serialized proto describing conversion properties, see 285 `toco/toco_flags.proto`. 286 input_data_str: Input data in serialized form (e.g. a graphdef is common) 287 debug_info_str: Serialized `GraphDebugInfo` proto describing logging 288 information. (default None) 289 enable_mlir_converter: Enables MLIR-based conversion instead of the default 290 TOCO conversion. (default False) 291 292 Returns: 293 Converted model in serialized form (e.g. a TFLITE model is common). 294 Raises: 295 ConverterError: When conversion fails in TFLiteConverter, usually due to 296 ops not being supported. 297 RuntimeError: When conversion fails, an exception is raised with the error 298 message embedded. 299 """ 300 # Historically, TOCO conversion failures would trigger a crash, so we would 301 # attempt to run the converter out-of-process. The MLIR conversion pipeline 302 # surfaces errors instead, and can be safely run in-process. 303 if enable_mlir_converter or not _toco_from_proto_bin: 304 try: 305 model_str = wrap_toco.wrapped_toco_convert(model_flags_str, 306 toco_flags_str, input_data_str, 307 debug_info_str, 308 enable_mlir_converter) 309 return model_str 310 except Exception as e: 311 converter_error = ConverterError(str(e)) 312 for error_data in _metrics_wrapper.retrieve_collected_errors(): 313 converter_error.append_error(error_data) 314 raise converter_error 315 316 return _run_toco_binary(model_flags_str, toco_flags_str, input_data_str, 317 debug_info_str) 318 319 320@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 321 SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER) 322def _run_toco_binary(model_flags_str, 323 toco_flags_str, 324 input_data_str, 325 debug_info_str=None): 326 """Convert `input_data_str` using TOCO converter binary. 327 328 Args: 329 model_flags_str: Serialized proto describing model properties, see 330 `toco/model_flags.proto`. 331 toco_flags_str: Serialized proto describing conversion properties, see 332 `toco/toco_flags.proto`. 333 input_data_str: Input data in serialized form (e.g. a graphdef is common) 334 debug_info_str: Serialized `GraphDebugInfo` proto describing logging 335 information. (default None) 336 337 Returns: 338 Converted model in serialized form (e.g. a TFLITE model is common). 339 Raises: 340 ConverterError: When cannot find the toco binary. 341 RuntimeError: When conversion fails, an exception is raised with the error 342 message embedded. 343 """ 344 if distutils.spawn.find_executable(_toco_from_proto_bin) is None: 345 raise ConverterError("""Could not find toco_from_protos binary, make sure 346your virtualenv bin directory or pip local bin directory is in your path. 347In particular, if you have installed TensorFlow with --user, make sure you 348add the install directory to your path. 349 350For example: 351Linux: export PATH=$PATH:~/.local/bin/ 352Mac: export PATH=$PATH:~/Library/Python/<version#>/bin 353 354Alternative, use virtualenv.""") 355 # Windows and TemporaryFile are not that useful together, 356 # since you cannot have two readers/writers. So we have to 357 # make the temporaries and close and delete them explicitly. 358 toco_filename, model_filename, input_filename, output_filename = (None, None, 359 None, None) 360 try: 361 # Build all input files 362 with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \ 363 _tempfile.NamedTemporaryFile(delete=False) as fp_model, \ 364 _tempfile.NamedTemporaryFile(delete=False) as fp_input, \ 365 _tempfile.NamedTemporaryFile(delete=False) as fp_debug: 366 toco_filename = fp_toco.name 367 input_filename = fp_input.name 368 model_filename = fp_model.name 369 debug_filename = fp_debug.name 370 371 fp_model.write(model_flags_str) 372 fp_toco.write(toco_flags_str) 373 fp_input.write(six.ensure_binary(input_data_str)) 374 debug_info_str = debug_info_str if debug_info_str else "" 375 # if debug_info_str contains a "string value", then the call to 376 # fp_debug.write(debug_info_str) will fail with the following error 377 # 378 # TypeError: a bytes-like object is required, not 'str' 379 # 380 # Some of the subtests within the "convert_test" unit-test fail 381 # with the error shown above. So watch out for that scenario and 382 # convert debug_info_str to bytes where needed 383 if not isinstance(debug_info_str, bytes): 384 fp_debug.write(debug_info_str.encode("utf-8")) 385 else: 386 fp_debug.write(debug_info_str) 387 388 # Reserve an output file 389 with _tempfile.NamedTemporaryFile(delete=False) as fp: 390 output_filename = fp.name 391 392 # Run 393 cmd = [ 394 _toco_from_proto_bin, 395 model_filename, 396 toco_filename, 397 input_filename, 398 output_filename, 399 "--debug_proto_file={}".format(debug_filename), 400 ] 401 cmdline = " ".join(cmd) 402 is_windows = _platform.system() == "Windows" 403 proc = _subprocess.Popen( 404 cmdline, 405 shell=True, 406 stdout=_subprocess.PIPE, 407 stderr=_subprocess.STDOUT, 408 close_fds=not is_windows) 409 stdout, stderr = proc.communicate() 410 exitcode = proc.returncode 411 if exitcode == 0: 412 with open(output_filename, "rb") as fp: 413 return fp.read() 414 else: 415 stdout = _try_convert_to_unicode(stdout) 416 stderr = _try_convert_to_unicode(stderr) 417 raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr)) 418 finally: 419 # Must manually cleanup files. 420 for filename in [ 421 toco_filename, input_filename, model_filename, output_filename 422 ]: 423 try: 424 _os.unlink(filename) 425 except (OSError, TypeError): 426 pass 427 428 429def build_toco_flags(inference_type=dtypes.float32, 430 inference_input_type=None, 431 input_format=lite_constants.TENSORFLOW_GRAPHDEF, 432 output_format=lite_constants.TFLITE, 433 default_ranges_stats=None, 434 drop_control_dependency=True, 435 reorder_across_fake_quant=False, 436 allow_custom_ops=False, 437 post_training_quantize=False, 438 quantize_to_float16=False, 439 dump_graphviz_dir=None, 440 dump_graphviz_video=False, 441 target_ops=None, 442 conversion_summary_dir=None, 443 select_user_tf_ops=None, 444 allow_all_select_tf_ops=False, 445 enable_tflite_resource_variables=False, 446 unfold_batchmatmul=True, 447 lower_tensor_list_ops=True, 448 accumulation_type=None, 449 allow_bfloat16=False, 450 unfold_large_splat_constant=False, 451 **_): 452 """Build the TOCO flags object from params.""" 453 toco = _toco_flags_pb2.TocoFlags() 454 toco.input_format = input_format 455 toco.output_format = output_format 456 toco.inference_type = convert_inference_tf_type_to_tflite_type( 457 inference_type, usage="inference_type flag") 458 if inference_input_type: 459 toco.inference_input_type = convert_inference_tf_type_to_tflite_type( 460 inference_input_type, usage="inference_input_type flag") 461 else: 462 toco.inference_input_type = toco.inference_type 463 toco.drop_control_dependency = drop_control_dependency 464 toco.reorder_across_fake_quant = reorder_across_fake_quant 465 toco.allow_custom_ops = allow_custom_ops 466 if select_user_tf_ops: 467 toco.select_user_tf_ops.extend(select_user_tf_ops) 468 toco.allow_all_select_tf_ops = allow_all_select_tf_ops 469 toco.post_training_quantize = post_training_quantize 470 toco.quantize_to_float16 = quantize_to_float16 471 if default_ranges_stats: 472 toco.default_ranges_min = default_ranges_stats[0] 473 toco.default_ranges_max = default_ranges_stats[1] 474 if dump_graphviz_dir: 475 toco.dump_graphviz_dir = dump_graphviz_dir 476 toco.dump_graphviz_include_video = dump_graphviz_video 477 if conversion_summary_dir: 478 toco.conversion_summary_dir = conversion_summary_dir 479 if target_ops: 480 if OpsSet.SELECT_TF_OPS in set(target_ops): 481 toco.enable_select_tf_ops = True 482 if set(target_ops) == set([OpsSet.SELECT_TF_OPS]): 483 toco.force_select_tf_ops = True 484 toco.enable_tflite_resource_variables = enable_tflite_resource_variables 485 toco.unfold_batchmatmul = unfold_batchmatmul 486 toco.lower_tensor_list_ops = lower_tensor_list_ops 487 toco.unfold_large_splat_constant = unfold_large_splat_constant 488 if accumulation_type: 489 toco.accumulation_type = convert_tensor_tf_type_to_tflite_type( 490 accumulation_type, usage="accumulation_type flag") 491 toco.allow_bfloat16 = allow_bfloat16 492 493 return toco 494 495 496def build_toco_convert_protos(input_tensors, 497 output_tensors, 498 inference_type=dtypes.float32, 499 inference_input_type=None, 500 input_format=lite_constants.TENSORFLOW_GRAPHDEF, 501 input_shapes=None, 502 output_format=lite_constants.TFLITE, 503 quantized_input_stats=None, 504 default_ranges_stats=None, 505 drop_control_dependency=True, 506 reorder_across_fake_quant=False, 507 allow_custom_ops=False, 508 change_concat_input_ranges=False, 509 post_training_quantize=False, 510 quantize_to_float16=False, 511 dump_graphviz_dir=None, 512 dump_graphviz_video=False, 513 target_ops=None, 514 allow_nonexistent_arrays=False, 515 debug_info=None, 516 conversion_summary_dir=None, 517 saved_model_dir=None, 518 saved_model_version=0, 519 saved_model_tags=None, 520 saved_model_exported_names=None, 521 select_user_tf_ops=None, 522 allow_all_select_tf_ops=False, 523 unfold_batchmatmul=True, 524 lower_tensor_list_ops=True, 525 accumulation_type=None, 526 allow_bfloat16=False, 527 unfold_large_splat_constant=False): 528 """Builds protocol buffers describing a conversion of a model using TOCO. 529 530 Typically this is to convert from TensorFlow GraphDef to TFLite, in which 531 case the default `input_format` and `output_format` are sufficient. 532 533 Args: 534 input_tensors: List of input tensors. Type and shape are computed using 535 `foo.shape` and `foo.dtype`. 536 output_tensors: List of output tensors (only .name is used from this). 537 inference_type: Data type of numeric arrays, excluding the input layer. 538 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 539 inference_input_type: Data type of the numeric arrays in the input layer. If 540 `inference_input_type` is in {tf.int8, tf.uint8}, then 541 `quantized_input_stats` must be provided. (default is the value assigned 542 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 543 input_format: Type of data to read. 544 (default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF}) 545 input_shapes: Input array shape. (default None, must be None or a list of 546 the same length as `input_tensors`.) 547 output_format: Output file format. (default TFLITE, must be in 548 {TFLITE, GRAPHVIZ_DOT}) 549 quantized_input_stats: Map of input tensor names to a tuple of floats 550 representing the mean and standard deviation of the training data. 551 (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8 552 or tf.uint8. (default None) 553 default_ranges_stats: Tuple of integers representing (min, max) range values 554 for all arrays without a specified range. Intended for experimenting with 555 quantization via "dummy quantization". (default None) 556 drop_control_dependency: Boolean indicating whether to drop control 557 dependencies silently. This is due to TFLite not supporting control 558 dependencies. (default True) 559 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 560 nodes in unexpected locations. Used when the location of the FakeQuant 561 nodes is preventing graph transformations necessary to convert the graph. 562 Results in a graph that differs from the quantized training graph, 563 potentially causing differing arithmetic behavior. (default False) 564 allow_custom_ops: Boolean indicating whether to allow custom operations. 565 When false any unknown operation is an error. When true, custom ops are 566 created for any op that is unknown. The developer will need to provide 567 these to the TensorFlow Lite runtime with a custom resolver. (default 568 False) 569 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 570 inputs and outputs of the concat operator for quantized models. Changes 571 the ranges of concat operator overlap when true. (default False) 572 post_training_quantize: Boolean indicating whether to quantize the weights 573 of the converted float model. Model size will be reduced and there will be 574 latency improvements (at the cost of accuracy). (default False) 575 quantize_to_float16: Boolean indicating whether to convert float buffers to 576 float16. (default False) 577 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 578 stages of processing GraphViz .dot files. Preferred over 579 --output_format=GRAPHVIZ_DOT in order to keep the requirements of the 580 output file. (default None) 581 dump_graphviz_video: Boolean indicating whether to dump the graph after 582 every graph transformation. (default False) 583 target_ops: Experimental flag, subject to change. Set of OpsSet options 584 indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS])) 585 allow_nonexistent_arrays: Allow specifying array names that don't exist or 586 are unused in the final graph. (default False) 587 debug_info: `GraphDebugInfo` proto containing the stack traces for the 588 original nodes referred by the converted graph. 589 conversion_summary_dir: A string, the path to the generated conversion logs. 590 saved_model_dir: Filepath of the saved model to be converted. This value 591 will be non-empty only when the saved model import path will be used. 592 Otherwises, the graph def-based conversion will be processed. 593 saved_model_version: SavedModel file format version of The saved model file 594 to be converted. This value will be set only when the SavedModel import 595 path will be used. 596 saved_model_tags: Set of string saved model tags, formatted in the 597 comma-separated value. This value will be set only when the SavedModel 598 import path will be used. 599 saved_model_exported_names: Names to be exported (default: export all) when 600 the saved model import path is on. This value will be set only when the 601 SavedModel import path will be used. 602 select_user_tf_ops: List of user's defined TensorFlow ops need to be 603 supported in the TensorFlow Lite runtime. These ops will be supported as 604 select TensorFlow ops. 605 allow_all_select_tf_ops: If True, automatically add all TF ops (including 606 custom TF ops) to the converted model as flex ops. 607 unfold_batchmatmul: Whether to unfold tf.BatchMatMul to a set of 608 tfl.fully_connected ops. If not, translate to tfl.batch_matmul. 609 lower_tensor_list_ops: Whether to lower tensor list ops to builtin ops. If 610 not, use Flex tensor list ops. 611 accumulation_type: Data type of the accumulators in quantized inference. 612 Typically used for float16 quantization and is either fp16 or fp32. 613 allow_bfloat16: Whether the converted model supports reduced precision 614 inference with the bfloat16 type. 615 unfold_large_splat_constant: Whether to unfold large splat constant tensors 616 in the flatbuffer model to reduce size. 617 618 Returns: 619 model_flags, toco_flags, debug_info: three protocol buffers describing the 620 conversion process and debug information. 621 622 Raises: 623 ValueError: 624 If the input tensor type is unknown 625 Missing mean_values or std_dev_values 626 RuntimeError: If TOCO fails to convert (in which case the runtime error's 627 error text will contain the TOCO error log) 628 """ 629 toco = build_toco_flags( 630 inference_type=inference_type, 631 inference_input_type=inference_input_type, 632 input_format=input_format, 633 output_format=output_format, 634 default_ranges_stats=default_ranges_stats, 635 drop_control_dependency=drop_control_dependency, 636 reorder_across_fake_quant=reorder_across_fake_quant, 637 allow_custom_ops=allow_custom_ops, 638 post_training_quantize=post_training_quantize, 639 quantize_to_float16=quantize_to_float16, 640 dump_graphviz_dir=dump_graphviz_dir, 641 dump_graphviz_video=dump_graphviz_video, 642 target_ops=target_ops, 643 conversion_summary_dir=conversion_summary_dir, 644 select_user_tf_ops=select_user_tf_ops, 645 allow_all_select_tf_ops=allow_all_select_tf_ops, 646 unfold_batchmatmul=unfold_batchmatmul, 647 lower_tensor_list_ops=lower_tensor_list_ops, 648 accumulation_type=accumulation_type, 649 allow_bfloat16=allow_bfloat16, 650 unfold_large_splat_constant=unfold_large_splat_constant) 651 model = _model_flags_pb2.ModelFlags() 652 model.change_concat_input_ranges = change_concat_input_ranges 653 for idx, input_tensor in enumerate(input_tensors): 654 input_array = model.input_arrays.add() 655 if saved_model_dir: 656 input_array.name = input_tensor.name 657 else: 658 input_array.name = util.get_tensor_name(input_tensor) 659 input_array.data_type = convert_tensor_tf_type_to_tflite_type( 660 input_tensor.dtype, usage="input type of the TensorFlow model") 661 662 if _requires_input_stats(toco) and quantized_input_stats: 663 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] 664 665 if input_shapes is None: 666 shape = input_tensor.shape 667 else: 668 shape = input_shapes[idx] 669 670 if shape.rank is not None: 671 # Create shapes with -1 for unknown dimensions. 672 dims = [] 673 for dim in shape: 674 if (dim is None or 675 (isinstance(dim, tensor_shape.Dimension) and dim.value is None)): 676 dims.append(-1) 677 else: 678 dims.append(int(dim)) 679 input_array.shape.dims.extend(dims) 680 input_array.shape.unknown_rank = False 681 else: 682 input_array.shape.unknown_rank = True 683 684 for output_tensor in output_tensors: 685 if saved_model_dir: 686 model.output_arrays.append(output_tensor.name) 687 else: 688 model.output_arrays.append(util.get_tensor_name(output_tensor)) 689 690 model.allow_nonexistent_arrays = allow_nonexistent_arrays 691 692 if saved_model_dir: 693 model.saved_model_dir = saved_model_dir 694 model.saved_model_version = saved_model_version 695 if saved_model_tags: 696 model.saved_model_tags.extend(saved_model_tags) 697 if saved_model_exported_names: 698 model.saved_model_exported_names.extend(saved_model_exported_names) 699 700 return model, toco, debug_info 701 702 703@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 704 SubComponent.CONVERT_GRAPHDEF) 705def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays, 706 enable_mlir_converter, control_output_arrays, *args, 707 **kwargs): 708 """"Convert a model using TOCO. 709 710 This function is used to convert GraphDefs that cannot be loaded into 711 TensorFlow to TFLite. Conversion can be customized by providing arguments 712 that are forwarded to `build_toco_convert_protos` (see documentation for 713 details). 714 715 Args: 716 input_data: Input data (i.e. often `sess.graph_def`), 717 input_arrays_with_shape: Tuple of strings representing input tensor names 718 and list of integers representing input shapes 719 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 720 into TensorFlow and when `input_tensors` is None. 721 output_arrays: List of output tensors to freeze graph with. Use only when 722 graph cannot be loaded into TensorFlow and when `output_tensors` is None. 723 enable_mlir_converter: Enables MLIR-based conversion instead of TOCO 724 conversion. 725 control_output_arrays: Control output node names. This is used when 726 converting a Graph with no output tensors. For example, if the 727 graph's last operation is a Print op, just specify that op's name in 728 this field. This can be used together with the `output_arrays` 729 parameter. 730 *args: See `build_toco_convert_protos`, 731 **kwargs: See `build_toco_convert_protos`. 732 733 Returns: 734 The converted data. For example if TFLite was the destination, then 735 this will be a tflite flatbuffer in a bytes array. 736 737 Raises: 738 Defined in `build_toco_convert_protos`. 739 """ 740 model_flags, toco_flags, _ = build_toco_convert_protos( 741 input_tensors=[], output_tensors=[], *args, **kwargs) 742 743 for idx, (name, shape) in enumerate(input_arrays_with_shape): 744 input_array = model_flags.input_arrays.add() 745 if _requires_input_stats(toco_flags): 746 if (("quantized_input_stats" not in kwargs) or 747 (not kwargs["quantized_input_stats"])): 748 raise ValueError( 749 "The `quantized_input_stats` flag must be defined when either " 750 "`inference_type` flag or `inference_input_type` flag is set to " 751 "tf.int8 or tf.uint8.") 752 input_array.mean_value, input_array.std_value = kwargs[ 753 "quantized_input_stats"][idx] 754 input_array.name = name 755 input_array.shape.dims.extend(list(map(int, shape))) 756 757 if output_arrays: 758 for name in output_arrays: 759 model_flags.output_arrays.append(name) 760 if control_output_arrays: 761 for name in control_output_arrays: 762 model_flags.control_output_arrays.append(name) 763 764 data = toco_convert_protos( 765 model_flags.SerializeToString(), 766 toco_flags.SerializeToString(), 767 input_data.SerializeToString(), 768 enable_mlir_converter=enable_mlir_converter) 769 return data 770 771 772@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 773 SubComponent.CONVERT_GRAPHDEF) 774def toco_convert_impl(input_data, input_tensors, output_tensors, 775 enable_mlir_converter, *args, **kwargs): 776 """"Convert a model using TOCO. 777 778 Typically this function is used to convert from TensorFlow GraphDef to TFLite. 779 Conversion can be customized by providing arguments that are forwarded to 780 `build_toco_convert_protos` (see documentation for details). 781 782 Args: 783 input_data: Input data (i.e. often `sess.graph_def`), 784 input_tensors: List of input tensors. Type and shape are computed using 785 `foo.shape` and `foo.dtype`. 786 output_tensors: List of output tensors (only .name is used from this). 787 enable_mlir_converter: Enables MLIR-based conversion instead of TOCO 788 conversion. 789 *args: See `build_toco_convert_protos`, 790 **kwargs: See `build_toco_convert_protos`. 791 792 Returns: 793 The converted data. For example if TFLite was the destination, then 794 this will be a tflite flatbuffer in a bytes array. 795 796 Raises: 797 Defined in `build_toco_convert_protos`. 798 """ 799 model_flags, toco_flags, debug_info = build_toco_convert_protos( 800 input_tensors, output_tensors, *args, **kwargs) 801 debug_info_str = debug_info.SerializeToString() if debug_info else None 802 data = toco_convert_protos( 803 model_flags.SerializeToString(), 804 toco_flags.SerializeToString(), 805 input_data.SerializeToString(), 806 debug_info_str=debug_info_str, 807 enable_mlir_converter=enable_mlir_converter) 808 return data 809 810 811@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 812 SubComponent.CONVERT_SAVED_MODEL) 813def convert_saved_model(saved_model_dir=None, 814 saved_model_version=0, 815 saved_model_tags=None, 816 saved_model_exported_names=None, 817 **kwargs): 818 """Converts a saved_model using TF Lite converter.""" 819 model_flags = _model_flags_pb2.ModelFlags() 820 if saved_model_dir: 821 model_flags.saved_model_dir = saved_model_dir 822 model_flags.saved_model_version = saved_model_version 823 if saved_model_tags: 824 model_flags.saved_model_tags.extend(saved_model_tags) 825 if saved_model_exported_names: 826 model_flags.saved_model_exported_names.extend(saved_model_exported_names) 827 toco_flags = build_toco_flags(**kwargs) 828 data = toco_convert_protos( 829 model_flags.SerializeToString(), 830 toco_flags.SerializeToString(), 831 None, # input_data, unused 832 None, # debug_info_str, unused 833 enable_mlir_converter=True) 834 return data 835 836 837@_tf_export(v1=["lite.toco_convert"]) 838@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.") 839def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): 840 """Convert a model using TOCO. 841 842 Typically this function is used to convert from TensorFlow GraphDef to TFLite. 843 Conversion can be customized by providing arguments that are forwarded to 844 `build_toco_convert_protos` (see documentation for details). This function has 845 been deprecated. Please use `tf.lite.TFLiteConverter` instead. 846 847 Args: 848 input_data: Input data (i.e. often `sess.graph_def`), 849 input_tensors: List of input tensors. Type and shape are computed using 850 `foo.shape` and `foo.dtype`. 851 output_tensors: List of output tensors (only .name is used from this). 852 *args: See `build_toco_convert_protos`, 853 **kwargs: See `build_toco_convert_protos`. 854 855 Returns: 856 The converted data. For example if TFLite was the destination, then 857 this will be a tflite flatbuffer in a bytes array. 858 859 Raises: 860 Defined in `build_toco_convert_protos`. 861 """ 862 enable_mlir_converter = kwargs.get("enable_mlir_converter", False) 863 return toco_convert_impl(input_data, input_tensors, output_tensors, 864 enable_mlir_converter, *args, **kwargs) 865 866 867def deduplicate_readonly_buffers(tflite_model): 868 """"Generates a new model byte array after deduplicating readonly buffers. 869 870 This function should be invoked after the model optimization toolkit. The 871 model optimization toolkit assumes that each tensor object owns its each 872 buffer separately. 873 874 Args: 875 tflite_model: TFLite flatbuffer in a byte array to be deduplicated. 876 877 Returns: 878 TFLite flatbuffer in a bytes array, processed with the deduplication method. 879 880 """ 881 # Load TFLite Flatbuffer byte array into an object. 882 model = flatbuffer_utils.convert_bytearray_to_object(tflite_model) 883 884 # Get all the read-only buffers, which can be modified without causing any 885 # issue in the graph invocation stage. 886 read_only_buffer_indices = set() 887 for subgraph in model.subgraphs: 888 # To get all the read-only buffers: 889 # (1) Get all read-only input tensors. 890 # (2) Discard intermediate or output tensors. 891 # (3) Discard the subgraph's input/output tensors. 892 # (4) Gather the buffers of the read-only input tensors. 893 894 # (1) Get read-only input tensors. 895 read_only_input_tensor_indices = set() 896 for op in subgraph.operators: 897 if op.inputs is None: 898 continue 899 for i, input_tensor_idx in enumerate(op.inputs): 900 # Ignore mutable tensors. 901 if op.mutatingVariableInputs is not None: 902 # Ignore invalid tensors. 903 if (i < len(op.mutatingVariableInputs) and 904 op.mutatingVariableInputs[i]): 905 continue 906 # Ignore variable tensors. 907 if subgraph.tensors[input_tensor_idx].isVariable: 908 continue 909 read_only_input_tensor_indices.add(input_tensor_idx) 910 911 # (2) Discard intermediate or output tensors. 912 for op in subgraph.operators: 913 if op.outputs is not None: 914 for output_tensor_idx in op.outputs: 915 read_only_input_tensor_indices.discard(output_tensor_idx) 916 if op.intermediates is not None: 917 for intermediate_tensor_idx in op.intermediates: 918 read_only_input_tensor_indices.discard(intermediate_tensor_idx) 919 920 # (3) Discard the subgraph's input and output tensors. 921 if subgraph.inputs is not None: 922 for input_tensor_idx in subgraph.inputs: 923 read_only_input_tensor_indices.discard(input_tensor_idx) 924 if subgraph.outputs is not None: 925 for output_tensor_idx in subgraph.outputs: 926 read_only_input_tensor_indices.discard(output_tensor_idx) 927 928 # (4) Gather the buffers of the read-only input tensors. 929 for tensor_idx in read_only_input_tensor_indices: 930 read_only_buffer_indices.add(subgraph.tensors[tensor_idx].buffer) 931 932 # Ignore invalid negative index or zero-sized buffers. 933 for buffer_idx in read_only_buffer_indices.copy(): 934 if (buffer_idx < 0 or (model.buffers[buffer_idx].data is None or 935 isinstance(model.buffers[buffer_idx].data, list) or 936 model.buffers[buffer_idx].data.size == 0)): 937 read_only_buffer_indices.discard(buffer_idx) 938 939 # Sort by buffer size. 940 read_only_buffer_indices = list(read_only_buffer_indices) 941 sorted( 942 read_only_buffer_indices, 943 key=lambda idx: model.buffers[idx].data.data.tobytes()) 944 945 # Create a map of duplicate buffers (same size and same type). 946 # eg: In [1, 2, 3, 4, 5, 6] if (1, 4, 6) and (2, 5) are each, groups of buffer 947 # indices of the same size and type, then the map would be {4:1, 6:1, 5:2} 948 duplicate_buffer_map = {} 949 for i, buffer_i_idx in enumerate(read_only_buffer_indices): 950 # This buffer is a duplicate. 951 if buffer_i_idx in duplicate_buffer_map: 952 continue 953 # This buffer is unique. Scan rest of the list to find duplicates 954 # of this buffer and mark them accordingly. 955 buffer_i = model.buffers[buffer_i_idx] 956 for buffer_j_idx in read_only_buffer_indices[i + 1:]: 957 if buffer_j_idx in duplicate_buffer_map: 958 continue 959 buffer_j = model.buffers[buffer_j_idx] 960 if buffer_i.data.size != buffer_j.data.size: 961 break 962 if buffer_i.data.data != buffer_j.data.data: 963 continue 964 # Found duplicate. Nullify j-th buffer and use i-th buffer instead. 965 duplicate_buffer_map[buffer_j_idx] = buffer_i_idx 966 967 # Make the duplicated tensors use the single shared buffer index. 968 for subgraph in model.subgraphs: 969 for op in subgraph.operators: 970 if op.inputs is None: 971 continue 972 for input_tensor in op.inputs: 973 buffer_idx = subgraph.tensors[input_tensor].buffer 974 if buffer_idx in duplicate_buffer_map: 975 subgraph.tensors[input_tensor].buffer = ( 976 duplicate_buffer_map[buffer_idx]) 977 978 # Nullify the unused buffers. 979 for idx in duplicate_buffer_map: 980 model.buffers[idx].data = None 981 982 # Return a TFLite flatbuffer as a byte array. 983 return flatbuffer_utils.convert_object_to_bytearray(model) 984