• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Testing for updating TensorFlow lite schema."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import json
22import tempfile
23from tensorflow.lite.schema import upgrade_schema as upgrade_schema_lib
24from tensorflow.python.framework import test_util
25from tensorflow.python.platform import test as test_lib
26
27EMPTY_TEST_SCHEMA_V1 = {
28    "version": 1,
29    "operator_codes": [],
30    "subgraphs": [],
31}
32
33EMPTY_TEST_SCHEMA_V3 = {
34    "version": 3,
35    "operator_codes": [],
36    "subgraphs": [],
37    "buffers": [{
38        "data": []
39    }]
40}
41
42TEST_SCHEMA_V0 = {
43    "operator_codes": [],
44    "tensors": [],
45    "inputs": [],
46    "outputs": [],
47    "operators": [],
48    "version": 0
49}
50
51TEST_SCHEMA_V3 = {
52    "operator_codes": [],
53    "buffers": [{
54        "data": []
55    }],
56    "subgraphs": [{
57        "tensors": [],
58        "inputs": [],
59        "outputs": [],
60        "operators": [],
61    }],
62    "version":
63        3
64}
65
66FULL_TEST_SCHEMA_V1 = {
67    "version":
68        1,
69    "operator_codes": [
70        {
71            "builtin_code": "CONVOLUTION"
72        },
73        {
74            "builtin_code": "DEPTHWISE_CONVOLUTION"
75        },
76        {
77            "builtin_code": "AVERAGE_POOL"
78        },
79        {
80            "builtin_code": "MAX_POOL"
81        },
82        {
83            "builtin_code": "L2_POOL"
84        },
85        {
86            "builtin_code": "SIGMOID"
87        },
88        {
89            "builtin_code": "L2NORM"
90        },
91        {
92            "builtin_code": "LOCAL_RESPONSE_NORM"
93        },
94        {
95            "builtin_code": "ADD"
96        },
97        {
98            "builtin_code": "Basic_RNN"
99        },
100    ],
101    "subgraphs": [{
102        "operators": [
103            {
104                "builtin_options_type": "PoolOptions"
105            },
106            {
107                "builtin_options_type": "DepthwiseConvolutionOptions"
108            },
109            {
110                "builtin_options_type": "ConvolutionOptions"
111            },
112            {
113                "builtin_options_type": "LocalResponseNormOptions"
114            },
115            {
116                "builtin_options_type": "BasicRNNOptions"
117            },
118        ],
119    }],
120    "description":
121        "",
122}
123
124FULL_TEST_SCHEMA_V3 = {
125    "version":
126        3,
127    "operator_codes": [
128        {
129            "builtin_code": "CONV_2D"
130        },
131        {
132            "builtin_code": "DEPTHWISE_CONV_2D"
133        },
134        {
135            "builtin_code": "AVERAGE_POOL_2D"
136        },
137        {
138            "builtin_code": "MAX_POOL_2D"
139        },
140        {
141            "builtin_code": "L2_POOL_2D"
142        },
143        {
144            "builtin_code": "LOGISTIC"
145        },
146        {
147            "builtin_code": "L2_NORMALIZATION"
148        },
149        {
150            "builtin_code": "LOCAL_RESPONSE_NORMALIZATION"
151        },
152        {
153            "builtin_code": "ADD"
154        },
155        {
156            "builtin_code": "RNN"
157        },
158    ],
159    "subgraphs": [{
160        "operators": [
161            {
162                "builtin_options_type": "Pool2DOptions"
163            },
164            {
165                "builtin_options_type": "DepthwiseConv2DOptions"
166            },
167            {
168                "builtin_options_type": "Conv2DOptions"
169            },
170            {
171                "builtin_options_type": "LocalResponseNormalizationOptions"
172            },
173            {
174                "builtin_options_type": "RNNOptions"
175            },
176        ],
177    }],
178    "description":
179        "",
180    "buffers": [{
181        "data": []
182    }]
183}
184
185BUFFER_TEST_V2 = {
186    "operator_codes": [],
187    "buffers": [],
188    "subgraphs": [{
189        "tensors": [
190            {
191                "data_buffer": [1, 2, 3, 4]
192            },
193            {
194                "data_buffer": [1, 2, 3, 4, 5, 6, 7, 8]
195            },
196            {
197                "data_buffer": []
198            },
199        ],
200        "inputs": [],
201        "outputs": [],
202        "operators": [],
203    }],
204    "version":
205        2
206}
207
208BUFFER_TEST_V3 = {
209    "operator_codes": [],
210    "subgraphs": [{
211        "tensors": [
212            {
213                "buffer": 1
214            },
215            {
216                "buffer": 2
217            },
218            {
219                "buffer": 0
220            },
221        ],
222        "inputs": [],
223        "outputs": [],
224        "operators": [],
225    }],
226    "buffers": [
227        {
228            "data": []
229        },
230        {
231            "data": [1, 2, 3, 4]
232        },
233        {
234            "data": [1, 2, 3, 4, 5, 6, 7, 8]
235        },
236    ],
237    "version":
238        3
239}
240
241
242def JsonDumpAndFlush(data, fp):
243  """Write the dictionary `data` to a JSON file `fp` (and flush).
244
245  Args:
246    data: in a dictionary that is JSON serializable.
247    fp: File-like object
248  """
249  json.dump(data, fp)
250  fp.flush()
251
252
253class TestSchemaUpgrade(test_util.TensorFlowTestCase):
254
255  def testNonExistentFile(self):
256    converter = upgrade_schema_lib.Converter()
257    non_existent = tempfile.mktemp(suffix=".json")
258    with self.assertRaisesRegex(IOError, "No such file or directory"):
259      converter.Convert(non_existent, non_existent)
260
261  def testInvalidExtension(self):
262    converter = upgrade_schema_lib.Converter()
263    invalid_extension = tempfile.mktemp(suffix=".foo")
264    with self.assertRaisesRegex(ValueError, "Invalid extension on input"):
265      converter.Convert(invalid_extension, invalid_extension)
266    with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json:
267      JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json)
268      with self.assertRaisesRegex(ValueError, "Invalid extension on output"):
269        converter.Convert(in_json.name, invalid_extension)
270
271  def CheckConversion(self, data_old, data_expected):
272    """Given a data dictionary, test upgrading to current version.
273
274    Args:
275        data_old: TFLite model as a dictionary (arbitrary version).
276        data_expected: TFLite model as a dictionary (upgraded).
277    """
278    converter = upgrade_schema_lib.Converter()
279    with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json, \
280            tempfile.NamedTemporaryFile(
281                suffix=".json", mode="w+") as out_json, \
282            tempfile.NamedTemporaryFile(
283                suffix=".bin", mode="w+b") as out_bin, \
284            tempfile.NamedTemporaryFile(
285                suffix=".tflite", mode="w+b") as out_tflite:
286      JsonDumpAndFlush(data_old, in_json)
287      # Test JSON output
288      converter.Convert(in_json.name, out_json.name)
289      # Test binary output
290      # Convert to .tflite  and then to .bin and check if binary is equal
291      converter.Convert(in_json.name, out_tflite.name)
292      converter.Convert(out_tflite.name, out_bin.name)
293      self.assertEqual(
294          open(out_bin.name, "rb").read(),
295          open(out_tflite.name, "rb").read())
296      # Test that conversion actually produced successful new json.
297      converted_schema = json.load(out_json)
298      self.assertEqual(converted_schema, data_expected)
299
300  def testAlreadyUpgraded(self):
301    """A file already at version 3 should stay at version 3."""
302    self.CheckConversion(EMPTY_TEST_SCHEMA_V3, EMPTY_TEST_SCHEMA_V3)
303    self.CheckConversion(TEST_SCHEMA_V3, TEST_SCHEMA_V3)
304    self.CheckConversion(BUFFER_TEST_V3, BUFFER_TEST_V3)
305
306  # Disable this while we have incorrectly versioned structures around.
307  # def testV0Upgrade_IntroducesSubgraphs(self):
308  #   """V0 did not have subgraphs; check to make sure they get introduced."""
309  #   self.CheckConversion(TEST_SCHEMA_V0, TEST_SCHEMA_V3)
310
311  def testV1Upgrade_RenameOps(self):
312    """V1 had many different names for ops; check to make sure they rename."""
313    self.CheckConversion(EMPTY_TEST_SCHEMA_V1, EMPTY_TEST_SCHEMA_V3)
314    self.CheckConversion(FULL_TEST_SCHEMA_V1, FULL_TEST_SCHEMA_V3)
315
316  def testV2Upgrade_CreateBuffers(self):
317    """V2 did not have buffers; check to make sure they are created."""
318    self.CheckConversion(BUFFER_TEST_V2, BUFFER_TEST_V3)
319
320
321if __name__ == "__main__":
322  test_lib.main()
323