1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Python command line interface for converting TF models to TFLite models.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import argparse 23import os 24import sys 25import warnings 26 27import six 28from six.moves import zip 29# Needed to enable TF2 by default. 30import tensorflow as tf # pylint: disable=unused-import 31 32from tensorflow.lite.python import lite 33from tensorflow.lite.python.convert import register_custom_opdefs 34from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 35from tensorflow.lite.toco.logging import gen_html 36from tensorflow.python import tf2 37from tensorflow.python.framework import dtypes 38from tensorflow.python.platform import app 39from tensorflow.python.util import keras_deps 40 41 42def _parse_array(values, type_fn=str): 43 if values is not None: 44 return [type_fn(val) for val in six.ensure_str(values).split(",") if val] 45 return None 46 47 48def _parse_set(values): 49 if values is not None: 50 return set([item for item in six.ensure_str(values).split(",") if item]) 51 return None 52 53 54def _parse_inference_type(value, flag): 55 """Converts the inference type to the value of the constant. 56 57 Args: 58 value: str representing the inference type. 59 flag: str representing the flag name. 60 61 Returns: 62 tf.dtype. 63 64 Raises: 65 ValueError: Unsupported value. 66 """ 67 if value == "FLOAT": 68 return dtypes.float32 69 if value == "INT8": 70 return dtypes.int8 71 if value == "UINT8" or value == "QUANTIZED_UINT8": 72 return dtypes.uint8 73 raise ValueError( 74 "Unsupported value for `{}` flag. Expected FLOAT, INT8, UINT8, or " 75 "QUANTIZED_UINT8 instead got {}.".format(flag, value)) 76 77 78def _get_tflite_converter(flags): 79 """Makes a TFLiteConverter object based on the flags provided. 80 81 Args: 82 flags: argparse.Namespace object containing TFLite flags. 83 84 Returns: 85 TFLiteConverter object. 86 87 Raises: 88 ValueError: Invalid flags. 89 """ 90 # Parse input and output arrays. 91 input_arrays = _parse_array(flags.input_arrays) 92 input_shapes = None 93 if flags.input_shapes: 94 input_shapes_list = [ 95 _parse_array(shape, type_fn=int) 96 for shape in six.ensure_str(flags.input_shapes).split(":") 97 ] 98 input_shapes = dict(list(zip(input_arrays, input_shapes_list))) 99 output_arrays = _parse_array(flags.output_arrays) 100 101 converter_kwargs = { 102 "input_arrays": input_arrays, 103 "input_shapes": input_shapes, 104 "output_arrays": output_arrays 105 } 106 107 # Create TFLiteConverter. 108 if flags.graph_def_file: 109 converter_fn = lite.TFLiteConverter.from_frozen_graph 110 converter_kwargs["graph_def_file"] = flags.graph_def_file 111 elif flags.saved_model_dir: 112 converter_fn = lite.TFLiteConverter.from_saved_model 113 converter_kwargs["saved_model_dir"] = flags.saved_model_dir 114 converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) 115 converter_kwargs["signature_key"] = flags.saved_model_signature_key 116 elif flags.keras_model_file: 117 converter_fn = lite.TFLiteConverter.from_keras_model_file 118 converter_kwargs["model_file"] = flags.keras_model_file 119 else: 120 raise ValueError("--graph_def_file, --saved_model_dir, or " 121 "--keras_model_file must be specified.") 122 123 return converter_fn(**converter_kwargs) 124 125 126def _convert_tf1_model(flags): 127 """Calls function to convert the TensorFlow 1.X model into a TFLite model. 128 129 Args: 130 flags: argparse.Namespace object. 131 132 Raises: 133 ValueError: Invalid flags. 134 """ 135 # Register custom opdefs before converter object creation. 136 if flags.custom_opdefs: 137 register_custom_opdefs(_parse_array(flags.custom_opdefs)) 138 139 # Create converter. 140 converter = _get_tflite_converter(flags) 141 if flags.inference_type: 142 converter.inference_type = _parse_inference_type(flags.inference_type, 143 "inference_type") 144 if flags.inference_input_type: 145 converter.inference_input_type = _parse_inference_type( 146 flags.inference_input_type, "inference_input_type") 147 if flags.output_format: 148 converter.output_format = _toco_flags_pb2.FileFormat.Value( 149 flags.output_format) 150 151 if flags.mean_values and flags.std_dev_values: 152 input_arrays = converter.get_input_arrays() 153 std_dev_values = _parse_array(flags.std_dev_values, type_fn=float) 154 155 # In quantized inference, mean_value has to be integer so that the real 156 # value 0.0 is exactly representable. 157 if converter.inference_type == dtypes.float32: 158 mean_values = _parse_array(flags.mean_values, type_fn=float) 159 else: 160 mean_values = _parse_array(flags.mean_values, type_fn=int) 161 quant_stats = list(zip(mean_values, std_dev_values)) 162 if ((not flags.input_arrays and len(input_arrays) > 1) or 163 (len(input_arrays) != len(quant_stats))): 164 raise ValueError("Mismatching --input_arrays, --std_dev_values, and " 165 "--mean_values. The flags must have the same number of " 166 "items. The current input arrays are '{0}'. " 167 "--input_arrays must be present when specifying " 168 "--std_dev_values and --mean_values with multiple input " 169 "tensors in order to map between names and " 170 "values.".format(",".join(input_arrays))) 171 converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats))) 172 if (flags.default_ranges_min is not None) and (flags.default_ranges_max is 173 not None): 174 converter.default_ranges_stats = (flags.default_ranges_min, 175 flags.default_ranges_max) 176 177 if flags.drop_control_dependency: 178 converter.drop_control_dependency = flags.drop_control_dependency 179 if flags.reorder_across_fake_quant: 180 converter.reorder_across_fake_quant = flags.reorder_across_fake_quant 181 if flags.change_concat_input_ranges: 182 converter.change_concat_input_ranges = ( 183 flags.change_concat_input_ranges == "TRUE") 184 185 if flags.allow_custom_ops: 186 converter.allow_custom_ops = flags.allow_custom_ops 187 188 if flags.target_ops: 189 ops_set_options = lite.OpsSet.get_options() 190 converter.target_spec.supported_ops = set() 191 for option in six.ensure_str(flags.target_ops).split(","): 192 if option not in ops_set_options: 193 raise ValueError("Invalid value for --target_ops. Options: " 194 "{0}".format(",".join(ops_set_options))) 195 converter.target_spec.supported_ops.add(lite.OpsSet(option)) 196 197 if flags.experimental_select_user_tf_ops: 198 if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops: 199 raise ValueError("--experimental_select_user_tf_ops can only be set if " 200 "--target_ops contains SELECT_TF_OPS.") 201 user_op_set = set() 202 for op_name in six.ensure_str( 203 flags.experimental_select_user_tf_ops).split(","): 204 user_op_set.add(op_name) 205 converter.target_spec.experimental_select_user_tf_ops = list(user_op_set) 206 207 if flags.post_training_quantize: 208 converter.optimizations = [lite.Optimize.DEFAULT] 209 if converter.inference_type != dtypes.float32: 210 print("--post_training_quantize quantizes a graph of inference_type " 211 "FLOAT. Overriding inference_type to FLOAT.") 212 converter.inference_type = dtypes.float32 213 214 if flags.quantize_to_float16: 215 converter.target_spec.supported_types = [dtypes.float16] 216 if not flags.post_training_quantize: 217 print("--quantize_to_float16 will only take effect with the " 218 "--post_training_quantize flag enabled.") 219 220 if flags.dump_graphviz_dir: 221 converter.dump_graphviz_dir = flags.dump_graphviz_dir 222 if flags.dump_graphviz_video: 223 converter.dump_graphviz_vode = flags.dump_graphviz_video 224 if flags.conversion_summary_dir: 225 converter.conversion_summary_dir = flags.conversion_summary_dir 226 227 if flags.experimental_new_converter is not None: 228 converter.experimental_new_converter = flags.experimental_new_converter 229 230 if flags.experimental_new_quantizer is not None: 231 converter.experimental_new_quantizer = flags.experimental_new_quantizer 232 233 # Convert model. 234 output_data = converter.convert() 235 with open(flags.output_file, "wb") as f: 236 f.write(six.ensure_binary(output_data)) 237 238 239def _convert_tf2_model(flags): 240 """Calls function to convert the TensorFlow 2.0 model into a TFLite model. 241 242 Args: 243 flags: argparse.Namespace object. 244 245 Raises: 246 ValueError: Unsupported file format. 247 """ 248 # Load the model. 249 if flags.saved_model_dir: 250 converter = lite.TFLiteConverterV2.from_saved_model( 251 flags.saved_model_dir, 252 signature_keys=_parse_array(flags.saved_model_signature_key), 253 tags=_parse_set(flags.saved_model_tag_set)) 254 elif flags.keras_model_file: 255 model = keras_deps.get_load_model_function()(flags.keras_model_file) 256 converter = lite.TFLiteConverterV2.from_keras_model(model) 257 258 if flags.experimental_new_converter is not None: 259 converter.experimental_new_converter = flags.experimental_new_converter 260 261 if flags.experimental_new_quantizer is not None: 262 converter.experimental_new_quantizer = flags.experimental_new_quantizer 263 264 # Convert the model. 265 tflite_model = converter.convert() 266 with open(flags.output_file, "wb") as f: 267 f.write(six.ensure_binary(tflite_model)) 268 269 270def _check_tf1_flags(flags, unparsed): 271 """Checks the parsed and unparsed flags to ensure they are valid in 1.X. 272 273 Raises an error if previously support unparsed flags are found. Raises an 274 error for parsed flags that don't meet the required conditions. 275 276 Args: 277 flags: argparse.Namespace object containing TFLite flags. 278 unparsed: List of unparsed flags. 279 280 Raises: 281 ValueError: Invalid flags. 282 """ 283 284 # Check unparsed flags for common mistakes based on previous TOCO. 285 def _get_message_unparsed(flag, orig_flag, new_flag): 286 if six.ensure_str(flag).startswith(orig_flag): 287 return "\n Use {0} instead of {1}".format(new_flag, orig_flag) 288 return "" 289 290 if unparsed: 291 output = "" 292 for flag in unparsed: 293 output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") 294 output += _get_message_unparsed(flag, "--savedmodel_directory", 295 "--saved_model_dir") 296 output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") 297 output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") 298 output += _get_message_unparsed(flag, "--dump_graphviz", 299 "--dump_graphviz_dir") 300 if output: 301 raise ValueError(output) 302 303 # Check that flags are valid. 304 if flags.graph_def_file and (not flags.input_arrays or 305 not flags.output_arrays): 306 raise ValueError("--input_arrays and --output_arrays are required with " 307 "--graph_def_file") 308 309 if flags.input_shapes: 310 if not flags.input_arrays: 311 raise ValueError("--input_shapes must be used with --input_arrays") 312 if flags.input_shapes.count(":") != flags.input_arrays.count(","): 313 raise ValueError("--input_shapes and --input_arrays must have the same " 314 "number of items") 315 316 if flags.std_dev_values or flags.mean_values: 317 if bool(flags.std_dev_values) != bool(flags.mean_values): 318 raise ValueError("--std_dev_values and --mean_values must be used " 319 "together") 320 if flags.std_dev_values.count(",") != flags.mean_values.count(","): 321 raise ValueError("--std_dev_values, --mean_values must have the same " 322 "number of items") 323 324 if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): 325 raise ValueError("--default_ranges_min and --default_ranges_max must be " 326 "used together") 327 328 if flags.dump_graphviz_video and not flags.dump_graphviz_dir: 329 raise ValueError("--dump_graphviz_video must be used with " 330 "--dump_graphviz_dir") 331 332 if flags.custom_opdefs and not flags.experimental_new_converter: 333 raise ValueError("--custom_opdefs must be used with " 334 "--experimental_new_converter") 335 if flags.custom_opdefs and not flags.allow_custom_ops: 336 raise ValueError("--custom_opdefs must be used with --allow_custom_ops") 337 if (flags.experimental_select_user_tf_ops and 338 not flags.experimental_new_converter): 339 raise ValueError("--experimental_select_user_tf_ops must be used with " 340 "--experimental_new_converter") 341 342 343def _check_tf2_flags(flags): 344 """Checks the parsed and unparsed flags to ensure they are valid in 2.X. 345 346 Args: 347 flags: argparse.Namespace object containing TFLite flags. 348 349 Raises: 350 ValueError: Invalid flags. 351 """ 352 if not flags.keras_model_file and not flags.saved_model_dir: 353 raise ValueError("one of the arguments --saved_model_dir " 354 "--keras_model_file is required") 355 356 357def _get_tf1_flags(parser): 358 """Returns ArgumentParser for tflite_convert for TensorFlow 1.X. 359 360 Args: 361 parser: ArgumentParser 362 """ 363 # Input file flags. 364 input_file_group = parser.add_mutually_exclusive_group(required=True) 365 input_file_group.add_argument( 366 "--graph_def_file", 367 type=str, 368 help="Full filepath of file containing frozen TensorFlow GraphDef.") 369 input_file_group.add_argument( 370 "--saved_model_dir", 371 type=str, 372 help="Full filepath of directory containing the SavedModel.") 373 input_file_group.add_argument( 374 "--keras_model_file", 375 type=str, 376 help="Full filepath of HDF5 file containing tf.Keras model.") 377 378 # Model format flags. 379 parser.add_argument( 380 "--output_format", 381 type=str.upper, 382 choices=["TFLITE", "GRAPHVIZ_DOT"], 383 help="Output file format.") 384 parser.add_argument( 385 "--inference_type", 386 type=str.upper, 387 default="FLOAT", 388 help=("Target data type of real-number arrays in the output file. " 389 "Must be either FLOAT, INT8 or UINT8.")) 390 parser.add_argument( 391 "--inference_input_type", 392 type=str.upper, 393 help=("Target data type of real-number input arrays. Allows for a " 394 "different type for input arrays in the case of quantization. " 395 "Must be either FLOAT, INT8 or UINT8.")) 396 397 # Input and output arrays flags. 398 parser.add_argument( 399 "--input_arrays", 400 type=str, 401 help="Names of the input arrays, comma-separated.") 402 parser.add_argument( 403 "--input_shapes", 404 type=str, 405 help="Shapes corresponding to --input_arrays, colon-separated.") 406 parser.add_argument( 407 "--output_arrays", 408 type=str, 409 help="Names of the output arrays, comma-separated.") 410 411 # SavedModel related flags. 412 parser.add_argument( 413 "--saved_model_tag_set", 414 type=str, 415 help=("Comma-separated set of tags identifying the MetaGraphDef within " 416 "the SavedModel to analyze. All tags must be present. In order to " 417 "pass in an empty tag set, pass in \"\". (default \"serve\")")) 418 parser.add_argument( 419 "--saved_model_signature_key", 420 type=str, 421 help=("Key identifying the SignatureDef containing inputs and outputs. " 422 "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) 423 424 # Quantization flags. 425 parser.add_argument( 426 "--std_dev_values", 427 type=str, 428 help=("Standard deviation of training data for each input tensor, " 429 "comma-separated floats. Used for quantized input tensors. " 430 "(default None)")) 431 parser.add_argument( 432 "--mean_values", 433 type=str, 434 help=("Mean of training data for each input tensor, comma-separated " 435 "floats. Used for quantized input tensors. (default None)")) 436 parser.add_argument( 437 "--default_ranges_min", 438 type=float, 439 help=("Default value for min bound of min/max range values used for all " 440 "arrays without a specified range, Intended for experimenting with " 441 "quantization via \"dummy quantization\". (default None)")) 442 parser.add_argument( 443 "--default_ranges_max", 444 type=float, 445 help=("Default value for max bound of min/max range values used for all " 446 "arrays without a specified range, Intended for experimenting with " 447 "quantization via \"dummy quantization\". (default None)")) 448 # quantize_weights is DEPRECATED. 449 parser.add_argument( 450 "--quantize_weights", 451 dest="post_training_quantize", 452 action="store_true", 453 help=argparse.SUPPRESS) 454 parser.add_argument( 455 "--post_training_quantize", 456 dest="post_training_quantize", 457 action="store_true", 458 help=( 459 "Boolean indicating whether to quantize the weights of the " 460 "converted float model. Model size will be reduced and there will " 461 "be latency improvements (at the cost of accuracy). (default False)")) 462 parser.add_argument( 463 "--quantize_to_float16", 464 dest="quantize_to_float16", 465 action="store_true", 466 help=("Boolean indicating whether to quantize weights to fp16 instead of " 467 "the default int8 when post-training quantization " 468 "(--post_training_quantize) is enabled. (default False)")) 469 # Graph manipulation flags. 470 parser.add_argument( 471 "--drop_control_dependency", 472 action="store_true", 473 help=("Boolean indicating whether to drop control dependencies silently. " 474 "This is due to TensorFlow not supporting control dependencies. " 475 "(default True)")) 476 parser.add_argument( 477 "--reorder_across_fake_quant", 478 action="store_true", 479 help=("Boolean indicating whether to reorder FakeQuant nodes in " 480 "unexpected locations. Used when the location of the FakeQuant " 481 "nodes is preventing graph transformations necessary to convert " 482 "the graph. Results in a graph that differs from the quantized " 483 "training graph, potentially causing differing arithmetic " 484 "behavior. (default False)")) 485 # Usage for this flag is --change_concat_input_ranges=true or 486 # --change_concat_input_ranges=false in order to make it clear what the flag 487 # is set to. This keeps the usage consistent with other usages of the flag 488 # where the default is different. The default value here is False. 489 parser.add_argument( 490 "--change_concat_input_ranges", 491 type=str.upper, 492 choices=["TRUE", "FALSE"], 493 help=("Boolean to change behavior of min/max ranges for inputs and " 494 "outputs of the concat operator for quantized models. Changes the " 495 "ranges of concat operator overlap when true. (default False)")) 496 497 # Permitted ops flags. 498 parser.add_argument( 499 "--allow_custom_ops", 500 action="store_true", 501 help=("Boolean indicating whether to allow custom operations. When false " 502 "any unknown operation is an error. When true, custom ops are " 503 "created for any op that is unknown. The developer will need to " 504 "provide these to the TensorFlow Lite runtime with a custom " 505 "resolver. (default False)")) 506 parser.add_argument( 507 "--custom_opdefs", 508 type=str, 509 help=("String representing a list of custom ops OpDefs delineated with " 510 "commas that are included in the GraphDef. Required when using " 511 "custom operations with --experimental_new_converter.")) 512 parser.add_argument( 513 "--target_ops", 514 type=str, 515 help=("Experimental flag, subject to change. Set of OpsSet options " 516 "indicating which converter to use. Options: {0}. One or more " 517 "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))" 518 "".format(",".join(lite.OpsSet.get_options())))) 519 parser.add_argument( 520 "--experimental_select_user_tf_ops", 521 type=str, 522 help=("Experimental flag, subject to change. Comma separated list of " 523 "user's defined TensorFlow operators required in the runtime.")) 524 525 # Logging flags. 526 parser.add_argument( 527 "--dump_graphviz_dir", 528 type=str, 529 help=("Full filepath of folder to dump the graphs at various stages of " 530 "processing GraphViz .dot files. Preferred over --output_format=" 531 "GRAPHVIZ_DOT in order to keep the requirements of the output " 532 "file.")) 533 parser.add_argument( 534 "--dump_graphviz_video", 535 action="store_true", 536 help=("Boolean indicating whether to dump the graph after every graph " 537 "transformation")) 538 parser.add_argument( 539 "--conversion_summary_dir", 540 type=str, 541 help=("Full filepath to store the conversion logs, which includes " 542 "graphviz of the model before/after the conversion, an HTML report " 543 "and the conversion proto buffers. This will only be generated " 544 "when passing --experimental_new_converter")) 545 546 547def _get_tf2_flags(parser): 548 """Returns ArgumentParser for tflite_convert for TensorFlow 2.0. 549 550 Args: 551 parser: ArgumentParser 552 """ 553 # Input file flags. 554 input_file_group = parser.add_mutually_exclusive_group() 555 input_file_group.add_argument( 556 "--saved_model_dir", 557 type=str, 558 help="Full path of the directory containing the SavedModel.") 559 input_file_group.add_argument( 560 "--keras_model_file", 561 type=str, 562 help="Full filepath of HDF5 file containing tf.Keras model.") 563 # SavedModel related flags. 564 parser.add_argument( 565 "--saved_model_tag_set", 566 type=str, 567 help=("Comma-separated set of tags identifying the MetaGraphDef within " 568 "the SavedModel to analyze. All tags must be present. In order to " 569 "pass in an empty tag set, pass in \"\". (default \"serve\")")) 570 parser.add_argument( 571 "--saved_model_signature_key", 572 type=str, 573 help=("Key identifying the SignatureDef containing inputs and outputs. " 574 "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) 575 576 # Enables 1.X converter in 2.X. 577 parser.add_argument( 578 "--enable_v1_converter", 579 action="store_true", 580 help=("Enables the TensorFlow V1 converter in 2.0")) 581 582 583class _ParseExperimentalNewConverter(argparse.Action): 584 """Helper class to parse --experimental_new_converter argument.""" 585 586 def __init__(self, option_strings, dest, nargs=None, **kwargs): 587 if nargs != "?": 588 # This should never happen. This class is only used once below with 589 # nargs="?". 590 raise ValueError( 591 "This parser only supports nargs='?' (0 or 1 additional arguments)") 592 super(_ParseExperimentalNewConverter, self).__init__( 593 option_strings, dest, nargs=nargs, **kwargs) 594 595 def __call__(self, parser, namespace, values, option_string=None): 596 if values is None: 597 # Handling `--experimental_new_converter`. 598 # Without additional arguments, it implies enabling the new converter. 599 experimental_new_converter = True 600 elif values.lower() == "true": 601 # Handling `--experimental_new_converter=true`. 602 # (Case insensitive after the equal sign) 603 experimental_new_converter = True 604 elif values.lower() == "false": 605 # Handling `--experimental_new_converter=false`. 606 # (Case insensitive after the equal sign) 607 experimental_new_converter = False 608 else: 609 raise ValueError("Invalid --experimental_new_converter argument.") 610 setattr(namespace, self.dest, experimental_new_converter) 611 612 613def _get_parser(use_v2_converter): 614 """Returns an ArgumentParser for tflite_convert. 615 616 Args: 617 use_v2_converter: Indicates which converter to return. 618 Return: ArgumentParser. 619 """ 620 parser = argparse.ArgumentParser( 621 description=("Command line tool to run TensorFlow Lite Converter.")) 622 623 # Output file flag. 624 parser.add_argument( 625 "--output_file", 626 type=str, 627 help="Full filepath of the output file.", 628 required=True) 629 630 if use_v2_converter: 631 _get_tf2_flags(parser) 632 else: 633 _get_tf1_flags(parser) 634 635 parser.add_argument( 636 "--experimental_new_converter", 637 action=_ParseExperimentalNewConverter, 638 nargs="?", 639 help=("Experimental flag, subject to change. Enables MLIR-based " 640 "conversion instead of TOCO conversion. (default True)")) 641 642 parser.add_argument( 643 "--experimental_new_quantizer", 644 action="store_true", 645 help=("Experimental flag, subject to change. Enables MLIR-based " 646 "quantizer instead of flatbuffer conversion. (default False)")) 647 return parser 648 649 650def run_main(_): 651 """Main in tflite_convert.py.""" 652 use_v2_converter = tf2.enabled() 653 parser = _get_parser(use_v2_converter) 654 tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) 655 656 # If the user is running TensorFlow 2.X but has passed in enable_v1_converter 657 # then parse the flags again with the 1.X converter flags. 658 if tf2.enabled() and tflite_flags.enable_v1_converter: 659 use_v2_converter = False 660 parser = _get_parser(use_v2_converter) 661 tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) 662 663 # Checks if the flags are valid. 664 try: 665 if use_v2_converter: 666 _check_tf2_flags(tflite_flags) 667 else: 668 _check_tf1_flags(tflite_flags, unparsed) 669 except ValueError as e: 670 parser.print_usage() 671 file_name = os.path.basename(sys.argv[0]) 672 sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) 673 sys.exit(1) 674 675 # Convert the model according to the user provided flag. 676 if use_v2_converter: 677 _convert_tf2_model(tflite_flags) 678 else: 679 try: 680 _convert_tf1_model(tflite_flags) 681 finally: 682 if tflite_flags.conversion_summary_dir: 683 if tflite_flags.experimental_new_converter: 684 gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir, 685 tflite_flags.post_training_quantize, 686 tflite_flags.output_file) 687 else: 688 warnings.warn( 689 "Conversion summary will only be generated when enabling" 690 " the new converter via --experimental_new_converter. ") 691 692 693def main(): 694 app.run(main=run_main, argv=sys.argv[:1]) 695 696 697if __name__ == "__main__": 698 main() 699