• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for lite.py."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import io
23import logging
24import os
25import tempfile
26
27from absl.testing import parameterized
28import numpy as np
29import six
30from six.moves import range
31from tensorflow import keras
32
33from tensorflow.lite.python import lite
34from tensorflow.lite.python import lite_constants
35from tensorflow.lite.python import schema_py_generated as schema_fb
36from tensorflow.lite.python import util
37from tensorflow.lite.python.convert import ConverterError
38from tensorflow.lite.python.convert import mlir_quantize
39from tensorflow.lite.python.interpreter import Interpreter
40from tensorflow.python.client import session
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import convert_to_constants
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import ops
47from tensorflow.python.framework import test_util
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import gen_array_ops
50from tensorflow.python.ops import logging_ops
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import nn_ops
53from tensorflow.python.ops import random_ops
54from tensorflow.python.ops import variable_scope
55from tensorflow.python.ops import variables
56from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
57from tensorflow.python.platform import gfile
58from tensorflow.python.platform import resource_loader
59from tensorflow.python.platform import test
60from tensorflow.python.saved_model import saved_model
61from tensorflow.python.training.training_util import write_graph
62
63
64class LiteTest(test_util.TensorFlowTestCase):
65  """Base class of all the tests in this module."""
66
67
68class TestModels(LiteTest):
69
70  def assertValidDebugInfo(self, debug_info):
71    """Verify the DebugInfo is valid."""
72    file_names = set()
73    for file_path in debug_info.files:
74      file_names.add(os.path.basename(file_path))
75    # To make the test independent on how the nodes are created, we only assert
76    # the name of this test file.
77    self.assertIn('lite_test.py', file_names)
78    self.assertNotIn('lite_v2_test.py', file_names)
79
80
81class FromConstructor(TestModels):
82
83  # Tests invalid constructors using a dummy value for the GraphDef.
84  def testInvalidConstructor(self):
85    message = (
86        'If input_tensors and output_tensors are None, both '
87        'input_arrays_with_shape and output_arrays|control_output_arrays must '
88        'be defined.')
89
90    # `output_arrays` is not defined.
91    with self.assertRaises(ValueError) as error:
92      lite.TFLiteConverter(
93          None, None, [], input_arrays_with_shape=[('input', [3,
94                                                              9])]).convert()
95    self.assertEqual(message, str(error.exception))
96
97    # `input_arrays_with_shape` is not defined.
98    with self.assertRaises(ValueError) as error:
99      lite.TFLiteConverter(None, [], None, output_arrays=['output']).convert()
100    self.assertEqual(message, str(error.exception))
101
102  # Tests valid constructors using a dummy value for the GraphDef.
103  def testValidConstructor(self):
104    converter = lite.TFLiteConverter(
105        None,
106        None,
107        None,
108        input_arrays_with_shape=[('input', [3, 9])],
109        output_arrays=['output'])
110    self.assertFalse(converter._has_valid_tensors())
111    self.assertEqual(converter.get_input_arrays(), ['input'])
112
113    with self.assertRaises(ValueError) as error:
114      converter._set_batch_size(1)
115    self.assertEqual(
116        'The batch size cannot be set for this model. Please use '
117        'input_shapes parameter.', str(error.exception))
118
119    converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
120    self.assertTrue(converter._has_valid_tensors())
121
122  def testRedundantArgumentsWarning(self):
123    """Test if the warning message when there are redundant arguments."""
124    with ops.Graph().as_default():
125      in_tensor = array_ops.placeholder(
126          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
127      out_tensor = math_ops.add(in_tensor, in_tensor, name='add')
128      sess = session.Session()
129
130    frozen_graph_def = (
131        convert_to_constants.convert_variables_to_constants_from_session_graph(
132            sess, sess.graph_def, ['add']))
133
134    # Convert model and ensure model is not None.
135    log = io.BytesIO() if six.PY2 else io.StringIO()
136    handler = logging.StreamHandler(log)
137    logging.root.addHandler(handler)
138    converter = lite.TFLiteConverter(frozen_graph_def, [in_tensor],
139                                     [out_tensor],
140                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
141
142    input_warning_message = 'input_arrays_with_shape will be ignored'
143    output_warning_message = 'output_arrays will be ignored'
144
145    # Convert model and ensure model is not None.
146    tflite_model = converter.convert()
147    self.assertIsNotNone(tflite_model)
148    self.assertIn(input_warning_message, log.getvalue())
149    self.assertIn(output_warning_message, log.getvalue())
150    logging.root.removeHandler(handler)
151
152  def testShapeOverriding(self):
153    """Test a shape overriding case via the constructor."""
154    with ops.Graph().as_default():
155      in_tensor = array_ops.placeholder(
156          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
157      math_ops.add(in_tensor, in_tensor, name='add')
158      sess = session.Session()
159
160    frozen_graph_def = (
161        convert_to_constants.convert_variables_to_constants_from_session_graph(
162            sess, sess.graph_def, ['add']))
163
164    # Convert model and ensure model is not None.
165    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
166                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
167    tflite_model = converter.convert()
168    self.assertIsNotNone(tflite_model)
169
170    # Check values from converted model.
171    interpreter = Interpreter(model_content=tflite_model)
172    interpreter.allocate_tensors()
173
174    input_details = interpreter.get_input_details()
175    self.assertLen(input_details, 1)
176    self.assertEqual('in_tensor', input_details[0]['name'])
177    self.assertEqual(np.float32, input_details[0]['dtype'])
178    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
179    self.assertEqual((0., 0.), input_details[0]['quantization'])
180
181    output_details = interpreter.get_output_details()
182    self.assertLen(output_details, 1)
183    self.assertEqual('add', output_details[0]['name'])
184    self.assertEqual(np.float32, output_details[0]['dtype'])
185    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
186    self.assertEqual((0., 0.), output_details[0]['quantization'])
187
188  def testPartialShapeOverriding(self):
189    """Test a partial shape overriding case via the constructor."""
190    with ops.Graph().as_default():
191      in_tensor_a = array_ops.placeholder(
192          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_a')
193      in_tensor_b = array_ops.placeholder(
194          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_b')
195      math_ops.add(in_tensor_a, in_tensor_b, name='add')
196      sess = session.Session()
197
198    frozen_graph_def = (
199        convert_to_constants.convert_variables_to_constants_from_session_graph(
200            sess, sess.graph_def, ['add']))
201
202    # Convert model and ensure model is not None.
203    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
204                                     [('in_tensor_a', [2, 16, 16, 3])], ['add'])
205    # There is an unhandled Placeholder op.
206    with self.assertRaises(ConverterError):
207      converter.convert()
208
209  def testInvalidShapeOverriding(self):
210    """Test an invalid shape overriding case via the constructor."""
211    with ops.Graph().as_default():
212      in_tensor = array_ops.placeholder(
213          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
214      math_ops.add(in_tensor, in_tensor, name='add')
215      sess = session.Session()
216
217    frozen_graph_def = (
218        convert_to_constants.convert_variables_to_constants_from_session_graph(
219            sess, sess.graph_def, ['add']))
220
221    # Convert model and ensure model is not None.
222    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
223                                     [('wrong_tensor', [2, 16, 16, 3])],
224                                     ['add'])
225    with self.assertRaises(ConverterError):
226      converter.convert()
227
228
229class FromSessionTest(TestModels, parameterized.TestCase):
230
231  def testFloatModel(self):
232    with ops.Graph().as_default():
233      in_tensor = array_ops.placeholder(
234          shape=[1, 16, 16, 3], dtype=dtypes.float32)
235      out_tensor = in_tensor + in_tensor
236      sess = session.Session()
237
238    # Convert model and ensure model is not None.
239    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
240                                                  [out_tensor])
241    tflite_model = converter.convert()
242    self.assertIsNotNone(tflite_model)
243
244    # Check values from converted model.
245    interpreter = Interpreter(model_content=tflite_model)
246    interpreter.allocate_tensors()
247
248    input_details = interpreter.get_input_details()
249    self.assertLen(input_details, 1)
250    self.assertEqual('Placeholder', input_details[0]['name'])
251    self.assertEqual(np.float32, input_details[0]['dtype'])
252    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
253    self.assertEqual((0., 0.), input_details[0]['quantization'])
254
255    output_details = interpreter.get_output_details()
256    self.assertLen(output_details, 1)
257    self.assertEqual('add', output_details[0]['name'])
258    self.assertEqual(np.float32, output_details[0]['dtype'])
259    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
260    self.assertEqual((0., 0.), output_details[0]['quantization'])
261
262  def testFloatModelQuantizedInput(self):
263    with ops.Graph().as_default():
264      in_tensor = array_ops.placeholder(
265          shape=[1, 16, 16, 3], dtype=dtypes.float32)
266      out_tensor = in_tensor + in_tensor
267      sess = session.Session()
268
269    # Convert model and ensure model is not None.
270    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
271                                                  [out_tensor])
272    converter.inference_input_type = dtypes.uint8
273    converter.inference_type = dtypes.float32
274    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
275    tflite_model = converter.convert()
276    self.assertIsNotNone(tflite_model)
277
278    # Check values from converted model.
279    interpreter = Interpreter(model_content=tflite_model)
280    interpreter.allocate_tensors()
281
282    input_details = interpreter.get_input_details()
283    self.assertLen(input_details, 1)
284    self.assertEqual('Placeholder', input_details[0]['name'])
285    self.assertEqual(np.uint8, input_details[0]['dtype'])
286    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
287    self.assertEqual((1., 0.), input_details[0]['quantization'])
288
289    output_details = interpreter.get_output_details()
290    self.assertLen(output_details, 1)
291    self.assertEqual('add', output_details[0]['name'])
292    self.assertEqual(np.float32, output_details[0]['dtype'])
293    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
294    self.assertEqual((0., 0.), output_details[0]['quantization'])  # float
295
296  def testForgottenCallToAllocateTensors(self):
297    with ops.Graph().as_default():
298      in_tensor = array_ops.placeholder(
299          shape=[1, 16, 16, 3], dtype=dtypes.float32)
300      out_tensor = in_tensor + in_tensor
301      sess = session.Session()
302    # Convert model and ensure model is not None.
303    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
304                                                  [out_tensor])
305    tflite_model = converter.convert()
306    self.assertIsNotNone(tflite_model)
307
308    # Check values from converted model.
309    interpreter = Interpreter(model_content=tflite_model)
310    input_index = interpreter.get_input_details()[0]['index']
311    dummy_tensor = np.ones(shape=[1, 16, 16, 3], dtype=np.float32)
312    with self.assertRaises(ValueError):
313      interpreter.set_tensor(input_index, dummy_tensor)
314
315  @parameterized.named_parameters(
316      ('_INT8InputOutput', False, False, dtypes.int8),
317      ('_UINT8InputOutput', False, False, dtypes.uint8),
318      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
319      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
320      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
321      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
322      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
323      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
324  def testIntegerQuantizationWithUnsupportedOps(self,
325                                                is_int_only,
326                                                is_int16_quantize,
327                                                inference_input_output_type,
328                                                enable_mlir_quantizer=False):
329    with ops.Graph().as_default():
330      in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
331      in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
332      # ceil kernel does not support int8 nor int16 types neither.
333      left = math_ops.ceil(in_tensor_a)
334      out_tensor_b = math_ops.tanh(in_tensor_b)
335      add = math_ops.add(left, out_tensor_b)
336      # ceil kernel does not support int8 nor int16 types neither.
337      out_tensor_a = math_ops.ceil(add)
338      sess = session.Session()
339
340    def calibration_gen():
341      for _ in range(5):
342        yield [
343            np.random.uniform(-1, 1, size=(3)).astype(np.float32),
344            np.random.uniform(-1, 1, size=(3)).astype(np.float32)
345        ]
346
347    quantized_converter = lite.TFLiteConverter.from_session(
348        sess, [in_tensor_a, in_tensor_b], [out_tensor_a, out_tensor_b])
349    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
350    quantized_converter.representative_dataset = calibration_gen
351    if is_int_only:
352      if is_int16_quantize:
353        quantized_converter.target_spec.supported_ops = [
354            lite.OpsSet
355            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
356            lite.OpsSet.TFLITE_BUILTINS
357        ]
358      else:
359        quantized_converter.target_spec.supported_ops = [
360            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
361        ]
362    else:
363      if is_int16_quantize:
364        quantized_converter.target_spec.supported_ops = [
365            lite.OpsSet
366            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
367            lite.OpsSet.TFLITE_BUILTINS
368        ]
369      else:
370        quantized_converter.target_spec.supported_ops = [
371            lite.OpsSet.TFLITE_BUILTINS
372        ]
373
374    quantized_converter.inference_input_type = inference_input_output_type
375    quantized_converter.inference_output_type = inference_input_output_type
376    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
377    quantized_tflite_model = quantized_converter.convert()
378    self.assertIsNotNone(quantized_tflite_model)
379
380    expected_dtype = inference_input_output_type.as_numpy_dtype
381    # Allow float32 for fallback on non-quantizable op.
382    expected_ceil_dtype = (
383        expected_dtype if enable_mlir_quantizer else dtypes.float32)
384
385    interpreter = Interpreter(model_content=quantized_tflite_model)
386    interpreter.allocate_tensors()
387    input_details = interpreter.get_input_details()
388    self.assertLen(input_details, 2)
389    self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
390    self.assertEqual(input_details[1]['dtype'], expected_dtype)
391    output_details = interpreter.get_output_details()
392    self.assertLen(output_details, 2)
393    self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
394    self.assertEqual(output_details[1]['dtype'], expected_dtype)
395
396  @parameterized.named_parameters(
397      ('_PerChannelQuant', False, False), ('_PerChannelMlirQuant', False, True),
398      ('_PerTensorQuant', True, False), ('_PerTensorMlirQuant', True, True))
399  def testDisablePerChannelQuantization(self,
400                                        disable_per_channel=False,
401                                        enable_mlir_quantizer=False):
402    k_conv_name = 'Conv2D1'
403    k_num_filters = 16
404    with ops.Graph().as_default():
405      inp, output, calibration_gen = self._getIntegerQuantizeModel()
406      sess = session.Session()
407
408    quantized_converter = lite.TFLiteConverter.from_session(
409        sess, [inp], [output])
410    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
411    quantized_converter.representative_dataset = calibration_gen
412    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
413    if disable_per_channel:
414      quantized_converter._experimental_disable_per_channel = (
415          disable_per_channel)
416    quantized_tflite_model = quantized_converter.convert()
417    self.assertIsNotNone(quantized_tflite_model)
418
419    interpreter = Interpreter(model_content=quantized_tflite_model)
420    interpreter.allocate_tensors()
421    detail = next((d for d in interpreter.get_tensor_details()
422                   if d['name'] == k_conv_name))
423    quant_params = detail['quantization_parameters']
424    expected_num_params = 1 if disable_per_channel else k_num_filters
425    self.assertLen(quant_params['scales'], expected_num_params)
426    self.assertLen(quant_params['zero_points'], expected_num_params)
427
428  @parameterized.named_parameters(
429      ('EnableMlirConverter', True),  # enable mlir
430      ('DisableMlirConverter', False))  # disable mlir
431  def testString(self, enable_mlir_converter):
432    with ops.Graph().as_default():
433      in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
434      out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
435      sess = session.Session()
436
437    # Convert model and ensure model is not None.
438    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
439                                                  [out_tensor])
440    converter.experimental_new_converter = enable_mlir_converter
441    tflite_model = converter.convert()
442    self.assertIsNotNone(tflite_model)
443
444    # Check values from converted model.
445    interpreter = Interpreter(model_content=tflite_model)
446    interpreter.allocate_tensors()
447
448    input_details = interpreter.get_input_details()
449    self.assertLen(input_details, 1)
450    self.assertEqual('Placeholder', input_details[0]['name'])
451    self.assertEqual(np.string_, input_details[0]['dtype'])
452    self.assertAllEqual([4], input_details[0]['shape'])
453
454    output_details = interpreter.get_output_details()
455    self.assertLen(output_details, 1)
456    self.assertEqual('Reshape', output_details[0]['name'])
457    self.assertEqual(np.string_, output_details[0]['dtype'])
458    self.assertAllEqual([2, 2], output_details[0]['shape'])
459    # TODO(b/122659643): Test setting/getting string data via the python
460    # interpreter API after support has been added.
461
462  def testIntermediateInputArray(self):
463    """Convert a model from an intermediate input array."""
464    with ops.Graph().as_default():
465      in_tensor_init = array_ops.placeholder(
466          shape=[1, 16, 16, 3], dtype=dtypes.float32)
467      in_tensor_final = in_tensor_init + in_tensor_init
468      out_tensor = in_tensor_final + in_tensor_final
469      sess = session.Session()
470
471    # Convert model and ensure model is not None.
472    converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final],
473                                                  [out_tensor])
474    tflite_model = converter.convert()
475    self.assertIsNotNone(tflite_model)
476
477    # Check values from converted model.
478    interpreter = Interpreter(model_content=tflite_model)
479    interpreter.allocate_tensors()
480
481    input_details = interpreter.get_input_details()
482    self.assertLen(input_details, 1)
483    self.assertEqual('add', input_details[0]['name'])
484    self.assertEqual(np.float32, input_details[0]['dtype'])
485    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
486    self.assertEqual((0., 0.), input_details[0]['quantization'])
487
488    output_details = interpreter.get_output_details()
489    self.assertLen(output_details, 1)
490    self.assertEqual('add_1', output_details[0]['name'])
491    self.assertEqual(np.float32, output_details[0]['dtype'])
492    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
493    self.assertEqual((0., 0.), output_details[0]['quantization'])
494
495  def testSizeNoneInvalid(self):
496    with ops.Graph().as_default():
497      in_tensor = array_ops.placeholder(dtype=dtypes.float32)
498      out_tensor = in_tensor + in_tensor
499      sess = session.Session()
500
501    # Test None as shape when dynamic shapes are disabled. Run with TOCO in
502    # order to invoke shape checking code.
503    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
504                                                  [out_tensor])
505    converter.experimental_new_converter = False
506    with self.assertRaises(ValueError) as error:
507      converter.convert()
508    self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
509                     str(error.exception))
510
511  @parameterized.named_parameters(
512      ('EnableMlirConverter', True),  # enable mlir
513      ('DisableMlirConverter', False))  # disable mlir
514  def testScalarValid(self, enable_mlir_converter):
515    # Construct a graph using a scalar (empty shape) input.
516    with ops.Graph().as_default():
517      in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
518      out_tensor = in_tensor + in_tensor
519      sess = session.Session()
520
521    # Test conversion with the scalar input shape.
522    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
523                                                  [out_tensor])
524    converter.experimental_new_converter = enable_mlir_converter
525    tflite_model = converter.convert()
526    self.assertIsNotNone(tflite_model)
527
528    # Check values from converted model.
529    interpreter = Interpreter(model_content=tflite_model)
530    interpreter.allocate_tensors()
531
532    input_details = interpreter.get_input_details()
533    self.assertLen(input_details, 1)
534    self.assertEqual('Placeholder', input_details[0]['name'])
535    self.assertEqual(np.float32, input_details[0]['dtype'])
536    self.assertEmpty(input_details[0]['shape'])
537
538    output_details = interpreter.get_output_details()
539    self.assertLen(output_details, 1)
540    self.assertEqual('add', output_details[0]['name'])
541    self.assertEqual(np.float32, output_details[0]['dtype'])
542    self.assertEmpty(input_details[0]['shape'])
543
544    # Validate inference using the scalar inputs/outputs.
545    test_input = np.array(4.0, dtype=np.float32)
546    expected_output = np.array(8.0, dtype=np.float32)
547    interpreter.set_tensor(input_details[0]['index'], test_input)
548    interpreter.invoke()
549
550    output_data = interpreter.get_tensor(output_details[0]['index'])
551    self.assertEqual(expected_output, output_data)
552
553  def testSizeInvalid(self):
554    with ops.Graph().as_default():
555      in_tensor = array_ops.placeholder(
556          shape=[1, None, 16, 3], dtype=dtypes.float32)
557      out_tensor = in_tensor + in_tensor
558      sess = session.Session()
559
560    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
561    # invoke shape checking code.
562    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
563                                                  [out_tensor])
564    converter.experimental_new_converter = False
565    with self.assertRaises(ValueError) as error:
566      converter.convert()
567    self.assertEqual(
568        'None is only supported in the 1st dimension. Tensor '
569        '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
570        str(error.exception))
571
572  def testSizeNone(self):
573    with ops.Graph().as_default():
574      in_tensor = array_ops.placeholder(
575          shape=[1, None, 16, 3], dtype=dtypes.float32)
576      out_tensor = in_tensor + in_tensor
577      sess = session.Session()
578
579    # Test None after 1st dimension.
580    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
581                                                  [out_tensor])
582    tflite_model = converter.convert()
583
584    # Check values from converted model.
585    interpreter = Interpreter(model_content=tflite_model)
586    input_details = interpreter.get_input_details()
587    self.assertLen(input_details, 1)
588    self.assertEqual('Placeholder', input_details[0]['name'])
589    self.assertEqual(np.float32, input_details[0]['dtype'])
590    self.assertAllEqual([1, 1, 16, 3], input_details[0]['shape'])
591    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])
592    self.assertEqual((0., 0.), input_details[0]['quantization'])
593
594    # Resize tensor with strict checking.
595    with self.assertRaises(RuntimeError) as error:
596      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
597    self.assertIn(
598        'ResizeInputTensorStrict only allows mutating unknown dimensions '
599        'identified by -1.', str(error.exception))
600
601    # Resize tensor and invoke.
602    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
603    interpreter.allocate_tensors()
604    interpreter.invoke()
605
606    input_details = interpreter.get_input_details()
607    self.assertLen(input_details, 1)
608    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
609    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])
610
611    output_details = interpreter.get_output_details()
612    self.assertAllEqual([1, -1, 16, 3], output_details[0]['shape_signature'])
613
614  def testResizeTensorInputStrict(self):
615    # Ensures that resize_tensor_input(strict=True) works as expected.
616    with ops.Graph().as_default():
617      in_tensor = array_ops.placeholder(
618          shape=[1, 16, 16, 3], dtype=dtypes.float32)
619      out_tensor = in_tensor + in_tensor
620      sess = session.Session()
621
622    # Convert model and ensure model is not None.
623    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
624                                                  [out_tensor])
625    tflite_model = converter.convert()
626    self.assertIsNotNone(tflite_model)
627
628    # Check values from converted model.
629    interpreter = Interpreter(model_content=tflite_model)
630
631    # Resize incorrect value.
632    with self.assertRaises(RuntimeError) as error:
633      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
634    self.assertIn(
635        'ResizeInputTensorStrict only allows mutating unknown dimensions '
636        'identified by -1.', str(error.exception))
637
638    # Resize correct value.
639    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
640    interpreter.allocate_tensors()
641
642  def testBatchSizeValid(self):
643    with ops.Graph().as_default():
644      in_tensor = array_ops.placeholder(
645          shape=[None, 16, 16, 3], dtype=dtypes.float32)
646      out_tensor = in_tensor + in_tensor
647      sess = session.Session()
648
649    # Convert model and ensure model is not None.
650    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
651                                                  [out_tensor])
652    tflite_model = converter.convert()
653    self.assertIsNotNone(tflite_model)
654
655    # Check values from converted model.
656    interpreter = Interpreter(model_content=tflite_model)
657    interpreter.allocate_tensors()
658
659    input_details = interpreter.get_input_details()
660    self.assertLen(input_details, 1)
661    self.assertEqual('Placeholder', input_details[0]['name'])
662    self.assertEqual(np.float32, input_details[0]['dtype'])
663    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
664    self.assertEqual((0., 0.), input_details[0]['quantization'])
665
666    output_details = interpreter.get_output_details()
667    self.assertLen(output_details, 1)
668    self.assertEqual('add', output_details[0]['name'])
669    self.assertEqual(np.float32, output_details[0]['dtype'])
670    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
671    self.assertEqual((0., 0.), output_details[0]['quantization'])
672
673  def testBatchSizeNonZero(self):
674    with ops.Graph().as_default():
675      in_tensor_1 = array_ops.placeholder(
676          shape=[None, 4], dtype=dtypes.float32, name='input1')
677      in_tensor_2 = array_ops.placeholder(
678          shape=[4, 10], dtype=dtypes.float32, name='input2')
679      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2)
680      sess = session.Session()
681
682    # Convert model and ensure model is not None.
683    converter = lite.TFLiteConverter.from_session(sess,
684                                                  [in_tensor_1, in_tensor_2],
685                                                  [out_tensor])
686    tflite_model = converter.convert()
687    self.assertIsNotNone(tflite_model)
688
689    # Check values from converted model.
690    interpreter = Interpreter(model_content=tflite_model)
691    interpreter.allocate_tensors()
692
693    input_details = interpreter.get_input_details()
694    self.assertLen(input_details, 2)
695    self.assertEqual('input1', input_details[0]['name'])
696    self.assertAllEqual([1, 4], input_details[0]['shape'])
697    self.assertEqual('input2', input_details[1]['name'])
698    self.assertAllEqual([4, 10], input_details[1]['shape'])
699
700  def testFreezeGraph(self):
701    with ops.Graph().as_default():
702      in_tensor = array_ops.placeholder(
703          shape=[1, 16, 16, 3], dtype=dtypes.float32)
704      var = variable_scope.get_variable(
705          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
706      # Get the second output to ensure freezing properly processes tensor names
707      # like 'X:1'.
708      out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1]
709      sess = session.Session()
710      sess.run(_global_variables_initializer())
711
712    # Convert model and ensure model is not None.
713    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
714                                                  [out_tensor])
715    tflite_model = converter.convert()
716    self.assertIsNotNone(tflite_model)
717
718    # Check values from converted model.
719    interpreter = Interpreter(model_content=tflite_model)
720    interpreter.allocate_tensors()
721
722    input_details = interpreter.get_input_details()
723    self.assertLen(input_details, 1)
724    self.assertEqual('Placeholder', input_details[0]['name'])
725    self.assertEqual(np.float32, input_details[0]['dtype'])
726    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
727    self.assertEqual((0., 0.), input_details[0]['quantization'])
728
729    output_details = interpreter.get_output_details()
730    self.assertLen(output_details, 1)
731    self.assertEqual('top_k:1', output_details[0]['name'])
732    self.assertEqual(np.int32, output_details[0]['dtype'])
733    self.assertAllEqual([1, 16, 16, 1], output_details[0]['shape'])
734    self.assertEqual((0., 0.), output_details[0]['quantization'])
735
736  def testGraphviz(self):
737    with ops.Graph().as_default():
738      in_tensor = array_ops.placeholder(
739          shape=[1, 16, 16, 3], dtype=dtypes.float32)
740      out_tensor = in_tensor + in_tensor
741      sess = session.Session()
742
743    # Convert model and ensure model is not None.
744    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
745                                                  [out_tensor])
746    converter.output_format = lite_constants.GRAPHVIZ_DOT
747    graphviz_output = converter.convert()
748    self.assertIsNotNone(graphviz_output)
749
750  @parameterized.named_parameters(
751      ('EnableMlirConverter', True),  # enable mlir
752      ('DisableMlirConverter', False))  # disable mlir
753  def testDumpGraphviz(self, enable_mlir_converter):
754    with ops.Graph().as_default():
755      in_tensor = array_ops.placeholder(
756          shape=[1, 16, 16, 3], dtype=dtypes.float32)
757      out_tensor = in_tensor + in_tensor
758      sess = session.Session()
759
760    # Convert model and ensure model is not None.
761    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
762                                                  [out_tensor])
763    converter.experimental_new_converter = enable_mlir_converter
764    graphviz_dir = self.get_temp_dir()
765    converter.dump_graphviz_dir = graphviz_dir
766    tflite_model = converter.convert()
767    self.assertIsNotNone(tflite_model)
768
769    # Ensure interpreter is able to allocate and check graphviz data.
770    interpreter = Interpreter(model_content=tflite_model)
771    interpreter.allocate_tensors()
772
773    num_items_graphviz = len(os.listdir(graphviz_dir))
774    self.assertIsNotNone(num_items_graphviz)
775    self.assertIsNotNone(
776        os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot')))
777    self.assertIsNotNone(
778        os.path.exists(
779            os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot')))
780
781    # new converter doesn't support `dump_graphviz_video` flag
782    if not enable_mlir_converter:
783      # Convert model and ensure model is not None.
784      converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
785                                                    [out_tensor])
786      converter.experimental_new_converter = enable_mlir_converter
787      graphviz_dir = self.get_temp_dir()
788      converter.dump_graphviz_dir = graphviz_dir
789      converter.dump_graphviz_video = True
790      tflite_model = converter.convert()
791      self.assertIsNotNone(tflite_model)
792
793      # Ensure graphviz folder has more data after using video flag.
794      num_items_graphviz_video = len(os.listdir(graphviz_dir))
795      self.assertGreater(num_items_graphviz_video, num_items_graphviz)
796
797  def testDumpConversionSummary(self):
798    with ops.Graph().as_default():
799      in_tensor = array_ops.placeholder(
800          shape=[1, 16, 16, 3], dtype=dtypes.float32)
801      out_tensor = in_tensor + in_tensor
802      sess = session.Session()
803
804    # Convert model and ensure model is not None.
805    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
806                                                  [out_tensor])
807    log_dir = self.get_temp_dir()
808    converter.conversion_summary_dir = log_dir
809    tflite_model = converter.convert()
810    self.assertIsNotNone(tflite_model)
811
812    self.assertNotEmpty(os.listdir(log_dir))
813
814  def testDumpConversionSummaryWithOldConverter(self):
815    with ops.Graph().as_default():
816      in_tensor = array_ops.placeholder(
817          shape=[1, 16, 16, 3], dtype=dtypes.float32)
818      out_tensor = in_tensor + in_tensor
819      sess = session.Session()
820
821    # Convert model and ensure model is not None.
822    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
823                                                  [out_tensor])
824    converter.experimental_new_converter = False
825    log_dir = self.get_temp_dir()
826    converter.conversion_summary_dir = log_dir
827    tflite_model = converter.convert()
828    self.assertIsNotNone(tflite_model)
829    # Check nothing is generated under the conversion summary path.
830    num_items_conversion_summary = len(os.listdir(log_dir))
831    self.assertEqual(num_items_conversion_summary, 0)
832
833  @parameterized.named_parameters(
834      ('EnableMlirConverter', True),  # enable mlir
835      ('DisableMlirConverter', False))  # disable mlir
836  def testQuantizeDynamicRange(self, enable_mlir_converter):
837    np.random.seed(0)
838    with ops.Graph().as_default():
839      # We need the tensor to have more than 1024 elements for quantize_weights
840      # to kick in. Thus, the [33, 33] shape.
841      in_tensor_1 = array_ops.placeholder(
842          shape=[33, 33], dtype=dtypes.float32, name='inputA')
843      in_tensor_2 = constant_op.constant(
844          np.random.uniform(low=-10., high=10., size=(33, 33)),
845          shape=[33, 33],
846          dtype=dtypes.float32,
847          name='inputB')
848      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
849      sess = session.Session()
850
851    # Convert float model.
852    float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
853                                                        [out_tensor])
854    float_converter.experimental_new_converter = enable_mlir_converter
855    float_tflite_model = float_converter.convert()
856    self.assertIsNotNone(float_tflite_model)
857
858    # Convert quantized weights model.
859    quantized_converter = lite.TFLiteConverter.from_session(
860        sess, [in_tensor_1], [out_tensor])
861
862    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
863    quantized_converter.experimental_new_converter = enable_mlir_converter
864    quantized_tflite_model = quantized_converter.convert()
865    self.assertIsNotNone(quantized_tflite_model)
866
867    # Ensure that the quantized weights tflite model is smaller.
868    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
869
870  @parameterized.named_parameters(
871      ('EnableMlirConverter', True),  # enable mlir
872      ('DisableMlirConverter', False))  # disable mlir
873  def testQuantizeDynamicRangeDeprecatedPostTrainingQuantizeAttribute(
874      self, enable_mlir_converter):
875    with ops.Graph().as_default():
876      in_tensor_1 = array_ops.placeholder(
877          shape=[33, 33], dtype=dtypes.float32, name='inputA')
878      in_tensor_2 = constant_op.constant(
879          np.random.uniform(low=-10., high=10., size=(33, 33)),
880          shape=[33, 33],
881          dtype=dtypes.float32,
882          name='inputB')
883      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
884      sess = session.Session()
885
886    quantized_converter = lite.TFLiteConverter.from_session(
887        sess, [in_tensor_1], [out_tensor])
888    self.assertFalse(quantized_converter.post_training_quantize)
889    quantized_converter.experimental_new_converter = enable_mlir_converter
890
891    quantized_converter.post_training_quantize = True
892    self.assertTrue(quantized_converter.post_training_quantize)
893    self.assertEqual(quantized_converter.optimizations, [lite.Optimize.DEFAULT])
894
895    quantized_tflite_model = quantized_converter.convert()
896    self.assertIsNotNone(quantized_tflite_model)
897
898  def _getIntegerQuantizeModel(self):
899    np.random.seed(0)
900    inp = array_ops.placeholder(
901        dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input')
902    conv = nn_ops.conv2d(
903        inp,
904        filter=array_ops.ones([3, 3, 3, 16]),
905        strides=[1, 1, 1, 1],
906        padding='SAME')
907    output = nn_ops.relu(conv, name='output')
908
909    def calibration_gen():
910      for _ in range(5):
911        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
912
913    return (inp, output, calibration_gen)
914
915  @parameterized.named_parameters(
916      ('EnableMlirConverter', True),  # enable mlir
917      ('DisableMlirConverter', False))  # disable mlir
918  def testQuantizeInt8AllowFloat(self, enable_mlir_converter):
919    with ops.Graph().as_default():
920      inp, output, calibration_gen = self._getIntegerQuantizeModel()
921      sess = session.Session()
922
923    # Convert float model.
924    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
925    float_tflite_model = float_converter.convert()
926    self.assertIsNotNone(float_tflite_model)
927
928    # Convert quantized model.
929    quantized_converter = lite.TFLiteConverter.from_session(
930        sess, [inp], [output])
931    quantized_converter.experimental_new_converter = enable_mlir_converter
932    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
933    quantized_converter.representative_dataset = calibration_gen
934    quantized_tflite_model = quantized_converter.convert()
935    self.assertIsNotNone(quantized_tflite_model)
936
937    # The default input and output types should be float.
938    interpreter = Interpreter(model_content=quantized_tflite_model)
939    interpreter.allocate_tensors()
940    input_details = interpreter.get_input_details()
941    self.assertLen(input_details, 1)
942    self.assertEqual(np.float32, input_details[0]['dtype'])
943    output_details = interpreter.get_output_details()
944    self.assertLen(output_details, 1)
945    self.assertEqual(np.float32, output_details[0]['dtype'])
946
947    # Ensure that the quantized weights tflite model is smaller.
948    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
949
950  @parameterized.named_parameters(
951      # Quantize model to Int8: with enable mlir
952      ('UseTfliteBuiltinsIntEnableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8
953                                         ], True),
954      # Quantize model to Int8: with disable mlir
955      ('UseTfliteBuiltinsIntDisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8
956                                          ], False),
957      # Quantize model to Int16: with disable mlir
958      ('UseTfliteBuiltinsInt16DisableMLIR', [
959          lite.OpsSet
960          .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
961      ], False),
962      ('UseTfliteBuiltinsInt16EnableMLIR', [
963          lite.OpsSet
964          .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
965      ], True))
966  def testQuantizeInt8And16x8(self, supported_ops, enable_mlir_converter):
967    with ops.Graph().as_default():
968      inp, output, calibration_gen = self._getIntegerQuantizeModel()
969      sess = session.Session()
970
971    # Convert float model.
972    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
973    float_converter.experimental_new_converter = enable_mlir_converter
974    float_tflite_model = float_converter.convert()
975    self.assertIsNotNone(float_tflite_model)
976
977    # Convert model by specifying target spec (instead of optimizations), since
978    # when targeting an integer only backend, quantization is mandatory.
979    quantized_converter = lite.TFLiteConverter.from_session(
980        sess, [inp], [output])
981    quantized_converter.experimental_new_converter = enable_mlir_converter
982    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
983    quantized_converter.target_spec.supported_ops = supported_ops
984    quantized_converter.representative_dataset = calibration_gen
985    quantized_tflite_model = quantized_converter.convert()
986    self.assertIsNotNone(quantized_tflite_model)
987
988    # The default input and output types should be float.
989    interpreter = Interpreter(model_content=quantized_tflite_model)
990    interpreter.allocate_tensors()
991    input_details = interpreter.get_input_details()
992    self.assertLen(input_details, 1)
993    self.assertEqual(np.float32, input_details[0]['dtype'])
994    output_details = interpreter.get_output_details()
995    self.assertLen(output_details, 1)
996    self.assertEqual(np.float32, output_details[0]['dtype'])
997
998    # Ensure that the quantized weights tflite model is smaller.
999    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1000
1001  @parameterized.named_parameters(
1002      ('EnableMlirConverter', True),  # enable mlir
1003      ('DisableMlirConverter', False))  # disable mlir
1004  def testQuantizeInt8InputOutput(self, enable_mlir_converter):
1005    with ops.Graph().as_default():
1006      inp, output, calibration_gen = self._getIntegerQuantizeModel()
1007      sess = session.Session()
1008
1009    # Convert float model.
1010    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1011    float_converter.experimental_new_converter = enable_mlir_converter
1012    float_tflite_model = float_converter.convert()
1013    self.assertIsNotNone(float_tflite_model)
1014
1015    # Convert quantized weights model.
1016    quantized_converter = lite.TFLiteConverter.from_session(
1017        sess, [inp], [output])
1018    quantized_converter.experimental_new_converter = enable_mlir_converter
1019    quantized_converter.inference_input_type = dtypes.int8
1020    quantized_converter.inference_output_type = dtypes.int8
1021    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1022    quantized_converter.representative_dataset = calibration_gen
1023    quantized_tflite_model = quantized_converter.convert()
1024    self.assertIsNotNone(quantized_tflite_model)
1025
1026    # The input and output types should be int8.
1027    interpreter = Interpreter(model_content=quantized_tflite_model)
1028    interpreter.allocate_tensors()
1029    input_details = interpreter.get_input_details()
1030    self.assertLen(input_details, 1)
1031    self.assertEqual(np.int8, input_details[0]['dtype'])
1032    output_details = interpreter.get_output_details()
1033    self.assertLen(output_details, 1)
1034    self.assertEqual(np.int8, output_details[0]['dtype'])
1035
1036    # Ensure that the quantized weights tflite model is smaller.
1037    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1038
1039  @parameterized.named_parameters(
1040      ('EnableMlirConverter', True),  # enable mlir
1041      ('DisableMlirConverter', False))  # disable mlir
1042  def testInvalidQuantizeInt8(self, enable_mlir_converter):
1043    np.random.seed(0)
1044    with ops.Graph().as_default():
1045      # We need the tensor to have more than 1024 elements for quantize_weights
1046      # to kick in. Thus, the [33, 33] shape.
1047      in_tensor_1 = array_ops.placeholder(
1048          shape=[33, 33], dtype=dtypes.float32, name='inputA')
1049      in_tensor_2 = constant_op.constant(
1050          np.random.uniform(low=-10., high=10., size=(33, 33)),
1051          shape=[33, 33],
1052          dtype=dtypes.float32,
1053          name='inputB')
1054      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
1055      sess = session.Session()
1056
1057    # Attempt to convert to quantized weights model.
1058    quantized_converter = lite.TFLiteConverter.from_session(
1059        sess, [in_tensor_1], [out_tensor])
1060    quantized_converter.experimental_new_converter = enable_mlir_converter
1061    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1062    # Restricting to int8 type only
1063    quantized_converter.target_spec.supported_types = [dtypes.int8]
1064    # A representative dataset is required for full fixed point quantization.
1065    with self.assertRaises(ValueError) as error:
1066      quantized_converter.convert()
1067    self.assertEqual(
1068        'representative_dataset is required when specifying '
1069        'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
1070
1071  @parameterized.named_parameters(
1072      ('EnableMlirConverter', True),  # enable mlir
1073      ('DisableMlirConverter', False))  # disable mlir
1074  def testQuantizeUInt8(self, enable_mlir_converter):
1075    with ops.Graph().as_default():
1076      in_tensor_1 = array_ops.placeholder(
1077          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
1078      in_tensor_2 = array_ops.placeholder(
1079          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
1080      out_tensor = array_ops.fake_quant_with_min_max_args(
1081          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
1082      sess = session.Session()
1083
1084    # Convert model and ensure model is not None.
1085    converter = lite.TFLiteConverter.from_session(sess,
1086                                                  [in_tensor_1, in_tensor_2],
1087                                                  [out_tensor])
1088    converter.inference_type = dtypes.uint8
1089    converter.quantized_input_stats = {
1090        'inputA': (0., 1.),
1091        'inputB': (0., 1.)
1092    }  # mean, std_dev
1093    converter.experimental_new_converter = enable_mlir_converter
1094    tflite_model = converter.convert()
1095    self.assertIsNotNone(tflite_model)
1096
1097    # Check values from converted model.
1098    interpreter = Interpreter(model_content=tflite_model)
1099    interpreter.allocate_tensors()
1100
1101    input_details = interpreter.get_input_details()
1102    self.assertLen(input_details, 2)
1103    self.assertEqual('inputA', input_details[0]['name'])
1104    self.assertEqual(np.uint8, input_details[0]['dtype'])
1105    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1106    self.assertEqual((1., 0.), input_details[0]['quantization'])
1107
1108    self.assertEqual('inputB', input_details[1]['name'])
1109    self.assertEqual(np.uint8, input_details[1]['dtype'])
1110    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
1111    self.assertEqual((1., 0.), input_details[1]['quantization'])
1112
1113    output_details = interpreter.get_output_details()
1114    self.assertLen(output_details, 1)
1115    self.assertEqual(np.uint8, output_details[0]['dtype'])
1116    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1117    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale
1118
1119  def testQuantizeUInt8UsingDefaultRangeStats(self):
1120    with ops.Graph().as_default():
1121      in_tensor = array_ops.placeholder(
1122          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1123      out_tensor = in_tensor + in_tensor
1124      sess = session.Session()
1125
1126    # Convert model and ensure model is not None.
1127    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1128                                                  [out_tensor])
1129    converter.inference_type = dtypes.uint8
1130    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
1131    converter.default_ranges_stats = (0, 6)  # min, max
1132    tflite_model = converter.convert()
1133    self.assertIsNotNone(tflite_model)
1134
1135    # Check values from converted model.
1136    interpreter = Interpreter(model_content=tflite_model)
1137    interpreter.allocate_tensors()
1138
1139    input_details = interpreter.get_input_details()
1140    self.assertLen(input_details, 1)
1141    self.assertEqual('Placeholder', input_details[0]['name'])
1142    self.assertEqual(np.uint8, input_details[0]['dtype'])
1143    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1144    self.assertEqual((1., 0.), input_details[0]['quantization'])
1145
1146    output_details = interpreter.get_output_details()
1147    self.assertLen(output_details, 1)
1148    self.assertEqual('add', output_details[0]['name'])
1149    self.assertEqual(np.uint8, output_details[0]['dtype'])
1150    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1151    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale
1152
1153  @parameterized.named_parameters(
1154      # Quantize to Float16 even if rep data provided.
1155      ('UseRepresentativeData', True, False, True, False, False, False, False,
1156       False),
1157      # Quantize to Float16 if no rep data provided.
1158      ('NoRepresentativeData', False, False, True, False, False, False, False,
1159       False),
1160      # Quantize to Float16 and set Float16Accumulation
1161      ('SpecifyFloat16Accumulation', False, False, True, True, False, False,
1162       False, False),
1163      # Post training quantization if both rep data and int8 included.
1164      ('UseSampleDataIncludeInt8', True, True, False, False, False, True, False,
1165       False),
1166      # Quantize to Float16 even if rep data provided with mlir.
1167      ('UseRepresentativeDataMlir', True, False, True, False, False, False,
1168       True, False),
1169      # Quantize to Float16 if no rep data provided with mlir.
1170      ('NoRepresentativeDataMlir', False, False, True, False, False, False,
1171       True, False),
1172      # Post training quantization if both rep data and int8 included with mlir.
1173      ('SampleDataIncludeInt8Mlir', True, True, False, False, False, True, True,
1174       False),
1175      # Same as above, but using MLIR quantizer
1176      ('SampleDataIncludeInt8MlirQuant', True, True, False, False, False, True,
1177       True, True))
1178  def testQuantizeFloat16(self, use_rep_data, include_int8,
1179                          is_float16_quantized, is_float16_accumulation,
1180                          is_error, is_post_training_quantized,
1181                          enable_mlir_converter, enable_mlir_quantizer):
1182    with ops.Graph().as_default():
1183      inp, output, calibration_gen = self._getIntegerQuantizeModel()
1184      sess = session.Session()
1185
1186    bias_idx = 1 if enable_mlir_converter else 0
1187    bias_name = 'Conv2D' if enable_mlir_converter else 'Conv2D_bias'
1188
1189    # Convert float model.
1190    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1191    float_converter.experimental_new_converter = enable_mlir_converter
1192    float_tflite_model = float_converter.convert()
1193    self.assertIsNotNone(float_tflite_model)
1194    interpreter = Interpreter(model_content=float_tflite_model)
1195    interpreter.allocate_tensors()
1196    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
1197                     bias_name)
1198    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
1199                     dtypes.float32)
1200
1201    # MLIR quantizer has different bias index.
1202    if enable_mlir_quantizer:
1203      bias_idx = 2
1204
1205    # Convert model to quantized version
1206    quantized_converter = lite.TFLiteConverter.from_session(
1207        sess, [inp], [output])
1208    quantized_converter.experimental_new_converter = enable_mlir_converter
1209    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
1210    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1211    quantized_converter.target_spec.supported_types = [dtypes.float16]
1212    if include_int8:
1213      quantized_converter.target_spec.supported_types.append(dtypes.int8)
1214    if use_rep_data:
1215      quantized_converter.representative_dataset = calibration_gen
1216    if is_float16_accumulation:
1217      quantized_converter.target_spec.experimental_supported_accumulation_type = dtypes.float16  # pylint: disable=line-too-long
1218
1219    if is_error:
1220      with self.assertRaises(ValueError) as error:
1221        quantized_converter.convert()
1222      self.assertEqual(
1223          'representative_dataset is required when specifying '
1224          'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
1225
1226    else:
1227      quantized_tflite_model = quantized_converter.convert()
1228      self.assertIsNotNone(quantized_tflite_model)
1229      interpreter = Interpreter(model_content=quantized_tflite_model)
1230      interpreter.allocate_tensors()
1231      self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
1232                       bias_name)
1233
1234      if is_float16_quantized:
1235        # Verify that bias constant is float16 type.
1236        self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
1237                         dtypes.float16)
1238      elif is_post_training_quantized:
1239        # Verify that bias constants is int32 type.
1240        self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
1241                         dtypes.int32)
1242      else:
1243        raise ValueError('Invalid test options.')
1244
1245  @parameterized.named_parameters(
1246      ('EnableMlirConverter', True),  # enable mlir
1247      ('DisableMlirConverter', False))  # disable mlir
1248  def testInvalidQuantizeFloat16(self, enable_mlir_converter):
1249    with ops.Graph().as_default():
1250      inp, output, _ = self._getIntegerQuantizeModel()
1251      sess = session.Session()
1252
1253    # Specify float16 quantization
1254    quantized_converter = lite.TFLiteConverter.from_session(
1255        sess, [inp], [output])
1256    quantized_converter.experimental_new_converter = enable_mlir_converter
1257    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1258    quantized_converter.target_spec.supported_types = [dtypes.float16]
1259    # Specify only int8 builtin ops
1260    quantized_converter.target_spec.supported_ops = [
1261        lite.OpsSet.TFLITE_BUILTINS_INT8
1262    ]
1263    with self.assertRaises(ValueError) as error:
1264      quantized_converter.convert()
1265    self.assertEqual(
1266        'TFLITE_BUILTINS_INT8 requires smallest supported type to be INT8.',
1267        str(error.exception))
1268
1269  @parameterized.named_parameters(('InferenceType_INT8', dtypes.int8),
1270                                  ('InferenceType_UINT8', dtypes.uint8))
1271  def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
1272    with ops.Graph().as_default():
1273      in_tensor = array_ops.placeholder(
1274          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1275      out_tensor = array_ops.fake_quant_with_min_max_args(
1276          in_tensor + in_tensor, min=0., max=1.)
1277      sess = session.Session()
1278
1279    quantized_converter = lite.TFLiteConverter.from_session(
1280        sess, [in_tensor], [out_tensor])
1281
1282    with self.assertRaises(ValueError) as error:
1283      quantized_converter.inference_type = quantized_type
1284      quantized_converter.convert()
1285    self.assertEqual(
1286        'The `quantized_input_stats` flag must be defined when either '
1287        '`inference_type` flag or `inference_input_type` flag is set to '
1288        'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
1289        '`inference_input_type=None`.'.format(quantized_type.name),
1290        str(error.exception))
1291
1292    with self.assertRaises(ValueError) as error:
1293      quantized_converter.inference_type = dtypes.float32
1294      quantized_converter.inference_input_type = quantized_type
1295      quantized_converter.convert()
1296    self.assertEqual(
1297        'The `quantized_input_stats` flag must be defined when either '
1298        '`inference_type` flag or `inference_input_type` flag is set to '
1299        'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
1300        '`inference_input_type=tf.{}`.'.format(quantized_type.name),
1301        str(error.exception))
1302
1303    quantized_converter.inference_type = quantized_type
1304    quantized_converter.inference_input_type = quantized_type
1305
1306    input_arrays = quantized_converter.get_input_arrays()
1307    quantized_converter.quantized_input_stats = {input_arrays[0]: (0., 1.)}
1308    quantized_converter.convert()
1309
1310  def testInvalidQuantizeQATModelMissingInputStats(self):
1311    with ops.Graph().as_default():
1312      in_tensor_1 = array_ops.placeholder(
1313          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
1314      in_tensor_2 = array_ops.placeholder(
1315          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
1316      out_tensor = array_ops.fake_quant_with_min_max_args(
1317          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
1318      sess = session.Session()
1319
1320    # Convert model and ensure model is not None.
1321    converter = lite.TFLiteConverter.from_session(sess,
1322                                                  [in_tensor_1, in_tensor_2],
1323                                                  [out_tensor])
1324    converter.inference_type = dtypes.uint8
1325    converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
1326    with self.assertRaises(ValueError) as error:
1327      converter.convert()
1328    self.assertEqual(
1329        'Quantization input stats are not available for input tensors '
1330        '\'inputB\'.', str(error.exception))
1331
1332  def testTrainingTimeAndPostTrainingCalibrateAndQuantize(self):
1333    with ops.Graph().as_default():
1334      inp, output, calibration_gen = self._getIntegerQuantizeModel()
1335      sess = session.Session()
1336
1337    # Convert float model.
1338    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1339    float_tflite_model = float_converter.convert()
1340    self.assertIsNotNone(float_tflite_model)
1341
1342    converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1343
1344    # extra flags to trigger training time quantization conversion
1345    converter.inference_type = dtypes.int8
1346    converter.inference_input_type = dtypes.float32
1347    converter.inference_output_type = dtypes.float32
1348    input_arrays = converter.get_input_arrays()
1349    converter.quantized_input_stats = {input_arrays[0]: (0., 1.)}
1350    # trigger post-training quantization
1351    converter.optimizations = [lite.Optimize.DEFAULT]
1352    converter.representative_dataset = calibration_gen
1353    converter.experimental_new_quantizer = True
1354    quantized_tflite_model = converter.convert()
1355    self.assertIsNotNone(quantized_tflite_model)
1356    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1357
1358    # calibration only api
1359    converter._experimental_calibrate_only = True
1360    calibrated_tflite = converter.convert()
1361    quantized_tflite_model = mlir_quantize(
1362        calibrated_tflite, fully_quantize=True)
1363    interpreter = Interpreter(model_content=quantized_tflite_model)
1364    interpreter.allocate_tensors()
1365    input_details = interpreter.get_input_details()
1366    self.assertEqual(np.int8, input_details[0]['dtype'])
1367    self.assertEqual((1., 0.), input_details[0]['quantization'])
1368
1369    output_details = interpreter.get_output_details()
1370    self.assertEqual(np.int8, output_details[0]['dtype'])
1371
1372  def testFloatTocoConverter(self):
1373    """Tests deprecated test TocoConverter."""
1374    with ops.Graph().as_default():
1375      in_tensor = array_ops.placeholder(
1376          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1377      out_tensor = in_tensor + in_tensor
1378      sess = session.Session()
1379
1380    # Convert model and ensure model is not None.
1381    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
1382    tflite_model = converter.convert()
1383    self.assertIsNotNone(tflite_model)
1384
1385    # Ensure the interpreter is able to load.
1386    interpreter = Interpreter(model_content=tflite_model)
1387    interpreter.allocate_tensors()
1388
1389  def testMultipleOutputNodeNames(self):
1390    """Tests converting a graph with an op that have multiple outputs."""
1391    with ops.Graph().as_default():
1392      input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
1393      out0, out1, out2, out3 = array_ops.split(
1394          input_tensor, [1, 1, 1, 1], axis=0)
1395      sess = session.Session()
1396
1397    # Convert model and ensure model is not None.
1398    converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
1399                                                  [out0, out1, out2, out3])
1400    tflite_model = converter.convert()
1401    self.assertIsNotNone(tflite_model)
1402
1403    # Check values from converted model.
1404    interpreter = Interpreter(model_content=tflite_model)
1405    interpreter.allocate_tensors()
1406
1407    input_details = interpreter.get_input_details()
1408    self.assertLen(input_details, 1)
1409    interpreter.set_tensor(input_details[0]['index'],
1410                           np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
1411    interpreter.invoke()
1412
1413    output_details = interpreter.get_output_details()
1414    self.assertLen(output_details, 4)
1415    self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
1416    self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
1417    self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
1418    self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
1419
1420  @parameterized.named_parameters(
1421      ('EnableMlirConverter', True),  # enable mlir
1422      ('DisableMlirConverter', False))  # disable mlir
1423  @test_util.run_in_graph_and_eager_modes
1424  def testFunctions(self, enable_mlir_converter):
1425    """Tests tf.function in 1.X."""
1426
1427    @def_function.function
1428    def plus_placeholder(x, placeholder):
1429      return x + placeholder
1430
1431    with ops.Graph().as_default():
1432      placeholder = array_ops.placeholder(
1433          dtype=dtypes.float32, shape=[1], name='input')
1434      variable_node = variables.Variable(1.0, name='variable_node')
1435      defun_node = plus_placeholder(variable_node, placeholder)
1436      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
1437
1438      # Initialize variables in the model.
1439      sess = session.Session()
1440      sess.run(variables.variables_initializer([variable_node]))
1441
1442    # Convert model and ensure model is not None.
1443    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
1444                                                  [output_node])
1445    converter.experimental_new_converter = enable_mlir_converter
1446    tflite_model = converter.convert()
1447    self.assertIsNotNone(tflite_model)
1448
1449    # Check values from converted model.
1450    interpreter = Interpreter(model_content=tflite_model)
1451    interpreter.allocate_tensors()
1452
1453    input_details = interpreter.get_input_details()
1454    self.assertLen(input_details, 1)
1455    self.assertEqual('input', input_details[0]['name'])
1456    self.assertEqual(np.float32, input_details[0]['dtype'])
1457    self.assertAllEqual([1], input_details[0]['shape'])
1458    self.assertEqual((0., 0.), input_details[0]['quantization'])
1459
1460    output_details = interpreter.get_output_details()
1461    self.assertLen(output_details, 1)
1462    self.assertEqual('output_node', output_details[0]['name'])
1463    self.assertEqual(np.float32, output_details[0]['dtype'])
1464    self.assertAllEqual([1], output_details[0]['shape'])
1465    self.assertEqual((0., 0.), output_details[0]['quantization'])
1466
1467  def testInferenceInputOutputTypeFloatDefault(self):
1468    with ops.Graph().as_default():
1469      in_tensor = array_ops.placeholder(
1470          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1471      out_tensor = in_tensor + in_tensor
1472      sess = session.Session()
1473
1474    # Convert model and ensure model is not None.
1475    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1476                                                  [out_tensor])
1477    tflite_model = converter.convert()
1478    self.assertIsNotNone(tflite_model)
1479
1480    # Check values from converted model.
1481    interpreter = Interpreter(model_content=tflite_model)
1482    interpreter.allocate_tensors()
1483
1484    input_details = interpreter.get_input_details()
1485    self.assertLen(input_details, 1)
1486    self.assertEqual('Placeholder', input_details[0]['name'])
1487    self.assertEqual(np.float32, input_details[0]['dtype'])
1488    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1489
1490    output_details = interpreter.get_output_details()
1491    self.assertLen(output_details, 1)
1492    self.assertEqual('add', output_details[0]['name'])
1493    self.assertEqual(np.float32, output_details[0]['dtype'])
1494    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1495
1496  def testInferenceInputOutputTypeQuantizedUint8Default(self):
1497    with ops.Graph().as_default():
1498      in_tensor = array_ops.placeholder(
1499          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1500      out_tensor = array_ops.fake_quant_with_min_max_args(
1501          in_tensor + in_tensor, min=0., max=1., name='output')
1502      sess = session.Session()
1503
1504    # Convert model and ensure model is not None.
1505    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1506                                                  [out_tensor])
1507    converter.inference_type = dtypes.uint8
1508    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
1509    tflite_model = converter.convert()
1510    self.assertIsNotNone(tflite_model)
1511
1512    # Check values from converted model.
1513    interpreter = Interpreter(model_content=tflite_model)
1514    interpreter.allocate_tensors()
1515
1516    input_details = interpreter.get_input_details()
1517    self.assertLen(input_details, 1)
1518    self.assertEqual('Placeholder', input_details[0]['name'])
1519    self.assertEqual(np.uint8, input_details[0]['dtype'])
1520    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1521
1522    output_details = interpreter.get_output_details()
1523    self.assertLen(output_details, 1)
1524    self.assertEqual('output', output_details[0]['name'])
1525    self.assertEqual(np.uint8, output_details[0]['dtype'])
1526    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1527
1528  def testReusingConverterWithDifferentPostTrainingQuantization(self):
1529    with ops.Graph().as_default():
1530      in_tensor = array_ops.placeholder(
1531          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1532      out_tensor = array_ops.fake_quant_with_min_max_args(
1533          in_tensor + in_tensor, min=0., max=1., name='output')
1534      sess = session.Session()
1535
1536    # Convert model and ensure model is not None.
1537    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1538                                                  [out_tensor])
1539
1540    converter.post_training_quantize = True
1541    tflite_model = converter.convert()
1542    self.assertIsNotNone(tflite_model)
1543
1544    converter.post_training_quantize = False
1545    tflite_model = converter.convert()
1546    self.assertIsNotNone(tflite_model)
1547
1548  def testResizeWithShape(self):
1549    with ops.Graph().as_default():
1550      # Construct a graph with a dynamically shapped input and an internal node
1551      # that relies on the output of that input's shape.
1552      in_tensor = array_ops.placeholder(
1553          shape=[None, None], dtype=dtypes.float32)
1554      in_tensor2 = [[1, 2], [3, 4]]
1555      out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor))
1556      sess = session.Session()
1557
1558    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1559                                                  [out_tensor])
1560    tflite_model = converter.convert()
1561
1562    # Check values from converted model.
1563    interpreter = Interpreter(model_content=tflite_model)
1564    input_details = interpreter.get_input_details()
1565    self.assertLen(input_details, 1)
1566    self.assertAllEqual([1, 1], input_details[0]['shape'])
1567    self.assertAllEqual([-1, -1], input_details[0]['shape_signature'])
1568
1569    # Resize tensor and invoke.
1570    interpreter.resize_tensor_input(0, [4])
1571    interpreter.allocate_tensors()
1572    interpreter.invoke()
1573
1574    # The output should be reshaped properly according to the resized input.
1575    output_details = interpreter.get_output_details()
1576    self.assertLen(output_details, 1)
1577    self.assertEqual(np.int32, output_details[0]['dtype'])
1578    self.assertAllEqual([4], output_details[0]['shape'])
1579    output_data = interpreter.get_tensor(output_details[0]['index'])
1580    self.assertAllEqual([1, 2, 3, 4], output_data)
1581
1582  def testResizingIntermediateDynamicTensor(self):
1583    # This is a regression test for the case where shape of dynamic output
1584    # tensors changes between invocations.
1585    # See also https://github.com/tensorflow/tensorflow/issues/26549
1586    with ops.Graph().as_default():
1587      input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32)
1588      input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32)
1589
1590      # The bug is triggered only when dynamic tensor is intermediate. Putting
1591      # some other ops around it.
1592      neg = math_ops.negative(input2_tensor)
1593      padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32)
1594      output_tensor = array_ops.pad(input_tensor, padding) + neg
1595
1596      sess = session.Session()
1597
1598    converter = lite.TFLiteConverter.from_session(
1599        sess, [input_tensor, padding, input2_tensor], [output_tensor])
1600    tflite_model = converter.convert()
1601
1602    interpreter = Interpreter(model_content=tflite_model)
1603    interpreter.allocate_tensors()
1604
1605    input_details = interpreter.get_input_details()
1606    interpreter.set_tensor(input_details[1]['index'],
1607                           np.array([[1, 1], [1, 1]], dtype=np.int32))
1608    interpreter.invoke()
1609
1610    # Without the fix, invocation will fail when changing the shape of
1611    # intermediate dynamic tensors.
1612    interpreter.set_tensor(input_details[1]['index'],
1613                           np.array([[2, 2], [2, 2]], dtype=np.int32))
1614    interpreter.invoke()
1615
1616  def testGraphDebugInfo(self):
1617    """Test a session has debug info captured."""
1618
1619    @def_function.function
1620    def plus_placeholder(x, placeholder):
1621      return x + placeholder
1622
1623    with ops.Graph().as_default():
1624      placeholder = array_ops.placeholder(
1625          dtype=dtypes.float32, shape=[1], name='input')
1626      variable_node = variables.Variable(1.0, name='variable_node')
1627      defun_node = plus_placeholder(variable_node, placeholder)
1628      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
1629
1630      # Initialize variables in the model.
1631      sess = session.Session()
1632      sess.run(variables.variables_initializer([variable_node]))
1633
1634    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
1635                                                  [output_node])
1636    converter.convert()
1637    self.assertValidDebugInfo(converter._debug_info)
1638
1639    # Check the add node in the inlined function is included.
1640    func = sess.graph.as_graph_def().library.function[0].signature.name
1641    self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces)
1642
1643  def testOutputOnlyModel(self):
1644    with ops.Graph().as_default():
1645      out_tensor = random_ops.random_normal(shape=[3])
1646      sess = session.Session()
1647
1648    # Convert model and ensure model is not None.
1649    converter = lite.TFLiteConverter.from_session(sess, [], [out_tensor])
1650    converter.target_spec.supported_ops = [
1651        lite.OpsSet.TFLITE_BUILTINS,
1652        lite.OpsSet.SELECT_TF_OPS,
1653    ]
1654
1655    # Empty input array is a valid input.
1656    self.assertTrue(converter._has_valid_tensors())
1657
1658    tflite_model = converter.convert()
1659    self.assertIsNotNone(tflite_model)
1660
1661
1662class FromFrozenGraphFile(LiteTest):
1663
1664  def testFloat(self):
1665    with ops.Graph().as_default():
1666      in_tensor = array_ops.placeholder(
1667          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1668      _ = in_tensor + in_tensor
1669      sess = session.Session()
1670
1671    # Write graph to file.
1672    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1673    write_graph(sess.graph_def, '', graph_def_file, False)
1674    sess.close()
1675
1676    # Convert model and ensure model is not None.
1677    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
1678                                                       ['Placeholder'], ['add'])
1679    tflite_model = converter.convert()
1680    self.assertIsNotNone(tflite_model)
1681
1682    # Check values from converted model.
1683    interpreter = Interpreter(model_content=tflite_model)
1684    interpreter.allocate_tensors()
1685
1686    input_details = interpreter.get_input_details()
1687    self.assertLen(input_details, 1)
1688    self.assertEqual('Placeholder', input_details[0]['name'])
1689    self.assertEqual(np.float32, input_details[0]['dtype'])
1690    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1691    self.assertEqual((0., 0.), input_details[0]['quantization'])
1692
1693    output_details = interpreter.get_output_details()
1694    self.assertLen(output_details, 1)
1695    self.assertEqual('add', output_details[0]['name'])
1696    self.assertEqual(np.float32, output_details[0]['dtype'])
1697    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1698    self.assertEqual((0., 0.), output_details[0]['quantization'])
1699
1700  def testFloatWithShapesArray(self):
1701    """Test a shape overriding case."""
1702    with ops.Graph().as_default():
1703      in_tensor = array_ops.placeholder(
1704          shape=[None, 16, 16, 3], dtype=dtypes.float32)
1705      _ = in_tensor + in_tensor
1706      sess = session.Session()
1707
1708    # Write graph to file.
1709    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1710    write_graph(sess.graph_def, '', graph_def_file, False)
1711    sess.close()
1712
1713    # Convert model and ensure model is not None.
1714    converter = lite.TFLiteConverter.from_frozen_graph(
1715        graph_def_file, ['Placeholder'], ['add'],
1716        input_shapes={'Placeholder': [2, 16, 16, 3]})
1717    tflite_model = converter.convert()
1718    self.assertIsNotNone(tflite_model)
1719
1720    # Check values from converted model.
1721    interpreter = Interpreter(model_content=tflite_model)
1722    interpreter.allocate_tensors()
1723
1724    input_details = interpreter.get_input_details()
1725    self.assertLen(input_details, 1)
1726    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
1727
1728  def testInvalidShapesArray(self):
1729    """Test an invalid shape overriding case, which has a wrong input name."""
1730    with ops.Graph().as_default():
1731      in_tensor = array_ops.placeholder(
1732          shape=[None, 16, 16, 3], dtype=dtypes.float32)
1733      _ = in_tensor + in_tensor
1734      sess = session.Session()
1735
1736    # Write graph to file.
1737    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1738    write_graph(sess.graph_def, '', graph_def_file, False)
1739    sess.close()
1740
1741    # Convert model and ensure model is not None.
1742    with self.assertRaises(ValueError):
1743      lite.TFLiteConverter.from_frozen_graph(
1744          graph_def_file, ['Placeholder'], ['add'],
1745          input_shapes={'wrong_input': [2, 16, 16, 3]})
1746
1747  def testPartialShapesArray(self):
1748    """Test a shape overriding case, with the only one input among two."""
1749    with ops.Graph().as_default():
1750      a = array_ops.placeholder(
1751          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='a')
1752      b = array_ops.placeholder(
1753          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='b')
1754      _ = math_ops.add(a, b, name='add')
1755      sess = session.Session()
1756
1757    # Write graph to file.
1758    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1759    write_graph(sess.graph_def, '', graph_def_file, False)
1760    sess.close()
1761
1762    # Convert model and ensure model is not None.
1763    converter = lite.TFLiteConverter.from_frozen_graph(
1764        graph_def_file, ['a', 'b'], ['add'], input_shapes={'a': [2, 16, 16, 3]})
1765    tflite_model = converter.convert()
1766    self.assertIsNotNone(tflite_model)
1767
1768    # Check values from converted model.
1769    interpreter = Interpreter(model_content=tflite_model)
1770    interpreter.allocate_tensors()
1771
1772    input_details = interpreter.get_input_details()
1773    self.assertLen(input_details, 2)
1774    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
1775    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
1776
1777  def testFreezeGraph(self):
1778    with ops.Graph().as_default():
1779      in_tensor = array_ops.placeholder(
1780          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1781      var = variable_scope.get_variable(
1782          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
1783      _ = in_tensor + var
1784      sess = session.Session()
1785
1786    # Write graph to file.
1787    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1788    write_graph(sess.graph_def, '', graph_def_file, False)
1789    sess.close()
1790
1791    # Ensure the graph with variables cannot be converted.
1792    with self.assertRaises(ValueError) as error:
1793      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
1794                                             ['add'])
1795    self.assertEqual('Please freeze the graph using freeze_graph.py.',
1796                     str(error.exception))
1797
1798  def testPbtxt(self):
1799    with ops.Graph().as_default():
1800      in_tensor = array_ops.placeholder(
1801          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1802      _ = in_tensor + in_tensor
1803      sess = session.Session()
1804
1805    # Write graph to file.
1806    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
1807    write_graph(sess.graph_def, '', graph_def_file, True)
1808    sess.close()
1809
1810    # Convert model and ensure model is not None.
1811    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
1812                                                       ['Placeholder'], ['add'])
1813    tflite_model = converter.convert()
1814    self.assertIsNotNone(tflite_model)
1815
1816    # Check values from converted model.
1817    interpreter = Interpreter(model_content=tflite_model)
1818    interpreter.allocate_tensors()
1819
1820    input_details = interpreter.get_input_details()
1821    self.assertLen(input_details, 1)
1822    self.assertEqual('Placeholder', input_details[0]['name'])
1823    self.assertEqual(np.float32, input_details[0]['dtype'])
1824    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1825    self.assertEqual((0., 0.), input_details[0]['quantization'])
1826
1827    output_details = interpreter.get_output_details()
1828    self.assertLen(output_details, 1)
1829    self.assertEqual('add', output_details[0]['name'])
1830    self.assertEqual(np.float32, output_details[0]['dtype'])
1831    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1832    self.assertEqual((0., 0.), output_details[0]['quantization'])
1833
1834  def testInvalidFileNotFound(self):
1835    with self.assertRaises(IOError) as error:
1836      lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
1837                                             ['add'])
1838    self.assertEqual('File \'invalid_file\' does not exist.',
1839                     str(error.exception))
1840
1841  def testInvalidFileBadData(self):
1842    graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
1843    with gfile.Open(graph_def_file, 'wb') as temp_file:
1844      temp_file.write('bad data')
1845      temp_file.flush()
1846
1847    # Attempts to convert the invalid model.
1848    with self.assertRaises(IOError) as error:
1849      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
1850                                             ['add'])
1851    self.assertEqual(
1852        'Unable to parse input file \'{}\'.'.format(graph_def_file),
1853        str(error.exception))
1854
1855  def testFloatTocoConverter(self):
1856    with ops.Graph().as_default():
1857      in_tensor = array_ops.placeholder(
1858          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1859      _ = in_tensor + in_tensor
1860      sess = session.Session()
1861
1862    # Write graph to file.
1863    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1864    write_graph(sess.graph_def, '', graph_def_file, False)
1865    sess.close()
1866
1867    # Convert model and ensure model is not None.
1868    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
1869                                                     ['Placeholder'], ['add'])
1870    tflite_model = converter.convert()
1871    self.assertIsNotNone(tflite_model)
1872
1873    # Ensure the model is able to load.
1874    interpreter = Interpreter(model_content=tflite_model)
1875    interpreter.allocate_tensors()
1876
1877  def testGraphDebugInfo(self):
1878    """Test a frozen graph doesn't have debug info captured."""
1879    with ops.Graph().as_default():
1880      in_tensor = array_ops.placeholder(
1881          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1882      _ = in_tensor + in_tensor
1883      sess = session.Session()
1884
1885    # Write graph to file.
1886    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1887    write_graph(sess.graph_def, '', graph_def_file, False)
1888    sess.close()
1889
1890    # Convert model and ensure model is not None.
1891    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
1892                                                     ['Placeholder'], ['add'])
1893    converter.convert()
1894    # GraphDebugInfo should be none for frozen graph.
1895    self.assertFalse(converter._debug_info)
1896
1897
1898class FromFrozenGraphObjectDetection(LiteTest):
1899
1900  def _initObjectDetectionArgs(self):
1901    # Initializes the arguments required for the object detection model.
1902    # Looks for the model file which is saved in a different location internally
1903    # and externally.
1904    filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
1905    if not os.path.exists(filename):
1906      filename = os.path.join(
1907          resource_loader.get_root_dir_with_all_resources(),
1908          '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
1909      if not os.path.exists(filename):
1910        raise IOError("File '{0}' does not exist.".format(filename))
1911
1912    self._graph_def_file = filename
1913    self._input_arrays = ['normalized_input_image_tensor']
1914    self._output_arrays = [
1915        'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
1916        'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
1917    ]
1918    self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
1919
1920  def testTFLiteGraphDef(self):
1921    # Tests the object detection model that cannot be loaded in TensorFlow.
1922    self._initObjectDetectionArgs()
1923
1924    converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
1925                                                       self._input_arrays,
1926                                                       self._output_arrays,
1927                                                       self._input_shapes)
1928    converter.allow_custom_ops = True
1929    tflite_model = converter.convert()
1930    self.assertIsNotNone(tflite_model)
1931
1932    # Check values from converted model.
1933    interpreter = Interpreter(model_content=tflite_model)
1934    interpreter.allocate_tensors()
1935
1936    input_details = interpreter.get_input_details()
1937    self.assertLen(input_details, 1)
1938    self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
1939    self.assertEqual(np.float32, input_details[0]['dtype'])
1940    self.assertAllEqual([1, 300, 300, 3], input_details[0]['shape'])
1941    self.assertEqual((0., 0.), input_details[0]['quantization'])
1942
1943    output_details = interpreter.get_output_details()
1944    self.assertLen(output_details, 4)
1945    self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
1946    self.assertEqual(np.float32, output_details[0]['dtype'])
1947    self.assertAllEqual([1, 10, 4], output_details[0]['shape'])
1948    self.assertEqual((0., 0.), output_details[0]['quantization'])
1949
1950    self.assertEqual('TFLite_Detection_PostProcess:1',
1951                     output_details[1]['name'])
1952    self.assertAllEqual([1, 10], output_details[1]['shape'])
1953    self.assertEqual('TFLite_Detection_PostProcess:2',
1954                     output_details[2]['name'])
1955    self.assertAllEqual([1, 10], output_details[2]['shape'])
1956    self.assertEqual('TFLite_Detection_PostProcess:3',
1957                     output_details[3]['name'])
1958    self.assertAllEqual([1], output_details[3]['shape'])
1959
1960  def testTFLiteGraphDefWithControlOutput(self):
1961    with ops.Graph().as_default():
1962      in_tensor = array_ops.placeholder(
1963          shape=[5, 5], dtype=dtypes.float32, name='input')
1964      out_tensor = in_tensor + in_tensor
1965      logging_ops.print_v2(out_tensor)
1966      sess = session.Session()
1967
1968    converter = lite.TFLiteConverter(
1969        sess.graph_def,
1970        input_tensors=None,
1971        output_tensors=None,
1972        input_arrays_with_shape=[('input', [5, 5])],
1973        output_arrays=None,
1974        experimental_debug_info_func=None)
1975    converter._control_output_arrays = ['PrintV2']
1976    converter.target_spec.supported_ops = [
1977        lite.OpsSet.TFLITE_BUILTINS,
1978        lite.OpsSet.SELECT_TF_OPS,
1979    ]
1980    tflite_model = converter.convert()
1981    self.assertIsNotNone(tflite_model)
1982
1983    model = util._convert_model_from_bytearray_to_object(tflite_model)
1984    self.assertEqual(model.operatorCodes[0].builtinCode,
1985                     schema_fb.BuiltinOperator.ADD)
1986    self.assertEqual(model.operatorCodes[1].builtinCode,
1987                     schema_fb.BuiltinOperator.CUSTOM)
1988    self.assertEqual(model.operatorCodes[1].customCode, b'FlexStringFormat')
1989    self.assertEqual(model.operatorCodes[2].builtinCode,
1990                     schema_fb.BuiltinOperator.CUSTOM)
1991    self.assertEqual(model.operatorCodes[2].customCode, b'FlexPrintV2')
1992
1993    # Check values from converted model.
1994    interpreter = Interpreter(model_content=tflite_model)
1995    interpreter.allocate_tensors()
1996
1997    input_details = interpreter.get_input_details()
1998    self.assertLen(input_details, 1)
1999    self.assertEqual('input', input_details[0]['name'])
2000    self.assertEqual(np.float32, input_details[0]['dtype'])
2001    self.assertAllEqual([5, 5], input_details[0]['shape'])
2002    self.assertEqual((0., 0.), input_details[0]['quantization'])
2003
2004    output_details = interpreter.get_output_details()
2005    self.assertLen(output_details, 0)
2006
2007  def testModifyIOToUint8(self):
2008    # Tests the object detection model that cannot be loaded in TensorFlow.
2009    self._initObjectDetectionArgs()
2010
2011    def representative_dataset_gen():
2012      for _ in range(2):
2013        yield [
2014            np.random.uniform(low=0, high=1,
2015                              size=(1, 300, 300, 3)).astype(np.float32)
2016        ]
2017
2018    converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
2019                                                       self._input_arrays,
2020                                                       self._output_arrays,
2021                                                       self._input_shapes)
2022    converter.representative_dataset = representative_dataset_gen
2023    converter.target_spec.supported_ops = {lite.OpsSet.TFLITE_BUILTINS_INT8}
2024    converter.inference_type = dtypes.int8
2025    converter.inference_input_type = dtypes.uint8
2026    converter.inference_output_type = dtypes.uint8
2027    converter.experimental_new_quantizer = True
2028    converter.quantized_input_stats = {
2029        'normalized_input_image_tensor': (0., 1.)
2030    }  # mean, std_dev
2031    converter.allow_custom_ops = True
2032    tflite_model = converter.convert()
2033
2034    self.assertIsNotNone(tflite_model)
2035
2036    model = util._convert_model_from_bytearray_to_object(tflite_model)
2037    quant_opcode_idxs = util.get_quantize_opcode_idx(model)
2038
2039    subgraph = model.subgraphs[0]
2040    tensors = subgraph.tensors
2041    operators = subgraph.operators
2042    for op in operators:
2043      if op.opcodeIndex in quant_opcode_idxs:
2044        input_type = util._convert_tflite_enum_type_to_tf_type(
2045            tensors[op.inputs[0]].type)
2046        if op.outputs[0] in subgraph.outputs:
2047          self.assertEqual(input_type, dtypes.float32)
2048
2049
2050class FromSavedModelTest(TestModels):
2051
2052  def _createSavedModel(self, shape):
2053    """Create a simple SavedModel."""
2054    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
2055    with ops.Graph().as_default():
2056      with session.Session() as sess:
2057        in_tensor_1 = array_ops.placeholder(
2058            shape=shape, dtype=dtypes.float32, name='inputB')
2059        in_tensor_2 = array_ops.placeholder(
2060            shape=shape, dtype=dtypes.float32, name='inputA')
2061        out_tensor = in_tensor_1 + in_tensor_2
2062        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
2063        outputs = {'z': out_tensor}
2064        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2065    return saved_model_dir
2066
2067  def testSimpleModel(self):
2068    """Test a SavedModel."""
2069    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2070
2071    # Convert model and ensure model is not None.
2072    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2073    tflite_model = converter.convert()
2074    self.assertIsNotNone(tflite_model)
2075
2076    interpreter = Interpreter(model_content=tflite_model)
2077    interpreter.allocate_tensors()
2078
2079    input_details = interpreter.get_input_details()
2080    self.assertLen(input_details, 2)
2081    self.assertStartsWith(input_details[0]['name'], 'inputA')
2082    self.assertEqual(np.float32, input_details[0]['dtype'])
2083    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2084    self.assertEqual((0., 0.), input_details[0]['quantization'])
2085
2086    self.assertStartsWith(input_details[1]['name'], 'inputB')
2087    self.assertEqual(np.float32, input_details[1]['dtype'])
2088    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2089    self.assertEqual((0., 0.), input_details[1]['quantization'])
2090
2091    output_details = interpreter.get_output_details()
2092    self.assertLen(output_details, 1)
2093    self.assertStartsWith(output_details[0]['name'], 'add')
2094    self.assertEqual(np.float32, output_details[0]['dtype'])
2095    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2096    self.assertEqual((0., 0.), output_details[0]['quantization'])
2097
2098  def testOldConverterWarning(self):
2099    """Test if the warning message when using TOCO is logged."""
2100    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2101    log = io.BytesIO() if six.PY2 else io.StringIO()
2102    handler = logging.StreamHandler(log)
2103    logging.root.addHandler(handler)
2104    warning_message = 'Please consider switching to the new converter'
2105    # Convert model and ensure model is not None.
2106    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2107    converter.experimental_new_converter = False
2108    tflite_model = converter.convert()
2109    self.assertIsNotNone(tflite_model)
2110    self.assertIn(warning_message, log.getvalue())
2111    logging.root.removeHandler(handler)
2112
2113  def testNewConverterOptOut(self):
2114    """Test if the opt out message when using New converter is logged."""
2115    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2116    log = io.BytesIO() if six.PY2 else io.StringIO()
2117    handler = logging.StreamHandler(log)
2118    logging.root.addHandler(handler)
2119    optout_message = ('Using experimental converter: '
2120                      'If you encountered a problem')
2121    # Convert model and ensure model is not None.
2122    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2123    tflite_model = converter.convert()
2124    self.assertIsNotNone(tflite_model)
2125    self.assertIn(optout_message, log.getvalue())
2126    logging.root.removeHandler(handler)
2127
2128  def testNoneBatchSize(self):
2129    """Test a SavedModel, with None in input tensor's shape."""
2130    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
2131
2132    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2133    tflite_model = converter.convert()
2134    self.assertIsNotNone(tflite_model)
2135
2136    # Check values from converted model.
2137    interpreter = Interpreter(model_content=tflite_model)
2138    interpreter.allocate_tensors()
2139
2140    input_details = interpreter.get_input_details()
2141    self.assertLen(input_details, 2)
2142    self.assertStartsWith(input_details[0]['name'], 'inputA')
2143    self.assertEqual(np.float32, input_details[0]['dtype'])
2144    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2145    self.assertEqual((0., 0.), input_details[0]['quantization'])
2146
2147    self.assertStartsWith(input_details[1]['name'], 'inputB')
2148    self.assertEqual(np.float32, input_details[1]['dtype'])
2149    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2150    self.assertEqual((0., 0.), input_details[1]['quantization'])
2151
2152    output_details = interpreter.get_output_details()
2153    self.assertLen(output_details, 1)
2154    self.assertStartsWith(output_details[0]['name'], 'add')
2155    self.assertEqual(np.float32, output_details[0]['dtype'])
2156    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2157    self.assertEqual((0., 0.), output_details[0]['quantization'])
2158
2159  def testOrderInputArrays(self):
2160    """Test a SavedModel ordering of input arrays."""
2161    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2162
2163    converter = lite.TFLiteConverter.from_saved_model(
2164        saved_model_dir, input_arrays=['inputB', 'inputA'])
2165    tflite_model = converter.convert()
2166    self.assertIsNotNone(tflite_model)
2167
2168    # Check values from converted model.
2169    interpreter = Interpreter(model_content=tflite_model)
2170    interpreter.allocate_tensors()
2171
2172    input_details = interpreter.get_input_details()
2173    self.assertLen(input_details, 2)
2174    self.assertStartsWith(input_details[0]['name'], 'inputA')
2175    self.assertEqual(np.float32, input_details[0]['dtype'])
2176    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2177    self.assertEqual((0., 0.), input_details[0]['quantization'])
2178
2179    self.assertStartsWith(input_details[1]['name'], 'inputB')
2180    self.assertEqual(np.float32, input_details[1]['dtype'])
2181    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2182    self.assertEqual((0., 0.), input_details[1]['quantization'])
2183
2184    output_details = interpreter.get_output_details()
2185    self.assertLen(output_details, 1)
2186    self.assertStartsWith(output_details[0]['name'], 'add')
2187    self.assertEqual(np.float32, output_details[0]['dtype'])
2188    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2189    self.assertEqual((0., 0.), output_details[0]['quantization'])
2190
2191  def testShapeOverriding(self):
2192    """Test a SavedModel with the input_shapes arugment."""
2193    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
2194
2195    # Convert model and ensure model is not None.
2196    converter = lite.TFLiteConverter.from_saved_model(
2197        saved_model_dir,
2198        input_shapes={
2199            'inputA': [2, 16, 16, 3],
2200            'inputB': [2, 16, 16, 3]
2201        })
2202    tflite_model = converter.convert()
2203    self.assertIsNotNone(tflite_model)
2204
2205    interpreter = Interpreter(model_content=tflite_model)
2206    interpreter.allocate_tensors()
2207
2208    input_details = interpreter.get_input_details()
2209    self.assertLen(input_details, 2)
2210    self.assertStartsWith(input_details[0]['name'], 'inputA')
2211    self.assertEqual(np.float32, input_details[0]['dtype'])
2212    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
2213    self.assertEqual((0., 0.), input_details[0]['quantization'])
2214
2215    self.assertStartsWith(input_details[1]['name'], 'inputB')
2216    self.assertEqual(np.float32, input_details[1]['dtype'])
2217    self.assertAllEqual([2, 16, 16, 3], input_details[1]['shape'])
2218    self.assertEqual((0., 0.), input_details[1]['quantization'])
2219
2220    output_details = interpreter.get_output_details()
2221    self.assertLen(output_details, 1)
2222    self.assertStartsWith(output_details[0]['name'], 'add')
2223    self.assertEqual(np.float32, output_details[0]['dtype'])
2224    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
2225    self.assertEqual((0., 0.), output_details[0]['quantization'])
2226
2227  def testWrongInputShapes(self):
2228    """Test a SavedModel with a wrong name in the input_shapes argument."""
2229    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2230
2231    # Check case where input shape is given.
2232    with self.assertRaises(ValueError):
2233      lite.TFLiteConverter.from_saved_model(
2234          saved_model_dir,
2235          input_arrays=['inputA'],
2236          input_shapes={'wrong_input': [1, 16, 16, 3]})
2237
2238  def testSubsetInputShaapes(self):
2239    """Test a SavedModel with a subset of the input array names of the model."""
2240    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2241
2242    # Check case where input shape is given.
2243    converter = lite.TFLiteConverter.from_saved_model(
2244        saved_model_dir,
2245        input_arrays=['inputA'],
2246        input_shapes={'inputA': [1, 16, 16, 3]})
2247
2248    # Since we only partially specify the input, this is not allowed.
2249    with self.assertRaises(ConverterError):
2250      _ = converter.convert()
2251
2252    # Check case where input shape is None.
2253    converter = lite.TFLiteConverter.from_saved_model(
2254        saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
2255
2256    # Since we only partially specify the input, this is not allowed.
2257    with self.assertRaises(ConverterError):
2258      _ = converter.convert()
2259
2260  def testSimpleModelTocoConverter(self):
2261    """Test a SavedModel with deprecated TocoConverter."""
2262    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2263
2264    # Convert model and ensure model is not None.
2265    converter = lite.TocoConverter.from_saved_model(saved_model_dir)
2266    tflite_model = converter.convert()
2267    self.assertIsNotNone(tflite_model)
2268
2269    # Ensure the model is able to load.
2270    interpreter = Interpreter(model_content=tflite_model)
2271    interpreter.allocate_tensors()
2272
2273  def testGraphDebugInfo(self):
2274    """Test a SavedModel has debug info captured."""
2275    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2276    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2277    converter.convert()
2278    self.assertValidDebugInfo(converter._debug_info)
2279
2280
2281class MyAddLayer(keras.layers.Layer):
2282
2283  def __init__(self, increment, **kwargs):
2284    super(MyAddLayer, self).__init__(**kwargs)
2285    self._increment = increment
2286
2287  def call(self, inputs):
2288    return inputs + self._increment
2289
2290  def get_config(self):
2291    config = super(MyAddLayer, self).get_config()
2292    config['increment'] = self._increment
2293    return config
2294
2295
2296class FromKerasFile(TestModels, parameterized.TestCase):
2297
2298  def setUp(self):
2299    super(FromKerasFile, self).setUp()
2300    self._keras_file = None
2301    self._custom_objects = None
2302    if not context.executing_eagerly():
2303      keras.backend.clear_session()
2304
2305  def tearDown(self):
2306    if self._keras_file:
2307      os.remove(self._keras_file)
2308    super(FromKerasFile, self).tearDown()
2309
2310  def _getSequentialModel(self, include_custom_layer=False):
2311    model = keras.models.Sequential()
2312    model.add(keras.layers.Dense(2, input_shape=(3,)))
2313    if include_custom_layer:
2314      model.add(MyAddLayer(1.0))
2315    model.add(keras.layers.RepeatVector(3))
2316    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
2317    model.compile(
2318        loss=keras.losses.MSE,
2319        optimizer='sgd',
2320        metrics=[keras.metrics.categorical_accuracy],
2321        sample_weight_mode='temporal')
2322    x = np.random.random((1, 3))
2323    y = np.random.random((1, 3, 3))
2324    model.train_on_batch(x, y)
2325    model.predict(x)
2326
2327    try:
2328      fd, self._keras_file = tempfile.mkstemp('.h5')
2329      keras.models.save_model(model, self._keras_file)
2330    finally:
2331      os.close(fd)
2332
2333    if include_custom_layer:
2334      self._custom_objects = {'MyAddLayer': MyAddLayer}
2335
2336  @parameterized.named_parameters(('_graph', context.graph_mode),
2337                                  ('_eager', context.eager_mode))
2338  def testSequentialModel(self, test_context):
2339    """Test a Sequential tf.keras model with default inputs."""
2340    with test_context():
2341      self._getSequentialModel()
2342
2343      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2344      tflite_model = converter.convert()
2345      self.assertIsNotNone(tflite_model)
2346
2347    # Check tensor details of converted model.
2348    interpreter = Interpreter(model_content=tflite_model)
2349    interpreter.allocate_tensors()
2350
2351    input_details = interpreter.get_input_details()
2352    self.assertLen(input_details, 1)
2353    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2354    self.assertEqual(np.float32, input_details[0]['dtype'])
2355    self.assertAllEqual([1, 3], input_details[0]['shape'])
2356    self.assertEqual((0., 0.), input_details[0]['quantization'])
2357
2358    output_details = interpreter.get_output_details()
2359    self.assertLen(output_details, 1)
2360    self.assertEqual(np.float32, output_details[0]['dtype'])
2361    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
2362    self.assertEqual((0., 0.), output_details[0]['quantization'])
2363
2364    # Check inference of converted model.
2365    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2366    interpreter.set_tensor(input_details[0]['index'], input_data)
2367    interpreter.invoke()
2368    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2369
2370    keras_model = keras.models.load_model(self._keras_file)
2371    keras_result = keras_model.predict(input_data)
2372
2373    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2374
2375  @parameterized.named_parameters(('_graph', context.graph_mode),
2376                                  ('_eager', context.eager_mode))
2377  def testCustomLayer(self, test_context):
2378    """Test a Sequential tf.keras model with default inputs."""
2379    with test_context():
2380      self._getSequentialModel(include_custom_layer=True)
2381
2382      converter = lite.TFLiteConverter.from_keras_model_file(
2383          self._keras_file, custom_objects=self._custom_objects)
2384      tflite_model = converter.convert()
2385      self.assertIsNotNone(tflite_model)
2386
2387    # Check tensor details of converted model.
2388    interpreter = Interpreter(model_content=tflite_model)
2389    interpreter.allocate_tensors()
2390
2391    input_details = interpreter.get_input_details()
2392    output_details = interpreter.get_output_details()
2393
2394    # Check inference of converted model.
2395    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2396    interpreter.set_tensor(input_details[0]['index'], input_data)
2397    interpreter.invoke()
2398    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2399
2400    keras_model = keras.models.load_model(
2401        self._keras_file, custom_objects=self._custom_objects)
2402    keras_result = keras_model.predict(input_data)
2403
2404    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2405
2406  def testSequentialModelInputArray(self):
2407    """Test a Sequential tf.keras model testing input arrays argument."""
2408    ops.disable_eager_execution()
2409    self._getSequentialModel()
2410
2411    # Invalid input array raises error.
2412    with self.assertRaises(ValueError) as error:
2413      lite.TFLiteConverter.from_keras_model_file(
2414          self._keras_file, input_arrays=['invalid-input'])
2415    self.assertEqual("Invalid tensors 'invalid-input' were found.",
2416                     str(error.exception))
2417
2418    # Valid input array.
2419    converter = lite.TFLiteConverter.from_keras_model_file(
2420        self._keras_file, input_arrays=['dense_input'])
2421    tflite_model = converter.convert()
2422    self.assertIsNotNone(tflite_model)
2423
2424  def testSequentialModelInputShape(self):
2425    """Test a Sequential tf.keras model testing input shapes argument."""
2426    self._getSequentialModel()
2427
2428    # Passing in shape of invalid input array raises error.
2429    with self.assertRaises(ValueError) as error:
2430      converter = lite.TFLiteConverter.from_keras_model_file(
2431          self._keras_file, input_shapes={'invalid-input': [2, 3]})
2432    self.assertEqual(
2433        "Invalid tensor 'invalid-input' found in tensor shapes map.",
2434        str(error.exception))
2435
2436    # Passing in shape of valid input array.
2437    converter = lite.TFLiteConverter.from_keras_model_file(
2438        self._keras_file, input_shapes={'dense_input': [2, 3]})
2439    tflite_model = converter.convert()
2440    self.assertIsNotNone(tflite_model)
2441
2442    # Check input shape from converted model.
2443    interpreter = Interpreter(model_content=tflite_model)
2444    interpreter.allocate_tensors()
2445
2446    input_details = interpreter.get_input_details()
2447    self.assertLen(input_details, 1)
2448    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2449    self.assertAllEqual([2, 3], input_details[0]['shape'])
2450
2451  def testSequentialModelOutputArray(self):
2452    """Test a Sequential tf.keras model testing output arrays argument."""
2453    ops.disable_eager_execution()
2454    self._getSequentialModel()
2455
2456    # Invalid output array raises error.
2457    with self.assertRaises(ValueError) as error:
2458      lite.TFLiteConverter.from_keras_model_file(
2459          self._keras_file, output_arrays=['invalid-output'])
2460    self.assertEqual("Invalid tensors 'invalid-output' were found.",
2461                     str(error.exception))
2462
2463    # Valid output array.
2464    converter = lite.TFLiteConverter.from_keras_model_file(
2465        self._keras_file, output_arrays=['time_distributed/Reshape_1'])
2466    tflite_model = converter.convert()
2467    self.assertIsNotNone(tflite_model)
2468
2469  @parameterized.named_parameters(('_graph', context.graph_mode),
2470                                  ('_eager', context.eager_mode))
2471  def testFunctionalModel(self, test_context):
2472    """Test a Functional tf.keras model with default inputs."""
2473    with test_context():
2474      inputs = keras.layers.Input(shape=(3,), name='input')
2475      x = keras.layers.Dense(2)(inputs)
2476      output = keras.layers.Dense(3)(x)
2477
2478      model = keras.models.Model(inputs, output)
2479      model.compile(
2480          loss=keras.losses.MSE,
2481          optimizer='sgd',
2482          metrics=[keras.metrics.categorical_accuracy])
2483      x = np.random.random((1, 3))
2484      y = np.random.random((1, 3))
2485      model.train_on_batch(x, y)
2486
2487      model.predict(x)
2488      fd, self._keras_file = tempfile.mkstemp('.h5')
2489      try:
2490        keras.models.save_model(model, self._keras_file)
2491      finally:
2492        os.close(fd)
2493
2494      # Convert to TFLite model.
2495      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2496      tflite_model = converter.convert()
2497      self.assertIsNotNone(tflite_model)
2498
2499    # Check tensor details of converted model.
2500    interpreter = Interpreter(model_content=tflite_model)
2501    interpreter.allocate_tensors()
2502
2503    input_details = interpreter.get_input_details()
2504    self.assertLen(input_details, 1)
2505    self.assertEqual('input', input_details[0]['name'])
2506    self.assertEqual(np.float32, input_details[0]['dtype'])
2507    self.assertAllEqual([1, 3], input_details[0]['shape'])
2508    self.assertEqual((0., 0.), input_details[0]['quantization'])
2509
2510    output_details = interpreter.get_output_details()
2511    self.assertLen(output_details, 1)
2512    self.assertEqual(np.float32, output_details[0]['dtype'])
2513    self.assertAllEqual([1, 3], output_details[0]['shape'])
2514    self.assertEqual((0., 0.), output_details[0]['quantization'])
2515
2516    # Check inference of converted model.
2517    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2518    interpreter.set_tensor(input_details[0]['index'], input_data)
2519    interpreter.invoke()
2520    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2521
2522    keras_model = keras.models.load_model(self._keras_file)
2523    keras_result = keras_model.predict(input_data)
2524
2525    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2526
2527  def _getFunctionalModelMultipleInputs(self):
2528    a = keras.layers.Input(shape=(3,), name='input_a')
2529    b = keras.layers.Input(shape=(3,), name='input_b')
2530    dense = keras.layers.Dense(4, name='dense')
2531    c = dense(a)
2532    d = dense(b)
2533    e = keras.layers.Dropout(0.5, name='dropout')(c)
2534
2535    model = keras.models.Model([a, b], [d, e])
2536    model.compile(
2537        loss=keras.losses.MSE,
2538        optimizer='sgd',
2539        metrics=[keras.metrics.mae],
2540        loss_weights=[1., 0.5])
2541
2542    input_a_np = np.random.random((10, 3))
2543    input_b_np = np.random.random((10, 3))
2544    output_d_np = np.random.random((10, 4))
2545    output_e_np = np.random.random((10, 4))
2546    model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
2547
2548    model.predict([input_a_np, input_b_np], batch_size=5)
2549    fd, self._keras_file = tempfile.mkstemp('.h5')
2550    try:
2551      keras.models.save_model(model, self._keras_file)
2552    finally:
2553      os.close(fd)
2554
2555  def testFunctionalModelMultipleInputs(self):
2556    """Test a Functional tf.keras model with multiple inputs and outputs."""
2557    self._getFunctionalModelMultipleInputs()
2558
2559    # Convert to TFLite model.
2560    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2561    tflite_model = converter.convert()
2562    self.assertIsNotNone(tflite_model)
2563
2564    # Check values from converted model.
2565    interpreter = Interpreter(model_content=tflite_model)
2566    interpreter.allocate_tensors()
2567
2568    input_details = interpreter.get_input_details()
2569    self.assertLen(input_details, 2)
2570    self.assertEndsWith(input_details[0]['name'], 'input_a')
2571    self.assertEqual(np.float32, input_details[0]['dtype'])
2572    self.assertAllEqual([1, 3], input_details[0]['shape'])
2573    self.assertEqual((0., 0.), input_details[0]['quantization'])
2574
2575    self.assertEndsWith(input_details[1]['name'], 'input_b')
2576    self.assertEqual(np.float32, input_details[1]['dtype'])
2577    self.assertAllEqual([1, 3], input_details[1]['shape'])
2578    self.assertEqual((0., 0.), input_details[1]['quantization'])
2579
2580    output_details = interpreter.get_output_details()
2581    self.assertLen(output_details, 2)
2582    self.assertEqual(np.float32, output_details[0]['dtype'])
2583    self.assertAllEqual([1, 4], output_details[0]['shape'])
2584    self.assertEqual((0., 0.), output_details[0]['quantization'])
2585
2586    self.assertEqual(np.float32, output_details[1]['dtype'])
2587    self.assertAllEqual([1, 4], output_details[1]['shape'])
2588    self.assertEqual((0., 0.), output_details[1]['quantization'])
2589
2590  def testShapeOverriding(self):
2591    """Test a Functional tf.keras model with input shape overriding."""
2592    self._getFunctionalModelMultipleInputs()
2593
2594    # Convert to TFLite model.
2595    converter = lite.TFLiteConverter.from_keras_model_file(
2596        self._keras_file, input_shapes={
2597            'input_a': {2, 3},
2598            'input_b': {2, 3}
2599        })
2600    tflite_model = converter.convert()
2601    self.assertIsNotNone(tflite_model)
2602
2603    # Check values from converted model.
2604    interpreter = Interpreter(model_content=tflite_model)
2605    interpreter.allocate_tensors()
2606
2607    input_details = interpreter.get_input_details()
2608    self.assertLen(input_details, 2)
2609    self.assertEndsWith(input_details[0]['name'], 'input_a')
2610    self.assertEqual(np.float32, input_details[0]['dtype'])
2611    self.assertAllEqual([2, 3], input_details[0]['shape'])
2612    self.assertEqual((0., 0.), input_details[0]['quantization'])
2613
2614    self.assertEndsWith(input_details[1]['name'], 'input_b')
2615    self.assertEqual(np.float32, input_details[1]['dtype'])
2616    self.assertAllEqual([2, 3], input_details[1]['shape'])
2617    self.assertEqual((0., 0.), input_details[1]['quantization'])
2618
2619    output_details = interpreter.get_output_details()
2620    self.assertLen(output_details, 2)
2621    self.assertEqual(np.float32, output_details[0]['dtype'])
2622    self.assertAllEqual([2, 4], output_details[0]['shape'])
2623    self.assertEqual((0., 0.), output_details[0]['quantization'])
2624
2625    self.assertEqual(np.float32, output_details[1]['dtype'])
2626    self.assertAllEqual([2, 4], output_details[1]['shape'])
2627    self.assertEqual((0., 0.), output_details[1]['quantization'])
2628
2629  def testPartialShapeOverriding(self):
2630    """Test a Functional tf.keras model with partial input shape overriding."""
2631    self._getFunctionalModelMultipleInputs()
2632
2633    # Convert to TFLite model.
2634    converter = lite.TFLiteConverter.from_keras_model_file(
2635        self._keras_file, input_shapes={'input_a': {2, 3}})
2636    tflite_model = converter.convert()
2637    self.assertIsNotNone(tflite_model)
2638
2639    # Check values from converted model.
2640    interpreter = Interpreter(model_content=tflite_model)
2641    interpreter.allocate_tensors()
2642
2643    input_details = interpreter.get_input_details()
2644    self.assertLen(input_details, 2)
2645    self.assertEndsWith(input_details[0]['name'], 'input_a')
2646    self.assertEqual(np.float32, input_details[0]['dtype'])
2647    self.assertAllEqual([2, 3], input_details[0]['shape'])
2648    self.assertEqual((0., 0.), input_details[0]['quantization'])
2649
2650    self.assertEndsWith(input_details[1]['name'], 'input_b')
2651    self.assertEqual(np.float32, input_details[1]['dtype'])
2652    self.assertAllEqual([1, 3], input_details[1]['shape'])
2653    self.assertEqual((0., 0.), input_details[1]['quantization'])
2654
2655    output_details = interpreter.get_output_details()
2656    self.assertLen(output_details, 2)
2657    self.assertEqual(np.float32, output_details[0]['dtype'])
2658    self.assertAllEqual([1, 4], output_details[0]['shape'])
2659    self.assertEqual((0., 0.), output_details[0]['quantization'])
2660
2661    self.assertEqual(np.float32, output_details[1]['dtype'])
2662    self.assertAllEqual([2, 4], output_details[1]['shape'])
2663    self.assertEqual((0., 0.), output_details[1]['quantization'])
2664
2665  def testWrongShapeOverriding(self):
2666    """Test a Functional tf.keras model with wrong input shape overriding."""
2667    self._getFunctionalModelMultipleInputs()
2668
2669    # Convert to TFLite model.
2670    with self.assertRaises(ValueError):
2671      lite.TFLiteConverter.from_keras_model_file(
2672          self._keras_file, input_shapes={'wrong_input': {2, 3}})
2673
2674  def testFunctionalSequentialModel(self):
2675    """Test a Functional tf.keras model containing a Sequential model."""
2676    model = keras.models.Sequential()
2677    model.add(keras.layers.Dense(2, input_shape=(3,)))
2678    model.add(keras.layers.RepeatVector(3))
2679    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
2680    model = keras.models.Model(model.input, model.output)
2681
2682    model.compile(
2683        loss=keras.losses.MSE,
2684        optimizer='sgd',
2685        metrics=[keras.metrics.categorical_accuracy],
2686        sample_weight_mode='temporal')
2687    x = np.random.random((1, 3))
2688    y = np.random.random((1, 3, 3))
2689    model.train_on_batch(x, y)
2690    model.predict(x)
2691
2692    model.predict(x)
2693    fd, self._keras_file = tempfile.mkstemp('.h5')
2694    try:
2695      keras.models.save_model(model, self._keras_file)
2696    finally:
2697      os.close(fd)
2698
2699    # Convert to TFLite model.
2700    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2701    tflite_model = converter.convert()
2702    self.assertIsNotNone(tflite_model)
2703
2704    # Check tensor details of converted model.
2705    interpreter = Interpreter(model_content=tflite_model)
2706    interpreter.allocate_tensors()
2707
2708    input_details = interpreter.get_input_details()
2709    self.assertLen(input_details, 1)
2710    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2711    self.assertEqual(np.float32, input_details[0]['dtype'])
2712    self.assertAllEqual([1, 3], input_details[0]['shape'])
2713    self.assertEqual((0., 0.), input_details[0]['quantization'])
2714
2715    output_details = interpreter.get_output_details()
2716    self.assertLen(output_details, 1)
2717    self.assertEqual(np.float32, output_details[0]['dtype'])
2718    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
2719    self.assertEqual((0., 0.), output_details[0]['quantization'])
2720
2721    # Check inference of converted model.
2722    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2723    interpreter.set_tensor(input_details[0]['index'], input_data)
2724    interpreter.invoke()
2725    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2726
2727    keras_model = keras.models.load_model(self._keras_file)
2728    keras_result = keras_model.predict(input_data)
2729
2730    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2731
2732  def testSequentialModelTocoConverter(self):
2733    """Test a Sequential tf.keras model with deprecated TocoConverter."""
2734    self._getSequentialModel()
2735
2736    converter = lite.TocoConverter.from_keras_model_file(self._keras_file)
2737    tflite_model = converter.convert()
2738    self.assertIsNotNone(tflite_model)
2739
2740    # Ensure the model is able to load.
2741    interpreter = Interpreter(model_content=tflite_model)
2742    interpreter.allocate_tensors()
2743
2744  @parameterized.named_parameters(('_graph', context.graph_mode),
2745                                  ('_eager', context.eager_mode))
2746  def testGraphDebugInfo(self, test_context):
2747    """Test a Sequential tf.keras model has debug info captured."""
2748    with test_context():
2749      self._getSequentialModel()
2750      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2751      converter.convert()
2752      self.assertValidDebugInfo(converter._debug_info)
2753
2754  def testSparsifyModel(self):
2755    self._getSequentialModel()
2756
2757    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2758    converter.optimizations = {lite.Optimize.EXPERIMENTAL_SPARSITY}
2759    tflite_model = converter.convert()
2760    self.assertTrue(tflite_model)
2761
2762  def testSparsifyQuantizedModel(self):
2763    self._getSequentialModel()
2764
2765    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2766    converter.optimizations = {
2767        lite.Optimize.DEFAULT, lite.Optimize.EXPERIMENTAL_SPARSITY
2768    }
2769    tflite_model = converter.convert()
2770    self.assertIsNotNone(tflite_model)
2771
2772
2773class GrapplerTest(TestModels, parameterized.TestCase):
2774
2775  def testConstantFolding(self):
2776    ops.disable_eager_execution()
2777    # Constant folding handles the tf.broadcast_to operation which was not
2778    # supported by the TFLite at the time this test was added.
2779    with ops.Graph().as_default():
2780      in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
2781      y_const = constant_op.constant([1., 2., 3.])
2782      y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
2783      out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
2784      sess = session.Session()
2785
2786    # Convert model.
2787    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2788                                                  [out_tensor])
2789    tflite_model = converter.convert()
2790
2791    # Check values from converted model.
2792    interpreter = Interpreter(model_content=tflite_model)
2793    interpreter.allocate_tensors()
2794
2795    input_details = interpreter.get_input_details()
2796    self.assertLen(input_details, 1)
2797    self.assertEqual('Placeholder', input_details[0]['name'])
2798    self.assertEqual(np.float32, input_details[0]['dtype'])
2799    self.assertAllEqual([3, 3], input_details[0]['shape'])
2800
2801    output_details = interpreter.get_output_details()
2802    self.assertLen(output_details, 1)
2803    self.assertEqual('output', output_details[0]['name'])
2804    self.assertEqual(np.float32, output_details[0]['dtype'])
2805    self.assertAllEqual([3, 3], output_details[0]['shape'])
2806
2807  @parameterized.named_parameters(
2808      ('EnableMlirConverter', True),  # enable mlir
2809      ('DisableMlirConverter', False))  # disable mlir
2810  def testInputNodeIsNotFolded(self, enable_mlir_converter):
2811    ops.disable_eager_execution()
2812    # Constant folding handles the tf.broadcast_to operation which was not
2813    # supported by the TFLite at the time this test was added.
2814    with ops.Graph().as_default():
2815      in_tensor = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
2816      y_const = constant_op.constant([1., 2., 3.])
2817      y_add = y_const + y_const
2818      out_tensor = in_tensor * y_add
2819      sess = session.Session()
2820
2821    # Convert model.
2822    converter = lite.TFLiteConverter.from_session(sess, [in_tensor, y_const],
2823                                                  [out_tensor])
2824    converter.experimental_new_converter = enable_mlir_converter
2825    tflite_model = converter.convert()
2826
2827    # Check values from converted model.
2828    interpreter = Interpreter(model_content=tflite_model)
2829    interpreter.allocate_tensors()
2830
2831    input_details = interpreter.get_input_details()
2832    self.assertLen(input_details, 2)
2833    self.assertEqual('Placeholder', input_details[0]['name'])
2834    self.assertEqual('Const', input_details[1]['name'])
2835
2836  def testGrapplerConstFolding(self):
2837    # Constant folding converts the following add operation to tf.broadcast_to
2838    # operation which was not supported by the TFLite at the time this test was
2839    # added.
2840    @def_function.function
2841    def plus_placeholder(x, placeholder):
2842      return x + placeholder
2843
2844    with ops.Graph().as_default():
2845      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
2846      out_tensor = plus_placeholder(
2847          array_ops.zeros([2, 2, 2]),
2848          array_ops.reshape(in_tensor, shape=[2, 2]))
2849      sess = session.Session()
2850
2851    # Convert model.
2852    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2853                                                  [out_tensor])
2854    tflite_model = converter.convert()
2855
2856    # Check values from converted model.
2857    interpreter = Interpreter(model_content=tflite_model)
2858    interpreter.allocate_tensors()
2859
2860    input_details = interpreter.get_input_details()
2861    self.assertLen(input_details, 1)
2862    self.assertEqual('Placeholder', input_details[0]['name'])
2863
2864
2865class DefaultConverterAttrsTest(LiteTest):
2866
2867  def testAttrs(self):
2868    with ops.Graph().as_default():
2869      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
2870      out_tensor = in_tensor + in_tensor
2871      sess = session.Session()
2872
2873    # Convert model.
2874    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2875                                                  [out_tensor])
2876
2877    # Assert output format.
2878    self.assertEqual(converter.output_format, lite_constants.TFLITE)
2879
2880    # Assert the default inference type is float.
2881    self.assertEqual(converter.inference_type, dtypes.float32)
2882
2883    # Assert the default inference type overrides are None.
2884    self.assertIsNone(converter.inference_input_type)
2885    self.assertIsNone(converter.inference_output_type)
2886
2887    # Assert the default quantization options are not set.
2888    self.assertEqual(converter.quantized_input_stats, {})
2889    self.assertIsNone(converter.default_ranges_stats)
2890    self.assertFalse(converter.reorder_across_fake_quant)
2891    self.assertFalse(converter.change_concat_input_ranges)
2892
2893    # Assert dropping control dependency is enabled by default.
2894    self.assertIsNotNone(converter.drop_control_dependency)
2895
2896    # Assert dumping extra information is disabled by default.
2897    self.assertIsNone(converter.dump_graphviz_dir)
2898    self.assertFalse(converter.dump_graphviz_video)
2899    self.assertIsNone(converter.conversion_summary_dir)
2900
2901
2902class ControlFlowV1OpsTest(LiteTest):
2903
2904  def testConverterErrorOnControlFlowV1Ops(self):
2905    graph_def_file = resource_loader.get_path_to_datafile(
2906        'testdata/control_flow_v1.pbtxt')
2907    input_arrays = ['a', 'b', 'c', 'd']
2908    output_arrays = ['Merge']
2909
2910    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
2911                                                       input_arrays,
2912                                                       output_arrays)
2913    with self.assertRaises(ConverterError) as error:
2914      converter.convert()
2915    self.assertIn(
2916        'Failed to functionalize Control Flow V1 ops. Consider using Control '
2917        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
2918        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
2919
2920
2921class QuantizationModeTest(LiteTest, parameterized.TestCase):
2922
2923  @parameterized.named_parameters(
2924      ('size', lite.Optimize.OPTIMIZE_FOR_SIZE),
2925      ('latency', lite.Optimize.OPTIMIZE_FOR_LATENCY))
2926  def testDeprecatedOptionWarning(self, optimization):
2927    """Test if the warning message when using TOCO is logged."""
2928    log = io.StringIO()
2929    handler = logging.StreamHandler(log)
2930    logging.root.addHandler(handler)
2931    warning_message = 'please use optimizations=[Optimize.DEFAULT] instead.'
2932    lite.QuantizationMode([optimization], lite.TargetSpec(), None, None)
2933    self.assertIn(warning_message, log.getvalue())
2934    logging.root.removeHandler(handler)
2935
2936
2937if __name__ == '__main__':
2938  test.main()
2939