1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for call_trees module.""" 16 17import imp 18 19from tensorflow.python.autograph.converters import call_trees 20from tensorflow.python.autograph.converters import functions 21from tensorflow.python.autograph.core import converter_testing 22from tensorflow.python.platform import test 23 24 25class MockConvertedCall(object): 26 27 def __init__(self): 28 self.calls = [] 29 30 def __call__(self, f, args, kwargs, caller_fn_scope=None, options=None): 31 del caller_fn_scope, options 32 self.calls.append((args, kwargs)) 33 kwargs = kwargs or {} 34 return f(*args, **kwargs) 35 36 37class CallTreesTest(converter_testing.TestCase): 38 39 def _transform_with_mock(self, f): 40 mock = MockConvertedCall() 41 tr = self.transform( 42 f, (functions, call_trees), 43 ag_overrides={'converted_call': mock}) 44 return tr, mock 45 46 def test_function_no_args(self): 47 48 def f(f): 49 return f() + 20 50 51 tr, mock = self._transform_with_mock(f) 52 53 self.assertEqual(tr(lambda: 1), 21) 54 self.assertListEqual(mock.calls, [((), None)]) 55 56 def test_function_with_expression_in_argument(self): 57 58 def f(f, g): 59 return f(g() + 20) + 4000 60 61 tr, mock = self._transform_with_mock(f) 62 63 self.assertEqual(tr(lambda x: x + 300, lambda: 1), 4321) 64 self.assertListEqual(mock.calls, [ 65 ((), None), 66 ((21,), None), 67 ]) 68 69 def test_function_with_call_in_argument(self): 70 71 def f(f, g): 72 return f(g()) + 300 73 74 tr, mock = self._transform_with_mock(f) 75 76 self.assertEqual(tr(lambda x: x + 20, lambda: 1), 321) 77 self.assertListEqual(mock.calls, [ 78 ((), None), 79 ((1,), None), 80 ]) 81 82 def test_function_chaining(self): 83 84 def get_one(): 85 return 1 86 87 def f(): 88 return get_one().__add__(20) 89 90 tr, mock = self._transform_with_mock(f) 91 92 self.assertEqual(tr(), 21) 93 self.assertListEqual(mock.calls, [ 94 ((), None), 95 ((20,), None), 96 ]) 97 98 def test_function_with_single_arg(self): 99 100 def f(f, a): 101 return f(a) + 20 102 103 tr, mock = self._transform_with_mock(f) 104 105 self.assertEqual(tr(lambda a: a, 1), 21) 106 self.assertListEqual(mock.calls, [((1,), None)]) 107 108 def test_function_with_args_only(self): 109 110 def f(f, a, b): 111 return f(a, b) + 300 112 113 tr, mock = self._transform_with_mock(f) 114 115 self.assertEqual(tr(lambda a, b: a + b, 1, 20), 321) 116 self.assertListEqual(mock.calls, [((1, 20), None)]) 117 118 def test_function_with_kwarg(self): 119 120 def f(f, a, b): 121 return f(a, c=b) + 300 122 123 tr, mock = self._transform_with_mock(f) 124 125 self.assertEqual(tr(lambda a, c: a + c, 1, 20), 321) 126 self.assertListEqual(mock.calls, [((1,), {'c': 20})]) 127 128 def test_function_with_kwargs_starargs(self): 129 130 def f(f, a, *args, **kwargs): 131 return f(a, *args, **kwargs) + 5 132 133 tr, mock = self._transform_with_mock(f) 134 135 self.assertEqual( 136 tr(lambda *args, **kwargs: 7, 1, *[2, 3], **{ 137 'b': 4, 138 'c': 5 139 }), 12) 140 self.assertListEqual(mock.calls, [((1, 2, 3), {'b': 4, 'c': 5})]) 141 142 def test_function_with_starargs_only(self): 143 144 def g(*args): 145 return sum(args) 146 147 def f(): 148 args = [1, 20, 300] 149 return g(*args) + 4000 150 151 tr, mock = self._transform_with_mock(f) 152 153 self.assertEqual(tr(), 4321) 154 self.assertListEqual(mock.calls, [((1, 20, 300), None)]) 155 156 def test_function_with_starargs_mixed(self): 157 158 def g(a, b, c, d): 159 return a * 1000 + b * 100 + c * 10 + d 160 161 def f(): 162 args1 = (1,) 163 args2 = [3] 164 return g(*args1, 2, *args2, 4) 165 166 tr, mock = self._transform_with_mock(f) 167 168 self.assertEqual(tr(), 1234) 169 self.assertListEqual(mock.calls, [((1, 2, 3, 4), None)]) 170 171 def test_function_with_kwargs_keywords(self): 172 173 def f(f, a, b, **kwargs): 174 return f(a, b=b, **kwargs) + 5 175 176 tr, mock = self._transform_with_mock(f) 177 178 self.assertEqual( 179 tr(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12) 180 self.assertListEqual(mock.calls, [((1,), {'b': 2, 'c': 3})]) 181 182 def test_function_with_multiple_kwargs(self): 183 184 def f(f, a, b, c, kwargs1, kwargs2): 185 return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5 186 187 tr, mock = self._transform_with_mock(f) 188 189 self.assertEqual( 190 tr(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, {'e': 5}), 12) 191 self.assertListEqual(mock.calls, [((1,), { 192 'b': 2, 193 'c': 3, 194 'd': 4, 195 'e': 5 196 })]) 197 198 def test_function_with_call_in_lambda_argument(self): 199 200 def h(l, a): 201 return l(a) + 4000 202 203 def g(a, *args): 204 return a + sum(args) 205 206 def f(h, g, a, *args): 207 return h(lambda x: g(x, *args), a) 208 209 tr, _ = self._transform_with_mock(f) 210 211 self.assertEqual(tr(h, g, 1, *(20, 300)), 4321) 212 213 def test_debugger_set_trace(self): 214 215 tracking_list = [] 216 217 pdb = imp.new_module('fake_pdb') 218 pdb.set_trace = lambda: tracking_list.append(1) 219 220 def f(): 221 return pdb.set_trace() 222 223 tr, _ = self._transform_with_mock(f) 224 225 tr() 226 self.assertListEqual(tracking_list, [1]) 227 228 def test_class_method(self): 229 230 class TestClass(object): 231 232 def other_method(self, x): 233 return x + 20 234 235 def test_method(self, a): 236 return self.other_method(a) + 300 237 238 tc = TestClass() 239 tr, mock = self._transform_with_mock(TestClass.test_method) 240 241 self.assertEqual(321, tr(tc, 1)) 242 self.assertListEqual(mock.calls, [((1,), None)]) 243 244 def test_object_method(self): 245 246 class TestClass(object): 247 248 def other_method(self, x): 249 return x + 20 250 251 def test_method(self, a): 252 return self.other_method(a) + 300 253 254 tc = TestClass() 255 tr, mock = self._transform_with_mock(tc.test_method) 256 257 self.assertEqual(321, tr(tc, 1)) 258 self.assertListEqual(mock.calls, [((1,), None)]) 259 260 261if __name__ == '__main__': 262 test.main() 263