• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lite.py functionality related to TensorFlow 2.0."""
16
17import ctypes
18import functools
19import itertools
20import os
21import sys
22
23from absl.testing import parameterized
24import numpy as np
25import tensorflow as tf
26
27# Force loaded shared object symbols to be globally visible. This is needed so
28# that the interpreter_wrapper, in one .so file, can see the test_registerer,
29# in a different .so file. Note that this may already be set by default.
30# pylint: disable=g-import-not-at-top
31if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
32  sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
33
34from tensorflow.lite.python import conversion_metadata_schema_py_generated as metadata_fb
35from tensorflow.lite.python import convert
36from tensorflow.lite.python import lite
37from tensorflow.lite.python import lite_v2_test_util
38from tensorflow.lite.python import schema_py_generated as schema_fb
39from tensorflow.lite.python import test_util as tflite_test_util
40from tensorflow.lite.python import util
41from tensorflow.lite.python.convert import mlir_quantize
42from tensorflow.lite.python.interpreter import Interpreter
43from tensorflow.lite.python.interpreter import InterpreterWithCustomOps
44from tensorflow.lite.python.interpreter import OpResolverType
45from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer
46from tensorflow.lite.python.testdata import double_op
47from tensorflow.lite.python.util import get_conversion_metadata
48from tensorflow.lite.toco import types_pb2 as _types_pb2
49from tensorflow.lite.tools.flatbuffer_utils import convert_bytearray_to_object as _convert_bytearray_to_object
50from tensorflow.python.framework import dtypes
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import test_util
53from tensorflow.python.framework import versions
54from tensorflow.python.lib.io import file_io
55from tensorflow.python.ops import map_ops
56from tensorflow.python.ops import rnn
57from tensorflow.python.platform import resource_loader
58from tensorflow.python.platform import test
59from tensorflow.python.saved_model import save_options
60from tensorflow.python.saved_model import saved_model
61from tensorflow.python.saved_model.loader_impl import parse_saved_model
62from tensorflow.python.saved_model.save import save
63from tensorflow.python.trackable import autotrackable
64
65# Only run jax related tests when we can import jax.
66DISABLE_JAX_TEST = False
67try:
68  import jax
69  from jax import numpy as jnp
70except ImportError:
71  DISABLE_JAX_TEST = True
72# pylint: enable=g-import-not-at-top
73
74
75class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
76
77  @test_util.run_v2_only
78  def testTypeInvalid(self):
79    root = self._getSimpleVariableModel()
80    with self.assertRaises(ValueError) as error:
81      _ = lite.TFLiteConverterV2.from_concrete_functions([root.f], root)
82    self.assertIn('call get_concrete_function', str(error.exception))
83
84  @test_util.run_v2_only
85  def testFloat(self):
86    root = self._getSimpleVariableModel()
87    input_data = tf.constant(1., shape=[1])
88    concrete_func = root.f.get_concrete_function(input_data)
89
90    # Convert model.
91    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
92                                                               root)
93    tflite_model = converter.convert()
94
95    # Check output value from converted model.
96    expected_value = root.f(input_data)
97    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
98    self.assertEqual(expected_value.numpy(), actual_value)
99
100  @parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
101                                  ('_UINT8InputOutput', dtypes.uint8),
102                                  ('_INT16InputOutput', dtypes.int16))
103  @test_util.run_v2_only
104  def testInvalidFloat(self, inference_input_output_type):
105    root = self._getSimpleVariableModel()
106    input_data = tf.constant(1., shape=[1])
107    concrete_func = root.f.get_concrete_function(input_data)
108
109    # Convert model.
110    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
111                                                               root)
112    with self.assertRaises(ValueError) as error:
113      converter.inference_input_type = inference_input_output_type
114      converter.inference_output_type = inference_input_output_type
115      converter.convert()
116    self.assertEqual(
117        'The inference_input_type and inference_output_type '
118        'must be tf.float32.', str(error.exception))
119
120  @test_util.run_v2_only
121  def testScalarInput(self):
122    root = self._getSimpleVariableModel()
123    input_data = tf.constant(1., shape=[])
124    concrete_func = root.f.get_concrete_function(input_data)
125
126    # Convert model.
127    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
128                                                               root)
129    tflite_model = converter.convert()
130
131    # Check values from converted model.
132    expected_value = root.f(input_data)
133    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
134    self.assertEqual(expected_value.numpy(), actual_value)
135
136  @test_util.run_v2_only
137  def testStringInput(self):
138
139    class Model(tf.Module):
140
141      @tf.function
142      def __call__(self, x):
143        return x
144
145    root = Model()
146    concrete_func = root.__call__.get_concrete_function(
147        tf.constant([str(x) for x in range(11)]))
148    # Convert model.
149    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
150                                                               root)
151    tflite_model = converter.convert()
152    input_data = tf.constant([str(x) for x in range(11)],
153                             shape=(11,),
154                             dtype=tf.dtypes.string)
155    # Check values from converted model.
156    interpreter = tf.lite.Interpreter(model_content=tflite_model)
157    interpreter.allocate_tensors()
158    my_signature = interpreter.get_signature_runner()
159
160    with self.assertRaises(ValueError) as error:
161      _ = my_signature(x=input_data)
162    self.assertIn('Passed in value type is not a numpy array, got type ',
163                  str(error.exception))
164
165  @test_util.run_v2_only
166  def testModelWithoutInputs(self):
167
168    def _get_random_number_gen():
169      root = autotrackable.AutoTrackable()
170
171      @tf.function(input_signature=[])
172      def func():
173        return tf.random.uniform(shape=[1], dtype=tf.float32)
174
175      root.f = func
176      to_save = root.f.get_concrete_function()
177      return (root, to_save)
178
179    # Model with no input
180    root, concrete_func = _get_random_number_gen()
181
182    # Convert model.
183    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
184                                                               root)
185    tflite_model = converter.convert()
186    self.assertIsNotNone(tflite_model)
187
188  @test_util.run_v2_only
189  def testMultiFunctionModel(self):
190    """Convert a single model in a multi-functional model."""
191    root = self._getMultiFunctionModel()
192    input_data = tf.constant(1., shape=[1])
193    concrete_func = root.add.get_concrete_function(input_data)
194
195    # Convert model and ensure model is not None.
196    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
197                                                               root)
198    tflite_model = converter.convert()
199
200    # Check values from converted model.
201    expected_value = root.add(input_data)
202    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
203    self.assertEqual(expected_value.numpy(), actual_value)
204
205  @test_util.run_v2_only
206  def testConvertMultipleFunctions(self):
207    """Convert multiple functions in a multi-functional model."""
208    root = self._getMultiFunctionModel()
209    input_data = tf.constant(1., shape=[1])
210    add_func = root.add.get_concrete_function(input_data)
211    sub_func = root.sub.get_concrete_function(input_data)
212
213    # Try converting multiple functions.
214    converter = lite.TFLiteConverterV2.from_concrete_functions(
215        [add_func, sub_func], root)
216    tflite_model = converter.convert()
217
218    # Check signatures are valid from converted model.
219    interpreter = Interpreter(model_content=tflite_model)
220    signature_defs = interpreter.get_signature_list()
221
222    # Verify the SignatureDef structure returned is as expected.
223    self.assertEqual(len(signature_defs), 2)
224    self.assertEqual(list(signature_defs.keys()), ['add', 'sub'])
225    self.assertEqual(len(signature_defs.values()), 2)
226    self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs'])
227    self.assertCountEqual(signature_defs['add']['inputs'], ['x'])
228    self.assertEqual(list(signature_defs['add']['outputs']), ['output_0'])
229    self.assertEqual(list(signature_defs['sub'].keys()), ['inputs', 'outputs'])
230    self.assertCountEqual(signature_defs['sub']['inputs'], ['x'])
231    self.assertEqual(list(signature_defs['sub']['outputs']), ['output_0'])
232
233    # Verify the Signature runner executions.
234    add_signature_runner = interpreter.get_signature_runner('add')
235    add_output = add_signature_runner(x=input_data)
236    self.assertEqual(add_output['output_0'], 3)
237    input_details = add_signature_runner.get_input_details()
238    self.assertEqual(1, len(input_details))
239    self.assertEqual('add_x:0', input_details['x']['name'])
240    self.assertEqual(np.float32, input_details['x']['dtype'])
241    self.assertTrue(([1] == input_details['x']['shape']).all())
242    self.assertEqual((0.0, 0), input_details['x']['quantization'])
243
244    sub_signature_runner = interpreter.get_signature_runner('sub')
245    sub_output = sub_signature_runner(x=input_data)
246    self.assertEqual(sub_output['output_0'], -2)
247    output_details = sub_signature_runner.get_output_details()
248    self.assertEqual(1, len(output_details))
249    self.assertEqual('StatefulPartitionedCall:0',
250                     output_details['output_0']['name'])
251    self.assertEqual(np.float32, output_details['output_0']['dtype'])
252    self.assertTrue(([1] == output_details['output_0']['shape']).all())
253    self.assertEqual((0.0, 0), output_details['output_0']['quantization'])
254
255    # Check the conversion metadata.
256    metadata = get_conversion_metadata(tflite_model)
257    self.assertIsNotNone(metadata)
258    self.assertEqual(metadata.environment.apiVersion, 2)
259    self.assertEqual(metadata.environment.modelType,
260                     metadata_fb.ModelType.TF_CONCRETE_FUNCTIONS)
261    self.assertAllEqual([], metadata.options.modelOptimizationModes)
262
263  def _getIntegerQuantizeModel(self, num_filters=16):
264    np.random.seed(0)
265
266    root = autotrackable.AutoTrackable()
267
268    @tf.function(
269        input_signature=[tf.TensorSpec(shape=[1, 5, 5, 3], dtype=tf.float32)])
270    def func(inp):
271      conv = tf.nn.conv2d(
272          inp,
273          tf.ones([3, 3, 3, num_filters]), strides=[1, 1, 1, 1], padding='SAME')
274      output = tf.nn.relu(conv, name='output')
275      return output
276
277    def calibration_gen():
278      for _ in range(5):
279        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
280
281    root.f = func
282    to_save = root.f.get_concrete_function()
283    return (root, to_save, calibration_gen)
284
285  @parameterized.named_parameters(
286      ('EnableMlirQuantizer', True),  # enable mlir quantizer
287      ('DisableMlirQuantizer', False))  # disable mlir quantizer
288  def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer):
289    root, func, calibration_gen = self._getIntegerQuantizeModel()
290
291    # Convert float model.
292    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
293                                                                     root)
294    float_tflite_model = float_converter.convert()
295    self.assertIsNotNone(float_tflite_model)
296
297    # Convert quantized model.
298    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
299                                                                         root)
300    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
301    quantized_converter.representative_dataset = calibration_gen
302    quantized_converter.experimental_new_quantizer = mlir_quantizer
303    quantized_tflite_model = quantized_converter.convert()
304    self.assertIsNotNone(quantized_tflite_model)
305    # Check the conversion metadata.
306    metadata = get_conversion_metadata(quantized_tflite_model)
307    self.assertIsNotNone(metadata)
308    self.assertEqual(
309        metadata.environment.tensorflowVersion.decode('utf-8'),
310        versions.__version__)
311    self.assertEqual(metadata.environment.apiVersion, 2)
312    self.assertEqual(metadata.environment.modelType,
313                     metadata_fb.ModelType.TF_CONCRETE_FUNCTIONS)
314    self.assertEqual(metadata.options.allowCustomOps, False)
315    self.assertEqual(metadata.options.enableSelectTfOps, False)
316    self.assertEqual(metadata.options.forceSelectTfOps, False)
317    self.assertAllEqual([metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER],
318                        metadata.options.modelOptimizationModes)
319
320    # The default input and output types should be float.
321    interpreter = Interpreter(model_content=quantized_tflite_model)
322    interpreter.allocate_tensors()
323    input_details = interpreter.get_input_details()
324    self.assertLen(input_details, 1)
325    self.assertEqual(np.float32, input_details[0]['dtype'])
326    output_details = interpreter.get_output_details()
327    self.assertLen(output_details, 1)
328    self.assertEqual(np.float32, output_details[0]['dtype'])
329
330    # Ensure that the quantized weights tflite model is smaller.
331    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
332
333  @parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
334                                  ('_UINT8InputOutput', dtypes.uint8),
335                                  ('_INT16InputOutput', dtypes.int16))
336  @test_util.run_v2_only
337  def testInvalidPostTrainingDynamicRangeQuantization(
338      self, inference_input_output_type):
339    root, func, _ = self._getIntegerQuantizeModel()
340
341    # Convert float model.
342    converter = lite.TFLiteConverterV2.from_concrete_functions([func], root)
343    tflite_model = converter.convert()
344    self.assertTrue(tflite_model)
345
346    # Convert quantized model.
347    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
348                                                                         root)
349    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
350    with self.assertRaises(ValueError) as error:
351      quantized_converter.inference_input_type = inference_input_output_type
352      quantized_converter.inference_output_type = inference_input_output_type
353      quantized_converter.convert()
354    self.assertEqual(
355        'The inference_input_type and inference_output_type '
356        'must be tf.float32.', str(error.exception))
357
358  def _createV2QATSavedModelWithFloatOpsAtEnd(self):
359    """Create a simple QAT SavedModel that includes float ops at the end."""
360    saved_model_dir = os.path.join(self.get_temp_dir(), 'qat_float_ops_at_end')
361    input_tensor = tf.keras.layers.Input((32, 32, 128))
362    x = tf.quantization.fake_quant_with_min_max_args(input_tensor, -3.0, 3.0)
363    x = tf.keras.layers.Conv2D(1, (3, 3))(x)
364    x = tf.quantization.fake_quant_with_min_max_args(x, -3.0, 3.0)
365    # Exclude the quantization of the following Dense layer by not putting
366    # fake quant layer after the dense layer.
367    output_tensor = tf.keras.layers.Dense(1, activation='sigmoid')(x)
368    model = tf.keras.Model(input_tensor, output_tensor)
369    model.save(saved_model_dir)
370    return saved_model_dir
371
372  def testQuantizationRemovesQDQsForFloatIOInQAT(self):
373    saved_model_dir = self._createV2QATSavedModelWithFloatOpsAtEnd()
374    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
375    converter.optimizations = [lite.Optimize.DEFAULT]
376    quantized_model = converter.convert()
377
378    # Because assertions on the model later, we opt out applying default TFLite
379    # delegates (i.e. the XNNPACK delegate).
380    interpreter = Interpreter(
381        model_content=quantized_model,
382        experimental_op_resolver_type=OpResolverType
383        .BUILTIN_WITHOUT_DEFAULT_DELEGATES)
384    interpreter.allocate_tensors()
385    # The model should have LOGISTIC op, instead of DEQUANTIZE op.
386    op_details = interpreter._get_ops_details()
387    self.assertEqual(op_details[len(op_details) - 1]['op_name'], 'LOGISTIC')
388
389  @parameterized.named_parameters(
390      ('EnableMlirQuantizer', True),  # enable mlir quantizer
391      ('DisableMlirQuantizer', False))  # disable mlir quantizer
392  def testQuantizationRemovesQDQsForFloatIO(self, mlir_quantizer):
393    func, calibration_gen = self._getSqrtModel()
394    converter = lite.TFLiteConverterV2.from_concrete_functions(
395        [func.get_concrete_function()])
396    converter.representative_dataset = calibration_gen
397    converter.optimizations = [lite.Optimize.DEFAULT]
398    converter.experimental_new_quantizer = mlir_quantizer
399    quantized_model = converter.convert()
400
401    # Because assertions on the model later, we opt out applying default TFLite
402    # delegates (i.e. the XNNPACK delegate).
403    interpreter = Interpreter(
404        model_content=quantized_model,
405        experimental_op_resolver_type=OpResolverType
406        .BUILTIN_WITHOUT_DEFAULT_DELEGATES)
407    interpreter.allocate_tensors()
408    # The model should have only one sqrt op.
409    op_details = interpreter._get_ops_details()
410    self.assertLen(op_details, 1)
411    self.assertEqual(op_details[0]['op_name'], 'SQRT')
412
413  @parameterized.named_parameters(
414      ('_Default', False, False, dtypes.float32),
415      ('_INT8InputOutput', False, False, dtypes.int8),
416      ('_UINT8InputOutput', False, False, dtypes.uint8),
417      ('_INT16Quantize', False, True, dtypes.float32),
418      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
419      ('_IntOnly', True, False, dtypes.float32),
420      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
421      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
422      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
423      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
424  def testIntegerQuantization(self, is_int_only, is_int16_quantize,
425                              inference_input_output_type):
426    root, func, calibration_gen = self._getIntegerQuantizeModel()
427
428    # Convert float model.
429    converter = lite.TFLiteConverterV2.from_concrete_functions([func], root)
430    tflite_model = converter.convert()
431    self.assertTrue(tflite_model)
432
433    # Convert quantized model.
434    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
435                                                                         root)
436    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
437    quantized_converter.representative_dataset = calibration_gen
438    if is_int_only:
439      if is_int16_quantize:
440        quantized_converter.target_spec.supported_ops = [
441            lite.OpsSet.
442            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
443        ]
444      else:
445        quantized_converter.target_spec.supported_ops = [
446            lite.OpsSet.TFLITE_BUILTINS_INT8
447        ]
448    else:
449      if is_int16_quantize:
450        quantized_converter.target_spec.supported_ops = [
451            lite.OpsSet.
452            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
453            lite.OpsSet.TFLITE_BUILTINS
454        ]
455    quantized_converter.inference_input_type = inference_input_output_type
456    quantized_converter.inference_output_type = inference_input_output_type
457    quantized_tflite_model = quantized_converter.convert()
458    self.assertIsNotNone(quantized_tflite_model)
459    # Check the conversion metadata.
460    metadata = get_conversion_metadata(quantized_tflite_model)
461    self.assertIsNotNone(metadata)
462    expected_opt_options = [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]
463    if is_int16_quantize:
464      expected_opt_options = [metadata_fb.ModelOptimizationMode.PTQ_INT16]
465    self.assertAllEqual(expected_opt_options,
466                        metadata.options.modelOptimizationModes)
467
468    interpreter = Interpreter(model_content=quantized_tflite_model)
469    interpreter.allocate_tensors()
470    input_details = interpreter.get_input_details()
471    self.assertLen(input_details, 1)
472    self.assertEqual(inference_input_output_type.as_numpy_dtype,
473                     input_details[0]['dtype'])
474    output_details = interpreter.get_output_details()
475    self.assertLen(output_details, 1)
476    self.assertEqual(inference_input_output_type.as_numpy_dtype,
477                     output_details[0]['dtype'])
478
479    # Ensure that the quantized tflite model is smaller.
480    self.assertLess(len(quantized_tflite_model), len(tflite_model))
481
482  @parameterized.named_parameters(
483      ('_INT16Quantize_INT8InputOutput', True, dtypes.int8))
484  def testInvalidIntegerQuantization(self, is_int16_quantize,
485                                     inference_input_output_type):
486    root, func, calibration_gen = self._getIntegerQuantizeModel()
487
488    # Convert quantized model.
489    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
490                                                                         root)
491    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
492    quantized_converter.representative_dataset = calibration_gen
493    if is_int16_quantize:
494      quantized_converter.target_spec.supported_ops = [
495          lite.OpsSet.
496          EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
497          lite.OpsSet.TFLITE_BUILTINS
498      ]
499    with self.assertRaises(ValueError) as error:
500      quantized_converter.inference_input_type = dtypes.int8
501      quantized_converter.inference_output_type = dtypes.int8
502      quantized_converter.convert()
503    self.assertEqual(
504        'The inference_input_type and inference_output_type '
505        "must be in ['tf.float32', 'tf.int16'].", str(error.exception))
506
507  def testCalibrateAndQuantizeBuiltinInt16(self):
508    root, func, calibration_gen = self._getIntegerQuantizeModel()
509
510    # Convert float model.
511    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
512                                                                     root)
513    float_tflite_model = float_converter.convert()
514    self.assertIsNotNone(float_tflite_model)
515
516    converter = lite.TFLiteConverterV2.from_concrete_functions([func], root)
517    # TODO(b/156309549): We should add INT16 to the builtin types.
518    converter.optimizations = [lite.Optimize.DEFAULT]
519    converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
520    converter.representative_dataset = calibration_gen
521    converter._experimental_calibrate_only = True
522    calibrated_tflite = converter.convert()
523    quantized_tflite_model = mlir_quantize(
524        calibrated_tflite, inference_type=_types_pb2.QUANTIZED_INT16)
525
526    self.assertIsNotNone(quantized_tflite_model)
527
528    # The default input and output types should be float.
529    interpreter = Interpreter(model_content=quantized_tflite_model)
530    interpreter.allocate_tensors()
531    input_details = interpreter.get_input_details()
532    self.assertLen(input_details, 1)
533    self.assertEqual(np.float32, input_details[0]['dtype'])
534    output_details = interpreter.get_output_details()
535    self.assertLen(output_details, 1)
536    self.assertEqual(np.float32, output_details[0]['dtype'])
537
538    # Ensure that the quantized weights tflite model is smaller.
539    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
540
541  @test_util.run_v2_only
542  def testSignatureDefs(self):
543    """Test converting SignatureDef is correct and uses SignatureDef API."""
544    root = self._getMultiFunctionModel()
545    input_data = tf.constant(1., shape=[1])
546    add_func = root.add.get_concrete_function(input_data)
547
548    converter = lite.TFLiteConverterV2([add_func], trackable_obj=root)
549    tflite_model = converter.convert()
550
551    # Check values from converted model.
552    expected_value = add_func(input_data)
553    interpreter = Interpreter(model_content=tflite_model)
554    signature_defs = interpreter.get_signature_list()
555    results = self._evaluateTFLiteModelUsingSignatureDef(
556        tflite_model, 'serving_default', {'x': input_data})
557    self.assertLen(list(results.keys()), 1)
558    self.assertStartsWith(list(results.keys())[0], 'output')
559    self.assertAllClose(
560        expected_value.numpy(),
561        results[signature_defs['serving_default']['outputs'][0]])
562
563    # Verify the SignatureDef structure returned is as expected.
564    self.assertEqual(len(signature_defs), 1)
565    self.assertEqual(list(signature_defs.keys()), ['serving_default'])
566    self.assertEqual(len(signature_defs.values()), 1)
567    self.assertEqual(
568        list(signature_defs['serving_default'].keys()), ['inputs', 'outputs'])
569    self.assertCountEqual(signature_defs['serving_default']['inputs'], ['x'])
570    self.assertLen(list(signature_defs['serving_default']['outputs']), 1)
571    self.assertStartsWith(
572        list(signature_defs['serving_default']['outputs'])[0], 'output')
573
574  @test_util.run_v2_only
575  def testNoSignatureDefsWhenTrackingObjIsNone(self):
576    """Test converting SignatureDef is correct and uses SignatureDef API."""
577    root = self._getSimpleVariableModel()
578    input_data = tf.constant(1., shape=[1])
579    concrete_func = root.f.get_concrete_function(input_data)
580
581    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
582                                                               None)
583    tflite_model = converter.convert()
584
585    # Check values from converted model.
586    interpreter = Interpreter(model_content=tflite_model)
587    signature_defs = interpreter.get_signature_list()
588    # Verify that there is no SignatureDef structure found.
589    self.assertEqual(len(signature_defs), 0)
590
591  @test_util.run_v2_only
592  def testNoSignatureDefsWhenInvalidTrackingObjIsGiven(self):
593    """Test converting SignatureDef is correct and uses SignatureDef API."""
594    root = self._getSimpleVariableModel()
595    input_data = tf.constant(1., shape=[1])
596    concrete_func = root.f.get_concrete_function(input_data)
597
598    converter = lite.TFLiteConverterV2.from_concrete_functions(
599        [concrete_func], trackable_obj=autotrackable.AutoTrackable())
600    tflite_model = converter.convert()
601
602    # Check values from converted model.
603    interpreter = Interpreter(model_content=tflite_model)
604    signature_defs = interpreter.get_signature_list()
605    # Verify that there is no SignatureDef structure found.
606    self.assertEqual(len(signature_defs), 0)
607
608  @test_util.run_v2_only
609  def testTrackbleObject(self):
610    """Test converting with trackable objects."""
611    root = self._getMultiFunctionModel()
612    input_data = tf.constant(1., shape=[1])
613    add_func = root.add.get_concrete_function(input_data)
614
615    converter = lite.TFLiteConverterV2.from_concrete_functions(
616        [add_func], trackable_obj=root)
617    tflite_model = converter.convert()
618
619    # Check values from converted model.
620    expected_value = add_func(input_data)
621    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
622    self.assertEqual(expected_value.numpy(), actual_value)
623
624  def _getTrainingTimeQuantizedModel(self):
625
626    class QLinear(tf.keras.layers.Layer):
627
628      def __init__(self, units=3, **kwargs):
629        super(QLinear, self).__init__(**kwargs)
630        self.units = units
631
632      def build(self, input_shape):
633        self.w = self.add_weight(
634            'weight',
635            shape=(input_shape[-1], self.units),
636            initializer='random_normal',
637            trainable=True)
638        self.min_var = self.add_weight(
639            'min',
640            initializer=tf.keras.initializers.Constant(-6.0),
641            trainable=False)
642        self.max_var = self.add_weight(
643            'max',
644            initializer=tf.keras.initializers.Constant(6.0),
645            trainable=False)
646
647      def call(self, inputs):
648        x = tf.quantization.fake_quant_with_min_max_vars(
649            inputs, self.min_var, self.max_var)
650
651        w_fq = tf.quantization.fake_quant_with_min_max_vars(
652            self.w, self.min_var, self.max_var)
653        x = tf.matmul(x, w_fq)
654
655        x = tf.quantization.fake_quant_with_min_max_vars(
656            x, self.min_var, self.max_var)
657
658        return x
659
660    return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
661
662  @parameterized.named_parameters(
663      ('_DefaultFLOAT32InputOutput', dtypes.float32),
664      ('_INT8InputOutput', dtypes.int8), ('_UINT8InputOutput', dtypes.uint8))
665  @test_util.run_v2_only
666  def testTrainingTimeQuantization(self, inference_input_output_type):
667    model = self._getTrainingTimeQuantizedModel()
668
669    float_converter = lite.TFLiteConverterV2.from_keras_model(model)
670    float_tflite_model = float_converter.convert()
671    self.assertIsNotNone(float_tflite_model)
672
673    quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
674    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
675    quantized_converter.inference_input_type = inference_input_output_type
676    quantized_converter.inference_output_type = inference_input_output_type
677    quantized_tflite_model = quantized_converter.convert()
678    self.assertIsNotNone(quantized_tflite_model)
679    # Check the conversion metadata.
680    metadata = get_conversion_metadata(quantized_tflite_model)
681    self.assertIsNotNone(metadata)
682    self.assertAllEqual(
683        [metadata_fb.ModelOptimizationMode.QUANTIZATION_AWARE_TRAINING],
684        metadata.options.modelOptimizationModes)
685
686    interpreter = Interpreter(model_content=quantized_tflite_model)
687    interpreter.allocate_tensors()
688    input_details = interpreter.get_input_details()
689    self.assertLen(input_details, 1)
690    self.assertEqual(inference_input_output_type.as_numpy_dtype,
691                     input_details[0]['dtype'])
692    output_details = interpreter.get_output_details()
693    self.assertLen(output_details, 1)
694    self.assertEqual(inference_input_output_type.as_numpy_dtype,
695                     output_details[0]['dtype'])
696
697    # Ensure that the quantized tflite model is smaller.
698    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
699
700  @test_util.run_v2_only
701  def testNewQuantizer(self):
702    """Test the model quantized by the new converter."""
703    root, func, calibration_gen = self._getIntegerQuantizeModel()
704
705    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
706                                                                         root)
707    quantized_converter.target_spec.supported_ops = [
708        lite.OpsSet.TFLITE_BUILTINS_INT8
709    ]
710    quantized_converter.representative_dataset = calibration_gen
711
712    # default quantizer
713    quantized_converter.experimental_new_quantizer = False
714    old_tflite = quantized_converter.convert()
715
716    # new quantizer
717    quantized_converter.experimental_new_quantizer = True
718    new_tflite = quantized_converter.convert()
719
720    for _ in range(5):
721      input_data = tf.constant(
722          np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
723      old_value = self._evaluateTFLiteModel(old_tflite, [input_data])
724      new_value = self._evaluateTFLiteModel(new_tflite, [input_data])
725      self.assertAllClose(old_value, new_value, atol=1e-01)
726
727  @test_util.run_v2_only
728  def testEmbeddings(self):
729    """Test model with embeddings."""
730    input_data = tf.constant(
731        np.array(np.random.random_sample((20)), dtype=np.int32))
732
733    class EmbeddingModel(tf.keras.Model):
734
735      def __init__(self):
736        super(EmbeddingModel, self).__init__()
737        self.shared_weights = self.add_weight(
738            'weights',
739            shape=(2000, 300),
740            dtype=tf.float32,
741            initializer=tf.random_normal_initializer(
742                mean=0.0, stddev=300**(-0.5)))
743
744      @tf.function(input_signature=[tf.TensorSpec(shape=(20), dtype=tf.int32)])
745      def func(self, x):
746        return tf.gather(self.shared_weights, x)
747
748    # Building the model.
749    root = EmbeddingModel()
750    concrete_func = root.func.get_concrete_function()
751
752    # Convert model.
753    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
754                                                               root)
755    tflite_model = converter.convert()
756
757    # Check values from converted model.
758    expected_value = root.func(input_data)
759    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
760    self.assertAllClose(expected_value.numpy(), actual_value[0], atol=1e-05)
761
762  @test_util.run_v2_only
763  def testGraphDebugInfo(self):
764    """Test a concrete function has debug info captured."""
765    root = autotrackable.AutoTrackable()
766    root.v1 = tf.Variable(3.)
767    root.f = tf.function(lambda x: root.v1 * x)
768    input_data = tf.constant(1., shape=[1])
769    concrete_func = root.f.get_concrete_function(input_data)
770
771    # Convert model.
772    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
773                                                               root)
774    converter.convert()
775    self._assertValidDebugInfo(converter._debug_info)
776
777  def _getIntegerQuantizationModelWithFlexOp(self):
778    np.random.seed(0)
779
780    root = autotrackable.AutoTrackable()
781
782    @tf.function(input_signature=[
783        tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32)
784    ])
785    def func(inp):
786      tanh = tf.math.tanh(inp)
787      # Flex delegate will merge the consecutive conv3d and erf ops into one
788      # Delegate node.
789      conv3d = tf.nn.conv3d(
790          tanh,
791          tf.ones([3, 3, 3, 3, 3]),
792          strides=[1, 1, 1, 1, 1],
793          padding='SAME')
794      erf = tf.math.erf(conv3d)
795      output = tf.math.tanh(erf)
796      return output
797
798    def calibration_gen():
799      for _ in range(5):
800        yield [
801            np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32)
802        ]
803
804    root.f = func
805    return (root, root.f.get_concrete_function(), calibration_gen)
806
807  @parameterized.named_parameters(
808      ('_Default', False, False, dtypes.float32),
809      ('_INT8InputOutput', False, False, dtypes.int8),
810      ('_UINT8InputOutput', False, False, dtypes.uint8),
811      ('_INT16Quantize', False, True, dtypes.float32),
812      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
813      ('_IntOnly', True, False, dtypes.float32),
814      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
815      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
816      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
817      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
818  @test_util.run_v2_only
819  def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize,
820                                        inference_input_output_type):
821    root, func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp()
822
823    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
824        [func], root)
825    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
826    quantized_converter.representative_dataset = calibration_gen
827    if is_int_only:
828      if is_int16_quantize:
829        quantized_converter.target_spec.supported_ops = [
830            lite.OpsSet.
831            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
832            lite.OpsSet.SELECT_TF_OPS
833        ]
834      else:
835        quantized_converter.target_spec.supported_ops = [
836            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS
837        ]
838    else:
839      if is_int16_quantize:
840        quantized_converter.target_spec.supported_ops = [
841            lite.OpsSet.
842            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
843            lite.OpsSet.TFLITE_BUILTINS,
844            lite.OpsSet.SELECT_TF_OPS
845        ]
846      else:
847        quantized_converter.target_spec.supported_ops = [
848            lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS
849        ]
850
851    quantized_converter.inference_input_type = inference_input_output_type
852    quantized_converter.inference_output_type = inference_input_output_type
853    quantized_tflite_model = quantized_converter.convert()
854    self.assertIsNotNone(quantized_tflite_model)
855    # Check the conversion metadata.
856    metadata = get_conversion_metadata(quantized_tflite_model)
857    self.assertIsNotNone(metadata)
858    self.assertEqual(metadata.options.enableSelectTfOps, True)
859    expected_opt_options = [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]
860    if is_int16_quantize:
861      expected_opt_options = [metadata_fb.ModelOptimizationMode.PTQ_INT16]
862    self.assertAllEqual(expected_opt_options,
863                        metadata.options.modelOptimizationModes)
864
865    interpreter = Interpreter(model_content=quantized_tflite_model)
866    interpreter.allocate_tensors()
867    input_details = interpreter.get_input_details()
868    self.assertLen(input_details, 1)
869    self.assertEqual(inference_input_output_type.as_numpy_dtype,
870                     input_details[0]['dtype'])
871    output_details = interpreter.get_output_details()
872    self.assertLen(output_details, 1)
873    self.assertEqual(inference_input_output_type.as_numpy_dtype,
874                     output_details[0]['dtype'])
875
876  def _getIntegerQuantizationModelWithUnsupportedOps(self):
877    np.random.seed(0)
878
879    root = autotrackable.AutoTrackable()
880
881    @tf.function(input_signature=[
882        tf.TensorSpec(shape=[3], dtype=tf.float32),
883        tf.TensorSpec(shape=[3], dtype=tf.float32)
884    ])
885    def func(a, b):
886      # ceil kernel does not support int8 nor int16 types neither.
887      left = tf.math.ceil(a)
888      right = tf.nn.tanh(b)
889      add = tf.math.add(left, right)
890      # ceil kernel does not support int8 nor int16 types neither.
891      output = tf.math.ceil(add)
892      return (output, right)
893
894    def calibration_gen():
895      for _ in range(5):
896        yield [
897            np.random.uniform(-1, 1, size=(3)).astype(np.float32),
898            np.random.uniform(-1, 1, size=(3)).astype(np.float32)
899        ]
900
901    root.f = func
902    return (root, root.f.get_concrete_function(), calibration_gen)
903
904  @parameterized.named_parameters(
905      ('_INT8InputOutput', False, False, dtypes.int8),
906      ('_UINT8InputOutput', False, False, dtypes.uint8),
907      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
908      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
909      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
910      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
911      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
912      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
913  @test_util.run_v2_only
914  def testIntegerQuantizationWithUnsupportedOps(self,
915                                                is_int_only,
916                                                is_int16_quantize,
917                                                inference_input_output_type,
918                                                enable_mlir_quantizer=False):
919    root, func, calib_gen = self._getIntegerQuantizationModelWithUnsupportedOps(
920    )
921
922    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
923        [func], root)
924    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
925    quantized_converter.representative_dataset = calib_gen
926    if is_int_only:
927      if is_int16_quantize:
928        quantized_converter.target_spec.supported_ops = [
929            lite.OpsSet.
930            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
931            lite.OpsSet.TFLITE_BUILTINS
932        ]
933      else:
934        quantized_converter.target_spec.supported_ops = [
935            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
936        ]
937    else:
938      if is_int16_quantize:
939        quantized_converter.target_spec.supported_ops = [
940            lite.OpsSet.
941            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
942            lite.OpsSet.TFLITE_BUILTINS
943        ]
944      else:
945        quantized_converter.target_spec.supported_ops = [
946            lite.OpsSet.TFLITE_BUILTINS
947        ]
948
949    quantized_converter.inference_input_type = inference_input_output_type
950    quantized_converter.inference_output_type = inference_input_output_type
951    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
952    quantized_tflite_model = quantized_converter.convert()
953    self.assertIsNotNone(quantized_tflite_model)
954
955    expected_dtype = inference_input_output_type.as_numpy_dtype
956    # Allow float32 for fallback on non-quantizable op.
957    expected_ceil_dtype = (
958        expected_dtype if enable_mlir_quantizer else dtypes.float32)
959
960    interpreter = Interpreter(model_content=quantized_tflite_model)
961    interpreter.allocate_tensors()
962    input_details = interpreter.get_input_details()
963    self.assertLen(input_details, 2)
964    self.assertEqual(input_details[0]['dtype'], expected_dtype)
965    self.assertEqual(input_details[1]['dtype'], expected_ceil_dtype)
966    output_details = interpreter.get_output_details()
967    self.assertLen(output_details, 2)
968    self.assertEqual(output_details[0]['dtype'], expected_dtype)
969    self.assertEqual(output_details[1]['dtype'], expected_ceil_dtype)
970
971  def _getIntegerQuantizationModelWithControlFlow(self):
972    def true_fn(x):
973      return x
974
975    def false_fn(x):
976      return x
977
978    @tf.function(input_signature=[
979        tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
980        tf.TensorSpec(shape=(), dtype=tf.bool)
981    ])
982    def model(x, b):
983      x = x + x
984      x = tf.cond(b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
985      return x + x
986
987    def calibration_gen():
988      for _ in range(5):
989        yield [
990            np.random.uniform(-1, 1, size=(
991                1,
992                2,
993            )).astype(np.float32),
994            tf.constant(True),
995        ]
996      for _ in range(5):
997        yield [
998            np.random.uniform(-1, 1, size=(
999                1,
1000                2,
1001            )).astype(np.float32),
1002            tf.constant(False),
1003        ]
1004
1005    return (model, model.get_concrete_function(), calibration_gen)
1006
1007  @parameterized.named_parameters(
1008      ('_INT8InputOutput', False, False, dtypes.int8),
1009      ('_UINT8InputOutput', False, False, dtypes.uint8),
1010      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
1011      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
1012      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
1013      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
1014      # TODO(b/198231624): Support control flow ops in MLIR quantizer
1015      # ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
1016      # ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True),
1017  )
1018  @test_util.run_v2_only
1019  def testIntegerQuantizationWithControlFlow(self,
1020                                             is_int_only,
1021                                             is_int16_quantize,
1022                                             inference_input_output_type,
1023                                             enable_mlir_quantizer=False):
1024    root, func, calib_gen = self._getIntegerQuantizationModelWithControlFlow()
1025
1026    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
1027        [func], root)
1028    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1029    quantized_converter.representative_dataset = calib_gen
1030    if is_int_only:
1031      if is_int16_quantize:
1032        quantized_converter.target_spec.supported_ops = [
1033            lite.OpsSet
1034            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
1035            lite.OpsSet.TFLITE_BUILTINS
1036        ]
1037      else:
1038        quantized_converter.target_spec.supported_ops = [
1039            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
1040        ]
1041    else:
1042      if is_int16_quantize:
1043        quantized_converter.target_spec.supported_ops = [
1044            lite.OpsSet
1045            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
1046            lite.OpsSet.TFLITE_BUILTINS
1047        ]
1048      else:
1049        quantized_converter.target_spec.supported_ops = [
1050            lite.OpsSet.TFLITE_BUILTINS
1051        ]
1052
1053    quantized_converter.inference_input_type = inference_input_output_type
1054    quantized_converter.inference_output_type = inference_input_output_type
1055    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
1056
1057    quantized_tflite_model = quantized_converter.convert()
1058    self.assertIsNotNone(quantized_tflite_model)
1059
1060    expected_dtype = inference_input_output_type.as_numpy_dtype
1061
1062    interpreter = Interpreter(model_content=quantized_tflite_model)
1063    interpreter.allocate_tensors()
1064    input_details = interpreter.get_input_details()
1065    self.assertLen(input_details, 2)
1066    self.assertEqual(input_details[0]['dtype'], expected_dtype)
1067    self.assertEqual(input_details[1]['dtype'], dtypes.bool)
1068    output_details = interpreter.get_output_details()
1069    self.assertLen(output_details, 1)
1070    self.assertEqual(output_details[0]['dtype'], expected_dtype)
1071
1072  @parameterized.named_parameters(
1073      ('_BlocklistedNoneWithLowering', None, None, True),
1074      ('_BlocklistedNoneWithoutLowering', None, None, False),
1075      ('_BlocklistedOpsWithLowering', {'CONV_2D'}, None, True),
1076      ('_BlocklistedOpsWithoutLowering', {'CONV_2D'}, None, False),
1077      ('_BlocklistedNodesWithLowering', None, {'PartitionedCall:0'}, True),
1078      ('_BlocklistedNodesWithoutLowering', None, {'Identity'}, False))
1079  @test_util.run_v2_only
1080  def testNewQuantizerBlocklistingArgs(self, denylisted_ops, denylisted_nodes,
1081                                       lower_to_saved_model):
1082    """Test the model quantized by the new converter and denylisted options."""
1083    root, func, calibration_gen = self._getIntegerQuantizeModel()
1084    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
1085                                                                         root)
1086    quantized_converter.target_spec.supported_ops = [
1087        lite.OpsSet.TFLITE_BUILTINS_INT8
1088    ]
1089    quantized_converter.representative_dataset = calibration_gen
1090    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1091    quantized_converter.experimental_new_quantizer = True
1092    quantized_converter._experimental_calibrate_only = True
1093    quantized_converter.experimental_lower_to_saved_model = lower_to_saved_model
1094    calibrated = quantized_converter.convert()
1095    quantized_tflite_model = mlir_quantize(
1096        calibrated,
1097        denylisted_ops=denylisted_ops,
1098        denylisted_nodes=denylisted_nodes)
1099    interpreter = Interpreter(model_content=quantized_tflite_model)
1100    details = interpreter.get_tensor_details()
1101    num_quantized_tensors = sum(
1102        [1 for detail in details
1103         if len(detail['quantization_parameters']['scales'])])
1104    if denylisted_nodes or denylisted_ops:
1105      self.assertEqual(num_quantized_tensors, 0)
1106      return
1107    self.assertEqual(num_quantized_tensors, 4)  # quant, filter, bias, dequant
1108
1109  @parameterized.named_parameters(
1110      ('_SingleLayer', False),
1111      ('_WholeModel', True),
1112  )
1113  @test_util.run_v2_only
1114  def testNewQuantizerNumericVerificationDebugMode(self, whole_model_verify):
1115    """Test the model quantized by the new converter with numeric verify ops."""
1116    root, func, calibration_gen = self._getIntegerQuantizeModel()
1117
1118    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
1119                                                                         root)
1120    quantized_converter.target_spec.supported_ops = [
1121        lite.OpsSet.TFLITE_BUILTINS_INT8
1122    ]
1123    quantized_converter.representative_dataset = calibration_gen
1124
1125    # Create a TFLite model with new quantizer.
1126    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1127    quantized_converter.experimental_new_quantizer = True
1128    production_tflite = quantized_converter.convert()
1129    # Create a TFLite model with new quantizer and numeric verify ops.
1130    quantized_converter._experimental_calibrate_only = True
1131    calibrated = quantized_converter.convert()
1132    debug_mode_tflite = mlir_quantize(
1133        calibrated,
1134        enable_numeric_verify=True,
1135        enable_whole_model_verify=whole_model_verify)
1136
1137    # Check if adding debug mode should output a different flatbuffer.
1138    self.assertNotEqual(production_tflite, debug_mode_tflite)
1139
1140    # Check if newly added ops are numeric verify ops.
1141    input_data = tf.constant(
1142        np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
1143
1144    def examine_tflite_model(tflite_content, input_data):
1145      interpreter = Interpreter(
1146          model_content=tflite_content,
1147          experimental_op_resolver_type=OpResolverType
1148          .BUILTIN_WITHOUT_DEFAULT_DELEGATES)
1149      interpreter.allocate_tensors()
1150      input_details = interpreter.get_input_details()
1151      interpreter.set_tensor(input_details[0]['index'], input_data.numpy())
1152      interpreter.invoke()
1153      tensor_details = interpreter.get_tensor_details()
1154      return {
1155          details['name']: interpreter.get_tensor(details['index'])
1156          for details in interpreter.get_tensor_details()
1157      }, tensor_details
1158
1159    tflite_result, _ = examine_tflite_model(production_tflite, input_data)
1160    debug_mode_tflite_result, debug_tensor_details = examine_tflite_model(
1161        debug_mode_tflite, input_data)
1162
1163    # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
1164    num_production_quantize_ops = len([
1165        None for output_tensor_name in tflite_result
1166        if 'tfl.quantize' in output_tensor_name
1167    ])
1168    self.assertEqual(num_production_quantize_ops, 1)
1169    # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
1170    num_debug_quantize_ops = len([
1171        None for output_tensor_name in debug_mode_tflite_result
1172        if 'tfl.quantize' in output_tensor_name
1173    ])
1174    # Two numbers should be equal.
1175    self.assertEqual(num_production_quantize_ops, num_debug_quantize_ops)
1176    # DebugMode TFLite flatbuffer should have NumericVerifyOps more than zero.
1177    # The name has the prefix "NumericVerify/{name}:{id}
1178    # where {name} is the tensor name of the original quantized op's activation,
1179    # and {id} is its tensor id.
1180    num_debug_ops = 0
1181    for output_tensor_name in debug_mode_tflite_result:
1182      if 'NumericVerify' in output_tensor_name:
1183        pos_end_prefix = len('NumericVerify/')
1184        pos_colon = output_tensor_name.rfind(':')
1185        self.assertEqual('NumericVerify/', output_tensor_name[:pos_end_prefix])
1186        tensor_id = int(output_tensor_name[pos_colon + 1:])
1187        original_tensor_name = output_tensor_name[pos_end_prefix:pos_colon]
1188        self.assertEqual(original_tensor_name,
1189                         debug_tensor_details[tensor_id]['name'])
1190        num_debug_ops += 1
1191    self.assertEqual(num_debug_ops, 1)
1192    # The number of debug ops should be equal to that of quantized ops.
1193    self.assertEqual(num_debug_ops, num_debug_quantize_ops)
1194
1195  @parameterized.named_parameters(
1196      ('_PerChannelQuant', False, False),
1197      ('_PerChannelMlirQuant', False, True),
1198      ('_PerTensorQuant', True, False),
1199      ('_PerTensorMlirQuant', True, True),
1200      ('_PerChannelDynamicRange', False, False, False),
1201      ('_PerTensorDynamicRange', True, False, False))
1202  @test_util.run_v2_only
1203  def testDisablePerChannelQuantization(self, disable_per_channel=False,
1204                                        enable_mlir_quantizer=False,
1205                                        representative_dataset=True):
1206    k_conv_name = 'Conv2D'
1207    # Dynamic range quant requires total num elements of filters > 1024.
1208    k_num_filters = 38
1209    root, func, calib_gen = self._getIntegerQuantizeModel(k_num_filters)
1210    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
1211        [func], root)
1212    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1213    quantized_converter.representative_dataset = calib_gen
1214    quantized_converter.target_spec.supported_ops = [
1215        lite.OpsSet.TFLITE_BUILTINS
1216    ]
1217    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
1218    if disable_per_channel:
1219      quantized_converter._experimental_disable_per_channel = (
1220          disable_per_channel)
1221    quantized_tflite_model = quantized_converter.convert()
1222    self.assertIsNotNone(quantized_tflite_model)
1223
1224    interpreter = Interpreter(model_content=quantized_tflite_model)
1225    interpreter.allocate_tensors()
1226    detail = next((d for d in interpreter.get_tensor_details()
1227                   if d['name'].startswith(k_conv_name)))
1228    quant_params = detail['quantization_parameters']
1229    expected_num_params = 1 if disable_per_channel else k_num_filters
1230    self.assertLen(quant_params['scales'], expected_num_params)
1231    self.assertLen(quant_params['zero_points'], expected_num_params)
1232
1233  @parameterized.named_parameters(('MlirQuantize', True),
1234                                  ('TocoQuantize', False))
1235  @test_util.run_v2_only
1236  def testQuantizeBiasOverflow(self, enable_mlir_quantizer):
1237    """Tests if the quantizer handles bias overflow by adjusting scales."""
1238    input_data = np.array([[-1e-3, 1e-3]], dtype=np.float32)
1239
1240    def calibration_gen():
1241      yield {'x': input_data}
1242
1243    root = self._getMatMulModelWithSmallWeights()
1244    input_data = tf.constant([-1e-3, 1e-3], shape=(1, 2))
1245    concrete_func = root.matmul.get_concrete_function(input_data)
1246    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
1247                                                               root)
1248    converter.optimizations = [lite.Optimize.DEFAULT]
1249    converter.representative_dataset = calibration_gen
1250    converter.experimental_new_quantizer = enable_mlir_quantizer
1251    quantized_model = converter.convert()
1252
1253    interpreter = Interpreter(model_content=quantized_model)
1254    interpreter.allocate_tensors()
1255    input_details = interpreter.get_input_details()
1256    interpreter.set_tensor(input_details[0]['index'], input_data)
1257    interpreter.invoke()
1258    output_details = interpreter.get_output_details()
1259    output = interpreter.get_tensor(output_details[0]['index'])
1260    # the inputs and weights are far smaller than the biases, so the final
1261    # result should be equal to the biases.
1262    self.assertAllClose(root.bias, output.flatten())
1263
1264  @test_util.run_v2_only
1265  def testOpVersion(self):
1266    @tf.function(
1267        input_signature=[tf.TensorSpec(shape=[5, 5], dtype=tf.float32)])
1268    def custom_resize(image):
1269      # Add "batch" and "channels" dimensions
1270      image = image[tf.newaxis, ..., tf.newaxis]
1271      # ResizeBilinear version 3.
1272      resize1 = tf.compat.v1.image.resize_bilinear(
1273          image, [2, 2], half_pixel_centers=True)
1274      # ResizeBilinear version 1.
1275      resize2 = tf.compat.v1.image.resize_bilinear(image, [2, 2])
1276      return resize1 + resize2
1277
1278    concrete_func = custom_resize.get_concrete_function()
1279    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
1280                                                               custom_resize)
1281    tflite_model = converter.convert()
1282    model_object = schema_fb.Model.GetRootAsModel(tflite_model, 0)
1283    model = schema_fb.ModelT.InitFromObj(model_object)
1284
1285    for operator in model.operatorCodes:
1286      if operator.builtinCode == schema_fb.BuiltinOperator.RESIZE_BILINEAR:
1287        # half_pixel_centers is supported by ResizeBilinear version 3.
1288        self.assertEqual(operator.version, 3)
1289        break
1290
1291  @test_util.run_v2_only
1292  def testForceSelectTFOps(self):
1293    root = self._getSimpleVariableModel()
1294    input_data = tf.constant(1., shape=[1])
1295    concrete_func = root.f.get_concrete_function(input_data)
1296
1297    # Convert model.
1298    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
1299                                                               root)
1300    converter.target_spec.supported_ops = [
1301        tf.lite.OpsSet.SELECT_TF_OPS
1302    ]
1303    tflite_model = converter.convert()
1304    # Check the conversion metadata.
1305    metadata = get_conversion_metadata(tflite_model)
1306    self.assertIsNotNone(metadata)
1307    self.assertEqual(metadata.options.forceSelectTfOps, True)
1308
1309    # Check output value from converted model.
1310    expected_value = root.f(input_data)
1311    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1312    self.assertEqual(expected_value.numpy(), actual_value)
1313
1314  def testExcludeConversionMetadata(self):
1315    root = self._getSimpleVariableModel()
1316    input_data = tf.constant(1., shape=[1])
1317    concrete_func = root.f.get_concrete_function(input_data)
1318
1319    # Convert model.
1320    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
1321                                                               root)
1322    converter.exclude_conversion_metadata = True
1323    tflite_model = converter.convert()
1324    # Check the conversion metadata.
1325    metadata = get_conversion_metadata(tflite_model)
1326    self.assertIsNone(metadata)
1327
1328  def testConversionMetadataForDynamicRange(self):
1329    func, _ = self._getSqrtModel()
1330    converter = lite.TFLiteConverterV2.from_concrete_functions(
1331        [func.get_concrete_function()])
1332    converter.optimizations = [lite.Optimize.DEFAULT]
1333    quantized_model = converter.convert()
1334    # Check the conversion metadata.
1335    metadata = get_conversion_metadata(quantized_model)
1336    self.assertIsNotNone(metadata)
1337    self.assertAllEqual([metadata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE],
1338                        metadata.options.modelOptimizationModes)
1339
1340  def testConversionMetadataForFloat16(self):
1341    root, func, calibration_gen = self._getIntegerQuantizeModel()
1342    converter = lite.TFLiteConverterV2.from_concrete_functions([func], root)
1343    converter.optimizations = [lite.Optimize.DEFAULT]
1344    converter.representative_dataset = calibration_gen
1345    converter.target_spec.supported_types = [dtypes.float16]
1346    quantized_model = converter.convert()
1347    # Check the conversion metadata.
1348    metadata = get_conversion_metadata(quantized_model)
1349    self.assertIsNotNone(metadata)
1350    self.assertAllEqual([metadata_fb.ModelOptimizationMode.PTQ_FLOAT16],
1351                        metadata.options.modelOptimizationModes)
1352
1353
1354class FromSavedModelTest(lite_v2_test_util.ModelTest):
1355
1356  def _createV1SavedModel(self, shape):
1357    """Create a simple SavedModel."""
1358    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
1359    with tf.Graph().as_default():
1360      with tf.compat.v1.Session() as sess:
1361        in_tensor_1 = tf.compat.v1.placeholder(
1362            shape=shape, dtype=tf.float32, name='inputB')
1363        in_tensor_2 = tf.compat.v1.placeholder(
1364            shape=shape, dtype=tf.float32, name='inputA')
1365        variable_node = tf.Variable(1.0, name='variable_node')
1366        out_tensor = in_tensor_1 + in_tensor_2 * variable_node
1367        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
1368        outputs = {'z': out_tensor}
1369        sess.run(tf.compat.v1.variables_initializer([variable_node]))
1370        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1371    return saved_model_dir
1372
1373  def _createV2QATSavedModel(self, shape):
1374    """Create a simple QAT SavedModel in TF 2."""
1375    saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1376    input_name = 'input'
1377    output_name = 'scores'
1378
1379    input_tensor = tf.keras.layers.Input((32, 32, 128), name=input_name)
1380    x = tf.quantization.fake_quant_with_min_max_args(input_tensor, -3.0, 3.0)
1381    x = tf.keras.layers.Conv2D(1, (3, 3))(x)
1382    x = tf.quantization.fake_quant_with_min_max_args(x, -3.0, 3.0)
1383    scores = tf.keras.layers.Reshape((-1,), name=output_name)(x)
1384    model = tf.keras.Model(input_tensor, scores)
1385    model.save(saved_model_dir)
1386    return saved_model_dir, input_name, output_name
1387
1388  @test_util.run_v2_only
1389  def testV1SimpleModel(self):
1390    """Test a SavedModel."""
1391    with tf.Graph().as_default():
1392      saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
1393
1394      # Convert model and ensure model is not None.
1395      converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1396      tflite_model = converter.convert()
1397      self.assertTrue(tflite_model)
1398
1399      interpreter = Interpreter(model_content=tflite_model)
1400      interpreter.allocate_tensors()
1401
1402      input_details = interpreter.get_input_details()
1403      self.assertLen(input_details, 2)
1404      self.assertStartsWith(input_details[0]['name'], 'inputA')
1405      self.assertEqual(np.float32, input_details[0]['dtype'])
1406      self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1407      self.assertEqual((0., 0.), input_details[0]['quantization'])
1408
1409      self.assertStartsWith(
1410          input_details[1]['name'],
1411          'inputB',
1412      )
1413      self.assertEqual(np.float32, input_details[1]['dtype'])
1414      self.assertTrue([1, 16, 16, 3], input_details[1]['shape'])
1415      self.assertEqual((0., 0.), input_details[1]['quantization'])
1416
1417      output_details = interpreter.get_output_details()
1418      self.assertLen(output_details, 1)
1419      self.assertStartsWith(output_details[0]['name'], 'add')
1420      self.assertEqual(np.float32, output_details[0]['dtype'])
1421      self.assertTrue([1, 16, 16, 3], output_details[0]['shape'])
1422      self.assertEqual((0., 0.), output_details[0]['quantization'])
1423
1424  @parameterized.named_parameters(
1425      ('Default', False),
1426      ('UnfoldLargeConstant', True),
1427  )
1428  @test_util.run_v2_only
1429  def testUnfoldLargeConstant(self, unfold_large_constant):
1430    """Test unfolding large splat constant in a TF Lite model."""
1431    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
1432    with tf.Graph().as_default():
1433      with tf.compat.v1.Session() as sess:
1434        in_tensor = tf.compat.v1.placeholder(
1435            shape=[1000, 1000], dtype=tf.float32, name='input')
1436        constant = tf.constant(value=1, dtype=tf.float32, shape=[1000, 1000])
1437        out_tensor = in_tensor + constant
1438        inputs = {'x': in_tensor}
1439        outputs = {'y': out_tensor}
1440        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1441
1442    # Convert model and ensure model is not None.
1443    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1444    converter._experimental_unfold_large_splat_constant = unfold_large_constant
1445    tflite_model = converter.convert()
1446    self.assertTrue(tflite_model)
1447
1448    model = util._convert_model_from_bytearray_to_object(tflite_model)
1449    if unfold_large_constant:
1450      self.assertEqual(model.operatorCodes[0].builtinCode,
1451                       schema_fb.BuiltinOperator.FILL)
1452      self.assertEqual(model.operatorCodes[1].builtinCode,
1453                       schema_fb.BuiltinOperator.ADD)
1454    else:
1455      self.assertEqual(model.operatorCodes[0].builtinCode,
1456                       schema_fb.BuiltinOperator.ADD)
1457
1458    # Check values from converted model.
1459    interpreter = Interpreter(model_content=tflite_model)
1460    interpreter.allocate_tensors()
1461
1462    input_details = interpreter.get_input_details()
1463    self.assertLen(input_details, 1)
1464    self.assertEqual('input:0', input_details[0]['name'])
1465    self.assertEqual(np.float32, input_details[0]['dtype'])
1466    self.assertAllEqual([1000, 1000], input_details[0]['shape'])
1467    self.assertEqual((0., 0.), input_details[0]['quantization'])
1468
1469    output_details = interpreter.get_output_details()
1470    self.assertEqual('add:0', output_details[0]['name'])
1471    self.assertEqual(np.float32, output_details[0]['dtype'])
1472    self.assertAllEqual([1000, 1000], output_details[0]['shape'])
1473    self.assertEqual((0., 0.), output_details[0]['quantization'])
1474
1475    interpreter.set_tensor(input_details[0]['index'],
1476                           np.ones(shape=[1000, 1000], dtype=np.float32))
1477    interpreter.invoke()
1478    self.assertAllEqual(
1479        np.full(shape=[1000, 1000], fill_value=2.0, dtype=np.float32),
1480        interpreter.get_tensor(output_details[0]['index']))
1481
1482  @test_util.run_v2_only
1483  def testPreserveAssert(self):
1484    """Test preserving AssertOp in a TF Lite model."""
1485    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
1486    with tf.Graph().as_default():
1487      with tf.compat.v1.Session() as sess:
1488        in_tensor = tf.compat.v1.placeholder(
1489            shape=[10, 10], dtype=tf.float32, name='input')
1490        constant = tf.constant(value=1, dtype=tf.float32, shape=[10, 10])
1491        assert_op = tf.Assert(tf.less_equal(in_tensor, constant), [in_tensor])
1492        with tf.control_dependencies([assert_op]):
1493          out_tensor = in_tensor + constant
1494        inputs = {'x': in_tensor}
1495        outputs = {'y': out_tensor}
1496        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1497
1498    # Convert model and ensure model is not None.
1499    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1500    converter.target_spec.supported_ops = [
1501        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
1502    ]
1503    converter._experimental_preserve_assert_op = True
1504    tflite_model = converter.convert()
1505    self.assertTrue(tflite_model)
1506
1507    model = util._convert_model_from_bytearray_to_object(tflite_model)
1508    has_assert = False
1509    for op_code in model.operatorCodes:
1510      if op_code.customCode == b'FlexAssert':
1511        has_assert = True
1512        break
1513    self.assertTrue(has_assert)
1514
1515  @test_util.run_v2_only
1516  def testTF1HubFormattedModel(self):
1517    """Test a TF1 hub formatted model."""
1518    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
1519
1520    # TF1 hub model is based on V1 saved model and they omit the saved model
1521    # schema version setting.
1522    saved_model_proto = parse_saved_model(saved_model_dir)
1523    saved_model_proto.saved_model_schema_version = 0
1524
1525    saved_model_pb_file_path = os.path.join(saved_model_dir, 'saved_model.pb')
1526    with file_io.FileIO(saved_model_pb_file_path, 'wb') as writer:
1527      writer.write(saved_model_proto.SerializeToString())
1528
1529    # Convert model and ensure model is not None.
1530    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1531    tflite_model = converter.convert()
1532    self.assertTrue(tflite_model)
1533
1534  def _createV1ModelWithHashTableInitializer(self):
1535    # Create a v1 saved model with hash table initializers.
1536    tf.compat.v1.disable_eager_execution()
1537    saved_model_dir = os.path.join(self.get_temp_dir(),
1538                                   'savedmodel_with_hashtable')
1539
1540    table_initializer = tf.lookup.KeyValueTensorInitializer(
1541        keys=['a', 'b', 'c', 'd'],
1542        values=[1, 2, 3, 4],
1543        key_dtype=tf.string,
1544        value_dtype=tf.int64)
1545    table = tf.lookup.StaticHashTable(
1546        table_initializer, default_value=tf.constant(-1, dtype=tf.int64))
1547
1548    x = tf.compat.v1.placeholder(tf.string, shape=(), name='input')
1549    y = table.lookup(x)
1550
1551    tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
1552    tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
1553
1554    signature_def_map, init_op, assets_collection = {
1555        'serving_default':
1556            (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
1557                inputs={'x': tensor_info_x},
1558                outputs={'y': tensor_info_y},
1559                method_name='some_function'))
1560    }, tf.compat.v1.tables_initializer(), None
1561
1562    sess = tf.compat.v1.Session()
1563    sess.run(tf.compat.v1.initializers.global_variables())
1564
1565    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(
1566        saved_model_dir)
1567    builder.add_meta_graph_and_variables(
1568        sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
1569        signature_def_map,
1570        main_op=init_op,
1571        assets_collection=assets_collection,
1572        strip_default_attrs=True)
1573    builder.save()
1574
1575    # Restore TF v2 behavior.
1576    tf.compat.v1.reset_default_graph()
1577    tf.compat.v1.enable_eager_execution()
1578    return saved_model_dir
1579
1580  @test_util.run_v2_only
1581  def testModelWithHashTableInitializer(self):
1582    """Test a model with saved_model's session initializer for hash tables."""
1583    saved_model_dir = self._createV1ModelWithHashTableInitializer()
1584
1585    # Convert model and ensure model is not None.
1586    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1587    tflite_model = converter.convert()
1588
1589    # Check values from converted model.
1590    interpreter = Interpreter(model_content=tflite_model)
1591    input_details = interpreter.get_input_details()
1592    output_details = interpreter.get_output_details()
1593
1594    input_data = np.array(['a', 'b', 'c', 'z'], dtype=np.string_)
1595    interpreter.resize_tensor_input(
1596        input_details[0]['index'], [4], strict=False)
1597    interpreter.allocate_tensors()
1598
1599    interpreter.set_tensor(input_details[0]['index'], input_data)
1600
1601    # Invoke multiple times to ensure the initializer graph runs only once.
1602    interpreter.invoke()
1603    actual_value = interpreter.get_tensor(output_details[0]['index'])
1604    self.assertEqual([1, 2, 3, -1], list(actual_value))
1605
1606    interpreter.invoke()
1607    actual_value = interpreter.get_tensor(output_details[0]['index'])
1608    self.assertEqual([1, 2, 3, -1], list(actual_value))
1609
1610    interpreter.invoke()
1611    actual_value = interpreter.get_tensor(output_details[0]['index'])
1612    self.assertEqual([1, 2, 3, -1], list(actual_value))
1613
1614  def _createV1ModelWithMutableHashTable(self):
1615    # Create a v1 saved model with mutable hash table.
1616    tf.compat.v1.disable_eager_execution()
1617    saved_model_dir = os.path.join(self.get_temp_dir(),
1618                                   'savedmodel_with_mutable_hashtable')
1619
1620    table = tf.raw_ops.MutableHashTableV2(
1621        key_dtype=tf.string, value_dtype=tf.int64)
1622    x = tf.compat.v1.placeholder(tf.string, shape=(), name='input')
1623    keys = tf.constant(['a', 'b'], tf.string)
1624    values = tf.constant([1, 5], tf.int64)
1625    default_value = tf.constant(-1, tf.int64)
1626    insert_call = tf.raw_ops.LookupTableInsertV2(
1627        table_handle=table, keys=keys, values=values)
1628    with tf.control_dependencies([insert_call]):
1629      y = tf.raw_ops.LookupTableFindV2(
1630          table_handle=table, keys=x, default_value=default_value)
1631
1632    tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
1633    tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
1634
1635    signature_def_map, init_op, assets_collection = {
1636        'serving_default':
1637            (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
1638                inputs={'x': tensor_info_x},
1639                outputs={'y': tensor_info_y},
1640                method_name='some_function'))
1641    }, tf.compat.v1.tables_initializer(), None
1642
1643    sess = tf.compat.v1.Session()
1644
1645    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(
1646        saved_model_dir)
1647    builder.add_meta_graph_and_variables(
1648        sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
1649        signature_def_map,
1650        main_op=init_op,
1651        assets_collection=assets_collection,
1652        strip_default_attrs=True)
1653    builder.save()
1654
1655    # Restore TF v2 behavior.
1656    tf.compat.v1.reset_default_graph()
1657    tf.compat.v1.enable_eager_execution()
1658    return saved_model_dir
1659
1660  @test_util.run_v2_only
1661  def testModelWithMutableHashTable(self):
1662    """Test a model with saved_model's session initializer for hash tables."""
1663    saved_model_dir = self._createV1ModelWithMutableHashTable()
1664
1665    # Convert model and ensure model is not None.
1666    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1667    converter.target_spec.supported_ops = [
1668        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
1669    ]
1670    tflite_model = converter.convert()
1671
1672    # Check values from converted model.
1673    interpreter = Interpreter(model_content=tflite_model)
1674    input_details = interpreter.get_input_details()
1675    output_details = interpreter.get_output_details()
1676
1677    input_data = np.array(['a', 'b', 'c'], dtype=np.string_)
1678    interpreter.resize_tensor_input(
1679        input_details[0]['index'], [3], strict=False)
1680    interpreter.allocate_tensors()
1681
1682    interpreter.set_tensor(input_details[0]['index'], input_data)
1683
1684    interpreter.invoke()
1685    actual_value = interpreter.get_tensor(output_details[0]['index'])
1686    self.assertEqual([1, 5, -1], list(actual_value))
1687
1688  @test_util.run_v2_only
1689  def testReduceSumWithInt16Quant(self):
1690    """Test a model with quantized int16 reduce sum op."""
1691    inp = tf.keras.Input([3, 3], 3, name='x')
1692    m = tf.keras.Model(inp, tf.reduce_sum(inp, axis=-1))
1693
1694    converter = tf.lite.TFLiteConverter.from_keras_model(m)
1695    converter.target_spec.supported_ops = [
1696        tf.lite.OpsSet
1697        .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
1698    ]
1699    converter.inference_input_type = tf.int16
1700    converter.inference_output_type = tf.int16
1701    converter.optimizations = [tf.lite.Optimize.DEFAULT]
1702    inputs = {
1703        i.name: np.random.normal(size=i.shape).astype(np.float32)
1704        for i in m.inputs
1705    }
1706    converter.representative_dataset = lambda: [inputs]
1707    content = converter.convert()
1708
1709    interpreter = tf.lite.Interpreter(model_content=content)
1710    runner = interpreter.get_signature_runner('serving_default')
1711    y = runner(x=np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]).astype(np.int16))
1712    self.assertEqual([3, 6, 9], list(list(y.values())[0]))
1713
1714  @test_util.run_v2_only
1715  def testConstModel(self):
1716    """Test a basic model with functions to make sure functions are inlined."""
1717    input_data = tf.constant(1., shape=[1])
1718    root = autotrackable.AutoTrackable()
1719    root.f = tf.function(lambda x: 2. * x)
1720    to_save = root.f.get_concrete_function(input_data)
1721
1722    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1723    save(root, save_dir, to_save)
1724
1725    # Convert model and ensure model is not None.
1726    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1727    tflite_model = converter.convert()
1728
1729    # Check values from converted model.
1730    expected_value = root.f(input_data)
1731    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1732    self.assertEqual(expected_value.numpy(), actual_value)
1733
1734  @test_util.run_v2_only
1735  def testVariableModel(self):
1736    """Test a basic model with Variables with saving/loading the SavedModel."""
1737    root = self._getSimpleVariableModel()
1738    input_data = tf.constant(1., shape=[1])
1739    to_save = root.f.get_concrete_function(input_data)
1740
1741    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1742    save(root, save_dir, to_save)
1743
1744    # Convert model and ensure model is not None.
1745    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1746    tflite_model = converter.convert()
1747    # Check the conversion metadata.
1748    metadata = get_conversion_metadata(tflite_model)
1749    self.assertIsNotNone(metadata)
1750    self.assertEqual(metadata.environment.modelType,
1751                     metadata_fb.ModelType.TF_SAVED_MODEL)
1752
1753    # Check values from converted model.
1754    expected_value = root.f(input_data)
1755    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1756    self.assertEqual(expected_value.numpy(), actual_value)
1757
1758  @parameterized.named_parameters(('EnableResourceVariables', True),
1759                                  ('DisableResourceVariables', False))
1760  @test_util.run_v2_only
1761  def testNativeVariablesModel(self, enable_resource_variables):
1762    """Test a basic model with Variables with saving/loading the SavedModel."""
1763    root = self._getSimpleModelWithVariables()
1764    input_data = tf.constant(1., shape=[1, 10])
1765    to_save = root.assign_add.get_concrete_function(input_data)
1766
1767    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1768    save(root, save_dir, to_save)
1769
1770    # Convert model and ensure model is not None.
1771    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1772    converter.experimental_enable_resource_variables = enable_resource_variables
1773
1774    if not enable_resource_variables:
1775      with self.assertRaises(convert.ConverterError) as error:
1776        tflite_model = converter.convert()
1777      self.assertIn(
1778          'Variable constant folding is failed. Please consider using enabling '
1779          '`experimental_enable_resource_variables` flag in the TFLite '
1780          'converter object.',
1781          str(error.exception))
1782      return
1783
1784    # Enable resource variables.
1785    tflite_model = converter.convert()
1786
1787    # Check values from converted model.
1788    expected_value = root.assign_add(input_data)
1789    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1790    for tf_result, tflite_result in zip(expected_value, actual_value[0]):
1791      self.assertAllClose(tf_result, tflite_result, atol=1e-05)
1792
1793  @test_util.run_v2_only
1794  def testSignatures(self):
1795    """Test values for `signature_keys` argument."""
1796    root = self._getSimpleVariableModel()
1797    input_data = tf.constant(1., shape=[1])
1798    to_save = root.f.get_concrete_function(input_data)
1799
1800    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1801    save(root, save_dir, to_save)
1802
1803    # Convert model with invalid `signature_keys`.
1804    with self.assertRaises(ValueError) as error:
1805      _ = lite.TFLiteConverterV2.from_saved_model(
1806          save_dir, signature_keys=['INVALID'])
1807    self.assertIn("Invalid signature key 'INVALID'", str(error.exception))
1808
1809    # Convert model with empty `signature_keys`.
1810    converter = lite.TFLiteConverterV2.from_saved_model(
1811        save_dir, signature_keys=[])
1812    tflite_model = converter.convert()
1813
1814    # Check values from converted model.
1815    expected_value = root.f(input_data)
1816    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1817    self.assertEqual(expected_value.numpy(), actual_value)
1818
1819  @test_util.run_v2_only
1820  def testSignatureDefsWithFullIntegerQuantization(self):
1821    # SETUP
1822    # 1. Define input shapes
1823    tf_input_shape = (32, 32, 128)
1824    tflite_input_shape = (1,) + tf_input_shape
1825    # 2. Define model
1826    tf_saved_model_dir, input_name, output_name = (
1827        self._createV2QATSavedModel(tf_input_shape))
1828
1829    # MODEL 1: TFLite (float) model
1830    # 1. Create TFLite model
1831    converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_model_dir)
1832    converter.optimizations = [tf.lite.Optimize.DEFAULT]
1833    tflite_model = converter.convert()
1834    # 2. Initialize the Intepreter
1835    interpreter = Interpreter(model_content=tflite_model)
1836    input_details = interpreter.get_input_details()[0]
1837    output_details = interpreter.get_output_details()[0]
1838    interpreter.resize_tensor_input(input_details['index'], tflite_input_shape)
1839    interpreter.allocate_tensors()
1840    signature_list = interpreter._get_full_signature_list()['serving_default']
1841    # 3. (Skip) Verify that signature def input/output tensors are in the model.
1842    # 4. Evaluate the model
1843    input_data = np.random.random(tflite_input_shape).astype(np.float32)
1844    result = self._evaluateTFLiteModelUsingSignatureDef(
1845        tflite_model, 'serving_default', {input_name: input_data})[output_name]
1846
1847    # MODEL 2: TFLite (full integer quantized) model
1848    # 1. Create TFLite model
1849    converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_model_dir)
1850    converter.optimizations = [tf.lite.Optimize.DEFAULT]
1851    converter.inference_input_type = tf.int8
1852    converter.inference_output_type = tf.int8
1853    tflite_model_quant = converter.convert()
1854    # 2. Initialize the Intepreter
1855    interpreter = Interpreter(model_content=tflite_model_quant)
1856    input_details = interpreter.get_input_details()[0]
1857    output_details = interpreter.get_output_details()[0]
1858    interpreter.resize_tensor_input(input_details['index'], tflite_input_shape)
1859    interpreter.allocate_tensors()
1860    # 3. Verify that signature def input/output tensors are in the model.
1861    all_indices = {item['index'] for item in interpreter.get_tensor_details()}
1862    signature_list = interpreter._get_full_signature_list()['serving_default']
1863    input_tensor_indices = set(signature_list['inputs'].values())
1864    assert input_tensor_indices.issubset(all_indices)
1865    output_tensor_indices = set(signature_list['outputs'].values())
1866    assert output_tensor_indices.issubset(all_indices)
1867
1868    # 4. Evaluate the model
1869    input_data = np.random.random(tflite_input_shape)
1870    input_scale, input_zero_point = input_details['quantization']
1871    if (input_scale, input_zero_point) != (0.0, 0):
1872      input_data = input_data / input_scale + input_zero_point
1873      input_data = input_data.astype(input_details['dtype'])
1874    result_quant = self._evaluateTFLiteModelUsingSignatureDef(
1875        tflite_model_quant, 'serving_default',
1876        {input_name: input_data})[output_name]
1877    output_scale, output_zero_point = output_details['quantization']
1878    if (output_scale, output_zero_point) != (0.0, 0):
1879      result_quant = result_quant.astype(np.float32)
1880      result_quant = (result_quant - output_zero_point) * output_scale
1881
1882    # COMPARE: Validate that results from both models are approx. the same.
1883    root_mean_squared = np.sqrt(np.mean((result-result_quant)**2))
1884    assert root_mean_squared < 1.0
1885
1886  @test_util.run_v2_only
1887  def testSignatureDefs(self):
1888    """Test converting SignatureDef is correct and uses SignatureDef API."""
1889    root = self._getMultiFunctionModel()
1890    input_data_0 = tf.constant(1., shape=[1])
1891    input_data_1 = tf.constant(3., shape=[1])
1892    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
1893                                                      input_data_0)
1894
1895    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1896    save(root, save_dir, {'mul_add': mul_add_func})
1897
1898    converter = lite.TFLiteConverterV2.from_saved_model(
1899        save_dir, signature_keys=['mul_add'])
1900    tflite_model = converter.convert()
1901
1902    # Check values from converted model.
1903    expected_value = root.mul_add(input_data_1, input_data_0)
1904    interpreter = Interpreter(model_content=tflite_model)
1905    signature_defs = interpreter.get_signature_list()
1906    results = self._evaluateTFLiteModelUsingSignatureDef(
1907        tflite_model, 'mul_add', {
1908            'y': input_data_0,
1909            'x': input_data_1
1910        })
1911    self.assertEqual(list(results.keys()), ['output_0'])
1912    self.assertEqual(expected_value.numpy(), results['output_0'])
1913
1914    # Verify the SignatureDef structure returned is as expected.
1915    self.assertEqual(len(signature_defs), 1)
1916    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
1917    self.assertEqual(len(signature_defs.values()), 1)
1918    self.assertEqual(
1919        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
1920    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
1921    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
1922
1923  @test_util.run_v2_only
1924  def testSignatureDefsWithDefaultValue(self):
1925    """Test converting SignatureDef is correct and uses SignatureDef API.
1926
1927    This test uses None as signature_key to test default behavior.
1928    """
1929    root = self._getMultiFunctionModel()
1930    input_data_0 = tf.constant(1., shape=[1])
1931    input_data_1 = tf.constant(3., shape=[1])
1932    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
1933                                                      input_data_0)
1934
1935    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1936    save(root, save_dir, {'mul_add': mul_add_func})
1937
1938    converter = lite.TFLiteConverterV2.from_saved_model(
1939        save_dir, signature_keys=['mul_add'])
1940    tflite_model = converter.convert()
1941
1942    # Check values from converted model.
1943    expected_value = root.mul_add(input_data_1, input_data_0)
1944    interpreter = Interpreter(model_content=tflite_model)
1945    signature_defs = interpreter.get_signature_list()
1946    results = self._evaluateTFLiteModelUsingSignatureDef(
1947        tflite_model, None, {
1948            'y': input_data_0,
1949            'x': input_data_1
1950        })
1951    self.assertEqual(list(results.keys()), ['output_0'])
1952    self.assertEqual(expected_value.numpy(), results['output_0'])
1953
1954    # Verify the SignatureDef structure returned is as expected.
1955    self.assertEqual(len(signature_defs), 1)
1956    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
1957    self.assertEqual(len(signature_defs.values()), 1)
1958    self.assertEqual(
1959        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
1960    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
1961    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
1962
1963  @test_util.run_v2_only
1964  def testSignatureDefsQuantizedModel(self):
1965    """Test converting SignatureDef on quantized model."""
1966    root = self._getMultiFunctionModel()
1967    input_data_0 = tf.constant(1., shape=[1])
1968    input_data_1 = tf.constant(3., shape=[1])
1969    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
1970                                                      input_data_0)
1971
1972    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1973    save(root, save_dir, {'mul_add': mul_add_func})
1974
1975    converter = lite.TFLiteConverterV2.from_saved_model(
1976        save_dir, signature_keys=['mul_add'])
1977
1978    def representative_dataset_gen():
1979      for _ in range(2):
1980        yield {
1981            'x':
1982                np.random.uniform(low=0, high=1,
1983                                  size=(1, 1)).astype(np.float32),
1984            'y':
1985                np.random.uniform(low=0, high=1, size=(1, 1)).astype(np.float32)
1986        }
1987
1988    converter.optimizations = [tf.lite.Optimize.DEFAULT]
1989    converter.representative_dataset = representative_dataset_gen
1990    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
1991    tflite_model = converter.convert()
1992
1993    # Check signatures are valid from converted model.
1994    interpreter = Interpreter(model_content=tflite_model)
1995    signature_defs = interpreter.get_signature_list()
1996
1997    # Verify the SignatureDef structure returned is as expected.
1998    self.assertEqual(len(signature_defs), 1)
1999    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
2000    self.assertEqual(len(signature_defs.values()), 1)
2001    self.assertEqual(
2002        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
2003    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
2004    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
2005
2006  @test_util.run_v2_only
2007  def testMultipleFunctionModel(self):
2008    """Convert multiple functions in a multi-functional model."""
2009    root = self._getMultiFunctionModel()
2010    input_data = tf.constant(1., shape=[1])
2011    add_func = root.add.get_concrete_function(input_data)
2012    sub_func = root.sub.get_concrete_function(input_data)
2013
2014    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
2015    save(root, save_dir, {'add': add_func, 'sub': sub_func})
2016
2017    # Try converting multiple functions.
2018    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
2019    tflite_model = converter.convert()
2020    self.assertIsNotNone(tflite_model)
2021
2022    interpreter = tf.lite.Interpreter(model_content=tflite_model)
2023    signature_defs = interpreter.get_signature_list()
2024
2025    # Verify the SignatureDef structure returned is as expected.
2026    self.assertEqual(len(signature_defs), 2)
2027    self.assertEqual(list(signature_defs.keys()), ['add', 'sub'])
2028    self.assertEqual(len(signature_defs.values()), 2)
2029    self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs'])
2030    self.assertCountEqual(signature_defs['add']['inputs'], ['x'])
2031    self.assertEqual(list(signature_defs['add']['outputs']), ['output_0'])
2032    self.assertEqual(list(signature_defs['sub'].keys()), ['inputs', 'outputs'])
2033    self.assertCountEqual(signature_defs['sub']['inputs'], ['x'])
2034    self.assertEqual(list(signature_defs['sub']['outputs']), ['output_0'])
2035
2036    # Verify the Signature runner executions.
2037    add_signature_runner = interpreter.get_signature_runner('add')
2038    add_output = add_signature_runner(x=input_data)
2039    self.assertEqual(add_output['output_0'], 3)
2040
2041    sub_signature_runner = interpreter.get_signature_runner('sub')
2042    sub_output = sub_signature_runner(x=input_data)
2043    self.assertEqual(sub_output['output_0'], -2)
2044
2045  @parameterized.named_parameters(
2046      ('_Default', False, False, dtypes.float32, False),
2047      ('_DefaultMlirQuant', False, False, dtypes.float32, True),
2048      ('_INT8InputOutput', False, False, dtypes.int8),
2049      ('_UINT8InputOutput', False, False, dtypes.uint8),
2050      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
2051      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
2052      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
2053      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
2054      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
2055      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
2056  @test_util.run_v2_only
2057  def testMultipleFunctionQuantizedModel(self,
2058                                         is_int_only,
2059                                         is_int16_quantize,
2060                                         inference_input_output_type,
2061                                         enable_mlir_quantizer=False):
2062    """Convert multiple functions in a multi-functional model."""
2063    root = self._getMultiFunctionModel()
2064    input_data = tf.constant(1., shape=[1])
2065    add_func = root.add.get_concrete_function(input_data)
2066    sub_func = root.sub.get_concrete_function(input_data)
2067
2068    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
2069    save(root, save_dir, {'add': add_func, 'sub': sub_func})
2070
2071    # Try converting multiple functions.
2072    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
2073
2074    def representative_dataset_gen():
2075      for _ in range(2):
2076        yield ('add', {
2077            'x': np.random.uniform(low=0, high=1, size=(1,)).astype(np.float32),
2078        })
2079      for _ in range(2):
2080        yield ('sub', {
2081            'x': np.random.uniform(low=0, high=1, size=(1,)).astype(np.float32),
2082        })
2083
2084    converter.optimizations = [tf.lite.Optimize.DEFAULT]
2085    converter.representative_dataset = representative_dataset_gen
2086    if is_int_only:
2087      if is_int16_quantize:
2088        converter.target_spec.supported_ops = [
2089            lite.OpsSet
2090            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
2091        ]
2092      else:
2093        converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
2094    else:
2095      if is_int16_quantize:
2096        converter.target_spec.supported_ops = [
2097            lite.OpsSet
2098            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
2099        ]
2100      else:
2101        converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS]
2102    converter.inference_input_type = inference_input_output_type
2103    converter.inference_output_type = inference_input_output_type
2104    converter.experimental_new_quantizer = enable_mlir_quantizer
2105    tflite_model = converter.convert()
2106    self.assertIsNotNone(tflite_model)
2107
2108    interpreter = tf.lite.Interpreter(model_content=tflite_model)
2109    signature_defs = interpreter.get_signature_list()
2110
2111    # Verify the SignatureDef structure returned is as expected.
2112    self.assertEqual(len(signature_defs), 2)
2113    self.assertEqual(list(signature_defs.keys()), ['add', 'sub'])
2114    self.assertEqual(len(signature_defs.values()), 2)
2115    self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs'])
2116    self.assertCountEqual(signature_defs['add']['inputs'], ['x'])
2117    self.assertEqual(list(signature_defs['add']['outputs']), ['output_0'])
2118    self.assertEqual(list(signature_defs['sub'].keys()), ['inputs', 'outputs'])
2119    self.assertCountEqual(signature_defs['sub']['inputs'], ['x'])
2120    self.assertEqual(list(signature_defs['sub']['outputs']), ['output_0'])
2121
2122    # Verify the Signature runner executions.
2123    input_data = tf.constant(
2124        np.random.uniform(-1, 1, size=(1,)).astype(
2125            inference_input_output_type.as_numpy_dtype))
2126    add_signature_runner = interpreter.get_signature_runner('add')
2127    add_output = add_signature_runner(x=input_data)
2128    self.assertIsNotNone(add_output['output_0'])
2129    input_details = add_signature_runner.get_input_details()
2130    self.assertLen(input_details, 1)
2131    self.assertStartsWith(input_details['x']['name'], 'add_x:0')
2132    self.assertEqual(inference_input_output_type.as_numpy_dtype,
2133                     input_details['x']['dtype'])
2134    self.assertTrue(([1] == input_details['x']['shape']).all())
2135    if inference_input_output_type == dtypes.float32:
2136      self.assertEqual((0.0, 0), input_details['x']['quantization'])
2137
2138    sub_signature_runner = interpreter.get_signature_runner('sub')
2139    sub_output = sub_signature_runner(x=input_data)
2140    self.assertIsNotNone(sub_output['output_0'])
2141    output_details = sub_signature_runner.get_output_details()
2142    self.assertLen(output_details, 1)
2143    self.assertStartsWith(output_details['output_0']['name'],
2144                          'StatefulPartitionedCall:0')
2145    self.assertEqual(inference_input_output_type.as_numpy_dtype,
2146                     output_details['output_0']['dtype'])
2147    self.assertTrue(([1] == output_details['output_0']['shape']).all())
2148    if inference_input_output_type == dtypes.float32:
2149      self.assertEqual((0.0, 0), output_details['output_0']['quantization'])
2150
2151  @test_util.run_v2_only
2152  def testMultipleFunctionModelWithSharedWeight(self):
2153    """Convert multiple functions with the shared weight."""
2154    root = self._getMultiFunctionModelWithSharedWeight()
2155    input_data = tf.constant(1., shape=[1])
2156    add_func = root.add.get_concrete_function(input_data)
2157    sub_func = root.sub.get_concrete_function(input_data)
2158    mul_func = root.mul.get_concrete_function(input_data)
2159
2160    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
2161    save(root, save_dir, {'add': add_func, 'sub': sub_func, 'mul': mul_func})
2162
2163    # Try converting multiple functions.
2164    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
2165    tflite_model = converter.convert()
2166    self.assertIsNotNone(tflite_model)
2167
2168    # Make sure that the weight tensors are shared.
2169    self.assertLess(len(tflite_model), 1100000)
2170
2171    # TODO(b/184696047): Write down the test codes for multiple signature
2172    #                    runners once the Python API is ready to use.
2173    interpreter = tf.lite.Interpreter(model_content=tflite_model)
2174    signature_defs = interpreter.get_signature_list()
2175    self.assertLen(signature_defs, 3)
2176    add_signature_runner = interpreter.get_signature_runner('add')
2177    sub_signature_runner = interpreter.get_signature_runner('sub')
2178    mul_signature_runner = interpreter.get_signature_runner('mul')
2179    self.assertIsNotNone(add_signature_runner)
2180    self.assertIsNotNone(sub_signature_runner)
2181    self.assertIsNotNone(mul_signature_runner)
2182
2183  @test_util.run_v2_only
2184  def testNoConcreteFunctionModel(self):
2185    root = self._getMultiFunctionModel()
2186
2187    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
2188    save(root, save_dir)
2189
2190    with self.assertRaises(ValueError) as error:
2191      _ = lite.TFLiteConverterV2.from_saved_model(save_dir)
2192    self.assertIn('Only support at least one signature key.',
2193                  str(error.exception))
2194
2195  @test_util.run_v2_only
2196  def testKerasSequentialModel(self):
2197    """Test a simple sequential tf.Keras model."""
2198    input_data = tf.constant(1., shape=[1, 1])
2199
2200    x = np.array([[1.], [2.]])
2201    y = np.array([[2.], [4.]])
2202
2203    model = tf.keras.models.Sequential([
2204        tf.keras.layers.Dropout(0.2),
2205        tf.keras.layers.Dense(1),
2206    ])
2207    model.compile(optimizer='sgd', loss='mean_squared_error')
2208    model.fit(x, y, epochs=1)
2209
2210    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
2211    save(model, save_dir)
2212
2213    # Convert model and ensure model is not None.
2214    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
2215    tflite_model = converter.convert()
2216
2217    # Check values from converted model.
2218    expected_value = model.predict(input_data)
2219    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
2220    self.assertEqual(expected_value, actual_value)
2221
2222  @test_util.run_v2_only
2223  def testGraphDebugInfo(self):
2224    """Test a SavedModel has debug info captured."""
2225    input_data = tf.constant(1., shape=[1])
2226    root = autotrackable.AutoTrackable()
2227    root.f = tf.function(lambda x: 2. * x)
2228    to_save = root.f.get_concrete_function(input_data)
2229    options = save_options.SaveOptions(save_debug_info=True)
2230    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
2231    save(root, save_dir, to_save, options)
2232
2233    # Convert model and ensure model is not None.
2234    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
2235    converter.convert()
2236    self._assertValidDebugInfo(converter._debug_info)
2237
2238  @test_util.run_v2_only
2239  def testNonStatefulConvLSTM2D(self):
2240    """Test saved model with non stateful ConvLSTM2D keras layer."""
2241    # Create keras model
2242    model = tf.keras.Sequential([
2243        tf.keras.layers.ConvLSTM2D(
2244            32, (3, 3),
2245            padding='same',
2246            return_sequences=True,
2247            stateful=False,
2248            batch_input_shape=(1, 1, 10, 10, 1))
2249    ])
2250    model.compile()
2251
2252    # Export the keras model to saved model.
2253    saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_lstm_2d')
2254    model.save(saved_model_dir, save_format='tf', include_optimizer=False)
2255
2256    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
2257    converter.target_spec.supported_ops = [
2258        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2259    ]
2260    tflite_model = converter.convert()
2261    self.assertTrue(tflite_model)
2262
2263  @test_util.run_v2_only
2264  def testKerasConvLSTM2DWithMoreThanOneDilationRate(self):
2265    input_tensor = tf.keras.layers.Input(
2266        batch_size=8,
2267        shape=[9, 10, 11, 12],
2268        name='input_tensor',
2269        dtype=tf.float32)
2270
2271    output = tf.keras.layers.ConvLSTM2D(
2272        filters=3,
2273        kernel_size=3,
2274        strides=1,
2275        padding='VALID',
2276        dilation_rate=2,
2277        use_bias=False,
2278        bias_initializer='ones',
2279        data_format='channels_last')(
2280            input_tensor)
2281
2282    model = tf.keras.Model(inputs=[input_tensor], outputs=output)
2283    model.compile(
2284        optimizer='adam',
2285        loss='sparse_categorical_crossentropy',
2286        metrics=['accuracy'])
2287
2288    # Export the keras model to saved model.
2289    saved_model_dir = os.path.join(self.get_temp_dir(),
2290                                   'conv_lstm_2d_with_dilation_rate')
2291    model.save(saved_model_dir, save_format='tf', include_optimizer=False)
2292
2293    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
2294    converter.target_spec.supported_ops = [
2295        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2296    ]
2297    tflite_model = converter.convert()
2298    self.assertTrue(tflite_model)
2299
2300  @test_util.run_v2_only
2301  def testKerasFullyConnectedOutputShape3D(self):
2302    """Create a simple FullyConnected Model with an output of three dimensions."""
2303    input_tensor = tf.keras.layers.Input(
2304        batch_size=1, shape=[3, 3], name='input_tensor', dtype=tf.float32)
2305
2306    x = tf.quantization.fake_quant_with_min_max_args(input_tensor, -3.0, 3.0)
2307    x = tf.keras.layers.Dense(3)(x)
2308    x = tf.quantization.fake_quant_with_min_max_args(x, -3.0, 3.0)
2309    model = tf.keras.Model(input_tensor, x)
2310
2311    model.compile(
2312        optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])
2313
2314    # Export the keras model to saved model.
2315    saved_model_dir = os.path.join(self.get_temp_dir(),
2316                                   'fully_connected_output_3d')
2317    model.save(saved_model_dir, save_format='tf', include_optimizer=False)
2318    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
2319    converter.optimizations = [lite.Optimize.DEFAULT]
2320    tflite_model = converter.convert()
2321    self.assertTrue(tflite_model)
2322
2323    interpreter = Interpreter(model_content=tflite_model)
2324    output_details = interpreter.get_output_details()
2325    input_details = interpreter.get_input_details()
2326    interpreter.allocate_tensors()
2327
2328    input_data = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], np.float32)
2329    interpreter.set_tensor(input_details[0]['index'], input_data)
2330    interpreter.invoke()
2331
2332    actual_value = interpreter.get_tensor(output_details[0]['index'])
2333    expected_value = model.predict(input_data)
2334
2335    self.assertLen(output_details[0]['shape_signature'], 3)
2336    self.assertAllClose(expected_value, actual_value, atol=1e-1)
2337    self.assertEqual(
2338        list(output_details[0]['shape_signature']),
2339        list(model.layers[-1].output_shape))
2340
2341  @test_util.run_v2_only
2342  def testKerasConv2DTransposedWithMismatchQuantizedAxes(self):
2343
2344    class QuantConv2DTransposed(tf.keras.layers.Layer):
2345
2346      def build(self, input_shape):
2347        self.kernel = self.add_weight('kernel', [3, 3, input_shape[-1], 24])
2348
2349      def call(self, inputs):
2350        filters = tf.quantization.fake_quant_with_min_max_vars_per_channel(
2351            self.kernel,
2352            -3.0 * tf.ones([24]),
2353            3.0 * tf.ones([24]),
2354            narrow_range=True)
2355        filters = tf.transpose(filters, (0, 1, 3, 2))
2356        return tf.nn.conv2d_transpose(inputs, filters, [*inputs.shape[:-1], 24],
2357                                      1)
2358
2359    inp = tf.keras.Input(shape=(6, 8, 48), batch_size=1)
2360    x = tf.quantization.fake_quant_with_min_max_vars(
2361        inp, -3.0, 3.0, narrow_range=True)
2362    x = QuantConv2DTransposed()(x)
2363    x = tf.quantization.fake_quant_with_min_max_vars(
2364        x, -3.0, 3.0, narrow_range=True)
2365
2366    model = tf.keras.Model(inp, x)
2367
2368    saved_model_dir = os.path.join(self.get_temp_dir(),
2369                                   'keras_conv2d_transpose')
2370    model.save(saved_model_dir)
2371    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
2372    converter.optimizations = [tf.lite.Optimize.DEFAULT]
2373
2374    with self.assertRaises(convert.ConverterError) as error:
2375      _ = converter.convert()
2376    self.assertIn('mismatched quantized axes of input and output',
2377                  str(error.exception))
2378
2379  def _createModelWithInputShape(self, shape):
2380    """Create a simple SavedModel with a certain shape."""
2381    saved_model_dir = os.path.join(self.get_temp_dir(), 'input_shape_model')
2382    with tf.Graph().as_default():
2383      with tf.compat.v1.Session() as sess:
2384        unknown_shape = tf.TensorShape(shape)
2385        in_tensor = tf.compat.v1.placeholder(
2386            shape=unknown_shape, dtype=tf.float32, name='input')
2387        out_tensor = in_tensor + in_tensor
2388        inputs = {'input': in_tensor}
2389        outputs = {'output': out_tensor}
2390        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2391    return saved_model_dir
2392
2393  @test_util.run_v2_only
2394  def testUnknownInputShapeModel(self):
2395    """Test a SavedModel with an unknown input shape."""
2396    saved_model_dir = self._createModelWithInputShape(None)
2397
2398    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
2399    tflite_model = converter.convert()
2400    self.assertTrue(tflite_model)
2401
2402    # Validate that tensors with unknown shape have unknown rank.
2403    tflite_model_obj = _convert_bytearray_to_object(tflite_model)
2404    for tensor in tflite_model_obj.subgraphs[0].tensors:
2405      self.assertEqual(False, tensor.hasRank)
2406      self.assertEqual([], tensor.shape.tolist())
2407
2408    # Check values from converted model.
2409    interpreter = Interpreter(model_content=tflite_model)
2410    input_details = interpreter.get_input_details()
2411    output_details = interpreter.get_output_details()
2412
2413    input_data = np.array([1., 2., 3.], dtype=np.float32)
2414    interpreter.resize_tensor_input(
2415        input_details[0]['index'], [3], strict=False)
2416    interpreter.allocate_tensors()
2417
2418    interpreter.set_tensor(input_details[0]['index'], input_data)
2419    interpreter.invoke()
2420    actual_value = interpreter.get_tensor(output_details[0]['index'])
2421    self.assertEqual([2., 4., 6.], list(actual_value))
2422
2423  @test_util.run_v2_only
2424  def testScalarInputShapeModel(self):
2425    """Test a SavedModel with a scalar input."""
2426    saved_model_dir = self._createModelWithInputShape([])
2427
2428    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
2429    tflite_model = converter.convert()
2430    self.assertTrue(tflite_model)
2431
2432    # Validate that scalar tensors have a rank = 0.
2433    tflite_model_obj = _convert_bytearray_to_object(tflite_model)
2434    for tensor in tflite_model_obj.subgraphs[0].tensors:
2435      self.assertEqual(True, tensor.hasRank)
2436      self.assertEqual([], tensor.shape.tolist())
2437
2438  @test_util.run_v2_only
2439  def testMatrixInputShapeModel(self):
2440    """Test a SavedModel with a matrix input."""
2441    saved_model_dir = self._createModelWithInputShape([2, 3])
2442
2443    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
2444    tflite_model = converter.convert()
2445    self.assertTrue(tflite_model)
2446
2447    # Validate that matrix tensors have a rank = 2.
2448    tflite_model_obj = _convert_bytearray_to_object(tflite_model)
2449    for tensor in tflite_model_obj.subgraphs[0].tensors:
2450      self.assertEqual(True, tensor.hasRank)
2451      self.assertEqual([2, 3], tensor.shape.tolist())
2452
2453  @parameterized.named_parameters(
2454      ('_PerChannelQuant', False, False),
2455      ('_PerChannelMlirQuant', False, True),
2456      ('_PerTensorQuant', True, False),
2457      ('_PerTensorMlirQuant', True, True),
2458      ('_PerChannelDynamicRange', False, False, True),
2459      ('_PerTensorDynamicRange', True, False, True))
2460  @test_util.run_v2_only
2461  def testDisablePerChannelQuantization(self,
2462                                        disable_per_channel=False,
2463                                        enable_mlir_quantizer=False,
2464                                        representative_dataset=True):
2465    # Dynamic range quant requires total num elements of filters > 1024.
2466    k_num_filters = 38
2467    model = tf.keras.models.Sequential([
2468        tf.keras.layers.Conv2D(k_num_filters, (3, 3), activation='relu')
2469    ])
2470    model.build(input_shape=(1, 5, 5, 3))
2471    saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_saved_model')
2472    save(model, saved_model_dir)
2473    k_conv_name = 'sequential/conv2d/Conv2D'
2474    quantized_converter = tf.lite.TFLiteConverter.from_saved_model(
2475        saved_model_dir)
2476    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
2477    if representative_dataset:
2478      def calib_gen():
2479        for _ in range(5):
2480          yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
2481      quantized_converter.representative_dataset = calib_gen
2482    quantized_converter.target_spec.supported_ops = [
2483        lite.OpsSet.TFLITE_BUILTINS
2484    ]
2485    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
2486    if disable_per_channel:
2487      quantized_converter._experimental_disable_per_channel = (
2488          disable_per_channel)
2489    quantized_tflite_model = quantized_converter.convert()
2490    self.assertIsNotNone(quantized_tflite_model)
2491
2492    interpreter = Interpreter(model_content=quantized_tflite_model)
2493    interpreter.allocate_tensors()
2494    detail = next((d for d in interpreter.get_tensor_details()
2495                   if d['name'].startswith(k_conv_name)))
2496    quant_params = detail['quantization_parameters']
2497    expected_num_params = k_num_filters
2498    if disable_per_channel:
2499      expected_num_params = 1
2500    self.assertLen(quant_params['scales'], expected_num_params)
2501    self.assertLen(quant_params['zero_points'], expected_num_params)
2502
2503  @parameterized.named_parameters(
2504      ('_INT8Quant_INT32Bias', False, False, dtypes.int32, True),
2505      ('_INT16Quant_INT64Bias', True, False, dtypes.int64, True),
2506      ('_INT8Quant_INT32Bias_Set', False, True, dtypes.int32, True),
2507      ('_INT8Quant_INT64Bias_Set', False, True, dtypes.int64, False),
2508      ('_INT16Quant_INT32Bias_Set', True, True, dtypes.int32, True),
2509      ('_INT16Quant_INT64Bias_Set', True, True, dtypes.int64, True),
2510      ('_INT16Quant_FLOAT32Bias_Set', True, True, dtypes.float32, False),
2511  )
2512  @test_util.run_v2_only
2513  def testBiasQuantization(self, is_int16_quantize, explicitly_set_bias,
2514                           bias_type, is_valid_bias_type):
2515    model = tf.keras.models.Sequential([
2516        tf.keras.layers.Dense(
2517            1024, input_shape=[1024], activation=None, bias_initializer='ones')
2518    ])
2519    saved_model_dir = os.path.join(self.get_temp_dir(), 'dense_saved_model')
2520    save(model, saved_model_dir)
2521    k_dense_bias_name = 'sequential/dense/BiasAdd/ReadVariableOp'
2522    quantized_converter = tf.lite.TFLiteConverter.from_saved_model(
2523        saved_model_dir)
2524    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
2525
2526    if explicitly_set_bias:
2527      quantized_converter._experimental_full_integer_quantization_bias_type = bias_type
2528
2529    if is_int16_quantize:
2530      quantized_converter.target_spec.supported_ops = [
2531          lite.OpsSet
2532          .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
2533      ]
2534    else:
2535      quantized_converter.target_spec.supported_ops = [
2536          lite.OpsSet.TFLITE_BUILTINS_INT8
2537      ]
2538
2539    def calibration_gen():
2540      for _ in range(5):
2541        yield [np.random.randn(1, 1024).astype(np.float32)]
2542
2543    quantized_converter.representative_dataset = calibration_gen
2544
2545    if not is_valid_bias_type:
2546      with self.assertRaisesRegex(ValueError, 'Expected bias type to be'):
2547        quantized_converter.convert()
2548      return
2549
2550    quantized_tflite_model = quantized_converter.convert()
2551    self.assertIsNotNone(quantized_tflite_model)
2552
2553    interpreter = Interpreter(model_content=quantized_tflite_model)
2554    interpreter.allocate_tensors()
2555    dense_bias = next((d for d in interpreter.get_tensor_details()
2556                       if d['name'].startswith(k_dense_bias_name)))
2557    self.assertEqual(bias_type, dense_bias['dtype'])
2558
2559  @parameterized.named_parameters(
2560      ('_Int8PerChannelMlirDynamicRangeQuant', True, False, False),
2561      ('_Int8PerChannelTocoDynamicRangeQuant', False, False, False),
2562      ('_Int8PerTensorMlirDynamicRangeQuant', True, True, False),
2563      ('_Int8PerTensorTocoDynamicRangeQuant', False, True, False),
2564      ('_Float16DynamicRangeQuant', True, False, True))
2565  @test_util.run_v2_only
2566  def testMlirDynamicRangeQuantization(self, enable_new_dynamic_range_quantizer,
2567                                       disable_per_channel,
2568                                       enable_float16_quant):
2569    num_filters = 1024
2570    conv_name = 'sequential/conv2d/Conv2D'
2571    model = tf.keras.models.Sequential(
2572        [tf.keras.layers.Conv2D(num_filters, (3, 3), activation='relu')])
2573    model.build(input_shape=(1, 32, 32, 3))
2574    saved_model_dir = self.create_tempdir()
2575    save(model, saved_model_dir.full_path)
2576
2577    converter = tf.lite.TFLiteConverter.from_saved_model(
2578        saved_model_dir.full_path)
2579    converter.optimizations = [lite.Optimize.DEFAULT]
2580    converter.experimental_new_dynamic_range_quantizer = (
2581        enable_new_dynamic_range_quantizer)
2582    converter._experimental_disable_per_channel = disable_per_channel
2583    if enable_float16_quant:
2584      converter.target_spec.supported_types = [tf.float16]
2585    quantized_tflite_model = converter.convert()
2586    self.assertIsNotNone(quantized_tflite_model)
2587
2588    interpreter = Interpreter(model_content=quantized_tflite_model)
2589    interpreter.allocate_tensors()
2590    quantized_weight = None
2591    quantized_weight_with_one_postfix = None
2592    quantized_weight_without_one_postfix = None
2593    for d in interpreter.get_tensor_details():
2594      if d['name'] == conv_name + '1':
2595        quantized_weight = d
2596        quantized_weight_with_one_postfix = d
2597        break
2598    for d in interpreter.get_tensor_details():
2599      if d['name'].startswith(conv_name):
2600        if quantized_weight is None:
2601          quantized_weight = d
2602        quantized_weight_without_one_postfix = d
2603        break
2604
2605    self.assertIsNotNone(quantized_weight)
2606    quant_params = quantized_weight['quantization_parameters']
2607
2608    if enable_float16_quant:
2609      expected_num_params = 0
2610    else:
2611      expected_num_params = 1 if disable_per_channel else num_filters
2612    self.assertLen(quant_params['scales'], expected_num_params)
2613    self.assertLen(quant_params['zero_points'], expected_num_params)
2614
2615    input_details = interpreter.get_input_details()
2616    output_details = interpreter.get_output_details()
2617    self.assertEqual(np.float32, input_details[0]['dtype'])
2618    self.assertEqual(np.float32, output_details[0]['dtype'])
2619    if enable_float16_quant:
2620      self.assertTrue(
2621          (quantized_weight_with_one_postfix is not None and
2622           np.float16 == quantized_weight_with_one_postfix['dtype']) or
2623          (quantized_weight_without_one_postfix is not None and
2624           np.float16 == quantized_weight_without_one_postfix['dtype']))
2625    else:
2626      self.assertEqual(np.int8, quantized_weight['dtype'])
2627
2628
2629class FromKerasModelTest(lite_v2_test_util.ModelTest):
2630
2631  @test_util.run_v2_only
2632  def testSequentialModel(self):
2633    """Test a simple sequential tf.Keras model."""
2634    input_data = tf.constant(1., shape=[1, 1])
2635
2636    # Create a simple Keras model.
2637    x = np.array([[1.], [2.]])
2638    y = np.array([[2.], [4.]])
2639
2640    model = tf.keras.models.Sequential([
2641        tf.keras.layers.Dropout(0.2),
2642        tf.keras.layers.Dense(units=1, input_shape=[1])
2643    ])
2644    model.compile(optimizer='sgd', loss='mean_squared_error')
2645    model.fit(x, y, epochs=1)
2646
2647    # Convert model and ensure model is not None.
2648    converter = lite.TFLiteConverterV2.from_keras_model(model)
2649    tflite_model = converter.convert()
2650    # Check the conversion metadata.
2651    metadata = get_conversion_metadata(tflite_model)
2652    self.assertIsNotNone(metadata)
2653    self.assertEqual(metadata.environment.modelType,
2654                     metadata_fb.ModelType.KERAS_MODEL)
2655
2656    # Check values from converted model.
2657    expected_value = model.predict(input_data)
2658    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
2659    self.assertEqual(expected_value, actual_value)
2660
2661  @test_util.run_v2_only
2662  def testSequentialMultiInputOutputModel(self):
2663    """Test a tf.Keras model with multiple inputs and outputs."""
2664    left_input_data = tf.constant(1., shape=[1, 3])
2665    right_input_data = tf.constant(1., shape=[1, 3])
2666
2667    # Create a simple Keras model.
2668    input_a_np = np.random.random((10, 3))
2669    input_b_np = np.random.random((10, 3))
2670    output_c_np = np.random.random((10, 3))
2671    output_d_np = np.random.random((10, 2))
2672
2673    input_a = tf.keras.layers.Input(shape=(3,), name='input_a')
2674    input_b = tf.keras.layers.Input(shape=(3,), name='input_b')
2675
2676    dense = tf.keras.layers.Dense(8, name='dense_1')
2677    interm_a = dense(input_a)
2678    interm_b = dense(input_b)
2679    merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
2680
2681    output_c = tf.keras.layers.Dense(
2682        3, activation='softmax', name='dense_2')(
2683            merged)
2684    output_d = tf.keras.layers.Dense(
2685        2, activation='softmax', name='dense_3')(
2686            merged)
2687
2688    model = tf.keras.models.Model(
2689        inputs=[input_a, input_b], outputs=[output_c, output_d])
2690    model.compile(optimizer='sgd', loss='mean_squared_error')
2691    model.fit([input_a_np, input_b_np], [output_c_np, output_d_np], epochs=1)
2692
2693    # Convert model and ensure model is not None.
2694    converter = lite.TFLiteConverterV2.from_keras_model(model)
2695    tflite_model = converter.convert()
2696
2697    # Check values from converted model.
2698    input_data = [left_input_data, right_input_data]
2699    expected_value = model.predict(input_data)
2700    actual_value = self._evaluateTFLiteModel(tflite_model, input_data)
2701    for tf_result, tflite_result in zip(expected_value, actual_value):
2702      self.assertAllClose(tf_result, tflite_result, atol=1e-05)
2703
2704  @test_util.run_v2_only
2705  def testGraphDebugInfo(self):
2706    """Test a tf.Keras model has debug info captured."""
2707    # Create a simple Keras model.
2708    x = [-1, 0, 1, 2, 3, 4]
2709    y = [-3, -1, 1, 3, 5, 7]
2710    model = tf.keras.models.Sequential(
2711        [tf.keras.layers.Dense(units=1, input_shape=[1])])
2712    model.compile(optimizer='sgd', loss='mean_squared_error')
2713    model.fit(x, y, epochs=1)
2714    converter = lite.TFLiteConverterV2.from_keras_model(model)
2715    converter.convert()
2716    self._assertValidDebugInfo(converter._debug_info)
2717
2718  @test_util.run_v2_only
2719  def testKerasFallbackPath(self):
2720    """Test keras model which failed when exporting to the saved model."""
2721    input_data = tf.constant(
2722        np.array(np.random.random_sample((20)), dtype=np.float32))
2723
2724    class Model(tf.keras.Model):
2725
2726      def __init__(self):
2727        super(Model, self).__init__()
2728        # A None name will cause a failure in exporting to a saved model.
2729        self.shared_weights = self.add_weight(
2730            name=None,
2731            shape=(20, 1),
2732            dtype=tf.float32,
2733            initializer=tf.random_normal_initializer(
2734                mean=0.0, stddev=300**(-0.5)))
2735
2736      def call(self, x):
2737        return tf.add(self.shared_weights, x)
2738
2739    # Building the model.
2740    model = Model()
2741    model.compile(optimizer='sgd', loss='mean_squared_error')
2742    model.fit(input_data, input_data, epochs=1)
2743
2744    # Convert model.
2745    converter = lite.TFLiteConverterV2.from_keras_model(model)
2746    tflite_model = converter.convert()
2747    self.assertTrue(tflite_model)
2748
2749  @test_util.run_v2_only
2750  def testSignatureDefs(self):
2751    """Test converting SignatureDef is correct and uses SignatureDef API."""
2752    keras_model = tf.keras.Sequential([
2753        tf.keras.layers.Conv2D(
2754            32,
2755            kernel_size=3,
2756            padding='same',
2757            activation='relu',
2758            input_shape=(32, 32, 3),
2759            name='tensor'),
2760        tf.keras.layers.Dense(10, name='output_tensor')
2761    ])
2762
2763    converter = lite.TFLiteConverterV2.from_keras_model(keras_model)
2764    tflite_model = converter.convert()
2765
2766    # Check values from converted model.
2767    input_data = tf.constant(
2768        np.random.uniform(-1, 1, size=(1, 32, 32, 3)).astype(np.float32))
2769    expected_value = keras_model(input_data)
2770    interpreter = Interpreter(model_content=tflite_model)
2771    signature_defs = interpreter.get_signature_list()
2772    results = self._evaluateTFLiteModelUsingSignatureDef(
2773        tflite_model, 'serving_default', {'tensor_input': input_data})
2774    self.assertEqual(list(results.keys()), ['output_tensor'])
2775    self.assertAllClose(expected_value.numpy(), results['output_tensor'])
2776
2777    # Verify the SignatureDef structure returned is as expected.
2778    self.assertEqual(len(signature_defs), 1)
2779    self.assertEqual(list(signature_defs.keys()), ['serving_default'])
2780    self.assertEqual(len(signature_defs.values()), 1)
2781    self.assertEqual(
2782        list(signature_defs['serving_default'].keys()), ['inputs', 'outputs'])
2783    self.assertCountEqual(signature_defs['serving_default']['inputs'],
2784                          ['tensor_input'])
2785    self.assertEqual(
2786        list(signature_defs['serving_default']['outputs']), ['output_tensor'])
2787
2788  @parameterized.named_parameters(
2789      ('_PerChannelMlirDynamicRangeQuant', True, False, False),
2790      ('_PerChannelTocoDynamicRangeQuant', False, False, False),
2791      ('_PerTensorMlirDynamicRangeQuant', True, True, False),
2792      ('_PerTensorTocoDynamicRangeQuant', False, True, False),
2793      ('_Float16DynamicRangeQuant', True, False, True))
2794  @test_util.run_v2_only
2795  def testMlirDynamicRangeQuantization(self, enable_new_dynamic_range_quantizer,
2796                                       disable_per_channel,
2797                                       enable_float16_quant):
2798    num_filters = 1024
2799    conv_name = 'sequential/conv2d/Conv2D'
2800    model = tf.keras.models.Sequential(
2801        [tf.keras.Input(shape=(32, 32, 3)),
2802         tf.keras.layers.Conv2D(num_filters, (3, 3), activation='relu')])
2803    model.build()
2804
2805    converter = lite.TFLiteConverterV2.from_keras_model(model)
2806    converter.optimizations = [lite.Optimize.DEFAULT]
2807    converter.experimental_new_dynamic_range_quantizer = (
2808        enable_new_dynamic_range_quantizer)
2809    converter._experimental_disable_per_channel = disable_per_channel
2810    if enable_float16_quant:
2811      converter.target_spec.supported_types = [tf.float16]
2812    quantized_tflite_model = converter.convert()
2813    self.assertIsNotNone(quantized_tflite_model)
2814
2815    interpreter = Interpreter(model_content=quantized_tflite_model)
2816    interpreter.allocate_tensors()
2817    quantized_weight = None
2818    quantized_weight_with_one_postfix = None
2819    quantized_weight_without_one_postfix = None
2820    for d in interpreter.get_tensor_details():
2821      if d['name'] == conv_name + '1':
2822        quantized_weight = d
2823        quantized_weight_with_one_postfix = d
2824        break
2825    for d in interpreter.get_tensor_details():
2826      if d['name'].startswith(conv_name):
2827        if quantized_weight is None:
2828          quantized_weight = d
2829        quantized_weight_without_one_postfix = d
2830        break
2831
2832    self.assertIsNotNone(quantized_weight)
2833    quant_params = quantized_weight['quantization_parameters']
2834
2835    if enable_float16_quant:
2836      expected_num_params = 0
2837    else:
2838      expected_num_params = 1 if disable_per_channel else num_filters
2839    self.assertLen(quant_params['scales'], expected_num_params)
2840    self.assertLen(quant_params['zero_points'], expected_num_params)
2841
2842    input_details = interpreter.get_input_details()
2843    output_details = interpreter.get_output_details()
2844    self.assertEqual(np.float32, input_details[0]['dtype'])
2845    self.assertEqual(np.float32, output_details[0]['dtype'])
2846    if enable_float16_quant:
2847      self.assertTrue(
2848          (quantized_weight_with_one_postfix is not None and
2849           np.float16 == quantized_weight_with_one_postfix['dtype']) or
2850          (quantized_weight_without_one_postfix is not None and
2851           np.float16 == quantized_weight_without_one_postfix['dtype']))
2852    else:
2853      self.assertEqual(np.int8, quantized_weight['dtype'])
2854
2855  @parameterized.named_parameters([
2856      ('{}BitWeightOnly={}LowBit={}'.format(num_bits, weight_only, low_bit),
2857       num_bits, weight_only, low_bit) for num_bits, weight_only, low_bit
2858      in itertools.product((2, 4, 6), (True, False), (True, False))])
2859  @test_util.run_v2_only
2860  def testQATLowBitKerasModel(self, num_bits, weight_only, low_bit):
2861    bit_max = (1 << (num_bits - 1)) - 1
2862    bit_min = -bit_max
2863    tf_input_shape = (5, 5, 3)
2864    tflite_input_shape = (1,) + tf_input_shape
2865    model, input_name, output_name = (self._createV2QATLowBitKerasModel(
2866        tf_input_shape, weight_only, num_bits, bit_min, bit_max))
2867    input_data = np.linspace(
2868        0, 6, np.prod(tflite_input_shape)).reshape(tflite_input_shape)
2869    tf_result = model(input_data)
2870
2871    converter = tf.lite.TFLiteConverter.from_keras_model(model)
2872    converter.optimizations = [tf.lite.Optimize.DEFAULT]
2873    if low_bit:
2874      converter._experimental_low_bit_qat = True
2875    tflite_model = converter.convert()
2876    result = self._evaluateTFLiteModelUsingSignatureDef(
2877        tflite_model, 'serving_default',
2878        {input_name: input_data.astype(np.float32)})[output_name]
2879    self.assertAllClose(
2880        [np.linalg.norm(result - tf_result.numpy().astype(np.float32))], [0.0])
2881    interpreter = tf.lite.Interpreter(model_content=tflite_model)
2882    interpreter.allocate_tensors()
2883    num_8bit_activations = 0
2884    num_8bit_weights = 0
2885    kernel_name = ('model/conv_wrapper/Conv2D;model/conv_wrapper/'
2886                   'FakeQuantWithMinMaxVarsPerChannel')
2887
2888    for detail in interpreter.get_tensor_details():
2889      if (detail['dtype'] == np.int8 and detail['name'] and
2890          detail['name'] == kernel_name):
2891        num_8bit_weights += 1
2892        weights = interpreter.get_tensor(detail['index'])
2893        if low_bit:
2894          self.assertFalse((bit_min > weights).any() or
2895                           (weights > bit_max).any())
2896        else:
2897          self.assertTrue((bit_min > weights).any() or
2898                          (weights > bit_max).any())
2899        self.assertIn('scales', detail['quantization_parameters'])
2900        if low_bit and detail['quantization_parameters']['scales']:
2901          self.assertAllClose(
2902              detail['quantization_parameters']['scales'], [1.0])
2903      elif detail['dtype'] == np.int8 and detail['name']:
2904        self.assertFalse(weight_only)
2905        self.assertIn('scales', detail['quantization_parameters'])
2906        if detail['quantization_parameters']['scales']:
2907          self.assertAllClose(
2908              detail['quantization_parameters']['scales'], [6/255])
2909        num_8bit_activations += 1
2910
2911    self.assertEqual(num_8bit_weights, 0 if weight_only and not low_bit else 1)
2912    # 3 activations with full integer: conv_input, conv_output, reshape_output
2913    self.assertEqual(num_8bit_activations, 0 if weight_only else 3)
2914
2915
2916class FromJaxModelTest(lite_v2_test_util.ModelTest):
2917
2918  @test_util.run_v2_only
2919  def testInvalidInputsModel(self):
2920    if DISABLE_JAX_TEST:
2921      return
2922
2923    def simple_model(input1, input2):
2924      return jnp.sin(input1) + jnp.cos(input2)
2925
2926    input_tensor = jnp.zeros([10, 10])
2927    # Invalid case: not specify serving_func
2928    converter = lite.TFLiteConverterV2.experimental_from_jax(
2929        None, [{
2930            'input1': input_tensor
2931        }])
2932    with self.assertRaisesRegex(ValueError, 'No serving func is specified.'):
2933      converter.convert()
2934
2935    # Invalid case: not specify input
2936    converter = lite.TFLiteConverterV2.experimental_from_jax([simple_model],
2937                                                             None)
2938    with self.assertRaisesRegex(ValueError, 'Input tensors are not specified.'):
2939      converter.convert()
2940
2941    converter = lite.TFLiteConverterV2.experimental_from_jax([simple_model], [])
2942    with self.assertRaisesRegex(ValueError, 'Input tensors are not specified.'):
2943      converter.convert()
2944
2945    # Invalid case: not wrap input_tensor in a list.
2946    converter = lite.TFLiteConverterV2.experimental_from_jax([simple_model],
2947                                                             input_tensor)
2948    with self.assertRaisesRegex(
2949        ValueError,
2950        'The truth value of an array with more than one element is ambiguous.'):
2951      converter.convert()
2952
2953    # Invalid case: only partial inputs are provided.
2954    converter = lite.TFLiteConverterV2.experimental_from_jax(
2955        [simple_model], [[('input1', input_tensor)]])
2956    with self.assertRaisesRegex(
2957        ValueError, 'Failed to convert the given Jax function to hlo.'):
2958      converter.convert()
2959
2960    # Invalid case: serving functions length does not match input mapping.
2961    converter = lite.TFLiteConverterV2.experimental_from_jax(
2962        [simple_model, simple_model], [[
2963            ('input1', input_tensor),
2964            ('input2', input_tensor),
2965        ]])
2966    with self.assertRaisesRegex(
2967        ValueError,
2968        'Input tensor mapping len 1 does not match serving func len 2.'):
2969      converter.convert()
2970
2971    # Invalid case: multiple serving function is provided.
2972    converter = lite.TFLiteConverterV2.experimental_from_jax(
2973        [simple_model, simple_model], [[
2974            ('input1', input_tensor),
2975            ('input2', input_tensor),
2976        ], [
2977            ('input1', input_tensor),
2978            ('input2', input_tensor),
2979        ]])
2980    with self.assertRaisesRegex(
2981        ValueError, 'Currently only support single serving function.'):
2982      converter.convert()
2983
2984  @test_util.run_v2_only
2985  def testSingleInputModel(self):
2986    if DISABLE_JAX_TEST:
2987      return
2988
2989    def single_input(input_tensor):
2990      return jnp.sin(input_tensor)
2991
2992    # Convert model.
2993    input_tensor = jnp.zeros([10, 10])
2994    converter = lite.TFLiteConverterV2.experimental_from_jax(
2995        [single_input], [[('input_tensor', input_tensor)]])
2996    tflite_model = converter.convert()
2997    # Check the conversion metadata.
2998    metadata = get_conversion_metadata(tflite_model)
2999    self.assertIsNotNone(metadata)
3000    self.assertEqual(metadata.environment.modelType, metadata_fb.ModelType.JAX)
3001
3002    # Check values from converted_model
3003    input_data = np.random.random_sample((10, 10))
3004    tf_input_data = tf.constant(input_data, dtype=np.float32)
3005    actual_value = self._evaluateTFLiteModel(tflite_model, [tf_input_data])[0]
3006    expected_value = single_input(input_data)
3007    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3008
3009  @test_util.run_v2_only
3010  def testMultipleInputsModel(self):
3011    if DISABLE_JAX_TEST:
3012      return
3013
3014    def multiple_inputs(input1, input2):
3015      return input1 + input2
3016
3017    # Convert model.
3018    input1 = jnp.zeros([10, 10])
3019    input2 = jnp.zeros([10, 1])
3020    converter = lite.TFLiteConverterV2.experimental_from_jax(
3021        [multiple_inputs], [[('input1', input1), ('input2', input2)]])
3022    tflite_model = converter.convert()
3023
3024    # Check values from converted_model
3025    input1_data = np.random.random_sample((10, 10))
3026    tf_input1_data = tf.constant(input1_data, dtype=np.float32)
3027    input2_data = np.random.random_sample((10, 1))
3028    tf_input2_data = tf.constant(input2_data, dtype=np.float32)
3029    actual_value = self._evaluateTFLiteModel(
3030        tflite_model, [tf_input1_data, tf_input2_data])[0]
3031    expected_value = multiple_inputs(input1_data, input2_data)
3032    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3033
3034  @test_util.run_v2_only
3035  def testInputSignaturesModel(self):
3036    if DISABLE_JAX_TEST:
3037      return
3038
3039    def multiple_inputs(input1, input2):
3040      return input1 + input2
3041
3042    # Convert model.
3043    input1 = jnp.zeros([10, 10])
3044    input2 = jnp.zeros([10, 1])
3045    converter = lite.TFLiteConverterV2.experimental_from_jax(
3046        [multiple_inputs], [[('input1', input1), ('input2', input2)]])
3047    tflite_model = converter.convert()
3048
3049    # Check values from converted_model
3050    input1_data = np.random.random_sample((10, 10))
3051    tf_input1_data = tf.constant(input1_data, dtype=np.float32)
3052    input2_data = np.random.random_sample((10, 1))
3053    tf_input2_data = tf.constant(input2_data, dtype=np.float32)
3054    actual_value = self._evaluateTFLiteModel(
3055        tflite_model, [tf_input1_data, tf_input2_data])[0]
3056    expected_value = multiple_inputs(input1_data, input2_data)
3057    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3058
3059  @test_util.run_v2_only
3060  def testModelWithParams(self):
3061    if DISABLE_JAX_TEST:
3062      return
3063
3064    def model(inputs, weights):
3065      return jnp.matmul(weights, inputs)
3066
3067    weights = np.random.random_sample((10, 10))
3068    serving_func = functools.partial(model, weights=weights)
3069
3070    # Convert model
3071    input_tensor = jnp.zeros([10, 10])
3072    converter = lite.TFLiteConverterV2.experimental_from_jax(
3073        [serving_func], [[('inputs', input_tensor)]])
3074    tflite_model = converter.convert()
3075
3076    # Check values from converted_model
3077    input_data = np.random.random_sample((10, 10))
3078    tf_input_data = tf.constant(input_data, dtype=np.float32)
3079    actual_value = self._evaluateTFLiteModel(tflite_model, [tf_input_data])[0]
3080    expected_value = serving_func(input_data)
3081    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3082
3083  @test_util.run_v2_only
3084  def testWhileLoop(self):
3085    if DISABLE_JAX_TEST:
3086      return
3087
3088    def condition(x):
3089      return jnp.sum(x, keepdims=False) < 100
3090
3091    def body(x):
3092      return jnp.add(x, 2.0)
3093
3094    def model(x):
3095      result = jax.lax.while_loop(condition, body, x)
3096      return result[0]
3097
3098    # Convert model.
3099    input_tensor = jnp.zeros([3, 3])
3100    converter = lite.TFLiteConverterV2.experimental_from_jax(
3101        [model], [[('x', input_tensor)]])
3102    tflite_model = converter.convert()
3103
3104    # Check values from converted_model
3105    input_data = np.random.random_sample((3, 3))
3106    tf_input_data = tf.constant(input_data, dtype=np.float32)
3107    actual_value = self._evaluateTFLiteModel(tflite_model, [tf_input_data])[0]
3108    expected_value = model(input_data)
3109    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3110
3111
3112class ControlFlowTest(lite_v2_test_util.ModelTest):
3113
3114  @test_util.run_v2_only
3115  def testCond(self):
3116    input_data = {
3117        'x': tf.constant([1., 2.], shape=[1, 2]),
3118        'b': tf.constant(True)
3119    }
3120
3121    weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
3122
3123    def true_fn(x):
3124      return tf.matmul(x, weights)
3125
3126    def false_fn(x):
3127      return tf.add(x, weights)
3128
3129    @tf.function(input_signature=[
3130        tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
3131        tf.TensorSpec(shape=(), dtype=tf.bool)
3132    ])
3133    def model(x, b):
3134      return tf.cond(
3135          b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
3136
3137    concrete_func = model.get_concrete_function()
3138
3139    # Convert model.
3140    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3141                                                               model)
3142    tflite_model = converter.convert()
3143
3144    # Check values from converted model.
3145    expected_value = concrete_func(**input_data)
3146    actual_value = self._evaluateTFLiteModel(
3147        tflite_model, [input_data['x'], input_data['b']])[0]
3148    self.assertAllClose(expected_value, actual_value)
3149
3150  @test_util.run_v2_only
3151  def testCondWithFullIntegerQuantization(self):
3152    weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
3153
3154    def true_fn(x):
3155      return tf.matmul(x, weights)
3156
3157    def false_fn(x):
3158      return tf.add(x, weights)
3159
3160    @tf.function(input_signature=[
3161        tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
3162        tf.TensorSpec(shape=(), dtype=tf.bool)
3163    ])
3164    def model(x, b):
3165      return tf.cond(
3166          b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
3167
3168    def calibration_gen():
3169      for _ in range(5):
3170        yield [
3171            np.random.uniform(-1, 1, size=(1, 2)).astype(np.float32),
3172            tf.constant(True)
3173        ]
3174      for _ in range(5):
3175        yield [
3176            np.random.uniform(-1, 1, size=(1, 2)).astype(np.float32),
3177            tf.constant(False)
3178        ]
3179
3180    concrete_func = model.get_concrete_function()
3181
3182    # Convert model.
3183    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3184                                                               model)
3185    converter.optimizations = [tf.lite.Optimize.DEFAULT]
3186    converter.representative_dataset = calibration_gen
3187    tflite_model = converter.convert()
3188    self.assertIsNotNone(tflite_model)
3189
3190  @test_util.run_v2_only
3191  def testConverterErrorOnControlFlowV1Ops(self):
3192    filename = resource_loader.get_path_to_datafile(
3193        'testdata/control_flow_v1_saved_model')
3194    converter = lite.TFLiteConverterV2.from_saved_model(filename)
3195    with self.assertRaises(convert.ConverterError) as error:
3196      converter.convert()
3197    self.assertIn(
3198        'Failed to functionalize Control Flow V1 ops. Consider using Control '
3199        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
3200        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
3201
3202  @test_util.run_v2_only
3203  def testStaticRnn(self):
3204    input_data = tf.constant(
3205        np.array(np.random.random_sample((3, 10)), dtype=np.float32))
3206
3207    cell = tf.keras.layers.LSTMCell(10)
3208
3209    @tf.function(
3210        input_signature=[tf.TensorSpec(shape=[3, 10], dtype=tf.float32)])
3211    def model(x):
3212      seq = tf.split(x, 3, 0)
3213      return rnn.static_rnn(cell, seq, dtype=tf.float32, sequence_length=[1])
3214
3215    concrete_func = model.get_concrete_function()
3216
3217    # Convert model.
3218    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3219                                                               model)
3220    tflite_model = converter.convert()
3221
3222    # Check values from converted model.
3223    expected_value = concrete_func(input_data)[0]
3224    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
3225    for expected, actual in zip(expected_value, actual_value):
3226      self.assertAllClose(expected, actual)
3227
3228  @test_util.run_v2_only
3229  def testWhileLoop(self):
3230    input_data = tf.constant([1., 2., 3., 4.], shape=[2, 2])
3231
3232    weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
3233
3234    def condition(x):
3235      return tf.reduce_sum(x) < 100
3236
3237    def body(x):
3238      return tf.add(x, weights)
3239
3240    @tf.function(
3241        input_signature=[tf.TensorSpec(shape=[2, 2], dtype=tf.float32)])
3242    def model(x):
3243      return tf.while_loop(condition, body, [x])
3244
3245    concrete_func = model.get_concrete_function()
3246
3247    # Convert model.
3248    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3249                                                               model)
3250    tflite_model = converter.convert()
3251
3252    # Check values from converted model.
3253    expected_value = concrete_func(input_data)[0]
3254    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
3255    self.assertAllClose(expected_value, actual_value)
3256
3257  @test_util.run_v2_only
3258  def testDynamicRnn(self):
3259    input_data = tf.constant(
3260        np.array(np.random.random_sample((3, 10, 10)), dtype=np.float32))
3261
3262    cell = tf.keras.layers.LSTMCell(10)
3263
3264    @tf.function(
3265        input_signature=[tf.TensorSpec(shape=[3, 10, 10], dtype=tf.float32)])
3266    def model(x):
3267      rnn_layer = tf.keras.layers.RNN([cell], return_sequences=True)
3268      return rnn_layer(x)
3269
3270    concrete_func = model.get_concrete_function()
3271
3272    # Convert model.
3273    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3274                                                               model)
3275    tflite_model = converter.convert()
3276
3277    # Check values from converted model.
3278    expected_value = concrete_func(input_data)
3279    lite_outputs = self._evaluateTFLiteModel(tflite_model, [input_data])
3280    self.assertLen(lite_outputs, 1)
3281    actual_value = lite_outputs[0]
3282    for expected, actual in zip(expected_value, actual_value):
3283      self.assertAllClose(expected, actual)
3284
3285  @parameterized.named_parameters(
3286      ('LSTMBatchSizeOne', tf.keras.layers.LSTM, True),
3287      ('LSTM', tf.keras.layers.LSTM, False),
3288      ('SimpleRNNBatchSizeOne', tf.keras.layers.SimpleRNN, True),
3289      ('SimpleRNN', tf.keras.layers.SimpleRNN, False),
3290      ('GRUBatchSizeOne', tf.keras.layers.GRU, True),
3291      ('GRU', tf.keras.layers.GRU, False))
3292  @test_util.run_v2_only
3293  def testKerasRNN(self, rnn_layer, default_to_single_batch):
3294    input_data = tf.constant(
3295        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
3296    rnn_obj = rnn_layer(units=10, input_shape=(10, 10))
3297    model = tf.keras.models.Sequential([
3298        tf.keras.layers.Input(shape=(10, 10), name='input'),
3299        rnn_obj,
3300    ])
3301
3302    # Convert model.
3303    converter = lite.TFLiteConverterV2.from_keras_model(model)
3304    converter._experimental_default_to_single_batch_in_tensor_list_ops = default_to_single_batch
3305    if not default_to_single_batch:
3306      converter.target_spec.supported_ops = [
3307          tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3308      ]
3309    tflite_model = converter.convert()
3310    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
3311
3312    # Check values from converted model.
3313    expected_value = model.predict(input_data)
3314    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3315
3316  @parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
3317                                  ('SimpleRNN', tf.keras.layers.SimpleRNN),
3318                                  ('GRU', tf.keras.layers.GRU))
3319  @test_util.run_v2_only
3320  def testKerasRNNMultiBatches(self, rnn_layer):
3321    input_data = tf.constant(
3322        np.array(np.random.random_sample((4, 10, 10)), dtype=np.float32))
3323    # Specify a fixed batch size(4) for the test model.
3324    x = tf.keras.layers.Input(batch_shape=(4, 10, 10))
3325    y = rnn_layer(units=10, input_shape=(10, 10))(x)
3326    model = tf.keras.Model(inputs=[x], outputs=[y])
3327
3328    # Convert model.
3329    converter = lite.TFLiteConverterV2.from_keras_model(model)
3330    tflite_model = converter.convert()
3331    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
3332
3333    # Check values from converted model.
3334    expected_value = model.predict(input_data)
3335    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3336
3337  @parameterized.named_parameters(('ForceToUseBatchSizeOne', True),
3338                                  ('DontForceToUseBatchSizeOne', False))
3339  @test_util.run_v2_only
3340  def testKerasBidirectionalRNNReturnSequence(self, default_to_single_batch):
3341    input_data = tf.constant(
3342        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
3343    model = tf.keras.models.Sequential()
3344    model.add(tf.keras.layers.Input(shape=(10, 10), name='input'))
3345    model.add(
3346        tf.keras.layers.Bidirectional(
3347            tf.keras.layers.LSTM(units=10, return_sequences=True),
3348            input_shape=(10, 10)))
3349    model.add(tf.keras.layers.Flatten())
3350    model.add(tf.keras.layers.Dense(5))
3351    model.add(tf.keras.layers.Activation('softmax'))
3352
3353    # Convert model.
3354    converter = lite.TFLiteConverterV2.from_keras_model(model)
3355    converter._experimental_default_to_single_batch_in_tensor_list_ops = default_to_single_batch
3356    if not default_to_single_batch:
3357      converter.target_spec.supported_ops = [
3358          tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3359      ]
3360    tflite_model = converter.convert()
3361    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
3362
3363    # Check values from converted model.
3364    expected_value = model.predict(input_data)
3365    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3366
3367  @parameterized.named_parameters(('ForceToUseBatchSizeOne', True),
3368                                  ('DontForceToUseBatchSizeOne', False))
3369  @test_util.run_v2_only
3370  def testKerasBidirectionalRNN(self, default_to_single_batch):
3371    input_data = tf.constant(
3372        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
3373    model = tf.keras.models.Sequential()
3374    model.add(tf.keras.layers.Input(shape=(10, 10), name='input'))
3375    model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=10)))
3376    model.add(tf.keras.layers.Dense(5))
3377    model.add(tf.keras.layers.Activation('softmax'))
3378
3379    # Convert model.
3380    converter = lite.TFLiteConverterV2.from_keras_model(model)
3381    converter._experimental_default_to_single_batch_in_tensor_list_ops = default_to_single_batch
3382    if not default_to_single_batch:
3383      converter.target_spec.supported_ops = [
3384          tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3385      ]
3386    tflite_model = converter.convert()
3387    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
3388
3389    # Check values from converted model.
3390    expected_value = model.predict(input_data)
3391    self.assertAllClose(expected_value, actual_value, atol=1e-05)
3392
3393
3394class GrapplerTest(lite_v2_test_util.ModelTest):
3395
3396  @test_util.run_v2_only
3397  def testConstantFolding(self):
3398    # Constant folding handles the tf.broadcast_to operation which was not
3399    # supported by the TFLite at the time this test was added.
3400    input_data = tf.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.], shape=[3, 3])
3401
3402    @tf.function
3403    def func(x):
3404      y_const = tf.constant([1., 2., 3.])
3405      y_broadcast = tf.broadcast_to(y_const, [3, 3])
3406      return tf.matmul(x, y_broadcast)
3407
3408    root = autotrackable.AutoTrackable()
3409    root.f = func
3410    concrete_func = root.f.get_concrete_function(input_data)
3411
3412    # Convert model.
3413    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3414                                                               root)
3415    tflite_model = converter.convert()
3416
3417    # Check values from converted model.
3418    expected_value = root.f(input_data)
3419    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
3420    self.assertAllClose(expected_value, actual_value)
3421
3422    # Enable hybrid quantization, same result
3423    converter.optimizations = [lite.Optimize.DEFAULT]
3424    tflite_model = converter.convert()
3425    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
3426    self.assertAllClose(expected_value, actual_value)
3427
3428
3429class UnknownShapes(lite_v2_test_util.ModelTest):
3430
3431  @test_util.run_v2_only
3432  def testMatMul(self):
3433    input_data = tf.constant(
3434        np.array(np.random.random_sample((10, 4)), dtype=np.float32))
3435
3436    @tf.function(
3437        input_signature=[tf.TensorSpec(shape=[None, 4], dtype=tf.float32)])
3438    def model(in_tensor):
3439      shape = tf.shape(in_tensor)
3440      fill = tf.transpose(tf.fill(shape, 1.))
3441      return tf.matmul(fill, in_tensor)
3442
3443    concrete_func = model.get_concrete_function()
3444
3445    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3446                                                               model)
3447    tflite_model = converter.convert()
3448
3449    # Check values from converted model.
3450    expected_value = concrete_func(input_data)
3451    actual_value = self._evaluateTFLiteModel(
3452        tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])[0]
3453    self.assertAllClose(expected_value, actual_value, atol=1e-06)
3454
3455  def _getIntegerQuantizeModelWithUnknownShapes(self):
3456    np.random.seed(0)
3457
3458    @tf.function(
3459        input_signature=[tf.TensorSpec(shape=[None, 33], dtype=tf.float32)])
3460    def model(input_tensor):
3461      """Define a model with tf.MatMul and unknown shapes."""
3462      # We need the tensor to have more than 1024 elements for quantize_weights
3463      # to kick in. Thus, the [33, 33] shape.
3464      const_tensor = tf.constant(
3465          np.random.uniform(low=-10., high=10., size=[33, 33]),
3466          shape=[33, 33],
3467          dtype=tf.float32,
3468          name='inputB')
3469
3470      shape = tf.shape(input_tensor)
3471      fill = tf.transpose(tf.fill(shape, 1.))
3472      mult = tf.matmul(fill, input_tensor)
3473      return tf.matmul(mult, const_tensor)
3474
3475    root = autotrackable.AutoTrackable()
3476    root.f = model
3477    concrete_func = root.f.get_concrete_function()
3478
3479    def calibration_gen():
3480      for batch in range(5, 20, 5):
3481        for _ in range(5):
3482          yield [np.random.uniform(-1, 1, size=(batch, 33)).astype(np.float32)]
3483
3484    return root, concrete_func, calibration_gen
3485
3486  @test_util.run_v2_only
3487  def testMatMulQuantize(self):
3488    root, concrete_func, _ = self._getIntegerQuantizeModelWithUnknownShapes()
3489    float_converter = lite.TFLiteConverterV2.from_concrete_functions(
3490        [concrete_func], root)
3491    float_tflite_model = float_converter.convert()
3492
3493    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
3494        [concrete_func], root)
3495    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
3496    quantized_tflite_model = quantized_converter.convert()
3497
3498    # The default input and output types should be float.
3499    quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
3500    quantized_interpreter.allocate_tensors()
3501    input_details = quantized_interpreter.get_input_details()
3502    self.assertLen(input_details, 1)
3503    self.assertEqual(np.float32, input_details[0]['dtype'])
3504    self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
3505
3506    # Ensure that the quantized weights tflite model is smaller.
3507    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
3508
3509  @test_util.run_v2_only
3510  def testMatMulCalibrateAndQuantize(self):
3511    root, concrete_func, calibration_gen = (
3512        self._getIntegerQuantizeModelWithUnknownShapes())
3513    float_converter = lite.TFLiteConverterV2.from_concrete_functions(
3514        [concrete_func], root)
3515    float_tflite_model = float_converter.convert()
3516
3517    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
3518        [concrete_func], root)
3519    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
3520    quantized_converter.representative_dataset = calibration_gen
3521    quantized_tflite_model = quantized_converter.convert()
3522
3523    # The default input and output types should be float.
3524    quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
3525    quantized_interpreter.allocate_tensors()
3526    input_details = quantized_interpreter.get_input_details()
3527    self.assertLen(input_details, 1)
3528    self.assertEqual(np.float32, input_details[0]['dtype'])
3529    self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
3530
3531    # Ensure that the quantized weights tflite model is smaller.
3532    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
3533
3534  def testBatchMatMul(self):
3535    input_data_1 = tf.constant(
3536        np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
3537    input_data_2 = tf.constant(
3538        np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
3539
3540    @tf.function(input_signature=[
3541        tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32),
3542        tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32)
3543    ])
3544    def model(in_tensor_1, in_tensor_2):
3545      return tf.matmul(in_tensor_1, in_tensor_2)
3546
3547    concrete_func = model.get_concrete_function()
3548
3549    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3550                                                               model)
3551    tflite_model = converter.convert()
3552
3553    # Check values from converted model.
3554    expected_value = concrete_func(input_data_1, input_data_2)
3555    actual_value = self._evaluateTFLiteModel(
3556        tflite_model, [input_data_1, input_data_2],
3557        input_shapes=[([-1, 256, 256], [1, 256, 256])])[0]
3558    self.assertAllClose(expected_value, actual_value, atol=4)
3559
3560  def testBatchMatMulInputInt8Int8OutputInt32(self):
3561    input_data_1 = tf.constant(
3562        np.array(
3563            np.random.random_integers(-128, high=127, size=(1, 20, 30)),
3564            dtype=np.int8))
3565    input_data_2 = tf.constant(
3566        np.array(
3567            np.random.random_integers(-128, high=127, size=(1, 30, 10)),
3568            dtype=np.int8))
3569
3570    @tf.function(input_signature=[
3571        tf.TensorSpec(shape=[None, 20, 30], dtype=tf.int8),
3572        tf.TensorSpec(shape=[None, 30, 10], dtype=tf.int8)
3573    ])
3574    def model(in_tensor_1, in_tensor_2):
3575      return tf.matmul(in_tensor_1, in_tensor_2, output_type=tf.int32)
3576
3577    concrete_func = model.get_concrete_function()
3578
3579    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3580                                                               model)
3581    tflite_model = converter.convert()
3582
3583    # Check values from converted model.
3584    expected_value = concrete_func(input_data_1, input_data_2)
3585    actual_value = self._evaluateTFLiteModel(
3586        tflite_model, [input_data_1, input_data_2],
3587        input_shapes=[([-1, 20, 30], [1, 20, 30]), ([-1, 30, 10], [1, 30,
3588                                                                   10])])[0]
3589    self.assertAllEqual(expected_value, actual_value)
3590
3591  def testBatchMatMulHybrid(self):
3592    # Test model that does batch matmul of:
3593    # lhs input (1, 256, 128), rhs const (1, 128, 256).
3594    # For dynamic range quantization situation, this will result in hybrid batch
3595    # matmul, where lhs type is float32 and rhs type is int8.
3596
3597    # Intentionally set lhs, rhs sizes to satisfy following conditions:
3598    # 1. rhs const num_elements >= 1024, since dynamic range quantization
3599    # requires const tensor num_elements to be larger than
3600    # min_elements_for_weights (which defaults to 1024).
3601    # (https://github.com/tensorflow/tensorflow/blob/25e649ac3688655547da998eba2715cf70b3e5c9/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc#L262)
3602    # 2. batch_size (256) > accum_dim_size (128) and
3603    # num_units (256) > accum_dim_size (128), to test if the sizes are set
3604    # correctly according to dimensions. See HybridAsymmetricBatchMatMulOpTest
3605    # tests in
3606    # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/batch_matmul_test.cc.
3607    input_data = tf.constant(
3608        np.array(np.random.random_sample((1, 256, 128)), dtype=np.float32))
3609
3610    @tf.function(input_signature=[
3611        tf.TensorSpec(shape=[None, 256, 128], dtype=tf.float32)
3612    ])
3613    def model(in_tensor):
3614      rhs = tf.constant(
3615          np.array(np.random.random_sample((1, 128, 256)), dtype=np.float32))
3616      return tf.matmul(in_tensor, rhs)
3617
3618    concrete_func = model.get_concrete_function()
3619
3620    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3621                                                               model)
3622    converter.optimizations = [tf.lite.Optimize.DEFAULT]
3623    tflite_model = converter.convert()
3624
3625    # Check values from converted model.
3626    expected_value = concrete_func(input_data)
3627    actual_value = self._evaluateTFLiteModel(
3628        tflite_model, [input_data],
3629        input_shapes=[([-1, 256, 128], [1, 256, 128])])[0]
3630    self.assertAllClose(expected_value, actual_value, atol=4)
3631
3632  def testSizeInvalid(self):
3633
3634    @tf.function(input_signature=[
3635        tf.TensorSpec(shape=[1, None, 16, 3], dtype=tf.float32)
3636    ])
3637    def model(in_tensor):
3638      return in_tensor + in_tensor
3639
3640    concrete_func = model.get_concrete_function()
3641
3642    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
3643    # invoke shape checking code.
3644    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3645                                                               model)
3646    converter.experimental_new_converter = False
3647    with self.assertRaises(ValueError) as error:
3648      converter.convert()
3649    self.assertEqual(
3650        'None is only supported in the 1st dimension. Tensor '
3651        '\'in_tensor\' has invalid shape \'[1, None, 16, 3]\'.',
3652        str(error.exception))
3653
3654
3655class ResourceAndVariantTypes(lite_v2_test_util.ModelTest):
3656
3657  @test_util.run_v2_only
3658  def testVariants(self):
3659
3660    @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)])
3661    def model(v):
3662      m = map_ops.empty_tensor_map()
3663      k = tf.constant(1.0)
3664      p = tf.add(k, v)
3665      with ops.control_dependencies([m]):
3666        m2 = map_ops.tensor_map_insert(m, p, v)
3667        with ops.control_dependencies([m2]):
3668          return map_ops.tensor_map_size(m2)
3669
3670    concrete_func = model.get_concrete_function()
3671
3672    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3673                                                               model)
3674    converter.target_spec.supported_ops = [
3675        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3676    ]
3677    tflite_model = converter.convert()
3678    self.assertIsNotNone(tflite_model)
3679
3680    # Check values from converted model.
3681    interpreter = Interpreter(model_content=tflite_model)
3682    input_details = interpreter.get_input_details()
3683    output_details = interpreter.get_output_details()
3684
3685    interpreter.allocate_tensors()
3686
3687    input_data = np.array([1.0], dtype=np.float32)
3688    interpreter.set_tensor(input_details[0]['index'], input_data)
3689
3690    interpreter.invoke()
3691    actual_value = interpreter.get_tensor(output_details[0]['index'])
3692    self.assertEqual(1, actual_value)
3693
3694    interpreter.invoke()
3695    actual_value = interpreter.get_tensor(output_details[0]['index'])
3696    self.assertEqual(1, actual_value)
3697
3698    interpreter.invoke()
3699    actual_value = interpreter.get_tensor(output_details[0]['index'])
3700    self.assertEqual(1, actual_value)
3701
3702  @test_util.run_v2_only
3703  def testVariantsWithCond(self):
3704
3705    def create_v1_saved_model():
3706      saved_model_dir = os.path.join(self.get_temp_dir(), 'variants_with_cond')
3707      with tf.Graph().as_default():
3708        with tf.compat.v1.Session() as sess:
3709          m = map_ops.empty_tensor_map()
3710
3711          def body(i, m):
3712            m = map_ops.tensor_map_insert(m, i, i)
3713            return i + 1, m
3714
3715          in_tensor = tf.compat.v1.placeholder(
3716              shape=[1], dtype=tf.int32, name='input')
3717          _, result_m = tf.cond(in_tensor < 10, lambda: body(in_tensor, m),
3718                                lambda: body(in_tensor + 1, m))
3719          out_tensor = in_tensor + map_ops.tensor_map_size(result_m)
3720
3721          inputs = {'x': in_tensor}
3722          outputs = {'z': out_tensor}
3723          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
3724      return saved_model_dir
3725
3726    saved_model_dir = create_v1_saved_model()
3727
3728    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
3729    converter.target_spec.supported_ops = [
3730        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3731    ]
3732    tflite_model = converter.convert()
3733    self.assertIsNotNone(tflite_model)
3734
3735    # Check values from converted model.
3736    interpreter = Interpreter(model_content=tflite_model)
3737    input_details = interpreter.get_input_details()
3738    output_details = interpreter.get_output_details()
3739
3740    interpreter.allocate_tensors()
3741
3742    input_data = np.array([0], dtype=np.int32)
3743    interpreter.set_tensor(input_details[0]['index'], input_data)
3744
3745    interpreter.invoke()
3746    expected_value = np.array([1], dtype=np.int32)
3747    actual_value = interpreter.get_tensor(output_details[0]['index'])
3748    self.assertEqual(expected_value, actual_value)
3749
3750    interpreter.invoke()
3751    actual_value = interpreter.get_tensor(output_details[0]['index'])
3752    self.assertEqual(expected_value, actual_value)
3753
3754    interpreter.invoke()
3755    actual_value = interpreter.get_tensor(output_details[0]['index'])
3756    self.assertEqual(expected_value, actual_value)
3757
3758  @test_util.run_v2_only
3759  def testVariantsWithWhile(self):
3760
3761    def create_v1_saved_model():
3762      saved_model_dir = os.path.join(self.get_temp_dir(), 'variants_with_while')
3763      with tf.Graph().as_default():
3764        with tf.compat.v1.Session() as sess:
3765          m = map_ops.empty_tensor_map()
3766
3767          def cond(i, m):
3768            del m
3769            return i < 10
3770
3771          def body(i, m):
3772            m = map_ops.tensor_map_insert(m, i, i)
3773            return i + 1, m
3774
3775          _, result_m = tf.while_loop(cond, body, [0, m])
3776          in_tensor = tf.compat.v1.placeholder(
3777              shape=[1], dtype=tf.int32, name='input')
3778          out_tensor = in_tensor + map_ops.tensor_map_size(result_m)
3779
3780          inputs = {'x': in_tensor}
3781          outputs = {'z': out_tensor}
3782          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
3783      return saved_model_dir
3784
3785    saved_model_dir = create_v1_saved_model()
3786
3787    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
3788    converter.target_spec.supported_ops = [
3789        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3790    ]
3791    tflite_model = converter.convert()
3792    self.assertIsNotNone(tflite_model)
3793
3794    # Check values from converted model.
3795    interpreter = Interpreter(model_content=tflite_model)
3796    input_details = interpreter.get_input_details()
3797    output_details = interpreter.get_output_details()
3798
3799    interpreter.allocate_tensors()
3800
3801    input_data = np.array([0], dtype=np.int32)
3802    interpreter.set_tensor(input_details[0]['index'], input_data)
3803
3804    interpreter.invoke()
3805    actual_value = interpreter.get_tensor(output_details[0]['index'])
3806    self.assertEqual(10, actual_value)
3807
3808    interpreter.invoke()
3809    actual_value = interpreter.get_tensor(output_details[0]['index'])
3810    self.assertEqual(10, actual_value)
3811
3812    interpreter.invoke()
3813    actual_value = interpreter.get_tensor(output_details[0]['index'])
3814    self.assertEqual(10, actual_value)
3815
3816  @test_util.run_v2_only
3817  def testResources(self):
3818
3819    def create_v1_saved_model():
3820      saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_resources')
3821      with tf.Graph().as_default():
3822        with tf.compat.v1.Session() as sess:
3823          in_tensor = tf.compat.v1.placeholder(
3824              shape=[1], dtype=tf.float32, name='input')
3825
3826          stack = tf.raw_ops.StackV2(max_size=10, elem_type=tf.float32)
3827          w = tf.raw_ops.StackPushV2(handle=stack, elem=in_tensor)
3828          with ops.control_dependencies([w]):
3829            a = in_tensor + in_tensor
3830            with ops.control_dependencies([a]):
3831              out_tensor = a + tf.raw_ops.StackPopV2(
3832                  handle=stack, elem_type=tf.float32)
3833
3834          inputs = {'x': in_tensor}
3835          outputs = {'z': out_tensor}
3836          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
3837      return saved_model_dir
3838
3839    saved_model_dir = create_v1_saved_model()
3840
3841    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
3842    converter.target_spec.supported_ops = [
3843        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3844    ]
3845    tflite_model = converter.convert()
3846    self.assertIsNotNone(tflite_model)
3847
3848    # Check values from converted model.
3849    interpreter = Interpreter(model_content=tflite_model)
3850    input_details = interpreter.get_input_details()
3851    output_details = interpreter.get_output_details()
3852
3853    interpreter.allocate_tensors()
3854
3855    input_data = np.array([1.0], dtype=np.float32)
3856    interpreter.set_tensor(input_details[0]['index'], input_data)
3857
3858    interpreter.invoke()
3859    actual_value = interpreter.get_tensor(output_details[0]['index'])
3860    self.assertEqual(3.0, actual_value)
3861
3862    interpreter.invoke()
3863    actual_value = interpreter.get_tensor(output_details[0]['index'])
3864    self.assertEqual(3.0, actual_value)
3865
3866    interpreter.invoke()
3867    actual_value = interpreter.get_tensor(output_details[0]['index'])
3868    self.assertEqual(3.0, actual_value)
3869
3870  @test_util.run_v2_only
3871  def testResourcesWithCond(self):
3872
3873    def create_v1_saved_model():
3874      saved_model_dir = os.path.join(self.get_temp_dir(), 'resources_with_cond')
3875      with tf.Graph().as_default():
3876        with tf.compat.v1.Session() as sess:
3877          in_tensor = tf.compat.v1.placeholder(
3878              shape=[1], dtype=tf.float32, name='input')
3879
3880          def body(i, arr):
3881            n = tf.raw_ops.StackPushV2(
3882                handle=arr, elem=tf.cast(i, dtype=tf.float32))
3883            return n, arr
3884
3885          arr = tf.raw_ops.StackV2(max_size=10, elem_type=tf.float32)
3886          n, result_arr = tf.cond(in_tensor < 10, lambda: body(0, arr),
3887                                  lambda: body(1, arr))
3888
3889          with ops.control_dependencies([result_arr, n]):
3890            out_tensor = tf.raw_ops.StackPopV2(
3891                handle=result_arr, elem_type=tf.float32)
3892
3893          inputs = {'x': in_tensor}
3894          outputs = {'a': out_tensor}
3895          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
3896      return saved_model_dir
3897
3898    saved_model_dir = create_v1_saved_model()
3899
3900    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
3901    converter.target_spec.supported_ops = [
3902        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3903    ]
3904    tflite_model = converter.convert()
3905    self.assertIsNotNone(tflite_model)
3906
3907    # Check values from converted model.
3908    interpreter = Interpreter(model_content=tflite_model)
3909    input_details = interpreter.get_input_details()
3910    output_details = interpreter.get_output_details()
3911
3912    interpreter.allocate_tensors()
3913
3914    input_data = np.array([1.0], dtype=np.float32)
3915    interpreter.set_tensor(input_details[0]['index'], input_data)
3916
3917    interpreter.invoke()
3918    actual_value = interpreter.get_tensor(output_details[0]['index'])
3919    self.assertEqual(0.0, actual_value)
3920
3921  @test_util.run_v2_only
3922  def testResourcesWithWhile(self):
3923
3924    def create_v1_saved_model():
3925      saved_model_dir = os.path.join(self.get_temp_dir(),
3926                                     'resources_with_while')
3927      with tf.Graph().as_default():
3928        with tf.compat.v1.Session() as sess:
3929          in_tensor = tf.compat.v1.placeholder(
3930              shape=[1], dtype=tf.float32, name='input')
3931
3932          def cond(i, arr, m):
3933            del arr
3934            del m
3935            return i < 10
3936
3937          def body(i, arr, m):
3938            del m
3939            n = tf.raw_ops.StackPushV2(
3940                handle=arr, elem=tf.cast(i, dtype=tf.float32))
3941            return i + 1, arr, n
3942
3943          arr = tf.raw_ops.StackV2(max_size=10, elem_type=tf.float32)
3944          _, result_arr, n = tf.while_loop(cond, body, [0, arr, 0.0])
3945
3946          with ops.control_dependencies([result_arr, n]):
3947            out_tensor = tf.raw_ops.StackPopV2(
3948                handle=result_arr, elem_type=tf.float32)
3949
3950          inputs = {'x': in_tensor}
3951          outputs = {'a': out_tensor}
3952          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
3953      return saved_model_dir
3954
3955    saved_model_dir = create_v1_saved_model()
3956
3957    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
3958    converter.target_spec.supported_ops = [
3959        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3960    ]
3961    tflite_model = converter.convert()
3962    self.assertIsNotNone(tflite_model)
3963
3964    # Check values from converted model.
3965    interpreter = Interpreter(model_content=tflite_model)
3966    input_details = interpreter.get_input_details()
3967    output_details = interpreter.get_output_details()
3968
3969    interpreter.allocate_tensors()
3970
3971    input_data = np.array([1.0], dtype=np.float32)
3972    interpreter.set_tensor(input_details[0]['index'], input_data)
3973
3974    interpreter.invoke()
3975    actual_value = interpreter.get_tensor(output_details[0]['index'])
3976    self.assertEqual(9.0, actual_value)
3977
3978  @parameterized.named_parameters(('EnableLoweringTensorListOps', True),
3979                                  ('DisableLoweringTensorListOps', False))
3980  @test_util.run_v2_only
3981  def testTensorListWithStaticSize(self, lower_tensor_list_ops):
3982
3983    def create_v1_saved_model():
3984      saved_model_dir = os.path.join(self.get_temp_dir(),
3985                                     'simple_mutable_variable')
3986      with tf.Graph().as_default():
3987        with tf.compat.v1.Session() as sess:
3988          in_tensor = tf.compat.v1.placeholder(
3989              shape=[1], dtype=tf.float32, name='input')
3990
3991          ta = tf.TensorArray(
3992              tf.float32, size=3, dynamic_size=False, clear_after_read=False)
3993          ta = ta.write(0, 10.0)
3994          ta = ta.write(1, 20.0)
3995          ta = ta.write(2, 30.0)
3996
3997          out_tensor = ta.read(0) + ta.read(2)
3998
3999          inputs = {'x': in_tensor}
4000          outputs = {'z': out_tensor}
4001          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
4002      return saved_model_dir
4003
4004    saved_model_dir = create_v1_saved_model()
4005
4006    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
4007    if not lower_tensor_list_ops:
4008      converter.target_spec.supported_ops = [
4009          tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
4010      ]
4011    converter._experimental_lower_tensor_list_ops = lower_tensor_list_ops
4012    tflite_model = converter.convert()
4013    self.assertIsNotNone(tflite_model)
4014
4015    # Check values from converted model.
4016    interpreter = Interpreter(model_content=tflite_model)
4017    input_details = interpreter.get_input_details()
4018    output_details = interpreter.get_output_details()
4019
4020    interpreter.allocate_tensors()
4021
4022    input_data = np.array([1.0], dtype=np.float32)
4023    interpreter.set_tensor(input_details[0]['index'], input_data)
4024
4025    interpreter.invoke()
4026    actual_value = interpreter.get_tensor(output_details[0]['index'])
4027    self.assertEqual(40.0, actual_value)
4028
4029  @parameterized.named_parameters(('EnableLoweringTensorListOps', True),
4030                                  ('DisableLoweringTensorListOps', False))
4031  @test_util.run_v2_only
4032  def testTensorListWithDynamicSize(self, lower_tensor_list_ops):
4033
4034    def create_v1_saved_model():
4035      saved_model_dir = os.path.join(self.get_temp_dir(),
4036                                     'simple_mutable_variable')
4037      with tf.Graph().as_default():
4038        with tf.compat.v1.Session() as sess:
4039          in_tensor = tf.compat.v1.placeholder(
4040              shape=[1], dtype=tf.float32, name='input')
4041
4042          ta = tf.TensorArray(
4043              tf.float32, size=0, dynamic_size=True, clear_after_read=False)
4044          ta = ta.write(0, 10.0)
4045          ta = ta.write(1, 20.0)
4046          ta = ta.write(2, 30.0)
4047
4048          out_tensor = ta.read(0) + ta.read(2)
4049
4050          inputs = {'x': in_tensor}
4051          outputs = {'z': out_tensor}
4052          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
4053      return saved_model_dir
4054
4055    saved_model_dir = create_v1_saved_model()
4056
4057    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
4058    if lower_tensor_list_ops:
4059      with self.assertRaises(convert.ConverterError) as error:
4060        converter.convert()
4061      self.assertIn(
4062          'Lowering tensor list ops is failed. Please consider using Select '
4063          'TF ops and disabling `_experimental_lower_tensor_list_ops` flag in '
4064          'the TFLite converter object.', str(error.exception))
4065
4066    converter.target_spec.supported_ops = [
4067        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
4068    ]
4069    tflite_model = converter.convert()
4070    self.assertIsNotNone(tflite_model)
4071
4072    # Check values from converted model.
4073    interpreter = Interpreter(model_content=tflite_model)
4074    input_details = interpreter.get_input_details()
4075    output_details = interpreter.get_output_details()
4076
4077    interpreter.allocate_tensors()
4078
4079    input_data = np.array([1.0], dtype=np.float32)
4080    interpreter.set_tensor(input_details[0]['index'], input_data)
4081
4082    interpreter.invoke()
4083    actual_value = interpreter.get_tensor(output_details[0]['index'])
4084    self.assertEqual(40.0, actual_value)
4085
4086
4087class CalibrateAndQuantizeWithCustomOpTest(lite_v2_test_util.ModelTest):
4088
4089  def _createGraphWithCustomOp(self):
4090    # Create a graph that has one double op.
4091    np.random.seed(0)
4092
4093    saved_model_dir = os.path.join(self.get_temp_dir(), 'double_model')
4094    with ops.Graph().as_default():
4095      with tf.compat.v1.Session() as sess:
4096        in_tensor = tf.compat.v1.placeholder(
4097            shape=[1, 4], dtype=dtypes.float32, name='input')
4098        out_tensor = double_op.double(in_tensor)
4099        inputs = {'x': in_tensor}
4100        outputs = {'z': out_tensor}
4101        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
4102
4103    def calibration_gen():
4104      for _ in range(100):
4105        yield [np.random.uniform(-1, 1, size=(1, 4)).astype(np.float32)]
4106
4107    return (saved_model_dir, calibration_gen)
4108
4109  def testCustomOpRegistererByName(self):
4110    """Test a calibration with custom op registered by name."""
4111    saved_model_dir, calibration_gen = self._createGraphWithCustomOp()
4112
4113    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
4114    converter.optimizations = [lite.Optimize.DEFAULT]
4115    converter.representative_dataset = calibration_gen
4116    converter.allow_custom_ops = True
4117    converter.target_spec._experimental_custom_op_registerers = [
4118        'TF_TestRegisterer'
4119    ]
4120    tflite_model = converter.convert()
4121    self.assertTrue(tflite_model)
4122    self.assertGreater(test_registerer.get_num_test_registerer_calls(), 0)
4123    self.assertIn('Double', tflite_test_util.get_ops_list(tflite_model))
4124    # Check the conversion metadata.
4125    metadata = get_conversion_metadata(tflite_model)
4126    self.assertIsNotNone(metadata)
4127    self.assertEqual(metadata.options.allowCustomOps, True)
4128
4129    # Check the model works with custom ops.
4130    interpreter = InterpreterWithCustomOps(
4131        model_content=tflite_model, custom_op_registerers=['TF_TestRegisterer'])
4132    interpreter.allocate_tensors()
4133    input_details = interpreter.get_input_details()
4134    test_input = np.array([[0.0, 0.1, 0.2, 0.3]], dtype=np.float32)
4135    interpreter.set_tensor(input_details[0]['index'], test_input)
4136    interpreter.invoke()
4137
4138    output_details = interpreter.get_output_details()
4139    expected_output = np.array([[0.0, 0.2, 0.4, 0.6]], dtype=np.float32)
4140    output_data = interpreter.get_tensor(output_details[0]['index'])
4141    self.assertArrayNear(expected_output[0], output_data[0], err=1e-2)
4142
4143  def testCustomOpRegistererByFunc(self):
4144    """Test a calibration with custom op registered by function."""
4145    saved_model_dir, calibration_gen = self._createGraphWithCustomOp()
4146
4147    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
4148    converter.optimizations = [lite.Optimize.DEFAULT]
4149    converter.representative_dataset = calibration_gen
4150    converter.allow_custom_ops = True
4151    converter.target_spec._experimental_custom_op_registerers = [
4152        test_registerer.TF_TestRegisterer
4153    ]
4154    tflite_model = converter.convert()
4155    self.assertTrue(tflite_model)
4156    self.assertGreater(test_registerer.get_num_test_registerer_calls(), 0)
4157    self.assertIn('Double', tflite_test_util.get_ops_list(tflite_model))
4158
4159    # Check the model works with custom ops.
4160    interpreter = InterpreterWithCustomOps(
4161        model_content=tflite_model,
4162        custom_op_registerers=[test_registerer.TF_TestRegisterer])
4163    interpreter.allocate_tensors()
4164    input_details = interpreter.get_input_details()
4165    test_input = np.array([[0.0, 0.1, 0.2, 0.3]], dtype=np.float32)
4166    interpreter.set_tensor(input_details[0]['index'], test_input)
4167    interpreter.invoke()
4168
4169    output_details = interpreter.get_output_details()
4170    expected_output = np.array([[0.0, 0.2, 0.4, 0.6]], dtype=np.float32)
4171    output_data = interpreter.get_tensor(output_details[0]['index'])
4172    self.assertArrayNear(expected_output[0], output_data[0], err=1e-2)
4173
4174  def testCustomOpRegistererFailure(self):
4175    """Test a calibration with wrong custom op registerer."""
4176    saved_model_dir, calibration_gen = self._createGraphWithCustomOp()
4177
4178    bogus_name = 'CompletelyBogusRegistererName'
4179
4180    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
4181    converter.optimizations = [lite.Optimize.DEFAULT]
4182    converter.representative_dataset = calibration_gen
4183    converter.allow_custom_ops = True
4184    converter.target_spec._experimental_custom_op_registerers = [bogus_name]
4185
4186    with self.assertRaisesRegex(
4187        ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'):
4188      converter.convert()
4189
4190
4191class IntermediatesTest(lite_v2_test_util.ModelTest):
4192
4193  def _run(self, experimental_preserve_all_tensors):
4194
4195    @tf.function
4196    def f(x):
4197      y = tf.add(x, x, name='y')
4198      z = tf.add(y, y, name='z')
4199      w = tf.add(z, z, name='w')
4200      return w
4201
4202    # NOTE this is exactly representable as a float as are the intermeidates of
4203    # f. So direct comparison is ok below.
4204
4205    input_data = np.array(2.0, np.float32)
4206    concrete_func = f.get_concrete_function(input_data)
4207    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
4208                                                               f)
4209    tflite_model = converter.convert()
4210    interpreter = Interpreter(
4211        model_content=tflite_model,
4212        experimental_preserve_all_tensors=experimental_preserve_all_tensors)
4213    interpreter.allocate_tensors()
4214    interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
4215                           input_data)
4216    interpreter.invoke()
4217    out = interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
4218    tensors = {}
4219    for t in interpreter.get_tensor_details():
4220      # With Tensorflow Lite default delegate applied to the model graph, the
4221      # access to original tensors of a delegated op could cause a ValueError
4222      # (i.e. 'Tensor data is null. Run allocate_tensors() first') to be thrown
4223      # out because the tensor memory isn't allocated at all.
4224      val = None
4225      try:
4226        val = interpreter.get_tensor(t['index'])
4227      except ValueError:
4228        pass
4229      tensors.update({t['name']: val})
4230    return (tensors, out)
4231
4232  def testPreserve(self):
4233    tensors, result = self._run(experimental_preserve_all_tensors=True)
4234    # All intermediates should be true and result be true.
4235    self.assertAllClose(tensors['x'], 2.0)
4236    self.assertAllClose(tensors['y'], 4.0)
4237    self.assertAllClose(tensors['z'], 8.0)
4238    self.assertAllClose(result, 16.0)
4239
4240  def testNoPreserve(self):
4241    tensors, result = self._run(experimental_preserve_all_tensors=False)
4242    # One of them should be wrong if preserve is not true, but result should be
4243    # ok. Input should still be ok for repeated invocation.
4244    self.assertAllClose(tensors['x'], 2.0)
4245    self.assertTrue(tensors['y'] != 4.0 or tensors['z'] != 8.0)
4246    self.assertAllClose(result, 16.0)
4247
4248
4249class DatasetOpsTest(lite_v2_test_util.ModelTest):
4250
4251  @test_util.run_v2_only
4252  def testReduceDataset(self):
4253
4254    @tf.function
4255    def model():
4256      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
4257      output = dataset.reduce(np.int32(0), lambda x, y: x + y)
4258      return output
4259
4260    concrete_func = model.get_concrete_function()
4261    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
4262                                                               model)
4263    converter.target_spec.supported_ops = [
4264        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
4265    ]
4266    tflite_model = converter.convert()
4267    self.assertIsNotNone(tflite_model)
4268
4269    # Check values from converted model.
4270    interpreter = Interpreter(model_content=tflite_model)
4271    output_details = interpreter.get_output_details()
4272
4273    interpreter.allocate_tensors()
4274
4275    interpreter.invoke()
4276    actual_value = interpreter.get_tensor(output_details[0]['index'])
4277    self.assertEqual(10, actual_value)
4278
4279
4280class SparsityTest(lite_v2_test_util.ModelTest):
4281
4282  def _getSparsificableModel(self, matrix_b_values):
4283    np.random.seed(0)
4284    root = autotrackable.AutoTrackable()
4285
4286    @tf.function(
4287        input_signature=[tf.TensorSpec(shape=[16, 4], dtype=tf.float32)])
4288    def func(inp):
4289      matrix_b = tf.constant(matrix_b_values, dtype=tf.float32)
4290      matrix_b = tf.reshape(matrix_b, [4, 8])
4291      matmul = tf.matmul(inp, matrix_b, transpose_a=False, transpose_b=False)
4292      output = tf.nn.relu(matmul, name='output')
4293      return output
4294
4295    root.f = func
4296    to_save = root.f.get_concrete_function()
4297    return (root, to_save)
4298
4299  def testRandomSparsity(self):
4300    matrix_b_values = [
4301        0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
4302        0, 0, 0, 0, 0, 0, 0, 1
4303    ]
4304    root, func = self._getSparsificableModel(matrix_b_values)
4305    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
4306                                                                     root)
4307    float_converter.optimizations = [lite.Optimize.EXPERIMENTAL_SPARSITY]
4308    float_tflite_model = float_converter.convert()
4309    self.assertIsNotNone(float_tflite_model)
4310    # Check the conversion metadata.
4311    metadata = get_conversion_metadata(float_tflite_model)
4312    self.assertIsNotNone(metadata)
4313    self.assertAllEqual([metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY],
4314                        metadata.options.modelOptimizationModes)
4315
4316  def testBlockSparsity(self):
4317    matrix_b_values = [
4318        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4319        0, 0, 0, 0, 0, 0, 1, 0
4320    ]
4321    root, func = self._getSparsificableModel(matrix_b_values)
4322    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
4323                                                                     root)
4324    float_converter.optimizations = [lite.Optimize.EXPERIMENTAL_SPARSITY]
4325    float_tflite_model = float_converter.convert()
4326    self.assertIsNotNone(float_tflite_model)
4327    # Check the conversion metadata.
4328    metadata = get_conversion_metadata(float_tflite_model)
4329    self.assertIsNotNone(metadata)
4330    self.assertAllEqual([metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY],
4331                        metadata.options.modelOptimizationModes)
4332
4333  def testQuantizedBlockSparsity(self):
4334    weight_values = np.array([
4335        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4336        [0, 2, 0, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 0, 1, 0],
4337        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4338        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4339        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4340        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4341        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4342        [3, 0, 7, 0, 0, 0, -6, -2, 0, 0, 0, 0, 0, -2, 0, 6],
4343    ])
4344
4345    custom_init = tf.constant_initializer(weight_values.transpose())
4346    model = tf.keras.models.Sequential([
4347        tf.keras.layers.Dense(
4348            8, kernel_initializer=custom_init, input_shape=[16])
4349    ])
4350
4351    def calibration_gen():
4352      for _ in range(10):
4353        yield [np.random.uniform(-1, 1, size=(1, 16)).astype(np.float32) * 16]
4354
4355    quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
4356    quantized_converter.optimizations = [
4357        lite.Optimize.EXPERIMENTAL_SPARSITY, lite.Optimize.DEFAULT
4358    ]
4359    quantized_converter.representative_dataset = calibration_gen
4360    quantized_tflite_model = quantized_converter.convert()
4361    self.assertIsNotNone(quantized_tflite_model)
4362
4363    # Check the conversion metadata.
4364    metadata = get_conversion_metadata(quantized_tflite_model)
4365    self.assertIsNotNone(metadata)
4366    self.assertEqual(
4367        metadata.environment.tensorflowVersion.decode('utf-8'),
4368        versions.__version__)
4369    self.assertEqual(metadata.environment.apiVersion, 2)
4370    self.assertAllEqual([
4371        metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER,
4372        metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY,
4373    ], metadata.options.modelOptimizationModes)
4374
4375    # Check values from converted model.
4376    interpreter = Interpreter(model_content=quantized_tflite_model)
4377    input_details = interpreter.get_input_details()
4378    output_details = interpreter.get_output_details()
4379    interpreter.allocate_tensors()
4380    input_data = np.array(
4381        [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],
4382        dtype=np.float32)
4383    interpreter.set_tensor(input_details[0]['index'], input_data)
4384    interpreter.invoke()
4385    actual_value = interpreter.get_tensor(output_details[0]['index'])
4386    self.assertArrayNear(
4387        np.array([0, 87, 0, 0, 0, 0, 0, 34], dtype=np.float32),
4388        actual_value.flatten(),
4389        err=1)
4390
4391  def testQuantizedButNotEnoughBlockSparsity(self):
4392    # Sparsity level is 25%, which is not enough to apply the sparse conversion.
4393    weight_values = np.array(
4394        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4395         [4, 4, -3, 4, 4, 1, -2, -2, 1, 3, 4, 1, 1, 1, -4, -5],
4396         [1, 1, 5, -1, 3, -1, 1, -3, 4, -3, 2, -3, 3, -1, 3, -4],
4397         [0, -3, -2, 5, 4, 2, 1, 4, -4, 4, 1, -2, 3, -2, -2, -1]])
4398
4399    custom_init = tf.constant_initializer(weight_values.transpose())
4400    model = tf.keras.models.Sequential([
4401        tf.keras.layers.Dense(
4402            4, kernel_initializer=custom_init, input_shape=[16])
4403    ])
4404
4405    def calibration_gen():
4406      for _ in range(10):
4407        yield [np.random.uniform(-1, 1, size=(1, 16)).astype(np.float32) * 16]
4408
4409    quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
4410    quantized_converter.optimizations = [
4411        lite.Optimize.EXPERIMENTAL_SPARSITY, lite.Optimize.DEFAULT
4412    ]
4413    quantized_converter.representative_dataset = calibration_gen
4414    quantized_tflite_model = quantized_converter.convert()
4415    self.assertIsNotNone(quantized_tflite_model)
4416
4417    # Check the conversion metadata.
4418    metadata = get_conversion_metadata(quantized_tflite_model)
4419    self.assertIsNotNone(metadata)
4420    self.assertEqual(
4421        metadata.environment.tensorflowVersion.decode('utf-8'),
4422        versions.__version__)
4423    self.assertEqual(metadata.environment.apiVersion, 2)
4424    self.assertAllEqual([
4425        metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER,
4426    ], metadata.options.modelOptimizationModes)
4427    self.assertNotIn(metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY,
4428                     metadata.options.modelOptimizationModes)
4429    self.assertNotIn(metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY,
4430                     metadata.options.modelOptimizationModes)
4431
4432    # Check values from converted model.
4433    interpreter = Interpreter(model_content=quantized_tflite_model)
4434    input_details = interpreter.get_input_details()
4435    output_details = interpreter.get_output_details()
4436    interpreter.allocate_tensors()
4437    input_data = np.array(
4438        [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],
4439        dtype=np.float32)
4440    interpreter.set_tensor(input_details[0]['index'], input_data)
4441    interpreter.invoke()
4442    actual_value = interpreter.get_tensor(output_details[0]['index'])
4443    self.assertArrayNear(
4444        np.array([0, -3, 4, 35], dtype=np.float32),
4445        actual_value.flatten(),
4446        err=1)
4447
4448if __name__ == '__main__':
4449  test.main()
4450