1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for FlatBuffers. 16 17All functions that are commonly used to work with FlatBuffers. 18 19Refer to the tensorflow lite flatbuffer schema here: 20tensorflow/lite/schema/schema.fbs 21 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28import copy 29import random 30import re 31 32import flatbuffers 33from tensorflow.lite.python import schema_py_generated as schema_fb 34from tensorflow.python.platform import gfile 35 36_TFLITE_FILE_IDENTIFIER = b'TFL3' 37 38 39def convert_bytearray_to_object(model_bytearray): 40 """Converts a tflite model from a bytearray to an object for parsing.""" 41 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) 42 return schema_fb.ModelT.InitFromObj(model_object) 43 44 45def read_model(input_tflite_file): 46 """Reads a tflite model as a python object. 47 48 Args: 49 input_tflite_file: Full path name to the input tflite file 50 51 Raises: 52 RuntimeError: If input_tflite_file path is invalid. 53 IOError: If input_tflite_file cannot be opened. 54 55 Returns: 56 A python object corresponding to the input tflite file. 57 """ 58 if not gfile.Exists(input_tflite_file): 59 raise RuntimeError('Input file not found at %r\n' % input_tflite_file) 60 with gfile.GFile(input_tflite_file, 'rb') as input_file_handle: 61 model_bytearray = bytearray(input_file_handle.read()) 62 return convert_bytearray_to_object(model_bytearray) 63 64 65def read_model_with_mutable_tensors(input_tflite_file): 66 """Reads a tflite model as a python object with mutable tensors. 67 68 Similar to read_model() with the addition that the returned object has 69 mutable tensors (read_model() returns an object with immutable tensors). 70 71 Args: 72 input_tflite_file: Full path name to the input tflite file 73 74 Raises: 75 RuntimeError: If input_tflite_file path is invalid. 76 IOError: If input_tflite_file cannot be opened. 77 78 Returns: 79 A mutable python object corresponding to the input tflite file. 80 """ 81 return copy.deepcopy(read_model(input_tflite_file)) 82 83 84def convert_object_to_bytearray(model_object): 85 """Converts a tflite model from an object to a immutable bytearray.""" 86 # Initial size of the buffer, which will grow automatically if needed 87 builder = flatbuffers.Builder(1024) 88 model_offset = model_object.Pack(builder) 89 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) 90 model_bytearray = bytes(builder.Output()) 91 return model_bytearray 92 93 94def write_model(model_object, output_tflite_file): 95 """Writes the tflite model, a python object, into the output file. 96 97 Args: 98 model_object: A tflite model as a python object 99 output_tflite_file: Full path name to the output tflite file. 100 101 Raises: 102 IOError: If output_tflite_file path is invalid or cannot be opened. 103 """ 104 model_bytearray = convert_object_to_bytearray(model_object) 105 with gfile.GFile(output_tflite_file, 'wb') as output_file_handle: 106 output_file_handle.write(model_bytearray) 107 108 109def strip_strings(model): 110 """Strips all nonessential strings from the model to reduce model size. 111 112 We remove the following strings: 113 (find strings by searching ":string" in the tensorflow lite flatbuffer schema) 114 1. Model description 115 2. SubGraph name 116 3. Tensor names 117 We retain OperatorCode custom_code and Metadata name. 118 119 Args: 120 model: The model from which to remove nonessential strings. 121 """ 122 123 model.description = None 124 for subgraph in model.subgraphs: 125 subgraph.name = None 126 for tensor in subgraph.tensors: 127 tensor.name = None 128 # We clear all signature_def structure, since without names it is useless. 129 model.signatureDefs = None 130 131 132def randomize_weights(model, random_seed=0): 133 """Randomize weights in a model. 134 135 Args: 136 model: The model in which to randomize weights. 137 random_seed: The input to the random number generator (default value is 0). 138 """ 139 140 # The input to the random seed generator. The default value is 0. 141 random.seed(random_seed) 142 143 # Parse model buffers which store the model weights 144 buffers = model.buffers 145 for i in range(1, len(buffers)): # ignore index 0 as it's always None 146 buffer_i_data = buffers[i].data 147 buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size 148 149 # Raw data buffers are of type ubyte (or uint8) whose values lie in the 150 # range [0, 255]. Those ubytes (or unint8s) are the underlying 151 # representation of each datatype. For example, a bias tensor of type 152 # int32 appears as a buffer 4 times it's length of type ubyte (or uint8). 153 # TODO(b/152324470): This does not work for float as randomized weights may 154 # end up as denormalized or NaN/Inf floating point numbers. 155 for j in range(buffer_i_size): 156 buffer_i_data[j] = random.randint(0, 255) 157 158 159def rename_custom_ops(model, map_custom_op_renames): 160 """Rename custom ops so they use the same naming style as builtin ops. 161 162 Args: 163 model: The input tflite model. 164 map_custom_op_renames: A mapping from old to new custom op names. 165 """ 166 for op_code in model.operatorCodes: 167 if op_code.customCode: 168 op_code_str = op_code.customCode.decode('ascii') 169 if op_code_str in map_custom_op_renames: 170 op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii') 171 172 173def xxd_output_to_bytes(input_cc_file): 174 """Converts xxd output C++ source file to bytes (immutable). 175 176 Args: 177 input_cc_file: Full path name to th C++ source file dumped by xxd 178 179 Raises: 180 RuntimeError: If input_cc_file path is invalid. 181 IOError: If input_cc_file cannot be opened. 182 183 Returns: 184 A bytearray corresponding to the input cc file array. 185 """ 186 # Match hex values in the string with comma as separator 187 pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*') 188 189 model_bytearray = bytearray() 190 191 with open(input_cc_file) as file_handle: 192 for line in file_handle: 193 values_match = pattern.match(line) 194 195 if values_match is None: 196 continue 197 198 # Match in the parentheses (hex array only) 199 list_text = values_match.group(1) 200 201 # Extract hex values (text) from the line 202 # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 203 values_text = filter(None, list_text.split(',')) 204 205 # Convert to hex 206 values = [int(x, base=16) for x in values_text] 207 model_bytearray.extend(values) 208 209 return bytes(model_bytearray) 210 211 212def xxd_output_to_object(input_cc_file): 213 """Converts xxd output C++ source file to object. 214 215 Args: 216 input_cc_file: Full path name to th C++ source file dumped by xxd 217 218 Raises: 219 RuntimeError: If input_cc_file path is invalid. 220 IOError: If input_cc_file cannot be opened. 221 222 Returns: 223 A python object corresponding to the input tflite file. 224 """ 225 model_bytes = xxd_output_to_bytes(input_cc_file) 226 return convert_bytearray_to_object(model_bytes) 227