• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts a frozen graph into a TFLite FlatBuffer."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import enum  # pylint: disable=g-bad-import-order
22
23import os as _os
24import platform as _platform
25import subprocess as _subprocess
26import tempfile as _tempfile
27
28from tensorflow.lite.python import lite_constants
29from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
30from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
31from tensorflow.lite.toco import types_pb2 as _types_pb2
32from tensorflow.python.framework import dtypes
33from tensorflow.python.platform import resource_loader as _resource_loader
34from tensorflow.python.util import deprecation
35from tensorflow.python.util.lazy_loader import LazyLoader
36from tensorflow.python.util.tf_export import tf_export as _tf_export
37
38# Lazy load since some of the performance benchmark skylark rules
39# break dependencies.
40_toco_python = LazyLoader(
41    "tensorflow_wrap_toco", globals(),
42    "tensorflow.lite.toco.python."
43    "tensorflow_wrap_toco")
44del LazyLoader
45
46# Find the toco_from_protos binary using the resource loader if using from
47# bazel, otherwise we are in a pip where console_scripts already has
48# the toco_from_protos tool.
49if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
50  _toco_from_proto_bin = ""
51else:
52  _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
53      "../toco/python/toco_from_protos")
54
55if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
56  _toco_from_proto_bin = "toco_from_protos"
57
58
59# Map of tf.dtypes to TFLite types_flag_pb2.
60_MAP_TF_TO_TFLITE_TYPES = {
61    dtypes.float32: _types_pb2.FLOAT,
62    dtypes.int32: _types_pb2.INT32,
63    dtypes.int64: _types_pb2.INT64,
64    dtypes.string: _types_pb2.STRING,
65    dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
66    dtypes.complex64: _types_pb2.COMPLEX64
67}
68
69
70def _try_convert_to_unicode(output):
71  if output is None:
72    return u""
73
74  if isinstance(output, bytes):
75    try:
76      return output.decode()
77    except UnicodeDecodeError:
78      pass
79  return output
80
81
82def convert_dtype_to_tflite_type(tf_dtype):
83  """Converts tf.dtype to TFLite proto type.
84
85  Args:
86    tf_dtype: tf.dtype
87
88  Raises:
89    ValueError: Unsupported tf.dtype.
90
91  Returns:
92    types_flag_pb2.
93  """
94  result = _MAP_TF_TO_TFLITE_TYPES.get(tf_dtype)
95  if result is None:
96    raise ValueError("Unsupported tf.dtype {0}".format(tf_dtype))
97  return result
98
99
100@_tf_export("lite.OpsSet")
101class OpsSet(enum.Enum):
102  """Enum class defining the sets of ops available to generate TFLite models.
103
104  WARNING: Experimental interface, subject to change.
105  """
106  # Convert model using TensorFlow Lite builtin ops.
107  TFLITE_BUILTINS = "TFLITE_BUILTINS"
108
109  # Convert model using TensorFlow ops. Not all TensorFlow ops are available.
110  # WARNING: Experimental interface, subject to change.
111  SELECT_TF_OPS = "SELECT_TF_OPS"
112
113  def __str__(self):
114    return self.value
115
116  @staticmethod
117  def get_options():
118    """Returns a list of OpsSet options as a list of strings."""
119    return [str(option) for option in list(OpsSet)]
120
121
122class ConverterError(Exception):
123  """Raised when an error occurs during model conversion."""
124  pass
125
126
127# Don't expose these for now.
128#  @_tf_export("lite.toco_convert_protos")
129def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
130  """Convert `input_data_str` according to model and toco parameters.
131
132  Unless you know what you are doing consider using
133  the more friendly `tf.lite.toco_convert`.
134
135  Args:
136    model_flags_str: Serialized proto describing model properties, see
137      `toco/model_flags.proto`.
138    toco_flags_str: Serialized proto describing conversion properties, see
139      `toco/toco_flags.proto`.
140    input_data_str: Input data in serialized form (e.g. a graphdef is common)
141  Returns:
142    Converted model in serialized form (e.g. a TFLITE model is common).
143  Raises:
144    ConverterError: When conversion fails in TFLiteConverter, usually due to
145      ops not being supported.
146    RuntimeError: When conversion fails, an exception is raised with the error
147      message embedded.
148  """
149  # TODO(aselle): When toco does not use fatal errors for failure, we can
150  # switch this on.
151  if not _toco_from_proto_bin:
152    try:
153      model_str = _toco_python.TocoConvert(model_flags_str, toco_flags_str,
154                                           input_data_str)
155      return model_str
156    except Exception as e:
157      raise ConverterError("TOCO failed: %s" % e)
158
159  # Windows and TemporaryFile are not that useful together,
160  # since you cannot have two readers/writers. So we have to
161  # make the temporaries and close and delete them explicitly.
162  toco_filename, model_filename, input_filename, output_filename = (
163      None, None, None, None)
164  try:
165    # Build all input files
166    with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
167             _tempfile.NamedTemporaryFile(delete=False) as fp_model, \
168             _tempfile.NamedTemporaryFile(delete=False) as fp_input:
169      toco_filename = fp_toco.name
170      input_filename = fp_input.name
171      model_filename = fp_model.name
172      fp_model.write(model_flags_str)
173      fp_toco.write(toco_flags_str)
174      fp_input.write(input_data_str)
175      fp_model.flush()
176      fp_toco.flush()
177      fp_input.flush()
178
179    # Reserve an output file
180    with _tempfile.NamedTemporaryFile(delete=False) as fp:
181      output_filename = fp.name
182
183    # Run
184    cmd = [
185        _toco_from_proto_bin, model_filename, toco_filename, input_filename,
186        output_filename
187    ]
188    cmdline = " ".join(cmd)
189    is_windows = _platform.system() == "Windows"
190    proc = _subprocess.Popen(
191        cmdline,
192        shell=True,
193        stdout=_subprocess.PIPE,
194        stderr=_subprocess.STDOUT,
195        close_fds=not is_windows)
196    stdout, stderr = proc.communicate()
197    exitcode = proc.returncode
198    if exitcode == 0:
199      with open(output_filename, "rb") as fp:
200        return fp.read()
201    else:
202      stdout = _try_convert_to_unicode(stdout)
203      stderr = _try_convert_to_unicode(stderr)
204      raise ConverterError(
205          "TOCO failed. See console for info.\n%s\n%s\n" % (stdout, stderr))
206  finally:
207    # Must manually cleanup files.
208    for filename in [
209        toco_filename, input_filename, model_filename, output_filename]:
210      try:
211        _os.unlink(filename)
212      except (OSError, TypeError):
213        pass
214
215
216def tensor_name(x):
217  """Returns name of the input tensor."""
218  parts = x.name.split(":")
219  if len(parts) > 2:
220    raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
221        len(parts) - 1))
222
223  # To be consistent with the tensor naming scheme in tensorflow, we need
224  # drop the ':0' suffix for the first tensor.
225  if len(parts) > 1 and parts[1] != "0":
226    return x.name
227  return parts[0]
228
229
230# Don't expose these for now.
231# @_tf_export("lite.build_toco_convert_protos")
232def build_toco_convert_protos(input_tensors,
233                              output_tensors,
234                              inference_type=lite_constants.FLOAT,
235                              inference_input_type=None,
236                              input_format=lite_constants.TENSORFLOW_GRAPHDEF,
237                              input_shapes=None,
238                              output_format=lite_constants.TFLITE,
239                              quantized_input_stats=None,
240                              default_ranges_stats=None,
241                              drop_control_dependency=True,
242                              reorder_across_fake_quant=False,
243                              allow_custom_ops=False,
244                              change_concat_input_ranges=False,
245                              post_training_quantize=False,
246                              dump_graphviz_dir=None,
247                              dump_graphviz_video=False,
248                              target_ops=None,
249                              allow_nonexistent_arrays=False):
250  """Builds protocol buffers describing a conversion of a model using TOCO.
251
252  Typically this is to convert from TensorFlow GraphDef to TFLite, in which
253  case the default `input_format` and `output_format` are sufficient.
254
255  Args:
256    input_tensors: List of input tensors. Type and shape are computed using
257      `foo.shape` and `foo.dtype`.
258    output_tensors: List of output tensors (only .name is used from this).
259    inference_type: Target data type of real-number arrays in the output file.
260      Must be `{tf.float32, tf.uint8}`.  (default tf.float32)
261    inference_input_type: Target data type of real-number input arrays. Allows
262      for a different type for input arrays in the case of quantization.
263      Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
264    input_format: Type of data to read Currently must be
265      `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
266    input_shapes: Input array shape. It needs to be a list of the same length
267      as `input_tensors`, or None. (default None)
268    output_format: Output file format. Currently must be `{TFLITE,
269      GRAPHVIZ_DOT}`. (default TFLITE)
270    quantized_input_stats: List of tuples of floats representing the mean and
271      standard deviation. Each tuple maps to the corresponding input tensor.
272      Only need if `inference_input_type` is `QUANTIZED_UINT8`.
273      real_input_value = (quantized_input_value - mean_value) / std_dev_value.
274      (default None)
275    default_ranges_stats: Tuple of integers representing (min, max) range values
276      for all arrays without a specified range. Intended for experimenting with
277      quantization via "dummy quantization". (default None)
278    drop_control_dependency: Boolean indicating whether to drop control
279      dependencies silently. This is due to TFLite not supporting control
280      dependencies. (default True)
281    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
282      nodes in unexpected locations. Used when the location of the FakeQuant
283      nodes is preventing graph transformations necessary to convert the graph.
284      Results in a graph that differs from the quantized training graph,
285      potentially causing differing arithmetic behavior. (default False)
286    allow_custom_ops: Boolean indicating whether to allow custom operations.
287      When false any unknown operation is an error. When true, custom ops are
288      created for any op that is unknown. The developer will need to provide
289      these to the TensorFlow Lite runtime with a custom resolver.
290      (default False)
291    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
292      inputs and outputs of the concat operator for quantized models. Changes
293      the ranges of concat operator overlap when true. (default False)
294    post_training_quantize: Boolean indicating whether to quantize the weights
295      of the converted float model. Model size will be reduced and there will be
296      latency improvements (at the cost of accuracy).
297      (default False)
298    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
299      stages of processing GraphViz .dot files. Preferred over
300      --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
301      output file. (default None)
302    dump_graphviz_video: Boolean indicating whether to dump the graph after
303      every graph transformation. (default False)
304    target_ops: Experimental flag, subject to change. Set of OpsSet
305      options indicating which converter to use.
306      (default set([OpsSet.TFLITE_BUILTINS]))
307    allow_nonexistent_arrays: Allow specifying array names that don't exist
308      or are unused in the final graph. (default False)
309
310  Returns:
311    model_flags, toco_flags: two protocol buffers describing the conversion
312    process.
313
314  Raises:
315    ValueError:
316      If the input tensor type is unknown
317      Missing mean_values or std_dev_values
318    RuntimeError: If TOCO fails to convert (in which case the runtime error's
319      error text will contain the TOCO error log)
320  """
321  toco = _toco_flags_pb2.TocoFlags()
322  toco.input_format = input_format
323  toco.output_format = output_format
324  toco.inference_type = convert_dtype_to_tflite_type(inference_type)
325  if inference_input_type:
326    toco.inference_input_type = convert_dtype_to_tflite_type(
327        inference_input_type)
328  else:
329    toco.inference_input_type = toco.inference_type
330  toco.drop_control_dependency = drop_control_dependency
331  toco.reorder_across_fake_quant = reorder_across_fake_quant
332  toco.allow_custom_ops = allow_custom_ops
333  toco.post_training_quantize = post_training_quantize
334  if default_ranges_stats:
335    toco.default_ranges_min = default_ranges_stats[0]
336    toco.default_ranges_max = default_ranges_stats[1]
337  if dump_graphviz_dir:
338    toco.dump_graphviz_dir = dump_graphviz_dir
339  toco.dump_graphviz_include_video = dump_graphviz_video
340  if target_ops:
341    if set(target_ops) == set([OpsSet.TFLITE_BUILTINS, OpsSet.SELECT_TF_OPS]):
342      toco.enable_select_tf_ops = True
343    elif set(target_ops) == set([OpsSet.SELECT_TF_OPS]):
344      toco.enable_select_tf_ops = True
345      toco.force_select_tf_ops = True
346
347  model = _model_flags_pb2.ModelFlags()
348  model.change_concat_input_ranges = change_concat_input_ranges
349  for idx, input_tensor in enumerate(input_tensors):
350    input_array = model.input_arrays.add()
351    input_array.name = tensor_name(input_tensor)
352    input_array.data_type = convert_dtype_to_tflite_type(input_tensor.dtype)
353
354    if toco.inference_input_type == _types_pb2.QUANTIZED_UINT8:
355      if not quantized_input_stats:
356        raise ValueError("std_dev and mean must be defined when "
357                         "inference_input_type is QUANTIZED_UINT8.")
358      input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
359    if input_shapes is None:
360      shape = input_tensor.shape
361    else:
362      shape = input_shapes[idx]
363    input_array.shape.dims.extend(map(int, shape))
364
365  for output_tensor in output_tensors:
366    model.output_arrays.append(tensor_name(output_tensor))
367
368  model.allow_nonexistent_arrays = allow_nonexistent_arrays
369
370  return model, toco
371
372
373def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
374                           *args, **kwargs):
375  """"Convert a model using TOCO.
376
377  This function is used to convert GraphDefs that cannot be loaded into
378  TensorFlow to TFLite. Conversion can be customized by providing arguments
379  that are forwarded to `build_toco_convert_protos` (see documentation for
380  details).
381
382  Args:
383    input_data: Input data (i.e. often `sess.graph_def`),
384    input_arrays_with_shape: Tuple of strings representing input tensor names
385      and list of integers representing input shapes
386      (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
387      into TensorFlow and when `input_tensors` is None. (default None)
388    output_arrays: List of output tensors to freeze graph with. Use only when
389      graph cannot be loaded into TensorFlow and when `output_tensors` is None.
390      (default None)
391    *args: See `build_toco_convert_protos`,
392    **kwargs: See `build_toco_convert_protos`.
393
394  Returns:
395    The converted data. For example if TFLite was the destination, then
396    this will be a tflite flatbuffer in a bytes array.
397
398  Raises:
399    Defined in `build_toco_convert_protos`.
400  """
401  model_flags, toco_flags = build_toco_convert_protos(
402      input_tensors=[], output_tensors=[], *args, **kwargs)
403
404  for idx, (name, shape) in enumerate(input_arrays_with_shape):
405    input_array = model_flags.input_arrays.add()
406    if toco_flags.inference_input_type == _types_pb2.QUANTIZED_UINT8:
407      if (("quantized_input_stats" not in kwargs) or
408          (not kwargs["quantized_input_stats"])):
409        raise ValueError("std_dev and mean must be defined when "
410                         "inference_input_type is QUANTIZED_UINT8.")
411      input_array.mean_value, input_array.std_value = kwargs[
412          "quantized_input_stats"][idx]
413    input_array.name = name
414    input_array.shape.dims.extend(map(int, shape))
415
416  for name in output_arrays:
417    model_flags.output_arrays.append(name)
418
419  data = toco_convert_protos(model_flags.SerializeToString(),
420                             toco_flags.SerializeToString(),
421                             input_data.SerializeToString())
422  return data
423
424
425def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
426                      **kwargs):
427  """"Convert a model using TOCO.
428
429  Typically this function is used to convert from TensorFlow GraphDef to TFLite.
430  Conversion can be customized by providing arguments that are forwarded to
431  `build_toco_convert_protos` (see documentation for details).
432
433  Args:
434    input_data: Input data (i.e. often `sess.graph_def`),
435    input_tensors: List of input tensors. Type and shape are computed using
436      `foo.shape` and `foo.dtype`.
437    output_tensors: List of output tensors (only .name is used from this).
438    *args: See `build_toco_convert_protos`,
439    **kwargs: See `build_toco_convert_protos`.
440
441  Returns:
442    The converted data. For example if TFLite was the destination, then
443    this will be a tflite flatbuffer in a bytes array.
444
445  Raises:
446    Defined in `build_toco_convert_protos`.
447  """
448  model_flags, toco_flags = build_toco_convert_protos(
449      input_tensors, output_tensors, *args, **kwargs)
450  data = toco_convert_protos(model_flags.SerializeToString(),
451                             toco_flags.SerializeToString(),
452                             input_data.SerializeToString())
453  return data
454
455
456@_tf_export(v1=["lite.toco_convert"])
457@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
458def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
459  """Convert a model using TOCO.
460
461  Typically this function is used to convert from TensorFlow GraphDef to TFLite.
462  Conversion can be customized by providing arguments that are forwarded to
463  `build_toco_convert_protos` (see documentation for details). This function has
464  been deprecated. Please use `lite.TFLiteConverter` instead.
465
466  Args:
467    input_data: Input data (i.e. often `sess.graph_def`),
468    input_tensors: List of input tensors. Type and shape are computed using
469      `foo.shape` and `foo.dtype`.
470    output_tensors: List of output tensors (only .name is used from this).
471    *args: See `build_toco_convert_protos`,
472    **kwargs: See `build_toco_convert_protos`.
473
474  Returns:
475    The converted data. For example if TFLite was the destination, then
476    this will be a tflite flatbuffer in a bytes array.
477
478  Raises:
479    Defined in `build_toco_convert_protos`.
480  """
481  return toco_convert_impl(input_data, input_tensors, output_tensors, *args,
482                           **kwargs)
483