• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions used by multiple converter files."""
16
17import copy
18import datetime
19import sys
20
21from absl import logging
22
23import flatbuffers
24from tensorflow.core.protobuf import config_pb2 as _config_pb2
25from tensorflow.core.protobuf import graph_debug_info_pb2
26from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
27from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
28from tensorflow.lite.python import schema_py_generated as schema_fb
29from tensorflow.lite.python import schema_util
30from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
31from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
32from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
33from tensorflow.lite.tools import flatbuffer_utils
34from tensorflow.python.eager import function
35from tensorflow.python.framework import convert_to_constants as _convert_to_constants
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import error_interpolation as _error_interpolation
38from tensorflow.python.framework import graph_util as tf_graph_util
39from tensorflow.python.grappler import tf_optimizer
40from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
41
42# The field name of conversion metadata in the flatbuffer file.
43CONVERSION_METADATA_FIELD_NAME = "CONVERSION_METADATA"
44
45# Keras functions used by TFLite
46model_input_signature = _tflite_keras_util.model_input_signature
47trace_model_call = _tflite_keras_util.trace_model_call
48
49# Jax functions used by TFLite
50# pylint: disable=g-import-not-at-top
51# pylint: disable=unused-import
52try:
53  from jax import xla_computation as _xla_computation
54except ImportError:
55  _xla_computation = None
56# pylint: enable=g-import-not-at-top
57# pylint: enable=unused-import
58
59# Defined as per TFLite schema
60_MAP_TFLITE_ENUM_TO_TF_TYPES = {
61    0: dtypes.float32,
62    1: dtypes.float16,
63    2: dtypes.int32,
64    3: dtypes.uint8,
65    4: dtypes.int64,
66    5: dtypes.string,
67    6: dtypes.bool,
68    7: dtypes.int16,
69    8: dtypes.complex64,
70    9: dtypes.int8,
71    10: dtypes.float64,
72    11: dtypes.complex128,
73    16: dtypes.uint32,
74}
75
76_TFLITE_FILE_IDENTIFIER = b"TFL3"
77
78_MAP_QUANT_TO_IO_TYPES = {
79    dtypes.int8: {dtypes.int8, dtypes.uint8},
80    dtypes.int16: {dtypes.int16},
81}
82
83
84def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
85  """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
86
87  Args:
88    tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
89
90  Raises:
91    ValueError: If an invalid tflite enum type is provided.
92
93  Returns:
94    tf type (eg: tf.float32)
95  """
96  tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
97  if tf_type is None:
98    raise ValueError(
99        "Unsupported enum {}. The valid map of enum to tf types is : {}"
100        .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
101  return tf_type
102
103
104def get_tf_type_name(tf_type):
105  """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
106  return "tf." + tf_type.name if tf_type else None
107
108
109def get_tensor_name(tensor):
110  """Returns name of the input tensor.
111
112  Args:
113    tensor: tf.Tensor
114
115  Returns:
116    str
117  """
118  parts = tensor.name.split(":")
119  if len(parts) > 2:
120    raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
121        len(parts) - 1))
122
123  # To be consistent with the tensor naming scheme in tensorflow, we need
124  # drop the ':0' suffix for the first tensor.
125  if len(parts) > 1 and parts[1] != "0":
126    return tensor.name
127  return parts[0]
128
129
130def get_tensors_from_tensor_names(graph, tensor_names):
131  """Gets the Tensors associated with the `tensor_names` in the provided graph.
132
133  Args:
134    graph: TensorFlow Graph.
135    tensor_names: List of strings that represent names of tensors in the graph.
136
137  Returns:
138    A list of Tensor objects in the same order the names are provided.
139
140  Raises:
141    ValueError:
142      tensor_names contains an invalid tensor name.
143  """
144  # Get the list of all of the tensors.
145  tensor_name_to_tensor = {}
146  for op in graph.get_operations():
147    for tensor in op.values():
148      tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
149
150  # Get the tensors associated with tensor_names.
151  tensors = []
152  invalid_tensors = []
153  for name in tensor_names:
154    if not isinstance(name, str):
155      raise ValueError("Invalid type for a tensor name in the provided graph. "
156                       "Expected type for a tensor name is 'str', instead got "
157                       "type '{}' for tensor name '{}'".format(
158                           type(name), name))
159
160    tensor = tensor_name_to_tensor.get(name)
161    if tensor is None:
162      invalid_tensors.append(name)
163    else:
164      tensors.append(tensor)
165
166  # Throw ValueError if any user input names are not valid tensors.
167  if invalid_tensors:
168    raise ValueError("Invalid tensors '{}' were found.".format(
169        ",".join(invalid_tensors)))
170  return tensors
171
172
173def set_tensor_shapes(tensors, shapes):
174  """Sets Tensor shape for each tensor if the shape is defined.
175
176  Args:
177    tensors: TensorFlow ops.Tensor.
178    shapes: Dict of strings representing input tensor names to list of
179      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
180
181  Raises:
182    ValueError:
183      `shapes` contains an invalid tensor.
184      `shapes` contains an invalid shape for a valid tensor.
185  """
186  if shapes:
187    tensor_names_to_tensor = {
188        get_tensor_name(tensor): tensor for tensor in tensors
189    }
190    for name, shape in shapes.items():
191      if name not in tensor_names_to_tensor:
192        raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
193                         "map.".format(name))
194      if shape is not None:
195        tensor = tensor_names_to_tensor[name]
196        try:
197          tensor.set_shape(shape)
198        except ValueError as error:
199          message = ("The shape of tensor '{0}' cannot be changed from {1} to "
200                     "{2}. {3}".format(name, tensor.shape, shape, str(error)))
201          raise ValueError(message)
202
203
204def get_grappler_config(optimizers_list):
205  """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
206
207  Args:
208    optimizers_list: List of strings that represents the list of optimizers.
209
210  Returns:
211    tf.ConfigProto.
212  """
213  config = _config_pb2.ConfigProto()
214  rewrite_options = config.graph_options.rewrite_options
215  for optimizer in optimizers_list:
216    rewrite_options.optimizers.append(optimizer)
217  return config
218
219
220def run_graph_optimizations(graph_def,
221                            input_arrays,
222                            output_arrays,
223                            config,
224                            graph=None):
225  """Apply standard TensorFlow optimizations to the graph_def.
226
227  Args:
228    graph_def: Frozen GraphDef to be optimized.
229    input_arrays: List of arrays that are considered inputs of the graph.
230    output_arrays: List of arrays that are considered outputs of the graph.
231    config: tf.ConfigProto.
232    graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
233
234  Returns:
235    A new, optimized GraphDef.
236  """
237  meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
238
239  signature = _meta_graph_pb2.SignatureDef()
240  for array in input_arrays:
241    signature.inputs[array.name].name = array.name
242    signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
243    signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
244
245  for array in output_arrays:
246    signature.outputs[array.name].name = array.name
247    signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
248    signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
249
250  meta_graph.signature_def["not_used_key"].CopyFrom(signature)
251
252  # We need to add a collection called 'train_op' so that grappler
253  # knows what the outputs are.
254  fetch_collection = _meta_graph_pb2.CollectionDef()
255  for array in input_arrays + output_arrays:
256    fetch_collection.node_list.value.append(array.name)
257  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
258
259  return tf_optimizer.OptimizeGraph(config, meta_graph)
260
261
262def _convert_op_hints_if_present(sess, graph_def, output_tensors,
263                                 hinted_outputs_nodes):
264  if is_frozen_graph(sess):
265    raise ValueError("Try to convert op hints, needs unfrozen graph.")
266  output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
267  graph_def = tf_graph_util.convert_variables_to_constants(
268      sess, graph_def, output_arrays + hinted_outputs_nodes)
269  graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
270  return graph_def
271
272
273def freeze_graph(sess, input_tensors, output_tensors):
274  """Returns a frozen GraphDef.
275
276  Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
277  existing GraphDef is returned. The Grappler pass is only run on models that
278  are frozen in order to inline the functions in the graph.
279  If OpHints is present, it will try to convert the OpHint graph.
280
281  Args:
282    sess: TensorFlow Session.
283    input_tensors: List of input tensors.
284    output_tensors: List of output tensors (only .name is used from this).
285
286  Returns:
287    Frozen GraphDef.
288  """
289  # Runs a Grappler pass in order to inline any functions in the graph.
290  # Asides from inlining any simple function, Grappler will also try to lower
291  # while loop into switch merge representation which is undesired for Ophints,
292  # so we simply remove those attributes to prevent Grappler from doing so.
293  graph_def = _convert_to_constants.disable_lower_using_switch_merge(
294      sess.graph_def)
295  config = get_grappler_config(["function"])
296  graph_def = run_graph_optimizations(
297      graph_def, input_tensors, output_tensors, config, graph=sess.graph)
298
299  # If ophints are present, just convert them.
300  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
301  if hinted_outputs_nodes:
302    return _convert_op_hints_if_present(sess, graph_def, output_tensors,
303                                        hinted_outputs_nodes)
304
305  if not is_frozen_graph(sess):
306    output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
307    return tf_graph_util.convert_variables_to_constants(sess, graph_def,
308                                                        output_node_names)
309  else:
310    return sess.graph_def
311
312
313def is_frozen_graph(sess):
314  """Determines if the graph is frozen.
315
316  Determines if a graph has previously been frozen by checking for any
317  operations of type Variable*. If variables are found, the graph is not frozen.
318
319  Args:
320    sess: TensorFlow Session.
321
322  Returns:
323    Bool.
324  """
325  for op in sess.graph.get_operations():
326    if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
327      return False
328  return True
329
330
331def build_debug_info_func(original_graph):
332  """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
333
334  Args:
335    original_graph: The original `Graph` containing all the op stack traces.
336
337  Returns:
338    A function which retrieves the stack traces from the original graph and
339    converts them to a `GraphDebugInfo` for a given set of nodes.
340  """
341
342  def f(original_nodes):
343    """Function to create `GraphDebugInfo` for the given `original_nodes`."""
344    if not original_graph:
345      return None
346    # For the given nodes, gets all the op definitions in the original graph.
347    useful_ops = []
348    for func, name in original_nodes:
349      try:
350        if not func:
351          useful_ops.append((func, original_graph.get_operation_by_name(name)))
352        else:
353          sub_func = original_graph._get_function(func)  # pylint: disable=protected-access
354          if isinstance(sub_func, function._EagerDefinedFunction):  # pylint: disable=protected-access
355            useful_ops.append(
356                (func, sub_func.graph.get_operation_by_name(name)))
357          else:
358            sys.stderr.write(
359                "Use '@tf.function' or '@defun' to decorate the function.\n")
360            continue
361      except KeyError:
362        # New node created by graph optimizer. No stack trace from source code.
363        continue
364    # Convert all the op definitions to stack traces in terms of GraphDebugInfo.
365    return _error_interpolation.create_graph_debug_info_def(useful_ops)
366
367  return f
368
369
370def convert_debug_info_func(saved_debug_info):
371  """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
372
373  Args:
374    saved_debug_info: The `GraphDebugInfo` containing all the debug info.
375
376  Returns:
377    A function which retrieves the stack traces from the original graph and
378    converts them to a `GraphDebugInfo` for a given set of nodes.
379  """
380
381  def f(original_nodes):
382    """Function to create `GraphDebugInfo` for the given `original_nodes`."""
383    if not saved_debug_info:
384      return None
385
386    output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
387    # All the files are copied over, so the index wouldn't be changed.
388    output_debug_info.files[:] = saved_debug_info.files
389    # We only copy over the debug info for the input nodes
390    for func, node in original_nodes:
391      debug_key = node + "@" + func
392      output_debug_info.traces[debug_key].CopyFrom(
393          saved_debug_info.traces[debug_key])
394    return output_debug_info
395
396  return f
397
398
399def get_debug_info(nodes_to_debug_info_func, converted_graph):
400  """Returns the debug info for the original nodes in the `converted_graph`.
401
402  Args:
403    nodes_to_debug_info_func: The method to collect the op debug info for the
404      nodes.
405    converted_graph: A `GraphDef` after optimization and transformation.
406
407  Returns:
408    `GraphDebugInfo` for all the original nodes in `converted_graph`.
409  """
410  if not nodes_to_debug_info_func:
411    return None
412
413  # Collect all the debug info nodes from the converted_graph
414  original_nodes = set()
415  for node in converted_graph.node:
416    debug_nodes = node.experimental_debug_info.original_node_names
417    debug_funcs = node.experimental_debug_info.original_func_names
418    # If the `original_node_names` are empty, uses the node name directly.
419    if not debug_nodes:
420      original_nodes.add(("", node.name))
421    else:
422      for i in range(len(debug_nodes)):
423        debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
424        original_nodes.add((debug_func, debug_nodes[i]))
425
426  # Convert the nodes to the debug info proto object.
427  return nodes_to_debug_info_func(original_nodes)
428
429
430def convert_bytes_to_c_source(data,
431                              array_name,
432                              max_line_width=80,
433                              include_guard=None,
434                              include_path=None,
435                              use_tensorflow_license=False):
436  """Returns strings representing a C constant array containing `data`.
437
438  Args:
439    data: Byte array that will be converted into a C constant.
440    array_name: String to use as the variable name for the constant array.
441    max_line_width: The longest line length, for formatting purposes.
442    include_guard: Name to use for the include guard macro definition.
443    include_path: Optional path to include in the source file.
444    use_tensorflow_license: Whether to include the standard TensorFlow Apache2
445      license in the generated files.
446
447  Returns:
448    Text that can be compiled as a C source file to link in the data as a
449    literal array of values.
450    Text that can be used as a C header file to reference the literal array.
451  """
452
453  starting_pad = "   "
454  array_lines = []
455  array_line = starting_pad
456  for value in bytearray(data):
457    if (len(array_line) + 4) > max_line_width:
458      array_lines.append(array_line + "\n")
459      array_line = starting_pad
460    array_line += " 0x%02x," % (value,)
461  if len(array_line) > len(starting_pad):
462    array_lines.append(array_line + "\n")
463  array_values = "".join(array_lines)
464
465  if include_guard is None:
466    include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
467
468  if include_path is not None:
469    include_line = "#include \"{include_path}\"\n".format(
470        include_path=include_path)
471  else:
472    include_line = ""
473
474  if use_tensorflow_license:
475    license_text = """
476/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
477
478Licensed under the Apache License, Version 2.0 (the "License");
479you may not use this file except in compliance with the License.
480You may obtain a copy of the License at
481
482    http://www.apache.org/licenses/LICENSE-2.0
483
484Unless required by applicable law or agreed to in writing, software
485distributed under the License is distributed on an "AS IS" BASIS,
486WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
487See the License for the specific language governing permissions and
488limitations under the License.
489==============================================================================*/
490""".format(year=datetime.date.today().year)
491  else:
492    license_text = ""
493
494  source_template = """{license_text}
495// This is a TensorFlow Lite model file that has been converted into a C data
496// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
497// This form is useful for compiling into a binary for devices that don't have a
498// file system.
499
500{include_line}
501// We need to keep the data array aligned on some architectures.
502#ifdef __has_attribute
503#define HAVE_ATTRIBUTE(x) __has_attribute(x)
504#else
505#define HAVE_ATTRIBUTE(x) 0
506#endif
507#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
508#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
509#else
510#define DATA_ALIGN_ATTRIBUTE
511#endif
512
513const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
514{array_values}}};
515const int {array_name}_len = {array_length};
516"""
517
518  source_text = source_template.format(
519      array_name=array_name,
520      array_length=len(data),
521      array_values=array_values,
522      license_text=license_text,
523      include_line=include_line)
524
525  header_template = """
526{license_text}
527
528// This is a TensorFlow Lite model file that has been converted into a C data
529// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
530// This form is useful for compiling into a binary for devices that don't have a
531// file system.
532
533#ifndef {include_guard}
534#define {include_guard}
535
536extern const unsigned char {array_name}[];
537extern const int {array_name}_len;
538
539#endif  // {include_guard}
540"""
541
542  header_text = header_template.format(
543      array_name=array_name,
544      include_guard=include_guard,
545      license_text=license_text)
546
547  return source_text, header_text
548
549
550def _convert_model_from_bytearray_to_object(model_bytearray):
551  """Converts a tflite model from a bytearray into a parsable object."""
552  model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
553  model_object = schema_fb.ModelT.InitFromObj(model_object)
554  model_object = copy.deepcopy(model_object)
555  return model_object
556
557
558def _convert_model_from_object_to_bytearray(model_object):
559  """Converts a tflite model from a parsable object into a bytearray."""
560  # Initial size of the buffer, which will grow automatically if needed
561  builder = flatbuffers.Builder(1024)
562  model_offset = model_object.Pack(builder)
563  builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
564  return bytes(builder.Output())
565
566
567def get_quantize_opcode_idx(model):
568  """Returns the quantize op idx."""
569  quant_opcode_idxs = []
570  for idx, opcode in enumerate(model.operatorCodes):
571    builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
572    if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
573      quant_opcode_idxs.append(idx)
574  return quant_opcode_idxs
575
576
577def get_dequantize_opcode_idx(model):
578  """Returns the quantize op idx."""
579  quant_opcode_idxs = []
580  for idx, opcode in enumerate(model.operatorCodes):
581    builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
582    if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
583      quant_opcode_idxs.append(idx)
584  return quant_opcode_idxs
585
586
587def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
588  """Update the tensors in the SignatureDef's TensorMaps."""
589  for i in range(len(tensor_maps)):
590    if tensor_maps[i].tensorIndex in map_old_to_new_tensors:
591      tensor_maps[i].tensorIndex = (
592          map_old_to_new_tensors[tensor_maps[i].tensorIndex])
593
594
595def _remove_tensors_from_model(model, remove_tensors_idxs):
596  """Remove tensors from model."""
597  if not remove_tensors_idxs:
598    return
599  if len(model.subgraphs) > 1:
600    logging.info("Skipping the removal of dangled tensors since the model has "
601                 "multiple subgraphs and tensors can be used in the different "
602                 "subgraph(s)")
603    return
604  subgraph = model.subgraphs[0]
605  tensors = subgraph.tensors
606  operators = subgraph.operators
607
608  logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
609  # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
610  # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
611  if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
612    logging.debug("Removing tensors only at the end of the tensor list")
613    del tensors[min(remove_tensors_idxs):]
614  else:
615    logging.debug("Removing tensors requires updating the model")
616    # Map the old tensor indices to new tensor indices
617    d_old_to_new_tensors = {}
618    left_shift_by = 0
619    for idx in range(len(tensors)):
620      if idx in remove_tensors_idxs:
621        left_shift_by += 1
622      else:
623        d_old_to_new_tensors[idx] = idx - left_shift_by
624    logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
625    # Update tensor indices referenced throughout the model
626    def update_tensors(tensor_idxs):
627      for i, ti in enumerate(tensor_idxs):
628        tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
629    update_tensors(subgraph.inputs)
630    update_tensors(subgraph.outputs)
631    for op in operators:
632      update_tensors(op.inputs)
633      update_tensors(op.outputs)
634    if model.signatureDefs:
635      signature_def = model.signatureDefs[0]
636      _update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors)
637      _update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors)
638    # Delete the tensors
639    for idx in sorted(remove_tensors_idxs, reverse=True):
640      tensors.pop(idx)
641    logging.debug("Removed tensors marked for deletion")
642
643
644def _modify_model_input_type(model, inference_input_type=dtypes.float32):
645  """Modify model input type."""
646  if inference_input_type == dtypes.float32:
647    return
648
649  if not model.signatureDefs:
650    _modify_model_input_type_per_subgraph(model, 0, -1, inference_input_type)
651    return
652
653  for signature_index, signature_def in enumerate(model.signatureDefs):
654    _modify_model_input_type_per_subgraph(model, signature_def.subgraphIndex,
655                                          signature_index, inference_input_type)
656
657
658def _modify_model_input_type_per_subgraph(model, subgraph_index,
659                                          signature_index,
660                                          inference_input_type):
661  """Modify model input type per subgraph."""
662  subgraph = model.subgraphs[subgraph_index]
663  tensors = subgraph.tensors
664  operators = subgraph.operators
665
666  # Find all quantize operators
667  quant_opcode_idxs = get_quantize_opcode_idx(model)
668  if operators and not quant_opcode_idxs:
669    for input_idx in subgraph.inputs:
670      input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
671      if input_type == dtypes.float32:
672        raise ValueError("Model input is not dequantized.")
673    # None of the inputs have float32, then they must be int16, int8, or bool
674    return
675
676  # Validate that the model input is quantized
677  input_quant_ops = []
678  for op in operators:
679    # Find operators that quantize model input
680    if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
681      float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
682      # If found, validate that the operator's input type is float
683      float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
684      if float_type != dtypes.float32:
685        if float_type == inference_input_type:
686          continue
687        else:
688          raise ValueError(
689              "Initial model input type must be tf.float32. Expected type for "
690              "tensor with name '{}' is tf.float32, instead type is {}".format(
691                  float_tensor.name, get_tf_type_name(float_type)))
692      # If found, validate that the operator output is quantized and compatible
693      # with the final model input type
694      quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
695      if quant_type not in _MAP_QUANT_TO_IO_TYPES:
696        raise ValueError(
697            "Initial model input is not quantized. Expected type for "
698            "tensor with name '{}' should be in {}, instead type is {}".format(
699                quant_tensor.name,
700                tuple(get_tf_type_name(t) for t in
701                      _MAP_QUANT_TO_IO_TYPES.keys()),
702                get_tf_type_name(quant_type)))
703      else:
704        inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
705        if inference_input_type not in inference_io_types:
706          raise ValueError(
707              "Unsupported `inference_input_type` value. Expected to be in "
708              "{}, instead got {}.".format(
709                  tuple(get_tf_type_name(t) for t in inference_io_types),
710                  get_tf_type_name(inference_input_type)))
711      input_quant_ops.append(op)
712
713  if len(subgraph.inputs) != len(input_quant_ops):
714    logging.warning(
715        "For model inputs containing unsupported operations which cannot be "
716        "quantized, the `inference_input_type` attribute will default to the "
717        "original type."
718        )
719
720  # Modify model input type
721  if inference_input_type == dtypes.uint8:
722    # Change quant op (float to int8) to quant op (uint8 to int8)
723    for op in input_quant_ops:
724      int8_quantization = tensors[op.outputs[0]].quantization
725      uint8_quantization = schema_fb.QuantizationParametersT()
726      uint8_quantization.scale = [int8_quantization.scale[0]]
727      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
728      tensors[op.inputs[0]].quantization = uint8_quantization
729      tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
730  elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
731    # Remove the inputs and the quant operator
732    remove_tensors_idxs = set()
733    for op in input_quant_ops:
734      subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
735      if signature_index >= 0:
736        signature_def = model.signatureDefs[signature_index]
737        for i in range(len(signature_def.inputs)):
738          if signature_def.inputs[i].tensorIndex == op.inputs[0]:
739            signature_def.inputs[i].tensorIndex = op.outputs[0]
740      remove_tensors_idxs.add(op.inputs[0])
741      operators.remove(op)
742    # Remove tensors marked for deletion.
743    _remove_tensors_from_model(model, remove_tensors_idxs)
744  else:
745    raise ValueError(
746        "Unsupported `inference_input_type` value {}.".format(
747            get_tf_type_name(inference_input_type)))
748
749
750def _modify_model_output_type(model, inference_output_type=dtypes.float32):
751  """Modify model output type."""
752  if inference_output_type == dtypes.float32:
753    return
754
755  if not model.signatureDefs:
756    _modify_model_output_type_per_subgraph(model, 0, -1, inference_output_type)
757    return
758
759  for signature_index, signature_def in enumerate(model.signatureDefs):
760    _modify_model_output_type_per_subgraph(model, signature_def.subgraphIndex,
761                                           signature_index,
762                                           inference_output_type)
763
764
765def _modify_model_output_type_per_subgraph(model, subgraph_index,
766                                           signature_index,
767                                           inference_output_type):
768  """Modify model output type per subgraph."""
769  subgraph = model.subgraphs[subgraph_index]
770  tensors = subgraph.tensors
771  operators = subgraph.operators
772
773  # Find all dequantize operators
774  dequant_opcode_idxs = get_dequantize_opcode_idx(model)
775  if operators and not dequant_opcode_idxs:
776    for output in subgraph.outputs:
777      output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
778      if output_type == dtypes.float32:
779        raise ValueError("Model output is not dequantized.")
780    # None of the outputs have float32, then they must be int16, int8, or bool
781    return
782
783  # Validate that the model output is dequantized
784  output_dequant_ops = []
785  for op in operators:
786    # Find operators that dequantize model output
787    if (op.opcodeIndex in dequant_opcode_idxs and
788        op.outputs[0] in subgraph.outputs):
789      # If found, validate that the operator's output type is float
790      quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
791      float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
792      if float_type != dtypes.float32:
793        if float_type == inference_output_type:
794          continue
795        else:
796          raise ValueError(
797              "Initial model output type must be tf.float32. Expected type for "
798              "tensor with name '{}' is tf.float32, instead type is {}".format(
799                  float_tensor.name, get_tf_type_name(float_type)))
800      # If found, validate that the operator input is quantized and compatible
801      # with the final model output type
802      quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
803      if quant_type not in _MAP_QUANT_TO_IO_TYPES:
804        raise ValueError(
805            "Initial model output is not dequantized. Expected type for "
806            "tensor with name '{}' should be in {}, instead type is {}".format(
807                quant_tensor.name,
808                tuple(get_tf_type_name(t) for t in
809                      _MAP_QUANT_TO_IO_TYPES.keys()),
810                get_tf_type_name(quant_type)))
811      else:
812        inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
813        if inference_output_type not in inference_io_types:
814          raise ValueError(
815              "Unsupported `inference_output_type` value. Expected to be in "
816              "{}, instead got {}.".format(
817                  tuple(get_tf_type_name(t) for t in inference_io_types),
818                  get_tf_type_name(inference_output_type)))
819      output_dequant_ops.append(op)
820
821  if len(subgraph.outputs) != len(output_dequant_ops):
822    logging.warning(
823        "For model outputs containing unsupported operations which cannot be "
824        "quantized, the `inference_output_type` attribute will default to the "
825        "original type."
826        )
827
828  # Modify model output type
829  if inference_output_type == dtypes.uint8:
830    # Find a quantize operator
831    quant_opcode_idx = -1
832    for idx, opcode in enumerate(model.operatorCodes):
833      builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
834      if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
835        quant_opcode_idx = idx
836        break
837    # Create a quantize operator, if none exist
838    if quant_opcode_idx == -1:
839      quant_op = schema_fb.OperatorCodeT()
840      quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
841      quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
842      model.operatorCodes.append(quant_op)
843      quant_opcode_idx = len(model.operatorCodes) - 1
844    # Change dequant op (int8 to float) to quant op (int8 to uint8)
845    for op in output_dequant_ops:
846      op.opcodeIndex = quant_opcode_idx
847      int8_quantization = tensors[op.inputs[0]].quantization
848      uint8_quantization = schema_fb.QuantizationParametersT()
849      uint8_quantization.scale = [int8_quantization.scale[0]]
850      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
851      tensors[op.outputs[0]].quantization = uint8_quantization
852      tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
853  elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
854    # Remove the outputs and the dequant operator
855    remove_tensors_idxs = set()
856    for op in output_dequant_ops:
857      subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
858      if signature_index >= 0:
859        signature_def = model.signatureDefs[signature_index]
860        for i in range(len(signature_def.outputs)):
861          if signature_def.outputs[i].tensorIndex == op.outputs[0]:
862            signature_def.outputs[i].tensorIndex = op.inputs[0]
863      remove_tensors_idxs.add(op.outputs[0])
864      operators.remove(op)
865    # Remove tensors marked for deletion.
866    _remove_tensors_from_model(model, remove_tensors_idxs)
867  else:
868    raise ValueError(
869        "Unsupported `inference_output_type` value {}.".format(
870            get_tf_type_name(inference_output_type)))
871
872
873def _remove_redundant_quantize_ops(model):
874  """Finds back to back quantize ops and remove the first quantize op."""
875  if not model.signatureDefs:
876    _remove_redundant_quantize_ops_per_subgraph(model, 0, -1)
877    return
878
879  for signature_index, signature_def in enumerate(model.signatureDefs):
880    _remove_redundant_quantize_ops_per_subgraph(model,
881                                                signature_def.subgraphIndex,
882                                                signature_index)
883
884
885def _remove_redundant_quantize_ops_per_subgraph(model, subgraph_index,
886                                                signature_index):
887  """Remove redundant quantize ops per subgraph."""
888  subgraph = model.subgraphs[subgraph_index]
889  tensors = subgraph.tensors
890  operators = subgraph.operators
891
892  # Find all quantize operators.
893  quant_opcode_idxs = get_quantize_opcode_idx(model)
894  dequant_opcode_idxs = get_dequantize_opcode_idx(model)
895
896  # Find all redundant quant tensors.
897  all_quant_ops = []
898  redundant_quant_tensors = {}
899  output_dequant_tensors = {}
900  for op in operators:
901    if op.opcodeIndex in quant_opcode_idxs:
902      all_quant_ops.append(op)
903      input_tensor = tensors[op.inputs[0]]
904      output_tensor = tensors[op.outputs[0]]
905      input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type)
906      output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type)
907      # This is a requantize op, so write down its input tensor index.
908      if input_type != dtypes.float32 and output_type != dtypes.float32:
909        redundant_quant_tensors[op.inputs[0]] = op
910    if (op.opcodeIndex in dequant_opcode_idxs and
911        op.outputs[0] in subgraph.outputs):
912      output_dequant_tensors[op.inputs[0]] = op
913
914  # Remove all the quant ops which produce the redundant quant tensors.
915  for op in all_quant_ops:
916    output_tensor_idx = op.outputs[0]
917    if output_tensor_idx in redundant_quant_tensors:
918      requantize_op = redundant_quant_tensors[output_tensor_idx]
919      if model.signatureDefs:
920        signature_def = model.signatureDefs[0]
921        for output in signature_def.outputs:
922          if output.tensorIndex == op.outputs[0]:
923            output.tensorIndex = op.inputs[0]
924      # Reset the input of the requantize op to the float input
925      requantize_op.inputs[0] = op.inputs[0]
926      operators.remove(op)
927
928  # Remove all the quant ops which connect to the output dequant op.
929  for op in all_quant_ops:
930    output_tensor_idx = op.outputs[0]
931    if output_tensor_idx in output_dequant_tensors:
932      dequant_op = output_dequant_tensors[output_tensor_idx]
933      subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
934      if signature_index >= 0:
935        signature_def = model.signatureDefs[signature_index]
936        for output in signature_def.outputs:
937          if output.tensorIndex == dequant_op.outputs[0]:
938            output.tensorIndex = op.inputs[0]
939      operators.remove(op)
940      operators.remove(dequant_op)
941
942
943def modify_model_io_type(
944    model, inference_input_type=dtypes.float32,
945    inference_output_type=dtypes.float32):
946  """Modify the input/output type of a tflite model.
947
948  Args:
949    model: A tflite model.
950    inference_input_type: tf.DType representing modified input type.
951      (default tf.float32. If model input is int8 quantized, it must be in
952      {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
953      it must be in {tf.float32, tf.int16}, else it must be tf.float32)
954    inference_output_type: tf.DType representing modified output type.
955      (default tf.float32. If model output is int8 dequantized, it must be in
956      {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
957      it must be in {tf.float32, tf.int16}, else it must be tf.float32)
958  Returns:
959    A tflite model with modified input/output type.
960
961  Raises:
962    ValueError: If `inference_input_type`/`inference_output_type` is unsupported
963      or a supported integer type is specified for a model whose input/output is
964      not quantized/dequantized.
965    RuntimeError: If the modification was unsuccessful.
966
967  """
968  if (inference_input_type == dtypes.float32 and
969      inference_output_type == dtypes.float32):
970    return model
971
972  model_object = _convert_model_from_bytearray_to_object(model)
973
974  _modify_model_input_type(model_object, inference_input_type)
975
976  _modify_model_output_type(model_object, inference_output_type)
977
978  _remove_redundant_quantize_ops(model_object)
979
980  return _convert_model_from_object_to_bytearray(model_object)
981
982
983def get_sparsity_modes(model_object):
984  """Get sparsity modes used in a tflite model.
985
986  The sparsity modes are listed in conversion_metadata.fbs file.
987
988  Args:
989    model_object: A tflite model in object form.
990
991  Returns:
992    The list of sparsity modes used in the model.
993  """
994  if not model_object or not model_object.metadata:
995    return []
996
997  result = set()
998  for subgraph in model_object.subgraphs:
999    for tensor in subgraph.tensors:
1000      if not tensor.sparsity:
1001        continue
1002
1003      # Block map is the list if indexes where the block size is larger than 1.
1004      # So empty block map means it is random sparsity.
1005      if not tensor.sparsity.blockMap:
1006        result.add(
1007            conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY)
1008      else:
1009        result.add(
1010            conversion_metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY)
1011
1012  return list(result)
1013
1014
1015def populate_conversion_metadata(model_object, metadata):
1016  """Add or update conversion metadata to a tflite model.
1017
1018  Args:
1019    model_object: A tflite model in object form.
1020    metadata: The conversion metadata.
1021
1022  Returns:
1023    A tflite model object with embedded conversion metadata.
1024  """
1025  try:
1026    metadata_builder = flatbuffers.Builder(0)
1027    metadata_builder.Finish(metadata.Pack(metadata_builder))
1028    buffer_field = schema_fb.BufferT()
1029    buffer_field.data = metadata_builder.Output()
1030
1031    if not model_object.metadata:
1032      model_object.metadata = []
1033    else:
1034      # Check if metadata has already been populated.
1035      for meta in model_object.metadata:
1036        if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
1037          model_object.buffers[meta.buffer] = buffer_field
1038          return model_object
1039
1040    if not model_object.buffers:
1041      model_object.buffers = []
1042    model_object.buffers.append(buffer_field)
1043    # Creates a new metadata field.
1044    metadata_field = schema_fb.MetadataT()
1045    metadata_field.name = CONVERSION_METADATA_FIELD_NAME
1046    metadata_field.buffer = len(model_object.buffers) - 1
1047    model_object.metadata.append(metadata_field)
1048
1049    return model_object
1050  except Exception:  # pylint: disable=broad-except
1051    return model_object
1052
1053
1054def get_conversion_metadata(model_buffer):
1055  """Read conversion metadata from a tflite model.
1056
1057  Args:
1058    model_buffer: A tflite model.
1059
1060  Returns:
1061    The conversion metadata or None if it is not populated.
1062  """
1063  model_object = flatbuffer_utils.convert_bytearray_to_object(model_buffer)
1064  if not model_object or not model_object.metadata:
1065    return None
1066
1067  for meta in model_object.metadata:
1068    if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
1069      metadata_buf = model_object.buffers[meta.buffer].data.tobytes()
1070      return conversion_metadata_fb.ConversionMetadataT.InitFromObj(
1071          conversion_metadata_fb.ConversionMetadata.GetRootAsConversionMetadata(
1072              metadata_buf, 0))
1073
1074  return None
1075