1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for tf_should_use.""" 16 17# pylint: disable=unused-import 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23import gc 24import sys 25 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import test_util 30from tensorflow.python.platform import test 31from tensorflow.python.platform import tf_logging 32from tensorflow.python.util import tf_should_use 33 34 35@contextlib.contextmanager 36def reroute_error(): 37 """Temporarily reroute errors written to tf_logging.error into `captured`.""" 38 with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error: 39 yield error 40 41 42class TfShouldUseTest(test.TestCase): 43 44 def testAddShouldUseWarningWhenNotUsed(self): 45 c = constant_op.constant(0, name='blah0') 46 def in_this_function(): 47 h = tf_should_use._add_should_use_warning(c, warn_in_eager=True) 48 del h 49 with reroute_error() as error: 50 in_this_function() 51 msg = '\n'.join(error.call_args[0]) 52 self.assertIn('Object was never used', msg) 53 if not context.executing_eagerly(): 54 self.assertIn('blah0:0', msg) 55 self.assertIn('in_this_function', msg) 56 self.assertFalse(gc.garbage) 57 58 def testAddShouldUseExceptionInEagerAndFunction(self): 59 def in_this_function(): 60 c = constant_op.constant(0, name='blah0') 61 h = tf_should_use._add_should_use_warning( 62 c, warn_in_eager=True, error_in_function=True) 63 del h 64 if context.executing_eagerly(): 65 with reroute_error() as error: 66 in_this_function() 67 msg = '\n'.join(error.call_args[0]) 68 self.assertIn('Object was never used', msg) 69 self.assertIn('in_this_function', msg) 70 self.assertFalse(gc.garbage) 71 72 tf_fn_in_this_function = def_function.function(in_this_function) 73 with self.assertRaisesRegex(RuntimeError, 74 r'Object was never used.*blah0:0'): 75 tf_fn_in_this_function() 76 self.assertFalse(gc.garbage) 77 78 def _testAddShouldUseWarningWhenUsed(self, fn, name): 79 c = constant_op.constant(0, name=name) 80 with reroute_error() as error: 81 h = tf_should_use._add_should_use_warning(c, warn_in_eager=True) 82 fn(h) 83 del h 84 error.assert_not_called() 85 86 def testAddShouldUseWarningWhenUsedWithAdd(self): 87 def add(h): 88 _ = h + 1 89 self._testAddShouldUseWarningWhenUsed(add, name='blah_add') 90 gc.collect() 91 self.assertFalse(gc.garbage) 92 93 def testAddShouldUseWarningWhenUsedWithGetShape(self): 94 def get_shape(h): 95 _ = h.shape 96 self._testAddShouldUseWarningWhenUsed(get_shape, name='blah_get_name') 97 gc.collect() 98 self.assertFalse(gc.garbage) 99 100 def testShouldUseResult(self): 101 @tf_should_use.should_use_result(warn_in_eager=True) 102 def return_const(value): 103 return constant_op.constant(value, name='blah2') 104 with reroute_error() as error: 105 return_const(0.0) 106 msg = '\n'.join(error.call_args[0]) 107 self.assertIn('Object was never used', msg) 108 if not context.executing_eagerly(): 109 self.assertIn('blah2:0', msg) 110 self.assertIn('return_const', msg) 111 gc.collect() 112 self.assertFalse(gc.garbage) 113 114 def testShouldUseResultWhenNotReallyUsed(self): 115 @tf_should_use.should_use_result(warn_in_eager=True) 116 def return_const(value): 117 return constant_op.constant(value, name='blah3') 118 with reroute_error() as error: 119 with self.cached_session(): 120 return_const(0.0) 121 # Creating another op and executing it does not mark the 122 # unused op as being "used". 123 v = constant_op.constant(1.0, name='meh') 124 self.evaluate(v) 125 msg = '\n'.join(error.call_args[0]) 126 self.assertIn('Object was never used', msg) 127 if not context.executing_eagerly(): 128 self.assertIn('blah3:0', msg) 129 self.assertIn('return_const', msg) 130 gc.collect() 131 self.assertFalse(gc.garbage) 132 133 # Tests that mark_used is available in the API. 134 def testMarkUsed(self): 135 @tf_should_use.should_use_result(warn_in_eager=True) 136 def return_const(value): 137 return constant_op.constant(value, name='blah3') 138 139 with self.cached_session(): 140 return_const(0.0).mark_used() 141 142if __name__ == '__main__': 143 test.main() 144