1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Estimator related util.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23from tensorflow.python.platform import test 24from tensorflow.python.util import function_utils 25 26 27def silly_example_function(): 28 pass 29 30 31class SillyCallableClass(object): 32 33 def __call__(self): 34 pass 35 36 37class FnArgsTest(test.TestCase): 38 39 def test_simple_function(self): 40 def fn(a, b): 41 return a + b 42 self.assertEqual(('a', 'b'), function_utils.fn_args(fn)) 43 44 def test_callable(self): 45 46 class Foo(object): 47 48 def __call__(self, a, b): 49 return a + b 50 51 self.assertEqual(('a', 'b'), function_utils.fn_args(Foo())) 52 53 def test_bound_method(self): 54 55 class Foo(object): 56 57 def bar(self, a, b): 58 return a + b 59 60 self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar)) 61 62 def test_bound_method_no_self(self): 63 64 class Foo(object): 65 66 def bar(*args): # pylint:disable=no-method-argument 67 return args[1] + args[2] 68 69 self.assertEqual((), function_utils.fn_args(Foo().bar)) 70 71 def test_partial_function(self): 72 expected_test_arg = 123 73 74 def fn(a, test_arg): 75 if test_arg != expected_test_arg: 76 return ValueError('partial fn does not work correctly') 77 return a 78 79 wrapped_fn = functools.partial(fn, test_arg=123) 80 81 self.assertEqual(('a',), function_utils.fn_args(wrapped_fn)) 82 83 def test_partial_function_with_positional_args(self): 84 expected_test_arg = 123 85 86 def fn(test_arg, a): 87 if test_arg != expected_test_arg: 88 return ValueError('partial fn does not work correctly') 89 return a 90 91 wrapped_fn = functools.partial(fn, 123) 92 93 self.assertEqual(('a',), function_utils.fn_args(wrapped_fn)) 94 95 self.assertEqual(3, wrapped_fn(3)) 96 self.assertEqual(3, wrapped_fn(a=3)) 97 98 def test_double_partial(self): 99 expected_test_arg1 = 123 100 expected_test_arg2 = 456 101 102 def fn(a, test_arg1, test_arg2): 103 if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: 104 return ValueError('partial does not work correctly') 105 return a 106 107 wrapped_fn = functools.partial(fn, test_arg2=456) 108 double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) 109 110 self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) 111 112 def test_double_partial_with_positional_args_in_outer_layer(self): 113 expected_test_arg1 = 123 114 expected_test_arg2 = 456 115 116 def fn(test_arg1, a, test_arg2): 117 if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: 118 return ValueError('partial fn does not work correctly') 119 return a 120 121 wrapped_fn = functools.partial(fn, test_arg2=456) 122 double_wrapped_fn = functools.partial(wrapped_fn, 123) 123 124 self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) 125 126 self.assertEqual(3, double_wrapped_fn(3)) 127 self.assertEqual(3, double_wrapped_fn(a=3)) 128 129 def test_double_partial_with_positional_args_in_both_layers(self): 130 expected_test_arg1 = 123 131 expected_test_arg2 = 456 132 133 def fn(test_arg1, test_arg2, a): 134 if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: 135 return ValueError('partial fn does not work correctly') 136 return a 137 138 wrapped_fn = functools.partial(fn, 123) # binds to test_arg1 139 double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2 140 141 self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) 142 143 self.assertEqual(3, double_wrapped_fn(3)) 144 self.assertEqual(3, double_wrapped_fn(a=3)) 145 146 147class HasKwargsTest(test.TestCase): 148 149 def test_simple_function(self): 150 151 fn_has_kwargs = lambda **x: x 152 self.assertTrue(function_utils.has_kwargs(fn_has_kwargs)) 153 154 fn_has_no_kwargs = lambda x: x 155 self.assertFalse(function_utils.has_kwargs(fn_has_no_kwargs)) 156 157 def test_callable(self): 158 159 class FooHasKwargs(object): 160 161 def __call__(self, **x): 162 del x 163 self.assertTrue(function_utils.has_kwargs(FooHasKwargs())) 164 165 class FooHasNoKwargs(object): 166 167 def __call__(self, x): 168 del x 169 self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs())) 170 171 def test_bound_method(self): 172 173 class FooHasKwargs(object): 174 175 def fn(self, **x): 176 del x 177 self.assertTrue(function_utils.has_kwargs(FooHasKwargs().fn)) 178 179 class FooHasNoKwargs(object): 180 181 def fn(self, x): 182 del x 183 self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs().fn)) 184 185 def test_partial_function(self): 186 expected_test_arg = 123 187 188 def fn_has_kwargs(test_arg, **x): 189 if test_arg != expected_test_arg: 190 return ValueError('partial fn does not work correctly') 191 return x 192 193 wrapped_fn = functools.partial(fn_has_kwargs, test_arg=123) 194 self.assertTrue(function_utils.has_kwargs(wrapped_fn)) 195 some_kwargs = dict(x=1, y=2, z=3) 196 self.assertEqual(wrapped_fn(**some_kwargs), some_kwargs) 197 198 def fn_has_no_kwargs(x, test_arg): 199 if test_arg != expected_test_arg: 200 return ValueError('partial fn does not work correctly') 201 return x 202 203 wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg=123) 204 self.assertFalse(function_utils.has_kwargs(wrapped_fn)) 205 some_arg = 1 206 self.assertEqual(wrapped_fn(some_arg), some_arg) 207 208 def test_double_partial(self): 209 expected_test_arg1 = 123 210 expected_test_arg2 = 456 211 212 def fn_has_kwargs(test_arg1, test_arg2, **x): 213 if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: 214 return ValueError('partial does not work correctly') 215 return x 216 217 wrapped_fn = functools.partial(fn_has_kwargs, test_arg2=456) 218 double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) 219 220 self.assertTrue(function_utils.has_kwargs(double_wrapped_fn)) 221 some_kwargs = dict(x=1, y=2, z=3) 222 self.assertEqual(double_wrapped_fn(**some_kwargs), some_kwargs) 223 224 def fn_has_no_kwargs(x, test_arg1, test_arg2): 225 if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: 226 return ValueError('partial does not work correctly') 227 return x 228 229 wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg2=456) 230 double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) 231 232 self.assertFalse(function_utils.has_kwargs(double_wrapped_fn)) 233 some_arg = 1 234 self.assertEqual(double_wrapped_fn(some_arg), some_arg) 235 236 def test_raises_type_error(self): 237 with self.assertRaisesRegex(TypeError, 238 'fn should be a function-like object'): 239 function_utils.has_kwargs('not a function') 240 241 242class GetFuncNameTest(test.TestCase): 243 244 def testWithSimpleFunction(self): 245 self.assertEqual( 246 'silly_example_function', 247 function_utils.get_func_name(silly_example_function)) 248 249 def testWithClassMethod(self): 250 self.assertEqual( 251 'GetFuncNameTest.testWithClassMethod', 252 function_utils.get_func_name(self.testWithClassMethod)) 253 254 def testWithCallableClass(self): 255 callable_instance = SillyCallableClass() 256 self.assertRegex( 257 function_utils.get_func_name(callable_instance), 258 '<.*SillyCallableClass.*>') 259 260 def testWithFunctoolsPartial(self): 261 partial = functools.partial(silly_example_function) 262 self.assertRegex( 263 function_utils.get_func_name(partial), '<.*functools.partial.*>') 264 265 def testWithLambda(self): 266 anon_fn = lambda x: x 267 self.assertEqual('<lambda>', function_utils.get_func_name(anon_fn)) 268 269 def testRaisesWithNonCallableObject(self): 270 with self.assertRaises(ValueError): 271 function_utils.get_func_name(None) 272 273 274class GetFuncCodeTest(test.TestCase): 275 276 def testWithSimpleFunction(self): 277 code = function_utils.get_func_code(silly_example_function) 278 self.assertIsNotNone(code) 279 self.assertRegex(code.co_filename, 'function_utils_test.py') 280 281 def testWithClassMethod(self): 282 code = function_utils.get_func_code(self.testWithClassMethod) 283 self.assertIsNotNone(code) 284 self.assertRegex(code.co_filename, 'function_utils_test.py') 285 286 def testWithCallableClass(self): 287 callable_instance = SillyCallableClass() 288 code = function_utils.get_func_code(callable_instance) 289 self.assertIsNotNone(code) 290 self.assertRegex(code.co_filename, 'function_utils_test.py') 291 292 def testWithLambda(self): 293 anon_fn = lambda x: x 294 code = function_utils.get_func_code(anon_fn) 295 self.assertIsNotNone(code) 296 self.assertRegex(code.co_filename, 'function_utils_test.py') 297 298 def testWithFunctoolsPartial(self): 299 partial = functools.partial(silly_example_function) 300 code = function_utils.get_func_code(partial) 301 self.assertIsNone(code) 302 303 def testRaisesWithNonCallableObject(self): 304 with self.assertRaises(ValueError): 305 function_utils.get_func_code(None) 306 307 308if __name__ == '__main__': 309 test.main() 310