• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""TensorFlow Lite Python metrics helper TFLiteMetrics check."""
17import gc
18import os
19import tempfile
20import time
21from unittest import mock
22
23from absl.testing import parameterized
24import numpy as np
25import tensorflow as tf
26
27from tensorflow.core.framework import graph_pb2
28from tensorflow.lite.python import lite
29from tensorflow.lite.python import metrics_nonportable as metrics
30from tensorflow.lite.python.convert import ConverterError
31from tensorflow.lite.python.convert import register_custom_opdefs
32from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
33from tensorflow.python.client import session
34from tensorflow.python.eager import monitoring
35from tensorflow.python.framework import convert_to_constants
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import test_util
39from tensorflow.python.framework.importer import import_graph_def
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import string_ops
43from tensorflow.python.ops.ragged import ragged_tensor
44from tensorflow.python.platform import resource_loader
45from tensorflow.python.platform import test
46from tensorflow.python.saved_model import saved_model
47from tensorflow.python.training.tracking import tracking
48
49
50class MetricsNonportableTest(test_util.TensorFlowTestCase):
51
52  def test_TFLiteMetrics_creation_no_arg_success(self):
53    metrics.TFLiteMetrics()
54
55  def test_TFLiteMetrics_creation_arg_success(self):
56    metrics.TFLiteMetrics('hash', '/path/to/model')
57
58  def test_TFLiteMetrics_creation_fails_with_only_hash(self):
59    with self.assertRaises(ValueError):
60      metrics.TFLiteMetrics(model_hash='hash')
61
62  def test_TFLiteMetrics_creation_fail2_with_only_model_path(self):
63    with self.assertRaises(ValueError):
64      metrics.TFLiteMetrics(model_path='/path/to/model')
65
66  def test_debugger_creation_counter_increase_multiple_same_topic_success(self):
67    try:
68      stub = metrics.TFLiteMetrics()
69      stub.increase_counter_debugger_creation()
70      self.assertEqual(metrics._counter_debugger_creation.get_cell().value(), 1)
71      stub2 = metrics.TFLiteMetrics()
72      stub2.increase_counter_debugger_creation()
73      self.assertEqual(metrics._counter_debugger_creation.get_cell().value(), 2)
74      del stub
75      gc.collect()
76      stub2.increase_counter_debugger_creation()
77      self.assertEqual(metrics._counter_debugger_creation.get_cell().value(), 3)
78    except:
79      raise Exception('No exception should be raised.')
80
81  def test_interpreter_creation_counter_increase_success(self):
82    stub = metrics.TFLiteMetrics()
83    stub.increase_counter_interpreter_creation()
84    self.assertEqual(
85        metrics._counter_interpreter_creation.get_cell('python').value(), 1)
86
87  def test_converter_attempt_counter_increase_success(self):
88    stub = metrics.TFLiteMetrics()
89    stub.increase_counter_converter_attempt()
90    self.assertEqual(metrics._counter_conversion_attempt.get_cell().value(), 1)
91
92  def test_converter_success_counter_increase_success(self):
93    stub = metrics.TFLiteMetrics()
94    stub.increase_counter_converter_success()
95    self.assertEqual(metrics._counter_conversion_success.get_cell().value(), 1)
96
97  def test_converter_params_set_success(self):
98    stub = metrics.TFLiteMetrics()
99    stub.set_converter_param('name', 'value')
100    self.assertEqual(
101        metrics._gauge_conversion_params.get_cell('name').value(), 'value')
102
103  def test_converter_params_multiple_set_success(self):
104    stub = metrics.TFLiteMetrics()
105    stub.set_converter_param('name', 'value')
106    stub.set_converter_param('name', 'value1')
107    self.assertEqual(
108        metrics._gauge_conversion_params.get_cell('name').value(), 'value1')
109
110  def test_converter_params_multiple_label_success(self):
111    stub = metrics.TFLiteMetrics()
112    stub.set_converter_param('name1', 'value1')
113    stub.set_converter_param('name2', 'value2')
114    self.assertEqual(
115        metrics._gauge_conversion_params.get_cell('name1').value(), 'value1')
116    self.assertEqual(
117        metrics._gauge_conversion_params.get_cell('name2').value(), 'value2')
118
119  def test_converter_params_set_latency(self):
120    stub = metrics.TFLiteMetrics()
121    stub.set_converter_latency(34566)
122    self.assertEqual(metrics._gauge_conversion_latency.get_cell().value(),
123                     34566)
124
125
126class ConverterMetricsTest(test_util.TensorFlowTestCase):
127  """Testing conversion metrics."""
128
129  def _constructGraphDef(self):
130    with ops.Graph().as_default():
131      in_tensor = array_ops.placeholder(
132          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
133      math_ops.add(in_tensor, in_tensor, name='add')
134      sess = session.Session()
135
136    return (
137        convert_to_constants.convert_variables_to_constants_from_session_graph(
138            sess, sess.graph_def, ['add']))
139
140  def test_conversion_from_constructor_success(self):
141    frozen_graph_def = self._constructGraphDef()
142
143    # Check metrics when conversion successed.
144    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
145                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
146    mock_metrics = mock.create_autospec(
147        metrics.TFLiteConverterMetrics, instance=True)
148    converter._tflite_metrics = mock_metrics
149    tflite_model = converter.convert()
150    self.assertIsNotNone(tflite_model)
151    mock_metrics.assert_has_calls([
152        mock.call.increase_counter_converter_attempt(),
153        mock.call.increase_counter_converter_success(),
154        mock.call.export_metrics(),
155        mock.call.set_converter_param('input_format', '1'),
156        mock.call.set_converter_param('enable_mlir_converter', 'True'),
157        mock.call.set_converter_param('allow_custom_ops', 'False'),
158        mock.call.set_converter_param('api_version', '1'),
159    ], any_order=True)  # pyformat: disable
160
161  def test_conversion_from_constructor_fail(self):
162    frozen_graph_def = self._constructGraphDef()
163
164    # Check metrics when conversion failed.
165    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
166                                     [('wrong_tensor', [2, 16, 16, 3])],
167                                     ['add'])
168    mock_metrics = mock.create_autospec(
169        metrics.TFLiteConverterMetrics, instance=True)
170    converter._tflite_metrics = mock_metrics
171    with self.assertRaises(ConverterError):
172      converter.convert()
173    mock_metrics.assert_has_calls([
174        mock.call.increase_counter_converter_attempt(),
175        mock.call.set_converter_param('output_format', '2'),
176        mock.call.set_converter_param('select_user_tf_ops', 'None'),
177        mock.call.set_converter_param('post_training_quantize', 'False'),
178    ], any_order=True)  # pyformat: disable
179    mock_metrics.increase_counter_converter_success.assert_not_called()
180
181  def _getIntegerQuantizeModel(self):
182    np.random.seed(0)
183
184    root = tracking.AutoTrackable()
185
186    @tf.function(
187        input_signature=[tf.TensorSpec(shape=[1, 5, 5, 3], dtype=tf.float32)])
188    def func(inp):
189      conv = tf.nn.conv2d(
190          inp, tf.ones([3, 3, 3, 16]), strides=[1, 1, 1, 1], padding='SAME')
191      output = tf.nn.relu(conv, name='output')
192      return output
193
194    def calibration_gen():
195      for _ in range(5):
196        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
197
198    root.f = func
199    to_save = root.f.get_concrete_function()
200    return (root, to_save, calibration_gen)
201
202  def test_conversion_from_frozen_graph_v2(self):
203    model, func, calibration_gen = self._getIntegerQuantizeModel()
204
205    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func],
206                                                                         model)
207    mock_metrics = mock.create_autospec(
208        metrics.TFLiteConverterMetrics, instance=True)
209    quantized_converter._tflite_metrics = mock_metrics
210    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
211    quantized_converter.representative_dataset = calibration_gen
212    quantized_tflite_model = quantized_converter.convert()
213    self.assertIsNotNone(quantized_tflite_model)
214    mock_metrics.assert_has_calls([
215        mock.call.increase_counter_converter_attempt(),
216        mock.call.increase_counter_converter_success(),
217        mock.call.set_converter_param(
218            'optimization_post_training_integer_quantize', 'True'),
219        mock.call.set_converter_param('inference_type', 'tf.int8'),
220        mock.call.set_converter_param('select_user_tf_ops', 'None'),
221        mock.call.set_converter_param('activations_type', 'tf.int8'),
222    ], any_order=True)  # pyformat: disable
223
224  def test_conversion_from_keras_v2(self):
225    x = [-1, 0, 1, 2, 3, 4]
226    y = [-3, -1, 1, 3, 5, 7]
227    model = tf.keras.models.Sequential(
228        [tf.keras.layers.Dense(units=1, input_shape=[1])])
229    model.compile(optimizer='sgd', loss='mean_squared_error')
230    model.fit(x, y, epochs=1)
231    converter = lite.TFLiteConverterV2.from_keras_model(model)
232    mock_metrics = mock.create_autospec(
233        metrics.TFLiteConverterMetrics, instance=True)
234    converter._tflite_metrics = mock_metrics
235    converter.convert()
236    mock_metrics.assert_has_calls([
237        mock.call.increase_counter_converter_attempt(),
238        mock.call.increase_counter_converter_success(),
239        mock.call.export_metrics(),
240        mock.call.set_converter_param('inference_type', 'tf.float32'),
241        mock.call.set_converter_param('target_ops', 'TFLITE_BUILTINS'),
242        mock.call.set_converter_param('optimization_default', 'False'),
243    ], any_order=True)  # pyformat: disable
244
245  def _createV1SavedModel(self, shape):
246    """Create a simple SavedModel."""
247    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
248    with tf.Graph().as_default():
249      with tf.compat.v1.Session() as sess:
250        in_tensor_1 = tf.compat.v1.placeholder(
251            shape=shape, dtype=tf.float32, name='inputB')
252        in_tensor_2 = tf.compat.v1.placeholder(
253            shape=shape, dtype=tf.float32, name='inputA')
254        variable_node = tf.Variable(1.0, name='variable_node')
255        out_tensor = in_tensor_1 + in_tensor_2 * variable_node
256        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
257        outputs = {'z': out_tensor}
258        sess.run(tf.compat.v1.variables_initializer([variable_node]))
259        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
260    return saved_model_dir
261
262  def test_conversion_from_saved_model(self):
263    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
264    converter = lite.TFLiteSavedModelConverter(saved_model_dir, set(['serve']),
265                                               ['serving_default'])
266    converter.experimental_new_converter = True
267    mock_metrics = mock.create_autospec(
268        metrics.TFLiteConverterMetrics, instance=True)
269    converter._tflite_metrics = mock_metrics
270    time.process_time = mock.Mock(side_effect=np.arange(1, 1000, 2).tolist())
271    converter.convert()
272    mock_metrics.assert_has_calls([
273        mock.call.increase_counter_converter_attempt(),
274        mock.call.increase_counter_converter_success(),
275        mock.call.set_converter_latency(2000),
276        mock.call.export_metrics(),
277        mock.call.set_converter_param('enable_mlir_converter', 'True'),
278    ], any_order=True)  # pyformat: disable
279
280  def test_conversion_from_saved_model_v2(self):
281    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
282
283    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
284    converter.experimental_new_converter = False
285    mock_metrics = mock.create_autospec(
286        metrics.TFLiteConverterMetrics, instance=True)
287    converter._tflite_metrics = mock_metrics
288    converter.convert()
289    mock_metrics.assert_has_calls([
290        mock.call.increase_counter_converter_attempt(),
291        mock.call.increase_counter_converter_success(),
292        mock.call.export_metrics(),
293        mock.call.set_converter_param('enable_mlir_converter', 'False'),
294        mock.call.set_converter_param('api_version', '2'),
295    ], any_order=True)  # pyformat: disable
296
297  def disable_converter_counter_metrics(self, tflite_metrics):
298
299    def empty_func():
300      pass
301
302    tflite_metrics.increase_counter_converter_attempt = empty_func
303    tflite_metrics.increase_counter_converter_success = empty_func
304
305  def test_export_at_conversion_done(self):
306    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
307
308    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
309    tflite_metrics = converter._tflite_metrics
310    mock_exporter = mock.MagicMock()
311    tflite_metrics._metrics_exporter = mock_exporter
312    self.disable_converter_counter_metrics(tflite_metrics)
313    mock_exporter.ExportMetrics.assert_not_called()
314    converter.convert()
315    mock_exporter.ExportMetrics.assert_called_once()
316    tflite_metrics.__del__()
317    mock_exporter.ExportMetrics.assert_called_once()
318
319  def test_export_at_exit(self):
320    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
321    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
322    tflite_metrics = converter._tflite_metrics
323    mock_exporter = mock.MagicMock()
324    tflite_metrics._metrics_exporter = mock_exporter
325    self.disable_converter_counter_metrics(tflite_metrics)
326    mock_exporter.ExportMetrics.assert_not_called()
327    tflite_metrics.__del__()
328    mock_exporter.ExportMetrics.assert_called_once()
329
330
331def mock_ngrams(data, width, axis=-1, string_separator=' ', name=None):
332  """This mock Ngrams lack the width attr, causing conversion to fail."""
333
334  experimental_implements = [
335      'name: "tftext:Ngrams"',
336      'attr { key: "axis" value { i: %d } }' % axis,
337      'attr { key: "reduction_type" value { s: "STRING_JOIN" } }',
338      'attr { key: "string_separator" value { s: "%s" } }' % string_separator,
339  ]
340  experimental_implements = ' '.join(experimental_implements)
341
342  @tf.function(experimental_implements=experimental_implements)
343  def func(data):
344    with ops.name_scope(name, 'NGrams', [data, width]):
345      data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
346      slices = []
347      for start in range(width):
348        stop = None if start - width + 1 == 0 else start - width + 1
349        if axis >= 0:
350          idx = [slice(None)] * axis + [slice(start, stop)]
351        else:
352          idx = [Ellipsis, slice(start, stop)] + [slice(None)] * (-axis - 1)
353        slices.append(data[idx])
354
355      # Stack the slices.
356      stack_axis = axis + 1 if axis >= 0 else axis
357      windowed_data = array_ops.stack(slices, stack_axis)
358
359      return string_ops.reduce_join(
360          windowed_data, axis=axis, separator=string_separator)
361
362  return func(data)
363
364
365class ConverterErrorMetricTest(test_util.TensorFlowTestCase,
366                               parameterized.TestCase):
367  """Testing conversion error metric."""
368
369  def setUp(self):
370    super(ConverterErrorMetricTest, self).setUp()
371
372    # Mock metrics instance except errors so other test cases are not affected.
373    mock_attempt = mock.create_autospec(monitoring.Counter, instance=True)
374    self._counter_conversion_attempt = metrics._counter_conversion_attempt
375    metrics._counter_conversion_attempt = mock_attempt
376
377    mock_success = mock.create_autospec(monitoring.Counter, instance=True)
378    self._counter_conversion_success = metrics._counter_conversion_success
379    metrics._counter_conversion_success = mock_success
380
381    mock_params = mock.create_autospec(monitoring.StringGauge, instance=True)
382    self._gauge_conversion_params = metrics._gauge_conversion_params
383    metrics._gauge_conversion_params = mock_params
384
385  def tearDown(self):
386    super(ConverterErrorMetricTest, self).tearDown()
387    # # Restore metrics instances.
388    metrics._counter_conversion_attempt = self._counter_conversion_attempt
389    metrics._counter_conversion_success = self._counter_conversion_success
390    metrics._gauge_conversion_params = self._gauge_conversion_params
391
392  def convert_and_check_location_info(self,
393                                      converter,
394                                      expected_type,
395                                      expected_sources=None):
396    # The custom attribute of ConverterError can't be accessed with
397    # assertRaises so use try-catch block instead.
398    try:
399      tflite_model = converter.convert()
400      self.assertIsNone(tflite_model)
401    except ConverterError as converter_error:
402      # pylint: disable=g-assert-in-except
403      self.assertLen(converter_error.errors, 1)
404      location = converter_error.errors[0].location
405      self.assertEqual(location.type, expected_type)
406
407      if expected_sources:
408        debug_string = str(location)
409        for source in expected_sources:
410          self.assertIn(source, debug_string)
411      # pylint: enable=g-assert-in-except
412
413  def test_failure_at_PrepareCompositeFunctionsPass(self):
414
415    class NgramsLayer(tf.keras.layers.Layer):
416
417      def call(self, input_tensor, **kwargs):
418        return mock_ngrams(input_tensor, width=2, axis=-1, string_separator=' ')
419
420    # Registers a fake WhitespaceTokenizeWithOffsets so the TFText fusing logic
421    # is enable in MLIR side.
422    custom_opdefs_str = (
423        'name: \'WhitespaceTokenizeWithOffsets\' input_arg: {name: \'Input1\' '
424        'type: DT_FLOAT} input_arg: {name: \'Input2\' type: DT_FLOAT} '
425        'output_arg: {name: \'Output\' type: DT_FLOAT}')
426    register_custom_opdefs([custom_opdefs_str])
427
428    model = tf.keras.models.Sequential([NgramsLayer()])
429    model.predict(tf.constant(['test']))
430    converter = tf.lite.TFLiteConverter.from_keras_model(model)
431    converter.allow_custom_ops = True
432    self.convert_and_check_location_info(
433        converter, converter_error_data_pb2.ConverterErrorData.UNKNOWNLOC)
434    exported_error = metrics._gauge_conversion_errors.get_cell(
435        'CONVERT_TF_TO_TFLITE_MODEL', 'PrepareCompositeFunctionsPass', '',
436        'UNKNOWN').value()
437    self.assertEqual(exported_error,
438                     "\'width\' attribute is not set or not an integer\n")
439
440  def test_need_flex_ops(self):
441
442    def create_graph_with_custom_add(opname='CustomAdd'):
443      custom_opdefs_str = (
444          'name: \'' + opname +
445          '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
446          'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
447          '\'Output\' type: DT_FLOAT}')
448
449      # Create a graph that has one add op.
450      new_graph = graph_pb2.GraphDef()
451      with ops.Graph().as_default():
452        with session.Session() as sess:
453          in_tensor = array_ops.placeholder(
454              shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
455          out_tensor = in_tensor + in_tensor
456          inputs = {'x': in_tensor}
457          outputs = {'z': out_tensor}
458
459          new_graph.CopyFrom(sess.graph_def)
460
461      # Rename Add op name to opname.
462      for node in new_graph.node:
463        if node.op.startswith('Add'):
464          node.op = opname
465          del node.attr['T']
466
467      # Register custom op defs to import modified graph def.
468      register_custom_opdefs([custom_opdefs_str])
469
470      return (new_graph, inputs, outputs)
471
472    new_graph, inputs, outputs = create_graph_with_custom_add()
473
474    # Import to load the custom opdef.
475    saved_model_dir = os.path.join(self.get_temp_dir(), 'model')
476    with ops.Graph().as_default():
477      with session.Session() as sess:
478        import_graph_def(new_graph, name='')
479        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
480
481    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
482    self.convert_and_check_location_info(
483        converter,
484        converter_error_data_pb2.ConverterErrorData.NAMELOC,
485        expected_sources='add')
486    exported_error = metrics._gauge_conversion_errors.get_cell(
487        'CONVERT_TF_TO_TFLITE_MODEL', 'CONVERT_SAVED_MODEL', 'tf.CustomAdd',
488        'ERROR_NEEDS_CUSTOM_OPS').value()
489    self.assertEqual(
490        exported_error,
491        "\'tf.CustomAdd\' op is neither a custom op nor a flex op\n"
492        "Error code: ERROR_NEEDS_CUSTOM_OPS"
493    )
494
495  def test_unsupported_control_flow_v1(self):
496    filename = resource_loader.get_path_to_datafile(
497        'testdata/control_flow_v1_saved_model')
498    converter = lite.TFLiteConverterV2.from_saved_model(filename)
499    self.convert_and_check_location_info(
500        converter, converter_error_data_pb2.ConverterErrorData.UNKNOWNLOC)
501    exported_error = metrics._gauge_conversion_errors.get_cell(
502        'CONVERT_TF_TO_TFLITE_MODEL', 'CONVERT_SAVED_MODEL', '',
503        'ERROR_UNSUPPORTED_CONTROL_FLOW_V1').value()
504    self.assertEqual(
505        exported_error,
506        'Merge only has 4 inputs, while only merge nodes with two inputs '
507        'supported.\n\tFailed to functionalize Control Flow V1 ops. Consider '
508        'using Control Flow V2 ops instead. See https://www.tensorflow.org/'
509        'api_docs/python/tf/compat/v1/enable_control_flow_v2.')
510
511  def test_location_from_concrete_functions(self):
512
513    @tf.function(input_signature=[
514        tf.TensorSpec(shape=[None, None, 2, 3, 3], dtype=tf.complex64),
515        tf.TensorSpec(shape=[None, None, 1, 3, 3], dtype=tf.complex64),
516    ])
517    def model(a, b):
518      return tf.add(a, b, name='add')
519
520    converter = tf.lite.TFLiteConverter.from_concrete_functions(
521        [model.get_concrete_function()], model)
522    self.convert_and_check_location_info(
523        converter,
524        converter_error_data_pb2.ConverterErrorData.CALLSITELOC,
525        expected_sources=[
526            'tensorflow/lite/python/metrics_nonportable_test.py',
527        ])
528
529  def test_location_from_saved_model(self):
530
531    with tempfile.TemporaryDirectory() as tmp_dir:
532
533      class Adder(tf.Module):
534
535        @tf.function(input_signature=[
536            tf.TensorSpec(shape=[None, None, 2, 3, 3], dtype=tf.complex64),
537            tf.TensorSpec(shape=[None, None, 1, 3, 3], dtype=tf.complex64),
538        ])
539        def serving_default(self, a, b):
540          return tf.add(a, b, name='add')
541
542      tf.saved_model.save(
543          Adder(),
544          tmp_dir,
545          options=tf.saved_model.SaveOptions(save_debug_info=True))
546
547      converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
548      self.convert_and_check_location_info(
549          converter,
550          converter_error_data_pb2.ConverterErrorData.CALLSITELOC,
551          expected_sources=[
552              'tensorflow/lite/python/metrics_nonportable_test.py',
553          ])
554
555  @parameterized.named_parameters(
556      ('_WithoutLoweringToSavedModel', False, None),
557      ('_WithLoweringToSavedModel', True,
558       'tensorflow/lite/python/metrics_nonportable_test.py'))
559  def test_location_from_keras_model(self, lower_to_saved_model,
560                                     expected_source):
561    input_tensor1 = tf.keras.layers.Input(
562        shape=[None, None, 2, 3, 3], dtype=tf.complex64)
563    input_tensor2 = tf.keras.layers.Input(
564        shape=[None, None, 2, 3, 3], dtype=tf.complex64)
565    output = tf.keras.layers.Add()([input_tensor1, input_tensor2])
566    model = tf.keras.Model(
567        inputs=[input_tensor1, input_tensor2], outputs=output)
568    model.compile(
569        optimizer='adam',
570        loss='sparse_categorical_crossentropy',
571        metrics=['accuracy'])
572
573    converter = tf.lite.TFLiteConverter.from_keras_model(model)
574    converter.experimental_lower_to_saved_model = lower_to_saved_model
575    # The location does not contain callsite to the current file.
576    self.convert_and_check_location_info(
577        converter,
578        converter_error_data_pb2.ConverterErrorData.CALLSITELOC,
579        expected_sources=[expected_source] if expected_source else None)
580
581
582if __name__ == '__main__':
583  test.main()
584