• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python command line interface for converting TF models to TFLite models."""
16
17import argparse
18import os
19import sys
20import warnings
21
22from absl import app
23import tensorflow as tf  # pylint: disable=unused-import
24
25from tensorflow.lite.python import lite
26from tensorflow.lite.python.convert import register_custom_opdefs
27from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
28from tensorflow.lite.toco.logging import gen_html
29from tensorflow.python import tf2
30from tensorflow.python.framework import dtypes
31from tensorflow.python.util import keras_deps
32
33# Needed to enable TF2 by default.
34
35
36def _parse_array(values, type_fn=str):
37  if values is not None:
38    return [type_fn(val) for val in values.split(",") if val]
39  return None
40
41
42def _parse_set(values):
43  if values is not None:
44    return set([item for item in values.split(",") if item])
45  return None
46
47
48def _parse_inference_type(value, flag):
49  """Converts the inference type to the value of the constant.
50
51  Args:
52    value: str representing the inference type.
53    flag: str representing the flag name.
54
55  Returns:
56    tf.dtype.
57
58  Raises:
59    ValueError: Unsupported value.
60  """
61  if value == "FLOAT":
62    return dtypes.float32
63  if value == "INT8":
64    return dtypes.int8
65  if value == "UINT8" or value == "QUANTIZED_UINT8":
66    return dtypes.uint8
67  raise ValueError(
68      "Unsupported value for `{}` flag. Expected FLOAT, INT8, UINT8, or "
69      "QUANTIZED_UINT8 instead got {}.".format(flag, value))
70
71
72class _ParseBooleanFlag(argparse.Action):
73  """Helper class to parse boolean flag that optionally accepts truth value."""
74
75  def __init__(self, option_strings, dest, nargs=None, **kwargs):
76    if nargs != "?":
77      # This should never happen. This class is only used once below with
78      # nargs="?".
79      raise ValueError(
80          "This parser only supports nargs='?' (0 or 1 additional arguments)")
81    super(_ParseBooleanFlag, self).__init__(
82        option_strings, dest, nargs=nargs, **kwargs)
83
84  def __call__(self, parser, namespace, values, option_string=None):
85    if values is None:
86      # Handling `--boolean_flag`.
87      # Without additional arguments, it implies true.
88      flag_value = True
89    elif values.lower() == "true":
90      # Handling `--boolean_flag=true`.
91      # (Case insensitive after the equal sign)
92      flag_value = True
93    elif values.lower() == "false":
94      # Handling `--boolean_flag=false`.
95      # (Case insensitive after the equal sign)
96      flag_value = False
97    else:
98      raise ValueError("Invalid argument to --{}. Must use flag alone,"
99                       " or specify true/false.".format(self.dest))
100    setattr(namespace, self.dest, flag_value)
101
102
103def _get_tflite_converter(flags):
104  """Makes a TFLiteConverter object based on the flags provided.
105
106  Args:
107    flags: argparse.Namespace object containing TFLite flags.
108
109  Returns:
110    TFLiteConverter object.
111
112  Raises:
113    ValueError: Invalid flags.
114  """
115  # Parse input and output arrays.
116  input_arrays = _parse_array(flags.input_arrays)
117  input_shapes = None
118  if flags.input_shapes:
119    input_shapes_list = [
120        _parse_array(shape, type_fn=int)
121        for shape in flags.input_shapes.split(":")
122    ]
123    input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
124  output_arrays = _parse_array(flags.output_arrays)
125
126  converter_kwargs = {
127      "input_arrays": input_arrays,
128      "input_shapes": input_shapes,
129      "output_arrays": output_arrays
130  }
131
132  # Create TFLiteConverter.
133  if flags.graph_def_file:
134    converter_fn = lite.TFLiteConverter.from_frozen_graph
135    converter_kwargs["graph_def_file"] = flags.graph_def_file
136  elif flags.saved_model_dir:
137    converter_fn = lite.TFLiteConverter.from_saved_model
138    converter_kwargs["saved_model_dir"] = flags.saved_model_dir
139    converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
140    converter_kwargs["signature_key"] = flags.saved_model_signature_key
141  elif flags.keras_model_file:
142    converter_fn = lite.TFLiteConverter.from_keras_model_file
143    converter_kwargs["model_file"] = flags.keras_model_file
144  else:
145    raise ValueError("--graph_def_file, --saved_model_dir, or "
146                     "--keras_model_file must be specified.")
147
148  return converter_fn(**converter_kwargs)
149
150
151def _convert_tf1_model(flags):
152  """Calls function to convert the TensorFlow 1.X model into a TFLite model.
153
154  Args:
155    flags: argparse.Namespace object.
156
157  Raises:
158    ValueError: Invalid flags.
159  """
160  # Register custom opdefs before converter object creation.
161  if flags.custom_opdefs:
162    register_custom_opdefs(_parse_array(flags.custom_opdefs))
163
164  # Create converter.
165  converter = _get_tflite_converter(flags)
166  if flags.inference_type:
167    converter.inference_type = _parse_inference_type(flags.inference_type,
168                                                     "inference_type")
169  if flags.inference_input_type:
170    converter.inference_input_type = _parse_inference_type(
171        flags.inference_input_type, "inference_input_type")
172  if flags.output_format:
173    converter.output_format = _toco_flags_pb2.FileFormat.Value(
174        flags.output_format)
175
176  if flags.mean_values and flags.std_dev_values:
177    input_arrays = converter.get_input_arrays()
178    std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
179
180    # In quantized inference, mean_value has to be integer so that the real
181    # value 0.0 is exactly representable.
182    if converter.inference_type == dtypes.float32:
183      mean_values = _parse_array(flags.mean_values, type_fn=float)
184    else:
185      mean_values = _parse_array(flags.mean_values, type_fn=int)
186    quant_stats = list(zip(mean_values, std_dev_values))
187    if ((not flags.input_arrays and len(input_arrays) > 1) or
188        (len(input_arrays) != len(quant_stats))):
189      raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
190                       "--mean_values. The flags must have the same number of "
191                       "items. The current input arrays are '{0}'. "
192                       "--input_arrays must be present when specifying "
193                       "--std_dev_values and --mean_values with multiple input "
194                       "tensors in order to map between names and "
195                       "values.".format(",".join(input_arrays)))
196    converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
197  if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
198                                                 not None):
199    converter.default_ranges_stats = (flags.default_ranges_min,
200                                      flags.default_ranges_max)
201
202  if flags.drop_control_dependency:
203    converter.drop_control_dependency = flags.drop_control_dependency
204  if flags.reorder_across_fake_quant:
205    converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
206  if flags.change_concat_input_ranges:
207    converter.change_concat_input_ranges = (
208        flags.change_concat_input_ranges == "TRUE")
209
210  if flags.allow_custom_ops:
211    converter.allow_custom_ops = flags.allow_custom_ops
212
213  if flags.target_ops:
214    ops_set_options = lite.OpsSet.get_options()
215    converter.target_spec.supported_ops = set()
216    for option in flags.target_ops.split(","):
217      if option not in ops_set_options:
218        raise ValueError("Invalid value for --target_ops. Options: "
219                         "{0}".format(",".join(ops_set_options)))
220      converter.target_spec.supported_ops.add(lite.OpsSet(option))
221
222  if flags.experimental_select_user_tf_ops:
223    if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops:
224      raise ValueError("--experimental_select_user_tf_ops can only be set if "
225                       "--target_ops contains SELECT_TF_OPS.")
226    user_op_set = set()
227    for op_name in flags.experimental_select_user_tf_ops.split(","):
228      user_op_set.add(op_name)
229    converter.target_spec.experimental_select_user_tf_ops = list(user_op_set)
230
231  if flags.post_training_quantize:
232    converter.optimizations = [lite.Optimize.DEFAULT]
233    if converter.inference_type != dtypes.float32:
234      print("--post_training_quantize quantizes a graph of inference_type "
235            "FLOAT. Overriding inference_type to FLOAT.")
236      converter.inference_type = dtypes.float32
237
238  if flags.quantize_to_float16:
239    converter.target_spec.supported_types = [dtypes.float16]
240    if not flags.post_training_quantize:
241      print("--quantize_to_float16 will only take effect with the "
242            "--post_training_quantize flag enabled.")
243
244  if flags.dump_graphviz_dir:
245    converter.dump_graphviz_dir = flags.dump_graphviz_dir
246  if flags.dump_graphviz_video:
247    converter.dump_graphviz_vode = flags.dump_graphviz_video
248  if flags.conversion_summary_dir:
249    converter.conversion_summary_dir = flags.conversion_summary_dir
250
251  converter.experimental_new_converter = flags.experimental_new_converter
252
253  if flags.experimental_new_quantizer is not None:
254    converter.experimental_new_quantizer = flags.experimental_new_quantizer
255
256  # Convert model.
257  output_data = converter.convert()
258  with open(flags.output_file, "wb") as f:
259    f.write(output_data)
260
261
262def _convert_tf2_model(flags):
263  """Calls function to convert the TensorFlow 2.0 model into a TFLite model.
264
265  Args:
266    flags: argparse.Namespace object.
267
268  Raises:
269    ValueError: Unsupported file format.
270  """
271  # Load the model.
272  if flags.saved_model_dir:
273    converter = lite.TFLiteConverterV2.from_saved_model(
274        flags.saved_model_dir,
275        signature_keys=_parse_array(flags.saved_model_signature_key),
276        tags=_parse_set(flags.saved_model_tag_set))
277  elif flags.keras_model_file:
278    model = keras_deps.get_load_model_function()(flags.keras_model_file)
279    converter = lite.TFLiteConverterV2.from_keras_model(model)
280
281  converter.experimental_new_converter = flags.experimental_new_converter
282
283  if flags.experimental_new_quantizer is not None:
284    converter.experimental_new_quantizer = flags.experimental_new_quantizer
285
286  # Convert the model.
287  tflite_model = converter.convert()
288  with open(flags.output_file, "wb") as f:
289    f.write(tflite_model)
290
291
292def _check_tf1_flags(flags, unparsed):
293  """Checks the parsed and unparsed flags to ensure they are valid in 1.X.
294
295  Raises an error if previously support unparsed flags are found. Raises an
296  error for parsed flags that don't meet the required conditions.
297
298  Args:
299    flags: argparse.Namespace object containing TFLite flags.
300    unparsed: List of unparsed flags.
301
302  Raises:
303    ValueError: Invalid flags.
304  """
305
306  # Check unparsed flags for common mistakes based on previous TOCO.
307  def _get_message_unparsed(flag, orig_flag, new_flag):
308    if flag.startswith(orig_flag):
309      return "\n  Use {0} instead of {1}".format(new_flag, orig_flag)
310    return ""
311
312  if unparsed:
313    output = ""
314    for flag in unparsed:
315      output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
316      output += _get_message_unparsed(flag, "--savedmodel_directory",
317                                      "--saved_model_dir")
318      output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
319      output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
320      output += _get_message_unparsed(flag, "--dump_graphviz",
321                                      "--dump_graphviz_dir")
322    if output:
323      raise ValueError(output)
324
325  # Check that flags are valid.
326  if flags.graph_def_file and (not flags.input_arrays or
327                               not flags.output_arrays):
328    raise ValueError("--input_arrays and --output_arrays are required with "
329                     "--graph_def_file")
330
331  if flags.input_shapes:
332    if not flags.input_arrays:
333      raise ValueError("--input_shapes must be used with --input_arrays")
334    if flags.input_shapes.count(":") != flags.input_arrays.count(","):
335      raise ValueError("--input_shapes and --input_arrays must have the same "
336                       "number of items")
337
338  if flags.std_dev_values or flags.mean_values:
339    if bool(flags.std_dev_values) != bool(flags.mean_values):
340      raise ValueError("--std_dev_values and --mean_values must be used "
341                       "together")
342    if flags.std_dev_values.count(",") != flags.mean_values.count(","):
343      raise ValueError("--std_dev_values, --mean_values must have the same "
344                       "number of items")
345
346  if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
347    raise ValueError("--default_ranges_min and --default_ranges_max must be "
348                     "used together")
349
350  if flags.dump_graphviz_video and not flags.dump_graphviz_dir:
351    raise ValueError("--dump_graphviz_video must be used with "
352                     "--dump_graphviz_dir")
353
354  if flags.custom_opdefs and not flags.experimental_new_converter:
355    raise ValueError("--custom_opdefs must be used with "
356                     "--experimental_new_converter")
357  if flags.custom_opdefs and not flags.allow_custom_ops:
358    raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
359  if (flags.experimental_select_user_tf_ops and
360      not flags.experimental_new_converter):
361    raise ValueError("--experimental_select_user_tf_ops must be used with "
362                     "--experimental_new_converter")
363
364
365def _check_tf2_flags(flags):
366  """Checks the parsed and unparsed flags to ensure they are valid in 2.X.
367
368  Args:
369    flags: argparse.Namespace object containing TFLite flags.
370
371  Raises:
372    ValueError: Invalid flags.
373  """
374  if not flags.keras_model_file and not flags.saved_model_dir:
375    raise ValueError("one of the arguments --saved_model_dir "
376                     "--keras_model_file is required")
377
378
379def _get_tf1_flags(parser):
380  """Returns ArgumentParser for tflite_convert for TensorFlow 1.X.
381
382  Args:
383    parser: ArgumentParser
384  """
385  # Input file flags.
386  input_file_group = parser.add_mutually_exclusive_group(required=True)
387  input_file_group.add_argument(
388      "--graph_def_file",
389      type=str,
390      help="Full filepath of file containing frozen TensorFlow GraphDef.")
391  input_file_group.add_argument(
392      "--saved_model_dir",
393      type=str,
394      help="Full filepath of directory containing the SavedModel.")
395  input_file_group.add_argument(
396      "--keras_model_file",
397      type=str,
398      help="Full filepath of HDF5 file containing tf.Keras model.")
399
400  # Model format flags.
401  parser.add_argument(
402      "--output_format",
403      type=str.upper,
404      choices=["TFLITE", "GRAPHVIZ_DOT"],
405      help="Output file format.")
406  parser.add_argument(
407      "--inference_type",
408      type=str.upper,
409      default="FLOAT",
410      help=("Target data type of real-number arrays in the output file. "
411            "Must be either FLOAT, INT8 or UINT8."))
412  parser.add_argument(
413      "--inference_input_type",
414      type=str.upper,
415      help=("Target data type of real-number input arrays. Allows for a "
416            "different type for input arrays in the case of quantization. "
417            "Must be either FLOAT, INT8 or UINT8."))
418
419  # Input and output arrays flags.
420  parser.add_argument(
421      "--input_arrays",
422      type=str,
423      help="Names of the input arrays, comma-separated.")
424  parser.add_argument(
425      "--input_shapes",
426      type=str,
427      help="Shapes corresponding to --input_arrays, colon-separated.")
428  parser.add_argument(
429      "--output_arrays",
430      type=str,
431      help="Names of the output arrays, comma-separated.")
432
433  # SavedModel related flags.
434  parser.add_argument(
435      "--saved_model_tag_set",
436      type=str,
437      help=("Comma-separated set of tags identifying the MetaGraphDef within "
438            "the SavedModel to analyze. All tags must be present. In order to "
439            "pass in an empty tag set, pass in \"\". (default \"serve\")"))
440  parser.add_argument(
441      "--saved_model_signature_key",
442      type=str,
443      help=("Key identifying the SignatureDef containing inputs and outputs. "
444            "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
445
446  # Quantization flags.
447  parser.add_argument(
448      "--std_dev_values",
449      type=str,
450      help=("Standard deviation of training data for each input tensor, "
451            "comma-separated floats. Used for quantized input tensors. "
452            "(default None)"))
453  parser.add_argument(
454      "--mean_values",
455      type=str,
456      help=("Mean of training data for each input tensor, comma-separated "
457            "floats. Used for quantized input tensors. (default None)"))
458  parser.add_argument(
459      "--default_ranges_min",
460      type=float,
461      help=("Default value for min bound of min/max range values used for all "
462            "arrays without a specified range, Intended for experimenting with "
463            "quantization via \"dummy quantization\". (default None)"))
464  parser.add_argument(
465      "--default_ranges_max",
466      type=float,
467      help=("Default value for max bound of min/max range values used for all "
468            "arrays without a specified range, Intended for experimenting with "
469            "quantization via \"dummy quantization\". (default None)"))
470  # quantize_weights is DEPRECATED.
471  parser.add_argument(
472      "--quantize_weights",
473      dest="post_training_quantize",
474      action="store_true",
475      help=argparse.SUPPRESS)
476  parser.add_argument(
477      "--post_training_quantize",
478      dest="post_training_quantize",
479      action="store_true",
480      help=(
481          "Boolean indicating whether to quantize the weights of the "
482          "converted float model. Model size will be reduced and there will "
483          "be latency improvements (at the cost of accuracy). (default False)"))
484  parser.add_argument(
485      "--quantize_to_float16",
486      dest="quantize_to_float16",
487      action="store_true",
488      help=("Boolean indicating whether to quantize weights to fp16 instead of "
489            "the default int8 when post-training quantization "
490            "(--post_training_quantize) is enabled. (default False)"))
491  # Graph manipulation flags.
492  parser.add_argument(
493      "--drop_control_dependency",
494      action="store_true",
495      help=("Boolean indicating whether to drop control dependencies silently. "
496            "This is due to TensorFlow not supporting control dependencies. "
497            "(default True)"))
498  parser.add_argument(
499      "--reorder_across_fake_quant",
500      action="store_true",
501      help=("Boolean indicating whether to reorder FakeQuant nodes in "
502            "unexpected locations. Used when the location of the FakeQuant "
503            "nodes is preventing graph transformations necessary to convert "
504            "the graph. Results in a graph that differs from the quantized "
505            "training graph, potentially causing differing arithmetic "
506            "behavior. (default False)"))
507  # Usage for this flag is --change_concat_input_ranges=true or
508  # --change_concat_input_ranges=false in order to make it clear what the flag
509  # is set to. This keeps the usage consistent with other usages of the flag
510  # where the default is different. The default value here is False.
511  parser.add_argument(
512      "--change_concat_input_ranges",
513      type=str.upper,
514      choices=["TRUE", "FALSE"],
515      help=("Boolean to change behavior of min/max ranges for inputs and "
516            "outputs of the concat operator for quantized models. Changes the "
517            "ranges of concat operator overlap when true. (default False)"))
518
519  # Permitted ops flags.
520  parser.add_argument(
521      "--allow_custom_ops",
522      action=_ParseBooleanFlag,
523      nargs="?",
524      help=("Boolean indicating whether to allow custom operations. When false "
525            "any unknown operation is an error. When true, custom ops are "
526            "created for any op that is unknown. The developer will need to "
527            "provide these to the TensorFlow Lite runtime with a custom "
528            "resolver. (default False)"))
529  parser.add_argument(
530      "--custom_opdefs",
531      type=str,
532      help=("String representing a list of custom ops OpDefs delineated with "
533            "commas that are included in the GraphDef. Required when using "
534            "custom operations with --experimental_new_converter."))
535  parser.add_argument(
536      "--target_ops",
537      type=str,
538      help=("Experimental flag, subject to change. Set of OpsSet options "
539            "indicating which converter to use. Options: {0}. One or more "
540            "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
541            "".format(",".join(lite.OpsSet.get_options()))))
542  parser.add_argument(
543      "--experimental_select_user_tf_ops",
544      type=str,
545      help=("Experimental flag, subject to change. Comma separated list of "
546            "user's defined TensorFlow operators required in the runtime."))
547
548  # Logging flags.
549  parser.add_argument(
550      "--dump_graphviz_dir",
551      type=str,
552      help=("Full filepath of folder to dump the graphs at various stages of "
553            "processing GraphViz .dot files. Preferred over --output_format="
554            "GRAPHVIZ_DOT in order to keep the requirements of the output "
555            "file."))
556  parser.add_argument(
557      "--dump_graphviz_video",
558      action="store_true",
559      help=("Boolean indicating whether to dump the graph after every graph "
560            "transformation"))
561  parser.add_argument(
562      "--conversion_summary_dir",
563      type=str,
564      help=("Full filepath to store the conversion logs, which includes "
565            "graphviz of the model before/after the conversion, an HTML report "
566            "and the conversion proto buffers. This will only be generated "
567            "when passing --experimental_new_converter"))
568
569
570def _get_tf2_flags(parser):
571  """Returns ArgumentParser for tflite_convert for TensorFlow 2.0.
572
573  Args:
574    parser: ArgumentParser
575  """
576  # Input file flags.
577  input_file_group = parser.add_mutually_exclusive_group()
578  input_file_group.add_argument(
579      "--saved_model_dir",
580      type=str,
581      help="Full path of the directory containing the SavedModel.")
582  input_file_group.add_argument(
583      "--keras_model_file",
584      type=str,
585      help="Full filepath of HDF5 file containing tf.Keras model.")
586  # SavedModel related flags.
587  parser.add_argument(
588      "--saved_model_tag_set",
589      type=str,
590      help=("Comma-separated set of tags identifying the MetaGraphDef within "
591            "the SavedModel to analyze. All tags must be present. In order to "
592            "pass in an empty tag set, pass in \"\". (default \"serve\")"))
593  parser.add_argument(
594      "--saved_model_signature_key",
595      type=str,
596      help=("Key identifying the SignatureDef containing inputs and outputs. "
597            "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
598
599  # Enables 1.X converter in 2.X.
600  parser.add_argument(
601      "--enable_v1_converter",
602      action="store_true",
603      help=("Enables the TensorFlow V1 converter in 2.0"))
604
605
606def _get_parser(use_v2_converter):
607  """Returns an ArgumentParser for tflite_convert.
608
609  Args:
610    use_v2_converter: Indicates which converter to return.
611  Return: ArgumentParser.
612  """
613  parser = argparse.ArgumentParser(
614      description=("Command line tool to run TensorFlow Lite Converter."))
615
616  # Output file flag.
617  parser.add_argument(
618      "--output_file",
619      type=str,
620      help="Full filepath of the output file.",
621      required=True)
622
623  if use_v2_converter:
624    _get_tf2_flags(parser)
625  else:
626    _get_tf1_flags(parser)
627
628  parser.add_argument(
629      "--experimental_new_converter",
630      action=_ParseBooleanFlag,
631      nargs="?",
632      default=True,
633      help=("Experimental flag, subject to change. Enables MLIR-based "
634            "conversion instead of TOCO conversion. (default True)"))
635
636  parser.add_argument(
637      "--experimental_new_quantizer",
638      action=_ParseBooleanFlag,
639      nargs="?",
640      help=("Experimental flag, subject to change. Enables MLIR-based "
641            "quantizer instead of flatbuffer conversion. (default True)"))
642  return parser
643
644
645def run_main(_):
646  """Main in tflite_convert.py."""
647  use_v2_converter = tf2.enabled()
648  parser = _get_parser(use_v2_converter)
649  tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
650
651  # If the user is running TensorFlow 2.X but has passed in enable_v1_converter
652  # then parse the flags again with the 1.X converter flags.
653  if tf2.enabled() and tflite_flags.enable_v1_converter:
654    use_v2_converter = False
655    parser = _get_parser(use_v2_converter)
656    tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
657
658  # Checks if the flags are valid.
659  try:
660    if use_v2_converter:
661      _check_tf2_flags(tflite_flags)
662    else:
663      _check_tf1_flags(tflite_flags, unparsed)
664  except ValueError as e:
665    parser.print_usage()
666    file_name = os.path.basename(sys.argv[0])
667    sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
668    sys.exit(1)
669
670  # Convert the model according to the user provided flag.
671  if use_v2_converter:
672    _convert_tf2_model(tflite_flags)
673  else:
674    try:
675      _convert_tf1_model(tflite_flags)
676    finally:
677      if tflite_flags.conversion_summary_dir:
678        if tflite_flags.experimental_new_converter:
679          gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir,
680                                           tflite_flags.post_training_quantize,
681                                           tflite_flags.output_file)
682        else:
683          warnings.warn(
684              "Conversion summary will only be generated when enabling"
685              " the new converter via --experimental_new_converter. ")
686
687
688def main():
689  app.run(main=run_main, argv=sys.argv[:1])
690
691
692if __name__ == "__main__":
693  main()
694