1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TFLite metrics_wrapper module test cases.""" 16 17import tensorflow as tf 18 19from tensorflow.lite.python import lite 20from tensorflow.lite.python.convert import ConverterError 21from tensorflow.lite.python.metrics.wrapper import metrics_wrapper 22from tensorflow.python.framework import test_util 23from tensorflow.python.platform import test 24 25 26class MetricsWrapperTest(test_util.TensorFlowTestCase): 27 28 def test_basic_retrieve_collected_errors_empty(self): 29 errors = metrics_wrapper.retrieve_collected_errors() 30 self.assertEmpty(errors) 31 32 def test_basic_retrieve_collected_errors_not_empty(self): 33 34 @tf.function( 35 input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)]) 36 def func(x): 37 return tf.cosh(x) 38 39 converter = lite.TFLiteConverterV2.from_concrete_functions( 40 [func.get_concrete_function()], func) 41 try: 42 converter.convert() 43 except ConverterError as err: 44 # retrieve_collected_errors is already captured in err.errors 45 captured_errors = err.errors 46 self.assertNotEmpty(captured_errors) 47 48 49if __name__ == "__main__": 50 test.main() 51