• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for tf_should_use."""
16
17# pylint: disable=unused-import
18import contextlib
19import gc
20import sys
21
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import test_util
26from tensorflow.python.platform import test
27from tensorflow.python.platform import tf_logging
28from tensorflow.python.util import tf_should_use
29
30
31@contextlib.contextmanager
32def reroute_error():
33  """Temporarily reroute errors written to tf_logging.error into `captured`."""
34  with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error:
35    yield error
36
37
38class TfShouldUseTest(test.TestCase):
39
40  def testAddShouldUseWarningWhenNotUsed(self):
41    c = constant_op.constant(0, name='blah0')
42    def in_this_function():
43      h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
44      del h
45    with reroute_error() as error:
46      in_this_function()
47    msg = '\n'.join(error.call_args[0])
48    self.assertIn('Object was never used', msg)
49    if not context.executing_eagerly():
50      self.assertIn('blah0:0', msg)
51    self.assertIn('in_this_function', msg)
52    self.assertFalse(gc.garbage)
53
54  def testAddShouldUseExceptionInEagerAndFunction(self):
55    def in_this_function():
56      c = constant_op.constant(0, name='blah0')
57      h = tf_should_use._add_should_use_warning(
58          c, warn_in_eager=True, error_in_function=True)
59      del h
60    if context.executing_eagerly():
61      with reroute_error() as error:
62        in_this_function()
63      msg = '\n'.join(error.call_args[0])
64      self.assertIn('Object was never used', msg)
65      self.assertIn('in_this_function', msg)
66      self.assertFalse(gc.garbage)
67
68    tf_fn_in_this_function = def_function.function(in_this_function)
69    with self.assertRaisesRegex(RuntimeError,
70                                r'Object was never used.*blah0:0'):
71      tf_fn_in_this_function()
72    self.assertFalse(gc.garbage)
73
74  def _testAddShouldUseWarningWhenUsed(self, fn, name):
75    c = constant_op.constant(0, name=name)
76    with reroute_error() as error:
77      h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
78      fn(h)
79      del h
80    error.assert_not_called()
81
82  def testAddShouldUseWarningWhenUsedWithAdd(self):
83    def add(h):
84      _ = h + 1
85    self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
86    gc.collect()
87    self.assertFalse(gc.garbage)
88
89  def testAddShouldUseWarningWhenUsedWithGetShape(self):
90    def get_shape(h):
91      _ = h.shape
92    self._testAddShouldUseWarningWhenUsed(get_shape, name='blah_get_name')
93    gc.collect()
94    self.assertFalse(gc.garbage)
95
96  def testShouldUseResult(self):
97    @tf_should_use.should_use_result(warn_in_eager=True)
98    def return_const(value):
99      return constant_op.constant(value, name='blah2')
100    with reroute_error() as error:
101      return_const(0.0)
102    msg = '\n'.join(error.call_args[0])
103    self.assertIn('Object was never used', msg)
104    if not context.executing_eagerly():
105      self.assertIn('blah2:0', msg)
106    self.assertIn('return_const', msg)
107    gc.collect()
108    self.assertFalse(gc.garbage)
109
110  def testShouldUseResultWhenNotReallyUsed(self):
111    @tf_should_use.should_use_result(warn_in_eager=True)
112    def return_const(value):
113      return constant_op.constant(value, name='blah3')
114    with reroute_error() as error:
115      with self.cached_session():
116        return_const(0.0)
117        # Creating another op and executing it does not mark the
118        # unused op as being "used".
119        v = constant_op.constant(1.0, name='meh')
120        self.evaluate(v)
121    msg = '\n'.join(error.call_args[0])
122    self.assertIn('Object was never used', msg)
123    if not context.executing_eagerly():
124      self.assertIn('blah3:0', msg)
125    self.assertIn('return_const', msg)
126    gc.collect()
127    self.assertFalse(gc.garbage)
128
129  # Tests that mark_used is available in the API.
130  def testMarkUsed(self):
131    @tf_should_use.should_use_result(warn_in_eager=True)
132    def return_const(value):
133      return constant_op.constant(value, name='blah3')
134
135    with self.cached_session():
136      return_const(0.0).mark_used()
137
138if __name__ == '__main__':
139  test.main()
140