1# ============================================================================== 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Upgrade script to move from pre-release schema to new schema. 16 17Usage examples: 18 19bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json 20bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin 21bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json 22bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin 23bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite 24""" 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28 29import argparse 30import contextlib 31import json 32import os 33import shutil 34import subprocess 35import sys 36import tempfile 37 38import tensorflow as tf 39from tensorflow.python.platform import resource_loader 40 41parser = argparse.ArgumentParser( 42 description="Script to move TFLite models from pre-release schema to " 43 "new schema.") 44parser.add_argument( 45 "input", 46 type=str, 47 help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.") 48parser.add_argument( 49 "output", 50 type=str, 51 help="Output json or bin TensorFlow lite model compliant with " 52 "the new schema. Extension must be `.json`, `.bin` or `.tflite`.") 53 54 55# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles. 56@contextlib.contextmanager 57def TemporaryDirectoryResource(): 58 temporary = tempfile.mkdtemp() 59 try: 60 yield temporary 61 finally: 62 shutil.rmtree(temporary) 63 64 65class Converter(object): 66 """Converts TensorFlow flatbuffer models from old to new version of schema. 67 68 This can convert between any version to the latest version. It uses 69 an incremental upgrade strategy to go from version to version. 70 71 Usage: 72 converter = Converter() 73 converter.Convert("a.tflite", "a.json") 74 converter.Convert("b.json", "b.tflite") 75 """ 76 77 def __init__(self): 78 # TODO(aselle): make this work in the open source version with better 79 # path. 80 paths_to_try = [ 81 "../../../../flatbuffers/flatc", # not bazel 82 "../../../../external/flatbuffers/flatc" # bazel 83 ] 84 for p in paths_to_try: 85 self._flatc_path = resource_loader.get_path_to_datafile(p) 86 if os.path.exists(self._flatc_path): break 87 88 def FindSchema(base_name): 89 return resource_loader.get_path_to_datafile("%s" % base_name) 90 91 # Supported schemas for upgrade. 92 self._schemas = [ 93 (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1), 94 (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2), 95 (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3), 96 (3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design. 97 ] 98 # Ensure schemas are sorted, and extract latest version and upgrade 99 # dispatch function table. 100 self._schemas.sort() 101 self._new_version, self._new_schema = self._schemas[-1][:2] 102 self._upgrade_dispatch = { 103 version: dispatch 104 for version, unused1, unused2, dispatch in self._schemas} 105 106 def _Read(self, input_file, schema, raw_binary=False): 107 """Read a tflite model assuming the given flatbuffer schema. 108 109 If `input_file` is in bin, then we must use flatc to convert the schema 110 from binary to json. 111 112 Args: 113 input_file: a binary (flatbuffer) or json file to read from. Extension 114 must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or 115 FlatBuffer JSON. 116 schema: which schema to use for reading 117 raw_binary: whether to assume raw_binary (versions previous to v3) 118 that lacked file_identifier require this. 119 120 Raises: 121 RuntimeError: When flatc cannot be invoked. 122 ValueError: When the extension is not json or bin. 123 124 Returns: 125 A dictionary representing the read tflite model. 126 """ 127 raw_binary = ["--raw-binary"] if raw_binary else [] 128 with TemporaryDirectoryResource() as tempdir: 129 basename = os.path.basename(input_file) 130 basename_no_extension, extension = os.path.splitext(basename) 131 if extension in [".bin", ".tflite"]: 132 # Convert to json using flatc 133 returncode = subprocess.call([ 134 self._flatc_path, 135 "-t", 136 "--strict-json", 137 "--defaults-json", 138 ] + raw_binary + ["-o", tempdir, schema, "--", input_file]) 139 if returncode != 0: 140 raise RuntimeError("flatc failed to convert from binary to json.") 141 json_file = os.path.join(tempdir, basename_no_extension + ".json") 142 if not os.path.exists(json_file): 143 raise RuntimeError("Could not find %r" % json_file) 144 elif extension == ".json": 145 json_file = input_file 146 else: 147 raise ValueError("Invalid extension on input file %r" % input_file) 148 return json.load(open(json_file)) 149 150 def _Write(self, data, output_file): 151 """Output a json or bin version of the flatbuffer model. 152 153 Args: 154 data: Dict representing the TensorFlow Lite model to write. 155 output_file: filename to write the converted flatbuffer to. (json, 156 tflite, or bin extension is required). 157 Raises: 158 ValueError: When the extension is not json or bin 159 RuntimeError: When flatc fails to convert json data to binary. 160 """ 161 _, extension = os.path.splitext(output_file) 162 with TemporaryDirectoryResource() as tempdir: 163 if extension == ".json": 164 json.dump(data, open(output_file, "w"), sort_keys=True, indent=2) 165 elif extension in [".tflite", ".bin"]: 166 input_json = os.path.join(tempdir, "temp.json") 167 with open(input_json, "w") as fp: 168 json.dump(data, fp, sort_keys=True, indent=2) 169 returncode = subprocess.call([ 170 self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o", 171 tempdir, self._new_schema, input_json 172 ]) 173 if returncode != 0: 174 raise RuntimeError("flatc failed to convert upgraded json to binary.") 175 176 shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file) 177 else: 178 raise ValueError("Invalid extension on output file %r" % output_file) 179 180 def _Upgrade0To1(self, data): 181 """Upgrade data from Version 0 to Version 1. 182 183 Changes: Added subgraphs (which contains a subset of formally global 184 entries). 185 186 Args: 187 data: Dictionary representing the TensorFlow lite data to be upgraded. 188 This will be modified in-place to be an upgraded version. 189 """ 190 subgraph = {} 191 for key_to_promote in ["tensors", "operators", "inputs", "outputs"]: 192 subgraph[key_to_promote] = data[key_to_promote] 193 del data[key_to_promote] 194 data["subgraphs"] = [subgraph] 195 196 def _Upgrade1To2(self, data): 197 """Upgrade data from Version 1 to Version 2. 198 199 Changes: Rename operators to Conform to NN API. 200 201 Args: 202 data: Dictionary representing the TensorFlow lite data to be upgraded. 203 This will be modified in-place to be an upgraded version. 204 Raises: 205 ValueError: Throws when model builtins are numeric rather than symbols. 206 """ 207 208 def RemapOperator(opcode_name): 209 """Go from old schema op name to new schema op name. 210 211 Args: 212 opcode_name: String representing the ops (see :schema.fbs). 213 Returns: 214 Converted opcode_name from V1 to V2. 215 """ 216 old_name_to_new_name = { 217 "CONVOLUTION": "CONV_2D", 218 "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D", 219 "AVERAGE_POOL": "AVERAGE_POOL_2D", 220 "MAX_POOL": "MAX_POOL_2D", 221 "L2_POOL": "L2_POOL_2D", 222 "SIGMOID": "LOGISTIC", 223 "L2NORM": "L2_NORMALIZATION", 224 "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION", 225 "Basic_RNN": "RNN", 226 } 227 228 return (old_name_to_new_name[opcode_name] 229 if opcode_name in old_name_to_new_name else opcode_name) 230 231 def RemapOperatorType(operator_type): 232 """Remap operator structs from old names to new names. 233 234 Args: 235 operator_type: String representing the builtin operator data type 236 string. 237 (see :schema.fbs). 238 Returns: 239 Upgraded builtin operator data type as a string. 240 """ 241 old_to_new = { 242 "PoolOptions": "Pool2DOptions", 243 "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions", 244 "ConvolutionOptions": "Conv2DOptions", 245 "LocalResponseNormOptions": "LocalResponseNormalizationOptions", 246 "BasicRNNOptions": "RNNOptions", 247 } 248 return (old_to_new[operator_type] 249 if operator_type in old_to_new else operator_type) 250 251 for subgraph in data["subgraphs"]: 252 for ops in subgraph["operators"]: 253 ops["builtin_options_type"] = RemapOperatorType( 254 ops["builtin_options_type"]) 255 256 # Upgrade the operator codes 257 for operator_code in data["operator_codes"]: 258 # Check if builtin_code is the appropriate string type 259 # use type("") instead of str or unicode. for py2and3 260 if not isinstance(operator_code["builtin_code"], type(u"")): 261 raise ValueError("builtin_code %r is non-string. this usually means " 262 "your model has consistency problems." % 263 (operator_code["builtin_code"])) 264 operator_code["builtin_code"] = (RemapOperator( 265 operator_code["builtin_code"])) 266 267 def _Upgrade2To3(self, data): 268 """Upgrade data from Version 2 to Version 3. 269 270 Changed actual read-only tensor data to be in a buffers table instead 271 of inline with the tensor. 272 273 Args: 274 data: Dictionary representing the TensorFlow lite data to be upgraded. 275 This will be modified in-place to be an upgraded version. 276 """ 277 buffers = [{"data": []}] # Start with 1 empty buffer 278 for subgraph in data["subgraphs"]: 279 if "tensors" not in subgraph: 280 continue 281 for tensor in subgraph["tensors"]: 282 if "data_buffer" not in tensor: 283 tensor["buffer"] = 0 284 else: 285 if tensor["data_buffer"]: 286 tensor[u"buffer"] = len(buffers) 287 buffers.append({"data": tensor["data_buffer"]}) 288 else: 289 tensor["buffer"] = 0 290 del tensor["data_buffer"] 291 data["buffers"] = buffers 292 293 def _PerformUpgrade(self, data): 294 """Manipulate the `data` (parsed JSON) based on changes in format. 295 296 This incrementally will upgrade from version to version within data. 297 298 Args: 299 data: Dictionary representing the TensorFlow data. This will be upgraded 300 in place. 301 """ 302 while data["version"] < self._new_version: 303 self._upgrade_dispatch[data["version"]](data) 304 data["version"] += 1 305 306 def Convert(self, input_file, output_file): 307 """Perform schema conversion from input_file to output_file. 308 309 Args: 310 input_file: Filename of TensorFlow Lite data to convert from. Must 311 be `.json` or `.bin` extension files for JSON or Binary forms of 312 the TensorFlow FlatBuffer schema. 313 output_file: Filename to write to. Extension also must be `.json` 314 or `.bin`. 315 316 Raises: 317 RuntimeError: Generated when none of the upgrader supported schemas 318 matche the `input_file` data. 319 """ 320 # Read data in each schema (since they are incompatible). Version is 321 # always present. Use the read data that matches the version of the 322 # schema. 323 for version, schema, raw_binary, _ in self._schemas: 324 try: 325 data_candidate = self._Read(input_file, schema, raw_binary) 326 except RuntimeError: 327 continue # Skip and hope another schema works 328 if "version" not in data_candidate: # Assume version 1 if not present. 329 data_candidate["version"] = 1 330 elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild. 331 data_candidate["version"] = 1 332 333 if data_candidate["version"] == version: 334 self._PerformUpgrade(data_candidate) 335 self._Write(data_candidate, output_file) 336 return 337 raise RuntimeError("No schema that the converter understands worked with " 338 "the data file you provided.") 339 340 341def main(argv): 342 del argv 343 Converter().Convert(FLAGS.input, FLAGS.output) 344 345 346if __name__ == "__main__": 347 FLAGS, unparsed = parser.parse_known_args() 348 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 349