• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Generate a series of TensorFlow graphs that become tflite test cases.
17
18Usage:
19
20generate_examples <output directory>
21
22bazel run //tensorflow/lite/testing:generate_examples
23
24To more easily debug failures use (or override) the --save_graphdefs flag to
25place text proto graphdefs into the generated zip files.
26"""
27
28from __future__ import absolute_import
29from __future__ import division
30from __future__ import print_function
31
32import tensorflow.compat.v1 as tf
33import argparse
34import os
35import sys
36from tensorflow.lite.testing import generate_examples_lib
37from tensorflow.lite.testing import toco_convert
38
39# TODO(aselle): Disable GPU for now
40os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
41
42
43parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
44parser.add_argument("output_path",
45                    help="Directory where the outputs will be go.")
46parser.add_argument(
47    "--zip_to_output",
48    type=str,
49    help="Particular zip to output.",
50    required=True)
51parser.add_argument("--toco",
52                    type=str,
53                    help="Path to toco tool.",
54                    required=True)
55parser.add_argument(
56    "--known_bugs_are_errors",
57    action="store_true",
58    help=("If a particular model is affected by a known bug,"
59          " count it as a converter error."))
60parser.add_argument(
61    "--ignore_converter_errors",
62    action="store_true",
63    help="Raise an exception if any converter error is encountered.")
64parser.add_argument(
65    "--save_graphdefs",
66    action="store_true",
67    help="Include intermediate graphdefs in the output zip files.")
68parser.add_argument(
69    "--run_with_flex",
70    action="store_true",
71    help="Whether the TFLite Flex converter is being used.")
72parser.add_argument(
73    "--make_edgetpu_tests",
74    action="store_true",
75    help="Whether to generate test cases for edgetpu.")
76parser.add_argument(
77    "--make_forward_compat_test",
78    action="store_true",
79    help="Make tests by setting TF forward compatibility horizon to the future")
80parser.add_argument(
81    "--no_tests_limit",
82    action="store_true",
83    help="Remove the limit of the number of tests.")
84parser.add_argument(
85    "--no_conversion_report",
86    action="store_true",
87    help="Do not create conversion report.")
88parser.add_argument(
89    "--test_sets",
90    type=str,
91    help=("Comma-separated list of test set names to generate. "
92          "If not specified, a test set is selected by parsing the name of "
93          "'zip_to_output' file."))
94parser.add_argument(
95    "--mlir_quantizer",
96    action="store_true",
97    help=("Whether the new MLIR quantizer is being used."))
98
99
100# Toco binary path provided by the generate rule.
101bin_path = None
102
103
104def main(unused_args):
105  # Eager execution is enabled by default in TF 2.0, but generated example
106  # tests are still using non-eager features (e.g. `tf.placeholder`).
107  tf.compat.v1.disable_eager_execution()
108
109  options = generate_examples_lib.Options()
110
111  options.output_path = FLAGS.output_path
112  options.zip_to_output = FLAGS.zip_to_output
113  options.toco = FLAGS.toco
114  options.known_bugs_are_errors = FLAGS.known_bugs_are_errors
115  options.ignore_converter_errors = FLAGS.ignore_converter_errors
116  options.save_graphdefs = FLAGS.save_graphdefs
117  options.run_with_flex = FLAGS.run_with_flex
118  options.make_edgetpu_tests = FLAGS.make_edgetpu_tests
119  options.make_forward_compat_test = FLAGS.make_forward_compat_test
120  options.tflite_convert_function = toco_convert.toco_convert
121  options.no_tests_limit = FLAGS.no_tests_limit
122  options.no_conversion_report = FLAGS.no_conversion_report
123  options.mlir_quantizer = FLAGS.mlir_quantizer
124
125  if FLAGS.test_sets:
126    test_sets = FLAGS.test_sets.split(",")
127    generate_examples_lib.generate_multi_set_examples(options, test_sets)
128  else:
129    generate_examples_lib.generate_examples(options)
130
131
132if __name__ == "__main__":
133  FLAGS, unparsed = parser.parse_known_args()
134
135  if unparsed:
136    parser.print_usage()
137    print("\nGot the following unparsed args, %r please fix.\n" % unparsed)
138    exit(1)
139  else:
140    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
141