• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for templates module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import imp
22
23from absl.testing import parameterized
24import gast
25
26from tensorflow.python.autograph.pyct import loader
27from tensorflow.python.autograph.pyct import parser
28from tensorflow.python.autograph.pyct import qual_names as qn
29from tensorflow.python.autograph.pyct import templates
30from tensorflow.python.platform import test
31
32
33class _CtxClearer(gast.NodeTransformer):
34
35  def visit(self, node):
36    super(_CtxClearer, self).visit(node)
37    if hasattr(node, 'ctx'):
38      node.ctx = None
39    return node
40
41
42def _parse_with_unset_ctx(expr_source):
43  ast_node = parser.parse_expression(expr_source)
44  _CtxClearer().visit(ast_node)
45  return ast_node
46
47
48class _CtxChecker(gast.NodeTransformer):
49
50  def __init__(self, test_instance, expected_ctx):
51    self.at_top_level = True
52    self.test_instance = test_instance
53    self.expected_ctx = expected_ctx
54
55  def visit(self, node):
56    if hasattr(node, 'ctx'):
57      self.test_instance.assertIsInstance(node.ctx, self.expected_ctx)
58    if self.at_top_level:
59      self.at_top_level = False
60      self.expected_ctx = gast.Load
61    return super(_CtxChecker, self).visit(node)
62
63
64class TemplatesTest(test.TestCase, parameterized.TestCase):
65
66  def assertExpectedCtxSet(self, node, ctx):
67    """Assert that node has ctx=ctx at top and ctx=gast.Load everywhere else."""
68    checker = _CtxChecker(self, ctx)
69    checker.visit(node)
70
71  def test_replace_tuple(self):
72    template = """
73      def test_fn(a, c):
74        return b,
75    """
76
77    node = templates.replace(template, b=('a', 'c'))[0]
78    result, _, _ = loader.load_ast(node)
79
80    self.assertEqual((2, 3), result.test_fn(2, 3))
81
82  def test_replace_variable(self):
83    template = """
84      def test_fn(a):
85        a += 1
86        a = 2 * a + 1
87        return b
88    """
89
90    node = templates.replace(template, a='b')[0]
91    result, _, _ = loader.load_ast(node)
92    self.assertEqual(7, result.test_fn(2))
93
94  def test_replace_function_name(self):
95    template = """
96      def fname(a):
97        a += 1
98        a = 2 * a + 1
99        return a
100    """
101
102    node = templates.replace(template, fname='test_fn')[0]
103    result, _, _ = loader.load_ast(node)
104    self.assertEqual(7, result.test_fn(2))
105
106  def test_replace_code_block(self):
107    template = """
108      def test_fn(a):
109        block
110        return a
111    """
112
113    class ShouldBeReplaced(object):
114      pass
115
116    node = templates.replace(
117        template,
118        block=[
119            gast.Assign(
120                [
121                    gast.Name(
122                        'a',
123                        ctx=ShouldBeReplaced,
124                        annotation=None,
125                        type_comment=None)
126                ],
127                gast.BinOp(
128                    gast.Name(
129                        'a',
130                        ctx=ShouldBeReplaced,
131                        annotation=None,
132                        type_comment=None), gast.Add(),
133                    gast.Constant(1, kind=None)),
134            ),
135        ] * 2)[0]
136    result, _, _ = loader.load_ast(node)
137    self.assertEqual(3, result.test_fn(1))
138
139  def test_replace_attribute(self):
140    template = """
141      def test_fn(a):
142        return a.foo
143    """
144
145    node = templates.replace(template, foo='b')[0]
146    result, _, _ = loader.load_ast(node)
147    mod = imp.new_module('test')
148    mod.b = 3
149    self.assertEqual(3, result.test_fn(mod))
150
151    with self.assertRaises(ValueError):
152      templates.replace(template, foo=1)
153
154  def test_replace_attribute_context(self):
155    template = """
156      def test_fn(foo):
157        foo = 0
158    """
159
160    node = templates.replace(
161        template,
162        foo=parser.parse_expression('a.b.c'))[0]
163    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
164    self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
165    self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
166
167  def test_replace_list_context(self):
168    template = """
169      def test_fn(foo):
170        foo = 0
171    """
172
173    node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
174    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
175    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
176    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
177
178  def test_replace_tuple_context(self):
179    template = """
180      def test_fn(foo):
181        foo = 0
182    """
183
184    node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
185    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
186    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
187    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
188
189  def test_replace_expression_context(self):
190    template = """
191      def test_fn():
192        foo
193    """
194
195    node = templates.replace(
196        template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
197    self.assertIsInstance(node.body[0].left.ctx, gast.Load)
198    self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
199
200  def test_replace_complex_context(self):
201    template = """
202      def test_fn():
203        foo = 0
204    """
205
206    node = templates.replace(
207        template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
208    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
209    function_call_arg = node.body[0].targets[0].value.args[0]
210    self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
211    self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
212    self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
213
214  def test_replace_index(self):
215    template = """
216      def test_fn():
217        foo = 0
218    """
219
220    node = templates.replace(
221        template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
222    function_call_arg = node.body[0].targets[0].value.args[0]
223    self.assertIsInstance(function_call_arg.ctx, gast.Load)
224    self.assertIsInstance(function_call_arg.slice.ctx, gast.Load)
225
226  def test_replace_call_keyword(self):
227    template = """
228      def test_fn():
229        def f(a, d, f):
230          return a + d + f
231        return f(1, kws=None)
232    """
233
234    source = parser.parse_expression('f(d=3, f=5)')
235    node = templates.replace(template, kws=source.keywords)[0]
236    result, _, _ = loader.load_ast(node)
237    self.assertEqual(9, result.test_fn())
238
239    with self.assertRaises(ValueError):
240      templates.replace(template, kws=[])
241      templates.replace(template, kws=1)
242
243  def test_replace_name_with_call(self):
244    template = """
245      def test_fn():
246        b = 5
247        def g(a):
248          return 3 * a
249        def f():
250          return g
251        return foo
252    """
253
254    source = parser.parse_expression('f()(b)')
255    node = templates.replace(template, foo=source)[0]
256    result, _, _ = loader.load_ast(node)
257    self.assertEqual(15, result.test_fn())
258
259  def test_replace_name_with_dict(self):
260    template = """
261      def test_fn():
262        return foo['bar']
263    """
264
265    source = parser.parse_expression('{\'bar\': 3}')
266    node = templates.replace(template, foo=source)[0]
267    result, _, _ = loader.load_ast(node)
268    self.assertEqual(3, result.test_fn())
269
270  def test_replace_as_expression(self):
271    template = """
272      foo(a)
273    """
274
275    node = templates.replace_as_expression(template, foo='bar', a='baz')
276    self.assertIsInstance(node, gast.Call)
277    self.assertEqual(node.func.id, 'bar')
278    self.assertEqual(node.args[0].id, 'baz')
279
280  def test_replace_as_expression_restrictions(self):
281    template = """
282      foo(a)
283      bar(b)
284    """
285    with self.assertRaises(ValueError):
286      templates.replace_as_expression(template)
287
288  def test_function_call_in_list(self):
289    template = """
290        foo(bar)
291    """
292    source = parser.parse_expression('[a(b(1))]')
293    templates.replace_as_expression(template, bar=source)
294
295  def test_star_comprehension_in_function_call(self):
296    template = """
297      a = foo(func, args)
298    """
299    source = parser.parse_expression('bar(*[i for i in range(j)])')
300    node = templates.replace(template, func=source.func, args=source.args)
301    arg_node = node[0].value.args[1].value
302    self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store)
303    self.assertIsInstance(arg_node.elt.ctx, gast.Load)
304
305  def test_lambda_in_function_call(self):
306    template = """
307      a = foo(arg)
308    """
309    source = parser.parse_expression('[lambda i: i]')
310    node = templates.replace(template, arg=source)
311    lambda_arg = node[0].value.args[0].elts[0]
312    self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
313    self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
314
315  def test_replace_name_with_subscript(self):
316    template = """
317        foo = bar
318    """
319    replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key'))
320
321    node = templates.replace(template, foo=replacement)[0].targets[0]
322    self.assertIsInstance(node.ctx, gast.Store)
323    self.assertIsInstance(node.value.ctx, gast.Load)
324
325  @parameterized.named_parameters([
326      ('mixed_attr_subscript', 'a.b["c"]'),
327      ('mixed_subscript_attr', 'a[b.c]'),
328      ('nested_subscript', 'a[b[c]]'),
329      ('repeated_subscript', 'a[b][c]'),
330  ])
331  def test_replace_name_mixed_attr_subscript(self, expression_source):
332    template = 'foo = bar'
333    replacement = _parse_with_unset_ctx(expression_source)
334
335    target_node = templates.replace(template, foo=replacement)[0].targets[0]
336    self.assertExpectedCtxSet(target_node, gast.Store)
337
338    value_node = templates.replace(template, bar=replacement)[0].value
339    self.assertExpectedCtxSet(value_node, gast.Load)
340
341if __name__ == '__main__':
342  test.main()
343