• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# ==============================================================================
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Upgrade script to move from pre-release schema to new schema.
16
17Usage examples:
18
19bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json
20bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin
21bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json
22bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin
23bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite
24"""
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28
29import argparse
30import contextlib
31import json
32import os
33import shutil
34import subprocess
35import sys
36import tempfile
37
38import tensorflow as tf
39from tensorflow.python.platform import resource_loader
40
41parser = argparse.ArgumentParser(
42    description="Script to move TFLite models from pre-release schema to "
43    "new schema.")
44parser.add_argument(
45    "input",
46    type=str,
47    help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
48parser.add_argument(
49    "output",
50    type=str,
51    help="Output json or bin TensorFlow lite model compliant with "
52    "the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
53
54
55# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
56@contextlib.contextmanager
57def TemporaryDirectoryResource():
58  temporary = tempfile.mkdtemp()
59  try:
60    yield temporary
61  finally:
62    shutil.rmtree(temporary)
63
64
65class Converter(object):
66  """Converts TensorFlow flatbuffer models from old to new version of schema.
67
68  This can convert between any version to the latest version. It uses
69  an incremental upgrade strategy to go from version to version.
70
71  Usage:
72    converter = Converter()
73    converter.Convert("a.tflite", "a.json")
74    converter.Convert("b.json", "b.tflite")
75  """
76
77  def __init__(self):
78    # TODO(aselle): make this work in the open source version with better
79    # path.
80    paths_to_try = [
81        "../../../../flatbuffers/flatc",  # not bazel
82        "../../../../external/flatbuffers/flatc"  # bazel
83    ]
84    for p in paths_to_try:
85      self._flatc_path = resource_loader.get_path_to_datafile(p)
86      if os.path.exists(self._flatc_path): break
87
88    def FindSchema(base_name):
89      return resource_loader.get_path_to_datafile("%s" % base_name)
90
91    # Supported schemas for upgrade.
92    self._schemas = [
93        (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
94        (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
95        (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
96        (3, FindSchema("schema_v3.fbs"), False, None)  # Non-callable by design.
97    ]
98    # Ensure schemas are sorted, and extract latest version and upgrade
99    # dispatch function table.
100    self._schemas.sort()
101    self._new_version, self._new_schema = self._schemas[-1][:2]
102    self._upgrade_dispatch = {
103        version: dispatch
104        for version, unused1, unused2, dispatch in self._schemas}
105
106  def _Read(self, input_file, schema, raw_binary=False):
107    """Read a tflite model assuming the given flatbuffer schema.
108
109    If `input_file` is in bin, then we must use flatc to convert the schema
110    from binary to json.
111
112    Args:
113      input_file: a binary (flatbuffer) or json file to read from. Extension
114        must  be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
115        FlatBuffer JSON.
116      schema: which schema to use for reading
117      raw_binary: whether to assume raw_binary (versions previous to v3)
118        that lacked file_identifier require this.
119
120    Raises:
121      RuntimeError: When flatc cannot be invoked.
122      ValueError: When the extension is not json or bin.
123
124    Returns:
125      A dictionary representing the read tflite model.
126    """
127    raw_binary = ["--raw-binary"] if raw_binary else []
128    with TemporaryDirectoryResource() as tempdir:
129      basename = os.path.basename(input_file)
130      basename_no_extension, extension = os.path.splitext(basename)
131      if extension in [".bin", ".tflite"]:
132        # Convert to json using flatc
133        returncode = subprocess.call([
134            self._flatc_path,
135            "-t",
136            "--strict-json",
137            "--defaults-json",
138        ] + raw_binary + ["-o", tempdir, schema, "--", input_file])
139        if returncode != 0:
140          raise RuntimeError("flatc failed to convert from binary to json.")
141        json_file = os.path.join(tempdir, basename_no_extension + ".json")
142        if not os.path.exists(json_file):
143          raise RuntimeError("Could not find %r" % json_file)
144      elif extension == ".json":
145        json_file = input_file
146      else:
147        raise ValueError("Invalid extension on input file %r" % input_file)
148      return json.load(open(json_file))
149
150  def _Write(self, data, output_file):
151    """Output a json or bin version of the flatbuffer model.
152
153    Args:
154      data: Dict representing the TensorFlow Lite model to write.
155      output_file: filename to write the converted flatbuffer to. (json,
156        tflite, or bin extension is required).
157    Raises:
158      ValueError: When the extension is not json or bin
159      RuntimeError: When flatc fails to convert json data to binary.
160    """
161    _, extension = os.path.splitext(output_file)
162    with TemporaryDirectoryResource() as tempdir:
163      if extension == ".json":
164        json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
165      elif extension in [".tflite", ".bin"]:
166        input_json = os.path.join(tempdir, "temp.json")
167        with open(input_json, "w") as fp:
168          json.dump(data, fp, sort_keys=True, indent=2)
169        returncode = subprocess.call([
170            self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
171            tempdir, self._new_schema, input_json
172        ])
173        if returncode != 0:
174          raise RuntimeError("flatc failed to convert upgraded json to binary.")
175
176        shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
177      else:
178        raise ValueError("Invalid extension on output file %r" % output_file)
179
180  def _Upgrade0To1(self, data):
181    """Upgrade data from Version 0 to Version 1.
182
183    Changes: Added subgraphs (which contains a subset of formally global
184    entries).
185
186    Args:
187      data: Dictionary representing the TensorFlow lite data to be upgraded.
188        This will be modified in-place to be an upgraded version.
189    """
190    subgraph = {}
191    for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
192      subgraph[key_to_promote] = data[key_to_promote]
193      del data[key_to_promote]
194    data["subgraphs"] = [subgraph]
195
196  def _Upgrade1To2(self, data):
197    """Upgrade data from Version 1 to Version 2.
198
199    Changes: Rename operators to Conform to NN API.
200
201    Args:
202      data: Dictionary representing the TensorFlow lite data to be upgraded.
203        This will be modified in-place to be an upgraded version.
204    Raises:
205      ValueError: Throws when model builtins are numeric rather than symbols.
206    """
207
208    def RemapOperator(opcode_name):
209      """Go from old schema op name to new schema op name.
210
211      Args:
212        opcode_name: String representing the ops (see :schema.fbs).
213      Returns:
214        Converted opcode_name from V1 to V2.
215      """
216      old_name_to_new_name = {
217          "CONVOLUTION": "CONV_2D",
218          "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
219          "AVERAGE_POOL": "AVERAGE_POOL_2D",
220          "MAX_POOL": "MAX_POOL_2D",
221          "L2_POOL": "L2_POOL_2D",
222          "SIGMOID": "LOGISTIC",
223          "L2NORM": "L2_NORMALIZATION",
224          "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
225          "Basic_RNN": "RNN",
226      }
227
228      return (old_name_to_new_name[opcode_name]
229              if opcode_name in old_name_to_new_name else opcode_name)
230
231    def RemapOperatorType(operator_type):
232      """Remap operator structs from old names to new names.
233
234      Args:
235        operator_type: String representing the builtin operator data type
236          string.
237        (see :schema.fbs).
238      Returns:
239        Upgraded builtin operator data type as a string.
240      """
241      old_to_new = {
242          "PoolOptions": "Pool2DOptions",
243          "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
244          "ConvolutionOptions": "Conv2DOptions",
245          "LocalResponseNormOptions": "LocalResponseNormalizationOptions",
246          "BasicRNNOptions": "RNNOptions",
247      }
248      return (old_to_new[operator_type]
249              if operator_type in old_to_new else operator_type)
250
251    for subgraph in data["subgraphs"]:
252      for ops in subgraph["operators"]:
253        ops["builtin_options_type"] = RemapOperatorType(
254            ops["builtin_options_type"])
255
256    # Upgrade the operator codes
257    for operator_code in data["operator_codes"]:
258      # Check if builtin_code is the appropriate string type
259      # use type("") instead of str or unicode. for py2and3
260      if not isinstance(operator_code["builtin_code"], type(u"")):
261        raise ValueError("builtin_code %r is non-string. this usually means "
262                         "your model has consistency problems." %
263                         (operator_code["builtin_code"]))
264      operator_code["builtin_code"] = (RemapOperator(
265          operator_code["builtin_code"]))
266
267  def _Upgrade2To3(self, data):
268    """Upgrade data from Version 2 to Version 3.
269
270    Changed actual read-only tensor data to be in a buffers table instead
271    of inline with the tensor.
272
273    Args:
274      data: Dictionary representing the TensorFlow lite data to be upgraded.
275        This will be modified in-place to be an upgraded version.
276    """
277    buffers = [{"data": []}]  # Start with 1 empty buffer
278    for subgraph in data["subgraphs"]:
279      if "tensors" not in subgraph:
280        continue
281      for tensor in subgraph["tensors"]:
282        if "data_buffer" not in tensor:
283          tensor["buffer"] = 0
284        else:
285          if tensor["data_buffer"]:
286            tensor[u"buffer"] = len(buffers)
287            buffers.append({"data": tensor["data_buffer"]})
288          else:
289            tensor["buffer"] = 0
290          del tensor["data_buffer"]
291    data["buffers"] = buffers
292
293  def _PerformUpgrade(self, data):
294    """Manipulate the `data` (parsed JSON) based on changes in format.
295
296    This incrementally will upgrade from version to version within data.
297
298    Args:
299      data: Dictionary representing the TensorFlow data. This will be upgraded
300        in place.
301    """
302    while data["version"] < self._new_version:
303      self._upgrade_dispatch[data["version"]](data)
304      data["version"] += 1
305
306  def Convert(self, input_file, output_file):
307    """Perform schema conversion from input_file to output_file.
308
309    Args:
310      input_file: Filename of TensorFlow Lite data to convert from. Must
311        be `.json` or `.bin` extension files for JSON or Binary forms of
312        the TensorFlow FlatBuffer schema.
313      output_file: Filename to write to. Extension also must be `.json`
314        or `.bin`.
315
316    Raises:
317      RuntimeError: Generated when none of the upgrader supported schemas
318        matche the `input_file` data.
319    """
320    # Read data in each schema (since they are incompatible). Version is
321    # always present. Use the read data that matches the version of the
322    # schema.
323    for version, schema, raw_binary, _ in self._schemas:
324      try:
325        data_candidate = self._Read(input_file, schema, raw_binary)
326      except RuntimeError:
327        continue  # Skip and hope another schema works
328      if "version" not in data_candidate:  # Assume version 1 if not present.
329        data_candidate["version"] = 1
330      elif data_candidate["version"] == 0:  # Version 0 doesn't exist in wild.
331        data_candidate["version"] = 1
332
333      if data_candidate["version"] == version:
334        self._PerformUpgrade(data_candidate)
335        self._Write(data_candidate, output_file)
336        return
337    raise RuntimeError("No schema that the converter understands worked with "
338                       "the data file you provided.")
339
340
341def main(argv):
342  del argv
343  Converter().Convert(FLAGS.input, FLAGS.output)
344
345
346if __name__ == "__main__":
347  FLAGS, unparsed = parser.parse_known_args()
348  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
349