1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Functions used by multiple converter files.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23import datetime 24import sys 25 26from absl import logging 27import six 28from six.moves import range 29 30import flatbuffers 31from tensorflow.core.protobuf import config_pb2 as _config_pb2 32from tensorflow.core.protobuf import graph_debug_info_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 34from tensorflow.lite.python import schema_py_generated as schema_fb 35from tensorflow.lite.python import schema_util 36from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util 37from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs 38from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes 39from tensorflow.python.eager import function 40from tensorflow.python.framework import convert_to_constants as _convert_to_constants 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import error_interpolation as _error_interpolation 43from tensorflow.python.framework import graph_util as tf_graph_util 44from tensorflow.python.grappler import tf_optimizer 45from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph 46 47# Keras functions used by TFLite 48model_input_signature = _tflite_keras_util.model_input_signature 49trace_model_call = _tflite_keras_util.trace_model_call 50 51# Defined as per TFLite schema 52_MAP_TFLITE_ENUM_TO_TF_TYPES = { 53 0: dtypes.float32, 54 1: dtypes.float16, 55 2: dtypes.int32, 56 3: dtypes.uint8, 57 4: dtypes.int64, 58 5: dtypes.string, 59 6: dtypes.bool, 60 7: dtypes.int16, 61 8: dtypes.complex64, 62 9: dtypes.int8, 63 10: dtypes.float64, 64 11: dtypes.complex128, 65 16: dtypes.uint32, 66} 67 68_TFLITE_FILE_IDENTIFIER = b"TFL3" 69 70_MAP_QUANT_TO_IO_TYPES = { 71 dtypes.int8: {dtypes.int8, dtypes.uint8}, 72 dtypes.int16: {dtypes.int16}, 73} 74 75 76def _convert_tflite_enum_type_to_tf_type(tflite_enum_type): 77 """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32). 78 79 Args: 80 tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32) 81 82 Raises: 83 ValueError: If an invalid tflite enum type is provided. 84 85 Returns: 86 tf type (eg: tf.float32) 87 """ 88 tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type) 89 if tf_type is None: 90 raise ValueError( 91 "Unsupported enum {}. The valid map of enum to tf types is : {}" 92 .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES)) 93 return tf_type 94 95 96def get_tf_type_name(tf_type): 97 """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32").""" 98 return "tf." + tf_type.name if tf_type else None 99 100 101def get_tensor_name(tensor): 102 """Returns name of the input tensor. 103 104 Args: 105 tensor: tf.Tensor 106 107 Returns: 108 str 109 """ 110 parts = six.ensure_str(tensor.name).split(":") 111 if len(parts) > 2: 112 raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( 113 len(parts) - 1)) 114 115 # To be consistent with the tensor naming scheme in tensorflow, we need 116 # drop the ':0' suffix for the first tensor. 117 if len(parts) > 1 and parts[1] != "0": 118 return tensor.name 119 return parts[0] 120 121 122def get_tensors_from_tensor_names(graph, tensor_names): 123 """Gets the Tensors associated with the `tensor_names` in the provided graph. 124 125 Args: 126 graph: TensorFlow Graph. 127 tensor_names: List of strings that represent names of tensors in the graph. 128 129 Returns: 130 A list of Tensor objects in the same order the names are provided. 131 132 Raises: 133 ValueError: 134 tensor_names contains an invalid tensor name. 135 """ 136 # Get the list of all of the tensors. 137 tensor_name_to_tensor = {} 138 for op in graph.get_operations(): 139 for tensor in op.values(): 140 tensor_name_to_tensor[get_tensor_name(tensor)] = tensor 141 142 # Get the tensors associated with tensor_names. 143 tensors = [] 144 invalid_tensors = [] 145 for name in tensor_names: 146 if not isinstance(name, six.string_types): 147 raise ValueError("Invalid type for a tensor name in the provided graph. " 148 "Expected type for a tensor name is 'str', instead got " 149 "type '{}' for tensor name '{}'".format( 150 type(name), name)) 151 152 tensor = tensor_name_to_tensor.get(name) 153 if tensor is None: 154 invalid_tensors.append(name) 155 else: 156 tensors.append(tensor) 157 158 # Throw ValueError if any user input names are not valid tensors. 159 if invalid_tensors: 160 raise ValueError("Invalid tensors '{}' were found.".format( 161 ",".join(invalid_tensors))) 162 return tensors 163 164 165def set_tensor_shapes(tensors, shapes): 166 """Sets Tensor shape for each tensor if the shape is defined. 167 168 Args: 169 tensors: TensorFlow ops.Tensor. 170 shapes: Dict of strings representing input tensor names to list of 171 integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). 172 173 Raises: 174 ValueError: 175 `shapes` contains an invalid tensor. 176 `shapes` contains an invalid shape for a valid tensor. 177 """ 178 if shapes: 179 tensor_names_to_tensor = { 180 get_tensor_name(tensor): tensor for tensor in tensors 181 } 182 for name, shape in shapes.items(): 183 if name not in tensor_names_to_tensor: 184 raise ValueError("Invalid tensor \'{}\' found in tensor shapes " 185 "map.".format(name)) 186 if shape is not None: 187 tensor = tensor_names_to_tensor[name] 188 try: 189 tensor.set_shape(shape) 190 except ValueError as error: 191 message = ("The shape of tensor '{0}' cannot be changed from {1} to " 192 "{2}. {3}".format(name, tensor.shape, shape, str(error))) 193 raise ValueError(message) 194 195 196def get_grappler_config(optimizers_list): 197 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 198 199 Args: 200 optimizers_list: List of strings that represents the list of optimizers. 201 202 Returns: 203 tf.ConfigProto. 204 """ 205 config = _config_pb2.ConfigProto() 206 rewrite_options = config.graph_options.rewrite_options 207 for optimizer in optimizers_list: 208 rewrite_options.optimizers.append(optimizer) 209 return config 210 211 212def run_graph_optimizations(graph_def, 213 input_arrays, 214 output_arrays, 215 config, 216 graph=None): 217 """Apply standard TensorFlow optimizations to the graph_def. 218 219 Args: 220 graph_def: Frozen GraphDef to be optimized. 221 input_arrays: List of arrays that are considered inputs of the graph. 222 output_arrays: List of arrays that are considered outputs of the graph. 223 config: tf.ConfigProto. 224 graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) 225 226 Returns: 227 A new, optimized GraphDef. 228 """ 229 meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) 230 231 signature = _meta_graph_pb2.SignatureDef() 232 for array in input_arrays: 233 signature.inputs[array.name].name = array.name 234 signature.inputs[array.name].dtype = array.dtype.as_datatype_enum 235 signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 236 237 for array in output_arrays: 238 signature.outputs[array.name].name = array.name 239 signature.outputs[array.name].dtype = array.dtype.as_datatype_enum 240 signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 241 242 meta_graph.signature_def["not_used_key"].CopyFrom(signature) 243 244 # We need to add a collection called 'train_op' so that grappler 245 # knows what the outputs are. 246 fetch_collection = _meta_graph_pb2.CollectionDef() 247 for array in input_arrays + output_arrays: 248 fetch_collection.node_list.value.append(array.name) 249 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 250 251 return tf_optimizer.OptimizeGraph(config, meta_graph) 252 253 254def _convert_op_hints_if_present(sess, graph_def, output_tensors, 255 hinted_outputs_nodes): 256 if is_frozen_graph(sess): 257 raise ValueError("Try to convert op hints, needs unfrozen graph.") 258 output_arrays = [get_tensor_name(tensor) for tensor in output_tensors] 259 graph_def = tf_graph_util.convert_variables_to_constants( 260 sess, graph_def, output_arrays + hinted_outputs_nodes) 261 graph_def = convert_op_hints_to_stubs(graph_def=graph_def) 262 return graph_def 263 264 265def freeze_graph(sess, input_tensors, output_tensors): 266 """Returns a frozen GraphDef. 267 268 Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the 269 existing GraphDef is returned. The Grappler pass is only run on models that 270 are frozen in order to inline the functions in the graph. 271 If OpHints is present, it will try to convert the OpHint graph. 272 273 Args: 274 sess: TensorFlow Session. 275 input_tensors: List of input tensors. 276 output_tensors: List of output tensors (only .name is used from this). 277 278 Returns: 279 Frozen GraphDef. 280 """ 281 # Runs a Grappler pass in order to inline any functions in the graph. 282 # Asides from inlining any simple function, Grappler will also try to lower 283 # while loop into switch merge representation which is undesired for Ophints, 284 # so we simply remove those attributes to prevent Grappler from doing so. 285 graph_def = _convert_to_constants.disable_lower_using_switch_merge( 286 sess.graph_def) 287 config = get_grappler_config(["function"]) 288 graph_def = run_graph_optimizations( 289 graph_def, input_tensors, output_tensors, config, graph=sess.graph) 290 291 # If ophints are present, just convert them. 292 hinted_outputs_nodes = find_all_hinted_output_nodes(sess) 293 if hinted_outputs_nodes: 294 return _convert_op_hints_if_present(sess, graph_def, output_tensors, 295 hinted_outputs_nodes) 296 297 if not is_frozen_graph(sess): 298 output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors] 299 return tf_graph_util.convert_variables_to_constants(sess, graph_def, 300 output_node_names) 301 else: 302 return sess.graph_def 303 304 305def is_frozen_graph(sess): 306 """Determines if the graph is frozen. 307 308 Determines if a graph has previously been frozen by checking for any 309 operations of type Variable*. If variables are found, the graph is not frozen. 310 311 Args: 312 sess: TensorFlow Session. 313 314 Returns: 315 Bool. 316 """ 317 for op in sess.graph.get_operations(): 318 if six.ensure_str(op.type).startswith("Variable") or six.ensure_str( 319 op.type).endswith("VariableOp"): 320 return False 321 return True 322 323 324def build_debug_info_func(original_graph): 325 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 326 327 Args: 328 original_graph: The original `Graph` containing all the op stack traces. 329 330 Returns: 331 A function which retrieves the stack traces from the original graph and 332 converts them to a `GraphDebugInfo` for a given set of nodes. 333 """ 334 335 def f(original_nodes): 336 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 337 if not original_graph: 338 return None 339 # For the given nodes, gets all the op definitions in the original graph. 340 useful_ops = [] 341 for func, name in original_nodes: 342 try: 343 if not func: 344 useful_ops.append((func, original_graph.get_operation_by_name(name))) 345 else: 346 sub_func = original_graph._get_function(func) # pylint: disable=protected-access 347 if isinstance(sub_func, function._EagerDefinedFunction): # pylint: disable=protected-access 348 useful_ops.append( 349 (func, sub_func.graph.get_operation_by_name(name))) 350 else: 351 sys.stderr.write( 352 "Use '@tf.function' or '@defun' to decorate the function.\n") 353 continue 354 except KeyError: 355 # New node created by graph optimizer. No stack trace from source code. 356 continue 357 # Convert all the op definitions to stack traces in terms of GraphDebugInfo. 358 return _error_interpolation.create_graph_debug_info_def(useful_ops) 359 360 return f 361 362 363def convert_debug_info_func(saved_debug_info): 364 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 365 366 Args: 367 saved_debug_info: The `GraphDebugInfo` containing all the debug info. 368 369 Returns: 370 A function which retrieves the stack traces from the original graph and 371 converts them to a `GraphDebugInfo` for a given set of nodes. 372 """ 373 374 def f(original_nodes): 375 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 376 if not saved_debug_info: 377 return None 378 379 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 380 # All the files are copied over, so the index wouldn't be changed. 381 output_debug_info.files[:] = saved_debug_info.files 382 # We only copy over the debug info for the input nodes 383 for func, node in original_nodes: 384 debug_key = node + "@" + func 385 output_debug_info.traces[debug_key].CopyFrom( 386 saved_debug_info.traces[debug_key]) 387 return output_debug_info 388 389 return f 390 391 392def get_debug_info(nodes_to_debug_info_func, converted_graph): 393 """Returns the debug info for the original nodes in the `converted_graph`. 394 395 Args: 396 nodes_to_debug_info_func: The method to collect the op debug info for the 397 nodes. 398 converted_graph: A `GraphDef` after optimization and transformation. 399 400 Returns: 401 `GraphDebugInfo` for all the original nodes in `converted_graph`. 402 """ 403 if not nodes_to_debug_info_func: 404 return None 405 406 # Collect all the debug info nodes from the converted_graph 407 original_nodes = set() 408 for node in converted_graph.node: 409 debug_nodes = node.experimental_debug_info.original_node_names 410 debug_funcs = node.experimental_debug_info.original_func_names 411 # If the `original_node_names` are empty, uses the node name directly. 412 if not debug_nodes: 413 original_nodes.add(("", node.name)) 414 else: 415 for i in range(len(debug_nodes)): 416 debug_func = "" if i >= len(debug_funcs) else debug_funcs[i] 417 original_nodes.add((debug_func, debug_nodes[i])) 418 419 # Convert the nodes to the debug info proto object. 420 return nodes_to_debug_info_func(original_nodes) 421 422 423def convert_bytes_to_c_source(data, 424 array_name, 425 max_line_width=80, 426 include_guard=None, 427 include_path=None, 428 use_tensorflow_license=False): 429 """Returns strings representing a C constant array containing `data`. 430 431 Args: 432 data: Byte array that will be converted into a C constant. 433 array_name: String to use as the variable name for the constant array. 434 max_line_width: The longest line length, for formatting purposes. 435 include_guard: Name to use for the include guard macro definition. 436 include_path: Optional path to include in the source file. 437 use_tensorflow_license: Whether to include the standard TensorFlow Apache2 438 license in the generated files. 439 440 Returns: 441 Text that can be compiled as a C source file to link in the data as a 442 literal array of values. 443 Text that can be used as a C header file to reference the literal array. 444 """ 445 446 starting_pad = " " 447 array_lines = [] 448 array_line = starting_pad 449 for value in bytearray(data): 450 if (len(array_line) + 4) > max_line_width: 451 array_lines.append(array_line + "\n") 452 array_line = starting_pad 453 array_line += " 0x%02x," % (value) 454 if len(array_line) > len(starting_pad): 455 array_lines.append(array_line + "\n") 456 array_values = "".join(array_lines) 457 458 if include_guard is None: 459 include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_" 460 461 if include_path is not None: 462 include_line = "#include \"{include_path}\"\n".format( 463 include_path=include_path) 464 else: 465 include_line = "" 466 467 if use_tensorflow_license: 468 license_text = """ 469/* Copyright {year} The TensorFlow Authors. All Rights Reserved. 470 471Licensed under the Apache License, Version 2.0 (the "License"); 472you may not use this file except in compliance with the License. 473You may obtain a copy of the License at 474 475 http://www.apache.org/licenses/LICENSE-2.0 476 477Unless required by applicable law or agreed to in writing, software 478distributed under the License is distributed on an "AS IS" BASIS, 479WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 480See the License for the specific language governing permissions and 481limitations under the License. 482==============================================================================*/ 483""".format(year=datetime.date.today().year) 484 else: 485 license_text = "" 486 487 source_template = """{license_text} 488// This is a TensorFlow Lite model file that has been converted into a C data 489// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 490// This form is useful for compiling into a binary for devices that don't have a 491// file system. 492 493{include_line} 494// We need to keep the data array aligned on some architectures. 495#ifdef __has_attribute 496#define HAVE_ATTRIBUTE(x) __has_attribute(x) 497#else 498#define HAVE_ATTRIBUTE(x) 0 499#endif 500#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) 501#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) 502#else 503#define DATA_ALIGN_ATTRIBUTE 504#endif 505 506const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{ 507{array_values}}}; 508const int {array_name}_len = {array_length}; 509""" 510 511 source_text = source_template.format( 512 array_name=array_name, 513 array_length=len(data), 514 array_values=array_values, 515 license_text=license_text, 516 include_line=include_line) 517 518 header_template = """ 519{license_text} 520 521// This is a TensorFlow Lite model file that has been converted into a C data 522// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 523// This form is useful for compiling into a binary for devices that don't have a 524// file system. 525 526#ifndef {include_guard} 527#define {include_guard} 528 529extern const unsigned char {array_name}[]; 530extern const int {array_name}_len; 531 532#endif // {include_guard} 533""" 534 535 header_text = header_template.format( 536 array_name=array_name, 537 include_guard=include_guard, 538 license_text=license_text) 539 540 return source_text, header_text 541 542 543def _convert_model_from_bytearray_to_object(model_bytearray): 544 """Converts a tflite model from a bytearray into a parsable object.""" 545 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) 546 model_object = schema_fb.ModelT.InitFromObj(model_object) 547 model_object = copy.deepcopy(model_object) 548 model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0] 549 return model_object 550 551 552def _convert_model_from_object_to_bytearray(model_object): 553 """Converts a tflite model from a parsable object into a bytearray.""" 554 # Initial size of the buffer, which will grow automatically if needed 555 builder = flatbuffers.Builder(1024) 556 model_offset = model_object.Pack(builder) 557 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) 558 return bytes(builder.Output()) 559 560 561def get_quantize_opcode_idx(model): 562 """Returns the quantize op idx.""" 563 quant_opcode_idxs = [] 564 for idx, opcode in enumerate(model.operatorCodes): 565 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 566 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 567 quant_opcode_idxs.append(idx) 568 return quant_opcode_idxs 569 570 571def get_dequantize_opcode_idx(model): 572 """Returns the quantize op idx.""" 573 quant_opcode_idxs = [] 574 for idx, opcode in enumerate(model.operatorCodes): 575 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 576 if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE: 577 quant_opcode_idxs.append(idx) 578 return quant_opcode_idxs 579 580 581def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors): 582 """Update the tensors in the SignatureDef's TensorMaps.""" 583 for i in range(len(tensor_maps)): 584 if tensor_maps[i].tensorIndex in map_old_to_new_tensors: 585 tensor_maps[i].tensorIndex = ( 586 map_old_to_new_tensors[tensor_maps[i].tensorIndex]) 587 588 589def _remove_tensors_from_model(model, remove_tensors_idxs): 590 """Remove tensors from model.""" 591 if not remove_tensors_idxs: 592 return 593 if len(model.subgraphs) > 1: 594 raise ValueError("Model must only have one subgraph. Instead, it has " 595 "{} subgraphs.".format(len(model.subgraphs))) 596 subgraph = model.subgraphs[0] 597 tensors = subgraph.tensors 598 operators = subgraph.operators 599 600 logging.debug("Removing tensors at indices : %s", remove_tensors_idxs) 601 # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an 602 # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]). 603 if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs): 604 logging.debug("Removing tensors only at the end of the tensor list") 605 del tensors[min(remove_tensors_idxs):] 606 else: 607 logging.debug("Removing tensors requires updating the model") 608 # Map the old tensor indices to new tensor indices 609 d_old_to_new_tensors = {} 610 left_shift_by = 0 611 for idx in range(len(tensors)): 612 if idx in remove_tensors_idxs: 613 left_shift_by += 1 614 else: 615 d_old_to_new_tensors[idx] = idx - left_shift_by 616 logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__()) 617 # Update tensor indices referenced throughout the model 618 def update_tensors(tensor_idxs): 619 for i, ti in enumerate(tensor_idxs): 620 tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1) 621 update_tensors(subgraph.inputs) 622 update_tensors(subgraph.outputs) 623 for op in operators: 624 update_tensors(op.inputs) 625 update_tensors(op.outputs) 626 if model.signatureDefs: 627 signature_def = model.signatureDefs[0] 628 _update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors) 629 _update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors) 630 # Delete the tensors 631 for idx in sorted(remove_tensors_idxs, reverse=True): 632 tensors.pop(idx) 633 logging.debug("Removed tensors marked for deletion") 634 635 636def _modify_model_input_type(model, inference_input_type=dtypes.float32): 637 """Modify model input type.""" 638 639 if inference_input_type == dtypes.float32: 640 return 641 642 subgraph = model.subgraphs[0] 643 tensors = subgraph.tensors 644 operators = subgraph.operators 645 646 # Find all quantize operators 647 quant_opcode_idxs = get_quantize_opcode_idx(model) 648 if operators and not quant_opcode_idxs: 649 for input_idx in subgraph.inputs: 650 input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type) 651 if input_type == dtypes.float32: 652 raise ValueError("Model input is not dequantized.") 653 # None of the inputs have float32, then they must be int16, int8, or bool 654 return 655 656 # Validate that the model input is quantized 657 input_quant_ops = [] 658 for op in operators: 659 # Find operators that quantize model input 660 if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs: 661 float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 662 # If found, validate that the operator's input type is float 663 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 664 if float_type != dtypes.float32: 665 if float_type == inference_input_type: 666 continue 667 else: 668 raise ValueError( 669 "Initial model input type must be tf.float32. Expected type for " 670 "tensor with name '{}' is tf.float32, instead type is {}".format( 671 float_tensor.name, get_tf_type_name(float_type))) 672 # If found, validate that the operator output is quantized and compatible 673 # with the final model input type 674 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 675 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 676 raise ValueError( 677 "Initial model input is not quantized. Expected type for " 678 "tensor with name '{}' should be in {}, instead type is {}".format( 679 quant_tensor.name, 680 tuple(get_tf_type_name(t) for t in 681 _MAP_QUANT_TO_IO_TYPES.keys()), 682 get_tf_type_name(quant_type))) 683 else: 684 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 685 if inference_input_type not in inference_io_types: 686 raise ValueError( 687 "Unsupported `inference_input_type` value. Expected to be in " 688 "{}, instead got {}.".format( 689 tuple(get_tf_type_name(t) for t in inference_io_types), 690 get_tf_type_name(inference_input_type))) 691 input_quant_ops.append(op) 692 693 if len(subgraph.inputs) != len(input_quant_ops): 694 logging.warning( 695 "For model inputs containing unsupported operations which cannot be " 696 "quantized, the `inference_input_type` attribute will default to the " 697 "original type." 698 ) 699 700 # Modify model input type 701 if inference_input_type == dtypes.uint8: 702 # Change quant op (float to int8) to quant op (uint8 to int8) 703 for op in input_quant_ops: 704 int8_quantization = tensors[op.outputs[0]].quantization 705 uint8_quantization = schema_fb.QuantizationParametersT() 706 uint8_quantization.scale = [int8_quantization.scale[0]] 707 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 708 tensors[op.inputs[0]].quantization = uint8_quantization 709 tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 710 elif inference_input_type in _MAP_QUANT_TO_IO_TYPES: 711 # Remove the inputs and the quant operator 712 remove_tensors_idxs = set() 713 for op in input_quant_ops: 714 subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] 715 if model.signatureDefs: 716 signature_def = model.signatureDefs[0] 717 for i in range(len(signature_def.inputs)): 718 if signature_def.inputs[i].tensorIndex == op.inputs[0]: 719 signature_def.inputs[i].tensorIndex = op.outputs[0] 720 remove_tensors_idxs.add(op.inputs[0]) 721 operators.remove(op) 722 # Remove tensors marked for deletion. 723 _remove_tensors_from_model(model, remove_tensors_idxs) 724 else: 725 raise ValueError( 726 "Unsupported `inference_input_type` value {}.".format( 727 get_tf_type_name(inference_input_type))) 728 729 730def _modify_model_output_type(model, inference_output_type=dtypes.float32): 731 """Modify model output type.""" 732 733 if inference_output_type == dtypes.float32: 734 return 735 736 subgraph = model.subgraphs[0] 737 tensors = subgraph.tensors 738 operators = subgraph.operators 739 740 # Find all dequantize operators 741 dequant_opcode_idxs = get_dequantize_opcode_idx(model) 742 if operators and not dequant_opcode_idxs: 743 for output in subgraph.outputs: 744 output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type) 745 if output_type == dtypes.float32: 746 raise ValueError("Model output is not dequantized.") 747 # None of the outputs have float32, then they must be int16, int8, or bool 748 return 749 750 # Validate that the model output is dequantized 751 output_dequant_ops = [] 752 for op in operators: 753 # Find operators that dequantize model output 754 if (op.opcodeIndex in dequant_opcode_idxs and 755 op.outputs[0] in subgraph.outputs): 756 # If found, validate that the operator's output type is float 757 quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 758 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 759 if float_type != dtypes.float32: 760 if float_type == inference_output_type: 761 continue 762 else: 763 raise ValueError( 764 "Initial model output type must be tf.float32. Expected type for " 765 "tensor with name '{}' is tf.float32, instead type is {}".format( 766 float_tensor.name, get_tf_type_name(float_type))) 767 # If found, validate that the operator input is quantized and compatible 768 # with the final model output type 769 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 770 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 771 raise ValueError( 772 "Initial model output is not dequantized. Expected type for " 773 "tensor with name '{}' should be in {}, instead type is {}".format( 774 quant_tensor.name, 775 tuple(get_tf_type_name(t) for t in 776 _MAP_QUANT_TO_IO_TYPES.keys()), 777 get_tf_type_name(quant_type))) 778 else: 779 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 780 if inference_output_type not in inference_io_types: 781 raise ValueError( 782 "Unsupported `inference_output_type` value. Expected to be in " 783 "{}, instead got {}.".format( 784 tuple(get_tf_type_name(t) for t in inference_io_types), 785 get_tf_type_name(inference_output_type))) 786 output_dequant_ops.append(op) 787 788 if len(subgraph.outputs) != len(output_dequant_ops): 789 logging.warning( 790 "For model outputs containing unsupported operations which cannot be " 791 "quantized, the `inference_output_type` attribute will default to the " 792 "original type." 793 ) 794 795 # Modify model output type 796 if inference_output_type == dtypes.uint8: 797 # Find a quantize operator 798 quant_opcode_idx = -1 799 for idx, opcode in enumerate(model.operatorCodes): 800 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 801 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 802 quant_opcode_idx = idx 803 break 804 # Create a quantize operator, if none exist 805 if quant_opcode_idx == -1: 806 quant_op = schema_fb.OperatorCodeT() 807 quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE 808 quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE 809 model.operatorCodes.append(quant_op) 810 quant_opcode_idx = len(model.operatorCodes) - 1 811 # Change dequant op (int8 to float) to quant op (int8 to uint8) 812 for op in output_dequant_ops: 813 op.opcodeIndex = quant_opcode_idx 814 int8_quantization = tensors[op.inputs[0]].quantization 815 uint8_quantization = schema_fb.QuantizationParametersT() 816 uint8_quantization.scale = [int8_quantization.scale[0]] 817 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 818 tensors[op.outputs[0]].quantization = uint8_quantization 819 tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 820 elif inference_output_type in _MAP_QUANT_TO_IO_TYPES: 821 # Remove the outputs and the dequant operator 822 remove_tensors_idxs = set() 823 for op in output_dequant_ops: 824 subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] 825 if model.signatureDefs: 826 signature_def = model.signatureDefs[0] 827 for i in range(len(signature_def.outputs)): 828 if signature_def.outputs[i].tensorIndex == op.outputs[0]: 829 signature_def.outputs[i].tensorIndex = op.inputs[0] 830 remove_tensors_idxs.add(op.outputs[0]) 831 operators.remove(op) 832 # Remove tensors marked for deletion. 833 _remove_tensors_from_model(model, remove_tensors_idxs) 834 else: 835 raise ValueError( 836 "Unsupported `inference_output_type` value {}.".format( 837 get_tf_type_name(inference_output_type))) 838 839 840def _remove_redundant_quantize_ops(model): 841 """Finds back to back quantize ops and remove the first quantize op.""" 842 subgraph = model.subgraphs[0] 843 tensors = subgraph.tensors 844 operators = subgraph.operators 845 846 # Find all quantize operators. 847 quant_opcode_idxs = get_quantize_opcode_idx(model) 848 dequant_opcode_idxs = get_dequantize_opcode_idx(model) 849 850 # Find all redundant quant tensors. 851 all_quant_ops = [] 852 redundant_quant_tensors = {} 853 output_dequant_tensors = {} 854 for op in operators: 855 if op.opcodeIndex in quant_opcode_idxs: 856 all_quant_ops.append(op) 857 input_tensor = tensors[op.inputs[0]] 858 output_tensor = tensors[op.outputs[0]] 859 input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type) 860 output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type) 861 # This is a requantize op, so write down its input tensor index. 862 if input_type != dtypes.float32 and output_type != dtypes.float32: 863 redundant_quant_tensors[op.inputs[0]] = op 864 if (op.opcodeIndex in dequant_opcode_idxs and 865 op.outputs[0] in subgraph.outputs): 866 output_dequant_tensors[op.inputs[0]] = op 867 868 # Remove all the quant ops which produce the redundant quant tensors. 869 for op in all_quant_ops: 870 output_tensor_idx = op.outputs[0] 871 if output_tensor_idx in redundant_quant_tensors: 872 requantize_op = redundant_quant_tensors[output_tensor_idx] 873 if model.signatureDefs: 874 signature_def = model.signatureDefs[0] 875 for output in signature_def.outputs: 876 if output.tensorIndex == op.outputs[0]: 877 output.tensorIndex = op.inputs[0] 878 # Reset the input of the requantize op to the float input 879 requantize_op.inputs[0] = op.inputs[0] 880 operators.remove(op) 881 882 # Remove all the quant ops which connect to the output dequant op. 883 for op in all_quant_ops: 884 output_tensor_idx = op.outputs[0] 885 if output_tensor_idx in output_dequant_tensors: 886 dequant_op = output_dequant_tensors[output_tensor_idx] 887 subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0] 888 if model.signatureDefs: 889 signature_def = model.signatureDefs[0] 890 for output in signature_def.outputs: 891 if output.tensorIndex == dequant_op.outputs[0]: 892 output.tensorIndex = op.inputs[0] 893 operators.remove(op) 894 operators.remove(dequant_op) 895 896 897def modify_model_io_type( 898 model, inference_input_type=dtypes.float32, 899 inference_output_type=dtypes.float32): 900 """Modify the input/output type of a tflite model. 901 902 Args: 903 model: A tflite model. 904 inference_input_type: tf.DType representing modified input type. 905 (default tf.float32. If model input is int8 quantized, it must be in 906 {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized, 907 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 908 inference_output_type: tf.DType representing modified output type. 909 (default tf.float32. If model output is int8 dequantized, it must be in 910 {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized, 911 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 912 Returns: 913 A tflite model with modified input/output type. 914 915 Raises: 916 ValueError: If `inference_input_type`/`inference_output_type` is unsupported 917 or a supported integer type is specified for a model whose input/output is 918 not quantized/dequantized. 919 RuntimeError: If the modification was unsuccessful. 920 921 """ 922 if (inference_input_type == dtypes.float32 and 923 inference_output_type == dtypes.float32): 924 return model 925 926 model_object = _convert_model_from_bytearray_to_object(model) 927 928 if len(model_object.subgraphs) > 1: 929 raise ValueError("Model must only have one subgraph. Instead, it has " 930 "{} subgraphs.".format(len(model_object.subgraphs))) 931 932 _modify_model_input_type(model_object, inference_input_type) 933 934 _modify_model_output_type(model_object, inference_output_type) 935 936 _remove_redundant_quantize_ops(model_object) 937 938 return _convert_model_from_object_to_bytearray(model_object) 939