1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for tf_should_use.""" 16 17# pylint: disable=unused-import 18import contextlib 19import gc 20import sys 21 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import test_util 26from tensorflow.python.platform import test 27from tensorflow.python.platform import tf_logging 28from tensorflow.python.util import tf_should_use 29 30 31@contextlib.contextmanager 32def reroute_error(): 33 """Temporarily reroute errors written to tf_logging.error into `captured`.""" 34 with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error: 35 yield error 36 37 38class TfShouldUseTest(test.TestCase): 39 40 def testAddShouldUseWarningWhenNotUsed(self): 41 c = constant_op.constant(0, name='blah0') 42 def in_this_function(): 43 h = tf_should_use._add_should_use_warning(c, warn_in_eager=True) 44 del h 45 with reroute_error() as error: 46 in_this_function() 47 msg = '\n'.join(error.call_args[0]) 48 self.assertIn('Object was never used', msg) 49 if not context.executing_eagerly(): 50 self.assertIn('blah0:0', msg) 51 self.assertIn('in_this_function', msg) 52 self.assertFalse(gc.garbage) 53 54 def testAddShouldUseExceptionInEagerAndFunction(self): 55 def in_this_function(): 56 c = constant_op.constant(0, name='blah0') 57 h = tf_should_use._add_should_use_warning( 58 c, warn_in_eager=True, error_in_function=True) 59 del h 60 if context.executing_eagerly(): 61 with reroute_error() as error: 62 in_this_function() 63 msg = '\n'.join(error.call_args[0]) 64 self.assertIn('Object was never used', msg) 65 self.assertIn('in_this_function', msg) 66 self.assertFalse(gc.garbage) 67 68 tf_fn_in_this_function = def_function.function(in_this_function) 69 with self.assertRaisesRegex(RuntimeError, 70 r'Object was never used.*blah0:0'): 71 tf_fn_in_this_function() 72 self.assertFalse(gc.garbage) 73 74 def _testAddShouldUseWarningWhenUsed(self, fn, name): 75 c = constant_op.constant(0, name=name) 76 with reroute_error() as error: 77 h = tf_should_use._add_should_use_warning(c, warn_in_eager=True) 78 fn(h) 79 del h 80 error.assert_not_called() 81 82 def testAddShouldUseWarningWhenUsedWithAdd(self): 83 def add(h): 84 _ = h + 1 85 self._testAddShouldUseWarningWhenUsed(add, name='blah_add') 86 gc.collect() 87 self.assertFalse(gc.garbage) 88 89 def testAddShouldUseWarningWhenUsedWithGetShape(self): 90 def get_shape(h): 91 _ = h.shape 92 self._testAddShouldUseWarningWhenUsed(get_shape, name='blah_get_name') 93 gc.collect() 94 self.assertFalse(gc.garbage) 95 96 def testShouldUseResult(self): 97 @tf_should_use.should_use_result(warn_in_eager=True) 98 def return_const(value): 99 return constant_op.constant(value, name='blah2') 100 with reroute_error() as error: 101 return_const(0.0) 102 msg = '\n'.join(error.call_args[0]) 103 self.assertIn('Object was never used', msg) 104 if not context.executing_eagerly(): 105 self.assertIn('blah2:0', msg) 106 self.assertIn('return_const', msg) 107 gc.collect() 108 self.assertFalse(gc.garbage) 109 110 def testShouldUseResultWhenNotReallyUsed(self): 111 @tf_should_use.should_use_result(warn_in_eager=True) 112 def return_const(value): 113 return constant_op.constant(value, name='blah3') 114 with reroute_error() as error: 115 with self.cached_session(): 116 return_const(0.0) 117 # Creating another op and executing it does not mark the 118 # unused op as being "used". 119 v = constant_op.constant(1.0, name='meh') 120 self.evaluate(v) 121 msg = '\n'.join(error.call_args[0]) 122 self.assertIn('Object was never used', msg) 123 if not context.executing_eagerly(): 124 self.assertIn('blah3:0', msg) 125 self.assertIn('return_const', msg) 126 gc.collect() 127 self.assertFalse(gc.garbage) 128 129 # Tests that mark_used is available in the API. 130 def testMarkUsed(self): 131 @tf_should_use.should_use_result(warn_in_eager=True) 132 def return_const(value): 133 return constant_op.constant(value, name='blah3') 134 135 with self.cached_session(): 136 return_const(0.0).mark_used() 137 138if __name__ == '__main__': 139 test.main() 140