• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for templates module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import imp
22
23import gast
24
25from tensorflow.python.autograph.pyct import compiler
26from tensorflow.python.autograph.pyct import parser
27from tensorflow.python.autograph.pyct import templates
28from tensorflow.python.platform import test
29
30
31class TemplatesTest(test.TestCase):
32
33  def test_replace_tuple(self):
34    template = """
35      def test_fn(a, c):
36        return b,
37    """
38
39    node = templates.replace(template, b=('a', 'c'))[0]
40    result, _ = compiler.ast_to_object(node)
41
42    self.assertEquals((2, 3), result.test_fn(2, 3))
43
44  def test_replace_variable(self):
45    template = """
46      def test_fn(a):
47        a += 1
48        a = 2 * a + 1
49        return b
50    """
51
52    node = templates.replace(template, a='b')[0]
53    result, _ = compiler.ast_to_object(node)
54    self.assertEquals(7, result.test_fn(2))
55
56  def test_replace_function_name(self):
57    template = """
58      def fname(a):
59        a += 1
60        a = 2 * a + 1
61        return a
62    """
63
64    node = templates.replace(template, fname='test_fn')[0]
65    result, _ = compiler.ast_to_object(node)
66    self.assertEquals(7, result.test_fn(2))
67
68  def test_replace_code_block(self):
69    template = """
70      def test_fn(a):
71        block
72        return a
73    """
74
75    node = templates.replace(
76        template,
77        block=[
78            gast.Assign([
79                gast.Name('a', None, None)
80            ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
81        ] * 2)[0]
82    result, _ = compiler.ast_to_object(node)
83    self.assertEquals(3, result.test_fn(1))
84
85  def test_replace_attribute(self):
86    template = """
87      def test_fn(a):
88        return a.foo
89    """
90
91    node = templates.replace(template, foo='b')[0]
92    result, _ = compiler.ast_to_object(node)
93    mod = imp.new_module('test')
94    mod.b = 3
95    self.assertEquals(3, result.test_fn(mod))
96
97    with self.assertRaises(ValueError):
98      templates.replace(template, foo=1)
99
100  def test_replace_attribute_context(self):
101    template = """
102      def test_fn(foo):
103        foo = 0
104    """
105
106    node = templates.replace(
107        template,
108        foo=parser.parse_expression('a.b.c'))[0]
109    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
110    self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
111    self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
112
113  def test_replace_list_context(self):
114    template = """
115      def test_fn(foo):
116        foo = 0
117    """
118
119    node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
120    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
121    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
122    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
123
124  def test_replace_tuple_context(self):
125    template = """
126      def test_fn(foo):
127        foo = 0
128    """
129
130    node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
131    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
132    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
133    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
134
135  def test_replace_expression_context(self):
136    template = """
137      def test_fn():
138        foo
139    """
140
141    node = templates.replace(
142        template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
143    self.assertIsInstance(node.body[0].left.ctx, gast.Load)
144    self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
145
146  def test_replace_complex_context(self):
147    template = """
148      def test_fn():
149        foo = 0
150    """
151
152    node = templates.replace(
153        template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
154    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
155    function_call_arg = node.body[0].targets[0].value.args[0]
156    self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
157    self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
158    self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
159
160  def test_replace_index(self):
161    template = """
162      def test_fn():
163        foo = 0
164    """
165
166    node = templates.replace(
167        template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
168    function_call_arg = node.body[0].targets[0].value.args[0]
169    self.assertIsInstance(function_call_arg.ctx, gast.Load)
170    self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
171
172  def test_replace_call_keyword(self):
173    template = """
174      def test_fn():
175        def f(a, d, f):
176          return a + d + f
177        return f(1, kws=None)
178    """
179
180    source = parser.parse_expression('f(d=3, f=5)')
181    node = templates.replace(template, kws=source.keywords)[0]
182    result, _ = compiler.ast_to_object(node)
183    self.assertEquals(9, result.test_fn())
184
185    with self.assertRaises(ValueError):
186      templates.replace(template, kws=[])
187      templates.replace(template, kws=1)
188
189  def test_replace_name_with_call(self):
190    template = """
191      def test_fn():
192        b = 5
193        def g(a):
194          return 3 * a
195        def f():
196          return g
197        return foo
198    """
199
200    source = parser.parse_expression('f()(b)')
201    node = templates.replace(template, foo=source)[0]
202    result, _ = compiler.ast_to_object(node)
203    self.assertEquals(15, result.test_fn())
204
205  def test_replace_name_with_dict(self):
206    template = """
207      def test_fn():
208        return foo['bar']
209    """
210
211    source = parser.parse_expression('{\'bar\': 3}')
212    node = templates.replace(template, foo=source)[0]
213    result, _ = compiler.ast_to_object(node)
214    self.assertEquals(3, result.test_fn())
215
216  def test_replace_as_expression(self):
217    template = """
218      foo(a)
219    """
220
221    node = templates.replace_as_expression(template, foo='bar', a='baz')
222    self.assertIsInstance(node, gast.Call)
223    self.assertEqual(node.func.id, 'bar')
224    self.assertEqual(node.args[0].id, 'baz')
225
226  def test_replace_as_expression_restrictions(self):
227    template = """
228      foo(a)
229      bar(b)
230    """
231    with self.assertRaises(ValueError):
232      templates.replace_as_expression(template)
233
234  def test_function_call_in_list(self):
235    template = """
236        foo(bar)
237    """
238    source = parser.parse_expression('[a(b(1))]')
239    templates.replace_as_expression(template, bar=source)
240
241  def test_star_comprehension_in_function_call(self):
242    template = """
243      a = foo(func, args)
244    """
245    source = parser.parse_expression('bar(*[i for i in range(j)])')
246    node = templates.replace(template, func=source.func, args=source.args)
247    arg_node = node[0].value.args[1].value
248    self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store)
249    self.assertIsInstance(arg_node.elt.ctx, gast.Load)
250
251  def test_lambda_in_function_call(self):
252    template = """
253      a = foo(arg)
254    """
255    source = parser.parse_expression('[lambda i: i]')
256    node = templates.replace(template, arg=source)
257    lambda_arg = node[0].value.args[0].elts[0]
258    self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
259    self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
260
261
262if __name__ == '__main__':
263  test.main()
264