• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# ==============================================================================
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Upgrade script to move from pre-release schema to new schema.
16
17Usage examples:
18
19bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json
20bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin
21bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json
22bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin
23bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite
24"""
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28
29import argparse
30import contextlib
31import json
32import os
33import shutil
34import subprocess
35import sys
36import tempfile
37
38import tensorflow as tf
39from tensorflow.python.platform import resource_loader
40
41parser = argparse.ArgumentParser(
42    description="Script to move TFLite models from pre-release schema to "
43    "new schema.")
44parser.add_argument(
45    "input",
46    type=str,
47    help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
48parser.add_argument(
49    "output",
50    type=str,
51    help="Output json or bin TensorFlow lite model compliant with "
52    "the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
53
54
55# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
56@contextlib.contextmanager
57def TemporaryDirectoryResource():
58  temporary = tempfile.mkdtemp()
59  try:
60    yield temporary
61  finally:
62    shutil.rmtree(temporary)
63
64
65class Converter(object):
66  """Converts TensorFlow flatbuffer models from old to new version of schema.
67
68  This can convert between any version to the latest version. It uses
69  an incremental upgrade strategy to go from version to version.
70
71  Usage:
72    converter = Converter()
73    converter.Convert("a.tflite", "a.json")
74    converter.Convert("b.json", "b.tflite")
75  """
76
77  def __init__(self):
78    # TODO(aselle): make this work in the open source version with better
79    # path.
80    paths_to_try = [
81        "../../../../flatbuffers/flatc",  # not bazel
82        "../../../../external/flatbuffers/flatc"  # bazel
83    ]
84    for p in paths_to_try:
85      self._flatc_path = resource_loader.get_path_to_datafile(p)
86      if os.path.exists(self._flatc_path): break
87
88    def FindSchema(base_name):
89      return resource_loader.get_path_to_datafile("%s" % base_name)
90
91    # Supported schemas for upgrade.
92    self._schemas = [
93        (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
94        (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
95        (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
96        (3, FindSchema("schema_v3.fbs"), False, None)  # Non-callable by design.
97    ]
98    # Ensure schemas are sorted, and extract latest version and upgrade
99    # dispatch function table.
100    self._schemas.sort()
101    self._new_version, self._new_schema = self._schemas[-1][:2]
102    self._upgrade_dispatch = {
103        version: dispatch
104        for version, unused1, unused2, dispatch in self._schemas}
105
106  def _Read(self, input_file, schema, raw_binary=False):
107    """Read a tflite model assuming the given flatbuffer schema.
108
109    If `input_file` is in bin, then we must use flatc to convert the schema
110    from binary to json.
111
112    Args:
113      input_file: a binary (flatbuffer) or json file to read from. Extension
114        must  be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
115        FlatBuffer JSON.
116      schema: which schema to use for reading
117      raw_binary: whether to assume raw_binary (versions previous to v3)
118        that lacked file_identifier require this.
119
120    Raises:
121      RuntimeError: 1. When flatc cannot be invoked.
122                    2. When json file does not exists.
123      ValueError: When the extension is not json or bin.
124
125    Returns:
126      A dictionary representing the read tflite model.
127    """
128    raw_binary = ["--raw-binary"] if raw_binary else []
129    with TemporaryDirectoryResource() as tempdir:
130      basename = os.path.basename(input_file)
131      basename_no_extension, extension = os.path.splitext(basename)
132      if extension in [".bin", ".tflite"]:
133        # Convert to json using flatc
134        returncode = subprocess.call([
135            self._flatc_path,
136            "-t",
137            "--strict-json",
138            "--defaults-json",
139        ] + raw_binary + ["-o", tempdir, schema, "--", input_file])
140        if returncode != 0:
141          raise RuntimeError("flatc failed to convert from binary to json.")
142        json_file = os.path.join(tempdir, basename_no_extension + ".json")
143        if not os.path.exists(json_file):
144          raise RuntimeError("Could not find %r" % json_file)
145      elif extension == ".json":
146        json_file = input_file
147      else:
148        raise ValueError("Invalid extension on input file %r" % input_file)
149      return json.load(open(json_file))
150
151  def _Write(self, data, output_file):
152    """Output a json or bin version of the flatbuffer model.
153
154    Args:
155      data: Dict representing the TensorFlow Lite model to write.
156      output_file: filename to write the converted flatbuffer to. (json,
157        tflite, or bin extension is required).
158    Raises:
159      ValueError: When the extension is not json or bin
160      RuntimeError: When flatc fails to convert json data to binary.
161    """
162    _, extension = os.path.splitext(output_file)
163    with TemporaryDirectoryResource() as tempdir:
164      if extension == ".json":
165        json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
166      elif extension in [".tflite", ".bin"]:
167        input_json = os.path.join(tempdir, "temp.json")
168        with open(input_json, "w") as fp:
169          json.dump(data, fp, sort_keys=True, indent=2)
170        returncode = subprocess.call([
171            self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
172            tempdir, self._new_schema, input_json
173        ])
174        if returncode != 0:
175          raise RuntimeError("flatc failed to convert upgraded json to binary.")
176
177        shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
178      else:
179        raise ValueError("Invalid extension on output file %r" % output_file)
180
181  def _Upgrade0To1(self, data):
182    """Upgrade data from Version 0 to Version 1.
183
184    Changes: Added subgraphs (which contains a subset of formally global
185    entries).
186
187    Args:
188      data: Dictionary representing the TensorFlow lite data to be upgraded.
189        This will be modified in-place to be an upgraded version.
190    """
191    subgraph = {}
192    for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
193      subgraph[key_to_promote] = data[key_to_promote]
194      del data[key_to_promote]
195    data["subgraphs"] = [subgraph]
196
197  def _Upgrade1To2(self, data):
198    """Upgrade data from Version 1 to Version 2.
199
200    Changes: Rename operators to Conform to NN API.
201
202    Args:
203      data: Dictionary representing the TensorFlow lite data to be upgraded.
204        This will be modified in-place to be an upgraded version.
205    Raises:
206      ValueError: Throws when model builtins are numeric rather than symbols.
207    """
208
209    def RemapOperator(opcode_name):
210      """Go from old schema op name to new schema op name.
211
212      Args:
213        opcode_name: String representing the ops (see :schema.fbs).
214      Returns:
215        Converted opcode_name from V1 to V2.
216      """
217      old_name_to_new_name = {
218          "CONVOLUTION": "CONV_2D",
219          "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
220          "AVERAGE_POOL": "AVERAGE_POOL_2D",
221          "MAX_POOL": "MAX_POOL_2D",
222          "L2_POOL": "L2_POOL_2D",
223          "SIGMOID": "LOGISTIC",
224          "L2NORM": "L2_NORMALIZATION",
225          "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
226          "Basic_RNN": "RNN",
227      }
228
229      return (old_name_to_new_name[opcode_name]
230              if opcode_name in old_name_to_new_name else opcode_name)
231
232    def RemapOperatorType(operator_type):
233      """Remap operator structs from old names to new names.
234
235      Args:
236        operator_type: String representing the builtin operator data type
237          string.
238        (see :schema.fbs).
239      Raises:
240        ValueError: When the model has consistency problems.
241      Returns:
242        Upgraded builtin operator data type as a string.
243      """
244      old_to_new = {
245          "PoolOptions": "Pool2DOptions",
246          "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
247          "ConvolutionOptions": "Conv2DOptions",
248          "LocalResponseNormOptions": "LocalResponseNormalizationOptions",
249          "BasicRNNOptions": "RNNOptions",
250      }
251      return (old_to_new[operator_type]
252              if operator_type in old_to_new else operator_type)
253
254    for subgraph in data["subgraphs"]:
255      for ops in subgraph["operators"]:
256        ops["builtin_options_type"] = RemapOperatorType(
257            ops["builtin_options_type"])
258
259    # Upgrade the operator codes
260    for operator_code in data["operator_codes"]:
261      # Check if builtin_code is the appropriate string type
262      # use type("") instead of str or unicode. for py2and3
263      if not isinstance(operator_code["builtin_code"], type(u"")):
264        raise ValueError("builtin_code %r is non-string. this usually means "
265                         "your model has consistency problems." %
266                         (operator_code["builtin_code"]))
267      operator_code["builtin_code"] = (RemapOperator(
268          operator_code["builtin_code"]))
269
270  def _Upgrade2To3(self, data):
271    """Upgrade data from Version 2 to Version 3.
272
273    Changed actual read-only tensor data to be in a buffers table instead
274    of inline with the tensor.
275
276    Args:
277      data: Dictionary representing the TensorFlow lite data to be upgraded.
278        This will be modified in-place to be an upgraded version.
279    """
280    buffers = [{"data": []}]  # Start with 1 empty buffer
281    for subgraph in data["subgraphs"]:
282      if "tensors" not in subgraph:
283        continue
284      for tensor in subgraph["tensors"]:
285        if "data_buffer" not in tensor:
286          tensor["buffer"] = 0
287        else:
288          if tensor["data_buffer"]:
289            tensor[u"buffer"] = len(buffers)
290            buffers.append({"data": tensor["data_buffer"]})
291          else:
292            tensor["buffer"] = 0
293          del tensor["data_buffer"]
294    data["buffers"] = buffers
295
296  def _PerformUpgrade(self, data):
297    """Manipulate the `data` (parsed JSON) based on changes in format.
298
299    This incrementally will upgrade from version to version within data.
300
301    Args:
302      data: Dictionary representing the TensorFlow data. This will be upgraded
303        in place.
304    """
305    while data["version"] < self._new_version:
306      self._upgrade_dispatch[data["version"]](data)
307      data["version"] += 1
308
309  def Convert(self, input_file, output_file):
310    """Perform schema conversion from input_file to output_file.
311
312    Args:
313      input_file: Filename of TensorFlow Lite data to convert from. Must
314        be `.json` or `.bin` extension files for JSON or Binary forms of
315        the TensorFlow FlatBuffer schema.
316      output_file: Filename to write to. Extension also must be `.json`
317        or `.bin`.
318
319    Raises:
320      RuntimeError: Generated when none of the upgrader supported schemas
321        matche the `input_file` data.
322    """
323    # Read data in each schema (since they are incompatible). Version is
324    # always present. Use the read data that matches the version of the
325    # schema.
326    for version, schema, raw_binary, _ in self._schemas:
327      try:
328        data_candidate = self._Read(input_file, schema, raw_binary)
329      except RuntimeError:
330        continue  # Skip and hope another schema works
331      if "version" not in data_candidate:  # Assume version 1 if not present.
332        data_candidate["version"] = 1
333      elif data_candidate["version"] == 0:  # Version 0 doesn't exist in wild.
334        data_candidate["version"] = 1
335
336      if data_candidate["version"] == version:
337        self._PerformUpgrade(data_candidate)
338        self._Write(data_candidate, output_file)
339        return
340    raise RuntimeError("No schema that the converter understands worked with "
341                       "the data file you provided.")
342
343
344def main(argv):
345  del argv
346  Converter().Convert(FLAGS.input, FLAGS.output)
347
348
349if __name__ == "__main__":
350  FLAGS, unparsed = parser.parse_known_args()
351  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
352