1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Generate a series of TensorFlow graphs that become tflite test cases. 17 18Usage: 19 20generate_examples <output directory> 21 22bazel run //tensorflow/lite/testing:generate_examples 23 24To more easily debug failures use (or override) the --save_graphdefs flag to 25place text proto graphdefs into the generated zip files. 26""" 27 28from __future__ import absolute_import 29from __future__ import division 30from __future__ import print_function 31 32import tensorflow.compat.v1 as tf 33import argparse 34import os 35import sys 36from tensorflow.lite.testing import generate_examples_lib 37from tensorflow.lite.testing import toco_convert 38 39# TODO(aselle): Disable GPU for now 40os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 41 42 43parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") 44parser.add_argument("output_path", 45 help="Directory where the outputs will be go.") 46parser.add_argument( 47 "--zip_to_output", 48 type=str, 49 help="Particular zip to output.", 50 required=True) 51parser.add_argument("--toco", 52 type=str, 53 help="Path to toco tool.", 54 required=True) 55parser.add_argument( 56 "--known_bugs_are_errors", 57 action="store_true", 58 help=("If a particular model is affected by a known bug," 59 " count it as a converter error.")) 60parser.add_argument( 61 "--ignore_converter_errors", 62 action="store_true", 63 help="Raise an exception if any converter error is encountered.") 64parser.add_argument( 65 "--save_graphdefs", 66 action="store_true", 67 help="Include intermediate graphdefs in the output zip files.") 68parser.add_argument( 69 "--run_with_flex", 70 action="store_true", 71 help="Whether the TFLite Flex converter is being used.") 72parser.add_argument( 73 "--make_edgetpu_tests", 74 action="store_true", 75 help="Whether to generate test cases for edgetpu.") 76parser.add_argument( 77 "--make_forward_compat_test", 78 action="store_true", 79 help="Make tests by setting TF forward compatibility horizon to the future") 80parser.add_argument( 81 "--no_tests_limit", 82 action="store_true", 83 help="Remove the limit of the number of tests.") 84parser.add_argument( 85 "--no_conversion_report", 86 action="store_true", 87 help="Do not create conversion report.") 88parser.add_argument( 89 "--test_sets", 90 type=str, 91 help=("Comma-separated list of test set names to generate. " 92 "If not specified, a test set is selected by parsing the name of " 93 "'zip_to_output' file.")) 94parser.add_argument( 95 "--mlir_quantizer", 96 action="store_true", 97 help=("Whether the new MLIR quantizer is being used.")) 98 99 100# Toco binary path provided by the generate rule. 101bin_path = None 102 103 104def main(unused_args): 105 # Eager execution is enabled by default in TF 2.0, but generated example 106 # tests are still using non-eager features (e.g. `tf.placeholder`). 107 tf.compat.v1.disable_eager_execution() 108 109 options = generate_examples_lib.Options() 110 111 options.output_path = FLAGS.output_path 112 options.zip_to_output = FLAGS.zip_to_output 113 options.toco = FLAGS.toco 114 options.known_bugs_are_errors = FLAGS.known_bugs_are_errors 115 options.ignore_converter_errors = FLAGS.ignore_converter_errors 116 options.save_graphdefs = FLAGS.save_graphdefs 117 options.run_with_flex = FLAGS.run_with_flex 118 options.make_edgetpu_tests = FLAGS.make_edgetpu_tests 119 options.make_forward_compat_test = FLAGS.make_forward_compat_test 120 options.tflite_convert_function = toco_convert.toco_convert 121 options.no_tests_limit = FLAGS.no_tests_limit 122 options.no_conversion_report = FLAGS.no_conversion_report 123 options.mlir_quantizer = FLAGS.mlir_quantizer 124 125 if FLAGS.test_sets: 126 test_sets = FLAGS.test_sets.split(",") 127 generate_examples_lib.generate_multi_set_examples(options, test_sets) 128 else: 129 generate_examples_lib.generate_examples(options) 130 131 132if __name__ == "__main__": 133 FLAGS, unparsed = parser.parse_known_args() 134 135 if unparsed: 136 parser.print_usage() 137 print("\nGot the following unparsed args, %r please fix.\n" % unparsed) 138 exit(1) 139 else: 140 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 141