• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for transpiler module."""
16
17import threading
18
19import gast
20
21from tensorflow.python.autograph.pyct import transformer
22from tensorflow.python.autograph.pyct import transpiler
23from tensorflow.python.platform import test
24
25
26class FlipSignTransformer(transformer.Base):
27
28  def visit_BinOp(self, node):
29    if isinstance(node.op, gast.Add):
30      node.op = gast.Sub()
31    return self.generic_visit(node)
32
33
34class TestTranspiler(transpiler.PyToPy):
35
36  def get_caching_key(self, ctx):
37    del ctx
38    return 0
39
40  def get_extra_locals(self):
41    return {}
42
43  def transform_ast(self, node, ctx):
44    return FlipSignTransformer(ctx).visit(node)
45
46
47global_var_for_test_global = 1
48global_var_for_test_namespace_collisions = object()
49
50
51class PyToPyTest(test.TestCase):
52
53  def test_basic(self):
54    def f(a):
55      return a + 1
56
57    tr = TestTranspiler()
58    f, _, _ = tr.transform(f, None)
59
60    self.assertEqual(f(1), 0)
61
62  def test_closure(self):
63    b = 1
64
65    def f(a):
66      return a + b
67
68    tr = TestTranspiler()
69    f, _, _ = tr.transform(f, None)
70
71    self.assertEqual(f(1), 0)
72    b = 2
73    self.assertEqual(f(1), -1)
74
75  def test_global(self):
76    def f(a):
77      return a + global_var_for_test_global
78
79    tr = TestTranspiler()
80    f, _, _ = tr.transform(f, None)
81
82    global global_var_for_test_global
83    global_var_for_test_global = 1
84    self.assertEqual(f(1), 0)
85    global_var_for_test_global = 2
86    self.assertEqual(f(1), -1)
87
88  def test_defaults(self):
89    b = 2
90    c = 1
91
92    def f(a, d=c + 1):
93      return a + b + d
94
95    tr = TestTranspiler()
96    f, _, _ = tr.transform(f, None)
97
98    self.assertEqual(f(1), 1 - 2 - 2)
99    c = 0
100    self.assertEqual(f(1), 1 - 2 - 2)  # Defaults are evaluated at definition.
101    b = 1
102    self.assertEqual(f(1), 1 - 2 - 1)
103
104  def test_call_tree(self):
105
106    def g(a):
107      return a + 1
108
109    def f(a):
110      return g(a) + 1
111
112    tr = TestTranspiler()
113    f, _, _ = tr.transform(f, None)
114
115    self.assertEqual(f(1), 1 - 1 + 1)  # Only f is converted.
116
117  def test_lambda(self):
118    b = 2
119    f = lambda x: (b + (x if x > 0 else -x))
120
121    tr = TestTranspiler()
122    f, _, _ = tr.transform(f, None)
123
124    self.assertEqual(f(1), 2 - 1)
125    self.assertEqual(f(-1), 2 - 1)
126
127    b = 3
128
129    self.assertEqual(f(1), 3 - 1)
130    self.assertEqual(f(-1), 3 - 1)
131
132  def test_multiple_lambdas(self):
133    a, b = 1, 2
134    # This can be disambiguated by the argument names.
135    f, _ = (lambda x: a + x, lambda y: b * y)
136
137    tr = TestTranspiler()
138    f, _, _ = tr.transform(f, None)
139
140    self.assertEqual(f(1), 1 - 1)
141
142  def test_nested_functions(self):
143    b = 2
144
145    def f(x):
146
147      def g(x):
148        return b + x
149
150      return g(x)
151
152    tr = TestTranspiler()
153    f, _, _ = tr.transform(f, None)
154
155    self.assertEqual(f(1), 2 - 1)
156
157  def test_nested_lambda(self):
158    b = 2
159
160    def f(x):
161      g = lambda x: b + x
162      return g(x)
163
164    tr = TestTranspiler()
165    f, _, _ = tr.transform(f, None)
166
167    self.assertEqual(f(1), 2 - 1)
168
169  def test_concurrency(self):
170
171    def f():
172      pass
173
174    outputs = []
175
176    tr = TestTranspiler()
177    # Note: this is not a test, it's a required invariant.
178    assert tr.get_caching_key(None) == tr.get_caching_key(None)
179
180    def conversion_thread():
181      _, mod, _ = tr.transform(f, None)
182      outputs.append(mod.__name__)
183
184    threads = tuple(
185        threading.Thread(target=conversion_thread) for _ in range(10))
186    for t in threads:
187      t.start()
188    for t in threads:
189      t.join()
190
191    # Races would potentially create multiple functions / modules
192    # (non-deterministically, but with high likelihood).
193    self.assertEqual(len(set(outputs)), 1)
194
195  def test_reentrance(self):
196
197    def test_fn():
198      return 1 + 1
199
200    class ReentrantTranspiler(transpiler.PyToPy):
201
202      def __init__(self):
203        super(ReentrantTranspiler, self).__init__()
204        self._recursion_depth = 0
205
206      def get_caching_key(self, ctx):
207        del ctx
208        return 0
209
210      def get_extra_locals(self):
211        return {}
212
213      def transform_ast(self, node, ctx):
214        self._recursion_depth += 1
215        if self._recursion_depth < 2:
216          self.transform(test_fn, None)
217        return FlipSignTransformer(ctx).visit(node)
218
219    tr = ReentrantTranspiler()
220
221    f, _, _ = tr.transform(test_fn, None)
222    self.assertEqual(f(), 0)
223
224  def test_namespace_collisions_avoided(self):
225
226    class TestClass(object):
227
228      def global_var_for_test_namespace_collisions(self):
229        return global_var_for_test_namespace_collisions
230
231    tr = TestTranspiler()
232    obj = TestClass()
233
234    f, _, _ = tr.transform(
235        obj.global_var_for_test_namespace_collisions, None)
236    self.assertIs(f(obj), global_var_for_test_namespace_collisions)
237
238
239if __name__ == '__main__':
240  test.main()
241