1# Copyright 2015 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for yapf.pytree_visitor.""" 15 16import unittest 17from io import StringIO 18 19from yapf.pytree import pytree_utils 20from yapf.pytree import pytree_visitor 21 22from yapftests import yapf_test_helper 23 24 25class _NodeNameCollector(pytree_visitor.PyTreeVisitor): 26 """A tree visitor that collects the names of all tree nodes into a list. 27 28 Attributes: 29 all_node_names: collected list of the names, available when the traversal 30 is over. 31 name_node_values: collects a list of NAME leaves (in addition to those going 32 into all_node_names). 33 """ 34 35 def __init__(self): 36 self.all_node_names = [] 37 self.name_node_values = [] 38 39 def DefaultNodeVisit(self, node): 40 self.all_node_names.append(pytree_utils.NodeName(node)) 41 super(_NodeNameCollector, self).DefaultNodeVisit(node) 42 43 def DefaultLeafVisit(self, leaf): 44 self.all_node_names.append(pytree_utils.NodeName(leaf)) 45 46 def Visit_NAME(self, leaf): 47 self.name_node_values.append(leaf.value) 48 self.DefaultLeafVisit(leaf) 49 50 51_VISITOR_TEST_SIMPLE_CODE = """\ 52foo = bar 53baz = x 54""" 55 56_VISITOR_TEST_NESTED_CODE = """\ 57if x: 58 if y: 59 return z 60""" 61 62 63class PytreeVisitorTest(yapf_test_helper.YAPFTest): 64 65 def testCollectAllNodeNamesSimpleCode(self): 66 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE) 67 collector = _NodeNameCollector() 68 collector.Visit(tree) 69 expected_names = [ 70 'file_input', 71 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE', 72 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE', 73 'ENDMARKER', 74 ] # yapf: disable 75 self.assertEqual(expected_names, collector.all_node_names) 76 77 expected_name_node_values = ['foo', 'bar', 'baz', 'x'] 78 self.assertEqual(expected_name_node_values, collector.name_node_values) 79 80 def testCollectAllNodeNamesNestedCode(self): 81 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_NESTED_CODE) 82 collector = _NodeNameCollector() 83 collector.Visit(tree) 84 expected_names = [ 85 'file_input', 86 'if_stmt', 'NAME', 'NAME', 'COLON', 87 'suite', 'NEWLINE', 88 'INDENT', 'if_stmt', 'NAME', 'NAME', 'COLON', 'suite', 'NEWLINE', 89 'INDENT', 'simple_stmt', 'return_stmt', 'NAME', 'NAME', 'NEWLINE', 90 'DEDENT', 'DEDENT', 'ENDMARKER', 91 ] # yapf: disable 92 self.assertEqual(expected_names, collector.all_node_names) 93 94 expected_name_node_values = ['if', 'x', 'if', 'y', 'return', 'z'] 95 self.assertEqual(expected_name_node_values, collector.name_node_values) 96 97 def testDumper(self): 98 # PyTreeDumper is mainly a debugging utility, so only do basic sanity 99 # checking. 100 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE) 101 stream = StringIO() 102 pytree_visitor.PyTreeDumper(target_stream=stream).Visit(tree) 103 104 dump_output = stream.getvalue() 105 self.assertIn('file_input [3 children]', dump_output) 106 self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output) 107 self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output) 108 109 def testDumpPyTree(self): 110 # Similar sanity checking for the convenience wrapper DumpPyTree 111 tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE) 112 stream = StringIO() 113 pytree_visitor.DumpPyTree(tree, target_stream=stream) 114 115 dump_output = stream.getvalue() 116 self.assertIn('file_input [3 children]', dump_output) 117 self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output) 118 self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output) 119 120 121if __name__ == '__main__': 122 unittest.main() 123