• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts a model's graph def into a tflite model with MLIR-based conversion."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import tempfile
22
23import numpy as np
24import tensorflow.compat.v1 as tf
25from tensorflow.lite.python import test_util as tflite_test_util
26from tensorflow.lite.testing import zip_test_utils
27from tensorflow.python.platform import resource_loader
28
29
30def mlir_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
31  """Convert a model's graph def into a tflite model with MLIR-based conversion.
32
33  Args:
34    options: A lite.testing.generate_examples_lib.Options instance.
35    graph_def: A GraphDef object.
36    input_tensors: List of input tensor tuples `(name, shape, type)`.
37    output_tensors: List of output tensors (names).
38    **kwargs: Extra parameters.
39
40  Returns:
41    output tflite model, log_txt from conversion
42    or None, log_txt if it did not convert properly.
43  """
44  test_params = kwargs.get("test_params", {})
45  # TODO(b/146025965): Rename ExtraTocoOptions to ExtraConvertOptions or
46  #                    something else.
47  extra_toco_options = kwargs.get("extra_toco_options",
48                                  zip_test_utils.ExtraTocoOptions())
49  input_arrays = [x[0] for x in input_tensors]
50  input_shapes = zip_test_utils.get_input_shapes_map(input_tensors)
51
52  tflite_model = None
53  log = ""
54
55  with tempfile.NamedTemporaryFile() as graphdef_file:
56    graphdef_file.write(graph_def.SerializeToString())
57    graphdef_file.flush()
58    converter = tf.lite.TFLiteConverter.from_frozen_graph(
59        graphdef_file.name, input_arrays, output_tensors, input_shapes)
60    converter.allow_custom_ops = extra_toco_options.allow_custom_ops
61    converter.experimental_new_quantizer = options.mlir_quantizer
62
63    if options.run_with_flex:
64      converter.supported_ops = set([
65          tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS])
66
67    if test_params.get("dynamic_range_quantize", False):
68      converter.optimizations = [tf.lite.Optimize.DEFAULT]
69
70    if test_params.get("fully_quantize", False):
71      converter.optimizations = [tf.lite.Optimize.DEFAULT]
72
73      # Read the input range for the representative dataset from parameters.
74      min_value, max_value = test_params.get("input_range", (-1, 1))
75
76      def representative_dataset(input_tensors):
77        calibration_inputs = []
78        for _, shape, _ in input_tensors:
79          if shape:
80            dims = [1 if dim.value is None else dim.value for dim in shape.dims]
81            calibration_inputs.append(
82                np.random.uniform(min_value, max_value,
83                                  tuple(dims)).astype(np.float32))
84        return calibration_inputs
85
86      def representative_dataset_gen():
87        for _ in range(100):
88          yield representative_dataset(input_tensors)
89
90      if test_params.get("quant_16x8", False):
91        converter.target_spec.supported_ops = [
92            tf.lite.OpsSet.\
93            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
94        ]
95      else:
96        converter.target_spec.supported_ops = [
97            tf.lite.OpsSet.TFLITE_BUILTINS_INT8
98        ]
99
100      converter.representative_dataset = representative_dataset_gen
101      if extra_toco_options.inference_input_type:
102        converter.inference_input_type = (
103            extra_toco_options.inference_input_type)
104
105      if extra_toco_options.inference_output_type:
106        converter.inference_output_type = (
107            extra_toco_options.inference_output_type)
108
109    try:
110      tflite_model = converter.convert()
111      if options.expected_ops_in_converted_model:
112        ops_list = tflite_test_util.get_ops_list(tflite_model)
113        for expected_op in options.expected_ops_in_converted_model:
114          if expected_op not in ops_list:
115            # Force the test to fail.
116            tflite_model = None
117            raise ValueError(
118                "{} op not found in the converted model".format(expected_op))
119    except Exception as e:  # pylint: disable=broad-except
120      log = str(e)
121
122  return tflite_model, log
123
124
125def mlir_convert_file(graph_def_filename,
126                      input_tensors,
127                      output_tensors,
128                      quantization_params=None,
129                      additional_flags=""):
130  """Convert a graphdef file into a tflite model with MLIR-based conversion.
131
132  NOTE: this currently shells out to the MLIR binary binary, but we would like
133  convert to Python API tooling in the future.
134
135  Args:
136    graph_def_filename: A GraphDef file.
137    input_tensors: List of input tensor tuples `(name, shape, type)`. name
138      should be a string. shape should be a tuple of integers. type should be a
139      string, for example 'DT_FLOAT'
140    output_tensors: List of output tensors (names).
141    quantization_params: parameters `(inference_type, min_values, max_values)`
142      to quantize the model.
143    additional_flags: A string of additional command line flags to be passed
144      to MLIR converter.
145
146  Returns:
147    output tflite model, log_txt from conversion
148    or None, log_txt if it did not convert properly.
149  """
150  bin_path = resource_loader.get_path_to_datafile(
151      "../../../../compiler/mlir/lite/tf_tfl_translate")
152
153  with tempfile.NamedTemporaryFile() as output_file, \
154       tempfile.NamedTemporaryFile("w+") as stdout_file:
155    input_shapes = []
156    for input_tensor in input_tensors:
157      shape = input_tensor[1]
158      input_shapes.append(",".join([str(dim) for dim in shape]))
159    input_shapes_str = ":".join(input_shapes)
160
161    input_types = ",".join([x[2] for x in input_tensors])
162
163    quant_flags = ""
164    if quantization_params is not None:
165      min_vals = ",".join([str(val) for val in quantization_params[1]])
166      max_vals = ",".join([str(val) for val in quantization_params[2]])
167      quant_flags = ("-tf-inference-type=" + quantization_params[0] +
168                     " -tf-input-min-values='" + min_vals +
169                     "' -tf-input-max-values='" + max_vals + "' " +
170                     "-emit-quant-adaptor-ops ")
171    cmd = ("%s -tf-input-arrays=%s -tf-input-data-types=%s -tf-input-shapes=%s "
172           "-tf-output-arrays=%s " + quant_flags + additional_flags +
173           "%s -o %s")
174    cmd = cmd % (
175        bin_path,
176        ",".join([x[0] for x in input_tensors]),
177        input_types,
178        input_shapes_str,
179        ",".join(output_tensors),
180        graph_def_filename,
181        output_file.name,
182    )
183    exit_code = os.system(cmd)
184    log = (
185        cmd + "exited with code %d" % exit_code + "\n------------------\n" +
186        stdout_file.read())
187    return (None if exit_code != 0 else output_file.read()), log
188
189
190def mlir_convert_saved_model(saved_model_dir,
191                             is_signature_def_saved_model,
192                             tags=(),
193                             exported_names=(),
194                             additional_flags=""):
195  """Convert a saved_model into a tflite model with MLIR-based conversion.
196
197  Args:
198    saved_model_dir: Saved model dir.
199    is_signature_def_saved_model: Whether the SavedModel SignatureDef importer
200      or ObjectGraph importer should be used.
201    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
202      analyze. All tags in the tag set must be present.
203    exported_names: Names to export from SavedModel.
204    additional_flags: A string of additional command line flags to be passed to
205      MLIR converter.
206
207  Returns:
208    output tflite model, log_txt from conversion
209    or None, log_txt if it did not convert properly.
210  """
211  bin_path = resource_loader.get_path_to_datafile(
212      "../../../../compiler/mlir/lite/tf_tfl_translate")
213  with tempfile.NamedTemporaryFile() as output_file, \
214       tempfile.NamedTemporaryFile("w+") as stdout_file:
215    tags_str = ",".join(tags)
216    exported_names_str = ",".join(exported_names)
217
218    saved_model_flag = "-savedmodel-objectgraph-to-mlir"
219    if is_signature_def_saved_model:
220      saved_model_flag = "-savedmodel-signaturedefs-to-mlir"
221
222    cmd = ("%s %s --tf-savedmodel-tags=%s --tf-savedmodel-exported-names=%s " +
223           additional_flags + " %s --o=%s")
224    cmd = cmd % (
225        bin_path,
226        saved_model_flag,
227        tags_str,
228        exported_names_str,
229        saved_model_dir,
230        output_file.name,
231    )
232    exit_code = os.system(cmd)
233    log = (
234        cmd + "exited with code %d" % exit_code + "\n------------------\n" +
235        stdout_file.read())
236    return (None if exit_code != 0 else output_file.read()), log
237