1# ============================================================================== 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Upgrade script to move from pre-release schema to new schema. 16 17Usage examples: 18 19bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json 20bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin 21bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json 22bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin 23bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite 24""" 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28 29import argparse 30import contextlib 31import json 32import os 33import shutil 34import subprocess 35import sys 36import tempfile 37 38import tensorflow as tf 39from tensorflow.python.platform import resource_loader 40 41parser = argparse.ArgumentParser( 42 description="Script to move TFLite models from pre-release schema to " 43 "new schema.") 44parser.add_argument( 45 "input", 46 type=str, 47 help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.") 48parser.add_argument( 49 "output", 50 type=str, 51 help="Output json or bin TensorFlow lite model compliant with " 52 "the new schema. Extension must be `.json`, `.bin` or `.tflite`.") 53 54 55# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles. 56@contextlib.contextmanager 57def TemporaryDirectoryResource(): 58 temporary = tempfile.mkdtemp() 59 try: 60 yield temporary 61 finally: 62 shutil.rmtree(temporary) 63 64 65class Converter(object): 66 """Converts TensorFlow flatbuffer models from old to new version of schema. 67 68 This can convert between any version to the latest version. It uses 69 an incremental upgrade strategy to go from version to version. 70 71 Usage: 72 converter = Converter() 73 converter.Convert("a.tflite", "a.json") 74 converter.Convert("b.json", "b.tflite") 75 """ 76 77 def __init__(self): 78 # TODO(aselle): make this work in the open source version with better 79 # path. 80 paths_to_try = [ 81 "../../../../flatbuffers/flatc", # not bazel 82 "../../../../external/flatbuffers/flatc" # bazel 83 ] 84 for p in paths_to_try: 85 self._flatc_path = resource_loader.get_path_to_datafile(p) 86 if os.path.exists(self._flatc_path): break 87 88 def FindSchema(base_name): 89 return resource_loader.get_path_to_datafile("%s" % base_name) 90 91 # Supported schemas for upgrade. 92 self._schemas = [ 93 (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1), 94 (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2), 95 (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3), 96 (3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design. 97 ] 98 # Ensure schemas are sorted, and extract latest version and upgrade 99 # dispatch function table. 100 self._schemas.sort() 101 self._new_version, self._new_schema = self._schemas[-1][:2] 102 self._upgrade_dispatch = { 103 version: dispatch 104 for version, unused1, unused2, dispatch in self._schemas} 105 106 def _Read(self, input_file, schema, raw_binary=False): 107 """Read a tflite model assuming the given flatbuffer schema. 108 109 If `input_file` is in bin, then we must use flatc to convert the schema 110 from binary to json. 111 112 Args: 113 input_file: a binary (flatbuffer) or json file to read from. Extension 114 must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or 115 FlatBuffer JSON. 116 schema: which schema to use for reading 117 raw_binary: whether to assume raw_binary (versions previous to v3) 118 that lacked file_identifier require this. 119 120 Raises: 121 RuntimeError: 1. When flatc cannot be invoked. 122 2. When json file does not exists. 123 ValueError: When the extension is not json or bin. 124 125 Returns: 126 A dictionary representing the read tflite model. 127 """ 128 raw_binary = ["--raw-binary"] if raw_binary else [] 129 with TemporaryDirectoryResource() as tempdir: 130 basename = os.path.basename(input_file) 131 basename_no_extension, extension = os.path.splitext(basename) 132 if extension in [".bin", ".tflite"]: 133 # Convert to json using flatc 134 returncode = subprocess.call([ 135 self._flatc_path, 136 "-t", 137 "--strict-json", 138 "--defaults-json", 139 ] + raw_binary + ["-o", tempdir, schema, "--", input_file]) 140 if returncode != 0: 141 raise RuntimeError("flatc failed to convert from binary to json.") 142 json_file = os.path.join(tempdir, basename_no_extension + ".json") 143 if not os.path.exists(json_file): 144 raise RuntimeError("Could not find %r" % json_file) 145 elif extension == ".json": 146 json_file = input_file 147 else: 148 raise ValueError("Invalid extension on input file %r" % input_file) 149 return json.load(open(json_file)) 150 151 def _Write(self, data, output_file): 152 """Output a json or bin version of the flatbuffer model. 153 154 Args: 155 data: Dict representing the TensorFlow Lite model to write. 156 output_file: filename to write the converted flatbuffer to. (json, 157 tflite, or bin extension is required). 158 Raises: 159 ValueError: When the extension is not json or bin 160 RuntimeError: When flatc fails to convert json data to binary. 161 """ 162 _, extension = os.path.splitext(output_file) 163 with TemporaryDirectoryResource() as tempdir: 164 if extension == ".json": 165 json.dump(data, open(output_file, "w"), sort_keys=True, indent=2) 166 elif extension in [".tflite", ".bin"]: 167 input_json = os.path.join(tempdir, "temp.json") 168 with open(input_json, "w") as fp: 169 json.dump(data, fp, sort_keys=True, indent=2) 170 returncode = subprocess.call([ 171 self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o", 172 tempdir, self._new_schema, input_json 173 ]) 174 if returncode != 0: 175 raise RuntimeError("flatc failed to convert upgraded json to binary.") 176 177 shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file) 178 else: 179 raise ValueError("Invalid extension on output file %r" % output_file) 180 181 def _Upgrade0To1(self, data): 182 """Upgrade data from Version 0 to Version 1. 183 184 Changes: Added subgraphs (which contains a subset of formally global 185 entries). 186 187 Args: 188 data: Dictionary representing the TensorFlow lite data to be upgraded. 189 This will be modified in-place to be an upgraded version. 190 """ 191 subgraph = {} 192 for key_to_promote in ["tensors", "operators", "inputs", "outputs"]: 193 subgraph[key_to_promote] = data[key_to_promote] 194 del data[key_to_promote] 195 data["subgraphs"] = [subgraph] 196 197 def _Upgrade1To2(self, data): 198 """Upgrade data from Version 1 to Version 2. 199 200 Changes: Rename operators to Conform to NN API. 201 202 Args: 203 data: Dictionary representing the TensorFlow lite data to be upgraded. 204 This will be modified in-place to be an upgraded version. 205 Raises: 206 ValueError: Throws when model builtins are numeric rather than symbols. 207 """ 208 209 def RemapOperator(opcode_name): 210 """Go from old schema op name to new schema op name. 211 212 Args: 213 opcode_name: String representing the ops (see :schema.fbs). 214 Returns: 215 Converted opcode_name from V1 to V2. 216 """ 217 old_name_to_new_name = { 218 "CONVOLUTION": "CONV_2D", 219 "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D", 220 "AVERAGE_POOL": "AVERAGE_POOL_2D", 221 "MAX_POOL": "MAX_POOL_2D", 222 "L2_POOL": "L2_POOL_2D", 223 "SIGMOID": "LOGISTIC", 224 "L2NORM": "L2_NORMALIZATION", 225 "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION", 226 "Basic_RNN": "RNN", 227 } 228 229 return (old_name_to_new_name[opcode_name] 230 if opcode_name in old_name_to_new_name else opcode_name) 231 232 def RemapOperatorType(operator_type): 233 """Remap operator structs from old names to new names. 234 235 Args: 236 operator_type: String representing the builtin operator data type 237 string. 238 (see :schema.fbs). 239 Raises: 240 ValueError: When the model has consistency problems. 241 Returns: 242 Upgraded builtin operator data type as a string. 243 """ 244 old_to_new = { 245 "PoolOptions": "Pool2DOptions", 246 "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions", 247 "ConvolutionOptions": "Conv2DOptions", 248 "LocalResponseNormOptions": "LocalResponseNormalizationOptions", 249 "BasicRNNOptions": "RNNOptions", 250 } 251 return (old_to_new[operator_type] 252 if operator_type in old_to_new else operator_type) 253 254 for subgraph in data["subgraphs"]: 255 for ops in subgraph["operators"]: 256 ops["builtin_options_type"] = RemapOperatorType( 257 ops["builtin_options_type"]) 258 259 # Upgrade the operator codes 260 for operator_code in data["operator_codes"]: 261 # Check if builtin_code is the appropriate string type 262 # use type("") instead of str or unicode. for py2and3 263 if not isinstance(operator_code["builtin_code"], type(u"")): 264 raise ValueError("builtin_code %r is non-string. this usually means " 265 "your model has consistency problems." % 266 (operator_code["builtin_code"])) 267 operator_code["builtin_code"] = (RemapOperator( 268 operator_code["builtin_code"])) 269 270 def _Upgrade2To3(self, data): 271 """Upgrade data from Version 2 to Version 3. 272 273 Changed actual read-only tensor data to be in a buffers table instead 274 of inline with the tensor. 275 276 Args: 277 data: Dictionary representing the TensorFlow lite data to be upgraded. 278 This will be modified in-place to be an upgraded version. 279 """ 280 buffers = [{"data": []}] # Start with 1 empty buffer 281 for subgraph in data["subgraphs"]: 282 if "tensors" not in subgraph: 283 continue 284 for tensor in subgraph["tensors"]: 285 if "data_buffer" not in tensor: 286 tensor["buffer"] = 0 287 else: 288 if tensor["data_buffer"]: 289 tensor[u"buffer"] = len(buffers) 290 buffers.append({"data": tensor["data_buffer"]}) 291 else: 292 tensor["buffer"] = 0 293 del tensor["data_buffer"] 294 data["buffers"] = buffers 295 296 def _PerformUpgrade(self, data): 297 """Manipulate the `data` (parsed JSON) based on changes in format. 298 299 This incrementally will upgrade from version to version within data. 300 301 Args: 302 data: Dictionary representing the TensorFlow data. This will be upgraded 303 in place. 304 """ 305 while data["version"] < self._new_version: 306 self._upgrade_dispatch[data["version"]](data) 307 data["version"] += 1 308 309 def Convert(self, input_file, output_file): 310 """Perform schema conversion from input_file to output_file. 311 312 Args: 313 input_file: Filename of TensorFlow Lite data to convert from. Must 314 be `.json` or `.bin` extension files for JSON or Binary forms of 315 the TensorFlow FlatBuffer schema. 316 output_file: Filename to write to. Extension also must be `.json` 317 or `.bin`. 318 319 Raises: 320 RuntimeError: Generated when none of the upgrader supported schemas 321 matche the `input_file` data. 322 """ 323 # Read data in each schema (since they are incompatible). Version is 324 # always present. Use the read data that matches the version of the 325 # schema. 326 for version, schema, raw_binary, _ in self._schemas: 327 try: 328 data_candidate = self._Read(input_file, schema, raw_binary) 329 except RuntimeError: 330 continue # Skip and hope another schema works 331 if "version" not in data_candidate: # Assume version 1 if not present. 332 data_candidate["version"] = 1 333 elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild. 334 data_candidate["version"] = 1 335 336 if data_candidate["version"] == version: 337 self._PerformUpgrade(data_candidate) 338 self._Write(data_candidate, output_file) 339 return 340 raise RuntimeError("No schema that the converter understands worked with " 341 "the data file you provided.") 342 343 344def main(argv): 345 del argv 346 Converter().Convert(FLAGS.input, FLAGS.output) 347 348 349if __name__ == "__main__": 350 FLAGS, unparsed = parser.parse_known_args() 351 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 352