• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Python command line interface for converting TF models to TFLite models."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import argparse
23import os
24import sys
25import warnings
26
27import six
28from six.moves import zip
29# Needed to enable TF2 by default.
30import tensorflow as tf  # pylint: disable=unused-import
31
32from tensorflow.lite.python import lite
33from tensorflow.lite.python.convert import register_custom_opdefs
34from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
35from tensorflow.lite.toco.logging import gen_html
36from tensorflow.python import tf2
37from tensorflow.python.framework import dtypes
38from tensorflow.python.platform import app
39from tensorflow.python.util import keras_deps
40
41
42def _parse_array(values, type_fn=str):
43  if values is not None:
44    return [type_fn(val) for val in six.ensure_str(values).split(",") if val]
45  return None
46
47
48def _parse_set(values):
49  if values is not None:
50    return set([item for item in six.ensure_str(values).split(",") if item])
51  return None
52
53
54def _parse_inference_type(value, flag):
55  """Converts the inference type to the value of the constant.
56
57  Args:
58    value: str representing the inference type.
59    flag: str representing the flag name.
60
61  Returns:
62    tf.dtype.
63
64  Raises:
65    ValueError: Unsupported value.
66  """
67  if value == "FLOAT":
68    return dtypes.float32
69  if value == "INT8":
70    return dtypes.int8
71  if value == "UINT8" or value == "QUANTIZED_UINT8":
72    return dtypes.uint8
73  raise ValueError(
74      "Unsupported value for `{}` flag. Expected FLOAT, INT8, UINT8, or "
75      "QUANTIZED_UINT8 instead got {}.".format(flag, value))
76
77
78def _get_tflite_converter(flags):
79  """Makes a TFLiteConverter object based on the flags provided.
80
81  Args:
82    flags: argparse.Namespace object containing TFLite flags.
83
84  Returns:
85    TFLiteConverter object.
86
87  Raises:
88    ValueError: Invalid flags.
89  """
90  # Parse input and output arrays.
91  input_arrays = _parse_array(flags.input_arrays)
92  input_shapes = None
93  if flags.input_shapes:
94    input_shapes_list = [
95        _parse_array(shape, type_fn=int)
96        for shape in six.ensure_str(flags.input_shapes).split(":")
97    ]
98    input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
99  output_arrays = _parse_array(flags.output_arrays)
100
101  converter_kwargs = {
102      "input_arrays": input_arrays,
103      "input_shapes": input_shapes,
104      "output_arrays": output_arrays
105  }
106
107  # Create TFLiteConverter.
108  if flags.graph_def_file:
109    converter_fn = lite.TFLiteConverter.from_frozen_graph
110    converter_kwargs["graph_def_file"] = flags.graph_def_file
111  elif flags.saved_model_dir:
112    converter_fn = lite.TFLiteConverter.from_saved_model
113    converter_kwargs["saved_model_dir"] = flags.saved_model_dir
114    converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
115    converter_kwargs["signature_key"] = flags.saved_model_signature_key
116  elif flags.keras_model_file:
117    converter_fn = lite.TFLiteConverter.from_keras_model_file
118    converter_kwargs["model_file"] = flags.keras_model_file
119  else:
120    raise ValueError("--graph_def_file, --saved_model_dir, or "
121                     "--keras_model_file must be specified.")
122
123  return converter_fn(**converter_kwargs)
124
125
126def _convert_tf1_model(flags):
127  """Calls function to convert the TensorFlow 1.X model into a TFLite model.
128
129  Args:
130    flags: argparse.Namespace object.
131
132  Raises:
133    ValueError: Invalid flags.
134  """
135  # Register custom opdefs before converter object creation.
136  if flags.custom_opdefs:
137    register_custom_opdefs(_parse_array(flags.custom_opdefs))
138
139  # Create converter.
140  converter = _get_tflite_converter(flags)
141  if flags.inference_type:
142    converter.inference_type = _parse_inference_type(flags.inference_type,
143                                                     "inference_type")
144  if flags.inference_input_type:
145    converter.inference_input_type = _parse_inference_type(
146        flags.inference_input_type, "inference_input_type")
147  if flags.output_format:
148    converter.output_format = _toco_flags_pb2.FileFormat.Value(
149        flags.output_format)
150
151  if flags.mean_values and flags.std_dev_values:
152    input_arrays = converter.get_input_arrays()
153    std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
154
155    # In quantized inference, mean_value has to be integer so that the real
156    # value 0.0 is exactly representable.
157    if converter.inference_type == dtypes.float32:
158      mean_values = _parse_array(flags.mean_values, type_fn=float)
159    else:
160      mean_values = _parse_array(flags.mean_values, type_fn=int)
161    quant_stats = list(zip(mean_values, std_dev_values))
162    if ((not flags.input_arrays and len(input_arrays) > 1) or
163        (len(input_arrays) != len(quant_stats))):
164      raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
165                       "--mean_values. The flags must have the same number of "
166                       "items. The current input arrays are '{0}'. "
167                       "--input_arrays must be present when specifying "
168                       "--std_dev_values and --mean_values with multiple input "
169                       "tensors in order to map between names and "
170                       "values.".format(",".join(input_arrays)))
171    converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
172  if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
173                                                 not None):
174    converter.default_ranges_stats = (flags.default_ranges_min,
175                                      flags.default_ranges_max)
176
177  if flags.drop_control_dependency:
178    converter.drop_control_dependency = flags.drop_control_dependency
179  if flags.reorder_across_fake_quant:
180    converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
181  if flags.change_concat_input_ranges:
182    converter.change_concat_input_ranges = (
183        flags.change_concat_input_ranges == "TRUE")
184
185  if flags.allow_custom_ops:
186    converter.allow_custom_ops = flags.allow_custom_ops
187
188  if flags.target_ops:
189    ops_set_options = lite.OpsSet.get_options()
190    converter.target_spec.supported_ops = set()
191    for option in six.ensure_str(flags.target_ops).split(","):
192      if option not in ops_set_options:
193        raise ValueError("Invalid value for --target_ops. Options: "
194                         "{0}".format(",".join(ops_set_options)))
195      converter.target_spec.supported_ops.add(lite.OpsSet(option))
196
197  if flags.experimental_select_user_tf_ops:
198    if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops:
199      raise ValueError("--experimental_select_user_tf_ops can only be set if "
200                       "--target_ops contains SELECT_TF_OPS.")
201    user_op_set = set()
202    for op_name in six.ensure_str(
203        flags.experimental_select_user_tf_ops).split(","):
204      user_op_set.add(op_name)
205    converter.target_spec.experimental_select_user_tf_ops = list(user_op_set)
206
207  if flags.post_training_quantize:
208    converter.optimizations = [lite.Optimize.DEFAULT]
209    if converter.inference_type != dtypes.float32:
210      print("--post_training_quantize quantizes a graph of inference_type "
211            "FLOAT. Overriding inference_type to FLOAT.")
212      converter.inference_type = dtypes.float32
213
214  if flags.quantize_to_float16:
215    converter.target_spec.supported_types = [dtypes.float16]
216    if not flags.post_training_quantize:
217      print("--quantize_to_float16 will only take effect with the "
218            "--post_training_quantize flag enabled.")
219
220  if flags.dump_graphviz_dir:
221    converter.dump_graphviz_dir = flags.dump_graphviz_dir
222  if flags.dump_graphviz_video:
223    converter.dump_graphviz_vode = flags.dump_graphviz_video
224  if flags.conversion_summary_dir:
225    converter.conversion_summary_dir = flags.conversion_summary_dir
226
227  if flags.experimental_new_converter is not None:
228    converter.experimental_new_converter = flags.experimental_new_converter
229
230  if flags.experimental_new_quantizer is not None:
231    converter.experimental_new_quantizer = flags.experimental_new_quantizer
232
233  # Convert model.
234  output_data = converter.convert()
235  with open(flags.output_file, "wb") as f:
236    f.write(six.ensure_binary(output_data))
237
238
239def _convert_tf2_model(flags):
240  """Calls function to convert the TensorFlow 2.0 model into a TFLite model.
241
242  Args:
243    flags: argparse.Namespace object.
244
245  Raises:
246    ValueError: Unsupported file format.
247  """
248  # Load the model.
249  if flags.saved_model_dir:
250    converter = lite.TFLiteConverterV2.from_saved_model(
251        flags.saved_model_dir,
252        signature_keys=_parse_array(flags.saved_model_signature_key),
253        tags=_parse_set(flags.saved_model_tag_set))
254  elif flags.keras_model_file:
255    model = keras_deps.get_load_model_function()(flags.keras_model_file)
256    converter = lite.TFLiteConverterV2.from_keras_model(model)
257
258  if flags.experimental_new_converter is not None:
259    converter.experimental_new_converter = flags.experimental_new_converter
260
261  if flags.experimental_new_quantizer is not None:
262    converter.experimental_new_quantizer = flags.experimental_new_quantizer
263
264  # Convert the model.
265  tflite_model = converter.convert()
266  with open(flags.output_file, "wb") as f:
267    f.write(six.ensure_binary(tflite_model))
268
269
270def _check_tf1_flags(flags, unparsed):
271  """Checks the parsed and unparsed flags to ensure they are valid in 1.X.
272
273  Raises an error if previously support unparsed flags are found. Raises an
274  error for parsed flags that don't meet the required conditions.
275
276  Args:
277    flags: argparse.Namespace object containing TFLite flags.
278    unparsed: List of unparsed flags.
279
280  Raises:
281    ValueError: Invalid flags.
282  """
283
284  # Check unparsed flags for common mistakes based on previous TOCO.
285  def _get_message_unparsed(flag, orig_flag, new_flag):
286    if six.ensure_str(flag).startswith(orig_flag):
287      return "\n  Use {0} instead of {1}".format(new_flag, orig_flag)
288    return ""
289
290  if unparsed:
291    output = ""
292    for flag in unparsed:
293      output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
294      output += _get_message_unparsed(flag, "--savedmodel_directory",
295                                      "--saved_model_dir")
296      output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
297      output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
298      output += _get_message_unparsed(flag, "--dump_graphviz",
299                                      "--dump_graphviz_dir")
300    if output:
301      raise ValueError(output)
302
303  # Check that flags are valid.
304  if flags.graph_def_file and (not flags.input_arrays or
305                               not flags.output_arrays):
306    raise ValueError("--input_arrays and --output_arrays are required with "
307                     "--graph_def_file")
308
309  if flags.input_shapes:
310    if not flags.input_arrays:
311      raise ValueError("--input_shapes must be used with --input_arrays")
312    if flags.input_shapes.count(":") != flags.input_arrays.count(","):
313      raise ValueError("--input_shapes and --input_arrays must have the same "
314                       "number of items")
315
316  if flags.std_dev_values or flags.mean_values:
317    if bool(flags.std_dev_values) != bool(flags.mean_values):
318      raise ValueError("--std_dev_values and --mean_values must be used "
319                       "together")
320    if flags.std_dev_values.count(",") != flags.mean_values.count(","):
321      raise ValueError("--std_dev_values, --mean_values must have the same "
322                       "number of items")
323
324  if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
325    raise ValueError("--default_ranges_min and --default_ranges_max must be "
326                     "used together")
327
328  if flags.dump_graphviz_video and not flags.dump_graphviz_dir:
329    raise ValueError("--dump_graphviz_video must be used with "
330                     "--dump_graphviz_dir")
331
332  if flags.custom_opdefs and not flags.experimental_new_converter:
333    raise ValueError("--custom_opdefs must be used with "
334                     "--experimental_new_converter")
335  if flags.custom_opdefs and not flags.allow_custom_ops:
336    raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
337  if (flags.experimental_select_user_tf_ops and
338      not flags.experimental_new_converter):
339    raise ValueError("--experimental_select_user_tf_ops must be used with "
340                     "--experimental_new_converter")
341
342
343def _check_tf2_flags(flags):
344  """Checks the parsed and unparsed flags to ensure they are valid in 2.X.
345
346  Args:
347    flags: argparse.Namespace object containing TFLite flags.
348
349  Raises:
350    ValueError: Invalid flags.
351  """
352  if not flags.keras_model_file and not flags.saved_model_dir:
353    raise ValueError("one of the arguments --saved_model_dir "
354                     "--keras_model_file is required")
355
356
357def _get_tf1_flags(parser):
358  """Returns ArgumentParser for tflite_convert for TensorFlow 1.X.
359
360  Args:
361    parser: ArgumentParser
362  """
363  # Input file flags.
364  input_file_group = parser.add_mutually_exclusive_group(required=True)
365  input_file_group.add_argument(
366      "--graph_def_file",
367      type=str,
368      help="Full filepath of file containing frozen TensorFlow GraphDef.")
369  input_file_group.add_argument(
370      "--saved_model_dir",
371      type=str,
372      help="Full filepath of directory containing the SavedModel.")
373  input_file_group.add_argument(
374      "--keras_model_file",
375      type=str,
376      help="Full filepath of HDF5 file containing tf.Keras model.")
377
378  # Model format flags.
379  parser.add_argument(
380      "--output_format",
381      type=str.upper,
382      choices=["TFLITE", "GRAPHVIZ_DOT"],
383      help="Output file format.")
384  parser.add_argument(
385      "--inference_type",
386      type=str.upper,
387      default="FLOAT",
388      help=("Target data type of real-number arrays in the output file. "
389            "Must be either FLOAT, INT8 or UINT8."))
390  parser.add_argument(
391      "--inference_input_type",
392      type=str.upper,
393      help=("Target data type of real-number input arrays. Allows for a "
394            "different type for input arrays in the case of quantization. "
395            "Must be either FLOAT, INT8 or UINT8."))
396
397  # Input and output arrays flags.
398  parser.add_argument(
399      "--input_arrays",
400      type=str,
401      help="Names of the input arrays, comma-separated.")
402  parser.add_argument(
403      "--input_shapes",
404      type=str,
405      help="Shapes corresponding to --input_arrays, colon-separated.")
406  parser.add_argument(
407      "--output_arrays",
408      type=str,
409      help="Names of the output arrays, comma-separated.")
410
411  # SavedModel related flags.
412  parser.add_argument(
413      "--saved_model_tag_set",
414      type=str,
415      help=("Comma-separated set of tags identifying the MetaGraphDef within "
416            "the SavedModel to analyze. All tags must be present. In order to "
417            "pass in an empty tag set, pass in \"\". (default \"serve\")"))
418  parser.add_argument(
419      "--saved_model_signature_key",
420      type=str,
421      help=("Key identifying the SignatureDef containing inputs and outputs. "
422            "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
423
424  # Quantization flags.
425  parser.add_argument(
426      "--std_dev_values",
427      type=str,
428      help=("Standard deviation of training data for each input tensor, "
429            "comma-separated floats. Used for quantized input tensors. "
430            "(default None)"))
431  parser.add_argument(
432      "--mean_values",
433      type=str,
434      help=("Mean of training data for each input tensor, comma-separated "
435            "floats. Used for quantized input tensors. (default None)"))
436  parser.add_argument(
437      "--default_ranges_min",
438      type=float,
439      help=("Default value for min bound of min/max range values used for all "
440            "arrays without a specified range, Intended for experimenting with "
441            "quantization via \"dummy quantization\". (default None)"))
442  parser.add_argument(
443      "--default_ranges_max",
444      type=float,
445      help=("Default value for max bound of min/max range values used for all "
446            "arrays without a specified range, Intended for experimenting with "
447            "quantization via \"dummy quantization\". (default None)"))
448  # quantize_weights is DEPRECATED.
449  parser.add_argument(
450      "--quantize_weights",
451      dest="post_training_quantize",
452      action="store_true",
453      help=argparse.SUPPRESS)
454  parser.add_argument(
455      "--post_training_quantize",
456      dest="post_training_quantize",
457      action="store_true",
458      help=(
459          "Boolean indicating whether to quantize the weights of the "
460          "converted float model. Model size will be reduced and there will "
461          "be latency improvements (at the cost of accuracy). (default False)"))
462  parser.add_argument(
463      "--quantize_to_float16",
464      dest="quantize_to_float16",
465      action="store_true",
466      help=("Boolean indicating whether to quantize weights to fp16 instead of "
467            "the default int8 when post-training quantization "
468            "(--post_training_quantize) is enabled. (default False)"))
469  # Graph manipulation flags.
470  parser.add_argument(
471      "--drop_control_dependency",
472      action="store_true",
473      help=("Boolean indicating whether to drop control dependencies silently. "
474            "This is due to TensorFlow not supporting control dependencies. "
475            "(default True)"))
476  parser.add_argument(
477      "--reorder_across_fake_quant",
478      action="store_true",
479      help=("Boolean indicating whether to reorder FakeQuant nodes in "
480            "unexpected locations. Used when the location of the FakeQuant "
481            "nodes is preventing graph transformations necessary to convert "
482            "the graph. Results in a graph that differs from the quantized "
483            "training graph, potentially causing differing arithmetic "
484            "behavior. (default False)"))
485  # Usage for this flag is --change_concat_input_ranges=true or
486  # --change_concat_input_ranges=false in order to make it clear what the flag
487  # is set to. This keeps the usage consistent with other usages of the flag
488  # where the default is different. The default value here is False.
489  parser.add_argument(
490      "--change_concat_input_ranges",
491      type=str.upper,
492      choices=["TRUE", "FALSE"],
493      help=("Boolean to change behavior of min/max ranges for inputs and "
494            "outputs of the concat operator for quantized models. Changes the "
495            "ranges of concat operator overlap when true. (default False)"))
496
497  # Permitted ops flags.
498  parser.add_argument(
499      "--allow_custom_ops",
500      action="store_true",
501      help=("Boolean indicating whether to allow custom operations. When false "
502            "any unknown operation is an error. When true, custom ops are "
503            "created for any op that is unknown. The developer will need to "
504            "provide these to the TensorFlow Lite runtime with a custom "
505            "resolver. (default False)"))
506  parser.add_argument(
507      "--custom_opdefs",
508      type=str,
509      help=("String representing a list of custom ops OpDefs delineated with "
510            "commas that are included in the GraphDef. Required when using "
511            "custom operations with --experimental_new_converter."))
512  parser.add_argument(
513      "--target_ops",
514      type=str,
515      help=("Experimental flag, subject to change. Set of OpsSet options "
516            "indicating which converter to use. Options: {0}. One or more "
517            "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
518            "".format(",".join(lite.OpsSet.get_options()))))
519  parser.add_argument(
520      "--experimental_select_user_tf_ops",
521      type=str,
522      help=("Experimental flag, subject to change. Comma separated list of "
523            "user's defined TensorFlow operators required in the runtime."))
524
525  # Logging flags.
526  parser.add_argument(
527      "--dump_graphviz_dir",
528      type=str,
529      help=("Full filepath of folder to dump the graphs at various stages of "
530            "processing GraphViz .dot files. Preferred over --output_format="
531            "GRAPHVIZ_DOT in order to keep the requirements of the output "
532            "file."))
533  parser.add_argument(
534      "--dump_graphviz_video",
535      action="store_true",
536      help=("Boolean indicating whether to dump the graph after every graph "
537            "transformation"))
538  parser.add_argument(
539      "--conversion_summary_dir",
540      type=str,
541      help=("Full filepath to store the conversion logs, which includes "
542            "graphviz of the model before/after the conversion, an HTML report "
543            "and the conversion proto buffers. This will only be generated "
544            "when passing --experimental_new_converter"))
545
546
547def _get_tf2_flags(parser):
548  """Returns ArgumentParser for tflite_convert for TensorFlow 2.0.
549
550  Args:
551    parser: ArgumentParser
552  """
553  # Input file flags.
554  input_file_group = parser.add_mutually_exclusive_group()
555  input_file_group.add_argument(
556      "--saved_model_dir",
557      type=str,
558      help="Full path of the directory containing the SavedModel.")
559  input_file_group.add_argument(
560      "--keras_model_file",
561      type=str,
562      help="Full filepath of HDF5 file containing tf.Keras model.")
563  # SavedModel related flags.
564  parser.add_argument(
565      "--saved_model_tag_set",
566      type=str,
567      help=("Comma-separated set of tags identifying the MetaGraphDef within "
568            "the SavedModel to analyze. All tags must be present. In order to "
569            "pass in an empty tag set, pass in \"\". (default \"serve\")"))
570  parser.add_argument(
571      "--saved_model_signature_key",
572      type=str,
573      help=("Key identifying the SignatureDef containing inputs and outputs. "
574            "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
575
576  # Enables 1.X converter in 2.X.
577  parser.add_argument(
578      "--enable_v1_converter",
579      action="store_true",
580      help=("Enables the TensorFlow V1 converter in 2.0"))
581
582
583class _ParseExperimentalNewConverter(argparse.Action):
584  """Helper class to parse --experimental_new_converter argument."""
585
586  def __init__(self, option_strings, dest, nargs=None, **kwargs):
587    if nargs != "?":
588      # This should never happen. This class is only used once below with
589      # nargs="?".
590      raise ValueError(
591          "This parser only supports nargs='?' (0 or 1 additional arguments)")
592    super(_ParseExperimentalNewConverter, self).__init__(
593        option_strings, dest, nargs=nargs, **kwargs)
594
595  def __call__(self, parser, namespace, values, option_string=None):
596    if values is None:
597      # Handling `--experimental_new_converter`.
598      # Without additional arguments, it implies enabling the new converter.
599      experimental_new_converter = True
600    elif values.lower() == "true":
601      # Handling `--experimental_new_converter=true`.
602      # (Case insensitive after the equal sign)
603      experimental_new_converter = True
604    elif values.lower() == "false":
605      # Handling `--experimental_new_converter=false`.
606      # (Case insensitive after the equal sign)
607      experimental_new_converter = False
608    else:
609      raise ValueError("Invalid --experimental_new_converter argument.")
610    setattr(namespace, self.dest, experimental_new_converter)
611
612
613def _get_parser(use_v2_converter):
614  """Returns an ArgumentParser for tflite_convert.
615
616  Args:
617    use_v2_converter: Indicates which converter to return.
618  Return: ArgumentParser.
619  """
620  parser = argparse.ArgumentParser(
621      description=("Command line tool to run TensorFlow Lite Converter."))
622
623  # Output file flag.
624  parser.add_argument(
625      "--output_file",
626      type=str,
627      help="Full filepath of the output file.",
628      required=True)
629
630  if use_v2_converter:
631    _get_tf2_flags(parser)
632  else:
633    _get_tf1_flags(parser)
634
635  parser.add_argument(
636      "--experimental_new_converter",
637      action=_ParseExperimentalNewConverter,
638      nargs="?",
639      help=("Experimental flag, subject to change. Enables MLIR-based "
640            "conversion instead of TOCO conversion. (default True)"))
641
642  parser.add_argument(
643      "--experimental_new_quantizer",
644      action="store_true",
645      help=("Experimental flag, subject to change. Enables MLIR-based "
646            "quantizer instead of flatbuffer conversion. (default False)"))
647  return parser
648
649
650def run_main(_):
651  """Main in tflite_convert.py."""
652  use_v2_converter = tf2.enabled()
653  parser = _get_parser(use_v2_converter)
654  tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
655
656  # If the user is running TensorFlow 2.X but has passed in enable_v1_converter
657  # then parse the flags again with the 1.X converter flags.
658  if tf2.enabled() and tflite_flags.enable_v1_converter:
659    use_v2_converter = False
660    parser = _get_parser(use_v2_converter)
661    tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
662
663  # Checks if the flags are valid.
664  try:
665    if use_v2_converter:
666      _check_tf2_flags(tflite_flags)
667    else:
668      _check_tf1_flags(tflite_flags, unparsed)
669  except ValueError as e:
670    parser.print_usage()
671    file_name = os.path.basename(sys.argv[0])
672    sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
673    sys.exit(1)
674
675  # Convert the model according to the user provided flag.
676  if use_v2_converter:
677    _convert_tf2_model(tflite_flags)
678  else:
679    try:
680      _convert_tf1_model(tflite_flags)
681    finally:
682      if tflite_flags.conversion_summary_dir:
683        if tflite_flags.experimental_new_converter:
684          gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir,
685                                           tflite_flags.post_training_quantize,
686                                           tflite_flags.output_file)
687        else:
688          warnings.warn(
689              "Conversion summary will only be generated when enabling"
690              " the new converter via --experimental_new_converter. ")
691
692
693def main():
694  app.run(main=run_main, argv=sys.argv[:1])
695
696
697if __name__ == "__main__":
698  main()
699