• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Generates a series of test cases using MLIR-based conversion."""
16
17# This is forked from `tensorflow/lite/testing/generate_examples.py`.
18# TODO(b/136499575): Merge this back to TFLite codebase when open sourcing.
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import argparse
25import os
26import sys
27
28import tensorflow.compat.v1 as tf
29
30from tensorflow.lite.experimental.mlir.testing import mlir_convert
31# pylint: disable=unused-import
32from tensorflow.lite.experimental.mlir.testing.op_tests.batchmatmul import make_batchmatmul_tests
33from tensorflow.lite.experimental.mlir.testing.op_tests.broadcast_args import make_broadcast_args_tests
34from tensorflow.lite.experimental.mlir.testing.op_tests.broadcast_gradient_args import make_broadcast_gradient_args_tests
35from tensorflow.lite.experimental.mlir.testing.op_tests.broadcast_to import make_broadcast_to_tests
36from tensorflow.lite.experimental.mlir.testing.op_tests.complex_abs import make_complex_abs_tests
37from tensorflow.lite.experimental.mlir.testing.op_tests.cond import make_cond_tests
38from tensorflow.lite.experimental.mlir.testing.op_tests.control_dep import make_control_dep_tests
39from tensorflow.lite.experimental.mlir.testing.op_tests.conv3d import make_conv3d_tests
40from tensorflow.lite.experimental.mlir.testing.op_tests.conv3d_transpose import make_conv3d_transpose_tests
41from tensorflow.lite.experimental.mlir.testing.op_tests.conv_bias_activation import make_conv_bias_relu6_tests
42from tensorflow.lite.experimental.mlir.testing.op_tests.cumsum import make_cumsum_tests
43# Placeholder for make_dense_image_warp_tests import
44from tensorflow.lite.experimental.mlir.testing.op_tests.dynamic_rnn import make_dynamic_rnn_tests
45from tensorflow.lite.experimental.mlir.testing.op_tests.einsum import make_einsum_tests
46from tensorflow.lite.experimental.mlir.testing.op_tests.identify_dilated_conv import make_identify_dilated_conv_tests
47from tensorflow.lite.experimental.mlir.testing.op_tests.identify_dilated_conv1d import make_identify_dilated_conv1d_tests
48from tensorflow.lite.experimental.mlir.testing.op_tests.imag import make_imag_tests
49from tensorflow.lite.experimental.mlir.testing.op_tests.irfft2d import make_irfft2d_tests
50from tensorflow.lite.experimental.mlir.testing.op_tests.is_finite import make_is_finite_tests
51from tensorflow.lite.experimental.mlir.testing.op_tests.max_pool_with_argmax import make_max_pool_with_argmax_tests
52from tensorflow.lite.experimental.mlir.testing.op_tests.parse_example import make_parse_example_tests
53from tensorflow.lite.experimental.mlir.testing.op_tests.pool3d import make_avg_pool3d_tests
54from tensorflow.lite.experimental.mlir.testing.op_tests.pool3d import make_max_pool3d_tests
55from tensorflow.lite.experimental.mlir.testing.op_tests.real import make_real_tests
56from tensorflow.lite.experimental.mlir.testing.op_tests.reciprocal import make_reciprocal_tests
57from tensorflow.lite.experimental.mlir.testing.op_tests.rfft import make_rfft_tests
58from tensorflow.lite.experimental.mlir.testing.op_tests.rfft2d import make_rfft2d_tests
59from tensorflow.lite.experimental.mlir.testing.op_tests.roll import make_roll_tests
60from tensorflow.lite.experimental.mlir.testing.op_tests.roll import make_roll_with_constant_tests
61from tensorflow.lite.experimental.mlir.testing.op_tests.segment_sum import make_segment_sum_tests
62from tensorflow.lite.experimental.mlir.testing.op_tests.shape_to_strided_slice import make_shape_to_strided_slice_tests
63from tensorflow.lite.experimental.mlir.testing.op_tests.softplus import make_softplus_tests
64from tensorflow.lite.experimental.mlir.testing.op_tests.static_hashtable import make_static_hashtable_tests
65from tensorflow.lite.experimental.mlir.testing.op_tests.static_rnn_with_control_flow_v2 import make_static_rnn_with_control_flow_v2_tests
66from tensorflow.lite.experimental.mlir.testing.op_tests.stft import make_stft_tests
67from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_list_concat import make_tensor_list_concat_tests
68from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_list_dynamic_shape import make_tensor_list_dynamic_shape_tests
69from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_list_get_item import make_tensor_list_get_item_tests
70from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_list_length import make_tensor_list_length_tests
71from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_list_resize import make_tensor_list_resize_tests
72from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_list_set_item import make_tensor_list_set_item_tests
73from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_scatter_add import make_tensor_scatter_add_tests
74from tensorflow.lite.experimental.mlir.testing.op_tests.tensor_scatter_update import make_tensor_scatter_update_tests
75from tensorflow.lite.experimental.mlir.testing.op_tests.where_v2 import make_where_v2_tests
76from tensorflow.lite.experimental.mlir.testing.op_tests.while_loop import make_while_tests
77
78from tensorflow.lite.testing import generate_examples_lib
79
80
81MLIR_CONVERTER_KNOWN_BUGS = {
82    # We need to support dynamic_rnn case.
83    r"unidirectional_sequence_rnn.*is_dynamic_rnn=True": "128997102",
84    r"unidirectional_sequence_lstm.*is_dynamic_rnn=True": "128997102",
85    # TODO(b/124314620): Test cases work with tf_tfl_translate binary
86    # but not TFLiteConverter interface.
87    # Concat & SpaceToDepth with uint8 doesn't work.
88    r"concat.*type=tf\.uint8": "124314620",
89    r"space_to_depth.*type=tf\.uint8": "124314620",
90    r"l2norm.*fully_quantize=True": "134594898",
91    # Below are not really a converter bug, but our kernels doesn't support
92    # int64.
93    r"div.*dtype=tf\.int64": "119126484",
94    r"floor_div.*dtype=tf\.int64": "119126484",
95    r"relu.*dtype=tf\.int64": "119126484",
96    r"squared_difference.*dtype=tf\.int64": "119126484",
97    # Post-training quantization support missing for below op in mlir.
98    r"prelu.*fully_quantize=True": "156112683",
99    # ResizeBilinear op kernel supports only float32 and quantized 8-bit
100    # integers.
101    r"resize_bilinear.*dtype=tf\.int32": "156569626",
102}
103
104# Disable GPU for now since we are just testing in TF against CPU reference
105# value and creating non-device-specific graphs to export.
106os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
107
108parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
109parser.add_argument("output_path",
110                    help="Directory where the outputs will be go.")
111parser.add_argument(
112    "--zip_to_output",
113    type=str,
114    help="Particular zip to output.",
115    required=True)
116parser.add_argument(
117    "--known_bugs_are_errors",
118    action="store_true",
119    help=("If a particular model is affected by a known bug,"
120          " count it as a converter error."))
121parser.add_argument(
122    "--ignore_converter_errors",
123    action="store_true",
124    help="Raise an exception if any converter error is encountered.")
125parser.add_argument(
126    "--save_graphdefs",
127    action="store_true",
128    help="Include intermediate graphdefs in the output zip files.")
129parser.add_argument(
130    "--run_with_flex",
131    action="store_true",
132    help="Whether the TFLite Flex converter is being used.")
133parser.add_argument(
134    "--make_edgetpu_tests",
135    action="store_true",
136    help="Whether to generate test cases for edgetpu.")
137parser.add_argument(
138    "--make_forward_compat_test",
139    action="store_true",
140    help="Make tests by setting TF forward compatibility horizon to the future")
141parser.add_argument(
142    "--test_sets",
143    type=str,
144    help=("Comma-separated list of test set names to generate. "
145          "If not specified, a test set is selected by parsing the name of "
146          "'zip_to_output' file."))
147parser.add_argument(
148    "--mlir_quantizer",
149    action="store_true",
150    help=("Whether the new MLIR quantizer is being used."))
151
152
153def main(unused_args):
154  options = generate_examples_lib.Options()
155
156  options.output_path = FLAGS.output_path
157  options.zip_to_output = FLAGS.zip_to_output
158  options.known_bugs_are_errors = FLAGS.known_bugs_are_errors
159  options.ignore_converter_errors = FLAGS.ignore_converter_errors
160  options.save_graphdefs = FLAGS.save_graphdefs
161  options.run_with_flex = FLAGS.run_with_flex
162  options.make_edgetpu_tests = FLAGS.make_edgetpu_tests
163  options.tflite_convert_function = mlir_convert.mlir_convert
164  options.known_bugs = MLIR_CONVERTER_KNOWN_BUGS
165  options.make_forward_compat_test = FLAGS.make_forward_compat_test
166  options.use_experimental_converter = True
167  options.mlir_quantizer = FLAGS.mlir_quantizer
168
169  if FLAGS.test_sets:
170    test_sets = FLAGS.test_sets.split(",")
171    generate_examples_lib.generate_multi_set_examples(options, test_sets)
172  else:
173    generate_examples_lib.generate_examples(options)
174
175
176if __name__ == "__main__":
177  FLAGS, unparsed = parser.parse_known_args()
178
179  if unparsed:
180    print("Usage: %s <path out> <zip file to generate>")
181    exit(1)
182  else:
183    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
184