• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for yapf.pytree_visitor."""
15
16import unittest
17from io import StringIO
18
19from yapf.pytree import pytree_utils
20from yapf.pytree import pytree_visitor
21
22from yapftests import yapf_test_helper
23
24
25class _NodeNameCollector(pytree_visitor.PyTreeVisitor):
26  """A tree visitor that collects the names of all tree nodes into a list.
27
28  Attributes:
29    all_node_names: collected list of the names, available when the traversal
30      is over.
31    name_node_values: collects a list of NAME leaves (in addition to those going
32      into all_node_names).
33  """
34
35  def __init__(self):
36    self.all_node_names = []
37    self.name_node_values = []
38
39  def DefaultNodeVisit(self, node):
40    self.all_node_names.append(pytree_utils.NodeName(node))
41    super(_NodeNameCollector, self).DefaultNodeVisit(node)
42
43  def DefaultLeafVisit(self, leaf):
44    self.all_node_names.append(pytree_utils.NodeName(leaf))
45
46  def Visit_NAME(self, leaf):
47    self.name_node_values.append(leaf.value)
48    self.DefaultLeafVisit(leaf)
49
50
51_VISITOR_TEST_SIMPLE_CODE = """\
52foo = bar
53baz = x
54"""
55
56_VISITOR_TEST_NESTED_CODE = """\
57if x:
58  if y:
59    return z
60"""
61
62
63class PytreeVisitorTest(yapf_test_helper.YAPFTest):
64
65  def testCollectAllNodeNamesSimpleCode(self):
66    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
67    collector = _NodeNameCollector()
68    collector.Visit(tree)
69    expected_names = [
70        'file_input',
71        'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
72        'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
73        'ENDMARKER',
74    ]  # yapf: disable
75    self.assertEqual(expected_names, collector.all_node_names)
76
77    expected_name_node_values = ['foo', 'bar', 'baz', 'x']
78    self.assertEqual(expected_name_node_values, collector.name_node_values)
79
80  def testCollectAllNodeNamesNestedCode(self):
81    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_NESTED_CODE)
82    collector = _NodeNameCollector()
83    collector.Visit(tree)
84    expected_names = [
85        'file_input',
86        'if_stmt', 'NAME', 'NAME', 'COLON',
87        'suite', 'NEWLINE',
88        'INDENT', 'if_stmt', 'NAME', 'NAME', 'COLON', 'suite', 'NEWLINE',
89        'INDENT', 'simple_stmt', 'return_stmt', 'NAME', 'NAME', 'NEWLINE',
90        'DEDENT', 'DEDENT', 'ENDMARKER',
91    ]  # yapf: disable
92    self.assertEqual(expected_names, collector.all_node_names)
93
94    expected_name_node_values = ['if', 'x', 'if', 'y', 'return', 'z']
95    self.assertEqual(expected_name_node_values, collector.name_node_values)
96
97  def testDumper(self):
98    # PyTreeDumper is mainly a debugging utility, so only do basic sanity
99    # checking.
100    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
101    stream = StringIO()
102    pytree_visitor.PyTreeDumper(target_stream=stream).Visit(tree)
103
104    dump_output = stream.getvalue()
105    self.assertIn('file_input [3 children]', dump_output)
106    self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
107    self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
108
109  def testDumpPyTree(self):
110    # Similar sanity checking for the convenience wrapper DumpPyTree
111    tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
112    stream = StringIO()
113    pytree_visitor.DumpPyTree(tree, target_stream=stream)
114
115    dump_output = stream.getvalue()
116    self.assertIn('file_input [3 children]', dump_output)
117    self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
118    self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
119
120
121if __name__ == '__main__':
122  unittest.main()
123