1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for exceptions module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.autograph.operators import exceptions 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import errors_impl 24from tensorflow.python.framework import test_util 25from tensorflow.python.platform import test 26 27 28class ExceptionsTest(test.TestCase): 29 30 def test_assert_tf_untriggered(self): 31 with self.cached_session() as sess: 32 t = exceptions.assert_stmt( 33 constant_op.constant(True), lambda: constant_op.constant('ignored')) 34 self.evaluate(t) 35 36 @test_util.run_deprecated_v1 37 def test_assert_tf_triggered(self): 38 with self.cached_session() as sess: 39 t = exceptions.assert_stmt( 40 constant_op.constant(False), 41 lambda: constant_op.constant('test message')) 42 43 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 44 'test message'): 45 self.evaluate(t) 46 47 @test_util.run_deprecated_v1 48 def test_assert_tf_multiple_printed_values(self): 49 two_tensors = [ 50 constant_op.constant('test message'), 51 constant_op.constant('another message') 52 ] 53 with self.cached_session() as sess: 54 t = exceptions.assert_stmt( 55 constant_op.constant(False), lambda: two_tensors) 56 57 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 58 'test message.*another message'): 59 self.evaluate(t) 60 61 def test_assert_python_untriggered(self): 62 side_effect_trace = [] 63 64 def expression_with_side_effects(): 65 side_effect_trace.append(object()) 66 return 'test message' 67 68 exceptions.assert_stmt(True, expression_with_side_effects) 69 70 self.assertListEqual(side_effect_trace, []) 71 72 def test_assert_python_triggered(self): 73 if not __debug__: 74 # Python assertions only be tested when in debug mode. 75 return 76 77 side_effect_trace = [] 78 tracer = object() 79 80 def expression_with_side_effects(): 81 side_effect_trace.append(tracer) 82 return 'test message' 83 84 with self.assertRaisesRegexp(AssertionError, 'test message'): 85 exceptions.assert_stmt(False, expression_with_side_effects) 86 self.assertListEqual(side_effect_trace, [tracer]) 87 88 89if __name__ == '__main__': 90 test.main() 91