1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python command line interface for converting TF models to TFLite models.""" 16 17import argparse 18import os 19import sys 20import warnings 21 22from absl import app 23import tensorflow as tf # pylint: disable=unused-import 24 25from tensorflow.lite.python import lite 26from tensorflow.lite.python.convert import register_custom_opdefs 27from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 28from tensorflow.lite.toco.logging import gen_html 29from tensorflow.python import tf2 30from tensorflow.python.framework import dtypes 31from tensorflow.python.util import keras_deps 32 33# Needed to enable TF2 by default. 34 35 36def _parse_array(values, type_fn=str): 37 if values is not None: 38 return [type_fn(val) for val in values.split(",") if val] 39 return None 40 41 42def _parse_set(values): 43 if values is not None: 44 return set([item for item in values.split(",") if item]) 45 return None 46 47 48def _parse_inference_type(value, flag): 49 """Converts the inference type to the value of the constant. 50 51 Args: 52 value: str representing the inference type. 53 flag: str representing the flag name. 54 55 Returns: 56 tf.dtype. 57 58 Raises: 59 ValueError: Unsupported value. 60 """ 61 if value == "FLOAT": 62 return dtypes.float32 63 if value == "INT8": 64 return dtypes.int8 65 if value == "UINT8" or value == "QUANTIZED_UINT8": 66 return dtypes.uint8 67 raise ValueError( 68 "Unsupported value for `{}` flag. Expected FLOAT, INT8, UINT8, or " 69 "QUANTIZED_UINT8 instead got {}.".format(flag, value)) 70 71 72class _ParseBooleanFlag(argparse.Action): 73 """Helper class to parse boolean flag that optionally accepts truth value.""" 74 75 def __init__(self, option_strings, dest, nargs=None, **kwargs): 76 if nargs != "?": 77 # This should never happen. This class is only used once below with 78 # nargs="?". 79 raise ValueError( 80 "This parser only supports nargs='?' (0 or 1 additional arguments)") 81 super(_ParseBooleanFlag, self).__init__( 82 option_strings, dest, nargs=nargs, **kwargs) 83 84 def __call__(self, parser, namespace, values, option_string=None): 85 if values is None: 86 # Handling `--boolean_flag`. 87 # Without additional arguments, it implies true. 88 flag_value = True 89 elif values.lower() == "true": 90 # Handling `--boolean_flag=true`. 91 # (Case insensitive after the equal sign) 92 flag_value = True 93 elif values.lower() == "false": 94 # Handling `--boolean_flag=false`. 95 # (Case insensitive after the equal sign) 96 flag_value = False 97 else: 98 raise ValueError("Invalid argument to --{}. Must use flag alone," 99 " or specify true/false.".format(self.dest)) 100 setattr(namespace, self.dest, flag_value) 101 102 103def _get_tflite_converter(flags): 104 """Makes a TFLiteConverter object based on the flags provided. 105 106 Args: 107 flags: argparse.Namespace object containing TFLite flags. 108 109 Returns: 110 TFLiteConverter object. 111 112 Raises: 113 ValueError: Invalid flags. 114 """ 115 # Parse input and output arrays. 116 input_arrays = _parse_array(flags.input_arrays) 117 input_shapes = None 118 if flags.input_shapes: 119 input_shapes_list = [ 120 _parse_array(shape, type_fn=int) 121 for shape in flags.input_shapes.split(":") 122 ] 123 input_shapes = dict(list(zip(input_arrays, input_shapes_list))) 124 output_arrays = _parse_array(flags.output_arrays) 125 126 converter_kwargs = { 127 "input_arrays": input_arrays, 128 "input_shapes": input_shapes, 129 "output_arrays": output_arrays 130 } 131 132 # Create TFLiteConverter. 133 if flags.graph_def_file: 134 converter_fn = lite.TFLiteConverter.from_frozen_graph 135 converter_kwargs["graph_def_file"] = flags.graph_def_file 136 elif flags.saved_model_dir: 137 converter_fn = lite.TFLiteConverter.from_saved_model 138 converter_kwargs["saved_model_dir"] = flags.saved_model_dir 139 converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) 140 converter_kwargs["signature_key"] = flags.saved_model_signature_key 141 elif flags.keras_model_file: 142 converter_fn = lite.TFLiteConverter.from_keras_model_file 143 converter_kwargs["model_file"] = flags.keras_model_file 144 else: 145 raise ValueError("--graph_def_file, --saved_model_dir, or " 146 "--keras_model_file must be specified.") 147 148 return converter_fn(**converter_kwargs) 149 150 151def _convert_tf1_model(flags): 152 """Calls function to convert the TensorFlow 1.X model into a TFLite model. 153 154 Args: 155 flags: argparse.Namespace object. 156 157 Raises: 158 ValueError: Invalid flags. 159 """ 160 # Register custom opdefs before converter object creation. 161 if flags.custom_opdefs: 162 register_custom_opdefs(_parse_array(flags.custom_opdefs)) 163 164 # Create converter. 165 converter = _get_tflite_converter(flags) 166 if flags.inference_type: 167 converter.inference_type = _parse_inference_type(flags.inference_type, 168 "inference_type") 169 if flags.inference_input_type: 170 converter.inference_input_type = _parse_inference_type( 171 flags.inference_input_type, "inference_input_type") 172 if flags.output_format: 173 converter.output_format = _toco_flags_pb2.FileFormat.Value( 174 flags.output_format) 175 176 if flags.mean_values and flags.std_dev_values: 177 input_arrays = converter.get_input_arrays() 178 std_dev_values = _parse_array(flags.std_dev_values, type_fn=float) 179 180 # In quantized inference, mean_value has to be integer so that the real 181 # value 0.0 is exactly representable. 182 if converter.inference_type == dtypes.float32: 183 mean_values = _parse_array(flags.mean_values, type_fn=float) 184 else: 185 mean_values = _parse_array(flags.mean_values, type_fn=int) 186 quant_stats = list(zip(mean_values, std_dev_values)) 187 if ((not flags.input_arrays and len(input_arrays) > 1) or 188 (len(input_arrays) != len(quant_stats))): 189 raise ValueError("Mismatching --input_arrays, --std_dev_values, and " 190 "--mean_values. The flags must have the same number of " 191 "items. The current input arrays are '{0}'. " 192 "--input_arrays must be present when specifying " 193 "--std_dev_values and --mean_values with multiple input " 194 "tensors in order to map between names and " 195 "values.".format(",".join(input_arrays))) 196 converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats))) 197 if (flags.default_ranges_min is not None) and (flags.default_ranges_max is 198 not None): 199 converter.default_ranges_stats = (flags.default_ranges_min, 200 flags.default_ranges_max) 201 202 if flags.drop_control_dependency: 203 converter.drop_control_dependency = flags.drop_control_dependency 204 if flags.reorder_across_fake_quant: 205 converter.reorder_across_fake_quant = flags.reorder_across_fake_quant 206 if flags.change_concat_input_ranges: 207 converter.change_concat_input_ranges = ( 208 flags.change_concat_input_ranges == "TRUE") 209 210 if flags.allow_custom_ops: 211 converter.allow_custom_ops = flags.allow_custom_ops 212 213 if flags.target_ops: 214 ops_set_options = lite.OpsSet.get_options() 215 converter.target_spec.supported_ops = set() 216 for option in flags.target_ops.split(","): 217 if option not in ops_set_options: 218 raise ValueError("Invalid value for --target_ops. Options: " 219 "{0}".format(",".join(ops_set_options))) 220 converter.target_spec.supported_ops.add(lite.OpsSet(option)) 221 222 if flags.experimental_select_user_tf_ops: 223 if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops: 224 raise ValueError("--experimental_select_user_tf_ops can only be set if " 225 "--target_ops contains SELECT_TF_OPS.") 226 user_op_set = set() 227 for op_name in flags.experimental_select_user_tf_ops.split(","): 228 user_op_set.add(op_name) 229 converter.target_spec.experimental_select_user_tf_ops = list(user_op_set) 230 231 if flags.post_training_quantize: 232 converter.optimizations = [lite.Optimize.DEFAULT] 233 if converter.inference_type != dtypes.float32: 234 print("--post_training_quantize quantizes a graph of inference_type " 235 "FLOAT. Overriding inference_type to FLOAT.") 236 converter.inference_type = dtypes.float32 237 238 if flags.quantize_to_float16: 239 converter.target_spec.supported_types = [dtypes.float16] 240 if not flags.post_training_quantize: 241 print("--quantize_to_float16 will only take effect with the " 242 "--post_training_quantize flag enabled.") 243 244 if flags.dump_graphviz_dir: 245 converter.dump_graphviz_dir = flags.dump_graphviz_dir 246 if flags.dump_graphviz_video: 247 converter.dump_graphviz_vode = flags.dump_graphviz_video 248 if flags.conversion_summary_dir: 249 converter.conversion_summary_dir = flags.conversion_summary_dir 250 251 converter.experimental_new_converter = flags.experimental_new_converter 252 253 if flags.experimental_new_quantizer is not None: 254 converter.experimental_new_quantizer = flags.experimental_new_quantizer 255 256 # Convert model. 257 output_data = converter.convert() 258 with open(flags.output_file, "wb") as f: 259 f.write(output_data) 260 261 262def _convert_tf2_model(flags): 263 """Calls function to convert the TensorFlow 2.0 model into a TFLite model. 264 265 Args: 266 flags: argparse.Namespace object. 267 268 Raises: 269 ValueError: Unsupported file format. 270 """ 271 # Load the model. 272 if flags.saved_model_dir: 273 converter = lite.TFLiteConverterV2.from_saved_model( 274 flags.saved_model_dir, 275 signature_keys=_parse_array(flags.saved_model_signature_key), 276 tags=_parse_set(flags.saved_model_tag_set)) 277 elif flags.keras_model_file: 278 model = keras_deps.get_load_model_function()(flags.keras_model_file) 279 converter = lite.TFLiteConverterV2.from_keras_model(model) 280 281 converter.experimental_new_converter = flags.experimental_new_converter 282 283 if flags.experimental_new_quantizer is not None: 284 converter.experimental_new_quantizer = flags.experimental_new_quantizer 285 286 # Convert the model. 287 tflite_model = converter.convert() 288 with open(flags.output_file, "wb") as f: 289 f.write(tflite_model) 290 291 292def _check_tf1_flags(flags, unparsed): 293 """Checks the parsed and unparsed flags to ensure they are valid in 1.X. 294 295 Raises an error if previously support unparsed flags are found. Raises an 296 error for parsed flags that don't meet the required conditions. 297 298 Args: 299 flags: argparse.Namespace object containing TFLite flags. 300 unparsed: List of unparsed flags. 301 302 Raises: 303 ValueError: Invalid flags. 304 """ 305 306 # Check unparsed flags for common mistakes based on previous TOCO. 307 def _get_message_unparsed(flag, orig_flag, new_flag): 308 if flag.startswith(orig_flag): 309 return "\n Use {0} instead of {1}".format(new_flag, orig_flag) 310 return "" 311 312 if unparsed: 313 output = "" 314 for flag in unparsed: 315 output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") 316 output += _get_message_unparsed(flag, "--savedmodel_directory", 317 "--saved_model_dir") 318 output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") 319 output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") 320 output += _get_message_unparsed(flag, "--dump_graphviz", 321 "--dump_graphviz_dir") 322 if output: 323 raise ValueError(output) 324 325 # Check that flags are valid. 326 if flags.graph_def_file and (not flags.input_arrays or 327 not flags.output_arrays): 328 raise ValueError("--input_arrays and --output_arrays are required with " 329 "--graph_def_file") 330 331 if flags.input_shapes: 332 if not flags.input_arrays: 333 raise ValueError("--input_shapes must be used with --input_arrays") 334 if flags.input_shapes.count(":") != flags.input_arrays.count(","): 335 raise ValueError("--input_shapes and --input_arrays must have the same " 336 "number of items") 337 338 if flags.std_dev_values or flags.mean_values: 339 if bool(flags.std_dev_values) != bool(flags.mean_values): 340 raise ValueError("--std_dev_values and --mean_values must be used " 341 "together") 342 if flags.std_dev_values.count(",") != flags.mean_values.count(","): 343 raise ValueError("--std_dev_values, --mean_values must have the same " 344 "number of items") 345 346 if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): 347 raise ValueError("--default_ranges_min and --default_ranges_max must be " 348 "used together") 349 350 if flags.dump_graphviz_video and not flags.dump_graphviz_dir: 351 raise ValueError("--dump_graphviz_video must be used with " 352 "--dump_graphviz_dir") 353 354 if flags.custom_opdefs and not flags.experimental_new_converter: 355 raise ValueError("--custom_opdefs must be used with " 356 "--experimental_new_converter") 357 if flags.custom_opdefs and not flags.allow_custom_ops: 358 raise ValueError("--custom_opdefs must be used with --allow_custom_ops") 359 if (flags.experimental_select_user_tf_ops and 360 not flags.experimental_new_converter): 361 raise ValueError("--experimental_select_user_tf_ops must be used with " 362 "--experimental_new_converter") 363 364 365def _check_tf2_flags(flags): 366 """Checks the parsed and unparsed flags to ensure they are valid in 2.X. 367 368 Args: 369 flags: argparse.Namespace object containing TFLite flags. 370 371 Raises: 372 ValueError: Invalid flags. 373 """ 374 if not flags.keras_model_file and not flags.saved_model_dir: 375 raise ValueError("one of the arguments --saved_model_dir " 376 "--keras_model_file is required") 377 378 379def _get_tf1_flags(parser): 380 """Returns ArgumentParser for tflite_convert for TensorFlow 1.X. 381 382 Args: 383 parser: ArgumentParser 384 """ 385 # Input file flags. 386 input_file_group = parser.add_mutually_exclusive_group(required=True) 387 input_file_group.add_argument( 388 "--graph_def_file", 389 type=str, 390 help="Full filepath of file containing frozen TensorFlow GraphDef.") 391 input_file_group.add_argument( 392 "--saved_model_dir", 393 type=str, 394 help="Full filepath of directory containing the SavedModel.") 395 input_file_group.add_argument( 396 "--keras_model_file", 397 type=str, 398 help="Full filepath of HDF5 file containing tf.Keras model.") 399 400 # Model format flags. 401 parser.add_argument( 402 "--output_format", 403 type=str.upper, 404 choices=["TFLITE", "GRAPHVIZ_DOT"], 405 help="Output file format.") 406 parser.add_argument( 407 "--inference_type", 408 type=str.upper, 409 default="FLOAT", 410 help=("Target data type of real-number arrays in the output file. " 411 "Must be either FLOAT, INT8 or UINT8.")) 412 parser.add_argument( 413 "--inference_input_type", 414 type=str.upper, 415 help=("Target data type of real-number input arrays. Allows for a " 416 "different type for input arrays in the case of quantization. " 417 "Must be either FLOAT, INT8 or UINT8.")) 418 419 # Input and output arrays flags. 420 parser.add_argument( 421 "--input_arrays", 422 type=str, 423 help="Names of the input arrays, comma-separated.") 424 parser.add_argument( 425 "--input_shapes", 426 type=str, 427 help="Shapes corresponding to --input_arrays, colon-separated.") 428 parser.add_argument( 429 "--output_arrays", 430 type=str, 431 help="Names of the output arrays, comma-separated.") 432 433 # SavedModel related flags. 434 parser.add_argument( 435 "--saved_model_tag_set", 436 type=str, 437 help=("Comma-separated set of tags identifying the MetaGraphDef within " 438 "the SavedModel to analyze. All tags must be present. In order to " 439 "pass in an empty tag set, pass in \"\". (default \"serve\")")) 440 parser.add_argument( 441 "--saved_model_signature_key", 442 type=str, 443 help=("Key identifying the SignatureDef containing inputs and outputs. " 444 "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) 445 446 # Quantization flags. 447 parser.add_argument( 448 "--std_dev_values", 449 type=str, 450 help=("Standard deviation of training data for each input tensor, " 451 "comma-separated floats. Used for quantized input tensors. " 452 "(default None)")) 453 parser.add_argument( 454 "--mean_values", 455 type=str, 456 help=("Mean of training data for each input tensor, comma-separated " 457 "floats. Used for quantized input tensors. (default None)")) 458 parser.add_argument( 459 "--default_ranges_min", 460 type=float, 461 help=("Default value for min bound of min/max range values used for all " 462 "arrays without a specified range, Intended for experimenting with " 463 "quantization via \"dummy quantization\". (default None)")) 464 parser.add_argument( 465 "--default_ranges_max", 466 type=float, 467 help=("Default value for max bound of min/max range values used for all " 468 "arrays without a specified range, Intended for experimenting with " 469 "quantization via \"dummy quantization\". (default None)")) 470 # quantize_weights is DEPRECATED. 471 parser.add_argument( 472 "--quantize_weights", 473 dest="post_training_quantize", 474 action="store_true", 475 help=argparse.SUPPRESS) 476 parser.add_argument( 477 "--post_training_quantize", 478 dest="post_training_quantize", 479 action="store_true", 480 help=( 481 "Boolean indicating whether to quantize the weights of the " 482 "converted float model. Model size will be reduced and there will " 483 "be latency improvements (at the cost of accuracy). (default False)")) 484 parser.add_argument( 485 "--quantize_to_float16", 486 dest="quantize_to_float16", 487 action="store_true", 488 help=("Boolean indicating whether to quantize weights to fp16 instead of " 489 "the default int8 when post-training quantization " 490 "(--post_training_quantize) is enabled. (default False)")) 491 # Graph manipulation flags. 492 parser.add_argument( 493 "--drop_control_dependency", 494 action="store_true", 495 help=("Boolean indicating whether to drop control dependencies silently. " 496 "This is due to TensorFlow not supporting control dependencies. " 497 "(default True)")) 498 parser.add_argument( 499 "--reorder_across_fake_quant", 500 action="store_true", 501 help=("Boolean indicating whether to reorder FakeQuant nodes in " 502 "unexpected locations. Used when the location of the FakeQuant " 503 "nodes is preventing graph transformations necessary to convert " 504 "the graph. Results in a graph that differs from the quantized " 505 "training graph, potentially causing differing arithmetic " 506 "behavior. (default False)")) 507 # Usage for this flag is --change_concat_input_ranges=true or 508 # --change_concat_input_ranges=false in order to make it clear what the flag 509 # is set to. This keeps the usage consistent with other usages of the flag 510 # where the default is different. The default value here is False. 511 parser.add_argument( 512 "--change_concat_input_ranges", 513 type=str.upper, 514 choices=["TRUE", "FALSE"], 515 help=("Boolean to change behavior of min/max ranges for inputs and " 516 "outputs of the concat operator for quantized models. Changes the " 517 "ranges of concat operator overlap when true. (default False)")) 518 519 # Permitted ops flags. 520 parser.add_argument( 521 "--allow_custom_ops", 522 action=_ParseBooleanFlag, 523 nargs="?", 524 help=("Boolean indicating whether to allow custom operations. When false " 525 "any unknown operation is an error. When true, custom ops are " 526 "created for any op that is unknown. The developer will need to " 527 "provide these to the TensorFlow Lite runtime with a custom " 528 "resolver. (default False)")) 529 parser.add_argument( 530 "--custom_opdefs", 531 type=str, 532 help=("String representing a list of custom ops OpDefs delineated with " 533 "commas that are included in the GraphDef. Required when using " 534 "custom operations with --experimental_new_converter.")) 535 parser.add_argument( 536 "--target_ops", 537 type=str, 538 help=("Experimental flag, subject to change. Set of OpsSet options " 539 "indicating which converter to use. Options: {0}. One or more " 540 "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))" 541 "".format(",".join(lite.OpsSet.get_options())))) 542 parser.add_argument( 543 "--experimental_select_user_tf_ops", 544 type=str, 545 help=("Experimental flag, subject to change. Comma separated list of " 546 "user's defined TensorFlow operators required in the runtime.")) 547 548 # Logging flags. 549 parser.add_argument( 550 "--dump_graphviz_dir", 551 type=str, 552 help=("Full filepath of folder to dump the graphs at various stages of " 553 "processing GraphViz .dot files. Preferred over --output_format=" 554 "GRAPHVIZ_DOT in order to keep the requirements of the output " 555 "file.")) 556 parser.add_argument( 557 "--dump_graphviz_video", 558 action="store_true", 559 help=("Boolean indicating whether to dump the graph after every graph " 560 "transformation")) 561 parser.add_argument( 562 "--conversion_summary_dir", 563 type=str, 564 help=("Full filepath to store the conversion logs, which includes " 565 "graphviz of the model before/after the conversion, an HTML report " 566 "and the conversion proto buffers. This will only be generated " 567 "when passing --experimental_new_converter")) 568 569 570def _get_tf2_flags(parser): 571 """Returns ArgumentParser for tflite_convert for TensorFlow 2.0. 572 573 Args: 574 parser: ArgumentParser 575 """ 576 # Input file flags. 577 input_file_group = parser.add_mutually_exclusive_group() 578 input_file_group.add_argument( 579 "--saved_model_dir", 580 type=str, 581 help="Full path of the directory containing the SavedModel.") 582 input_file_group.add_argument( 583 "--keras_model_file", 584 type=str, 585 help="Full filepath of HDF5 file containing tf.Keras model.") 586 # SavedModel related flags. 587 parser.add_argument( 588 "--saved_model_tag_set", 589 type=str, 590 help=("Comma-separated set of tags identifying the MetaGraphDef within " 591 "the SavedModel to analyze. All tags must be present. In order to " 592 "pass in an empty tag set, pass in \"\". (default \"serve\")")) 593 parser.add_argument( 594 "--saved_model_signature_key", 595 type=str, 596 help=("Key identifying the SignatureDef containing inputs and outputs. " 597 "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) 598 599 # Enables 1.X converter in 2.X. 600 parser.add_argument( 601 "--enable_v1_converter", 602 action="store_true", 603 help=("Enables the TensorFlow V1 converter in 2.0")) 604 605 606def _get_parser(use_v2_converter): 607 """Returns an ArgumentParser for tflite_convert. 608 609 Args: 610 use_v2_converter: Indicates which converter to return. 611 Return: ArgumentParser. 612 """ 613 parser = argparse.ArgumentParser( 614 description=("Command line tool to run TensorFlow Lite Converter.")) 615 616 # Output file flag. 617 parser.add_argument( 618 "--output_file", 619 type=str, 620 help="Full filepath of the output file.", 621 required=True) 622 623 if use_v2_converter: 624 _get_tf2_flags(parser) 625 else: 626 _get_tf1_flags(parser) 627 628 parser.add_argument( 629 "--experimental_new_converter", 630 action=_ParseBooleanFlag, 631 nargs="?", 632 default=True, 633 help=("Experimental flag, subject to change. Enables MLIR-based " 634 "conversion instead of TOCO conversion. (default True)")) 635 636 parser.add_argument( 637 "--experimental_new_quantizer", 638 action=_ParseBooleanFlag, 639 nargs="?", 640 help=("Experimental flag, subject to change. Enables MLIR-based " 641 "quantizer instead of flatbuffer conversion. (default True)")) 642 return parser 643 644 645def run_main(_): 646 """Main in tflite_convert.py.""" 647 use_v2_converter = tf2.enabled() 648 parser = _get_parser(use_v2_converter) 649 tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) 650 651 # If the user is running TensorFlow 2.X but has passed in enable_v1_converter 652 # then parse the flags again with the 1.X converter flags. 653 if tf2.enabled() and tflite_flags.enable_v1_converter: 654 use_v2_converter = False 655 parser = _get_parser(use_v2_converter) 656 tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) 657 658 # Checks if the flags are valid. 659 try: 660 if use_v2_converter: 661 _check_tf2_flags(tflite_flags) 662 else: 663 _check_tf1_flags(tflite_flags, unparsed) 664 except ValueError as e: 665 parser.print_usage() 666 file_name = os.path.basename(sys.argv[0]) 667 sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) 668 sys.exit(1) 669 670 # Convert the model according to the user provided flag. 671 if use_v2_converter: 672 _convert_tf2_model(tflite_flags) 673 else: 674 try: 675 _convert_tf1_model(tflite_flags) 676 finally: 677 if tflite_flags.conversion_summary_dir: 678 if tflite_flags.experimental_new_converter: 679 gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir, 680 tflite_flags.post_training_quantize, 681 tflite_flags.output_file) 682 else: 683 warnings.warn( 684 "Conversion summary will only be generated when enabling" 685 " the new converter via --experimental_new_converter. ") 686 687 688def main(): 689 app.run(main=run_main, argv=sys.argv[:1]) 690 691 692if __name__ == "__main__": 693 main() 694