• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for tf_should_use."""
16
17# pylint: disable=unused-import
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import gc
24import sys
25
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import test_util
30from tensorflow.python.platform import test
31from tensorflow.python.platform import tf_logging
32from tensorflow.python.util import tf_should_use
33
34
35@contextlib.contextmanager
36def reroute_error():
37  """Temporarily reroute errors written to tf_logging.error into `captured`."""
38  with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error:
39    yield error
40
41
42class TfShouldUseTest(test.TestCase):
43
44  def testAddShouldUseWarningWhenNotUsed(self):
45    c = constant_op.constant(0, name='blah0')
46    def in_this_function():
47      h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
48      del h
49    with reroute_error() as error:
50      in_this_function()
51    msg = '\n'.join(error.call_args[0])
52    self.assertIn('Object was never used', msg)
53    if not context.executing_eagerly():
54      self.assertIn('blah0:0', msg)
55    self.assertIn('in_this_function', msg)
56    self.assertFalse(gc.garbage)
57
58  def testAddShouldUseExceptionInEagerAndFunction(self):
59    def in_this_function():
60      c = constant_op.constant(0, name='blah0')
61      h = tf_should_use._add_should_use_warning(
62          c, warn_in_eager=True, error_in_function=True)
63      del h
64    if context.executing_eagerly():
65      with reroute_error() as error:
66        in_this_function()
67      msg = '\n'.join(error.call_args[0])
68      self.assertIn('Object was never used', msg)
69      self.assertIn('in_this_function', msg)
70      self.assertFalse(gc.garbage)
71
72    tf_fn_in_this_function = def_function.function(in_this_function)
73    with self.assertRaisesRegex(RuntimeError,
74                                r'Object was never used.*blah0:0'):
75      tf_fn_in_this_function()
76    self.assertFalse(gc.garbage)
77
78  def _testAddShouldUseWarningWhenUsed(self, fn, name):
79    c = constant_op.constant(0, name=name)
80    with reroute_error() as error:
81      h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
82      fn(h)
83      del h
84    error.assert_not_called()
85
86  def testAddShouldUseWarningWhenUsedWithAdd(self):
87    def add(h):
88      _ = h + 1
89    self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
90    gc.collect()
91    self.assertFalse(gc.garbage)
92
93  def testAddShouldUseWarningWhenUsedWithGetShape(self):
94    def get_shape(h):
95      _ = h.shape
96    self._testAddShouldUseWarningWhenUsed(get_shape, name='blah_get_name')
97    gc.collect()
98    self.assertFalse(gc.garbage)
99
100  def testShouldUseResult(self):
101    @tf_should_use.should_use_result(warn_in_eager=True)
102    def return_const(value):
103      return constant_op.constant(value, name='blah2')
104    with reroute_error() as error:
105      return_const(0.0)
106    msg = '\n'.join(error.call_args[0])
107    self.assertIn('Object was never used', msg)
108    if not context.executing_eagerly():
109      self.assertIn('blah2:0', msg)
110    self.assertIn('return_const', msg)
111    gc.collect()
112    self.assertFalse(gc.garbage)
113
114  def testShouldUseResultWhenNotReallyUsed(self):
115    @tf_should_use.should_use_result(warn_in_eager=True)
116    def return_const(value):
117      return constant_op.constant(value, name='blah3')
118    with reroute_error() as error:
119      with self.cached_session():
120        return_const(0.0)
121        # Creating another op and executing it does not mark the
122        # unused op as being "used".
123        v = constant_op.constant(1.0, name='meh')
124        self.evaluate(v)
125    msg = '\n'.join(error.call_args[0])
126    self.assertIn('Object was never used', msg)
127    if not context.executing_eagerly():
128      self.assertIn('blah3:0', msg)
129    self.assertIn('return_const', msg)
130    gc.collect()
131    self.assertFalse(gc.garbage)
132
133  # Tests that mark_used is available in the API.
134  def testMarkUsed(self):
135    @tf_should_use.should_use_result(warn_in_eager=True)
136    def return_const(value):
137      return constant_op.constant(value, name='blah3')
138
139    with self.cached_session():
140      return_const(0.0).mark_used()
141
142if __name__ == '__main__':
143  test.main()
144