• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for lite.py functionality related to TensorFlow 2.0."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import ctypes
23import os
24import sys
25
26from absl.testing import parameterized
27import numpy as np
28from six.moves import range
29from six.moves import zip
30import tensorflow as tf
31
32# Force loaded shared object symbols to be globally visible. This is needed so
33# that the interpreter_wrapper, in one .so file, can see the test_registerer,
34# in a different .so file. Note that this may already be set by default.
35# pylint: disable=g-import-not-at-top
36if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
37  sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
38
39from tensorflow.lite.python import convert
40from tensorflow.lite.python import lite
41from tensorflow.lite.python import lite_v2_test_util
42from tensorflow.lite.python import schema_py_generated as schema_fb
43from tensorflow.lite.python import test_util as tflite_test_util
44from tensorflow.lite.python import util
45from tensorflow.lite.python.convert import mlir_quantize
46from tensorflow.lite.python.interpreter import Interpreter
47from tensorflow.lite.python.interpreter import InterpreterWithCustomOps
48from tensorflow.lite.python.interpreter import OpResolverType
49from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer
50from tensorflow.lite.python.testdata import double_op
51from tensorflow.lite.toco import types_pb2 as _types_pb2
52from tensorflow.python.framework import dtypes
53from tensorflow.python.framework import ops
54from tensorflow.python.framework import test_util
55from tensorflow.python.lib.io import file_io
56from tensorflow.python.ops import map_ops
57from tensorflow.python.platform import resource_loader
58from tensorflow.python.platform import test
59from tensorflow.python.saved_model import save_options
60from tensorflow.python.saved_model import saved_model
61from tensorflow.python.saved_model.loader_impl import parse_saved_model
62from tensorflow.python.saved_model.save import save
63from tensorflow.python.training.tracking import tracking
64# pylint: enable=g-import-not-at-top
65
66
67class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
68
69  @test_util.run_v2_only
70  def testTypeInvalid(self):
71    root = self._getSimpleVariableModel()
72    with self.assertRaises(ValueError) as error:
73      _ = lite.TFLiteConverterV2.from_concrete_functions([root.f], root)
74    self.assertIn('call get_concrete_function', str(error.exception))
75
76  @parameterized.named_parameters(
77      ('EnableMlirConverter', True),  # enable mlir
78      ('DisableMlirConverter', False))  # disable mlir
79  @test_util.run_v2_only
80  def testFloat(self, enable_mlir_converter):
81    root = self._getSimpleVariableModel()
82    input_data = tf.constant(1., shape=[1])
83    concrete_func = root.f.get_concrete_function(input_data)
84
85    # Convert model.
86    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
87                                                               root)
88    converter.experimental_new_converter = enable_mlir_converter
89    tflite_model = converter.convert()
90
91    # Check output value from converted model.
92    expected_value = root.f(input_data)
93    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
94    self.assertEqual(expected_value.numpy(), actual_value)
95
96  @parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
97                                  ('_UINT8InputOutput', dtypes.uint8),
98                                  ('_INT16InputOutput', dtypes.int16))
99  @test_util.run_v2_only
100  def testInvalidFloat(self, inference_input_output_type):
101    root = self._getSimpleVariableModel()
102    input_data = tf.constant(1., shape=[1])
103    concrete_func = root.f.get_concrete_function(input_data)
104
105    # Convert model.
106    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
107                                                               root)
108    with self.assertRaises(ValueError) as error:
109      converter.inference_input_type = inference_input_output_type
110      converter.inference_output_type = inference_input_output_type
111      converter.convert()
112    self.assertEqual(
113        'The inference_input_type and inference_output_type '
114        'must be tf.float32.', str(error.exception))
115
116  @test_util.run_v2_only
117  def testScalarInput(self):
118    root = self._getSimpleVariableModel()
119    input_data = tf.constant(1., shape=[])
120    concrete_func = root.f.get_concrete_function(input_data)
121
122    # Convert model.
123    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
124                                                               root)
125    tflite_model = converter.convert()
126
127    # Check values from converted model.
128    expected_value = root.f(input_data)
129    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
130    self.assertEqual(expected_value.numpy(), actual_value)
131
132  @test_util.run_v2_only
133  def testMultiFunctionModel(self):
134    """Convert a single model in a multi-functional model."""
135    root = self._getMultiFunctionModel()
136    input_data = tf.constant(1., shape=[1])
137    concrete_func = root.add.get_concrete_function(input_data)
138
139    # Convert model and ensure model is not None.
140    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
141                                                               root)
142    tflite_model = converter.convert()
143
144    # Check values from converted model.
145    expected_value = root.add(input_data)
146    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
147    self.assertEqual(expected_value.numpy(), actual_value)
148
149  @test_util.run_v2_only
150  def testConvertMultipleFunctions(self):
151    """Convert multiple functions in a multi-functional model."""
152    root = self._getMultiFunctionModel()
153    input_data = tf.constant(1., shape=[1])
154    add_func = root.add.get_concrete_function(input_data)
155    sub_func = root.sub.get_concrete_function(input_data)
156
157    # Try converting multiple functions.
158    converter = lite.TFLiteConverterV2.from_concrete_functions(
159        [add_func, sub_func], root)
160    tflite_model = converter.convert()
161
162    # Check signatures are valid from converted model.
163    interpreter = Interpreter(model_content=tflite_model)
164    signature_defs = interpreter.get_signature_list()
165
166    # Verify the SignatureDef structure returned is as expected.
167    self.assertEqual(len(signature_defs), 2)
168    self.assertEqual(list(signature_defs.keys()), ['add', 'sub'])
169    self.assertEqual(len(signature_defs.values()), 2)
170    self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs'])
171    self.assertCountEqual(signature_defs['add']['inputs'], ['x'])
172    self.assertEqual(list(signature_defs['add']['outputs']), ['output_0'])
173    self.assertEqual(list(signature_defs['sub'].keys()), ['inputs', 'outputs'])
174    self.assertCountEqual(signature_defs['sub']['inputs'], ['x'])
175    self.assertEqual(list(signature_defs['sub']['outputs']), ['output_0'])
176
177    # Verify the Signature runner executions.
178    add_signature_runner = interpreter.get_signature_runner('add')
179    add_output = add_signature_runner(x=input_data)
180    self.assertEqual(add_output['output_0'], 3)
181
182    sub_signature_runner = interpreter.get_signature_runner('sub')
183    sub_output = sub_signature_runner(x=input_data)
184    self.assertEqual(sub_output['output_0'], -2)
185
186  def _getIntegerQuantizeModel(self):
187    np.random.seed(0)
188
189    root = tracking.AutoTrackable()
190
191    @tf.function(
192        input_signature=[tf.TensorSpec(shape=[1, 5, 5, 3], dtype=tf.float32)])
193    def func(inp):
194      conv = tf.nn.conv2d(
195          inp, tf.ones([3, 3, 3, 16]), strides=[1, 1, 1, 1], padding='SAME')
196      output = tf.nn.relu(conv, name='output')
197      return output
198
199    def calibration_gen():
200      for _ in range(5):
201        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
202
203    root.f = func
204    to_save = root.f.get_concrete_function()
205    return (root, to_save, calibration_gen)
206
207  @parameterized.named_parameters(
208      ('EnableMlirQuantizer', True),  # enable mlir quantizer
209      ('DisableMlirQuantizer', False))  # disable mlir quantizer
210  def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer):
211    root, func, calibration_gen = self._getIntegerQuantizeModel()
212
213    # Convert float model.
214    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
215                                                                     root)
216    float_tflite_model = float_converter.convert()
217    self.assertIsNotNone(float_tflite_model)
218
219    # Convert quantized model.
220    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
221                                                                         root)
222    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
223    quantized_converter.representative_dataset = calibration_gen
224    quantized_converter.experimental_new_quantizer = mlir_quantizer
225    quantized_tflite_model = quantized_converter.convert()
226    self.assertIsNotNone(quantized_tflite_model)
227
228    # The default input and output types should be float.
229    interpreter = Interpreter(model_content=quantized_tflite_model)
230    interpreter.allocate_tensors()
231    input_details = interpreter.get_input_details()
232    self.assertLen(input_details, 1)
233    self.assertEqual(np.float32, input_details[0]['dtype'])
234    output_details = interpreter.get_output_details()
235    self.assertLen(output_details, 1)
236    self.assertEqual(np.float32, output_details[0]['dtype'])
237
238    # Ensure that the quantized weights tflite model is smaller.
239    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
240
241  @parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
242                                  ('_UINT8InputOutput', dtypes.uint8),
243                                  ('_INT16InputOutput', dtypes.int16))
244  @test_util.run_v2_only
245  def testInvalidPostTrainingDynamicRangeQuantization(
246      self, inference_input_output_type):
247    root, func, _ = self._getIntegerQuantizeModel()
248
249    # Convert float model.
250    converter = lite.TFLiteConverterV2.from_concrete_functions([func], root)
251    tflite_model = converter.convert()
252    self.assertTrue(tflite_model)
253
254    # Convert quantized model.
255    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
256                                                                         root)
257    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
258    with self.assertRaises(ValueError) as error:
259      quantized_converter.inference_input_type = inference_input_output_type
260      quantized_converter.inference_output_type = inference_input_output_type
261      quantized_converter.convert()
262    self.assertEqual(
263        'The inference_input_type and inference_output_type '
264        'must be tf.float32.', str(error.exception))
265
266  @parameterized.named_parameters(
267      ('_Default', False, False, dtypes.float32),
268      ('_INT8InputOutput', False, False, dtypes.int8),
269      ('_UINT8InputOutput', False, False, dtypes.uint8),
270      ('_INT16Quantize', False, True, dtypes.float32),
271      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
272      ('_IntOnly', True, False, dtypes.float32),
273      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
274      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
275      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
276      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
277  def testIntegerQuantization(self, is_int_only, is_int16_quantize,
278                              inference_input_output_type):
279    root, func, calibration_gen = self._getIntegerQuantizeModel()
280
281    # Convert float model.
282    converter = lite.TFLiteConverterV2.from_concrete_functions([func], root)
283    tflite_model = converter.convert()
284    self.assertTrue(tflite_model)
285
286    # Convert quantized model.
287    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
288                                                                         root)
289    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
290    quantized_converter.representative_dataset = calibration_gen
291    if is_int_only:
292      if is_int16_quantize:
293        quantized_converter.target_spec.supported_ops = [
294            lite.OpsSet.
295            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
296        ]
297      else:
298        quantized_converter.target_spec.supported_ops = [
299            lite.OpsSet.TFLITE_BUILTINS_INT8
300        ]
301    else:
302      if is_int16_quantize:
303        quantized_converter.target_spec.supported_ops = [
304            lite.OpsSet.
305            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
306            lite.OpsSet.TFLITE_BUILTINS
307        ]
308    quantized_converter.inference_input_type = inference_input_output_type
309    quantized_converter.inference_output_type = inference_input_output_type
310    quantized_tflite_model = quantized_converter.convert()
311    self.assertIsNotNone(quantized_tflite_model)
312
313    interpreter = Interpreter(model_content=quantized_tflite_model)
314    interpreter.allocate_tensors()
315    input_details = interpreter.get_input_details()
316    self.assertLen(input_details, 1)
317    self.assertEqual(inference_input_output_type.as_numpy_dtype,
318                     input_details[0]['dtype'])
319    output_details = interpreter.get_output_details()
320    self.assertLen(output_details, 1)
321    self.assertEqual(inference_input_output_type.as_numpy_dtype,
322                     output_details[0]['dtype'])
323
324    # Ensure that the quantized tflite model is smaller.
325    self.assertLess(len(quantized_tflite_model), len(tflite_model))
326
327  @parameterized.named_parameters(
328      ('_INT16Quantize_INT8InputOutput', True, dtypes.int8))
329  def testInvalidIntegerQuantization(self, is_int16_quantize,
330                                     inference_input_output_type):
331    root, func, calibration_gen = self._getIntegerQuantizeModel()
332
333    # Convert quantized model.
334    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
335                                                                         root)
336    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
337    quantized_converter.representative_dataset = calibration_gen
338    if is_int16_quantize:
339      quantized_converter.target_spec.supported_ops = [
340          lite.OpsSet.
341          EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
342          lite.OpsSet.TFLITE_BUILTINS
343      ]
344    with self.assertRaises(ValueError) as error:
345      quantized_converter.inference_input_type = dtypes.int8
346      quantized_converter.inference_output_type = dtypes.int8
347      quantized_converter.convert()
348    self.assertEqual(
349        'The inference_input_type and inference_output_type '
350        "must be in ['tf.float32', 'tf.int16'].", str(error.exception))
351
352  def testCalibrateAndQuantizeBuiltinInt16(self):
353    root, func, calibration_gen = self._getIntegerQuantizeModel()
354
355    # Convert float model.
356    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
357                                                                     root)
358    float_tflite_model = float_converter.convert()
359    self.assertIsNotNone(float_tflite_model)
360
361    converter = lite.TFLiteConverterV2.from_concrete_functions([func], root)
362    # TODO(b/156309549): We should add INT16 to the builtin types.
363    converter.optimizations = [lite.Optimize.DEFAULT]
364    converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
365    converter.representative_dataset = calibration_gen
366    converter._experimental_calibrate_only = True
367    calibrated_tflite = converter.convert()
368    quantized_tflite_model = mlir_quantize(
369        calibrated_tflite, inference_type=_types_pb2.QUANTIZED_INT16)
370
371    self.assertIsNotNone(quantized_tflite_model)
372
373    # The default input and output types should be float.
374    interpreter = Interpreter(model_content=quantized_tflite_model)
375    interpreter.allocate_tensors()
376    input_details = interpreter.get_input_details()
377    self.assertLen(input_details, 1)
378    self.assertEqual(np.float32, input_details[0]['dtype'])
379    output_details = interpreter.get_output_details()
380    self.assertLen(output_details, 1)
381    self.assertEqual(np.float32, output_details[0]['dtype'])
382
383    # Ensure that the quantized weights tflite model is smaller.
384    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
385
386  @test_util.run_v2_only
387  def testSignatureDefs(self):
388    """Test converting SignatureDef is correct and uses SignatureDef API."""
389    root = self._getMultiFunctionModel()
390    input_data = tf.constant(1., shape=[1])
391    add_func = root.add.get_concrete_function(input_data)
392
393    converter = lite.TFLiteConverterV2([add_func], trackable_obj=root)
394    tflite_model = converter.convert()
395
396    # Check values from converted model.
397    expected_value = add_func(input_data)
398    interpreter = Interpreter(model_content=tflite_model)
399    signature_defs = interpreter.get_signature_list()
400    results = self._evaluateTFLiteModelUsingSignatureDef(
401        tflite_model, 'serving_default', {'x': input_data})
402    self.assertLen(list(results.keys()), 1)
403    self.assertStartsWith(list(results.keys())[0], 'output')
404    self.assertAllClose(
405        expected_value.numpy(),
406        results[signature_defs['serving_default']['outputs'][0]])
407
408    # Verify the SignatureDef structure returned is as expected.
409    self.assertEqual(len(signature_defs), 1)
410    self.assertEqual(list(signature_defs.keys()), ['serving_default'])
411    self.assertEqual(len(signature_defs.values()), 1)
412    self.assertEqual(
413        list(signature_defs['serving_default'].keys()), ['inputs', 'outputs'])
414    self.assertCountEqual(signature_defs['serving_default']['inputs'], ['x'])
415    self.assertLen(list(signature_defs['serving_default']['outputs']), 1)
416    self.assertStartsWith(
417        list(signature_defs['serving_default']['outputs'])[0], 'output')
418
419  @test_util.run_v2_only
420  def testNoSignatureDefsWhenTrackingObjIsNone(self):
421    """Test converting SignatureDef is correct and uses SignatureDef API."""
422    root = self._getSimpleVariableModel()
423    input_data = tf.constant(1., shape=[1])
424    concrete_func = root.f.get_concrete_function(input_data)
425
426    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
427                                                               None)
428    tflite_model = converter.convert()
429
430    # Check values from converted model.
431    interpreter = Interpreter(model_content=tflite_model)
432    signature_defs = interpreter.get_signature_list()
433    # Verify that there is no SignatureDef structure found.
434    self.assertEqual(len(signature_defs), 0)
435
436  @test_util.run_v2_only
437  def testNoSignatureDefsWhenInvalidTrackingObjIsGiven(self):
438    """Test converting SignatureDef is correct and uses SignatureDef API."""
439    root = self._getSimpleVariableModel()
440    input_data = tf.constant(1., shape=[1])
441    concrete_func = root.f.get_concrete_function(input_data)
442
443    converter = lite.TFLiteConverterV2.from_concrete_functions(
444        [concrete_func], trackable_obj=tracking.AutoTrackable())
445    tflite_model = converter.convert()
446
447    # Check values from converted model.
448    interpreter = Interpreter(model_content=tflite_model)
449    signature_defs = interpreter.get_signature_list()
450    # Verify that there is no SignatureDef structure found.
451    self.assertEqual(len(signature_defs), 0)
452
453  @test_util.run_v2_only
454  def testTrackbleObject(self):
455    """Test converting with trackable objects."""
456    root = self._getMultiFunctionModel()
457    input_data = tf.constant(1., shape=[1])
458    add_func = root.add.get_concrete_function(input_data)
459
460    converter = lite.TFLiteConverterV2.from_concrete_functions(
461        [add_func], trackable_obj=root)
462    tflite_model = converter.convert()
463
464    # Check values from converted model.
465    expected_value = add_func(input_data)
466    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
467    self.assertEqual(expected_value.numpy(), actual_value)
468
469  def _getTrainingTimeQuantizedModel(self):
470
471    class QLinear(tf.keras.layers.Layer):
472
473      def __init__(self, units=3, **kwargs):
474        super(QLinear, self).__init__(**kwargs)
475        self.units = units
476
477      def build(self, input_shape):
478        self.w = self.add_weight(
479            'weight',
480            shape=(input_shape[-1], self.units),
481            initializer='random_normal',
482            trainable=True)
483        self.min_var = self.add_weight(
484            'min',
485            initializer=tf.keras.initializers.Constant(-6.0),
486            trainable=False)
487        self.max_var = self.add_weight(
488            'max',
489            initializer=tf.keras.initializers.Constant(6.0),
490            trainable=False)
491
492      def call(self, inputs):
493        x = tf.quantization.fake_quant_with_min_max_vars(
494            inputs, self.min_var, self.max_var)
495
496        w_fq = tf.quantization.fake_quant_with_min_max_vars(
497            self.w, self.min_var, self.max_var)
498        x = tf.matmul(x, w_fq)
499
500        x = tf.quantization.fake_quant_with_min_max_vars(
501            x, self.min_var, self.max_var)
502
503        return x
504
505    return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
506
507  @parameterized.named_parameters(
508      ('_DefaultFLOAT32InputOutput', dtypes.float32),
509      ('_INT8InputOutput', dtypes.int8), ('_UINT8InputOutput', dtypes.uint8))
510  @test_util.run_v2_only
511  def testTrainingTimeQuantization(self, inference_input_output_type):
512    model = self._getTrainingTimeQuantizedModel()
513
514    float_converter = lite.TFLiteConverterV2.from_keras_model(model)
515    float_tflite_model = float_converter.convert()
516    self.assertIsNotNone(float_tflite_model)
517
518    quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
519    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
520    quantized_converter.inference_input_type = inference_input_output_type
521    quantized_converter.inference_output_type = inference_input_output_type
522    quantized_tflite_model = quantized_converter.convert()
523    self.assertIsNotNone(quantized_tflite_model)
524
525    interpreter = Interpreter(model_content=quantized_tflite_model)
526    interpreter.allocate_tensors()
527    input_details = interpreter.get_input_details()
528    self.assertLen(input_details, 1)
529    self.assertEqual(inference_input_output_type.as_numpy_dtype,
530                     input_details[0]['dtype'])
531    output_details = interpreter.get_output_details()
532    self.assertLen(output_details, 1)
533    self.assertEqual(inference_input_output_type.as_numpy_dtype,
534                     output_details[0]['dtype'])
535
536    # Ensure that the quantized tflite model is smaller.
537    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
538
539  @test_util.run_v2_only
540  def testNewQuantizer(self):
541    """Test the model quantized by the new converter."""
542    root, func, calibration_gen = self._getIntegerQuantizeModel()
543
544    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
545                                                                         root)
546    quantized_converter.target_spec.supported_ops = [
547        lite.OpsSet.TFLITE_BUILTINS_INT8
548    ]
549    quantized_converter.representative_dataset = calibration_gen
550
551    # default quantizer
552    quantized_converter.experimental_new_quantizer = False
553    old_tflite = quantized_converter.convert()
554
555    # new quantizer
556    quantized_converter.experimental_new_quantizer = True
557    new_tflite = quantized_converter.convert()
558
559    for _ in range(5):
560      input_data = tf.constant(
561          np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
562      old_value = self._evaluateTFLiteModel(old_tflite, [input_data])
563      new_value = self._evaluateTFLiteModel(new_tflite, [input_data])
564      self.assertAllClose(old_value, new_value, atol=1e-01)
565
566  @parameterized.named_parameters(
567      ('EnableMlirConverter', True),  # enable mlir
568      ('DisableMlirConverter', False))  # disable mlir
569  @test_util.run_v2_only
570  def testEmbeddings(self, enable_mlir_converter):
571    """Test model with embeddings."""
572    input_data = tf.constant(
573        np.array(np.random.random_sample((20)), dtype=np.int32))
574
575    class EmbeddingModel(tf.keras.Model):
576
577      def __init__(self):
578        super(EmbeddingModel, self).__init__()
579        self.shared_weights = self.add_weight(
580            'weights',
581            shape=(2000, 300),
582            dtype=tf.float32,
583            initializer=tf.random_normal_initializer(
584                mean=0.0, stddev=300**(-0.5)))
585
586      @tf.function(input_signature=[tf.TensorSpec(shape=(20), dtype=tf.int32)])
587      def func(self, x):
588        return tf.gather(self.shared_weights, x)
589
590    # Building the model.
591    root = EmbeddingModel()
592    concrete_func = root.func.get_concrete_function()
593
594    # Convert model.
595    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
596                                                               root)
597    converter.experimental_new_converter = enable_mlir_converter
598    tflite_model = converter.convert()
599
600    # Check values from converted model.
601    expected_value = root.func(input_data)
602    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
603    self.assertAllClose(expected_value.numpy(), actual_value[0], atol=1e-05)
604
605  @test_util.run_v2_only
606  def testGraphDebugInfo(self):
607    """Test a concrete function has debug info captured."""
608    root = tracking.AutoTrackable()
609    root.v1 = tf.Variable(3.)
610    root.f = tf.function(lambda x: root.v1 * x)
611    input_data = tf.constant(1., shape=[1])
612    concrete_func = root.f.get_concrete_function(input_data)
613
614    # Convert model.
615    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
616                                                               root)
617    converter.convert()
618    self._assertValidDebugInfo(converter._debug_info)
619
620  def _getIntegerQuantizationModelWithFlexOp(self):
621    np.random.seed(0)
622
623    root = tracking.AutoTrackable()
624
625    @tf.function(input_signature=[
626        tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32)
627    ])
628    def func(inp):
629      tanh = tf.math.tanh(inp)
630      # Flex delegate will merge the consecutive conv3d and erf ops into one
631      # Delegate node.
632      conv3d = tf.nn.conv3d(
633          tanh,
634          tf.ones([3, 3, 3, 3, 3]),
635          strides=[1, 1, 1, 1, 1],
636          padding='SAME')
637      erf = tf.math.erf(conv3d)
638      output = tf.math.tanh(erf)
639      return output
640
641    def calibration_gen():
642      for _ in range(5):
643        yield [
644            np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32)
645        ]
646
647    root.f = func
648    return (root, root.f.get_concrete_function(), calibration_gen)
649
650  @parameterized.named_parameters(
651      ('_Default', False, False, dtypes.float32),
652      ('_INT8InputOutput', False, False, dtypes.int8),
653      ('_UINT8InputOutput', False, False, dtypes.uint8),
654      ('_INT16Quantize', False, True, dtypes.float32),
655      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
656      ('_IntOnly', True, False, dtypes.float32),
657      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
658      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
659      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
660      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
661  @test_util.run_v2_only
662  def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize,
663                                        inference_input_output_type):
664    root, func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp()
665
666    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
667        [func], root)
668    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
669    quantized_converter.representative_dataset = calibration_gen
670    if is_int_only:
671      if is_int16_quantize:
672        quantized_converter.target_spec.supported_ops = [
673            lite.OpsSet.
674            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
675            lite.OpsSet.SELECT_TF_OPS
676        ]
677      else:
678        quantized_converter.target_spec.supported_ops = [
679            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS
680        ]
681    else:
682      if is_int16_quantize:
683        quantized_converter.target_spec.supported_ops = [
684            lite.OpsSet.
685            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
686            lite.OpsSet.TFLITE_BUILTINS,
687            lite.OpsSet.SELECT_TF_OPS
688        ]
689      else:
690        quantized_converter.target_spec.supported_ops = [
691            lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS
692        ]
693
694    quantized_converter.inference_input_type = inference_input_output_type
695    quantized_converter.inference_output_type = inference_input_output_type
696    quantized_tflite_model = quantized_converter.convert()
697    self.assertIsNotNone(quantized_tflite_model)
698
699    interpreter = Interpreter(model_content=quantized_tflite_model)
700    interpreter.allocate_tensors()
701    input_details = interpreter.get_input_details()
702    self.assertLen(input_details, 1)
703    self.assertEqual(inference_input_output_type.as_numpy_dtype,
704                     input_details[0]['dtype'])
705    output_details = interpreter.get_output_details()
706    self.assertLen(output_details, 1)
707    self.assertEqual(inference_input_output_type.as_numpy_dtype,
708                     output_details[0]['dtype'])
709
710  def _getIntegerQuantizationModelWithUnsupportedOps(self):
711    np.random.seed(0)
712
713    root = tracking.AutoTrackable()
714
715    @tf.function(input_signature=[
716        tf.TensorSpec(shape=[3], dtype=tf.float32),
717        tf.TensorSpec(shape=[3], dtype=tf.float32)
718    ])
719    def func(a, b):
720      # ceil kernel does not support int8 nor int16 types neither.
721      left = tf.math.ceil(a)
722      right = tf.nn.tanh(b)
723      add = tf.math.add(left, right)
724      # ceil kernel does not support int8 nor int16 types neither.
725      output = tf.math.ceil(add)
726      return (output, right)
727
728    def calibration_gen():
729      for _ in range(5):
730        yield [
731            np.random.uniform(-1, 1, size=(3)).astype(np.float32),
732            np.random.uniform(-1, 1, size=(3)).astype(np.float32)
733        ]
734
735    root.f = func
736    return (root, root.f.get_concrete_function(), calibration_gen)
737
738  @parameterized.named_parameters(
739      ('_INT8InputOutput', False, False, dtypes.int8),
740      ('_UINT8InputOutput', False, False, dtypes.uint8),
741      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
742      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
743      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
744      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
745      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
746      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
747  @test_util.run_v2_only
748  def testIntegerQuantizationWithUnsupportedOps(self,
749                                                is_int_only,
750                                                is_int16_quantize,
751                                                inference_input_output_type,
752                                                enable_mlir_quantizer=False):
753    root, func, calib_gen = self._getIntegerQuantizationModelWithUnsupportedOps(
754    )
755
756    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
757        [func], root)
758    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
759    quantized_converter.representative_dataset = calib_gen
760    if is_int_only:
761      if is_int16_quantize:
762        quantized_converter.target_spec.supported_ops = [
763            lite.OpsSet.
764            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
765            lite.OpsSet.TFLITE_BUILTINS
766        ]
767      else:
768        quantized_converter.target_spec.supported_ops = [
769            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
770        ]
771    else:
772      if is_int16_quantize:
773        quantized_converter.target_spec.supported_ops = [
774            lite.OpsSet.
775            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
776            lite.OpsSet.TFLITE_BUILTINS
777        ]
778      else:
779        quantized_converter.target_spec.supported_ops = [
780            lite.OpsSet.TFLITE_BUILTINS
781        ]
782
783    quantized_converter.inference_input_type = inference_input_output_type
784    quantized_converter.inference_output_type = inference_input_output_type
785    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
786    quantized_tflite_model = quantized_converter.convert()
787    self.assertIsNotNone(quantized_tflite_model)
788
789    expected_dtype = inference_input_output_type.as_numpy_dtype
790    # Allow float32 for fallback on non-quantizable op.
791    expected_ceil_dtype = (
792        expected_dtype if enable_mlir_quantizer else dtypes.float32)
793
794    interpreter = Interpreter(model_content=quantized_tflite_model)
795    interpreter.allocate_tensors()
796    input_details = interpreter.get_input_details()
797    self.assertLen(input_details, 2)
798    self.assertEqual(input_details[0]['dtype'], expected_dtype)
799    self.assertEqual(input_details[1]['dtype'], expected_ceil_dtype)
800    output_details = interpreter.get_output_details()
801    self.assertLen(output_details, 2)
802    self.assertEqual(output_details[0]['dtype'], expected_dtype)
803    self.assertEqual(output_details[1]['dtype'], expected_ceil_dtype)
804
805  @parameterized.named_parameters(
806      ('_BlocklistedNoneWithLowering', None, None, True),
807      ('_BlocklistedNoneWithoutLowering', None, None, False),
808      ('_BlocklistedOpsWithLowering', {'CONV_2D'}, None, True),
809      ('_BlocklistedOpsWithoutLowering', {'CONV_2D'}, None, False),
810      ('_BlocklistedNodesWithLowering', None, {'PartitionedCall:0'}, True),
811      ('_BlocklistedNodesWithoutLowering', None, {'Identity'}, False))
812  @test_util.run_v2_only
813  def testNewQuantizerBlocklistingArgs(self, denylisted_ops, denylisted_nodes,
814                                       lower_to_saved_model):
815    """Test the model quantized by the new converter and denylisted options."""
816    root, func, calibration_gen = self._getIntegerQuantizeModel()
817    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
818                                                                         root)
819    quantized_converter.target_spec.supported_ops = [
820        lite.OpsSet.TFLITE_BUILTINS_INT8
821    ]
822    quantized_converter.representative_dataset = calibration_gen
823    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
824    quantized_converter.experimental_new_quantizer = True
825    quantized_converter._experimental_calibrate_only = True
826    quantized_converter.experimental_lower_to_saved_model = lower_to_saved_model
827    calibrated = quantized_converter.convert()
828    quantized_tflite_model = mlir_quantize(
829        calibrated,
830        denylisted_ops=denylisted_ops,
831        denylisted_nodes=denylisted_nodes)
832    interpreter = Interpreter(model_content=quantized_tflite_model)
833    details = interpreter.get_tensor_details()
834    num_quantized_tensors = sum(
835        [1 for detail in details
836         if len(detail['quantization_parameters']['scales'])])
837    if denylisted_nodes or denylisted_ops:
838      self.assertEqual(num_quantized_tensors, 0)
839      return
840    self.assertEqual(num_quantized_tensors, 4)  # quant, filter, bias, dequant
841
842  @parameterized.named_parameters(
843      ('_SingleLayer', False),
844      ('_WholeModel', True),
845  )
846  @test_util.run_v2_only
847  def testNewQuantizerNumericVerificationDebugMode(self, whole_model_verify):
848    """Test the model quantized by the new converter with numeric verify ops."""
849    root, func, calibration_gen = self._getIntegerQuantizeModel()
850
851    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
852                                                                         root)
853    quantized_converter.target_spec.supported_ops = [
854        lite.OpsSet.TFLITE_BUILTINS_INT8
855    ]
856    quantized_converter.representative_dataset = calibration_gen
857
858    # Create a TFLite model with new quantizer.
859    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
860    quantized_converter.experimental_new_quantizer = True
861    production_tflite = quantized_converter.convert()
862    # Create a TFLite model with new quantizer and numeric verify ops.
863    quantized_converter._experimental_calibrate_only = True
864    calibrated = quantized_converter.convert()
865    debug_mode_tflite = mlir_quantize(
866        calibrated,
867        enable_numeric_verify=True,
868        enable_whole_model_verify=whole_model_verify)
869
870    # Check if adding debug mode should output a different flatbuffer.
871    self.assertNotEqual(production_tflite, debug_mode_tflite)
872
873    # Check if newly added ops are numeric verify ops.
874    input_data = tf.constant(
875        np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
876
877    def examine_tflite_model(tflite_content, input_data):
878      interpreter = Interpreter(
879          model_content=tflite_content,
880          experimental_op_resolver_type=OpResolverType
881          .BUILTIN_WITHOUT_DEFAULT_DELEGATES)
882      interpreter.allocate_tensors()
883      input_details = interpreter.get_input_details()
884      interpreter.set_tensor(input_details[0]['index'], input_data.numpy())
885      interpreter.invoke()
886      tensor_details = interpreter.get_tensor_details()
887      return {
888          details['name']: interpreter.get_tensor(details['index'])
889          for details in interpreter.get_tensor_details()
890      }, tensor_details
891
892    tflite_result, _ = examine_tflite_model(production_tflite, input_data)
893    debug_mode_tflite_result, debug_tensor_details = examine_tflite_model(
894        debug_mode_tflite, input_data)
895
896    # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
897    num_production_quantize_ops = len([
898        None for output_tensor_name in tflite_result
899        if 'tfl.quantize' in output_tensor_name
900    ])
901    self.assertEqual(num_production_quantize_ops, 1)
902    # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
903    num_debug_quantize_ops = len([
904        None for output_tensor_name in debug_mode_tflite_result
905        if 'tfl.quantize' in output_tensor_name
906    ])
907    # Two numbers should be equal.
908    self.assertEqual(num_production_quantize_ops, num_debug_quantize_ops)
909    # DebugMode TFLite flatbuffer should have NumericVerifyOps more than zero.
910    # The name has the prefix "NumericVerify/{name}:{id}
911    # where {name} is the tensor name of the original quantized op's activation,
912    # and {id} is its tensor id.
913    num_debug_ops = 0
914    for output_tensor_name in debug_mode_tflite_result:
915      if 'NumericVerify' in output_tensor_name:
916        pos_end_prefix = len('NumericVerify/')
917        pos_colon = output_tensor_name.rfind(':')
918        self.assertEqual('NumericVerify/', output_tensor_name[:pos_end_prefix])
919        tensor_id = int(output_tensor_name[pos_colon + 1:])
920        original_tensor_name = output_tensor_name[pos_end_prefix:pos_colon]
921        self.assertEqual(original_tensor_name,
922                         debug_tensor_details[tensor_id]['name'])
923        num_debug_ops += 1
924    self.assertEqual(num_debug_ops, 1)
925    # The number of debug ops should be equal to that of quantized ops.
926    self.assertEqual(num_debug_ops, num_debug_quantize_ops)
927
928  @parameterized.named_parameters(
929      ('_PerChannelQuant', False, False),
930      ('_PerChannelMlirQuant', False, True),
931      ('_PerTensorQuant', True, False),
932      ('_PerTensorMlirQuant', True, True))
933  @test_util.run_v2_only
934  def testDisablePerChannelQuantization(self, disable_per_channel=False,
935                                        enable_mlir_quantizer=False):
936    k_conv_name = 'Conv2D1'
937    k_num_filters = 16
938    root, func, calib_gen = self._getIntegerQuantizeModel()
939    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
940        [func], root)
941    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
942    quantized_converter.representative_dataset = calib_gen
943    quantized_converter.target_spec.supported_ops = [
944        lite.OpsSet.TFLITE_BUILTINS
945    ]
946    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
947    if disable_per_channel:
948      quantized_converter._experimental_disable_per_channel = (
949          disable_per_channel)
950    quantized_tflite_model = quantized_converter.convert()
951    self.assertIsNotNone(quantized_tflite_model)
952
953    interpreter = Interpreter(model_content=quantized_tflite_model)
954    interpreter.allocate_tensors()
955    detail = next((d for d in interpreter.get_tensor_details()
956                   if d['name'] == k_conv_name))
957    quant_params = detail['quantization_parameters']
958    expected_num_params = 1 if disable_per_channel else k_num_filters
959    self.assertLen(quant_params['scales'], expected_num_params)
960    self.assertLen(quant_params['zero_points'], expected_num_params)
961
962  @test_util.run_v2_only
963  def testOpVersion(self):
964    @tf.function(
965        input_signature=[tf.TensorSpec(shape=[5, 5], dtype=tf.float32)])
966    def custom_resize(image):
967      # Add "batch" and "channels" dimensions
968      image = image[tf.newaxis, ..., tf.newaxis]
969      # ResizeBilinear version 3.
970      resize1 = tf.compat.v1.image.resize_bilinear(
971          image, [2, 2], half_pixel_centers=True)
972      # ResizeBilinear version 1.
973      resize2 = tf.compat.v1.image.resize_bilinear(image, [2, 2])
974      return resize1 + resize2
975
976    concrete_func = custom_resize.get_concrete_function()
977    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
978                                                               custom_resize)
979    tflite_model = converter.convert()
980    model_object = schema_fb.Model.GetRootAsModel(tflite_model, 0)
981    model = schema_fb.ModelT.InitFromObj(model_object)
982
983    for operator in model.operatorCodes:
984      if operator.builtinCode == schema_fb.BuiltinOperator.RESIZE_BILINEAR:
985        # half_pixel_centers is supported by ResizeBilinear version 3.
986        self.assertEqual(operator.version, 3)
987        break
988
989
990class FromSavedModelTest(lite_v2_test_util.ModelTest):
991
992  def _createV1SavedModel(self, shape):
993    """Create a simple SavedModel."""
994    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
995    with tf.Graph().as_default():
996      with tf.compat.v1.Session() as sess:
997        in_tensor_1 = tf.compat.v1.placeholder(
998            shape=shape, dtype=tf.float32, name='inputB')
999        in_tensor_2 = tf.compat.v1.placeholder(
1000            shape=shape, dtype=tf.float32, name='inputA')
1001        variable_node = tf.Variable(1.0, name='variable_node')
1002        out_tensor = in_tensor_1 + in_tensor_2 * variable_node
1003        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
1004        outputs = {'z': out_tensor}
1005        sess.run(tf.compat.v1.variables_initializer([variable_node]))
1006        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1007    return saved_model_dir
1008
1009  def _createV2QATSavedModel(self, shape):
1010    """Create a simple QAT SavedModel in TF 2."""
1011    saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1012    input_name = 'input'
1013    output_name = 'scores'
1014
1015    input_tensor = tf.keras.layers.Input((32, 32, 128), name=input_name)
1016    x = tf.quantization.fake_quant_with_min_max_args(input_tensor, -3.0, 3.0)
1017    x = tf.keras.layers.Conv2D(1, (3, 3))(x)
1018    x = tf.quantization.fake_quant_with_min_max_args(x, -3.0, 3.0)
1019    scores = tf.keras.layers.Reshape((-1,), name=output_name)(x)
1020    model = tf.keras.Model(input_tensor, scores)
1021    model.save(saved_model_dir)
1022    return saved_model_dir, input_name, output_name
1023
1024  @test_util.run_v2_only
1025  def testV1SimpleModel(self):
1026    """Test a SavedModel."""
1027    with tf.Graph().as_default():
1028      saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
1029
1030      # Convert model and ensure model is not None.
1031      converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1032      tflite_model = converter.convert()
1033      self.assertTrue(tflite_model)
1034
1035      interpreter = Interpreter(model_content=tflite_model)
1036      interpreter.allocate_tensors()
1037
1038      input_details = interpreter.get_input_details()
1039      self.assertLen(input_details, 2)
1040      self.assertStartsWith(input_details[0]['name'], 'inputA')
1041      self.assertEqual(np.float32, input_details[0]['dtype'])
1042      self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1043      self.assertEqual((0., 0.), input_details[0]['quantization'])
1044
1045      self.assertStartsWith(
1046          input_details[1]['name'],
1047          'inputB',
1048      )
1049      self.assertEqual(np.float32, input_details[1]['dtype'])
1050      self.assertTrue([1, 16, 16, 3], input_details[1]['shape'])
1051      self.assertEqual((0., 0.), input_details[1]['quantization'])
1052
1053      output_details = interpreter.get_output_details()
1054      self.assertLen(output_details, 1)
1055      self.assertStartsWith(output_details[0]['name'], 'add')
1056      self.assertEqual(np.float32, output_details[0]['dtype'])
1057      self.assertTrue([1, 16, 16, 3], output_details[0]['shape'])
1058      self.assertEqual((0., 0.), output_details[0]['quantization'])
1059
1060  @parameterized.named_parameters(
1061      ('Default', False),
1062      ('UnfoldLargeConstant', True),
1063  )
1064  @test_util.run_v2_only
1065  def testUnfoldLargeConstant(self, unfold_large_constant):
1066    """Test unfolding large splat constant in a TF Lite model."""
1067    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
1068    with tf.Graph().as_default():
1069      with tf.compat.v1.Session() as sess:
1070        in_tensor = tf.compat.v1.placeholder(
1071            shape=[1000, 1000], dtype=tf.float32, name='input')
1072        constant = tf.constant(value=1, dtype=tf.float32, shape=[1000, 1000])
1073        out_tensor = in_tensor + constant
1074        inputs = {'x': in_tensor}
1075        outputs = {'y': out_tensor}
1076        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1077
1078    # Convert model and ensure model is not None.
1079    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1080    converter._experimental_unfold_large_splat_constant = unfold_large_constant
1081    tflite_model = converter.convert()
1082    self.assertTrue(tflite_model)
1083
1084    model = util._convert_model_from_bytearray_to_object(tflite_model)
1085    if unfold_large_constant:
1086      self.assertEqual(model.operatorCodes[0].builtinCode,
1087                       schema_fb.BuiltinOperator.FILL)
1088      self.assertEqual(model.operatorCodes[1].builtinCode,
1089                       schema_fb.BuiltinOperator.ADD)
1090    else:
1091      self.assertEqual(model.operatorCodes[0].builtinCode,
1092                       schema_fb.BuiltinOperator.ADD)
1093
1094    # Check values from converted model.
1095    interpreter = Interpreter(model_content=tflite_model)
1096    interpreter.allocate_tensors()
1097
1098    input_details = interpreter.get_input_details()
1099    self.assertLen(input_details, 1)
1100    self.assertEqual('input:0', input_details[0]['name'])
1101    self.assertEqual(np.float32, input_details[0]['dtype'])
1102    self.assertAllEqual([1000, 1000], input_details[0]['shape'])
1103    self.assertEqual((0., 0.), input_details[0]['quantization'])
1104
1105    output_details = interpreter.get_output_details()
1106    self.assertEqual('add:0', output_details[0]['name'])
1107    self.assertEqual(np.float32, output_details[0]['dtype'])
1108    self.assertAllEqual([1000, 1000], output_details[0]['shape'])
1109    self.assertEqual((0., 0.), output_details[0]['quantization'])
1110
1111    interpreter.set_tensor(input_details[0]['index'],
1112                           np.ones(shape=[1000, 1000], dtype=np.float32))
1113    interpreter.invoke()
1114    self.assertAllEqual(
1115        np.full(shape=[1000, 1000], fill_value=2.0, dtype=np.float32),
1116        interpreter.get_tensor(output_details[0]['index']))
1117
1118  @test_util.run_v2_only
1119  def testTF1HubFormattedModel(self):
1120    """Test a TF1 hub formatted model."""
1121    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
1122
1123    # TF1 hub model is based on V1 saved model and they omit the saved model
1124    # schema version setting.
1125    saved_model_proto = parse_saved_model(saved_model_dir)
1126    saved_model_proto.saved_model_schema_version = 0
1127
1128    saved_model_pb_file_path = os.path.join(saved_model_dir, 'saved_model.pb')
1129    with file_io.FileIO(saved_model_pb_file_path, 'wb') as writer:
1130      writer.write(saved_model_proto.SerializeToString())
1131
1132    # Convert model and ensure model is not None.
1133    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1134    tflite_model = converter.convert()
1135    self.assertTrue(tflite_model)
1136
1137  def _createV1ModelWithHashTableInitializer(self):
1138    # Create a v1 saved model with hash table initializers.
1139    tf.compat.v1.disable_eager_execution()
1140    saved_model_dir = os.path.join(self.get_temp_dir(),
1141                                   'savedmodel_with_hashtable')
1142
1143    table_initializer = tf.lookup.KeyValueTensorInitializer(
1144        keys=['a', 'b', 'c', 'd'],
1145        values=[1, 2, 3, 4],
1146        key_dtype=tf.string,
1147        value_dtype=tf.int64)
1148    table = tf.lookup.StaticHashTable(
1149        table_initializer, default_value=tf.constant(-1, dtype=tf.int64))
1150
1151    x = tf.compat.v1.placeholder(tf.string, shape=(), name='input')
1152    y = table.lookup(x)
1153
1154    tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
1155    tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
1156
1157    signature_def_map, init_op, assets_collection = {
1158        'serving_default':
1159            (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
1160                inputs={'x': tensor_info_x},
1161                outputs={'y': tensor_info_y},
1162                method_name='some_function'))
1163    }, tf.compat.v1.tables_initializer(), None
1164
1165    sess = tf.compat.v1.Session()
1166    sess.run(tf.compat.v1.initializers.global_variables())
1167
1168    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(
1169        saved_model_dir)
1170    builder.add_meta_graph_and_variables(
1171        sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
1172        signature_def_map,
1173        main_op=init_op,
1174        assets_collection=assets_collection,
1175        strip_default_attrs=True)
1176    builder.save()
1177
1178    # Restore TF v2 behavior.
1179    tf.compat.v1.reset_default_graph()
1180    tf.compat.v1.enable_eager_execution()
1181    return saved_model_dir
1182
1183  @test_util.run_v2_only
1184  def testModelWithHashTableInitializer(self):
1185    """Test a model with saved_model's session initializer for hash tables."""
1186    saved_model_dir = self._createV1ModelWithHashTableInitializer()
1187
1188    # Convert model and ensure model is not None.
1189    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1190    tflite_model = converter.convert()
1191
1192    # Check values from converted model.
1193    interpreter = Interpreter(model_content=tflite_model)
1194    input_details = interpreter.get_input_details()
1195    output_details = interpreter.get_output_details()
1196
1197    input_data = np.array(['a', 'b', 'c', 'z'], dtype=np.string_)
1198    interpreter.resize_tensor_input(
1199        input_details[0]['index'], [4], strict=False)
1200    interpreter.allocate_tensors()
1201
1202    interpreter.set_tensor(input_details[0]['index'], input_data)
1203
1204    # Invoke multiple times to ensure the initializer graph runs only once.
1205    interpreter.invoke()
1206    actual_value = interpreter.get_tensor(output_details[0]['index'])
1207    self.assertEqual([1, 2, 3, -1], list(actual_value))
1208
1209    interpreter.invoke()
1210    actual_value = interpreter.get_tensor(output_details[0]['index'])
1211    self.assertEqual([1, 2, 3, -1], list(actual_value))
1212
1213    interpreter.invoke()
1214    actual_value = interpreter.get_tensor(output_details[0]['index'])
1215    self.assertEqual([1, 2, 3, -1], list(actual_value))
1216
1217  def _createV1ModelWithMutableHashTable(self):
1218    # Create a v1 saved model with mutable hash table.
1219    tf.compat.v1.disable_eager_execution()
1220    saved_model_dir = os.path.join(self.get_temp_dir(),
1221                                   'savedmodel_with_mutable_hashtable')
1222
1223    table = tf.raw_ops.MutableHashTableV2(
1224        key_dtype=tf.string, value_dtype=tf.int64)
1225    x = tf.compat.v1.placeholder(tf.string, shape=(), name='input')
1226    keys = tf.constant(['a', 'b'], tf.string)
1227    values = tf.constant([1, 5], tf.int64)
1228    default_value = tf.constant(-1, tf.int64)
1229    insert_call = tf.raw_ops.LookupTableInsertV2(
1230        table_handle=table, keys=keys, values=values)
1231    with tf.control_dependencies([insert_call]):
1232      y = tf.raw_ops.LookupTableFindV2(
1233          table_handle=table, keys=x, default_value=default_value)
1234
1235    tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
1236    tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
1237
1238    signature_def_map, init_op, assets_collection = {
1239        'serving_default':
1240            (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
1241                inputs={'x': tensor_info_x},
1242                outputs={'y': tensor_info_y},
1243                method_name='some_function'))
1244    }, tf.compat.v1.tables_initializer(), None
1245
1246    sess = tf.compat.v1.Session()
1247
1248    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(
1249        saved_model_dir)
1250    builder.add_meta_graph_and_variables(
1251        sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
1252        signature_def_map,
1253        main_op=init_op,
1254        assets_collection=assets_collection,
1255        strip_default_attrs=True)
1256    builder.save()
1257
1258    # Restore TF v2 behavior.
1259    tf.compat.v1.reset_default_graph()
1260    tf.compat.v1.enable_eager_execution()
1261    return saved_model_dir
1262
1263  @test_util.run_v2_only
1264  def testModelWithMutableHashTable(self):
1265    """Test a model with saved_model's session initializer for hash tables."""
1266    saved_model_dir = self._createV1ModelWithMutableHashTable()
1267
1268    # Convert model and ensure model is not None.
1269    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1270    converter.target_spec.supported_ops = [
1271        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
1272    ]
1273    tflite_model = converter.convert()
1274
1275    # Check values from converted model.
1276    interpreter = Interpreter(model_content=tflite_model)
1277    input_details = interpreter.get_input_details()
1278    output_details = interpreter.get_output_details()
1279
1280    input_data = np.array(['a', 'b', 'c'], dtype=np.string_)
1281    interpreter.resize_tensor_input(
1282        input_details[0]['index'], [3], strict=False)
1283    interpreter.allocate_tensors()
1284
1285    interpreter.set_tensor(input_details[0]['index'], input_data)
1286
1287    interpreter.invoke()
1288    actual_value = interpreter.get_tensor(output_details[0]['index'])
1289    self.assertEqual([1, 5, -1], list(actual_value))
1290
1291  @test_util.run_v2_only
1292  def testConstModel(self):
1293    """Test a basic model with functions to make sure functions are inlined."""
1294    input_data = tf.constant(1., shape=[1])
1295    root = tracking.AutoTrackable()
1296    root.f = tf.function(lambda x: 2. * x)
1297    to_save = root.f.get_concrete_function(input_data)
1298
1299    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1300    save(root, save_dir, to_save)
1301
1302    # Convert model and ensure model is not None.
1303    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1304    tflite_model = converter.convert()
1305
1306    # Check values from converted model.
1307    expected_value = root.f(input_data)
1308    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1309    self.assertEqual(expected_value.numpy(), actual_value)
1310
1311  @test_util.run_v2_only
1312  def testVariableModel(self):
1313    """Test a basic model with Variables with saving/loading the SavedModel."""
1314    root = self._getSimpleVariableModel()
1315    input_data = tf.constant(1., shape=[1])
1316    to_save = root.f.get_concrete_function(input_data)
1317
1318    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1319    save(root, save_dir, to_save)
1320
1321    # Convert model and ensure model is not None.
1322    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1323    tflite_model = converter.convert()
1324
1325    # Check values from converted model.
1326    expected_value = root.f(input_data)
1327    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1328    self.assertEqual(expected_value.numpy(), actual_value)
1329
1330  @parameterized.named_parameters(('EnableResourceVariables', True),
1331                                  ('DisableResourceVariables', False))
1332  @test_util.run_v2_only
1333  def testNativeVariablesModel(self, enable_resource_variables):
1334    """Test a basic model with Variables with saving/loading the SavedModel."""
1335    root = self._getSimpleModelWithVariables()
1336    input_data = tf.constant(1., shape=[1, 10])
1337    to_save = root.assign_add.get_concrete_function(input_data)
1338
1339    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1340    save(root, save_dir, to_save)
1341
1342    # Convert model and ensure model is not None.
1343    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1344    converter.experimental_enable_resource_variables = enable_resource_variables
1345
1346    if not enable_resource_variables:
1347      with self.assertRaises(convert.ConverterError) as error:
1348        tflite_model = converter.convert()
1349      self.assertIn(
1350          'Variable constant folding is failed. Please consider using enabling '
1351          '`experimental_enable_resource_variables` flag in the TFLite '
1352          'converter object.',
1353          str(error.exception))
1354      return
1355
1356    # Enable resource variables.
1357    tflite_model = converter.convert()
1358
1359    # Check values from converted model.
1360    expected_value = root.assign_add(input_data)
1361    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1362    for tf_result, tflite_result in zip(expected_value, actual_value[0]):
1363      self.assertAllClose(tf_result, tflite_result, atol=1e-05)
1364
1365  @test_util.run_v2_only
1366  def testSignatures(self):
1367    """Test values for `signature_keys` argument."""
1368    root = self._getSimpleVariableModel()
1369    input_data = tf.constant(1., shape=[1])
1370    to_save = root.f.get_concrete_function(input_data)
1371
1372    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1373    save(root, save_dir, to_save)
1374
1375    # Convert model with invalid `signature_keys`.
1376    with self.assertRaises(ValueError) as error:
1377      _ = lite.TFLiteConverterV2.from_saved_model(
1378          save_dir, signature_keys=['INVALID'])
1379    self.assertIn("Invalid signature key 'INVALID'", str(error.exception))
1380
1381    # Convert model with empty `signature_keys`.
1382    converter = lite.TFLiteConverterV2.from_saved_model(
1383        save_dir, signature_keys=[])
1384    tflite_model = converter.convert()
1385
1386    # Check values from converted model.
1387    expected_value = root.f(input_data)
1388    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1389    self.assertEqual(expected_value.numpy(), actual_value)
1390
1391  @test_util.run_v2_only
1392  def testSignatureDefsWithFullIntegerQuantization(self):
1393    # SETUP
1394    # 1. Define input shapes
1395    tf_input_shape = (32, 32, 128)
1396    tflite_input_shape = (1,) + tf_input_shape
1397    # 2. Define model
1398    tf_saved_model_dir, input_name, output_name = (
1399        self._createV2QATSavedModel(tf_input_shape))
1400
1401    # MODEL 1: TFLite (float) model
1402    # 1. Create TFLite model
1403    converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_model_dir)
1404    converter.optimizations = [tf.lite.Optimize.DEFAULT]
1405    tflite_model = converter.convert()
1406    # 2. Initialize the Intepreter
1407    interpreter = Interpreter(model_content=tflite_model)
1408    input_details = interpreter.get_input_details()[0]
1409    output_details = interpreter.get_output_details()[0]
1410    interpreter.resize_tensor_input(input_details['index'], tflite_input_shape)
1411    interpreter.allocate_tensors()
1412    signature_list = interpreter._get_full_signature_list()['serving_default']
1413    # 3. (Skip) Verify that signature def input/output tensors are in the model.
1414    # 4. Evaluate the model
1415    input_data = np.random.random(tflite_input_shape).astype(np.float32)
1416    result = self._evaluateTFLiteModelUsingSignatureDef(
1417        tflite_model, 'serving_default', {input_name: input_data})[output_name]
1418
1419    # MODEL 2: TFLite (full integer quantized) model
1420    # 1. Create TFLite model
1421    converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_model_dir)
1422    converter.optimizations = [tf.lite.Optimize.DEFAULT]
1423    converter.inference_input_type = tf.int8
1424    converter.inference_output_type = tf.int8
1425    tflite_model_quant = converter.convert()
1426    # 2. Initialize the Intepreter
1427    interpreter = Interpreter(model_content=tflite_model_quant)
1428    input_details = interpreter.get_input_details()[0]
1429    output_details = interpreter.get_output_details()[0]
1430    interpreter.resize_tensor_input(input_details['index'], tflite_input_shape)
1431    interpreter.allocate_tensors()
1432    # 3. Verify that signature def input/output tensors are in the model.
1433    all_indices = {item['index'] for item in interpreter.get_tensor_details()}
1434    signature_list = interpreter._get_full_signature_list()['serving_default']
1435    input_tensor_indices = set(signature_list['inputs'].values())
1436    assert input_tensor_indices.issubset(all_indices)
1437    output_tensor_indices = set(signature_list['outputs'].values())
1438    assert output_tensor_indices.issubset(all_indices)
1439
1440    # 4. Evaluate the model
1441    input_data = np.random.random(tflite_input_shape)
1442    input_scale, input_zero_point = input_details['quantization']
1443    if (input_scale, input_zero_point) != (0.0, 0):
1444      input_data = input_data / input_scale + input_zero_point
1445      input_data = input_data.astype(input_details['dtype'])
1446    result_quant = self._evaluateTFLiteModelUsingSignatureDef(
1447        tflite_model_quant, 'serving_default',
1448        {input_name: input_data})[output_name]
1449    output_scale, output_zero_point = output_details['quantization']
1450    if (output_scale, output_zero_point) != (0.0, 0):
1451      result_quant = result_quant.astype(np.float32)
1452      result_quant = (result_quant - output_zero_point) * output_scale
1453
1454    # COMPARE: Validate that results from both models are approx. the same.
1455    root_mean_squared = np.sqrt(np.mean((result-result_quant)**2))
1456    assert root_mean_squared < 1.0
1457
1458  @test_util.run_v2_only
1459  def testSignatureDefs(self):
1460    """Test converting SignatureDef is correct and uses SignatureDef API."""
1461    root = self._getMultiFunctionModel()
1462    input_data_0 = tf.constant(1., shape=[1])
1463    input_data_1 = tf.constant(3., shape=[1])
1464    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
1465                                                      input_data_0)
1466
1467    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1468    save(root, save_dir, {'mul_add': mul_add_func})
1469
1470    converter = lite.TFLiteConverterV2.from_saved_model(
1471        save_dir, signature_keys=['mul_add'])
1472    tflite_model = converter.convert()
1473
1474    # Check values from converted model.
1475    expected_value = root.mul_add(input_data_1, input_data_0)
1476    interpreter = Interpreter(model_content=tflite_model)
1477    signature_defs = interpreter.get_signature_list()
1478    results = self._evaluateTFLiteModelUsingSignatureDef(
1479        tflite_model, 'mul_add', {
1480            'y': input_data_0,
1481            'x': input_data_1
1482        })
1483    self.assertEqual(list(results.keys()), ['output_0'])
1484    self.assertEqual(expected_value.numpy(), results['output_0'])
1485
1486    # Verify the SignatureDef structure returned is as expected.
1487    self.assertEqual(len(signature_defs), 1)
1488    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
1489    self.assertEqual(len(signature_defs.values()), 1)
1490    self.assertEqual(
1491        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
1492    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
1493    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
1494
1495  @test_util.run_v2_only
1496  def testSignatureDefsWithDefaultValue(self):
1497    """Test converting SignatureDef is correct and uses SignatureDef API.
1498
1499    This test uses None as signature_key to test default behavior.
1500    """
1501    root = self._getMultiFunctionModel()
1502    input_data_0 = tf.constant(1., shape=[1])
1503    input_data_1 = tf.constant(3., shape=[1])
1504    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
1505                                                      input_data_0)
1506
1507    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1508    save(root, save_dir, {'mul_add': mul_add_func})
1509
1510    converter = lite.TFLiteConverterV2.from_saved_model(
1511        save_dir, signature_keys=['mul_add'])
1512    tflite_model = converter.convert()
1513
1514    # Check values from converted model.
1515    expected_value = root.mul_add(input_data_1, input_data_0)
1516    interpreter = Interpreter(model_content=tflite_model)
1517    signature_defs = interpreter.get_signature_list()
1518    results = self._evaluateTFLiteModelUsingSignatureDef(
1519        tflite_model, None, {
1520            'y': input_data_0,
1521            'x': input_data_1
1522        })
1523    self.assertEqual(list(results.keys()), ['output_0'])
1524    self.assertEqual(expected_value.numpy(), results['output_0'])
1525
1526    # Verify the SignatureDef structure returned is as expected.
1527    self.assertEqual(len(signature_defs), 1)
1528    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
1529    self.assertEqual(len(signature_defs.values()), 1)
1530    self.assertEqual(
1531        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
1532    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
1533    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
1534
1535  @test_util.run_v2_only
1536  def testSignatureDefsQuantizedModel(self):
1537    """Test converting SignatureDef on quantized model."""
1538    root = self._getMultiFunctionModel()
1539    input_data_0 = tf.constant(1., shape=[1])
1540    input_data_1 = tf.constant(3., shape=[1])
1541    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
1542                                                      input_data_0)
1543
1544    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1545    save(root, save_dir, {'mul_add': mul_add_func})
1546
1547    converter = lite.TFLiteConverterV2.from_saved_model(
1548        save_dir, signature_keys=['mul_add'])
1549
1550    def representative_dataset_gen():
1551      for _ in range(2):
1552        yield [
1553            np.random.uniform(low=0, high=1, size=(1, 1)).astype(np.float32),
1554            np.random.uniform(low=0, high=1, size=(1, 1)).astype(np.float32)
1555        ]
1556
1557    converter.optimizations = [tf.lite.Optimize.DEFAULT]
1558    converter.representative_dataset = representative_dataset_gen
1559    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
1560    tflite_model = converter.convert()
1561
1562    # Check signatures are valid from converted model.
1563    interpreter = Interpreter(model_content=tflite_model)
1564    signature_defs = interpreter.get_signature_list()
1565
1566    # Verify the SignatureDef structure returned is as expected.
1567    self.assertEqual(len(signature_defs), 1)
1568    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
1569    self.assertEqual(len(signature_defs.values()), 1)
1570    self.assertEqual(
1571        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
1572    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
1573    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
1574
1575  @test_util.run_v2_only
1576  def testMultipleFunctionModel(self):
1577    """Convert multiple functions in a multi-functional model."""
1578    root = self._getMultiFunctionModel()
1579    input_data = tf.constant(1., shape=[1])
1580    add_func = root.add.get_concrete_function(input_data)
1581    sub_func = root.sub.get_concrete_function(input_data)
1582
1583    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1584    save(root, save_dir, {'add': add_func, 'sub': sub_func})
1585
1586    # Try converting multiple functions.
1587    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1588    tflite_model = converter.convert()
1589    self.assertIsNotNone(tflite_model)
1590
1591    interpreter = tf.lite.Interpreter(model_content=tflite_model)
1592    signature_defs = interpreter.get_signature_list()
1593
1594    # Verify the SignatureDef structure returned is as expected.
1595    self.assertEqual(len(signature_defs), 2)
1596    self.assertEqual(list(signature_defs.keys()), ['add', 'sub'])
1597    self.assertEqual(len(signature_defs.values()), 2)
1598    self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs'])
1599    self.assertCountEqual(signature_defs['add']['inputs'], ['x'])
1600    self.assertEqual(list(signature_defs['add']['outputs']), ['output_0'])
1601    self.assertEqual(list(signature_defs['sub'].keys()), ['inputs', 'outputs'])
1602    self.assertCountEqual(signature_defs['sub']['inputs'], ['x'])
1603    self.assertEqual(list(signature_defs['sub']['outputs']), ['output_0'])
1604
1605    # Verify the Signature runner executions.
1606    add_signature_runner = interpreter.get_signature_runner('add')
1607    add_output = add_signature_runner(x=input_data)
1608    self.assertEqual(add_output['output_0'], 3)
1609
1610    sub_signature_runner = interpreter.get_signature_runner('sub')
1611    sub_output = sub_signature_runner(x=input_data)
1612    self.assertEqual(sub_output['output_0'], -2)
1613
1614  @test_util.run_v2_only
1615  def testMultipleFunctionModelWithSharedWeight(self):
1616    """Convert multiple functions with the shared weight."""
1617    root = self._getMultiFunctionModelWithSharedWeight()
1618    input_data = tf.constant(1., shape=[1])
1619    add_func = root.add.get_concrete_function(input_data)
1620    sub_func = root.sub.get_concrete_function(input_data)
1621    mul_func = root.mul.get_concrete_function(input_data)
1622
1623    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1624    save(root, save_dir, {'add': add_func, 'sub': sub_func, 'mul': mul_func})
1625
1626    # Try converting multiple functions.
1627    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1628    tflite_model = converter.convert()
1629    self.assertIsNotNone(tflite_model)
1630
1631    # Make sure that the weight tensors are shared.
1632    self.assertLess(len(tflite_model), 1100000)
1633
1634    # TODO(b/184696047): Write down the test codes for multiple signature
1635    #                    runners once the Python API is ready to use.
1636    interpreter = tf.lite.Interpreter(model_content=tflite_model)
1637    signature_defs = interpreter.get_signature_list()
1638    self.assertLen(signature_defs, 3)
1639    add_signature_runner = interpreter.get_signature_runner('add')
1640    sub_signature_runner = interpreter.get_signature_runner('sub')
1641    mul_signature_runner = interpreter.get_signature_runner('mul')
1642    self.assertIsNotNone(add_signature_runner)
1643    self.assertIsNotNone(sub_signature_runner)
1644    self.assertIsNotNone(mul_signature_runner)
1645
1646  @test_util.run_v2_only
1647  def testNoConcreteFunctionModel(self):
1648    root = self._getMultiFunctionModel()
1649
1650    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1651    save(root, save_dir)
1652
1653    with self.assertRaises(ValueError) as error:
1654      _ = lite.TFLiteConverterV2.from_saved_model(save_dir)
1655    self.assertIn('Only support at least one signature key.',
1656                  str(error.exception))
1657
1658  @test_util.run_v2_only
1659  def testKerasSequentialModel(self):
1660    """Test a simple sequential tf.Keras model."""
1661    input_data = tf.constant(1., shape=[1, 1])
1662
1663    x = np.array([[1.], [2.]])
1664    y = np.array([[2.], [4.]])
1665
1666    model = tf.keras.models.Sequential([
1667        tf.keras.layers.Dropout(0.2),
1668        tf.keras.layers.Dense(1),
1669    ])
1670    model.compile(optimizer='sgd', loss='mean_squared_error')
1671    model.fit(x, y, epochs=1)
1672
1673    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1674    save(model, save_dir)
1675
1676    # Convert model and ensure model is not None.
1677    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1678    tflite_model = converter.convert()
1679
1680    # Check values from converted model.
1681    expected_value = model.predict(input_data)
1682    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1683    self.assertEqual(expected_value, actual_value)
1684
1685  @test_util.run_v2_only
1686  def testGraphDebugInfo(self):
1687    """Test a SavedModel has debug info captured."""
1688    input_data = tf.constant(1., shape=[1])
1689    root = tracking.AutoTrackable()
1690    root.f = tf.function(lambda x: 2. * x)
1691    to_save = root.f.get_concrete_function(input_data)
1692    options = save_options.SaveOptions(save_debug_info=True)
1693    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1694    save(root, save_dir, to_save, options)
1695
1696    # Convert model and ensure model is not None.
1697    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1698    converter.convert()
1699    self._assertValidDebugInfo(converter._debug_info)
1700
1701  @test_util.run_v2_only
1702  def testFallbackPath(self):
1703    """Test a SavedModel fallback path using old converter."""
1704    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
1705
1706    # Convert model and ensure model is not None.
1707    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1708    converter.experimental_new_converter = False
1709    tflite_model = converter.convert()
1710
1711    self.assertTrue(tflite_model)
1712
1713  @test_util.run_v2_only
1714  def testNonStatefulConvLSTM2D(self):
1715    """Test saved model with non stateful ConvLSTM2D keras layer."""
1716    # Create keras model
1717    model = tf.keras.Sequential([
1718        tf.keras.layers.ConvLSTM2D(
1719            32, (3, 3),
1720            padding='same',
1721            return_sequences=True,
1722            stateful=False,
1723            batch_input_shape=(1, 1, 10, 10, 1))
1724    ])
1725    model.compile()
1726
1727    # Export the keras model to saved model.
1728    saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_lstm_2d')
1729    model.save(saved_model_dir, save_format='tf', include_optimizer=False)
1730
1731    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1732    converter.target_spec.supported_ops = [
1733        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
1734    ]
1735    tflite_model = converter.convert()
1736    self.assertTrue(tflite_model)
1737
1738  @test_util.run_v2_only
1739  def testKerasConvLSTM2DWithMoreThanOneDilationRate(self):
1740    input_tensor = tf.keras.layers.Input(
1741        batch_size=8,
1742        shape=[9, 10, 11, 12],
1743        name='input_tensor',
1744        dtype=tf.float32)
1745
1746    output = tf.keras.layers.ConvLSTM2D(
1747        filters=3,
1748        kernel_size=3,
1749        strides=1,
1750        padding='VALID',
1751        dilation_rate=2,
1752        use_bias=False,
1753        bias_initializer='ones',
1754        data_format='channels_last')(
1755            input_tensor)
1756
1757    model = tf.keras.Model(inputs=[input_tensor], outputs=output)
1758    model.compile(
1759        optimizer='adam',
1760        loss='sparse_categorical_crossentropy',
1761        metrics=['accuracy'])
1762
1763    # Export the keras model to saved model.
1764    saved_model_dir = os.path.join(self.get_temp_dir(),
1765                                   'conv_lstm_2d_with_dilation_rate')
1766    model.save(saved_model_dir, save_format='tf', include_optimizer=False)
1767
1768    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1769    converter.target_spec.supported_ops = [
1770        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
1771    ]
1772    tflite_model = converter.convert()
1773    self.assertTrue(tflite_model)
1774
1775  def _createUnknownInputShapeModel(self):
1776    """Create a simple SavedModel with unknown input."""
1777    saved_model_dir = os.path.join(self.get_temp_dir(), 'unknown_input_shape')
1778    with tf.Graph().as_default():
1779      with tf.compat.v1.Session() as sess:
1780        unknown_shape = tf.TensorShape(None)
1781        in_tensor = tf.compat.v1.placeholder(
1782            shape=unknown_shape, dtype=tf.float32, name='input')
1783        out_tensor = in_tensor + in_tensor
1784        inputs = {'input': in_tensor}
1785        outputs = {'output': out_tensor}
1786        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1787    return saved_model_dir
1788
1789  @test_util.run_v2_only
1790  def testUnknownInputShapeModel(self):
1791    """Test a SavedModel with an unknown input shape."""
1792    saved_model_dir = self._createUnknownInputShapeModel()
1793
1794    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1795    tflite_model = converter.convert()
1796    self.assertTrue(tflite_model)
1797
1798    # Check values from converted model.
1799    interpreter = Interpreter(model_content=tflite_model)
1800    input_details = interpreter.get_input_details()
1801    output_details = interpreter.get_output_details()
1802
1803    input_data = np.array([1., 2., 3.], dtype=np.float32)
1804    interpreter.resize_tensor_input(
1805        input_details[0]['index'], [3], strict=False)
1806    interpreter.allocate_tensors()
1807
1808    interpreter.set_tensor(input_details[0]['index'], input_data)
1809    interpreter.invoke()
1810    actual_value = interpreter.get_tensor(output_details[0]['index'])
1811    self.assertEqual([2., 4., 6.], list(actual_value))
1812
1813  @parameterized.named_parameters(
1814      ('_PerChannelQuant', False, False),
1815      ('_PerChannelMlirQuant', False, True),
1816      ('_PerTensorQuant', True, False),
1817      ('_PerTensorMlirQuant', True, True))
1818  @test_util.run_v2_only
1819  def testDisablePerChannelQuantization(self, disable_per_channel=False,
1820                                        enable_mlir_quantizer=False):
1821    model = tf.keras.models.Sequential([
1822        tf.keras.layers.Conv2D(16, (3, 3), activation='relu')
1823    ])
1824    model.build(input_shape=(1, 5, 5, 3))
1825    saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_saved_model')
1826    save(model, saved_model_dir)
1827    k_conv_name = 'sequential/conv2d/Conv2D1'
1828    k_num_filters = 16
1829    quantized_converter = tf.lite.TFLiteConverter.from_saved_model(
1830        saved_model_dir)
1831    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1832    def calib_gen():
1833      for _ in range(5):
1834        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
1835
1836    quantized_converter.representative_dataset = calib_gen
1837    quantized_converter.target_spec.supported_ops = [
1838        lite.OpsSet.TFLITE_BUILTINS
1839    ]
1840    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
1841    if disable_per_channel:
1842      quantized_converter._experimental_disable_per_channel = (
1843          disable_per_channel)
1844    quantized_tflite_model = quantized_converter.convert()
1845    self.assertIsNotNone(quantized_tflite_model)
1846
1847    interpreter = Interpreter(model_content=quantized_tflite_model)
1848    interpreter.allocate_tensors()
1849    detail = next((d for d in interpreter.get_tensor_details()
1850                   if d['name'] == k_conv_name))
1851    quant_params = detail['quantization_parameters']
1852    expected_num_params = k_num_filters
1853    if disable_per_channel:
1854      expected_num_params = 1
1855    self.assertLen(quant_params['scales'], expected_num_params)
1856    self.assertLen(quant_params['zero_points'], expected_num_params)
1857
1858
1859class FromKerasModelTest(lite_v2_test_util.ModelTest):
1860
1861  @test_util.run_v2_only
1862  def testSequentialModel(self):
1863    """Test a simple sequential tf.Keras model."""
1864    input_data = tf.constant(1., shape=[1, 1])
1865
1866    # Create a simple Keras model.
1867    x = np.array([[1.], [2.]])
1868    y = np.array([[2.], [4.]])
1869
1870    model = tf.keras.models.Sequential([
1871        tf.keras.layers.Dropout(0.2),
1872        tf.keras.layers.Dense(units=1, input_shape=[1])
1873    ])
1874    model.compile(optimizer='sgd', loss='mean_squared_error')
1875    model.fit(x, y, epochs=1)
1876
1877    # Convert model and ensure model is not None.
1878    converter = lite.TFLiteConverterV2.from_keras_model(model)
1879    tflite_model = converter.convert()
1880
1881    # Check values from converted model.
1882    expected_value = model.predict(input_data)
1883    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1884    self.assertEqual(expected_value, actual_value)
1885
1886  @test_util.run_v2_only
1887  def testSequentialMultiInputOutputModel(self):
1888    """Test a tf.Keras model with multiple inputs and outputs."""
1889    left_input_data = tf.constant(1., shape=[1, 3])
1890    right_input_data = tf.constant(1., shape=[1, 3])
1891
1892    # Create a simple Keras model.
1893    input_a_np = np.random.random((10, 3))
1894    input_b_np = np.random.random((10, 3))
1895    output_c_np = np.random.random((10, 3))
1896    output_d_np = np.random.random((10, 2))
1897
1898    input_a = tf.keras.layers.Input(shape=(3,), name='input_a')
1899    input_b = tf.keras.layers.Input(shape=(3,), name='input_b')
1900
1901    dense = tf.keras.layers.Dense(8, name='dense_1')
1902    interm_a = dense(input_a)
1903    interm_b = dense(input_b)
1904    merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
1905
1906    output_c = tf.keras.layers.Dense(
1907        3, activation='softmax', name='dense_2')(
1908            merged)
1909    output_d = tf.keras.layers.Dense(
1910        2, activation='softmax', name='dense_3')(
1911            merged)
1912
1913    model = tf.keras.models.Model(
1914        inputs=[input_a, input_b], outputs=[output_c, output_d])
1915    model.compile(optimizer='sgd', loss='mean_squared_error')
1916    model.fit([input_a_np, input_b_np], [output_c_np, output_d_np], epochs=1)
1917
1918    # Convert model and ensure model is not None.
1919    converter = lite.TFLiteConverterV2.from_keras_model(model)
1920    tflite_model = converter.convert()
1921
1922    # Check values from converted model.
1923    input_data = [left_input_data, right_input_data]
1924    expected_value = model.predict(input_data)
1925    actual_value = self._evaluateTFLiteModel(tflite_model, input_data)
1926    for tf_result, tflite_result in zip(expected_value, actual_value):
1927      self.assertAllClose(tf_result, tflite_result, atol=1e-05)
1928
1929  @test_util.run_v2_only
1930  def testGraphDebugInfo(self):
1931    """Test a tf.Keras model has debug info captured."""
1932    # Create a simple Keras model.
1933    x = [-1, 0, 1, 2, 3, 4]
1934    y = [-3, -1, 1, 3, 5, 7]
1935    model = tf.keras.models.Sequential(
1936        [tf.keras.layers.Dense(units=1, input_shape=[1])])
1937    model.compile(optimizer='sgd', loss='mean_squared_error')
1938    model.fit(x, y, epochs=1)
1939    converter = lite.TFLiteConverterV2.from_keras_model(model)
1940    converter.convert()
1941    self._assertValidDebugInfo(converter._debug_info)
1942
1943  @test_util.run_v2_only
1944  def testKerasFallbackPath(self):
1945    """Test keras model which failed when exporting to the saved model."""
1946    input_data = tf.constant(
1947        np.array(np.random.random_sample((20)), dtype=np.float32))
1948
1949    class Model(tf.keras.Model):
1950
1951      def __init__(self):
1952        super(Model, self).__init__()
1953        # A None name will cause a failure in exporting to a saved model.
1954        self.shared_weights = self.add_weight(
1955            name=None,
1956            shape=(20, 1),
1957            dtype=tf.float32,
1958            initializer=tf.random_normal_initializer(
1959                mean=0.0, stddev=300**(-0.5)))
1960
1961      def call(self, x):
1962        return tf.add(self.shared_weights, x)
1963
1964    # Building the model.
1965    model = Model()
1966    model.compile(optimizer='sgd', loss='mean_squared_error')
1967    model.fit(input_data, input_data, epochs=1)
1968
1969    # Convert model.
1970    converter = lite.TFLiteConverterV2.from_keras_model(model)
1971    tflite_model = converter.convert()
1972    self.assertTrue(tflite_model)
1973
1974  @test_util.run_v2_only
1975  def testSignatureDefs(self):
1976    """Test converting SignatureDef is correct and uses SignatureDef API."""
1977    keras_model = tf.keras.Sequential([
1978        tf.keras.layers.Conv2D(
1979            32,
1980            kernel_size=3,
1981            padding='same',
1982            activation='relu',
1983            input_shape=(32, 32, 3),
1984            name='tensor'),
1985        tf.keras.layers.Dense(10, name='output_tensor')
1986    ])
1987
1988    converter = lite.TFLiteConverterV2.from_keras_model(keras_model)
1989    tflite_model = converter.convert()
1990
1991    # Check values from converted model.
1992    input_data = tf.constant(
1993        np.random.uniform(-1, 1, size=(1, 32, 32, 3)).astype(np.float32))
1994    expected_value = keras_model(input_data)
1995    interpreter = Interpreter(model_content=tflite_model)
1996    signature_defs = interpreter.get_signature_list()
1997    results = self._evaluateTFLiteModelUsingSignatureDef(
1998        tflite_model, 'serving_default', {'tensor_input': input_data})
1999    self.assertEqual(list(results.keys()), ['output_tensor'])
2000    self.assertAllClose(expected_value.numpy(), results['output_tensor'])
2001
2002    # Verify the SignatureDef structure returned is as expected.
2003    self.assertEqual(len(signature_defs), 1)
2004    self.assertEqual(list(signature_defs.keys()), ['serving_default'])
2005    self.assertEqual(len(signature_defs.values()), 1)
2006    self.assertEqual(
2007        list(signature_defs['serving_default'].keys()), ['inputs', 'outputs'])
2008    self.assertCountEqual(signature_defs['serving_default']['inputs'],
2009                          ['tensor_input'])
2010    self.assertEqual(
2011        list(signature_defs['serving_default']['outputs']), ['output_tensor'])
2012
2013
2014class ControlFlowTest(lite_v2_test_util.ModelTest):
2015
2016  @test_util.run_v2_only
2017  def testCond(self):
2018    input_data = {
2019        'x': tf.constant([1., 2.], shape=[1, 2]),
2020        'b': tf.constant(True)
2021    }
2022
2023    weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
2024
2025    def true_fn(x):
2026      return tf.matmul(x, weights)
2027
2028    def false_fn(x):
2029      return tf.add(x, weights)
2030
2031    @tf.function(input_signature=[
2032        tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
2033        tf.TensorSpec(shape=(), dtype=tf.bool)
2034    ])
2035    def model(x, b):
2036      return tf.cond(
2037          b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
2038
2039    concrete_func = model.get_concrete_function()
2040
2041    # Convert model.
2042    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2043                                                               model)
2044    tflite_model = converter.convert()
2045
2046    # Check values from converted model.
2047    expected_value = concrete_func(**input_data)
2048    actual_value = self._evaluateTFLiteModel(
2049        tflite_model, [input_data['x'], input_data['b']])[0]
2050    self.assertAllClose(expected_value, actual_value)
2051
2052  @test_util.run_v2_only
2053  def testConverterErrorOnControlFlowV1Ops(self):
2054    filename = resource_loader.get_path_to_datafile(
2055        'testdata/control_flow_v1_saved_model')
2056    converter = lite.TFLiteConverterV2.from_saved_model(filename)
2057    with self.assertRaises(convert.ConverterError) as error:
2058      converter.convert()
2059    self.assertIn(
2060        'Failed to functionalize Control Flow V1 ops. Consider using Control '
2061        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
2062        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
2063
2064  @test_util.run_v2_only
2065  def testStaticRnn(self):
2066    input_data = tf.constant(
2067        np.array(np.random.random_sample((3, 10)), dtype=np.float32))
2068
2069    cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10)
2070
2071    @tf.function(
2072        input_signature=[tf.TensorSpec(shape=[3, 10], dtype=tf.float32)])
2073    def model(x):
2074      seq = tf.split(x, 3, 0)
2075      return tf.compat.v1.nn.static_rnn(
2076          cell, seq, dtype=tf.float32, sequence_length=[1])
2077
2078    concrete_func = model.get_concrete_function()
2079
2080    # Convert model.
2081    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2082                                                               model)
2083    tflite_model = converter.convert()
2084
2085    # Check values from converted model.
2086    expected_value = concrete_func(input_data)[0]
2087    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
2088    for expected, actual in zip(expected_value, actual_value):
2089      self.assertAllClose(expected, actual)
2090
2091  @test_util.run_v2_only
2092  def testWhileLoop(self):
2093    input_data = tf.constant([1., 2., 3., 4.], shape=[2, 2])
2094
2095    weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
2096
2097    def condition(x):
2098      return tf.reduce_sum(x) < 100
2099
2100    def body(x):
2101      return tf.add(x, weights)
2102
2103    @tf.function(
2104        input_signature=[tf.TensorSpec(shape=[2, 2], dtype=tf.float32)])
2105    def model(x):
2106      return tf.while_loop(condition, body, [x])
2107
2108    concrete_func = model.get_concrete_function()
2109
2110    # Convert model.
2111    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2112                                                               model)
2113    tflite_model = converter.convert()
2114
2115    # Check values from converted model.
2116    expected_value = concrete_func(input_data)[0]
2117    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
2118    self.assertAllClose(expected_value, actual_value)
2119
2120  @test_util.run_v2_only
2121  def testDynamicRnn(self):
2122    input_data = tf.constant(
2123        np.array(np.random.random_sample((3, 10, 10)), dtype=np.float32))
2124
2125    cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10)
2126
2127    @tf.function(
2128        input_signature=[tf.TensorSpec(shape=[3, 10, 10], dtype=tf.float32)])
2129    def model(x):
2130      return tf.compat.v1.nn.dynamic_rnn(cell, x, dtype=tf.float32)
2131
2132    concrete_func = model.get_concrete_function()
2133
2134    # Convert model.
2135    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2136                                                               model)
2137    tflite_model = converter.convert()
2138
2139    # Check values from converted model.
2140    expected_value = concrete_func(input_data)
2141    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
2142    for expected, actual in zip(expected_value, actual_value):
2143      if not isinstance(expected, ops.EagerTensor):
2144        expected = expected.c
2145      self.assertAllClose(expected, actual)
2146
2147  @parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
2148                                  ('SimpleRNN', tf.keras.layers.SimpleRNN),
2149                                  ('GRU', tf.keras.layers.GRU))
2150  @test_util.run_v2_only
2151  def testKerasRNN(self, rnn_layer):
2152    input_data = tf.constant(
2153        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
2154    rnn_obj = rnn_layer(units=10, input_shape=(10, 10))
2155    model = tf.keras.models.Sequential([
2156        tf.keras.layers.Input(shape=(10, 10), name='input'),
2157        rnn_obj,
2158    ])
2159
2160    # Convert model.
2161    converter = lite.TFLiteConverterV2.from_keras_model(model)
2162    tflite_model = converter.convert()
2163    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
2164
2165    # Check values from converted model.
2166    expected_value = model.predict(input_data)
2167    self.assertAllClose(expected_value, actual_value, atol=1e-05)
2168
2169  @parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
2170                                  ('SimpleRNN', tf.keras.layers.SimpleRNN),
2171                                  ('GRU', tf.keras.layers.GRU))
2172  @test_util.run_v2_only
2173  def testKerasRNNMultiBatches(self, rnn_layer):
2174    input_data = tf.constant(
2175        np.array(np.random.random_sample((4, 10, 10)), dtype=np.float32))
2176    # Specify a fixed batch size(4) for the test model.
2177    x = tf.keras.layers.Input(batch_shape=(4, 10, 10))
2178    y = rnn_layer(units=10, input_shape=(10, 10))(x)
2179    model = tf.keras.Model(inputs=[x], outputs=[y])
2180
2181    # Convert model.
2182    converter = lite.TFLiteConverterV2.from_keras_model(model)
2183    tflite_model = converter.convert()
2184    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
2185
2186    # Check values from converted model.
2187    expected_value = model.predict(input_data)
2188    self.assertAllClose(expected_value, actual_value, atol=1e-05)
2189
2190  @test_util.run_v2_only
2191  def testKerasBidirectionalRNNReturnSequence(self):
2192    input_data = tf.constant(
2193        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
2194    model = tf.keras.models.Sequential()
2195    model.add(tf.keras.layers.Input(shape=(10, 10), name='input'))
2196    model.add(
2197        tf.keras.layers.Bidirectional(
2198            tf.keras.layers.LSTM(units=10, return_sequences=True),
2199            input_shape=(10, 10)))
2200    model.add(tf.keras.layers.Flatten())
2201    model.add(tf.keras.layers.Dense(5))
2202    model.add(tf.keras.layers.Activation('softmax'))
2203
2204    # Convert model.
2205    converter = lite.TFLiteConverterV2.from_keras_model(model)
2206    tflite_model = converter.convert()
2207    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
2208
2209    # Check values from converted model.
2210    expected_value = model.predict(input_data)
2211    self.assertAllClose(expected_value, actual_value, atol=1e-05)
2212
2213  @test_util.run_v2_only
2214  def testKerasBidirectionalRNN(self):
2215    input_data = tf.constant(
2216        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
2217    model = tf.keras.models.Sequential()
2218    model.add(tf.keras.layers.Input(shape=(10, 10), name='input'))
2219    model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=10)))
2220    model.add(tf.keras.layers.Dense(5))
2221    model.add(tf.keras.layers.Activation('softmax'))
2222
2223    # Convert model.
2224    converter = lite.TFLiteConverterV2.from_keras_model(model)
2225    tflite_model = converter.convert()
2226    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
2227
2228    # Check values from converted model.
2229    expected_value = model.predict(input_data)
2230    self.assertAllClose(expected_value, actual_value, atol=1e-05)
2231
2232
2233class GrapplerTest(lite_v2_test_util.ModelTest):
2234
2235  @test_util.run_v2_only
2236  def testConstantFolding(self):
2237    # Constant folding handles the tf.broadcast_to operation which was not
2238    # supported by the TFLite at the time this test was added.
2239    input_data = tf.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.], shape=[3, 3])
2240
2241    @tf.function
2242    def func(x):
2243      y_const = tf.constant([1., 2., 3.])
2244      y_broadcast = tf.broadcast_to(y_const, [3, 3])
2245      return tf.matmul(x, y_broadcast)
2246
2247    root = tracking.AutoTrackable()
2248    root.f = func
2249    concrete_func = root.f.get_concrete_function(input_data)
2250
2251    # Convert model.
2252    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2253                                                               root)
2254    tflite_model = converter.convert()
2255
2256    # Check values from converted model.
2257    expected_value = root.f(input_data)
2258    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
2259    self.assertAllClose(expected_value, actual_value)
2260
2261    # Enable hybrid quantization, same result
2262    converter.optimizations = [lite.Optimize.DEFAULT]
2263    tflite_model = converter.convert()
2264    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
2265    self.assertAllClose(expected_value, actual_value)
2266
2267
2268class UnknownShapes(lite_v2_test_util.ModelTest):
2269
2270  @test_util.run_v2_only
2271  def testMatMul(self):
2272    input_data = tf.constant(
2273        np.array(np.random.random_sample((10, 4)), dtype=np.float32))
2274
2275    @tf.function(
2276        input_signature=[tf.TensorSpec(shape=[None, 4], dtype=tf.float32)])
2277    def model(in_tensor):
2278      shape = tf.shape(in_tensor)
2279      fill = tf.transpose(tf.fill(shape, 1.))
2280      return tf.matmul(fill, in_tensor)
2281
2282    concrete_func = model.get_concrete_function()
2283
2284    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2285                                                               model)
2286    tflite_model = converter.convert()
2287
2288    # Check values from converted model.
2289    expected_value = concrete_func(input_data)
2290    actual_value = self._evaluateTFLiteModel(
2291        tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])[0]
2292    self.assertAllClose(expected_value, actual_value, atol=1e-06)
2293
2294  def _getIntegerQuantizeModelWithUnknownShapes(self):
2295    np.random.seed(0)
2296
2297    @tf.function(
2298        input_signature=[tf.TensorSpec(shape=[None, 33], dtype=tf.float32)])
2299    def model(input_tensor):
2300      """Define a model with tf.MatMul and unknown shapes."""
2301      # We need the tensor to have more than 1024 elements for quantize_weights
2302      # to kick in. Thus, the [33, 33] shape.
2303      const_tensor = tf.constant(
2304          np.random.uniform(low=-10., high=10., size=[33, 33]),
2305          shape=[33, 33],
2306          dtype=tf.float32,
2307          name='inputB')
2308
2309      shape = tf.shape(input_tensor)
2310      fill = tf.transpose(tf.fill(shape, 1.))
2311      mult = tf.matmul(fill, input_tensor)
2312      return tf.matmul(mult, const_tensor)
2313
2314    root = tracking.AutoTrackable()
2315    root.f = model
2316    concrete_func = root.f.get_concrete_function()
2317
2318    def calibration_gen():
2319      for batch in range(5, 20, 5):
2320        for _ in range(5):
2321          yield [np.random.uniform(-1, 1, size=(batch, 33)).astype(np.float32)]
2322
2323    return root, concrete_func, calibration_gen
2324
2325  @test_util.run_v2_only
2326  def testMatMulQuantize(self):
2327    root, concrete_func, _ = self._getIntegerQuantizeModelWithUnknownShapes()
2328    float_converter = lite.TFLiteConverterV2.from_concrete_functions(
2329        [concrete_func], root)
2330    float_tflite_model = float_converter.convert()
2331
2332    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
2333        [concrete_func], root)
2334    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
2335    quantized_tflite_model = quantized_converter.convert()
2336
2337    # The default input and output types should be float.
2338    quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
2339    quantized_interpreter.allocate_tensors()
2340    input_details = quantized_interpreter.get_input_details()
2341    self.assertLen(input_details, 1)
2342    self.assertEqual(np.float32, input_details[0]['dtype'])
2343    self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
2344
2345    # Ensure that the quantized weights tflite model is smaller.
2346    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
2347
2348  @test_util.run_v2_only
2349  def testMatMulCalibrateAndQuantize(self):
2350    root, concrete_func, calibration_gen = (
2351        self._getIntegerQuantizeModelWithUnknownShapes())
2352    float_converter = lite.TFLiteConverterV2.from_concrete_functions(
2353        [concrete_func], root)
2354    float_tflite_model = float_converter.convert()
2355
2356    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
2357        [concrete_func], root)
2358    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
2359    quantized_converter.representative_dataset = calibration_gen
2360    quantized_tflite_model = quantized_converter.convert()
2361
2362    # The default input and output types should be float.
2363    quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
2364    quantized_interpreter.allocate_tensors()
2365    input_details = quantized_interpreter.get_input_details()
2366    self.assertLen(input_details, 1)
2367    self.assertEqual(np.float32, input_details[0]['dtype'])
2368    self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
2369
2370    # Ensure that the quantized weights tflite model is smaller.
2371    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
2372
2373  def testBatchMatMul(self):
2374    input_data_1 = tf.constant(
2375        np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
2376    input_data_2 = tf.constant(
2377        np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
2378
2379    @tf.function(input_signature=[
2380        tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32),
2381        tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32)
2382    ])
2383    def model(in_tensor_1, in_tensor_2):
2384      return tf.matmul(in_tensor_1, in_tensor_2)
2385
2386    concrete_func = model.get_concrete_function()
2387
2388    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2389                                                               model)
2390    tflite_model = converter.convert()
2391
2392    # Check values from converted model.
2393    expected_value = concrete_func(input_data_1, input_data_2)
2394    actual_value = self._evaluateTFLiteModel(
2395        tflite_model, [input_data_1, input_data_2],
2396        input_shapes=[([-1, 256, 256], [1, 256, 256])])[0]
2397    self.assertAllClose(expected_value, actual_value, atol=4)
2398
2399  def testSizeInvalid(self):
2400
2401    @tf.function(input_signature=[
2402        tf.TensorSpec(shape=[1, None, 16, 3], dtype=tf.float32)
2403    ])
2404    def model(in_tensor):
2405      return in_tensor + in_tensor
2406
2407    concrete_func = model.get_concrete_function()
2408
2409    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
2410    # invoke shape checking code.
2411    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2412                                                               model)
2413    converter.experimental_new_converter = False
2414    with self.assertRaises(ValueError) as error:
2415      converter.convert()
2416    self.assertEqual(
2417        'None is only supported in the 1st dimension. Tensor '
2418        '\'in_tensor\' has invalid shape \'[1, None, 16, 3]\'.',
2419        str(error.exception))
2420
2421
2422class ResourceAndVariantTypes(lite_v2_test_util.ModelTest):
2423
2424  @test_util.run_v2_only
2425  def testVariants(self):
2426
2427    @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)])
2428    def model(v):
2429      m = map_ops.empty_tensor_map()
2430      k = tf.constant(1.0)
2431      p = tf.add(k, v)
2432      with ops.control_dependencies([m]):
2433        m2 = map_ops.tensor_map_insert(m, p, v)
2434        with ops.control_dependencies([m2]):
2435          return map_ops.tensor_map_size(m2)
2436
2437    concrete_func = model.get_concrete_function()
2438
2439    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2440                                                               model)
2441    converter.target_spec.supported_ops = [
2442        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2443    ]
2444    tflite_model = converter.convert()
2445    self.assertIsNotNone(tflite_model)
2446
2447    # Check values from converted model.
2448    interpreter = Interpreter(model_content=tflite_model)
2449    input_details = interpreter.get_input_details()
2450    output_details = interpreter.get_output_details()
2451
2452    interpreter.allocate_tensors()
2453
2454    input_data = np.array([1.0], dtype=np.float32)
2455    interpreter.set_tensor(input_details[0]['index'], input_data)
2456
2457    interpreter.invoke()
2458    actual_value = interpreter.get_tensor(output_details[0]['index'])
2459    self.assertEqual(1, actual_value)
2460
2461    interpreter.invoke()
2462    actual_value = interpreter.get_tensor(output_details[0]['index'])
2463    self.assertEqual(1, actual_value)
2464
2465    interpreter.invoke()
2466    actual_value = interpreter.get_tensor(output_details[0]['index'])
2467    self.assertEqual(1, actual_value)
2468
2469  @test_util.run_v2_only
2470  def testVariantsWithCond(self):
2471
2472    def create_v1_saved_model():
2473      saved_model_dir = os.path.join(self.get_temp_dir(), 'variants_with_cond')
2474      with tf.Graph().as_default():
2475        with tf.compat.v1.Session() as sess:
2476          m = map_ops.empty_tensor_map()
2477
2478          def body(i, m):
2479            m = map_ops.tensor_map_insert(m, i, i)
2480            return i + 1, m
2481
2482          in_tensor = tf.compat.v1.placeholder(
2483              shape=[1], dtype=tf.int32, name='input')
2484          _, result_m = tf.cond(in_tensor < 10, lambda: body(in_tensor, m),
2485                                lambda: body(in_tensor + 1, m))
2486          out_tensor = in_tensor + map_ops.tensor_map_size(result_m)
2487
2488          inputs = {'x': in_tensor}
2489          outputs = {'z': out_tensor}
2490          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2491      return saved_model_dir
2492
2493    saved_model_dir = create_v1_saved_model()
2494
2495    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2496    converter.target_spec.supported_ops = [
2497        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2498    ]
2499    tflite_model = converter.convert()
2500    self.assertIsNotNone(tflite_model)
2501
2502    # Check values from converted model.
2503    interpreter = Interpreter(model_content=tflite_model)
2504    input_details = interpreter.get_input_details()
2505    output_details = interpreter.get_output_details()
2506
2507    interpreter.allocate_tensors()
2508
2509    input_data = np.array([0], dtype=np.int32)
2510    interpreter.set_tensor(input_details[0]['index'], input_data)
2511
2512    interpreter.invoke()
2513    expected_value = np.array([1], dtype=np.int32)
2514    actual_value = interpreter.get_tensor(output_details[0]['index'])
2515    self.assertEqual(expected_value, actual_value)
2516
2517    interpreter.invoke()
2518    actual_value = interpreter.get_tensor(output_details[0]['index'])
2519    self.assertEqual(expected_value, actual_value)
2520
2521    interpreter.invoke()
2522    actual_value = interpreter.get_tensor(output_details[0]['index'])
2523    self.assertEqual(expected_value, actual_value)
2524
2525  @test_util.run_v2_only
2526  def testVariantsWithWhile(self):
2527
2528    def create_v1_saved_model():
2529      saved_model_dir = os.path.join(self.get_temp_dir(), 'variants_with_while')
2530      with tf.Graph().as_default():
2531        with tf.compat.v1.Session() as sess:
2532          m = map_ops.empty_tensor_map()
2533
2534          def cond(i, m):
2535            del m
2536            return i < 10
2537
2538          def body(i, m):
2539            m = map_ops.tensor_map_insert(m, i, i)
2540            return i + 1, m
2541
2542          _, result_m = tf.while_loop(cond, body, [0, m])
2543          in_tensor = tf.compat.v1.placeholder(
2544              shape=[1], dtype=tf.int32, name='input')
2545          out_tensor = in_tensor + map_ops.tensor_map_size(result_m)
2546
2547          inputs = {'x': in_tensor}
2548          outputs = {'z': out_tensor}
2549          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2550      return saved_model_dir
2551
2552    saved_model_dir = create_v1_saved_model()
2553
2554    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2555    converter.target_spec.supported_ops = [
2556        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2557    ]
2558    tflite_model = converter.convert()
2559    self.assertIsNotNone(tflite_model)
2560
2561    # Check values from converted model.
2562    interpreter = Interpreter(model_content=tflite_model)
2563    input_details = interpreter.get_input_details()
2564    output_details = interpreter.get_output_details()
2565
2566    interpreter.allocate_tensors()
2567
2568    input_data = np.array([0], dtype=np.int32)
2569    interpreter.set_tensor(input_details[0]['index'], input_data)
2570
2571    interpreter.invoke()
2572    actual_value = interpreter.get_tensor(output_details[0]['index'])
2573    self.assertEqual(10, actual_value)
2574
2575    interpreter.invoke()
2576    actual_value = interpreter.get_tensor(output_details[0]['index'])
2577    self.assertEqual(10, actual_value)
2578
2579    interpreter.invoke()
2580    actual_value = interpreter.get_tensor(output_details[0]['index'])
2581    self.assertEqual(10, actual_value)
2582
2583  @test_util.run_v2_only
2584  def testResources(self):
2585
2586    def create_v1_saved_model():
2587      saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_resources')
2588      with tf.Graph().as_default():
2589        with tf.compat.v1.Session() as sess:
2590          in_tensor = tf.compat.v1.placeholder(
2591              shape=[1], dtype=tf.float32, name='input')
2592
2593          stack = tf.raw_ops.StackV2(max_size=10, elem_type=tf.float32)
2594          w = tf.raw_ops.StackPushV2(handle=stack, elem=in_tensor)
2595          with ops.control_dependencies([w]):
2596            a = in_tensor + in_tensor
2597            with ops.control_dependencies([a]):
2598              out_tensor = a + tf.raw_ops.StackPopV2(
2599                  handle=stack, elem_type=tf.float32)
2600
2601          inputs = {'x': in_tensor}
2602          outputs = {'z': out_tensor}
2603          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2604      return saved_model_dir
2605
2606    saved_model_dir = create_v1_saved_model()
2607
2608    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2609    converter.target_spec.supported_ops = [
2610        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2611    ]
2612    tflite_model = converter.convert()
2613    self.assertIsNotNone(tflite_model)
2614
2615    # Check values from converted model.
2616    interpreter = Interpreter(model_content=tflite_model)
2617    input_details = interpreter.get_input_details()
2618    output_details = interpreter.get_output_details()
2619
2620    interpreter.allocate_tensors()
2621
2622    input_data = np.array([1.0], dtype=np.float32)
2623    interpreter.set_tensor(input_details[0]['index'], input_data)
2624
2625    interpreter.invoke()
2626    actual_value = interpreter.get_tensor(output_details[0]['index'])
2627    self.assertEqual(3.0, actual_value)
2628
2629    interpreter.invoke()
2630    actual_value = interpreter.get_tensor(output_details[0]['index'])
2631    self.assertEqual(3.0, actual_value)
2632
2633    interpreter.invoke()
2634    actual_value = interpreter.get_tensor(output_details[0]['index'])
2635    self.assertEqual(3.0, actual_value)
2636
2637  @test_util.run_v2_only
2638  def testResourcesWithCond(self):
2639
2640    def create_v1_saved_model():
2641      saved_model_dir = os.path.join(self.get_temp_dir(), 'resources_with_cond')
2642      with tf.Graph().as_default():
2643        with tf.compat.v1.Session() as sess:
2644          in_tensor = tf.compat.v1.placeholder(
2645              shape=[1], dtype=tf.float32, name='input')
2646
2647          def body(i, arr):
2648            n = tf.raw_ops.StackPushV2(
2649                handle=arr, elem=tf.cast(i, dtype=tf.float32))
2650            return n, arr
2651
2652          arr = tf.raw_ops.StackV2(max_size=10, elem_type=tf.float32)
2653          n, result_arr = tf.cond(in_tensor < 10, lambda: body(0, arr),
2654                                  lambda: body(1, arr))
2655
2656          with ops.control_dependencies([result_arr, n]):
2657            out_tensor = tf.raw_ops.StackPopV2(
2658                handle=result_arr, elem_type=tf.float32)
2659
2660          inputs = {'x': in_tensor}
2661          outputs = {'a': out_tensor}
2662          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2663      return saved_model_dir
2664
2665    saved_model_dir = create_v1_saved_model()
2666
2667    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2668    converter.target_spec.supported_ops = [
2669        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2670    ]
2671    tflite_model = converter.convert()
2672    self.assertIsNotNone(tflite_model)
2673
2674    # Check values from converted model.
2675    interpreter = Interpreter(model_content=tflite_model)
2676    input_details = interpreter.get_input_details()
2677    output_details = interpreter.get_output_details()
2678
2679    interpreter.allocate_tensors()
2680
2681    input_data = np.array([1.0], dtype=np.float32)
2682    interpreter.set_tensor(input_details[0]['index'], input_data)
2683
2684    interpreter.invoke()
2685    actual_value = interpreter.get_tensor(output_details[0]['index'])
2686    self.assertEqual(0.0, actual_value)
2687
2688  @test_util.run_v2_only
2689  def testResourcesWithWhile(self):
2690
2691    def create_v1_saved_model():
2692      saved_model_dir = os.path.join(self.get_temp_dir(),
2693                                     'resources_with_while')
2694      with tf.Graph().as_default():
2695        with tf.compat.v1.Session() as sess:
2696          in_tensor = tf.compat.v1.placeholder(
2697              shape=[1], dtype=tf.float32, name='input')
2698
2699          def cond(i, arr, m):
2700            del arr
2701            del m
2702            return i < 10
2703
2704          def body(i, arr, m):
2705            del m
2706            n = tf.raw_ops.StackPushV2(
2707                handle=arr, elem=tf.cast(i, dtype=tf.float32))
2708            return i + 1, arr, n
2709
2710          arr = tf.raw_ops.StackV2(max_size=10, elem_type=tf.float32)
2711          _, result_arr, n = tf.while_loop(cond, body, [0, arr, 0.0])
2712
2713          with ops.control_dependencies([result_arr, n]):
2714            out_tensor = tf.raw_ops.StackPopV2(
2715                handle=result_arr, elem_type=tf.float32)
2716
2717          inputs = {'x': in_tensor}
2718          outputs = {'a': out_tensor}
2719          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2720      return saved_model_dir
2721
2722    saved_model_dir = create_v1_saved_model()
2723
2724    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2725    converter.target_spec.supported_ops = [
2726        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2727    ]
2728    tflite_model = converter.convert()
2729    self.assertIsNotNone(tflite_model)
2730
2731    # Check values from converted model.
2732    interpreter = Interpreter(model_content=tflite_model)
2733    input_details = interpreter.get_input_details()
2734    output_details = interpreter.get_output_details()
2735
2736    interpreter.allocate_tensors()
2737
2738    input_data = np.array([1.0], dtype=np.float32)
2739    interpreter.set_tensor(input_details[0]['index'], input_data)
2740
2741    interpreter.invoke()
2742    actual_value = interpreter.get_tensor(output_details[0]['index'])
2743    self.assertEqual(9.0, actual_value)
2744
2745  @parameterized.named_parameters(('EnableLoweringTensorListOps', True),
2746                                  ('DisableLoweringTensorListOps', False))
2747  @test_util.run_v2_only
2748  def testTensorListWithStaticSize(self, lower_tensor_list_ops):
2749
2750    def create_v1_saved_model():
2751      saved_model_dir = os.path.join(self.get_temp_dir(),
2752                                     'simple_mutable_variable')
2753      with tf.Graph().as_default():
2754        with tf.compat.v1.Session() as sess:
2755          in_tensor = tf.compat.v1.placeholder(
2756              shape=[1], dtype=tf.float32, name='input')
2757
2758          ta = tf.TensorArray(
2759              tf.float32, size=3, dynamic_size=False, clear_after_read=False)
2760          ta = ta.write(0, 10.0)
2761          ta = ta.write(1, 20.0)
2762          ta = ta.write(2, 30.0)
2763
2764          out_tensor = ta.read(0) + ta.read(2)
2765
2766          inputs = {'x': in_tensor}
2767          outputs = {'z': out_tensor}
2768          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2769      return saved_model_dir
2770
2771    saved_model_dir = create_v1_saved_model()
2772
2773    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2774    if not lower_tensor_list_ops:
2775      converter.target_spec.supported_ops = [
2776          tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2777      ]
2778    converter._experimental_lower_tensor_list_ops = lower_tensor_list_ops
2779    tflite_model = converter.convert()
2780    self.assertIsNotNone(tflite_model)
2781
2782    # Check values from converted model.
2783    interpreter = Interpreter(model_content=tflite_model)
2784    input_details = interpreter.get_input_details()
2785    output_details = interpreter.get_output_details()
2786
2787    interpreter.allocate_tensors()
2788
2789    input_data = np.array([1.0], dtype=np.float32)
2790    interpreter.set_tensor(input_details[0]['index'], input_data)
2791
2792    interpreter.invoke()
2793    actual_value = interpreter.get_tensor(output_details[0]['index'])
2794    self.assertEqual(40.0, actual_value)
2795
2796  @parameterized.named_parameters(('EnableLoweringTensorListOps', True),
2797                                  ('DisableLoweringTensorListOps', False))
2798  @test_util.run_v2_only
2799  def testTensorListWithDynamicSize(self, lower_tensor_list_ops):
2800
2801    def create_v1_saved_model():
2802      saved_model_dir = os.path.join(self.get_temp_dir(),
2803                                     'simple_mutable_variable')
2804      with tf.Graph().as_default():
2805        with tf.compat.v1.Session() as sess:
2806          in_tensor = tf.compat.v1.placeholder(
2807              shape=[1], dtype=tf.float32, name='input')
2808
2809          ta = tf.TensorArray(
2810              tf.float32, size=0, dynamic_size=True, clear_after_read=False)
2811          ta = ta.write(0, 10.0)
2812          ta = ta.write(1, 20.0)
2813          ta = ta.write(2, 30.0)
2814
2815          out_tensor = ta.read(0) + ta.read(2)
2816
2817          inputs = {'x': in_tensor}
2818          outputs = {'z': out_tensor}
2819          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2820      return saved_model_dir
2821
2822    saved_model_dir = create_v1_saved_model()
2823
2824    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2825    if lower_tensor_list_ops:
2826      with self.assertRaises(convert.ConverterError) as error:
2827        converter.convert()
2828      self.assertIn(
2829          'Lowering tensor list ops is failed. Please consider using Select '
2830          'TF ops and disabling `_experimental_lower_tensor_list_ops` flag in '
2831          'the TFLite converter object.', str(error.exception))
2832
2833    converter.target_spec.supported_ops = [
2834        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
2835    ]
2836    tflite_model = converter.convert()
2837    self.assertIsNotNone(tflite_model)
2838
2839    # Check values from converted model.
2840    interpreter = Interpreter(model_content=tflite_model)
2841    input_details = interpreter.get_input_details()
2842    output_details = interpreter.get_output_details()
2843
2844    interpreter.allocate_tensors()
2845
2846    input_data = np.array([1.0], dtype=np.float32)
2847    interpreter.set_tensor(input_details[0]['index'], input_data)
2848
2849    interpreter.invoke()
2850    actual_value = interpreter.get_tensor(output_details[0]['index'])
2851    self.assertEqual(40.0, actual_value)
2852
2853
2854class CalibrateAndQuantizeWithCustomOpTest(lite_v2_test_util.ModelTest):
2855
2856  def _createGraphWithCustomOp(self):
2857    # Create a graph that has one double op.
2858    np.random.seed(0)
2859
2860    saved_model_dir = os.path.join(self.get_temp_dir(), 'double_model')
2861    with ops.Graph().as_default():
2862      with tf.compat.v1.Session() as sess:
2863        in_tensor = tf.compat.v1.placeholder(
2864            shape=[1, 4], dtype=dtypes.float32, name='input')
2865        out_tensor = double_op.double(in_tensor)
2866        inputs = {'x': in_tensor}
2867        outputs = {'z': out_tensor}
2868        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2869
2870    def calibration_gen():
2871      for _ in range(100):
2872        yield [np.random.uniform(-1, 1, size=(1, 4)).astype(np.float32)]
2873
2874    return (saved_model_dir, calibration_gen)
2875
2876  def testCustomOpRegistererByName(self):
2877    """Test a calibration with custom op registered by name."""
2878    saved_model_dir, calibration_gen = self._createGraphWithCustomOp()
2879
2880    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2881    converter.optimizations = [lite.Optimize.DEFAULT]
2882    converter.representative_dataset = calibration_gen
2883    converter.allow_custom_ops = True
2884    converter.target_spec._experimental_custom_op_registerers = [
2885        'TF_TestRegisterer'
2886    ]
2887    tflite_model = converter.convert()
2888    self.assertTrue(tflite_model)
2889    self.assertGreater(test_registerer.get_num_test_registerer_calls(), 0)
2890    self.assertIn('Double', tflite_test_util.get_ops_list(tflite_model))
2891
2892    # Check the model works with custom ops.
2893    interpreter = InterpreterWithCustomOps(
2894        model_content=tflite_model, custom_op_registerers=['TF_TestRegisterer'])
2895    interpreter.allocate_tensors()
2896    input_details = interpreter.get_input_details()
2897    test_input = np.array([[0.0, 0.1, 0.2, 0.3]], dtype=np.float32)
2898    interpreter.set_tensor(input_details[0]['index'], test_input)
2899    interpreter.invoke()
2900
2901    output_details = interpreter.get_output_details()
2902    expected_output = np.array([[0.0, 0.2, 0.4, 0.6]], dtype=np.float32)
2903    output_data = interpreter.get_tensor(output_details[0]['index'])
2904    self.assertArrayNear(expected_output[0], output_data[0], err=1e-2)
2905
2906  def testCustomOpRegistererByFunc(self):
2907    """Test a calibration with custom op registered by function."""
2908    saved_model_dir, calibration_gen = self._createGraphWithCustomOp()
2909
2910    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2911    converter.optimizations = [lite.Optimize.DEFAULT]
2912    converter.representative_dataset = calibration_gen
2913    converter.allow_custom_ops = True
2914    converter.target_spec._experimental_custom_op_registerers = [
2915        test_registerer.TF_TestRegisterer
2916    ]
2917    tflite_model = converter.convert()
2918    self.assertTrue(tflite_model)
2919    self.assertGreater(test_registerer.get_num_test_registerer_calls(), 0)
2920    self.assertIn('Double', tflite_test_util.get_ops_list(tflite_model))
2921
2922    # Check the model works with custom ops.
2923    interpreter = InterpreterWithCustomOps(
2924        model_content=tflite_model,
2925        custom_op_registerers=[test_registerer.TF_TestRegisterer])
2926    interpreter.allocate_tensors()
2927    input_details = interpreter.get_input_details()
2928    test_input = np.array([[0.0, 0.1, 0.2, 0.3]], dtype=np.float32)
2929    interpreter.set_tensor(input_details[0]['index'], test_input)
2930    interpreter.invoke()
2931
2932    output_details = interpreter.get_output_details()
2933    expected_output = np.array([[0.0, 0.2, 0.4, 0.6]], dtype=np.float32)
2934    output_data = interpreter.get_tensor(output_details[0]['index'])
2935    self.assertArrayNear(expected_output[0], output_data[0], err=1e-2)
2936
2937  def testCustomOpRegistererFailure(self):
2938    """Test a calibration with wrong custom op registerer."""
2939    saved_model_dir, calibration_gen = self._createGraphWithCustomOp()
2940
2941    bogus_name = 'CompletelyBogusRegistererName'
2942
2943    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
2944    converter.optimizations = [lite.Optimize.DEFAULT]
2945    converter.representative_dataset = calibration_gen
2946    converter.allow_custom_ops = True
2947    converter.target_spec._experimental_custom_op_registerers = [bogus_name]
2948
2949    with self.assertRaisesRegex(
2950        ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'):
2951      converter.convert()
2952
2953
2954class IntermediatesTest(lite_v2_test_util.ModelTest):
2955
2956  def _run(self, experimental_preserve_all_tensors):
2957
2958    @tf.function
2959    def f(x):
2960      y = tf.add(x, x, name='y')
2961      z = tf.add(y, y, name='z')
2962      w = tf.add(z, z, name='w')
2963      return w
2964
2965    # NOTE this is exactly representable as a float as are the intermeidates of
2966    # f. So direct comparison is ok below.
2967
2968    input_data = np.array(2.0, np.float32)
2969    concrete_func = f.get_concrete_function(input_data)
2970    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
2971                                                               f)
2972    tflite_model = converter.convert()
2973    interpreter = Interpreter(
2974        model_content=tflite_model,
2975        experimental_preserve_all_tensors=experimental_preserve_all_tensors)
2976    interpreter.allocate_tensors()
2977    interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
2978                           input_data)
2979    interpreter.invoke()
2980    out = interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
2981    tensors = {}
2982    for t in interpreter.get_tensor_details():
2983      # With Tensorflow Lite default delegate applied to the model graph, the
2984      # access to original tensors of a delegated op could cause a ValueError
2985      # (i.e. 'Tensor data is null. Run allocate_tensors() first') to be thrown
2986      # out because the tensor memory isn't allocated at all.
2987      val = None
2988      try:
2989        val = interpreter.get_tensor(t['index'])
2990      except ValueError:
2991        pass
2992      tensors.update({t['name']: val})
2993    return (tensors, out)
2994
2995  def testPreserve(self):
2996    tensors, result = self._run(experimental_preserve_all_tensors=True)
2997    # All intermediates should be true and result be true.
2998    self.assertAllClose(tensors['x'], 2.0)
2999    self.assertAllClose(tensors['y'], 4.0)
3000    self.assertAllClose(tensors['z'], 8.0)
3001    self.assertAllClose(result, 16.0)
3002
3003  def testNoPreserve(self):
3004    tensors, result = self._run(experimental_preserve_all_tensors=False)
3005    # One of them should be wrong if preserve is not true, but result should be
3006    # ok. Input should still be ok for repeated invocation.
3007    self.assertAllClose(tensors['x'], 2.0)
3008    self.assertTrue(tensors['y'] != 4.0 or tensors['z'] != 8.0)
3009    self.assertAllClose(result, 16.0)
3010
3011
3012class DatasetOpsTest(lite_v2_test_util.ModelTest):
3013
3014  @test_util.run_v2_only
3015  def testReduceDataset(self):
3016
3017    @tf.function
3018    def model():
3019      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
3020      output = dataset.reduce(np.int32(0), lambda x, y: x + y)
3021      return output
3022
3023    concrete_func = model.get_concrete_function()
3024    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
3025                                                               model)
3026    converter.target_spec.supported_ops = [
3027        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
3028    ]
3029    tflite_model = converter.convert()
3030    self.assertIsNotNone(tflite_model)
3031
3032    # Check values from converted model.
3033    interpreter = Interpreter(model_content=tflite_model)
3034    output_details = interpreter.get_output_details()
3035
3036    interpreter.allocate_tensors()
3037
3038    interpreter.invoke()
3039    actual_value = interpreter.get_tensor(output_details[0]['index'])
3040    self.assertEqual(10, actual_value)
3041
3042
3043if __name__ == '__main__':
3044  test.main()
3045