• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Lite tooling helper functionality."""
16
17import enum
18import functools
19import pprint
20import shutil
21import tempfile
22import time
23import warnings
24
25from absl import logging
26
27from google.protobuf import text_format as _text_format
28from google.protobuf.message import DecodeError
29from tensorflow.core.framework import graph_pb2 as _graph_pb2
30from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op  # pylint: disable=unused-import
31from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metdata_fb
32from tensorflow.lite.python import lite_constants as constants
33from tensorflow.lite.python.convert import convert_graphdef as _convert_graphdef
34from tensorflow.lite.python.convert import convert_graphdef_with_arrays as _convert_graphdef_with_arrays
35from tensorflow.lite.python.convert import convert_jax_hlo as _convert_jax_hlo
36from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model
37from tensorflow.lite.python.convert import ConverterError  # pylint: disable=unused-import
38from tensorflow.lite.python.convert import deduplicate_readonly_buffers as _deduplicate_readonly_buffers
39from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize
40from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify
41from tensorflow.lite.python.convert import OpsSet
42from tensorflow.lite.python.convert import toco_convert  # pylint: disable=unused-import
43from tensorflow.lite.python.convert_phase import Component
44from tensorflow.lite.python.convert_phase import convert_phase
45from tensorflow.lite.python.convert_phase import SubComponent
46from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
47from tensorflow.lite.python.interpreter import Interpreter  # pylint: disable=unused-import
48from tensorflow.lite.python.interpreter import load_delegate  # pylint: disable=unused-import
49from tensorflow.lite.python.interpreter import OpResolverType  # pylint: disable=unused-import
50from tensorflow.lite.python.metrics import metrics
51from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs  # pylint: disable=unused-import
52from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
53from tensorflow.lite.python.op_hint import OpHint  # pylint: disable=unused-import
54from tensorflow.lite.python.optimize import calibrator as _calibrator
55from tensorflow.lite.python.util import _xla_computation
56from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
57from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func
58from tensorflow.lite.python.util import freeze_graph as _freeze_graph
59from tensorflow.lite.python.util import get_debug_info as _get_debug_info
60from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
61from tensorflow.lite.python.util import get_sparsity_modes as _get_sparsity_modes
62from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
63from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
64from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
65from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
66from tensorflow.lite.python.util import model_input_signature as _model_input_signature
67from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
68from tensorflow.lite.python.util import populate_conversion_metadata as _populate_conversion_metadata
69from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
70from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
71from tensorflow.lite.python.util import trace_model_call as _trace_model_call
72from tensorflow.lite.tools import flatbuffer_utils
73from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugger  # pylint: disable=unused-import
74from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugOptions  # pylint: disable=unused-import
75from tensorflow.python import saved_model as _saved_model
76from tensorflow.python.client import session as _session
77from tensorflow.python.eager import context
78from tensorflow.python.eager import def_function as _def_function
79from tensorflow.python.eager import function as _function
80from tensorflow.python.framework import convert_to_constants as _convert_to_constants
81from tensorflow.python.framework import dtypes as _dtypes
82from tensorflow.python.framework import ops as _ops
83from tensorflow.python.framework import versions
84from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
85from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
86from tensorflow.python.platform import gfile
87from tensorflow.python.saved_model import loader_impl as _loader_impl
88from tensorflow.python.saved_model import save_options as _save_options
89from tensorflow.python.saved_model import signature_constants as _signature_constants
90from tensorflow.python.saved_model import tag_constants as _tag_constants
91from tensorflow.python.saved_model.load import load as _load
92from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
93from tensorflow.python.util import deprecation as _deprecation
94from tensorflow.python.util import keras_deps
95from tensorflow.python.util.tf_export import tf_export as _tf_export
96
97
98@_tf_export("lite.Optimize")
99class Optimize(enum.Enum):
100  """Enum defining the optimizations to apply when generating a tflite model.
101
102  DEFAULT
103      Default optimization strategy that quantizes model weights. Enhanced
104      optimizations are gained by providing a representative dataset that
105      quantizes biases and activations as well.
106      Converter will do its best to reduce size and latency, while minimizing
107      the loss in accuracy.
108
109  OPTIMIZE_FOR_SIZE
110      Deprecated. Does the same as DEFAULT.
111
112  OPTIMIZE_FOR_LATENCY
113      Deprecated. Does the same as DEFAULT.
114
115  EXPERIMENTAL_SPARSITY
116      Experimental flag, subject to change.
117
118      Enable optimization by taking advantage of the sparse model weights
119      trained with pruning.
120
121      The converter will inspect the sparsity pattern of the model weights and
122      do its best to improve size and latency.
123      The flag can be used alone to optimize float32 models with sparse weights.
124      It can also be used together with the DEFAULT optimization mode to
125      optimize quantized models with sparse weights.
126  """
127
128  # Default optimization strategy that quantizes model weights. Enhanced
129  # optimizations are gained by providing a representative dataset that
130  # quantizes biases and activations as well.
131  # Converter will do its best to reduce size and latency, while minimizing
132  # the loss in accuracy.
133  DEFAULT = "DEFAULT"
134
135  # Deprecated. Does the same as DEFAULT.
136  OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
137
138  # Deprecated. Does the same as DEFAULT.
139  OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
140
141  # Experimental flag, subject to change.
142  # Enable optimization by taking advantage of the sparse model weights trained
143  # with pruning.
144  #
145  # The converter will inspect the sparsity pattern of the model weights and do
146  # its best to improve size and latency.
147  # The flag can be used alone to optimize float32 models with sparse weights.
148  # It can also be used together with the DEFAULT optimization mode to optimize
149  # quantized models with sparse weights.
150  # TODO(b/161560631): Add log message when this optimization is applied.
151  EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY"
152
153  def __str__(self):
154    return str(self.value)
155
156
157# TODO(b/198099651): move converter implementation out of lite.py
158@_tf_export("lite.RepresentativeDataset")
159class RepresentativeDataset:
160  """Representative dataset used to optimize the model.
161
162  This is a generator function that provides a small dataset to calibrate or
163  estimate the range, i.e, (min, max) of all floating-point arrays in the model
164  (such as model input, activation outputs of intermediate layers, and model
165  output) for quantization. Usually, this is a small subset of a few hundred
166  samples randomly chosen, in no particular order, from the training or
167  evaluation dataset.
168  """
169
170  def __init__(self, input_gen):
171    """Creates a representative dataset.
172
173    Args:
174      input_gen: A generator function that generates input samples for the
175        model and has the same order, type and shape as the inputs to the model.
176        Usually, this is a small subset of a few hundred samples randomly
177        chosen, in no particular order, from the training or evaluation dataset.
178    """
179    self.input_gen = input_gen
180
181
182@_tf_export("lite.TargetSpec")
183class TargetSpec:
184  """Specification of target device used to optimize the model.
185
186  Attributes:
187    supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
188      options, where each option represents a set of operators supported by the
189      target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
190    supported_types: Set of `tf.dtypes.DType` data types supported on the target
191      device. If initialized, optimization might be driven by the smallest type
192      in this set. (default set())
193    experimental_select_user_tf_ops: Experimental flag, subject to change. Set
194      of user's TensorFlow operators' names that are required in the TensorFlow
195      Lite runtime. These ops will be exported as select TensorFlow ops in the
196      model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
197      an advanced feature that should only be used if the client is using TF ops
198      that may not be linked in by default with the TF ops that are provided
199      when using the SELECT_TF_OPS path. The client is responsible for linking
200      these ops into the target runtime.
201    experimental_supported_backends: Experimental flag, subject to change.
202      Set containing names of supported backends. Currently only "GPU" is
203      supported, more options will be available later.
204  """
205
206  def __init__(self,
207               supported_ops=None,
208               supported_types=None,
209               experimental_select_user_tf_ops=None,
210               experimental_supported_backends=None):
211    if supported_ops is None:
212      supported_ops = {OpsSet.TFLITE_BUILTINS}
213    self.supported_ops = supported_ops
214    if supported_types is None:
215      supported_types = set()
216    self.supported_types = supported_types
217    if experimental_select_user_tf_ops is None:
218      experimental_select_user_tf_ops = set()
219    self.experimental_select_user_tf_ops = experimental_select_user_tf_ops
220    self.experimental_supported_backends = experimental_supported_backends
221    self._experimental_custom_op_registerers = []
222    # Hint for the supported accumulation type used for inference. Typically
223    # used for fp16 post-training quantization, where some models can use fp16
224    # accumulators instead of the typical fp32 type.
225    # TODO(b/188185962): Provide full API and authoring support for
226    # reduced precision accumulation types.
227    self._experimental_supported_accumulation_type = None
228
229
230class QuantizationMode:
231  """QuantizationMode determines the quantization type from user options."""
232
233  def __init__(self,
234               optimizations,
235               target_spec,
236               representative_dataset,
237               graph_def,
238               disable_per_channel=False,
239               experimental_new_dynamic_range_quantizer=False,
240               experimental_low_bit_qat=False,
241               full_integer_quantization_bias_type=None):
242    self._optimizations = optimizations
243    for deprecated_optimization in [
244        Optimize.OPTIMIZE_FOR_SIZE, Optimize.OPTIMIZE_FOR_LATENCY
245    ]:
246      if deprecated_optimization in self._optimizations:
247        logging.warning(
248            "Optimization option %s is deprecated, please use optimizations="
249            "[Optimize.DEFAULT] instead.", deprecated_optimization)
250
251    self._target_spec = target_spec
252    self._representative_dataset = representative_dataset
253    self._graph_def = graph_def
254
255    self._validate_int8_required()
256    self._disable_per_channel = disable_per_channel
257
258    self._enable_new_dynamic_range_quantizer = (
259        experimental_new_dynamic_range_quantizer)
260    # Allow training with lower than 8 bit weights to be converted
261    # to constants with trained scale.
262    self._experimental_low_bit_qat = experimental_low_bit_qat
263
264    self._full_integer_quantization_bias_type = full_integer_quantization_bias_type
265    self._validate_full_integer_quantization_bias_type()
266
267  def is_post_training_int8_only_quantization(self):
268    return (self.is_any_optimization_enabled() and
269            self._representative_dataset is not None and
270            not self._is_int16x8_target_required() and
271            not self.is_allow_float() and
272            self._is_int8_target_required())
273
274  def is_post_training_int8_quantization_with_float_fallback(self):
275    return (self.is_any_optimization_enabled() and
276            self._representative_dataset is not None and
277            not self._is_int16x8_target_required() and
278            self.is_allow_float() and
279            self._smallest_supported_type() == _dtypes.int8)
280
281  def is_post_training_int8_quantization(self):
282    return (self.is_post_training_int8_only_quantization() or
283            self.is_post_training_int8_quantization_with_float_fallback())
284
285  def is_post_training_int16x8_only_quantization(self):
286    return (self.is_any_optimization_enabled() and
287            self._representative_dataset is not None and
288            self._is_int16x8_target_required() and
289            not self.is_allow_float())
290
291  def is_post_training_int16x8_quantization_with_float_fallback(self):
292    return (self.is_any_optimization_enabled() and
293            self._representative_dataset is not None and
294            self._is_int16x8_target_required() and
295            self.is_allow_float())
296
297  def is_post_training_int16x8_quantization(self):
298    return (self.is_post_training_int16x8_only_quantization() or
299            self.is_post_training_int16x8_quantization_with_float_fallback())
300
301  def is_post_training_integer_quantization(self):
302    return (self.is_post_training_int8_quantization() or
303            self.is_post_training_int16x8_quantization())
304
305  def is_low_bit_quantize_aware_training(self):
306    return (self.is_any_optimization_enabled() and
307            self.is_quantization_aware_trained_model() and
308            self._experimental_low_bit_qat)
309
310  def is_quantization_aware_training(self):
311    return (self.is_any_optimization_enabled() and
312            self.is_quantization_aware_trained_model() and
313            not self.is_low_bit_quantize_aware_training())
314
315  def is_integer_quantization(self):
316    return (self.is_post_training_integer_quantization() or
317            self.is_quantization_aware_training() or
318            self.is_low_bit_quantize_aware_training())
319
320  def is_post_training_dynamic_range_quantization(self):
321    # Post-training dynamic range quantization is only enabled if post-training
322    # int8 quantization and training time quantization was not done.
323    return (self.is_any_optimization_enabled() and
324            self._representative_dataset is None and
325            not self.is_quantization_aware_trained_model() and
326            self._smallest_supported_type() == _dtypes.int8)
327
328  def is_post_training_float16_quantization(self):
329    return (self.is_any_optimization_enabled() and
330            self._smallest_supported_type().size == 2 and
331            _dtypes.float16 in self._target_spec.supported_types)
332
333  def is_bfloat16_quantization(self):
334    return (self.is_any_optimization_enabled() and
335            self._smallest_supported_type().size == 2 and
336            _dtypes.bfloat16 in self._target_spec.supported_types)
337
338  def activations_type(self):
339    if self.is_integer_quantization():
340      if self._is_int16x8_target_required():
341        return _dtypes.int16
342      else:
343        return _dtypes.int8
344    else:
345      return _dtypes.float32
346
347  def bias_type(self):
348    if self._full_integer_quantization_bias_type:
349      return self._full_integer_quantization_bias_type
350
351    if self.activations_type() == _dtypes.int16:
352      return _dtypes.int64
353    elif self.activations_type() == _dtypes.int8:
354      return _dtypes.int32
355    else:
356      return _dtypes.float32
357
358  def converter_flags(self, inference_ty=None, inference_input_ty=None):
359    """Flags to the converter."""
360
361    if self.is_integer_quantization():
362      is_low_bit_qat = self.is_low_bit_quantize_aware_training()
363      return {
364          "inference_type": (inference_ty if inference_ty is not None else
365                             self.activations_type()),
366          "inference_input_type": _dtypes.float32,
367          "post_training_quantize": False,  # disable dynamic range quantization
368          "quantize_to_float16": False,  # disable float16 quantization
369          "disable_infer_tensor_range": is_low_bit_qat,
370          "use_fake_quant_num_bits": is_low_bit_qat,
371      }
372    elif self.is_post_training_dynamic_range_quantization():
373      return {
374          "inference_type": _dtypes.float32,
375          "inference_input_type": _dtypes.float32,
376          "post_training_quantize": True,  # enable dynamic range quantization
377          "quantize_to_float16": False,  # disable float16 quantization
378          # experimental: disable per-channel (per-axis) quantization.
379          "disable_per_channel_quantization":
380              self._disable_per_channel,
381          "enable_mlir_dynamic_range_quantizer":
382              self._enable_new_dynamic_range_quantizer
383      }
384    elif self.is_post_training_float16_quantization():
385      return {
386          "inference_type": _dtypes.float32,
387          "inference_input_type": _dtypes.float32,
388          "post_training_quantize": True,
389          "quantize_to_float16": True,  # enable float16 quantization
390          "accumulation_type":
391              self._target_spec._experimental_supported_accumulation_type,  # pylint: disable=protected-access
392          "allow_bfloat16":
393              self.is_bfloat16_quantization(),
394          "enable_mlir_dynamic_range_quantizer":
395              self._enable_new_dynamic_range_quantizer
396      }
397    else:
398      # Note this might still trigger (uint8) quantization to be compatible with
399      # the old converter.
400      return {
401          "inference_type": (
402              inference_ty if inference_ty is not None else _dtypes.float32),
403          "inference_input_type": inference_input_ty,
404          "post_training_quantize": False,  # enable dynamic range quantization
405          "quantize_to_float16": False,  # disable float16 quantization
406          "allow_bfloat16": self.is_bfloat16_quantization()
407      }
408
409  # Below are helpers for the above functions.
410
411  def _validate_int8_required(self):
412    """Int8 mode requires certain parameters to exist and be compatible."""
413    if not self._is_int8_target_required():
414      return
415
416    # Validate target_spec attibute.
417    if (set(self._target_spec.supported_ops) == {OpsSet.TFLITE_BUILTINS_INT8}
418        and not (set(self._target_spec.supported_types) == set() or
419                 set(self._target_spec.supported_types) == {_dtypes.int8})):
420      raise ValueError(
421          "As full integer quantization has been enabled by setting "
422          "`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, "
423          "thus `target_spec.supported_types` should be left uninitizalized "
424          "or set to {tf.int8}.")
425    if set(self._target_spec.supported_types) == {_dtypes.int8}:
426      self._target_spec.supported_ops = {OpsSet.TFLITE_BUILTINS_INT8}
427
428    # Check if representative_dataset is specified.
429    if (not self._representative_dataset and
430        not self.is_quantization_aware_training()):
431      raise ValueError("For full integer quantization, a "
432                       "`representative_dataset` must be specified.")
433
434    # Update represenative dataset to the expected format.
435    if self._representative_dataset:
436      if not isinstance(self._representative_dataset, RepresentativeDataset):
437        self._representative_dataset = RepresentativeDataset(
438            self._representative_dataset)
439
440  def _validate_full_integer_quantization_bias_type(self):
441    """Validates bias type for full interger quantization."""
442    bias_type = self._full_integer_quantization_bias_type
443    if not bias_type:
444      return
445
446    if self.activations_type() == _dtypes.float32:
447      raise ValueError(
448          "`full_integer_quantization_bias_type` is only supported for full integer quantization."
449      )
450
451    if self.activations_type() == _dtypes.int8 and bias_type != _dtypes.int32:
452      raise ValueError(
453          f"Expected bias type to be `dtypes.int32` for Int8Quant. "
454          f"Current setting bias type: {bias_type}")
455
456    if self.activations_type(
457    ) == _dtypes.int16 and bias_type != _dtypes.int32 and bias_type != _dtypes.int64:
458      raise ValueError(
459          f"Expected bias type to be `dtypes.int32` or `dtypes.int64` for "
460          f"Int16Quant. Current setting bias type: {bias_type}")
461
462  def _is_int8_target_required(self):
463    return (OpsSet.TFLITE_BUILTINS_INT8 in set(
464        self._target_spec.supported_ops)) or (set(
465            self._target_spec.supported_types) == set([_dtypes.int8]))
466
467  def _is_int16x8_target_required(self):
468    return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
469            in set(self._target_spec.supported_ops))
470
471  def is_allow_float(self):
472    return (OpsSet.TFLITE_BUILTINS in set(
473        self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set(
474            self._target_spec.supported_ops))
475
476  def is_any_optimization_enabled(self):
477    return bool(
478        set(self._optimizations).intersection([
479            Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
480            Optimize.DEFAULT
481        ]))
482
483  def _smallest_supported_type(self):
484    if self._target_spec.supported_types:
485      return min(self._target_spec.supported_types, key=lambda x: x.size)
486    else:
487      # The default smallest supported type is INT8.
488      return _dtypes.int8
489
490  def is_quantization_aware_trained_model(self):
491    """Checks if the graph contains any training-time quantization ops."""
492    training_quant_ops = frozenset({
493        "FakeQuantWithMinMaxVars",
494        "FakeQuantWithMinMaxVarsPerChannel",
495        "FakeQuantWithMinMaxArgs",
496        "QuantizeAndDequantizeV2",
497        "QuantizeAndDequantizeV3",
498    })
499
500    if self._graph_def:
501      for node_def in self._graph_def.node:
502        if node_def.op in training_quant_ops:
503          return True
504      for function in self._graph_def.library.function:
505        for node_def in function.node_def:
506          if node_def.op in training_quant_ops:
507            return True
508    return False
509
510
511class TFLiteConverterBase:
512  """Converter subclass to share functionality between V1 and V2 converters."""
513
514  # Stores the original model type temporarily to transmit the information
515  # from the factory class methods to TFLiteConverterBase init function.
516  _original_model_type = conversion_metdata_fb.ModelType.NONE
517
518  def __init__(self):
519    self.optimizations = set()
520    self.representative_dataset = None
521    self.target_spec = TargetSpec()
522    self.allow_custom_ops = False
523    self.experimental_new_converter = True
524    self.experimental_new_quantizer = True
525    self.experimental_enable_resource_variables = True
526    self._experimental_calibrate_only = False
527    self._experimental_sparsify_model = False
528    self._experimental_disable_per_channel = False
529    self._debug_info = None  # contains the stack traces of all the original
530    # nodes in the `GraphDef` to the converter.
531    self.saved_model_dir = None
532    self._saved_model_tags = None
533    self._saved_model_version = 0
534    self._saved_model_exported_names = []
535    self._tflite_metrics = metrics.TFLiteConverterMetrics()
536    self._collected_converter_params = {}
537    self._experimental_disable_batchmatmul_unfold = False
538    self._experimental_lower_tensor_list_ops = True
539    self._experimental_default_to_single_batch_in_tensor_list_ops = False
540    self._experimental_unfold_large_splat_constant = False
541    self._experimental_tf_quantization_mode = None
542    # If unset, bias:int32 is by default except 16x8 quant.
543    # For 16x8 quant, bias:int64 is used to prevent any overflow by default.
544    self._experimental_full_integer_quantization_bias_type = None
545    # Initializes conversion metadata.
546    self.exclude_conversion_metadata = False
547    self._metadata = conversion_metdata_fb.ConversionMetadataT()
548    self._metadata.environment = conversion_metdata_fb.EnvironmentT()
549    self._metadata.options = conversion_metdata_fb.ConversionOptionsT()
550    self._metadata.environment.tensorflowVersion = versions.__version__
551    self._metadata.environment.modelType = self._get_original_model_type()
552    self._experimental_enable_dynamic_update_slice = False
553    self._experimental_preserve_assert_op = False
554    self._experimental_guarantee_all_funcs_one_use = False
555
556    # When the value is true, the MLIR quantantizer triggers dynamic range
557    # quantization in MLIR instead of the old quantizer. Used only if
558    # experimental_new_quantizer is on.
559    self.experimental_new_dynamic_range_quantizer = True
560    # Experimental flag to enable low-bit QAT in 8 bit.
561    self._experimental_low_bit_qat = False
562    # Experimental flag to add all TF ops (including custom TF ops) to the
563    # converted model as flex ops.
564    self._experimental_allow_all_select_tf_ops = False
565
566  def _grappler_config(self, optimizers=None):
567    """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
568
569    Args:
570      optimizers: List of strings that represents the list of optimizers.
571
572    Returns:
573      tf.ConfigProto.
574    """
575    if not optimizers:
576      optimizers = []
577    # MLIR converter will take care of constant folding instead of grappler.
578    if not self.experimental_new_converter:
579      optimizers.append("constfold")
580
581    is_only_flex_enabled = (
582        set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
583    if is_only_flex_enabled:
584      # The layout optimizer turns NHCW to NCHW. This provides performance
585      # optimizations when Flex mode is enabled. However, this is not compatible
586      # with builtin ops.
587      optimizers.append("layout")
588    return _get_grappler_config(optimizers)
589
590  def _quantize(self, result, input_type, output_type, activations_type,
591                bias_type, allow_float):
592    """Quantize the model."""
593    # pylint: disable=protected-access
594    custom_op_registerers_by_name = [
595        x for x in self.target_spec._experimental_custom_op_registerers
596        if isinstance(x, str)
597    ]
598    custom_op_registerers_by_func = [
599        x for x in self.target_spec._experimental_custom_op_registerers
600        if not isinstance(x, str)
601    ]
602    # pylint: enable=protected-access
603    if not isinstance(self.representative_dataset, RepresentativeDataset):
604      self.representative_dataset = RepresentativeDataset(
605          self.representative_dataset)
606
607    # Add intermediate tensors to the model if needed.
608    result = _calibrator.add_intermediate_tensors(result)
609    calibrate_quantize = _calibrator.Calibrator(result,
610                                                custom_op_registerers_by_name,
611                                                custom_op_registerers_by_func)
612    if self._experimental_calibrate_only or self.experimental_new_quantizer:
613      calibrated = calibrate_quantize.calibrate(
614          self.representative_dataset.input_gen)
615
616    if self._experimental_calibrate_only:
617      return calibrated
618    elif self.experimental_new_quantizer and (
619        activations_type != _dtypes.int16):
620      # TODO(b/175659372): remove the activations_type restriction and enable
621      # it for all the activation types.
622      return _mlir_quantize(
623          calibrated,
624          self._experimental_disable_per_channel,
625          input_data_type=input_type,
626          output_data_type=output_type)
627    else:
628      return calibrate_quantize.calibrate_and_quantize(
629          self.representative_dataset.input_gen,
630          input_type,
631          output_type,
632          allow_float,
633          activations_type,
634          bias_type,
635          disable_per_channel=self._experimental_disable_per_channel)
636
637  def _is_unknown_shapes_allowed(self):
638    # Unknown dimensions are only allowed with the new converter.
639    return self.experimental_new_converter
640
641  def _get_base_converter_args(self):
642    """Returns the base converter args.
643
644    Returns:
645      {key str: val}
646    """
647    args = {
648        "input_format":
649            constants.TENSORFLOW_GRAPHDEF,
650        "allow_custom_ops":
651            self.allow_custom_ops,
652        "debug_info":
653            self._debug_info,
654        "target_ops":
655            self.target_spec.supported_ops,
656        "enable_mlir_converter":
657            self.experimental_new_converter,
658        "select_user_tf_ops":
659            self.target_spec.experimental_select_user_tf_ops,
660        "supported_backends":
661            self.target_spec.experimental_supported_backends,
662        "unfold_batchmatmul":
663            not self._experimental_disable_batchmatmul_unfold,
664        "lower_tensor_list_ops":
665            self._experimental_lower_tensor_list_ops,
666        "unfold_large_splat_constant":
667            self._experimental_unfold_large_splat_constant,
668        "default_to_single_batch_in_tensor_list_ops":
669            self._experimental_default_to_single_batch_in_tensor_list_ops,
670        "tf_quantization_mode":
671            self._experimental_tf_quantization_mode,
672        "experimental_enable_resource_variables":
673            self.experimental_enable_resource_variables,
674        "enable_dynamic_update_slice":
675            self._experimental_enable_dynamic_update_slice,
676        "preserve_assert_op":
677            self._experimental_preserve_assert_op,
678        "guarantee_all_funcs_one_use":
679            self._experimental_guarantee_all_funcs_one_use,
680        "allow_all_select_tf_ops":
681            self._experimental_allow_all_select_tf_ops,
682    }
683
684    if self.saved_model_dir:
685      args.update({
686          "saved_model_dir": self.saved_model_dir,
687          "saved_model_version": self._saved_model_version,
688          "saved_model_tags": self._saved_model_tags,
689          "saved_model_exported_names": self._saved_model_exported_names,
690      })
691
692    return args
693
694  def _contains_function_with_implements_attr(self, saved_model_proto):
695    meta_graph = saved_model_proto.meta_graphs[0]
696    for function in meta_graph.graph_def.library.function:
697      if function.attr.get("_implements", None) or function.attr.get(
698          "api_implements", None):
699        return True
700    return False
701
702  def _parse_saved_model_args(self, always_enable_saved_model_import=False):
703    """Parses SavedModel arguments from the given Keras/RNN SavedModel.
704
705    Args:
706      always_enable_saved_model_import: Bool. When the value is true, it enables
707        MLIR saved model import path regardless of checking the conditions.
708    """
709    if not self.experimental_new_converter:
710      self.saved_model_dir = None
711      return
712    if self.saved_model_dir:
713      try:
714        saved_model_proto, _ = (
715            _parse_saved_model_with_debug_info(self.saved_model_dir))
716      except OSError:
717        # If it fails to read the given saved model, it will fall back to the
718        # frozen graph def path.
719        self.saved_model_dir = None
720        return
721      if (not always_enable_saved_model_import and
722          not self._contains_function_with_implements_attr(saved_model_proto)):
723        self.saved_model_dir = None
724        return
725
726      if not self._saved_model_exported_names:
727        self._saved_model_exported_names = []
728      self._saved_model_version = saved_model_proto.saved_model_schema_version
729      if self._saved_model_version == 0:
730        self.saved_model_dir = None
731        logging.warning("SavedModel schema version is zero.")
732        return
733      if self._saved_model_version not in [1, 2]:
734        raise ValueError("SavedModel file format({0}) is not supported".format(
735            self._saved_model_version))
736
737  def _sparsify_model(self):
738    return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations
739
740  def _increase_conversion_attempt_metric(self):
741    self._tflite_metrics.increase_counter_converter_attempt()
742
743  def _increase_conversion_success_metric(self):
744    self._tflite_metrics.increase_counter_converter_success()
745
746  @classmethod
747  def _set_original_model_type(cls, model_type):
748    """Stores the original model type."""
749    if model_type == conversion_metdata_fb.ModelType.NONE:
750      raise ValueError("The original model type should be specified.")
751    cls._original_model_type = model_type
752
753  def _get_original_model_type(self):
754    """One-time getter to return original model type and set it to NONE."""
755    model_type = TFLiteConverterBase._original_model_type
756    TFLiteConverterBase._original_model_type = conversion_metdata_fb.ModelType.NONE
757    return model_type
758
759  def _save_conversion_params_metric(self,
760                                     graph_def=None,
761                                     inference_type=None,
762                                     inference_input_type=None):
763    """Set conversion parameter metrics."""
764    converter_kwargs = self._collected_converter_params
765    converter_kwargs.update(self._get_base_converter_args())
766
767    # Optimization parameters.
768    quant_mode = QuantizationMode(
769        self.optimizations, self.target_spec, self.representative_dataset,
770        graph_def, self._experimental_disable_per_channel,
771        self.experimental_new_dynamic_range_quantizer,
772        self._experimental_low_bit_qat,
773        self._experimental_full_integer_quantization_bias_type)
774    converter_kwargs.update({
775        "tf_version":
776            self._metadata.environment.tensorflowVersion,
777        "api_version":
778            self._metadata.environment.apiVersion,
779        "original_model_format":
780            self._metadata.environment.modelType,
781        "optimization_default":
782            quant_mode.is_any_optimization_enabled(),
783        "optimization_post_training_dynamic_range":
784            quant_mode.is_post_training_dynamic_range_quantization(),
785        "optimization_post_training_float16":
786            quant_mode.is_post_training_float16_quantization(),
787        "optimization_post_training_integer_quantize":
788            quant_mode.is_post_training_integer_quantization(),
789        "optimization_qat":
790            quant_mode.is_quantization_aware_training(),
791        "optimization_low_bit_qat":
792            quant_mode.is_low_bit_quantize_aware_training(),
793        "optimization_sparsify":
794            self._sparsify_model(),
795        "activations_type":
796            quant_mode.activations_type()
797    })
798    converter_kwargs.update(
799        quant_mode.converter_flags(inference_type, inference_input_type))
800
801    # pylint: disable=protected-access
802    if self.target_spec._experimental_supported_accumulation_type:
803      converter_kwargs.update({
804          "accumulation_type":
805              self.target_spec._experimental_supported_accumulation_type
806      })
807    # pylint: enable=protected-access
808
809    def format_element(elem):
810      if isinstance(elem, enum.Enum):
811        return str(elem.value)
812      return pprint.pformat(elem)
813
814    def format_param(param):
815      if isinstance(param, (list, tuple, set)):
816        if not param:
817          return "None"  # Return None if empty.
818        string_list = [format_element(x) for x in param]
819        return ",".join(sorted(string_list))
820      return format_element(param)
821
822    for key, value in converter_kwargs.items():
823      self._tflite_metrics.set_converter_param(key, format_param(value))
824    self._tflite_metrics.set_export_required()
825
826    # Set conversion option metadata.
827    self._metadata.options.allowCustomOps = self.allow_custom_ops
828    self._metadata.options.enableSelectTfOps = (
829        OpsSet.SELECT_TF_OPS in self.target_spec.supported_ops)
830    self._metadata.options.forceSelectTfOps = (
831        set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
832    self._metadata.options.modelOptimizationModes = []
833
834    if quant_mode.is_post_training_float16_quantization():
835      self._metadata.options.modelOptimizationModes.append(
836          conversion_metdata_fb.ModelOptimizationMode.PTQ_FLOAT16)
837
838    if quant_mode.is_post_training_dynamic_range_quantization():
839      self._metadata.options.modelOptimizationModes.append(
840          conversion_metdata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE)
841
842    if quant_mode.is_post_training_int8_quantization():
843      self._metadata.options.modelOptimizationModes.append(
844          conversion_metdata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER)
845
846    if quant_mode.is_post_training_int16x8_quantization():
847      self._metadata.options.modelOptimizationModes.append(
848          conversion_metdata_fb.ModelOptimizationMode.PTQ_INT16)
849
850    if quant_mode.is_quantization_aware_training():
851      self._metadata.options.modelOptimizationModes.append(
852          conversion_metdata_fb.ModelOptimizationMode
853          .QUANTIZATION_AWARE_TRAINING)
854
855  def _set_conversion_latency_metric(self, value):
856    self._tflite_metrics.set_converter_latency(value)
857
858  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL)
859  def _optimize_tflite_model(self, model, quant_mode, quant_io=True):
860    """Apply optimizations on a TFLite model."""
861
862    if quant_mode.is_integer_quantization():
863      in_type, out_type = self.inference_input_type, self.inference_output_type
864
865      if quant_mode.is_post_training_integer_quantization():
866        q_in_type = in_type if in_type and quant_io else _dtypes.float32
867        q_out_type = out_type if out_type and quant_io else _dtypes.float32
868        q_activations_type = quant_mode.activations_type()
869        q_bias_type = quant_mode.bias_type()
870        q_allow_float = quant_mode.is_allow_float()
871        model = self._quantize(model, q_in_type, q_out_type, q_activations_type,
872                               q_bias_type, q_allow_float)
873
874      m_in_type = in_type if in_type else _dtypes.float32
875      m_out_type = out_type if out_type else _dtypes.float32
876      # Skip updating model io types if MLIR quantizer already takes care of it
877      if not (quant_mode.is_post_training_integer_quantization() and
878              self.experimental_new_quantizer and quant_io and
879              (m_in_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32]) and
880              (m_out_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32])):
881        model = _modify_model_io_type(model, m_in_type, m_out_type)
882
883    if self._sparsify_model():
884      model = _mlir_sparsify(model)
885
886    try:
887      model = _deduplicate_readonly_buffers(model)
888    except Exception:  # pylint: disable=broad-except
889      # Skip buffer deduplication when flatbuffer library is not ready to be
890      # utilized.
891      logging.warning(
892          "Buffer deduplication procedure will be skipped when flatbuffer "
893          "library is not properly loaded")
894
895    return model
896
897  def _convert_and_export_metrics(self, convert_func, *args, **kwargs):
898    """Wraps around convert function to export metrics.
899
900    Args:
901      convert_func: The convert function to wrap.
902      *args: Positional arguments of the convert function.
903      **kwargs: The keyword arguments of the convert function.
904
905    Returns:
906      The decorator to wrap the convert function.
907    """
908    self._increase_conversion_attempt_metric()
909    self._save_conversion_params_metric()
910    start_time = time.process_time()
911    result = convert_func(self, *args, **kwargs)
912    elapsed_time_ms = (time.process_time() - start_time) * 1000
913    if result:
914      self._increase_conversion_success_metric()
915    self._set_conversion_latency_metric(round(elapsed_time_ms))
916    self._tflite_metrics.export_metrics()
917    if self.exclude_conversion_metadata:
918      return result
919    model_object = flatbuffer_utils.convert_bytearray_to_object(result)
920    # Populates the conversion metadata.
921    # TODO(b/202090541): Collects sparsity block size information.
922    sparsity_modes = _get_sparsity_modes(model_object)
923    self._metadata.options.modelOptimizationModes.extend(sparsity_modes)
924    model_object = _populate_conversion_metadata(model_object, self._metadata)
925    return flatbuffer_utils.convert_object_to_bytearray(model_object)
926
927
928def _export_metrics(convert_func):
929  """The decorator around convert function to export metrics."""
930  @functools.wraps(convert_func)
931  def wrapper(self, *args, **kwargs):
932    # pylint: disable=protected-access
933    return self._convert_and_export_metrics(convert_func, *args, **kwargs)
934    # pylint: enable=protected-access
935
936  return wrapper
937
938
939class TFLiteConverterBaseV2(TFLiteConverterBase):
940  """Converter subclass to share functionality between V2 converters."""
941
942  def __init__(self):
943    """Constructor for TFLiteConverter."""
944    super(TFLiteConverterBaseV2, self).__init__()
945    self.inference_input_type = _dtypes.float32
946    self.inference_output_type = _dtypes.float32
947    self._metadata.environment.apiVersion = 2
948
949  def _validate_inference_input_output_types(self, quant_mode):
950    """Validate inference_input_type and inference_output_type flags."""
951    default_types = [_dtypes.float32]
952    # We support integer input/output for integer quantized models only.
953    if quant_mode.is_integer_quantization():
954      if quant_mode.is_post_training_int16x8_quantization():
955        all_types = default_types + [_dtypes.int16]
956      else:
957        all_types = default_types + [_dtypes.int8, _dtypes.uint8]
958      if (self.inference_input_type not in all_types or
959          self.inference_output_type not in all_types):
960        all_types_names = ["tf." + t.name for t in all_types]
961        raise ValueError("The inference_input_type and inference_output_type "
962                         "must be in {}.".format(all_types_names))
963    elif (self.inference_input_type not in default_types or
964          self.inference_output_type not in default_types):
965      raise ValueError("The inference_input_type and inference_output_type "
966                       "must be tf.float32.")
967
968  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL)
969  def _load_saved_model(self, saved_model_dir, saved_model_tags):
970    """Load graph_def from saved model with the default serving signature key.
971
972    Args:
973      saved_model_dir: Directory of the SavedModel.
974      saved_model_tags: Set of tags identifying the MetaGraphDef within the
975        SavedModel to analyze.
976
977    Returns:
978      graph_def: The loaded GraphDef.
979      input_tensors: List of input tensors.
980      output_tensors: List of output tensors.
981    """
982    graph = _ops.Graph()
983    saved_model = _loader_impl.SavedModelLoader(saved_model_dir)
984    saved_model.load_graph(graph, tags=saved_model_tags)
985    meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags)
986    graph_def = meta_graph.graph_def
987    signature_def = meta_graph.signature_def[
988        _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
989    input_tensors = [
990        graph.get_tensor_by_name(signature_def.inputs[key].name)
991        for key in signature_def.inputs
992    ]
993    output_tensors = [
994        graph.get_tensor_by_name(signature_def.outputs[key].name)
995        for key in signature_def.outputs
996    ]
997    return graph_def, input_tensors, output_tensors
998
999  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
1000  def _validate_inputs(self, graph_def, input_tensors):
1001    """Validate the input parameters.
1002
1003    Args:
1004      graph_def: The TensorFlow GraphDef.
1005      input_tensors: List of input tensors.
1006    Raise:
1007      ValueError:
1008        Input shape is not specified.
1009        Invalid quantization parameters.
1010    """
1011    # Update conversion params with graph_def.
1012    self._save_conversion_params_metric(graph_def)
1013    self._quant_mode = QuantizationMode(
1014        self.optimizations, self.target_spec, self.representative_dataset,
1015        graph_def, self._experimental_disable_per_channel,
1016        self.experimental_new_dynamic_range_quantizer,
1017        self._experimental_low_bit_qat,
1018        self._experimental_full_integer_quantization_bias_type)
1019    self._validate_inference_input_output_types(self._quant_mode)
1020
1021    if not self._is_unknown_shapes_allowed():
1022      # Checks dimensions in input tensor.
1023      for tensor in input_tensors:
1024        # Note that shape_list might be empty for scalar shapes.
1025        shape_list = tensor.shape.as_list()
1026        if None in shape_list[1:]:
1027          raise ValueError(
1028              "None is only supported in the 1st dimension. Tensor '{0}' has "
1029              "invalid shape '{1}'.".format(
1030                  _get_tensor_name(tensor), shape_list))
1031        elif shape_list and shape_list[0] is None:
1032          # Set the batch size to 1 if undefined.
1033          shape = tensor.shape.as_list()
1034          shape[0] = 1
1035          tensor.set_shape(shape)
1036
1037    if (self._trackable_obj is None or
1038        not hasattr(self._trackable_obj, "graph_debug_info")):
1039      self._debug_info = _get_debug_info(
1040          _build_debug_info_func(self._funcs[0].graph), graph_def)
1041    else:
1042      self._debug_info = _get_debug_info(
1043          _convert_debug_info_func(self._trackable_obj.graph_debug_info),
1044          graph_def)
1045
1046  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
1047  def _optimize_tf_model(self, graph_def, input_tensors, output_tensors,
1048                         frozen_func):
1049    """Run a Grappler pass to optimize the TensorFlow graph.
1050
1051    Args:
1052      graph_def: Frozen GraphDef to be optimized.
1053      input_tensors: List of input tensors.
1054      output_tensors: List of output tensors.
1055      frozen_func: TensorFlow Graph.
1056
1057    Returns:
1058      The optimized TensorFlow graph.
1059    """
1060    grappler_config = self._grappler_config()
1061    # Skip running grappler when there are no optimizers to run. If not,
1062    # grappler will run with the default optimizer set and it will lead to
1063    # causing an unexpected behavior.
1064    if grappler_config.graph_options.rewrite_options.optimizers:
1065      graph_def = _run_graph_optimizations(
1066          graph_def,
1067          input_tensors,
1068          output_tensors,
1069          config=grappler_config,
1070          graph=frozen_func.graph)
1071    return graph_def
1072
1073  def _convert_from_saved_model(self, graph_def):
1074    """Helper method that converts saved model.
1075
1076    Args:
1077      graph_def: GraphDef object for the model, used only for stats.
1078
1079    Returns:
1080      The converted TFLite model.
1081    """
1082    # Update conversion params with graph_def.
1083    self._save_conversion_params_metric(graph_def)
1084    # Get quantization options and do some sanity checks.
1085    quant_mode = QuantizationMode(
1086        self.optimizations, self.target_spec, self.representative_dataset,
1087        graph_def, self._experimental_disable_per_channel,
1088        self.experimental_new_dynamic_range_quantizer,
1089        self._experimental_low_bit_qat,
1090        self._experimental_full_integer_quantization_bias_type)
1091    self._validate_inference_input_output_types(quant_mode)
1092    converter_kwargs = {
1093        "enable_tflite_resource_variables":
1094            self.experimental_enable_resource_variables
1095    }
1096    converter_kwargs.update(self._get_base_converter_args())
1097    converter_kwargs.update(quant_mode.converter_flags())
1098
1099    result = _convert_saved_model(**converter_kwargs)
1100    return self._optimize_tflite_model(
1101        result, quant_mode, quant_io=self.experimental_new_quantizer)
1102
1103  def convert(self, graph_def, input_tensors, output_tensors):
1104    """Converts a TensorFlow GraphDef based on instance variables.
1105
1106    Args:
1107      graph_def: Frozen TensorFlow GraphDef.
1108      input_tensors: List of input tensors.
1109      output_tensors: List of output tensors.
1110
1111    Returns:
1112      The converted data in serialized format.
1113
1114    Raises:
1115      ValueError:
1116        No concrete functions is specified.
1117        Multiple concrete functions are specified.
1118        Input shape is not specified.
1119        Invalid quantization parameters.
1120    """
1121    self._validate_inputs(graph_def, input_tensors)
1122    converter_kwargs = self._get_base_converter_args()
1123    converter_kwargs.update(self._quant_mode.converter_flags())
1124    if not self.experimental_new_converter:
1125      logging.warning(
1126          "Please consider switching to the new converter by setting "
1127          "experimental_new_converter=True. "
1128          "The old converter is deprecated.")
1129    else:
1130      logging.info("Using new converter: If you encounter a problem "
1131                   "please file a bug. You can opt-out "
1132                   "by setting experimental_new_converter=False")
1133
1134    # Converts model.
1135    result = _convert_graphdef(
1136        input_data=graph_def,
1137        input_tensors=input_tensors,
1138        output_tensors=output_tensors,
1139        **converter_kwargs)
1140
1141    return self._optimize_tflite_model(
1142        result, self._quant_mode, quant_io=self.experimental_new_quantizer)
1143
1144
1145class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
1146  """Converts the given SavedModel into TensorFlow Lite model.
1147
1148  Attributes:
1149      saved_model_dir: Directory of the SavedModel.
1150  """
1151
1152  def __init__(self,
1153               saved_model_dir,
1154               saved_model_tags=None,
1155               saved_model_exported_names=None,
1156               trackable_obj=None):
1157    """Constructor for TFLiteConverter.
1158
1159    Args:
1160      saved_model_dir: Directory of the SavedModel.
1161      saved_model_tags: Set of tags identifying the MetaGraphDef within the
1162        SavedModel to analyze. All tags in the tag set must be present. (default
1163        {tf.saved_model.SERVING}).
1164      saved_model_exported_names: Names to be exported when the saved model
1165        import path is on.
1166      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1167        reference to this object needs to be maintained so that Variables do not
1168        get garbage collected since functions have a weak reference to
1169        Variables. This is only required when the tf.AutoTrackable object is not
1170        maintained by the user (e.g. `from_saved_model`).
1171    """
1172    super(TFLiteSavedModelConverterV2, self).__init__()
1173    self.saved_model_dir = saved_model_dir
1174    self._saved_model_tags = saved_model_tags
1175    self._saved_model_exported_names = saved_model_exported_names
1176    self._trackable_obj = trackable_obj
1177    self._parse_saved_model_args(always_enable_saved_model_import=True)
1178
1179  @_export_metrics
1180  def convert(self):
1181    """Converts a TensorFlow GraphDef based on instance variables.
1182
1183    Returns:
1184      The converted data in serialized format.
1185
1186    Raises:
1187      ValueError:
1188        No concrete functions is specified.
1189        Multiple concrete functions are specified.
1190        Input shape is not specified.
1191        Invalid quantization parameters.
1192    """
1193    graph_def, input_tensors, output_tensors = self._load_saved_model(
1194        self.saved_model_dir, self._saved_model_tags)
1195    # If we can't use saved model importer, then fallback
1196    # to frozen graph conversion path.
1197    if self.saved_model_dir is None or not self.experimental_new_converter:
1198      graph_def, _, _, _ = _freeze_saved_model(
1199          self.saved_model_dir, None, None, None, self._saved_model_tags,
1200          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
1201      # We make sure to clear the saved_model_dir as there is some
1202      # legacy code down in the caller that checks this.
1203      # TODO(b/162537905): Clean these indirect dependencies.
1204      self.saved_model_dir = None
1205      return super(TFLiteSavedModelConverterV2,
1206                   self).convert(graph_def, input_tensors, output_tensors)
1207
1208    if self._trackable_obj is None:
1209      self._debug_info = _get_debug_info(
1210          _build_debug_info_func(self._funcs[0].graph), graph_def)
1211    else:
1212      self._debug_info = _get_debug_info(
1213          _convert_debug_info_func(self._trackable_obj.graph_debug_info),
1214          graph_def)
1215
1216    return self._convert_from_saved_model(graph_def)
1217
1218
1219class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
1220  """Converts the given Keras model into TensorFlow Lite model."""
1221
1222  def __init__(self, keras_model, trackable_obj=None):
1223    """Constructor for TFLiteConverter.
1224
1225    Args:
1226      keras_model: tf.Keras.Model.
1227      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1228        reference to this object needs to be maintained so that Variables do not
1229        get garbage collected since functions have a weak reference to
1230        Variables. This is only required when the tf.AutoTrackable object is not
1231        maintained by the user (e.g. `from_saved_model`).
1232    """
1233    super(TFLiteKerasModelConverterV2, self).__init__()
1234    self._keras_model = keras_model
1235    self._trackable_obj = trackable_obj
1236    self.experimental_lower_to_saved_model = True
1237
1238  @convert_phase(Component.PREPARE_TF_MODEL,
1239                 SubComponent.CONVERT_KERAS_TO_SAVED_MODEL)
1240  def _convert_keras_to_saved_model(self, output_dir):
1241    """Save Keras model to the SavedModel format.
1242
1243    Args:
1244      output_dir: The output directory to save the SavedModel.
1245
1246    Returns:
1247      graph_def: The frozen GraphDef.
1248      input_tensors: List of input tensors.
1249      output_tensors: List of output tensors.
1250    """
1251    try:
1252      _saved_model.save(
1253          self._keras_model,
1254          output_dir,
1255          options=_save_options.SaveOptions(save_debug_info=True))
1256    except Exception:  # pylint: disable=broad-except
1257      # When storing the given keras model to a saved model is failed, let's
1258      # use original keras model conversion pipeline.
1259      return None, None, None
1260    self.saved_model_dir = output_dir
1261    self._saved_model_tags = set([_tag_constants.SERVING])
1262    self._saved_model_exported_names = [
1263        _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1264    ]
1265    self._parse_saved_model_args(
1266        always_enable_saved_model_import=self.experimental_lower_to_saved_model)
1267    if self.saved_model_dir:
1268      graph_def, input_tensors, output_tensors = self._load_saved_model(
1269          self.saved_model_dir, self._saved_model_tags)
1270      self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
1271      return graph_def, input_tensors, output_tensors
1272    return None, None, None
1273
1274  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
1275  def _freeze_keras_model(self):
1276    """Freeze Keras model to frozen graph.
1277
1278    Returns:
1279      graph_def: The frozen GraphDef.
1280      input_tensors: List of input tensors.
1281      output_tensors: List of output tensors.
1282      frozen_func: The frozen ConcreteFunction.
1283    """
1284    input_signature = None
1285    # If the model's call is not a `tf.function`, then we need to first get its
1286    # input signature from `model_input_signature` method. We can't directly
1287    # call `trace_model_call` because otherwise the batch dimension is set
1288    # to None.
1289    # Once we have better support for dynamic shapes, we can remove this.
1290    if not isinstance(self._keras_model.call, _def_function.Function):
1291      # Pass `keep_original_batch_size=True` will ensure that we get an input
1292      # signature including the batch dimension specified by the user.
1293      # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
1294      input_signature = _model_input_signature(
1295          self._keras_model, keep_original_batch_size=True)
1296
1297    # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
1298    func = _trace_model_call(self._keras_model, input_signature)
1299    concrete_func = func.get_concrete_function()
1300    self._funcs = [concrete_func]
1301
1302    frozen_func, graph_def = (
1303        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
1304            self._funcs[0], lower_control_flow=False))
1305
1306    input_tensors = [
1307        tensor for tensor in frozen_func.inputs
1308        if tensor.dtype != _dtypes.resource
1309    ]
1310    output_tensors = frozen_func.outputs
1311    return graph_def, input_tensors, output_tensors, frozen_func
1312
1313  def _convert_as_saved_model(self):
1314    """Converts a Keras model as a saved model.
1315
1316    Returns:
1317      The converted data in serialized format.
1318    """
1319    temp_dir = tempfile.mkdtemp()
1320    try:
1321      graph_def, input_tensors, output_tensors = (
1322          self._convert_keras_to_saved_model(temp_dir))
1323      if self.saved_model_dir:
1324        return super(TFLiteKerasModelConverterV2,
1325                     self).convert(graph_def, input_tensors, output_tensors)
1326    finally:
1327      shutil.rmtree(temp_dir, True)
1328
1329  @_export_metrics
1330  def convert(self):
1331    """Converts a keras model based on instance variables.
1332
1333    Returns:
1334      The converted data in serialized format.
1335
1336    Raises:
1337      ValueError:
1338        Multiple concrete functions are specified.
1339        Input shape is not specified.
1340        Invalid quantization parameters.
1341    """
1342    saved_model_convert_result = self._convert_as_saved_model()
1343    if saved_model_convert_result:
1344      return saved_model_convert_result
1345
1346    graph_def, input_tensors, output_tensors, frozen_func = (
1347        self._freeze_keras_model())
1348
1349    graph_def = self._optimize_tf_model(graph_def, input_tensors,
1350                                        output_tensors, frozen_func)
1351
1352    return super(TFLiteKerasModelConverterV2,
1353                 self).convert(graph_def, input_tensors, output_tensors)
1354
1355
1356class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
1357  """Converts the given frozen graph into TensorFlow Lite model."""
1358
1359  def __init__(self, funcs, trackable_obj=None):
1360    """Constructor for TFLiteConverter.
1361
1362    Args:
1363      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1364        duplicate elements.
1365      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1366        reference to this object needs to be maintained so that Variables do not
1367        get garbage collected since functions have a weak reference to
1368        Variables. This is only required when the tf.AutoTrackable object is not
1369        maintained by the user (e.g. `from_saved_model`).
1370    """
1371    super(TFLiteFrozenGraphConverterV2, self).__init__()
1372    self._funcs = funcs
1373    self._trackable_obj = trackable_obj
1374    self.experimental_lower_to_saved_model = True
1375
1376  @convert_phase(Component.PREPARE_TF_MODEL,
1377                 SubComponent.FREEZE_CONCRETE_FUNCTION)
1378  def _freeze_concrete_function(self):
1379    """Convert the given ConcreteFunction to frozen graph.
1380
1381    Returns:
1382      graph_def: The frozen GraphDef.
1383      input_tensors: List of input tensors.
1384      output_tensors: List of output tensors.
1385      frozen_func: The frozen ConcreteFunction.
1386
1387    Raises:
1388      ValueError: none or multiple ConcreteFunctions provided.
1389    """
1390    # TODO(b/130297984): Add support for converting multiple function.
1391
1392    if len(self._funcs) == 0:  # pylint: disable=g-explicit-length-test
1393      raise ValueError("No ConcreteFunction is specified.")
1394
1395    if len(self._funcs) > 1:
1396      raise ValueError("This converter can only convert a single "
1397                       "ConcreteFunction. Converting multiple functions is "
1398                       "under development.")
1399
1400    frozen_func, graph_def = (
1401        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
1402            self._funcs[0], lower_control_flow=False))
1403
1404    input_tensors = [
1405        tensor for tensor in frozen_func.inputs
1406        if tensor.dtype != _dtypes.resource
1407    ]
1408    output_tensors = frozen_func.outputs
1409    return graph_def, input_tensors, output_tensors, frozen_func
1410
1411  @convert_phase(Component.PREPARE_TF_MODEL,
1412                 SubComponent.CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL)
1413  def _convert_concrete_functions_to_saved_model(self, output_dir):
1414    """Save concrete functions to the SavedModel format.
1415
1416    Args:
1417      output_dir: The output directory to save the SavedModel.
1418
1419    Returns:
1420      graph_def: The frozen GraphDef.
1421      input_tensors: List of input tensors.
1422      output_tensors: List of output tensors.
1423    """
1424    if len(self._funcs) == 0:  # pylint: disable=g-explicit-length-test
1425      raise ValueError("No ConcreteFunction is specified.")
1426
1427    if not self.experimental_lower_to_saved_model:
1428      return None, None, None
1429
1430    # Without the provided trackable obj, it is not able to serialize the given
1431    # concrete functions as a saved model format. Also when trackable obj is
1432    # a function, use the original concrete function conversion pipline.
1433    if (not self._trackable_obj or
1434        isinstance(self._trackable_obj, (_function.ConcreteFunction,
1435                                         _def_function.Function))):
1436      return None, None, None
1437
1438    signatures = {}
1439    signature_keys = []
1440    try:
1441      if len(self._funcs) == 1:
1442        signatures[_signature_constants
1443                   .DEFAULT_SERVING_SIGNATURE_DEF_KEY] = self._funcs[0]
1444        signature_keys = [
1445            _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1446        ]
1447      else:
1448        for func in self._funcs:
1449          signatures[func.graph.name] = func
1450          signature_keys.append(func.graph.name)
1451
1452      _saved_model.save(
1453          self._trackable_obj,
1454          output_dir,
1455          signatures=signatures,
1456          options=_save_options.SaveOptions(save_debug_info=True))
1457    except Exception:  # pylint: disable=broad-except
1458      # When storing the given concrete function to a saved model is failed,
1459      # let's use original concrete function conversion pipeline.
1460      return None, None, None
1461
1462    self.saved_model_dir = output_dir
1463    self._saved_model_tags = set([_tag_constants.SERVING])
1464    self._saved_model_exported_names = signature_keys
1465    self._parse_saved_model_args(always_enable_saved_model_import=True)
1466    if self.saved_model_dir:
1467      graph_def, input_tensors, output_tensors = self._load_saved_model(
1468          self.saved_model_dir, self._saved_model_tags)
1469      self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
1470      return graph_def, input_tensors, output_tensors
1471    return None, None, None
1472
1473  def _convert_as_saved_model(self):
1474    """Converts the given concrete functions as a saved model format.
1475
1476    Returns:
1477      The converted data in serialized format.
1478    """
1479    temp_dir = tempfile.mkdtemp()
1480    try:
1481      graph_def, input_tensors, _ = (
1482          self._convert_concrete_functions_to_saved_model(temp_dir))
1483      if self.saved_model_dir:
1484        self._validate_inputs(graph_def, input_tensors)
1485        return self._convert_from_saved_model(graph_def)
1486    finally:
1487      shutil.rmtree(temp_dir, True)
1488    return None
1489
1490  @_export_metrics
1491  def convert(self):
1492    """Converts a TensorFlow GraphDef based on instance variables.
1493
1494    Returns:
1495      The converted data in serialized format.
1496
1497    Raises:
1498      ValueError:
1499        No concrete functions is specified.
1500        Multiple concrete functions are specified.
1501        Input shape is not specified.
1502        Invalid quantization parameters.
1503    """
1504    if self.experimental_lower_to_saved_model:
1505      saved_model_convert_result = self._convert_as_saved_model()
1506      if saved_model_convert_result:
1507        return saved_model_convert_result
1508
1509    graph_def, input_tensors, output_tensors, frozen_func = (
1510        self._freeze_concrete_function())
1511
1512    graph_def = self._optimize_tf_model(graph_def, input_tensors,
1513                                        output_tensors, frozen_func)
1514
1515    return super(TFLiteFrozenGraphConverterV2,
1516                 self).convert(graph_def, input_tensors, output_tensors)
1517
1518
1519class TFLiteJaxConverterV2(TFLiteConverterBaseV2):
1520  """Converts the given jax model into TensorFlow Lite model."""
1521
1522  def __init__(self, serving_funcs, inputs):
1523    """Constructor for TFLiteConverter.
1524
1525    Args:
1526      serving_funcs: A list functions of the serving func of the jax module, the
1527        model params should already be inlined. (e.g., `serving_func =
1528        functools.partial(model, params=params)`)
1529      inputs: Array of input tensor placeholders tuple,s like `jnp.zeros`. For
1530        example, wrapped in an array like
1531        "[('input1', input1), ('input2', input2)]]".
1532    Jax function is polymorphic, for example:
1533    ```python
1534    def add(a, b):
1535      return a + b
1536    ```
1537    Will yield different computations if different input signatures are passed
1538    in: Pass `add(10.0, 20.0)` will yield a scalar `add` while pass
1539      `add(np.random((100, 1)), np.random(100, 100))` will yield a broadcasting
1540      add.  We will need the input information to do tracing for the converter
1541      to properly convert the model. So it's important to pass in the desired
1542      `input placeholders` with the correct input shape/type.
1543
1544    In the converted tflite model:
1545    Currently: the function name will be default to main, the output names will
1546    be the traced outputs. The output ordering shall match the serving function.
1547    """
1548    super(TFLiteJaxConverterV2, self).__init__()
1549    self._serving_funcs = serving_funcs
1550    self._inputs = inputs
1551
1552  @_export_metrics
1553  def convert(self):
1554    """Converts a Jax serving func based on instance variables.
1555
1556    Returns:
1557      The converted data in serialized format.
1558
1559    Raises:
1560      ImportError:
1561        If cannot import the xla_computation from jax.
1562      ValueError:
1563        No serving function is specified.
1564        Input tensors are not specified.
1565        The truth value of an array with more than one element is ambiguous.
1566        Failed to convert the given Jax function to hlo.
1567
1568    """
1569    if not _xla_computation:
1570      raise ImportError("Cannot import xla_computation from jax.")
1571
1572    if not self._serving_funcs:
1573      raise ValueError("No serving func is specified.")
1574
1575    if not self._inputs:
1576      raise ValueError("Input tensors are not specified.")
1577
1578    if len(self._inputs) != len(self._serving_funcs):
1579      msg = ("Input tensor mapping len {} does not match serving func len {}."
1580             .format(len(self._inputs), len(self._serving_funcs)))
1581      raise ValueError(msg)
1582
1583    if not isinstance(self._inputs, (tuple, list)):
1584      raise ValueError(
1585          "Input tensors should be pass in a tuple list wrapped in an array.")
1586
1587    # TODO(b/197690428): Support multiple functions.
1588    # Currently only support one serving function.
1589    if len(self._serving_funcs) > 1:
1590      raise ValueError("Currently only support single serving function.")
1591
1592    if not isinstance(self._inputs[0], (tuple, list)):
1593      raise ValueError("The input placeholders are not a dictionary.")
1594
1595    input_names = []
1596    ordered_inputs = []
1597    for input_name, tensor in self._inputs[0]:
1598      input_names.append(input_name)
1599      ordered_inputs.append(tensor)
1600
1601    try:
1602      xla_compuation = _xla_computation(self._serving_funcs[0], backend="cpu")
1603      hlo_proto = xla_compuation(
1604          *ordered_inputs).as_serialized_hlo_module_proto()
1605    except Exception:  # pylint: disable=broad-except
1606      raise ValueError("Failed to convert the given Jax function to hlo.")
1607
1608    # We need to set the hlo proto, and here we use serialized proto format
1609    # since it's more compact.
1610    converter_kwargs = {
1611        "input_content": hlo_proto,
1612        "input_names": input_names,
1613        "is_proto_format": True
1614    }
1615    converter_kwargs.update(self._get_base_converter_args())
1616
1617    # Get quantization options and do some checks.
1618    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
1619                                  self.representative_dataset, None)
1620    self._validate_inference_input_output_types(quant_mode)
1621    converter_kwargs.update(quant_mode.converter_flags())
1622    result = _convert_jax_hlo(**converter_kwargs)
1623
1624    return self._optimize_tflite_model(
1625        result, quant_mode, quant_io=self.experimental_new_quantizer)
1626
1627
1628@_tf_export("lite.TFLiteConverter", v1=[])
1629class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
1630  """Converts a TensorFlow model into TensorFlow Lite model.
1631
1632  Attributes:
1633    optimizations: Experimental flag, subject to change. Set of optimizations to
1634      apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
1635      set of values of type `tf.lite.Optimize`)
1636    representative_dataset: A generator function used for integer quantization
1637      where each generated sample has the same order, type and shape as the
1638      inputs to the model. Usually, this is a small subset of a few hundred
1639      samples randomly chosen, in no particular order, from the training or
1640      evaluation dataset. This is an optional attribute, but required for full
1641      integer quantization, i.e, if `tf.int8` is the only supported type in
1642      `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
1643      (default None)
1644    target_spec: Experimental flag, subject to change. Specifications of target
1645      device, including supported ops set, supported types and a set of user's
1646      defined TensorFlow operators required in the TensorFlow Lite runtime.
1647      Refer to `tf.lite.TargetSpec`.
1648    inference_input_type: Data type of the input layer. Note that integer types
1649      (tf.int8 and tf.uint8) are currently only supported for post training
1650      integer quantization and quantization aware training. (default tf.float32,
1651      must be in {tf.float32, tf.int8, tf.uint8})
1652    inference_output_type: Data type of the output layer. Note that integer
1653      types (tf.int8 and tf.uint8) are currently only supported for post
1654      training integer quantization and quantization aware training. (default
1655      tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
1656    allow_custom_ops: Boolean indicating whether to allow custom operations.
1657      When False, any unknown operation is an error. When True, custom ops are
1658      created for any op that is unknown. The developer needs to provide these
1659      to the TensorFlow Lite runtime with a custom resolver. (default False)
1660    exclude_conversion_metadata: Whether not to embed the conversion metadata
1661      into the converted model. (default False)
1662    experimental_new_converter: Experimental flag, subject to change. Enables
1663      MLIR-based conversion. (default True)
1664    experimental_new_quantizer: Experimental flag, subject to change. Enables
1665      MLIR-based quantization conversion instead of Flatbuffer-based conversion.
1666      (default True)
1667    experimental_enable_resource_variables: Experimental flag, subject to
1668      change. Enables
1669      [resource variables](https://tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables)
1670      to be converted by this converter. This is only allowed if the
1671      from_saved_model interface is used. (default True)
1672
1673  Example usage:
1674
1675  ```python
1676  # Converting a SavedModel to a TensorFlow Lite model.
1677    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1678    tflite_model = converter.convert()
1679
1680  # Converting a tf.Keras model to a TensorFlow Lite model.
1681  converter = tf.lite.TFLiteConverter.from_keras_model(model)
1682  tflite_model = converter.convert()
1683
1684  # Converting ConcreteFunctions to a TensorFlow Lite model.
1685  converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model)
1686  tflite_model = converter.convert()
1687
1688  # Converting a Jax model to a TensorFlow Lite model.
1689  converter = tf.lite.TFLiteConverter.experimental_from_jax([func], [[
1690      ('input1', input1), ('input2', input2)]])
1691  tflite_model = converter.convert()
1692  ```
1693  """
1694
1695  # pylint: disable=useless-super-delegation
1696  def __init__(self, funcs, trackable_obj=None):
1697    """Constructor for TFLiteConverter.
1698
1699    Args:
1700      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1701        duplicate elements.
1702      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1703        reference to this object needs to be maintained so that Variables do not
1704        get garbage collected since functions have a weak reference to
1705        Variables. This is only required when the tf.AutoTrackable object is not
1706        maintained by the user (e.g. `from_saved_model`).
1707    """
1708    super(TFLiteConverterV2, self).__init__(funcs, trackable_obj)
1709
1710  @classmethod
1711  def from_concrete_functions(cls, funcs, trackable_obj=None):
1712    """Creates a TFLiteConverter object from ConcreteFunctions.
1713
1714    Args:
1715      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1716        duplicate elements. Currently converter can only convert a single
1717        ConcreteFunction. Converting multiple functions is under development.
1718      trackable_obj:   An `AutoTrackable` object (typically `tf.module`)
1719        associated with `funcs`. A reference to this object needs to be
1720        maintained so that Variables do not get garbage collected since
1721        functions have a weak reference to Variables.
1722
1723    Returns:
1724      TFLiteConverter object.
1725
1726    Raises:
1727      Invalid input type.
1728    """
1729    # pylint: disable=protected-access
1730    TFLiteConverterBase._set_original_model_type(
1731        conversion_metdata_fb.ModelType.TF_CONCRETE_FUNCTIONS)
1732    # pylint: enable=protected-access
1733    if trackable_obj is None:
1734      logging.warning(
1735          "Please consider providing the trackable_obj argument in the "
1736          "from_concrete_functions. Providing without the trackable_obj "
1737          "argument is deprecated and it will use the deprecated conversion "
1738          "path.")
1739    for func in funcs:
1740      if not isinstance(func, _function.ConcreteFunction):
1741        message = "This function takes in a list of ConcreteFunction."
1742        if isinstance(func, _def_function.Function):
1743          message += (" To get the ConcreteFunction from a Function,"
1744                      " call get_concrete_function.")
1745        raise ValueError(message)
1746    return cls(funcs, trackable_obj)
1747
1748  @classmethod
1749  def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None):
1750    """Creates a TFLiteConverter object from a SavedModel directory.
1751
1752    Args:
1753      saved_model_dir: SavedModel directory to convert.
1754      signature_keys: List of keys identifying SignatureDef containing inputs
1755        and outputs. Elements should not be duplicated. By default the
1756        `signatures` attribute of the MetaGraphdef is used. (default
1757        saved_model.signatures)
1758      tags: Set of tags identifying the MetaGraphDef within the SavedModel to
1759        analyze. All tags in the tag set must be present. (default
1760        {tf.saved_model.SERVING} or {'serve'})
1761
1762    Returns:
1763      TFLiteConverter object.
1764
1765    Raises:
1766      Invalid signature keys.
1767    """
1768    # pylint: disable=protected-access
1769    TFLiteConverterBase._set_original_model_type(
1770        conversion_metdata_fb.ModelType.TF_SAVED_MODEL)
1771    # pylint: enable=protected-access
1772    # When run without eager enabled, this will return the legacy
1773    # TFLiteConverter.
1774    if not context.executing_eagerly():
1775      signature_key = None
1776      if signature_keys:
1777        if len(signature_keys) != 1:
1778          raise ValueError("Only support a single signature key.")
1779        else:
1780          signature_key = signature_keys[0]
1781      logging.warning("Invoking the TF1 implementation of TFLiteConverter "
1782                      "because eager is disabled. Consider enabling eager.")
1783      return TFLiteConverter.from_saved_model(
1784          saved_model_dir, signature_key=signature_key, tag_set=tags)
1785
1786    # Ensures any graphs created in Eager mode are able to run. This is required
1787    # in order to create a tf.estimator.Exporter that exports a TFLite model.
1788    if tags is None:
1789      tags = set([_tag_constants.SERVING])
1790
1791    with context.eager_mode():
1792      saved_model = _load(saved_model_dir, tags)
1793    if not signature_keys:
1794      signature_keys = saved_model.signatures
1795
1796    if not signature_keys:
1797      raise ValueError("Only support at least one signature key.")
1798
1799    funcs = []
1800    for key in signature_keys:
1801      if key not in saved_model.signatures:
1802        raise ValueError("Invalid signature key '{}' found. Valid keys are "
1803                         "'{}'.".format(key, ",".join(saved_model.signatures)))
1804      funcs.append(saved_model.signatures[key])
1805
1806    saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags,
1807                                                        signature_keys,
1808                                                        saved_model)
1809    if saved_model_converter.saved_model_dir:
1810      return saved_model_converter
1811
1812    return cls(funcs, saved_model)
1813
1814  @classmethod
1815  def from_keras_model(cls, model):
1816    """Creates a TFLiteConverter object from a Keras model.
1817
1818    Args:
1819      model: tf.Keras.Model
1820
1821    Returns:
1822      TFLiteConverter object.
1823    """
1824    # pylint: disable=protected-access
1825    TFLiteConverterBase._set_original_model_type(
1826        conversion_metdata_fb.ModelType.KERAS_MODEL)
1827    # pylint: enable=protected-access
1828    return TFLiteKerasModelConverterV2(model)
1829
1830  @classmethod
1831  def experimental_from_jax(cls, serving_funcs, inputs):
1832    # Experimental API, subject to changes.
1833    # TODO(b/197690428): Currently only support single function.
1834    """Creates a TFLiteConverter object from a Jax model with its inputs.
1835
1836    Args:
1837      serving_funcs: A array of Jax functions with all the weights applied
1838        already.
1839      inputs: A array of Jax input placeholders tuples list, e.g.,
1840        jnp.zeros(INPUT_SHAPE). Each tuple list should correspond with the
1841        serving function.
1842
1843    Returns:
1844      TFLiteConverter object.
1845    """
1846    # pylint: disable=protected-access
1847    TFLiteConverterBase._set_original_model_type(
1848        conversion_metdata_fb.ModelType.JAX)
1849    # pylint: enable=protected-access
1850    return TFLiteJaxConverterV2(serving_funcs, inputs)
1851
1852  # pylint: disable=useless-super-delegation
1853  def convert(self):
1854    """Converts a TensorFlow GraphDef based on instance variables.
1855
1856    Returns:
1857      The converted data in serialized format.
1858
1859    Raises:
1860      ValueError:
1861        No concrete functions is specified.
1862        Multiple concrete functions are specified.
1863        Input shape is not specified.
1864        Invalid quantization parameters.
1865    """
1866    return super(TFLiteConverterV2, self).convert()
1867
1868
1869class TFLiteConverterBaseV1(TFLiteConverterBase):
1870  """Converter subclass to share functionality between V1 converters."""
1871
1872  def __init__(self, experimental_debug_info_func):
1873    """Constructor for TFLiteConverter.
1874
1875    Args:
1876      experimental_debug_info_func: An experimental function to retrieve the
1877        graph debug info for a set of nodes from the `graph_def`.
1878    """
1879    super(TFLiteConverterBaseV1, self).__init__()
1880    self.inference_type = _dtypes.float32
1881    self.inference_input_type = None
1882    self.inference_output_type = None
1883    self.output_format = constants.TFLITE
1884    self.quantized_input_stats = {}
1885    self.default_ranges_stats = None
1886    self.drop_control_dependency = True
1887    self.reorder_across_fake_quant = False
1888    self.change_concat_input_ranges = False
1889    self.dump_graphviz_dir = None
1890    self.dump_graphviz_video = False
1891    self.conversion_summary_dir = None
1892    self._debug_info_func = experimental_debug_info_func
1893    self._metadata.environment.apiVersion = 1
1894
1895  def __setattr__(self, name, value):
1896    if name == "post_training_quantize":
1897      warnings.warn("Property %s is deprecated, "
1898                    "please use optimizations=[Optimize.DEFAULT]"
1899                    " instead." % name)
1900      if value:
1901        self.optimizations = [Optimize.DEFAULT]
1902      else:
1903        self.optimizations = []
1904      return
1905    if name == "target_ops":
1906      warnings.warn("Property %s is deprecated, please use "
1907                    "target_spec.supported_ops instead." % name)
1908      self.target_spec.supported_ops = value
1909      return
1910    object.__setattr__(self, name, value)
1911
1912  def __getattribute__(self, name):
1913    if name == "post_training_quantize":
1914      warnings.warn("Property %s is deprecated, "
1915                    "please use optimizations=[Optimize.DEFAULT]"
1916                    " instead." % name)
1917      return Optimize.DEFAULT in set(self.optimizations)
1918    if name == "target_ops":
1919      warnings.warn("Property %s is deprecated, please use "
1920                    "target_spec.supported_ops instead." % name)
1921      return self.target_spec.supported_ops
1922    return object.__getattribute__(self, name)
1923
1924  def _validate_quantized_input_stats(self, converter_kwargs, quant_mode):
1925    """Ensure the `quantized_input_stats` flag is provided if required."""
1926
1927    quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
1928
1929    requires_quantized_input_stats = (
1930        (converter_kwargs["inference_type"] in quantized_types or
1931         converter_kwargs["inference_input_type"] in quantized_types) and
1932        not quant_mode.is_post_training_integer_quantization())
1933
1934    if (requires_quantized_input_stats and
1935        not converter_kwargs["quantized_input_stats"]):
1936      raise ValueError(
1937          "The `quantized_input_stats` flag must be defined when either "
1938          "`inference_type` flag or `inference_input_type` flag is set to "
1939          "tf.int8 or tf.uint8. Currently, `inference_type={}` and "
1940          "`inference_input_type={}`.".format(
1941              _get_tf_type_name(converter_kwargs["inference_type"]),
1942              _get_tf_type_name(converter_kwargs["inference_input_type"])))
1943
1944  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
1945  def _validate_inputs(self, input_tensors, quantized_input_stats):
1946    """Validate input parameters.
1947
1948    Args:
1949      input_tensors: List of input tensors.
1950      quantized_input_stats: Map of input tensor names to a tuple of floats
1951        representing the mean and standard deviation of the training data.
1952
1953    Raises:
1954      ValueError:
1955        Input shape is not specified.
1956        Quantization input stats is required but not provided.
1957    """
1958
1959    if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()):
1960      # Checks dimensions in input tensor.
1961      for tensor in input_tensors:
1962        shape = tensor.shape
1963        if not shape:
1964          raise ValueError("Provide an input shape for input array "
1965                           "'{0}'.".format(_get_tensor_name(tensor)))
1966        # Note that shape_list might be empty for scalar shapes.
1967        shape_list = shape.as_list()
1968        if None in shape_list[1:]:
1969          raise ValueError(
1970              "None is only supported in the 1st dimension. Tensor '{0}' has "
1971              "invalid shape '{1}'.".format(
1972                  _get_tensor_name(tensor), shape_list))
1973        elif shape_list and shape_list[0] is None:
1974          self._set_batch_size(batch_size=1)
1975
1976    # Get quantization stats. Ensures there is one stat per name if the stats
1977    # are specified.
1978    if quantized_input_stats:
1979      self._quantized_stats = []
1980      invalid_stats = []
1981      for name in self.get_input_arrays():
1982        if name in quantized_input_stats:
1983          self._quantized_stats.append(quantized_input_stats[name])
1984        else:
1985          invalid_stats.append(name)
1986
1987      if invalid_stats:
1988        raise ValueError("Quantization input stats are not available for input "
1989                         "tensors '{0}'.".format(",".join(invalid_stats)))
1990    else:
1991      self._quantized_stats = None
1992
1993  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
1994  def _optimize_tf_model(self, graph_def, input_tensors, output_tensors,
1995                         quant_mode):
1996    """Run a Grappler pass to optimize the TensorFlow graph.
1997
1998    Args:
1999      graph_def: Frozen GraphDef to be optimized.
2000      input_tensors: List of input tensors.
2001      output_tensors: List of output tensors.
2002      quant_mode: the quantization mode.
2003
2004    Returns:
2005      The optimized TensorFlow graph.
2006    """
2007    # Disable grappler constant folding if there are training quant ops.
2008    if self.saved_model_dir or quant_mode.is_quantization_aware_trained_model():
2009      return graph_def
2010
2011    try:
2012      # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
2013      # Grappler will also try to lower while loop into switch merge
2014      # representation which is undesired for Ophints, so we simply remove
2015      # those attributes to prevent Grappler from doing so.
2016      graph = _convert_to_constants.disable_lower_using_switch_merge(graph_def)
2017      # Run function inlining optimization to ensure any models generated
2018      # through the from_frozen_graph path have been inlined.
2019      optimized_graph = _run_graph_optimizations(
2020          graph,
2021          input_tensors,
2022          output_tensors,
2023          config=self._grappler_config(["function"]))
2024      return optimized_graph
2025    except Exception:  # pylint: disable=broad-except
2026      return graph_def
2027
2028  def convert(self):
2029    """Converts a TensorFlow GraphDef based on instance variables.
2030
2031    Returns:
2032      The converted data in serialized format. Either a TFLite Flatbuffer or a
2033      Graphviz graph depending on value in `output_format`.
2034
2035    Raises:
2036      ValueError:
2037        Input shape is not specified.
2038        None value for dimension in input_tensor.
2039    """
2040    self._validate_inputs(self._input_tensors, self.quantized_input_stats)
2041
2042    quant_mode = QuantizationMode(
2043        self.optimizations, self.target_spec, self.representative_dataset,
2044        self._graph_def, self._experimental_disable_per_channel,
2045        self.experimental_new_dynamic_range_quantizer,
2046        self._experimental_low_bit_qat,
2047        self._experimental_full_integer_quantization_bias_type)
2048
2049    optimized_graph = self._optimize_tf_model(self._graph_def,
2050                                              self._input_tensors,
2051                                              self._output_tensors, quant_mode)
2052
2053    self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
2054
2055    converter_kwargs = self._get_base_converter_args()
2056    converter_kwargs.update(
2057        quant_mode.converter_flags(self.inference_type,
2058                                   self.inference_input_type))
2059    converter_kwargs.update({
2060        "output_format": self.output_format,
2061        "quantized_input_stats": self._quantized_stats,
2062        "default_ranges_stats": self.default_ranges_stats,
2063        "drop_control_dependency": self.drop_control_dependency,
2064        "reorder_across_fake_quant": self.reorder_across_fake_quant,
2065        "change_concat_input_ranges": self.change_concat_input_ranges,
2066        "dump_graphviz_dir": self.dump_graphviz_dir,
2067        "dump_graphviz_video": self.dump_graphviz_video,
2068        "conversion_summary_dir": self.conversion_summary_dir,
2069    })
2070
2071    self._validate_quantized_input_stats(converter_kwargs, quant_mode)
2072    if not self.experimental_new_converter:
2073      logging.warning(
2074          "Please consider switching to the new converter by setting "
2075          "experimental_new_converter=True. "
2076          "The old converter is deprecated.")
2077    else:
2078      logging.info("Using experimental converter: If you encountered a problem "
2079                   "please file a bug. You can opt-out "
2080                   "by setting experimental_new_converter=False")
2081    # Converts model.
2082    if self._has_valid_tensors():
2083      result = _convert_graphdef(
2084          input_data=optimized_graph,
2085          input_tensors=self._input_tensors,
2086          output_tensors=self._output_tensors,
2087          **converter_kwargs)
2088    else:
2089      result = _convert_graphdef_with_arrays(
2090          input_data=optimized_graph,
2091          input_arrays_with_shape=self._input_arrays_with_shape,
2092          output_arrays=self._output_arrays,
2093          control_output_arrays=self._control_output_arrays,
2094          **converter_kwargs)
2095
2096    return self._optimize_tflite_model(
2097        result, quant_mode, quant_io=self.experimental_new_quantizer)
2098
2099  def get_input_arrays(self):
2100    """Returns a list of the names of the input tensors.
2101
2102    Returns:
2103      List of strings.
2104    """
2105    if self._has_valid_tensors():
2106      return [_get_tensor_name(tensor) for tensor in self._input_tensors]
2107    else:
2108      return [name for name, _ in self._input_arrays_with_shape]
2109
2110  def _has_valid_tensors(self):
2111    """Checks if the input and output tensors have been initialized.
2112
2113    Returns:
2114      Bool.
2115    """
2116    return self._input_tensors is not None and self._output_tensors
2117
2118  def _set_batch_size(self, batch_size):
2119    """Sets the first dimension of the input tensor to `batch_size`.
2120
2121    Args:
2122      batch_size: Batch size for the model. Replaces the first dimension of an
2123        input size array if undefined. (default 1)
2124
2125    Raises:
2126      ValueError: input_tensor is not defined.
2127    """
2128    if not self._has_valid_tensors():
2129      raise ValueError("The batch size cannot be set for this model. Please "
2130                       "use input_shapes parameter.")
2131
2132    for tensor in self._input_tensors:
2133      shape = tensor.shape.as_list()
2134      if shape[0] is None:
2135        shape[0] = batch_size
2136        tensor.set_shape(shape)
2137
2138  def _is_unknown_shapes_allowed(self):
2139    # Ophint Converted nodes will need the shapes to be known.
2140    if _is_ophint_converted(self._graph_def):
2141      return False
2142
2143    if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed():
2144      return False
2145
2146    # `conversion_summary_dir` calls the old converter. Unknown shapes are only
2147    # supported by the MLIR converter.
2148    if self.conversion_summary_dir:
2149      logging.warning(
2150          "`conversion_summary_dir` does not work with unknown shapes. "
2151          "Graphs with unknown shapes might be different than when this flag "
2152          "is disabled.")
2153      return False
2154    return True
2155
2156  def _save_conversion_params_metric(self):
2157    self._collected_converter_params.update({
2158        "output_format": self.output_format,
2159        "default_ranges_stats": self.default_ranges_stats,
2160        "drop_control_dependency": self.drop_control_dependency,
2161        "reorder_across_fake_quant": self.reorder_across_fake_quant,
2162        "change_concat_input_ranges": self.change_concat_input_ranges,
2163        "dump_graphviz_dir": self.dump_graphviz_dir,
2164        "dump_graphviz_video": self.dump_graphviz_video,
2165        "conversion_summary_dir": self.conversion_summary_dir,
2166    })
2167    super(TFLiteConverterBaseV1,
2168          self)._save_conversion_params_metric(self._graph_def,
2169                                               self.inference_type,
2170                                               self.inference_input_type)
2171
2172
2173class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
2174  """Converts the given SavedModel into TensorFlow Lite model.
2175
2176  Attributes:
2177      saved_model_dir: Directory of the SavedModel.
2178  """
2179
2180  def __init__(self,
2181               saved_model_dir,
2182               saved_model_tags,
2183               saved_model_exported_names,
2184               experimental_debug_info_func=None):
2185    """Constructor for TFLiteConverter.
2186
2187    Args:
2188      saved_model_dir: Directory of the SavedModel.
2189      saved_model_tags: Set of tags identifying the MetaGraphDef within the
2190        SavedModel to analyze. All tags in the tag set must be present. (default
2191        {tf.saved_model.SERVING}).
2192      saved_model_exported_names: Names to be exported when the saved model
2193        import path is on.
2194      experimental_debug_info_func: An experimental function to retrieve the
2195        graph debug info for a set of nodes from the `graph_def`.
2196
2197    Raises:
2198      ValueError: Invalid arguments.
2199    """
2200    super(TFLiteSavedModelConverter,
2201          self).__init__(experimental_debug_info_func)
2202    self.saved_model_dir = saved_model_dir
2203    self._saved_model_tags = saved_model_tags
2204    self._saved_model_exported_names = saved_model_exported_names
2205
2206    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
2207
2208    if len(self._saved_model_exported_names) != 1:
2209      raise ValueError("Only support a single signature key.")
2210
2211    signature_key = self._saved_model_exported_names[0]
2212
2213    result = _freeze_saved_model(self.saved_model_dir, None, None, None,
2214                                 self._saved_model_tags, signature_key)
2215    self._graph_def = result[0]
2216    self._input_tensors = result[1]
2217    self._output_tensors = result[2]
2218    self._parse_saved_model_args()
2219
2220  @_export_metrics
2221  def convert(self):
2222    """Converts a TensorFlow GraphDef based on instance variables.
2223
2224    Returns:
2225      The converted data in serialized format. Either a TFLite Flatbuffer or a
2226      Graphviz graph depending on value in `output_format`.
2227
2228    Raises:
2229      ValueError:
2230        Input shape is not specified.
2231        None value for dimension in input_tensor.
2232    """
2233    return super(TFLiteSavedModelConverter, self).convert()
2234
2235
2236class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
2237  """Converts the given SavedModel into TensorFlow Lite model."""
2238
2239  def __init__(self,
2240               model_file,
2241               input_arrays=None,
2242               input_shapes=None,
2243               output_arrays=None,
2244               custom_objects=None):
2245    """Constructor for TFLiteConverter.
2246
2247    Args:
2248      model_file: Full filepath of HDF5 file containing the tf.keras model.
2249      input_arrays: List of input tensors to freeze graph with. Uses input
2250        arrays from SignatureDef when none are provided. (default None)
2251      input_shapes: Dict of strings representing input tensor names to list of
2252        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2253        Automatically determined when input shapes is None (e.g., {"foo" :
2254          None}). (default None)
2255      output_arrays: List of output tensors to freeze graph with. Uses output
2256        arrays from SignatureDef when none are provided. (default None)
2257      custom_objects: Dict mapping names (strings) to custom classes or
2258        functions to be considered during model deserialization. (default None)
2259
2260    Raises:
2261      ValueError: Invalid arguments.
2262    """
2263    super(TFLiteKerasModelConverter,
2264          self).__init__(experimental_debug_info_func=None)
2265    # Handles Keras when Eager mode is enabled.
2266    if context.executing_eagerly():
2267      if input_arrays or output_arrays:
2268        raise ValueError("`input_arrays` and `output_arrays` are unsupported "
2269                         "with Eager mode. If your model requires any of these "
2270                         "parameters, please use disable_eager_execution().")
2271
2272      keras_model = keras_deps.get_load_model_function()(model_file,
2273                                                         custom_objects)
2274      function = _trace_model_call(keras_model)
2275      concrete_func = function.get_concrete_function()
2276
2277      frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
2278          concrete_func, lower_control_flow=False)
2279      _set_tensor_shapes(frozen_func.inputs, input_shapes)
2280      self._keras_model = keras_model
2281      self._graph_def = frozen_func.graph.as_graph_def()
2282      self._input_tensors = frozen_func.inputs
2283      self._output_tensors = frozen_func.outputs
2284      self._debug_info_func = _build_debug_info_func(frozen_func.graph)
2285      return
2286
2287    # Handles Keras when Eager mode is disabled.
2288    keras_deps.get_clear_session_function()()
2289    keras_model = keras_deps.get_load_model_function()(model_file,
2290                                                       custom_objects)
2291    sess = keras_deps.get_get_session_function()()
2292
2293    # Get input and output tensors.
2294    if input_arrays:
2295      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
2296    else:
2297      input_tensors = keras_model.inputs
2298
2299    if output_arrays:
2300      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
2301    else:
2302      output_tensors = keras_model.outputs
2303    _set_tensor_shapes(input_tensors, input_shapes)
2304
2305    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
2306    self._keras_model = keras_model
2307    self._graph_def = graph_def
2308    self._input_tensors = input_tensors
2309    self._output_tensors = output_tensors
2310    self._debug_info_func = _build_debug_info_func(sess.graph)
2311
2312  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
2313  def _freeze_keras_model(self, output_dir):
2314    """Save Keras model to Saved Model format.
2315
2316    Args:
2317      output_dir: The output directory to save the SavedModel.
2318    """
2319    try:
2320      self._keras_model.save(output_dir, save_format="tf")
2321    except Exception:  # pylint: disable=broad-except
2322      # When storing the given keras model to a saved model is failed, let's
2323      # use original keras model conversion pipeline.
2324      return None
2325    tag_set = set([_tag_constants.SERVING])
2326    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
2327    graph_def, input_tensors, output_tensors, sess_graph = _freeze_saved_model(
2328        output_dir, None, None, None, tag_set, signature_key)
2329
2330    self.saved_model_dir = output_dir
2331    self._saved_model_tags = tag_set
2332    self._saved_model_exported_names = [signature_key]
2333    self._parse_saved_model_args()
2334    if self.saved_model_dir:
2335      self._graph_def = graph_def
2336      self._input_tensors = input_tensors
2337      self._output_tensors = output_tensors
2338      self._debug_info_func = _build_debug_info_func(sess_graph)
2339
2340  def _convert_as_saved_model(self):
2341    """Converts a Keras model as a saved model.
2342
2343    Returns:
2344      The converted data in serialized format.
2345    """
2346    temp_dir = tempfile.mkdtemp()
2347    try:
2348      self._freeze_keras_model(temp_dir)
2349      if self.saved_model_dir:
2350        return super(TFLiteKerasModelConverter, self).convert()
2351    finally:
2352      shutil.rmtree(temp_dir, True)
2353
2354  @_export_metrics
2355  def convert(self):
2356    """Converts a Keras model based on instance variables.
2357
2358    Returns:
2359      The converted data in serialized format. Either a TFLite Flatbuffer or a
2360      Graphviz graph depending on value in `output_format`.
2361
2362    Raises:
2363      ValueError:
2364        Input shape is not specified.
2365        None value for dimension in input_tensor.
2366    """
2367    saved_model_convert_result = self._convert_as_saved_model()
2368    if saved_model_convert_result:
2369      return saved_model_convert_result
2370
2371    return super(TFLiteKerasModelConverter, self).convert()
2372
2373
2374class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
2375  """Converts the given frozen graph def into TensorFlow Lite model."""
2376
2377  def __init__(self,
2378               graph_def,
2379               input_tensors,
2380               output_tensors,
2381               input_arrays_with_shape=None,
2382               output_arrays=None,
2383               experimental_debug_info_func=None):
2384    """Constructor for TFLiteConverter.
2385
2386    Args:
2387      graph_def: Frozen TensorFlow GraphDef.
2388      input_tensors: List of input tensors. Type and shape are computed using
2389        `foo.shape` and `foo.dtype`.
2390      output_tensors: List of output tensors (only .name is used from this).
2391      input_arrays_with_shape: Tuple of strings representing input tensor names
2392        and list of integers representing input shapes
2393        (e.g., [("foo", [1, 16, 16, 3])]). Use only when graph cannot be loaded
2394          into TensorFlow and when `input_tensors` and `output_tensors` are
2395          None. (default None)
2396      output_arrays: List of output tensors to freeze graph with. Use only when
2397        graph cannot be loaded into TensorFlow and when `input_tensors` and
2398        `output_tensors` are None. (default None)
2399      experimental_debug_info_func: An experimental function to retrieve the
2400        graph debug info for a set of nodes from the `graph_def`.
2401
2402    Raises:
2403      ValueError: Invalid arguments.
2404    """
2405    super(TFLiteFrozenGraphConverter,
2406          self).__init__(experimental_debug_info_func)
2407    self._graph_def = graph_def
2408    self._input_tensors = input_tensors
2409    self._output_tensors = output_tensors
2410    self._control_output_arrays = None
2411
2412    # Attributes are used by models that cannot be loaded into TensorFlow.
2413    if not self._has_valid_tensors():
2414      self._input_arrays_with_shape = input_arrays_with_shape
2415      self._output_arrays = output_arrays
2416
2417    if input_tensors is not None and input_arrays_with_shape is not None:
2418      logging.warning("input_arrays_with_shape will be ignored when both the "
2419                      "given input_tensors and input_arrays_with_shape are not "
2420                      "None.")
2421
2422    if output_tensors is not None and output_arrays is not None:
2423      logging.warning("output_arrays will be ignored when both the given "
2424                      "output_tensors and output_arrays are not None.")
2425
2426  @_export_metrics
2427  def convert(self):
2428    """Converts a TensorFlow GraphDef based on instance variables.
2429
2430    Returns:
2431      The converted data in serialized format. Either a TFLite Flatbuffer or a
2432      Graphviz graph depending on value in `output_format`.
2433
2434    Raises:
2435      ValueError:
2436        Input shape is not specified.
2437        None value for dimension in input_tensor.
2438    """
2439    if not self._has_valid_tensors():
2440      if not self._input_arrays_with_shape or not (self._output_arrays or
2441                                                   self._control_output_arrays):
2442        raise ValueError(
2443            "If input_tensors and output_tensors are None, both "
2444            "input_arrays_with_shape and output_arrays|control_output_arrays "
2445            "must be defined.")
2446    return super(TFLiteFrozenGraphConverter, self).convert()
2447
2448
2449@_tf_export(v1=["lite.TFLiteConverter"])
2450class TFLiteConverter(TFLiteFrozenGraphConverter):
2451  """Convert a TensorFlow model into `output_format`.
2452
2453  This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras
2454  model into either a TFLite FlatBuffer or graph visualization.
2455
2456  Attributes:
2457    optimizations: Experimental flag, subject to change. Set of optimizations to
2458      apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
2459      set of values of type `tf.lite.Optimize`)
2460    representative_dataset: A generator function used for integer quantization
2461      where each generated sample has the same order, type and shape as the
2462      inputs to the model. Usually, this is a small subset of a few hundred
2463      samples randomly chosen, in no particular order, from the training or
2464      evaluation dataset. This is an optional attribute, but required for full
2465      integer quantization, i.e, if `tf.int8` is the only supported type in
2466      `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
2467      (default None)
2468    target_spec: Experimental flag, subject to change. Specifications of target
2469      device, including supported ops set, supported types and a set of user's
2470      defined TensorFlow operators required in the TensorFlow Lite runtime.
2471      Refer to `tf.lite.TargetSpec`.
2472    inference_type: Data type of numeric arrays, excluding the input layer.
2473      (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
2474    inference_input_type: Data type of the numeric arrays in the input layer. If
2475      `inference_input_type` is in {tf.int8, tf.uint8}, then
2476      `quantized_input_stats` must be provided. (default is the value assigned
2477      to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
2478    inference_output_type: Data type of the numeric arrays in the output layer.
2479      (default is the value assigned to `inference_type`, must be in
2480      {tf.float32, tf.int8, tf.uint8})
2481    quantized_input_stats: Map of input tensor names to a tuple of floats
2482      representing the mean and standard deviation of the training data.
2483      (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
2484        or tf.uint8. (default None)
2485    default_ranges_stats: Tuple of integers (min, max) representing range values
2486      for all numeric arrays without a specified range. Intended for
2487      experimenting with quantization via "dummy quantization". (default None)
2488    allow_custom_ops: Boolean indicating whether to allow custom operations.
2489      When False any unknown operation is an error. When True, custom ops are
2490      created for any op that is unknown. The developer will need to provide
2491      these to the TensorFlow Lite runtime with a custom resolver. (default
2492      False)
2493    drop_control_dependency: Boolean indicating whether to drop control
2494      dependencies silently. This is due to TFLite not supporting control
2495      dependencies. (default True)
2496    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
2497      nodes in unexpected locations. Used when the location of the FakeQuant
2498      nodes is preventing graph transformations necessary to convert the graph.
2499      Results in a graph that differs from the quantized training graph,
2500      potentially causing differing arithmetic behavior. (default False)
2501    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
2502      inputs and outputs of the concat operator for quantized models. Changes
2503      the ranges of concat operator overlap when true. (default False)
2504    output_format: Output file format. (default
2505      tf.compat.v1.lite.constants.TFLITE, must be in
2506      {tf.compat.v1.lite.constants.TFLITE,
2507      tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
2508    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
2509      stages of processing GraphViz .dot files. Preferred over
2510      `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
2511      the requirements of the output file. (default None)
2512    dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
2513      files after every graph transformation. Requires the `dump_graphviz_dir`
2514      flag to be specified. (default False)
2515    conversion_summary_dir: Full path of the directory to store conversion logs.
2516      (default None)
2517    exclude_conversion_metadata: Whether not to embed the conversion metadata
2518      into the converted model. (default False)
2519    target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
2520    post_training_quantize: Deprecated. Please use `optimizations` instead and
2521      set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
2522    experimental_new_converter: Experimental flag, subject to change. Enables
2523      MLIR-based conversion. (default True)
2524    experimental_new_quantizer: Experimental flag, subject to change. Enables
2525      MLIR-based quantization conversion instead of Flatbuffer-based conversion.
2526      (default True)
2527
2528  Example usage:
2529
2530    ```python
2531    # Converting a GraphDef from session.
2532    converter = tf.compat.v1.lite.TFLiteConverter.from_session(
2533      sess, in_tensors, out_tensors)
2534    tflite_model = converter.convert()
2535    open("converted_model.tflite", "wb").write(tflite_model)
2536
2537    # Converting a GraphDef from file.
2538    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
2539      graph_def_file, input_arrays, output_arrays)
2540    tflite_model = converter.convert()
2541    open("converted_model.tflite", "wb").write(tflite_model)
2542
2543    # Converting a SavedModel.
2544    converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
2545        saved_model_dir)
2546    tflite_model = converter.convert()
2547    open("converted_model.tflite", "wb").write(tflite_model)
2548
2549    # Converting a tf.keras model.
2550    converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
2551        keras_model)
2552    tflite_model = converter.convert()
2553    open("converted_model.tflite", "wb").write(tflite_model)
2554    ```
2555  """
2556
2557  # pylint: disable=useless-super-delegation
2558  def __init__(self,
2559               graph_def,
2560               input_tensors,
2561               output_tensors,
2562               input_arrays_with_shape=None,
2563               output_arrays=None,
2564               experimental_debug_info_func=None):
2565    """Constructor for TFLiteConverter.
2566
2567    Args:
2568      graph_def: Frozen TensorFlow GraphDef.
2569      input_tensors: List of input tensors. Type and shape are computed using
2570        `foo.shape` and `foo.dtype`.
2571      output_tensors: List of output tensors (only .name is used from this).
2572      input_arrays_with_shape: Tuple of strings representing input tensor names
2573        and list of integers representing input shapes
2574        (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
2575          into TensorFlow and when `input_tensors` and `output_tensors` are
2576          None. (default None)
2577      output_arrays: List of output tensors to freeze graph with. Use only when
2578        graph cannot be loaded into TensorFlow and when `input_tensors` and
2579        `output_tensors` are None. (default None)
2580      experimental_debug_info_func: An experimental function to retrieve the
2581        graph debug info for a set of nodes from the `graph_def`.
2582
2583    Raises:
2584      ValueError: Invalid arguments.
2585    """
2586    super(TFLiteConverter,
2587          self).__init__(graph_def, input_tensors, output_tensors,
2588                         input_arrays_with_shape, output_arrays,
2589                         experimental_debug_info_func)
2590
2591  @classmethod
2592  def from_session(cls, sess, input_tensors, output_tensors):
2593    """Creates a TFLiteConverter class from a TensorFlow Session.
2594
2595    Args:
2596      sess: TensorFlow Session.
2597      input_tensors: List of input tensors. Type and shape are computed using
2598        `foo.shape` and `foo.dtype`.
2599      output_tensors: List of output tensors (only .name is used from this).
2600
2601    Returns:
2602      TFLiteConverter class.
2603    """
2604    # pylint: disable=protected-access
2605    TFLiteConverterBase._set_original_model_type(
2606        conversion_metdata_fb.ModelType.TF_SESSION)
2607    # pylint: enable=protected-access
2608    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
2609    return cls(
2610        graph_def,
2611        input_tensors,
2612        output_tensors,
2613        experimental_debug_info_func=_build_debug_info_func(sess.graph))
2614
2615  @classmethod
2616  def from_frozen_graph(cls,
2617                        graph_def_file,
2618                        input_arrays,
2619                        output_arrays,
2620                        input_shapes=None):
2621    """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
2622
2623    Args:
2624      graph_def_file: Full filepath of file containing frozen GraphDef.
2625      input_arrays: List of input tensors to freeze graph with.
2626      output_arrays: List of output tensors to freeze graph with.
2627      input_shapes: Dict of strings representing input tensor names to list of
2628        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2629        Automatically determined when input shapes is None (e.g., {"foo" :
2630          None}). (default None)
2631
2632    Returns:
2633      TFLiteConverter class.
2634
2635    Raises:
2636      IOError:
2637        File not found.
2638        Unable to parse input file.
2639      ValueError:
2640        The graph is not frozen.
2641        input_arrays or output_arrays contains an invalid tensor name.
2642        input_shapes is not correctly defined when required
2643    """
2644    # pylint: disable=protected-access
2645    TFLiteConverterBase._set_original_model_type(
2646        conversion_metdata_fb.ModelType.TF_GRAPH_DEF)
2647    # pylint: enable=protected-access
2648    with _ops.Graph().as_default():
2649      with _session.Session() as sess:
2650        # Read GraphDef from file.
2651        if not gfile.Exists(graph_def_file):
2652          raise IOError("File '{0}' does not exist.".format(graph_def_file))
2653        with gfile.GFile(graph_def_file, "rb") as f:
2654          file_content = f.read()
2655
2656        try:
2657          graph_def = _graph_pb2.GraphDef()
2658          graph_def.ParseFromString(file_content)
2659        except (_text_format.ParseError, DecodeError):
2660          try:
2661            print("Ignore 'tcmalloc: large alloc' warnings.")
2662
2663            if not isinstance(file_content, str):
2664              file_content = file_content.decode("utf-8")
2665            graph_def = _graph_pb2.GraphDef()
2666            _text_format.Merge(file_content, graph_def)
2667          except (_text_format.ParseError, DecodeError):
2668            raise IOError(
2669                "Unable to parse input file '{}'.".format(graph_def_file))
2670
2671        # Handles models with custom TFLite ops that cannot be resolved in
2672        # TensorFlow.
2673        load_model_in_session = True
2674        try:
2675          _import_graph_def(graph_def, name="")
2676        except _NotFoundError:
2677          load_model_in_session = False
2678
2679        if load_model_in_session:
2680          # Check if graph is frozen.
2681          if not _is_frozen_graph(sess):
2682            raise ValueError("Please freeze the graph using freeze_graph.py.")
2683
2684          # Get input and output tensors.
2685          input_tensors = _get_tensors_from_tensor_names(
2686              sess.graph, input_arrays)
2687          output_tensors = _get_tensors_from_tensor_names(
2688              sess.graph, output_arrays)
2689          _set_tensor_shapes(input_tensors, input_shapes)
2690
2691          return cls(sess.graph_def, input_tensors, output_tensors)
2692        else:
2693          if not input_shapes:
2694            raise ValueError("input_shapes must be defined for this model.")
2695          if set(input_arrays) != set(input_shapes.keys()):
2696            raise ValueError("input_shapes must contain a value for each item "
2697                             "in input_array.")
2698
2699          input_arrays_with_shape = [
2700              (name, input_shapes[name]) for name in input_arrays
2701          ]
2702          return cls(
2703              graph_def,
2704              input_tensors=None,
2705              output_tensors=None,
2706              input_arrays_with_shape=input_arrays_with_shape,
2707              output_arrays=output_arrays)
2708
2709  @classmethod
2710  def from_saved_model(cls,
2711                       saved_model_dir,
2712                       input_arrays=None,
2713                       input_shapes=None,
2714                       output_arrays=None,
2715                       tag_set=None,
2716                       signature_key=None):
2717    """Creates a TFLiteConverter class from a SavedModel.
2718
2719    Args:
2720      saved_model_dir: SavedModel directory to convert.
2721      input_arrays: List of input tensors to freeze graph with. Uses input
2722        arrays from SignatureDef when none are provided. (default None)
2723      input_shapes: Dict of strings representing input tensor names to list of
2724        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2725        Automatically determined when input shapes is None (e.g., {"foo" :
2726          None}). (default None)
2727      output_arrays: List of output tensors to freeze graph with. Uses output
2728        arrays from SignatureDef when none are provided. (default None)
2729      tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
2730        analyze. All tags in the tag set must be present. (default
2731        {tf.saved_model.SERVING})
2732      signature_key: Key identifying SignatureDef containing inputs and outputs.
2733        (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
2734
2735    Returns:
2736      TFLiteConverter class.
2737    """
2738    # pylint: disable=protected-access
2739    TFLiteConverterBase._set_original_model_type(
2740        conversion_metdata_fb.ModelType.TF_SAVED_MODEL)
2741    # pylint: enable=protected-access
2742    if tag_set is None:
2743      tag_set = set([_tag_constants.SERVING])
2744    if signature_key is None:
2745      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
2746
2747    saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set,
2748                                                      [signature_key])
2749    if saved_model_converter.saved_model_dir:
2750      return saved_model_converter
2751
2752    result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
2753                                 output_arrays, tag_set, signature_key)
2754
2755    return cls(
2756        graph_def=result[0],
2757        input_tensors=result[1],
2758        output_tensors=result[2],
2759        experimental_debug_info_func=_build_debug_info_func(result[3]))
2760
2761  @classmethod
2762  def from_keras_model_file(cls,
2763                            model_file,
2764                            input_arrays=None,
2765                            input_shapes=None,
2766                            output_arrays=None,
2767                            custom_objects=None):
2768    """Creates a TFLiteConverter class from a tf.keras model file.
2769
2770    Args:
2771      model_file: Full filepath of HDF5 file containing the tf.keras model.
2772      input_arrays: List of input tensors to freeze graph with. Uses input
2773        arrays from SignatureDef when none are provided. (default None)
2774      input_shapes: Dict of strings representing input tensor names to list of
2775        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2776        Automatically determined when input shapes is None (e.g., {"foo" :
2777          None}). (default None)
2778      output_arrays: List of output tensors to freeze graph with. Uses output
2779        arrays from SignatureDef when none are provided. (default None)
2780      custom_objects: Dict mapping names (strings) to custom classes or
2781        functions to be considered during model deserialization. (default None)
2782
2783    Returns:
2784      TFLiteConverter class.
2785    """
2786    # pylint: disable=protected-access
2787    TFLiteConverterBase._set_original_model_type(
2788        conversion_metdata_fb.ModelType.KERAS_MODEL)
2789    # pylint: enable=protected-access
2790    return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes,
2791                                     output_arrays, custom_objects)
2792
2793  # pylint: disable=useless-super-delegation
2794  def convert(self):
2795    """Converts a TensorFlow GraphDef based on instance variables.
2796
2797    Returns:
2798      The converted data in serialized format. Either a TFLite Flatbuffer or a
2799      Graphviz graph depending on value in `output_format`.
2800
2801    Raises:
2802      ValueError:
2803        Input shape is not specified.
2804        None value for dimension in input_tensor.
2805    """
2806    return super(TFLiteConverter, self).convert()
2807
2808
2809@_tf_export(v1=["lite.TocoConverter"])
2810class TocoConverter:
2811  """Convert a TensorFlow model into `output_format`.
2812
2813  This class has been deprecated. Please use `lite.TFLiteConverter` instead.
2814  """
2815
2816  @classmethod
2817  @_deprecation.deprecated(None,
2818                           "Use `lite.TFLiteConverter.from_session` instead.")
2819  def from_session(cls, sess, input_tensors, output_tensors):
2820    """Creates a TocoConverter class from a TensorFlow Session."""
2821    return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
2822
2823  @classmethod
2824  @_deprecation.deprecated(
2825      None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
2826  def from_frozen_graph(cls,
2827                        graph_def_file,
2828                        input_arrays,
2829                        output_arrays,
2830                        input_shapes=None):
2831    """Creates a TocoConverter class from a file containing a frozen graph."""
2832    return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
2833                                             output_arrays, input_shapes)
2834
2835  @classmethod
2836  @_deprecation.deprecated(
2837      None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
2838  def from_saved_model(cls,
2839                       saved_model_dir,
2840                       input_arrays=None,
2841                       input_shapes=None,
2842                       output_arrays=None,
2843                       tag_set=None,
2844                       signature_key=None):
2845    """Creates a TocoConverter class from a SavedModel."""
2846    return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
2847                                            input_shapes, output_arrays,
2848                                            tag_set, signature_key)
2849
2850  @classmethod
2851  @_deprecation.deprecated(
2852      None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
2853  def from_keras_model_file(cls,
2854                            model_file,
2855                            input_arrays=None,
2856                            input_shapes=None,
2857                            output_arrays=None):
2858    """Creates a TocoConverter class from a tf.keras model file."""
2859    return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
2860                                                 input_shapes, output_arrays)
2861