1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils for make_zip tests.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import itertools 22import operator 23import os 24import re 25import string 26import traceback 27import zipfile 28 29import numpy as np 30from six import StringIO 31 32# pylint: disable=g-import-not-at-top 33import tensorflow.compat.v1 as tf 34from google.protobuf import text_format 35from tensorflow.lite.testing import _pywrap_string_util 36from tensorflow.lite.testing import generate_examples_report as report_lib 37from tensorflow.python.framework import graph_util as tf_graph_util 38 39# A map from names to functions which make test cases. 40_MAKE_TEST_FUNCTIONS_MAP = {} 41 42 43# A decorator to register the make test functions. 44# Usage: 45# All the make_*_test should be registered. Example: 46# @register_make_test_function() 47# def make_conv_tests(options): 48# # ... 49# If a function is decorated by other decorators, it's required to specify the 50# name explicitly. Example: 51# @register_make_test_function(name="make_unidirectional_sequence_lstm_tests") 52# @test_util.enable_control_flow_v2 53# def make_unidirectional_sequence_lstm_tests(options): 54# # ... 55def register_make_test_function(name=None): 56 57 def decorate(function, name=name): 58 if name is None: 59 name = function.__name__ 60 _MAKE_TEST_FUNCTIONS_MAP[name] = function 61 62 return decorate 63 64 65def get_test_function(test_function_name): 66 """Get the test function according to the test function name.""" 67 68 if test_function_name not in _MAKE_TEST_FUNCTIONS_MAP: 69 return None 70 return _MAKE_TEST_FUNCTIONS_MAP[test_function_name] 71 72 73RANDOM_SEED = 342 74 75TF_TYPE_INFO = { 76 tf.float32: (np.float32, "FLOAT"), 77 tf.float16: (np.float16, "FLOAT"), 78 tf.float64: (np.double, "FLOAT64"), 79 tf.int32: (np.int32, "INT32"), 80 tf.uint8: (np.uint8, "QUANTIZED_UINT8"), 81 tf.int16: (np.int16, "QUANTIZED_INT16"), 82 tf.int64: (np.int64, "INT64"), 83 tf.bool: (np.bool, "BOOL"), 84 tf.string: (np.string_, "STRING"), 85} 86 87 88class ExtraTocoOptions(object): 89 """Additional toco options besides input, output, shape.""" 90 91 def __init__(self): 92 # Whether to ignore control dependency nodes. 93 self.drop_control_dependency = False 94 # Allow custom ops in the toco conversion. 95 self.allow_custom_ops = False 96 # Rnn states that are used to support rnn / lstm cells. 97 self.rnn_states = None 98 # Split the LSTM inputs from 5 inputs to 18 inputs for TFLite. 99 self.split_tflite_lstm_inputs = None 100 # The inference input type passed to TFLiteConvert. 101 self.inference_input_type = None 102 # The inference output type passed to TFLiteConvert. 103 self.inference_output_type = None 104 105 106def create_tensor_data(dtype, shape, min_value=-100, max_value=100): 107 """Build tensor data spreading the range [min_value, max_value).""" 108 109 if dtype in TF_TYPE_INFO: 110 dtype = TF_TYPE_INFO[dtype][0] 111 112 if dtype in (tf.float32, tf.float16, tf.float64): 113 value = (max_value - min_value) * np.random.random_sample(shape) + min_value 114 elif dtype in (tf.complex64, tf.complex128): 115 real = (max_value - min_value) * np.random.random_sample(shape) + min_value 116 imag = (max_value - min_value) * np.random.random_sample(shape) + min_value 117 value = real + imag * 1j 118 elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): 119 value = np.random.randint(min_value, max_value + 1, shape) 120 elif dtype == tf.bool: 121 value = np.random.choice([True, False], size=shape) 122 elif dtype == np.string_: 123 # Not the best strings, but they will do for some basic testing. 124 letters = list(string.ascii_uppercase) 125 return np.random.choice(letters, size=shape).astype(dtype) 126 return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype( 127 dtype) 128 129 130def create_scalar_data(dtype, min_value=-100, max_value=100): 131 """Build scalar tensor data range from min_value to max_value exclusively.""" 132 133 if dtype in TF_TYPE_INFO: 134 dtype = TF_TYPE_INFO[dtype][0] 135 136 if dtype in (tf.float32, tf.float16, tf.float64): 137 value = (max_value - min_value) * np.random.random() + min_value 138 elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): 139 value = np.random.randint(min_value, max_value + 1) 140 elif dtype == tf.bool: 141 value = np.random.choice([True, False]) 142 elif dtype == np.string_: 143 l = np.random.randint(1, 6) 144 value = "".join(np.random.choice(list(string.ascii_uppercase), size=l)) 145 return np.array(value, dtype=dtype) 146 147 148def freeze_graph(session, outputs): 149 """Freeze the current graph. 150 151 Args: 152 session: Tensorflow sessions containing the graph 153 outputs: List of output tensors 154 155 Returns: 156 The frozen graph_def. 157 """ 158 return tf_graph_util.convert_variables_to_constants( 159 session, session.graph.as_graph_def(), [x.op.name for x in outputs]) 160 161 162def format_result(t): 163 """Convert a tensor to a format that can be used in test specs.""" 164 if t.dtype.kind not in [np.dtype(np.string_).kind, np.dtype(np.object_).kind]: 165 # Output 9 digits after the point to ensure the precision is good enough. 166 values = ["{:.9f}".format(value) for value in list(t.flatten())] 167 return ",".join(values) 168 else: 169 # SerializeAsHexString returns bytes in PY3, so decode if appropriate. 170 return _pywrap_string_util.SerializeAsHexString(t.flatten()).decode("utf-8") 171 172 173def write_examples(fp, examples): 174 """Given a list `examples`, write a text format representation. 175 176 The file format is csv like with a simple repeated pattern. We would ike 177 to use proto here, but we can't yet due to interfacing with the Android 178 team using this format. 179 180 Args: 181 fp: File-like object to write to. 182 examples: Example dictionary consisting of keys "inputs" and "outputs" 183 """ 184 185 def write_tensor(fp, x): 186 """Write tensor in file format supported by TFLITE example.""" 187 fp.write("dtype,%s\n" % x.dtype) 188 fp.write("shape," + ",".join(map(str, x.shape)) + "\n") 189 fp.write("values," + format_result(x) + "\n") 190 191 fp.write("test_cases,%d\n" % len(examples)) 192 for example in examples: 193 fp.write("inputs,%d\n" % len(example["inputs"])) 194 for i in example["inputs"]: 195 write_tensor(fp, i) 196 fp.write("outputs,%d\n" % len(example["outputs"])) 197 for i in example["outputs"]: 198 write_tensor(fp, i) 199 200 201def write_test_cases(fp, model_name, examples): 202 """Given a dictionary of `examples`, write a text format representation. 203 204 The file format is protocol-buffer-like, even though we don't use proto due 205 to the needs of the Android team. 206 207 Args: 208 fp: File-like object to write to. 209 model_name: Filename where the model was written to, relative to filename. 210 examples: Example dictionary consisting of keys "inputs" and "outputs" 211 """ 212 213 fp.write("load_model: %s\n" % os.path.basename(model_name)) 214 for example in examples: 215 fp.write("reshape {\n") 216 for t in example["inputs"]: 217 fp.write(" input: \"" + ",".join(map(str, t.shape)) + "\"\n") 218 fp.write("}\n") 219 fp.write("invoke {\n") 220 221 for t in example["inputs"]: 222 fp.write(" input: \"" + format_result(t) + "\"\n") 223 for t in example["outputs"]: 224 fp.write(" output: \"" + format_result(t) + "\"\n") 225 fp.write(" output_shape: \"" + ",".join([str(dim) for dim in t.shape]) + 226 "\"\n") 227 fp.write("}\n") 228 229 230def get_input_shapes_map(input_tensors): 231 """Gets a map of input names to shapes. 232 233 Args: 234 input_tensors: List of input tensor tuples `(name, shape, type)`. 235 236 Returns: 237 {string : list of integers}. 238 """ 239 input_arrays = [tensor[0] for tensor in input_tensors] 240 input_shapes_list = [] 241 242 for _, shape, _ in input_tensors: 243 dims = None 244 if shape: 245 dims = [dim.value for dim in shape.dims] 246 input_shapes_list.append(dims) 247 248 input_shapes = { 249 name: shape 250 for name, shape in zip(input_arrays, input_shapes_list) 251 if shape 252 } 253 return input_shapes 254 255 256def _normalize_output_name(output_name): 257 """Remove :0 suffix from tensor names.""" 258 return output_name.split(":")[0] if output_name.endswith( 259 ":0") else output_name 260 261 262# How many test cases we may have in a zip file. Too many test cases will 263# slow down the test data generation process. 264_MAX_TESTS_PER_ZIP = 500 265 266 267def make_zip_of_tests(options, 268 test_parameters, 269 make_graph, 270 make_test_inputs, 271 extra_toco_options=ExtraTocoOptions(), 272 use_frozen_graph=False, 273 expected_tf_failures=0): 274 """Helper to make a zip file of a bunch of TensorFlow models. 275 276 This does a cartesian product of the dictionary of test_parameters and 277 calls make_graph() for each item in the cartesian product set. 278 If the graph is built successfully, then make_test_inputs() is called to 279 build expected input/output value pairs. The model is then converted to tflite 280 with toco, and the examples are serialized with the tflite model into a zip 281 file (2 files per item in the cartesian product set). 282 283 Args: 284 options: An Options instance. 285 test_parameters: Dictionary mapping to lists for each parameter. 286 e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}` 287 make_graph: function that takes current parameters and returns tuple 288 `[input1, input2, ...], [output1, output2, ...]` 289 make_test_inputs: function taking `curr_params`, `session`, `input_tensors`, 290 `output_tensors` and returns tuple `(input_values, output_values)`. 291 extra_toco_options: Additional toco options. 292 use_frozen_graph: Whether or not freeze graph before toco converter. 293 expected_tf_failures: Number of times tensorflow is expected to fail in 294 executing the input graphs. In some cases it is OK for TensorFlow to fail 295 because the one or more combination of parameters is invalid. 296 297 Raises: 298 RuntimeError: if there are converter errors that can't be ignored. 299 """ 300 zip_path = os.path.join(options.output_path, options.zip_to_output) 301 parameter_count = 0 302 for parameters in test_parameters: 303 parameter_count += functools.reduce( 304 operator.mul, [len(values) for values in parameters.values()]) 305 306 all_parameter_count = parameter_count 307 if options.multi_gen_state: 308 all_parameter_count += options.multi_gen_state.parameter_count 309 if not options.no_tests_limit and all_parameter_count > _MAX_TESTS_PER_ZIP: 310 raise RuntimeError( 311 "Too many parameter combinations for generating '%s'.\n" 312 "There are at least %d combinations while the upper limit is %d.\n" 313 "Having too many combinations will slow down the tests.\n" 314 "Please consider splitting the test into multiple functions.\n" % 315 (zip_path, all_parameter_count, _MAX_TESTS_PER_ZIP)) 316 if options.multi_gen_state: 317 options.multi_gen_state.parameter_count = all_parameter_count 318 319 # TODO(aselle): Make this allow multiple inputs outputs. 320 if options.multi_gen_state: 321 archive = options.multi_gen_state.archive 322 else: 323 archive = zipfile.PyZipFile(zip_path, "w") 324 zip_manifest = [] 325 convert_report = [] 326 toco_errors = 0 327 328 processed_labels = set() 329 330 if options.make_edgetpu_tests: 331 extra_toco_options.inference_input_type = tf.uint8 332 extra_toco_options.inference_output_type = tf.uint8 333 # Only count parameters when fully_quantize is True. 334 parameter_count = 0 335 for parameters in test_parameters: 336 if True in parameters.get("fully_quantize", 337 []) and False in parameters.get( 338 "quant_16x8", [False]): 339 parameter_count += functools.reduce(operator.mul, [ 340 len(values) 341 for key, values in parameters.items() 342 if key != "fully_quantize" and key != "quant_16x8" 343 ]) 344 345 label_base_path = zip_path 346 if options.multi_gen_state: 347 label_base_path = options.multi_gen_state.label_base_path 348 349 i = 1 350 for parameters in test_parameters: 351 keys = parameters.keys() 352 for curr in itertools.product(*parameters.values()): 353 label = label_base_path.replace(".zip", "_") + (",".join( 354 "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) 355 if label[0] == "/": 356 label = label[1:] 357 358 zip_path_label = label 359 if len(os.path.basename(zip_path_label)) > 245: 360 zip_path_label = label_base_path.replace(".zip", "_") + str(i) 361 362 i += 1 363 if label in processed_labels: 364 # Do not populate data for the same label more than once. It will cause 365 # errors when unzipping. 366 continue 367 processed_labels.add(label) 368 369 param_dict = dict(zip(keys, curr)) 370 371 if options.make_edgetpu_tests and (not param_dict.get( 372 "fully_quantize", False) or param_dict.get("quant_16x8", False)): 373 continue 374 375 def generate_inputs_outputs(tflite_model_binary, 376 min_value=0, 377 max_value=255): 378 """Generate input values and output values of the given tflite model. 379 380 Args: 381 tflite_model_binary: A serialized flatbuffer as a string. 382 min_value: min value for the input tensor. 383 max_value: max value for the input tensor. 384 385 Returns: 386 (input_values, output_values): input values and output values built. 387 """ 388 interpreter = tf.lite.Interpreter(model_content=tflite_model_binary) 389 interpreter.allocate_tensors() 390 391 input_details = interpreter.get_input_details() 392 input_values = [] 393 for input_detail in input_details: 394 input_value = create_tensor_data( 395 input_detail["dtype"], 396 input_detail["shape"], 397 min_value=min_value, 398 max_value=max_value) 399 interpreter.set_tensor(input_detail["index"], input_value) 400 input_values.append(input_value) 401 402 interpreter.invoke() 403 404 output_details = interpreter.get_output_details() 405 output_values = [] 406 for output_detail in output_details: 407 output_values.append(interpreter.get_tensor(output_detail["index"])) 408 409 return input_values, output_values 410 411 def build_example(label, param_dict_real, zip_path_label): 412 """Build the model with parameter values set in param_dict_real. 413 414 Args: 415 label: Label of the model 416 param_dict_real: Parameter dictionary (arguments to the factories 417 make_graph and make_test_inputs) 418 zip_path_label: Filename in the zip 419 420 Returns: 421 (tflite_model_binary, report) where tflite_model_binary is the 422 serialized flatbuffer as a string and report is a dictionary with 423 keys `toco_log` (log of toco conversion), `tf_log` (log of tf 424 conversion), `toco` (a string of success status of the conversion), 425 `tf` (a string success status of the conversion). 426 """ 427 428 np.random.seed(RANDOM_SEED) 429 report = {"converter": report_lib.NOTRUN, "tf": report_lib.FAILED} 430 431 # Build graph 432 report["tf_log"] = "" 433 report["converter_log"] = "" 434 tf.reset_default_graph() 435 436 with tf.Graph().as_default(): 437 with tf.device("/cpu:0"): 438 try: 439 inputs, outputs = make_graph(param_dict_real) 440 except (tf.errors.UnimplementedError, 441 tf.errors.InvalidArgumentError, ValueError): 442 report["tf_log"] += traceback.format_exc() 443 return None, report 444 445 sess = tf.Session() 446 try: 447 baseline_inputs, baseline_outputs = ( 448 make_test_inputs(param_dict_real, sess, inputs, outputs)) 449 except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, 450 ValueError): 451 report["tf_log"] += traceback.format_exc() 452 return None, report 453 report["converter"] = report_lib.FAILED 454 report["tf"] = report_lib.SUCCESS 455 # Convert graph to toco 456 input_tensors = [(input_tensor.name.split(":")[0], input_tensor.shape, 457 input_tensor.dtype) for input_tensor in inputs] 458 output_tensors = [_normalize_output_name(out.name) for out in outputs] 459 # pylint: disable=g-long-ternary 460 graph_def = freeze_graph( 461 sess, 462 tf.global_variables() + inputs + 463 outputs) if use_frozen_graph else sess.graph_def 464 465 if "split_tflite_lstm_inputs" in param_dict_real: 466 extra_toco_options.split_tflite_lstm_inputs = param_dict_real[ 467 "split_tflite_lstm_inputs"] 468 tflite_model_binary, toco_log = options.tflite_convert_function( 469 options, 470 graph_def, 471 input_tensors, 472 output_tensors, 473 extra_toco_options=extra_toco_options, 474 test_params=param_dict_real) 475 report["converter"] = ( 476 report_lib.SUCCESS 477 if tflite_model_binary is not None else report_lib.FAILED) 478 report["converter_log"] = toco_log 479 480 if options.save_graphdefs: 481 archive.writestr(zip_path_label + ".pbtxt", 482 text_format.MessageToString(graph_def), 483 zipfile.ZIP_DEFLATED) 484 485 if tflite_model_binary: 486 if options.make_edgetpu_tests: 487 # Set proper min max values according to input dtype. 488 baseline_inputs, baseline_outputs = generate_inputs_outputs( 489 tflite_model_binary, min_value=0, max_value=255) 490 archive.writestr(zip_path_label + ".bin", tflite_model_binary, 491 zipfile.ZIP_DEFLATED) 492 example = {"inputs": baseline_inputs, "outputs": baseline_outputs} 493 494 example_fp = StringIO() 495 write_examples(example_fp, [example]) 496 archive.writestr(zip_path_label + ".inputs", example_fp.getvalue(), 497 zipfile.ZIP_DEFLATED) 498 499 example_fp2 = StringIO() 500 write_test_cases(example_fp2, zip_path_label + ".bin", [example]) 501 archive.writestr(zip_path_label + "_tests.txt", 502 example_fp2.getvalue(), zipfile.ZIP_DEFLATED) 503 504 zip_manifest_label = zip_path_label + " " + label 505 if zip_path_label == label: 506 zip_manifest_label = zip_path_label 507 508 zip_manifest.append(zip_manifest_label + "\n") 509 510 return tflite_model_binary, report 511 512 _, report = build_example(label, param_dict, zip_path_label) 513 514 if report["converter"] == report_lib.FAILED: 515 ignore_error = False 516 if not options.known_bugs_are_errors: 517 for pattern, bug_number in options.known_bugs.items(): 518 if re.search(pattern, label): 519 print("Ignored converter error due to bug %s" % bug_number) 520 ignore_error = True 521 if not ignore_error: 522 toco_errors += 1 523 print("-----------------\nconverter error!\n%s\n-----------------\n" % 524 report["converter_log"]) 525 526 convert_report.append((param_dict, report)) 527 528 if not options.no_conversion_report: 529 report_io = StringIO() 530 report_lib.make_report_table(report_io, zip_path, convert_report) 531 if options.multi_gen_state: 532 archive.writestr("report_" + options.multi_gen_state.test_name + ".html", 533 report_io.getvalue()) 534 else: 535 archive.writestr("report.html", report_io.getvalue()) 536 537 if options.multi_gen_state: 538 options.multi_gen_state.zip_manifest.extend(zip_manifest) 539 else: 540 archive.writestr("manifest.txt", "".join(zip_manifest), 541 zipfile.ZIP_DEFLATED) 542 543 # Log statistics of what succeeded 544 total_conversions = len(convert_report) 545 tf_success = sum( 546 1 for x in convert_report if x[1]["tf"] == report_lib.SUCCESS) 547 toco_success = sum( 548 1 for x in convert_report if x[1]["converter"] == report_lib.SUCCESS) 549 percent = 0 550 if tf_success > 0: 551 percent = float(toco_success) / float(tf_success) * 100. 552 tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs " 553 " and %d TOCO converted graphs (%.1f%%"), zip_path, 554 total_conversions, tf_success, toco_success, percent) 555 556 tf_failures = parameter_count - tf_success 557 558 if tf_failures / parameter_count > 0.8: 559 raise RuntimeError(("Test for '%s' is not very useful. " 560 "TensorFlow fails in %d percent of the cases.") % 561 (zip_path, int(100 * tf_failures / parameter_count))) 562 563 if not options.make_edgetpu_tests and tf_failures != expected_tf_failures: 564 raise RuntimeError(("Expected TF to fail %d times while generating '%s', " 565 "but that happened %d times") % 566 (expected_tf_failures, zip_path, tf_failures)) 567 568 if not options.ignore_converter_errors and toco_errors > 0: 569 raise RuntimeError("Found %d errors while generating toco models" % 570 toco_errors) 571