1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Testing for updating TensorFlow lite schema.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import json 22import tempfile 23from tensorflow.lite.schema import upgrade_schema as upgrade_schema_lib 24from tensorflow.python.framework import test_util 25from tensorflow.python.platform import test as test_lib 26 27EMPTY_TEST_SCHEMA_V1 = { 28 "version": 1, 29 "operator_codes": [], 30 "subgraphs": [], 31} 32 33EMPTY_TEST_SCHEMA_V3 = { 34 "version": 3, 35 "operator_codes": [], 36 "subgraphs": [], 37 "buffers": [{ 38 "data": [] 39 }] 40} 41 42TEST_SCHEMA_V0 = { 43 "operator_codes": [], 44 "tensors": [], 45 "inputs": [], 46 "outputs": [], 47 "operators": [], 48 "version": 0 49} 50 51TEST_SCHEMA_V3 = { 52 "operator_codes": [], 53 "buffers": [{ 54 "data": [] 55 }], 56 "subgraphs": [{ 57 "tensors": [], 58 "inputs": [], 59 "outputs": [], 60 "operators": [], 61 }], 62 "version": 63 3 64} 65 66FULL_TEST_SCHEMA_V1 = { 67 "version": 68 1, 69 "operator_codes": [ 70 { 71 "builtin_code": "CONVOLUTION" 72 }, 73 { 74 "builtin_code": "DEPTHWISE_CONVOLUTION" 75 }, 76 { 77 "builtin_code": "AVERAGE_POOL" 78 }, 79 { 80 "builtin_code": "MAX_POOL" 81 }, 82 { 83 "builtin_code": "L2_POOL" 84 }, 85 { 86 "builtin_code": "SIGMOID" 87 }, 88 { 89 "builtin_code": "L2NORM" 90 }, 91 { 92 "builtin_code": "LOCAL_RESPONSE_NORM" 93 }, 94 { 95 "builtin_code": "ADD" 96 }, 97 { 98 "builtin_code": "Basic_RNN" 99 }, 100 ], 101 "subgraphs": [{ 102 "operators": [ 103 { 104 "builtin_options_type": "PoolOptions" 105 }, 106 { 107 "builtin_options_type": "DepthwiseConvolutionOptions" 108 }, 109 { 110 "builtin_options_type": "ConvolutionOptions" 111 }, 112 { 113 "builtin_options_type": "LocalResponseNormOptions" 114 }, 115 { 116 "builtin_options_type": "BasicRNNOptions" 117 }, 118 ], 119 }], 120 "description": 121 "", 122} 123 124FULL_TEST_SCHEMA_V3 = { 125 "version": 126 3, 127 "operator_codes": [ 128 { 129 "builtin_code": "CONV_2D" 130 }, 131 { 132 "builtin_code": "DEPTHWISE_CONV_2D" 133 }, 134 { 135 "builtin_code": "AVERAGE_POOL_2D" 136 }, 137 { 138 "builtin_code": "MAX_POOL_2D" 139 }, 140 { 141 "builtin_code": "L2_POOL_2D" 142 }, 143 { 144 "builtin_code": "LOGISTIC" 145 }, 146 { 147 "builtin_code": "L2_NORMALIZATION" 148 }, 149 { 150 "builtin_code": "LOCAL_RESPONSE_NORMALIZATION" 151 }, 152 { 153 "builtin_code": "ADD" 154 }, 155 { 156 "builtin_code": "RNN" 157 }, 158 ], 159 "subgraphs": [{ 160 "operators": [ 161 { 162 "builtin_options_type": "Pool2DOptions" 163 }, 164 { 165 "builtin_options_type": "DepthwiseConv2DOptions" 166 }, 167 { 168 "builtin_options_type": "Conv2DOptions" 169 }, 170 { 171 "builtin_options_type": "LocalResponseNormalizationOptions" 172 }, 173 { 174 "builtin_options_type": "RNNOptions" 175 }, 176 ], 177 }], 178 "description": 179 "", 180 "buffers": [{ 181 "data": [] 182 }] 183} 184 185BUFFER_TEST_V2 = { 186 "operator_codes": [], 187 "buffers": [], 188 "subgraphs": [{ 189 "tensors": [ 190 { 191 "data_buffer": [1, 2, 3, 4] 192 }, 193 { 194 "data_buffer": [1, 2, 3, 4, 5, 6, 7, 8] 195 }, 196 { 197 "data_buffer": [] 198 }, 199 ], 200 "inputs": [], 201 "outputs": [], 202 "operators": [], 203 }], 204 "version": 205 2 206} 207 208BUFFER_TEST_V3 = { 209 "operator_codes": [], 210 "subgraphs": [{ 211 "tensors": [ 212 { 213 "buffer": 1 214 }, 215 { 216 "buffer": 2 217 }, 218 { 219 "buffer": 0 220 }, 221 ], 222 "inputs": [], 223 "outputs": [], 224 "operators": [], 225 }], 226 "buffers": [ 227 { 228 "data": [] 229 }, 230 { 231 "data": [1, 2, 3, 4] 232 }, 233 { 234 "data": [1, 2, 3, 4, 5, 6, 7, 8] 235 }, 236 ], 237 "version": 238 3 239} 240 241 242def JsonDumpAndFlush(data, fp): 243 """Write the dictionary `data` to a JSON file `fp` (and flush). 244 245 Args: 246 data: in a dictionary that is JSON serializable. 247 fp: File-like object 248 """ 249 json.dump(data, fp) 250 fp.flush() 251 252 253class TestSchemaUpgrade(test_util.TensorFlowTestCase): 254 255 def testNonExistentFile(self): 256 converter = upgrade_schema_lib.Converter() 257 non_existent = tempfile.mktemp(suffix=".json") 258 with self.assertRaisesRegex(IOError, "No such file or directory"): 259 converter.Convert(non_existent, non_existent) 260 261 def testInvalidExtension(self): 262 converter = upgrade_schema_lib.Converter() 263 invalid_extension = tempfile.mktemp(suffix=".foo") 264 with self.assertRaisesRegex(ValueError, "Invalid extension on input"): 265 converter.Convert(invalid_extension, invalid_extension) 266 with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json: 267 JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json) 268 with self.assertRaisesRegex(ValueError, "Invalid extension on output"): 269 converter.Convert(in_json.name, invalid_extension) 270 271 def CheckConversion(self, data_old, data_expected): 272 """Given a data dictionary, test upgrading to current version. 273 274 Args: 275 data_old: TFLite model as a dictionary (arbitrary version). 276 data_expected: TFLite model as a dictionary (upgraded). 277 """ 278 converter = upgrade_schema_lib.Converter() 279 with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json, \ 280 tempfile.NamedTemporaryFile( 281 suffix=".json", mode="w+") as out_json, \ 282 tempfile.NamedTemporaryFile( 283 suffix=".bin", mode="w+b") as out_bin, \ 284 tempfile.NamedTemporaryFile( 285 suffix=".tflite", mode="w+b") as out_tflite: 286 JsonDumpAndFlush(data_old, in_json) 287 # Test JSON output 288 converter.Convert(in_json.name, out_json.name) 289 # Test binary output 290 # Convert to .tflite and then to .bin and check if binary is equal 291 converter.Convert(in_json.name, out_tflite.name) 292 converter.Convert(out_tflite.name, out_bin.name) 293 self.assertEqual( 294 open(out_bin.name, "rb").read(), 295 open(out_tflite.name, "rb").read()) 296 # Test that conversion actually produced successful new json. 297 converted_schema = json.load(out_json) 298 self.assertEqual(converted_schema, data_expected) 299 300 def testAlreadyUpgraded(self): 301 """A file already at version 3 should stay at version 3.""" 302 self.CheckConversion(EMPTY_TEST_SCHEMA_V3, EMPTY_TEST_SCHEMA_V3) 303 self.CheckConversion(TEST_SCHEMA_V3, TEST_SCHEMA_V3) 304 self.CheckConversion(BUFFER_TEST_V3, BUFFER_TEST_V3) 305 306 # Disable this while we have incorrectly versioned structures around. 307 # def testV0Upgrade_IntroducesSubgraphs(self): 308 # """V0 did not have subgraphs; check to make sure they get introduced.""" 309 # self.CheckConversion(TEST_SCHEMA_V0, TEST_SCHEMA_V3) 310 311 def testV1Upgrade_RenameOps(self): 312 """V1 had many different names for ops; check to make sure they rename.""" 313 self.CheckConversion(EMPTY_TEST_SCHEMA_V1, EMPTY_TEST_SCHEMA_V3) 314 self.CheckConversion(FULL_TEST_SCHEMA_V1, FULL_TEST_SCHEMA_V3) 315 316 def testV2Upgrade_CreateBuffers(self): 317 """V2 did not have buffers; check to make sure they are created.""" 318 self.CheckConversion(BUFFER_TEST_V2, BUFFER_TEST_V3) 319 320 321if __name__ == "__main__": 322 test_lib.main() 323