• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions to test TFLite models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import numpy as np
23
24from tensorflow.core.framework import graph_pb2 as _graph_pb2
25from tensorflow.lite.python import convert_saved_model as _convert_saved_model
26from tensorflow.lite.python import lite as _lite
27from tensorflow.python import keras as _keras
28from tensorflow.python.client import session as _session
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
31from tensorflow.python.keras.preprocessing import image
32from tensorflow.python.lib.io import file_io as _file_io
33from tensorflow.python.platform import resource_loader as _resource_loader
34from tensorflow.python.saved_model import load as _load
35from tensorflow.python.saved_model import loader as _loader
36from tensorflow.python.saved_model import signature_constants as _signature_constants
37from tensorflow.python.saved_model import tag_constants as _tag_constants
38
39
40def get_filepath(filename, base_dir=None):
41  """Returns the full path of the filename.
42
43  Args:
44    filename: Subdirectory and name of the model file.
45    base_dir: Base directory containing model file.
46
47  Returns:
48    str.
49  """
50  if base_dir is None:
51    base_dir = "learning/brain/mobile/tflite_compat_models"
52  return os.path.join(_resource_loader.get_root_dir_with_all_resources(),
53                      base_dir, filename)
54
55
56def get_image(size):
57  """Returns an image loaded into an np.ndarray with dims [1, size, size, 3].
58
59  Args:
60    size: Size of image.
61
62  Returns:
63    np.ndarray.
64  """
65  img_filename = _resource_loader.get_path_to_datafile(
66      "testdata/grace_hopper.jpg")
67  img = image.load_img(img_filename, target_size=(size, size))
68  img_array = image.img_to_array(img)
69  img_array = np.expand_dims(img_array, axis=0)
70  return img_array
71
72
73def _convert(converter, **kwargs):
74  """Converts the model.
75
76  Args:
77    converter: TFLiteConverter object.
78    **kwargs: Additional arguments to be passed into the converter. Supported
79      flags are {"target_ops", "post_training_quantize"}.
80
81  Returns:
82    The converted TFLite model in serialized format.
83  """
84  if "target_ops" in kwargs:
85    converter.target_ops = kwargs["target_ops"]
86  if "post_training_quantize" in kwargs:
87    converter.post_training_quantize = kwargs["post_training_quantize"]
88  return converter.convert()
89
90
91def _generate_random_input_data(tflite_model, seed=None):
92  """Generates input data based on the input tensors in the TFLite model.
93
94  Args:
95    tflite_model: Serialized TensorFlow Lite model.
96    seed: Integer seed for the random generator. (default None)
97
98  Returns:
99    List of np.ndarray.
100  """
101  interpreter = _lite.Interpreter(model_content=tflite_model)
102  interpreter.allocate_tensors()
103  input_details = interpreter.get_input_details()
104
105  if seed:
106    np.random.seed(seed=seed)
107  return [
108      np.array(
109          np.random.random_sample(input_tensor["shape"]),
110          dtype=input_tensor["dtype"]) for input_tensor in input_details
111  ]
112
113
114def _evaluate_tflite_model(tflite_model, input_data):
115  """Returns evaluation of input data on TFLite model.
116
117  Args:
118    tflite_model: Serialized TensorFlow Lite model.
119    input_data: List of np.ndarray.
120
121  Returns:
122    List of np.ndarray.
123  """
124  interpreter = _lite.Interpreter(model_content=tflite_model)
125  interpreter.allocate_tensors()
126
127  input_details = interpreter.get_input_details()
128  output_details = interpreter.get_output_details()
129
130  for input_tensor, tensor_data in zip(input_details, input_data):
131    interpreter.set_tensor(input_tensor["index"], tensor_data)
132
133  interpreter.invoke()
134  output_data = [
135      interpreter.get_tensor(output_tensor["index"])
136      for output_tensor in output_details
137  ]
138  return output_data
139
140
141def evaluate_frozen_graph(filename, input_arrays, output_arrays):
142  """Returns a function that evaluates the frozen graph on input data.
143
144  Args:
145    filename: Full filepath of file containing frozen GraphDef.
146    input_arrays: List of input tensors to freeze graph with.
147    output_arrays: List of output tensors to freeze graph with.
148
149  Returns:
150    Lambda function ([np.ndarray data] : [np.ndarray result]).
151  """
152  with _session.Session().as_default() as sess:
153    with _file_io.FileIO(filename, "rb") as f:
154      file_content = f.read()
155
156    graph_def = _graph_pb2.GraphDef()
157    graph_def.ParseFromString(file_content)
158    _import_graph_def(graph_def, name="")
159
160    inputs = _convert_saved_model.get_tensors_from_tensor_names(
161        sess.graph, input_arrays)
162    outputs = _convert_saved_model.get_tensors_from_tensor_names(
163        sess.graph, output_arrays)
164
165    return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
166
167
168def evaluate_saved_model(directory, tag_set, signature_key):
169  """Returns a function that evaluates the SavedModel on input data.
170
171  Args:
172    directory: SavedModel directory to convert.
173    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
174      analyze. All tags in the tag set must be present.
175    signature_key: Key identifying SignatureDef containing inputs and outputs.
176
177  Returns:
178    Lambda function ([np.ndarray data] : [np.ndarray result]).
179  """
180  with _session.Session().as_default() as sess:
181    if tag_set is None:
182      tag_set = set([_tag_constants.SERVING])
183    if signature_key is None:
184      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
185
186    meta_graph = _loader.load(sess, tag_set, directory)
187    signature_def = _convert_saved_model.get_signature_def(
188        meta_graph, signature_key)
189    inputs, outputs = _convert_saved_model.get_inputs_outputs(signature_def)
190
191    return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
192
193
194def evaluate_keras_model(filename):
195  """Returns a function that evaluates the tf.keras model on input data.
196
197  Args:
198    filename: Full filepath of HDF5 file containing the tf.keras model.
199
200  Returns:
201    Lambda function ([np.ndarray data] : [np.ndarray result]).
202  """
203  keras_model = _keras.models.load_model(filename)
204  return lambda input_data: [keras_model.predict(input_data)]
205
206
207def compare_models(tflite_model, tf_eval_func, input_data=None, tolerance=5):
208  """Compares TensorFlow and TFLite models.
209
210  Unless the input data is provided, the models are compared with random data.
211
212  Args:
213    tflite_model: Serialized TensorFlow Lite model.
214    tf_eval_func: Lambda function that takes in input data and outputs the
215      results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]).
216    input_data: np.ndarray to pass into models during inference. (default None)
217    tolerance: Decimal place to check accuracy to. (default 5)
218  """
219  if input_data is None:
220    input_data = _generate_random_input_data(tflite_model)
221  tf_results = tf_eval_func(input_data)
222  tflite_results = _evaluate_tflite_model(tflite_model, input_data)
223  for tf_result, tflite_result in zip(tf_results, tflite_results):
224    np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
225
226
227def compare_models_v2(tflite_model, concrete_func, input_data=None,
228                      tolerance=5):
229  """Compares TensorFlow and TFLite models for TensorFlow 2.0.
230
231  Unless the input data is provided, the models are compared with random data.
232  Currently only 1 input and 1 output are supported by this function.
233
234  Args:
235    tflite_model: Serialized TensorFlow Lite model.
236    concrete_func: TensorFlow ConcreteFunction.
237    input_data: np.ndarray to pass into models during inference. (default None)
238    tolerance: Decimal place to check accuracy to. (default 5)
239  """
240  if input_data is None:
241    input_data = _generate_random_input_data(tflite_model)
242  input_data_func = constant_op.constant(input_data[0])
243
244  # Gets the TensorFlow results as a map from the output names to outputs.
245  # Converts the map into a list that is equivalent to the TFLite list.
246  tf_results_map = concrete_func(input_data_func)
247  tf_results = [tf_results_map[tf_results_map.keys()[0]]]
248  tflite_results = _evaluate_tflite_model(tflite_model, input_data)
249  for tf_result, tflite_result in zip(tf_results, tflite_results):
250    np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
251
252
253def test_frozen_graph_quant(filename,
254                            input_arrays,
255                            output_arrays,
256                            input_shapes=None,
257                            **kwargs):
258  """Sanity check to validate post quantize flag alters the graph.
259
260  This test does not check correctness of the converted model. It converts the
261  TensorFlow frozen graph to TFLite with and without the post_training_quantized
262  flag. It ensures some tensors have different types between the float and
263  quantized models in the case of an all TFLite model or mix-and-match model.
264  It ensures tensor types do not change in the case of an all Flex model.
265
266  Args:
267    filename: Full filepath of file containing frozen GraphDef.
268    input_arrays: List of input tensors to freeze graph with.
269    output_arrays: List of output tensors to freeze graph with.
270    input_shapes: Dict of strings representing input tensor names to list of
271      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
272      Automatically determined when input shapes is None (e.g., {"foo" : None}).
273        (default None)
274    **kwargs: Additional arguments to be passed into the converter.
275
276  Raises:
277    ValueError: post_training_quantize flag doesn't act as intended.
278  """
279  # Convert and load the float model.
280  converter = _lite.TFLiteConverter.from_frozen_graph(
281      filename, input_arrays, output_arrays, input_shapes)
282  tflite_model_float = _convert(converter, **kwargs)
283
284  interpreter_float = _lite.Interpreter(model_content=tflite_model_float)
285  interpreter_float.allocate_tensors()
286  float_tensors = interpreter_float.get_tensor_details()
287
288  # Convert and load the quantized model.
289  converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays,
290                                                      output_arrays)
291  tflite_model_quant = _convert(
292      converter, post_training_quantize=True, **kwargs)
293
294  interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant)
295  interpreter_quant.allocate_tensors()
296  quant_tensors = interpreter_quant.get_tensor_details()
297  quant_tensors_map = {
298      tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
299  }
300
301  # Check if weights are of different types in the float and quantized models.
302  num_tensors_float = len(float_tensors)
303  num_tensors_same_dtypes = sum(
304      float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
305      for float_tensor in float_tensors)
306  has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
307
308  if ("target_ops" in kwargs and
309      set(kwargs["target_ops"]) == set([_lite.OpsSet.SELECT_TF_OPS])):
310    if has_quant_tensor:
311      raise ValueError("--post_training_quantize flag unexpectedly altered the "
312                       "full Flex mode graph.")
313  elif not has_quant_tensor:
314    raise ValueError("--post_training_quantize flag was unable to quantize the "
315                     "graph as expected in TFLite and mix-and-match mode.")
316
317
318def test_frozen_graph(filename,
319                      input_arrays,
320                      output_arrays,
321                      input_shapes=None,
322                      input_data=None,
323                      **kwargs):
324  """Validates the TensorFlow frozen graph converts to a TFLite model.
325
326  Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the
327  model on random data.
328
329  Args:
330    filename: Full filepath of file containing frozen GraphDef.
331    input_arrays: List of input tensors to freeze graph with.
332    output_arrays: List of output tensors to freeze graph with.
333    input_shapes: Dict of strings representing input tensor names to list of
334      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
335      Automatically determined when input shapes is None (e.g., {"foo" : None}).
336        (default None)
337    input_data: np.ndarray to pass into models during inference. (default None)
338    **kwargs: Additional arguments to be passed into the converter.
339  """
340  converter = _lite.TFLiteConverter.from_frozen_graph(
341      filename, input_arrays, output_arrays, input_shapes)
342  tflite_model = _convert(converter, **kwargs)
343
344  tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)
345  compare_models(tflite_model, tf_eval_func, input_data=input_data)
346
347
348def test_saved_model(directory,
349                     input_shapes=None,
350                     tag_set=None,
351                     signature_key=None,
352                     input_data=None,
353                     **kwargs):
354  """Validates the TensorFlow SavedModel converts to a TFLite model.
355
356  Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the
357  model on random data.
358
359  Args:
360    directory: SavedModel directory to convert.
361    input_shapes: Dict of strings representing input tensor names to list of
362      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
363      Automatically determined when input shapes is None (e.g., {"foo" : None}).
364        (default None)
365    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
366      analyze. All tags in the tag set must be present.
367    signature_key: Key identifying SignatureDef containing inputs and outputs.
368    input_data: np.ndarray to pass into models during inference. (default None)
369    **kwargs: Additional arguments to be passed into the converter.
370  """
371  converter = _lite.TFLiteConverter.from_saved_model(
372      directory,
373      input_shapes=input_shapes,
374      tag_set=tag_set,
375      signature_key=signature_key)
376  tflite_model = _convert(converter, **kwargs)
377
378  tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key)
379  compare_models(tflite_model, tf_eval_func, input_data=input_data)
380
381
382# TODO(nupurgarg): Remove input_shape parameter after bug with shapes is fixed.
383def test_saved_model_v2(directory,
384                        input_shape=None,
385                        tag_set=None,
386                        signature_key=None,
387                        input_data=None,
388                        **kwargs):
389  """Validates the TensorFlow SavedModel converts to a TFLite model.
390
391  Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the
392  model on random data.
393
394  Args:
395    directory: SavedModel directory to convert.
396    input_shape: Input shape for the single input array as a list of integers.
397    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
398      analyze. All tags in the tag set must be present.
399    signature_key: Key identifying SignatureDef containing inputs and outputs.
400    input_data: np.ndarray to pass into models during inference. (default None)
401    **kwargs: Additional arguments to be passed into the converter.
402  """
403  model = _load.load(directory, tags=tag_set)
404  if not signature_key:
405    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
406  concrete_func = model.signatures[signature_key]
407  concrete_func.inputs[0].set_shape(input_shape)
408
409  converter = _lite.TFLiteConverterV2.from_concrete_function(concrete_func)
410  tflite_model = _convert(converter, **kwargs)
411
412  compare_models_v2(tflite_model, concrete_func, input_data=input_data)
413
414
415def test_keras_model(filename,
416                     input_arrays=None,
417                     input_shapes=None,
418                     input_data=None,
419                     **kwargs):
420  """Validates the tf.keras model converts to a TFLite model.
421
422  Converts the tf.keras model to TFLite and checks the accuracy of the model on
423  random data.
424
425  Args:
426    filename: Full filepath of HDF5 file containing the tf.keras model.
427    input_arrays: List of input tensors to freeze graph with.
428    input_shapes: Dict of strings representing input tensor names to list of
429      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
430      Automatically determined when input shapes is None (e.g., {"foo" : None}).
431        (default None)
432    input_data: np.ndarray to pass into models during inference. (default None)
433    **kwargs: Additional arguments to be passed into the converter.
434  """
435  converter = _lite.TFLiteConverter.from_keras_model_file(
436      filename, input_arrays=input_arrays, input_shapes=input_shapes)
437  tflite_model = _convert(converter, **kwargs)
438
439  tf_eval_func = evaluate_keras_model(filename)
440  compare_models(tflite_model, tf_eval_func, input_data=input_data)
441