• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for yapf.pytree_visitor."""
15
16import unittest
17
18from yapf.yapflib import py3compat
19from yapf.yapflib import pytree_utils
20from yapf.yapflib import pytree_visitor
21
22
23class _NodeNameCollector(pytree_visitor.PyTreeVisitor):
24  """A tree visitor that collects the names of all tree nodes into a list.
25
26  Attributes:
27    all_node_names: collected list of the names, available when the traversal
28      is over.
29    name_node_values: collects a list of NAME leaves (in addition to those going
30      into all_node_names).
31  """
32
33  def __init__(self):
34    self.all_node_names = []
35    self.name_node_values = []
36
37  def DefaultNodeVisit(self, node):
38    self.all_node_names.append(pytree_utils.NodeName(node))
39    super(_NodeNameCollector, self).DefaultNodeVisit(node)
40
41  def DefaultLeafVisit(self, leaf):
42    self.all_node_names.append(pytree_utils.NodeName(leaf))
43
44  def Visit_NAME(self, leaf):
45    self.name_node_values.append(leaf.value)
46    self.DefaultLeafVisit(leaf)
47
48
49_VISITOR_TEST_SIMPLE_CODE = r"""
50foo = bar
51baz = x
52"""
53
54_VISITOR_TEST_NESTED_CODE = r"""
55if x:
56  if y:
57    return z
58"""
59
60
61class PytreeVisitorTest(unittest.TestCase):
62
63  def testCollectAllNodeNamesSimpleCode(self):
64    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
65    collector = _NodeNameCollector()
66    collector.Visit(tree)
67    expected_names = [
68        'file_input',
69        'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
70        'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
71        'ENDMARKER',
72    ]  # yapf: disable
73    self.assertEqual(expected_names, collector.all_node_names)
74
75    expected_name_node_values = ['foo', 'bar', 'baz', 'x']
76    self.assertEqual(expected_name_node_values, collector.name_node_values)
77
78  def testCollectAllNodeNamesNestedCode(self):
79    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_NESTED_CODE)
80    collector = _NodeNameCollector()
81    collector.Visit(tree)
82    expected_names = [
83        'file_input',
84        'if_stmt', 'NAME', 'NAME', 'COLON',
85        'suite', 'NEWLINE',
86        'INDENT', 'if_stmt', 'NAME', 'NAME', 'COLON', 'suite', 'NEWLINE',
87        'INDENT', 'simple_stmt', 'return_stmt', 'NAME', 'NAME', 'NEWLINE',
88        'DEDENT', 'DEDENT', 'ENDMARKER',
89    ]  # yapf: disable
90    self.assertEqual(expected_names, collector.all_node_names)
91
92    expected_name_node_values = ['if', 'x', 'if', 'y', 'return', 'z']
93    self.assertEqual(expected_name_node_values, collector.name_node_values)
94
95  def testDumper(self):
96    # PyTreeDumper is mainly a debugging utility, so only do basic sanity
97    # checking.
98    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
99    stream = py3compat.StringIO()
100    pytree_visitor.PyTreeDumper(target_stream=stream).Visit(tree)
101
102    dump_output = stream.getvalue()
103    self.assertIn('file_input [3 children]', dump_output)
104    self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
105    self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
106
107  def testDumpPyTree(self):
108    # Similar sanity checking for the convenience wrapper DumpPyTree
109    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
110    stream = py3compat.StringIO()
111    pytree_visitor.DumpPyTree(tree, target_stream=stream)
112
113    dump_output = stream.getvalue()
114    self.assertIn('file_input [3 children]', dump_output)
115    self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
116    self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
117
118
119if __name__ == '__main__':
120  unittest.main()
121