• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Generate a series of TensorFlow graphs that become tflite test cases.
17
18Usage:
19
20generate_examples <output directory>
21
22bazel run //tensorflow/lite/testing:generate_examples
23
24To more easily debug failures use (or override) the --save_graphdefs flag to
25place text proto graphdefs into the generated zip files.
26"""
27
28import argparse
29import os
30import sys
31
32import tensorflow.compat.v1 as tf
33
34from tensorflow.lite.testing import generate_examples_lib
35from tensorflow.lite.testing import mlir_convert
36
37MLIR_CONVERTER_KNOWN_BUGS = {
38    # We need to support dynamic_rnn case.
39    r"unidirectional_sequence_rnn.*is_dynamic_rnn=True": "128997102",
40    r"unidirectional_sequence_lstm.*is_dynamic_rnn=True": "128997102",
41    # TODO(b/124314620): Test cases work with tf_tfl_translate binary
42    # but not TFLiteConverter interface.
43    # Concat & SpaceToDepth with uint8 doesn't work.
44    r"concat.*type=tf\.uint8": "124314620",
45    r"space_to_depth.*type=tf\.uint8": "124314620",
46    r"l2norm.*fully_quantize=True": "134594898",
47    # Below are not really a converter bug, but our kernels doesn't support
48    # int64.
49    r"div.*dtype=tf\.int64": "119126484",
50    r"floor_div.*dtype=tf\.int64": "119126484",
51    r"relu.*dtype=tf\.int64": "119126484",
52    r"squared_difference.*dtype=tf\.int64": "119126484",
53    # Post-training quantization support missing for below op in mlir.
54    r"prelu.*fully_quantize=True": "156112683",
55    # ResizeBilinear op kernel supports only float32 and quantized 8-bit
56    # integers.
57    r"resize_bilinear.*dtype=tf\.int32": "156569626",
58}
59
60# Disable GPU for now since we are just testing in TF against CPU reference
61# value and creating non-device-specific graphs to export.
62os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
63
64parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
65parser.add_argument(
66    "output_path", help="Directory where the outputs will be go.")
67parser.add_argument(
68    "--zip_to_output",
69    type=str,
70    help="Particular zip to output.",
71    required=True)
72parser.add_argument(
73    "--known_bugs_are_errors",
74    action="store_true",
75    help=("If a particular model is affected by a known bug,"
76          " count it as a converter error."))
77parser.add_argument(
78    "--ignore_converter_errors",
79    action="store_true",
80    help="Raise an exception if any converter error is encountered.")
81parser.add_argument(
82    "--save_graphdefs",
83    action="store_true",
84    help="Include intermediate graphdefs in the output zip files.")
85parser.add_argument(
86    "--run_with_flex",
87    action="store_true",
88    help="Whether the TFLite Flex converter is being used.")
89parser.add_argument(
90    "--make_edgetpu_tests",
91    action="store_true",
92    help="Whether to generate test cases for edgetpu.")
93parser.add_argument(
94    "--make_tf_ptq_tests",
95    action="store_true",
96    help="Whether to generate test cases for TF post-training quantization.")
97parser.add_argument(
98    "--hlo_aware_conversion",
99    action="store_true",
100    help="For TF Quantization only: whether conversion for HLO target.")
101parser.add_argument(
102    "--make_forward_compat_test",
103    action="store_true",
104    help="Make tests by setting TF forward compatibility horizon to the future")
105parser.add_argument(
106    "--no_tests_limit",
107    action="store_true",
108    help="Remove the limit of the number of tests.")
109parser.add_argument(
110    "--test_sets",
111    type=str,
112    help=("Comma-separated list of test set names to generate. "
113          "If not specified, a test set is selected by parsing the name of "
114          "'zip_to_output' file."))
115parser.add_argument(
116    "--mlir_quantizer",
117    action="store_true",
118    help=("Whether the new MLIR quantizer is being used."))
119parser.add_argument(
120    "--skip_high_dimension_inputs",
121    action="store_true",
122    help=("Whether to skip generating tests with high dimension input shape."))
123
124
125def main(unused_args):
126  options = generate_examples_lib.Options()
127
128  options.output_path = FLAGS.output_path
129  options.zip_to_output = FLAGS.zip_to_output
130  options.known_bugs_are_errors = FLAGS.known_bugs_are_errors
131  options.ignore_converter_errors = FLAGS.ignore_converter_errors
132  options.save_graphdefs = FLAGS.save_graphdefs
133  options.run_with_flex = FLAGS.run_with_flex
134  options.make_edgetpu_tests = FLAGS.make_edgetpu_tests
135  options.make_tf_ptq_tests = FLAGS.make_tf_ptq_tests
136  options.tflite_convert_function = mlir_convert.mlir_convert
137  options.known_bugs = MLIR_CONVERTER_KNOWN_BUGS
138  options.make_forward_compat_test = FLAGS.make_forward_compat_test
139  options.no_tests_limit = FLAGS.no_tests_limit
140  options.mlir_quantizer = FLAGS.mlir_quantizer
141  options.skip_high_dimension_inputs = FLAGS.skip_high_dimension_inputs
142
143  if FLAGS.test_sets:
144    test_sets = FLAGS.test_sets.split(",")
145    generate_examples_lib.generate_multi_set_examples(options, test_sets)
146  else:
147    generate_examples_lib.generate_examples(options)
148
149
150if __name__ == "__main__":
151  FLAGS, unparsed = parser.parse_known_args()
152
153  if unparsed:
154    print("\nGot the following unparsed args, %r please fix.\n" % unparsed +
155          "Usage: %s <path out> <zip file to generate>")
156    exit(1)
157  else:
158    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
159