1# Copyright 2015 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for yapf.pytree_visitor.""" 15 16import unittest 17 18from yapf.yapflib import py3compat 19from yapf.yapflib import pytree_utils 20from yapf.yapflib import pytree_visitor 21 22 23class _NodeNameCollector(pytree_visitor.PyTreeVisitor): 24 """A tree visitor that collects the names of all tree nodes into a list. 25 26 Attributes: 27 all_node_names: collected list of the names, available when the traversal 28 is over. 29 name_node_values: collects a list of NAME leaves (in addition to those going 30 into all_node_names). 31 """ 32 33 def __init__(self): 34 self.all_node_names = [] 35 self.name_node_values = [] 36 37 def DefaultNodeVisit(self, node): 38 self.all_node_names.append(pytree_utils.NodeName(node)) 39 super(_NodeNameCollector, self).DefaultNodeVisit(node) 40 41 def DefaultLeafVisit(self, leaf): 42 self.all_node_names.append(pytree_utils.NodeName(leaf)) 43 44 def Visit_NAME(self, leaf): 45 self.name_node_values.append(leaf.value) 46 self.DefaultLeafVisit(leaf) 47 48 49_VISITOR_TEST_SIMPLE_CODE = r""" 50foo = bar 51baz = x 52""" 53 54_VISITOR_TEST_NESTED_CODE = r""" 55if x: 56 if y: 57 return z 58""" 59 60 61class PytreeVisitorTest(unittest.TestCase): 62 63 def testCollectAllNodeNamesSimpleCode(self): 64 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE) 65 collector = _NodeNameCollector() 66 collector.Visit(tree) 67 expected_names = [ 68 'file_input', 69 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE', 70 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE', 71 'ENDMARKER', 72 ] # yapf: disable 73 self.assertEqual(expected_names, collector.all_node_names) 74 75 expected_name_node_values = ['foo', 'bar', 'baz', 'x'] 76 self.assertEqual(expected_name_node_values, collector.name_node_values) 77 78 def testCollectAllNodeNamesNestedCode(self): 79 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_NESTED_CODE) 80 collector = _NodeNameCollector() 81 collector.Visit(tree) 82 expected_names = [ 83 'file_input', 84 'if_stmt', 'NAME', 'NAME', 'COLON', 85 'suite', 'NEWLINE', 86 'INDENT', 'if_stmt', 'NAME', 'NAME', 'COLON', 'suite', 'NEWLINE', 87 'INDENT', 'simple_stmt', 'return_stmt', 'NAME', 'NAME', 'NEWLINE', 88 'DEDENT', 'DEDENT', 'ENDMARKER', 89 ] # yapf: disable 90 self.assertEqual(expected_names, collector.all_node_names) 91 92 expected_name_node_values = ['if', 'x', 'if', 'y', 'return', 'z'] 93 self.assertEqual(expected_name_node_values, collector.name_node_values) 94 95 def testDumper(self): 96 # PyTreeDumper is mainly a debugging utility, so only do basic sanity 97 # checking. 98 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE) 99 stream = py3compat.StringIO() 100 pytree_visitor.PyTreeDumper(target_stream=stream).Visit(tree) 101 102 dump_output = stream.getvalue() 103 self.assertIn('file_input [3 children]', dump_output) 104 self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output) 105 self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output) 106 107 def testDumpPyTree(self): 108 # Similar sanity checking for the convenience wrapper DumpPyTree 109 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE) 110 stream = py3compat.StringIO() 111 pytree_visitor.DumpPyTree(tree, target_stream=stream) 112 113 dump_output = stream.getvalue() 114 self.assertIn('file_input [3 children]', dump_output) 115 self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output) 116 self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output) 117 118 119if __name__ == '__main__': 120 unittest.main() 121