• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions to test TFLite models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24from six import PY2
25from tensorflow import keras
26
27from google.protobuf import text_format as _text_format
28from google.protobuf.message import DecodeError
29from tensorflow.core.framework import graph_pb2 as _graph_pb2
30from tensorflow.lite.python import convert_saved_model as _convert_saved_model
31from tensorflow.lite.python import interpreter as _interpreter
32from tensorflow.lite.python import lite as _lite
33from tensorflow.lite.python import util as _util
34from tensorflow.python.client import session as _session
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
39from tensorflow.python.lib.io import file_io as _file_io
40from tensorflow.python.platform import resource_loader as _resource_loader
41from tensorflow.python.platform import tf_logging as logging
42from tensorflow.python.saved_model import load as _load
43from tensorflow.python.saved_model import loader as _loader
44from tensorflow.python.saved_model import signature_constants as _signature_constants
45from tensorflow.python.saved_model import tag_constants as _tag_constants
46
47
48_GOLDENS_UPDATE_WARNING = """
49  Golden file update requested!
50  This test is now going to write new golden files.
51
52  Make sure to package the updates together with your CL.
53"""
54
55
56def get_filepath(filename, base_dir=None):
57  """Returns the full path of the filename.
58
59  Args:
60    filename: Subdirectory and name of the model file.
61    base_dir: Base directory containing model file.
62
63  Returns:
64    str.
65  """
66  if base_dir is None:
67    base_dir = "learning/brain/mobile/tflite_compat_models"
68  return os.path.join(_resource_loader.get_root_dir_with_all_resources(),
69                      base_dir, filename)
70
71
72def get_golden_filepath(name):
73  """Returns the full path to a golden values file.
74
75  Args:
76    name: the name of the golden data, usually same as the test name.
77  """
78  goldens_directory = os.path.join(_resource_loader.get_data_files_path(),
79                                   "testdata", "golden")
80  return os.path.join(goldens_directory, "%s.npy.golden" % name)
81
82
83def get_image(size):
84  """Returns an image loaded into an np.ndarray with dims [1, size, size, 3].
85
86  Args:
87    size: Size of image.
88
89  Returns:
90    np.ndarray.
91  """
92  img_filename = _resource_loader.get_path_to_datafile(
93      "testdata/grace_hopper.jpg")
94  img = keras.preprocessing.image.load_img(
95      img_filename, target_size=(size, size))
96  img_array = keras.preprocessing.image.img_to_array(img)
97  img_array = np.expand_dims(img_array, axis=0)
98  return img_array
99
100
101def _get_calib_data_func(input_size):
102  """Returns a function to generate a representative data set.
103
104  Args:
105    input_size: 3D shape of the representative data.
106  """
107  def representative_data_gen():
108    num_calibration = 20
109    for _ in range(num_calibration):
110      yield [
111          np.random.rand(
112              1,
113              input_size[0],
114              input_size[1],
115              input_size[2],
116          ).astype(np.float32)
117      ]
118
119  return representative_data_gen
120
121
122def _convert(converter, **kwargs):
123  """Converts the model.
124
125  Args:
126    converter: TFLiteConverter object.
127    **kwargs: Additional arguments to be passed into the converter. Supported
128      flags are {"target_ops", "post_training_quantize", "quantize_to_float16",
129      "post_training_quantize_int8", "post_training_quantize_16x8",
130      "model_input_size"}.
131
132  Returns:
133    The converted TFLite model in serialized format.
134
135  Raises:
136    ValueError: Invalid version number.
137  """
138
139  if "target_ops" in kwargs:
140    converter.target_spec.supported_ops = kwargs["target_ops"]
141  if "post_training_quantize" in kwargs:
142    converter.optimizations = [_lite.Optimize.DEFAULT]
143  if kwargs.get("quantize_to_float16", False):
144    converter.target_spec.supported_types = [dtypes.float16]
145  if kwargs.get("post_training_quantize_int8", False):
146    input_size = kwargs.get("model_input_size")
147    converter.optimizations = [_lite.Optimize.DEFAULT]
148    converter.target_spec.supported_ops = [_lite.OpsSet.TFLITE_BUILTINS_INT8]
149    converter.representative_dataset = _get_calib_data_func(input_size)
150    # Note that the full integer quantization is by the mlir quantizer
151    converter.experimental_new_quantizer = True
152  if kwargs.get("post_training_quantize_16x8", False):
153    input_size = kwargs.get("model_input_size")
154    converter.optimizations = [_lite.Optimize.DEFAULT]
155    converter.target_spec.supported_ops = \
156      [_lite.OpsSet.\
157        EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]
158    converter.representative_dataset = _get_calib_data_func(input_size)
159  return converter.convert()
160
161
162def _check_model_quantized_to_16x8(tflite_model):
163  """Checks that the activations are quantized into int16.
164
165  Args:
166    tflite_model: Serialized TensorFlow Lite model.
167
168  Raises:
169    ValueError: Activations with int16 type are not found.
170  """
171  interpreter = _get_tflite_interpreter(tflite_model)
172  interpreter.allocate_tensors()
173  all_tensor_details = interpreter.get_tensor_details()
174
175  found_input = False
176  for tensor in all_tensor_details:
177    if "_int16" in tensor["name"]:
178      found_input = True
179      if tensor["dtype"] is not np.int16:
180        raise ValueError("Activations should be int16.")
181
182  # Check that we found activations in the correct type: int16
183  if not found_input:
184    raise ValueError("Could not find int16 activations.")
185
186
187def _get_tflite_interpreter(tflite_model,
188                            input_shapes_resize=None,
189                            custom_op_registerers=None):
190  """Creates a TFLite interpreter with resized input tensors.
191
192  Args:
193    tflite_model: Serialized TensorFlow Lite model.
194    input_shapes_resize: A map where the key is the input tensor name and the
195      value is the shape of the input tensor. This resize happens after model
196      conversion, prior to calling allocate tensors. (default None)
197    custom_op_registerers: Op registerers for custom ops.
198
199  Returns:
200    lite.Interpreter
201  """
202  if custom_op_registerers is None:
203    custom_op_registerers = []
204  interpreter = _interpreter.InterpreterWithCustomOps(
205      model_content=tflite_model, custom_op_registerers=custom_op_registerers)
206  if input_shapes_resize:
207    input_details = interpreter.get_input_details()
208    input_details_map = {
209        detail["name"]: detail["index"] for detail in input_details
210    }
211    for name, shape in input_shapes_resize.items():
212      idx = input_details_map[name]
213      interpreter.resize_tensor_input(idx, shape)
214  return interpreter
215
216
217def _get_input_data_map(tflite_model, input_data, custom_op_registerers=None):
218  """Generates a map of input data based on the TFLite model.
219
220  Args:
221    tflite_model: Serialized TensorFlow Lite model.
222    input_data: List of np.ndarray.
223    custom_op_registerers: Op registerers for custom ops.
224
225  Returns:
226    {str: [np.ndarray]}.
227  """
228  interpreter = _get_tflite_interpreter(
229      tflite_model, custom_op_registerers=custom_op_registerers)
230  interpreter.allocate_tensors()
231  input_details = interpreter.get_input_details()
232  return {
233      input_tensor["name"]: data
234      for input_tensor, data in zip(input_details, input_data)
235  }
236
237
238def _generate_random_input_data(tflite_model,
239                                seed=None,
240                                input_data_range=None,
241                                input_shapes_resize=None,
242                                custom_op_registerers=None):
243  """Generates input data based on the input tensors in the TFLite model.
244
245  Args:
246    tflite_model: Serialized TensorFlow Lite model.
247    seed: Integer seed for the random generator. (default None)
248    input_data_range: A map where the key is the input tensor name and
249      the value is a tuple (min_val, max_val) which specifies the value range of
250      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
251      generate a random value for tensor `input1` within range [1.0, 5.0)
252      (half-inclusive). (default None)
253    input_shapes_resize: A map where the key is the input tensor name and the
254      value is the shape of the input tensor. This resize happens after model
255      conversion, prior to calling allocate tensors. (default None)
256    custom_op_registerers: Op registerers for custom ops.
257
258  Returns:
259    ([np.ndarray], {str : [np.ndarray]}).
260  """
261  interpreter = _get_tflite_interpreter(
262      tflite_model,
263      input_shapes_resize,
264      custom_op_registerers=custom_op_registerers)
265  interpreter.allocate_tensors()
266  input_details = interpreter.get_input_details()
267
268  if seed:
269    np.random.seed(seed=seed)
270
271  # Generate random input data. If a tensor's value range is specified, say
272  # [a, b), then the generated value will be (b - a) * Unif[0.0, 1.0) + a,
273  # otherwise it's Unif[0.0, 1.0).
274  input_data = []
275  for input_tensor in input_details:
276    val = np.random.random_sample(input_tensor["shape"])
277    if (input_data_range is not None and
278        input_tensor["name"] in input_data_range):
279      val = (input_data_range[input_tensor["name"]][1] -
280             input_data_range[input_tensor["name"]][0]
281            ) * val + input_data_range[input_tensor["name"]][0]
282    input_data.append(np.array(val, dtype=input_tensor["dtype"]))
283
284  input_data_map = _get_input_data_map(
285      tflite_model, input_data, custom_op_registerers=custom_op_registerers)
286  return input_data, input_data_map
287
288
289def _evaluate_tflite_model(tflite_model,
290                           input_data,
291                           input_shapes_resize=None,
292                           custom_op_registerers=None):
293  """Returns evaluation of input data on TFLite model.
294
295  Args:
296    tflite_model: Serialized TensorFlow Lite model.
297    input_data: List of np.ndarray.
298    input_shapes_resize: A map where the key is the input tensor name and the
299      value is the shape of the input tensor. This resize happens after model
300      conversion, prior to calling allocate tensors. (default None)
301    custom_op_registerers: Op registerers for custom ops.
302
303  Returns:
304    List of np.ndarray.
305  """
306  interpreter = _get_tflite_interpreter(
307      tflite_model,
308      input_shapes_resize,
309      custom_op_registerers=custom_op_registerers)
310  interpreter.allocate_tensors()
311
312  input_details = interpreter.get_input_details()
313  output_details = interpreter.get_output_details()
314
315  for input_tensor, tensor_data in zip(input_details, input_data):
316    interpreter.set_tensor(input_tensor["index"], tensor_data)
317
318  interpreter.invoke()
319  output_data = [
320      interpreter.get_tensor(output_tensor["index"])
321      for output_tensor in output_details
322  ]
323  output_labels = [output_tensor["name"] for output_tensor in output_details]
324  return output_data, output_labels
325
326
327def evaluate_frozen_graph(filename, input_arrays, output_arrays):
328  """Returns a function that evaluates the frozen graph on input data.
329
330  Args:
331    filename: Full filepath of file containing frozen GraphDef.
332    input_arrays: List of input tensors to freeze graph with.
333    output_arrays: List of output tensors to freeze graph with.
334
335  Returns:
336    Lambda function ([np.ndarray data] : [np.ndarray result]).
337  """
338  with _file_io.FileIO(filename, "rb") as f:
339    file_content = f.read()
340
341  graph_def = _graph_pb2.GraphDef()
342  try:
343    graph_def.ParseFromString(file_content)
344  except (_text_format.ParseError, DecodeError):
345    if not isinstance(file_content, str):
346      if PY2:
347        file_content = file_content.encode("utf-8")
348      else:
349        file_content = file_content.decode("utf-8")
350    _text_format.Merge(file_content, graph_def)
351
352  graph = ops.Graph()
353  with graph.as_default():
354    _import_graph_def(graph_def, name="")
355  inputs = _util.get_tensors_from_tensor_names(graph, input_arrays)
356  outputs = _util.get_tensors_from_tensor_names(graph, output_arrays)
357
358  def run_session(input_data):
359    with _session.Session(graph=graph) as sess:
360      return sess.run(outputs, dict(zip(inputs, input_data)))
361
362  return run_session
363
364
365def evaluate_saved_model(directory, tag_set, signature_key):
366  """Returns a function that evaluates the SavedModel on input data.
367
368  Args:
369    directory: SavedModel directory to convert.
370    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
371      analyze. All tags in the tag set must be present.
372    signature_key: Key identifying SignatureDef containing inputs and outputs.
373
374  Returns:
375    Lambda function ([np.ndarray data] : [np.ndarray result]).
376  """
377  with _session.Session().as_default() as sess:
378    if tag_set is None:
379      tag_set = set([_tag_constants.SERVING])
380    if signature_key is None:
381      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
382
383    meta_graph = _loader.load(sess, tag_set, directory)
384    signature_def = _convert_saved_model.get_signature_def(
385        meta_graph, signature_key)
386    inputs, outputs = _convert_saved_model.get_inputs_outputs(signature_def)
387
388    return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
389
390
391def evaluate_keras_model(filename):
392  """Returns a function that evaluates the tf.keras model on input data.
393
394  Args:
395    filename: Full filepath of HDF5 file containing the tf.keras model.
396
397  Returns:
398    Lambda function ([np.ndarray data] : [np.ndarray result]).
399  """
400  keras_model = keras.models.load_model(filename)
401  return lambda input_data: [keras_model.predict(input_data)]
402
403
404def compare_models(tflite_model,
405                   tf_eval_func,
406                   input_shapes_resize=None,
407                   input_data=None,
408                   input_data_range=None,
409                   tolerance=5):
410  """Compares TensorFlow and TFLite models.
411
412  Unless the input data is provided, the models are compared with random data.
413
414  Args:
415    tflite_model: Serialized TensorFlow Lite model.
416    tf_eval_func: Lambda function that takes in input data and outputs the
417      results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]).
418    input_shapes_resize: A map where the key is the input tensor name and the
419      value is the shape of the input tensor. This resize happens after model
420      conversion, prior to calling allocate tensors. (default None)
421    input_data: np.ndarray to pass into models during inference. (default None)
422    input_data_range: A map where the key is the input tensor name and
423      the value is a tuple (min_val, max_val) which specifies the value range of
424      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
425      generate a random value for tensor `input1` within range [1.0, 5.0)
426      (half-inclusive). (default None)
427    tolerance: Decimal place to check accuracy to. (default 5).
428  """
429  if input_data is None:
430    input_data, _ = _generate_random_input_data(
431        tflite_model=tflite_model,
432        input_data_range=input_data_range,
433        input_shapes_resize=input_shapes_resize)
434  tf_results = tf_eval_func(input_data)
435  tflite_results, _ = _evaluate_tflite_model(
436      tflite_model, input_data, input_shapes_resize=input_shapes_resize)
437  for tf_result, tflite_result in zip(tf_results, tflite_results):
438    np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
439
440
441def _compare_tf_tflite_results(tf_results,
442                               tflite_results,
443                               tflite_labels,
444                               tolerance=5):
445  """Compare the result of TF and TFLite model.
446
447  Args:
448    tf_results: results returned by the TF model.
449    tflite_results: results returned by the TFLite model.
450    tflite_labels: names of the output tensors in the TFlite model.
451    tolerance: Decimal place to check accuracy to. (default 5).
452  """
453  # Convert the output TensorFlow results into an ordered list.
454  if isinstance(tf_results, dict):
455    if len(tf_results) == 1:
456      tf_results = [tf_results[list(tf_results.keys())[0]]]
457    else:
458      tf_results = [tf_results[tflite_label] for tflite_label in tflite_labels]
459  else:
460    tf_results = [tf_results]
461
462  for tf_result, tflite_result in zip(tf_results, tflite_results):
463    np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
464
465
466def compare_models_v2(tflite_model,
467                      tf_eval_func,
468                      input_data=None,
469                      input_data_range=None,
470                      tolerance=5):
471  """Compares TensorFlow and TFLite models for TensorFlow 2.0.
472
473  Unless the input data is provided, the models are compared with random data.
474  Currently only 1 input and 1 output are supported by this function.
475
476  Args:
477    tflite_model: Serialized TensorFlow Lite model.
478    tf_eval_func: Function to evaluate TensorFlow model. Either a lambda
479      function that takes in input data and outputs the results or a TensorFlow
480      ConcreteFunction.
481    input_data: np.ndarray to pass into models during inference. (default None).
482    input_data_range: A map where the key is the input tensor name and
483      the value is a tuple (min_val, max_val) which specifies the value range of
484      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
485      generate a random value for tensor `input1` within range [1.0, 5.0)
486      (half-inclusive). (default None)
487    tolerance: Decimal place to check accuracy to. (default 5)
488  """
489  # Convert the input data into a map.
490  if input_data is None:
491    input_data, input_data_map = _generate_random_input_data(
492        tflite_model=tflite_model, input_data_range=input_data_range)
493  else:
494    input_data_map = _get_input_data_map(tflite_model, input_data)
495  input_data_func_map = {
496      input_name: constant_op.constant(input_data)
497      for input_name, input_data in input_data_map.items()
498  }
499
500  if len(input_data) > 1:
501    tf_results = tf_eval_func(**input_data_func_map)
502  else:
503    tf_results = tf_eval_func(constant_op.constant(input_data[0]))
504  tflite_results, tflite_labels = _evaluate_tflite_model(
505      tflite_model, input_data)
506
507  _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels,
508                             tolerance)
509
510
511def compare_tflite_keras_models_v2(tflite_model,
512                                   keras_model,
513                                   input_data=None,
514                                   input_data_range=None,
515                                   tolerance=5,
516                                   custom_op_registerers=None):
517  """Similar to compare_models_v2 but accept Keras model.
518
519  Unless the input data is provided, the models are compared with random data.
520  Currently only 1 input and 1 output are supported by this function.
521
522  Args:
523    tflite_model: Serialized TensorFlow Lite model.
524    keras_model: Keras model to evaluate.
525    input_data: np.ndarray to pass into models during inference. (default None).
526    input_data_range: A map where the key is the input tensor name and the value
527      is a tuple (min_val, max_val) which specifies the value range of
528      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
529      generate a random value for tensor `input1` within range [1.0, 5.0)
530      (half-inclusive). (default None)
531    tolerance: Decimal place to check accuracy to. (default 5)
532    custom_op_registerers: Op registerers for custom ops.
533  """
534  # Generate random input data if not provided.
535  if input_data is None:
536    input_data, _ = _generate_random_input_data(
537        tflite_model=tflite_model,
538        input_data_range=input_data_range,
539        custom_op_registerers=custom_op_registerers)
540
541  if len(input_data) > 1:
542    tf_results = keras_model.predict(input_data)
543  else:
544    tf_results = keras_model.predict(input_data[0])
545  tflite_results, tflite_labels = _evaluate_tflite_model(
546      tflite_model, input_data, custom_op_registerers=custom_op_registerers)
547
548  _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels,
549                             tolerance)
550
551
552def compare_model_golden(tflite_model,
553                         input_data,
554                         golden_name,
555                         update_golden=False,
556                         tolerance=5):
557  """Compares the output of a TFLite model against pre-existing golden values.
558
559  Args:
560    tflite_model: Serialized TensorFlow Lite model.
561    input_data: np.ndarray to pass into models during inference.
562    golden_name: Name of the file containing the (expected) golden values.
563    update_golden: Whether to update the golden values with the model output
564      instead of comparing against them. This should only be done when a change
565      in TFLite warrants it.
566    tolerance: Decimal place to check accuracy to. (default 5).
567  """
568  tflite_results, _ = _evaluate_tflite_model(tflite_model, input_data)
569  golden_file = get_golden_filepath(golden_name)
570  if update_golden:
571    logging.warning(_GOLDENS_UPDATE_WARNING)
572    logging.warning("Updating golden values in file %s.", golden_file)
573    if not os.path.exists(golden_file):
574      golden_relative_path = os.path.relpath(
575          golden_file, _resource_loader.get_root_dir_with_all_resources())
576      logging.warning(
577          "Golden file not found. Manually create it first:\ntouch %r",
578          golden_relative_path)
579
580    with open(golden_file, "wb") as f:
581      np.save(f, tflite_results, allow_pickle=False)
582  else:
583    golden_data = np.load(golden_file, allow_pickle=False)
584    np.testing.assert_almost_equal(golden_data, tflite_results, tolerance)
585
586
587def test_frozen_graph_quant(filename,
588                            input_arrays,
589                            output_arrays,
590                            input_shapes=None,
591                            **kwargs):
592  """Sanity check to validate post quantize flag alters the graph.
593
594  This test does not check correctness of the converted model. It converts the
595  TensorFlow frozen graph to TFLite with and without the post_training_quantized
596  flag. It ensures some tensors have different types between the float and
597  quantized models in the case of an all TFLite model or mix-and-match model.
598  It ensures tensor types do not change in the case of an all Flex model.
599
600  Args:
601    filename: Full filepath of file containing frozen GraphDef.
602    input_arrays: List of input tensors to freeze graph with.
603    output_arrays: List of output tensors to freeze graph with.
604    input_shapes: Dict of strings representing input tensor names to list of
605      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
606      Automatically determined when input shapes is None (e.g., {"foo" : None}).
607        (default None)
608    **kwargs: Additional arguments to be passed into the converter.
609
610  Raises:
611    ValueError: post_training_quantize flag doesn't act as intended.
612  """
613  # Convert and load the float model.
614  converter = _lite.TFLiteConverter.from_frozen_graph(
615      filename, input_arrays, output_arrays, input_shapes)
616  tflite_model_float = _convert(converter, **kwargs)
617
618  interpreter_float = _get_tflite_interpreter(tflite_model_float)
619  interpreter_float.allocate_tensors()
620  float_tensors = interpreter_float.get_tensor_details()
621
622  # Convert and load the quantized model.
623  converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays,
624                                                      output_arrays,
625                                                      input_shapes)
626  tflite_model_quant = _convert(
627      converter, post_training_quantize=True, **kwargs)
628
629  interpreter_quant = _get_tflite_interpreter(tflite_model_quant)
630  interpreter_quant.allocate_tensors()
631  quant_tensors = interpreter_quant.get_tensor_details()
632  quant_tensors_map = {
633      tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
634  }
635  quantized_tensors = {
636      tensor_detail["name"]: tensor_detail
637      for tensor_detail in quant_tensors
638      if tensor_detail["quantization_parameters"]
639  }
640
641  # Check if weights are of different types in the float and quantized models.
642  num_tensors_float = len(float_tensors)
643  num_tensors_same_dtypes = sum(
644      float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
645      for float_tensor in float_tensors)
646  has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
647
648  # For the "flex" case, post_training_quantize should not alter the graph,
649  # unless we are quantizing to float16.
650  if ("target_ops" in kwargs and
651      not kwargs.get("quantize_to_float16", False) and
652      not kwargs.get("post_training_quantize_int8", False) and
653      not kwargs.get("post_training_quantize_16x8", False) and
654      set(kwargs["target_ops"]) == set([_lite.OpsSet.SELECT_TF_OPS])):
655    if has_quant_tensor:
656      raise ValueError("--post_training_quantize flag unexpectedly altered the "
657                       "full Flex mode graph.")
658  elif kwargs.get("post_training_quantize_int8", False):
659    # Instead of using tensor names, we use the number of tensors which have
660    # quantization parameters to verify the model is quantized.
661    if not quantized_tensors:
662      raise ValueError("--post_training_quantize flag was unable to quantize "
663                       "the graph as expected in TFLite.")
664  elif not has_quant_tensor:
665    raise ValueError("--post_training_quantize flag was unable to quantize the "
666                     "graph as expected in TFLite and mix-and-match mode.")
667
668
669def test_frozen_graph(filename,
670                      input_arrays,
671                      output_arrays,
672                      input_shapes=None,
673                      input_shapes_resize=None,
674                      input_data=None,
675                      input_data_range=None,
676                      **kwargs):
677  """Validates the TensorFlow frozen graph converts to a TFLite model.
678
679  Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the
680  model on random data.
681
682  Args:
683    filename: Full filepath of file containing frozen GraphDef.
684    input_arrays: List of input tensors to freeze graph with.
685    output_arrays: List of output tensors to freeze graph with.
686    input_shapes: Dict of strings representing input tensor names to list of
687      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
688      Automatically determined when input shapes is None (e.g., {"foo" : None}).
689        (default None)
690    input_shapes_resize: A map where the key is the input tensor name and the
691      value is the shape of the input tensor. This resize happens after model
692      conversion, prior to calling allocate tensors. (default None)
693    input_data: np.ndarray to pass into models during inference. (default None).
694    input_data_range: A map where the key is the input tensor name and
695      the value is a tuple (min_val, max_val) which specifies the value range of
696      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
697      generate a random value for tensor `input1` within range [1.0, 5.0)
698      (half-inclusive). (default None)
699    **kwargs: Additional arguments to be passed into the converter.
700  """
701  converter = _lite.TFLiteConverter.from_frozen_graph(
702      filename, input_arrays, output_arrays, input_shapes)
703  tflite_model = _convert(converter, **kwargs)
704
705  tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)
706  compare_models(
707      tflite_model,
708      tf_eval_func,
709      input_shapes_resize=input_shapes_resize,
710      input_data=input_data,
711      input_data_range=input_data_range)
712
713
714def test_saved_model(directory,
715                     input_shapes=None,
716                     tag_set=None,
717                     signature_key=None,
718                     input_data=None,
719                     input_data_range=None,
720                     **kwargs):
721  """Validates the TensorFlow SavedModel converts to a TFLite model.
722
723  Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the
724  model on random data.
725
726  Args:
727    directory: SavedModel directory to convert.
728    input_shapes: Dict of strings representing input tensor names to list of
729      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
730      Automatically determined when input shapes is None (e.g., {"foo" : None}).
731        (default None)
732    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
733      analyze. All tags in the tag set must be present.
734    signature_key: Key identifying SignatureDef containing inputs and outputs.
735    input_data: np.ndarray to pass into models during inference. (default None).
736    input_data_range: A map where the key is the input tensor name and
737      the value is a tuple (min_val, max_val) which specifies the value range of
738      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
739      generate a random value for tensor `input1` within range [1.0, 5.0)
740      (half-inclusive). (default None)
741    **kwargs: Additional arguments to be passed into the converter.
742  """
743  converter = _lite.TFLiteConverter.from_saved_model(
744      directory,
745      input_shapes=input_shapes,
746      tag_set=tag_set,
747      signature_key=signature_key)
748  tflite_model = _convert(converter, **kwargs)
749
750  # 5 decimal places by default
751  tolerance = 5
752  if kwargs.get("post_training_quantize_16x8", False):
753    _check_model_quantized_to_16x8(tflite_model)
754    # only 2 decimal places for full quantization
755    tolerance = 2
756
757  tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key)
758  compare_models(
759      tflite_model,
760      tf_eval_func,
761      input_data=input_data,
762      input_data_range=input_data_range,
763      tolerance=tolerance)
764
765
766def test_saved_model_v2(directory,
767                        tag_set=None,
768                        signature_key=None,
769                        input_data=None,
770                        input_data_range=None,
771                        **kwargs):
772  """Validates the TensorFlow SavedModel converts to a TFLite model.
773
774  Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the
775  model on random data.
776
777  Args:
778    directory: SavedModel directory to convert.
779    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
780      analyze. All tags in the tag set must be present.
781    signature_key: Key identifying SignatureDef containing inputs and outputs.
782    input_data: np.ndarray to pass into models during inference. (default None).
783    input_data_range: A map where the key is the input tensor name and
784      the value is a tuple (min_val, max_val) which specifies the value range of
785      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
786      generate a random value for tensor `input1` within range [1.0, 5.0)
787      (half-inclusive). (default None)
788    **kwargs: Additional arguments to be passed into the converter.
789  """
790  model = _load.load(directory, tags=tag_set)
791  if not signature_key:
792    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
793  concrete_func = model.signatures[signature_key]
794
795  converter = _lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
796  tflite_model = _convert(converter, **kwargs)
797
798  compare_models_v2(
799      tflite_model,
800      concrete_func,
801      input_data=input_data,
802      input_data_range=input_data_range)
803
804
805def _test_conversion_quant_float16(converter,
806                                   input_data,
807                                   golden_name=None,
808                                   update_golden=False,
809                                   **kwargs):
810  """Validates conversion with float16 quantization.
811
812  Args:
813    converter: TFLite converter instance for the model to convert.
814    input_data: np.ndarray to pass into models during inference.
815    golden_name: Optional golden values to compare the output of the model
816      against.
817    update_golden: Whether to update the golden values with the model output
818      instead of comparing against them.
819    **kwargs: Additional arguments to be passed into the converter.
820  """
821  tflite_model_float = _convert(converter, version=2, **kwargs)
822
823  interpreter_float = _get_tflite_interpreter(tflite_model_float)
824  interpreter_float.allocate_tensors()
825  float_tensors = interpreter_float.get_tensor_details()
826
827  tflite_model_quant = _convert(
828      converter,
829      version=2,
830      post_training_quantize=True,
831      quantize_to_float16=True,
832      **kwargs)
833
834  interpreter_quant = _get_tflite_interpreter(tflite_model_quant)
835  interpreter_quant.allocate_tensors()
836  quant_tensors = interpreter_quant.get_tensor_details()
837  quant_tensors_map = {
838      tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
839  }
840
841  # Check if weights are of different types in the float and quantized models.
842  num_tensors_float = len(float_tensors)
843  num_tensors_same_dtypes = sum(
844      float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
845      for float_tensor in float_tensors)
846  has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
847
848  if not has_quant_tensor:
849    raise ValueError("--post_training_quantize flag was unable to quantize the "
850                     "graph as expected.")
851
852  if golden_name:
853    compare_model_golden(tflite_model_quant, input_data, golden_name,
854                         update_golden)
855
856
857def test_saved_model_v2_quant_float16(directory,
858                                      input_data,
859                                      golden_name=None,
860                                      update_golden=False,
861                                      **kwargs):
862  """Validates conversion of a saved model to TFLite with float16 quantization.
863
864  Args:
865    directory: SavedModel directory to convert.
866    input_data: np.ndarray to pass into models during inference.
867    golden_name: Optional golden values to compare the output of the model
868      against.
869    update_golden: Whether to update the golden values with the model output
870      instead of comparing against them.
871    **kwargs: Additional arguments to be passed into the converter.
872  """
873  converter = _lite.TFLiteConverterV2.from_saved_model(directory)
874  _test_conversion_quant_float16(converter, input_data, golden_name,
875                                 update_golden, **kwargs)
876
877
878def test_frozen_graph_quant_float16(filename,
879                                    input_arrays,
880                                    output_arrays,
881                                    input_data,
882                                    input_shapes=None,
883                                    golden_name=None,
884                                    update_golden=False,
885                                    **kwargs):
886  """Validates conversion of a frozen graph to TFLite with float16 quantization.
887
888  Args:
889    filename: Full filepath of file containing frozen GraphDef.
890    input_arrays: List of input tensors to freeze graph with.
891    output_arrays: List of output tensors to freeze graph with.
892    input_data: np.ndarray to pass into models during inference.
893    input_shapes: Dict of strings representing input tensor names to list of
894      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
895      Automatically determined when input shapes is None (e.g., {"foo" : None}).
896        (default None)
897    golden_name: Optional golden values to compare the output of the model
898      against.
899    update_golden: Whether to update the golden values with the model output
900      instead of comparing against them.
901    **kwargs: Additional arguments to be passed into the converter.
902  """
903  converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays,
904                                                      output_arrays,
905                                                      input_shapes)
906  _test_conversion_quant_float16(converter, input_data,
907                                 golden_name, update_golden, **kwargs)
908
909
910def test_keras_model(filename,
911                     input_arrays=None,
912                     input_shapes=None,
913                     input_data=None,
914                     input_data_range=None,
915                     **kwargs):
916  """Validates the tf.keras model converts to a TFLite model.
917
918  Converts the tf.keras model to TFLite and checks the accuracy of the model on
919  random data.
920
921  Args:
922    filename: Full filepath of HDF5 file containing the tf.keras model.
923    input_arrays: List of input tensors to freeze graph with.
924    input_shapes: Dict of strings representing input tensor names to list of
925      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
926      Automatically determined when input shapes is None (e.g., {"foo" : None}).
927        (default None)
928    input_data: np.ndarray to pass into models during inference. (default None).
929    input_data_range: A map where the key is the input tensor name and
930      the value is a tuple (min_val, max_val) which specifies the value range of
931      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
932      generate a random value for tensor `input1` within range [1.0, 5.0)
933      (half-inclusive). (default None)
934    **kwargs: Additional arguments to be passed into the converter.
935  """
936  converter = _lite.TFLiteConverter.from_keras_model_file(
937      filename, input_arrays=input_arrays, input_shapes=input_shapes)
938  tflite_model = _convert(converter, **kwargs)
939
940  tf_eval_func = evaluate_keras_model(filename)
941  compare_models(
942      tflite_model,
943      tf_eval_func,
944      input_data=input_data,
945      input_data_range=input_data_range)
946
947
948def test_keras_model_v2(filename,
949                        input_shapes=None,
950                        input_data=None,
951                        input_data_range=None,
952                        **kwargs):
953  """Validates the tf.keras model converts to a TFLite model.
954
955  Converts the tf.keras model to TFLite and checks the accuracy of the model on
956  random data.
957
958  Args:
959    filename: Full filepath of HDF5 file containing the tf.keras model.
960    input_shapes: List of list of integers representing input shapes in the
961      order of the tf.keras model's .input attribute (e.g., [[1, 16, 16, 3]]).
962      (default None)
963    input_data: np.ndarray to pass into models during inference. (default None).
964    input_data_range: A map where the key is the input tensor name and
965      the value is a tuple (min_val, max_val) which specifies the value range of
966      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
967      generate a random value for tensor `input1` within range [1.0, 5.0)
968      (half-inclusive). (default None)
969    **kwargs: Additional arguments to be passed into the converter.
970  """
971  keras_model = keras.models.load_model(filename)
972  if input_shapes:
973    for tensor, shape in zip(keras_model.inputs, input_shapes):
974      tensor.set_shape(shape)
975
976  converter = _lite.TFLiteConverterV2.from_keras_model(keras_model)
977  tflite_model = _convert(converter, **kwargs)
978
979  tf_eval_func = evaluate_keras_model(filename)
980  compare_models_v2(
981      tflite_model,
982      tf_eval_func,
983      input_data=input_data,
984      input_data_range=input_data_range)
985