• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for FlatBuffers.
16
17All functions that are commonly used to work with FlatBuffers.
18
19Refer to the tensorflow lite flatbuffer schema here:
20tensorflow/lite/schema/schema.fbs
21
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import copy
29import random
30import re
31
32import flatbuffers
33from tensorflow.lite.python import schema_py_generated as schema_fb
34from tensorflow.python.platform import gfile
35
36_TFLITE_FILE_IDENTIFIER = b'TFL3'
37
38
39def convert_bytearray_to_object(model_bytearray):
40  """Converts a tflite model from a bytearray to an object for parsing."""
41  model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
42  return schema_fb.ModelT.InitFromObj(model_object)
43
44
45def read_model(input_tflite_file):
46  """Reads a tflite model as a python object.
47
48  Args:
49    input_tflite_file: Full path name to the input tflite file
50
51  Raises:
52    RuntimeError: If input_tflite_file path is invalid.
53    IOError: If input_tflite_file cannot be opened.
54
55  Returns:
56    A python object corresponding to the input tflite file.
57  """
58  if not gfile.Exists(input_tflite_file):
59    raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
60  with gfile.GFile(input_tflite_file, 'rb') as input_file_handle:
61    model_bytearray = bytearray(input_file_handle.read())
62  return convert_bytearray_to_object(model_bytearray)
63
64
65def read_model_with_mutable_tensors(input_tflite_file):
66  """Reads a tflite model as a python object with mutable tensors.
67
68  Similar to read_model() with the addition that the returned object has
69  mutable tensors (read_model() returns an object with immutable tensors).
70
71  Args:
72    input_tflite_file: Full path name to the input tflite file
73
74  Raises:
75    RuntimeError: If input_tflite_file path is invalid.
76    IOError: If input_tflite_file cannot be opened.
77
78  Returns:
79    A mutable python object corresponding to the input tflite file.
80  """
81  return copy.deepcopy(read_model(input_tflite_file))
82
83
84def convert_object_to_bytearray(model_object):
85  """Converts a tflite model from an object to a immutable bytearray."""
86  # Initial size of the buffer, which will grow automatically if needed
87  builder = flatbuffers.Builder(1024)
88  model_offset = model_object.Pack(builder)
89  builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
90  model_bytearray = bytes(builder.Output())
91  return model_bytearray
92
93
94def write_model(model_object, output_tflite_file):
95  """Writes the tflite model, a python object, into the output file.
96
97  Args:
98    model_object: A tflite model as a python object
99    output_tflite_file: Full path name to the output tflite file.
100
101  Raises:
102    IOError: If output_tflite_file path is invalid or cannot be opened.
103  """
104  model_bytearray = convert_object_to_bytearray(model_object)
105  with gfile.GFile(output_tflite_file, 'wb') as output_file_handle:
106    output_file_handle.write(model_bytearray)
107
108
109def strip_strings(model):
110  """Strips all nonessential strings from the model to reduce model size.
111
112  We remove the following strings:
113  (find strings by searching ":string" in the tensorflow lite flatbuffer schema)
114  1. Model description
115  2. SubGraph name
116  3. Tensor names
117  We retain OperatorCode custom_code and Metadata name.
118
119  Args:
120    model: The model from which to remove nonessential strings.
121  """
122
123  model.description = None
124  for subgraph in model.subgraphs:
125    subgraph.name = None
126    for tensor in subgraph.tensors:
127      tensor.name = None
128  # We clear all signature_def structure, since without names it is useless.
129  model.signatureDefs = None
130
131
132def randomize_weights(model, random_seed=0):
133  """Randomize weights in a model.
134
135  Args:
136    model: The model in which to randomize weights.
137    random_seed: The input to the random number generator (default value is 0).
138  """
139
140  # The input to the random seed generator. The default value is 0.
141  random.seed(random_seed)
142
143  # Parse model buffers which store the model weights
144  buffers = model.buffers
145  for i in range(1, len(buffers)):  # ignore index 0 as it's always None
146    buffer_i_data = buffers[i].data
147    buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
148
149    # Raw data buffers are of type ubyte (or uint8) whose values lie in the
150    # range [0, 255]. Those ubytes (or unint8s) are the underlying
151    # representation of each datatype. For example, a bias tensor of type
152    # int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
153    # TODO(b/152324470): This does not work for float as randomized weights may
154    # end up as denormalized or NaN/Inf floating point numbers.
155    for j in range(buffer_i_size):
156      buffer_i_data[j] = random.randint(0, 255)
157
158
159def rename_custom_ops(model, map_custom_op_renames):
160  """Rename custom ops so they use the same naming style as builtin ops.
161
162  Args:
163    model: The input tflite model.
164    map_custom_op_renames: A mapping from old to new custom op names.
165  """
166  for op_code in model.operatorCodes:
167    if op_code.customCode:
168      op_code_str = op_code.customCode.decode('ascii')
169      if op_code_str in map_custom_op_renames:
170        op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
171
172
173def xxd_output_to_bytes(input_cc_file):
174  """Converts xxd output C++ source file to bytes (immutable).
175
176  Args:
177    input_cc_file: Full path name to th C++ source file dumped by xxd
178
179  Raises:
180    RuntimeError: If input_cc_file path is invalid.
181    IOError: If input_cc_file cannot be opened.
182
183  Returns:
184    A bytearray corresponding to the input cc file array.
185  """
186  # Match hex values in the string with comma as separator
187  pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
188
189  model_bytearray = bytearray()
190
191  with open(input_cc_file) as file_handle:
192    for line in file_handle:
193      values_match = pattern.match(line)
194
195      if values_match is None:
196        continue
197
198      # Match in the parentheses (hex array only)
199      list_text = values_match.group(1)
200
201      # Extract hex values (text) from the line
202      # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
203      values_text = filter(None, list_text.split(','))
204
205      # Convert to hex
206      values = [int(x, base=16) for x in values_text]
207      model_bytearray.extend(values)
208
209  return bytes(model_bytearray)
210
211
212def xxd_output_to_object(input_cc_file):
213  """Converts xxd output C++ source file to object.
214
215  Args:
216    input_cc_file: Full path name to th C++ source file dumped by xxd
217
218  Raises:
219    RuntimeError: If input_cc_file path is invalid.
220    IOError: If input_cc_file cannot be opened.
221
222  Returns:
223    A python object corresponding to the input tflite file.
224  """
225  model_bytes = xxd_output_to_bytes(input_cc_file)
226  return convert_bytearray_to_object(model_bytes)
227