1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converts a frozen graph into a TFLite FlatBuffer.""" 16 17import distutils.spawn 18import enum 19import os as _os 20import platform as _platform 21import subprocess as _subprocess 22import tempfile as _tempfile 23import warnings 24 25from tensorflow.lite.python import lite_constants 26from tensorflow.lite.python import util 27from tensorflow.lite.python import wrap_toco 28from tensorflow.lite.python.convert_phase import Component 29from tensorflow.lite.python.convert_phase import convert_phase 30from tensorflow.lite.python.convert_phase import ConverterError 31from tensorflow.lite.python.convert_phase import SubComponent 32from tensorflow.lite.python.metrics.wrapper import metrics_wrapper as _metrics_wrapper 33from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 34from tensorflow.lite.toco import toco_flags_pb2 as _conversion_flags_pb2 35from tensorflow.lite.toco import types_pb2 as _types_pb2 36from tensorflow.lite.tools import flatbuffer_utils 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.platform import resource_loader as _resource_loader 40from tensorflow.python.util import deprecation 41from tensorflow.python.util.tf_export import tf_export as _tf_export 42 43 44def _is_quantized_input_stats_required( 45 conversion_flags: _conversion_flags_pb2.TocoFlags()) -> bool: 46 """Checks if the `quantized_input_stats` flag is required for conversion. 47 48 Args: 49 conversion_flags: A protocol buffer describing the conversion process. 50 51 Returns: 52 True, if the `inference_type` or the `inference_input_type` is a quantized 53 type and it is not post training quantization, else False. 54 """ 55 quantized_inference_types = ([ 56 _types_pb2.QUANTIZED_UINT8, _types_pb2.QUANTIZED_INT8 57 ]) 58 return ((conversion_flags.inference_type in quantized_inference_types or 59 conversion_flags.inference_input_type in quantized_inference_types) 60 and not conversion_flags.post_training_quantize) 61 62 63def convert_tensor_tf_type_to_tflite_type(tf_type: dtypes.DType, 64 usage: str = "" 65 ) -> _types_pb2.IODataType: 66 """Convert tensor type from tf type to tflite type. 67 68 Args: 69 tf_type: TensorFlow type. 70 usage: Text describing the reason for invoking this function. 71 72 Raises: 73 ValueError: If `tf_type` is unsupported. 74 75 Returns: 76 tflite_type: TFLite type. Refer to lite/toco/types.proto. 77 """ 78 mapping = { 79 dtypes.float16: _types_pb2.FLOAT16, 80 dtypes.float32: _types_pb2.FLOAT, 81 dtypes.float64: _types_pb2.FLOAT64, 82 dtypes.int8: _types_pb2.INT8, 83 dtypes.int16: _types_pb2.INT16, 84 dtypes.uint16: _types_pb2.UINT16, 85 dtypes.int32: _types_pb2.INT32, 86 dtypes.int64: _types_pb2.INT64, 87 dtypes.uint8: _types_pb2.UINT8, 88 dtypes.uint32: _types_pb2.UINT32, 89 dtypes.uint64: _types_pb2.UINT64, 90 dtypes.string: _types_pb2.STRING, 91 dtypes.bool: _types_pb2.BOOL, 92 dtypes.complex64: _types_pb2.COMPLEX64, 93 dtypes.complex128: _types_pb2.COMPLEX128, 94 } 95 tflite_type = mapping.get(tf_type) 96 if tflite_type is None: 97 raise ValueError( 98 "Unsupported TensorFlow type `{0}` provided for the {1}".format( 99 tf_type, usage)) 100 return tflite_type 101 102 103# Only a few restricted tensor types are allowed for explicitly setting 104# inference/input/output types. 105def convert_inference_tf_type_to_tflite_type(tf_type: dtypes.DType, 106 usage: str = "" 107 ) -> _types_pb2.IODataType: 108 """Convert inference type from tf type to tflite type. 109 110 Args: 111 tf_type: TensorFlow type. 112 usage: Text describing the reason for invoking this function. 113 114 Raises: 115 ValueError: If `tf_type` is unsupported. 116 117 Returns: 118 tflite_type: TFLite type. Refer to lite/toco/types.proto. 119 """ 120 mapping = { 121 dtypes.float32: _types_pb2.FLOAT, 122 dtypes.uint8: _types_pb2.QUANTIZED_UINT8, 123 dtypes.int8: _types_pb2.QUANTIZED_INT8, 124 dtypes.int16: _types_pb2.QUANTIZED_INT16, 125 } 126 tflite_type = mapping.get(tf_type) 127 if tflite_type is None: 128 raise ValueError( 129 "Unsupported TensorFlow type `{0}` provided for the {1}".format( 130 tf_type, usage)) 131 return tflite_type 132 133 134# Find the deprecated conversion binary using the resource loader if using from 135# bazel, otherwise we are in a pip where console_scripts already has the tool. 136if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: 137 _deprecated_conversion_binary = "" 138else: 139 _deprecated_conversion_binary = _resource_loader.get_path_to_datafile( 140 "../toco/python/toco_from_protos") 141 if not _os.path.exists(_deprecated_conversion_binary): 142 _deprecated_conversion_binary = "toco_from_protos" 143 144 145def _try_convert_to_unicode(output): 146 if output is None: 147 return u"" 148 149 if isinstance(output, bytes): 150 try: 151 return output.decode("utf-8") 152 except UnicodeDecodeError: 153 pass 154 return output 155 156 157@_tf_export("lite.OpsSet") 158class OpsSet(enum.Enum): 159 """Enum class defining the sets of ops available to generate TFLite models. 160 161 WARNING: Experimental interface, subject to change. 162 """ 163 # Convert model using TensorFlow Lite builtin ops. 164 TFLITE_BUILTINS = "TFLITE_BUILTINS" 165 166 # Convert model using TensorFlow ops. Not all TensorFlow ops are available. 167 # WARNING: Experimental interface, subject to change. 168 SELECT_TF_OPS = "SELECT_TF_OPS" 169 170 # Convert model using only TensorFlow Lite quantized int8 operations. 171 # Specifying this will throw an error for operations that do not yet have 172 # quantized implementations. 173 TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8" 174 175 # Convert model using only TensorFlow Lite operations with quantized int8 176 # weights, int16 activations and int64 bias. 177 # Specifying this will throw an error for operations that do not yet have 178 # quantized implementations. 179 # This quantization mode may be used in models for super-resolution, 180 # audio signal processing or image de-noising. It improves accuracy 181 # significantly, but only slightly increases the model size. 182 # WARNING: These ops are currently experimental and have not yet been 183 # finalized. 184 # They are only compatible with CPU execution, and have not been optimized for 185 # production. 186 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = ( 187 "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8") 188 189 def __str__(self): 190 return str(self.value) 191 192 @staticmethod 193 def get_options(): 194 """Returns a list of OpsSet options as a list of strings.""" 195 return [str(option) for option in list(OpsSet)] 196 197 198@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.QUANTIZE) 199def mlir_quantize(input_data_str, 200 disable_per_channel=False, 201 fully_quantize=False, 202 inference_type=_types_pb2.QUANTIZED_INT8, 203 input_data_type=dtypes.float32, 204 output_data_type=dtypes.float32, 205 enable_numeric_verify=False, 206 enable_whole_model_verify=False, 207 denylisted_ops=None, 208 denylisted_nodes=None): 209 """Quantize `input_data_str` with calibration results. 210 211 Args: 212 input_data_str: Input data in serialized form (e.g. a TFLITE model with 213 calibration results). 214 disable_per_channel: Bool indicating whether to do per-channel or per-tensor 215 quantization 216 fully_quantize: Bool indicating whether to fully quantize the model. Besides 217 model body, the input/output will be quantized as well. 218 inference_type: Data type for the activations. The default value is int8. 219 input_data_type: Data type for the inputs. The default value is float32. 220 output_data_type: Data type for the outputs. The default value is float32. 221 enable_numeric_verify: Experimental. Subject to change. Bool indicating 222 whether to add NumericVerify ops into the debug mode quantized model. 223 enable_whole_model_verify: Experimental. Subject to change. Bool indicating 224 whether to add verification for layer by layer, or on whole model. When 225 disabled (per-layer) float and quantized ops will be run from same input 226 (output of previous quantized layer). When enabled, float and quantized 227 ops will run with respective float and quantized output of previous ops. 228 denylisted_ops: Experimental. Subject to change. Set of ops to denylist. 229 denylisted_nodes: Experimental. Subject to change. Set of notes to denylist. 230 231 Returns: 232 Quantized model in serialized form (e.g. a TFLITE model) with floating-point 233 inputs and outputs. 234 """ 235 return wrap_toco.wrapped_experimental_mlir_quantize( 236 input_data_str, disable_per_channel, fully_quantize, inference_type, 237 convert_tensor_tf_type_to_tflite_type(input_data_type), 238 convert_tensor_tf_type_to_tflite_type(output_data_type), 239 enable_numeric_verify, enable_whole_model_verify, denylisted_ops, 240 denylisted_nodes) 241 242 243@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.SPARSIFY) 244def mlir_sparsify(input_data_str): 245 """Sparsify `input_data_str` to encode sparse tensor with proper format. 246 247 Args: 248 input_data_str: Input data in serialized form (e.g. a TFLITE model). 249 250 Returns: 251 Sparsified model in serialized form (e.g. a TFLITE model). 252 """ 253 return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str) 254 255 256def register_custom_opdefs(custom_opdefs_list): 257 """Register the given custom opdefs to the TensorFlow global op registry. 258 259 Args: 260 custom_opdefs_list: String representing the custom ops OpDefs that are 261 included in the GraphDef. 262 263 Returns: 264 True if the registration is successfully completed. 265 """ 266 return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list) 267 268 269def convert(model_flags_str, 270 conversion_flags_str, 271 input_data_str, 272 debug_info_str=None, 273 enable_mlir_converter=True): 274 """Converts `input_data_str` to a TFLite model. 275 276 Args: 277 model_flags_str: Serialized proto describing model properties, see 278 `model_flags.proto`. 279 conversion_flags_str: Serialized proto describing conversion properties, see 280 `toco/toco_flags.proto`. 281 input_data_str: Input data in serialized form (e.g. a graphdef is common, or 282 it can be hlo text or proto) 283 debug_info_str: Serialized `GraphDebugInfo` proto describing logging 284 information. (default None) 285 enable_mlir_converter: Enables MLIR-based conversion. (default True) 286 287 Returns: 288 Converted model in serialized form (e.g. a TFLITE model is common). 289 Raises: 290 ConverterError: When conversion fails in TFLiteConverter, usually due to 291 ops not being supported. 292 RuntimeError: When conversion fails, an exception is raised with the error 293 message embedded. 294 """ 295 # Historically, deprecated conversion failures would trigger a crash, so we 296 # attempt to run the converter out-of-process. The current MLIR conversion 297 # pipeline surfaces errors instead, and can be safely run in-process. 298 if enable_mlir_converter or not _deprecated_conversion_binary: 299 try: 300 model_str = wrap_toco.wrapped_toco_convert(model_flags_str, 301 conversion_flags_str, 302 input_data_str, debug_info_str, 303 enable_mlir_converter) 304 return model_str 305 except Exception as e: 306 converter_error = ConverterError(str(e)) 307 for error_data in _metrics_wrapper.retrieve_collected_errors(): 308 converter_error.append_error(error_data) 309 raise converter_error 310 311 return _run_deprecated_conversion_binary(model_flags_str, 312 conversion_flags_str, input_data_str, 313 debug_info_str) 314 315 316@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 317 SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER) 318def _run_deprecated_conversion_binary(model_flags_str, 319 conversion_flags_str, 320 input_data_str, 321 debug_info_str=None): 322 """Convert `input_data_str` using deprecated conversion binary. 323 324 Args: 325 model_flags_str: Serialized proto describing model properties, see 326 `model_flags.proto`. 327 conversion_flags_str: Serialized proto describing TFLite converter 328 properties, see `toco/toco_flags.proto`. 329 input_data_str: Input data in serialized form (e.g. a graphdef is common) 330 debug_info_str: Serialized `GraphDebugInfo` proto describing logging 331 information. (default None) 332 333 Returns: 334 Converted model in serialized form (e.g. a TFLITE model is common). 335 Raises: 336 ConverterError: When cannot find the deprecated conversion binary. 337 RuntimeError: When conversion fails, an exception is raised with the error 338 message embedded. 339 """ 340 if distutils.spawn.find_executable(_deprecated_conversion_binary) is None: 341 raise ConverterError("""Could not find `toco_from_protos` binary, make sure 342your virtualenv bin directory or pip local bin directory is in your path. 343In particular, if you have installed TensorFlow with --user, make sure you 344add the install directory to your path. 345 346For example: 347Linux: export PATH=$PATH:~/.local/bin/ 348Mac: export PATH=$PATH:~/Library/Python/<version#>/bin 349 350Alternative, use virtualenv.""") 351 # Windows and TemporaryFile are not that useful together, 352 # since you cannot have two readers/writers. So we have to 353 # make the temporaries and close and delete them explicitly. 354 conversion_filename, model_filename, input_filename, output_filename = (None, 355 None, 356 None, 357 None) 358 try: 359 # Build all input files 360 with _tempfile.NamedTemporaryFile(delete=False) as fp_conversion, \ 361 _tempfile.NamedTemporaryFile(delete=False) as fp_model, \ 362 _tempfile.NamedTemporaryFile(delete=False) as fp_input, \ 363 _tempfile.NamedTemporaryFile(delete=False) as fp_debug: 364 conversion_filename = fp_conversion.name 365 input_filename = fp_input.name 366 model_filename = fp_model.name 367 debug_filename = fp_debug.name 368 369 fp_model.write(model_flags_str) 370 fp_conversion.write(conversion_flags_str) 371 fp_input.write(input_data_str) 372 debug_info_str = debug_info_str if debug_info_str else "" 373 # if debug_info_str contains a "string value", then the call to 374 # fp_debug.write(debug_info_str) will fail with the following error 375 # 376 # TypeError: a bytes-like object is required, not 'str' 377 # 378 # Some of the subtests within the "convert_test" unit-test fail 379 # with the error shown above. So watch out for that scenario and 380 # convert debug_info_str to bytes where needed 381 if not isinstance(debug_info_str, bytes): 382 fp_debug.write(debug_info_str.encode("utf-8")) 383 else: 384 fp_debug.write(debug_info_str) 385 386 # Reserve an output file 387 with _tempfile.NamedTemporaryFile(delete=False) as fp: 388 output_filename = fp.name 389 390 # Run 391 cmd = [ 392 _deprecated_conversion_binary, 393 model_filename, 394 conversion_filename, 395 input_filename, 396 output_filename, 397 "--debug_proto_file={}".format(debug_filename), 398 ] 399 cmdline = " ".join(cmd) 400 is_windows = _platform.system() == "Windows" 401 proc = _subprocess.Popen( 402 cmdline, 403 shell=True, 404 stdout=_subprocess.PIPE, 405 stderr=_subprocess.STDOUT, 406 close_fds=not is_windows) 407 stdout, stderr = proc.communicate() 408 exitcode = proc.returncode 409 if exitcode == 0: 410 with open(output_filename, "rb") as fp: 411 return fp.read() 412 else: 413 stdout = _try_convert_to_unicode(stdout) 414 stderr = _try_convert_to_unicode(stderr) 415 raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr)) 416 finally: 417 # Must manually cleanup files. 418 for filename in [ 419 conversion_filename, input_filename, model_filename, output_filename 420 ]: 421 try: 422 _os.unlink(filename) 423 except (OSError, TypeError): 424 pass 425 426 427def build_model_flags(change_concat_input_ranges=False, 428 allow_nonexistent_arrays=False, 429 saved_model_dir=None, 430 saved_model_version=0, 431 saved_model_tags=None, 432 saved_model_exported_names=None, 433 **_): 434 """Builds the model flags object from params. 435 436 Args: 437 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 438 inputs and outputs of the concat operator for quantized models. Changes 439 the ranges of concat operator overlap when true. (default False) 440 allow_nonexistent_arrays: Allow specifying array names that don't exist or 441 are unused in the final graph. (default False) 442 saved_model_dir: Filepath of the saved model to be converted. This value 443 will be non-empty only when the saved model import path will be used. 444 Otherwises, the graph def-based conversion will be processed. 445 saved_model_version: SavedModel file format version of The saved model file 446 to be converted. This value will be set only when the SavedModel import 447 path will be used. 448 saved_model_tags: Set of string saved model tags, formatted in the 449 comma-separated value. This value will be set only when the SavedModel 450 import path will be used. 451 saved_model_exported_names: Names to be exported (default: export all) when 452 the saved model import path is on. This value will be set only when the 453 SavedModel import path will be used. 454 455 Returns: 456 model_flags: protocol buffer describing the model. 457 """ 458 model_flags = _model_flags_pb2.ModelFlags() 459 model_flags.change_concat_input_ranges = change_concat_input_ranges 460 model_flags.allow_nonexistent_arrays = allow_nonexistent_arrays 461 if saved_model_dir: 462 model_flags.saved_model_dir = saved_model_dir 463 model_flags.saved_model_version = saved_model_version 464 if saved_model_tags: 465 model_flags.saved_model_tags.extend(saved_model_tags) 466 if saved_model_exported_names: 467 model_flags.saved_model_exported_names.extend(saved_model_exported_names) 468 return model_flags 469 470 471def build_conversion_flags(inference_type=dtypes.float32, 472 inference_input_type=None, 473 input_format=lite_constants.TENSORFLOW_GRAPHDEF, 474 output_format=lite_constants.TFLITE, 475 default_ranges_stats=None, 476 drop_control_dependency=True, 477 reorder_across_fake_quant=False, 478 allow_custom_ops=False, 479 post_training_quantize=False, 480 quantize_to_float16=False, 481 dump_graphviz_dir=None, 482 dump_graphviz_video=False, 483 target_ops=None, 484 conversion_summary_dir=None, 485 select_user_tf_ops=None, 486 allow_all_select_tf_ops=False, 487 enable_tflite_resource_variables=True, 488 unfold_batchmatmul=True, 489 lower_tensor_list_ops=True, 490 default_to_single_batch_in_tensor_list_ops=False, 491 accumulation_type=None, 492 allow_bfloat16=False, 493 unfold_large_splat_constant=False, 494 supported_backends=None, 495 disable_per_channel_quantization=False, 496 enable_mlir_dynamic_range_quantizer=False, 497 tf_quantization_mode=None, 498 disable_infer_tensor_range=False, 499 use_fake_quant_num_bits=False, 500 enable_dynamic_update_slice=False, 501 preserve_assert_op=False, 502 guarantee_all_funcs_one_use=False, 503 **_): 504 """Builds protocol buffer describing a conversion of a model. 505 506 Typically this is to convert from TensorFlow GraphDef to TFLite, in which 507 case the default `input_format` and `output_format` are sufficient. 508 509 Args: 510 inference_type: Data type of numeric arrays, excluding the input layer. 511 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 512 inference_input_type: Data type of the numeric arrays in the input layer. If 513 `inference_input_type` is in {tf.int8, tf.uint8}, then 514 `quantized_input_stats` must be provided. (default is the value assigned 515 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 516 input_format: Type of data to read. (default TENSORFLOW_GRAPHDEF, must be in 517 {TENSORFLOW_GRAPHDEF}) 518 output_format: Output file format. (default TFLITE, must be in {TFLITE, 519 GRAPHVIZ_DOT}) 520 default_ranges_stats: Tuple of integers representing (min, max) range values 521 for all arrays without a specified range. Intended for experimenting with 522 quantization via "dummy quantization". (default None) 523 drop_control_dependency: Boolean indicating whether to drop control 524 dependencies silently. This is due to TFLite not supporting control 525 dependencies. (default True) 526 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 527 nodes in unexpected locations. Used when the location of the FakeQuant 528 nodes is preventing graph transformations necessary to convert the graph. 529 Results in a graph that differs from the quantized training graph, 530 potentially causing differing arithmetic behavior. (default False) 531 allow_custom_ops: Boolean indicating whether to allow custom operations. 532 When false any unknown operation is an error. When true, custom ops are 533 created for any op that is unknown. The developer will need to provide 534 these to the TensorFlow Lite runtime with a custom resolver. (default 535 False) 536 post_training_quantize: Boolean indicating whether to quantize the weights 537 of the converted float model. Model size will be reduced and there will be 538 latency improvements (at the cost of accuracy). (default False) 539 quantize_to_float16: Boolean indicating whether to convert float buffers to 540 float16. (default False) 541 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 542 stages of processing GraphViz .dot files. Preferred over 543 --output_format=GRAPHVIZ_DOT in order to keep the requirements of the 544 output file. (default None) 545 dump_graphviz_video: Boolean indicating whether to dump the graph after 546 every graph transformation. (default False) 547 target_ops: Experimental flag, subject to change. Set of OpsSet options 548 indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS])) 549 conversion_summary_dir: A string, the path to the generated conversion logs. 550 select_user_tf_ops: List of user's defined TensorFlow ops need to be 551 supported in the TensorFlow Lite runtime. These ops will be supported as 552 select TensorFlow ops. 553 allow_all_select_tf_ops: If True, automatically add all TF ops (including 554 custom TF ops) to the converted model as flex ops. 555 enable_tflite_resource_variables: Experimental flag, subject to change. 556 Enables conversion of resource variables. (default False) 557 unfold_batchmatmul: Whether to unfold tf.BatchMatMul to a set of 558 tfl.fully_connected ops. If not, translate to tfl.batch_matmul. 559 lower_tensor_list_ops: Whether to lower tensor list ops to builtin ops. If 560 not, use Flex tensor list ops. 561 default_to_single_batch_in_tensor_list_ops: Whether to force to use batch 562 size one when the tensor list ops has the unspecified batch size. 563 accumulation_type: Data type of the accumulators in quantized inference. 564 Typically used for float16 quantization and is either fp16 or fp32. 565 allow_bfloat16: Whether the converted model supports reduced precision 566 inference with the bfloat16 type. 567 unfold_large_splat_constant: Whether to unfold large splat constant tensors 568 in the flatbuffer model to reduce size. 569 supported_backends: List of TFLite backends which needs to check 570 compatibility. 571 disable_per_channel_quantization: Disable per-channel quantized weights for 572 dynamic range quantization. Only per-tensor quantization will be used. 573 enable_mlir_dynamic_range_quantizer: Enable MLIR dynamic range quantization. 574 If False, the old converter dynamic range quantizer is used. 575 tf_quantization_mode: Indicates the mode of TF Quantization when the output 576 model is used for TF Quantization. 577 disable_infer_tensor_range: Disable infering tensor ranges. 578 use_fake_quant_num_bits: Allow quantization parameters to be calculated from 579 num_bits attribute. 580 enable_dynamic_update_slice: Enable to convert to DynamicUpdateSlice op. 581 (default: False). 582 preserve_assert_op: Whether to preserve `TF::AssertOp` (default: False). 583 guarantee_all_funcs_one_use: Whether to clone functions so that each 584 function only has a single use. This option will be helpful if the 585 conversion fails when the `PartitionedCall` or `StatefulPartitionedCall` 586 can't be properly inlined (default: False). 587 588 Returns: 589 conversion_flags: protocol buffer describing the conversion process. 590 Raises: 591 ValueError, if the input tensor type is unknown. 592 """ 593 conversion_flags = _conversion_flags_pb2.TocoFlags() 594 conversion_flags.inference_type = convert_inference_tf_type_to_tflite_type( 595 inference_type, usage="inference_type flag") 596 if inference_input_type: 597 conversion_flags.inference_input_type = ( 598 convert_inference_tf_type_to_tflite_type( 599 inference_input_type, usage="inference_input_type flag")) 600 else: 601 conversion_flags.inference_input_type = conversion_flags.inference_type 602 conversion_flags.input_format = input_format 603 conversion_flags.output_format = output_format 604 if default_ranges_stats: 605 conversion_flags.default_ranges_min = default_ranges_stats[0] 606 conversion_flags.default_ranges_max = default_ranges_stats[1] 607 conversion_flags.drop_control_dependency = drop_control_dependency 608 conversion_flags.reorder_across_fake_quant = reorder_across_fake_quant 609 conversion_flags.allow_custom_ops = allow_custom_ops 610 conversion_flags.post_training_quantize = post_training_quantize 611 conversion_flags.quantize_to_float16 = quantize_to_float16 612 if dump_graphviz_dir: 613 conversion_flags.dump_graphviz_dir = dump_graphviz_dir 614 conversion_flags.dump_graphviz_include_video = dump_graphviz_video 615 if target_ops: 616 if OpsSet.SELECT_TF_OPS in target_ops: 617 conversion_flags.enable_select_tf_ops = True 618 if set(target_ops) == {OpsSet.SELECT_TF_OPS}: 619 conversion_flags.force_select_tf_ops = True 620 if conversion_summary_dir: 621 conversion_flags.conversion_summary_dir = conversion_summary_dir 622 if select_user_tf_ops: 623 conversion_flags.select_user_tf_ops.extend(select_user_tf_ops) 624 conversion_flags.allow_all_select_tf_ops = allow_all_select_tf_ops 625 conversion_flags.enable_tflite_resource_variables = ( 626 enable_tflite_resource_variables) 627 conversion_flags.unfold_batchmatmul = unfold_batchmatmul 628 conversion_flags.lower_tensor_list_ops = lower_tensor_list_ops 629 conversion_flags.default_to_single_batch_in_tensor_list_ops = ( 630 default_to_single_batch_in_tensor_list_ops) 631 if accumulation_type: 632 conversion_flags.accumulation_type = convert_tensor_tf_type_to_tflite_type( 633 accumulation_type, usage="accumulation_type flag") 634 conversion_flags.allow_bfloat16 = allow_bfloat16 635 conversion_flags.unfold_large_splat_constant = unfold_large_splat_constant 636 if supported_backends: 637 conversion_flags.supported_backends.extend(supported_backends) 638 conversion_flags.disable_per_channel_quantization = ( 639 disable_per_channel_quantization) 640 conversion_flags.enable_mlir_dynamic_range_quantizer = ( 641 enable_mlir_dynamic_range_quantizer) 642 conversion_flags.enable_dynamic_update_slice = enable_dynamic_update_slice 643 conversion_flags.preserve_assert_op = preserve_assert_op 644 conversion_flags.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use 645 if tf_quantization_mode: 646 conversion_flags.tf_quantization_mode = tf_quantization_mode 647 conversion_flags.disable_infer_tensor_range = disable_infer_tensor_range 648 conversion_flags.use_fake_quant_num_bits = use_fake_quant_num_bits 649 return conversion_flags 650 651 652@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 653 SubComponent.CONVERT_GRAPHDEF) 654def convert_graphdef_with_arrays(input_data, input_arrays_with_shape, 655 output_arrays, control_output_arrays, 656 **kwargs): 657 """"Convert a frozen GraphDef that can't be loaded in TF. 658 659 Conversion can be customized by providing arguments that are forwarded to 660 `build_model_flags` and `build_conversion_flags` (see documentation). 661 662 Args: 663 input_data: Input data (i.e. often `sess.graph_def`), 664 input_arrays_with_shape: Tuple of strings representing input tensor names 665 and list of integers representing input shapes 666 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 667 into TensorFlow and when `input_tensors` is None. 668 output_arrays: List of output tensors to freeze graph with. Use only when 669 graph cannot be loaded into TensorFlow and when `output_tensors` is None. 670 control_output_arrays: Control output node names. This is used when 671 converting a Graph with no output tensors. For example, if the graph's 672 last operation is a Print op, just specify that op's name in this field. 673 This can be used together with the `output_arrays` parameter. 674 **kwargs: See `build_model_flags` and `build_conversion_flags`. 675 676 Returns: 677 The converted data. For example if TFLite was the destination, then 678 this will be a tflite flatbuffer in a bytes array. 679 680 Raises: 681 Defined in `build_conversion_flags`. 682 """ 683 model_flags = build_model_flags(**kwargs) 684 conversion_flags = build_conversion_flags(**kwargs) 685 enable_mlir_converter = kwargs.get("enable_mlir_converter", True) 686 quantized_input_stats = kwargs.get("quantized_input_stats", None) 687 688 for idx, (name, shape) in enumerate(input_arrays_with_shape): 689 input_array = model_flags.input_arrays.add() 690 if _is_quantized_input_stats_required(conversion_flags): 691 if quantized_input_stats: 692 input_array.mean_value, input_array.std_value = ( 693 quantized_input_stats[idx]) 694 else: 695 raise ValueError( 696 "The `quantized_input_stats` flag must be defined when either " 697 "`inference_type` flag or `inference_input_type` flag is set to " 698 "tf.int8 or tf.uint8.") 699 input_array.name = name 700 input_array.shape.dims.extend(list(map(int, shape))) 701 702 if output_arrays: 703 for name in output_arrays: 704 model_flags.output_arrays.append(name) 705 if control_output_arrays: 706 for name in control_output_arrays: 707 model_flags.control_output_arrays.append(name) 708 709 data = convert( 710 model_flags.SerializeToString(), 711 conversion_flags.SerializeToString(), 712 input_data.SerializeToString(), 713 debug_info_str=None, 714 enable_mlir_converter=enable_mlir_converter) 715 return data 716 717 718@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 719 SubComponent.CONVERT_GRAPHDEF) 720def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs): 721 """Convert a frozen GraphDef model using the TF Lite converter. 722 723 Conversion can be customized by providing arguments that are forwarded to 724 `build_model_flags` and `build_conversion_flags` (see documentation). 725 726 Args: 727 input_data: Input data (i.e. often `sess.graph_def`), 728 input_tensors: List of input tensors. Type and shape are computed using 729 `foo.shape` and `foo.dtype`. 730 output_tensors: List of output tensors (only .name is used from this). 731 **kwargs: See `build_model_flags` and `build_conversion_flags`. 732 733 Returns: 734 The converted data. For example if TFLite was the destination, then 735 this will be a tflite flatbuffer in a bytes array. 736 737 Raises: 738 Defined in `build_conversion_flags`. 739 """ 740 model_flags = build_model_flags(**kwargs) 741 conversion_flags = build_conversion_flags(**kwargs) 742 saved_model_dir = kwargs.get("saved_model_dir", None) 743 input_shapes = kwargs.get("input_shapes", None) 744 enable_mlir_converter = kwargs.get("enable_mlir_converter", True) 745 quantized_input_stats = kwargs.get("quantized_input_stats", None) 746 debug_info = kwargs.get("debug_info", None) 747 748 for idx, input_tensor in enumerate(input_tensors): 749 input_array = model_flags.input_arrays.add() 750 if saved_model_dir: 751 input_array.name = input_tensor.name 752 else: 753 input_array.name = util.get_tensor_name(input_tensor) 754 input_array.data_type = convert_tensor_tf_type_to_tflite_type( 755 input_tensor.dtype, usage="input type of the TensorFlow model") 756 757 if _is_quantized_input_stats_required(conversion_flags): 758 if quantized_input_stats: 759 input_array.mean_value, input_array.std_value = ( 760 quantized_input_stats[idx]) 761 else: 762 # We should ideally raise an error here, but we don't as it would break 763 # several models/projects that depend on this workflow. 764 warnings.warn("Statistics for quantized inputs were expected, but not " 765 "specified; continuing anyway.") 766 767 if input_shapes is None: 768 shape = input_tensor.shape 769 else: 770 shape = input_shapes[idx] 771 772 if shape.rank is not None: 773 # Create shapes with -1 for unknown dimensions. 774 dims = [] 775 for dim in shape: 776 if (dim is None or 777 (isinstance(dim, tensor_shape.Dimension) and dim.value is None)): 778 dims.append(-1) 779 else: 780 dims.append(int(dim)) 781 input_array.shape.dims.extend(dims) 782 input_array.shape.unknown_rank = False 783 else: 784 input_array.shape.unknown_rank = True 785 786 for output_tensor in output_tensors: 787 if saved_model_dir: 788 model_flags.output_arrays.append(output_tensor.name) 789 else: 790 model_flags.output_arrays.append(util.get_tensor_name(output_tensor)) 791 792 data = convert( 793 model_flags.SerializeToString(), 794 conversion_flags.SerializeToString(), 795 input_data.SerializeToString(), 796 debug_info_str=debug_info.SerializeToString() if debug_info else None, 797 enable_mlir_converter=enable_mlir_converter) 798 return data 799 800 801@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 802 SubComponent.CONVERT_SAVED_MODEL) 803def convert_saved_model(**kwargs): 804 """Converts a SavedModel using TF Lite converter.""" 805 model_flags = build_model_flags(**kwargs) 806 conversion_flags = build_conversion_flags(**kwargs) 807 data = convert( 808 model_flags.SerializeToString(), 809 conversion_flags.SerializeToString(), 810 input_data_str=None, 811 debug_info_str=None, 812 enable_mlir_converter=True) 813 return data 814 815 816@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL, 817 SubComponent.CONVERT_JAX_HLO) 818def convert_jax_hlo(input_content, input_names, is_proto_format, **kwargs): 819 """Converts a Jax hlo-based model using TFLite converter.""" 820 model_flags = _model_flags_pb2.ModelFlags() 821 model_flags.use_hlo_import = True 822 if is_proto_format: 823 model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_PROTO 824 else: 825 model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_TEXT 826 827 # Build input names. 828 for input_name in input_names: 829 input_array = model_flags.input_arrays.add() 830 input_array.name = input_name 831 832 conversion_flags = build_conversion_flags(**kwargs) 833 data = convert( 834 model_flags.SerializeToString(), 835 conversion_flags.SerializeToString(), 836 input_data_str=input_content, 837 debug_info_str=None, 838 enable_mlir_converter=True) 839 return data 840 841 842@_tf_export(v1=["lite.toco_convert"]) 843@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.") 844def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): 845 """Convert a TensorFlow GraphDef to TFLite. 846 847 This function is deprecated. Please use `tf.lite.TFLiteConverter` API instead. 848 Conversion can be customized by providing arguments that are forwarded to 849 `build_model_flags` and `build_conversion_flags` (see documentation for 850 details). 851 Args: 852 input_data: Input data (i.e. often `sess.graph_def`). 853 input_tensors: List of input tensors. Type and shape are computed using 854 `foo.shape` and `foo.dtype`. 855 output_tensors: List of output tensors (only .name is used from this). 856 *args: See `build_model_flags` and `build_conversion_flags`. 857 **kwargs: See `build_model_flags` and `build_conversion_flags`. 858 859 Returns: 860 The converted TensorFlow Lite model in a bytes array. 861 862 Raises: 863 Defined in `convert`. 864 """ 865 kwargs["enable_mlir_converter"] = kwargs.get("enable_mlir_converter", False) 866 return convert_graphdef(input_data, input_tensors, output_tensors, *args, 867 **kwargs) 868 869 870def deduplicate_readonly_buffers(tflite_model): 871 """"Generates a new model byte array after deduplicating readonly buffers. 872 873 This function should be invoked after the model optimization toolkit. The 874 model optimization toolkit assumes that each tensor object owns its each 875 buffer separately. 876 877 Args: 878 tflite_model: TFLite flatbuffer in a byte array to be deduplicated. 879 880 Returns: 881 TFLite flatbuffer in a bytes array, processed with the deduplication method. 882 883 """ 884 # Load TFLite Flatbuffer byte array into an object. 885 model = flatbuffer_utils.convert_bytearray_to_object(tflite_model) 886 887 # Get all the read-only buffers, which can be modified without causing any 888 # issue in the graph invocation stage. 889 read_only_buffer_indices = set() 890 for subgraph in model.subgraphs: 891 # To get all the read-only buffers: 892 # (1) Get all read-only input tensors. 893 # (2) Discard intermediate or output tensors. 894 # (3) Discard the subgraph's input/output tensors. 895 # (4) Gather the buffers of the read-only input tensors. 896 897 # (1) Get read-only input tensors. 898 read_only_input_tensor_indices = set() 899 for op in subgraph.operators: 900 if op.inputs is None: 901 continue 902 for i, input_tensor_idx in enumerate(op.inputs): 903 # Ignore mutable tensors. 904 if op.mutatingVariableInputs is not None: 905 # Ignore invalid tensors. 906 if (i < len(op.mutatingVariableInputs) and 907 op.mutatingVariableInputs[i]): 908 continue 909 # Ignore variable tensors. 910 if subgraph.tensors[input_tensor_idx].isVariable: 911 continue 912 read_only_input_tensor_indices.add(input_tensor_idx) 913 914 # (2) Discard intermediate or output tensors. 915 for op in subgraph.operators: 916 if op.outputs is not None: 917 for output_tensor_idx in op.outputs: 918 read_only_input_tensor_indices.discard(output_tensor_idx) 919 if op.intermediates is not None: 920 for intermediate_tensor_idx in op.intermediates: 921 read_only_input_tensor_indices.discard(intermediate_tensor_idx) 922 923 # (3) Discard the subgraph's input and output tensors. 924 if subgraph.inputs is not None: 925 for input_tensor_idx in subgraph.inputs: 926 read_only_input_tensor_indices.discard(input_tensor_idx) 927 if subgraph.outputs is not None: 928 for output_tensor_idx in subgraph.outputs: 929 read_only_input_tensor_indices.discard(output_tensor_idx) 930 931 # (4) Gather the buffers of the read-only input tensors. 932 for tensor_idx in read_only_input_tensor_indices: 933 read_only_buffer_indices.add(subgraph.tensors[tensor_idx].buffer) 934 935 # Ignore invalid negative index or zero-sized buffers. 936 for buffer_idx in read_only_buffer_indices.copy(): 937 if (buffer_idx < 0 or (model.buffers[buffer_idx].data is None or 938 isinstance(model.buffers[buffer_idx].data, list) or 939 model.buffers[buffer_idx].data.size == 0)): 940 read_only_buffer_indices.discard(buffer_idx) 941 942 # Sort by buffer size. 943 read_only_buffer_indices = list(read_only_buffer_indices) 944 sorted( 945 read_only_buffer_indices, 946 key=lambda idx: model.buffers[idx].data.data.tobytes()) 947 948 # Create a map of duplicate buffers (same size and same type). 949 # eg: In [1, 2, 3, 4, 5, 6] if (1, 4, 6) and (2, 5) are each, groups of buffer 950 # indices of the same size and type, then the map would be {4:1, 6:1, 5:2} 951 duplicate_buffer_map = {} 952 for i, buffer_i_idx in enumerate(read_only_buffer_indices): 953 # This buffer is a duplicate. 954 if buffer_i_idx in duplicate_buffer_map: 955 continue 956 # This buffer is unique. Scan rest of the list to find duplicates 957 # of this buffer and mark them accordingly. 958 buffer_i = model.buffers[buffer_i_idx] 959 for buffer_j_idx in read_only_buffer_indices[i + 1:]: 960 if buffer_j_idx in duplicate_buffer_map: 961 continue 962 buffer_j = model.buffers[buffer_j_idx] 963 if buffer_i.data.size != buffer_j.data.size: 964 break 965 if buffer_i.data.data != buffer_j.data.data: 966 continue 967 # Found duplicate. Nullify j-th buffer and use i-th buffer instead. 968 duplicate_buffer_map[buffer_j_idx] = buffer_i_idx 969 970 # Make the duplicated tensors use the single shared buffer index. 971 for subgraph in model.subgraphs: 972 for op in subgraph.operators: 973 if op.inputs is None: 974 continue 975 for input_tensor in op.inputs: 976 buffer_idx = subgraph.tensors[input_tensor].buffer 977 if buffer_idx in duplicate_buffer_map: 978 subgraph.tensors[input_tensor].buffer = ( 979 duplicate_buffer_map[buffer_idx]) 980 981 # Nullify the unused buffers. 982 for idx in duplicate_buffer_map: 983 model.buffers[idx].data = None 984 985 # Return a TFLite flatbuffer as a byte array. 986 return flatbuffer_utils.convert_object_to_bytearray(model) 987