• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for make_zip tests."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import itertools
22import operator
23import os
24import re
25import string
26import traceback
27import zipfile
28
29import numpy as np
30from six import StringIO
31
32# pylint: disable=g-import-not-at-top
33import tensorflow.compat.v1 as tf
34from google.protobuf import text_format
35from tensorflow.lite.testing import _pywrap_string_util
36from tensorflow.lite.testing import generate_examples_report as report_lib
37from tensorflow.python.framework import graph_util as tf_graph_util
38
39# A map from names to functions which make test cases.
40_MAKE_TEST_FUNCTIONS_MAP = {}
41
42
43# A decorator to register the make test functions.
44# Usage:
45# All the make_*_test should be registered. Example:
46#   @register_make_test_function()
47#   def make_conv_tests(options):
48#     # ...
49# If a function is decorated by other decorators, it's required to specify the
50# name explicitly. Example:
51#   @register_make_test_function(name="make_unidirectional_sequence_lstm_tests")
52#   @test_util.enable_control_flow_v2
53#   def make_unidirectional_sequence_lstm_tests(options):
54#     # ...
55def register_make_test_function(name=None):
56
57  def decorate(function, name=name):
58    if name is None:
59      name = function.__name__
60    _MAKE_TEST_FUNCTIONS_MAP[name] = function
61
62  return decorate
63
64
65def get_test_function(test_function_name):
66  """Get the test function according to the test function name."""
67
68  if test_function_name not in _MAKE_TEST_FUNCTIONS_MAP:
69    return None
70  return _MAKE_TEST_FUNCTIONS_MAP[test_function_name]
71
72
73RANDOM_SEED = 342
74
75TF_TYPE_INFO = {
76    tf.float32: (np.float32, "FLOAT"),
77    tf.float16: (np.float16, "FLOAT"),
78    tf.float64: (np.double, "FLOAT64"),
79    tf.int32: (np.int32, "INT32"),
80    tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
81    tf.int16: (np.int16, "QUANTIZED_INT16"),
82    tf.int64: (np.int64, "INT64"),
83    tf.bool: (np.bool, "BOOL"),
84    tf.string: (np.string_, "STRING"),
85}
86
87
88class ExtraTocoOptions(object):
89  """Additional toco options besides input, output, shape."""
90
91  def __init__(self):
92    # Whether to ignore control dependency nodes.
93    self.drop_control_dependency = False
94    # Allow custom ops in the toco conversion.
95    self.allow_custom_ops = False
96    # Rnn states that are used to support rnn / lstm cells.
97    self.rnn_states = None
98    # Split the LSTM inputs from 5 inputs to 18 inputs for TFLite.
99    self.split_tflite_lstm_inputs = None
100    # The inference input type passed to TFLiteConvert.
101    self.inference_input_type = None
102    # The inference output type passed to TFLiteConvert.
103    self.inference_output_type = None
104
105
106def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
107  """Build tensor data spreading the range [min_value, max_value)."""
108
109  if dtype in TF_TYPE_INFO:
110    dtype = TF_TYPE_INFO[dtype][0]
111
112  if dtype in (tf.float32, tf.float16, tf.float64):
113    value = (max_value - min_value) * np.random.random_sample(shape) + min_value
114  elif dtype in (tf.complex64, tf.complex128):
115    real = (max_value - min_value) * np.random.random_sample(shape) + min_value
116    imag = (max_value - min_value) * np.random.random_sample(shape) + min_value
117    value = real + imag * 1j
118  elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
119    value = np.random.randint(min_value, max_value + 1, shape)
120  elif dtype == tf.bool:
121    value = np.random.choice([True, False], size=shape)
122  elif dtype == np.string_:
123    # Not the best strings, but they will do for some basic testing.
124    letters = list(string.ascii_uppercase)
125    return np.random.choice(letters, size=shape).astype(dtype)
126  return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
127      dtype)
128
129
130def create_scalar_data(dtype, min_value=-100, max_value=100):
131  """Build scalar tensor data range from min_value to max_value exclusively."""
132
133  if dtype in TF_TYPE_INFO:
134    dtype = TF_TYPE_INFO[dtype][0]
135
136  if dtype in (tf.float32, tf.float16, tf.float64):
137    value = (max_value - min_value) * np.random.random() + min_value
138  elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
139    value = np.random.randint(min_value, max_value + 1)
140  elif dtype == tf.bool:
141    value = np.random.choice([True, False])
142  elif dtype == np.string_:
143    l = np.random.randint(1, 6)
144    value = "".join(np.random.choice(list(string.ascii_uppercase), size=l))
145  return np.array(value, dtype=dtype)
146
147
148def freeze_graph(session, outputs):
149  """Freeze the current graph.
150
151  Args:
152    session: Tensorflow sessions containing the graph
153    outputs: List of output tensors
154
155  Returns:
156    The frozen graph_def.
157  """
158  return tf_graph_util.convert_variables_to_constants(
159      session, session.graph.as_graph_def(), [x.op.name for x in outputs])
160
161
162def format_result(t):
163  """Convert a tensor to a format that can be used in test specs."""
164  if t.dtype.kind not in [np.dtype(np.string_).kind, np.dtype(np.object_).kind]:
165    # Output 9 digits after the point to ensure the precision is good enough.
166    values = ["{:.9f}".format(value) for value in list(t.flatten())]
167    return ",".join(values)
168  else:
169    # SerializeAsHexString returns bytes in PY3, so decode if appropriate.
170    return _pywrap_string_util.SerializeAsHexString(t.flatten()).decode("utf-8")
171
172
173def write_examples(fp, examples):
174  """Given a list `examples`, write a text format representation.
175
176  The file format is csv like with a simple repeated pattern. We would ike
177  to use proto here, but we can't yet due to interfacing with the Android
178  team using this format.
179
180  Args:
181    fp: File-like object to write to.
182    examples: Example dictionary consisting of keys "inputs" and "outputs"
183  """
184
185  def write_tensor(fp, x):
186    """Write tensor in file format supported by TFLITE example."""
187    fp.write("dtype,%s\n" % x.dtype)
188    fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
189    fp.write("values," + format_result(x) + "\n")
190
191  fp.write("test_cases,%d\n" % len(examples))
192  for example in examples:
193    fp.write("inputs,%d\n" % len(example["inputs"]))
194    for i in example["inputs"]:
195      write_tensor(fp, i)
196    fp.write("outputs,%d\n" % len(example["outputs"]))
197    for i in example["outputs"]:
198      write_tensor(fp, i)
199
200
201def write_test_cases(fp, model_name, examples):
202  """Given a dictionary of `examples`, write a text format representation.
203
204  The file format is protocol-buffer-like, even though we don't use proto due
205  to the needs of the Android team.
206
207  Args:
208    fp: File-like object to write to.
209    model_name: Filename where the model was written to, relative to filename.
210    examples: Example dictionary consisting of keys "inputs" and "outputs"
211  """
212
213  fp.write("load_model: %s\n" % os.path.basename(model_name))
214  for example in examples:
215    fp.write("reshape {\n")
216    for t in example["inputs"]:
217      fp.write("  input: \"" + ",".join(map(str, t.shape)) + "\"\n")
218    fp.write("}\n")
219    fp.write("invoke {\n")
220
221    for t in example["inputs"]:
222      fp.write("  input: \"" + format_result(t) + "\"\n")
223    for t in example["outputs"]:
224      fp.write("  output: \"" + format_result(t) + "\"\n")
225      fp.write("  output_shape: \"" + ",".join([str(dim) for dim in t.shape]) +
226               "\"\n")
227    fp.write("}\n")
228
229
230def get_input_shapes_map(input_tensors):
231  """Gets a map of input names to shapes.
232
233  Args:
234    input_tensors: List of input tensor tuples `(name, shape, type)`.
235
236  Returns:
237    {string : list of integers}.
238  """
239  input_arrays = [tensor[0] for tensor in input_tensors]
240  input_shapes_list = []
241
242  for _, shape, _ in input_tensors:
243    dims = None
244    if shape:
245      dims = [dim.value for dim in shape.dims]
246    input_shapes_list.append(dims)
247
248  input_shapes = {
249      name: shape
250      for name, shape in zip(input_arrays, input_shapes_list)
251      if shape
252  }
253  return input_shapes
254
255
256def _normalize_output_name(output_name):
257  """Remove :0 suffix from tensor names."""
258  return output_name.split(":")[0] if output_name.endswith(
259      ":0") else output_name
260
261
262# How many test cases we may have in a zip file. Too many test cases will
263# slow down the test data generation process.
264_MAX_TESTS_PER_ZIP = 500
265
266
267def make_zip_of_tests(options,
268                      test_parameters,
269                      make_graph,
270                      make_test_inputs,
271                      extra_toco_options=ExtraTocoOptions(),
272                      use_frozen_graph=False,
273                      expected_tf_failures=0):
274  """Helper to make a zip file of a bunch of TensorFlow models.
275
276  This does a cartesian product of the dictionary of test_parameters and
277  calls make_graph() for each item in the cartesian product set.
278  If the graph is built successfully, then make_test_inputs() is called to
279  build expected input/output value pairs. The model is then converted to tflite
280  with toco, and the examples are serialized with the tflite model into a zip
281  file (2 files per item in the cartesian product set).
282
283  Args:
284    options: An Options instance.
285    test_parameters: Dictionary mapping to lists for each parameter.
286      e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}`
287    make_graph: function that takes current parameters and returns tuple
288      `[input1, input2, ...], [output1, output2, ...]`
289    make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
290      `output_tensors` and returns tuple `(input_values, output_values)`.
291    extra_toco_options: Additional toco options.
292    use_frozen_graph: Whether or not freeze graph before toco converter.
293    expected_tf_failures: Number of times tensorflow is expected to fail in
294      executing the input graphs. In some cases it is OK for TensorFlow to fail
295      because the one or more combination of parameters is invalid.
296
297  Raises:
298    RuntimeError: if there are converter errors that can't be ignored.
299  """
300  zip_path = os.path.join(options.output_path, options.zip_to_output)
301  parameter_count = 0
302  for parameters in test_parameters:
303    parameter_count += functools.reduce(
304        operator.mul, [len(values) for values in parameters.values()])
305
306  all_parameter_count = parameter_count
307  if options.multi_gen_state:
308    all_parameter_count += options.multi_gen_state.parameter_count
309  if not options.no_tests_limit and all_parameter_count > _MAX_TESTS_PER_ZIP:
310    raise RuntimeError(
311        "Too many parameter combinations for generating '%s'.\n"
312        "There are at least %d combinations while the upper limit is %d.\n"
313        "Having too many combinations will slow down the tests.\n"
314        "Please consider splitting the test into multiple functions.\n" %
315        (zip_path, all_parameter_count, _MAX_TESTS_PER_ZIP))
316  if options.multi_gen_state:
317    options.multi_gen_state.parameter_count = all_parameter_count
318
319  # TODO(aselle): Make this allow multiple inputs outputs.
320  if options.multi_gen_state:
321    archive = options.multi_gen_state.archive
322  else:
323    archive = zipfile.PyZipFile(zip_path, "w")
324  zip_manifest = []
325  convert_report = []
326  toco_errors = 0
327
328  processed_labels = set()
329
330  if options.make_edgetpu_tests:
331    extra_toco_options.inference_input_type = tf.uint8
332    extra_toco_options.inference_output_type = tf.uint8
333    # Only count parameters when fully_quantize is True.
334    parameter_count = 0
335    for parameters in test_parameters:
336      if True in parameters.get("fully_quantize",
337                                []) and False in parameters.get(
338                                    "quant_16x8", [False]):
339        parameter_count += functools.reduce(operator.mul, [
340            len(values)
341            for key, values in parameters.items()
342            if key != "fully_quantize" and key != "quant_16x8"
343        ])
344
345  label_base_path = zip_path
346  if options.multi_gen_state:
347    label_base_path = options.multi_gen_state.label_base_path
348
349  i = 1
350  for parameters in test_parameters:
351    keys = parameters.keys()
352    for curr in itertools.product(*parameters.values()):
353      label = label_base_path.replace(".zip", "_") + (",".join(
354          "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
355      if label[0] == "/":
356        label = label[1:]
357
358      zip_path_label = label
359      if len(os.path.basename(zip_path_label)) > 245:
360        zip_path_label = label_base_path.replace(".zip", "_") + str(i)
361
362      i += 1
363      if label in processed_labels:
364        # Do not populate data for the same label more than once. It will cause
365        # errors when unzipping.
366        continue
367      processed_labels.add(label)
368
369      param_dict = dict(zip(keys, curr))
370
371      if options.make_edgetpu_tests and (not param_dict.get(
372          "fully_quantize", False) or param_dict.get("quant_16x8", False)):
373        continue
374
375      def generate_inputs_outputs(tflite_model_binary,
376                                  min_value=0,
377                                  max_value=255):
378        """Generate input values and output values of the given tflite model.
379
380        Args:
381          tflite_model_binary: A serialized flatbuffer as a string.
382          min_value: min value for the input tensor.
383          max_value: max value for the input tensor.
384
385        Returns:
386          (input_values, output_values): input values and output values built.
387        """
388        interpreter = tf.lite.Interpreter(model_content=tflite_model_binary)
389        interpreter.allocate_tensors()
390
391        input_details = interpreter.get_input_details()
392        input_values = []
393        for input_detail in input_details:
394          input_value = create_tensor_data(
395              input_detail["dtype"],
396              input_detail["shape"],
397              min_value=min_value,
398              max_value=max_value)
399          interpreter.set_tensor(input_detail["index"], input_value)
400          input_values.append(input_value)
401
402        interpreter.invoke()
403
404        output_details = interpreter.get_output_details()
405        output_values = []
406        for output_detail in output_details:
407          output_values.append(interpreter.get_tensor(output_detail["index"]))
408
409        return input_values, output_values
410
411      def build_example(label, param_dict_real, zip_path_label):
412        """Build the model with parameter values set in param_dict_real.
413
414        Args:
415          label: Label of the model
416          param_dict_real: Parameter dictionary (arguments to the factories
417            make_graph and make_test_inputs)
418          zip_path_label: Filename in the zip
419
420        Returns:
421          (tflite_model_binary, report) where tflite_model_binary is the
422          serialized flatbuffer as a string and report is a dictionary with
423          keys `toco_log` (log of toco conversion), `tf_log` (log of tf
424          conversion), `toco` (a string of success status of the conversion),
425          `tf` (a string success status of the conversion).
426        """
427
428        np.random.seed(RANDOM_SEED)
429        report = {"converter": report_lib.NOTRUN, "tf": report_lib.FAILED}
430
431        # Build graph
432        report["tf_log"] = ""
433        report["converter_log"] = ""
434        tf.reset_default_graph()
435
436        with tf.Graph().as_default():
437          with tf.device("/cpu:0"):
438            try:
439              inputs, outputs = make_graph(param_dict_real)
440            except (tf.errors.UnimplementedError,
441                    tf.errors.InvalidArgumentError, ValueError):
442              report["tf_log"] += traceback.format_exc()
443              return None, report
444
445          sess = tf.Session()
446          try:
447            baseline_inputs, baseline_outputs = (
448                make_test_inputs(param_dict_real, sess, inputs, outputs))
449          except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
450                  ValueError):
451            report["tf_log"] += traceback.format_exc()
452            return None, report
453          report["converter"] = report_lib.FAILED
454          report["tf"] = report_lib.SUCCESS
455          # Convert graph to toco
456          input_tensors = [(input_tensor.name.split(":")[0], input_tensor.shape,
457                            input_tensor.dtype) for input_tensor in inputs]
458          output_tensors = [_normalize_output_name(out.name) for out in outputs]
459          # pylint: disable=g-long-ternary
460          graph_def = freeze_graph(
461              sess,
462              tf.global_variables() + inputs +
463              outputs) if use_frozen_graph else sess.graph_def
464
465        if "split_tflite_lstm_inputs" in param_dict_real:
466          extra_toco_options.split_tflite_lstm_inputs = param_dict_real[
467              "split_tflite_lstm_inputs"]
468        tflite_model_binary, toco_log = options.tflite_convert_function(
469            options,
470            graph_def,
471            input_tensors,
472            output_tensors,
473            extra_toco_options=extra_toco_options,
474            test_params=param_dict_real)
475        report["converter"] = (
476            report_lib.SUCCESS
477            if tflite_model_binary is not None else report_lib.FAILED)
478        report["converter_log"] = toco_log
479
480        if options.save_graphdefs:
481          archive.writestr(zip_path_label + ".pbtxt",
482                           text_format.MessageToString(graph_def),
483                           zipfile.ZIP_DEFLATED)
484
485        if tflite_model_binary:
486          if options.make_edgetpu_tests:
487            # Set proper min max values according to input dtype.
488            baseline_inputs, baseline_outputs = generate_inputs_outputs(
489                tflite_model_binary, min_value=0, max_value=255)
490          archive.writestr(zip_path_label + ".bin", tflite_model_binary,
491                           zipfile.ZIP_DEFLATED)
492          example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
493
494          example_fp = StringIO()
495          write_examples(example_fp, [example])
496          archive.writestr(zip_path_label + ".inputs", example_fp.getvalue(),
497                           zipfile.ZIP_DEFLATED)
498
499          example_fp2 = StringIO()
500          write_test_cases(example_fp2, zip_path_label + ".bin", [example])
501          archive.writestr(zip_path_label + "_tests.txt",
502                           example_fp2.getvalue(), zipfile.ZIP_DEFLATED)
503
504          zip_manifest_label = zip_path_label + " " + label
505          if zip_path_label == label:
506            zip_manifest_label = zip_path_label
507
508          zip_manifest.append(zip_manifest_label + "\n")
509
510        return tflite_model_binary, report
511
512      _, report = build_example(label, param_dict, zip_path_label)
513
514      if report["converter"] == report_lib.FAILED:
515        ignore_error = False
516        if not options.known_bugs_are_errors:
517          for pattern, bug_number in options.known_bugs.items():
518            if re.search(pattern, label):
519              print("Ignored converter error due to bug %s" % bug_number)
520              ignore_error = True
521        if not ignore_error:
522          toco_errors += 1
523          print("-----------------\nconverter error!\n%s\n-----------------\n" %
524                report["converter_log"])
525
526      convert_report.append((param_dict, report))
527
528  if not options.no_conversion_report:
529    report_io = StringIO()
530    report_lib.make_report_table(report_io, zip_path, convert_report)
531    if options.multi_gen_state:
532      archive.writestr("report_" + options.multi_gen_state.test_name + ".html",
533                       report_io.getvalue())
534    else:
535      archive.writestr("report.html", report_io.getvalue())
536
537  if options.multi_gen_state:
538    options.multi_gen_state.zip_manifest.extend(zip_manifest)
539  else:
540    archive.writestr("manifest.txt", "".join(zip_manifest),
541                     zipfile.ZIP_DEFLATED)
542
543  # Log statistics of what succeeded
544  total_conversions = len(convert_report)
545  tf_success = sum(
546      1 for x in convert_report if x[1]["tf"] == report_lib.SUCCESS)
547  toco_success = sum(
548      1 for x in convert_report if x[1]["converter"] == report_lib.SUCCESS)
549  percent = 0
550  if tf_success > 0:
551    percent = float(toco_success) / float(tf_success) * 100.
552  tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs "
553                   " and %d TOCO converted graphs (%.1f%%"), zip_path,
554                  total_conversions, tf_success, toco_success, percent)
555
556  tf_failures = parameter_count - tf_success
557
558  if tf_failures / parameter_count > 0.8:
559    raise RuntimeError(("Test for '%s' is not very useful. "
560                        "TensorFlow fails in %d percent of the cases.") %
561                       (zip_path, int(100 * tf_failures / parameter_count)))
562
563  if not options.make_edgetpu_tests and tf_failures != expected_tf_failures:
564    raise RuntimeError(("Expected TF to fail %d times while generating '%s', "
565                        "but that happened %d times") %
566                       (expected_tf_failures, zip_path, tf_failures))
567
568  if not options.ignore_converter_errors and toco_errors > 0:
569    raise RuntimeError("Found %d errors while generating toco models" %
570                       toco_errors)
571