• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Converts a frozen graph into a TFLite FlatBuffer."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import distutils.spawn
23import enum  # pylint: disable=g-bad-import-order
24import os as _os
25import platform as _platform
26import subprocess as _subprocess
27import tempfile as _tempfile
28
29import six
30from six.moves import map
31
32from tensorflow.lite.python import lite_constants
33from tensorflow.lite.python import util
34from tensorflow.lite.python import wrap_toco
35from tensorflow.lite.python.convert_phase import Component
36from tensorflow.lite.python.convert_phase import convert_phase
37from tensorflow.lite.python.convert_phase import ConverterError
38from tensorflow.lite.python.convert_phase import SubComponent
39from tensorflow.lite.python.metrics_wrapper import metrics_wrapper as _metrics_wrapper
40from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
41from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
42from tensorflow.lite.toco import types_pb2 as _types_pb2
43from tensorflow.lite.tools import flatbuffer_utils
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import tensor_shape
46from tensorflow.python.platform import resource_loader as _resource_loader
47from tensorflow.python.util import deprecation
48from tensorflow.python.util.tf_export import tf_export as _tf_export
49
50
51def _requires_input_stats(toco_flags: _toco_flags_pb2.TocoFlags()) -> bool:
52  """Checks if the `input_stats` flag is required for conversion.
53
54  Args:
55    toco_flags: A protocol buffer describing the conversion process.
56
57  Returns:
58    True, if the `inference_type` or the `inference_input_type` is a quantized
59    type and it is not post training quantization, else False.
60  """
61  quantized_inference_types = (
62      [_types_pb2.QUANTIZED_UINT8, _types_pb2.QUANTIZED_INT8])
63  return ((toco_flags.inference_type in quantized_inference_types or
64           toco_flags.inference_input_type in quantized_inference_types) and
65          not toco_flags.post_training_quantize)
66
67
68def convert_tensor_tf_type_to_tflite_type(
69    tf_type: dtypes.DType, usage: str = "") -> _types_pb2.IODataType:
70  """Convert tensor type from tf type to tflite type.
71
72  Args:
73    tf_type: TensorFlow type.
74    usage: Text describing the reason for invoking this function.
75
76  Raises:
77    ValueError: If `tf_type` is unsupported.
78
79  Returns:
80    tflite_type: TFLite type. Refer to lite/toco/types.proto.
81  """
82  mapping = {
83      dtypes.float16: _types_pb2.FLOAT16,
84      dtypes.float32: _types_pb2.FLOAT,
85      dtypes.float64: _types_pb2.FLOAT64,
86      dtypes.int8: _types_pb2.INT8,
87      dtypes.int16: _types_pb2.INT16,
88      dtypes.int32: _types_pb2.INT32,
89      dtypes.int64: _types_pb2.INT64,
90      dtypes.uint8: _types_pb2.UINT8,
91      dtypes.uint32: _types_pb2.UINT32,
92      dtypes.uint64: _types_pb2.UINT64,
93      dtypes.string: _types_pb2.STRING,
94      dtypes.bool: _types_pb2.BOOL,
95      dtypes.complex64: _types_pb2.COMPLEX64,
96      dtypes.complex128: _types_pb2.COMPLEX128,
97  }
98  tflite_type = mapping.get(tf_type)
99  if tflite_type is None:
100    raise ValueError("Unsupported TensorFlow type `{0}` provided for the {1}"
101                     .format(tf_type, usage))
102  return tflite_type
103
104
105# Only a few restricted tensor types are allowed for explicitly setting
106# inference/input/output types.
107def convert_inference_tf_type_to_tflite_type(
108    tf_type: dtypes.DType, usage: str = "") -> _types_pb2.IODataType:
109  """Convert inference type from tf type to tflite type.
110
111  Args:
112    tf_type: TensorFlow type.
113    usage: Text describing the reason for invoking this function.
114
115  Raises:
116    ValueError: If `tf_type` is unsupported.
117
118  Returns:
119    tflite_type: TFLite type. Refer to lite/toco/types.proto.
120  """
121  mapping = {
122      dtypes.float32: _types_pb2.FLOAT,
123      dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
124      dtypes.int8: _types_pb2.QUANTIZED_INT8,
125      dtypes.int16: _types_pb2.QUANTIZED_INT16,
126  }
127  tflite_type = mapping.get(tf_type)
128  if tflite_type is None:
129    raise ValueError("Unsupported TensorFlow type `{0}` provided for the {1}"
130                     .format(tf_type, usage))
131  return tflite_type
132
133
134# Find the toco_from_protos binary using the resource loader if using from
135# bazel, otherwise we are in a pip where console_scripts already has
136# the toco_from_protos tool.
137if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
138  _toco_from_proto_bin = ""
139else:
140  _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
141      "../toco/python/toco_from_protos")
142
143if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
144  _toco_from_proto_bin = "toco_from_protos"
145
146
147def _try_convert_to_unicode(output):
148  if output is None:
149    return u""
150
151  if isinstance(output, bytes):
152    try:
153      return six.ensure_text(output)
154    except UnicodeDecodeError:
155      pass
156  return output
157
158
159@_tf_export("lite.OpsSet")
160class OpsSet(enum.Enum):
161  """Enum class defining the sets of ops available to generate TFLite models.
162
163  WARNING: Experimental interface, subject to change.
164  """
165  # Convert model using TensorFlow Lite builtin ops.
166  TFLITE_BUILTINS = "TFLITE_BUILTINS"
167
168  # Convert model using TensorFlow ops. Not all TensorFlow ops are available.
169  # WARNING: Experimental interface, subject to change.
170  SELECT_TF_OPS = "SELECT_TF_OPS"
171
172  # Convert model using only TensorFlow Lite quantized int8 operations.
173  # Specifying this will throw an error for operations that do not yet have
174  # quantized implementations.
175  TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
176
177  # Convert model using only TensorFlow Lite operations with quantized int8
178  # weights, int16 activations and int64 bias.
179  # Specifying this will throw an error for operations that do not yet have
180  # quantized implementations.
181  # This quantization mode may be used in models for super-resolution,
182  # audio signal processing or image de-noising. It improves accuracy
183  # significantly, but only slightly increases the model size.
184  # WARNING: These ops are currently experimental and have not yet been
185  # finalized.
186  # They are only compatible with CPU execution, and have not been optimized for
187  # production.
188  EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = (
189      "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8")
190
191  def __str__(self):
192    return str(self.value)
193
194  @staticmethod
195  def get_options():
196    """Returns a list of OpsSet options as a list of strings."""
197    return [str(option) for option in list(OpsSet)]
198
199
200@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.QUANTIZE)
201def mlir_quantize(input_data_str,
202                  disable_per_channel=False,
203                  fully_quantize=False,
204                  inference_type=_types_pb2.QUANTIZED_INT8,
205                  input_data_type=dtypes.float32,
206                  output_data_type=dtypes.float32,
207                  enable_numeric_verify=False,
208                  enable_whole_model_verify=False,
209                  denylisted_ops=None,
210                  denylisted_nodes=None):
211  """Quantize `input_data_str` with calibration results.
212
213  Args:
214    input_data_str: Input data in serialized form (e.g. a TFLITE model with
215      calibration results).
216    disable_per_channel: Bool indicating whether to do per-channel or per-tensor
217      quantization
218    fully_quantize: Bool indicating whether to fully quantize the model. Besides
219      model body, the input/output will be quantized as well.
220    inference_type: Data type for the activations. The default value is int8.
221    input_data_type: Data type for the inputs. The default value is float32.
222    output_data_type: Data type for the outputs. The default value is float32.
223    enable_numeric_verify: Experimental. Subject to change. Bool indicating
224      whether to add NumericVerify ops into the debug mode quantized model.
225    enable_whole_model_verify: Experimental. Subject to change. Bool indicating
226    whether to add verification for layer by layer, or on whole model. When
227    disabled (per-layer) float and quantized ops will be run from same input
228    (output of previous quantized layer). When enabled, float and quantized ops
229    will run with respective float and quantized output of previous ops.
230    denylisted_ops: Experimental. Subject to change. Set of ops to denylist.
231    denylisted_nodes: Experimental. Subject to change. Set of notes to
232      denylist.
233  Returns:
234    Quantized model in serialized form (e.g. a TFLITE model) with floating-point
235    inputs and outputs.
236  """
237  return wrap_toco.wrapped_experimental_mlir_quantize(
238      input_data_str, disable_per_channel, fully_quantize, inference_type,
239      convert_tensor_tf_type_to_tflite_type(input_data_type),
240      convert_tensor_tf_type_to_tflite_type(output_data_type),
241      enable_numeric_verify, enable_whole_model_verify, denylisted_ops,
242      denylisted_nodes)
243
244
245@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.SPARSIFY)
246def mlir_sparsify(input_data_str):
247  """Sparsify `input_data_str` to encode sparse tensor with proper format.
248
249  Args:
250    input_data_str: Input data in serialized form (e.g. a TFLITE model).
251
252  Returns:
253    Sparsified model in serialized form (e.g. a TFLITE model).
254  """
255  return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str)
256
257
258def register_custom_opdefs(custom_opdefs_list):
259  """Register the given custom opdefs to the TensorFlow global op registry.
260
261  Args:
262    custom_opdefs_list: String representing the custom ops OpDefs that are
263      included in the GraphDef.
264
265  Returns:
266    True if the registration is successfully completed.
267  """
268  return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list)
269
270
271def toco_convert_protos(model_flags_str,
272                        toco_flags_str,
273                        input_data_str,
274                        debug_info_str=None,
275                        enable_mlir_converter=False):
276  """Convert `input_data_str` according to model and toco parameters.
277
278  Unless you know what you are doing consider using
279  the more friendly `tf.compat.v1.lite.toco_convert`.
280
281  Args:
282    model_flags_str: Serialized proto describing model properties, see
283      `toco/model_flags.proto`.
284    toco_flags_str: Serialized proto describing conversion properties, see
285      `toco/toco_flags.proto`.
286    input_data_str: Input data in serialized form (e.g. a graphdef is common)
287    debug_info_str: Serialized `GraphDebugInfo` proto describing logging
288      information. (default None)
289    enable_mlir_converter: Enables MLIR-based conversion instead of the default
290      TOCO conversion. (default False)
291
292  Returns:
293    Converted model in serialized form (e.g. a TFLITE model is common).
294  Raises:
295    ConverterError: When conversion fails in TFLiteConverter, usually due to
296      ops not being supported.
297    RuntimeError: When conversion fails, an exception is raised with the error
298      message embedded.
299  """
300  # Historically, TOCO conversion failures would trigger a crash, so we would
301  # attempt to run the converter out-of-process. The MLIR conversion pipeline
302  # surfaces errors instead, and can be safely run in-process.
303  if enable_mlir_converter or not _toco_from_proto_bin:
304    try:
305      model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
306                                                 toco_flags_str, input_data_str,
307                                                 debug_info_str,
308                                                 enable_mlir_converter)
309      return model_str
310    except Exception as e:
311      converter_error = ConverterError(str(e))
312      for error_data in _metrics_wrapper.retrieve_collected_errors():
313        converter_error.append_error(error_data)
314      raise converter_error
315
316  return _run_toco_binary(model_flags_str, toco_flags_str, input_data_str,
317                          debug_info_str)
318
319
320@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
321               SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER)
322def _run_toco_binary(model_flags_str,
323                     toco_flags_str,
324                     input_data_str,
325                     debug_info_str=None):
326  """Convert `input_data_str` using TOCO converter binary.
327
328  Args:
329    model_flags_str: Serialized proto describing model properties, see
330      `toco/model_flags.proto`.
331    toco_flags_str: Serialized proto describing conversion properties, see
332      `toco/toco_flags.proto`.
333    input_data_str: Input data in serialized form (e.g. a graphdef is common)
334    debug_info_str: Serialized `GraphDebugInfo` proto describing logging
335      information. (default None)
336
337  Returns:
338    Converted model in serialized form (e.g. a TFLITE model is common).
339  Raises:
340    ConverterError: When cannot find the toco binary.
341    RuntimeError: When conversion fails, an exception is raised with the error
342      message embedded.
343  """
344  if distutils.spawn.find_executable(_toco_from_proto_bin) is None:
345    raise ConverterError("""Could not find toco_from_protos binary, make sure
346your virtualenv bin directory or pip local bin directory is in your path.
347In particular, if you have installed TensorFlow with --user, make sure you
348add the install directory to your path.
349
350For example:
351Linux: export PATH=$PATH:~/.local/bin/
352Mac: export PATH=$PATH:~/Library/Python/<version#>/bin
353
354Alternative, use virtualenv.""")
355  # Windows and TemporaryFile are not that useful together,
356  # since you cannot have two readers/writers. So we have to
357  # make the temporaries and close and delete them explicitly.
358  toco_filename, model_filename, input_filename, output_filename = (None, None,
359                                                                    None, None)
360  try:
361    # Build all input files
362    with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
363             _tempfile.NamedTemporaryFile(delete=False) as fp_model, \
364             _tempfile.NamedTemporaryFile(delete=False) as fp_input, \
365             _tempfile.NamedTemporaryFile(delete=False) as fp_debug:
366      toco_filename = fp_toco.name
367      input_filename = fp_input.name
368      model_filename = fp_model.name
369      debug_filename = fp_debug.name
370
371      fp_model.write(model_flags_str)
372      fp_toco.write(toco_flags_str)
373      fp_input.write(six.ensure_binary(input_data_str))
374      debug_info_str = debug_info_str if debug_info_str else ""
375      # if debug_info_str contains a "string value", then the call to
376      # fp_debug.write(debug_info_str) will fail with the following error
377      #
378      # TypeError: a bytes-like object is required, not 'str'
379      #
380      # Some of the subtests within the "convert_test" unit-test fail
381      # with the error shown above. So watch out for that scenario and
382      # convert debug_info_str to bytes where needed
383      if not isinstance(debug_info_str, bytes):
384        fp_debug.write(debug_info_str.encode("utf-8"))
385      else:
386        fp_debug.write(debug_info_str)
387
388    # Reserve an output file
389    with _tempfile.NamedTemporaryFile(delete=False) as fp:
390      output_filename = fp.name
391
392    # Run
393    cmd = [
394        _toco_from_proto_bin,
395        model_filename,
396        toco_filename,
397        input_filename,
398        output_filename,
399        "--debug_proto_file={}".format(debug_filename),
400    ]
401    cmdline = " ".join(cmd)
402    is_windows = _platform.system() == "Windows"
403    proc = _subprocess.Popen(
404        cmdline,
405        shell=True,
406        stdout=_subprocess.PIPE,
407        stderr=_subprocess.STDOUT,
408        close_fds=not is_windows)
409    stdout, stderr = proc.communicate()
410    exitcode = proc.returncode
411    if exitcode == 0:
412      with open(output_filename, "rb") as fp:
413        return fp.read()
414    else:
415      stdout = _try_convert_to_unicode(stdout)
416      stderr = _try_convert_to_unicode(stderr)
417      raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr))
418  finally:
419    # Must manually cleanup files.
420    for filename in [
421        toco_filename, input_filename, model_filename, output_filename
422    ]:
423      try:
424        _os.unlink(filename)
425      except (OSError, TypeError):
426        pass
427
428
429def build_toco_flags(inference_type=dtypes.float32,
430                     inference_input_type=None,
431                     input_format=lite_constants.TENSORFLOW_GRAPHDEF,
432                     output_format=lite_constants.TFLITE,
433                     default_ranges_stats=None,
434                     drop_control_dependency=True,
435                     reorder_across_fake_quant=False,
436                     allow_custom_ops=False,
437                     post_training_quantize=False,
438                     quantize_to_float16=False,
439                     dump_graphviz_dir=None,
440                     dump_graphviz_video=False,
441                     target_ops=None,
442                     conversion_summary_dir=None,
443                     select_user_tf_ops=None,
444                     allow_all_select_tf_ops=False,
445                     enable_tflite_resource_variables=False,
446                     unfold_batchmatmul=True,
447                     lower_tensor_list_ops=True,
448                     accumulation_type=None,
449                     allow_bfloat16=False,
450                     unfold_large_splat_constant=False,
451                     **_):
452  """Build the TOCO flags object from params."""
453  toco = _toco_flags_pb2.TocoFlags()
454  toco.input_format = input_format
455  toco.output_format = output_format
456  toco.inference_type = convert_inference_tf_type_to_tflite_type(
457      inference_type, usage="inference_type flag")
458  if inference_input_type:
459    toco.inference_input_type = convert_inference_tf_type_to_tflite_type(
460        inference_input_type, usage="inference_input_type flag")
461  else:
462    toco.inference_input_type = toco.inference_type
463  toco.drop_control_dependency = drop_control_dependency
464  toco.reorder_across_fake_quant = reorder_across_fake_quant
465  toco.allow_custom_ops = allow_custom_ops
466  if select_user_tf_ops:
467    toco.select_user_tf_ops.extend(select_user_tf_ops)
468  toco.allow_all_select_tf_ops = allow_all_select_tf_ops
469  toco.post_training_quantize = post_training_quantize
470  toco.quantize_to_float16 = quantize_to_float16
471  if default_ranges_stats:
472    toco.default_ranges_min = default_ranges_stats[0]
473    toco.default_ranges_max = default_ranges_stats[1]
474  if dump_graphviz_dir:
475    toco.dump_graphviz_dir = dump_graphviz_dir
476  toco.dump_graphviz_include_video = dump_graphviz_video
477  if conversion_summary_dir:
478    toco.conversion_summary_dir = conversion_summary_dir
479  if target_ops:
480    if OpsSet.SELECT_TF_OPS in set(target_ops):
481      toco.enable_select_tf_ops = True
482    if set(target_ops) == set([OpsSet.SELECT_TF_OPS]):
483      toco.force_select_tf_ops = True
484  toco.enable_tflite_resource_variables = enable_tflite_resource_variables
485  toco.unfold_batchmatmul = unfold_batchmatmul
486  toco.lower_tensor_list_ops = lower_tensor_list_ops
487  toco.unfold_large_splat_constant = unfold_large_splat_constant
488  if accumulation_type:
489    toco.accumulation_type = convert_tensor_tf_type_to_tflite_type(
490        accumulation_type, usage="accumulation_type flag")
491  toco.allow_bfloat16 = allow_bfloat16
492
493  return toco
494
495
496def build_toco_convert_protos(input_tensors,
497                              output_tensors,
498                              inference_type=dtypes.float32,
499                              inference_input_type=None,
500                              input_format=lite_constants.TENSORFLOW_GRAPHDEF,
501                              input_shapes=None,
502                              output_format=lite_constants.TFLITE,
503                              quantized_input_stats=None,
504                              default_ranges_stats=None,
505                              drop_control_dependency=True,
506                              reorder_across_fake_quant=False,
507                              allow_custom_ops=False,
508                              change_concat_input_ranges=False,
509                              post_training_quantize=False,
510                              quantize_to_float16=False,
511                              dump_graphviz_dir=None,
512                              dump_graphviz_video=False,
513                              target_ops=None,
514                              allow_nonexistent_arrays=False,
515                              debug_info=None,
516                              conversion_summary_dir=None,
517                              saved_model_dir=None,
518                              saved_model_version=0,
519                              saved_model_tags=None,
520                              saved_model_exported_names=None,
521                              select_user_tf_ops=None,
522                              allow_all_select_tf_ops=False,
523                              unfold_batchmatmul=True,
524                              lower_tensor_list_ops=True,
525                              accumulation_type=None,
526                              allow_bfloat16=False,
527                              unfold_large_splat_constant=False):
528  """Builds protocol buffers describing a conversion of a model using TOCO.
529
530  Typically this is to convert from TensorFlow GraphDef to TFLite, in which
531  case the default `input_format` and `output_format` are sufficient.
532
533  Args:
534    input_tensors: List of input tensors. Type and shape are computed using
535      `foo.shape` and `foo.dtype`.
536    output_tensors: List of output tensors (only .name is used from this).
537    inference_type: Data type of numeric arrays, excluding the input layer.
538      (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
539    inference_input_type: Data type of the numeric arrays in the input layer. If
540      `inference_input_type` is in {tf.int8, tf.uint8}, then
541      `quantized_input_stats` must be provided. (default is the value assigned
542      to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
543    input_format: Type of data to read.
544      (default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF})
545    input_shapes: Input array shape. (default None, must be None or a list of
546      the same length as `input_tensors`.)
547    output_format: Output file format. (default TFLITE, must be in
548    {TFLITE, GRAPHVIZ_DOT})
549    quantized_input_stats: Map of input tensor names to a tuple of floats
550      representing the mean and standard deviation of the training data.
551      (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
552        or tf.uint8. (default None)
553    default_ranges_stats: Tuple of integers representing (min, max) range values
554      for all arrays without a specified range. Intended for experimenting with
555      quantization via "dummy quantization". (default None)
556    drop_control_dependency: Boolean indicating whether to drop control
557      dependencies silently. This is due to TFLite not supporting control
558      dependencies. (default True)
559    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
560      nodes in unexpected locations. Used when the location of the FakeQuant
561      nodes is preventing graph transformations necessary to convert the graph.
562      Results in a graph that differs from the quantized training graph,
563      potentially causing differing arithmetic behavior. (default False)
564    allow_custom_ops: Boolean indicating whether to allow custom operations.
565      When false any unknown operation is an error. When true, custom ops are
566      created for any op that is unknown. The developer will need to provide
567      these to the TensorFlow Lite runtime with a custom resolver. (default
568      False)
569    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
570      inputs and outputs of the concat operator for quantized models. Changes
571      the ranges of concat operator overlap when true. (default False)
572    post_training_quantize: Boolean indicating whether to quantize the weights
573      of the converted float model. Model size will be reduced and there will be
574      latency improvements (at the cost of accuracy). (default False)
575    quantize_to_float16: Boolean indicating whether to convert float buffers to
576      float16. (default False)
577    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
578      stages of processing GraphViz .dot files. Preferred over
579      --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
580      output file. (default None)
581    dump_graphviz_video: Boolean indicating whether to dump the graph after
582      every graph transformation. (default False)
583    target_ops: Experimental flag, subject to change. Set of OpsSet options
584      indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS]))
585    allow_nonexistent_arrays: Allow specifying array names that don't exist or
586      are unused in the final graph. (default False)
587    debug_info: `GraphDebugInfo` proto containing the stack traces for the
588      original nodes referred by the converted graph.
589    conversion_summary_dir: A string, the path to the generated conversion logs.
590    saved_model_dir: Filepath of the saved model to be converted. This value
591      will be non-empty only when the saved model import path will be used.
592      Otherwises, the graph def-based conversion will be processed.
593    saved_model_version: SavedModel file format version of The saved model file
594      to be converted. This value will be set only when the SavedModel import
595      path will be used.
596    saved_model_tags: Set of string saved model tags, formatted in the
597      comma-separated value. This value will be set only when the SavedModel
598      import path will be used.
599    saved_model_exported_names: Names to be exported (default: export all) when
600      the saved model import path is on. This value will be set only when the
601      SavedModel import path will be used.
602    select_user_tf_ops: List of user's defined TensorFlow ops need to be
603      supported in the TensorFlow Lite runtime. These ops will be supported as
604      select TensorFlow ops.
605    allow_all_select_tf_ops: If True, automatically add all TF ops (including
606      custom TF ops) to the converted model as flex ops.
607    unfold_batchmatmul: Whether to unfold tf.BatchMatMul to a set of
608      tfl.fully_connected ops. If not, translate to tfl.batch_matmul.
609    lower_tensor_list_ops: Whether to lower tensor list ops to builtin ops. If
610      not, use Flex tensor list ops.
611    accumulation_type: Data type of the accumulators in quantized inference.
612      Typically used for float16 quantization and is either fp16 or fp32.
613    allow_bfloat16: Whether the converted model supports reduced precision
614      inference with the bfloat16 type.
615    unfold_large_splat_constant: Whether to unfold large splat constant tensors
616      in the flatbuffer model to reduce size.
617
618  Returns:
619    model_flags, toco_flags, debug_info: three protocol buffers describing the
620      conversion process and debug information.
621
622  Raises:
623    ValueError:
624      If the input tensor type is unknown
625      Missing mean_values or std_dev_values
626    RuntimeError: If TOCO fails to convert (in which case the runtime error's
627      error text will contain the TOCO error log)
628  """
629  toco = build_toco_flags(
630      inference_type=inference_type,
631      inference_input_type=inference_input_type,
632      input_format=input_format,
633      output_format=output_format,
634      default_ranges_stats=default_ranges_stats,
635      drop_control_dependency=drop_control_dependency,
636      reorder_across_fake_quant=reorder_across_fake_quant,
637      allow_custom_ops=allow_custom_ops,
638      post_training_quantize=post_training_quantize,
639      quantize_to_float16=quantize_to_float16,
640      dump_graphviz_dir=dump_graphviz_dir,
641      dump_graphviz_video=dump_graphviz_video,
642      target_ops=target_ops,
643      conversion_summary_dir=conversion_summary_dir,
644      select_user_tf_ops=select_user_tf_ops,
645      allow_all_select_tf_ops=allow_all_select_tf_ops,
646      unfold_batchmatmul=unfold_batchmatmul,
647      lower_tensor_list_ops=lower_tensor_list_ops,
648      accumulation_type=accumulation_type,
649      allow_bfloat16=allow_bfloat16,
650      unfold_large_splat_constant=unfold_large_splat_constant)
651  model = _model_flags_pb2.ModelFlags()
652  model.change_concat_input_ranges = change_concat_input_ranges
653  for idx, input_tensor in enumerate(input_tensors):
654    input_array = model.input_arrays.add()
655    if saved_model_dir:
656      input_array.name = input_tensor.name
657    else:
658      input_array.name = util.get_tensor_name(input_tensor)
659    input_array.data_type = convert_tensor_tf_type_to_tflite_type(
660        input_tensor.dtype, usage="input type of the TensorFlow model")
661
662    if _requires_input_stats(toco) and quantized_input_stats:
663      input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
664
665    if input_shapes is None:
666      shape = input_tensor.shape
667    else:
668      shape = input_shapes[idx]
669
670    if shape.rank is not None:
671      # Create shapes with -1 for unknown dimensions.
672      dims = []
673      for dim in shape:
674        if (dim is None or
675            (isinstance(dim, tensor_shape.Dimension) and dim.value is None)):
676          dims.append(-1)
677        else:
678          dims.append(int(dim))
679      input_array.shape.dims.extend(dims)
680      input_array.shape.unknown_rank = False
681    else:
682      input_array.shape.unknown_rank = True
683
684  for output_tensor in output_tensors:
685    if saved_model_dir:
686      model.output_arrays.append(output_tensor.name)
687    else:
688      model.output_arrays.append(util.get_tensor_name(output_tensor))
689
690  model.allow_nonexistent_arrays = allow_nonexistent_arrays
691
692  if saved_model_dir:
693    model.saved_model_dir = saved_model_dir
694  model.saved_model_version = saved_model_version
695  if saved_model_tags:
696    model.saved_model_tags.extend(saved_model_tags)
697  if saved_model_exported_names:
698    model.saved_model_exported_names.extend(saved_model_exported_names)
699
700  return model, toco, debug_info
701
702
703@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
704               SubComponent.CONVERT_GRAPHDEF)
705def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
706                           enable_mlir_converter, control_output_arrays, *args,
707                           **kwargs):
708  """"Convert a model using TOCO.
709
710  This function is used to convert GraphDefs that cannot be loaded into
711  TensorFlow to TFLite. Conversion can be customized by providing arguments
712  that are forwarded to `build_toco_convert_protos` (see documentation for
713  details).
714
715  Args:
716    input_data: Input data (i.e. often `sess.graph_def`),
717    input_arrays_with_shape: Tuple of strings representing input tensor names
718      and list of integers representing input shapes
719      (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
720        into TensorFlow and when `input_tensors` is None.
721    output_arrays: List of output tensors to freeze graph with. Use only when
722      graph cannot be loaded into TensorFlow and when `output_tensors` is None.
723    enable_mlir_converter: Enables MLIR-based conversion instead of TOCO
724      conversion.
725    control_output_arrays: Control output node names. This is used when
726      converting a Graph with no output tensors. For example, if the
727      graph's last operation is a Print op, just specify that op's name in
728      this field. This can be used together with the `output_arrays`
729      parameter.
730    *args: See `build_toco_convert_protos`,
731    **kwargs: See `build_toco_convert_protos`.
732
733  Returns:
734    The converted data. For example if TFLite was the destination, then
735    this will be a tflite flatbuffer in a bytes array.
736
737  Raises:
738    Defined in `build_toco_convert_protos`.
739  """
740  model_flags, toco_flags, _ = build_toco_convert_protos(
741      input_tensors=[], output_tensors=[], *args, **kwargs)
742
743  for idx, (name, shape) in enumerate(input_arrays_with_shape):
744    input_array = model_flags.input_arrays.add()
745    if _requires_input_stats(toco_flags):
746      if (("quantized_input_stats" not in kwargs) or
747          (not kwargs["quantized_input_stats"])):
748        raise ValueError(
749            "The `quantized_input_stats` flag must be defined when either "
750            "`inference_type` flag or `inference_input_type` flag is set to "
751            "tf.int8 or tf.uint8.")
752      input_array.mean_value, input_array.std_value = kwargs[
753          "quantized_input_stats"][idx]
754    input_array.name = name
755    input_array.shape.dims.extend(list(map(int, shape)))
756
757  if output_arrays:
758    for name in output_arrays:
759      model_flags.output_arrays.append(name)
760  if control_output_arrays:
761    for name in control_output_arrays:
762      model_flags.control_output_arrays.append(name)
763
764  data = toco_convert_protos(
765      model_flags.SerializeToString(),
766      toco_flags.SerializeToString(),
767      input_data.SerializeToString(),
768      enable_mlir_converter=enable_mlir_converter)
769  return data
770
771
772@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
773               SubComponent.CONVERT_GRAPHDEF)
774def toco_convert_impl(input_data, input_tensors, output_tensors,
775                      enable_mlir_converter, *args, **kwargs):
776  """"Convert a model using TOCO.
777
778  Typically this function is used to convert from TensorFlow GraphDef to TFLite.
779  Conversion can be customized by providing arguments that are forwarded to
780  `build_toco_convert_protos` (see documentation for details).
781
782  Args:
783    input_data: Input data (i.e. often `sess.graph_def`),
784    input_tensors: List of input tensors. Type and shape are computed using
785      `foo.shape` and `foo.dtype`.
786    output_tensors: List of output tensors (only .name is used from this).
787    enable_mlir_converter: Enables MLIR-based conversion instead of TOCO
788      conversion.
789    *args: See `build_toco_convert_protos`,
790    **kwargs: See `build_toco_convert_protos`.
791
792  Returns:
793    The converted data. For example if TFLite was the destination, then
794    this will be a tflite flatbuffer in a bytes array.
795
796  Raises:
797    Defined in `build_toco_convert_protos`.
798  """
799  model_flags, toco_flags, debug_info = build_toco_convert_protos(
800      input_tensors, output_tensors, *args, **kwargs)
801  debug_info_str = debug_info.SerializeToString() if debug_info else None
802  data = toco_convert_protos(
803      model_flags.SerializeToString(),
804      toco_flags.SerializeToString(),
805      input_data.SerializeToString(),
806      debug_info_str=debug_info_str,
807      enable_mlir_converter=enable_mlir_converter)
808  return data
809
810
811@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
812               SubComponent.CONVERT_SAVED_MODEL)
813def convert_saved_model(saved_model_dir=None,
814                        saved_model_version=0,
815                        saved_model_tags=None,
816                        saved_model_exported_names=None,
817                        **kwargs):
818  """Converts a saved_model using TF Lite converter."""
819  model_flags = _model_flags_pb2.ModelFlags()
820  if saved_model_dir:
821    model_flags.saved_model_dir = saved_model_dir
822  model_flags.saved_model_version = saved_model_version
823  if saved_model_tags:
824    model_flags.saved_model_tags.extend(saved_model_tags)
825  if saved_model_exported_names:
826    model_flags.saved_model_exported_names.extend(saved_model_exported_names)
827  toco_flags = build_toco_flags(**kwargs)
828  data = toco_convert_protos(
829      model_flags.SerializeToString(),
830      toco_flags.SerializeToString(),
831      None,  # input_data, unused
832      None,  # debug_info_str, unused
833      enable_mlir_converter=True)
834  return data
835
836
837@_tf_export(v1=["lite.toco_convert"])
838@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
839def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
840  """Convert a model using TOCO.
841
842  Typically this function is used to convert from TensorFlow GraphDef to TFLite.
843  Conversion can be customized by providing arguments that are forwarded to
844  `build_toco_convert_protos` (see documentation for details). This function has
845  been deprecated. Please use `tf.lite.TFLiteConverter` instead.
846
847  Args:
848    input_data: Input data (i.e. often `sess.graph_def`),
849    input_tensors: List of input tensors. Type and shape are computed using
850      `foo.shape` and `foo.dtype`.
851    output_tensors: List of output tensors (only .name is used from this).
852    *args: See `build_toco_convert_protos`,
853    **kwargs: See `build_toco_convert_protos`.
854
855  Returns:
856    The converted data. For example if TFLite was the destination, then
857    this will be a tflite flatbuffer in a bytes array.
858
859  Raises:
860    Defined in `build_toco_convert_protos`.
861  """
862  enable_mlir_converter = kwargs.get("enable_mlir_converter", False)
863  return toco_convert_impl(input_data, input_tensors, output_tensors,
864                           enable_mlir_converter, *args, **kwargs)
865
866
867def deduplicate_readonly_buffers(tflite_model):
868  """"Generates a new model byte array after deduplicating readonly buffers.
869
870  This function should be invoked after the model optimization toolkit. The
871  model optimization toolkit assumes that each tensor object owns its each
872  buffer separately.
873
874  Args:
875    tflite_model: TFLite flatbuffer in a byte array to be deduplicated.
876
877  Returns:
878    TFLite flatbuffer in a bytes array, processed with the deduplication method.
879
880  """
881  # Load TFLite Flatbuffer byte array into an object.
882  model = flatbuffer_utils.convert_bytearray_to_object(tflite_model)
883
884  # Get all the read-only buffers, which can be modified without causing any
885  # issue in the graph invocation stage.
886  read_only_buffer_indices = set()
887  for subgraph in model.subgraphs:
888    # To get all the read-only buffers:
889    # (1) Get all read-only input tensors.
890    # (2) Discard intermediate or output tensors.
891    # (3) Discard the subgraph's input/output tensors.
892    # (4) Gather the buffers of the read-only input tensors.
893
894    # (1) Get read-only input tensors.
895    read_only_input_tensor_indices = set()
896    for op in subgraph.operators:
897      if op.inputs is None:
898        continue
899      for i, input_tensor_idx in enumerate(op.inputs):
900        # Ignore mutable tensors.
901        if op.mutatingVariableInputs is not None:
902          # Ignore invalid tensors.
903          if (i < len(op.mutatingVariableInputs) and
904              op.mutatingVariableInputs[i]):
905            continue
906        # Ignore variable tensors.
907        if subgraph.tensors[input_tensor_idx].isVariable:
908          continue
909        read_only_input_tensor_indices.add(input_tensor_idx)
910
911    # (2) Discard intermediate or output tensors.
912    for op in subgraph.operators:
913      if op.outputs is not None:
914        for output_tensor_idx in op.outputs:
915          read_only_input_tensor_indices.discard(output_tensor_idx)
916      if op.intermediates is not None:
917        for intermediate_tensor_idx in op.intermediates:
918          read_only_input_tensor_indices.discard(intermediate_tensor_idx)
919
920    # (3) Discard the subgraph's input and output tensors.
921    if subgraph.inputs is not None:
922      for input_tensor_idx in subgraph.inputs:
923        read_only_input_tensor_indices.discard(input_tensor_idx)
924    if subgraph.outputs is not None:
925      for output_tensor_idx in subgraph.outputs:
926        read_only_input_tensor_indices.discard(output_tensor_idx)
927
928    # (4) Gather the buffers of the read-only input tensors.
929    for tensor_idx in read_only_input_tensor_indices:
930      read_only_buffer_indices.add(subgraph.tensors[tensor_idx].buffer)
931
932  # Ignore invalid negative index or zero-sized buffers.
933  for buffer_idx in read_only_buffer_indices.copy():
934    if (buffer_idx < 0 or (model.buffers[buffer_idx].data is None or
935                           isinstance(model.buffers[buffer_idx].data, list) or
936                           model.buffers[buffer_idx].data.size == 0)):
937      read_only_buffer_indices.discard(buffer_idx)
938
939  # Sort by buffer size.
940  read_only_buffer_indices = list(read_only_buffer_indices)
941  sorted(
942      read_only_buffer_indices,
943      key=lambda idx: model.buffers[idx].data.data.tobytes())
944
945  # Create a map of duplicate buffers (same size and same type).
946  # eg: In [1, 2, 3, 4, 5, 6] if (1, 4, 6) and (2, 5) are each, groups of buffer
947  # indices of the same size and type, then the map would be {4:1, 6:1, 5:2}
948  duplicate_buffer_map = {}
949  for i, buffer_i_idx in enumerate(read_only_buffer_indices):
950    # This buffer is a duplicate.
951    if buffer_i_idx in duplicate_buffer_map:
952      continue
953    # This buffer is unique. Scan rest of the list to find duplicates
954    # of this buffer and mark them accordingly.
955    buffer_i = model.buffers[buffer_i_idx]
956    for buffer_j_idx in read_only_buffer_indices[i + 1:]:
957      if buffer_j_idx in duplicate_buffer_map:
958        continue
959      buffer_j = model.buffers[buffer_j_idx]
960      if buffer_i.data.size != buffer_j.data.size:
961        break
962      if buffer_i.data.data != buffer_j.data.data:
963        continue
964      # Found duplicate. Nullify j-th buffer and use i-th buffer instead.
965      duplicate_buffer_map[buffer_j_idx] = buffer_i_idx
966
967  # Make the duplicated tensors use the single shared buffer index.
968  for subgraph in model.subgraphs:
969    for op in subgraph.operators:
970      if op.inputs is None:
971        continue
972      for input_tensor in op.inputs:
973        buffer_idx = subgraph.tensors[input_tensor].buffer
974        if buffer_idx in duplicate_buffer_map:
975          subgraph.tensors[input_tensor].buffer = (
976              duplicate_buffer_map[buffer_idx])
977
978  # Nullify the unused buffers.
979  for idx in duplicate_buffer_map:
980    model.buffers[idx].data = None
981
982  # Return a TFLite flatbuffer as a byte array.
983  return flatbuffer_utils.convert_object_to_bytearray(model)
984