• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TFLite metrics_wrapper module test cases."""
16
17import tensorflow as tf
18
19from tensorflow.lite.python import lite
20from tensorflow.lite.python.convert import ConverterError
21from tensorflow.lite.python.metrics.wrapper import metrics_wrapper
22from tensorflow.python.framework import test_util
23from tensorflow.python.platform import test
24
25
26class MetricsWrapperTest(test_util.TensorFlowTestCase):
27
28  def test_basic_retrieve_collected_errors_empty(self):
29    errors = metrics_wrapper.retrieve_collected_errors()
30    self.assertEmpty(errors)
31
32  def test_basic_retrieve_collected_errors_not_empty(self):
33
34    @tf.function(
35        input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
36    def func(x):
37      return tf.cosh(x)
38
39    converter = lite.TFLiteConverterV2.from_concrete_functions(
40        [func.get_concrete_function()], func)
41    try:
42      converter.convert()
43    except ConverterError as err:
44      # retrieve_collected_errors is already captured in err.errors
45      captured_errors = err.errors
46    self.assertNotEmpty(captured_errors)
47
48
49if __name__ == "__main__":
50  test.main()
51