• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Estimator related util."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensorflow.python.platform import test
24from tensorflow.python.util import function_utils
25
26
27def silly_example_function():
28  pass
29
30
31class SillyCallableClass(object):
32
33  def __call__(self):
34    pass
35
36
37class FnArgsTest(test.TestCase):
38
39  def test_simple_function(self):
40    def fn(a, b):
41      return a + b
42    self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
43
44  def test_callable(self):
45
46    class Foo(object):
47
48      def __call__(self, a, b):
49        return a + b
50
51    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
52
53  def test_bound_method(self):
54
55    class Foo(object):
56
57      def bar(self, a, b):
58        return a + b
59
60    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
61
62  def test_bound_method_no_self(self):
63
64    class Foo(object):
65
66      def bar(*args):  # pylint:disable=no-method-argument
67        return args[1] + args[2]
68
69    self.assertEqual((), function_utils.fn_args(Foo().bar))
70
71  def test_partial_function(self):
72    expected_test_arg = 123
73
74    def fn(a, test_arg):
75      if test_arg != expected_test_arg:
76        return ValueError('partial fn does not work correctly')
77      return a
78
79    wrapped_fn = functools.partial(fn, test_arg=123)
80
81    self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
82
83  def test_partial_function_with_positional_args(self):
84    expected_test_arg = 123
85
86    def fn(test_arg, a):
87      if test_arg != expected_test_arg:
88        return ValueError('partial fn does not work correctly')
89      return a
90
91    wrapped_fn = functools.partial(fn, 123)
92
93    self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
94
95    self.assertEqual(3, wrapped_fn(3))
96    self.assertEqual(3, wrapped_fn(a=3))
97
98  def test_double_partial(self):
99    expected_test_arg1 = 123
100    expected_test_arg2 = 456
101
102    def fn(a, test_arg1, test_arg2):
103      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
104        return ValueError('partial does not work correctly')
105      return a
106
107    wrapped_fn = functools.partial(fn, test_arg2=456)
108    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
109
110    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
111
112  def test_double_partial_with_positional_args_in_outer_layer(self):
113    expected_test_arg1 = 123
114    expected_test_arg2 = 456
115
116    def fn(test_arg1, a, test_arg2):
117      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
118        return ValueError('partial fn does not work correctly')
119      return a
120
121    wrapped_fn = functools.partial(fn, test_arg2=456)
122    double_wrapped_fn = functools.partial(wrapped_fn, 123)
123
124    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
125
126    self.assertEqual(3, double_wrapped_fn(3))
127    self.assertEqual(3, double_wrapped_fn(a=3))
128
129  def test_double_partial_with_positional_args_in_both_layers(self):
130    expected_test_arg1 = 123
131    expected_test_arg2 = 456
132
133    def fn(test_arg1, test_arg2, a):
134      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
135        return ValueError('partial fn does not work correctly')
136      return a
137
138    wrapped_fn = functools.partial(fn, 123)  # binds to test_arg1
139    double_wrapped_fn = functools.partial(wrapped_fn, 456)  # binds to test_arg2
140
141    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
142
143    self.assertEqual(3, double_wrapped_fn(3))
144    self.assertEqual(3, double_wrapped_fn(a=3))
145
146
147class HasKwargsTest(test.TestCase):
148
149  def test_simple_function(self):
150
151    fn_has_kwargs = lambda **x: x
152    self.assertTrue(function_utils.has_kwargs(fn_has_kwargs))
153
154    fn_has_no_kwargs = lambda x: x
155    self.assertFalse(function_utils.has_kwargs(fn_has_no_kwargs))
156
157  def test_callable(self):
158
159    class FooHasKwargs(object):
160
161      def __call__(self, **x):
162        del x
163    self.assertTrue(function_utils.has_kwargs(FooHasKwargs()))
164
165    class FooHasNoKwargs(object):
166
167      def __call__(self, x):
168        del x
169    self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs()))
170
171  def test_bound_method(self):
172
173    class FooHasKwargs(object):
174
175      def fn(self, **x):
176        del x
177    self.assertTrue(function_utils.has_kwargs(FooHasKwargs().fn))
178
179    class FooHasNoKwargs(object):
180
181      def fn(self, x):
182        del x
183    self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs().fn))
184
185  def test_partial_function(self):
186    expected_test_arg = 123
187
188    def fn_has_kwargs(test_arg, **x):
189      if test_arg != expected_test_arg:
190        return ValueError('partial fn does not work correctly')
191      return x
192
193    wrapped_fn = functools.partial(fn_has_kwargs, test_arg=123)
194    self.assertTrue(function_utils.has_kwargs(wrapped_fn))
195    some_kwargs = dict(x=1, y=2, z=3)
196    self.assertEqual(wrapped_fn(**some_kwargs), some_kwargs)
197
198    def fn_has_no_kwargs(x, test_arg):
199      if test_arg != expected_test_arg:
200        return ValueError('partial fn does not work correctly')
201      return x
202
203    wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg=123)
204    self.assertFalse(function_utils.has_kwargs(wrapped_fn))
205    some_arg = 1
206    self.assertEqual(wrapped_fn(some_arg), some_arg)
207
208  def test_double_partial(self):
209    expected_test_arg1 = 123
210    expected_test_arg2 = 456
211
212    def fn_has_kwargs(test_arg1, test_arg2, **x):
213      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
214        return ValueError('partial does not work correctly')
215      return x
216
217    wrapped_fn = functools.partial(fn_has_kwargs, test_arg2=456)
218    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
219
220    self.assertTrue(function_utils.has_kwargs(double_wrapped_fn))
221    some_kwargs = dict(x=1, y=2, z=3)
222    self.assertEqual(double_wrapped_fn(**some_kwargs), some_kwargs)
223
224    def fn_has_no_kwargs(x, test_arg1, test_arg2):
225      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
226        return ValueError('partial does not work correctly')
227      return x
228
229    wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg2=456)
230    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
231
232    self.assertFalse(function_utils.has_kwargs(double_wrapped_fn))
233    some_arg = 1
234    self.assertEqual(double_wrapped_fn(some_arg), some_arg)
235
236  def test_raises_type_error(self):
237    with self.assertRaisesRegex(TypeError,
238                                'fn should be a function-like object'):
239      function_utils.has_kwargs('not a function')
240
241
242class GetFuncNameTest(test.TestCase):
243
244  def testWithSimpleFunction(self):
245    self.assertEqual(
246        'silly_example_function',
247        function_utils.get_func_name(silly_example_function))
248
249  def testWithClassMethod(self):
250    self.assertEqual(
251        'GetFuncNameTest.testWithClassMethod',
252        function_utils.get_func_name(self.testWithClassMethod))
253
254  def testWithCallableClass(self):
255    callable_instance = SillyCallableClass()
256    self.assertRegex(
257        function_utils.get_func_name(callable_instance),
258        '<.*SillyCallableClass.*>')
259
260  def testWithFunctoolsPartial(self):
261    partial = functools.partial(silly_example_function)
262    self.assertRegex(
263        function_utils.get_func_name(partial), '<.*functools.partial.*>')
264
265  def testWithLambda(self):
266    anon_fn = lambda x: x
267    self.assertEqual('<lambda>', function_utils.get_func_name(anon_fn))
268
269  def testRaisesWithNonCallableObject(self):
270    with self.assertRaises(ValueError):
271      function_utils.get_func_name(None)
272
273
274class GetFuncCodeTest(test.TestCase):
275
276  def testWithSimpleFunction(self):
277    code = function_utils.get_func_code(silly_example_function)
278    self.assertIsNotNone(code)
279    self.assertRegex(code.co_filename, 'function_utils_test.py')
280
281  def testWithClassMethod(self):
282    code = function_utils.get_func_code(self.testWithClassMethod)
283    self.assertIsNotNone(code)
284    self.assertRegex(code.co_filename, 'function_utils_test.py')
285
286  def testWithCallableClass(self):
287    callable_instance = SillyCallableClass()
288    code = function_utils.get_func_code(callable_instance)
289    self.assertIsNotNone(code)
290    self.assertRegex(code.co_filename, 'function_utils_test.py')
291
292  def testWithLambda(self):
293    anon_fn = lambda x: x
294    code = function_utils.get_func_code(anon_fn)
295    self.assertIsNotNone(code)
296    self.assertRegex(code.co_filename, 'function_utils_test.py')
297
298  def testWithFunctoolsPartial(self):
299    partial = functools.partial(silly_example_function)
300    code = function_utils.get_func_code(partial)
301    self.assertIsNone(code)
302
303  def testRaisesWithNonCallableObject(self):
304    with self.assertRaises(ValueError):
305      function_utils.get_func_code(None)
306
307
308if __name__ == '__main__':
309  test.main()
310