1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Lite tooling helper functionality.""" 16 17import enum 18import functools 19import pprint 20import shutil 21import tempfile 22import time 23import warnings 24 25from absl import logging 26 27from google.protobuf import text_format as _text_format 28from google.protobuf.message import DecodeError 29from tensorflow.core.framework import graph_pb2 as _graph_pb2 30from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import 31from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metdata_fb 32from tensorflow.lite.python import lite_constants as constants 33from tensorflow.lite.python.convert import convert_graphdef as _convert_graphdef 34from tensorflow.lite.python.convert import convert_graphdef_with_arrays as _convert_graphdef_with_arrays 35from tensorflow.lite.python.convert import convert_jax_hlo as _convert_jax_hlo 36from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model 37from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import 38from tensorflow.lite.python.convert import deduplicate_readonly_buffers as _deduplicate_readonly_buffers 39from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize 40from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify 41from tensorflow.lite.python.convert import OpsSet 42from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import 43from tensorflow.lite.python.convert_phase import Component 44from tensorflow.lite.python.convert_phase import convert_phase 45from tensorflow.lite.python.convert_phase import SubComponent 46from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model 47from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import 48from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import 49from tensorflow.lite.python.interpreter import OpResolverType # pylint: disable=unused-import 50from tensorflow.lite.python.metrics import metrics 51from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import 52from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted 53from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import 54from tensorflow.lite.python.optimize import calibrator as _calibrator 55from tensorflow.lite.python.util import _xla_computation 56from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func 57from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func 58from tensorflow.lite.python.util import freeze_graph as _freeze_graph 59from tensorflow.lite.python.util import get_debug_info as _get_debug_info 60from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config 61from tensorflow.lite.python.util import get_sparsity_modes as _get_sparsity_modes 62from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name 63from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names 64from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name 65from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph 66from tensorflow.lite.python.util import model_input_signature as _model_input_signature 67from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type 68from tensorflow.lite.python.util import populate_conversion_metadata as _populate_conversion_metadata 69from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations 70from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes 71from tensorflow.lite.python.util import trace_model_call as _trace_model_call 72from tensorflow.lite.tools import flatbuffer_utils 73from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugger # pylint: disable=unused-import 74from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugOptions # pylint: disable=unused-import 75from tensorflow.python import saved_model as _saved_model 76from tensorflow.python.client import session as _session 77from tensorflow.python.eager import context 78from tensorflow.python.eager import def_function as _def_function 79from tensorflow.python.eager import function as _function 80from tensorflow.python.framework import convert_to_constants as _convert_to_constants 81from tensorflow.python.framework import dtypes as _dtypes 82from tensorflow.python.framework import ops as _ops 83from tensorflow.python.framework import versions 84from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError 85from tensorflow.python.framework.importer import import_graph_def as _import_graph_def 86from tensorflow.python.platform import gfile 87from tensorflow.python.saved_model import loader_impl as _loader_impl 88from tensorflow.python.saved_model import save_options as _save_options 89from tensorflow.python.saved_model import signature_constants as _signature_constants 90from tensorflow.python.saved_model import tag_constants as _tag_constants 91from tensorflow.python.saved_model.load import load as _load 92from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info 93from tensorflow.python.util import deprecation as _deprecation 94from tensorflow.python.util import keras_deps 95from tensorflow.python.util.tf_export import tf_export as _tf_export 96 97 98@_tf_export("lite.Optimize") 99class Optimize(enum.Enum): 100 """Enum defining the optimizations to apply when generating a tflite model. 101 102 DEFAULT 103 Default optimization strategy that quantizes model weights. Enhanced 104 optimizations are gained by providing a representative dataset that 105 quantizes biases and activations as well. 106 Converter will do its best to reduce size and latency, while minimizing 107 the loss in accuracy. 108 109 OPTIMIZE_FOR_SIZE 110 Deprecated. Does the same as DEFAULT. 111 112 OPTIMIZE_FOR_LATENCY 113 Deprecated. Does the same as DEFAULT. 114 115 EXPERIMENTAL_SPARSITY 116 Experimental flag, subject to change. 117 118 Enable optimization by taking advantage of the sparse model weights 119 trained with pruning. 120 121 The converter will inspect the sparsity pattern of the model weights and 122 do its best to improve size and latency. 123 The flag can be used alone to optimize float32 models with sparse weights. 124 It can also be used together with the DEFAULT optimization mode to 125 optimize quantized models with sparse weights. 126 """ 127 128 # Default optimization strategy that quantizes model weights. Enhanced 129 # optimizations are gained by providing a representative dataset that 130 # quantizes biases and activations as well. 131 # Converter will do its best to reduce size and latency, while minimizing 132 # the loss in accuracy. 133 DEFAULT = "DEFAULT" 134 135 # Deprecated. Does the same as DEFAULT. 136 OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE" 137 138 # Deprecated. Does the same as DEFAULT. 139 OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY" 140 141 # Experimental flag, subject to change. 142 # Enable optimization by taking advantage of the sparse model weights trained 143 # with pruning. 144 # 145 # The converter will inspect the sparsity pattern of the model weights and do 146 # its best to improve size and latency. 147 # The flag can be used alone to optimize float32 models with sparse weights. 148 # It can also be used together with the DEFAULT optimization mode to optimize 149 # quantized models with sparse weights. 150 # TODO(b/161560631): Add log message when this optimization is applied. 151 EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY" 152 153 def __str__(self): 154 return str(self.value) 155 156 157# TODO(b/198099651): move converter implementation out of lite.py 158@_tf_export("lite.RepresentativeDataset") 159class RepresentativeDataset: 160 """Representative dataset used to optimize the model. 161 162 This is a generator function that provides a small dataset to calibrate or 163 estimate the range, i.e, (min, max) of all floating-point arrays in the model 164 (such as model input, activation outputs of intermediate layers, and model 165 output) for quantization. Usually, this is a small subset of a few hundred 166 samples randomly chosen, in no particular order, from the training or 167 evaluation dataset. 168 """ 169 170 def __init__(self, input_gen): 171 """Creates a representative dataset. 172 173 Args: 174 input_gen: A generator function that generates input samples for the 175 model and has the same order, type and shape as the inputs to the model. 176 Usually, this is a small subset of a few hundred samples randomly 177 chosen, in no particular order, from the training or evaluation dataset. 178 """ 179 self.input_gen = input_gen 180 181 182@_tf_export("lite.TargetSpec") 183class TargetSpec: 184 """Specification of target device used to optimize the model. 185 186 Attributes: 187 supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet` 188 options, where each option represents a set of operators supported by the 189 target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS})) 190 supported_types: Set of `tf.dtypes.DType` data types supported on the target 191 device. If initialized, optimization might be driven by the smallest type 192 in this set. (default set()) 193 experimental_select_user_tf_ops: Experimental flag, subject to change. Set 194 of user's TensorFlow operators' names that are required in the TensorFlow 195 Lite runtime. These ops will be exported as select TensorFlow ops in the 196 model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is 197 an advanced feature that should only be used if the client is using TF ops 198 that may not be linked in by default with the TF ops that are provided 199 when using the SELECT_TF_OPS path. The client is responsible for linking 200 these ops into the target runtime. 201 experimental_supported_backends: Experimental flag, subject to change. 202 Set containing names of supported backends. Currently only "GPU" is 203 supported, more options will be available later. 204 """ 205 206 def __init__(self, 207 supported_ops=None, 208 supported_types=None, 209 experimental_select_user_tf_ops=None, 210 experimental_supported_backends=None): 211 if supported_ops is None: 212 supported_ops = {OpsSet.TFLITE_BUILTINS} 213 self.supported_ops = supported_ops 214 if supported_types is None: 215 supported_types = set() 216 self.supported_types = supported_types 217 if experimental_select_user_tf_ops is None: 218 experimental_select_user_tf_ops = set() 219 self.experimental_select_user_tf_ops = experimental_select_user_tf_ops 220 self.experimental_supported_backends = experimental_supported_backends 221 self._experimental_custom_op_registerers = [] 222 # Hint for the supported accumulation type used for inference. Typically 223 # used for fp16 post-training quantization, where some models can use fp16 224 # accumulators instead of the typical fp32 type. 225 # TODO(b/188185962): Provide full API and authoring support for 226 # reduced precision accumulation types. 227 self._experimental_supported_accumulation_type = None 228 229 230class QuantizationMode: 231 """QuantizationMode determines the quantization type from user options.""" 232 233 def __init__(self, 234 optimizations, 235 target_spec, 236 representative_dataset, 237 graph_def, 238 disable_per_channel=False, 239 experimental_new_dynamic_range_quantizer=False, 240 experimental_low_bit_qat=False, 241 full_integer_quantization_bias_type=None): 242 self._optimizations = optimizations 243 for deprecated_optimization in [ 244 Optimize.OPTIMIZE_FOR_SIZE, Optimize.OPTIMIZE_FOR_LATENCY 245 ]: 246 if deprecated_optimization in self._optimizations: 247 logging.warning( 248 "Optimization option %s is deprecated, please use optimizations=" 249 "[Optimize.DEFAULT] instead.", deprecated_optimization) 250 251 self._target_spec = target_spec 252 self._representative_dataset = representative_dataset 253 self._graph_def = graph_def 254 255 self._validate_int8_required() 256 self._disable_per_channel = disable_per_channel 257 258 self._enable_new_dynamic_range_quantizer = ( 259 experimental_new_dynamic_range_quantizer) 260 # Allow training with lower than 8 bit weights to be converted 261 # to constants with trained scale. 262 self._experimental_low_bit_qat = experimental_low_bit_qat 263 264 self._full_integer_quantization_bias_type = full_integer_quantization_bias_type 265 self._validate_full_integer_quantization_bias_type() 266 267 def is_post_training_int8_only_quantization(self): 268 return (self.is_any_optimization_enabled() and 269 self._representative_dataset is not None and 270 not self._is_int16x8_target_required() and 271 not self.is_allow_float() and 272 self._is_int8_target_required()) 273 274 def is_post_training_int8_quantization_with_float_fallback(self): 275 return (self.is_any_optimization_enabled() and 276 self._representative_dataset is not None and 277 not self._is_int16x8_target_required() and 278 self.is_allow_float() and 279 self._smallest_supported_type() == _dtypes.int8) 280 281 def is_post_training_int8_quantization(self): 282 return (self.is_post_training_int8_only_quantization() or 283 self.is_post_training_int8_quantization_with_float_fallback()) 284 285 def is_post_training_int16x8_only_quantization(self): 286 return (self.is_any_optimization_enabled() and 287 self._representative_dataset is not None and 288 self._is_int16x8_target_required() and 289 not self.is_allow_float()) 290 291 def is_post_training_int16x8_quantization_with_float_fallback(self): 292 return (self.is_any_optimization_enabled() and 293 self._representative_dataset is not None and 294 self._is_int16x8_target_required() and 295 self.is_allow_float()) 296 297 def is_post_training_int16x8_quantization(self): 298 return (self.is_post_training_int16x8_only_quantization() or 299 self.is_post_training_int16x8_quantization_with_float_fallback()) 300 301 def is_post_training_integer_quantization(self): 302 return (self.is_post_training_int8_quantization() or 303 self.is_post_training_int16x8_quantization()) 304 305 def is_low_bit_quantize_aware_training(self): 306 return (self.is_any_optimization_enabled() and 307 self.is_quantization_aware_trained_model() and 308 self._experimental_low_bit_qat) 309 310 def is_quantization_aware_training(self): 311 return (self.is_any_optimization_enabled() and 312 self.is_quantization_aware_trained_model() and 313 not self.is_low_bit_quantize_aware_training()) 314 315 def is_integer_quantization(self): 316 return (self.is_post_training_integer_quantization() or 317 self.is_quantization_aware_training() or 318 self.is_low_bit_quantize_aware_training()) 319 320 def is_post_training_dynamic_range_quantization(self): 321 # Post-training dynamic range quantization is only enabled if post-training 322 # int8 quantization and training time quantization was not done. 323 return (self.is_any_optimization_enabled() and 324 self._representative_dataset is None and 325 not self.is_quantization_aware_trained_model() and 326 self._smallest_supported_type() == _dtypes.int8) 327 328 def is_post_training_float16_quantization(self): 329 return (self.is_any_optimization_enabled() and 330 self._smallest_supported_type().size == 2 and 331 _dtypes.float16 in self._target_spec.supported_types) 332 333 def is_bfloat16_quantization(self): 334 return (self.is_any_optimization_enabled() and 335 self._smallest_supported_type().size == 2 and 336 _dtypes.bfloat16 in self._target_spec.supported_types) 337 338 def activations_type(self): 339 if self.is_integer_quantization(): 340 if self._is_int16x8_target_required(): 341 return _dtypes.int16 342 else: 343 return _dtypes.int8 344 else: 345 return _dtypes.float32 346 347 def bias_type(self): 348 if self._full_integer_quantization_bias_type: 349 return self._full_integer_quantization_bias_type 350 351 if self.activations_type() == _dtypes.int16: 352 return _dtypes.int64 353 elif self.activations_type() == _dtypes.int8: 354 return _dtypes.int32 355 else: 356 return _dtypes.float32 357 358 def converter_flags(self, inference_ty=None, inference_input_ty=None): 359 """Flags to the converter.""" 360 361 if self.is_integer_quantization(): 362 is_low_bit_qat = self.is_low_bit_quantize_aware_training() 363 return { 364 "inference_type": (inference_ty if inference_ty is not None else 365 self.activations_type()), 366 "inference_input_type": _dtypes.float32, 367 "post_training_quantize": False, # disable dynamic range quantization 368 "quantize_to_float16": False, # disable float16 quantization 369 "disable_infer_tensor_range": is_low_bit_qat, 370 "use_fake_quant_num_bits": is_low_bit_qat, 371 } 372 elif self.is_post_training_dynamic_range_quantization(): 373 return { 374 "inference_type": _dtypes.float32, 375 "inference_input_type": _dtypes.float32, 376 "post_training_quantize": True, # enable dynamic range quantization 377 "quantize_to_float16": False, # disable float16 quantization 378 # experimental: disable per-channel (per-axis) quantization. 379 "disable_per_channel_quantization": 380 self._disable_per_channel, 381 "enable_mlir_dynamic_range_quantizer": 382 self._enable_new_dynamic_range_quantizer 383 } 384 elif self.is_post_training_float16_quantization(): 385 return { 386 "inference_type": _dtypes.float32, 387 "inference_input_type": _dtypes.float32, 388 "post_training_quantize": True, 389 "quantize_to_float16": True, # enable float16 quantization 390 "accumulation_type": 391 self._target_spec._experimental_supported_accumulation_type, # pylint: disable=protected-access 392 "allow_bfloat16": 393 self.is_bfloat16_quantization(), 394 "enable_mlir_dynamic_range_quantizer": 395 self._enable_new_dynamic_range_quantizer 396 } 397 else: 398 # Note this might still trigger (uint8) quantization to be compatible with 399 # the old converter. 400 return { 401 "inference_type": ( 402 inference_ty if inference_ty is not None else _dtypes.float32), 403 "inference_input_type": inference_input_ty, 404 "post_training_quantize": False, # enable dynamic range quantization 405 "quantize_to_float16": False, # disable float16 quantization 406 "allow_bfloat16": self.is_bfloat16_quantization() 407 } 408 409 # Below are helpers for the above functions. 410 411 def _validate_int8_required(self): 412 """Int8 mode requires certain parameters to exist and be compatible.""" 413 if not self._is_int8_target_required(): 414 return 415 416 # Validate target_spec attibute. 417 if (set(self._target_spec.supported_ops) == {OpsSet.TFLITE_BUILTINS_INT8} 418 and not (set(self._target_spec.supported_types) == set() or 419 set(self._target_spec.supported_types) == {_dtypes.int8})): 420 raise ValueError( 421 "As full integer quantization has been enabled by setting " 422 "`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, " 423 "thus `target_spec.supported_types` should be left uninitizalized " 424 "or set to {tf.int8}.") 425 if set(self._target_spec.supported_types) == {_dtypes.int8}: 426 self._target_spec.supported_ops = {OpsSet.TFLITE_BUILTINS_INT8} 427 428 # Check if representative_dataset is specified. 429 if (not self._representative_dataset and 430 not self.is_quantization_aware_training()): 431 raise ValueError("For full integer quantization, a " 432 "`representative_dataset` must be specified.") 433 434 # Update represenative dataset to the expected format. 435 if self._representative_dataset: 436 if not isinstance(self._representative_dataset, RepresentativeDataset): 437 self._representative_dataset = RepresentativeDataset( 438 self._representative_dataset) 439 440 def _validate_full_integer_quantization_bias_type(self): 441 """Validates bias type for full interger quantization.""" 442 bias_type = self._full_integer_quantization_bias_type 443 if not bias_type: 444 return 445 446 if self.activations_type() == _dtypes.float32: 447 raise ValueError( 448 "`full_integer_quantization_bias_type` is only supported for full integer quantization." 449 ) 450 451 if self.activations_type() == _dtypes.int8 and bias_type != _dtypes.int32: 452 raise ValueError( 453 f"Expected bias type to be `dtypes.int32` for Int8Quant. " 454 f"Current setting bias type: {bias_type}") 455 456 if self.activations_type( 457 ) == _dtypes.int16 and bias_type != _dtypes.int32 and bias_type != _dtypes.int64: 458 raise ValueError( 459 f"Expected bias type to be `dtypes.int32` or `dtypes.int64` for " 460 f"Int16Quant. Current setting bias type: {bias_type}") 461 462 def _is_int8_target_required(self): 463 return (OpsSet.TFLITE_BUILTINS_INT8 in set( 464 self._target_spec.supported_ops)) or (set( 465 self._target_spec.supported_types) == set([_dtypes.int8])) 466 467 def _is_int16x8_target_required(self): 468 return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 469 in set(self._target_spec.supported_ops)) 470 471 def is_allow_float(self): 472 return (OpsSet.TFLITE_BUILTINS in set( 473 self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set( 474 self._target_spec.supported_ops)) 475 476 def is_any_optimization_enabled(self): 477 return bool( 478 set(self._optimizations).intersection([ 479 Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE, 480 Optimize.DEFAULT 481 ])) 482 483 def _smallest_supported_type(self): 484 if self._target_spec.supported_types: 485 return min(self._target_spec.supported_types, key=lambda x: x.size) 486 else: 487 # The default smallest supported type is INT8. 488 return _dtypes.int8 489 490 def is_quantization_aware_trained_model(self): 491 """Checks if the graph contains any training-time quantization ops.""" 492 training_quant_ops = frozenset({ 493 "FakeQuantWithMinMaxVars", 494 "FakeQuantWithMinMaxVarsPerChannel", 495 "FakeQuantWithMinMaxArgs", 496 "QuantizeAndDequantizeV2", 497 "QuantizeAndDequantizeV3", 498 }) 499 500 if self._graph_def: 501 for node_def in self._graph_def.node: 502 if node_def.op in training_quant_ops: 503 return True 504 for function in self._graph_def.library.function: 505 for node_def in function.node_def: 506 if node_def.op in training_quant_ops: 507 return True 508 return False 509 510 511class TFLiteConverterBase: 512 """Converter subclass to share functionality between V1 and V2 converters.""" 513 514 # Stores the original model type temporarily to transmit the information 515 # from the factory class methods to TFLiteConverterBase init function. 516 _original_model_type = conversion_metdata_fb.ModelType.NONE 517 518 def __init__(self): 519 self.optimizations = set() 520 self.representative_dataset = None 521 self.target_spec = TargetSpec() 522 self.allow_custom_ops = False 523 self.experimental_new_converter = True 524 self.experimental_new_quantizer = True 525 self.experimental_enable_resource_variables = True 526 self._experimental_calibrate_only = False 527 self._experimental_sparsify_model = False 528 self._experimental_disable_per_channel = False 529 self._debug_info = None # contains the stack traces of all the original 530 # nodes in the `GraphDef` to the converter. 531 self.saved_model_dir = None 532 self._saved_model_tags = None 533 self._saved_model_version = 0 534 self._saved_model_exported_names = [] 535 self._tflite_metrics = metrics.TFLiteConverterMetrics() 536 self._collected_converter_params = {} 537 self._experimental_disable_batchmatmul_unfold = False 538 self._experimental_lower_tensor_list_ops = True 539 self._experimental_default_to_single_batch_in_tensor_list_ops = False 540 self._experimental_unfold_large_splat_constant = False 541 self._experimental_tf_quantization_mode = None 542 # If unset, bias:int32 is by default except 16x8 quant. 543 # For 16x8 quant, bias:int64 is used to prevent any overflow by default. 544 self._experimental_full_integer_quantization_bias_type = None 545 # Initializes conversion metadata. 546 self.exclude_conversion_metadata = False 547 self._metadata = conversion_metdata_fb.ConversionMetadataT() 548 self._metadata.environment = conversion_metdata_fb.EnvironmentT() 549 self._metadata.options = conversion_metdata_fb.ConversionOptionsT() 550 self._metadata.environment.tensorflowVersion = versions.__version__ 551 self._metadata.environment.modelType = self._get_original_model_type() 552 self._experimental_enable_dynamic_update_slice = False 553 self._experimental_preserve_assert_op = False 554 self._experimental_guarantee_all_funcs_one_use = False 555 556 # When the value is true, the MLIR quantantizer triggers dynamic range 557 # quantization in MLIR instead of the old quantizer. Used only if 558 # experimental_new_quantizer is on. 559 self.experimental_new_dynamic_range_quantizer = True 560 # Experimental flag to enable low-bit QAT in 8 bit. 561 self._experimental_low_bit_qat = False 562 # Experimental flag to add all TF ops (including custom TF ops) to the 563 # converted model as flex ops. 564 self._experimental_allow_all_select_tf_ops = False 565 566 def _grappler_config(self, optimizers=None): 567 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 568 569 Args: 570 optimizers: List of strings that represents the list of optimizers. 571 572 Returns: 573 tf.ConfigProto. 574 """ 575 if not optimizers: 576 optimizers = [] 577 # MLIR converter will take care of constant folding instead of grappler. 578 if not self.experimental_new_converter: 579 optimizers.append("constfold") 580 581 is_only_flex_enabled = ( 582 set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops)) 583 if is_only_flex_enabled: 584 # The layout optimizer turns NHCW to NCHW. This provides performance 585 # optimizations when Flex mode is enabled. However, this is not compatible 586 # with builtin ops. 587 optimizers.append("layout") 588 return _get_grappler_config(optimizers) 589 590 def _quantize(self, result, input_type, output_type, activations_type, 591 bias_type, allow_float): 592 """Quantize the model.""" 593 # pylint: disable=protected-access 594 custom_op_registerers_by_name = [ 595 x for x in self.target_spec._experimental_custom_op_registerers 596 if isinstance(x, str) 597 ] 598 custom_op_registerers_by_func = [ 599 x for x in self.target_spec._experimental_custom_op_registerers 600 if not isinstance(x, str) 601 ] 602 # pylint: enable=protected-access 603 if not isinstance(self.representative_dataset, RepresentativeDataset): 604 self.representative_dataset = RepresentativeDataset( 605 self.representative_dataset) 606 607 # Add intermediate tensors to the model if needed. 608 result = _calibrator.add_intermediate_tensors(result) 609 calibrate_quantize = _calibrator.Calibrator(result, 610 custom_op_registerers_by_name, 611 custom_op_registerers_by_func) 612 if self._experimental_calibrate_only or self.experimental_new_quantizer: 613 calibrated = calibrate_quantize.calibrate( 614 self.representative_dataset.input_gen) 615 616 if self._experimental_calibrate_only: 617 return calibrated 618 elif self.experimental_new_quantizer and ( 619 activations_type != _dtypes.int16): 620 # TODO(b/175659372): remove the activations_type restriction and enable 621 # it for all the activation types. 622 return _mlir_quantize( 623 calibrated, 624 self._experimental_disable_per_channel, 625 input_data_type=input_type, 626 output_data_type=output_type) 627 else: 628 return calibrate_quantize.calibrate_and_quantize( 629 self.representative_dataset.input_gen, 630 input_type, 631 output_type, 632 allow_float, 633 activations_type, 634 bias_type, 635 disable_per_channel=self._experimental_disable_per_channel) 636 637 def _is_unknown_shapes_allowed(self): 638 # Unknown dimensions are only allowed with the new converter. 639 return self.experimental_new_converter 640 641 def _get_base_converter_args(self): 642 """Returns the base converter args. 643 644 Returns: 645 {key str: val} 646 """ 647 args = { 648 "input_format": 649 constants.TENSORFLOW_GRAPHDEF, 650 "allow_custom_ops": 651 self.allow_custom_ops, 652 "debug_info": 653 self._debug_info, 654 "target_ops": 655 self.target_spec.supported_ops, 656 "enable_mlir_converter": 657 self.experimental_new_converter, 658 "select_user_tf_ops": 659 self.target_spec.experimental_select_user_tf_ops, 660 "supported_backends": 661 self.target_spec.experimental_supported_backends, 662 "unfold_batchmatmul": 663 not self._experimental_disable_batchmatmul_unfold, 664 "lower_tensor_list_ops": 665 self._experimental_lower_tensor_list_ops, 666 "unfold_large_splat_constant": 667 self._experimental_unfold_large_splat_constant, 668 "default_to_single_batch_in_tensor_list_ops": 669 self._experimental_default_to_single_batch_in_tensor_list_ops, 670 "tf_quantization_mode": 671 self._experimental_tf_quantization_mode, 672 "experimental_enable_resource_variables": 673 self.experimental_enable_resource_variables, 674 "enable_dynamic_update_slice": 675 self._experimental_enable_dynamic_update_slice, 676 "preserve_assert_op": 677 self._experimental_preserve_assert_op, 678 "guarantee_all_funcs_one_use": 679 self._experimental_guarantee_all_funcs_one_use, 680 "allow_all_select_tf_ops": 681 self._experimental_allow_all_select_tf_ops, 682 } 683 684 if self.saved_model_dir: 685 args.update({ 686 "saved_model_dir": self.saved_model_dir, 687 "saved_model_version": self._saved_model_version, 688 "saved_model_tags": self._saved_model_tags, 689 "saved_model_exported_names": self._saved_model_exported_names, 690 }) 691 692 return args 693 694 def _contains_function_with_implements_attr(self, saved_model_proto): 695 meta_graph = saved_model_proto.meta_graphs[0] 696 for function in meta_graph.graph_def.library.function: 697 if function.attr.get("_implements", None) or function.attr.get( 698 "api_implements", None): 699 return True 700 return False 701 702 def _parse_saved_model_args(self, always_enable_saved_model_import=False): 703 """Parses SavedModel arguments from the given Keras/RNN SavedModel. 704 705 Args: 706 always_enable_saved_model_import: Bool. When the value is true, it enables 707 MLIR saved model import path regardless of checking the conditions. 708 """ 709 if not self.experimental_new_converter: 710 self.saved_model_dir = None 711 return 712 if self.saved_model_dir: 713 try: 714 saved_model_proto, _ = ( 715 _parse_saved_model_with_debug_info(self.saved_model_dir)) 716 except OSError: 717 # If it fails to read the given saved model, it will fall back to the 718 # frozen graph def path. 719 self.saved_model_dir = None 720 return 721 if (not always_enable_saved_model_import and 722 not self._contains_function_with_implements_attr(saved_model_proto)): 723 self.saved_model_dir = None 724 return 725 726 if not self._saved_model_exported_names: 727 self._saved_model_exported_names = [] 728 self._saved_model_version = saved_model_proto.saved_model_schema_version 729 if self._saved_model_version == 0: 730 self.saved_model_dir = None 731 logging.warning("SavedModel schema version is zero.") 732 return 733 if self._saved_model_version not in [1, 2]: 734 raise ValueError("SavedModel file format({0}) is not supported".format( 735 self._saved_model_version)) 736 737 def _sparsify_model(self): 738 return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations 739 740 def _increase_conversion_attempt_metric(self): 741 self._tflite_metrics.increase_counter_converter_attempt() 742 743 def _increase_conversion_success_metric(self): 744 self._tflite_metrics.increase_counter_converter_success() 745 746 @classmethod 747 def _set_original_model_type(cls, model_type): 748 """Stores the original model type.""" 749 if model_type == conversion_metdata_fb.ModelType.NONE: 750 raise ValueError("The original model type should be specified.") 751 cls._original_model_type = model_type 752 753 def _get_original_model_type(self): 754 """One-time getter to return original model type and set it to NONE.""" 755 model_type = TFLiteConverterBase._original_model_type 756 TFLiteConverterBase._original_model_type = conversion_metdata_fb.ModelType.NONE 757 return model_type 758 759 def _save_conversion_params_metric(self, 760 graph_def=None, 761 inference_type=None, 762 inference_input_type=None): 763 """Set conversion parameter metrics.""" 764 converter_kwargs = self._collected_converter_params 765 converter_kwargs.update(self._get_base_converter_args()) 766 767 # Optimization parameters. 768 quant_mode = QuantizationMode( 769 self.optimizations, self.target_spec, self.representative_dataset, 770 graph_def, self._experimental_disable_per_channel, 771 self.experimental_new_dynamic_range_quantizer, 772 self._experimental_low_bit_qat, 773 self._experimental_full_integer_quantization_bias_type) 774 converter_kwargs.update({ 775 "tf_version": 776 self._metadata.environment.tensorflowVersion, 777 "api_version": 778 self._metadata.environment.apiVersion, 779 "original_model_format": 780 self._metadata.environment.modelType, 781 "optimization_default": 782 quant_mode.is_any_optimization_enabled(), 783 "optimization_post_training_dynamic_range": 784 quant_mode.is_post_training_dynamic_range_quantization(), 785 "optimization_post_training_float16": 786 quant_mode.is_post_training_float16_quantization(), 787 "optimization_post_training_integer_quantize": 788 quant_mode.is_post_training_integer_quantization(), 789 "optimization_qat": 790 quant_mode.is_quantization_aware_training(), 791 "optimization_low_bit_qat": 792 quant_mode.is_low_bit_quantize_aware_training(), 793 "optimization_sparsify": 794 self._sparsify_model(), 795 "activations_type": 796 quant_mode.activations_type() 797 }) 798 converter_kwargs.update( 799 quant_mode.converter_flags(inference_type, inference_input_type)) 800 801 # pylint: disable=protected-access 802 if self.target_spec._experimental_supported_accumulation_type: 803 converter_kwargs.update({ 804 "accumulation_type": 805 self.target_spec._experimental_supported_accumulation_type 806 }) 807 # pylint: enable=protected-access 808 809 def format_element(elem): 810 if isinstance(elem, enum.Enum): 811 return str(elem.value) 812 return pprint.pformat(elem) 813 814 def format_param(param): 815 if isinstance(param, (list, tuple, set)): 816 if not param: 817 return "None" # Return None if empty. 818 string_list = [format_element(x) for x in param] 819 return ",".join(sorted(string_list)) 820 return format_element(param) 821 822 for key, value in converter_kwargs.items(): 823 self._tflite_metrics.set_converter_param(key, format_param(value)) 824 self._tflite_metrics.set_export_required() 825 826 # Set conversion option metadata. 827 self._metadata.options.allowCustomOps = self.allow_custom_ops 828 self._metadata.options.enableSelectTfOps = ( 829 OpsSet.SELECT_TF_OPS in self.target_spec.supported_ops) 830 self._metadata.options.forceSelectTfOps = ( 831 set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops)) 832 self._metadata.options.modelOptimizationModes = [] 833 834 if quant_mode.is_post_training_float16_quantization(): 835 self._metadata.options.modelOptimizationModes.append( 836 conversion_metdata_fb.ModelOptimizationMode.PTQ_FLOAT16) 837 838 if quant_mode.is_post_training_dynamic_range_quantization(): 839 self._metadata.options.modelOptimizationModes.append( 840 conversion_metdata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE) 841 842 if quant_mode.is_post_training_int8_quantization(): 843 self._metadata.options.modelOptimizationModes.append( 844 conversion_metdata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER) 845 846 if quant_mode.is_post_training_int16x8_quantization(): 847 self._metadata.options.modelOptimizationModes.append( 848 conversion_metdata_fb.ModelOptimizationMode.PTQ_INT16) 849 850 if quant_mode.is_quantization_aware_training(): 851 self._metadata.options.modelOptimizationModes.append( 852 conversion_metdata_fb.ModelOptimizationMode 853 .QUANTIZATION_AWARE_TRAINING) 854 855 def _set_conversion_latency_metric(self, value): 856 self._tflite_metrics.set_converter_latency(value) 857 858 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL) 859 def _optimize_tflite_model(self, model, quant_mode, quant_io=True): 860 """Apply optimizations on a TFLite model.""" 861 862 if quant_mode.is_integer_quantization(): 863 in_type, out_type = self.inference_input_type, self.inference_output_type 864 865 if quant_mode.is_post_training_integer_quantization(): 866 q_in_type = in_type if in_type and quant_io else _dtypes.float32 867 q_out_type = out_type if out_type and quant_io else _dtypes.float32 868 q_activations_type = quant_mode.activations_type() 869 q_bias_type = quant_mode.bias_type() 870 q_allow_float = quant_mode.is_allow_float() 871 model = self._quantize(model, q_in_type, q_out_type, q_activations_type, 872 q_bias_type, q_allow_float) 873 874 m_in_type = in_type if in_type else _dtypes.float32 875 m_out_type = out_type if out_type else _dtypes.float32 876 # Skip updating model io types if MLIR quantizer already takes care of it 877 if not (quant_mode.is_post_training_integer_quantization() and 878 self.experimental_new_quantizer and quant_io and 879 (m_in_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32]) and 880 (m_out_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32])): 881 model = _modify_model_io_type(model, m_in_type, m_out_type) 882 883 if self._sparsify_model(): 884 model = _mlir_sparsify(model) 885 886 try: 887 model = _deduplicate_readonly_buffers(model) 888 except Exception: # pylint: disable=broad-except 889 # Skip buffer deduplication when flatbuffer library is not ready to be 890 # utilized. 891 logging.warning( 892 "Buffer deduplication procedure will be skipped when flatbuffer " 893 "library is not properly loaded") 894 895 return model 896 897 def _convert_and_export_metrics(self, convert_func, *args, **kwargs): 898 """Wraps around convert function to export metrics. 899 900 Args: 901 convert_func: The convert function to wrap. 902 *args: Positional arguments of the convert function. 903 **kwargs: The keyword arguments of the convert function. 904 905 Returns: 906 The decorator to wrap the convert function. 907 """ 908 self._increase_conversion_attempt_metric() 909 self._save_conversion_params_metric() 910 start_time = time.process_time() 911 result = convert_func(self, *args, **kwargs) 912 elapsed_time_ms = (time.process_time() - start_time) * 1000 913 if result: 914 self._increase_conversion_success_metric() 915 self._set_conversion_latency_metric(round(elapsed_time_ms)) 916 self._tflite_metrics.export_metrics() 917 if self.exclude_conversion_metadata: 918 return result 919 model_object = flatbuffer_utils.convert_bytearray_to_object(result) 920 # Populates the conversion metadata. 921 # TODO(b/202090541): Collects sparsity block size information. 922 sparsity_modes = _get_sparsity_modes(model_object) 923 self._metadata.options.modelOptimizationModes.extend(sparsity_modes) 924 model_object = _populate_conversion_metadata(model_object, self._metadata) 925 return flatbuffer_utils.convert_object_to_bytearray(model_object) 926 927 928def _export_metrics(convert_func): 929 """The decorator around convert function to export metrics.""" 930 @functools.wraps(convert_func) 931 def wrapper(self, *args, **kwargs): 932 # pylint: disable=protected-access 933 return self._convert_and_export_metrics(convert_func, *args, **kwargs) 934 # pylint: enable=protected-access 935 936 return wrapper 937 938 939class TFLiteConverterBaseV2(TFLiteConverterBase): 940 """Converter subclass to share functionality between V2 converters.""" 941 942 def __init__(self): 943 """Constructor for TFLiteConverter.""" 944 super(TFLiteConverterBaseV2, self).__init__() 945 self.inference_input_type = _dtypes.float32 946 self.inference_output_type = _dtypes.float32 947 self._metadata.environment.apiVersion = 2 948 949 def _validate_inference_input_output_types(self, quant_mode): 950 """Validate inference_input_type and inference_output_type flags.""" 951 default_types = [_dtypes.float32] 952 # We support integer input/output for integer quantized models only. 953 if quant_mode.is_integer_quantization(): 954 if quant_mode.is_post_training_int16x8_quantization(): 955 all_types = default_types + [_dtypes.int16] 956 else: 957 all_types = default_types + [_dtypes.int8, _dtypes.uint8] 958 if (self.inference_input_type not in all_types or 959 self.inference_output_type not in all_types): 960 all_types_names = ["tf." + t.name for t in all_types] 961 raise ValueError("The inference_input_type and inference_output_type " 962 "must be in {}.".format(all_types_names)) 963 elif (self.inference_input_type not in default_types or 964 self.inference_output_type not in default_types): 965 raise ValueError("The inference_input_type and inference_output_type " 966 "must be tf.float32.") 967 968 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL) 969 def _load_saved_model(self, saved_model_dir, saved_model_tags): 970 """Load graph_def from saved model with the default serving signature key. 971 972 Args: 973 saved_model_dir: Directory of the SavedModel. 974 saved_model_tags: Set of tags identifying the MetaGraphDef within the 975 SavedModel to analyze. 976 977 Returns: 978 graph_def: The loaded GraphDef. 979 input_tensors: List of input tensors. 980 output_tensors: List of output tensors. 981 """ 982 graph = _ops.Graph() 983 saved_model = _loader_impl.SavedModelLoader(saved_model_dir) 984 saved_model.load_graph(graph, tags=saved_model_tags) 985 meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags) 986 graph_def = meta_graph.graph_def 987 signature_def = meta_graph.signature_def[ 988 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 989 input_tensors = [ 990 graph.get_tensor_by_name(signature_def.inputs[key].name) 991 for key in signature_def.inputs 992 ] 993 output_tensors = [ 994 graph.get_tensor_by_name(signature_def.outputs[key].name) 995 for key in signature_def.outputs 996 ] 997 return graph_def, input_tensors, output_tensors 998 999 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS) 1000 def _validate_inputs(self, graph_def, input_tensors): 1001 """Validate the input parameters. 1002 1003 Args: 1004 graph_def: The TensorFlow GraphDef. 1005 input_tensors: List of input tensors. 1006 Raise: 1007 ValueError: 1008 Input shape is not specified. 1009 Invalid quantization parameters. 1010 """ 1011 # Update conversion params with graph_def. 1012 self._save_conversion_params_metric(graph_def) 1013 self._quant_mode = QuantizationMode( 1014 self.optimizations, self.target_spec, self.representative_dataset, 1015 graph_def, self._experimental_disable_per_channel, 1016 self.experimental_new_dynamic_range_quantizer, 1017 self._experimental_low_bit_qat, 1018 self._experimental_full_integer_quantization_bias_type) 1019 self._validate_inference_input_output_types(self._quant_mode) 1020 1021 if not self._is_unknown_shapes_allowed(): 1022 # Checks dimensions in input tensor. 1023 for tensor in input_tensors: 1024 # Note that shape_list might be empty for scalar shapes. 1025 shape_list = tensor.shape.as_list() 1026 if None in shape_list[1:]: 1027 raise ValueError( 1028 "None is only supported in the 1st dimension. Tensor '{0}' has " 1029 "invalid shape '{1}'.".format( 1030 _get_tensor_name(tensor), shape_list)) 1031 elif shape_list and shape_list[0] is None: 1032 # Set the batch size to 1 if undefined. 1033 shape = tensor.shape.as_list() 1034 shape[0] = 1 1035 tensor.set_shape(shape) 1036 1037 if (self._trackable_obj is None or 1038 not hasattr(self._trackable_obj, "graph_debug_info")): 1039 self._debug_info = _get_debug_info( 1040 _build_debug_info_func(self._funcs[0].graph), graph_def) 1041 else: 1042 self._debug_info = _get_debug_info( 1043 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 1044 graph_def) 1045 1046 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL) 1047 def _optimize_tf_model(self, graph_def, input_tensors, output_tensors, 1048 frozen_func): 1049 """Run a Grappler pass to optimize the TensorFlow graph. 1050 1051 Args: 1052 graph_def: Frozen GraphDef to be optimized. 1053 input_tensors: List of input tensors. 1054 output_tensors: List of output tensors. 1055 frozen_func: TensorFlow Graph. 1056 1057 Returns: 1058 The optimized TensorFlow graph. 1059 """ 1060 grappler_config = self._grappler_config() 1061 # Skip running grappler when there are no optimizers to run. If not, 1062 # grappler will run with the default optimizer set and it will lead to 1063 # causing an unexpected behavior. 1064 if grappler_config.graph_options.rewrite_options.optimizers: 1065 graph_def = _run_graph_optimizations( 1066 graph_def, 1067 input_tensors, 1068 output_tensors, 1069 config=grappler_config, 1070 graph=frozen_func.graph) 1071 return graph_def 1072 1073 def _convert_from_saved_model(self, graph_def): 1074 """Helper method that converts saved model. 1075 1076 Args: 1077 graph_def: GraphDef object for the model, used only for stats. 1078 1079 Returns: 1080 The converted TFLite model. 1081 """ 1082 # Update conversion params with graph_def. 1083 self._save_conversion_params_metric(graph_def) 1084 # Get quantization options and do some sanity checks. 1085 quant_mode = QuantizationMode( 1086 self.optimizations, self.target_spec, self.representative_dataset, 1087 graph_def, self._experimental_disable_per_channel, 1088 self.experimental_new_dynamic_range_quantizer, 1089 self._experimental_low_bit_qat, 1090 self._experimental_full_integer_quantization_bias_type) 1091 self._validate_inference_input_output_types(quant_mode) 1092 converter_kwargs = { 1093 "enable_tflite_resource_variables": 1094 self.experimental_enable_resource_variables 1095 } 1096 converter_kwargs.update(self._get_base_converter_args()) 1097 converter_kwargs.update(quant_mode.converter_flags()) 1098 1099 result = _convert_saved_model(**converter_kwargs) 1100 return self._optimize_tflite_model( 1101 result, quant_mode, quant_io=self.experimental_new_quantizer) 1102 1103 def convert(self, graph_def, input_tensors, output_tensors): 1104 """Converts a TensorFlow GraphDef based on instance variables. 1105 1106 Args: 1107 graph_def: Frozen TensorFlow GraphDef. 1108 input_tensors: List of input tensors. 1109 output_tensors: List of output tensors. 1110 1111 Returns: 1112 The converted data in serialized format. 1113 1114 Raises: 1115 ValueError: 1116 No concrete functions is specified. 1117 Multiple concrete functions are specified. 1118 Input shape is not specified. 1119 Invalid quantization parameters. 1120 """ 1121 self._validate_inputs(graph_def, input_tensors) 1122 converter_kwargs = self._get_base_converter_args() 1123 converter_kwargs.update(self._quant_mode.converter_flags()) 1124 if not self.experimental_new_converter: 1125 logging.warning( 1126 "Please consider switching to the new converter by setting " 1127 "experimental_new_converter=True. " 1128 "The old converter is deprecated.") 1129 else: 1130 logging.info("Using new converter: If you encounter a problem " 1131 "please file a bug. You can opt-out " 1132 "by setting experimental_new_converter=False") 1133 1134 # Converts model. 1135 result = _convert_graphdef( 1136 input_data=graph_def, 1137 input_tensors=input_tensors, 1138 output_tensors=output_tensors, 1139 **converter_kwargs) 1140 1141 return self._optimize_tflite_model( 1142 result, self._quant_mode, quant_io=self.experimental_new_quantizer) 1143 1144 1145class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): 1146 """Converts the given SavedModel into TensorFlow Lite model. 1147 1148 Attributes: 1149 saved_model_dir: Directory of the SavedModel. 1150 """ 1151 1152 def __init__(self, 1153 saved_model_dir, 1154 saved_model_tags=None, 1155 saved_model_exported_names=None, 1156 trackable_obj=None): 1157 """Constructor for TFLiteConverter. 1158 1159 Args: 1160 saved_model_dir: Directory of the SavedModel. 1161 saved_model_tags: Set of tags identifying the MetaGraphDef within the 1162 SavedModel to analyze. All tags in the tag set must be present. (default 1163 {tf.saved_model.SERVING}). 1164 saved_model_exported_names: Names to be exported when the saved model 1165 import path is on. 1166 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1167 reference to this object needs to be maintained so that Variables do not 1168 get garbage collected since functions have a weak reference to 1169 Variables. This is only required when the tf.AutoTrackable object is not 1170 maintained by the user (e.g. `from_saved_model`). 1171 """ 1172 super(TFLiteSavedModelConverterV2, self).__init__() 1173 self.saved_model_dir = saved_model_dir 1174 self._saved_model_tags = saved_model_tags 1175 self._saved_model_exported_names = saved_model_exported_names 1176 self._trackable_obj = trackable_obj 1177 self._parse_saved_model_args(always_enable_saved_model_import=True) 1178 1179 @_export_metrics 1180 def convert(self): 1181 """Converts a TensorFlow GraphDef based on instance variables. 1182 1183 Returns: 1184 The converted data in serialized format. 1185 1186 Raises: 1187 ValueError: 1188 No concrete functions is specified. 1189 Multiple concrete functions are specified. 1190 Input shape is not specified. 1191 Invalid quantization parameters. 1192 """ 1193 graph_def, input_tensors, output_tensors = self._load_saved_model( 1194 self.saved_model_dir, self._saved_model_tags) 1195 # If we can't use saved model importer, then fallback 1196 # to frozen graph conversion path. 1197 if self.saved_model_dir is None or not self.experimental_new_converter: 1198 graph_def, _, _, _ = _freeze_saved_model( 1199 self.saved_model_dir, None, None, None, self._saved_model_tags, 1200 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 1201 # We make sure to clear the saved_model_dir as there is some 1202 # legacy code down in the caller that checks this. 1203 # TODO(b/162537905): Clean these indirect dependencies. 1204 self.saved_model_dir = None 1205 return super(TFLiteSavedModelConverterV2, 1206 self).convert(graph_def, input_tensors, output_tensors) 1207 1208 if self._trackable_obj is None: 1209 self._debug_info = _get_debug_info( 1210 _build_debug_info_func(self._funcs[0].graph), graph_def) 1211 else: 1212 self._debug_info = _get_debug_info( 1213 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 1214 graph_def) 1215 1216 return self._convert_from_saved_model(graph_def) 1217 1218 1219class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): 1220 """Converts the given Keras model into TensorFlow Lite model.""" 1221 1222 def __init__(self, keras_model, trackable_obj=None): 1223 """Constructor for TFLiteConverter. 1224 1225 Args: 1226 keras_model: tf.Keras.Model. 1227 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1228 reference to this object needs to be maintained so that Variables do not 1229 get garbage collected since functions have a weak reference to 1230 Variables. This is only required when the tf.AutoTrackable object is not 1231 maintained by the user (e.g. `from_saved_model`). 1232 """ 1233 super(TFLiteKerasModelConverterV2, self).__init__() 1234 self._keras_model = keras_model 1235 self._trackable_obj = trackable_obj 1236 self.experimental_lower_to_saved_model = True 1237 1238 @convert_phase(Component.PREPARE_TF_MODEL, 1239 SubComponent.CONVERT_KERAS_TO_SAVED_MODEL) 1240 def _convert_keras_to_saved_model(self, output_dir): 1241 """Save Keras model to the SavedModel format. 1242 1243 Args: 1244 output_dir: The output directory to save the SavedModel. 1245 1246 Returns: 1247 graph_def: The frozen GraphDef. 1248 input_tensors: List of input tensors. 1249 output_tensors: List of output tensors. 1250 """ 1251 try: 1252 _saved_model.save( 1253 self._keras_model, 1254 output_dir, 1255 options=_save_options.SaveOptions(save_debug_info=True)) 1256 except Exception: # pylint: disable=broad-except 1257 # When storing the given keras model to a saved model is failed, let's 1258 # use original keras model conversion pipeline. 1259 return None, None, None 1260 self.saved_model_dir = output_dir 1261 self._saved_model_tags = set([_tag_constants.SERVING]) 1262 self._saved_model_exported_names = [ 1263 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1264 ] 1265 self._parse_saved_model_args( 1266 always_enable_saved_model_import=self.experimental_lower_to_saved_model) 1267 if self.saved_model_dir: 1268 graph_def, input_tensors, output_tensors = self._load_saved_model( 1269 self.saved_model_dir, self._saved_model_tags) 1270 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags) 1271 return graph_def, input_tensors, output_tensors 1272 return None, None, None 1273 1274 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL) 1275 def _freeze_keras_model(self): 1276 """Freeze Keras model to frozen graph. 1277 1278 Returns: 1279 graph_def: The frozen GraphDef. 1280 input_tensors: List of input tensors. 1281 output_tensors: List of output tensors. 1282 frozen_func: The frozen ConcreteFunction. 1283 """ 1284 input_signature = None 1285 # If the model's call is not a `tf.function`, then we need to first get its 1286 # input signature from `model_input_signature` method. We can't directly 1287 # call `trace_model_call` because otherwise the batch dimension is set 1288 # to None. 1289 # Once we have better support for dynamic shapes, we can remove this. 1290 if not isinstance(self._keras_model.call, _def_function.Function): 1291 # Pass `keep_original_batch_size=True` will ensure that we get an input 1292 # signature including the batch dimension specified by the user. 1293 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 1294 input_signature = _model_input_signature( 1295 self._keras_model, keep_original_batch_size=True) 1296 1297 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 1298 func = _trace_model_call(self._keras_model, input_signature) 1299 concrete_func = func.get_concrete_function() 1300 self._funcs = [concrete_func] 1301 1302 frozen_func, graph_def = ( 1303 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 1304 self._funcs[0], lower_control_flow=False)) 1305 1306 input_tensors = [ 1307 tensor for tensor in frozen_func.inputs 1308 if tensor.dtype != _dtypes.resource 1309 ] 1310 output_tensors = frozen_func.outputs 1311 return graph_def, input_tensors, output_tensors, frozen_func 1312 1313 def _convert_as_saved_model(self): 1314 """Converts a Keras model as a saved model. 1315 1316 Returns: 1317 The converted data in serialized format. 1318 """ 1319 temp_dir = tempfile.mkdtemp() 1320 try: 1321 graph_def, input_tensors, output_tensors = ( 1322 self._convert_keras_to_saved_model(temp_dir)) 1323 if self.saved_model_dir: 1324 return super(TFLiteKerasModelConverterV2, 1325 self).convert(graph_def, input_tensors, output_tensors) 1326 finally: 1327 shutil.rmtree(temp_dir, True) 1328 1329 @_export_metrics 1330 def convert(self): 1331 """Converts a keras model based on instance variables. 1332 1333 Returns: 1334 The converted data in serialized format. 1335 1336 Raises: 1337 ValueError: 1338 Multiple concrete functions are specified. 1339 Input shape is not specified. 1340 Invalid quantization parameters. 1341 """ 1342 saved_model_convert_result = self._convert_as_saved_model() 1343 if saved_model_convert_result: 1344 return saved_model_convert_result 1345 1346 graph_def, input_tensors, output_tensors, frozen_func = ( 1347 self._freeze_keras_model()) 1348 1349 graph_def = self._optimize_tf_model(graph_def, input_tensors, 1350 output_tensors, frozen_func) 1351 1352 return super(TFLiteKerasModelConverterV2, 1353 self).convert(graph_def, input_tensors, output_tensors) 1354 1355 1356class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2): 1357 """Converts the given frozen graph into TensorFlow Lite model.""" 1358 1359 def __init__(self, funcs, trackable_obj=None): 1360 """Constructor for TFLiteConverter. 1361 1362 Args: 1363 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1364 duplicate elements. 1365 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1366 reference to this object needs to be maintained so that Variables do not 1367 get garbage collected since functions have a weak reference to 1368 Variables. This is only required when the tf.AutoTrackable object is not 1369 maintained by the user (e.g. `from_saved_model`). 1370 """ 1371 super(TFLiteFrozenGraphConverterV2, self).__init__() 1372 self._funcs = funcs 1373 self._trackable_obj = trackable_obj 1374 self.experimental_lower_to_saved_model = True 1375 1376 @convert_phase(Component.PREPARE_TF_MODEL, 1377 SubComponent.FREEZE_CONCRETE_FUNCTION) 1378 def _freeze_concrete_function(self): 1379 """Convert the given ConcreteFunction to frozen graph. 1380 1381 Returns: 1382 graph_def: The frozen GraphDef. 1383 input_tensors: List of input tensors. 1384 output_tensors: List of output tensors. 1385 frozen_func: The frozen ConcreteFunction. 1386 1387 Raises: 1388 ValueError: none or multiple ConcreteFunctions provided. 1389 """ 1390 # TODO(b/130297984): Add support for converting multiple function. 1391 1392 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test 1393 raise ValueError("No ConcreteFunction is specified.") 1394 1395 if len(self._funcs) > 1: 1396 raise ValueError("This converter can only convert a single " 1397 "ConcreteFunction. Converting multiple functions is " 1398 "under development.") 1399 1400 frozen_func, graph_def = ( 1401 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 1402 self._funcs[0], lower_control_flow=False)) 1403 1404 input_tensors = [ 1405 tensor for tensor in frozen_func.inputs 1406 if tensor.dtype != _dtypes.resource 1407 ] 1408 output_tensors = frozen_func.outputs 1409 return graph_def, input_tensors, output_tensors, frozen_func 1410 1411 @convert_phase(Component.PREPARE_TF_MODEL, 1412 SubComponent.CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL) 1413 def _convert_concrete_functions_to_saved_model(self, output_dir): 1414 """Save concrete functions to the SavedModel format. 1415 1416 Args: 1417 output_dir: The output directory to save the SavedModel. 1418 1419 Returns: 1420 graph_def: The frozen GraphDef. 1421 input_tensors: List of input tensors. 1422 output_tensors: List of output tensors. 1423 """ 1424 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test 1425 raise ValueError("No ConcreteFunction is specified.") 1426 1427 if not self.experimental_lower_to_saved_model: 1428 return None, None, None 1429 1430 # Without the provided trackable obj, it is not able to serialize the given 1431 # concrete functions as a saved model format. Also when trackable obj is 1432 # a function, use the original concrete function conversion pipline. 1433 if (not self._trackable_obj or 1434 isinstance(self._trackable_obj, (_function.ConcreteFunction, 1435 _def_function.Function))): 1436 return None, None, None 1437 1438 signatures = {} 1439 signature_keys = [] 1440 try: 1441 if len(self._funcs) == 1: 1442 signatures[_signature_constants 1443 .DEFAULT_SERVING_SIGNATURE_DEF_KEY] = self._funcs[0] 1444 signature_keys = [ 1445 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1446 ] 1447 else: 1448 for func in self._funcs: 1449 signatures[func.graph.name] = func 1450 signature_keys.append(func.graph.name) 1451 1452 _saved_model.save( 1453 self._trackable_obj, 1454 output_dir, 1455 signatures=signatures, 1456 options=_save_options.SaveOptions(save_debug_info=True)) 1457 except Exception: # pylint: disable=broad-except 1458 # When storing the given concrete function to a saved model is failed, 1459 # let's use original concrete function conversion pipeline. 1460 return None, None, None 1461 1462 self.saved_model_dir = output_dir 1463 self._saved_model_tags = set([_tag_constants.SERVING]) 1464 self._saved_model_exported_names = signature_keys 1465 self._parse_saved_model_args(always_enable_saved_model_import=True) 1466 if self.saved_model_dir: 1467 graph_def, input_tensors, output_tensors = self._load_saved_model( 1468 self.saved_model_dir, self._saved_model_tags) 1469 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags) 1470 return graph_def, input_tensors, output_tensors 1471 return None, None, None 1472 1473 def _convert_as_saved_model(self): 1474 """Converts the given concrete functions as a saved model format. 1475 1476 Returns: 1477 The converted data in serialized format. 1478 """ 1479 temp_dir = tempfile.mkdtemp() 1480 try: 1481 graph_def, input_tensors, _ = ( 1482 self._convert_concrete_functions_to_saved_model(temp_dir)) 1483 if self.saved_model_dir: 1484 self._validate_inputs(graph_def, input_tensors) 1485 return self._convert_from_saved_model(graph_def) 1486 finally: 1487 shutil.rmtree(temp_dir, True) 1488 return None 1489 1490 @_export_metrics 1491 def convert(self): 1492 """Converts a TensorFlow GraphDef based on instance variables. 1493 1494 Returns: 1495 The converted data in serialized format. 1496 1497 Raises: 1498 ValueError: 1499 No concrete functions is specified. 1500 Multiple concrete functions are specified. 1501 Input shape is not specified. 1502 Invalid quantization parameters. 1503 """ 1504 if self.experimental_lower_to_saved_model: 1505 saved_model_convert_result = self._convert_as_saved_model() 1506 if saved_model_convert_result: 1507 return saved_model_convert_result 1508 1509 graph_def, input_tensors, output_tensors, frozen_func = ( 1510 self._freeze_concrete_function()) 1511 1512 graph_def = self._optimize_tf_model(graph_def, input_tensors, 1513 output_tensors, frozen_func) 1514 1515 return super(TFLiteFrozenGraphConverterV2, 1516 self).convert(graph_def, input_tensors, output_tensors) 1517 1518 1519class TFLiteJaxConverterV2(TFLiteConverterBaseV2): 1520 """Converts the given jax model into TensorFlow Lite model.""" 1521 1522 def __init__(self, serving_funcs, inputs): 1523 """Constructor for TFLiteConverter. 1524 1525 Args: 1526 serving_funcs: A list functions of the serving func of the jax module, the 1527 model params should already be inlined. (e.g., `serving_func = 1528 functools.partial(model, params=params)`) 1529 inputs: Array of input tensor placeholders tuple,s like `jnp.zeros`. For 1530 example, wrapped in an array like 1531 "[('input1', input1), ('input2', input2)]]". 1532 Jax function is polymorphic, for example: 1533 ```python 1534 def add(a, b): 1535 return a + b 1536 ``` 1537 Will yield different computations if different input signatures are passed 1538 in: Pass `add(10.0, 20.0)` will yield a scalar `add` while pass 1539 `add(np.random((100, 1)), np.random(100, 100))` will yield a broadcasting 1540 add. We will need the input information to do tracing for the converter 1541 to properly convert the model. So it's important to pass in the desired 1542 `input placeholders` with the correct input shape/type. 1543 1544 In the converted tflite model: 1545 Currently: the function name will be default to main, the output names will 1546 be the traced outputs. The output ordering shall match the serving function. 1547 """ 1548 super(TFLiteJaxConverterV2, self).__init__() 1549 self._serving_funcs = serving_funcs 1550 self._inputs = inputs 1551 1552 @_export_metrics 1553 def convert(self): 1554 """Converts a Jax serving func based on instance variables. 1555 1556 Returns: 1557 The converted data in serialized format. 1558 1559 Raises: 1560 ImportError: 1561 If cannot import the xla_computation from jax. 1562 ValueError: 1563 No serving function is specified. 1564 Input tensors are not specified. 1565 The truth value of an array with more than one element is ambiguous. 1566 Failed to convert the given Jax function to hlo. 1567 1568 """ 1569 if not _xla_computation: 1570 raise ImportError("Cannot import xla_computation from jax.") 1571 1572 if not self._serving_funcs: 1573 raise ValueError("No serving func is specified.") 1574 1575 if not self._inputs: 1576 raise ValueError("Input tensors are not specified.") 1577 1578 if len(self._inputs) != len(self._serving_funcs): 1579 msg = ("Input tensor mapping len {} does not match serving func len {}." 1580 .format(len(self._inputs), len(self._serving_funcs))) 1581 raise ValueError(msg) 1582 1583 if not isinstance(self._inputs, (tuple, list)): 1584 raise ValueError( 1585 "Input tensors should be pass in a tuple list wrapped in an array.") 1586 1587 # TODO(b/197690428): Support multiple functions. 1588 # Currently only support one serving function. 1589 if len(self._serving_funcs) > 1: 1590 raise ValueError("Currently only support single serving function.") 1591 1592 if not isinstance(self._inputs[0], (tuple, list)): 1593 raise ValueError("The input placeholders are not a dictionary.") 1594 1595 input_names = [] 1596 ordered_inputs = [] 1597 for input_name, tensor in self._inputs[0]: 1598 input_names.append(input_name) 1599 ordered_inputs.append(tensor) 1600 1601 try: 1602 xla_compuation = _xla_computation(self._serving_funcs[0], backend="cpu") 1603 hlo_proto = xla_compuation( 1604 *ordered_inputs).as_serialized_hlo_module_proto() 1605 except Exception: # pylint: disable=broad-except 1606 raise ValueError("Failed to convert the given Jax function to hlo.") 1607 1608 # We need to set the hlo proto, and here we use serialized proto format 1609 # since it's more compact. 1610 converter_kwargs = { 1611 "input_content": hlo_proto, 1612 "input_names": input_names, 1613 "is_proto_format": True 1614 } 1615 converter_kwargs.update(self._get_base_converter_args()) 1616 1617 # Get quantization options and do some checks. 1618 quant_mode = QuantizationMode(self.optimizations, self.target_spec, 1619 self.representative_dataset, None) 1620 self._validate_inference_input_output_types(quant_mode) 1621 converter_kwargs.update(quant_mode.converter_flags()) 1622 result = _convert_jax_hlo(**converter_kwargs) 1623 1624 return self._optimize_tflite_model( 1625 result, quant_mode, quant_io=self.experimental_new_quantizer) 1626 1627 1628@_tf_export("lite.TFLiteConverter", v1=[]) 1629class TFLiteConverterV2(TFLiteFrozenGraphConverterV2): 1630 """Converts a TensorFlow model into TensorFlow Lite model. 1631 1632 Attributes: 1633 optimizations: Experimental flag, subject to change. Set of optimizations to 1634 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 1635 set of values of type `tf.lite.Optimize`) 1636 representative_dataset: A generator function used for integer quantization 1637 where each generated sample has the same order, type and shape as the 1638 inputs to the model. Usually, this is a small subset of a few hundred 1639 samples randomly chosen, in no particular order, from the training or 1640 evaluation dataset. This is an optional attribute, but required for full 1641 integer quantization, i.e, if `tf.int8` is the only supported type in 1642 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 1643 (default None) 1644 target_spec: Experimental flag, subject to change. Specifications of target 1645 device, including supported ops set, supported types and a set of user's 1646 defined TensorFlow operators required in the TensorFlow Lite runtime. 1647 Refer to `tf.lite.TargetSpec`. 1648 inference_input_type: Data type of the input layer. Note that integer types 1649 (tf.int8 and tf.uint8) are currently only supported for post training 1650 integer quantization and quantization aware training. (default tf.float32, 1651 must be in {tf.float32, tf.int8, tf.uint8}) 1652 inference_output_type: Data type of the output layer. Note that integer 1653 types (tf.int8 and tf.uint8) are currently only supported for post 1654 training integer quantization and quantization aware training. (default 1655 tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 1656 allow_custom_ops: Boolean indicating whether to allow custom operations. 1657 When False, any unknown operation is an error. When True, custom ops are 1658 created for any op that is unknown. The developer needs to provide these 1659 to the TensorFlow Lite runtime with a custom resolver. (default False) 1660 exclude_conversion_metadata: Whether not to embed the conversion metadata 1661 into the converted model. (default False) 1662 experimental_new_converter: Experimental flag, subject to change. Enables 1663 MLIR-based conversion. (default True) 1664 experimental_new_quantizer: Experimental flag, subject to change. Enables 1665 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 1666 (default True) 1667 experimental_enable_resource_variables: Experimental flag, subject to 1668 change. Enables 1669 [resource variables](https://tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables) 1670 to be converted by this converter. This is only allowed if the 1671 from_saved_model interface is used. (default True) 1672 1673 Example usage: 1674 1675 ```python 1676 # Converting a SavedModel to a TensorFlow Lite model. 1677 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) 1678 tflite_model = converter.convert() 1679 1680 # Converting a tf.Keras model to a TensorFlow Lite model. 1681 converter = tf.lite.TFLiteConverter.from_keras_model(model) 1682 tflite_model = converter.convert() 1683 1684 # Converting ConcreteFunctions to a TensorFlow Lite model. 1685 converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model) 1686 tflite_model = converter.convert() 1687 1688 # Converting a Jax model to a TensorFlow Lite model. 1689 converter = tf.lite.TFLiteConverter.experimental_from_jax([func], [[ 1690 ('input1', input1), ('input2', input2)]]) 1691 tflite_model = converter.convert() 1692 ``` 1693 """ 1694 1695 # pylint: disable=useless-super-delegation 1696 def __init__(self, funcs, trackable_obj=None): 1697 """Constructor for TFLiteConverter. 1698 1699 Args: 1700 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1701 duplicate elements. 1702 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1703 reference to this object needs to be maintained so that Variables do not 1704 get garbage collected since functions have a weak reference to 1705 Variables. This is only required when the tf.AutoTrackable object is not 1706 maintained by the user (e.g. `from_saved_model`). 1707 """ 1708 super(TFLiteConverterV2, self).__init__(funcs, trackable_obj) 1709 1710 @classmethod 1711 def from_concrete_functions(cls, funcs, trackable_obj=None): 1712 """Creates a TFLiteConverter object from ConcreteFunctions. 1713 1714 Args: 1715 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1716 duplicate elements. Currently converter can only convert a single 1717 ConcreteFunction. Converting multiple functions is under development. 1718 trackable_obj: An `AutoTrackable` object (typically `tf.module`) 1719 associated with `funcs`. A reference to this object needs to be 1720 maintained so that Variables do not get garbage collected since 1721 functions have a weak reference to Variables. 1722 1723 Returns: 1724 TFLiteConverter object. 1725 1726 Raises: 1727 Invalid input type. 1728 """ 1729 # pylint: disable=protected-access 1730 TFLiteConverterBase._set_original_model_type( 1731 conversion_metdata_fb.ModelType.TF_CONCRETE_FUNCTIONS) 1732 # pylint: enable=protected-access 1733 if trackable_obj is None: 1734 logging.warning( 1735 "Please consider providing the trackable_obj argument in the " 1736 "from_concrete_functions. Providing without the trackable_obj " 1737 "argument is deprecated and it will use the deprecated conversion " 1738 "path.") 1739 for func in funcs: 1740 if not isinstance(func, _function.ConcreteFunction): 1741 message = "This function takes in a list of ConcreteFunction." 1742 if isinstance(func, _def_function.Function): 1743 message += (" To get the ConcreteFunction from a Function," 1744 " call get_concrete_function.") 1745 raise ValueError(message) 1746 return cls(funcs, trackable_obj) 1747 1748 @classmethod 1749 def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None): 1750 """Creates a TFLiteConverter object from a SavedModel directory. 1751 1752 Args: 1753 saved_model_dir: SavedModel directory to convert. 1754 signature_keys: List of keys identifying SignatureDef containing inputs 1755 and outputs. Elements should not be duplicated. By default the 1756 `signatures` attribute of the MetaGraphdef is used. (default 1757 saved_model.signatures) 1758 tags: Set of tags identifying the MetaGraphDef within the SavedModel to 1759 analyze. All tags in the tag set must be present. (default 1760 {tf.saved_model.SERVING} or {'serve'}) 1761 1762 Returns: 1763 TFLiteConverter object. 1764 1765 Raises: 1766 Invalid signature keys. 1767 """ 1768 # pylint: disable=protected-access 1769 TFLiteConverterBase._set_original_model_type( 1770 conversion_metdata_fb.ModelType.TF_SAVED_MODEL) 1771 # pylint: enable=protected-access 1772 # When run without eager enabled, this will return the legacy 1773 # TFLiteConverter. 1774 if not context.executing_eagerly(): 1775 signature_key = None 1776 if signature_keys: 1777 if len(signature_keys) != 1: 1778 raise ValueError("Only support a single signature key.") 1779 else: 1780 signature_key = signature_keys[0] 1781 logging.warning("Invoking the TF1 implementation of TFLiteConverter " 1782 "because eager is disabled. Consider enabling eager.") 1783 return TFLiteConverter.from_saved_model( 1784 saved_model_dir, signature_key=signature_key, tag_set=tags) 1785 1786 # Ensures any graphs created in Eager mode are able to run. This is required 1787 # in order to create a tf.estimator.Exporter that exports a TFLite model. 1788 if tags is None: 1789 tags = set([_tag_constants.SERVING]) 1790 1791 with context.eager_mode(): 1792 saved_model = _load(saved_model_dir, tags) 1793 if not signature_keys: 1794 signature_keys = saved_model.signatures 1795 1796 if not signature_keys: 1797 raise ValueError("Only support at least one signature key.") 1798 1799 funcs = [] 1800 for key in signature_keys: 1801 if key not in saved_model.signatures: 1802 raise ValueError("Invalid signature key '{}' found. Valid keys are " 1803 "'{}'.".format(key, ",".join(saved_model.signatures))) 1804 funcs.append(saved_model.signatures[key]) 1805 1806 saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags, 1807 signature_keys, 1808 saved_model) 1809 if saved_model_converter.saved_model_dir: 1810 return saved_model_converter 1811 1812 return cls(funcs, saved_model) 1813 1814 @classmethod 1815 def from_keras_model(cls, model): 1816 """Creates a TFLiteConverter object from a Keras model. 1817 1818 Args: 1819 model: tf.Keras.Model 1820 1821 Returns: 1822 TFLiteConverter object. 1823 """ 1824 # pylint: disable=protected-access 1825 TFLiteConverterBase._set_original_model_type( 1826 conversion_metdata_fb.ModelType.KERAS_MODEL) 1827 # pylint: enable=protected-access 1828 return TFLiteKerasModelConverterV2(model) 1829 1830 @classmethod 1831 def experimental_from_jax(cls, serving_funcs, inputs): 1832 # Experimental API, subject to changes. 1833 # TODO(b/197690428): Currently only support single function. 1834 """Creates a TFLiteConverter object from a Jax model with its inputs. 1835 1836 Args: 1837 serving_funcs: A array of Jax functions with all the weights applied 1838 already. 1839 inputs: A array of Jax input placeholders tuples list, e.g., 1840 jnp.zeros(INPUT_SHAPE). Each tuple list should correspond with the 1841 serving function. 1842 1843 Returns: 1844 TFLiteConverter object. 1845 """ 1846 # pylint: disable=protected-access 1847 TFLiteConverterBase._set_original_model_type( 1848 conversion_metdata_fb.ModelType.JAX) 1849 # pylint: enable=protected-access 1850 return TFLiteJaxConverterV2(serving_funcs, inputs) 1851 1852 # pylint: disable=useless-super-delegation 1853 def convert(self): 1854 """Converts a TensorFlow GraphDef based on instance variables. 1855 1856 Returns: 1857 The converted data in serialized format. 1858 1859 Raises: 1860 ValueError: 1861 No concrete functions is specified. 1862 Multiple concrete functions are specified. 1863 Input shape is not specified. 1864 Invalid quantization parameters. 1865 """ 1866 return super(TFLiteConverterV2, self).convert() 1867 1868 1869class TFLiteConverterBaseV1(TFLiteConverterBase): 1870 """Converter subclass to share functionality between V1 converters.""" 1871 1872 def __init__(self, experimental_debug_info_func): 1873 """Constructor for TFLiteConverter. 1874 1875 Args: 1876 experimental_debug_info_func: An experimental function to retrieve the 1877 graph debug info for a set of nodes from the `graph_def`. 1878 """ 1879 super(TFLiteConverterBaseV1, self).__init__() 1880 self.inference_type = _dtypes.float32 1881 self.inference_input_type = None 1882 self.inference_output_type = None 1883 self.output_format = constants.TFLITE 1884 self.quantized_input_stats = {} 1885 self.default_ranges_stats = None 1886 self.drop_control_dependency = True 1887 self.reorder_across_fake_quant = False 1888 self.change_concat_input_ranges = False 1889 self.dump_graphviz_dir = None 1890 self.dump_graphviz_video = False 1891 self.conversion_summary_dir = None 1892 self._debug_info_func = experimental_debug_info_func 1893 self._metadata.environment.apiVersion = 1 1894 1895 def __setattr__(self, name, value): 1896 if name == "post_training_quantize": 1897 warnings.warn("Property %s is deprecated, " 1898 "please use optimizations=[Optimize.DEFAULT]" 1899 " instead." % name) 1900 if value: 1901 self.optimizations = [Optimize.DEFAULT] 1902 else: 1903 self.optimizations = [] 1904 return 1905 if name == "target_ops": 1906 warnings.warn("Property %s is deprecated, please use " 1907 "target_spec.supported_ops instead." % name) 1908 self.target_spec.supported_ops = value 1909 return 1910 object.__setattr__(self, name, value) 1911 1912 def __getattribute__(self, name): 1913 if name == "post_training_quantize": 1914 warnings.warn("Property %s is deprecated, " 1915 "please use optimizations=[Optimize.DEFAULT]" 1916 " instead." % name) 1917 return Optimize.DEFAULT in set(self.optimizations) 1918 if name == "target_ops": 1919 warnings.warn("Property %s is deprecated, please use " 1920 "target_spec.supported_ops instead." % name) 1921 return self.target_spec.supported_ops 1922 return object.__getattribute__(self, name) 1923 1924 def _validate_quantized_input_stats(self, converter_kwargs, quant_mode): 1925 """Ensure the `quantized_input_stats` flag is provided if required.""" 1926 1927 quantized_types = frozenset({_dtypes.int8, _dtypes.uint8}) 1928 1929 requires_quantized_input_stats = ( 1930 (converter_kwargs["inference_type"] in quantized_types or 1931 converter_kwargs["inference_input_type"] in quantized_types) and 1932 not quant_mode.is_post_training_integer_quantization()) 1933 1934 if (requires_quantized_input_stats and 1935 not converter_kwargs["quantized_input_stats"]): 1936 raise ValueError( 1937 "The `quantized_input_stats` flag must be defined when either " 1938 "`inference_type` flag or `inference_input_type` flag is set to " 1939 "tf.int8 or tf.uint8. Currently, `inference_type={}` and " 1940 "`inference_input_type={}`.".format( 1941 _get_tf_type_name(converter_kwargs["inference_type"]), 1942 _get_tf_type_name(converter_kwargs["inference_input_type"]))) 1943 1944 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS) 1945 def _validate_inputs(self, input_tensors, quantized_input_stats): 1946 """Validate input parameters. 1947 1948 Args: 1949 input_tensors: List of input tensors. 1950 quantized_input_stats: Map of input tensor names to a tuple of floats 1951 representing the mean and standard deviation of the training data. 1952 1953 Raises: 1954 ValueError: 1955 Input shape is not specified. 1956 Quantization input stats is required but not provided. 1957 """ 1958 1959 if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()): 1960 # Checks dimensions in input tensor. 1961 for tensor in input_tensors: 1962 shape = tensor.shape 1963 if not shape: 1964 raise ValueError("Provide an input shape for input array " 1965 "'{0}'.".format(_get_tensor_name(tensor))) 1966 # Note that shape_list might be empty for scalar shapes. 1967 shape_list = shape.as_list() 1968 if None in shape_list[1:]: 1969 raise ValueError( 1970 "None is only supported in the 1st dimension. Tensor '{0}' has " 1971 "invalid shape '{1}'.".format( 1972 _get_tensor_name(tensor), shape_list)) 1973 elif shape_list and shape_list[0] is None: 1974 self._set_batch_size(batch_size=1) 1975 1976 # Get quantization stats. Ensures there is one stat per name if the stats 1977 # are specified. 1978 if quantized_input_stats: 1979 self._quantized_stats = [] 1980 invalid_stats = [] 1981 for name in self.get_input_arrays(): 1982 if name in quantized_input_stats: 1983 self._quantized_stats.append(quantized_input_stats[name]) 1984 else: 1985 invalid_stats.append(name) 1986 1987 if invalid_stats: 1988 raise ValueError("Quantization input stats are not available for input " 1989 "tensors '{0}'.".format(",".join(invalid_stats))) 1990 else: 1991 self._quantized_stats = None 1992 1993 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL) 1994 def _optimize_tf_model(self, graph_def, input_tensors, output_tensors, 1995 quant_mode): 1996 """Run a Grappler pass to optimize the TensorFlow graph. 1997 1998 Args: 1999 graph_def: Frozen GraphDef to be optimized. 2000 input_tensors: List of input tensors. 2001 output_tensors: List of output tensors. 2002 quant_mode: the quantization mode. 2003 2004 Returns: 2005 The optimized TensorFlow graph. 2006 """ 2007 # Disable grappler constant folding if there are training quant ops. 2008 if self.saved_model_dir or quant_mode.is_quantization_aware_trained_model(): 2009 return graph_def 2010 2011 try: 2012 # TODO(b/150163103): Merge `disabling lower using switch merge' calls. 2013 # Grappler will also try to lower while loop into switch merge 2014 # representation which is undesired for Ophints, so we simply remove 2015 # those attributes to prevent Grappler from doing so. 2016 graph = _convert_to_constants.disable_lower_using_switch_merge(graph_def) 2017 # Run function inlining optimization to ensure any models generated 2018 # through the from_frozen_graph path have been inlined. 2019 optimized_graph = _run_graph_optimizations( 2020 graph, 2021 input_tensors, 2022 output_tensors, 2023 config=self._grappler_config(["function"])) 2024 return optimized_graph 2025 except Exception: # pylint: disable=broad-except 2026 return graph_def 2027 2028 def convert(self): 2029 """Converts a TensorFlow GraphDef based on instance variables. 2030 2031 Returns: 2032 The converted data in serialized format. Either a TFLite Flatbuffer or a 2033 Graphviz graph depending on value in `output_format`. 2034 2035 Raises: 2036 ValueError: 2037 Input shape is not specified. 2038 None value for dimension in input_tensor. 2039 """ 2040 self._validate_inputs(self._input_tensors, self.quantized_input_stats) 2041 2042 quant_mode = QuantizationMode( 2043 self.optimizations, self.target_spec, self.representative_dataset, 2044 self._graph_def, self._experimental_disable_per_channel, 2045 self.experimental_new_dynamic_range_quantizer, 2046 self._experimental_low_bit_qat, 2047 self._experimental_full_integer_quantization_bias_type) 2048 2049 optimized_graph = self._optimize_tf_model(self._graph_def, 2050 self._input_tensors, 2051 self._output_tensors, quant_mode) 2052 2053 self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph) 2054 2055 converter_kwargs = self._get_base_converter_args() 2056 converter_kwargs.update( 2057 quant_mode.converter_flags(self.inference_type, 2058 self.inference_input_type)) 2059 converter_kwargs.update({ 2060 "output_format": self.output_format, 2061 "quantized_input_stats": self._quantized_stats, 2062 "default_ranges_stats": self.default_ranges_stats, 2063 "drop_control_dependency": self.drop_control_dependency, 2064 "reorder_across_fake_quant": self.reorder_across_fake_quant, 2065 "change_concat_input_ranges": self.change_concat_input_ranges, 2066 "dump_graphviz_dir": self.dump_graphviz_dir, 2067 "dump_graphviz_video": self.dump_graphviz_video, 2068 "conversion_summary_dir": self.conversion_summary_dir, 2069 }) 2070 2071 self._validate_quantized_input_stats(converter_kwargs, quant_mode) 2072 if not self.experimental_new_converter: 2073 logging.warning( 2074 "Please consider switching to the new converter by setting " 2075 "experimental_new_converter=True. " 2076 "The old converter is deprecated.") 2077 else: 2078 logging.info("Using experimental converter: If you encountered a problem " 2079 "please file a bug. You can opt-out " 2080 "by setting experimental_new_converter=False") 2081 # Converts model. 2082 if self._has_valid_tensors(): 2083 result = _convert_graphdef( 2084 input_data=optimized_graph, 2085 input_tensors=self._input_tensors, 2086 output_tensors=self._output_tensors, 2087 **converter_kwargs) 2088 else: 2089 result = _convert_graphdef_with_arrays( 2090 input_data=optimized_graph, 2091 input_arrays_with_shape=self._input_arrays_with_shape, 2092 output_arrays=self._output_arrays, 2093 control_output_arrays=self._control_output_arrays, 2094 **converter_kwargs) 2095 2096 return self._optimize_tflite_model( 2097 result, quant_mode, quant_io=self.experimental_new_quantizer) 2098 2099 def get_input_arrays(self): 2100 """Returns a list of the names of the input tensors. 2101 2102 Returns: 2103 List of strings. 2104 """ 2105 if self._has_valid_tensors(): 2106 return [_get_tensor_name(tensor) for tensor in self._input_tensors] 2107 else: 2108 return [name for name, _ in self._input_arrays_with_shape] 2109 2110 def _has_valid_tensors(self): 2111 """Checks if the input and output tensors have been initialized. 2112 2113 Returns: 2114 Bool. 2115 """ 2116 return self._input_tensors is not None and self._output_tensors 2117 2118 def _set_batch_size(self, batch_size): 2119 """Sets the first dimension of the input tensor to `batch_size`. 2120 2121 Args: 2122 batch_size: Batch size for the model. Replaces the first dimension of an 2123 input size array if undefined. (default 1) 2124 2125 Raises: 2126 ValueError: input_tensor is not defined. 2127 """ 2128 if not self._has_valid_tensors(): 2129 raise ValueError("The batch size cannot be set for this model. Please " 2130 "use input_shapes parameter.") 2131 2132 for tensor in self._input_tensors: 2133 shape = tensor.shape.as_list() 2134 if shape[0] is None: 2135 shape[0] = batch_size 2136 tensor.set_shape(shape) 2137 2138 def _is_unknown_shapes_allowed(self): 2139 # Ophint Converted nodes will need the shapes to be known. 2140 if _is_ophint_converted(self._graph_def): 2141 return False 2142 2143 if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed(): 2144 return False 2145 2146 # `conversion_summary_dir` calls the old converter. Unknown shapes are only 2147 # supported by the MLIR converter. 2148 if self.conversion_summary_dir: 2149 logging.warning( 2150 "`conversion_summary_dir` does not work with unknown shapes. " 2151 "Graphs with unknown shapes might be different than when this flag " 2152 "is disabled.") 2153 return False 2154 return True 2155 2156 def _save_conversion_params_metric(self): 2157 self._collected_converter_params.update({ 2158 "output_format": self.output_format, 2159 "default_ranges_stats": self.default_ranges_stats, 2160 "drop_control_dependency": self.drop_control_dependency, 2161 "reorder_across_fake_quant": self.reorder_across_fake_quant, 2162 "change_concat_input_ranges": self.change_concat_input_ranges, 2163 "dump_graphviz_dir": self.dump_graphviz_dir, 2164 "dump_graphviz_video": self.dump_graphviz_video, 2165 "conversion_summary_dir": self.conversion_summary_dir, 2166 }) 2167 super(TFLiteConverterBaseV1, 2168 self)._save_conversion_params_metric(self._graph_def, 2169 self.inference_type, 2170 self.inference_input_type) 2171 2172 2173class TFLiteSavedModelConverter(TFLiteConverterBaseV1): 2174 """Converts the given SavedModel into TensorFlow Lite model. 2175 2176 Attributes: 2177 saved_model_dir: Directory of the SavedModel. 2178 """ 2179 2180 def __init__(self, 2181 saved_model_dir, 2182 saved_model_tags, 2183 saved_model_exported_names, 2184 experimental_debug_info_func=None): 2185 """Constructor for TFLiteConverter. 2186 2187 Args: 2188 saved_model_dir: Directory of the SavedModel. 2189 saved_model_tags: Set of tags identifying the MetaGraphDef within the 2190 SavedModel to analyze. All tags in the tag set must be present. (default 2191 {tf.saved_model.SERVING}). 2192 saved_model_exported_names: Names to be exported when the saved model 2193 import path is on. 2194 experimental_debug_info_func: An experimental function to retrieve the 2195 graph debug info for a set of nodes from the `graph_def`. 2196 2197 Raises: 2198 ValueError: Invalid arguments. 2199 """ 2200 super(TFLiteSavedModelConverter, 2201 self).__init__(experimental_debug_info_func) 2202 self.saved_model_dir = saved_model_dir 2203 self._saved_model_tags = saved_model_tags 2204 self._saved_model_exported_names = saved_model_exported_names 2205 2206 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 2207 2208 if len(self._saved_model_exported_names) != 1: 2209 raise ValueError("Only support a single signature key.") 2210 2211 signature_key = self._saved_model_exported_names[0] 2212 2213 result = _freeze_saved_model(self.saved_model_dir, None, None, None, 2214 self._saved_model_tags, signature_key) 2215 self._graph_def = result[0] 2216 self._input_tensors = result[1] 2217 self._output_tensors = result[2] 2218 self._parse_saved_model_args() 2219 2220 @_export_metrics 2221 def convert(self): 2222 """Converts a TensorFlow GraphDef based on instance variables. 2223 2224 Returns: 2225 The converted data in serialized format. Either a TFLite Flatbuffer or a 2226 Graphviz graph depending on value in `output_format`. 2227 2228 Raises: 2229 ValueError: 2230 Input shape is not specified. 2231 None value for dimension in input_tensor. 2232 """ 2233 return super(TFLiteSavedModelConverter, self).convert() 2234 2235 2236class TFLiteKerasModelConverter(TFLiteConverterBaseV1): 2237 """Converts the given SavedModel into TensorFlow Lite model.""" 2238 2239 def __init__(self, 2240 model_file, 2241 input_arrays=None, 2242 input_shapes=None, 2243 output_arrays=None, 2244 custom_objects=None): 2245 """Constructor for TFLiteConverter. 2246 2247 Args: 2248 model_file: Full filepath of HDF5 file containing the tf.keras model. 2249 input_arrays: List of input tensors to freeze graph with. Uses input 2250 arrays from SignatureDef when none are provided. (default None) 2251 input_shapes: Dict of strings representing input tensor names to list of 2252 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 2253 Automatically determined when input shapes is None (e.g., {"foo" : 2254 None}). (default None) 2255 output_arrays: List of output tensors to freeze graph with. Uses output 2256 arrays from SignatureDef when none are provided. (default None) 2257 custom_objects: Dict mapping names (strings) to custom classes or 2258 functions to be considered during model deserialization. (default None) 2259 2260 Raises: 2261 ValueError: Invalid arguments. 2262 """ 2263 super(TFLiteKerasModelConverter, 2264 self).__init__(experimental_debug_info_func=None) 2265 # Handles Keras when Eager mode is enabled. 2266 if context.executing_eagerly(): 2267 if input_arrays or output_arrays: 2268 raise ValueError("`input_arrays` and `output_arrays` are unsupported " 2269 "with Eager mode. If your model requires any of these " 2270 "parameters, please use disable_eager_execution().") 2271 2272 keras_model = keras_deps.get_load_model_function()(model_file, 2273 custom_objects) 2274 function = _trace_model_call(keras_model) 2275 concrete_func = function.get_concrete_function() 2276 2277 frozen_func = _convert_to_constants.convert_variables_to_constants_v2( 2278 concrete_func, lower_control_flow=False) 2279 _set_tensor_shapes(frozen_func.inputs, input_shapes) 2280 self._keras_model = keras_model 2281 self._graph_def = frozen_func.graph.as_graph_def() 2282 self._input_tensors = frozen_func.inputs 2283 self._output_tensors = frozen_func.outputs 2284 self._debug_info_func = _build_debug_info_func(frozen_func.graph) 2285 return 2286 2287 # Handles Keras when Eager mode is disabled. 2288 keras_deps.get_clear_session_function()() 2289 keras_model = keras_deps.get_load_model_function()(model_file, 2290 custom_objects) 2291 sess = keras_deps.get_get_session_function()() 2292 2293 # Get input and output tensors. 2294 if input_arrays: 2295 input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) 2296 else: 2297 input_tensors = keras_model.inputs 2298 2299 if output_arrays: 2300 output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) 2301 else: 2302 output_tensors = keras_model.outputs 2303 _set_tensor_shapes(input_tensors, input_shapes) 2304 2305 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 2306 self._keras_model = keras_model 2307 self._graph_def = graph_def 2308 self._input_tensors = input_tensors 2309 self._output_tensors = output_tensors 2310 self._debug_info_func = _build_debug_info_func(sess.graph) 2311 2312 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL) 2313 def _freeze_keras_model(self, output_dir): 2314 """Save Keras model to Saved Model format. 2315 2316 Args: 2317 output_dir: The output directory to save the SavedModel. 2318 """ 2319 try: 2320 self._keras_model.save(output_dir, save_format="tf") 2321 except Exception: # pylint: disable=broad-except 2322 # When storing the given keras model to a saved model is failed, let's 2323 # use original keras model conversion pipeline. 2324 return None 2325 tag_set = set([_tag_constants.SERVING]) 2326 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 2327 graph_def, input_tensors, output_tensors, sess_graph = _freeze_saved_model( 2328 output_dir, None, None, None, tag_set, signature_key) 2329 2330 self.saved_model_dir = output_dir 2331 self._saved_model_tags = tag_set 2332 self._saved_model_exported_names = [signature_key] 2333 self._parse_saved_model_args() 2334 if self.saved_model_dir: 2335 self._graph_def = graph_def 2336 self._input_tensors = input_tensors 2337 self._output_tensors = output_tensors 2338 self._debug_info_func = _build_debug_info_func(sess_graph) 2339 2340 def _convert_as_saved_model(self): 2341 """Converts a Keras model as a saved model. 2342 2343 Returns: 2344 The converted data in serialized format. 2345 """ 2346 temp_dir = tempfile.mkdtemp() 2347 try: 2348 self._freeze_keras_model(temp_dir) 2349 if self.saved_model_dir: 2350 return super(TFLiteKerasModelConverter, self).convert() 2351 finally: 2352 shutil.rmtree(temp_dir, True) 2353 2354 @_export_metrics 2355 def convert(self): 2356 """Converts a Keras model based on instance variables. 2357 2358 Returns: 2359 The converted data in serialized format. Either a TFLite Flatbuffer or a 2360 Graphviz graph depending on value in `output_format`. 2361 2362 Raises: 2363 ValueError: 2364 Input shape is not specified. 2365 None value for dimension in input_tensor. 2366 """ 2367 saved_model_convert_result = self._convert_as_saved_model() 2368 if saved_model_convert_result: 2369 return saved_model_convert_result 2370 2371 return super(TFLiteKerasModelConverter, self).convert() 2372 2373 2374class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1): 2375 """Converts the given frozen graph def into TensorFlow Lite model.""" 2376 2377 def __init__(self, 2378 graph_def, 2379 input_tensors, 2380 output_tensors, 2381 input_arrays_with_shape=None, 2382 output_arrays=None, 2383 experimental_debug_info_func=None): 2384 """Constructor for TFLiteConverter. 2385 2386 Args: 2387 graph_def: Frozen TensorFlow GraphDef. 2388 input_tensors: List of input tensors. Type and shape are computed using 2389 `foo.shape` and `foo.dtype`. 2390 output_tensors: List of output tensors (only .name is used from this). 2391 input_arrays_with_shape: Tuple of strings representing input tensor names 2392 and list of integers representing input shapes 2393 (e.g., [("foo", [1, 16, 16, 3])]). Use only when graph cannot be loaded 2394 into TensorFlow and when `input_tensors` and `output_tensors` are 2395 None. (default None) 2396 output_arrays: List of output tensors to freeze graph with. Use only when 2397 graph cannot be loaded into TensorFlow and when `input_tensors` and 2398 `output_tensors` are None. (default None) 2399 experimental_debug_info_func: An experimental function to retrieve the 2400 graph debug info for a set of nodes from the `graph_def`. 2401 2402 Raises: 2403 ValueError: Invalid arguments. 2404 """ 2405 super(TFLiteFrozenGraphConverter, 2406 self).__init__(experimental_debug_info_func) 2407 self._graph_def = graph_def 2408 self._input_tensors = input_tensors 2409 self._output_tensors = output_tensors 2410 self._control_output_arrays = None 2411 2412 # Attributes are used by models that cannot be loaded into TensorFlow. 2413 if not self._has_valid_tensors(): 2414 self._input_arrays_with_shape = input_arrays_with_shape 2415 self._output_arrays = output_arrays 2416 2417 if input_tensors is not None and input_arrays_with_shape is not None: 2418 logging.warning("input_arrays_with_shape will be ignored when both the " 2419 "given input_tensors and input_arrays_with_shape are not " 2420 "None.") 2421 2422 if output_tensors is not None and output_arrays is not None: 2423 logging.warning("output_arrays will be ignored when both the given " 2424 "output_tensors and output_arrays are not None.") 2425 2426 @_export_metrics 2427 def convert(self): 2428 """Converts a TensorFlow GraphDef based on instance variables. 2429 2430 Returns: 2431 The converted data in serialized format. Either a TFLite Flatbuffer or a 2432 Graphviz graph depending on value in `output_format`. 2433 2434 Raises: 2435 ValueError: 2436 Input shape is not specified. 2437 None value for dimension in input_tensor. 2438 """ 2439 if not self._has_valid_tensors(): 2440 if not self._input_arrays_with_shape or not (self._output_arrays or 2441 self._control_output_arrays): 2442 raise ValueError( 2443 "If input_tensors and output_tensors are None, both " 2444 "input_arrays_with_shape and output_arrays|control_output_arrays " 2445 "must be defined.") 2446 return super(TFLiteFrozenGraphConverter, self).convert() 2447 2448 2449@_tf_export(v1=["lite.TFLiteConverter"]) 2450class TFLiteConverter(TFLiteFrozenGraphConverter): 2451 """Convert a TensorFlow model into `output_format`. 2452 2453 This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras 2454 model into either a TFLite FlatBuffer or graph visualization. 2455 2456 Attributes: 2457 optimizations: Experimental flag, subject to change. Set of optimizations to 2458 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 2459 set of values of type `tf.lite.Optimize`) 2460 representative_dataset: A generator function used for integer quantization 2461 where each generated sample has the same order, type and shape as the 2462 inputs to the model. Usually, this is a small subset of a few hundred 2463 samples randomly chosen, in no particular order, from the training or 2464 evaluation dataset. This is an optional attribute, but required for full 2465 integer quantization, i.e, if `tf.int8` is the only supported type in 2466 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 2467 (default None) 2468 target_spec: Experimental flag, subject to change. Specifications of target 2469 device, including supported ops set, supported types and a set of user's 2470 defined TensorFlow operators required in the TensorFlow Lite runtime. 2471 Refer to `tf.lite.TargetSpec`. 2472 inference_type: Data type of numeric arrays, excluding the input layer. 2473 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 2474 inference_input_type: Data type of the numeric arrays in the input layer. If 2475 `inference_input_type` is in {tf.int8, tf.uint8}, then 2476 `quantized_input_stats` must be provided. (default is the value assigned 2477 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 2478 inference_output_type: Data type of the numeric arrays in the output layer. 2479 (default is the value assigned to `inference_type`, must be in 2480 {tf.float32, tf.int8, tf.uint8}) 2481 quantized_input_stats: Map of input tensor names to a tuple of floats 2482 representing the mean and standard deviation of the training data. 2483 (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8 2484 or tf.uint8. (default None) 2485 default_ranges_stats: Tuple of integers (min, max) representing range values 2486 for all numeric arrays without a specified range. Intended for 2487 experimenting with quantization via "dummy quantization". (default None) 2488 allow_custom_ops: Boolean indicating whether to allow custom operations. 2489 When False any unknown operation is an error. When True, custom ops are 2490 created for any op that is unknown. The developer will need to provide 2491 these to the TensorFlow Lite runtime with a custom resolver. (default 2492 False) 2493 drop_control_dependency: Boolean indicating whether to drop control 2494 dependencies silently. This is due to TFLite not supporting control 2495 dependencies. (default True) 2496 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 2497 nodes in unexpected locations. Used when the location of the FakeQuant 2498 nodes is preventing graph transformations necessary to convert the graph. 2499 Results in a graph that differs from the quantized training graph, 2500 potentially causing differing arithmetic behavior. (default False) 2501 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 2502 inputs and outputs of the concat operator for quantized models. Changes 2503 the ranges of concat operator overlap when true. (default False) 2504 output_format: Output file format. (default 2505 tf.compat.v1.lite.constants.TFLITE, must be in 2506 {tf.compat.v1.lite.constants.TFLITE, 2507 tf.compat.v1.lite.constants.GRAPHVIZ_DOT}) 2508 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 2509 stages of processing GraphViz .dot files. Preferred over 2510 `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep 2511 the requirements of the output file. (default None) 2512 dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot 2513 files after every graph transformation. Requires the `dump_graphviz_dir` 2514 flag to be specified. (default False) 2515 conversion_summary_dir: Full path of the directory to store conversion logs. 2516 (default None) 2517 exclude_conversion_metadata: Whether not to embed the conversion metadata 2518 into the converted model. (default False) 2519 target_ops: Deprecated. Please use `target_spec.supported_ops` instead. 2520 post_training_quantize: Deprecated. Please use `optimizations` instead and 2521 set it to `{tf.lite.Optimize.DEFAULT}`. (default False) 2522 experimental_new_converter: Experimental flag, subject to change. Enables 2523 MLIR-based conversion. (default True) 2524 experimental_new_quantizer: Experimental flag, subject to change. Enables 2525 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 2526 (default True) 2527 2528 Example usage: 2529 2530 ```python 2531 # Converting a GraphDef from session. 2532 converter = tf.compat.v1.lite.TFLiteConverter.from_session( 2533 sess, in_tensors, out_tensors) 2534 tflite_model = converter.convert() 2535 open("converted_model.tflite", "wb").write(tflite_model) 2536 2537 # Converting a GraphDef from file. 2538 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( 2539 graph_def_file, input_arrays, output_arrays) 2540 tflite_model = converter.convert() 2541 open("converted_model.tflite", "wb").write(tflite_model) 2542 2543 # Converting a SavedModel. 2544 converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model( 2545 saved_model_dir) 2546 tflite_model = converter.convert() 2547 open("converted_model.tflite", "wb").write(tflite_model) 2548 2549 # Converting a tf.keras model. 2550 converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file( 2551 keras_model) 2552 tflite_model = converter.convert() 2553 open("converted_model.tflite", "wb").write(tflite_model) 2554 ``` 2555 """ 2556 2557 # pylint: disable=useless-super-delegation 2558 def __init__(self, 2559 graph_def, 2560 input_tensors, 2561 output_tensors, 2562 input_arrays_with_shape=None, 2563 output_arrays=None, 2564 experimental_debug_info_func=None): 2565 """Constructor for TFLiteConverter. 2566 2567 Args: 2568 graph_def: Frozen TensorFlow GraphDef. 2569 input_tensors: List of input tensors. Type and shape are computed using 2570 `foo.shape` and `foo.dtype`. 2571 output_tensors: List of output tensors (only .name is used from this). 2572 input_arrays_with_shape: Tuple of strings representing input tensor names 2573 and list of integers representing input shapes 2574 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 2575 into TensorFlow and when `input_tensors` and `output_tensors` are 2576 None. (default None) 2577 output_arrays: List of output tensors to freeze graph with. Use only when 2578 graph cannot be loaded into TensorFlow and when `input_tensors` and 2579 `output_tensors` are None. (default None) 2580 experimental_debug_info_func: An experimental function to retrieve the 2581 graph debug info for a set of nodes from the `graph_def`. 2582 2583 Raises: 2584 ValueError: Invalid arguments. 2585 """ 2586 super(TFLiteConverter, 2587 self).__init__(graph_def, input_tensors, output_tensors, 2588 input_arrays_with_shape, output_arrays, 2589 experimental_debug_info_func) 2590 2591 @classmethod 2592 def from_session(cls, sess, input_tensors, output_tensors): 2593 """Creates a TFLiteConverter class from a TensorFlow Session. 2594 2595 Args: 2596 sess: TensorFlow Session. 2597 input_tensors: List of input tensors. Type and shape are computed using 2598 `foo.shape` and `foo.dtype`. 2599 output_tensors: List of output tensors (only .name is used from this). 2600 2601 Returns: 2602 TFLiteConverter class. 2603 """ 2604 # pylint: disable=protected-access 2605 TFLiteConverterBase._set_original_model_type( 2606 conversion_metdata_fb.ModelType.TF_SESSION) 2607 # pylint: enable=protected-access 2608 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 2609 return cls( 2610 graph_def, 2611 input_tensors, 2612 output_tensors, 2613 experimental_debug_info_func=_build_debug_info_func(sess.graph)) 2614 2615 @classmethod 2616 def from_frozen_graph(cls, 2617 graph_def_file, 2618 input_arrays, 2619 output_arrays, 2620 input_shapes=None): 2621 """Creates a TFLiteConverter class from a file containing a frozen GraphDef. 2622 2623 Args: 2624 graph_def_file: Full filepath of file containing frozen GraphDef. 2625 input_arrays: List of input tensors to freeze graph with. 2626 output_arrays: List of output tensors to freeze graph with. 2627 input_shapes: Dict of strings representing input tensor names to list of 2628 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 2629 Automatically determined when input shapes is None (e.g., {"foo" : 2630 None}). (default None) 2631 2632 Returns: 2633 TFLiteConverter class. 2634 2635 Raises: 2636 IOError: 2637 File not found. 2638 Unable to parse input file. 2639 ValueError: 2640 The graph is not frozen. 2641 input_arrays or output_arrays contains an invalid tensor name. 2642 input_shapes is not correctly defined when required 2643 """ 2644 # pylint: disable=protected-access 2645 TFLiteConverterBase._set_original_model_type( 2646 conversion_metdata_fb.ModelType.TF_GRAPH_DEF) 2647 # pylint: enable=protected-access 2648 with _ops.Graph().as_default(): 2649 with _session.Session() as sess: 2650 # Read GraphDef from file. 2651 if not gfile.Exists(graph_def_file): 2652 raise IOError("File '{0}' does not exist.".format(graph_def_file)) 2653 with gfile.GFile(graph_def_file, "rb") as f: 2654 file_content = f.read() 2655 2656 try: 2657 graph_def = _graph_pb2.GraphDef() 2658 graph_def.ParseFromString(file_content) 2659 except (_text_format.ParseError, DecodeError): 2660 try: 2661 print("Ignore 'tcmalloc: large alloc' warnings.") 2662 2663 if not isinstance(file_content, str): 2664 file_content = file_content.decode("utf-8") 2665 graph_def = _graph_pb2.GraphDef() 2666 _text_format.Merge(file_content, graph_def) 2667 except (_text_format.ParseError, DecodeError): 2668 raise IOError( 2669 "Unable to parse input file '{}'.".format(graph_def_file)) 2670 2671 # Handles models with custom TFLite ops that cannot be resolved in 2672 # TensorFlow. 2673 load_model_in_session = True 2674 try: 2675 _import_graph_def(graph_def, name="") 2676 except _NotFoundError: 2677 load_model_in_session = False 2678 2679 if load_model_in_session: 2680 # Check if graph is frozen. 2681 if not _is_frozen_graph(sess): 2682 raise ValueError("Please freeze the graph using freeze_graph.py.") 2683 2684 # Get input and output tensors. 2685 input_tensors = _get_tensors_from_tensor_names( 2686 sess.graph, input_arrays) 2687 output_tensors = _get_tensors_from_tensor_names( 2688 sess.graph, output_arrays) 2689 _set_tensor_shapes(input_tensors, input_shapes) 2690 2691 return cls(sess.graph_def, input_tensors, output_tensors) 2692 else: 2693 if not input_shapes: 2694 raise ValueError("input_shapes must be defined for this model.") 2695 if set(input_arrays) != set(input_shapes.keys()): 2696 raise ValueError("input_shapes must contain a value for each item " 2697 "in input_array.") 2698 2699 input_arrays_with_shape = [ 2700 (name, input_shapes[name]) for name in input_arrays 2701 ] 2702 return cls( 2703 graph_def, 2704 input_tensors=None, 2705 output_tensors=None, 2706 input_arrays_with_shape=input_arrays_with_shape, 2707 output_arrays=output_arrays) 2708 2709 @classmethod 2710 def from_saved_model(cls, 2711 saved_model_dir, 2712 input_arrays=None, 2713 input_shapes=None, 2714 output_arrays=None, 2715 tag_set=None, 2716 signature_key=None): 2717 """Creates a TFLiteConverter class from a SavedModel. 2718 2719 Args: 2720 saved_model_dir: SavedModel directory to convert. 2721 input_arrays: List of input tensors to freeze graph with. Uses input 2722 arrays from SignatureDef when none are provided. (default None) 2723 input_shapes: Dict of strings representing input tensor names to list of 2724 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 2725 Automatically determined when input shapes is None (e.g., {"foo" : 2726 None}). (default None) 2727 output_arrays: List of output tensors to freeze graph with. Uses output 2728 arrays from SignatureDef when none are provided. (default None) 2729 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 2730 analyze. All tags in the tag set must be present. (default 2731 {tf.saved_model.SERVING}) 2732 signature_key: Key identifying SignatureDef containing inputs and outputs. 2733 (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 2734 2735 Returns: 2736 TFLiteConverter class. 2737 """ 2738 # pylint: disable=protected-access 2739 TFLiteConverterBase._set_original_model_type( 2740 conversion_metdata_fb.ModelType.TF_SAVED_MODEL) 2741 # pylint: enable=protected-access 2742 if tag_set is None: 2743 tag_set = set([_tag_constants.SERVING]) 2744 if signature_key is None: 2745 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 2746 2747 saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set, 2748 [signature_key]) 2749 if saved_model_converter.saved_model_dir: 2750 return saved_model_converter 2751 2752 result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, 2753 output_arrays, tag_set, signature_key) 2754 2755 return cls( 2756 graph_def=result[0], 2757 input_tensors=result[1], 2758 output_tensors=result[2], 2759 experimental_debug_info_func=_build_debug_info_func(result[3])) 2760 2761 @classmethod 2762 def from_keras_model_file(cls, 2763 model_file, 2764 input_arrays=None, 2765 input_shapes=None, 2766 output_arrays=None, 2767 custom_objects=None): 2768 """Creates a TFLiteConverter class from a tf.keras model file. 2769 2770 Args: 2771 model_file: Full filepath of HDF5 file containing the tf.keras model. 2772 input_arrays: List of input tensors to freeze graph with. Uses input 2773 arrays from SignatureDef when none are provided. (default None) 2774 input_shapes: Dict of strings representing input tensor names to list of 2775 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 2776 Automatically determined when input shapes is None (e.g., {"foo" : 2777 None}). (default None) 2778 output_arrays: List of output tensors to freeze graph with. Uses output 2779 arrays from SignatureDef when none are provided. (default None) 2780 custom_objects: Dict mapping names (strings) to custom classes or 2781 functions to be considered during model deserialization. (default None) 2782 2783 Returns: 2784 TFLiteConverter class. 2785 """ 2786 # pylint: disable=protected-access 2787 TFLiteConverterBase._set_original_model_type( 2788 conversion_metdata_fb.ModelType.KERAS_MODEL) 2789 # pylint: enable=protected-access 2790 return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes, 2791 output_arrays, custom_objects) 2792 2793 # pylint: disable=useless-super-delegation 2794 def convert(self): 2795 """Converts a TensorFlow GraphDef based on instance variables. 2796 2797 Returns: 2798 The converted data in serialized format. Either a TFLite Flatbuffer or a 2799 Graphviz graph depending on value in `output_format`. 2800 2801 Raises: 2802 ValueError: 2803 Input shape is not specified. 2804 None value for dimension in input_tensor. 2805 """ 2806 return super(TFLiteConverter, self).convert() 2807 2808 2809@_tf_export(v1=["lite.TocoConverter"]) 2810class TocoConverter: 2811 """Convert a TensorFlow model into `output_format`. 2812 2813 This class has been deprecated. Please use `lite.TFLiteConverter` instead. 2814 """ 2815 2816 @classmethod 2817 @_deprecation.deprecated(None, 2818 "Use `lite.TFLiteConverter.from_session` instead.") 2819 def from_session(cls, sess, input_tensors, output_tensors): 2820 """Creates a TocoConverter class from a TensorFlow Session.""" 2821 return TFLiteConverter.from_session(sess, input_tensors, output_tensors) 2822 2823 @classmethod 2824 @_deprecation.deprecated( 2825 None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.") 2826 def from_frozen_graph(cls, 2827 graph_def_file, 2828 input_arrays, 2829 output_arrays, 2830 input_shapes=None): 2831 """Creates a TocoConverter class from a file containing a frozen graph.""" 2832 return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 2833 output_arrays, input_shapes) 2834 2835 @classmethod 2836 @_deprecation.deprecated( 2837 None, "Use `lite.TFLiteConverter.from_saved_model` instead.") 2838 def from_saved_model(cls, 2839 saved_model_dir, 2840 input_arrays=None, 2841 input_shapes=None, 2842 output_arrays=None, 2843 tag_set=None, 2844 signature_key=None): 2845 """Creates a TocoConverter class from a SavedModel.""" 2846 return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays, 2847 input_shapes, output_arrays, 2848 tag_set, signature_key) 2849 2850 @classmethod 2851 @_deprecation.deprecated( 2852 None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.") 2853 def from_keras_model_file(cls, 2854 model_file, 2855 input_arrays=None, 2856 input_shapes=None, 2857 output_arrays=None): 2858 """Creates a TocoConverter class from a tf.keras model file.""" 2859 return TFLiteConverter.from_keras_model_file(model_file, input_arrays, 2860 input_shapes, output_arrays) 2861