1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for templates module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import imp 22 23from absl.testing import parameterized 24import gast 25 26from tensorflow.python.autograph.pyct import loader 27from tensorflow.python.autograph.pyct import parser 28from tensorflow.python.autograph.pyct import qual_names as qn 29from tensorflow.python.autograph.pyct import templates 30from tensorflow.python.platform import test 31 32 33class _CtxClearer(gast.NodeTransformer): 34 35 def visit(self, node): 36 super(_CtxClearer, self).visit(node) 37 if hasattr(node, 'ctx'): 38 node.ctx = None 39 return node 40 41 42def _parse_with_unset_ctx(expr_source): 43 ast_node = parser.parse_expression(expr_source) 44 _CtxClearer().visit(ast_node) 45 return ast_node 46 47 48class _CtxChecker(gast.NodeTransformer): 49 50 def __init__(self, test_instance, expected_ctx): 51 self.at_top_level = True 52 self.test_instance = test_instance 53 self.expected_ctx = expected_ctx 54 55 def visit(self, node): 56 if hasattr(node, 'ctx'): 57 self.test_instance.assertIsInstance(node.ctx, self.expected_ctx) 58 if self.at_top_level: 59 self.at_top_level = False 60 self.expected_ctx = gast.Load 61 return super(_CtxChecker, self).visit(node) 62 63 64class TemplatesTest(test.TestCase, parameterized.TestCase): 65 66 def assertExpectedCtxSet(self, node, ctx): 67 """Assert that node has ctx=ctx at top and ctx=gast.Load everywhere else.""" 68 checker = _CtxChecker(self, ctx) 69 checker.visit(node) 70 71 def test_replace_tuple(self): 72 template = """ 73 def test_fn(a, c): 74 return b, 75 """ 76 77 node = templates.replace(template, b=('a', 'c'))[0] 78 result, _, _ = loader.load_ast(node) 79 80 self.assertEqual((2, 3), result.test_fn(2, 3)) 81 82 def test_replace_variable(self): 83 template = """ 84 def test_fn(a): 85 a += 1 86 a = 2 * a + 1 87 return b 88 """ 89 90 node = templates.replace(template, a='b')[0] 91 result, _, _ = loader.load_ast(node) 92 self.assertEqual(7, result.test_fn(2)) 93 94 def test_replace_function_name(self): 95 template = """ 96 def fname(a): 97 a += 1 98 a = 2 * a + 1 99 return a 100 """ 101 102 node = templates.replace(template, fname='test_fn')[0] 103 result, _, _ = loader.load_ast(node) 104 self.assertEqual(7, result.test_fn(2)) 105 106 def test_replace_code_block(self): 107 template = """ 108 def test_fn(a): 109 block 110 return a 111 """ 112 113 class ShouldBeReplaced(object): 114 pass 115 116 node = templates.replace( 117 template, 118 block=[ 119 gast.Assign( 120 [ 121 gast.Name( 122 'a', 123 ctx=ShouldBeReplaced, 124 annotation=None, 125 type_comment=None) 126 ], 127 gast.BinOp( 128 gast.Name( 129 'a', 130 ctx=ShouldBeReplaced, 131 annotation=None, 132 type_comment=None), gast.Add(), 133 gast.Constant(1, kind=None)), 134 ), 135 ] * 2)[0] 136 result, _, _ = loader.load_ast(node) 137 self.assertEqual(3, result.test_fn(1)) 138 139 def test_replace_attribute(self): 140 template = """ 141 def test_fn(a): 142 return a.foo 143 """ 144 145 node = templates.replace(template, foo='b')[0] 146 result, _, _ = loader.load_ast(node) 147 mod = imp.new_module('test') 148 mod.b = 3 149 self.assertEqual(3, result.test_fn(mod)) 150 151 with self.assertRaises(ValueError): 152 templates.replace(template, foo=1) 153 154 def test_replace_attribute_context(self): 155 template = """ 156 def test_fn(foo): 157 foo = 0 158 """ 159 160 node = templates.replace( 161 template, 162 foo=parser.parse_expression('a.b.c'))[0] 163 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 164 self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) 165 self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load) 166 167 def test_replace_list_context(self): 168 template = """ 169 def test_fn(foo): 170 foo = 0 171 """ 172 173 node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0] 174 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 175 self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) 176 self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) 177 178 def test_replace_tuple_context(self): 179 template = """ 180 def test_fn(foo): 181 foo = 0 182 """ 183 184 node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0] 185 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 186 self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) 187 self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) 188 189 def test_replace_expression_context(self): 190 template = """ 191 def test_fn(): 192 foo 193 """ 194 195 node = templates.replace( 196 template, foo=parser.parse_expression('a + 2 * b / -c'))[0] 197 self.assertIsInstance(node.body[0].left.ctx, gast.Load) 198 self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load) 199 200 def test_replace_complex_context(self): 201 template = """ 202 def test_fn(): 203 foo = 0 204 """ 205 206 node = templates.replace( 207 template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0] 208 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 209 function_call_arg = node.body[0].targets[0].value.args[0] 210 self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load) 211 self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load) 212 self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load) 213 214 def test_replace_index(self): 215 template = """ 216 def test_fn(): 217 foo = 0 218 """ 219 220 node = templates.replace( 221 template, foo=parser.parse_expression('foo(a[b]).bar'))[0] 222 function_call_arg = node.body[0].targets[0].value.args[0] 223 self.assertIsInstance(function_call_arg.ctx, gast.Load) 224 self.assertIsInstance(function_call_arg.slice.ctx, gast.Load) 225 226 def test_replace_call_keyword(self): 227 template = """ 228 def test_fn(): 229 def f(a, d, f): 230 return a + d + f 231 return f(1, kws=None) 232 """ 233 234 source = parser.parse_expression('f(d=3, f=5)') 235 node = templates.replace(template, kws=source.keywords)[0] 236 result, _, _ = loader.load_ast(node) 237 self.assertEqual(9, result.test_fn()) 238 239 with self.assertRaises(ValueError): 240 templates.replace(template, kws=[]) 241 templates.replace(template, kws=1) 242 243 def test_replace_name_with_call(self): 244 template = """ 245 def test_fn(): 246 b = 5 247 def g(a): 248 return 3 * a 249 def f(): 250 return g 251 return foo 252 """ 253 254 source = parser.parse_expression('f()(b)') 255 node = templates.replace(template, foo=source)[0] 256 result, _, _ = loader.load_ast(node) 257 self.assertEqual(15, result.test_fn()) 258 259 def test_replace_name_with_dict(self): 260 template = """ 261 def test_fn(): 262 return foo['bar'] 263 """ 264 265 source = parser.parse_expression('{\'bar\': 3}') 266 node = templates.replace(template, foo=source)[0] 267 result, _, _ = loader.load_ast(node) 268 self.assertEqual(3, result.test_fn()) 269 270 def test_replace_as_expression(self): 271 template = """ 272 foo(a) 273 """ 274 275 node = templates.replace_as_expression(template, foo='bar', a='baz') 276 self.assertIsInstance(node, gast.Call) 277 self.assertEqual(node.func.id, 'bar') 278 self.assertEqual(node.args[0].id, 'baz') 279 280 def test_replace_as_expression_restrictions(self): 281 template = """ 282 foo(a) 283 bar(b) 284 """ 285 with self.assertRaises(ValueError): 286 templates.replace_as_expression(template) 287 288 def test_function_call_in_list(self): 289 template = """ 290 foo(bar) 291 """ 292 source = parser.parse_expression('[a(b(1))]') 293 templates.replace_as_expression(template, bar=source) 294 295 def test_star_comprehension_in_function_call(self): 296 template = """ 297 a = foo(func, args) 298 """ 299 source = parser.parse_expression('bar(*[i for i in range(j)])') 300 node = templates.replace(template, func=source.func, args=source.args) 301 arg_node = node[0].value.args[1].value 302 self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store) 303 self.assertIsInstance(arg_node.elt.ctx, gast.Load) 304 305 def test_lambda_in_function_call(self): 306 template = """ 307 a = foo(arg) 308 """ 309 source = parser.parse_expression('[lambda i: i]') 310 node = templates.replace(template, arg=source) 311 lambda_arg = node[0].value.args[0].elts[0] 312 self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param) 313 self.assertIsInstance(lambda_arg.body.ctx, gast.Load) 314 315 def test_replace_name_with_subscript(self): 316 template = """ 317 foo = bar 318 """ 319 replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key')) 320 321 node = templates.replace(template, foo=replacement)[0].targets[0] 322 self.assertIsInstance(node.ctx, gast.Store) 323 self.assertIsInstance(node.value.ctx, gast.Load) 324 325 @parameterized.named_parameters([ 326 ('mixed_attr_subscript', 'a.b["c"]'), 327 ('mixed_subscript_attr', 'a[b.c]'), 328 ('nested_subscript', 'a[b[c]]'), 329 ('repeated_subscript', 'a[b][c]'), 330 ]) 331 def test_replace_name_mixed_attr_subscript(self, expression_source): 332 template = 'foo = bar' 333 replacement = _parse_with_unset_ctx(expression_source) 334 335 target_node = templates.replace(template, foo=replacement)[0].targets[0] 336 self.assertExpectedCtxSet(target_node, gast.Store) 337 338 value_node = templates.replace(template, bar=replacement)[0].value 339 self.assertExpectedCtxSet(value_node, gast.Load) 340 341if __name__ == '__main__': 342 test.main() 343