1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for compiler module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import textwrap 22 23import gast 24 25from tensorflow.python.autograph.pyct import compiler 26from tensorflow.python.autograph.pyct import parser 27from tensorflow.python.platform import test 28from tensorflow.python.util import tf_inspect 29 30 31class CompilerTest(test.TestCase): 32 33 def test_parser_compile_idempotent(self): 34 35 def test_fn(x): 36 a = True 37 b = '' 38 if a: 39 b = x + 1 40 return b 41 42 _, _, all_nodes = parser.parse_entity(test_fn) 43 44 self.assertEqual( 45 textwrap.dedent(tf_inspect.getsource(test_fn)), 46 tf_inspect.getsource( 47 compiler.ast_to_object(all_nodes)[0].test_fn)) 48 49 def test_ast_to_source(self): 50 node = gast.If( 51 test=gast.Num(1), 52 body=[ 53 gast.Assign( 54 targets=[gast.Name('a', gast.Store(), None)], 55 value=gast.Name('b', gast.Load(), None)) 56 ], 57 orelse=[ 58 gast.Assign( 59 targets=[gast.Name('a', gast.Store(), None)], 60 value=gast.Str('c')) 61 ]) 62 63 source = compiler.ast_to_source(node, indentation=' ') 64 self.assertEqual( 65 textwrap.dedent(""" 66 if 1: 67 a = b 68 else: 69 a = 'c' 70 """).strip(), source.strip()) 71 72 def test_ast_to_object(self): 73 node = gast.FunctionDef( 74 name='f', 75 args=gast.arguments( 76 args=[gast.Name('a', gast.Param(), None)], 77 vararg=None, 78 kwonlyargs=[], 79 kwarg=None, 80 defaults=[], 81 kw_defaults=[]), 82 body=[ 83 gast.Return( 84 gast.BinOp( 85 op=gast.Add(), 86 left=gast.Name('a', gast.Load(), None), 87 right=gast.Num(1))) 88 ], 89 decorator_list=[], 90 returns=None) 91 92 module, source = compiler.ast_to_object(node) 93 94 expected_source = """ 95 def f(a): 96 return a + 1 97 """ 98 self.assertEqual( 99 textwrap.dedent(expected_source).strip(), 100 source.strip()) 101 self.assertEqual(2, module.f(1)) 102 with open(module.__file__, 'r') as temp_output: 103 self.assertEqual( 104 textwrap.dedent(expected_source).strip(), 105 temp_output.read().strip()) 106 107 108if __name__ == '__main__': 109 test.main() 110