• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts a model's graph def into a tflite model with MLIR-based conversion."""
16import os
17import tempfile
18
19import numpy as np
20import tensorflow as tf
21
22from tensorflow.lite.python import test_util as tflite_test_util
23from tensorflow.lite.testing import zip_test_utils
24from tensorflow.python.platform import resource_loader
25from tensorflow.python.saved_model import signature_constants
26
27
28def mlir_convert(
29    options,
30    saved_model_dir,
31    input_tensors,
32    output_tensors,  # pylint: disable=unused-argument
33    **kwargs):
34  """Convert a saved model into a tflite model with MLIR-based conversion.
35
36  Args:
37    options: A lite.testing.generate_examples_lib.Options instance.
38    saved_model_dir: Path to the saved model.
39    input_tensors: List of input tensor tuples `(name, shape, type)`.
40    output_tensors: List of output tensors (names).
41    **kwargs: Extra parameters.
42
43  Returns:
44    output tflite model, log_txt from conversion
45    or None, log_txt if it did not convert properly.
46  """
47  test_params = kwargs.get("test_params", {})
48  extra_convert_options = kwargs.get("extra_convert_options",
49                                     zip_test_utils.ExtraConvertOptions())
50  tflite_model = None
51  log = ""
52
53  signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
54  converter = tf.lite.TFLiteConverter.from_saved_model(
55      saved_model_dir, [signature_key])
56  converter.allow_custom_ops = extra_convert_options.allow_custom_ops
57  converter.experimental_new_quantizer = options.mlir_quantizer
58  if options.make_tf_ptq_tests:
59    if options.hlo_aware_conversion:
60      tf_quantization_mode = "DEFAULT"
61    else:
62      tf_quantization_mode = "LEGACY_INTEGER"
63    converter._experimental_tf_quantization_mode = tf_quantization_mode  # pylint: disable=protected-access
64
65  if options.run_with_flex:
66    converter.target_spec.supported_ops = set(
67        [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS])
68
69  if options.enable_dynamic_update_slice:
70    converter._experimental_enable_dynamic_update_slice = True  # pylint: disable=protected-access
71
72  if options.disable_batchmatmul_unfold:
73    converter._experimental_disable_batchmatmul_unfold = True  # pylint: disable=protected-access
74
75  if test_params.get("dynamic_range_quantize", False):
76    converter.optimizations = [tf.lite.Optimize.DEFAULT]
77
78  if test_params.get("fully_quantize", False):
79    converter.optimizations = [tf.lite.Optimize.DEFAULT]
80
81    # Read the input range for the representative dataset from parameters.
82    min_value, max_value = test_params.get("input_range", (-1, 1))
83
84    def representative_dataset(input_tensors):
85      calibration_inputs = {}
86      for name, shape, dtype in input_tensors:
87        if shape:
88          dims = [1 if dim.value is None else dim.value for dim in shape.dims]
89          calibration_inputs[name] = np.random.uniform(
90              min_value, max_value, tuple(dims)).astype(dtype.as_numpy_dtype)
91      return calibration_inputs
92
93    def representative_dataset_gen():
94      for _ in range(100):
95        yield representative_dataset(input_tensors)
96
97    if test_params.get("quant_16x8", False):
98      converter.target_spec.supported_ops = [
99          tf.lite.OpsSet
100          .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
101      ]
102    else:
103      converter.target_spec.supported_ops = [
104          tf.lite.OpsSet.TFLITE_BUILTINS_INT8
105      ]
106
107    converter.representative_dataset = representative_dataset_gen
108    if extra_convert_options.inference_input_type:
109      converter.inference_input_type = (
110          extra_convert_options.inference_input_type)
111
112    if extra_convert_options.inference_output_type:
113      converter.inference_output_type = (
114          extra_convert_options.inference_output_type)
115
116  try:
117    tflite_model = converter.convert()
118    if options.expected_ops_in_converted_model:
119      ops_list = tflite_test_util.get_ops_list(tflite_model)
120      for expected_op in options.expected_ops_in_converted_model:
121        if expected_op not in ops_list:
122          # Force the test to fail.
123          tflite_model = None
124          raise ValueError(
125              "{} op not found in the converted model".format(expected_op))
126  except Exception as e:  # pylint: disable=broad-except
127    log = str(e)
128
129  return tflite_model, log
130
131
132def mlir_convert_file(graph_def_filename,
133                      input_tensors,
134                      output_tensors,
135                      quantization_params=None,
136                      additional_flags=""):
137  """Convert a graphdef file into a tflite model with MLIR-based conversion.
138
139  NOTE: this currently shells out to the MLIR binary binary, but we would like
140  convert to Python API tooling in the future.
141
142  Args:
143    graph_def_filename: A GraphDef file.
144    input_tensors: List of input tensor tuples `(name, shape, type)`. name
145      should be a string. shape should be a tuple of integers. type should be a
146      string, for example 'DT_FLOAT'
147    output_tensors: List of output tensors (names).
148    quantization_params: parameters `(inference_type, min_values, max_values)`
149      to quantize the model.
150    additional_flags: A string of additional command line flags to be passed to
151      MLIR converter.
152
153  Returns:
154    output tflite model, log_txt from conversion
155    or None, log_txt if it did not convert properly.
156  """
157  bin_path = resource_loader.get_path_to_datafile(
158      "../../../../compiler/mlir/lite/tf_tfl_translate")
159
160  with tempfile.NamedTemporaryFile() as output_file, \
161       tempfile.NamedTemporaryFile("w+") as stdout_file:
162    input_shapes = []
163    for input_tensor in input_tensors:
164      shape = input_tensor[1]
165      input_shapes.append(",".join([str(dim) for dim in shape]))
166    input_shapes_str = ":".join(input_shapes)
167
168    input_types = ",".join([x[2] for x in input_tensors])
169
170    quant_flags = ""
171    if quantization_params is not None:
172      min_vals = ",".join([str(val) for val in quantization_params[1]])
173      max_vals = ",".join([str(val) for val in quantization_params[2]])
174      quant_flags = ("-tf-inference-type=" + quantization_params[0] +
175                     " -tf-input-min-values='" + min_vals +
176                     "' -tf-input-max-values='" + max_vals + "' " +
177                     "-emit-quant-adaptor-ops ")
178    cmd = ("%s -tf-input-arrays=%s -tf-input-data-types=%s -tf-input-shapes=%s "
179           "-tf-output-arrays=%s " + quant_flags + additional_flags +
180           "%s -o %s")
181    cmd = cmd % (
182        bin_path,
183        ",".join([x[0] for x in input_tensors]),
184        input_types,
185        input_shapes_str,
186        ",".join(output_tensors),
187        graph_def_filename,
188        output_file.name,
189    )
190    exit_code = os.system(cmd)
191    log = (
192        cmd + "exited with code %d" % exit_code + "\n------------------\n" +
193        stdout_file.read())
194    return (None if exit_code != 0 else output_file.read()), log
195