1# ============================================================================== 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Upgrade script to move from pre-release schema to new schema. 16 17Usage examples: 18 19bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json 20bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin 21bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json 22bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin 23bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite 24""" 25import argparse 26import contextlib 27import json 28import os 29import shutil 30import subprocess 31import sys 32import tempfile 33 34import tensorflow as tf 35from tensorflow.python.platform import resource_loader 36 37parser = argparse.ArgumentParser( 38 description="Script to move TFLite models from pre-release schema to " 39 "new schema.") 40parser.add_argument( 41 "input", 42 type=str, 43 help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.") 44parser.add_argument( 45 "output", 46 type=str, 47 help="Output json or bin TensorFlow lite model compliant with " 48 "the new schema. Extension must be `.json`, `.bin` or `.tflite`.") 49 50 51# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles. 52@contextlib.contextmanager 53def TemporaryDirectoryResource(): 54 temporary = tempfile.mkdtemp() 55 try: 56 yield temporary 57 finally: 58 shutil.rmtree(temporary) 59 60 61class Converter: 62 """Converts TensorFlow flatbuffer models from old to new version of schema. 63 64 This can convert between any version to the latest version. It uses 65 an incremental upgrade strategy to go from version to version. 66 67 Usage: 68 converter = Converter() 69 converter.Convert("a.tflite", "a.json") 70 converter.Convert("b.json", "b.tflite") 71 """ 72 73 def __init__(self): 74 # TODO(aselle): make this work in the open source version with better 75 # path. 76 paths_to_try = [ 77 "../../../../flatbuffers/flatc", # not bazel 78 "../../../../external/flatbuffers/flatc" # bazel 79 ] 80 for p in paths_to_try: 81 self._flatc_path = resource_loader.get_path_to_datafile(p) 82 if os.path.exists(self._flatc_path): break 83 84 def FindSchema(base_name): 85 return resource_loader.get_path_to_datafile("%s" % base_name) 86 87 # Supported schemas for upgrade. 88 self._schemas = [ 89 (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1), 90 (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2), 91 (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3), 92 (3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design. 93 ] 94 # Ensure schemas are sorted, and extract latest version and upgrade 95 # dispatch function table. 96 self._schemas.sort() 97 self._new_version, self._new_schema = self._schemas[-1][:2] 98 self._upgrade_dispatch = { 99 version: dispatch 100 for version, unused1, unused2, dispatch in self._schemas} 101 102 def _Read(self, input_file, schema, raw_binary=False): 103 """Read a tflite model assuming the given flatbuffer schema. 104 105 If `input_file` is in bin, then we must use flatc to convert the schema 106 from binary to json. 107 108 Args: 109 input_file: a binary (flatbuffer) or json file to read from. Extension 110 must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or 111 FlatBuffer JSON. 112 schema: which schema to use for reading 113 raw_binary: whether to assume raw_binary (versions previous to v3) 114 that lacked file_identifier require this. 115 116 Raises: 117 RuntimeError: 1. When flatc cannot be invoked. 118 2. When json file does not exists. 119 ValueError: When the extension is not json or bin. 120 121 Returns: 122 A dictionary representing the read tflite model. 123 """ 124 raw_binary = ["--raw-binary"] if raw_binary else [] 125 with TemporaryDirectoryResource() as tempdir: 126 basename = os.path.basename(input_file) 127 basename_no_extension, extension = os.path.splitext(basename) 128 if extension in [".bin", ".tflite"]: 129 # Convert to json using flatc 130 returncode = subprocess.call([ 131 self._flatc_path, 132 "-t", 133 "--strict-json", 134 "--defaults-json", 135 ] + raw_binary + ["-o", tempdir, schema, "--", input_file]) 136 if returncode != 0: 137 raise RuntimeError("flatc failed to convert from binary to json.") 138 json_file = os.path.join(tempdir, basename_no_extension + ".json") 139 if not os.path.exists(json_file): 140 raise RuntimeError("Could not find %r" % json_file) 141 elif extension == ".json": 142 json_file = input_file 143 else: 144 raise ValueError("Invalid extension on input file %r" % input_file) 145 return json.load(open(json_file)) 146 147 def _Write(self, data, output_file): 148 """Output a json or bin version of the flatbuffer model. 149 150 Args: 151 data: Dict representing the TensorFlow Lite model to write. 152 output_file: filename to write the converted flatbuffer to. (json, 153 tflite, or bin extension is required). 154 Raises: 155 ValueError: When the extension is not json or bin 156 RuntimeError: When flatc fails to convert json data to binary. 157 """ 158 _, extension = os.path.splitext(output_file) 159 with TemporaryDirectoryResource() as tempdir: 160 if extension == ".json": 161 json.dump(data, open(output_file, "w"), sort_keys=True, indent=2) 162 elif extension in [".tflite", ".bin"]: 163 input_json = os.path.join(tempdir, "temp.json") 164 with open(input_json, "w") as fp: 165 json.dump(data, fp, sort_keys=True, indent=2) 166 returncode = subprocess.call([ 167 self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o", 168 tempdir, self._new_schema, input_json 169 ]) 170 if returncode != 0: 171 raise RuntimeError("flatc failed to convert upgraded json to binary.") 172 173 shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file) 174 else: 175 raise ValueError("Invalid extension on output file %r" % output_file) 176 177 def _Upgrade0To1(self, data): 178 """Upgrade data from Version 0 to Version 1. 179 180 Changes: Added subgraphs (which contains a subset of formally global 181 entries). 182 183 Args: 184 data: Dictionary representing the TensorFlow lite data to be upgraded. 185 This will be modified in-place to be an upgraded version. 186 """ 187 subgraph = {} 188 for key_to_promote in ["tensors", "operators", "inputs", "outputs"]: 189 subgraph[key_to_promote] = data[key_to_promote] 190 del data[key_to_promote] 191 data["subgraphs"] = [subgraph] 192 193 def _Upgrade1To2(self, data): 194 """Upgrade data from Version 1 to Version 2. 195 196 Changes: Rename operators to Conform to NN API. 197 198 Args: 199 data: Dictionary representing the TensorFlow lite data to be upgraded. 200 This will be modified in-place to be an upgraded version. 201 Raises: 202 ValueError: Throws when model builtins are numeric rather than symbols. 203 """ 204 205 def RemapOperator(opcode_name): 206 """Go from old schema op name to new schema op name. 207 208 Args: 209 opcode_name: String representing the ops (see :schema.fbs). 210 Returns: 211 Converted opcode_name from V1 to V2. 212 """ 213 old_name_to_new_name = { 214 "CONVOLUTION": "CONV_2D", 215 "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D", 216 "AVERAGE_POOL": "AVERAGE_POOL_2D", 217 "MAX_POOL": "MAX_POOL_2D", 218 "L2_POOL": "L2_POOL_2D", 219 "SIGMOID": "LOGISTIC", 220 "L2NORM": "L2_NORMALIZATION", 221 "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION", 222 "Basic_RNN": "RNN", 223 } 224 225 return (old_name_to_new_name[opcode_name] 226 if opcode_name in old_name_to_new_name else opcode_name) 227 228 def RemapOperatorType(operator_type): 229 """Remap operator structs from old names to new names. 230 231 Args: 232 operator_type: String representing the builtin operator data type 233 string. 234 (see :schema.fbs). 235 Raises: 236 ValueError: When the model has consistency problems. 237 Returns: 238 Upgraded builtin operator data type as a string. 239 """ 240 old_to_new = { 241 "PoolOptions": "Pool2DOptions", 242 "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions", 243 "ConvolutionOptions": "Conv2DOptions", 244 "LocalResponseNormOptions": "LocalResponseNormalizationOptions", 245 "BasicRNNOptions": "RNNOptions", 246 } 247 return (old_to_new[operator_type] 248 if operator_type in old_to_new else operator_type) 249 250 for subgraph in data["subgraphs"]: 251 for ops in subgraph["operators"]: 252 ops["builtin_options_type"] = RemapOperatorType( 253 ops["builtin_options_type"]) 254 255 # Upgrade the operator codes 256 for operator_code in data["operator_codes"]: 257 # Check if builtin_code is the appropriate string type 258 # use type("") instead of str or unicode. for py2and3 259 if not isinstance(operator_code["builtin_code"], type(u"")): 260 raise ValueError("builtin_code %r is non-string. this usually means " 261 "your model has consistency problems." % 262 (operator_code["builtin_code"])) 263 operator_code["builtin_code"] = (RemapOperator( 264 operator_code["builtin_code"])) 265 266 def _Upgrade2To3(self, data): 267 """Upgrade data from Version 2 to Version 3. 268 269 Changed actual read-only tensor data to be in a buffers table instead 270 of inline with the tensor. 271 272 Args: 273 data: Dictionary representing the TensorFlow lite data to be upgraded. 274 This will be modified in-place to be an upgraded version. 275 """ 276 buffers = [{"data": []}] # Start with 1 empty buffer 277 for subgraph in data["subgraphs"]: 278 if "tensors" not in subgraph: 279 continue 280 for tensor in subgraph["tensors"]: 281 if "data_buffer" not in tensor: 282 tensor["buffer"] = 0 283 else: 284 if tensor["data_buffer"]: 285 tensor[u"buffer"] = len(buffers) 286 buffers.append({"data": tensor["data_buffer"]}) 287 else: 288 tensor["buffer"] = 0 289 del tensor["data_buffer"] 290 data["buffers"] = buffers 291 292 def _PerformUpgrade(self, data): 293 """Manipulate the `data` (parsed JSON) based on changes in format. 294 295 This incrementally will upgrade from version to version within data. 296 297 Args: 298 data: Dictionary representing the TensorFlow data. This will be upgraded 299 in place. 300 """ 301 while data["version"] < self._new_version: 302 self._upgrade_dispatch[data["version"]](data) 303 data["version"] += 1 304 305 def Convert(self, input_file, output_file): 306 """Perform schema conversion from input_file to output_file. 307 308 Args: 309 input_file: Filename of TensorFlow Lite data to convert from. Must 310 be `.json` or `.bin` extension files for JSON or Binary forms of 311 the TensorFlow FlatBuffer schema. 312 output_file: Filename to write to. Extension also must be `.json` 313 or `.bin`. 314 315 Raises: 316 RuntimeError: Generated when none of the upgrader supported schemas 317 matche the `input_file` data. 318 """ 319 # Read data in each schema (since they are incompatible). Version is 320 # always present. Use the read data that matches the version of the 321 # schema. 322 for version, schema, raw_binary, _ in self._schemas: 323 try: 324 data_candidate = self._Read(input_file, schema, raw_binary) 325 except RuntimeError: 326 continue # Skip and hope another schema works 327 if "version" not in data_candidate: # Assume version 1 if not present. 328 data_candidate["version"] = 1 329 elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild. 330 data_candidate["version"] = 1 331 332 if data_candidate["version"] == version: 333 self._PerformUpgrade(data_candidate) 334 self._Write(data_candidate, output_file) 335 return 336 raise RuntimeError("No schema that the converter understands worked with " 337 "the data file you provided.") 338 339 340def main(argv): 341 del argv 342 Converter().Convert(FLAGS.input, FLAGS.output) 343 344 345if __name__ == "__main__": 346 FLAGS, unparsed = parser.parse_known_args() 347 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 348