• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Generate ast module from specification
2
3This script generates the ast module from a simple specification,
4which makes it easy to accommodate changes in the grammar.  This
5approach would be quite reasonable if the grammar changed often.
6Instead, it is rather complex to generate the appropriate code.  And
7the Node interface has changed more often than the grammar.
8"""
9
10import fileinput
11import re
12import sys
13from StringIO import StringIO
14
15SPEC = "ast.txt"
16COMMA = ", "
17
18def load_boilerplate(file):
19    f = open(file)
20    buf = f.read()
21    f.close()
22    i = buf.find('### ''PROLOGUE')
23    j = buf.find('### ''EPILOGUE')
24    pro = buf[i+12:j].strip()
25    epi = buf[j+12:].strip()
26    return pro, epi
27
28def strip_default(arg):
29    """Return the argname from an 'arg = default' string"""
30    i = arg.find('=')
31    if i == -1:
32        return arg
33    t = arg[:i].strip()
34    return t
35
36P_NODE = 1
37P_OTHER = 2
38P_NESTED = 3
39P_NONE = 4
40
41class NodeInfo:
42    """Each instance describes a specific AST node"""
43    def __init__(self, name, args):
44        self.name = name
45        self.args = args.strip()
46        self.argnames = self.get_argnames()
47        self.argprops = self.get_argprops()
48        self.nargs = len(self.argnames)
49        self.init = []
50
51    def get_argnames(self):
52        if '(' in self.args:
53            i = self.args.find('(')
54            j = self.args.rfind(')')
55            args = self.args[i+1:j]
56        else:
57            args = self.args
58        return [strip_default(arg.strip())
59                for arg in args.split(',') if arg]
60
61    def get_argprops(self):
62        """Each argument can have a property like '*' or '!'
63
64        XXX This method modifies the argnames in place!
65        """
66        d = {}
67        hardest_arg = P_NODE
68        for i in range(len(self.argnames)):
69            arg = self.argnames[i]
70            if arg.endswith('*'):
71                arg = self.argnames[i] = arg[:-1]
72                d[arg] = P_OTHER
73                hardest_arg = max(hardest_arg, P_OTHER)
74            elif arg.endswith('!'):
75                arg = self.argnames[i] = arg[:-1]
76                d[arg] = P_NESTED
77                hardest_arg = max(hardest_arg, P_NESTED)
78            elif arg.endswith('&'):
79                arg = self.argnames[i] = arg[:-1]
80                d[arg] = P_NONE
81                hardest_arg = max(hardest_arg, P_NONE)
82            else:
83                d[arg] = P_NODE
84        self.hardest_arg = hardest_arg
85
86        if hardest_arg > P_NODE:
87            self.args = self.args.replace('*', '')
88            self.args = self.args.replace('!', '')
89            self.args = self.args.replace('&', '')
90
91        return d
92
93    def gen_source(self):
94        buf = StringIO()
95        print >> buf, "class %s(Node):" % self.name
96        self._gen_init(buf)
97        print >> buf
98        self._gen_getChildren(buf)
99        print >> buf
100        self._gen_getChildNodes(buf)
101        print >> buf
102        self._gen_repr(buf)
103        buf.seek(0, 0)
104        return buf.read()
105
106    def _gen_init(self, buf):
107        if self.args:
108            argtuple = '(' in self.args
109            args = self.args if not argtuple else ''.join(self.argnames)
110            print >> buf, "    def __init__(self, %s, lineno=None):" % args
111        else:
112            print >> buf, "    def __init__(self, lineno=None):"
113        if self.argnames:
114            if argtuple:
115                for idx, name in enumerate(self.argnames):
116                    print >> buf, "        self.%s = %s[%s]" % (name, args, idx)
117            else:
118                for name in self.argnames:
119                    print >> buf, "        self.%s = %s" % (name, name)
120        print >> buf, "        self.lineno = lineno"
121        # Copy the lines in self.init, indented four spaces.  The rstrip()
122        # business is to get rid of the four spaces if line happens to be
123        # empty, so that reindent.py is happy with the output.
124        for line in self.init:
125            print >> buf, ("    " + line).rstrip()
126
127    def _gen_getChildren(self, buf):
128        print >> buf, "    def getChildren(self):"
129        if len(self.argnames) == 0:
130            print >> buf, "        return ()"
131        else:
132            if self.hardest_arg < P_NESTED:
133                clist = COMMA.join(["self.%s" % c
134                                    for c in self.argnames])
135                if self.nargs == 1:
136                    print >> buf, "        return %s," % clist
137                else:
138                    print >> buf, "        return %s" % clist
139            else:
140                if len(self.argnames) == 1:
141                    print >> buf, "        return tuple(flatten(self.%s))" % self.argnames[0]
142                else:
143                    print >> buf, "        children = []"
144                    template = "        children.%s(%sself.%s%s)"
145                    for name in self.argnames:
146                        if self.argprops[name] == P_NESTED:
147                            print >> buf, template % ("extend", "flatten(",
148                                                      name, ")")
149                        else:
150                            print >> buf, template % ("append", "", name, "")
151                    print >> buf, "        return tuple(children)"
152
153    def _gen_getChildNodes(self, buf):
154        print >> buf, "    def getChildNodes(self):"
155        if len(self.argnames) == 0:
156            print >> buf, "        return ()"
157        else:
158            if self.hardest_arg < P_NESTED:
159                clist = ["self.%s" % c
160                         for c in self.argnames
161                         if self.argprops[c] == P_NODE]
162                if len(clist) == 0:
163                    print >> buf, "        return ()"
164                elif len(clist) == 1:
165                    print >> buf, "        return %s," % clist[0]
166                else:
167                    print >> buf, "        return %s" % COMMA.join(clist)
168            else:
169                print >> buf, "        nodelist = []"
170                template = "        nodelist.%s(%sself.%s%s)"
171                for name in self.argnames:
172                    if self.argprops[name] == P_NONE:
173                        tmp = ("        if self.%s is not None:\n"
174                               "            nodelist.append(self.%s)")
175                        print >> buf, tmp % (name, name)
176                    elif self.argprops[name] == P_NESTED:
177                        print >> buf, template % ("extend", "flatten_nodes(",
178                                                  name, ")")
179                    elif self.argprops[name] == P_NODE:
180                        print >> buf, template % ("append", "", name, "")
181                print >> buf, "        return tuple(nodelist)"
182
183    def _gen_repr(self, buf):
184        print >> buf, "    def __repr__(self):"
185        if self.argnames:
186            fmt = COMMA.join(["%s"] * self.nargs)
187            if '(' in self.args:
188                fmt = '(%s)' % fmt
189            vals = ["repr(self.%s)" % name for name in self.argnames]
190            vals = COMMA.join(vals)
191            if self.nargs == 1:
192                vals = vals + ","
193            print >> buf, '        return "%s(%s)" %% (%s)' % \
194                  (self.name, fmt, vals)
195        else:
196            print >> buf, '        return "%s()"' % self.name
197
198rx_init = re.compile('init\((.*)\):')
199
200def parse_spec(file):
201    classes = {}
202    cur = None
203    for line in fileinput.input(file):
204        if line.strip().startswith('#'):
205            continue
206        mo = rx_init.search(line)
207        if mo is None:
208            if cur is None:
209                # a normal entry
210                try:
211                    name, args = line.split(':')
212                except ValueError:
213                    continue
214                classes[name] = NodeInfo(name, args)
215                cur = None
216            else:
217                # some code for the __init__ method
218                cur.init.append(line)
219        else:
220            # some extra code for a Node's __init__ method
221            name = mo.group(1)
222            cur = classes[name]
223    return sorted(classes.values(), key=lambda n: n.name)
224
225def main():
226    prologue, epilogue = load_boilerplate(sys.argv[-1])
227    print prologue
228    print
229    classes = parse_spec(SPEC)
230    for info in classes:
231        print info.gen_source()
232    print epilogue
233
234if __name__ == "__main__":
235    main()
236    sys.exit(0)
237
238### PROLOGUE
239"""Python abstract syntax node definitions
240
241This file is automatically generated by Tools/compiler/astgen.py
242"""
243from consts import CO_VARARGS, CO_VARKEYWORDS
244
245def flatten(seq):
246    l = []
247    for elt in seq:
248        t = type(elt)
249        if t is tuple or t is list:
250            for elt2 in flatten(elt):
251                l.append(elt2)
252        else:
253            l.append(elt)
254    return l
255
256def flatten_nodes(seq):
257    return [n for n in flatten(seq) if isinstance(n, Node)]
258
259nodes = {}
260
261class Node:
262    """Abstract base class for ast nodes."""
263    def getChildren(self):
264        pass # implemented by subclasses
265    def __iter__(self):
266        for n in self.getChildren():
267            yield n
268    def asList(self): # for backwards compatibility
269        return self.getChildren()
270    def getChildNodes(self):
271        pass # implemented by subclasses
272
273class EmptyNode(Node):
274    pass
275
276class Expression(Node):
277    # Expression is an artificial node class to support "eval"
278    nodes["expression"] = "Expression"
279    def __init__(self, node):
280        self.node = node
281
282    def getChildren(self):
283        return self.node,
284
285    def getChildNodes(self):
286        return self.node,
287
288    def __repr__(self):
289        return "Expression(%s)" % (repr(self.node))
290
291### EPILOGUE
292for name, obj in globals().items():
293    if isinstance(obj, type) and issubclass(obj, Node):
294        nodes[name.lower()] = obj
295