• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Print an AST tree in a form more readable than ast.dump."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22import gast
23import six
24import termcolor
25
26
27class PrettyPrinter(gast.NodeVisitor):
28  """Print AST nodes."""
29
30  def __init__(self, color, noanno):
31    self.indent_lvl = 0
32    self.result = ''
33    self.color = color
34    self.noanno = noanno
35
36  def _color(self, string, color, attrs=None):
37    if self.color:
38      return termcolor.colored(string, color, attrs=attrs)
39    return string
40
41  def _type(self, node):
42    return self._color(node.__class__.__name__, None, ['bold'])
43
44  def _field(self, name):
45    return self._color(name, 'blue')
46
47  def _value(self, name):
48    return self._color(name, 'magenta')
49
50  def _warning(self, name):
51    return self._color(name, 'red')
52
53  def _indent(self):
54    return self._color('| ' * self.indent_lvl, None, ['dark'])
55
56  def _print(self, s):
57    self.result += s
58    self.result += '\n'
59
60  def generic_visit(self, node, name=None):
61    # In very rare instances, a list can contain something other than a Node.
62    # e.g. Global contains a list of strings.
63    if isinstance(node, str):
64      if name:
65        self._print('%s%s="%s"' % (self._indent(), name, node))
66      else:
67        self._print('%s"%s"' % (self._indent(), node))
68      return
69
70    if node._fields:
71      cont = ':'
72    else:
73      cont = '()'
74
75    if name:
76      self._print('%s%s=%s%s' % (self._indent(), self._field(name),
77                                 self._type(node), cont))
78    else:
79      self._print('%s%s%s' % (self._indent(), self._type(node), cont))
80
81    self.indent_lvl += 1
82    for f in node._fields:
83      if self.noanno and f.startswith('__'):
84        continue
85      if not hasattr(node, f):
86        self._print('%s%s' % (self._indent(), self._warning('%s=<unset>' % f)))
87        continue
88      v = getattr(node, f)
89      if isinstance(v, list):
90        if v:
91          self._print('%s%s=[' % (self._indent(), self._field(f)))
92          self.indent_lvl += 1
93          for n in v:
94            self.generic_visit(n)
95          self.indent_lvl -= 1
96          self._print('%s]' % (self._indent()))
97        else:
98          self._print('%s%s=[]' % (self._indent(), self._field(f)))
99      elif isinstance(v, tuple):
100        if v:
101          self._print('%s%s=(' % (self._indent(), self._field(f)))
102          self.indent_lvl += 1
103          for n in v:
104            self.generic_visit(n)
105          self.indent_lvl -= 1
106          self._print('%s)' % (self._indent()))
107        else:
108          self._print('%s%s=()' % (self._indent(), self._field(f)))
109      elif isinstance(v, gast.AST):
110        self.generic_visit(v, f)
111      elif isinstance(v, six.binary_type):
112        self._print('%s%s=%s' % (self._indent(), self._field(f),
113                                 self._value('b"%s"' % v)))
114      elif isinstance(v, six.text_type):
115        self._print('%s%s=%s' % (self._indent(), self._field(f),
116                                 self._value('u"%s"' % v)))
117      else:
118        self._print('%s%s=%s' % (self._indent(), self._field(f),
119                                 self._value(v)))
120    self.indent_lvl -= 1
121
122
123def fmt(node, color=True, noanno=False):
124  printer = PrettyPrinter(color, noanno)
125  if isinstance(node, (list, tuple)):
126    for n in node:
127      printer.visit(n)
128  else:
129    printer.visit(node)
130  return printer.result
131