• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for compiler module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import textwrap
22
23import gast
24
25from tensorflow.python.autograph.pyct import compiler
26from tensorflow.python.autograph.pyct import parser
27from tensorflow.python.platform import test
28from tensorflow.python.util import tf_inspect
29
30
31class CompilerTest(test.TestCase):
32
33  def test_parser_compile_idempotent(self):
34
35    def test_fn(x):
36      a = True
37      b = ''
38      if a:
39        b = x + 1
40      return b
41
42    _, _, all_nodes = parser.parse_entity(test_fn)
43
44    self.assertEqual(
45        textwrap.dedent(tf_inspect.getsource(test_fn)),
46        tf_inspect.getsource(
47            compiler.ast_to_object(all_nodes)[0].test_fn))
48
49  def test_ast_to_source(self):
50    node = gast.If(
51        test=gast.Num(1),
52        body=[
53            gast.Assign(
54                targets=[gast.Name('a', gast.Store(), None)],
55                value=gast.Name('b', gast.Load(), None))
56        ],
57        orelse=[
58            gast.Assign(
59                targets=[gast.Name('a', gast.Store(), None)],
60                value=gast.Str('c'))
61        ])
62
63    source = compiler.ast_to_source(node, indentation='  ')
64    self.assertEqual(
65        textwrap.dedent("""
66            if 1:
67              a = b
68            else:
69              a = 'c'
70        """).strip(), source.strip())
71
72  def test_ast_to_object(self):
73    node = gast.FunctionDef(
74        name='f',
75        args=gast.arguments(
76            args=[gast.Name('a', gast.Param(), None)],
77            vararg=None,
78            kwonlyargs=[],
79            kwarg=None,
80            defaults=[],
81            kw_defaults=[]),
82        body=[
83            gast.Return(
84                gast.BinOp(
85                    op=gast.Add(),
86                    left=gast.Name('a', gast.Load(), None),
87                    right=gast.Num(1)))
88        ],
89        decorator_list=[],
90        returns=None)
91
92    module, source = compiler.ast_to_object(node)
93
94    expected_source = """
95      def f(a):
96        return a + 1
97    """
98    self.assertEqual(
99        textwrap.dedent(expected_source).strip(),
100        source.strip())
101    self.assertEqual(2, module.f(1))
102    with open(module.__file__, 'r') as temp_output:
103      self.assertEqual(
104          textwrap.dedent(expected_source).strip(),
105          temp_output.read().strip())
106
107
108if __name__ == '__main__':
109  test.main()
110