• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for make_zip tests."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import itertools
22import operator
23import os
24import re
25import string
26import traceback
27import zipfile
28
29import numpy as np
30from six import StringIO
31
32# pylint: disable=g-import-not-at-top
33import tensorflow as tf
34from google.protobuf import text_format
35from tensorflow.lite.testing import _pywrap_string_util
36from tensorflow.lite.testing import generate_examples_report as report_lib
37from tensorflow.python.framework import graph_util as tf_graph_util
38
39# A map from names to functions which make test cases.
40_MAKE_TEST_FUNCTIONS_MAP = {}
41
42
43# A decorator to register the make test functions.
44# Usage:
45# All the make_*_test should be registered. Example:
46#   @register_make_test_function()
47#   def make_conv_tests(options):
48#     # ...
49# If a function is decorated by other decorators, it's required to specify the
50# name explicitly. Example:
51#   @register_make_test_function(name="make_unidirectional_sequence_lstm_tests")
52#   @test_util.enable_control_flow_v2
53#   def make_unidirectional_sequence_lstm_tests(options):
54#     # ...
55def register_make_test_function(name=None):
56
57  def decorate(function, name=name):
58    if name is None:
59      name = function.__name__
60    _MAKE_TEST_FUNCTIONS_MAP[name] = function
61
62  return decorate
63
64
65def get_test_function(test_function_name):
66  """Get the test function according to the test function name."""
67
68  if test_function_name not in _MAKE_TEST_FUNCTIONS_MAP:
69    return None
70  return _MAKE_TEST_FUNCTIONS_MAP[test_function_name]
71
72
73RANDOM_SEED = 342
74
75TF_TYPE_INFO = {
76    tf.float32: (np.float32, "FLOAT"),
77    tf.float16: (np.float16, "FLOAT"),
78    tf.int32: (np.int32, "INT32"),
79    tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
80    tf.int16: (np.int16, "QUANTIZED_INT16"),
81    tf.int64: (np.int64, "INT64"),
82    tf.bool: (np.bool, "BOOL"),
83    tf.string: (np.string_, "STRING"),
84}
85
86
87class ExtraTocoOptions(object):
88  """Additional toco options besides input, output, shape."""
89
90  def __init__(self):
91    # Whether to ignore control dependency nodes.
92    self.drop_control_dependency = False
93    # Allow custom ops in the toco conversion.
94    self.allow_custom_ops = False
95    # Rnn states that are used to support rnn / lstm cells.
96    self.rnn_states = None
97    # Split the LSTM inputs from 5 inoputs to 18 inputs for TFLite.
98    self.split_tflite_lstm_inputs = None
99    # The inference input type passed to TFLiteConvert.
100    self.inference_input_type = None
101    # The inference output type passed to TFLiteConvert.
102    self.inference_output_type = None
103
104
105def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
106  """Build tensor data spreading the range [min_value, max_value)."""
107
108  if dtype in TF_TYPE_INFO:
109    dtype = TF_TYPE_INFO[dtype][0]
110
111  if dtype in (tf.float32, tf.float16):
112    value = (max_value - min_value) * np.random.random_sample(shape) + min_value
113  elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
114    value = np.random.randint(min_value, max_value + 1, shape)
115  elif dtype == tf.bool:
116    value = np.random.choice([True, False], size=shape)
117  elif dtype == np.string_:
118    # Not the best strings, but they will do for some basic testing.
119    letters = list(string.ascii_uppercase)
120    return np.random.choice(letters, size=shape).astype(dtype)
121  return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
122      dtype)
123
124
125def create_scalar_data(dtype, min_value=-100, max_value=100):
126  """Build scalar tensor data range from min_value to max_value exclusively."""
127
128  if dtype in TF_TYPE_INFO:
129    dtype = TF_TYPE_INFO[dtype][0]
130
131  if dtype in (tf.float32, tf.float16):
132    value = (max_value - min_value) * np.random.random() + min_value
133  elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
134    value = np.random.randint(min_value, max_value + 1)
135  return np.array(value, dtype=dtype)
136
137
138def freeze_graph(session, outputs):
139  """Freeze the current graph.
140
141  Args:
142    session: Tensorflow sessions containing the graph
143    outputs: List of output tensors
144
145  Returns:
146    The frozen graph_def.
147  """
148  return tf_graph_util.convert_variables_to_constants(
149      session, session.graph.as_graph_def(), [x.op.name for x in outputs])
150
151
152def format_result(t):
153  """Convert a tensor to a format that can be used in test specs."""
154  if t.dtype.kind not in [np.dtype(np.string_).kind, np.dtype(np.object_).kind]:
155    # Output 9 digits after the point to ensure the precision is good enough.
156    values = ["{:.9f}".format(value) for value in list(t.flatten())]
157    return ",".join(values)
158  else:
159    return _pywrap_string_util.SerializeAsHexString(t.flatten())
160
161
162def write_examples(fp, examples):
163  """Given a list `examples`, write a text format representation.
164
165  The file format is csv like with a simple repeated pattern. We would ike
166  to use proto here, but we can't yet due to interfacing with the Android
167  team using this format.
168
169  Args:
170    fp: File-like object to write to.
171    examples: Example dictionary consiting of keys "inputs" and "outputs"
172  """
173
174  def write_tensor(fp, x):
175    """Write tensor in file format supported by TFLITE example."""
176    fp.write("dtype,%s\n" % x.dtype)
177    fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
178    fp.write("values," + format_result(x) + "\n")
179
180  fp.write("test_cases,%d\n" % len(examples))
181  for example in examples:
182    fp.write("inputs,%d\n" % len(example["inputs"]))
183    for i in example["inputs"]:
184      write_tensor(fp, i)
185    fp.write("outputs,%d\n" % len(example["outputs"]))
186    for i in example["outputs"]:
187      write_tensor(fp, i)
188
189
190def write_test_cases(fp, model_name, examples):
191  """Given a dictionary of `examples`, write a text format representation.
192
193  The file format is protocol-buffer-like, even though we don't use proto due
194  to the needs of the Android team.
195
196  Args:
197    fp: File-like object to write to.
198    model_name: Filename where the model was written to, relative to filename.
199    examples: Example dictionary consiting of keys "inputs" and "outputs"
200  """
201
202  fp.write("load_model: %s\n" % os.path.basename(model_name))
203  for example in examples:
204    fp.write("reshape {\n")
205    for t in example["inputs"]:
206      fp.write("  input: \"" + ",".join(map(str, t.shape)) + "\"\n")
207    fp.write("}\n")
208    fp.write("invoke {\n")
209
210    for t in example["inputs"]:
211      fp.write("  input: \"" + format_result(t) + "\"\n")
212    for t in example["outputs"]:
213      fp.write("  output: \"" + format_result(t) + "\"\n")
214      fp.write("  output_shape: \"" + ",".join([str(dim) for dim in t.shape]) +
215               "\"\n")
216    fp.write("}\n")
217
218
219def get_input_shapes_map(input_tensors):
220  """Gets a map of input names to shapes.
221
222  Args:
223    input_tensors: List of input tensor tuples `(name, shape, type)`.
224
225  Returns:
226    {string : list of integers}.
227  """
228  input_arrays = [tensor[0] for tensor in input_tensors]
229  input_shapes_list = []
230
231  for _, shape, _ in input_tensors:
232    dims = None
233    if shape:
234      dims = [dim.value for dim in shape.dims]
235    input_shapes_list.append(dims)
236
237  input_shapes = {
238      name: shape
239      for name, shape in zip(input_arrays, input_shapes_list)
240      if shape
241  }
242  return input_shapes
243
244
245def _normalize_output_name(output_name):
246  """Remove :0 suffix from tensor names."""
247  return output_name.split(":")[0] if output_name.endswith(
248      ":0") else output_name
249
250
251# How many test cases we may have in a zip file. Too many test cases will
252# slow down the test data generation process.
253_MAX_TESTS_PER_ZIP = 500
254
255
256def make_zip_of_tests(options,
257                      test_parameters,
258                      make_graph,
259                      make_test_inputs,
260                      extra_toco_options=ExtraTocoOptions(),
261                      use_frozen_graph=False,
262                      expected_tf_failures=0):
263  """Helper to make a zip file of a bunch of TensorFlow models.
264
265  This does a cartestian product of the dictionary of test_parameters and
266  calls make_graph() for each item in the cartestian product set.
267  If the graph is built successfully, then make_test_inputs() is called to
268  build expected input/output value pairs. The model is then converted to tflite
269  with toco, and the examples are serialized with the tflite model into a zip
270  file (2 files per item in the cartesian product set).
271
272  Args:
273    options: An Options instance.
274    test_parameters: Dictionary mapping to lists for each parameter.
275      e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}`
276    make_graph: function that takes current parameters and returns tuple
277      `[input1, input2, ...], [output1, output2, ...]`
278    make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
279      `output_tensors` and returns tuple `(input_values, output_values)`.
280    extra_toco_options: Additional toco options.
281    use_frozen_graph: Whether or not freeze graph before toco converter.
282    expected_tf_failures: Number of times tensorflow is expected to fail in
283      executing the input graphs. In some cases it is OK for TensorFlow to fail
284      because the one or more combination of parameters is invalid.
285
286  Raises:
287    RuntimeError: if there are converter errors that can't be ignored.
288  """
289  zip_path = os.path.join(options.output_path, options.zip_to_output)
290  parameter_count = 0
291  for parameters in test_parameters:
292    parameter_count += functools.reduce(
293        operator.mul, [len(values) for values in parameters.values()])
294
295  all_parameter_count = parameter_count
296  if options.multi_gen_state:
297    all_parameter_count += options.multi_gen_state.parameter_count
298  if not options.no_tests_limit and all_parameter_count > _MAX_TESTS_PER_ZIP:
299    raise RuntimeError(
300        "Too many parameter combinations for generating '%s'.\n"
301        "There are at least %d combinations while the upper limit is %d.\n"
302        "Having too many combinations will slow down the tests.\n"
303        "Please consider splitting the test into multiple functions.\n" %
304        (zip_path, all_parameter_count, _MAX_TESTS_PER_ZIP))
305  if options.multi_gen_state:
306    options.multi_gen_state.parameter_count = all_parameter_count
307
308  # TODO(aselle): Make this allow multiple inputs outputs.
309  if options.multi_gen_state:
310    archive = options.multi_gen_state.archive
311  else:
312    archive = zipfile.PyZipFile(zip_path, "w")
313  zip_manifest = []
314  convert_report = []
315  toco_errors = 0
316
317  processed_labels = set()
318
319  if options.make_edgetpu_tests:
320    extra_toco_options.inference_input_type = tf.uint8
321    extra_toco_options.inference_output_type = tf.uint8
322    # Only count parameters when fully_quantize is True.
323    parameter_count = 0
324    for parameters in test_parameters:
325      if True in parameters.get("fully_quantize", []):
326        parameter_count += functools.reduce(operator.mul, [
327            len(values)
328            for key, values in parameters.items()
329            if key != "fully_quantize"
330        ])
331
332  label_base_path = zip_path
333  if options.multi_gen_state:
334    label_base_path = options.multi_gen_state.label_base_path
335
336  for parameters in test_parameters:
337    keys = parameters.keys()
338    for curr in itertools.product(*parameters.values()):
339      label = label_base_path.replace(".zip", "_") + (",".join(
340          "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
341      if label[0] == "/":
342        label = label[1:]
343      if label in processed_labels:
344        # Do not populate data for the same label more than once. It will cause
345        # errors when unzipping.
346        continue
347      processed_labels.add(label)
348
349      param_dict = dict(zip(keys, curr))
350
351      if options.make_edgetpu_tests and not param_dict.get(
352          "fully_quantize", False):
353        continue
354
355      def generate_inputs_outputs(tflite_model_binary,
356                                  min_value=0,
357                                  max_value=255):
358        """Generate input values and output values of the given tflite model.
359
360        Args:
361          tflite_model_binary: A serialized flatbuffer as a string.
362          min_value: min value for the input tensor.
363          max_value: max value for the input tensor.
364
365        Returns:
366          (input_values, output_values): input values and output values built.
367        """
368        interpreter = tf.lite.Interpreter(model_content=tflite_model_binary)
369        interpreter.allocate_tensors()
370
371        input_details = interpreter.get_input_details()
372        input_values = []
373        for input_detail in input_details:
374          input_value = create_tensor_data(
375              input_detail["dtype"],
376              input_detail["shape"],
377              min_value=min_value,
378              max_value=max_value)
379          interpreter.set_tensor(input_detail["index"], input_value)
380          input_values.append(input_value)
381
382        interpreter.invoke()
383
384        output_details = interpreter.get_output_details()
385        output_values = []
386        for output_detail in output_details:
387          output_values.append(interpreter.get_tensor(output_detail["index"]))
388
389        return input_values, output_values
390
391      def build_example(label, param_dict_real):
392        """Build the model with parameter values set in param_dict_real.
393
394        Args:
395          label: Label of the model (i.e. the filename in the zip).
396          param_dict_real: Parameter dictionary (arguments to the factories
397            make_graph and make_test_inputs)
398
399        Returns:
400          (tflite_model_binary, report) where tflite_model_binary is the
401          serialized flatbuffer as a string and report is a dictionary with
402          keys `toco_log` (log of toco conversion), `tf_log` (log of tf
403          conversion), `toco` (a string of success status of the conversion),
404          `tf` (a string success status of the conversion).
405        """
406
407        np.random.seed(RANDOM_SEED)
408        report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED}
409
410        # Build graph
411        report["tf_log"] = ""
412        report["toco_log"] = ""
413        tf.compat.v1.reset_default_graph()
414
415        with tf.device("/cpu:0"):
416          try:
417            inputs, outputs = make_graph(param_dict_real)
418          except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
419                  ValueError):
420            report["tf_log"] += traceback.format_exc()
421            return None, report
422
423        sess = tf.compat.v1.Session()
424        try:
425          baseline_inputs, baseline_outputs = (
426              make_test_inputs(param_dict_real, sess, inputs, outputs))
427        except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
428                ValueError):
429          report["tf_log"] += traceback.format_exc()
430          return None, report
431        report["toco"] = report_lib.FAILED
432        report["tf"] = report_lib.SUCCESS
433        # Convert graph to toco
434        input_tensors = [(input_tensor.name.split(":")[0], input_tensor.shape,
435                          input_tensor.dtype) for input_tensor in inputs]
436        output_tensors = [_normalize_output_name(out.name) for out in outputs]
437        # pylint: disable=g-long-ternary
438        graph_def = freeze_graph(
439            sess,
440            tf.global_variables() + inputs +
441            outputs) if use_frozen_graph else sess.graph_def
442
443        if "split_tflite_lstm_inputs" in param_dict_real:
444          extra_toco_options.split_tflite_lstm_inputs = param_dict_real[
445              "split_tflite_lstm_inputs"]
446        tflite_model_binary, toco_log = options.tflite_convert_function(
447            options,
448            graph_def,
449            input_tensors,
450            output_tensors,
451            extra_toco_options=extra_toco_options,
452            test_params=param_dict_real)
453        report["toco"] = (
454            report_lib.SUCCESS
455            if tflite_model_binary is not None else report_lib.FAILED)
456        report["toco_log"] = toco_log
457
458        if options.save_graphdefs:
459          archive.writestr(label + ".pbtxt",
460                           text_format.MessageToString(graph_def),
461                           zipfile.ZIP_DEFLATED)
462
463        if tflite_model_binary:
464          if options.make_edgetpu_tests:
465            # Set proper min max values according to input dtype.
466            baseline_inputs, baseline_outputs = generate_inputs_outputs(
467                tflite_model_binary, min_value=0, max_value=255)
468          archive.writestr(label + ".bin", tflite_model_binary,
469                           zipfile.ZIP_DEFLATED)
470          example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
471
472          example_fp = StringIO()
473          write_examples(example_fp, [example])
474          archive.writestr(label + ".inputs", example_fp.getvalue(),
475                           zipfile.ZIP_DEFLATED)
476
477          example_fp2 = StringIO()
478          write_test_cases(example_fp2, label + ".bin", [example])
479          archive.writestr(label + "_tests.txt", example_fp2.getvalue(),
480                           zipfile.ZIP_DEFLATED)
481
482          zip_manifest.append(label + "\n")
483
484        return tflite_model_binary, report
485
486      _, report = build_example(label, param_dict)
487
488      if report["toco"] == report_lib.FAILED:
489        ignore_error = False
490        if not options.known_bugs_are_errors:
491          for pattern, bug_number in options.known_bugs.items():
492            if re.search(pattern, label):
493              print("Ignored converter error due to bug %s" % bug_number)
494              ignore_error = True
495        if not ignore_error:
496          toco_errors += 1
497          print("-----------------\nconverter error!\n%s\n-----------------\n" %
498                report["toco_log"])
499
500      convert_report.append((param_dict, report))
501
502  if not options.no_conversion_report:
503    report_io = StringIO()
504    report_lib.make_report_table(report_io, zip_path, convert_report)
505    if options.multi_gen_state:
506      archive.writestr("report_" + options.multi_gen_state.test_name + ".html",
507                       report_io.getvalue())
508    else:
509      archive.writestr("report.html", report_io.getvalue())
510
511  if options.multi_gen_state:
512    options.multi_gen_state.zip_manifest.extend(zip_manifest)
513  else:
514    archive.writestr("manifest.txt", "".join(zip_manifest),
515                     zipfile.ZIP_DEFLATED)
516
517  # Log statistics of what succeeded
518  total_conversions = len(convert_report)
519  tf_success = sum(
520      1 for x in convert_report if x[1]["tf"] == report_lib.SUCCESS)
521  toco_success = sum(
522      1 for x in convert_report if x[1]["toco"] == report_lib.SUCCESS)
523  percent = 0
524  if tf_success > 0:
525    percent = float(toco_success) / float(tf_success) * 100.
526  tf.compat.v1.logging.info(
527      ("Archive %s Considered %d graphs, %d TF evaluated graphs "
528       " and %d TOCO converted graphs (%.1f%%"), zip_path, total_conversions,
529      tf_success, toco_success, percent)
530
531  tf_failures = parameter_count - tf_success
532
533  if tf_failures / parameter_count > 0.8:
534    raise RuntimeError(("Test for '%s' is not very useful. "
535                        "TensorFlow fails in %d percent of the cases.") %
536                       (zip_path, int(100 * tf_failures / parameter_count)))
537
538  if not options.make_edgetpu_tests and tf_failures != expected_tf_failures:
539    raise RuntimeError(("Expected TF to fail %d times while generating '%s', "
540                        "but that happened %d times") %
541                       (expected_tf_failures, zip_path, tf_failures))
542
543  if not options.ignore_converter_errors and toco_errors > 0:
544    raise RuntimeError("Found %d errors while generating toco models" %
545                       toco_errors)
546