• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# ==============================================================================
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Upgrade script to move from pre-release schema to new schema.
16
17Usage examples:
18
19bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json
20bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin
21bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json
22bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin
23bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite
24"""
25import argparse
26import contextlib
27import json
28import os
29import shutil
30import subprocess
31import sys
32import tempfile
33
34import tensorflow as tf
35from tensorflow.python.platform import resource_loader
36
37parser = argparse.ArgumentParser(
38    description="Script to move TFLite models from pre-release schema to "
39    "new schema.")
40parser.add_argument(
41    "input",
42    type=str,
43    help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
44parser.add_argument(
45    "output",
46    type=str,
47    help="Output json or bin TensorFlow lite model compliant with "
48    "the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
49
50
51# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
52@contextlib.contextmanager
53def TemporaryDirectoryResource():
54  temporary = tempfile.mkdtemp()
55  try:
56    yield temporary
57  finally:
58    shutil.rmtree(temporary)
59
60
61class Converter:
62  """Converts TensorFlow flatbuffer models from old to new version of schema.
63
64  This can convert between any version to the latest version. It uses
65  an incremental upgrade strategy to go from version to version.
66
67  Usage:
68    converter = Converter()
69    converter.Convert("a.tflite", "a.json")
70    converter.Convert("b.json", "b.tflite")
71  """
72
73  def __init__(self):
74    # TODO(aselle): make this work in the open source version with better
75    # path.
76    paths_to_try = [
77        "../../../../flatbuffers/flatc",  # not bazel
78        "../../../../external/flatbuffers/flatc"  # bazel
79    ]
80    for p in paths_to_try:
81      self._flatc_path = resource_loader.get_path_to_datafile(p)
82      if os.path.exists(self._flatc_path): break
83
84    def FindSchema(base_name):
85      return resource_loader.get_path_to_datafile("%s" % base_name)
86
87    # Supported schemas for upgrade.
88    self._schemas = [
89        (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
90        (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
91        (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
92        (3, FindSchema("schema_v3.fbs"), False, None)  # Non-callable by design.
93    ]
94    # Ensure schemas are sorted, and extract latest version and upgrade
95    # dispatch function table.
96    self._schemas.sort()
97    self._new_version, self._new_schema = self._schemas[-1][:2]
98    self._upgrade_dispatch = {
99        version: dispatch
100        for version, unused1, unused2, dispatch in self._schemas}
101
102  def _Read(self, input_file, schema, raw_binary=False):
103    """Read a tflite model assuming the given flatbuffer schema.
104
105    If `input_file` is in bin, then we must use flatc to convert the schema
106    from binary to json.
107
108    Args:
109      input_file: a binary (flatbuffer) or json file to read from. Extension
110        must  be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
111        FlatBuffer JSON.
112      schema: which schema to use for reading
113      raw_binary: whether to assume raw_binary (versions previous to v3)
114        that lacked file_identifier require this.
115
116    Raises:
117      RuntimeError: 1. When flatc cannot be invoked.
118                    2. When json file does not exists.
119      ValueError: When the extension is not json or bin.
120
121    Returns:
122      A dictionary representing the read tflite model.
123    """
124    raw_binary = ["--raw-binary"] if raw_binary else []
125    with TemporaryDirectoryResource() as tempdir:
126      basename = os.path.basename(input_file)
127      basename_no_extension, extension = os.path.splitext(basename)
128      if extension in [".bin", ".tflite"]:
129        # Convert to json using flatc
130        returncode = subprocess.call([
131            self._flatc_path,
132            "-t",
133            "--strict-json",
134            "--defaults-json",
135        ] + raw_binary + ["-o", tempdir, schema, "--", input_file])
136        if returncode != 0:
137          raise RuntimeError("flatc failed to convert from binary to json.")
138        json_file = os.path.join(tempdir, basename_no_extension + ".json")
139        if not os.path.exists(json_file):
140          raise RuntimeError("Could not find %r" % json_file)
141      elif extension == ".json":
142        json_file = input_file
143      else:
144        raise ValueError("Invalid extension on input file %r" % input_file)
145      return json.load(open(json_file))
146
147  def _Write(self, data, output_file):
148    """Output a json or bin version of the flatbuffer model.
149
150    Args:
151      data: Dict representing the TensorFlow Lite model to write.
152      output_file: filename to write the converted flatbuffer to. (json,
153        tflite, or bin extension is required).
154    Raises:
155      ValueError: When the extension is not json or bin
156      RuntimeError: When flatc fails to convert json data to binary.
157    """
158    _, extension = os.path.splitext(output_file)
159    with TemporaryDirectoryResource() as tempdir:
160      if extension == ".json":
161        json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
162      elif extension in [".tflite", ".bin"]:
163        input_json = os.path.join(tempdir, "temp.json")
164        with open(input_json, "w") as fp:
165          json.dump(data, fp, sort_keys=True, indent=2)
166        returncode = subprocess.call([
167            self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
168            tempdir, self._new_schema, input_json
169        ])
170        if returncode != 0:
171          raise RuntimeError("flatc failed to convert upgraded json to binary.")
172
173        shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
174      else:
175        raise ValueError("Invalid extension on output file %r" % output_file)
176
177  def _Upgrade0To1(self, data):
178    """Upgrade data from Version 0 to Version 1.
179
180    Changes: Added subgraphs (which contains a subset of formally global
181    entries).
182
183    Args:
184      data: Dictionary representing the TensorFlow lite data to be upgraded.
185        This will be modified in-place to be an upgraded version.
186    """
187    subgraph = {}
188    for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
189      subgraph[key_to_promote] = data[key_to_promote]
190      del data[key_to_promote]
191    data["subgraphs"] = [subgraph]
192
193  def _Upgrade1To2(self, data):
194    """Upgrade data from Version 1 to Version 2.
195
196    Changes: Rename operators to Conform to NN API.
197
198    Args:
199      data: Dictionary representing the TensorFlow lite data to be upgraded.
200        This will be modified in-place to be an upgraded version.
201    Raises:
202      ValueError: Throws when model builtins are numeric rather than symbols.
203    """
204
205    def RemapOperator(opcode_name):
206      """Go from old schema op name to new schema op name.
207
208      Args:
209        opcode_name: String representing the ops (see :schema.fbs).
210      Returns:
211        Converted opcode_name from V1 to V2.
212      """
213      old_name_to_new_name = {
214          "CONVOLUTION": "CONV_2D",
215          "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
216          "AVERAGE_POOL": "AVERAGE_POOL_2D",
217          "MAX_POOL": "MAX_POOL_2D",
218          "L2_POOL": "L2_POOL_2D",
219          "SIGMOID": "LOGISTIC",
220          "L2NORM": "L2_NORMALIZATION",
221          "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
222          "Basic_RNN": "RNN",
223      }
224
225      return (old_name_to_new_name[opcode_name]
226              if opcode_name in old_name_to_new_name else opcode_name)
227
228    def RemapOperatorType(operator_type):
229      """Remap operator structs from old names to new names.
230
231      Args:
232        operator_type: String representing the builtin operator data type
233          string.
234        (see :schema.fbs).
235      Raises:
236        ValueError: When the model has consistency problems.
237      Returns:
238        Upgraded builtin operator data type as a string.
239      """
240      old_to_new = {
241          "PoolOptions": "Pool2DOptions",
242          "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
243          "ConvolutionOptions": "Conv2DOptions",
244          "LocalResponseNormOptions": "LocalResponseNormalizationOptions",
245          "BasicRNNOptions": "RNNOptions",
246      }
247      return (old_to_new[operator_type]
248              if operator_type in old_to_new else operator_type)
249
250    for subgraph in data["subgraphs"]:
251      for ops in subgraph["operators"]:
252        ops["builtin_options_type"] = RemapOperatorType(
253            ops["builtin_options_type"])
254
255    # Upgrade the operator codes
256    for operator_code in data["operator_codes"]:
257      # Check if builtin_code is the appropriate string type
258      # use type("") instead of str or unicode. for py2and3
259      if not isinstance(operator_code["builtin_code"], type(u"")):
260        raise ValueError("builtin_code %r is non-string. this usually means "
261                         "your model has consistency problems." %
262                         (operator_code["builtin_code"]))
263      operator_code["builtin_code"] = (RemapOperator(
264          operator_code["builtin_code"]))
265
266  def _Upgrade2To3(self, data):
267    """Upgrade data from Version 2 to Version 3.
268
269    Changed actual read-only tensor data to be in a buffers table instead
270    of inline with the tensor.
271
272    Args:
273      data: Dictionary representing the TensorFlow lite data to be upgraded.
274        This will be modified in-place to be an upgraded version.
275    """
276    buffers = [{"data": []}]  # Start with 1 empty buffer
277    for subgraph in data["subgraphs"]:
278      if "tensors" not in subgraph:
279        continue
280      for tensor in subgraph["tensors"]:
281        if "data_buffer" not in tensor:
282          tensor["buffer"] = 0
283        else:
284          if tensor["data_buffer"]:
285            tensor[u"buffer"] = len(buffers)
286            buffers.append({"data": tensor["data_buffer"]})
287          else:
288            tensor["buffer"] = 0
289          del tensor["data_buffer"]
290    data["buffers"] = buffers
291
292  def _PerformUpgrade(self, data):
293    """Manipulate the `data` (parsed JSON) based on changes in format.
294
295    This incrementally will upgrade from version to version within data.
296
297    Args:
298      data: Dictionary representing the TensorFlow data. This will be upgraded
299        in place.
300    """
301    while data["version"] < self._new_version:
302      self._upgrade_dispatch[data["version"]](data)
303      data["version"] += 1
304
305  def Convert(self, input_file, output_file):
306    """Perform schema conversion from input_file to output_file.
307
308    Args:
309      input_file: Filename of TensorFlow Lite data to convert from. Must
310        be `.json` or `.bin` extension files for JSON or Binary forms of
311        the TensorFlow FlatBuffer schema.
312      output_file: Filename to write to. Extension also must be `.json`
313        or `.bin`.
314
315    Raises:
316      RuntimeError: Generated when none of the upgrader supported schemas
317        matche the `input_file` data.
318    """
319    # Read data in each schema (since they are incompatible). Version is
320    # always present. Use the read data that matches the version of the
321    # schema.
322    for version, schema, raw_binary, _ in self._schemas:
323      try:
324        data_candidate = self._Read(input_file, schema, raw_binary)
325      except RuntimeError:
326        continue  # Skip and hope another schema works
327      if "version" not in data_candidate:  # Assume version 1 if not present.
328        data_candidate["version"] = 1
329      elif data_candidate["version"] == 0:  # Version 0 doesn't exist in wild.
330        data_candidate["version"] = 1
331
332      if data_candidate["version"] == version:
333        self._PerformUpgrade(data_candidate)
334        self._Write(data_candidate, output_file)
335        return
336    raise RuntimeError("No schema that the converter understands worked with "
337                       "the data file you provided.")
338
339
340def main(argv):
341  del argv
342  Converter().Convert(FLAGS.input, FLAGS.output)
343
344
345if __name__ == "__main__":
346  FLAGS, unparsed = parser.parse_known_args()
347  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
348