• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CodeLab for displaying error stack trace w/ MLIR-based converter."""
16
17import sys
18
19from absl import app
20
21import tensorflow as tf
22
23
24def suppress_exception(f):
25  def wrapped():
26    try:
27      f()
28    except:  # pylint: disable=bare-except
29      pass
30  return wrapped
31
32
33class TestModule(tf.Module):
34  """The test model has unsupported op."""
35
36  @tf.function(input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)])
37  def model(self, x):
38    y = tf.math.reciprocal(x)  # Not supported
39    return y + y
40
41
42# comment out the `@suppress_exception` to display the stack trace
43@suppress_exception
44def test_from_saved_model():
45  """displaying stack trace when converting saved model."""
46  test_model = TestModule()
47  saved_model_path = '/tmp/test.saved_model'
48  save_options = tf.saved_model.SaveOptions(save_debug_info=True)
49  tf.saved_model.save(test_model, saved_model_path, options=save_options)
50
51  # load the model and convert
52  converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
53  converter.convert()
54
55
56# comment out the `@suppress_exception` to display the stack trace
57# @suppress_exception
58def test_from_concrete_function():
59  """displaying stack trace when converting concrete function."""
60  @tf.function(input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)])
61  def model(x):
62    y = tf.math.reciprocal(x)  # not supported
63    return y + y
64
65  func = model.get_concrete_function()
66  converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model)
67  converter.convert()
68
69
70def main(argv):
71  if len(argv) > 1:
72    raise app.UsageError('Too many command-line arguments.')
73
74  sys.stdout.write('==== Testing from_concrete_functions ====\n')
75  test_from_concrete_function()
76
77  sys.stdout.write('==== Testing from_saved_model ====\n')
78  test_from_saved_model()
79
80
81if __name__ == '__main__':
82  app.run(main)
83