1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for transpiler module.""" 16 17import threading 18 19import gast 20 21from tensorflow.python.autograph.pyct import transformer 22from tensorflow.python.autograph.pyct import transpiler 23from tensorflow.python.platform import test 24 25 26class FlipSignTransformer(transformer.Base): 27 28 def visit_BinOp(self, node): 29 if isinstance(node.op, gast.Add): 30 node.op = gast.Sub() 31 return self.generic_visit(node) 32 33 34class TestTranspiler(transpiler.PyToPy): 35 36 def get_caching_key(self, ctx): 37 del ctx 38 return 0 39 40 def get_extra_locals(self): 41 return {} 42 43 def transform_ast(self, node, ctx): 44 return FlipSignTransformer(ctx).visit(node) 45 46 47global_var_for_test_global = 1 48global_var_for_test_namespace_collisions = object() 49 50 51class PyToPyTest(test.TestCase): 52 53 def test_basic(self): 54 def f(a): 55 return a + 1 56 57 tr = TestTranspiler() 58 f, _, _ = tr.transform(f, None) 59 60 self.assertEqual(f(1), 0) 61 62 def test_closure(self): 63 b = 1 64 65 def f(a): 66 return a + b 67 68 tr = TestTranspiler() 69 f, _, _ = tr.transform(f, None) 70 71 self.assertEqual(f(1), 0) 72 b = 2 73 self.assertEqual(f(1), -1) 74 75 def test_global(self): 76 def f(a): 77 return a + global_var_for_test_global 78 79 tr = TestTranspiler() 80 f, _, _ = tr.transform(f, None) 81 82 global global_var_for_test_global 83 global_var_for_test_global = 1 84 self.assertEqual(f(1), 0) 85 global_var_for_test_global = 2 86 self.assertEqual(f(1), -1) 87 88 def test_defaults(self): 89 b = 2 90 c = 1 91 92 def f(a, d=c + 1): 93 return a + b + d 94 95 tr = TestTranspiler() 96 f, _, _ = tr.transform(f, None) 97 98 self.assertEqual(f(1), 1 - 2 - 2) 99 c = 0 100 self.assertEqual(f(1), 1 - 2 - 2) # Defaults are evaluated at definition. 101 b = 1 102 self.assertEqual(f(1), 1 - 2 - 1) 103 104 def test_call_tree(self): 105 106 def g(a): 107 return a + 1 108 109 def f(a): 110 return g(a) + 1 111 112 tr = TestTranspiler() 113 f, _, _ = tr.transform(f, None) 114 115 self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted. 116 117 def test_lambda(self): 118 b = 2 119 f = lambda x: (b + (x if x > 0 else -x)) 120 121 tr = TestTranspiler() 122 f, _, _ = tr.transform(f, None) 123 124 self.assertEqual(f(1), 2 - 1) 125 self.assertEqual(f(-1), 2 - 1) 126 127 b = 3 128 129 self.assertEqual(f(1), 3 - 1) 130 self.assertEqual(f(-1), 3 - 1) 131 132 def test_multiple_lambdas(self): 133 a, b = 1, 2 134 # This can be disambiguated by the argument names. 135 f, _ = (lambda x: a + x, lambda y: b * y) 136 137 tr = TestTranspiler() 138 f, _, _ = tr.transform(f, None) 139 140 self.assertEqual(f(1), 1 - 1) 141 142 def test_nested_functions(self): 143 b = 2 144 145 def f(x): 146 147 def g(x): 148 return b + x 149 150 return g(x) 151 152 tr = TestTranspiler() 153 f, _, _ = tr.transform(f, None) 154 155 self.assertEqual(f(1), 2 - 1) 156 157 def test_nested_lambda(self): 158 b = 2 159 160 def f(x): 161 g = lambda x: b + x 162 return g(x) 163 164 tr = TestTranspiler() 165 f, _, _ = tr.transform(f, None) 166 167 self.assertEqual(f(1), 2 - 1) 168 169 def test_concurrency(self): 170 171 def f(): 172 pass 173 174 outputs = [] 175 176 tr = TestTranspiler() 177 # Note: this is not a test, it's a required invariant. 178 assert tr.get_caching_key(None) == tr.get_caching_key(None) 179 180 def conversion_thread(): 181 _, mod, _ = tr.transform(f, None) 182 outputs.append(mod.__name__) 183 184 threads = tuple( 185 threading.Thread(target=conversion_thread) for _ in range(10)) 186 for t in threads: 187 t.start() 188 for t in threads: 189 t.join() 190 191 # Races would potentially create multiple functions / modules 192 # (non-deterministically, but with high likelihood). 193 self.assertEqual(len(set(outputs)), 1) 194 195 def test_reentrance(self): 196 197 def test_fn(): 198 return 1 + 1 199 200 class ReentrantTranspiler(transpiler.PyToPy): 201 202 def __init__(self): 203 super(ReentrantTranspiler, self).__init__() 204 self._recursion_depth = 0 205 206 def get_caching_key(self, ctx): 207 del ctx 208 return 0 209 210 def get_extra_locals(self): 211 return {} 212 213 def transform_ast(self, node, ctx): 214 self._recursion_depth += 1 215 if self._recursion_depth < 2: 216 self.transform(test_fn, None) 217 return FlipSignTransformer(ctx).visit(node) 218 219 tr = ReentrantTranspiler() 220 221 f, _, _ = tr.transform(test_fn, None) 222 self.assertEqual(f(), 0) 223 224 def test_namespace_collisions_avoided(self): 225 226 class TestClass(object): 227 228 def global_var_for_test_namespace_collisions(self): 229 return global_var_for_test_namespace_collisions 230 231 tr = TestTranspiler() 232 obj = TestClass() 233 234 f, _, _ = tr.transform( 235 obj.global_var_for_test_namespace_collisions, None) 236 self.assertIs(f(obj), global_var_for_test_namespace_collisions) 237 238 239if __name__ == '__main__': 240 test.main() 241