• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""TensorFlow Lite tooling helper functionality."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import enum
23import functools
24import pprint
25import shutil
26import tempfile
27import time
28import warnings
29
30from absl import logging
31import six
32from six import PY2
33
34from google.protobuf import text_format as _text_format
35from google.protobuf.message import DecodeError
36from tensorflow.core.framework import graph_pb2 as _graph_pb2
37from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op  # pylint: disable=unused-import
38from tensorflow.lite.python import lite_constants as constants
39from tensorflow.lite.python.convert import build_toco_convert_protos  # pylint: disable=unused-import
40from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model
41from tensorflow.lite.python.convert import ConverterError  # pylint: disable=unused-import
42from tensorflow.lite.python.convert import deduplicate_readonly_buffers as _deduplicate_readonly_buffers
43from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize
44from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify
45from tensorflow.lite.python.convert import OpsSet
46from tensorflow.lite.python.convert import toco_convert  # pylint: disable=unused-import
47from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
48from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl
49from tensorflow.lite.python.convert import toco_convert_protos  # pylint: disable=unused-import
50from tensorflow.lite.python.convert_phase import Component
51from tensorflow.lite.python.convert_phase import convert_phase
52from tensorflow.lite.python.convert_phase import SubComponent
53from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
54from tensorflow.lite.python.interpreter import Interpreter  # pylint: disable=unused-import
55from tensorflow.lite.python.interpreter import load_delegate  # pylint: disable=unused-import
56from tensorflow.lite.python.interpreter import OpResolverType  # pylint: disable=unused-import
57from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs  # pylint: disable=unused-import
58from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
59from tensorflow.lite.python.op_hint import OpHint  # pylint: disable=unused-import
60from tensorflow.lite.python.optimize import calibrator as _calibrator
61from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
62from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func
63from tensorflow.lite.python.util import freeze_graph as _freeze_graph
64from tensorflow.lite.python.util import get_debug_info as _get_debug_info
65from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
66from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
67from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
68from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
69from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
70from tensorflow.lite.python.util import model_input_signature as _model_input_signature
71from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
72from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
73from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
74from tensorflow.lite.python.util import trace_model_call as _trace_model_call
75from tensorflow.python import saved_model as _saved_model
76from tensorflow.python.client import session as _session
77from tensorflow.python.eager import context
78from tensorflow.python.eager import def_function as _def_function
79from tensorflow.python.eager import function as _function
80from tensorflow.python.framework import convert_to_constants as _convert_to_constants
81from tensorflow.python.framework import dtypes as _dtypes
82from tensorflow.python.framework import ops as _ops
83from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
84from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
85from tensorflow.python.platform import gfile
86from tensorflow.python.saved_model import loader_impl as _loader_impl
87from tensorflow.python.saved_model import save_options as _save_options
88from tensorflow.python.saved_model import signature_constants as _signature_constants
89from tensorflow.python.saved_model import tag_constants as _tag_constants
90from tensorflow.python.saved_model.load import load as _load
91from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
92from tensorflow.python.util import deprecation as _deprecation
93from tensorflow.python.util import keras_deps
94from tensorflow.python.util.tf_export import tf_export as _tf_export
95
96# pylint: disable=g-import-not-at-top
97try:
98  from tensorflow.lite.python import metrics_portable as metrics
99except ImportError:
100  from tensorflow.lite.python import metrics_nonportable as metrics
101# pylint: enable=g-import-not-at-top
102
103
104@_tf_export("lite.Optimize")
105class Optimize(enum.Enum):
106  """Enum defining the optimizations to apply when generating a tflite model.
107
108  DEFAULT
109      Default optimization strategy that quantizes model weights. Enhanced
110      optimizations are gained by providing a representative dataset that
111      quantizes biases and activations as well.
112      Converter will do its best to reduce size and latency, while minimizing
113      the loss in accuracy.
114
115  OPTIMIZE_FOR_SIZE
116      Deprecated. Does the same as DEFAULT.
117
118  OPTIMIZE_FOR_LATENCY
119      Deprecated. Does the same as DEFAULT.
120
121  EXPERIMENTAL_SPARSITY
122      Experimental flag, subject to change.
123
124      Enable optimization by taking advantage of the sparse model weights
125      trained with pruning.
126
127      The converter will inspect the sparsity pattern of the model weights and
128      do its best to improve size and latency.
129      The flag can be used alone to optimize float32 models with sparse weights.
130      It can also be used together with the DEFAULT optimization mode to
131      optimize quantized models with sparse weights.
132  """
133
134  # Default optimization strategy that quantizes model weights. Enhanced
135  # optimizations are gained by providing a representative dataset that
136  # quantizes biases and activations as well.
137  # Converter will do its best to reduce size and latency, while minimizing
138  # the loss in accuracy.
139  DEFAULT = "DEFAULT"
140
141  # Deprecated. Does the same as DEFAULT.
142  OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
143
144  # Deprecated. Does the same as DEFAULT.
145  OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
146
147  # Experimental flag, subject to change.
148  # Enable optimization by taking advantage of the sparse model weights trained
149  # with pruning.
150  #
151  # The converter will inspect the sparsity pattern of the model weights and do
152  # its best to improve size and latency.
153  # The flag can be used alone to optimize float32 models with sparse weights.
154  # It can also be used together with the DEFAULT optimization mode to optimize
155  # quantized models with sparse weights.
156  # TODO(b/161560631): Add log message when this optimization is applied.
157  EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY"
158
159  def __str__(self):
160    return str(self.value)
161
162
163@_tf_export("lite.RepresentativeDataset")
164class RepresentativeDataset(object):
165  """Representative dataset used to optimize the model.
166
167  This is a generator function that provides a small dataset to calibrate or
168  estimate the range, i.e, (min, max) of all floating-point arrays in the model
169  (such as model input, activation outputs of intermediate layers, and model
170  output) for quantization. Usually, this is a small subset of a few hundred
171  samples randomly chosen, in no particular order, from the training or
172  evaluation dataset.
173  """
174
175  def __init__(self, input_gen):
176    """Creates a representative dataset.
177
178    Args:
179      input_gen: A generator function that generates input samples for the
180        model and has the same order, type and shape as the inputs to the model.
181        Usually, this is a small subset of a few hundred samples randomly
182        chosen, in no particular order, from the training or evaluation dataset.
183    """
184    self.input_gen = input_gen
185
186
187@_tf_export("lite.TargetSpec")
188class TargetSpec(object):
189  """Specification of target device used to optimize the model.
190
191  Attributes:
192    supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
193      options, where each option represents a set of operators supported by the
194      target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
195    supported_types: Set of `tf.dtypes.DType` data types supported on the target
196      device. If initialized, optimization might be driven by the smallest type
197      in this set. (default set())
198    experimental_select_user_tf_ops: Experimental flag, subject to change. Set
199      of user's TensorFlow operators' names that are required in the TensorFlow
200      Lite runtime. These ops will be exported as select TensorFlow ops in the
201      model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
202      an advanced feature that should only be used if the client is using TF ops
203      that may not be linked in by default with the TF ops that are provided
204      when using the SELECT_TF_OPS path. The client is responsible for linking
205      these ops into the target runtime.
206  """
207
208  def __init__(self,
209               supported_ops=None,
210               supported_types=None,
211               experimental_select_user_tf_ops=None):
212    if supported_ops is None:
213      supported_ops = {OpsSet.TFLITE_BUILTINS}
214    self.supported_ops = supported_ops
215    if supported_types is None:
216      supported_types = set()
217    self.supported_types = supported_types
218    if experimental_select_user_tf_ops is None:
219      experimental_select_user_tf_ops = set()
220    self.experimental_select_user_tf_ops = experimental_select_user_tf_ops
221    self._experimental_custom_op_registerers = []
222    # Hint for the supported accumulation type used for inference. Typically
223    # used for fp16 post-training quantization, where some models can use fp16
224    # accumulators instead of the typical fp32 type.
225    # TODO(b/188185962): Provide full API and authoring support for
226    # reduced precision accumulation types.
227    self._experimental_supported_accumulation_type = None
228
229
230class QuantizationMode(object):
231  """QuantizationMode determines the quantization type from user options."""
232
233  def __init__(self, optimizations, target_spec, representative_dataset,
234               graph_def):
235    self._optimizations = optimizations
236    for deprecated_optimization in [
237        Optimize.OPTIMIZE_FOR_SIZE, Optimize.OPTIMIZE_FOR_LATENCY
238    ]:
239      if deprecated_optimization in self._optimizations:
240        logging.warning(
241            "Optimization option %s is deprecated, please use optimizations="
242            "[Optimize.DEFAULT] instead.", deprecated_optimization)
243
244    self._target_spec = target_spec
245    self._representative_dataset = representative_dataset
246    self._graph_def = graph_def
247
248    self._validate_int8_required()
249
250  # TODO(b/162537905): Refactor the following quantization functions -
251  # re-organize and refactor for better readability.
252  def post_training_int8_no_float(self):
253    return (self.any_optimization_enabled() and
254            self._is_int8_target_required() and
255            not self._is_int16x8_target_required() and
256            not self.is_allow_float() and
257            self._representative_dataset is not None)
258
259  def post_training_int8_allow_float(self):
260    return (self.any_optimization_enabled() and
261            not self._is_int16x8_target_required() and
262            self._representative_dataset is not None and
263            self._smallest_supported_type() == _dtypes.int8)
264
265  def is_post_training_integer_quantize_8(self):
266    return (self.post_training_int8_no_float() or
267            self.post_training_int8_allow_float())
268
269  def is_post_training_integer_quantize_16x8(self):
270    return (self.post_training_int16x8_no_float() or
271            self.post_training_int16x8_allow_float())
272
273  def is_post_training_integer_quantize(self):
274    return (self.is_post_training_integer_quantize_8() or
275            self.is_post_training_integer_quantize_16x8())
276
277  def is_integer_quantize(self):
278    return (self.is_post_training_integer_quantize() or
279            self.is_training_time_int8_allow_float())
280
281  def is_training_time_int8_allow_float(self):
282    return (self.any_optimization_enabled() and
283            self.contains_training_quant_op())
284
285  def is_bfloat16_inference_allowed(self):
286    return (self.any_optimization_enabled() and
287            self._smallest_supported_type().size == 2 and
288            _dtypes.bfloat16 in self._target_spec.supported_types)
289
290  def post_training_int16x8_no_float(self):
291    return (self.any_optimization_enabled() and
292            not self._is_int8_target_required() and
293            self._is_int16x8_target_required() and
294            not self.is_allow_float() and
295            self._representative_dataset is not None)
296
297  def post_training_int16x8_allow_float(self):
298    return (self.any_optimization_enabled() and
299            self._is_int16x8_target_required() and
300            self.is_allow_float())
301
302  def post_training_dynamic_range_int8(self):
303    # Post-training dynamic range quantization is only enabled if post-training
304    # int8 quantization and training time quantization was not done.
305    return (self.any_optimization_enabled() and
306            self._representative_dataset is None and
307            not self.contains_training_quant_op() and
308            self._smallest_supported_type() == _dtypes.int8)
309
310  def post_training_fp16(self):
311    return (self.any_optimization_enabled() and
312            self._smallest_supported_type().size == 2 and
313            _dtypes.float16 in self._target_spec.supported_types)
314
315  def fp32_execution(self):
316    """If none of the above are true."""
317    return not (self.is_integer_quantize() or
318                self.post_training_dynamic_range_int8() or
319                self.post_training_fp16())
320
321  def activations_type(self):
322    if self.is_integer_quantize():
323      if self._is_int16x8_target_required():
324        return _dtypes.int16
325      else:
326        return _dtypes.int8
327    else:
328      return _dtypes.float32
329
330  def converter_flags(self, inference_ty=None, inference_input_ty=None):
331    """Flags to the converter."""
332
333    if self.is_integer_quantize():
334      return {
335          "inference_type": (
336              inference_ty if inference_ty else self.activations_type()),
337          "inference_input_type": _dtypes.float32,
338          "post_training_quantize": False,  # disable dynamic range quantization
339          "quantize_to_float16": False  # disable float16 quantization
340      }
341    elif self.post_training_dynamic_range_int8():
342      return {
343          "inference_type": _dtypes.float32,
344          "inference_input_type": _dtypes.float32,
345          "post_training_quantize": True,  # enable dynamic range quantization
346          "quantize_to_float16": False  # disable float16 quantization
347      }
348    elif self.post_training_fp16():
349      return {
350          "inference_type": _dtypes.float32,
351          "inference_input_type": _dtypes.float32,
352          "post_training_quantize": True,
353          "quantize_to_float16": True,  # enable float16 quantization
354          "accumulation_type":
355              self._target_spec._experimental_supported_accumulation_type,
356          "allow_bfloat16":
357              self.is_bfloat16_inference_allowed()
358      }
359    else:
360      # Note this might still trigger (uint8) quantization to be compatible with
361      # TOCO.
362      return {
363          "inference_type": inference_ty if inference_ty else _dtypes.float32,
364          "inference_input_type": inference_input_ty,
365          "post_training_quantize": False,  # enable dynamic range quantization
366          "quantize_to_float16": False,  # disable float16 quantization
367          "allow_bfloat16": self.is_bfloat16_inference_allowed()
368      }
369
370  # Below are helpers for the above functions.
371
372  def _validate_int8_required(self):
373    """Int8 mode requires certain parameters to exist and be compatible."""
374    if not self._is_int8_target_required():
375      return
376
377    if self._target_spec.supported_types and (self._smallest_supported_type() !=
378                                              _dtypes.int8):
379      raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
380                       "type to be INT8.")
381
382    if self._representative_dataset:
383      if not isinstance(self._representative_dataset, RepresentativeDataset):
384        self._representative_dataset = RepresentativeDataset(
385            self._representative_dataset)
386      if self._representative_dataset.input_gen is None:
387        raise ValueError(
388            "Provide an input generator for representative_dataset")
389    else:
390      # TODO(b/162537905): Relax this check for QAT.
391      raise ValueError("representative_dataset is required when specifying "
392                       "TFLITE_BUILTINS_INT8 or INT8 supported types.")
393
394  def _is_int8_target_required(self):
395    return (OpsSet.TFLITE_BUILTINS_INT8 in set(
396        self._target_spec.supported_ops)) or (set(
397            self._target_spec.supported_types) == set([_dtypes.int8]))
398
399  def _is_int16x8_target_required(self):
400    return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
401            in set(self._target_spec.supported_ops))
402
403  def is_allow_float(self):
404    return (OpsSet.TFLITE_BUILTINS in set(
405        self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set(
406            self._target_spec.supported_ops))
407
408  def any_optimization_enabled(self):
409    return bool(
410        set(self._optimizations).intersection([
411            Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
412            Optimize.DEFAULT
413        ]))
414
415  def _smallest_supported_type(self):
416    if self._target_spec.supported_types:
417      return min(self._target_spec.supported_types, key=lambda x: x.size)
418    else:
419      # The default smallest supported type is INT8.
420      return _dtypes.int8
421
422  def contains_training_quant_op(self):
423    """Checks if the graph contains any training-time quantization ops."""
424    training_quant_ops = frozenset({
425        "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel",
426        "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsPerChannel",
427        "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"
428    })
429
430    if self._graph_def:
431      for node_def in self._graph_def.node:
432        if node_def.op in training_quant_ops:
433          return True
434      for function in self._graph_def.library.function:
435        for node_def in function.node_def:
436          if node_def.op in training_quant_ops:
437            return True
438    return False
439
440
441class TFLiteConverterBase(object):
442  """Converter subclass to share functionality between V1 and V2 converters."""
443
444  def __init__(self):
445    self.optimizations = set()
446    self.representative_dataset = None
447    self.target_spec = TargetSpec()
448    self.allow_custom_ops = False
449    self.experimental_new_converter = True
450    self.experimental_new_quantizer = True
451    self.experimental_enable_resource_variables = False
452    self._experimental_new_quantizer = None
453    self._experimental_calibrate_only = False
454    self._experimental_sparsify_model = False
455    self._experimental_disable_per_channel = False
456    self._debug_info = None  # contains the stack traces of all the original
457    # nodes in the `GraphDef` to the converter.
458    self.saved_model_dir = None
459    self._saved_model_tags = None
460    self._saved_model_version = 0
461    self._saved_model_exported_names = []
462    self._tflite_metrics = metrics.TFLiteConverterMetrics()
463    self._collected_converter_params = {}
464    self._experimental_disable_batchmatmul_unfold = False
465    self._experimental_lower_tensor_list_ops = True
466    self._experimental_unfold_large_splat_constant = False
467
468  def _grappler_config(self, optimizers=None):
469    """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
470
471    Args:
472      optimizers: List of strings that represents the list of optimizers.
473
474    Returns:
475      tf.ConfigProto.
476    """
477    if not optimizers:
478      optimizers = []
479    # MLIR converter will take care of constant folding instead of grappler.
480    if not self.experimental_new_converter:
481      optimizers.append("constfold")
482
483    is_only_flex_enabled = (
484        set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
485    if is_only_flex_enabled:
486      # The layout optimizer turns NHCW to NCHW. This provides performance
487      # optimizations when Flex mode is enabled. However, this is not compatible
488      # with builtin ops.
489      optimizers.append("layout")
490    return _get_grappler_config(optimizers)
491
492  def _quantize(self, result, input_type, output_type, activations_type,
493                allow_float):
494    """Quantize the model."""
495    # pylint: disable=protected-access
496    custom_op_registerers_by_name = [
497        x for x in self.target_spec._experimental_custom_op_registerers
498        if isinstance(x, str)
499    ]
500    custom_op_registerers_by_func = [
501        x for x in self.target_spec._experimental_custom_op_registerers
502        if not isinstance(x, str)
503    ]
504    # pylint: enable=protected-access
505    if not isinstance(self.representative_dataset, RepresentativeDataset):
506      self.representative_dataset = RepresentativeDataset(
507          self.representative_dataset)
508
509    # Add intermediate tensors to the model if needed.
510    result = _calibrator.add_intermediate_tensors(result)
511    calibrate_quantize = _calibrator.Calibrator(result,
512                                                custom_op_registerers_by_name,
513                                                custom_op_registerers_by_func)
514    if self._experimental_calibrate_only or self.experimental_new_quantizer:
515      calibrated = calibrate_quantize.calibrate(
516          self.representative_dataset.input_gen)
517
518    if self._experimental_calibrate_only:
519      return calibrated
520    elif self.experimental_new_quantizer and (
521        activations_type != _dtypes.int16):
522      # TODO(b/175659372): remove the activations_type restriction and enable
523      # it for all the activation types.
524      return _mlir_quantize(
525          calibrated,
526          self._experimental_disable_per_channel,
527          input_data_type=input_type,
528          output_data_type=output_type)
529    else:
530      return calibrate_quantize.calibrate_and_quantize(
531          self.representative_dataset.input_gen, input_type, output_type,
532          allow_float, activations_type,
533          disable_per_channel=self._experimental_disable_per_channel)
534
535  def _is_unknown_shapes_allowed(self):
536    # Unknown dimensions are only allowed with the new converter.
537    return self.experimental_new_converter
538
539  def _get_base_converter_args(self):
540    """Returns the base converter args.
541
542    Returns:
543      {key str: val}
544    """
545    args = {
546        "input_format":
547            constants.TENSORFLOW_GRAPHDEF,
548        "allow_custom_ops":
549            self.allow_custom_ops,
550        "debug_info":
551            self._debug_info,
552        "target_ops":
553            self.target_spec.supported_ops,
554        "enable_mlir_converter":
555            self.experimental_new_converter,
556        "select_user_tf_ops":
557            self.target_spec.experimental_select_user_tf_ops,
558        "unfold_batchmatmul":
559            not self._experimental_disable_batchmatmul_unfold,
560        "lower_tensor_list_ops":
561            self._experimental_lower_tensor_list_ops,
562        "unfold_large_splat_constant":
563            self._experimental_unfold_large_splat_constant,
564    }
565
566    if self.saved_model_dir:
567      args.update({
568          "saved_model_dir": self.saved_model_dir,
569          "saved_model_version": self._saved_model_version,
570          "saved_model_tags": self._saved_model_tags,
571          "saved_model_exported_names": self._saved_model_exported_names,
572      })
573
574    return args
575
576  def _contains_function_with_implements_attr(self, saved_model_proto):
577    meta_graph = saved_model_proto.meta_graphs[0]
578    for function in meta_graph.graph_def.library.function:
579      if function.attr.get("_implements", None) or function.attr.get(
580          "api_implements", None):
581        return True
582    return False
583
584  def _parse_saved_model_args(self, always_enable_saved_model_import=False):
585    """Parses SavedModel arguments from the given Keras/RNN SavedModel.
586
587    Args:
588      always_enable_saved_model_import: Bool. When the value is true, it enables
589        MLIR saved model import path regardless of checking the conditions.
590    """
591    if not self.experimental_new_converter:
592      self.saved_model_dir = None
593      return
594    if self.saved_model_dir:
595      try:
596        saved_model_proto, _ = (
597            _parse_saved_model_with_debug_info(self.saved_model_dir))
598      except OSError:
599        # If it fails to read the given saved model, it will fall back to the
600        # frozen graph def path.
601        self.saved_model_dir = None
602        return
603      if (not always_enable_saved_model_import and
604          not self._contains_function_with_implements_attr(saved_model_proto)):
605        self.saved_model_dir = None
606        return
607
608      if not self._saved_model_exported_names:
609        self._saved_model_exported_names = []
610      self._saved_model_version = saved_model_proto.saved_model_schema_version
611      if self._saved_model_version == 0:
612        self.saved_model_dir = None
613        logging.warning("SavedModel schema version is zero.")
614        return
615      if self._saved_model_version not in [1, 2]:
616        raise ValueError("SavedModel file format({0}) is not supported".format(
617            self._saved_model_version))
618
619  def _sparsify_model(self):
620    return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations
621
622  def _validate_experimental_new_quantizer_flag(self):
623    if self._experimental_new_quantizer is not None:
624      raise ValueError("Please use 'experimental_new_quantizer' instead.")
625
626  def _increase_conversion_attempt_metric(self):
627    self._tflite_metrics.increase_counter_converter_attempt()
628
629  def _increase_conversion_success_metric(self):
630    self._tflite_metrics.increase_counter_converter_success()
631
632  def _save_conversion_params_metric(self,
633                                     graph_def=None,
634                                     inference_type=None,
635                                     inference_input_type=None):
636    """Set conversion parameter metrics."""
637    converter_kwargs = self._collected_converter_params
638    converter_kwargs.update(self._get_base_converter_args())
639
640    # Optimization parameters.
641    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
642                                  self.representative_dataset, graph_def)
643    converter_kwargs.update({
644        "optimization_default":
645            quant_mode.any_optimization_enabled(),
646        "optimization_post_training_dynamic_range":
647            quant_mode.post_training_dynamic_range_int8(),
648        "optimization_post_training_float16":
649            quant_mode.post_training_fp16(),
650        "optimization_post_training_integer_quantize":
651            quant_mode.is_post_training_integer_quantize(),
652        "optimization_qat":
653            quant_mode.is_training_time_int8_allow_float(),
654        "optimization_sparsify":
655            self._sparsify_model(),
656        "activations_type":
657            quant_mode.activations_type()
658    })
659    converter_kwargs.update(
660        quant_mode.converter_flags(inference_type, inference_input_type))
661
662    # pylint: disable=protected-access
663    if self.target_spec._experimental_supported_accumulation_type:
664      converter_kwargs.update({
665          "accumulation_type":
666              self.target_spec._experimental_supported_accumulation_type
667      })
668    # pylint: enable=protected-access
669
670    def format_element(elem):
671      if isinstance(elem, enum.Enum):
672        return str(elem.value)
673      return pprint.pformat(elem)
674
675    def format_param(param):
676      if isinstance(param, (list, tuple, set)):
677        if not param:
678          return "None"  # Return None if empty.
679        string_list = [format_element(x) for x in param]
680        return ",".join(sorted(string_list))
681      return format_element(param)
682
683    for key, value in converter_kwargs.items():
684      self._tflite_metrics.set_converter_param(key, format_param(value))
685    self._tflite_metrics.set_export_required()
686
687  def _set_conversion_latency_metric(self, value):
688    self._tflite_metrics.set_converter_latency(value)
689
690  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL)
691  def _optimize_tflite_model(self, model, quant_mode, quant_io=True):
692    """Apply optimizations on a TFLite model."""
693
694    if quant_mode.is_integer_quantize():
695
696      in_type, out_type = self.inference_input_type, self.inference_output_type
697
698      if quant_mode.is_post_training_integer_quantize():
699        q_in_type = in_type if in_type and quant_io else _dtypes.float32
700        q_out_type = out_type if out_type and quant_io else _dtypes.float32
701        q_activations_type = quant_mode.activations_type()
702        q_allow_float = quant_mode.is_allow_float()
703        model = self._quantize(
704            model, q_in_type, q_out_type, q_activations_type, q_allow_float)
705
706      m_in_type = in_type if in_type else _dtypes.float32
707      m_out_type = out_type if out_type else _dtypes.float32
708      model = _modify_model_io_type(model, m_in_type, m_out_type)
709
710    if self._sparsify_model():
711      model = _mlir_sparsify(model)
712
713    try:
714      model = _deduplicate_readonly_buffers(model)
715    except Exception:  # pylint: disable=broad-except
716      # Skip buffer deduplication when flatbuffer library is not ready to be
717      # utilized.
718      logging.warning(
719          "Buffer deduplication procedure will be skipped when flatbuffer "
720          "library is not properly loaded")
721
722    return model
723
724  def _convert_and_export_metrics(self, convert_func, *args, **kwargs):
725    """Wraps around convert function to export metrics.
726
727    Args:
728      convert_func: The convert function to wrap.
729      *args: Positional arguments of the convert function.
730      **kwargs: The keyword arguments of the convert function.
731
732    Returns:
733      The decorator to wrap the convert function.
734    """
735    self._increase_conversion_attempt_metric()
736    self._save_conversion_params_metric()
737    start_time = time.process_time()
738    result = convert_func(self, *args, **kwargs)
739    elapsed_time_ms = (time.process_time() - start_time) * 1000
740    if result:
741      self._increase_conversion_success_metric()
742    self._set_conversion_latency_metric(round(elapsed_time_ms))
743    self._tflite_metrics.export_metrics()
744    return result
745
746
747def _export_metrics(convert_func):
748  """The decorator around convert function to export metrics."""
749  @functools.wraps(convert_func)
750  def wrapper(self, *args, **kwargs):
751    # pylint: disable=protected-access
752    return self._convert_and_export_metrics(convert_func, *args, **kwargs)
753    # pylint: enable=protected-access
754
755  return wrapper
756
757
758class TFLiteConverterBaseV2(TFLiteConverterBase):
759  """Converter subclass to share functionality between V2 converters."""
760
761  def __init__(self):
762    """Constructor for TFLiteConverter."""
763    super(TFLiteConverterBaseV2, self).__init__()
764    self.inference_input_type = _dtypes.float32
765    self.inference_output_type = _dtypes.float32
766    self._collected_converter_params.update({"api_version": 2})
767
768  def _validate_inference_input_output_types(self, quant_mode):
769    """Validate inference_input_type and inference_output_type flags."""
770    default_types = [_dtypes.float32]
771    # We support integer input/output for integer quantized models only.
772    if quant_mode.is_integer_quantize():
773      if quant_mode.is_post_training_integer_quantize_16x8():
774        all_types = default_types + [_dtypes.int16]
775      else:
776        all_types = default_types + [_dtypes.int8, _dtypes.uint8]
777      if (self.inference_input_type not in all_types or
778          self.inference_output_type not in all_types):
779        all_types_names = ["tf." + t.name for t in all_types]
780        raise ValueError("The inference_input_type and inference_output_type "
781                         "must be in {}.".format(all_types_names))
782    elif (self.inference_input_type not in default_types or
783          self.inference_output_type not in default_types):
784      raise ValueError("The inference_input_type and inference_output_type "
785                       "must be tf.float32.")
786
787  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL)
788  def _load_saved_model(self, saved_model_dir, saved_model_tags):
789    """Load graph_def from saved model with the default serving signature key.
790
791    Args:
792      saved_model_dir: Directory of the SavedModel.
793      saved_model_tags: Set of tags identifying the MetaGraphDef within the
794        SavedModel to analyze.
795
796    Returns:
797      graph_def: The loaded GraphDef.
798      input_tensors: List of input tensors.
799      output_tensors: List of output tensors.
800    """
801    graph = _ops.Graph()
802    saved_model = _loader_impl.SavedModelLoader(saved_model_dir)
803    saved_model.load_graph(graph, tags=saved_model_tags)
804    meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags)
805    graph_def = meta_graph.graph_def
806    signature_def = meta_graph.signature_def[
807        _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
808    input_tensors = [
809        graph.get_tensor_by_name(signature_def.inputs[key].name)
810        for key in signature_def.inputs
811    ]
812    output_tensors = [
813        graph.get_tensor_by_name(signature_def.outputs[key].name)
814        for key in signature_def.outputs
815    ]
816    return graph_def, input_tensors, output_tensors
817
818  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
819  def _validate_inputs(self, graph_def, input_tensors):
820    """Validate the input parameters.
821
822    Args:
823      graph_def: The TensorFlow GraphDef.
824      input_tensors: List of input tensors.
825    Raise:
826      ValueError:
827        Input shape is not specified.
828        Invalid quantization parameters.
829    """
830    # Update conversion params with graph_def.
831    self._save_conversion_params_metric(graph_def)
832    self._quant_mode = QuantizationMode(self.optimizations, self.target_spec,
833                                        self.representative_dataset, graph_def)
834    self._validate_inference_input_output_types(self._quant_mode)
835    self._validate_experimental_new_quantizer_flag()
836
837    if not self._is_unknown_shapes_allowed():
838      # Checks dimensions in input tensor.
839      for tensor in input_tensors:
840        # Note that shape_list might be empty for scalar shapes.
841        shape_list = tensor.shape.as_list()
842        if None in shape_list[1:]:
843          raise ValueError(
844              "None is only supported in the 1st dimension. Tensor '{0}' has "
845              "invalid shape '{1}'.".format(
846                  _get_tensor_name(tensor), shape_list))
847        elif shape_list and shape_list[0] is None:
848          # Set the batch size to 1 if undefined.
849          shape = tensor.shape.as_list()
850          shape[0] = 1
851          tensor.set_shape(shape)
852
853    if (self._trackable_obj is None or
854        not hasattr(self._trackable_obj, "graph_debug_info")):
855      self._debug_info = _get_debug_info(
856          _build_debug_info_func(self._funcs[0].graph), graph_def)
857    else:
858      self._debug_info = _get_debug_info(
859          _convert_debug_info_func(self._trackable_obj.graph_debug_info),
860          graph_def)
861
862  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
863  def _optimize_tf_model(self, graph_def, input_tensors, output_tensors,
864                         frozen_func):
865    """Run a Grappler pass to optimize the TensorFlow graph.
866
867    Args:
868      graph_def: Frozen GraphDef to be optimized.
869      input_tensors: List of input tensors.
870      output_tensors: List of output tensors.
871      frozen_func: TensorFlow Graph.
872
873    Returns:
874      The optimized TensorFlow graph.
875    """
876    grappler_config = self._grappler_config()
877    # Skip running grappler when there are no optimizers to run. If not,
878    # grappler will run with the default optimizer set and it will lead to
879    # causing an unexpected behavior.
880    if grappler_config.graph_options.rewrite_options.optimizers:
881      graph_def = _run_graph_optimizations(
882          graph_def,
883          input_tensors,
884          output_tensors,
885          config=grappler_config,
886          graph=frozen_func.graph)
887    return graph_def
888
889  def convert(self, graph_def, input_tensors, output_tensors):
890    """Converts a TensorFlow GraphDef based on instance variables.
891
892    Args:
893      graph_def: Frozen TensorFlow GraphDef.
894      input_tensors: List of input tensors.
895      output_tensors: List of output tensors.
896
897    Returns:
898      The converted data in serialized format.
899
900    Raises:
901      ValueError:
902        No concrete functions is specified.
903        Multiple concrete functions are specified.
904        Input shape is not specified.
905        Invalid quantization parameters.
906    """
907    self._validate_inputs(graph_def, input_tensors)
908    converter_kwargs = self._get_base_converter_args()
909    converter_kwargs.update(self._quant_mode.converter_flags())
910    if not self.experimental_new_converter:
911      logging.warning(
912          "Please consider switching to the new converter by setting "
913          "experimental_new_converter=True. "
914          "The old converter (TOCO) is deprecated.")
915    else:
916      logging.info("Using new converter: If you encounter a problem "
917                   "please file a bug. You can opt-out "
918                   "by setting experimental_new_converter=False")
919
920    # Converts model.
921    result = _toco_convert_impl(
922        input_data=graph_def,
923        input_tensors=input_tensors,
924        output_tensors=output_tensors,
925        **converter_kwargs)
926
927    return self._optimize_tflite_model(
928        result, self._quant_mode, quant_io=self.experimental_new_quantizer)
929
930
931class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
932  """Converts the given SavedModel into TensorFlow Lite model.
933
934  Attributes:
935      saved_model_dir: Directory of the SavedModel.
936  """
937
938  def __init__(self,
939               saved_model_dir,
940               saved_model_tags=None,
941               saved_model_exported_names=None,
942               trackable_obj=None):
943    """Constructor for TFLiteConverter.
944
945    Args:
946      saved_model_dir: Directory of the SavedModel.
947      saved_model_tags: Set of tags identifying the MetaGraphDef within the
948        SavedModel to analyze. All tags in the tag set must be present. (default
949        {tf.saved_model.SERVING}).
950      saved_model_exported_names: Names to be exported when the saved model
951        import path is on.
952      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
953        reference to this object needs to be maintained so that Variables do not
954        get garbage collected since functions have a weak reference to
955        Variables. This is only required when the tf.AutoTrackable object is not
956        maintained by the user (e.g. `from_saved_model`).
957    """
958    super(TFLiteSavedModelConverterV2, self).__init__()
959    self.saved_model_dir = saved_model_dir
960    self._saved_model_tags = saved_model_tags
961    self._saved_model_exported_names = saved_model_exported_names
962    self._trackable_obj = trackable_obj
963    self._parse_saved_model_args(always_enable_saved_model_import=True)
964
965  @_export_metrics
966  def convert(self):
967    """Converts a TensorFlow GraphDef based on instance variables.
968
969    Returns:
970      The converted data in serialized format.
971
972    Raises:
973      ValueError:
974        No concrete functions is specified.
975        Multiple concrete functions are specified.
976        Input shape is not specified.
977        Invalid quantization parameters.
978    """
979    graph_def, input_tensors, output_tensors = self._load_saved_model(
980        self.saved_model_dir, self._saved_model_tags)
981    # If we can't use saved model importer, then fallback
982    # to frozen graph conversion path.
983    if self.saved_model_dir is None or not self.experimental_new_converter:
984      graph_def, _, _, _ = _freeze_saved_model(
985          self.saved_model_dir, None, None, None, self._saved_model_tags,
986          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
987      # We make sure to clear the saved_model_dir as there is some
988      # legacy code down in the caller that checks this.
989      # TODO(b/162537905): Clean these indirect dependencies.
990      self.saved_model_dir = None
991      return super(TFLiteSavedModelConverterV2,
992                   self).convert(graph_def, input_tensors, output_tensors)
993
994    if self._trackable_obj is None:
995      self._debug_info = _get_debug_info(
996          _build_debug_info_func(self._funcs[0].graph), graph_def)
997    else:
998      self._debug_info = _get_debug_info(
999          _convert_debug_info_func(self._trackable_obj.graph_debug_info),
1000          graph_def)
1001
1002    # Update conversion params with graph_def.
1003    self._save_conversion_params_metric(graph_def)
1004    # Get quantization options and do some sanity checks.
1005    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
1006                                  self.representative_dataset, graph_def)
1007    self._validate_inference_input_output_types(quant_mode)
1008
1009    converter_kwargs = {
1010        "enable_tflite_resource_variables":
1011            self.experimental_enable_resource_variables
1012    }
1013    converter_kwargs.update(self._get_base_converter_args())
1014    converter_kwargs.update(quant_mode.converter_flags())
1015
1016    result = _convert_saved_model(**converter_kwargs)
1017
1018    return self._optimize_tflite_model(
1019        result, quant_mode, quant_io=self.experimental_new_quantizer)
1020
1021
1022class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
1023  """Converts the given Keras model into TensorFlow Lite model."""
1024
1025  def __init__(self, keras_model, trackable_obj=None):
1026    """Constructor for TFLiteConverter.
1027
1028    Args:
1029      keras_model: tf.Keras.Model.
1030      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1031        reference to this object needs to be maintained so that Variables do not
1032        get garbage collected since functions have a weak reference to
1033        Variables. This is only required when the tf.AutoTrackable object is not
1034        maintained by the user (e.g. `from_saved_model`).
1035    """
1036    super(TFLiteKerasModelConverterV2, self).__init__()
1037    self._keras_model = keras_model
1038    self._trackable_obj = trackable_obj
1039    self.experimental_lower_to_saved_model = True
1040
1041  @convert_phase(Component.PREPARE_TF_MODEL,
1042                 SubComponent.CONVERT_KERAS_TO_SAVED_MODEL)
1043  def _convert_keras_to_saved_model(self, output_dir):
1044    """Save Keras model to the SavedModel format.
1045
1046    Args:
1047      output_dir: The output directory to save the SavedModel.
1048
1049    Returns:
1050      graph_def: The frozen GraphDef.
1051      input_tensors: List of input tensors.
1052      output_tensors: List of output tensors.
1053    """
1054    try:
1055      _saved_model.save(
1056          self._keras_model,
1057          output_dir,
1058          options=_save_options.SaveOptions(save_debug_info=True))
1059    except Exception:  # pylint: disable=broad-except
1060      # When storing the given keras model to a saved model is failed, let's
1061      # use original keras model conversion pipeline.
1062      return None, None, None
1063    self.saved_model_dir = output_dir
1064    self._saved_model_tags = set([_tag_constants.SERVING])
1065    self._saved_model_exported_names = [
1066        _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1067    ]
1068    self._parse_saved_model_args(
1069        always_enable_saved_model_import=self.experimental_lower_to_saved_model)
1070    if self.saved_model_dir:
1071      graph_def, input_tensors, output_tensors = self._load_saved_model(
1072          self.saved_model_dir, self._saved_model_tags)
1073      self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
1074      return graph_def, input_tensors, output_tensors
1075    return None, None, None
1076
1077  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
1078  def _freeze_keras_model(self):
1079    """Freeze Keras model to frozen graph.
1080
1081    Returns:
1082      graph_def: The frozen GraphDef.
1083      input_tensors: List of input tensors.
1084      output_tensors: List of output tensors.
1085      frozen_func: The frozen ConcreteFunction.
1086    """
1087    input_signature = None
1088    # If the model's call is not a `tf.function`, then we need to first get its
1089    # input signature from `model_input_signature` method. We can't directly
1090    # call `trace_model_call` because otherwise the batch dimension is set
1091    # to None.
1092    # Once we have better support for dynamic shapes, we can remove this.
1093    if not isinstance(self._keras_model.call, _def_function.Function):
1094      # Pass `keep_original_batch_size=True` will ensure that we get an input
1095      # signature including the batch dimension specified by the user.
1096      # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
1097      input_signature = _model_input_signature(
1098          self._keras_model, keep_original_batch_size=True)
1099
1100    # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
1101    func = _trace_model_call(self._keras_model, input_signature)
1102    concrete_func = func.get_concrete_function()
1103    self._funcs = [concrete_func]
1104
1105    frozen_func, graph_def = (
1106        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
1107            self._funcs[0], lower_control_flow=False))
1108
1109    input_tensors = [
1110        tensor for tensor in frozen_func.inputs
1111        if tensor.dtype != _dtypes.resource
1112    ]
1113    output_tensors = frozen_func.outputs
1114    return graph_def, input_tensors, output_tensors, frozen_func
1115
1116  def _convert_as_saved_model(self):
1117    """Converts a Keras model as a saved model.
1118
1119    Returns:
1120      The converted data in serialized format.
1121    """
1122    temp_dir = tempfile.mkdtemp()
1123    try:
1124      graph_def, input_tensors, output_tensors = (
1125          self._convert_keras_to_saved_model(temp_dir))
1126      if self.saved_model_dir:
1127        return super(TFLiteKerasModelConverterV2,
1128                     self).convert(graph_def, input_tensors, output_tensors)
1129    finally:
1130      shutil.rmtree(temp_dir, True)
1131
1132  @_export_metrics
1133  def convert(self):
1134    """Converts a keras model based on instance variables.
1135
1136    Returns:
1137      The converted data in serialized format.
1138
1139    Raises:
1140      ValueError:
1141        Multiple concrete functions are specified.
1142        Input shape is not specified.
1143        Invalid quantization parameters.
1144    """
1145    saved_model_convert_result = self._convert_as_saved_model()
1146    if saved_model_convert_result:
1147      return saved_model_convert_result
1148
1149    graph_def, input_tensors, output_tensors, frozen_func = (
1150        self._freeze_keras_model())
1151
1152    graph_def = self._optimize_tf_model(graph_def, input_tensors,
1153                                        output_tensors, frozen_func)
1154
1155    return super(TFLiteKerasModelConverterV2,
1156                 self).convert(graph_def, input_tensors, output_tensors)
1157
1158
1159class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
1160  """Converts the given frozen graph into TensorFlow Lite model."""
1161
1162  def __init__(self, funcs, trackable_obj=None):
1163    """Constructor for TFLiteConverter.
1164
1165    Args:
1166      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1167        duplicate elements.
1168      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1169        reference to this object needs to be maintained so that Variables do not
1170        get garbage collected since functions have a weak reference to
1171        Variables. This is only required when the tf.AutoTrackable object is not
1172        maintained by the user (e.g. `from_saved_model`).
1173    """
1174    super(TFLiteFrozenGraphConverterV2, self).__init__()
1175    self._funcs = funcs
1176    self._trackable_obj = trackable_obj
1177    self.experimental_lower_to_saved_model = True
1178
1179  @convert_phase(Component.PREPARE_TF_MODEL,
1180                 SubComponent.FREEZE_CONCRETE_FUNCTION)
1181  def _freeze_concrete_function(self):
1182    """Convert the given ConcreteFunction to frozen graph.
1183
1184    Returns:
1185      graph_def: The frozen GraphDef.
1186      input_tensors: List of input tensors.
1187      output_tensors: List of output tensors.
1188      frozen_func: The frozen ConcreteFunction.
1189
1190    Raises:
1191      ValueError: none or multiple ConcreteFunctions provided.
1192    """
1193    # TODO(b/130297984): Add support for converting multiple function.
1194
1195    if len(self._funcs) == 0:  # pylint: disable=g-explicit-length-test
1196      raise ValueError("No ConcreteFunction is specified.")
1197
1198    if len(self._funcs) > 1:
1199      raise ValueError("This converter can only convert a single "
1200                       "ConcreteFunction. Converting multiple functions is "
1201                       "under development.")
1202
1203    frozen_func, graph_def = (
1204        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
1205            self._funcs[0], lower_control_flow=False))
1206
1207    input_tensors = [
1208        tensor for tensor in frozen_func.inputs
1209        if tensor.dtype != _dtypes.resource
1210    ]
1211    output_tensors = frozen_func.outputs
1212    return graph_def, input_tensors, output_tensors, frozen_func
1213
1214  @convert_phase(Component.PREPARE_TF_MODEL,
1215                 SubComponent.CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL)
1216  def _convert_concrete_functions_to_saved_model(self, output_dir):
1217    """Save concrete functions to the SavedModel format.
1218
1219    Args:
1220      output_dir: The output directory to save the SavedModel.
1221
1222    Returns:
1223      graph_def: The frozen GraphDef.
1224      input_tensors: List of input tensors.
1225      output_tensors: List of output tensors.
1226    """
1227    if len(self._funcs) == 0:  # pylint: disable=g-explicit-length-test
1228      raise ValueError("No ConcreteFunction is specified.")
1229
1230    if not self.experimental_lower_to_saved_model:
1231      return None, None, None
1232
1233    # Without the provided trackable obj, it is not able to serialize the given
1234    # concrete functions as a saved model format.
1235    if not self._trackable_obj:
1236      return None, None, None
1237
1238    signatures = {}
1239    signature_keys = []
1240    try:
1241      if len(self._funcs) == 1:
1242        signatures[_signature_constants
1243                   .DEFAULT_SERVING_SIGNATURE_DEF_KEY] = self._funcs[0]
1244        signature_keys = [
1245            _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1246        ]
1247      else:
1248        for func in self._funcs:
1249          signatures[func.graph.name] = func
1250          signature_keys.append(func.graph.name)
1251
1252      _saved_model.save(
1253          self._trackable_obj,
1254          output_dir,
1255          signatures=signatures,
1256          options=_save_options.SaveOptions(save_debug_info=True))
1257    except Exception:  # pylint: disable=broad-except
1258      # When storing the given concrete function to a saved model is failed,
1259      # let's use original concrete function conversion pipeline.
1260      return None, None, None
1261
1262    self.saved_model_dir = output_dir
1263    self._saved_model_tags = set([_tag_constants.SERVING])
1264    self._saved_model_exported_names = signature_keys
1265    self._parse_saved_model_args(always_enable_saved_model_import=True)
1266    if self.saved_model_dir:
1267      graph_def, input_tensors, output_tensors = self._load_saved_model(
1268          self.saved_model_dir, self._saved_model_tags)
1269      self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
1270      return graph_def, input_tensors, output_tensors
1271    return None, None, None
1272
1273  def _convert_as_saved_model(self):
1274    """Converts the given concrete functions as a saved model format.
1275
1276    Returns:
1277      The converted data in serialized format.
1278    """
1279    temp_dir = tempfile.mkdtemp()
1280    try:
1281      graph_def, input_tensors, output_tensors = (
1282          self._convert_concrete_functions_to_saved_model(temp_dir))
1283      if self.saved_model_dir:
1284        return super(TFLiteFrozenGraphConverterV2,
1285                     self).convert(graph_def, input_tensors, output_tensors)
1286    finally:
1287      shutil.rmtree(temp_dir, True)
1288    return None
1289
1290  @_export_metrics
1291  def convert(self):
1292    """Converts a TensorFlow GraphDef based on instance variables.
1293
1294    Returns:
1295      The converted data in serialized format.
1296
1297    Raises:
1298      ValueError:
1299        No concrete functions is specified.
1300        Multiple concrete functions are specified.
1301        Input shape is not specified.
1302        Invalid quantization parameters.
1303    """
1304    if self.experimental_lower_to_saved_model:
1305      saved_model_convert_result = self._convert_as_saved_model()
1306      if saved_model_convert_result:
1307        return saved_model_convert_result
1308
1309    graph_def, input_tensors, output_tensors, frozen_func = (
1310        self._freeze_concrete_function())
1311
1312    graph_def = self._optimize_tf_model(graph_def, input_tensors,
1313                                        output_tensors, frozen_func)
1314
1315    return super(TFLiteFrozenGraphConverterV2,
1316                 self).convert(graph_def, input_tensors, output_tensors)
1317
1318
1319@_tf_export("lite.TFLiteConverter", v1=[])
1320class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
1321  """Converts a TensorFlow model into TensorFlow Lite model.
1322
1323  Attributes:
1324    optimizations: Experimental flag, subject to change. Set of optimizations
1325      to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
1326      set of values of type `tf.lite.Optimize`)
1327    representative_dataset: A generator function used for integer quantization
1328      where each generated sample has the same order, type and shape as the
1329      inputs to the model. Usually, this is a small subset of a few hundred
1330      samples randomly chosen, in no particular order, from the training or
1331      evaluation dataset. This is an optional attribute, but required for full
1332      integer quantization, i.e, if `tf.int8` is the only supported type in
1333      `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
1334      (default None)
1335    target_spec: Experimental flag, subject to change. Specifications of target
1336      device, including supported ops set, supported types and a set of user's
1337      defined TensorFlow operators required in the TensorFlow Lite runtime.
1338      Refer to `tf.lite.TargetSpec`.
1339    inference_input_type: Data type of the input layer. Note that integer types
1340      (tf.int8 and tf.uint8) are currently only supported for post training
1341      integer quantization and quantization aware training. (default tf.float32,
1342      must be in {tf.float32, tf.int8, tf.uint8})
1343    inference_output_type: Data type of the output layer. Note that integer
1344      types (tf.int8 and tf.uint8) are currently only supported for post
1345      training integer quantization and quantization aware training. (default
1346      tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
1347    allow_custom_ops: Boolean indicating whether to allow custom operations.
1348      When False, any unknown operation is an error. When True, custom ops are
1349      created for any op that is unknown. The developer needs to provide these
1350      to the TensorFlow Lite runtime with a custom resolver. (default False)
1351    experimental_new_converter: Experimental flag, subject to change. Enables
1352      MLIR-based conversion instead of TOCO conversion. (default True)
1353    experimental_new_quantizer: Experimental flag, subject to change. Enables
1354      MLIR-based quantization conversion instead of Flatbuffer-based conversion.
1355      (default True)
1356    experimental_enable_resource_variables: Experimental flag, subject to
1357      change. Enables resource variables to be converted by this converter.
1358      This is only allowed if from_saved_model interface is used.
1359      (default False)
1360
1361  Example usage:
1362
1363    ```python
1364    # Converting a SavedModel to a TensorFlow Lite model.
1365    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1366    tflite_model = converter.convert()
1367
1368    # Converting a tf.Keras model to a TensorFlow Lite model.
1369    converter = tf.lite.TFLiteConverter.from_keras_model(model)
1370    tflite_model = converter.convert()
1371
1372    # Converting ConcreteFunctions to a TensorFlow Lite model.
1373    converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model)
1374    tflite_model = converter.convert()
1375    ```
1376  """
1377
1378  # pylint: disable=useless-super-delegation
1379  def __init__(self, funcs, trackable_obj=None):
1380    """Constructor for TFLiteConverter.
1381
1382    Args:
1383      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1384        duplicate elements.
1385      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1386        reference to this object needs to be maintained so that Variables do not
1387        get garbage collected since functions have a weak reference to
1388        Variables. This is only required when the tf.AutoTrackable object is not
1389        maintained by the user (e.g. `from_saved_model`).
1390    """
1391    super(TFLiteConverterV2, self).__init__(funcs, trackable_obj)
1392
1393  @classmethod
1394  def from_concrete_functions(cls, funcs, trackable_obj=None):
1395    """Creates a TFLiteConverter object from ConcreteFunctions.
1396
1397    Args:
1398      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1399        duplicate elements. Currently converter can only convert a single
1400        ConcreteFunction. Converting multiple functions is under development.
1401      trackable_obj:   An `AutoTrackable` object (typically `tf.module`)
1402        associated with `funcs`. A reference to this object needs to be
1403        maintained so that Variables do not get garbage collected since
1404        functions have a weak reference to Variables.
1405
1406    Returns:
1407      TFLiteConverter object.
1408
1409    Raises:
1410      Invalid input type.
1411    """
1412    if trackable_obj is None:
1413      logging.warning(
1414          "Please consider providing the trackable_obj argument in the "
1415          "from_concrete_functions. Providing without the trackable_obj "
1416          "argument is deprecated and it will use the deprecated conversion "
1417          "path.")
1418    for func in funcs:
1419      if not isinstance(func, _function.ConcreteFunction):
1420        message = "This function takes in a list of ConcreteFunction."
1421        if isinstance(func, _def_function.Function):
1422          message += (" To get the ConcreteFunction from a Function,"
1423                      " call get_concrete_function.")
1424        raise ValueError(message)
1425    return cls(funcs, trackable_obj)
1426
1427  @classmethod
1428  def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None):
1429    """Creates a TFLiteConverter object from a SavedModel directory.
1430
1431    Args:
1432      saved_model_dir: SavedModel directory to convert.
1433      signature_keys: List of keys identifying SignatureDef containing inputs
1434        and outputs. Elements should not be duplicated. By default the
1435        `signatures` attribute of the MetaGraphdef is used. (default
1436        saved_model.signatures)
1437      tags: Set of tags identifying the MetaGraphDef within the SavedModel to
1438        analyze. All tags in the tag set must be present. (default
1439        {tf.saved_model.SERVING} or {'serve'})
1440
1441    Returns:
1442      TFLiteConverter object.
1443
1444    Raises:
1445      Invalid signature keys.
1446    """
1447    # When run without eager enabled, this will return the legacy
1448    # TFLiteConverter.
1449    if not context.executing_eagerly():
1450      signature_key = None
1451      if signature_keys:
1452        if len(signature_keys) != 1:
1453          raise ValueError("Only support a single signature key.")
1454        else:
1455          signature_key = signature_keys[0]
1456      logging.warning("Invoking the TF1 implementation of TFLiteConverter "
1457                      "because eager is disabled. Consider enabling eager.")
1458      return TFLiteConverter.from_saved_model(
1459          saved_model_dir, signature_key=signature_key, tag_set=tags)
1460
1461    # Ensures any graphs created in Eager mode are able to run. This is required
1462    # in order to create a tf.estimator.Exporter that exports a TFLite model.
1463    if tags is None:
1464      tags = set([_tag_constants.SERVING])
1465
1466    with context.eager_mode():
1467      saved_model = _load(saved_model_dir, tags)
1468    if not signature_keys:
1469      signature_keys = saved_model.signatures
1470
1471    if not signature_keys:
1472      raise ValueError("Only support at least one signature key.")
1473
1474    funcs = []
1475    for key in signature_keys:
1476      if key not in saved_model.signatures:
1477        raise ValueError("Invalid signature key '{}' found. Valid keys are "
1478                         "'{}'.".format(key, ",".join(saved_model.signatures)))
1479      funcs.append(saved_model.signatures[key])
1480
1481    saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags,
1482                                                        signature_keys,
1483                                                        saved_model)
1484    if saved_model_converter.saved_model_dir:
1485      return saved_model_converter
1486
1487    return cls(funcs, saved_model)
1488
1489  @classmethod
1490  def from_keras_model(cls, model):
1491    """Creates a TFLiteConverter object from a Keras model.
1492
1493    Args:
1494      model: tf.Keras.Model
1495
1496    Returns:
1497      TFLiteConverter object.
1498    """
1499    return TFLiteKerasModelConverterV2(model)
1500
1501  # pylint: disable=useless-super-delegation
1502  def convert(self):
1503    """Converts a TensorFlow GraphDef based on instance variables.
1504
1505    Returns:
1506      The converted data in serialized format.
1507
1508    Raises:
1509      ValueError:
1510        No concrete functions is specified.
1511        Multiple concrete functions are specified.
1512        Input shape is not specified.
1513        Invalid quantization parameters.
1514    """
1515    return super(TFLiteConverterV2, self).convert()
1516
1517
1518class TFLiteConverterBaseV1(TFLiteConverterBase):
1519  """Converter subclass to share functionality between V1 converters."""
1520
1521  def __init__(self, experimental_debug_info_func):
1522    """Constructor for TFLiteConverter.
1523
1524    Args:
1525      experimental_debug_info_func: An experimental function to retrieve the
1526        graph debug info for a set of nodes from the `graph_def`.
1527    """
1528    super(TFLiteConverterBaseV1, self).__init__()
1529    self.inference_type = _dtypes.float32
1530    self.inference_input_type = None
1531    self.inference_output_type = None
1532    self.output_format = constants.TFLITE
1533    self.quantized_input_stats = {}
1534    self.default_ranges_stats = None
1535    self.drop_control_dependency = True
1536    self.reorder_across_fake_quant = False
1537    self.change_concat_input_ranges = False
1538    self.dump_graphviz_dir = None
1539    self.dump_graphviz_video = False
1540    self.conversion_summary_dir = None
1541    self._debug_info_func = experimental_debug_info_func
1542    self._experimental_allow_all_select_tf_ops = False
1543
1544  def __setattr__(self, name, value):
1545    if name == "post_training_quantize":
1546      warnings.warn("Property %s is deprecated, "
1547                    "please use optimizations=[Optimize.DEFAULT]"
1548                    " instead." % name)
1549      if value:
1550        self.optimizations = [Optimize.DEFAULT]
1551      else:
1552        self.optimizations = []
1553      return
1554    if name == "target_ops":
1555      warnings.warn("Property %s is deprecated, please use "
1556                    "target_spec.supported_ops instead." % name)
1557      self.target_spec.supported_ops = value
1558      return
1559    object.__setattr__(self, name, value)
1560
1561  def __getattribute__(self, name):
1562    if name == "post_training_quantize":
1563      warnings.warn("Property %s is deprecated, "
1564                    "please use optimizations=[Optimize.DEFAULT]"
1565                    " instead." % name)
1566      return Optimize.DEFAULT in set(self.optimizations)
1567    if name == "target_ops":
1568      warnings.warn("Property %s is deprecated, please use "
1569                    "target_spec.supported_ops instead." % name)
1570      return self.target_spec.supported_ops
1571    return object.__getattribute__(self, name)
1572
1573  def _validate_quantized_input_stats(self, converter_kwargs, quant_mode):
1574    """Ensure the `quantized_input_stats` flag is provided if required."""
1575
1576    quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
1577
1578    requires_quantized_input_stats = (
1579        (converter_kwargs["inference_type"] in quantized_types or
1580         converter_kwargs["inference_input_type"] in quantized_types) and
1581        not quant_mode.is_post_training_integer_quantize())
1582
1583    if (requires_quantized_input_stats and
1584        not converter_kwargs["quantized_input_stats"]):
1585      raise ValueError(
1586          "The `quantized_input_stats` flag must be defined when either "
1587          "`inference_type` flag or `inference_input_type` flag is set to "
1588          "tf.int8 or tf.uint8. Currently, `inference_type={}` and "
1589          "`inference_input_type={}`.".format(
1590              _get_tf_type_name(converter_kwargs["inference_type"]),
1591              _get_tf_type_name(converter_kwargs["inference_input_type"])))
1592
1593  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
1594  def _validate_inputs(self, input_tensors, quantized_input_stats):
1595    """Validate input parameters.
1596
1597    Args:
1598      input_tensors: List of input tensors.
1599      quantized_input_stats: Map of input tensor names to a tuple of floats
1600        representing the mean and standard deviation of the training data.
1601
1602    Raises:
1603      ValueError:
1604        Input shape is not specified.
1605        Quantization input stats is required but not provided.
1606    """
1607
1608    if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()):
1609      # Checks dimensions in input tensor.
1610      for tensor in input_tensors:
1611        shape = tensor.shape
1612        if not shape:
1613          raise ValueError("Provide an input shape for input array "
1614                           "'{0}'.".format(_get_tensor_name(tensor)))
1615        # Note that shape_list might be empty for scalar shapes.
1616        shape_list = shape.as_list()
1617        if None in shape_list[1:]:
1618          raise ValueError(
1619              "None is only supported in the 1st dimension. Tensor '{0}' has "
1620              "invalid shape '{1}'.".format(
1621                  _get_tensor_name(tensor), shape_list))
1622        elif shape_list and shape_list[0] is None:
1623          self._set_batch_size(batch_size=1)
1624
1625    # Get quantization stats. Ensures there is one stat per name if the stats
1626    # are specified.
1627    if quantized_input_stats:
1628      self._quantized_stats = []
1629      invalid_stats = []
1630      for name in self.get_input_arrays():
1631        if name in quantized_input_stats:
1632          self._quantized_stats.append(quantized_input_stats[name])
1633        else:
1634          invalid_stats.append(name)
1635
1636      if invalid_stats:
1637        raise ValueError("Quantization input stats are not available for input "
1638                         "tensors '{0}'.".format(",".join(invalid_stats)))
1639    else:
1640      self._quantized_stats = None
1641
1642  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
1643  def _optimize_tf_model(self, graph_def, input_tensors, output_tensors,
1644                         quant_mode):
1645    """Run a Grappler pass to optimize the TensorFlow graph.
1646
1647    Args:
1648      graph_def: Frozen GraphDef to be optimized.
1649      input_tensors: List of input tensors.
1650      output_tensors: List of output tensors.
1651      quant_mode: the quantization mode.
1652
1653    Returns:
1654      The optimized TensorFlow graph.
1655    """
1656    # Disable grappler constant folding if there are training quant ops.
1657    if self.saved_model_dir or quant_mode.contains_training_quant_op():
1658      return graph_def
1659
1660    try:
1661      # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
1662      # Grappler will also try to lower while loop into switch merge
1663      # representation which is undesired for Ophints, so we simply remove
1664      # those attributes to prevent Grappler from doing so.
1665      graph = _convert_to_constants.disable_lower_using_switch_merge(graph_def)
1666      # Run function inlining optimization to ensure any models generated
1667      # through the from_frozen_graph path have been inlined.
1668      optimized_graph = _run_graph_optimizations(
1669          graph,
1670          input_tensors,
1671          output_tensors,
1672          config=self._grappler_config(["function"]))
1673      return optimized_graph
1674    except Exception:  # pylint: disable=broad-except
1675      return graph_def
1676
1677  def convert(self):
1678    """Converts a TensorFlow GraphDef based on instance variables.
1679
1680    Returns:
1681      The converted data in serialized format. Either a TFLite Flatbuffer or a
1682      Graphviz graph depending on value in `output_format`.
1683
1684    Raises:
1685      ValueError:
1686        Input shape is not specified.
1687        None value for dimension in input_tensor.
1688    """
1689    self._validate_inputs(self._input_tensors, self.quantized_input_stats)
1690
1691    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
1692                                  self.representative_dataset, self._graph_def)
1693
1694    optimized_graph = self._optimize_tf_model(self._graph_def,
1695                                              self._input_tensors,
1696                                              self._output_tensors, quant_mode)
1697
1698    self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
1699
1700    converter_kwargs = self._get_base_converter_args()
1701    converter_kwargs.update(
1702        quant_mode.converter_flags(self.inference_type,
1703                                   self.inference_input_type))
1704    converter_kwargs.update({
1705        "output_format": self.output_format,
1706        "quantized_input_stats": self._quantized_stats,
1707        "default_ranges_stats": self.default_ranges_stats,
1708        "drop_control_dependency": self.drop_control_dependency,
1709        "reorder_across_fake_quant": self.reorder_across_fake_quant,
1710        "change_concat_input_ranges": self.change_concat_input_ranges,
1711        "dump_graphviz_dir": self.dump_graphviz_dir,
1712        "dump_graphviz_video": self.dump_graphviz_video,
1713        "conversion_summary_dir": self.conversion_summary_dir,
1714        "allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops,
1715    })
1716
1717    if not self.experimental_new_converter:
1718      logging.warning(
1719          "Please consider switching to the new converter by setting "
1720          "experimental_new_converter=True. "
1721          "The old converter (TOCO) is deprecated.")
1722    else:
1723      logging.info("Using experimental converter: If you encountered a problem "
1724                   "please file a bug. You can opt-out "
1725                   "by setting experimental_new_converter=False")
1726
1727    self._validate_quantized_input_stats(converter_kwargs, quant_mode)
1728    self._validate_experimental_new_quantizer_flag()
1729    # Converts model.
1730    if self._has_valid_tensors():
1731      result = _toco_convert_impl(
1732          input_data=optimized_graph,
1733          input_tensors=self._input_tensors,
1734          output_tensors=self._output_tensors,
1735          **converter_kwargs)
1736    else:
1737      result = _toco_convert_graph_def(
1738          input_data=optimized_graph,
1739          input_arrays_with_shape=self._input_arrays_with_shape,
1740          output_arrays=self._output_arrays,
1741          control_output_arrays=self._control_output_arrays,
1742          **converter_kwargs)
1743
1744    return self._optimize_tflite_model(
1745        result, quant_mode, quant_io=not self.experimental_new_converter)
1746
1747  def get_input_arrays(self):
1748    """Returns a list of the names of the input tensors.
1749
1750    Returns:
1751      List of strings.
1752    """
1753    if self._has_valid_tensors():
1754      return [_get_tensor_name(tensor) for tensor in self._input_tensors]
1755    else:
1756      return [name for name, _ in self._input_arrays_with_shape]
1757
1758  def _has_valid_tensors(self):
1759    """Checks if the input and output tensors have been initialized.
1760
1761    Returns:
1762      Bool.
1763    """
1764    return self._input_tensors is not None and self._output_tensors
1765
1766  def _set_batch_size(self, batch_size):
1767    """Sets the first dimension of the input tensor to `batch_size`.
1768
1769    Args:
1770      batch_size: Batch size for the model. Replaces the first dimension of an
1771        input size array if undefined. (default 1)
1772
1773    Raises:
1774      ValueError: input_tensor is not defined.
1775    """
1776    if not self._has_valid_tensors():
1777      raise ValueError("The batch size cannot be set for this model. Please "
1778                       "use input_shapes parameter.")
1779
1780    for tensor in self._input_tensors:
1781      shape = tensor.shape.as_list()
1782      if shape[0] is None:
1783        shape[0] = batch_size
1784        tensor.set_shape(shape)
1785
1786  def _is_unknown_shapes_allowed(self):
1787    # Ophint Converted nodes will need the shapes to be known.
1788    if _is_ophint_converted(self._graph_def):
1789      return False
1790
1791    if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed():
1792      return False
1793
1794    # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
1795    # the MLIR converter.
1796    if self.conversion_summary_dir:
1797      logging.warning(
1798          "`conversion_summary_dir` does not work with unknown shapes. "
1799          "Graphs with unknown shapes might be different than when this flag "
1800          "is disabled.")
1801      return False
1802    return True
1803
1804  def _save_conversion_params_metric(self):
1805    self._collected_converter_params.update({
1806        "output_format": self.output_format,
1807        "default_ranges_stats": self.default_ranges_stats,
1808        "drop_control_dependency": self.drop_control_dependency,
1809        "reorder_across_fake_quant": self.reorder_across_fake_quant,
1810        "change_concat_input_ranges": self.change_concat_input_ranges,
1811        "dump_graphviz_dir": self.dump_graphviz_dir,
1812        "dump_graphviz_video": self.dump_graphviz_video,
1813        "conversion_summary_dir": self.conversion_summary_dir,
1814        "api_version": 1,
1815    })
1816    super(TFLiteConverterBaseV1,
1817          self)._save_conversion_params_metric(self._graph_def,
1818                                               self.inference_type,
1819                                               self.inference_input_type)
1820
1821
1822class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
1823  """Converts the given SavedModel into TensorFlow Lite model.
1824
1825  Attributes:
1826      saved_model_dir: Directory of the SavedModel.
1827  """
1828
1829  def __init__(self,
1830               saved_model_dir,
1831               saved_model_tags,
1832               saved_model_exported_names,
1833               experimental_debug_info_func=None):
1834    """Constructor for TFLiteConverter.
1835
1836    Args:
1837      saved_model_dir: Directory of the SavedModel.
1838      saved_model_tags: Set of tags identifying the MetaGraphDef within the
1839        SavedModel to analyze. All tags in the tag set must be present. (default
1840        {tf.saved_model.SERVING}).
1841      saved_model_exported_names: Names to be exported when the saved model
1842        import path is on.
1843      experimental_debug_info_func: An experimental function to retrieve the
1844        graph debug info for a set of nodes from the `graph_def`.
1845
1846    Raises:
1847      ValueError: Invalid arguments.
1848    """
1849    super(TFLiteSavedModelConverter,
1850          self).__init__(experimental_debug_info_func)
1851    self.saved_model_dir = saved_model_dir
1852    self._saved_model_tags = saved_model_tags
1853    self._saved_model_exported_names = saved_model_exported_names
1854
1855    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1856
1857    if len(self._saved_model_exported_names) != 1:
1858      raise ValueError("Only support a single signature key.")
1859
1860    signature_key = self._saved_model_exported_names[0]
1861
1862    result = _freeze_saved_model(self.saved_model_dir, None, None, None,
1863                                 self._saved_model_tags, signature_key)
1864    self._graph_def = result[0]
1865    self._input_tensors = result[1]
1866    self._output_tensors = result[2]
1867    self._parse_saved_model_args()
1868
1869  @_export_metrics
1870  def convert(self):
1871    """Converts a TensorFlow GraphDef based on instance variables.
1872
1873    Returns:
1874      The converted data in serialized format. Either a TFLite Flatbuffer or a
1875      Graphviz graph depending on value in `output_format`.
1876
1877    Raises:
1878      ValueError:
1879        Input shape is not specified.
1880        None value for dimension in input_tensor.
1881    """
1882    return super(TFLiteSavedModelConverter, self).convert()
1883
1884
1885class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
1886  """Converts the given SavedModel into TensorFlow Lite model."""
1887
1888  def __init__(self,
1889               model_file,
1890               input_arrays=None,
1891               input_shapes=None,
1892               output_arrays=None,
1893               custom_objects=None):
1894    """Constructor for TFLiteConverter.
1895
1896    Args:
1897      model_file: Full filepath of HDF5 file containing the tf.keras model.
1898      input_arrays: List of input tensors to freeze graph with. Uses input
1899        arrays from SignatureDef when none are provided. (default None)
1900      input_shapes: Dict of strings representing input tensor names to list of
1901        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
1902        Automatically determined when input shapes is None (e.g., {"foo" :
1903          None}). (default None)
1904      output_arrays: List of output tensors to freeze graph with. Uses output
1905        arrays from SignatureDef when none are provided. (default None)
1906      custom_objects: Dict mapping names (strings) to custom classes or
1907        functions to be considered during model deserialization. (default None)
1908
1909    Raises:
1910      ValueError: Invalid arguments.
1911    """
1912    super(TFLiteKerasModelConverter,
1913          self).__init__(experimental_debug_info_func=None)
1914    # Handles Keras when Eager mode is enabled.
1915    if context.executing_eagerly():
1916      if input_arrays or output_arrays:
1917        raise ValueError("`input_arrays` and `output_arrays` are unsupported "
1918                         "with Eager mode. If your model requires any of these "
1919                         "parameters, please use disable_eager_execution().")
1920
1921      keras_model = keras_deps.get_load_model_function()(model_file,
1922                                                         custom_objects)
1923      function = _trace_model_call(keras_model)
1924      concrete_func = function.get_concrete_function()
1925
1926      frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
1927          concrete_func, lower_control_flow=False)
1928      _set_tensor_shapes(frozen_func.inputs, input_shapes)
1929      self._keras_model = keras_model
1930      self._graph_def = frozen_func.graph.as_graph_def()
1931      self._input_tensors = frozen_func.inputs
1932      self._output_tensors = frozen_func.outputs
1933      self._debug_info_func = _build_debug_info_func(frozen_func.graph)
1934      return
1935
1936    # Handles Keras when Eager mode is disabled.
1937    keras_deps.get_clear_session_function()()
1938    keras_model = keras_deps.get_load_model_function()(model_file,
1939                                                       custom_objects)
1940    sess = keras_deps.get_get_session_function()()
1941
1942    # Get input and output tensors.
1943    if input_arrays:
1944      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
1945    else:
1946      input_tensors = keras_model.inputs
1947
1948    if output_arrays:
1949      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
1950    else:
1951      output_tensors = keras_model.outputs
1952    _set_tensor_shapes(input_tensors, input_shapes)
1953
1954    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
1955    self._keras_model = keras_model
1956    self._graph_def = graph_def
1957    self._input_tensors = input_tensors
1958    self._output_tensors = output_tensors
1959    self._debug_info_func = _build_debug_info_func(sess.graph)
1960
1961  @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
1962  def _freeze_keras_model(self, output_dir):
1963    """Save Keras model to Saved Model format.
1964
1965    Args:
1966      output_dir: The output directory to save the SavedModel.
1967    """
1968    try:
1969      self._keras_model.save(output_dir, save_format="tf")
1970    except Exception:  # pylint: disable=broad-except
1971      # When storing the given keras model to a saved model is failed, let's
1972      # use original keras model conversion pipeline.
1973      return None
1974    tag_set = set([_tag_constants.SERVING])
1975    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1976    graph_def, input_tensors, output_tensors, sess_graph = _freeze_saved_model(
1977        output_dir, None, None, None, tag_set, signature_key)
1978
1979    self.saved_model_dir = output_dir
1980    self._saved_model_tags = tag_set
1981    self._saved_model_exported_names = [signature_key]
1982    self._parse_saved_model_args()
1983    if self.saved_model_dir:
1984      self._graph_def = graph_def
1985      self._input_tensors = input_tensors
1986      self._output_tensors = output_tensors
1987      self._debug_info_func = _build_debug_info_func(sess_graph)
1988
1989  def _convert_as_saved_model(self):
1990    """Converts a Keras model as a saved model.
1991
1992    Returns:
1993      The converted data in serialized format.
1994    """
1995    temp_dir = tempfile.mkdtemp()
1996    try:
1997      self._freeze_keras_model(temp_dir)
1998      if self.saved_model_dir:
1999        return super(TFLiteKerasModelConverter, self).convert()
2000    finally:
2001      shutil.rmtree(temp_dir, True)
2002
2003  @_export_metrics
2004  def convert(self):
2005    """Converts a Keras model based on instance variables.
2006
2007    Returns:
2008      The converted data in serialized format. Either a TFLite Flatbuffer or a
2009      Graphviz graph depending on value in `output_format`.
2010
2011    Raises:
2012      ValueError:
2013        Input shape is not specified.
2014        None value for dimension in input_tensor.
2015    """
2016    saved_model_convert_result = self._convert_as_saved_model()
2017    if saved_model_convert_result:
2018      return saved_model_convert_result
2019
2020    return super(TFLiteKerasModelConverter, self).convert()
2021
2022
2023class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
2024  """Converts the given frozen graph def into TensorFlow Lite model."""
2025
2026  def __init__(self,
2027               graph_def,
2028               input_tensors,
2029               output_tensors,
2030               input_arrays_with_shape=None,
2031               output_arrays=None,
2032               experimental_debug_info_func=None):
2033    """Constructor for TFLiteConverter.
2034
2035    Args:
2036      graph_def: Frozen TensorFlow GraphDef.
2037      input_tensors: List of input tensors. Type and shape are computed using
2038        `foo.shape` and `foo.dtype`.
2039      output_tensors: List of output tensors (only .name is used from this).
2040      input_arrays_with_shape: Tuple of strings representing input tensor names
2041        and list of integers representing input shapes
2042        (e.g., [("foo", [1, 16, 16, 3])]). Use only when graph cannot be loaded
2043          into TensorFlow and when `input_tensors` and `output_tensors` are
2044          None. (default None)
2045      output_arrays: List of output tensors to freeze graph with. Use only when
2046        graph cannot be loaded into TensorFlow and when `input_tensors` and
2047        `output_tensors` are None. (default None)
2048      experimental_debug_info_func: An experimental function to retrieve the
2049        graph debug info for a set of nodes from the `graph_def`.
2050
2051    Raises:
2052      ValueError: Invalid arguments.
2053    """
2054    super(TFLiteFrozenGraphConverter,
2055          self).__init__(experimental_debug_info_func)
2056    self._graph_def = graph_def
2057    self._input_tensors = input_tensors
2058    self._output_tensors = output_tensors
2059    self._control_output_arrays = None
2060
2061    # Attributes are used by models that cannot be loaded into TensorFlow.
2062    if not self._has_valid_tensors():
2063      self._input_arrays_with_shape = input_arrays_with_shape
2064      self._output_arrays = output_arrays
2065
2066    if input_tensors is not None and input_arrays_with_shape is not None:
2067      logging.warning("input_arrays_with_shape will be ignored when both the "
2068                      "given input_tensors and input_arrays_with_shape are not "
2069                      "None.")
2070
2071    if output_tensors is not None and output_arrays is not None:
2072      logging.warning("output_arrays will be ignored when both the given "
2073                      "output_tensors and output_arrays are not None.")
2074
2075  @_export_metrics
2076  def convert(self):
2077    """Converts a TensorFlow GraphDef based on instance variables.
2078
2079    Returns:
2080      The converted data in serialized format. Either a TFLite Flatbuffer or a
2081      Graphviz graph depending on value in `output_format`.
2082
2083    Raises:
2084      ValueError:
2085        Input shape is not specified.
2086        None value for dimension in input_tensor.
2087    """
2088    if not self._has_valid_tensors():
2089      if not self._input_arrays_with_shape or not (self._output_arrays or
2090                                                   self._control_output_arrays):
2091        raise ValueError(
2092            "If input_tensors and output_tensors are None, both "
2093            "input_arrays_with_shape and output_arrays|control_output_arrays "
2094            "must be defined.")
2095    return super(TFLiteFrozenGraphConverter, self).convert()
2096
2097
2098@_tf_export(v1=["lite.TFLiteConverter"])
2099class TFLiteConverter(TFLiteFrozenGraphConverter):
2100  """Convert a TensorFlow model into `output_format`.
2101
2102  This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras
2103  model into either a TFLite FlatBuffer or graph visualization.
2104
2105  Attributes:
2106    optimizations: Experimental flag, subject to change. Set of optimizations to
2107      apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
2108      set of values of type `tf.lite.Optimize`)
2109    representative_dataset: A generator function used for integer quantization
2110      where each generated sample has the same order, type and shape as the
2111      inputs to the model. Usually, this is a small subset of a few hundred
2112      samples randomly chosen, in no particular order, from the training or
2113      evaluation dataset. This is an optional attribute, but required for full
2114      integer quantization, i.e, if `tf.int8` is the only supported type in
2115      `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
2116      (default None)
2117    target_spec: Experimental flag, subject to change. Specifications of target
2118      device, including supported ops set, supported types and a set of user's
2119      defined TensorFlow operators required in the TensorFlow Lite runtime.
2120      Refer to `tf.lite.TargetSpec`.
2121    inference_type: Data type of numeric arrays, excluding the input layer.
2122      (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
2123    inference_input_type: Data type of the numeric arrays in the input layer. If
2124      `inference_input_type` is in {tf.int8, tf.uint8}, then
2125      `quantized_input_stats` must be provided. (default is the value assigned
2126      to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
2127    inference_output_type: Data type of the numeric arrays in the output layer.
2128      (default is the value assigned to `inference_type`, must be in
2129      {tf.float32, tf.int8, tf.uint8})
2130    quantized_input_stats: Map of input tensor names to a tuple of floats
2131      representing the mean and standard deviation of the training data.
2132      (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
2133        or tf.uint8. (default None)
2134    default_ranges_stats: Tuple of integers (min, max) representing range values
2135      for all numeric arrays without a specified range. Intended for
2136      experimenting with quantization via "dummy quantization". (default None)
2137    allow_custom_ops: Boolean indicating whether to allow custom operations.
2138      When False any unknown operation is an error. When True, custom ops are
2139      created for any op that is unknown. The developer will need to provide
2140      these to the TensorFlow Lite runtime with a custom resolver. (default
2141      False)
2142    drop_control_dependency: Boolean indicating whether to drop control
2143      dependencies silently. This is due to TFLite not supporting control
2144      dependencies. (default True)
2145    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
2146      nodes in unexpected locations. Used when the location of the FakeQuant
2147      nodes is preventing graph transformations necessary to convert the graph.
2148      Results in a graph that differs from the quantized training graph,
2149      potentially causing differing arithmetic behavior. (default False)
2150    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
2151      inputs and outputs of the concat operator for quantized models. Changes
2152      the ranges of concat operator overlap when true. (default False)
2153    output_format: Output file format. (default
2154      tf.compat.v1.lite.constants.TFLITE, must be in
2155      {tf.compat.v1.lite.constants.TFLITE,
2156      tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
2157    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
2158      stages of processing GraphViz .dot files. Preferred over
2159      `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
2160      the requirements of the output file. (default None)
2161    dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
2162      files after every graph transformation. Requires the `dump_graphviz_dir`
2163      flag to be specified. (default False)
2164    conversion_summary_dir: Full path of the directory to store conversion logs.
2165      (default None)
2166    target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
2167    post_training_quantize: Deprecated. Please use `optimizations` instead and
2168      set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
2169    experimental_new_converter: Experimental flag, subject to change. Enables
2170      MLIR-based conversion instead of TOCO conversion. (default True)
2171    experimental_new_quantizer: Experimental flag, subject to change. Enables
2172      MLIR-based quantization conversion instead of Flatbuffer-based conversion.
2173      (default True)
2174
2175  Example usage:
2176
2177    ```python
2178    # Converting a GraphDef from session.
2179    converter = tf.compat.v1.lite.TFLiteConverter.from_session(
2180      sess, in_tensors, out_tensors)
2181    tflite_model = converter.convert()
2182    open("converted_model.tflite", "wb").write(tflite_model)
2183
2184    # Converting a GraphDef from file.
2185    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
2186      graph_def_file, input_arrays, output_arrays)
2187    tflite_model = converter.convert()
2188    open("converted_model.tflite", "wb").write(tflite_model)
2189
2190    # Converting a SavedModel.
2191    converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
2192        saved_model_dir)
2193    tflite_model = converter.convert()
2194    open("converted_model.tflite", "wb").write(tflite_model)
2195
2196    # Converting a tf.keras model.
2197    converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
2198        keras_model)
2199    tflite_model = converter.convert()
2200    open("converted_model.tflite", "wb").write(tflite_model)
2201    ```
2202  """
2203
2204  # pylint: disable=useless-super-delegation
2205  def __init__(self,
2206               graph_def,
2207               input_tensors,
2208               output_tensors,
2209               input_arrays_with_shape=None,
2210               output_arrays=None,
2211               experimental_debug_info_func=None):
2212    """Constructor for TFLiteConverter.
2213
2214    Args:
2215      graph_def: Frozen TensorFlow GraphDef.
2216      input_tensors: List of input tensors. Type and shape are computed using
2217        `foo.shape` and `foo.dtype`.
2218      output_tensors: List of output tensors (only .name is used from this).
2219      input_arrays_with_shape: Tuple of strings representing input tensor names
2220        and list of integers representing input shapes
2221        (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
2222          into TensorFlow and when `input_tensors` and `output_tensors` are
2223          None. (default None)
2224      output_arrays: List of output tensors to freeze graph with. Use only when
2225        graph cannot be loaded into TensorFlow and when `input_tensors` and
2226        `output_tensors` are None. (default None)
2227      experimental_debug_info_func: An experimental function to retrieve the
2228        graph debug info for a set of nodes from the `graph_def`.
2229
2230    Raises:
2231      ValueError: Invalid arguments.
2232    """
2233    super(TFLiteConverter,
2234          self).__init__(graph_def, input_tensors, output_tensors,
2235                         input_arrays_with_shape, output_arrays,
2236                         experimental_debug_info_func)
2237
2238  @classmethod
2239  def from_session(cls, sess, input_tensors, output_tensors):
2240    """Creates a TFLiteConverter class from a TensorFlow Session.
2241
2242    Args:
2243      sess: TensorFlow Session.
2244      input_tensors: List of input tensors. Type and shape are computed using
2245        `foo.shape` and `foo.dtype`.
2246      output_tensors: List of output tensors (only .name is used from this).
2247
2248    Returns:
2249      TFLiteConverter class.
2250    """
2251    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
2252    return cls(
2253        graph_def,
2254        input_tensors,
2255        output_tensors,
2256        experimental_debug_info_func=_build_debug_info_func(sess.graph))
2257
2258  @classmethod
2259  def from_frozen_graph(cls,
2260                        graph_def_file,
2261                        input_arrays,
2262                        output_arrays,
2263                        input_shapes=None):
2264    """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
2265
2266    Args:
2267      graph_def_file: Full filepath of file containing frozen GraphDef.
2268      input_arrays: List of input tensors to freeze graph with.
2269      output_arrays: List of output tensors to freeze graph with.
2270      input_shapes: Dict of strings representing input tensor names to list of
2271        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2272        Automatically determined when input shapes is None (e.g., {"foo" :
2273          None}). (default None)
2274
2275    Returns:
2276      TFLiteConverter class.
2277
2278    Raises:
2279      IOError:
2280        File not found.
2281        Unable to parse input file.
2282      ValueError:
2283        The graph is not frozen.
2284        input_arrays or output_arrays contains an invalid tensor name.
2285        input_shapes is not correctly defined when required
2286    """
2287    with _ops.Graph().as_default():
2288      with _session.Session() as sess:
2289        # Read GraphDef from file.
2290        if not gfile.Exists(graph_def_file):
2291          raise IOError("File '{0}' does not exist.".format(graph_def_file))
2292        with gfile.GFile(graph_def_file, "rb") as f:
2293          file_content = f.read()
2294
2295        try:
2296          graph_def = _graph_pb2.GraphDef()
2297          graph_def.ParseFromString(file_content)
2298        except (_text_format.ParseError, DecodeError):
2299          try:
2300            print("Ignore 'tcmalloc: large alloc' warnings.")
2301
2302            if not isinstance(file_content, str):
2303              if PY2:
2304                file_content = six.ensure_binary(file_content, "utf-8")
2305              else:
2306                file_content = six.ensure_text(file_content, "utf-8")
2307            graph_def = _graph_pb2.GraphDef()
2308            _text_format.Merge(file_content, graph_def)
2309          except (_text_format.ParseError, DecodeError):
2310            raise IOError(
2311                "Unable to parse input file '{}'.".format(graph_def_file))
2312
2313        # Handles models with custom TFLite ops that cannot be resolved in
2314        # TensorFlow.
2315        load_model_in_session = True
2316        try:
2317          _import_graph_def(graph_def, name="")
2318        except _NotFoundError:
2319          load_model_in_session = False
2320
2321        if load_model_in_session:
2322          # Check if graph is frozen.
2323          if not _is_frozen_graph(sess):
2324            raise ValueError("Please freeze the graph using freeze_graph.py.")
2325
2326          # Get input and output tensors.
2327          input_tensors = _get_tensors_from_tensor_names(
2328              sess.graph, input_arrays)
2329          output_tensors = _get_tensors_from_tensor_names(
2330              sess.graph, output_arrays)
2331          _set_tensor_shapes(input_tensors, input_shapes)
2332
2333          return cls(sess.graph_def, input_tensors, output_tensors)
2334        else:
2335          if not input_shapes:
2336            raise ValueError("input_shapes must be defined for this model.")
2337          if set(input_arrays) != set(input_shapes.keys()):
2338            raise ValueError("input_shapes must contain a value for each item "
2339                             "in input_array.")
2340
2341          input_arrays_with_shape = [
2342              (name, input_shapes[name]) for name in input_arrays
2343          ]
2344          return cls(
2345              graph_def,
2346              input_tensors=None,
2347              output_tensors=None,
2348              input_arrays_with_shape=input_arrays_with_shape,
2349              output_arrays=output_arrays)
2350
2351  @classmethod
2352  def from_saved_model(cls,
2353                       saved_model_dir,
2354                       input_arrays=None,
2355                       input_shapes=None,
2356                       output_arrays=None,
2357                       tag_set=None,
2358                       signature_key=None):
2359    """Creates a TFLiteConverter class from a SavedModel.
2360
2361    Args:
2362      saved_model_dir: SavedModel directory to convert.
2363      input_arrays: List of input tensors to freeze graph with. Uses input
2364        arrays from SignatureDef when none are provided. (default None)
2365      input_shapes: Dict of strings representing input tensor names to list of
2366        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2367        Automatically determined when input shapes is None (e.g., {"foo" :
2368          None}). (default None)
2369      output_arrays: List of output tensors to freeze graph with. Uses output
2370        arrays from SignatureDef when none are provided. (default None)
2371      tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
2372        analyze. All tags in the tag set must be present. (default
2373        {tf.saved_model.SERVING})
2374      signature_key: Key identifying SignatureDef containing inputs and outputs.
2375        (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
2376
2377    Returns:
2378      TFLiteConverter class.
2379    """
2380    if tag_set is None:
2381      tag_set = set([_tag_constants.SERVING])
2382    if signature_key is None:
2383      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
2384
2385    saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set,
2386                                                      [signature_key])
2387    if saved_model_converter.saved_model_dir:
2388      return saved_model_converter
2389
2390    result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
2391                                 output_arrays, tag_set, signature_key)
2392
2393    return cls(
2394        graph_def=result[0],
2395        input_tensors=result[1],
2396        output_tensors=result[2],
2397        experimental_debug_info_func=_build_debug_info_func(result[3]))
2398
2399  @classmethod
2400  def from_keras_model_file(cls,
2401                            model_file,
2402                            input_arrays=None,
2403                            input_shapes=None,
2404                            output_arrays=None,
2405                            custom_objects=None):
2406    """Creates a TFLiteConverter class from a tf.keras model file.
2407
2408    Args:
2409      model_file: Full filepath of HDF5 file containing the tf.keras model.
2410      input_arrays: List of input tensors to freeze graph with. Uses input
2411        arrays from SignatureDef when none are provided. (default None)
2412      input_shapes: Dict of strings representing input tensor names to list of
2413        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2414        Automatically determined when input shapes is None (e.g., {"foo" :
2415          None}). (default None)
2416      output_arrays: List of output tensors to freeze graph with. Uses output
2417        arrays from SignatureDef when none are provided. (default None)
2418      custom_objects: Dict mapping names (strings) to custom classes or
2419        functions to be considered during model deserialization. (default None)
2420
2421    Returns:
2422      TFLiteConverter class.
2423    """
2424    return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes,
2425                                     output_arrays, custom_objects)
2426
2427  # pylint: disable=useless-super-delegation
2428  def convert(self):
2429    """Converts a TensorFlow GraphDef based on instance variables.
2430
2431    Returns:
2432      The converted data in serialized format. Either a TFLite Flatbuffer or a
2433      Graphviz graph depending on value in `output_format`.
2434
2435    Raises:
2436      ValueError:
2437        Input shape is not specified.
2438        None value for dimension in input_tensor.
2439    """
2440    return super(TFLiteConverter, self).convert()
2441
2442
2443@_tf_export(v1=["lite.TocoConverter"])
2444class TocoConverter(object):
2445  """Convert a TensorFlow model into `output_format` using TOCO.
2446
2447  This class has been deprecated. Please use `lite.TFLiteConverter` instead.
2448  """
2449
2450  @classmethod
2451  @_deprecation.deprecated(None,
2452                           "Use `lite.TFLiteConverter.from_session` instead.")
2453  def from_session(cls, sess, input_tensors, output_tensors):
2454    """Creates a TocoConverter class from a TensorFlow Session."""
2455    return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
2456
2457  @classmethod
2458  @_deprecation.deprecated(
2459      None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
2460  def from_frozen_graph(cls,
2461                        graph_def_file,
2462                        input_arrays,
2463                        output_arrays,
2464                        input_shapes=None):
2465    """Creates a TocoConverter class from a file containing a frozen graph."""
2466    return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
2467                                             output_arrays, input_shapes)
2468
2469  @classmethod
2470  @_deprecation.deprecated(
2471      None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
2472  def from_saved_model(cls,
2473                       saved_model_dir,
2474                       input_arrays=None,
2475                       input_shapes=None,
2476                       output_arrays=None,
2477                       tag_set=None,
2478                       signature_key=None):
2479    """Creates a TocoConverter class from a SavedModel."""
2480    return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
2481                                            input_shapes, output_arrays,
2482                                            tag_set, signature_key)
2483
2484  @classmethod
2485  @_deprecation.deprecated(
2486      None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
2487  def from_keras_model_file(cls,
2488                            model_file,
2489                            input_arrays=None,
2490                            input_shapes=None,
2491                            output_arrays=None):
2492    """Creates a TocoConverter class from a tf.keras model file."""
2493    return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
2494                                                 input_shapes, output_arrays)
2495