• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Functions used by multiple converter files."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23import datetime
24import sys
25
26from absl import logging
27import six
28from six.moves import range
29
30import flatbuffers
31from tensorflow.core.protobuf import config_pb2 as _config_pb2
32from tensorflow.core.protobuf import graph_debug_info_pb2
33from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
34from tensorflow.lite.python import schema_py_generated as schema_fb
35from tensorflow.lite.python import schema_util
36from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
37from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
38from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
39from tensorflow.python.eager import function
40from tensorflow.python.framework import convert_to_constants as _convert_to_constants
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import error_interpolation as _error_interpolation
43from tensorflow.python.framework import graph_util as tf_graph_util
44from tensorflow.python.grappler import tf_optimizer
45from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
46
47# Keras functions used by TFLite
48model_input_signature = _tflite_keras_util.model_input_signature
49trace_model_call = _tflite_keras_util.trace_model_call
50
51# Defined as per TFLite schema
52_MAP_TFLITE_ENUM_TO_TF_TYPES = {
53    0: dtypes.float32,
54    1: dtypes.float16,
55    2: dtypes.int32,
56    3: dtypes.uint8,
57    4: dtypes.int64,
58    5: dtypes.string,
59    6: dtypes.bool,
60    7: dtypes.int16,
61    8: dtypes.complex64,
62    9: dtypes.int8,
63    10: dtypes.float64,
64    11: dtypes.complex128,
65    16: dtypes.uint32,
66}
67
68_TFLITE_FILE_IDENTIFIER = b"TFL3"
69
70_MAP_QUANT_TO_IO_TYPES = {
71    dtypes.int8: {dtypes.int8, dtypes.uint8},
72    dtypes.int16: {dtypes.int16},
73}
74
75
76def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
77  """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
78
79  Args:
80    tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
81
82  Raises:
83    ValueError: If an invalid tflite enum type is provided.
84
85  Returns:
86    tf type (eg: tf.float32)
87  """
88  tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
89  if tf_type is None:
90    raise ValueError(
91        "Unsupported enum {}. The valid map of enum to tf types is : {}"
92        .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
93  return tf_type
94
95
96def get_tf_type_name(tf_type):
97  """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
98  return "tf." + tf_type.name if tf_type else None
99
100
101def get_tensor_name(tensor):
102  """Returns name of the input tensor.
103
104  Args:
105    tensor: tf.Tensor
106
107  Returns:
108    str
109  """
110  parts = six.ensure_str(tensor.name).split(":")
111  if len(parts) > 2:
112    raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
113        len(parts) - 1))
114
115  # To be consistent with the tensor naming scheme in tensorflow, we need
116  # drop the ':0' suffix for the first tensor.
117  if len(parts) > 1 and parts[1] != "0":
118    return tensor.name
119  return parts[0]
120
121
122def get_tensors_from_tensor_names(graph, tensor_names):
123  """Gets the Tensors associated with the `tensor_names` in the provided graph.
124
125  Args:
126    graph: TensorFlow Graph.
127    tensor_names: List of strings that represent names of tensors in the graph.
128
129  Returns:
130    A list of Tensor objects in the same order the names are provided.
131
132  Raises:
133    ValueError:
134      tensor_names contains an invalid tensor name.
135  """
136  # Get the list of all of the tensors.
137  tensor_name_to_tensor = {}
138  for op in graph.get_operations():
139    for tensor in op.values():
140      tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
141
142  # Get the tensors associated with tensor_names.
143  tensors = []
144  invalid_tensors = []
145  for name in tensor_names:
146    if not isinstance(name, six.string_types):
147      raise ValueError("Invalid type for a tensor name in the provided graph. "
148                       "Expected type for a tensor name is 'str', instead got "
149                       "type '{}' for tensor name '{}'".format(
150                           type(name), name))
151
152    tensor = tensor_name_to_tensor.get(name)
153    if tensor is None:
154      invalid_tensors.append(name)
155    else:
156      tensors.append(tensor)
157
158  # Throw ValueError if any user input names are not valid tensors.
159  if invalid_tensors:
160    raise ValueError("Invalid tensors '{}' were found.".format(
161        ",".join(invalid_tensors)))
162  return tensors
163
164
165def set_tensor_shapes(tensors, shapes):
166  """Sets Tensor shape for each tensor if the shape is defined.
167
168  Args:
169    tensors: TensorFlow ops.Tensor.
170    shapes: Dict of strings representing input tensor names to list of
171      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
172
173  Raises:
174    ValueError:
175      `shapes` contains an invalid tensor.
176      `shapes` contains an invalid shape for a valid tensor.
177  """
178  if shapes:
179    tensor_names_to_tensor = {
180        get_tensor_name(tensor): tensor for tensor in tensors
181    }
182    for name, shape in shapes.items():
183      if name not in tensor_names_to_tensor:
184        raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
185                         "map.".format(name))
186      if shape is not None:
187        tensor = tensor_names_to_tensor[name]
188        try:
189          tensor.set_shape(shape)
190        except ValueError as error:
191          message = ("The shape of tensor '{0}' cannot be changed from {1} to "
192                     "{2}. {3}".format(name, tensor.shape, shape, str(error)))
193          raise ValueError(message)
194
195
196def get_grappler_config(optimizers_list):
197  """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
198
199  Args:
200    optimizers_list: List of strings that represents the list of optimizers.
201
202  Returns:
203    tf.ConfigProto.
204  """
205  config = _config_pb2.ConfigProto()
206  rewrite_options = config.graph_options.rewrite_options
207  for optimizer in optimizers_list:
208    rewrite_options.optimizers.append(optimizer)
209  return config
210
211
212def run_graph_optimizations(graph_def,
213                            input_arrays,
214                            output_arrays,
215                            config,
216                            graph=None):
217  """Apply standard TensorFlow optimizations to the graph_def.
218
219  Args:
220    graph_def: Frozen GraphDef to be optimized.
221    input_arrays: List of arrays that are considered inputs of the graph.
222    output_arrays: List of arrays that are considered outputs of the graph.
223    config: tf.ConfigProto.
224    graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
225
226  Returns:
227    A new, optimized GraphDef.
228  """
229  meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
230
231  signature = _meta_graph_pb2.SignatureDef()
232  for array in input_arrays:
233    signature.inputs[array.name].name = array.name
234    signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
235    signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
236
237  for array in output_arrays:
238    signature.outputs[array.name].name = array.name
239    signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
240    signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
241
242  meta_graph.signature_def["not_used_key"].CopyFrom(signature)
243
244  # We need to add a collection called 'train_op' so that grappler
245  # knows what the outputs are.
246  fetch_collection = _meta_graph_pb2.CollectionDef()
247  for array in input_arrays + output_arrays:
248    fetch_collection.node_list.value.append(array.name)
249  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
250
251  return tf_optimizer.OptimizeGraph(config, meta_graph)
252
253
254def _convert_op_hints_if_present(sess, graph_def, output_tensors,
255                                 hinted_outputs_nodes):
256  if is_frozen_graph(sess):
257    raise ValueError("Try to convert op hints, needs unfrozen graph.")
258  output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
259  graph_def = tf_graph_util.convert_variables_to_constants(
260      sess, graph_def, output_arrays + hinted_outputs_nodes)
261  graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
262  return graph_def
263
264
265def freeze_graph(sess, input_tensors, output_tensors):
266  """Returns a frozen GraphDef.
267
268  Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
269  existing GraphDef is returned. The Grappler pass is only run on models that
270  are frozen in order to inline the functions in the graph.
271  If OpHints is present, it will try to convert the OpHint graph.
272
273  Args:
274    sess: TensorFlow Session.
275    input_tensors: List of input tensors.
276    output_tensors: List of output tensors (only .name is used from this).
277
278  Returns:
279    Frozen GraphDef.
280  """
281  # Runs a Grappler pass in order to inline any functions in the graph.
282  # Asides from inlining any simple function, Grappler will also try to lower
283  # while loop into switch merge representation which is undesired for Ophints,
284  # so we simply remove those attributes to prevent Grappler from doing so.
285  graph_def = _convert_to_constants.disable_lower_using_switch_merge(
286      sess.graph_def)
287  config = get_grappler_config(["function"])
288  graph_def = run_graph_optimizations(
289      graph_def, input_tensors, output_tensors, config, graph=sess.graph)
290
291  # If ophints are present, just convert them.
292  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
293  if hinted_outputs_nodes:
294    return _convert_op_hints_if_present(sess, graph_def, output_tensors,
295                                        hinted_outputs_nodes)
296
297  if not is_frozen_graph(sess):
298    output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
299    return tf_graph_util.convert_variables_to_constants(sess, graph_def,
300                                                        output_node_names)
301  else:
302    return sess.graph_def
303
304
305def is_frozen_graph(sess):
306  """Determines if the graph is frozen.
307
308  Determines if a graph has previously been frozen by checking for any
309  operations of type Variable*. If variables are found, the graph is not frozen.
310
311  Args:
312    sess: TensorFlow Session.
313
314  Returns:
315    Bool.
316  """
317  for op in sess.graph.get_operations():
318    if six.ensure_str(op.type).startswith("Variable") or six.ensure_str(
319        op.type).endswith("VariableOp"):
320      return False
321  return True
322
323
324def build_debug_info_func(original_graph):
325  """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
326
327  Args:
328    original_graph: The original `Graph` containing all the op stack traces.
329
330  Returns:
331    A function which retrieves the stack traces from the original graph and
332    converts them to a `GraphDebugInfo` for a given set of nodes.
333  """
334
335  def f(original_nodes):
336    """Function to create `GraphDebugInfo` for the given `original_nodes`."""
337    if not original_graph:
338      return None
339    # For the given nodes, gets all the op definitions in the original graph.
340    useful_ops = []
341    for func, name in original_nodes:
342      try:
343        if not func:
344          useful_ops.append((func, original_graph.get_operation_by_name(name)))
345        else:
346          sub_func = original_graph._get_function(func)  # pylint: disable=protected-access
347          if isinstance(sub_func, function._EagerDefinedFunction):  # pylint: disable=protected-access
348            useful_ops.append(
349                (func, sub_func.graph.get_operation_by_name(name)))
350          else:
351            sys.stderr.write(
352                "Use '@tf.function' or '@defun' to decorate the function.\n")
353            continue
354      except KeyError:
355        # New node created by graph optimizer. No stack trace from source code.
356        continue
357    # Convert all the op definitions to stack traces in terms of GraphDebugInfo.
358    return _error_interpolation.create_graph_debug_info_def(useful_ops)
359
360  return f
361
362
363def convert_debug_info_func(saved_debug_info):
364  """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
365
366  Args:
367    saved_debug_info: The `GraphDebugInfo` containing all the debug info.
368
369  Returns:
370    A function which retrieves the stack traces from the original graph and
371    converts them to a `GraphDebugInfo` for a given set of nodes.
372  """
373
374  def f(original_nodes):
375    """Function to create `GraphDebugInfo` for the given `original_nodes`."""
376    if not saved_debug_info:
377      return None
378
379    output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
380    # All the files are copied over, so the index wouldn't be changed.
381    output_debug_info.files[:] = saved_debug_info.files
382    # We only copy over the debug info for the input nodes
383    for func, node in original_nodes:
384      debug_key = node + "@" + func
385      output_debug_info.traces[debug_key].CopyFrom(
386          saved_debug_info.traces[debug_key])
387    return output_debug_info
388
389  return f
390
391
392def get_debug_info(nodes_to_debug_info_func, converted_graph):
393  """Returns the debug info for the original nodes in the `converted_graph`.
394
395  Args:
396    nodes_to_debug_info_func: The method to collect the op debug info for the
397      nodes.
398    converted_graph: A `GraphDef` after optimization and transformation.
399
400  Returns:
401    `GraphDebugInfo` for all the original nodes in `converted_graph`.
402  """
403  if not nodes_to_debug_info_func:
404    return None
405
406  # Collect all the debug info nodes from the converted_graph
407  original_nodes = set()
408  for node in converted_graph.node:
409    debug_nodes = node.experimental_debug_info.original_node_names
410    debug_funcs = node.experimental_debug_info.original_func_names
411    # If the `original_node_names` are empty, uses the node name directly.
412    if not debug_nodes:
413      original_nodes.add(("", node.name))
414    else:
415      for i in range(len(debug_nodes)):
416        debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
417        original_nodes.add((debug_func, debug_nodes[i]))
418
419  # Convert the nodes to the debug info proto object.
420  return nodes_to_debug_info_func(original_nodes)
421
422
423def convert_bytes_to_c_source(data,
424                              array_name,
425                              max_line_width=80,
426                              include_guard=None,
427                              include_path=None,
428                              use_tensorflow_license=False):
429  """Returns strings representing a C constant array containing `data`.
430
431  Args:
432    data: Byte array that will be converted into a C constant.
433    array_name: String to use as the variable name for the constant array.
434    max_line_width: The longest line length, for formatting purposes.
435    include_guard: Name to use for the include guard macro definition.
436    include_path: Optional path to include in the source file.
437    use_tensorflow_license: Whether to include the standard TensorFlow Apache2
438      license in the generated files.
439
440  Returns:
441    Text that can be compiled as a C source file to link in the data as a
442    literal array of values.
443    Text that can be used as a C header file to reference the literal array.
444  """
445
446  starting_pad = "   "
447  array_lines = []
448  array_line = starting_pad
449  for value in bytearray(data):
450    if (len(array_line) + 4) > max_line_width:
451      array_lines.append(array_line + "\n")
452      array_line = starting_pad
453    array_line += " 0x%02x," % (value)
454  if len(array_line) > len(starting_pad):
455    array_lines.append(array_line + "\n")
456  array_values = "".join(array_lines)
457
458  if include_guard is None:
459    include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
460
461  if include_path is not None:
462    include_line = "#include \"{include_path}\"\n".format(
463        include_path=include_path)
464  else:
465    include_line = ""
466
467  if use_tensorflow_license:
468    license_text = """
469/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
470
471Licensed under the Apache License, Version 2.0 (the "License");
472you may not use this file except in compliance with the License.
473You may obtain a copy of the License at
474
475    http://www.apache.org/licenses/LICENSE-2.0
476
477Unless required by applicable law or agreed to in writing, software
478distributed under the License is distributed on an "AS IS" BASIS,
479WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
480See the License for the specific language governing permissions and
481limitations under the License.
482==============================================================================*/
483""".format(year=datetime.date.today().year)
484  else:
485    license_text = ""
486
487  source_template = """{license_text}
488// This is a TensorFlow Lite model file that has been converted into a C data
489// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
490// This form is useful for compiling into a binary for devices that don't have a
491// file system.
492
493{include_line}
494// We need to keep the data array aligned on some architectures.
495#ifdef __has_attribute
496#define HAVE_ATTRIBUTE(x) __has_attribute(x)
497#else
498#define HAVE_ATTRIBUTE(x) 0
499#endif
500#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
501#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
502#else
503#define DATA_ALIGN_ATTRIBUTE
504#endif
505
506const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
507{array_values}}};
508const int {array_name}_len = {array_length};
509"""
510
511  source_text = source_template.format(
512      array_name=array_name,
513      array_length=len(data),
514      array_values=array_values,
515      license_text=license_text,
516      include_line=include_line)
517
518  header_template = """
519{license_text}
520
521// This is a TensorFlow Lite model file that has been converted into a C data
522// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
523// This form is useful for compiling into a binary for devices that don't have a
524// file system.
525
526#ifndef {include_guard}
527#define {include_guard}
528
529extern const unsigned char {array_name}[];
530extern const int {array_name}_len;
531
532#endif  // {include_guard}
533"""
534
535  header_text = header_template.format(
536      array_name=array_name,
537      include_guard=include_guard,
538      license_text=license_text)
539
540  return source_text, header_text
541
542
543def _convert_model_from_bytearray_to_object(model_bytearray):
544  """Converts a tflite model from a bytearray into a parsable object."""
545  model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
546  model_object = schema_fb.ModelT.InitFromObj(model_object)
547  model_object = copy.deepcopy(model_object)
548  model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0]
549  return model_object
550
551
552def _convert_model_from_object_to_bytearray(model_object):
553  """Converts a tflite model from a parsable object into a bytearray."""
554  # Initial size of the buffer, which will grow automatically if needed
555  builder = flatbuffers.Builder(1024)
556  model_offset = model_object.Pack(builder)
557  builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
558  return bytes(builder.Output())
559
560
561def get_quantize_opcode_idx(model):
562  """Returns the quantize op idx."""
563  quant_opcode_idxs = []
564  for idx, opcode in enumerate(model.operatorCodes):
565    builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
566    if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
567      quant_opcode_idxs.append(idx)
568  return quant_opcode_idxs
569
570
571def get_dequantize_opcode_idx(model):
572  """Returns the quantize op idx."""
573  quant_opcode_idxs = []
574  for idx, opcode in enumerate(model.operatorCodes):
575    builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
576    if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
577      quant_opcode_idxs.append(idx)
578  return quant_opcode_idxs
579
580
581def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
582  """Update the tensors in the SignatureDef's TensorMaps."""
583  for i in range(len(tensor_maps)):
584    if tensor_maps[i].tensorIndex in map_old_to_new_tensors:
585      tensor_maps[i].tensorIndex = (
586          map_old_to_new_tensors[tensor_maps[i].tensorIndex])
587
588
589def _remove_tensors_from_model(model, remove_tensors_idxs):
590  """Remove tensors from model."""
591  if not remove_tensors_idxs:
592    return
593  if len(model.subgraphs) > 1:
594    raise ValueError("Model must only have one subgraph. Instead, it has "
595                     "{} subgraphs.".format(len(model.subgraphs)))
596  subgraph = model.subgraphs[0]
597  tensors = subgraph.tensors
598  operators = subgraph.operators
599
600  logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
601  # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
602  # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
603  if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
604    logging.debug("Removing tensors only at the end of the tensor list")
605    del tensors[min(remove_tensors_idxs):]
606  else:
607    logging.debug("Removing tensors requires updating the model")
608    # Map the old tensor indices to new tensor indices
609    d_old_to_new_tensors = {}
610    left_shift_by = 0
611    for idx in range(len(tensors)):
612      if idx in remove_tensors_idxs:
613        left_shift_by += 1
614      else:
615        d_old_to_new_tensors[idx] = idx - left_shift_by
616    logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
617    # Update tensor indices referenced throughout the model
618    def update_tensors(tensor_idxs):
619      for i, ti in enumerate(tensor_idxs):
620        tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
621    update_tensors(subgraph.inputs)
622    update_tensors(subgraph.outputs)
623    for op in operators:
624      update_tensors(op.inputs)
625      update_tensors(op.outputs)
626    if model.signatureDefs:
627      signature_def = model.signatureDefs[0]
628      _update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors)
629      _update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors)
630    # Delete the tensors
631    for idx in sorted(remove_tensors_idxs, reverse=True):
632      tensors.pop(idx)
633    logging.debug("Removed tensors marked for deletion")
634
635
636def _modify_model_input_type(model, inference_input_type=dtypes.float32):
637  """Modify model input type."""
638
639  if inference_input_type == dtypes.float32:
640    return
641
642  subgraph = model.subgraphs[0]
643  tensors = subgraph.tensors
644  operators = subgraph.operators
645
646  # Find all quantize operators
647  quant_opcode_idxs = get_quantize_opcode_idx(model)
648  if operators and not quant_opcode_idxs:
649    for input_idx in subgraph.inputs:
650      input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
651      if input_type == dtypes.float32:
652        raise ValueError("Model input is not dequantized.")
653    # None of the inputs have float32, then they must be int16, int8, or bool
654    return
655
656  # Validate that the model input is quantized
657  input_quant_ops = []
658  for op in operators:
659    # Find operators that quantize model input
660    if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
661      float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
662      # If found, validate that the operator's input type is float
663      float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
664      if float_type != dtypes.float32:
665        if float_type == inference_input_type:
666          continue
667        else:
668          raise ValueError(
669              "Initial model input type must be tf.float32. Expected type for "
670              "tensor with name '{}' is tf.float32, instead type is {}".format(
671                  float_tensor.name, get_tf_type_name(float_type)))
672      # If found, validate that the operator output is quantized and compatible
673      # with the final model input type
674      quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
675      if quant_type not in _MAP_QUANT_TO_IO_TYPES:
676        raise ValueError(
677            "Initial model input is not quantized. Expected type for "
678            "tensor with name '{}' should be in {}, instead type is {}".format(
679                quant_tensor.name,
680                tuple(get_tf_type_name(t) for t in
681                      _MAP_QUANT_TO_IO_TYPES.keys()),
682                get_tf_type_name(quant_type)))
683      else:
684        inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
685        if inference_input_type not in inference_io_types:
686          raise ValueError(
687              "Unsupported `inference_input_type` value. Expected to be in "
688              "{}, instead got {}.".format(
689                  tuple(get_tf_type_name(t) for t in inference_io_types),
690                  get_tf_type_name(inference_input_type)))
691      input_quant_ops.append(op)
692
693  if len(subgraph.inputs) != len(input_quant_ops):
694    logging.warning(
695        "For model inputs containing unsupported operations which cannot be "
696        "quantized, the `inference_input_type` attribute will default to the "
697        "original type."
698        )
699
700  # Modify model input type
701  if inference_input_type == dtypes.uint8:
702    # Change quant op (float to int8) to quant op (uint8 to int8)
703    for op in input_quant_ops:
704      int8_quantization = tensors[op.outputs[0]].quantization
705      uint8_quantization = schema_fb.QuantizationParametersT()
706      uint8_quantization.scale = [int8_quantization.scale[0]]
707      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
708      tensors[op.inputs[0]].quantization = uint8_quantization
709      tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
710  elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
711    # Remove the inputs and the quant operator
712    remove_tensors_idxs = set()
713    for op in input_quant_ops:
714      subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
715      if model.signatureDefs:
716        signature_def = model.signatureDefs[0]
717        for i in range(len(signature_def.inputs)):
718          if signature_def.inputs[i].tensorIndex == op.inputs[0]:
719            signature_def.inputs[i].tensorIndex = op.outputs[0]
720      remove_tensors_idxs.add(op.inputs[0])
721      operators.remove(op)
722    # Remove tensors marked for deletion.
723    _remove_tensors_from_model(model, remove_tensors_idxs)
724  else:
725    raise ValueError(
726        "Unsupported `inference_input_type` value {}.".format(
727            get_tf_type_name(inference_input_type)))
728
729
730def _modify_model_output_type(model, inference_output_type=dtypes.float32):
731  """Modify model output type."""
732
733  if inference_output_type == dtypes.float32:
734    return
735
736  subgraph = model.subgraphs[0]
737  tensors = subgraph.tensors
738  operators = subgraph.operators
739
740  # Find all dequantize operators
741  dequant_opcode_idxs = get_dequantize_opcode_idx(model)
742  if operators and not dequant_opcode_idxs:
743    for output in subgraph.outputs:
744      output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
745      if output_type == dtypes.float32:
746        raise ValueError("Model output is not dequantized.")
747    # None of the outputs have float32, then they must be int16, int8, or bool
748    return
749
750  # Validate that the model output is dequantized
751  output_dequant_ops = []
752  for op in operators:
753    # Find operators that dequantize model output
754    if (op.opcodeIndex in dequant_opcode_idxs and
755        op.outputs[0] in subgraph.outputs):
756      # If found, validate that the operator's output type is float
757      quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
758      float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
759      if float_type != dtypes.float32:
760        if float_type == inference_output_type:
761          continue
762        else:
763          raise ValueError(
764              "Initial model output type must be tf.float32. Expected type for "
765              "tensor with name '{}' is tf.float32, instead type is {}".format(
766                  float_tensor.name, get_tf_type_name(float_type)))
767      # If found, validate that the operator input is quantized and compatible
768      # with the final model output type
769      quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
770      if quant_type not in _MAP_QUANT_TO_IO_TYPES:
771        raise ValueError(
772            "Initial model output is not dequantized. Expected type for "
773            "tensor with name '{}' should be in {}, instead type is {}".format(
774                quant_tensor.name,
775                tuple(get_tf_type_name(t) for t in
776                      _MAP_QUANT_TO_IO_TYPES.keys()),
777                get_tf_type_name(quant_type)))
778      else:
779        inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
780        if inference_output_type not in inference_io_types:
781          raise ValueError(
782              "Unsupported `inference_output_type` value. Expected to be in "
783              "{}, instead got {}.".format(
784                  tuple(get_tf_type_name(t) for t in inference_io_types),
785                  get_tf_type_name(inference_output_type)))
786      output_dequant_ops.append(op)
787
788  if len(subgraph.outputs) != len(output_dequant_ops):
789    logging.warning(
790        "For model outputs containing unsupported operations which cannot be "
791        "quantized, the `inference_output_type` attribute will default to the "
792        "original type."
793        )
794
795  # Modify model output type
796  if inference_output_type == dtypes.uint8:
797    # Find a quantize operator
798    quant_opcode_idx = -1
799    for idx, opcode in enumerate(model.operatorCodes):
800      builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
801      if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
802        quant_opcode_idx = idx
803        break
804    # Create a quantize operator, if none exist
805    if quant_opcode_idx == -1:
806      quant_op = schema_fb.OperatorCodeT()
807      quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
808      quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
809      model.operatorCodes.append(quant_op)
810      quant_opcode_idx = len(model.operatorCodes) - 1
811    # Change dequant op (int8 to float) to quant op (int8 to uint8)
812    for op in output_dequant_ops:
813      op.opcodeIndex = quant_opcode_idx
814      int8_quantization = tensors[op.inputs[0]].quantization
815      uint8_quantization = schema_fb.QuantizationParametersT()
816      uint8_quantization.scale = [int8_quantization.scale[0]]
817      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
818      tensors[op.outputs[0]].quantization = uint8_quantization
819      tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
820  elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
821    # Remove the outputs and the dequant operator
822    remove_tensors_idxs = set()
823    for op in output_dequant_ops:
824      subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
825      if model.signatureDefs:
826        signature_def = model.signatureDefs[0]
827        for i in range(len(signature_def.outputs)):
828          if signature_def.outputs[i].tensorIndex == op.outputs[0]:
829            signature_def.outputs[i].tensorIndex = op.inputs[0]
830      remove_tensors_idxs.add(op.outputs[0])
831      operators.remove(op)
832    # Remove tensors marked for deletion.
833    _remove_tensors_from_model(model, remove_tensors_idxs)
834  else:
835    raise ValueError(
836        "Unsupported `inference_output_type` value {}.".format(
837            get_tf_type_name(inference_output_type)))
838
839
840def _remove_redundant_quantize_ops(model):
841  """Finds back to back quantize ops and remove the first quantize op."""
842  subgraph = model.subgraphs[0]
843  tensors = subgraph.tensors
844  operators = subgraph.operators
845
846  # Find all quantize operators.
847  quant_opcode_idxs = get_quantize_opcode_idx(model)
848  dequant_opcode_idxs = get_dequantize_opcode_idx(model)
849
850  # Find all redundant quant tensors.
851  all_quant_ops = []
852  redundant_quant_tensors = {}
853  output_dequant_tensors = {}
854  for op in operators:
855    if op.opcodeIndex in quant_opcode_idxs:
856      all_quant_ops.append(op)
857      input_tensor = tensors[op.inputs[0]]
858      output_tensor = tensors[op.outputs[0]]
859      input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type)
860      output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type)
861      # This is a requantize op, so write down its input tensor index.
862      if input_type != dtypes.float32 and output_type != dtypes.float32:
863        redundant_quant_tensors[op.inputs[0]] = op
864    if (op.opcodeIndex in dequant_opcode_idxs and
865        op.outputs[0] in subgraph.outputs):
866      output_dequant_tensors[op.inputs[0]] = op
867
868  # Remove all the quant ops which produce the redundant quant tensors.
869  for op in all_quant_ops:
870    output_tensor_idx = op.outputs[0]
871    if output_tensor_idx in redundant_quant_tensors:
872      requantize_op = redundant_quant_tensors[output_tensor_idx]
873      if model.signatureDefs:
874        signature_def = model.signatureDefs[0]
875        for output in signature_def.outputs:
876          if output.tensorIndex == op.outputs[0]:
877            output.tensorIndex = op.inputs[0]
878      # Reset the input of the requantize op to the float input
879      requantize_op.inputs[0] = op.inputs[0]
880      operators.remove(op)
881
882  # Remove all the quant ops which connect to the output dequant op.
883  for op in all_quant_ops:
884    output_tensor_idx = op.outputs[0]
885    if output_tensor_idx in output_dequant_tensors:
886      dequant_op = output_dequant_tensors[output_tensor_idx]
887      subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
888      if model.signatureDefs:
889        signature_def = model.signatureDefs[0]
890        for output in signature_def.outputs:
891          if output.tensorIndex == dequant_op.outputs[0]:
892            output.tensorIndex = op.inputs[0]
893      operators.remove(op)
894      operators.remove(dequant_op)
895
896
897def modify_model_io_type(
898    model, inference_input_type=dtypes.float32,
899    inference_output_type=dtypes.float32):
900  """Modify the input/output type of a tflite model.
901
902  Args:
903    model: A tflite model.
904    inference_input_type: tf.DType representing modified input type.
905      (default tf.float32. If model input is int8 quantized, it must be in
906      {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
907      it must be in {tf.float32, tf.int16}, else it must be tf.float32)
908    inference_output_type: tf.DType representing modified output type.
909      (default tf.float32. If model output is int8 dequantized, it must be in
910      {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
911      it must be in {tf.float32, tf.int16}, else it must be tf.float32)
912  Returns:
913    A tflite model with modified input/output type.
914
915  Raises:
916    ValueError: If `inference_input_type`/`inference_output_type` is unsupported
917      or a supported integer type is specified for a model whose input/output is
918      not quantized/dequantized.
919    RuntimeError: If the modification was unsuccessful.
920
921  """
922  if (inference_input_type == dtypes.float32 and
923      inference_output_type == dtypes.float32):
924    return model
925
926  model_object = _convert_model_from_bytearray_to_object(model)
927
928  if len(model_object.subgraphs) > 1:
929    raise ValueError("Model must only have one subgraph. Instead, it has "
930                     "{} subgraphs.".format(len(model_object.subgraphs)))
931
932  _modify_model_input_type(model_object, inference_input_type)
933
934  _modify_model_output_type(model_object, inference_output_type)
935
936  _remove_redundant_quantize_ops(model_object)
937
938  return _convert_model_from_object_to_bytearray(model_object)
939