• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mako/_ast_util.py
2# Copyright 2006-2023 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""
8    ast
9    ~~~
10
11    This is a stripped down version of Armin Ronacher's ast module.
12
13    :copyright: Copyright 2008 by Armin Ronacher.
14    :license: Python License.
15"""
16
17
18from _ast import Add
19from _ast import And
20from _ast import AST
21from _ast import BitAnd
22from _ast import BitOr
23from _ast import BitXor
24from _ast import Div
25from _ast import Eq
26from _ast import FloorDiv
27from _ast import Gt
28from _ast import GtE
29from _ast import If
30from _ast import In
31from _ast import Invert
32from _ast import Is
33from _ast import IsNot
34from _ast import LShift
35from _ast import Lt
36from _ast import LtE
37from _ast import Mod
38from _ast import Mult
39from _ast import Name
40from _ast import Not
41from _ast import NotEq
42from _ast import NotIn
43from _ast import Or
44from _ast import PyCF_ONLY_AST
45from _ast import RShift
46from _ast import Sub
47from _ast import UAdd
48from _ast import USub
49
50
51BOOLOP_SYMBOLS = {And: "and", Or: "or"}
52
53BINOP_SYMBOLS = {
54    Add: "+",
55    Sub: "-",
56    Mult: "*",
57    Div: "/",
58    FloorDiv: "//",
59    Mod: "%",
60    LShift: "<<",
61    RShift: ">>",
62    BitOr: "|",
63    BitAnd: "&",
64    BitXor: "^",
65}
66
67CMPOP_SYMBOLS = {
68    Eq: "==",
69    Gt: ">",
70    GtE: ">=",
71    In: "in",
72    Is: "is",
73    IsNot: "is not",
74    Lt: "<",
75    LtE: "<=",
76    NotEq: "!=",
77    NotIn: "not in",
78}
79
80UNARYOP_SYMBOLS = {Invert: "~", Not: "not", UAdd: "+", USub: "-"}
81
82ALL_SYMBOLS = {}
83ALL_SYMBOLS.update(BOOLOP_SYMBOLS)
84ALL_SYMBOLS.update(BINOP_SYMBOLS)
85ALL_SYMBOLS.update(CMPOP_SYMBOLS)
86ALL_SYMBOLS.update(UNARYOP_SYMBOLS)
87
88
89def parse(expr, filename="<unknown>", mode="exec"):
90    """Parse an expression into an AST node."""
91    return compile(expr, filename, mode, PyCF_ONLY_AST)
92
93
94def iter_fields(node):
95    """Iterate over all fields of a node, only yielding existing fields."""
96
97    for field in node._fields:
98        try:
99            yield field, getattr(node, field)
100        except AttributeError:
101            pass
102
103
104class NodeVisitor:
105
106    """
107    Walks the abstract syntax tree and call visitor functions for every node
108    found.  The visitor functions may return values which will be forwarded
109    by the `visit` method.
110
111    Per default the visitor functions for the nodes are ``'visit_'`` +
112    class name of the node.  So a `TryFinally` node visit function would
113    be `visit_TryFinally`.  This behavior can be changed by overriding
114    the `get_visitor` function.  If no visitor function exists for a node
115    (return value `None`) the `generic_visit` visitor is used instead.
116
117    Don't use the `NodeVisitor` if you want to apply changes to nodes during
118    traversing.  For this a special visitor exists (`NodeTransformer`) that
119    allows modifications.
120    """
121
122    def get_visitor(self, node):
123        """
124        Return the visitor function for this node or `None` if no visitor
125        exists for this node.  In that case the generic visit function is
126        used instead.
127        """
128        method = "visit_" + node.__class__.__name__
129        return getattr(self, method, None)
130
131    def visit(self, node):
132        """Visit a node."""
133        f = self.get_visitor(node)
134        if f is not None:
135            return f(node)
136        return self.generic_visit(node)
137
138    def generic_visit(self, node):
139        """Called if no explicit visitor function exists for a node."""
140        for field, value in iter_fields(node):
141            if isinstance(value, list):
142                for item in value:
143                    if isinstance(item, AST):
144                        self.visit(item)
145            elif isinstance(value, AST):
146                self.visit(value)
147
148
149class NodeTransformer(NodeVisitor):
150
151    """
152    Walks the abstract syntax tree and allows modifications of nodes.
153
154    The `NodeTransformer` will walk the AST and use the return value of the
155    visitor functions to replace or remove the old node.  If the return
156    value of the visitor function is `None` the node will be removed
157    from the previous location otherwise it's replaced with the return
158    value.  The return value may be the original node in which case no
159    replacement takes place.
160
161    Here an example transformer that rewrites all `foo` to `data['foo']`::
162
163        class RewriteName(NodeTransformer):
164
165            def visit_Name(self, node):
166                return copy_location(Subscript(
167                    value=Name(id='data', ctx=Load()),
168                    slice=Index(value=Str(s=node.id)),
169                    ctx=node.ctx
170                ), node)
171
172    Keep in mind that if the node you're operating on has child nodes
173    you must either transform the child nodes yourself or call the generic
174    visit function for the node first.
175
176    Nodes that were part of a collection of statements (that applies to
177    all statement nodes) may also return a list of nodes rather than just
178    a single node.
179
180    Usually you use the transformer like this::
181
182        node = YourTransformer().visit(node)
183    """
184
185    def generic_visit(self, node):
186        for field, old_value in iter_fields(node):
187            old_value = getattr(node, field, None)
188            if isinstance(old_value, list):
189                new_values = []
190                for value in old_value:
191                    if isinstance(value, AST):
192                        value = self.visit(value)
193                        if value is None:
194                            continue
195                        elif not isinstance(value, AST):
196                            new_values.extend(value)
197                            continue
198                    new_values.append(value)
199                old_value[:] = new_values
200            elif isinstance(old_value, AST):
201                new_node = self.visit(old_value)
202                if new_node is None:
203                    delattr(node, field)
204                else:
205                    setattr(node, field, new_node)
206        return node
207
208
209class SourceGenerator(NodeVisitor):
210
211    """
212    This visitor is able to transform a well formed syntax tree into python
213    sourcecode.  For more details have a look at the docstring of the
214    `node_to_source` function.
215    """
216
217    def __init__(self, indent_with):
218        self.result = []
219        self.indent_with = indent_with
220        self.indentation = 0
221        self.new_lines = 0
222
223    def write(self, x):
224        if self.new_lines:
225            if self.result:
226                self.result.append("\n" * self.new_lines)
227            self.result.append(self.indent_with * self.indentation)
228            self.new_lines = 0
229        self.result.append(x)
230
231    def newline(self, n=1):
232        self.new_lines = max(self.new_lines, n)
233
234    def body(self, statements):
235        self.new_line = True
236        self.indentation += 1
237        for stmt in statements:
238            self.visit(stmt)
239        self.indentation -= 1
240
241    def body_or_else(self, node):
242        self.body(node.body)
243        if node.orelse:
244            self.newline()
245            self.write("else:")
246            self.body(node.orelse)
247
248    def signature(self, node):
249        want_comma = []
250
251        def write_comma():
252            if want_comma:
253                self.write(", ")
254            else:
255                want_comma.append(True)
256
257        padding = [None] * (len(node.args) - len(node.defaults))
258        for arg, default in zip(node.args, padding + node.defaults):
259            write_comma()
260            self.visit(arg)
261            if default is not None:
262                self.write("=")
263                self.visit(default)
264        if node.vararg is not None:
265            write_comma()
266            self.write("*" + node.vararg.arg)
267        if node.kwarg is not None:
268            write_comma()
269            self.write("**" + node.kwarg.arg)
270
271    def decorators(self, node):
272        for decorator in node.decorator_list:
273            self.newline()
274            self.write("@")
275            self.visit(decorator)
276
277    # Statements
278
279    def visit_Assign(self, node):
280        self.newline()
281        for idx, target in enumerate(node.targets):
282            if idx:
283                self.write(", ")
284            self.visit(target)
285        self.write(" = ")
286        self.visit(node.value)
287
288    def visit_AugAssign(self, node):
289        self.newline()
290        self.visit(node.target)
291        self.write(BINOP_SYMBOLS[type(node.op)] + "=")
292        self.visit(node.value)
293
294    def visit_ImportFrom(self, node):
295        self.newline()
296        self.write("from %s%s import " % ("." * node.level, node.module))
297        for idx, item in enumerate(node.names):
298            if idx:
299                self.write(", ")
300            self.write(item)
301
302    def visit_Import(self, node):
303        self.newline()
304        for item in node.names:
305            self.write("import ")
306            self.visit(item)
307
308    def visit_Expr(self, node):
309        self.newline()
310        self.generic_visit(node)
311
312    def visit_FunctionDef(self, node):
313        self.newline(n=2)
314        self.decorators(node)
315        self.newline()
316        self.write("def %s(" % node.name)
317        self.signature(node.args)
318        self.write("):")
319        self.body(node.body)
320
321    def visit_ClassDef(self, node):
322        have_args = []
323
324        def paren_or_comma():
325            if have_args:
326                self.write(", ")
327            else:
328                have_args.append(True)
329                self.write("(")
330
331        self.newline(n=3)
332        self.decorators(node)
333        self.newline()
334        self.write("class %s" % node.name)
335        for base in node.bases:
336            paren_or_comma()
337            self.visit(base)
338        # XXX: the if here is used to keep this module compatible
339        #      with python 2.6.
340        if hasattr(node, "keywords"):
341            for keyword in node.keywords:
342                paren_or_comma()
343                self.write(keyword.arg + "=")
344                self.visit(keyword.value)
345            if getattr(node, "starargs", None):
346                paren_or_comma()
347                self.write("*")
348                self.visit(node.starargs)
349            if getattr(node, "kwargs", None):
350                paren_or_comma()
351                self.write("**")
352                self.visit(node.kwargs)
353        self.write(have_args and "):" or ":")
354        self.body(node.body)
355
356    def visit_If(self, node):
357        self.newline()
358        self.write("if ")
359        self.visit(node.test)
360        self.write(":")
361        self.body(node.body)
362        while True:
363            else_ = node.orelse
364            if len(else_) == 1 and isinstance(else_[0], If):
365                node = else_[0]
366                self.newline()
367                self.write("elif ")
368                self.visit(node.test)
369                self.write(":")
370                self.body(node.body)
371            else:
372                self.newline()
373                self.write("else:")
374                self.body(else_)
375                break
376
377    def visit_For(self, node):
378        self.newline()
379        self.write("for ")
380        self.visit(node.target)
381        self.write(" in ")
382        self.visit(node.iter)
383        self.write(":")
384        self.body_or_else(node)
385
386    def visit_While(self, node):
387        self.newline()
388        self.write("while ")
389        self.visit(node.test)
390        self.write(":")
391        self.body_or_else(node)
392
393    def visit_With(self, node):
394        self.newline()
395        self.write("with ")
396        self.visit(node.context_expr)
397        if node.optional_vars is not None:
398            self.write(" as ")
399            self.visit(node.optional_vars)
400        self.write(":")
401        self.body(node.body)
402
403    def visit_Pass(self, node):
404        self.newline()
405        self.write("pass")
406
407    def visit_Print(self, node):
408        # XXX: python 2.6 only
409        self.newline()
410        self.write("print ")
411        want_comma = False
412        if node.dest is not None:
413            self.write(" >> ")
414            self.visit(node.dest)
415            want_comma = True
416        for value in node.values:
417            if want_comma:
418                self.write(", ")
419            self.visit(value)
420            want_comma = True
421        if not node.nl:
422            self.write(",")
423
424    def visit_Delete(self, node):
425        self.newline()
426        self.write("del ")
427        for idx, target in enumerate(node):
428            if idx:
429                self.write(", ")
430            self.visit(target)
431
432    def visit_TryExcept(self, node):
433        self.newline()
434        self.write("try:")
435        self.body(node.body)
436        for handler in node.handlers:
437            self.visit(handler)
438
439    def visit_TryFinally(self, node):
440        self.newline()
441        self.write("try:")
442        self.body(node.body)
443        self.newline()
444        self.write("finally:")
445        self.body(node.finalbody)
446
447    def visit_Global(self, node):
448        self.newline()
449        self.write("global " + ", ".join(node.names))
450
451    def visit_Nonlocal(self, node):
452        self.newline()
453        self.write("nonlocal " + ", ".join(node.names))
454
455    def visit_Return(self, node):
456        self.newline()
457        self.write("return ")
458        self.visit(node.value)
459
460    def visit_Break(self, node):
461        self.newline()
462        self.write("break")
463
464    def visit_Continue(self, node):
465        self.newline()
466        self.write("continue")
467
468    def visit_Raise(self, node):
469        # XXX: Python 2.6 / 3.0 compatibility
470        self.newline()
471        self.write("raise")
472        if hasattr(node, "exc") and node.exc is not None:
473            self.write(" ")
474            self.visit(node.exc)
475            if node.cause is not None:
476                self.write(" from ")
477                self.visit(node.cause)
478        elif hasattr(node, "type") and node.type is not None:
479            self.visit(node.type)
480            if node.inst is not None:
481                self.write(", ")
482                self.visit(node.inst)
483            if node.tback is not None:
484                self.write(", ")
485                self.visit(node.tback)
486
487    # Expressions
488
489    def visit_Attribute(self, node):
490        self.visit(node.value)
491        self.write("." + node.attr)
492
493    def visit_Call(self, node):
494        want_comma = []
495
496        def write_comma():
497            if want_comma:
498                self.write(", ")
499            else:
500                want_comma.append(True)
501
502        self.visit(node.func)
503        self.write("(")
504        for arg in node.args:
505            write_comma()
506            self.visit(arg)
507        for keyword in node.keywords:
508            write_comma()
509            self.write(keyword.arg + "=")
510            self.visit(keyword.value)
511        if getattr(node, "starargs", None):
512            write_comma()
513            self.write("*")
514            self.visit(node.starargs)
515        if getattr(node, "kwargs", None):
516            write_comma()
517            self.write("**")
518            self.visit(node.kwargs)
519        self.write(")")
520
521    def visit_Name(self, node):
522        self.write(node.id)
523
524    def visit_NameConstant(self, node):
525        self.write(str(node.value))
526
527    def visit_arg(self, node):
528        self.write(node.arg)
529
530    def visit_Str(self, node):
531        self.write(repr(node.s))
532
533    def visit_Bytes(self, node):
534        self.write(repr(node.s))
535
536    def visit_Num(self, node):
537        self.write(repr(node.n))
538
539    # newly needed in Python 3.8
540    def visit_Constant(self, node):
541        self.write(repr(node.value))
542
543    def visit_Tuple(self, node):
544        self.write("(")
545        idx = -1
546        for idx, item in enumerate(node.elts):
547            if idx:
548                self.write(", ")
549            self.visit(item)
550        self.write(idx and ")" or ",)")
551
552    def sequence_visit(left, right):
553        def visit(self, node):
554            self.write(left)
555            for idx, item in enumerate(node.elts):
556                if idx:
557                    self.write(", ")
558                self.visit(item)
559            self.write(right)
560
561        return visit
562
563    visit_List = sequence_visit("[", "]")
564    visit_Set = sequence_visit("{", "}")
565    del sequence_visit
566
567    def visit_Dict(self, node):
568        self.write("{")
569        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
570            if idx:
571                self.write(", ")
572            self.visit(key)
573            self.write(": ")
574            self.visit(value)
575        self.write("}")
576
577    def visit_BinOp(self, node):
578        self.write("(")
579        self.visit(node.left)
580        self.write(" %s " % BINOP_SYMBOLS[type(node.op)])
581        self.visit(node.right)
582        self.write(")")
583
584    def visit_BoolOp(self, node):
585        self.write("(")
586        for idx, value in enumerate(node.values):
587            if idx:
588                self.write(" %s " % BOOLOP_SYMBOLS[type(node.op)])
589            self.visit(value)
590        self.write(")")
591
592    def visit_Compare(self, node):
593        self.write("(")
594        self.visit(node.left)
595        for op, right in zip(node.ops, node.comparators):
596            self.write(" %s " % CMPOP_SYMBOLS[type(op)])
597            self.visit(right)
598        self.write(")")
599
600    def visit_UnaryOp(self, node):
601        self.write("(")
602        op = UNARYOP_SYMBOLS[type(node.op)]
603        self.write(op)
604        if op == "not":
605            self.write(" ")
606        self.visit(node.operand)
607        self.write(")")
608
609    def visit_Subscript(self, node):
610        self.visit(node.value)
611        self.write("[")
612        self.visit(node.slice)
613        self.write("]")
614
615    def visit_Slice(self, node):
616        if node.lower is not None:
617            self.visit(node.lower)
618        self.write(":")
619        if node.upper is not None:
620            self.visit(node.upper)
621        if node.step is not None:
622            self.write(":")
623            if not (isinstance(node.step, Name) and node.step.id == "None"):
624                self.visit(node.step)
625
626    def visit_ExtSlice(self, node):
627        for idx, item in node.dims:
628            if idx:
629                self.write(", ")
630            self.visit(item)
631
632    def visit_Yield(self, node):
633        self.write("yield ")
634        self.visit(node.value)
635
636    def visit_Lambda(self, node):
637        self.write("lambda ")
638        self.signature(node.args)
639        self.write(": ")
640        self.visit(node.body)
641
642    def visit_Ellipsis(self, node):
643        self.write("Ellipsis")
644
645    def generator_visit(left, right):
646        def visit(self, node):
647            self.write(left)
648            self.visit(node.elt)
649            for comprehension in node.generators:
650                self.visit(comprehension)
651            self.write(right)
652
653        return visit
654
655    visit_ListComp = generator_visit("[", "]")
656    visit_GeneratorExp = generator_visit("(", ")")
657    visit_SetComp = generator_visit("{", "}")
658    del generator_visit
659
660    def visit_DictComp(self, node):
661        self.write("{")
662        self.visit(node.key)
663        self.write(": ")
664        self.visit(node.value)
665        for comprehension in node.generators:
666            self.visit(comprehension)
667        self.write("}")
668
669    def visit_IfExp(self, node):
670        self.visit(node.body)
671        self.write(" if ")
672        self.visit(node.test)
673        self.write(" else ")
674        self.visit(node.orelse)
675
676    def visit_Starred(self, node):
677        self.write("*")
678        self.visit(node.value)
679
680    def visit_Repr(self, node):
681        # XXX: python 2.6 only
682        self.write("`")
683        self.visit(node.value)
684        self.write("`")
685
686    # Helper Nodes
687
688    def visit_alias(self, node):
689        self.write(node.name)
690        if node.asname is not None:
691            self.write(" as " + node.asname)
692
693    def visit_comprehension(self, node):
694        self.write(" for ")
695        self.visit(node.target)
696        self.write(" in ")
697        self.visit(node.iter)
698        if node.ifs:
699            for if_ in node.ifs:
700                self.write(" if ")
701                self.visit(if_)
702
703    def visit_excepthandler(self, node):
704        self.newline()
705        self.write("except")
706        if node.type is not None:
707            self.write(" ")
708            self.visit(node.type)
709            if node.name is not None:
710                self.write(" as ")
711                self.visit(node.name)
712        self.write(":")
713        self.body(node.body)
714