• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for call_trees module."""
16
17import imp
18
19from tensorflow.python.autograph.converters import call_trees
20from tensorflow.python.autograph.converters import functions
21from tensorflow.python.autograph.core import converter_testing
22from tensorflow.python.platform import test
23
24
25class MockConvertedCall(object):
26
27  def __init__(self):
28    self.calls = []
29
30  def __call__(self, f, args, kwargs, caller_fn_scope=None, options=None):
31    del caller_fn_scope, options
32    self.calls.append((args, kwargs))
33    kwargs = kwargs or {}
34    return f(*args, **kwargs)
35
36
37class CallTreesTest(converter_testing.TestCase):
38
39  def _transform_with_mock(self, f):
40    mock = MockConvertedCall()
41    tr = self.transform(
42        f, (functions, call_trees),
43        ag_overrides={'converted_call': mock})
44    return tr, mock
45
46  def test_function_no_args(self):
47
48    def f(f):
49      return f() + 20
50
51    tr, mock = self._transform_with_mock(f)
52
53    self.assertEqual(tr(lambda: 1), 21)
54    self.assertListEqual(mock.calls, [((), None)])
55
56  def test_function_with_expression_in_argument(self):
57
58    def f(f, g):
59      return f(g() + 20) + 4000
60
61    tr, mock = self._transform_with_mock(f)
62
63    self.assertEqual(tr(lambda x: x + 300, lambda: 1), 4321)
64    self.assertListEqual(mock.calls, [
65        ((), None),
66        ((21,), None),
67    ])
68
69  def test_function_with_call_in_argument(self):
70
71    def f(f, g):
72      return f(g()) + 300
73
74    tr, mock = self._transform_with_mock(f)
75
76    self.assertEqual(tr(lambda x: x + 20, lambda: 1), 321)
77    self.assertListEqual(mock.calls, [
78        ((), None),
79        ((1,), None),
80    ])
81
82  def test_function_chaining(self):
83
84    def get_one():
85      return 1
86
87    def f():
88      return get_one().__add__(20)
89
90    tr, mock = self._transform_with_mock(f)
91
92    self.assertEqual(tr(), 21)
93    self.assertListEqual(mock.calls, [
94        ((), None),
95        ((20,), None),
96    ])
97
98  def test_function_with_single_arg(self):
99
100    def f(f, a):
101      return f(a) + 20
102
103    tr, mock = self._transform_with_mock(f)
104
105    self.assertEqual(tr(lambda a: a, 1), 21)
106    self.assertListEqual(mock.calls, [((1,), None)])
107
108  def test_function_with_args_only(self):
109
110    def f(f, a, b):
111      return f(a, b) + 300
112
113    tr, mock = self._transform_with_mock(f)
114
115    self.assertEqual(tr(lambda a, b: a + b, 1, 20), 321)
116    self.assertListEqual(mock.calls, [((1, 20), None)])
117
118  def test_function_with_kwarg(self):
119
120    def f(f, a, b):
121      return f(a, c=b) + 300
122
123    tr, mock = self._transform_with_mock(f)
124
125    self.assertEqual(tr(lambda a, c: a + c, 1, 20), 321)
126    self.assertListEqual(mock.calls, [((1,), {'c': 20})])
127
128  def test_function_with_kwargs_starargs(self):
129
130    def f(f, a, *args, **kwargs):
131      return f(a, *args, **kwargs) + 5
132
133    tr, mock = self._transform_with_mock(f)
134
135    self.assertEqual(
136        tr(lambda *args, **kwargs: 7, 1, *[2, 3], **{
137            'b': 4,
138            'c': 5
139        }), 12)
140    self.assertListEqual(mock.calls, [((1, 2, 3), {'b': 4, 'c': 5})])
141
142  def test_function_with_starargs_only(self):
143
144    def g(*args):
145      return sum(args)
146
147    def f():
148      args = [1, 20, 300]
149      return g(*args) + 4000
150
151    tr, mock = self._transform_with_mock(f)
152
153    self.assertEqual(tr(), 4321)
154    self.assertListEqual(mock.calls, [((1, 20, 300), None)])
155
156  def test_function_with_starargs_mixed(self):
157
158    def g(a, b, c, d):
159      return a * 1000 + b * 100 + c * 10 + d
160
161    def f():
162      args1 = (1,)
163      args2 = [3]
164      return g(*args1, 2, *args2, 4)
165
166    tr, mock = self._transform_with_mock(f)
167
168    self.assertEqual(tr(), 1234)
169    self.assertListEqual(mock.calls, [((1, 2, 3, 4), None)])
170
171  def test_function_with_kwargs_keywords(self):
172
173    def f(f, a, b, **kwargs):
174      return f(a, b=b, **kwargs) + 5
175
176    tr, mock = self._transform_with_mock(f)
177
178    self.assertEqual(
179        tr(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
180    self.assertListEqual(mock.calls, [((1,), {'b': 2, 'c': 3})])
181
182  def test_function_with_multiple_kwargs(self):
183
184    def f(f, a, b, c, kwargs1, kwargs2):
185      return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
186
187    tr, mock = self._transform_with_mock(f)
188
189    self.assertEqual(
190        tr(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, {'e': 5}), 12)
191    self.assertListEqual(mock.calls, [((1,), {
192        'b': 2,
193        'c': 3,
194        'd': 4,
195        'e': 5
196    })])
197
198  def test_function_with_call_in_lambda_argument(self):
199
200    def h(l, a):
201      return l(a) + 4000
202
203    def g(a, *args):
204      return a + sum(args)
205
206    def f(h, g, a, *args):
207      return h(lambda x: g(x, *args), a)
208
209    tr, _ = self._transform_with_mock(f)
210
211    self.assertEqual(tr(h, g, 1, *(20, 300)), 4321)
212
213  def test_debugger_set_trace(self):
214
215    tracking_list = []
216
217    pdb = imp.new_module('fake_pdb')
218    pdb.set_trace = lambda: tracking_list.append(1)
219
220    def f():
221      return pdb.set_trace()
222
223    tr, _ = self._transform_with_mock(f)
224
225    tr()
226    self.assertListEqual(tracking_list, [1])
227
228  def test_class_method(self):
229
230    class TestClass(object):
231
232      def other_method(self, x):
233        return x + 20
234
235      def test_method(self, a):
236        return self.other_method(a) + 300
237
238    tc = TestClass()
239    tr, mock = self._transform_with_mock(TestClass.test_method)
240
241    self.assertEqual(321, tr(tc, 1))
242    self.assertListEqual(mock.calls, [((1,), None)])
243
244  def test_object_method(self):
245
246    class TestClass(object):
247
248      def other_method(self, x):
249        return x + 20
250
251      def test_method(self, a):
252        return self.other_method(a) + 300
253
254    tc = TestClass()
255    tr, mock = self._transform_with_mock(tc.test_method)
256
257    self.assertEqual(321, tr(tc, 1))
258    self.assertListEqual(mock.calls, [((1,), None)])
259
260
261if __name__ == '__main__':
262  test.main()
263