1# Lint as: python2, python3 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""TensorFlow Lite tooling helper functionality.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import enum 23import functools 24import pprint 25import shutil 26import tempfile 27import time 28import warnings 29 30from absl import logging 31import six 32from six import PY2 33 34from google.protobuf import text_format as _text_format 35from google.protobuf.message import DecodeError 36from tensorflow.core.framework import graph_pb2 as _graph_pb2 37from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import 38from tensorflow.lite.python import lite_constants as constants 39from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import 40from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model 41from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import 42from tensorflow.lite.python.convert import deduplicate_readonly_buffers as _deduplicate_readonly_buffers 43from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize 44from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify 45from tensorflow.lite.python.convert import OpsSet 46from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import 47from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def 48from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl 49from tensorflow.lite.python.convert import toco_convert_protos # pylint: disable=unused-import 50from tensorflow.lite.python.convert_phase import Component 51from tensorflow.lite.python.convert_phase import convert_phase 52from tensorflow.lite.python.convert_phase import SubComponent 53from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model 54from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import 55from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import 56from tensorflow.lite.python.interpreter import OpResolverType # pylint: disable=unused-import 57from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import 58from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted 59from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import 60from tensorflow.lite.python.optimize import calibrator as _calibrator 61from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func 62from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func 63from tensorflow.lite.python.util import freeze_graph as _freeze_graph 64from tensorflow.lite.python.util import get_debug_info as _get_debug_info 65from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config 66from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name 67from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names 68from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name 69from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph 70from tensorflow.lite.python.util import model_input_signature as _model_input_signature 71from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type 72from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations 73from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes 74from tensorflow.lite.python.util import trace_model_call as _trace_model_call 75from tensorflow.python import saved_model as _saved_model 76from tensorflow.python.client import session as _session 77from tensorflow.python.eager import context 78from tensorflow.python.eager import def_function as _def_function 79from tensorflow.python.eager import function as _function 80from tensorflow.python.framework import convert_to_constants as _convert_to_constants 81from tensorflow.python.framework import dtypes as _dtypes 82from tensorflow.python.framework import ops as _ops 83from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError 84from tensorflow.python.framework.importer import import_graph_def as _import_graph_def 85from tensorflow.python.platform import gfile 86from tensorflow.python.saved_model import loader_impl as _loader_impl 87from tensorflow.python.saved_model import save_options as _save_options 88from tensorflow.python.saved_model import signature_constants as _signature_constants 89from tensorflow.python.saved_model import tag_constants as _tag_constants 90from tensorflow.python.saved_model.load import load as _load 91from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info 92from tensorflow.python.util import deprecation as _deprecation 93from tensorflow.python.util import keras_deps 94from tensorflow.python.util.tf_export import tf_export as _tf_export 95 96# pylint: disable=g-import-not-at-top 97try: 98 from tensorflow.lite.python import metrics_portable as metrics 99except ImportError: 100 from tensorflow.lite.python import metrics_nonportable as metrics 101# pylint: enable=g-import-not-at-top 102 103 104@_tf_export("lite.Optimize") 105class Optimize(enum.Enum): 106 """Enum defining the optimizations to apply when generating a tflite model. 107 108 DEFAULT 109 Default optimization strategy that quantizes model weights. Enhanced 110 optimizations are gained by providing a representative dataset that 111 quantizes biases and activations as well. 112 Converter will do its best to reduce size and latency, while minimizing 113 the loss in accuracy. 114 115 OPTIMIZE_FOR_SIZE 116 Deprecated. Does the same as DEFAULT. 117 118 OPTIMIZE_FOR_LATENCY 119 Deprecated. Does the same as DEFAULT. 120 121 EXPERIMENTAL_SPARSITY 122 Experimental flag, subject to change. 123 124 Enable optimization by taking advantage of the sparse model weights 125 trained with pruning. 126 127 The converter will inspect the sparsity pattern of the model weights and 128 do its best to improve size and latency. 129 The flag can be used alone to optimize float32 models with sparse weights. 130 It can also be used together with the DEFAULT optimization mode to 131 optimize quantized models with sparse weights. 132 """ 133 134 # Default optimization strategy that quantizes model weights. Enhanced 135 # optimizations are gained by providing a representative dataset that 136 # quantizes biases and activations as well. 137 # Converter will do its best to reduce size and latency, while minimizing 138 # the loss in accuracy. 139 DEFAULT = "DEFAULT" 140 141 # Deprecated. Does the same as DEFAULT. 142 OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE" 143 144 # Deprecated. Does the same as DEFAULT. 145 OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY" 146 147 # Experimental flag, subject to change. 148 # Enable optimization by taking advantage of the sparse model weights trained 149 # with pruning. 150 # 151 # The converter will inspect the sparsity pattern of the model weights and do 152 # its best to improve size and latency. 153 # The flag can be used alone to optimize float32 models with sparse weights. 154 # It can also be used together with the DEFAULT optimization mode to optimize 155 # quantized models with sparse weights. 156 # TODO(b/161560631): Add log message when this optimization is applied. 157 EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY" 158 159 def __str__(self): 160 return str(self.value) 161 162 163@_tf_export("lite.RepresentativeDataset") 164class RepresentativeDataset(object): 165 """Representative dataset used to optimize the model. 166 167 This is a generator function that provides a small dataset to calibrate or 168 estimate the range, i.e, (min, max) of all floating-point arrays in the model 169 (such as model input, activation outputs of intermediate layers, and model 170 output) for quantization. Usually, this is a small subset of a few hundred 171 samples randomly chosen, in no particular order, from the training or 172 evaluation dataset. 173 """ 174 175 def __init__(self, input_gen): 176 """Creates a representative dataset. 177 178 Args: 179 input_gen: A generator function that generates input samples for the 180 model and has the same order, type and shape as the inputs to the model. 181 Usually, this is a small subset of a few hundred samples randomly 182 chosen, in no particular order, from the training or evaluation dataset. 183 """ 184 self.input_gen = input_gen 185 186 187@_tf_export("lite.TargetSpec") 188class TargetSpec(object): 189 """Specification of target device used to optimize the model. 190 191 Attributes: 192 supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet` 193 options, where each option represents a set of operators supported by the 194 target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS})) 195 supported_types: Set of `tf.dtypes.DType` data types supported on the target 196 device. If initialized, optimization might be driven by the smallest type 197 in this set. (default set()) 198 experimental_select_user_tf_ops: Experimental flag, subject to change. Set 199 of user's TensorFlow operators' names that are required in the TensorFlow 200 Lite runtime. These ops will be exported as select TensorFlow ops in the 201 model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is 202 an advanced feature that should only be used if the client is using TF ops 203 that may not be linked in by default with the TF ops that are provided 204 when using the SELECT_TF_OPS path. The client is responsible for linking 205 these ops into the target runtime. 206 """ 207 208 def __init__(self, 209 supported_ops=None, 210 supported_types=None, 211 experimental_select_user_tf_ops=None): 212 if supported_ops is None: 213 supported_ops = {OpsSet.TFLITE_BUILTINS} 214 self.supported_ops = supported_ops 215 if supported_types is None: 216 supported_types = set() 217 self.supported_types = supported_types 218 if experimental_select_user_tf_ops is None: 219 experimental_select_user_tf_ops = set() 220 self.experimental_select_user_tf_ops = experimental_select_user_tf_ops 221 self._experimental_custom_op_registerers = [] 222 # Hint for the supported accumulation type used for inference. Typically 223 # used for fp16 post-training quantization, where some models can use fp16 224 # accumulators instead of the typical fp32 type. 225 # TODO(b/188185962): Provide full API and authoring support for 226 # reduced precision accumulation types. 227 self._experimental_supported_accumulation_type = None 228 229 230class QuantizationMode(object): 231 """QuantizationMode determines the quantization type from user options.""" 232 233 def __init__(self, optimizations, target_spec, representative_dataset, 234 graph_def): 235 self._optimizations = optimizations 236 for deprecated_optimization in [ 237 Optimize.OPTIMIZE_FOR_SIZE, Optimize.OPTIMIZE_FOR_LATENCY 238 ]: 239 if deprecated_optimization in self._optimizations: 240 logging.warning( 241 "Optimization option %s is deprecated, please use optimizations=" 242 "[Optimize.DEFAULT] instead.", deprecated_optimization) 243 244 self._target_spec = target_spec 245 self._representative_dataset = representative_dataset 246 self._graph_def = graph_def 247 248 self._validate_int8_required() 249 250 # TODO(b/162537905): Refactor the following quantization functions - 251 # re-organize and refactor for better readability. 252 def post_training_int8_no_float(self): 253 return (self.any_optimization_enabled() and 254 self._is_int8_target_required() and 255 not self._is_int16x8_target_required() and 256 not self.is_allow_float() and 257 self._representative_dataset is not None) 258 259 def post_training_int8_allow_float(self): 260 return (self.any_optimization_enabled() and 261 not self._is_int16x8_target_required() and 262 self._representative_dataset is not None and 263 self._smallest_supported_type() == _dtypes.int8) 264 265 def is_post_training_integer_quantize_8(self): 266 return (self.post_training_int8_no_float() or 267 self.post_training_int8_allow_float()) 268 269 def is_post_training_integer_quantize_16x8(self): 270 return (self.post_training_int16x8_no_float() or 271 self.post_training_int16x8_allow_float()) 272 273 def is_post_training_integer_quantize(self): 274 return (self.is_post_training_integer_quantize_8() or 275 self.is_post_training_integer_quantize_16x8()) 276 277 def is_integer_quantize(self): 278 return (self.is_post_training_integer_quantize() or 279 self.is_training_time_int8_allow_float()) 280 281 def is_training_time_int8_allow_float(self): 282 return (self.any_optimization_enabled() and 283 self.contains_training_quant_op()) 284 285 def is_bfloat16_inference_allowed(self): 286 return (self.any_optimization_enabled() and 287 self._smallest_supported_type().size == 2 and 288 _dtypes.bfloat16 in self._target_spec.supported_types) 289 290 def post_training_int16x8_no_float(self): 291 return (self.any_optimization_enabled() and 292 not self._is_int8_target_required() and 293 self._is_int16x8_target_required() and 294 not self.is_allow_float() and 295 self._representative_dataset is not None) 296 297 def post_training_int16x8_allow_float(self): 298 return (self.any_optimization_enabled() and 299 self._is_int16x8_target_required() and 300 self.is_allow_float()) 301 302 def post_training_dynamic_range_int8(self): 303 # Post-training dynamic range quantization is only enabled if post-training 304 # int8 quantization and training time quantization was not done. 305 return (self.any_optimization_enabled() and 306 self._representative_dataset is None and 307 not self.contains_training_quant_op() and 308 self._smallest_supported_type() == _dtypes.int8) 309 310 def post_training_fp16(self): 311 return (self.any_optimization_enabled() and 312 self._smallest_supported_type().size == 2 and 313 _dtypes.float16 in self._target_spec.supported_types) 314 315 def fp32_execution(self): 316 """If none of the above are true.""" 317 return not (self.is_integer_quantize() or 318 self.post_training_dynamic_range_int8() or 319 self.post_training_fp16()) 320 321 def activations_type(self): 322 if self.is_integer_quantize(): 323 if self._is_int16x8_target_required(): 324 return _dtypes.int16 325 else: 326 return _dtypes.int8 327 else: 328 return _dtypes.float32 329 330 def converter_flags(self, inference_ty=None, inference_input_ty=None): 331 """Flags to the converter.""" 332 333 if self.is_integer_quantize(): 334 return { 335 "inference_type": ( 336 inference_ty if inference_ty else self.activations_type()), 337 "inference_input_type": _dtypes.float32, 338 "post_training_quantize": False, # disable dynamic range quantization 339 "quantize_to_float16": False # disable float16 quantization 340 } 341 elif self.post_training_dynamic_range_int8(): 342 return { 343 "inference_type": _dtypes.float32, 344 "inference_input_type": _dtypes.float32, 345 "post_training_quantize": True, # enable dynamic range quantization 346 "quantize_to_float16": False # disable float16 quantization 347 } 348 elif self.post_training_fp16(): 349 return { 350 "inference_type": _dtypes.float32, 351 "inference_input_type": _dtypes.float32, 352 "post_training_quantize": True, 353 "quantize_to_float16": True, # enable float16 quantization 354 "accumulation_type": 355 self._target_spec._experimental_supported_accumulation_type, 356 "allow_bfloat16": 357 self.is_bfloat16_inference_allowed() 358 } 359 else: 360 # Note this might still trigger (uint8) quantization to be compatible with 361 # TOCO. 362 return { 363 "inference_type": inference_ty if inference_ty else _dtypes.float32, 364 "inference_input_type": inference_input_ty, 365 "post_training_quantize": False, # enable dynamic range quantization 366 "quantize_to_float16": False, # disable float16 quantization 367 "allow_bfloat16": self.is_bfloat16_inference_allowed() 368 } 369 370 # Below are helpers for the above functions. 371 372 def _validate_int8_required(self): 373 """Int8 mode requires certain parameters to exist and be compatible.""" 374 if not self._is_int8_target_required(): 375 return 376 377 if self._target_spec.supported_types and (self._smallest_supported_type() != 378 _dtypes.int8): 379 raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported " 380 "type to be INT8.") 381 382 if self._representative_dataset: 383 if not isinstance(self._representative_dataset, RepresentativeDataset): 384 self._representative_dataset = RepresentativeDataset( 385 self._representative_dataset) 386 if self._representative_dataset.input_gen is None: 387 raise ValueError( 388 "Provide an input generator for representative_dataset") 389 else: 390 # TODO(b/162537905): Relax this check for QAT. 391 raise ValueError("representative_dataset is required when specifying " 392 "TFLITE_BUILTINS_INT8 or INT8 supported types.") 393 394 def _is_int8_target_required(self): 395 return (OpsSet.TFLITE_BUILTINS_INT8 in set( 396 self._target_spec.supported_ops)) or (set( 397 self._target_spec.supported_types) == set([_dtypes.int8])) 398 399 def _is_int16x8_target_required(self): 400 return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 401 in set(self._target_spec.supported_ops)) 402 403 def is_allow_float(self): 404 return (OpsSet.TFLITE_BUILTINS in set( 405 self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set( 406 self._target_spec.supported_ops)) 407 408 def any_optimization_enabled(self): 409 return bool( 410 set(self._optimizations).intersection([ 411 Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE, 412 Optimize.DEFAULT 413 ])) 414 415 def _smallest_supported_type(self): 416 if self._target_spec.supported_types: 417 return min(self._target_spec.supported_types, key=lambda x: x.size) 418 else: 419 # The default smallest supported type is INT8. 420 return _dtypes.int8 421 422 def contains_training_quant_op(self): 423 """Checks if the graph contains any training-time quantization ops.""" 424 training_quant_ops = frozenset({ 425 "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel", 426 "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsPerChannel", 427 "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3" 428 }) 429 430 if self._graph_def: 431 for node_def in self._graph_def.node: 432 if node_def.op in training_quant_ops: 433 return True 434 for function in self._graph_def.library.function: 435 for node_def in function.node_def: 436 if node_def.op in training_quant_ops: 437 return True 438 return False 439 440 441class TFLiteConverterBase(object): 442 """Converter subclass to share functionality between V1 and V2 converters.""" 443 444 def __init__(self): 445 self.optimizations = set() 446 self.representative_dataset = None 447 self.target_spec = TargetSpec() 448 self.allow_custom_ops = False 449 self.experimental_new_converter = True 450 self.experimental_new_quantizer = True 451 self.experimental_enable_resource_variables = False 452 self._experimental_new_quantizer = None 453 self._experimental_calibrate_only = False 454 self._experimental_sparsify_model = False 455 self._experimental_disable_per_channel = False 456 self._debug_info = None # contains the stack traces of all the original 457 # nodes in the `GraphDef` to the converter. 458 self.saved_model_dir = None 459 self._saved_model_tags = None 460 self._saved_model_version = 0 461 self._saved_model_exported_names = [] 462 self._tflite_metrics = metrics.TFLiteConverterMetrics() 463 self._collected_converter_params = {} 464 self._experimental_disable_batchmatmul_unfold = False 465 self._experimental_lower_tensor_list_ops = True 466 self._experimental_unfold_large_splat_constant = False 467 468 def _grappler_config(self, optimizers=None): 469 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 470 471 Args: 472 optimizers: List of strings that represents the list of optimizers. 473 474 Returns: 475 tf.ConfigProto. 476 """ 477 if not optimizers: 478 optimizers = [] 479 # MLIR converter will take care of constant folding instead of grappler. 480 if not self.experimental_new_converter: 481 optimizers.append("constfold") 482 483 is_only_flex_enabled = ( 484 set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops)) 485 if is_only_flex_enabled: 486 # The layout optimizer turns NHCW to NCHW. This provides performance 487 # optimizations when Flex mode is enabled. However, this is not compatible 488 # with builtin ops. 489 optimizers.append("layout") 490 return _get_grappler_config(optimizers) 491 492 def _quantize(self, result, input_type, output_type, activations_type, 493 allow_float): 494 """Quantize the model.""" 495 # pylint: disable=protected-access 496 custom_op_registerers_by_name = [ 497 x for x in self.target_spec._experimental_custom_op_registerers 498 if isinstance(x, str) 499 ] 500 custom_op_registerers_by_func = [ 501 x for x in self.target_spec._experimental_custom_op_registerers 502 if not isinstance(x, str) 503 ] 504 # pylint: enable=protected-access 505 if not isinstance(self.representative_dataset, RepresentativeDataset): 506 self.representative_dataset = RepresentativeDataset( 507 self.representative_dataset) 508 509 # Add intermediate tensors to the model if needed. 510 result = _calibrator.add_intermediate_tensors(result) 511 calibrate_quantize = _calibrator.Calibrator(result, 512 custom_op_registerers_by_name, 513 custom_op_registerers_by_func) 514 if self._experimental_calibrate_only or self.experimental_new_quantizer: 515 calibrated = calibrate_quantize.calibrate( 516 self.representative_dataset.input_gen) 517 518 if self._experimental_calibrate_only: 519 return calibrated 520 elif self.experimental_new_quantizer and ( 521 activations_type != _dtypes.int16): 522 # TODO(b/175659372): remove the activations_type restriction and enable 523 # it for all the activation types. 524 return _mlir_quantize( 525 calibrated, 526 self._experimental_disable_per_channel, 527 input_data_type=input_type, 528 output_data_type=output_type) 529 else: 530 return calibrate_quantize.calibrate_and_quantize( 531 self.representative_dataset.input_gen, input_type, output_type, 532 allow_float, activations_type, 533 disable_per_channel=self._experimental_disable_per_channel) 534 535 def _is_unknown_shapes_allowed(self): 536 # Unknown dimensions are only allowed with the new converter. 537 return self.experimental_new_converter 538 539 def _get_base_converter_args(self): 540 """Returns the base converter args. 541 542 Returns: 543 {key str: val} 544 """ 545 args = { 546 "input_format": 547 constants.TENSORFLOW_GRAPHDEF, 548 "allow_custom_ops": 549 self.allow_custom_ops, 550 "debug_info": 551 self._debug_info, 552 "target_ops": 553 self.target_spec.supported_ops, 554 "enable_mlir_converter": 555 self.experimental_new_converter, 556 "select_user_tf_ops": 557 self.target_spec.experimental_select_user_tf_ops, 558 "unfold_batchmatmul": 559 not self._experimental_disable_batchmatmul_unfold, 560 "lower_tensor_list_ops": 561 self._experimental_lower_tensor_list_ops, 562 "unfold_large_splat_constant": 563 self._experimental_unfold_large_splat_constant, 564 } 565 566 if self.saved_model_dir: 567 args.update({ 568 "saved_model_dir": self.saved_model_dir, 569 "saved_model_version": self._saved_model_version, 570 "saved_model_tags": self._saved_model_tags, 571 "saved_model_exported_names": self._saved_model_exported_names, 572 }) 573 574 return args 575 576 def _contains_function_with_implements_attr(self, saved_model_proto): 577 meta_graph = saved_model_proto.meta_graphs[0] 578 for function in meta_graph.graph_def.library.function: 579 if function.attr.get("_implements", None) or function.attr.get( 580 "api_implements", None): 581 return True 582 return False 583 584 def _parse_saved_model_args(self, always_enable_saved_model_import=False): 585 """Parses SavedModel arguments from the given Keras/RNN SavedModel. 586 587 Args: 588 always_enable_saved_model_import: Bool. When the value is true, it enables 589 MLIR saved model import path regardless of checking the conditions. 590 """ 591 if not self.experimental_new_converter: 592 self.saved_model_dir = None 593 return 594 if self.saved_model_dir: 595 try: 596 saved_model_proto, _ = ( 597 _parse_saved_model_with_debug_info(self.saved_model_dir)) 598 except OSError: 599 # If it fails to read the given saved model, it will fall back to the 600 # frozen graph def path. 601 self.saved_model_dir = None 602 return 603 if (not always_enable_saved_model_import and 604 not self._contains_function_with_implements_attr(saved_model_proto)): 605 self.saved_model_dir = None 606 return 607 608 if not self._saved_model_exported_names: 609 self._saved_model_exported_names = [] 610 self._saved_model_version = saved_model_proto.saved_model_schema_version 611 if self._saved_model_version == 0: 612 self.saved_model_dir = None 613 logging.warning("SavedModel schema version is zero.") 614 return 615 if self._saved_model_version not in [1, 2]: 616 raise ValueError("SavedModel file format({0}) is not supported".format( 617 self._saved_model_version)) 618 619 def _sparsify_model(self): 620 return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations 621 622 def _validate_experimental_new_quantizer_flag(self): 623 if self._experimental_new_quantizer is not None: 624 raise ValueError("Please use 'experimental_new_quantizer' instead.") 625 626 def _increase_conversion_attempt_metric(self): 627 self._tflite_metrics.increase_counter_converter_attempt() 628 629 def _increase_conversion_success_metric(self): 630 self._tflite_metrics.increase_counter_converter_success() 631 632 def _save_conversion_params_metric(self, 633 graph_def=None, 634 inference_type=None, 635 inference_input_type=None): 636 """Set conversion parameter metrics.""" 637 converter_kwargs = self._collected_converter_params 638 converter_kwargs.update(self._get_base_converter_args()) 639 640 # Optimization parameters. 641 quant_mode = QuantizationMode(self.optimizations, self.target_spec, 642 self.representative_dataset, graph_def) 643 converter_kwargs.update({ 644 "optimization_default": 645 quant_mode.any_optimization_enabled(), 646 "optimization_post_training_dynamic_range": 647 quant_mode.post_training_dynamic_range_int8(), 648 "optimization_post_training_float16": 649 quant_mode.post_training_fp16(), 650 "optimization_post_training_integer_quantize": 651 quant_mode.is_post_training_integer_quantize(), 652 "optimization_qat": 653 quant_mode.is_training_time_int8_allow_float(), 654 "optimization_sparsify": 655 self._sparsify_model(), 656 "activations_type": 657 quant_mode.activations_type() 658 }) 659 converter_kwargs.update( 660 quant_mode.converter_flags(inference_type, inference_input_type)) 661 662 # pylint: disable=protected-access 663 if self.target_spec._experimental_supported_accumulation_type: 664 converter_kwargs.update({ 665 "accumulation_type": 666 self.target_spec._experimental_supported_accumulation_type 667 }) 668 # pylint: enable=protected-access 669 670 def format_element(elem): 671 if isinstance(elem, enum.Enum): 672 return str(elem.value) 673 return pprint.pformat(elem) 674 675 def format_param(param): 676 if isinstance(param, (list, tuple, set)): 677 if not param: 678 return "None" # Return None if empty. 679 string_list = [format_element(x) for x in param] 680 return ",".join(sorted(string_list)) 681 return format_element(param) 682 683 for key, value in converter_kwargs.items(): 684 self._tflite_metrics.set_converter_param(key, format_param(value)) 685 self._tflite_metrics.set_export_required() 686 687 def _set_conversion_latency_metric(self, value): 688 self._tflite_metrics.set_converter_latency(value) 689 690 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL) 691 def _optimize_tflite_model(self, model, quant_mode, quant_io=True): 692 """Apply optimizations on a TFLite model.""" 693 694 if quant_mode.is_integer_quantize(): 695 696 in_type, out_type = self.inference_input_type, self.inference_output_type 697 698 if quant_mode.is_post_training_integer_quantize(): 699 q_in_type = in_type if in_type and quant_io else _dtypes.float32 700 q_out_type = out_type if out_type and quant_io else _dtypes.float32 701 q_activations_type = quant_mode.activations_type() 702 q_allow_float = quant_mode.is_allow_float() 703 model = self._quantize( 704 model, q_in_type, q_out_type, q_activations_type, q_allow_float) 705 706 m_in_type = in_type if in_type else _dtypes.float32 707 m_out_type = out_type if out_type else _dtypes.float32 708 model = _modify_model_io_type(model, m_in_type, m_out_type) 709 710 if self._sparsify_model(): 711 model = _mlir_sparsify(model) 712 713 try: 714 model = _deduplicate_readonly_buffers(model) 715 except Exception: # pylint: disable=broad-except 716 # Skip buffer deduplication when flatbuffer library is not ready to be 717 # utilized. 718 logging.warning( 719 "Buffer deduplication procedure will be skipped when flatbuffer " 720 "library is not properly loaded") 721 722 return model 723 724 def _convert_and_export_metrics(self, convert_func, *args, **kwargs): 725 """Wraps around convert function to export metrics. 726 727 Args: 728 convert_func: The convert function to wrap. 729 *args: Positional arguments of the convert function. 730 **kwargs: The keyword arguments of the convert function. 731 732 Returns: 733 The decorator to wrap the convert function. 734 """ 735 self._increase_conversion_attempt_metric() 736 self._save_conversion_params_metric() 737 start_time = time.process_time() 738 result = convert_func(self, *args, **kwargs) 739 elapsed_time_ms = (time.process_time() - start_time) * 1000 740 if result: 741 self._increase_conversion_success_metric() 742 self._set_conversion_latency_metric(round(elapsed_time_ms)) 743 self._tflite_metrics.export_metrics() 744 return result 745 746 747def _export_metrics(convert_func): 748 """The decorator around convert function to export metrics.""" 749 @functools.wraps(convert_func) 750 def wrapper(self, *args, **kwargs): 751 # pylint: disable=protected-access 752 return self._convert_and_export_metrics(convert_func, *args, **kwargs) 753 # pylint: enable=protected-access 754 755 return wrapper 756 757 758class TFLiteConverterBaseV2(TFLiteConverterBase): 759 """Converter subclass to share functionality between V2 converters.""" 760 761 def __init__(self): 762 """Constructor for TFLiteConverter.""" 763 super(TFLiteConverterBaseV2, self).__init__() 764 self.inference_input_type = _dtypes.float32 765 self.inference_output_type = _dtypes.float32 766 self._collected_converter_params.update({"api_version": 2}) 767 768 def _validate_inference_input_output_types(self, quant_mode): 769 """Validate inference_input_type and inference_output_type flags.""" 770 default_types = [_dtypes.float32] 771 # We support integer input/output for integer quantized models only. 772 if quant_mode.is_integer_quantize(): 773 if quant_mode.is_post_training_integer_quantize_16x8(): 774 all_types = default_types + [_dtypes.int16] 775 else: 776 all_types = default_types + [_dtypes.int8, _dtypes.uint8] 777 if (self.inference_input_type not in all_types or 778 self.inference_output_type not in all_types): 779 all_types_names = ["tf." + t.name for t in all_types] 780 raise ValueError("The inference_input_type and inference_output_type " 781 "must be in {}.".format(all_types_names)) 782 elif (self.inference_input_type not in default_types or 783 self.inference_output_type not in default_types): 784 raise ValueError("The inference_input_type and inference_output_type " 785 "must be tf.float32.") 786 787 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL) 788 def _load_saved_model(self, saved_model_dir, saved_model_tags): 789 """Load graph_def from saved model with the default serving signature key. 790 791 Args: 792 saved_model_dir: Directory of the SavedModel. 793 saved_model_tags: Set of tags identifying the MetaGraphDef within the 794 SavedModel to analyze. 795 796 Returns: 797 graph_def: The loaded GraphDef. 798 input_tensors: List of input tensors. 799 output_tensors: List of output tensors. 800 """ 801 graph = _ops.Graph() 802 saved_model = _loader_impl.SavedModelLoader(saved_model_dir) 803 saved_model.load_graph(graph, tags=saved_model_tags) 804 meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags) 805 graph_def = meta_graph.graph_def 806 signature_def = meta_graph.signature_def[ 807 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 808 input_tensors = [ 809 graph.get_tensor_by_name(signature_def.inputs[key].name) 810 for key in signature_def.inputs 811 ] 812 output_tensors = [ 813 graph.get_tensor_by_name(signature_def.outputs[key].name) 814 for key in signature_def.outputs 815 ] 816 return graph_def, input_tensors, output_tensors 817 818 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS) 819 def _validate_inputs(self, graph_def, input_tensors): 820 """Validate the input parameters. 821 822 Args: 823 graph_def: The TensorFlow GraphDef. 824 input_tensors: List of input tensors. 825 Raise: 826 ValueError: 827 Input shape is not specified. 828 Invalid quantization parameters. 829 """ 830 # Update conversion params with graph_def. 831 self._save_conversion_params_metric(graph_def) 832 self._quant_mode = QuantizationMode(self.optimizations, self.target_spec, 833 self.representative_dataset, graph_def) 834 self._validate_inference_input_output_types(self._quant_mode) 835 self._validate_experimental_new_quantizer_flag() 836 837 if not self._is_unknown_shapes_allowed(): 838 # Checks dimensions in input tensor. 839 for tensor in input_tensors: 840 # Note that shape_list might be empty for scalar shapes. 841 shape_list = tensor.shape.as_list() 842 if None in shape_list[1:]: 843 raise ValueError( 844 "None is only supported in the 1st dimension. Tensor '{0}' has " 845 "invalid shape '{1}'.".format( 846 _get_tensor_name(tensor), shape_list)) 847 elif shape_list and shape_list[0] is None: 848 # Set the batch size to 1 if undefined. 849 shape = tensor.shape.as_list() 850 shape[0] = 1 851 tensor.set_shape(shape) 852 853 if (self._trackable_obj is None or 854 not hasattr(self._trackable_obj, "graph_debug_info")): 855 self._debug_info = _get_debug_info( 856 _build_debug_info_func(self._funcs[0].graph), graph_def) 857 else: 858 self._debug_info = _get_debug_info( 859 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 860 graph_def) 861 862 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL) 863 def _optimize_tf_model(self, graph_def, input_tensors, output_tensors, 864 frozen_func): 865 """Run a Grappler pass to optimize the TensorFlow graph. 866 867 Args: 868 graph_def: Frozen GraphDef to be optimized. 869 input_tensors: List of input tensors. 870 output_tensors: List of output tensors. 871 frozen_func: TensorFlow Graph. 872 873 Returns: 874 The optimized TensorFlow graph. 875 """ 876 grappler_config = self._grappler_config() 877 # Skip running grappler when there are no optimizers to run. If not, 878 # grappler will run with the default optimizer set and it will lead to 879 # causing an unexpected behavior. 880 if grappler_config.graph_options.rewrite_options.optimizers: 881 graph_def = _run_graph_optimizations( 882 graph_def, 883 input_tensors, 884 output_tensors, 885 config=grappler_config, 886 graph=frozen_func.graph) 887 return graph_def 888 889 def convert(self, graph_def, input_tensors, output_tensors): 890 """Converts a TensorFlow GraphDef based on instance variables. 891 892 Args: 893 graph_def: Frozen TensorFlow GraphDef. 894 input_tensors: List of input tensors. 895 output_tensors: List of output tensors. 896 897 Returns: 898 The converted data in serialized format. 899 900 Raises: 901 ValueError: 902 No concrete functions is specified. 903 Multiple concrete functions are specified. 904 Input shape is not specified. 905 Invalid quantization parameters. 906 """ 907 self._validate_inputs(graph_def, input_tensors) 908 converter_kwargs = self._get_base_converter_args() 909 converter_kwargs.update(self._quant_mode.converter_flags()) 910 if not self.experimental_new_converter: 911 logging.warning( 912 "Please consider switching to the new converter by setting " 913 "experimental_new_converter=True. " 914 "The old converter (TOCO) is deprecated.") 915 else: 916 logging.info("Using new converter: If you encounter a problem " 917 "please file a bug. You can opt-out " 918 "by setting experimental_new_converter=False") 919 920 # Converts model. 921 result = _toco_convert_impl( 922 input_data=graph_def, 923 input_tensors=input_tensors, 924 output_tensors=output_tensors, 925 **converter_kwargs) 926 927 return self._optimize_tflite_model( 928 result, self._quant_mode, quant_io=self.experimental_new_quantizer) 929 930 931class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): 932 """Converts the given SavedModel into TensorFlow Lite model. 933 934 Attributes: 935 saved_model_dir: Directory of the SavedModel. 936 """ 937 938 def __init__(self, 939 saved_model_dir, 940 saved_model_tags=None, 941 saved_model_exported_names=None, 942 trackable_obj=None): 943 """Constructor for TFLiteConverter. 944 945 Args: 946 saved_model_dir: Directory of the SavedModel. 947 saved_model_tags: Set of tags identifying the MetaGraphDef within the 948 SavedModel to analyze. All tags in the tag set must be present. (default 949 {tf.saved_model.SERVING}). 950 saved_model_exported_names: Names to be exported when the saved model 951 import path is on. 952 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 953 reference to this object needs to be maintained so that Variables do not 954 get garbage collected since functions have a weak reference to 955 Variables. This is only required when the tf.AutoTrackable object is not 956 maintained by the user (e.g. `from_saved_model`). 957 """ 958 super(TFLiteSavedModelConverterV2, self).__init__() 959 self.saved_model_dir = saved_model_dir 960 self._saved_model_tags = saved_model_tags 961 self._saved_model_exported_names = saved_model_exported_names 962 self._trackable_obj = trackable_obj 963 self._parse_saved_model_args(always_enable_saved_model_import=True) 964 965 @_export_metrics 966 def convert(self): 967 """Converts a TensorFlow GraphDef based on instance variables. 968 969 Returns: 970 The converted data in serialized format. 971 972 Raises: 973 ValueError: 974 No concrete functions is specified. 975 Multiple concrete functions are specified. 976 Input shape is not specified. 977 Invalid quantization parameters. 978 """ 979 graph_def, input_tensors, output_tensors = self._load_saved_model( 980 self.saved_model_dir, self._saved_model_tags) 981 # If we can't use saved model importer, then fallback 982 # to frozen graph conversion path. 983 if self.saved_model_dir is None or not self.experimental_new_converter: 984 graph_def, _, _, _ = _freeze_saved_model( 985 self.saved_model_dir, None, None, None, self._saved_model_tags, 986 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 987 # We make sure to clear the saved_model_dir as there is some 988 # legacy code down in the caller that checks this. 989 # TODO(b/162537905): Clean these indirect dependencies. 990 self.saved_model_dir = None 991 return super(TFLiteSavedModelConverterV2, 992 self).convert(graph_def, input_tensors, output_tensors) 993 994 if self._trackable_obj is None: 995 self._debug_info = _get_debug_info( 996 _build_debug_info_func(self._funcs[0].graph), graph_def) 997 else: 998 self._debug_info = _get_debug_info( 999 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 1000 graph_def) 1001 1002 # Update conversion params with graph_def. 1003 self._save_conversion_params_metric(graph_def) 1004 # Get quantization options and do some sanity checks. 1005 quant_mode = QuantizationMode(self.optimizations, self.target_spec, 1006 self.representative_dataset, graph_def) 1007 self._validate_inference_input_output_types(quant_mode) 1008 1009 converter_kwargs = { 1010 "enable_tflite_resource_variables": 1011 self.experimental_enable_resource_variables 1012 } 1013 converter_kwargs.update(self._get_base_converter_args()) 1014 converter_kwargs.update(quant_mode.converter_flags()) 1015 1016 result = _convert_saved_model(**converter_kwargs) 1017 1018 return self._optimize_tflite_model( 1019 result, quant_mode, quant_io=self.experimental_new_quantizer) 1020 1021 1022class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): 1023 """Converts the given Keras model into TensorFlow Lite model.""" 1024 1025 def __init__(self, keras_model, trackable_obj=None): 1026 """Constructor for TFLiteConverter. 1027 1028 Args: 1029 keras_model: tf.Keras.Model. 1030 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1031 reference to this object needs to be maintained so that Variables do not 1032 get garbage collected since functions have a weak reference to 1033 Variables. This is only required when the tf.AutoTrackable object is not 1034 maintained by the user (e.g. `from_saved_model`). 1035 """ 1036 super(TFLiteKerasModelConverterV2, self).__init__() 1037 self._keras_model = keras_model 1038 self._trackable_obj = trackable_obj 1039 self.experimental_lower_to_saved_model = True 1040 1041 @convert_phase(Component.PREPARE_TF_MODEL, 1042 SubComponent.CONVERT_KERAS_TO_SAVED_MODEL) 1043 def _convert_keras_to_saved_model(self, output_dir): 1044 """Save Keras model to the SavedModel format. 1045 1046 Args: 1047 output_dir: The output directory to save the SavedModel. 1048 1049 Returns: 1050 graph_def: The frozen GraphDef. 1051 input_tensors: List of input tensors. 1052 output_tensors: List of output tensors. 1053 """ 1054 try: 1055 _saved_model.save( 1056 self._keras_model, 1057 output_dir, 1058 options=_save_options.SaveOptions(save_debug_info=True)) 1059 except Exception: # pylint: disable=broad-except 1060 # When storing the given keras model to a saved model is failed, let's 1061 # use original keras model conversion pipeline. 1062 return None, None, None 1063 self.saved_model_dir = output_dir 1064 self._saved_model_tags = set([_tag_constants.SERVING]) 1065 self._saved_model_exported_names = [ 1066 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1067 ] 1068 self._parse_saved_model_args( 1069 always_enable_saved_model_import=self.experimental_lower_to_saved_model) 1070 if self.saved_model_dir: 1071 graph_def, input_tensors, output_tensors = self._load_saved_model( 1072 self.saved_model_dir, self._saved_model_tags) 1073 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags) 1074 return graph_def, input_tensors, output_tensors 1075 return None, None, None 1076 1077 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL) 1078 def _freeze_keras_model(self): 1079 """Freeze Keras model to frozen graph. 1080 1081 Returns: 1082 graph_def: The frozen GraphDef. 1083 input_tensors: List of input tensors. 1084 output_tensors: List of output tensors. 1085 frozen_func: The frozen ConcreteFunction. 1086 """ 1087 input_signature = None 1088 # If the model's call is not a `tf.function`, then we need to first get its 1089 # input signature from `model_input_signature` method. We can't directly 1090 # call `trace_model_call` because otherwise the batch dimension is set 1091 # to None. 1092 # Once we have better support for dynamic shapes, we can remove this. 1093 if not isinstance(self._keras_model.call, _def_function.Function): 1094 # Pass `keep_original_batch_size=True` will ensure that we get an input 1095 # signature including the batch dimension specified by the user. 1096 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 1097 input_signature = _model_input_signature( 1098 self._keras_model, keep_original_batch_size=True) 1099 1100 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 1101 func = _trace_model_call(self._keras_model, input_signature) 1102 concrete_func = func.get_concrete_function() 1103 self._funcs = [concrete_func] 1104 1105 frozen_func, graph_def = ( 1106 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 1107 self._funcs[0], lower_control_flow=False)) 1108 1109 input_tensors = [ 1110 tensor for tensor in frozen_func.inputs 1111 if tensor.dtype != _dtypes.resource 1112 ] 1113 output_tensors = frozen_func.outputs 1114 return graph_def, input_tensors, output_tensors, frozen_func 1115 1116 def _convert_as_saved_model(self): 1117 """Converts a Keras model as a saved model. 1118 1119 Returns: 1120 The converted data in serialized format. 1121 """ 1122 temp_dir = tempfile.mkdtemp() 1123 try: 1124 graph_def, input_tensors, output_tensors = ( 1125 self._convert_keras_to_saved_model(temp_dir)) 1126 if self.saved_model_dir: 1127 return super(TFLiteKerasModelConverterV2, 1128 self).convert(graph_def, input_tensors, output_tensors) 1129 finally: 1130 shutil.rmtree(temp_dir, True) 1131 1132 @_export_metrics 1133 def convert(self): 1134 """Converts a keras model based on instance variables. 1135 1136 Returns: 1137 The converted data in serialized format. 1138 1139 Raises: 1140 ValueError: 1141 Multiple concrete functions are specified. 1142 Input shape is not specified. 1143 Invalid quantization parameters. 1144 """ 1145 saved_model_convert_result = self._convert_as_saved_model() 1146 if saved_model_convert_result: 1147 return saved_model_convert_result 1148 1149 graph_def, input_tensors, output_tensors, frozen_func = ( 1150 self._freeze_keras_model()) 1151 1152 graph_def = self._optimize_tf_model(graph_def, input_tensors, 1153 output_tensors, frozen_func) 1154 1155 return super(TFLiteKerasModelConverterV2, 1156 self).convert(graph_def, input_tensors, output_tensors) 1157 1158 1159class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2): 1160 """Converts the given frozen graph into TensorFlow Lite model.""" 1161 1162 def __init__(self, funcs, trackable_obj=None): 1163 """Constructor for TFLiteConverter. 1164 1165 Args: 1166 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1167 duplicate elements. 1168 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1169 reference to this object needs to be maintained so that Variables do not 1170 get garbage collected since functions have a weak reference to 1171 Variables. This is only required when the tf.AutoTrackable object is not 1172 maintained by the user (e.g. `from_saved_model`). 1173 """ 1174 super(TFLiteFrozenGraphConverterV2, self).__init__() 1175 self._funcs = funcs 1176 self._trackable_obj = trackable_obj 1177 self.experimental_lower_to_saved_model = True 1178 1179 @convert_phase(Component.PREPARE_TF_MODEL, 1180 SubComponent.FREEZE_CONCRETE_FUNCTION) 1181 def _freeze_concrete_function(self): 1182 """Convert the given ConcreteFunction to frozen graph. 1183 1184 Returns: 1185 graph_def: The frozen GraphDef. 1186 input_tensors: List of input tensors. 1187 output_tensors: List of output tensors. 1188 frozen_func: The frozen ConcreteFunction. 1189 1190 Raises: 1191 ValueError: none or multiple ConcreteFunctions provided. 1192 """ 1193 # TODO(b/130297984): Add support for converting multiple function. 1194 1195 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test 1196 raise ValueError("No ConcreteFunction is specified.") 1197 1198 if len(self._funcs) > 1: 1199 raise ValueError("This converter can only convert a single " 1200 "ConcreteFunction. Converting multiple functions is " 1201 "under development.") 1202 1203 frozen_func, graph_def = ( 1204 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 1205 self._funcs[0], lower_control_flow=False)) 1206 1207 input_tensors = [ 1208 tensor for tensor in frozen_func.inputs 1209 if tensor.dtype != _dtypes.resource 1210 ] 1211 output_tensors = frozen_func.outputs 1212 return graph_def, input_tensors, output_tensors, frozen_func 1213 1214 @convert_phase(Component.PREPARE_TF_MODEL, 1215 SubComponent.CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL) 1216 def _convert_concrete_functions_to_saved_model(self, output_dir): 1217 """Save concrete functions to the SavedModel format. 1218 1219 Args: 1220 output_dir: The output directory to save the SavedModel. 1221 1222 Returns: 1223 graph_def: The frozen GraphDef. 1224 input_tensors: List of input tensors. 1225 output_tensors: List of output tensors. 1226 """ 1227 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test 1228 raise ValueError("No ConcreteFunction is specified.") 1229 1230 if not self.experimental_lower_to_saved_model: 1231 return None, None, None 1232 1233 # Without the provided trackable obj, it is not able to serialize the given 1234 # concrete functions as a saved model format. 1235 if not self._trackable_obj: 1236 return None, None, None 1237 1238 signatures = {} 1239 signature_keys = [] 1240 try: 1241 if len(self._funcs) == 1: 1242 signatures[_signature_constants 1243 .DEFAULT_SERVING_SIGNATURE_DEF_KEY] = self._funcs[0] 1244 signature_keys = [ 1245 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1246 ] 1247 else: 1248 for func in self._funcs: 1249 signatures[func.graph.name] = func 1250 signature_keys.append(func.graph.name) 1251 1252 _saved_model.save( 1253 self._trackable_obj, 1254 output_dir, 1255 signatures=signatures, 1256 options=_save_options.SaveOptions(save_debug_info=True)) 1257 except Exception: # pylint: disable=broad-except 1258 # When storing the given concrete function to a saved model is failed, 1259 # let's use original concrete function conversion pipeline. 1260 return None, None, None 1261 1262 self.saved_model_dir = output_dir 1263 self._saved_model_tags = set([_tag_constants.SERVING]) 1264 self._saved_model_exported_names = signature_keys 1265 self._parse_saved_model_args(always_enable_saved_model_import=True) 1266 if self.saved_model_dir: 1267 graph_def, input_tensors, output_tensors = self._load_saved_model( 1268 self.saved_model_dir, self._saved_model_tags) 1269 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags) 1270 return graph_def, input_tensors, output_tensors 1271 return None, None, None 1272 1273 def _convert_as_saved_model(self): 1274 """Converts the given concrete functions as a saved model format. 1275 1276 Returns: 1277 The converted data in serialized format. 1278 """ 1279 temp_dir = tempfile.mkdtemp() 1280 try: 1281 graph_def, input_tensors, output_tensors = ( 1282 self._convert_concrete_functions_to_saved_model(temp_dir)) 1283 if self.saved_model_dir: 1284 return super(TFLiteFrozenGraphConverterV2, 1285 self).convert(graph_def, input_tensors, output_tensors) 1286 finally: 1287 shutil.rmtree(temp_dir, True) 1288 return None 1289 1290 @_export_metrics 1291 def convert(self): 1292 """Converts a TensorFlow GraphDef based on instance variables. 1293 1294 Returns: 1295 The converted data in serialized format. 1296 1297 Raises: 1298 ValueError: 1299 No concrete functions is specified. 1300 Multiple concrete functions are specified. 1301 Input shape is not specified. 1302 Invalid quantization parameters. 1303 """ 1304 if self.experimental_lower_to_saved_model: 1305 saved_model_convert_result = self._convert_as_saved_model() 1306 if saved_model_convert_result: 1307 return saved_model_convert_result 1308 1309 graph_def, input_tensors, output_tensors, frozen_func = ( 1310 self._freeze_concrete_function()) 1311 1312 graph_def = self._optimize_tf_model(graph_def, input_tensors, 1313 output_tensors, frozen_func) 1314 1315 return super(TFLiteFrozenGraphConverterV2, 1316 self).convert(graph_def, input_tensors, output_tensors) 1317 1318 1319@_tf_export("lite.TFLiteConverter", v1=[]) 1320class TFLiteConverterV2(TFLiteFrozenGraphConverterV2): 1321 """Converts a TensorFlow model into TensorFlow Lite model. 1322 1323 Attributes: 1324 optimizations: Experimental flag, subject to change. Set of optimizations 1325 to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 1326 set of values of type `tf.lite.Optimize`) 1327 representative_dataset: A generator function used for integer quantization 1328 where each generated sample has the same order, type and shape as the 1329 inputs to the model. Usually, this is a small subset of a few hundred 1330 samples randomly chosen, in no particular order, from the training or 1331 evaluation dataset. This is an optional attribute, but required for full 1332 integer quantization, i.e, if `tf.int8` is the only supported type in 1333 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 1334 (default None) 1335 target_spec: Experimental flag, subject to change. Specifications of target 1336 device, including supported ops set, supported types and a set of user's 1337 defined TensorFlow operators required in the TensorFlow Lite runtime. 1338 Refer to `tf.lite.TargetSpec`. 1339 inference_input_type: Data type of the input layer. Note that integer types 1340 (tf.int8 and tf.uint8) are currently only supported for post training 1341 integer quantization and quantization aware training. (default tf.float32, 1342 must be in {tf.float32, tf.int8, tf.uint8}) 1343 inference_output_type: Data type of the output layer. Note that integer 1344 types (tf.int8 and tf.uint8) are currently only supported for post 1345 training integer quantization and quantization aware training. (default 1346 tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 1347 allow_custom_ops: Boolean indicating whether to allow custom operations. 1348 When False, any unknown operation is an error. When True, custom ops are 1349 created for any op that is unknown. The developer needs to provide these 1350 to the TensorFlow Lite runtime with a custom resolver. (default False) 1351 experimental_new_converter: Experimental flag, subject to change. Enables 1352 MLIR-based conversion instead of TOCO conversion. (default True) 1353 experimental_new_quantizer: Experimental flag, subject to change. Enables 1354 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 1355 (default True) 1356 experimental_enable_resource_variables: Experimental flag, subject to 1357 change. Enables resource variables to be converted by this converter. 1358 This is only allowed if from_saved_model interface is used. 1359 (default False) 1360 1361 Example usage: 1362 1363 ```python 1364 # Converting a SavedModel to a TensorFlow Lite model. 1365 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) 1366 tflite_model = converter.convert() 1367 1368 # Converting a tf.Keras model to a TensorFlow Lite model. 1369 converter = tf.lite.TFLiteConverter.from_keras_model(model) 1370 tflite_model = converter.convert() 1371 1372 # Converting ConcreteFunctions to a TensorFlow Lite model. 1373 converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model) 1374 tflite_model = converter.convert() 1375 ``` 1376 """ 1377 1378 # pylint: disable=useless-super-delegation 1379 def __init__(self, funcs, trackable_obj=None): 1380 """Constructor for TFLiteConverter. 1381 1382 Args: 1383 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1384 duplicate elements. 1385 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1386 reference to this object needs to be maintained so that Variables do not 1387 get garbage collected since functions have a weak reference to 1388 Variables. This is only required when the tf.AutoTrackable object is not 1389 maintained by the user (e.g. `from_saved_model`). 1390 """ 1391 super(TFLiteConverterV2, self).__init__(funcs, trackable_obj) 1392 1393 @classmethod 1394 def from_concrete_functions(cls, funcs, trackable_obj=None): 1395 """Creates a TFLiteConverter object from ConcreteFunctions. 1396 1397 Args: 1398 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1399 duplicate elements. Currently converter can only convert a single 1400 ConcreteFunction. Converting multiple functions is under development. 1401 trackable_obj: An `AutoTrackable` object (typically `tf.module`) 1402 associated with `funcs`. A reference to this object needs to be 1403 maintained so that Variables do not get garbage collected since 1404 functions have a weak reference to Variables. 1405 1406 Returns: 1407 TFLiteConverter object. 1408 1409 Raises: 1410 Invalid input type. 1411 """ 1412 if trackable_obj is None: 1413 logging.warning( 1414 "Please consider providing the trackable_obj argument in the " 1415 "from_concrete_functions. Providing without the trackable_obj " 1416 "argument is deprecated and it will use the deprecated conversion " 1417 "path.") 1418 for func in funcs: 1419 if not isinstance(func, _function.ConcreteFunction): 1420 message = "This function takes in a list of ConcreteFunction." 1421 if isinstance(func, _def_function.Function): 1422 message += (" To get the ConcreteFunction from a Function," 1423 " call get_concrete_function.") 1424 raise ValueError(message) 1425 return cls(funcs, trackable_obj) 1426 1427 @classmethod 1428 def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None): 1429 """Creates a TFLiteConverter object from a SavedModel directory. 1430 1431 Args: 1432 saved_model_dir: SavedModel directory to convert. 1433 signature_keys: List of keys identifying SignatureDef containing inputs 1434 and outputs. Elements should not be duplicated. By default the 1435 `signatures` attribute of the MetaGraphdef is used. (default 1436 saved_model.signatures) 1437 tags: Set of tags identifying the MetaGraphDef within the SavedModel to 1438 analyze. All tags in the tag set must be present. (default 1439 {tf.saved_model.SERVING} or {'serve'}) 1440 1441 Returns: 1442 TFLiteConverter object. 1443 1444 Raises: 1445 Invalid signature keys. 1446 """ 1447 # When run without eager enabled, this will return the legacy 1448 # TFLiteConverter. 1449 if not context.executing_eagerly(): 1450 signature_key = None 1451 if signature_keys: 1452 if len(signature_keys) != 1: 1453 raise ValueError("Only support a single signature key.") 1454 else: 1455 signature_key = signature_keys[0] 1456 logging.warning("Invoking the TF1 implementation of TFLiteConverter " 1457 "because eager is disabled. Consider enabling eager.") 1458 return TFLiteConverter.from_saved_model( 1459 saved_model_dir, signature_key=signature_key, tag_set=tags) 1460 1461 # Ensures any graphs created in Eager mode are able to run. This is required 1462 # in order to create a tf.estimator.Exporter that exports a TFLite model. 1463 if tags is None: 1464 tags = set([_tag_constants.SERVING]) 1465 1466 with context.eager_mode(): 1467 saved_model = _load(saved_model_dir, tags) 1468 if not signature_keys: 1469 signature_keys = saved_model.signatures 1470 1471 if not signature_keys: 1472 raise ValueError("Only support at least one signature key.") 1473 1474 funcs = [] 1475 for key in signature_keys: 1476 if key not in saved_model.signatures: 1477 raise ValueError("Invalid signature key '{}' found. Valid keys are " 1478 "'{}'.".format(key, ",".join(saved_model.signatures))) 1479 funcs.append(saved_model.signatures[key]) 1480 1481 saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags, 1482 signature_keys, 1483 saved_model) 1484 if saved_model_converter.saved_model_dir: 1485 return saved_model_converter 1486 1487 return cls(funcs, saved_model) 1488 1489 @classmethod 1490 def from_keras_model(cls, model): 1491 """Creates a TFLiteConverter object from a Keras model. 1492 1493 Args: 1494 model: tf.Keras.Model 1495 1496 Returns: 1497 TFLiteConverter object. 1498 """ 1499 return TFLiteKerasModelConverterV2(model) 1500 1501 # pylint: disable=useless-super-delegation 1502 def convert(self): 1503 """Converts a TensorFlow GraphDef based on instance variables. 1504 1505 Returns: 1506 The converted data in serialized format. 1507 1508 Raises: 1509 ValueError: 1510 No concrete functions is specified. 1511 Multiple concrete functions are specified. 1512 Input shape is not specified. 1513 Invalid quantization parameters. 1514 """ 1515 return super(TFLiteConverterV2, self).convert() 1516 1517 1518class TFLiteConverterBaseV1(TFLiteConverterBase): 1519 """Converter subclass to share functionality between V1 converters.""" 1520 1521 def __init__(self, experimental_debug_info_func): 1522 """Constructor for TFLiteConverter. 1523 1524 Args: 1525 experimental_debug_info_func: An experimental function to retrieve the 1526 graph debug info for a set of nodes from the `graph_def`. 1527 """ 1528 super(TFLiteConverterBaseV1, self).__init__() 1529 self.inference_type = _dtypes.float32 1530 self.inference_input_type = None 1531 self.inference_output_type = None 1532 self.output_format = constants.TFLITE 1533 self.quantized_input_stats = {} 1534 self.default_ranges_stats = None 1535 self.drop_control_dependency = True 1536 self.reorder_across_fake_quant = False 1537 self.change_concat_input_ranges = False 1538 self.dump_graphviz_dir = None 1539 self.dump_graphviz_video = False 1540 self.conversion_summary_dir = None 1541 self._debug_info_func = experimental_debug_info_func 1542 self._experimental_allow_all_select_tf_ops = False 1543 1544 def __setattr__(self, name, value): 1545 if name == "post_training_quantize": 1546 warnings.warn("Property %s is deprecated, " 1547 "please use optimizations=[Optimize.DEFAULT]" 1548 " instead." % name) 1549 if value: 1550 self.optimizations = [Optimize.DEFAULT] 1551 else: 1552 self.optimizations = [] 1553 return 1554 if name == "target_ops": 1555 warnings.warn("Property %s is deprecated, please use " 1556 "target_spec.supported_ops instead." % name) 1557 self.target_spec.supported_ops = value 1558 return 1559 object.__setattr__(self, name, value) 1560 1561 def __getattribute__(self, name): 1562 if name == "post_training_quantize": 1563 warnings.warn("Property %s is deprecated, " 1564 "please use optimizations=[Optimize.DEFAULT]" 1565 " instead." % name) 1566 return Optimize.DEFAULT in set(self.optimizations) 1567 if name == "target_ops": 1568 warnings.warn("Property %s is deprecated, please use " 1569 "target_spec.supported_ops instead." % name) 1570 return self.target_spec.supported_ops 1571 return object.__getattribute__(self, name) 1572 1573 def _validate_quantized_input_stats(self, converter_kwargs, quant_mode): 1574 """Ensure the `quantized_input_stats` flag is provided if required.""" 1575 1576 quantized_types = frozenset({_dtypes.int8, _dtypes.uint8}) 1577 1578 requires_quantized_input_stats = ( 1579 (converter_kwargs["inference_type"] in quantized_types or 1580 converter_kwargs["inference_input_type"] in quantized_types) and 1581 not quant_mode.is_post_training_integer_quantize()) 1582 1583 if (requires_quantized_input_stats and 1584 not converter_kwargs["quantized_input_stats"]): 1585 raise ValueError( 1586 "The `quantized_input_stats` flag must be defined when either " 1587 "`inference_type` flag or `inference_input_type` flag is set to " 1588 "tf.int8 or tf.uint8. Currently, `inference_type={}` and " 1589 "`inference_input_type={}`.".format( 1590 _get_tf_type_name(converter_kwargs["inference_type"]), 1591 _get_tf_type_name(converter_kwargs["inference_input_type"]))) 1592 1593 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS) 1594 def _validate_inputs(self, input_tensors, quantized_input_stats): 1595 """Validate input parameters. 1596 1597 Args: 1598 input_tensors: List of input tensors. 1599 quantized_input_stats: Map of input tensor names to a tuple of floats 1600 representing the mean and standard deviation of the training data. 1601 1602 Raises: 1603 ValueError: 1604 Input shape is not specified. 1605 Quantization input stats is required but not provided. 1606 """ 1607 1608 if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()): 1609 # Checks dimensions in input tensor. 1610 for tensor in input_tensors: 1611 shape = tensor.shape 1612 if not shape: 1613 raise ValueError("Provide an input shape for input array " 1614 "'{0}'.".format(_get_tensor_name(tensor))) 1615 # Note that shape_list might be empty for scalar shapes. 1616 shape_list = shape.as_list() 1617 if None in shape_list[1:]: 1618 raise ValueError( 1619 "None is only supported in the 1st dimension. Tensor '{0}' has " 1620 "invalid shape '{1}'.".format( 1621 _get_tensor_name(tensor), shape_list)) 1622 elif shape_list and shape_list[0] is None: 1623 self._set_batch_size(batch_size=1) 1624 1625 # Get quantization stats. Ensures there is one stat per name if the stats 1626 # are specified. 1627 if quantized_input_stats: 1628 self._quantized_stats = [] 1629 invalid_stats = [] 1630 for name in self.get_input_arrays(): 1631 if name in quantized_input_stats: 1632 self._quantized_stats.append(quantized_input_stats[name]) 1633 else: 1634 invalid_stats.append(name) 1635 1636 if invalid_stats: 1637 raise ValueError("Quantization input stats are not available for input " 1638 "tensors '{0}'.".format(",".join(invalid_stats))) 1639 else: 1640 self._quantized_stats = None 1641 1642 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL) 1643 def _optimize_tf_model(self, graph_def, input_tensors, output_tensors, 1644 quant_mode): 1645 """Run a Grappler pass to optimize the TensorFlow graph. 1646 1647 Args: 1648 graph_def: Frozen GraphDef to be optimized. 1649 input_tensors: List of input tensors. 1650 output_tensors: List of output tensors. 1651 quant_mode: the quantization mode. 1652 1653 Returns: 1654 The optimized TensorFlow graph. 1655 """ 1656 # Disable grappler constant folding if there are training quant ops. 1657 if self.saved_model_dir or quant_mode.contains_training_quant_op(): 1658 return graph_def 1659 1660 try: 1661 # TODO(b/150163103): Merge `disabling lower using switch merge' calls. 1662 # Grappler will also try to lower while loop into switch merge 1663 # representation which is undesired for Ophints, so we simply remove 1664 # those attributes to prevent Grappler from doing so. 1665 graph = _convert_to_constants.disable_lower_using_switch_merge(graph_def) 1666 # Run function inlining optimization to ensure any models generated 1667 # through the from_frozen_graph path have been inlined. 1668 optimized_graph = _run_graph_optimizations( 1669 graph, 1670 input_tensors, 1671 output_tensors, 1672 config=self._grappler_config(["function"])) 1673 return optimized_graph 1674 except Exception: # pylint: disable=broad-except 1675 return graph_def 1676 1677 def convert(self): 1678 """Converts a TensorFlow GraphDef based on instance variables. 1679 1680 Returns: 1681 The converted data in serialized format. Either a TFLite Flatbuffer or a 1682 Graphviz graph depending on value in `output_format`. 1683 1684 Raises: 1685 ValueError: 1686 Input shape is not specified. 1687 None value for dimension in input_tensor. 1688 """ 1689 self._validate_inputs(self._input_tensors, self.quantized_input_stats) 1690 1691 quant_mode = QuantizationMode(self.optimizations, self.target_spec, 1692 self.representative_dataset, self._graph_def) 1693 1694 optimized_graph = self._optimize_tf_model(self._graph_def, 1695 self._input_tensors, 1696 self._output_tensors, quant_mode) 1697 1698 self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph) 1699 1700 converter_kwargs = self._get_base_converter_args() 1701 converter_kwargs.update( 1702 quant_mode.converter_flags(self.inference_type, 1703 self.inference_input_type)) 1704 converter_kwargs.update({ 1705 "output_format": self.output_format, 1706 "quantized_input_stats": self._quantized_stats, 1707 "default_ranges_stats": self.default_ranges_stats, 1708 "drop_control_dependency": self.drop_control_dependency, 1709 "reorder_across_fake_quant": self.reorder_across_fake_quant, 1710 "change_concat_input_ranges": self.change_concat_input_ranges, 1711 "dump_graphviz_dir": self.dump_graphviz_dir, 1712 "dump_graphviz_video": self.dump_graphviz_video, 1713 "conversion_summary_dir": self.conversion_summary_dir, 1714 "allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops, 1715 }) 1716 1717 if not self.experimental_new_converter: 1718 logging.warning( 1719 "Please consider switching to the new converter by setting " 1720 "experimental_new_converter=True. " 1721 "The old converter (TOCO) is deprecated.") 1722 else: 1723 logging.info("Using experimental converter: If you encountered a problem " 1724 "please file a bug. You can opt-out " 1725 "by setting experimental_new_converter=False") 1726 1727 self._validate_quantized_input_stats(converter_kwargs, quant_mode) 1728 self._validate_experimental_new_quantizer_flag() 1729 # Converts model. 1730 if self._has_valid_tensors(): 1731 result = _toco_convert_impl( 1732 input_data=optimized_graph, 1733 input_tensors=self._input_tensors, 1734 output_tensors=self._output_tensors, 1735 **converter_kwargs) 1736 else: 1737 result = _toco_convert_graph_def( 1738 input_data=optimized_graph, 1739 input_arrays_with_shape=self._input_arrays_with_shape, 1740 output_arrays=self._output_arrays, 1741 control_output_arrays=self._control_output_arrays, 1742 **converter_kwargs) 1743 1744 return self._optimize_tflite_model( 1745 result, quant_mode, quant_io=not self.experimental_new_converter) 1746 1747 def get_input_arrays(self): 1748 """Returns a list of the names of the input tensors. 1749 1750 Returns: 1751 List of strings. 1752 """ 1753 if self._has_valid_tensors(): 1754 return [_get_tensor_name(tensor) for tensor in self._input_tensors] 1755 else: 1756 return [name for name, _ in self._input_arrays_with_shape] 1757 1758 def _has_valid_tensors(self): 1759 """Checks if the input and output tensors have been initialized. 1760 1761 Returns: 1762 Bool. 1763 """ 1764 return self._input_tensors is not None and self._output_tensors 1765 1766 def _set_batch_size(self, batch_size): 1767 """Sets the first dimension of the input tensor to `batch_size`. 1768 1769 Args: 1770 batch_size: Batch size for the model. Replaces the first dimension of an 1771 input size array if undefined. (default 1) 1772 1773 Raises: 1774 ValueError: input_tensor is not defined. 1775 """ 1776 if not self._has_valid_tensors(): 1777 raise ValueError("The batch size cannot be set for this model. Please " 1778 "use input_shapes parameter.") 1779 1780 for tensor in self._input_tensors: 1781 shape = tensor.shape.as_list() 1782 if shape[0] is None: 1783 shape[0] = batch_size 1784 tensor.set_shape(shape) 1785 1786 def _is_unknown_shapes_allowed(self): 1787 # Ophint Converted nodes will need the shapes to be known. 1788 if _is_ophint_converted(self._graph_def): 1789 return False 1790 1791 if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed(): 1792 return False 1793 1794 # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by 1795 # the MLIR converter. 1796 if self.conversion_summary_dir: 1797 logging.warning( 1798 "`conversion_summary_dir` does not work with unknown shapes. " 1799 "Graphs with unknown shapes might be different than when this flag " 1800 "is disabled.") 1801 return False 1802 return True 1803 1804 def _save_conversion_params_metric(self): 1805 self._collected_converter_params.update({ 1806 "output_format": self.output_format, 1807 "default_ranges_stats": self.default_ranges_stats, 1808 "drop_control_dependency": self.drop_control_dependency, 1809 "reorder_across_fake_quant": self.reorder_across_fake_quant, 1810 "change_concat_input_ranges": self.change_concat_input_ranges, 1811 "dump_graphviz_dir": self.dump_graphviz_dir, 1812 "dump_graphviz_video": self.dump_graphviz_video, 1813 "conversion_summary_dir": self.conversion_summary_dir, 1814 "api_version": 1, 1815 }) 1816 super(TFLiteConverterBaseV1, 1817 self)._save_conversion_params_metric(self._graph_def, 1818 self.inference_type, 1819 self.inference_input_type) 1820 1821 1822class TFLiteSavedModelConverter(TFLiteConverterBaseV1): 1823 """Converts the given SavedModel into TensorFlow Lite model. 1824 1825 Attributes: 1826 saved_model_dir: Directory of the SavedModel. 1827 """ 1828 1829 def __init__(self, 1830 saved_model_dir, 1831 saved_model_tags, 1832 saved_model_exported_names, 1833 experimental_debug_info_func=None): 1834 """Constructor for TFLiteConverter. 1835 1836 Args: 1837 saved_model_dir: Directory of the SavedModel. 1838 saved_model_tags: Set of tags identifying the MetaGraphDef within the 1839 SavedModel to analyze. All tags in the tag set must be present. (default 1840 {tf.saved_model.SERVING}). 1841 saved_model_exported_names: Names to be exported when the saved model 1842 import path is on. 1843 experimental_debug_info_func: An experimental function to retrieve the 1844 graph debug info for a set of nodes from the `graph_def`. 1845 1846 Raises: 1847 ValueError: Invalid arguments. 1848 """ 1849 super(TFLiteSavedModelConverter, 1850 self).__init__(experimental_debug_info_func) 1851 self.saved_model_dir = saved_model_dir 1852 self._saved_model_tags = saved_model_tags 1853 self._saved_model_exported_names = saved_model_exported_names 1854 1855 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1856 1857 if len(self._saved_model_exported_names) != 1: 1858 raise ValueError("Only support a single signature key.") 1859 1860 signature_key = self._saved_model_exported_names[0] 1861 1862 result = _freeze_saved_model(self.saved_model_dir, None, None, None, 1863 self._saved_model_tags, signature_key) 1864 self._graph_def = result[0] 1865 self._input_tensors = result[1] 1866 self._output_tensors = result[2] 1867 self._parse_saved_model_args() 1868 1869 @_export_metrics 1870 def convert(self): 1871 """Converts a TensorFlow GraphDef based on instance variables. 1872 1873 Returns: 1874 The converted data in serialized format. Either a TFLite Flatbuffer or a 1875 Graphviz graph depending on value in `output_format`. 1876 1877 Raises: 1878 ValueError: 1879 Input shape is not specified. 1880 None value for dimension in input_tensor. 1881 """ 1882 return super(TFLiteSavedModelConverter, self).convert() 1883 1884 1885class TFLiteKerasModelConverter(TFLiteConverterBaseV1): 1886 """Converts the given SavedModel into TensorFlow Lite model.""" 1887 1888 def __init__(self, 1889 model_file, 1890 input_arrays=None, 1891 input_shapes=None, 1892 output_arrays=None, 1893 custom_objects=None): 1894 """Constructor for TFLiteConverter. 1895 1896 Args: 1897 model_file: Full filepath of HDF5 file containing the tf.keras model. 1898 input_arrays: List of input tensors to freeze graph with. Uses input 1899 arrays from SignatureDef when none are provided. (default None) 1900 input_shapes: Dict of strings representing input tensor names to list of 1901 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 1902 Automatically determined when input shapes is None (e.g., {"foo" : 1903 None}). (default None) 1904 output_arrays: List of output tensors to freeze graph with. Uses output 1905 arrays from SignatureDef when none are provided. (default None) 1906 custom_objects: Dict mapping names (strings) to custom classes or 1907 functions to be considered during model deserialization. (default None) 1908 1909 Raises: 1910 ValueError: Invalid arguments. 1911 """ 1912 super(TFLiteKerasModelConverter, 1913 self).__init__(experimental_debug_info_func=None) 1914 # Handles Keras when Eager mode is enabled. 1915 if context.executing_eagerly(): 1916 if input_arrays or output_arrays: 1917 raise ValueError("`input_arrays` and `output_arrays` are unsupported " 1918 "with Eager mode. If your model requires any of these " 1919 "parameters, please use disable_eager_execution().") 1920 1921 keras_model = keras_deps.get_load_model_function()(model_file, 1922 custom_objects) 1923 function = _trace_model_call(keras_model) 1924 concrete_func = function.get_concrete_function() 1925 1926 frozen_func = _convert_to_constants.convert_variables_to_constants_v2( 1927 concrete_func, lower_control_flow=False) 1928 _set_tensor_shapes(frozen_func.inputs, input_shapes) 1929 self._keras_model = keras_model 1930 self._graph_def = frozen_func.graph.as_graph_def() 1931 self._input_tensors = frozen_func.inputs 1932 self._output_tensors = frozen_func.outputs 1933 self._debug_info_func = _build_debug_info_func(frozen_func.graph) 1934 return 1935 1936 # Handles Keras when Eager mode is disabled. 1937 keras_deps.get_clear_session_function()() 1938 keras_model = keras_deps.get_load_model_function()(model_file, 1939 custom_objects) 1940 sess = keras_deps.get_get_session_function()() 1941 1942 # Get input and output tensors. 1943 if input_arrays: 1944 input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) 1945 else: 1946 input_tensors = keras_model.inputs 1947 1948 if output_arrays: 1949 output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) 1950 else: 1951 output_tensors = keras_model.outputs 1952 _set_tensor_shapes(input_tensors, input_shapes) 1953 1954 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 1955 self._keras_model = keras_model 1956 self._graph_def = graph_def 1957 self._input_tensors = input_tensors 1958 self._output_tensors = output_tensors 1959 self._debug_info_func = _build_debug_info_func(sess.graph) 1960 1961 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL) 1962 def _freeze_keras_model(self, output_dir): 1963 """Save Keras model to Saved Model format. 1964 1965 Args: 1966 output_dir: The output directory to save the SavedModel. 1967 """ 1968 try: 1969 self._keras_model.save(output_dir, save_format="tf") 1970 except Exception: # pylint: disable=broad-except 1971 # When storing the given keras model to a saved model is failed, let's 1972 # use original keras model conversion pipeline. 1973 return None 1974 tag_set = set([_tag_constants.SERVING]) 1975 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1976 graph_def, input_tensors, output_tensors, sess_graph = _freeze_saved_model( 1977 output_dir, None, None, None, tag_set, signature_key) 1978 1979 self.saved_model_dir = output_dir 1980 self._saved_model_tags = tag_set 1981 self._saved_model_exported_names = [signature_key] 1982 self._parse_saved_model_args() 1983 if self.saved_model_dir: 1984 self._graph_def = graph_def 1985 self._input_tensors = input_tensors 1986 self._output_tensors = output_tensors 1987 self._debug_info_func = _build_debug_info_func(sess_graph) 1988 1989 def _convert_as_saved_model(self): 1990 """Converts a Keras model as a saved model. 1991 1992 Returns: 1993 The converted data in serialized format. 1994 """ 1995 temp_dir = tempfile.mkdtemp() 1996 try: 1997 self._freeze_keras_model(temp_dir) 1998 if self.saved_model_dir: 1999 return super(TFLiteKerasModelConverter, self).convert() 2000 finally: 2001 shutil.rmtree(temp_dir, True) 2002 2003 @_export_metrics 2004 def convert(self): 2005 """Converts a Keras model based on instance variables. 2006 2007 Returns: 2008 The converted data in serialized format. Either a TFLite Flatbuffer or a 2009 Graphviz graph depending on value in `output_format`. 2010 2011 Raises: 2012 ValueError: 2013 Input shape is not specified. 2014 None value for dimension in input_tensor. 2015 """ 2016 saved_model_convert_result = self._convert_as_saved_model() 2017 if saved_model_convert_result: 2018 return saved_model_convert_result 2019 2020 return super(TFLiteKerasModelConverter, self).convert() 2021 2022 2023class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1): 2024 """Converts the given frozen graph def into TensorFlow Lite model.""" 2025 2026 def __init__(self, 2027 graph_def, 2028 input_tensors, 2029 output_tensors, 2030 input_arrays_with_shape=None, 2031 output_arrays=None, 2032 experimental_debug_info_func=None): 2033 """Constructor for TFLiteConverter. 2034 2035 Args: 2036 graph_def: Frozen TensorFlow GraphDef. 2037 input_tensors: List of input tensors. Type and shape are computed using 2038 `foo.shape` and `foo.dtype`. 2039 output_tensors: List of output tensors (only .name is used from this). 2040 input_arrays_with_shape: Tuple of strings representing input tensor names 2041 and list of integers representing input shapes 2042 (e.g., [("foo", [1, 16, 16, 3])]). Use only when graph cannot be loaded 2043 into TensorFlow and when `input_tensors` and `output_tensors` are 2044 None. (default None) 2045 output_arrays: List of output tensors to freeze graph with. Use only when 2046 graph cannot be loaded into TensorFlow and when `input_tensors` and 2047 `output_tensors` are None. (default None) 2048 experimental_debug_info_func: An experimental function to retrieve the 2049 graph debug info for a set of nodes from the `graph_def`. 2050 2051 Raises: 2052 ValueError: Invalid arguments. 2053 """ 2054 super(TFLiteFrozenGraphConverter, 2055 self).__init__(experimental_debug_info_func) 2056 self._graph_def = graph_def 2057 self._input_tensors = input_tensors 2058 self._output_tensors = output_tensors 2059 self._control_output_arrays = None 2060 2061 # Attributes are used by models that cannot be loaded into TensorFlow. 2062 if not self._has_valid_tensors(): 2063 self._input_arrays_with_shape = input_arrays_with_shape 2064 self._output_arrays = output_arrays 2065 2066 if input_tensors is not None and input_arrays_with_shape is not None: 2067 logging.warning("input_arrays_with_shape will be ignored when both the " 2068 "given input_tensors and input_arrays_with_shape are not " 2069 "None.") 2070 2071 if output_tensors is not None and output_arrays is not None: 2072 logging.warning("output_arrays will be ignored when both the given " 2073 "output_tensors and output_arrays are not None.") 2074 2075 @_export_metrics 2076 def convert(self): 2077 """Converts a TensorFlow GraphDef based on instance variables. 2078 2079 Returns: 2080 The converted data in serialized format. Either a TFLite Flatbuffer or a 2081 Graphviz graph depending on value in `output_format`. 2082 2083 Raises: 2084 ValueError: 2085 Input shape is not specified. 2086 None value for dimension in input_tensor. 2087 """ 2088 if not self._has_valid_tensors(): 2089 if not self._input_arrays_with_shape or not (self._output_arrays or 2090 self._control_output_arrays): 2091 raise ValueError( 2092 "If input_tensors and output_tensors are None, both " 2093 "input_arrays_with_shape and output_arrays|control_output_arrays " 2094 "must be defined.") 2095 return super(TFLiteFrozenGraphConverter, self).convert() 2096 2097 2098@_tf_export(v1=["lite.TFLiteConverter"]) 2099class TFLiteConverter(TFLiteFrozenGraphConverter): 2100 """Convert a TensorFlow model into `output_format`. 2101 2102 This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras 2103 model into either a TFLite FlatBuffer or graph visualization. 2104 2105 Attributes: 2106 optimizations: Experimental flag, subject to change. Set of optimizations to 2107 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 2108 set of values of type `tf.lite.Optimize`) 2109 representative_dataset: A generator function used for integer quantization 2110 where each generated sample has the same order, type and shape as the 2111 inputs to the model. Usually, this is a small subset of a few hundred 2112 samples randomly chosen, in no particular order, from the training or 2113 evaluation dataset. This is an optional attribute, but required for full 2114 integer quantization, i.e, if `tf.int8` is the only supported type in 2115 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 2116 (default None) 2117 target_spec: Experimental flag, subject to change. Specifications of target 2118 device, including supported ops set, supported types and a set of user's 2119 defined TensorFlow operators required in the TensorFlow Lite runtime. 2120 Refer to `tf.lite.TargetSpec`. 2121 inference_type: Data type of numeric arrays, excluding the input layer. 2122 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 2123 inference_input_type: Data type of the numeric arrays in the input layer. If 2124 `inference_input_type` is in {tf.int8, tf.uint8}, then 2125 `quantized_input_stats` must be provided. (default is the value assigned 2126 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 2127 inference_output_type: Data type of the numeric arrays in the output layer. 2128 (default is the value assigned to `inference_type`, must be in 2129 {tf.float32, tf.int8, tf.uint8}) 2130 quantized_input_stats: Map of input tensor names to a tuple of floats 2131 representing the mean and standard deviation of the training data. 2132 (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8 2133 or tf.uint8. (default None) 2134 default_ranges_stats: Tuple of integers (min, max) representing range values 2135 for all numeric arrays without a specified range. Intended for 2136 experimenting with quantization via "dummy quantization". (default None) 2137 allow_custom_ops: Boolean indicating whether to allow custom operations. 2138 When False any unknown operation is an error. When True, custom ops are 2139 created for any op that is unknown. The developer will need to provide 2140 these to the TensorFlow Lite runtime with a custom resolver. (default 2141 False) 2142 drop_control_dependency: Boolean indicating whether to drop control 2143 dependencies silently. This is due to TFLite not supporting control 2144 dependencies. (default True) 2145 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 2146 nodes in unexpected locations. Used when the location of the FakeQuant 2147 nodes is preventing graph transformations necessary to convert the graph. 2148 Results in a graph that differs from the quantized training graph, 2149 potentially causing differing arithmetic behavior. (default False) 2150 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 2151 inputs and outputs of the concat operator for quantized models. Changes 2152 the ranges of concat operator overlap when true. (default False) 2153 output_format: Output file format. (default 2154 tf.compat.v1.lite.constants.TFLITE, must be in 2155 {tf.compat.v1.lite.constants.TFLITE, 2156 tf.compat.v1.lite.constants.GRAPHVIZ_DOT}) 2157 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 2158 stages of processing GraphViz .dot files. Preferred over 2159 `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep 2160 the requirements of the output file. (default None) 2161 dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot 2162 files after every graph transformation. Requires the `dump_graphviz_dir` 2163 flag to be specified. (default False) 2164 conversion_summary_dir: Full path of the directory to store conversion logs. 2165 (default None) 2166 target_ops: Deprecated. Please use `target_spec.supported_ops` instead. 2167 post_training_quantize: Deprecated. Please use `optimizations` instead and 2168 set it to `{tf.lite.Optimize.DEFAULT}`. (default False) 2169 experimental_new_converter: Experimental flag, subject to change. Enables 2170 MLIR-based conversion instead of TOCO conversion. (default True) 2171 experimental_new_quantizer: Experimental flag, subject to change. Enables 2172 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 2173 (default True) 2174 2175 Example usage: 2176 2177 ```python 2178 # Converting a GraphDef from session. 2179 converter = tf.compat.v1.lite.TFLiteConverter.from_session( 2180 sess, in_tensors, out_tensors) 2181 tflite_model = converter.convert() 2182 open("converted_model.tflite", "wb").write(tflite_model) 2183 2184 # Converting a GraphDef from file. 2185 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( 2186 graph_def_file, input_arrays, output_arrays) 2187 tflite_model = converter.convert() 2188 open("converted_model.tflite", "wb").write(tflite_model) 2189 2190 # Converting a SavedModel. 2191 converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model( 2192 saved_model_dir) 2193 tflite_model = converter.convert() 2194 open("converted_model.tflite", "wb").write(tflite_model) 2195 2196 # Converting a tf.keras model. 2197 converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file( 2198 keras_model) 2199 tflite_model = converter.convert() 2200 open("converted_model.tflite", "wb").write(tflite_model) 2201 ``` 2202 """ 2203 2204 # pylint: disable=useless-super-delegation 2205 def __init__(self, 2206 graph_def, 2207 input_tensors, 2208 output_tensors, 2209 input_arrays_with_shape=None, 2210 output_arrays=None, 2211 experimental_debug_info_func=None): 2212 """Constructor for TFLiteConverter. 2213 2214 Args: 2215 graph_def: Frozen TensorFlow GraphDef. 2216 input_tensors: List of input tensors. Type and shape are computed using 2217 `foo.shape` and `foo.dtype`. 2218 output_tensors: List of output tensors (only .name is used from this). 2219 input_arrays_with_shape: Tuple of strings representing input tensor names 2220 and list of integers representing input shapes 2221 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 2222 into TensorFlow and when `input_tensors` and `output_tensors` are 2223 None. (default None) 2224 output_arrays: List of output tensors to freeze graph with. Use only when 2225 graph cannot be loaded into TensorFlow and when `input_tensors` and 2226 `output_tensors` are None. (default None) 2227 experimental_debug_info_func: An experimental function to retrieve the 2228 graph debug info for a set of nodes from the `graph_def`. 2229 2230 Raises: 2231 ValueError: Invalid arguments. 2232 """ 2233 super(TFLiteConverter, 2234 self).__init__(graph_def, input_tensors, output_tensors, 2235 input_arrays_with_shape, output_arrays, 2236 experimental_debug_info_func) 2237 2238 @classmethod 2239 def from_session(cls, sess, input_tensors, output_tensors): 2240 """Creates a TFLiteConverter class from a TensorFlow Session. 2241 2242 Args: 2243 sess: TensorFlow Session. 2244 input_tensors: List of input tensors. Type and shape are computed using 2245 `foo.shape` and `foo.dtype`. 2246 output_tensors: List of output tensors (only .name is used from this). 2247 2248 Returns: 2249 TFLiteConverter class. 2250 """ 2251 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 2252 return cls( 2253 graph_def, 2254 input_tensors, 2255 output_tensors, 2256 experimental_debug_info_func=_build_debug_info_func(sess.graph)) 2257 2258 @classmethod 2259 def from_frozen_graph(cls, 2260 graph_def_file, 2261 input_arrays, 2262 output_arrays, 2263 input_shapes=None): 2264 """Creates a TFLiteConverter class from a file containing a frozen GraphDef. 2265 2266 Args: 2267 graph_def_file: Full filepath of file containing frozen GraphDef. 2268 input_arrays: List of input tensors to freeze graph with. 2269 output_arrays: List of output tensors to freeze graph with. 2270 input_shapes: Dict of strings representing input tensor names to list of 2271 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 2272 Automatically determined when input shapes is None (e.g., {"foo" : 2273 None}). (default None) 2274 2275 Returns: 2276 TFLiteConverter class. 2277 2278 Raises: 2279 IOError: 2280 File not found. 2281 Unable to parse input file. 2282 ValueError: 2283 The graph is not frozen. 2284 input_arrays or output_arrays contains an invalid tensor name. 2285 input_shapes is not correctly defined when required 2286 """ 2287 with _ops.Graph().as_default(): 2288 with _session.Session() as sess: 2289 # Read GraphDef from file. 2290 if not gfile.Exists(graph_def_file): 2291 raise IOError("File '{0}' does not exist.".format(graph_def_file)) 2292 with gfile.GFile(graph_def_file, "rb") as f: 2293 file_content = f.read() 2294 2295 try: 2296 graph_def = _graph_pb2.GraphDef() 2297 graph_def.ParseFromString(file_content) 2298 except (_text_format.ParseError, DecodeError): 2299 try: 2300 print("Ignore 'tcmalloc: large alloc' warnings.") 2301 2302 if not isinstance(file_content, str): 2303 if PY2: 2304 file_content = six.ensure_binary(file_content, "utf-8") 2305 else: 2306 file_content = six.ensure_text(file_content, "utf-8") 2307 graph_def = _graph_pb2.GraphDef() 2308 _text_format.Merge(file_content, graph_def) 2309 except (_text_format.ParseError, DecodeError): 2310 raise IOError( 2311 "Unable to parse input file '{}'.".format(graph_def_file)) 2312 2313 # Handles models with custom TFLite ops that cannot be resolved in 2314 # TensorFlow. 2315 load_model_in_session = True 2316 try: 2317 _import_graph_def(graph_def, name="") 2318 except _NotFoundError: 2319 load_model_in_session = False 2320 2321 if load_model_in_session: 2322 # Check if graph is frozen. 2323 if not _is_frozen_graph(sess): 2324 raise ValueError("Please freeze the graph using freeze_graph.py.") 2325 2326 # Get input and output tensors. 2327 input_tensors = _get_tensors_from_tensor_names( 2328 sess.graph, input_arrays) 2329 output_tensors = _get_tensors_from_tensor_names( 2330 sess.graph, output_arrays) 2331 _set_tensor_shapes(input_tensors, input_shapes) 2332 2333 return cls(sess.graph_def, input_tensors, output_tensors) 2334 else: 2335 if not input_shapes: 2336 raise ValueError("input_shapes must be defined for this model.") 2337 if set(input_arrays) != set(input_shapes.keys()): 2338 raise ValueError("input_shapes must contain a value for each item " 2339 "in input_array.") 2340 2341 input_arrays_with_shape = [ 2342 (name, input_shapes[name]) for name in input_arrays 2343 ] 2344 return cls( 2345 graph_def, 2346 input_tensors=None, 2347 output_tensors=None, 2348 input_arrays_with_shape=input_arrays_with_shape, 2349 output_arrays=output_arrays) 2350 2351 @classmethod 2352 def from_saved_model(cls, 2353 saved_model_dir, 2354 input_arrays=None, 2355 input_shapes=None, 2356 output_arrays=None, 2357 tag_set=None, 2358 signature_key=None): 2359 """Creates a TFLiteConverter class from a SavedModel. 2360 2361 Args: 2362 saved_model_dir: SavedModel directory to convert. 2363 input_arrays: List of input tensors to freeze graph with. Uses input 2364 arrays from SignatureDef when none are provided. (default None) 2365 input_shapes: Dict of strings representing input tensor names to list of 2366 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 2367 Automatically determined when input shapes is None (e.g., {"foo" : 2368 None}). (default None) 2369 output_arrays: List of output tensors to freeze graph with. Uses output 2370 arrays from SignatureDef when none are provided. (default None) 2371 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 2372 analyze. All tags in the tag set must be present. (default 2373 {tf.saved_model.SERVING}) 2374 signature_key: Key identifying SignatureDef containing inputs and outputs. 2375 (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 2376 2377 Returns: 2378 TFLiteConverter class. 2379 """ 2380 if tag_set is None: 2381 tag_set = set([_tag_constants.SERVING]) 2382 if signature_key is None: 2383 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 2384 2385 saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set, 2386 [signature_key]) 2387 if saved_model_converter.saved_model_dir: 2388 return saved_model_converter 2389 2390 result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, 2391 output_arrays, tag_set, signature_key) 2392 2393 return cls( 2394 graph_def=result[0], 2395 input_tensors=result[1], 2396 output_tensors=result[2], 2397 experimental_debug_info_func=_build_debug_info_func(result[3])) 2398 2399 @classmethod 2400 def from_keras_model_file(cls, 2401 model_file, 2402 input_arrays=None, 2403 input_shapes=None, 2404 output_arrays=None, 2405 custom_objects=None): 2406 """Creates a TFLiteConverter class from a tf.keras model file. 2407 2408 Args: 2409 model_file: Full filepath of HDF5 file containing the tf.keras model. 2410 input_arrays: List of input tensors to freeze graph with. Uses input 2411 arrays from SignatureDef when none are provided. (default None) 2412 input_shapes: Dict of strings representing input tensor names to list of 2413 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 2414 Automatically determined when input shapes is None (e.g., {"foo" : 2415 None}). (default None) 2416 output_arrays: List of output tensors to freeze graph with. Uses output 2417 arrays from SignatureDef when none are provided. (default None) 2418 custom_objects: Dict mapping names (strings) to custom classes or 2419 functions to be considered during model deserialization. (default None) 2420 2421 Returns: 2422 TFLiteConverter class. 2423 """ 2424 return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes, 2425 output_arrays, custom_objects) 2426 2427 # pylint: disable=useless-super-delegation 2428 def convert(self): 2429 """Converts a TensorFlow GraphDef based on instance variables. 2430 2431 Returns: 2432 The converted data in serialized format. Either a TFLite Flatbuffer or a 2433 Graphviz graph depending on value in `output_format`. 2434 2435 Raises: 2436 ValueError: 2437 Input shape is not specified. 2438 None value for dimension in input_tensor. 2439 """ 2440 return super(TFLiteConverter, self).convert() 2441 2442 2443@_tf_export(v1=["lite.TocoConverter"]) 2444class TocoConverter(object): 2445 """Convert a TensorFlow model into `output_format` using TOCO. 2446 2447 This class has been deprecated. Please use `lite.TFLiteConverter` instead. 2448 """ 2449 2450 @classmethod 2451 @_deprecation.deprecated(None, 2452 "Use `lite.TFLiteConverter.from_session` instead.") 2453 def from_session(cls, sess, input_tensors, output_tensors): 2454 """Creates a TocoConverter class from a TensorFlow Session.""" 2455 return TFLiteConverter.from_session(sess, input_tensors, output_tensors) 2456 2457 @classmethod 2458 @_deprecation.deprecated( 2459 None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.") 2460 def from_frozen_graph(cls, 2461 graph_def_file, 2462 input_arrays, 2463 output_arrays, 2464 input_shapes=None): 2465 """Creates a TocoConverter class from a file containing a frozen graph.""" 2466 return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 2467 output_arrays, input_shapes) 2468 2469 @classmethod 2470 @_deprecation.deprecated( 2471 None, "Use `lite.TFLiteConverter.from_saved_model` instead.") 2472 def from_saved_model(cls, 2473 saved_model_dir, 2474 input_arrays=None, 2475 input_shapes=None, 2476 output_arrays=None, 2477 tag_set=None, 2478 signature_key=None): 2479 """Creates a TocoConverter class from a SavedModel.""" 2480 return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays, 2481 input_shapes, output_arrays, 2482 tag_set, signature_key) 2483 2484 @classmethod 2485 @_deprecation.deprecated( 2486 None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.") 2487 def from_keras_model_file(cls, 2488 model_file, 2489 input_arrays=None, 2490 input_shapes=None, 2491 output_arrays=None): 2492 """Creates a TocoConverter class from a tf.keras model file.""" 2493 return TFLiteConverter.from_keras_model_file(model_file, input_arrays, 2494 input_shapes, output_arrays) 2495