1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""CodeLab for displaying error stack trace w/ MLIR-based converter.""" 16 17import sys 18 19from absl import app 20 21import tensorflow as tf 22 23 24def suppress_exception(f): 25 def wrapped(): 26 try: 27 f() 28 except: # pylint: disable=bare-except 29 pass 30 return wrapped 31 32 33class TestModule(tf.Module): 34 """The test model has unsupported op.""" 35 36 @tf.function(input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)]) 37 def model(self, x): 38 y = tf.math.reciprocal(x) # Not supported 39 return y + y 40 41 42# comment out the `@suppress_exception` to display the stack trace 43@suppress_exception 44def test_from_saved_model(): 45 """displaying stack trace when converting saved model.""" 46 test_model = TestModule() 47 saved_model_path = '/tmp/test.saved_model' 48 save_options = tf.saved_model.SaveOptions(save_debug_info=True) 49 tf.saved_model.save(test_model, saved_model_path, options=save_options) 50 51 # load the model and convert 52 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path) 53 converter.convert() 54 55 56# comment out the `@suppress_exception` to display the stack trace 57# @suppress_exception 58def test_from_concrete_function(): 59 """displaying stack trace when converting concrete function.""" 60 @tf.function(input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)]) 61 def model(x): 62 y = tf.math.reciprocal(x) # not supported 63 return y + y 64 65 func = model.get_concrete_function() 66 converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model) 67 converter.convert() 68 69 70def main(argv): 71 if len(argv) > 1: 72 raise app.UsageError('Too many command-line arguments.') 73 74 sys.stdout.write('==== Testing from_concrete_functions ====\n') 75 test_from_concrete_function() 76 77 sys.stdout.write('==== Testing from_saved_model ====\n') 78 test_from_saved_model() 79 80 81if __name__ == '__main__': 82 app.run(main) 83