1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for templates module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import imp 22 23import gast 24 25from tensorflow.python.autograph.pyct import compiler 26from tensorflow.python.autograph.pyct import parser 27from tensorflow.python.autograph.pyct import templates 28from tensorflow.python.platform import test 29 30 31class TemplatesTest(test.TestCase): 32 33 def test_replace_tuple(self): 34 template = """ 35 def test_fn(a, c): 36 return b, 37 """ 38 39 node = templates.replace(template, b=('a', 'c'))[0] 40 result, _ = compiler.ast_to_object(node) 41 42 self.assertEquals((2, 3), result.test_fn(2, 3)) 43 44 def test_replace_variable(self): 45 template = """ 46 def test_fn(a): 47 a += 1 48 a = 2 * a + 1 49 return b 50 """ 51 52 node = templates.replace(template, a='b')[0] 53 result, _ = compiler.ast_to_object(node) 54 self.assertEquals(7, result.test_fn(2)) 55 56 def test_replace_function_name(self): 57 template = """ 58 def fname(a): 59 a += 1 60 a = 2 * a + 1 61 return a 62 """ 63 64 node = templates.replace(template, fname='test_fn')[0] 65 result, _ = compiler.ast_to_object(node) 66 self.assertEquals(7, result.test_fn(2)) 67 68 def test_replace_code_block(self): 69 template = """ 70 def test_fn(a): 71 block 72 return a 73 """ 74 75 node = templates.replace( 76 template, 77 block=[ 78 gast.Assign([ 79 gast.Name('a', None, None) 80 ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), 81 ] * 2)[0] 82 result, _ = compiler.ast_to_object(node) 83 self.assertEquals(3, result.test_fn(1)) 84 85 def test_replace_attribute(self): 86 template = """ 87 def test_fn(a): 88 return a.foo 89 """ 90 91 node = templates.replace(template, foo='b')[0] 92 result, _ = compiler.ast_to_object(node) 93 mod = imp.new_module('test') 94 mod.b = 3 95 self.assertEquals(3, result.test_fn(mod)) 96 97 with self.assertRaises(ValueError): 98 templates.replace(template, foo=1) 99 100 def test_replace_attribute_context(self): 101 template = """ 102 def test_fn(foo): 103 foo = 0 104 """ 105 106 node = templates.replace( 107 template, 108 foo=parser.parse_expression('a.b.c'))[0] 109 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 110 self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) 111 self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load) 112 113 def test_replace_list_context(self): 114 template = """ 115 def test_fn(foo): 116 foo = 0 117 """ 118 119 node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0] 120 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 121 self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) 122 self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) 123 124 def test_replace_tuple_context(self): 125 template = """ 126 def test_fn(foo): 127 foo = 0 128 """ 129 130 node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0] 131 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 132 self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) 133 self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) 134 135 def test_replace_expression_context(self): 136 template = """ 137 def test_fn(): 138 foo 139 """ 140 141 node = templates.replace( 142 template, foo=parser.parse_expression('a + 2 * b / -c'))[0] 143 self.assertIsInstance(node.body[0].left.ctx, gast.Load) 144 self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load) 145 146 def test_replace_complex_context(self): 147 template = """ 148 def test_fn(): 149 foo = 0 150 """ 151 152 node = templates.replace( 153 template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0] 154 self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) 155 function_call_arg = node.body[0].targets[0].value.args[0] 156 self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load) 157 self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load) 158 self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load) 159 160 def test_replace_index(self): 161 template = """ 162 def test_fn(): 163 foo = 0 164 """ 165 166 node = templates.replace( 167 template, foo=parser.parse_expression('foo(a[b]).bar'))[0] 168 function_call_arg = node.body[0].targets[0].value.args[0] 169 self.assertIsInstance(function_call_arg.ctx, gast.Load) 170 self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load) 171 172 def test_replace_call_keyword(self): 173 template = """ 174 def test_fn(): 175 def f(a, d, f): 176 return a + d + f 177 return f(1, kws=None) 178 """ 179 180 source = parser.parse_expression('f(d=3, f=5)') 181 node = templates.replace(template, kws=source.keywords)[0] 182 result, _ = compiler.ast_to_object(node) 183 self.assertEquals(9, result.test_fn()) 184 185 with self.assertRaises(ValueError): 186 templates.replace(template, kws=[]) 187 templates.replace(template, kws=1) 188 189 def test_replace_name_with_call(self): 190 template = """ 191 def test_fn(): 192 b = 5 193 def g(a): 194 return 3 * a 195 def f(): 196 return g 197 return foo 198 """ 199 200 source = parser.parse_expression('f()(b)') 201 node = templates.replace(template, foo=source)[0] 202 result, _ = compiler.ast_to_object(node) 203 self.assertEquals(15, result.test_fn()) 204 205 def test_replace_name_with_dict(self): 206 template = """ 207 def test_fn(): 208 return foo['bar'] 209 """ 210 211 source = parser.parse_expression('{\'bar\': 3}') 212 node = templates.replace(template, foo=source)[0] 213 result, _ = compiler.ast_to_object(node) 214 self.assertEquals(3, result.test_fn()) 215 216 def test_replace_as_expression(self): 217 template = """ 218 foo(a) 219 """ 220 221 node = templates.replace_as_expression(template, foo='bar', a='baz') 222 self.assertIsInstance(node, gast.Call) 223 self.assertEqual(node.func.id, 'bar') 224 self.assertEqual(node.args[0].id, 'baz') 225 226 def test_replace_as_expression_restrictions(self): 227 template = """ 228 foo(a) 229 bar(b) 230 """ 231 with self.assertRaises(ValueError): 232 templates.replace_as_expression(template) 233 234 def test_function_call_in_list(self): 235 template = """ 236 foo(bar) 237 """ 238 source = parser.parse_expression('[a(b(1))]') 239 templates.replace_as_expression(template, bar=source) 240 241 def test_star_comprehension_in_function_call(self): 242 template = """ 243 a = foo(func, args) 244 """ 245 source = parser.parse_expression('bar(*[i for i in range(j)])') 246 node = templates.replace(template, func=source.func, args=source.args) 247 arg_node = node[0].value.args[1].value 248 self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store) 249 self.assertIsInstance(arg_node.elt.ctx, gast.Load) 250 251 def test_lambda_in_function_call(self): 252 template = """ 253 a = foo(arg) 254 """ 255 source = parser.parse_expression('[lambda i: i]') 256 node = templates.replace(template, arg=source) 257 lambda_arg = node[0].value.args[0].elts[0] 258 self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param) 259 self.assertIsInstance(lambda_arg.body.ctx, gast.Load) 260 261 262if __name__ == '__main__': 263 test.main() 264