1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions used by multiple converter files.""" 16 17import copy 18import datetime 19import sys 20 21from absl import logging 22 23import flatbuffers 24from tensorflow.core.protobuf import config_pb2 as _config_pb2 25from tensorflow.core.protobuf import graph_debug_info_pb2 26from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 27from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb 28from tensorflow.lite.python import schema_py_generated as schema_fb 29from tensorflow.lite.python import schema_util 30from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util 31from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs 32from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes 33from tensorflow.lite.tools import flatbuffer_utils 34from tensorflow.python.eager import function 35from tensorflow.python.framework import convert_to_constants as _convert_to_constants 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import error_interpolation as _error_interpolation 38from tensorflow.python.framework import graph_util as tf_graph_util 39from tensorflow.python.grappler import tf_optimizer 40from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph 41 42# The field name of conversion metadata in the flatbuffer file. 43CONVERSION_METADATA_FIELD_NAME = "CONVERSION_METADATA" 44 45# Keras functions used by TFLite 46model_input_signature = _tflite_keras_util.model_input_signature 47trace_model_call = _tflite_keras_util.trace_model_call 48 49# Jax functions used by TFLite 50# pylint: disable=g-import-not-at-top 51# pylint: disable=unused-import 52try: 53 from jax import xla_computation as _xla_computation 54except ImportError: 55 _xla_computation = None 56# pylint: enable=g-import-not-at-top 57# pylint: enable=unused-import 58 59# Defined as per TFLite schema 60_MAP_TFLITE_ENUM_TO_TF_TYPES = { 61 0: dtypes.float32, 62 1: dtypes.float16, 63 2: dtypes.int32, 64 3: dtypes.uint8, 65 4: dtypes.int64, 66 5: dtypes.string, 67 6: dtypes.bool, 68 7: dtypes.int16, 69 8: dtypes.complex64, 70 9: dtypes.int8, 71 10: dtypes.float64, 72 11: dtypes.complex128, 73 16: dtypes.uint32, 74} 75 76_TFLITE_FILE_IDENTIFIER = b"TFL3" 77 78_MAP_QUANT_TO_IO_TYPES = { 79 dtypes.int8: {dtypes.int8, dtypes.uint8}, 80 dtypes.int16: {dtypes.int16}, 81} 82 83 84def _convert_tflite_enum_type_to_tf_type(tflite_enum_type): 85 """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32). 86 87 Args: 88 tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32) 89 90 Raises: 91 ValueError: If an invalid tflite enum type is provided. 92 93 Returns: 94 tf type (eg: tf.float32) 95 """ 96 tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type) 97 if tf_type is None: 98 raise ValueError( 99 "Unsupported enum {}. The valid map of enum to tf types is : {}" 100 .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES)) 101 return tf_type 102 103 104def get_tf_type_name(tf_type): 105 """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32").""" 106 return "tf." + tf_type.name if tf_type else None 107 108 109def get_tensor_name(tensor): 110 """Returns name of the input tensor. 111 112 Args: 113 tensor: tf.Tensor 114 115 Returns: 116 str 117 """ 118 parts = tensor.name.split(":") 119 if len(parts) > 2: 120 raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( 121 len(parts) - 1)) 122 123 # To be consistent with the tensor naming scheme in tensorflow, we need 124 # drop the ':0' suffix for the first tensor. 125 if len(parts) > 1 and parts[1] != "0": 126 return tensor.name 127 return parts[0] 128 129 130def get_tensors_from_tensor_names(graph, tensor_names): 131 """Gets the Tensors associated with the `tensor_names` in the provided graph. 132 133 Args: 134 graph: TensorFlow Graph. 135 tensor_names: List of strings that represent names of tensors in the graph. 136 137 Returns: 138 A list of Tensor objects in the same order the names are provided. 139 140 Raises: 141 ValueError: 142 tensor_names contains an invalid tensor name. 143 """ 144 # Get the list of all of the tensors. 145 tensor_name_to_tensor = {} 146 for op in graph.get_operations(): 147 for tensor in op.values(): 148 tensor_name_to_tensor[get_tensor_name(tensor)] = tensor 149 150 # Get the tensors associated with tensor_names. 151 tensors = [] 152 invalid_tensors = [] 153 for name in tensor_names: 154 if not isinstance(name, str): 155 raise ValueError("Invalid type for a tensor name in the provided graph. " 156 "Expected type for a tensor name is 'str', instead got " 157 "type '{}' for tensor name '{}'".format( 158 type(name), name)) 159 160 tensor = tensor_name_to_tensor.get(name) 161 if tensor is None: 162 invalid_tensors.append(name) 163 else: 164 tensors.append(tensor) 165 166 # Throw ValueError if any user input names are not valid tensors. 167 if invalid_tensors: 168 raise ValueError("Invalid tensors '{}' were found.".format( 169 ",".join(invalid_tensors))) 170 return tensors 171 172 173def set_tensor_shapes(tensors, shapes): 174 """Sets Tensor shape for each tensor if the shape is defined. 175 176 Args: 177 tensors: TensorFlow ops.Tensor. 178 shapes: Dict of strings representing input tensor names to list of 179 integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). 180 181 Raises: 182 ValueError: 183 `shapes` contains an invalid tensor. 184 `shapes` contains an invalid shape for a valid tensor. 185 """ 186 if shapes: 187 tensor_names_to_tensor = { 188 get_tensor_name(tensor): tensor for tensor in tensors 189 } 190 for name, shape in shapes.items(): 191 if name not in tensor_names_to_tensor: 192 raise ValueError("Invalid tensor \'{}\' found in tensor shapes " 193 "map.".format(name)) 194 if shape is not None: 195 tensor = tensor_names_to_tensor[name] 196 try: 197 tensor.set_shape(shape) 198 except ValueError as error: 199 message = ("The shape of tensor '{0}' cannot be changed from {1} to " 200 "{2}. {3}".format(name, tensor.shape, shape, str(error))) 201 raise ValueError(message) 202 203 204def get_grappler_config(optimizers_list): 205 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 206 207 Args: 208 optimizers_list: List of strings that represents the list of optimizers. 209 210 Returns: 211 tf.ConfigProto. 212 """ 213 config = _config_pb2.ConfigProto() 214 rewrite_options = config.graph_options.rewrite_options 215 for optimizer in optimizers_list: 216 rewrite_options.optimizers.append(optimizer) 217 return config 218 219 220def run_graph_optimizations(graph_def, 221 input_arrays, 222 output_arrays, 223 config, 224 graph=None): 225 """Apply standard TensorFlow optimizations to the graph_def. 226 227 Args: 228 graph_def: Frozen GraphDef to be optimized. 229 input_arrays: List of arrays that are considered inputs of the graph. 230 output_arrays: List of arrays that are considered outputs of the graph. 231 config: tf.ConfigProto. 232 graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) 233 234 Returns: 235 A new, optimized GraphDef. 236 """ 237 meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) 238 239 signature = _meta_graph_pb2.SignatureDef() 240 for array in input_arrays: 241 signature.inputs[array.name].name = array.name 242 signature.inputs[array.name].dtype = array.dtype.as_datatype_enum 243 signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 244 245 for array in output_arrays: 246 signature.outputs[array.name].name = array.name 247 signature.outputs[array.name].dtype = array.dtype.as_datatype_enum 248 signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 249 250 meta_graph.signature_def["not_used_key"].CopyFrom(signature) 251 252 # We need to add a collection called 'train_op' so that grappler 253 # knows what the outputs are. 254 fetch_collection = _meta_graph_pb2.CollectionDef() 255 for array in input_arrays + output_arrays: 256 fetch_collection.node_list.value.append(array.name) 257 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 258 259 return tf_optimizer.OptimizeGraph(config, meta_graph) 260 261 262def _convert_op_hints_if_present(sess, graph_def, output_tensors, 263 hinted_outputs_nodes): 264 if is_frozen_graph(sess): 265 raise ValueError("Try to convert op hints, needs unfrozen graph.") 266 output_arrays = [get_tensor_name(tensor) for tensor in output_tensors] 267 graph_def = tf_graph_util.convert_variables_to_constants( 268 sess, graph_def, output_arrays + hinted_outputs_nodes) 269 graph_def = convert_op_hints_to_stubs(graph_def=graph_def) 270 return graph_def 271 272 273def freeze_graph(sess, input_tensors, output_tensors): 274 """Returns a frozen GraphDef. 275 276 Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the 277 existing GraphDef is returned. The Grappler pass is only run on models that 278 are frozen in order to inline the functions in the graph. 279 If OpHints is present, it will try to convert the OpHint graph. 280 281 Args: 282 sess: TensorFlow Session. 283 input_tensors: List of input tensors. 284 output_tensors: List of output tensors (only .name is used from this). 285 286 Returns: 287 Frozen GraphDef. 288 """ 289 # Runs a Grappler pass in order to inline any functions in the graph. 290 # Asides from inlining any simple function, Grappler will also try to lower 291 # while loop into switch merge representation which is undesired for Ophints, 292 # so we simply remove those attributes to prevent Grappler from doing so. 293 graph_def = _convert_to_constants.disable_lower_using_switch_merge( 294 sess.graph_def) 295 config = get_grappler_config(["function"]) 296 graph_def = run_graph_optimizations( 297 graph_def, input_tensors, output_tensors, config, graph=sess.graph) 298 299 # If ophints are present, just convert them. 300 hinted_outputs_nodes = find_all_hinted_output_nodes(sess) 301 if hinted_outputs_nodes: 302 return _convert_op_hints_if_present(sess, graph_def, output_tensors, 303 hinted_outputs_nodes) 304 305 if not is_frozen_graph(sess): 306 output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors] 307 return tf_graph_util.convert_variables_to_constants(sess, graph_def, 308 output_node_names) 309 else: 310 return sess.graph_def 311 312 313def is_frozen_graph(sess): 314 """Determines if the graph is frozen. 315 316 Determines if a graph has previously been frozen by checking for any 317 operations of type Variable*. If variables are found, the graph is not frozen. 318 319 Args: 320 sess: TensorFlow Session. 321 322 Returns: 323 Bool. 324 """ 325 for op in sess.graph.get_operations(): 326 if op.type.startswith("Variable") or op.type.endswith("VariableOp"): 327 return False 328 return True 329 330 331def build_debug_info_func(original_graph): 332 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 333 334 Args: 335 original_graph: The original `Graph` containing all the op stack traces. 336 337 Returns: 338 A function which retrieves the stack traces from the original graph and 339 converts them to a `GraphDebugInfo` for a given set of nodes. 340 """ 341 342 def f(original_nodes): 343 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 344 if not original_graph: 345 return None 346 # For the given nodes, gets all the op definitions in the original graph. 347 useful_ops = [] 348 for func, name in original_nodes: 349 try: 350 if not func: 351 useful_ops.append((func, original_graph.get_operation_by_name(name))) 352 else: 353 sub_func = original_graph._get_function(func) # pylint: disable=protected-access 354 if isinstance(sub_func, function._EagerDefinedFunction): # pylint: disable=protected-access 355 useful_ops.append( 356 (func, sub_func.graph.get_operation_by_name(name))) 357 else: 358 sys.stderr.write( 359 "Use '@tf.function' or '@defun' to decorate the function.\n") 360 continue 361 except KeyError: 362 # New node created by graph optimizer. No stack trace from source code. 363 continue 364 # Convert all the op definitions to stack traces in terms of GraphDebugInfo. 365 return _error_interpolation.create_graph_debug_info_def(useful_ops) 366 367 return f 368 369 370def convert_debug_info_func(saved_debug_info): 371 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 372 373 Args: 374 saved_debug_info: The `GraphDebugInfo` containing all the debug info. 375 376 Returns: 377 A function which retrieves the stack traces from the original graph and 378 converts them to a `GraphDebugInfo` for a given set of nodes. 379 """ 380 381 def f(original_nodes): 382 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 383 if not saved_debug_info: 384 return None 385 386 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 387 # All the files are copied over, so the index wouldn't be changed. 388 output_debug_info.files[:] = saved_debug_info.files 389 # We only copy over the debug info for the input nodes 390 for func, node in original_nodes: 391 debug_key = node + "@" + func 392 output_debug_info.traces[debug_key].CopyFrom( 393 saved_debug_info.traces[debug_key]) 394 return output_debug_info 395 396 return f 397 398 399def get_debug_info(nodes_to_debug_info_func, converted_graph): 400 """Returns the debug info for the original nodes in the `converted_graph`. 401 402 Args: 403 nodes_to_debug_info_func: The method to collect the op debug info for the 404 nodes. 405 converted_graph: A `GraphDef` after optimization and transformation. 406 407 Returns: 408 `GraphDebugInfo` for all the original nodes in `converted_graph`. 409 """ 410 if not nodes_to_debug_info_func: 411 return None 412 413 # Collect all the debug info nodes from the converted_graph 414 original_nodes = set() 415 for node in converted_graph.node: 416 debug_nodes = node.experimental_debug_info.original_node_names 417 debug_funcs = node.experimental_debug_info.original_func_names 418 # If the `original_node_names` are empty, uses the node name directly. 419 if not debug_nodes: 420 original_nodes.add(("", node.name)) 421 else: 422 for i in range(len(debug_nodes)): 423 debug_func = "" if i >= len(debug_funcs) else debug_funcs[i] 424 original_nodes.add((debug_func, debug_nodes[i])) 425 426 # Convert the nodes to the debug info proto object. 427 return nodes_to_debug_info_func(original_nodes) 428 429 430def convert_bytes_to_c_source(data, 431 array_name, 432 max_line_width=80, 433 include_guard=None, 434 include_path=None, 435 use_tensorflow_license=False): 436 """Returns strings representing a C constant array containing `data`. 437 438 Args: 439 data: Byte array that will be converted into a C constant. 440 array_name: String to use as the variable name for the constant array. 441 max_line_width: The longest line length, for formatting purposes. 442 include_guard: Name to use for the include guard macro definition. 443 include_path: Optional path to include in the source file. 444 use_tensorflow_license: Whether to include the standard TensorFlow Apache2 445 license in the generated files. 446 447 Returns: 448 Text that can be compiled as a C source file to link in the data as a 449 literal array of values. 450 Text that can be used as a C header file to reference the literal array. 451 """ 452 453 starting_pad = " " 454 array_lines = [] 455 array_line = starting_pad 456 for value in bytearray(data): 457 if (len(array_line) + 4) > max_line_width: 458 array_lines.append(array_line + "\n") 459 array_line = starting_pad 460 array_line += " 0x%02x," % (value,) 461 if len(array_line) > len(starting_pad): 462 array_lines.append(array_line + "\n") 463 array_values = "".join(array_lines) 464 465 if include_guard is None: 466 include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_" 467 468 if include_path is not None: 469 include_line = "#include \"{include_path}\"\n".format( 470 include_path=include_path) 471 else: 472 include_line = "" 473 474 if use_tensorflow_license: 475 license_text = """ 476/* Copyright {year} The TensorFlow Authors. All Rights Reserved. 477 478Licensed under the Apache License, Version 2.0 (the "License"); 479you may not use this file except in compliance with the License. 480You may obtain a copy of the License at 481 482 http://www.apache.org/licenses/LICENSE-2.0 483 484Unless required by applicable law or agreed to in writing, software 485distributed under the License is distributed on an "AS IS" BASIS, 486WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 487See the License for the specific language governing permissions and 488limitations under the License. 489==============================================================================*/ 490""".format(year=datetime.date.today().year) 491 else: 492 license_text = "" 493 494 source_template = """{license_text} 495// This is a TensorFlow Lite model file that has been converted into a C data 496// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 497// This form is useful for compiling into a binary for devices that don't have a 498// file system. 499 500{include_line} 501// We need to keep the data array aligned on some architectures. 502#ifdef __has_attribute 503#define HAVE_ATTRIBUTE(x) __has_attribute(x) 504#else 505#define HAVE_ATTRIBUTE(x) 0 506#endif 507#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) 508#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) 509#else 510#define DATA_ALIGN_ATTRIBUTE 511#endif 512 513const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{ 514{array_values}}}; 515const int {array_name}_len = {array_length}; 516""" 517 518 source_text = source_template.format( 519 array_name=array_name, 520 array_length=len(data), 521 array_values=array_values, 522 license_text=license_text, 523 include_line=include_line) 524 525 header_template = """ 526{license_text} 527 528// This is a TensorFlow Lite model file that has been converted into a C data 529// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 530// This form is useful for compiling into a binary for devices that don't have a 531// file system. 532 533#ifndef {include_guard} 534#define {include_guard} 535 536extern const unsigned char {array_name}[]; 537extern const int {array_name}_len; 538 539#endif // {include_guard} 540""" 541 542 header_text = header_template.format( 543 array_name=array_name, 544 include_guard=include_guard, 545 license_text=license_text) 546 547 return source_text, header_text 548 549 550def _convert_model_from_bytearray_to_object(model_bytearray): 551 """Converts a tflite model from a bytearray into a parsable object.""" 552 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) 553 model_object = schema_fb.ModelT.InitFromObj(model_object) 554 model_object = copy.deepcopy(model_object) 555 return model_object 556 557 558def _convert_model_from_object_to_bytearray(model_object): 559 """Converts a tflite model from a parsable object into a bytearray.""" 560 # Initial size of the buffer, which will grow automatically if needed 561 builder = flatbuffers.Builder(1024) 562 model_offset = model_object.Pack(builder) 563 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) 564 return bytes(builder.Output()) 565 566 567def get_quantize_opcode_idx(model): 568 """Returns the quantize op idx.""" 569 quant_opcode_idxs = [] 570 for idx, opcode in enumerate(model.operatorCodes): 571 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 572 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 573 quant_opcode_idxs.append(idx) 574 return quant_opcode_idxs 575 576 577def get_dequantize_opcode_idx(model): 578 """Returns the quantize op idx.""" 579 quant_opcode_idxs = [] 580 for idx, opcode in enumerate(model.operatorCodes): 581 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 582 if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE: 583 quant_opcode_idxs.append(idx) 584 return quant_opcode_idxs 585 586 587def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors): 588 """Update the tensors in the SignatureDef's TensorMaps.""" 589 for i in range(len(tensor_maps)): 590 if tensor_maps[i].tensorIndex in map_old_to_new_tensors: 591 tensor_maps[i].tensorIndex = ( 592 map_old_to_new_tensors[tensor_maps[i].tensorIndex]) 593 594 595def _remove_tensors_from_model(model, remove_tensors_idxs): 596 """Remove tensors from model.""" 597 if not remove_tensors_idxs: 598 return 599 if len(model.subgraphs) > 1: 600 logging.info("Skipping the removal of dangled tensors since the model has " 601 "multiple subgraphs and tensors can be used in the different " 602 "subgraph(s)") 603 return 604 subgraph = model.subgraphs[0] 605 tensors = subgraph.tensors 606 operators = subgraph.operators 607 608 logging.debug("Removing tensors at indices : %s", remove_tensors_idxs) 609 # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an 610 # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]). 611 if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs): 612 logging.debug("Removing tensors only at the end of the tensor list") 613 del tensors[min(remove_tensors_idxs):] 614 else: 615 logging.debug("Removing tensors requires updating the model") 616 # Map the old tensor indices to new tensor indices 617 d_old_to_new_tensors = {} 618 left_shift_by = 0 619 for idx in range(len(tensors)): 620 if idx in remove_tensors_idxs: 621 left_shift_by += 1 622 else: 623 d_old_to_new_tensors[idx] = idx - left_shift_by 624 logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__()) 625 # Update tensor indices referenced throughout the model 626 def update_tensors(tensor_idxs): 627 for i, ti in enumerate(tensor_idxs): 628 tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1) 629 update_tensors(subgraph.inputs) 630 update_tensors(subgraph.outputs) 631 for op in operators: 632 update_tensors(op.inputs) 633 update_tensors(op.outputs) 634 if model.signatureDefs: 635 signature_def = model.signatureDefs[0] 636 _update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors) 637 _update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors) 638 # Delete the tensors 639 for idx in sorted(remove_tensors_idxs, reverse=True): 640 tensors.pop(idx) 641 logging.debug("Removed tensors marked for deletion") 642 643 644def _modify_model_input_type(model, inference_input_type=dtypes.float32): 645 """Modify model input type.""" 646 if inference_input_type == dtypes.float32: 647 return 648 649 if not model.signatureDefs: 650 _modify_model_input_type_per_subgraph(model, 0, -1, inference_input_type) 651 return 652 653 for signature_index, signature_def in enumerate(model.signatureDefs): 654 _modify_model_input_type_per_subgraph(model, signature_def.subgraphIndex, 655 signature_index, inference_input_type) 656 657 658def _modify_model_input_type_per_subgraph(model, subgraph_index, 659 signature_index, 660 inference_input_type): 661 """Modify model input type per subgraph.""" 662 subgraph = model.subgraphs[subgraph_index] 663 tensors = subgraph.tensors 664 operators = subgraph.operators 665 666 # Find all quantize operators 667 quant_opcode_idxs = get_quantize_opcode_idx(model) 668 if operators and not quant_opcode_idxs: 669 for input_idx in subgraph.inputs: 670 input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type) 671 if input_type == dtypes.float32: 672 raise ValueError("Model input is not dequantized.") 673 # None of the inputs have float32, then they must be int16, int8, or bool 674 return 675 676 # Validate that the model input is quantized 677 input_quant_ops = [] 678 for op in operators: 679 # Find operators that quantize model input 680 if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs: 681 float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 682 # If found, validate that the operator's input type is float 683 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 684 if float_type != dtypes.float32: 685 if float_type == inference_input_type: 686 continue 687 else: 688 raise ValueError( 689 "Initial model input type must be tf.float32. Expected type for " 690 "tensor with name '{}' is tf.float32, instead type is {}".format( 691 float_tensor.name, get_tf_type_name(float_type))) 692 # If found, validate that the operator output is quantized and compatible 693 # with the final model input type 694 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 695 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 696 raise ValueError( 697 "Initial model input is not quantized. Expected type for " 698 "tensor with name '{}' should be in {}, instead type is {}".format( 699 quant_tensor.name, 700 tuple(get_tf_type_name(t) for t in 701 _MAP_QUANT_TO_IO_TYPES.keys()), 702 get_tf_type_name(quant_type))) 703 else: 704 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 705 if inference_input_type not in inference_io_types: 706 raise ValueError( 707 "Unsupported `inference_input_type` value. Expected to be in " 708 "{}, instead got {}.".format( 709 tuple(get_tf_type_name(t) for t in inference_io_types), 710 get_tf_type_name(inference_input_type))) 711 input_quant_ops.append(op) 712 713 if len(subgraph.inputs) != len(input_quant_ops): 714 logging.warning( 715 "For model inputs containing unsupported operations which cannot be " 716 "quantized, the `inference_input_type` attribute will default to the " 717 "original type." 718 ) 719 720 # Modify model input type 721 if inference_input_type == dtypes.uint8: 722 # Change quant op (float to int8) to quant op (uint8 to int8) 723 for op in input_quant_ops: 724 int8_quantization = tensors[op.outputs[0]].quantization 725 uint8_quantization = schema_fb.QuantizationParametersT() 726 uint8_quantization.scale = [int8_quantization.scale[0]] 727 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 728 tensors[op.inputs[0]].quantization = uint8_quantization 729 tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 730 elif inference_input_type in _MAP_QUANT_TO_IO_TYPES: 731 # Remove the inputs and the quant operator 732 remove_tensors_idxs = set() 733 for op in input_quant_ops: 734 subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] 735 if signature_index >= 0: 736 signature_def = model.signatureDefs[signature_index] 737 for i in range(len(signature_def.inputs)): 738 if signature_def.inputs[i].tensorIndex == op.inputs[0]: 739 signature_def.inputs[i].tensorIndex = op.outputs[0] 740 remove_tensors_idxs.add(op.inputs[0]) 741 operators.remove(op) 742 # Remove tensors marked for deletion. 743 _remove_tensors_from_model(model, remove_tensors_idxs) 744 else: 745 raise ValueError( 746 "Unsupported `inference_input_type` value {}.".format( 747 get_tf_type_name(inference_input_type))) 748 749 750def _modify_model_output_type(model, inference_output_type=dtypes.float32): 751 """Modify model output type.""" 752 if inference_output_type == dtypes.float32: 753 return 754 755 if not model.signatureDefs: 756 _modify_model_output_type_per_subgraph(model, 0, -1, inference_output_type) 757 return 758 759 for signature_index, signature_def in enumerate(model.signatureDefs): 760 _modify_model_output_type_per_subgraph(model, signature_def.subgraphIndex, 761 signature_index, 762 inference_output_type) 763 764 765def _modify_model_output_type_per_subgraph(model, subgraph_index, 766 signature_index, 767 inference_output_type): 768 """Modify model output type per subgraph.""" 769 subgraph = model.subgraphs[subgraph_index] 770 tensors = subgraph.tensors 771 operators = subgraph.operators 772 773 # Find all dequantize operators 774 dequant_opcode_idxs = get_dequantize_opcode_idx(model) 775 if operators and not dequant_opcode_idxs: 776 for output in subgraph.outputs: 777 output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type) 778 if output_type == dtypes.float32: 779 raise ValueError("Model output is not dequantized.") 780 # None of the outputs have float32, then they must be int16, int8, or bool 781 return 782 783 # Validate that the model output is dequantized 784 output_dequant_ops = [] 785 for op in operators: 786 # Find operators that dequantize model output 787 if (op.opcodeIndex in dequant_opcode_idxs and 788 op.outputs[0] in subgraph.outputs): 789 # If found, validate that the operator's output type is float 790 quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 791 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 792 if float_type != dtypes.float32: 793 if float_type == inference_output_type: 794 continue 795 else: 796 raise ValueError( 797 "Initial model output type must be tf.float32. Expected type for " 798 "tensor with name '{}' is tf.float32, instead type is {}".format( 799 float_tensor.name, get_tf_type_name(float_type))) 800 # If found, validate that the operator input is quantized and compatible 801 # with the final model output type 802 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 803 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 804 raise ValueError( 805 "Initial model output is not dequantized. Expected type for " 806 "tensor with name '{}' should be in {}, instead type is {}".format( 807 quant_tensor.name, 808 tuple(get_tf_type_name(t) for t in 809 _MAP_QUANT_TO_IO_TYPES.keys()), 810 get_tf_type_name(quant_type))) 811 else: 812 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 813 if inference_output_type not in inference_io_types: 814 raise ValueError( 815 "Unsupported `inference_output_type` value. Expected to be in " 816 "{}, instead got {}.".format( 817 tuple(get_tf_type_name(t) for t in inference_io_types), 818 get_tf_type_name(inference_output_type))) 819 output_dequant_ops.append(op) 820 821 if len(subgraph.outputs) != len(output_dequant_ops): 822 logging.warning( 823 "For model outputs containing unsupported operations which cannot be " 824 "quantized, the `inference_output_type` attribute will default to the " 825 "original type." 826 ) 827 828 # Modify model output type 829 if inference_output_type == dtypes.uint8: 830 # Find a quantize operator 831 quant_opcode_idx = -1 832 for idx, opcode in enumerate(model.operatorCodes): 833 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 834 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 835 quant_opcode_idx = idx 836 break 837 # Create a quantize operator, if none exist 838 if quant_opcode_idx == -1: 839 quant_op = schema_fb.OperatorCodeT() 840 quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE 841 quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE 842 model.operatorCodes.append(quant_op) 843 quant_opcode_idx = len(model.operatorCodes) - 1 844 # Change dequant op (int8 to float) to quant op (int8 to uint8) 845 for op in output_dequant_ops: 846 op.opcodeIndex = quant_opcode_idx 847 int8_quantization = tensors[op.inputs[0]].quantization 848 uint8_quantization = schema_fb.QuantizationParametersT() 849 uint8_quantization.scale = [int8_quantization.scale[0]] 850 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 851 tensors[op.outputs[0]].quantization = uint8_quantization 852 tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 853 elif inference_output_type in _MAP_QUANT_TO_IO_TYPES: 854 # Remove the outputs and the dequant operator 855 remove_tensors_idxs = set() 856 for op in output_dequant_ops: 857 subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] 858 if signature_index >= 0: 859 signature_def = model.signatureDefs[signature_index] 860 for i in range(len(signature_def.outputs)): 861 if signature_def.outputs[i].tensorIndex == op.outputs[0]: 862 signature_def.outputs[i].tensorIndex = op.inputs[0] 863 remove_tensors_idxs.add(op.outputs[0]) 864 operators.remove(op) 865 # Remove tensors marked for deletion. 866 _remove_tensors_from_model(model, remove_tensors_idxs) 867 else: 868 raise ValueError( 869 "Unsupported `inference_output_type` value {}.".format( 870 get_tf_type_name(inference_output_type))) 871 872 873def _remove_redundant_quantize_ops(model): 874 """Finds back to back quantize ops and remove the first quantize op.""" 875 if not model.signatureDefs: 876 _remove_redundant_quantize_ops_per_subgraph(model, 0, -1) 877 return 878 879 for signature_index, signature_def in enumerate(model.signatureDefs): 880 _remove_redundant_quantize_ops_per_subgraph(model, 881 signature_def.subgraphIndex, 882 signature_index) 883 884 885def _remove_redundant_quantize_ops_per_subgraph(model, subgraph_index, 886 signature_index): 887 """Remove redundant quantize ops per subgraph.""" 888 subgraph = model.subgraphs[subgraph_index] 889 tensors = subgraph.tensors 890 operators = subgraph.operators 891 892 # Find all quantize operators. 893 quant_opcode_idxs = get_quantize_opcode_idx(model) 894 dequant_opcode_idxs = get_dequantize_opcode_idx(model) 895 896 # Find all redundant quant tensors. 897 all_quant_ops = [] 898 redundant_quant_tensors = {} 899 output_dequant_tensors = {} 900 for op in operators: 901 if op.opcodeIndex in quant_opcode_idxs: 902 all_quant_ops.append(op) 903 input_tensor = tensors[op.inputs[0]] 904 output_tensor = tensors[op.outputs[0]] 905 input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type) 906 output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type) 907 # This is a requantize op, so write down its input tensor index. 908 if input_type != dtypes.float32 and output_type != dtypes.float32: 909 redundant_quant_tensors[op.inputs[0]] = op 910 if (op.opcodeIndex in dequant_opcode_idxs and 911 op.outputs[0] in subgraph.outputs): 912 output_dequant_tensors[op.inputs[0]] = op 913 914 # Remove all the quant ops which produce the redundant quant tensors. 915 for op in all_quant_ops: 916 output_tensor_idx = op.outputs[0] 917 if output_tensor_idx in redundant_quant_tensors: 918 requantize_op = redundant_quant_tensors[output_tensor_idx] 919 if model.signatureDefs: 920 signature_def = model.signatureDefs[0] 921 for output in signature_def.outputs: 922 if output.tensorIndex == op.outputs[0]: 923 output.tensorIndex = op.inputs[0] 924 # Reset the input of the requantize op to the float input 925 requantize_op.inputs[0] = op.inputs[0] 926 operators.remove(op) 927 928 # Remove all the quant ops which connect to the output dequant op. 929 for op in all_quant_ops: 930 output_tensor_idx = op.outputs[0] 931 if output_tensor_idx in output_dequant_tensors: 932 dequant_op = output_dequant_tensors[output_tensor_idx] 933 subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0] 934 if signature_index >= 0: 935 signature_def = model.signatureDefs[signature_index] 936 for output in signature_def.outputs: 937 if output.tensorIndex == dequant_op.outputs[0]: 938 output.tensorIndex = op.inputs[0] 939 operators.remove(op) 940 operators.remove(dequant_op) 941 942 943def modify_model_io_type( 944 model, inference_input_type=dtypes.float32, 945 inference_output_type=dtypes.float32): 946 """Modify the input/output type of a tflite model. 947 948 Args: 949 model: A tflite model. 950 inference_input_type: tf.DType representing modified input type. 951 (default tf.float32. If model input is int8 quantized, it must be in 952 {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized, 953 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 954 inference_output_type: tf.DType representing modified output type. 955 (default tf.float32. If model output is int8 dequantized, it must be in 956 {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized, 957 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 958 Returns: 959 A tflite model with modified input/output type. 960 961 Raises: 962 ValueError: If `inference_input_type`/`inference_output_type` is unsupported 963 or a supported integer type is specified for a model whose input/output is 964 not quantized/dequantized. 965 RuntimeError: If the modification was unsuccessful. 966 967 """ 968 if (inference_input_type == dtypes.float32 and 969 inference_output_type == dtypes.float32): 970 return model 971 972 model_object = _convert_model_from_bytearray_to_object(model) 973 974 _modify_model_input_type(model_object, inference_input_type) 975 976 _modify_model_output_type(model_object, inference_output_type) 977 978 _remove_redundant_quantize_ops(model_object) 979 980 return _convert_model_from_object_to_bytearray(model_object) 981 982 983def get_sparsity_modes(model_object): 984 """Get sparsity modes used in a tflite model. 985 986 The sparsity modes are listed in conversion_metadata.fbs file. 987 988 Args: 989 model_object: A tflite model in object form. 990 991 Returns: 992 The list of sparsity modes used in the model. 993 """ 994 if not model_object or not model_object.metadata: 995 return [] 996 997 result = set() 998 for subgraph in model_object.subgraphs: 999 for tensor in subgraph.tensors: 1000 if not tensor.sparsity: 1001 continue 1002 1003 # Block map is the list if indexes where the block size is larger than 1. 1004 # So empty block map means it is random sparsity. 1005 if not tensor.sparsity.blockMap: 1006 result.add( 1007 conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY) 1008 else: 1009 result.add( 1010 conversion_metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY) 1011 1012 return list(result) 1013 1014 1015def populate_conversion_metadata(model_object, metadata): 1016 """Add or update conversion metadata to a tflite model. 1017 1018 Args: 1019 model_object: A tflite model in object form. 1020 metadata: The conversion metadata. 1021 1022 Returns: 1023 A tflite model object with embedded conversion metadata. 1024 """ 1025 try: 1026 metadata_builder = flatbuffers.Builder(0) 1027 metadata_builder.Finish(metadata.Pack(metadata_builder)) 1028 buffer_field = schema_fb.BufferT() 1029 buffer_field.data = metadata_builder.Output() 1030 1031 if not model_object.metadata: 1032 model_object.metadata = [] 1033 else: 1034 # Check if metadata has already been populated. 1035 for meta in model_object.metadata: 1036 if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME: 1037 model_object.buffers[meta.buffer] = buffer_field 1038 return model_object 1039 1040 if not model_object.buffers: 1041 model_object.buffers = [] 1042 model_object.buffers.append(buffer_field) 1043 # Creates a new metadata field. 1044 metadata_field = schema_fb.MetadataT() 1045 metadata_field.name = CONVERSION_METADATA_FIELD_NAME 1046 metadata_field.buffer = len(model_object.buffers) - 1 1047 model_object.metadata.append(metadata_field) 1048 1049 return model_object 1050 except Exception: # pylint: disable=broad-except 1051 return model_object 1052 1053 1054def get_conversion_metadata(model_buffer): 1055 """Read conversion metadata from a tflite model. 1056 1057 Args: 1058 model_buffer: A tflite model. 1059 1060 Returns: 1061 The conversion metadata or None if it is not populated. 1062 """ 1063 model_object = flatbuffer_utils.convert_bytearray_to_object(model_buffer) 1064 if not model_object or not model_object.metadata: 1065 return None 1066 1067 for meta in model_object.metadata: 1068 if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME: 1069 metadata_buf = model_object.buffers[meta.buffer].data.tobytes() 1070 return conversion_metadata_fb.ConversionMetadataT.InitFromObj( 1071 conversion_metadata_fb.ConversionMetadata.GetRootAsConversionMetadata( 1072 metadata_buf, 0)) 1073 1074 return None 1075