• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2    ast
3    ~~~
4
5    The `ast` module helps Python applications to process trees of the Python
6    abstract syntax grammar.  The abstract syntax itself might change with
7    each Python release; this module helps to find out programmatically what
8    the current grammar looks like and allows modifications of it.
9
10    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11    a flag to the `compile()` builtin function or by using the `parse()`
12    function from this module.  The result will be a tree of objects whose
13    classes all inherit from `ast.AST`.
14
15    A modified abstract syntax tree can be compiled into a Python code object
16    using the built-in `compile()` function.
17
18    Additionally various helper functions are provided that make working with
19    the trees simpler.  The main intention of the helper functions and this
20    module in general is to provide an easy to use interface for libraries
21    that work tightly with the python syntax (template engines for example).
22
23
24    :copyright: Copyright 2008 by Armin Ronacher.
25    :license: Python License.
26"""
27import sys
28from _ast import *
29from contextlib import contextmanager, nullcontext
30from enum import IntEnum, auto
31
32
33def parse(source, filename='<unknown>', mode='exec', *,
34          type_comments=False, feature_version=None):
35    """
36    Parse the source into an AST node.
37    Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
38    Pass type_comments=True to get back type comments where the syntax allows.
39    """
40    flags = PyCF_ONLY_AST
41    if type_comments:
42        flags |= PyCF_TYPE_COMMENTS
43    if isinstance(feature_version, tuple):
44        major, minor = feature_version  # Should be a 2-tuple.
45        assert major == 3
46        feature_version = minor
47    elif feature_version is None:
48        feature_version = -1
49    # Else it should be an int giving the minor version for 3.x.
50    return compile(source, filename, mode, flags,
51                   _feature_version=feature_version)
52
53
54def literal_eval(node_or_string):
55    """
56    Safely evaluate an expression node or a string containing a Python
57    expression.  The string or node provided may only consist of the following
58    Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
59    sets, booleans, and None.
60    """
61    if isinstance(node_or_string, str):
62        node_or_string = parse(node_or_string, mode='eval')
63    if isinstance(node_or_string, Expression):
64        node_or_string = node_or_string.body
65    def _raise_malformed_node(node):
66        raise ValueError(f'malformed node or string: {node!r}')
67    def _convert_num(node):
68        if not isinstance(node, Constant) or type(node.value) not in (int, float, complex):
69            _raise_malformed_node(node)
70        return node.value
71    def _convert_signed_num(node):
72        if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
73            operand = _convert_num(node.operand)
74            if isinstance(node.op, UAdd):
75                return + operand
76            else:
77                return - operand
78        return _convert_num(node)
79    def _convert(node):
80        if isinstance(node, Constant):
81            return node.value
82        elif isinstance(node, Tuple):
83            return tuple(map(_convert, node.elts))
84        elif isinstance(node, List):
85            return list(map(_convert, node.elts))
86        elif isinstance(node, Set):
87            return set(map(_convert, node.elts))
88        elif (isinstance(node, Call) and isinstance(node.func, Name) and
89              node.func.id == 'set' and node.args == node.keywords == []):
90            return set()
91        elif isinstance(node, Dict):
92            if len(node.keys) != len(node.values):
93                _raise_malformed_node(node)
94            return dict(zip(map(_convert, node.keys),
95                            map(_convert, node.values)))
96        elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
97            left = _convert_signed_num(node.left)
98            right = _convert_num(node.right)
99            if isinstance(left, (int, float)) and isinstance(right, complex):
100                if isinstance(node.op, Add):
101                    return left + right
102                else:
103                    return left - right
104        return _convert_signed_num(node)
105    return _convert(node_or_string)
106
107
108def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
109    """
110    Return a formatted dump of the tree in node.  This is mainly useful for
111    debugging purposes.  If annotate_fields is true (by default),
112    the returned string will show the names and the values for fields.
113    If annotate_fields is false, the result string will be more compact by
114    omitting unambiguous field names.  Attributes such as line
115    numbers and column offsets are not dumped by default.  If this is wanted,
116    include_attributes can be set to true.  If indent is a non-negative
117    integer or string, then the tree will be pretty-printed with that indent
118    level. None (the default) selects the single line representation.
119    """
120    def _format(node, level=0):
121        if indent is not None:
122            level += 1
123            prefix = '\n' + indent * level
124            sep = ',\n' + indent * level
125        else:
126            prefix = ''
127            sep = ', '
128        if isinstance(node, AST):
129            cls = type(node)
130            args = []
131            allsimple = True
132            keywords = annotate_fields
133            for name in node._fields:
134                try:
135                    value = getattr(node, name)
136                except AttributeError:
137                    keywords = True
138                    continue
139                if value is None and getattr(cls, name, ...) is None:
140                    keywords = True
141                    continue
142                value, simple = _format(value, level)
143                allsimple = allsimple and simple
144                if keywords:
145                    args.append('%s=%s' % (name, value))
146                else:
147                    args.append(value)
148            if include_attributes and node._attributes:
149                for name in node._attributes:
150                    try:
151                        value = getattr(node, name)
152                    except AttributeError:
153                        continue
154                    if value is None and getattr(cls, name, ...) is None:
155                        continue
156                    value, simple = _format(value, level)
157                    allsimple = allsimple and simple
158                    args.append('%s=%s' % (name, value))
159            if allsimple and len(args) <= 3:
160                return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
161            return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
162        elif isinstance(node, list):
163            if not node:
164                return '[]', True
165            return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
166        return repr(node), True
167
168    if not isinstance(node, AST):
169        raise TypeError('expected AST, got %r' % node.__class__.__name__)
170    if indent is not None and not isinstance(indent, str):
171        indent = ' ' * indent
172    return _format(node)[0]
173
174
175def copy_location(new_node, old_node):
176    """
177    Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
178    attributes) from *old_node* to *new_node* if possible, and return *new_node*.
179    """
180    for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
181        if attr in old_node._attributes and attr in new_node._attributes:
182            value = getattr(old_node, attr, None)
183            # end_lineno and end_col_offset are optional attributes, and they
184            # should be copied whether the value is None or not.
185            if value is not None or (
186                hasattr(old_node, attr) and attr.startswith("end_")
187            ):
188                setattr(new_node, attr, value)
189    return new_node
190
191
192def fix_missing_locations(node):
193    """
194    When you compile a node tree with compile(), the compiler expects lineno and
195    col_offset attributes for every node that supports them.  This is rather
196    tedious to fill in for generated nodes, so this helper adds these attributes
197    recursively where not already set, by setting them to the values of the
198    parent node.  It works recursively starting at *node*.
199    """
200    def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
201        if 'lineno' in node._attributes:
202            if not hasattr(node, 'lineno'):
203                node.lineno = lineno
204            else:
205                lineno = node.lineno
206        if 'end_lineno' in node._attributes:
207            if getattr(node, 'end_lineno', None) is None:
208                node.end_lineno = end_lineno
209            else:
210                end_lineno = node.end_lineno
211        if 'col_offset' in node._attributes:
212            if not hasattr(node, 'col_offset'):
213                node.col_offset = col_offset
214            else:
215                col_offset = node.col_offset
216        if 'end_col_offset' in node._attributes:
217            if getattr(node, 'end_col_offset', None) is None:
218                node.end_col_offset = end_col_offset
219            else:
220                end_col_offset = node.end_col_offset
221        for child in iter_child_nodes(node):
222            _fix(child, lineno, col_offset, end_lineno, end_col_offset)
223    _fix(node, 1, 0, 1, 0)
224    return node
225
226
227def increment_lineno(node, n=1):
228    """
229    Increment the line number and end line number of each node in the tree
230    starting at *node* by *n*. This is useful to "move code" to a different
231    location in a file.
232    """
233    for child in walk(node):
234        if 'lineno' in child._attributes:
235            child.lineno = getattr(child, 'lineno', 0) + n
236        if (
237            "end_lineno" in child._attributes
238            and (end_lineno := getattr(child, "end_lineno", 0)) is not None
239        ):
240            child.end_lineno = end_lineno + n
241    return node
242
243
244def iter_fields(node):
245    """
246    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
247    that is present on *node*.
248    """
249    for field in node._fields:
250        try:
251            yield field, getattr(node, field)
252        except AttributeError:
253            pass
254
255
256def iter_child_nodes(node):
257    """
258    Yield all direct child nodes of *node*, that is, all fields that are nodes
259    and all items of fields that are lists of nodes.
260    """
261    for name, field in iter_fields(node):
262        if isinstance(field, AST):
263            yield field
264        elif isinstance(field, list):
265            for item in field:
266                if isinstance(item, AST):
267                    yield item
268
269
270def get_docstring(node, clean=True):
271    """
272    Return the docstring for the given node or None if no docstring can
273    be found.  If the node provided does not have docstrings a TypeError
274    will be raised.
275
276    If *clean* is `True`, all tabs are expanded to spaces and any whitespace
277    that can be uniformly removed from the second line onwards is removed.
278    """
279    if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
280        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
281    if not(node.body and isinstance(node.body[0], Expr)):
282        return None
283    node = node.body[0].value
284    if isinstance(node, Str):
285        text = node.s
286    elif isinstance(node, Constant) and isinstance(node.value, str):
287        text = node.value
288    else:
289        return None
290    if clean:
291        import inspect
292        text = inspect.cleandoc(text)
293    return text
294
295
296def _splitlines_no_ff(source):
297    """Split a string into lines ignoring form feed and other chars.
298
299    This mimics how the Python parser splits source code.
300    """
301    idx = 0
302    lines = []
303    next_line = ''
304    while idx < len(source):
305        c = source[idx]
306        next_line += c
307        idx += 1
308        # Keep \r\n together
309        if c == '\r' and idx < len(source) and source[idx] == '\n':
310            next_line += '\n'
311            idx += 1
312        if c in '\r\n':
313            lines.append(next_line)
314            next_line = ''
315
316    if next_line:
317        lines.append(next_line)
318    return lines
319
320
321def _pad_whitespace(source):
322    r"""Replace all chars except '\f\t' in a line with spaces."""
323    result = ''
324    for c in source:
325        if c in '\f\t':
326            result += c
327        else:
328            result += ' '
329    return result
330
331
332def get_source_segment(source, node, *, padded=False):
333    """Get source code segment of the *source* that generated *node*.
334
335    If some location information (`lineno`, `end_lineno`, `col_offset`,
336    or `end_col_offset`) is missing, return None.
337
338    If *padded* is `True`, the first line of a multi-line statement will
339    be padded with spaces to match its original position.
340    """
341    try:
342        if node.end_lineno is None or node.end_col_offset is None:
343            return None
344        lineno = node.lineno - 1
345        end_lineno = node.end_lineno - 1
346        col_offset = node.col_offset
347        end_col_offset = node.end_col_offset
348    except AttributeError:
349        return None
350
351    lines = _splitlines_no_ff(source)
352    if end_lineno == lineno:
353        return lines[lineno].encode()[col_offset:end_col_offset].decode()
354
355    if padded:
356        padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
357    else:
358        padding = ''
359
360    first = padding + lines[lineno].encode()[col_offset:].decode()
361    last = lines[end_lineno].encode()[:end_col_offset].decode()
362    lines = lines[lineno+1:end_lineno]
363
364    lines.insert(0, first)
365    lines.append(last)
366    return ''.join(lines)
367
368
369def walk(node):
370    """
371    Recursively yield all descendant nodes in the tree starting at *node*
372    (including *node* itself), in no specified order.  This is useful if you
373    only want to modify nodes in place and don't care about the context.
374    """
375    from collections import deque
376    todo = deque([node])
377    while todo:
378        node = todo.popleft()
379        todo.extend(iter_child_nodes(node))
380        yield node
381
382
383class NodeVisitor(object):
384    """
385    A node visitor base class that walks the abstract syntax tree and calls a
386    visitor function for every node found.  This function may return a value
387    which is forwarded by the `visit` method.
388
389    This class is meant to be subclassed, with the subclass adding visitor
390    methods.
391
392    Per default the visitor functions for the nodes are ``'visit_'`` +
393    class name of the node.  So a `TryFinally` node visit function would
394    be `visit_TryFinally`.  This behavior can be changed by overriding
395    the `visit` method.  If no visitor function exists for a node
396    (return value `None`) the `generic_visit` visitor is used instead.
397
398    Don't use the `NodeVisitor` if you want to apply changes to nodes during
399    traversing.  For this a special visitor exists (`NodeTransformer`) that
400    allows modifications.
401    """
402
403    def visit(self, node):
404        """Visit a node."""
405        method = 'visit_' + node.__class__.__name__
406        visitor = getattr(self, method, self.generic_visit)
407        return visitor(node)
408
409    def generic_visit(self, node):
410        """Called if no explicit visitor function exists for a node."""
411        for field, value in iter_fields(node):
412            if isinstance(value, list):
413                for item in value:
414                    if isinstance(item, AST):
415                        self.visit(item)
416            elif isinstance(value, AST):
417                self.visit(value)
418
419    def visit_Constant(self, node):
420        value = node.value
421        type_name = _const_node_type_names.get(type(value))
422        if type_name is None:
423            for cls, name in _const_node_type_names.items():
424                if isinstance(value, cls):
425                    type_name = name
426                    break
427        if type_name is not None:
428            method = 'visit_' + type_name
429            try:
430                visitor = getattr(self, method)
431            except AttributeError:
432                pass
433            else:
434                import warnings
435                warnings.warn(f"{method} is deprecated; add visit_Constant",
436                              DeprecationWarning, 2)
437                return visitor(node)
438        return self.generic_visit(node)
439
440
441class NodeTransformer(NodeVisitor):
442    """
443    A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
444    allows modification of nodes.
445
446    The `NodeTransformer` will walk the AST and use the return value of the
447    visitor methods to replace or remove the old node.  If the return value of
448    the visitor method is ``None``, the node will be removed from its location,
449    otherwise it is replaced with the return value.  The return value may be the
450    original node in which case no replacement takes place.
451
452    Here is an example transformer that rewrites all occurrences of name lookups
453    (``foo``) to ``data['foo']``::
454
455       class RewriteName(NodeTransformer):
456
457           def visit_Name(self, node):
458               return Subscript(
459                   value=Name(id='data', ctx=Load()),
460                   slice=Constant(value=node.id),
461                   ctx=node.ctx
462               )
463
464    Keep in mind that if the node you're operating on has child nodes you must
465    either transform the child nodes yourself or call the :meth:`generic_visit`
466    method for the node first.
467
468    For nodes that were part of a collection of statements (that applies to all
469    statement nodes), the visitor may also return a list of nodes rather than
470    just a single node.
471
472    Usually you use the transformer like this::
473
474       node = YourTransformer().visit(node)
475    """
476
477    def generic_visit(self, node):
478        for field, old_value in iter_fields(node):
479            if isinstance(old_value, list):
480                new_values = []
481                for value in old_value:
482                    if isinstance(value, AST):
483                        value = self.visit(value)
484                        if value is None:
485                            continue
486                        elif not isinstance(value, AST):
487                            new_values.extend(value)
488                            continue
489                    new_values.append(value)
490                old_value[:] = new_values
491            elif isinstance(old_value, AST):
492                new_node = self.visit(old_value)
493                if new_node is None:
494                    delattr(node, field)
495                else:
496                    setattr(node, field, new_node)
497        return node
498
499
500# If the ast module is loaded more than once, only add deprecated methods once
501if not hasattr(Constant, 'n'):
502    # The following code is for backward compatibility.
503    # It will be removed in future.
504
505    def _getter(self):
506        """Deprecated. Use value instead."""
507        return self.value
508
509    def _setter(self, value):
510        self.value = value
511
512    Constant.n = property(_getter, _setter)
513    Constant.s = property(_getter, _setter)
514
515class _ABC(type):
516
517    def __init__(cls, *args):
518        cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead"""
519
520    def __instancecheck__(cls, inst):
521        if not isinstance(inst, Constant):
522            return False
523        if cls in _const_types:
524            try:
525                value = inst.value
526            except AttributeError:
527                return False
528            else:
529                return (
530                    isinstance(value, _const_types[cls]) and
531                    not isinstance(value, _const_types_not.get(cls, ()))
532                )
533        return type.__instancecheck__(cls, inst)
534
535def _new(cls, *args, **kwargs):
536    for key in kwargs:
537        if key not in cls._fields:
538            # arbitrary keyword arguments are accepted
539            continue
540        pos = cls._fields.index(key)
541        if pos < len(args):
542            raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}")
543    if cls in _const_types:
544        return Constant(*args, **kwargs)
545    return Constant.__new__(cls, *args, **kwargs)
546
547class Num(Constant, metaclass=_ABC):
548    _fields = ('n',)
549    __new__ = _new
550
551class Str(Constant, metaclass=_ABC):
552    _fields = ('s',)
553    __new__ = _new
554
555class Bytes(Constant, metaclass=_ABC):
556    _fields = ('s',)
557    __new__ = _new
558
559class NameConstant(Constant, metaclass=_ABC):
560    __new__ = _new
561
562class Ellipsis(Constant, metaclass=_ABC):
563    _fields = ()
564
565    def __new__(cls, *args, **kwargs):
566        if cls is Ellipsis:
567            return Constant(..., *args, **kwargs)
568        return Constant.__new__(cls, *args, **kwargs)
569
570_const_types = {
571    Num: (int, float, complex),
572    Str: (str,),
573    Bytes: (bytes,),
574    NameConstant: (type(None), bool),
575    Ellipsis: (type(...),),
576}
577_const_types_not = {
578    Num: (bool,),
579}
580
581_const_node_type_names = {
582    bool: 'NameConstant',  # should be before int
583    type(None): 'NameConstant',
584    int: 'Num',
585    float: 'Num',
586    complex: 'Num',
587    str: 'Str',
588    bytes: 'Bytes',
589    type(...): 'Ellipsis',
590}
591
592class slice(AST):
593    """Deprecated AST node class."""
594
595class Index(slice):
596    """Deprecated AST node class. Use the index value directly instead."""
597    def __new__(cls, value, **kwargs):
598        return value
599
600class ExtSlice(slice):
601    """Deprecated AST node class. Use ast.Tuple instead."""
602    def __new__(cls, dims=(), **kwargs):
603        return Tuple(list(dims), Load(), **kwargs)
604
605# If the ast module is loaded more than once, only add deprecated methods once
606if not hasattr(Tuple, 'dims'):
607    # The following code is for backward compatibility.
608    # It will be removed in future.
609
610    def _dims_getter(self):
611        """Deprecated. Use elts instead."""
612        return self.elts
613
614    def _dims_setter(self, value):
615        self.elts = value
616
617    Tuple.dims = property(_dims_getter, _dims_setter)
618
619class Suite(mod):
620    """Deprecated AST node class.  Unused in Python 3."""
621
622class AugLoad(expr_context):
623    """Deprecated AST node class.  Unused in Python 3."""
624
625class AugStore(expr_context):
626    """Deprecated AST node class.  Unused in Python 3."""
627
628class Param(expr_context):
629    """Deprecated AST node class.  Unused in Python 3."""
630
631
632# Large float and imaginary literals get turned into infinities in the AST.
633# We unparse those infinities to INFSTR.
634_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
635
636class _Precedence(IntEnum):
637    """Precedence table that originated from python grammar."""
638
639    TUPLE = auto()
640    YIELD = auto()           # 'yield', 'yield from'
641    TEST = auto()            # 'if'-'else', 'lambda'
642    OR = auto()              # 'or'
643    AND = auto()             # 'and'
644    NOT = auto()             # 'not'
645    CMP = auto()             # '<', '>', '==', '>=', '<=', '!=',
646                             # 'in', 'not in', 'is', 'is not'
647    EXPR = auto()
648    BOR = EXPR               # '|'
649    BXOR = auto()            # '^'
650    BAND = auto()            # '&'
651    SHIFT = auto()           # '<<', '>>'
652    ARITH = auto()           # '+', '-'
653    TERM = auto()            # '*', '@', '/', '%', '//'
654    FACTOR = auto()          # unary '+', '-', '~'
655    POWER = auto()           # '**'
656    AWAIT = auto()           # 'await'
657    ATOM = auto()
658
659    def next(self):
660        try:
661            return self.__class__(self + 1)
662        except ValueError:
663            return self
664
665
666_SINGLE_QUOTES = ("'", '"')
667_MULTI_QUOTES = ('"""', "'''")
668_ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES)
669
670class _Unparser(NodeVisitor):
671    """Methods in this class recursively traverse an AST and
672    output source code for the abstract syntax; original formatting
673    is disregarded."""
674
675    def __init__(self, *, _avoid_backslashes=False):
676        self._source = []
677        self._buffer = []
678        self._precedences = {}
679        self._type_ignores = {}
680        self._indent = 0
681        self._avoid_backslashes = _avoid_backslashes
682
683    def interleave(self, inter, f, seq):
684        """Call f on each item in seq, calling inter() in between."""
685        seq = iter(seq)
686        try:
687            f(next(seq))
688        except StopIteration:
689            pass
690        else:
691            for x in seq:
692                inter()
693                f(x)
694
695    def items_view(self, traverser, items):
696        """Traverse and separate the given *items* with a comma and append it to
697        the buffer. If *items* is a single item sequence, a trailing comma
698        will be added."""
699        if len(items) == 1:
700            traverser(items[0])
701            self.write(",")
702        else:
703            self.interleave(lambda: self.write(", "), traverser, items)
704
705    def maybe_newline(self):
706        """Adds a newline if it isn't the start of generated source"""
707        if self._source:
708            self.write("\n")
709
710    def fill(self, text=""):
711        """Indent a piece of text and append it, according to the current
712        indentation level"""
713        self.maybe_newline()
714        self.write("    " * self._indent + text)
715
716    def write(self, text):
717        """Append a piece of text"""
718        self._source.append(text)
719
720    def buffer_writer(self, text):
721        self._buffer.append(text)
722
723    @property
724    def buffer(self):
725        value = "".join(self._buffer)
726        self._buffer.clear()
727        return value
728
729    @contextmanager
730    def block(self, *, extra = None):
731        """A context manager for preparing the source for blocks. It adds
732        the character':', increases the indentation on enter and decreases
733        the indentation on exit. If *extra* is given, it will be directly
734        appended after the colon character.
735        """
736        self.write(":")
737        if extra:
738            self.write(extra)
739        self._indent += 1
740        yield
741        self._indent -= 1
742
743    @contextmanager
744    def delimit(self, start, end):
745        """A context manager for preparing the source for expressions. It adds
746        *start* to the buffer and enters, after exit it adds *end*."""
747
748        self.write(start)
749        yield
750        self.write(end)
751
752    def delimit_if(self, start, end, condition):
753        if condition:
754            return self.delimit(start, end)
755        else:
756            return nullcontext()
757
758    def require_parens(self, precedence, node):
759        """Shortcut to adding precedence related parens"""
760        return self.delimit_if("(", ")", self.get_precedence(node) > precedence)
761
762    def get_precedence(self, node):
763        return self._precedences.get(node, _Precedence.TEST)
764
765    def set_precedence(self, precedence, *nodes):
766        for node in nodes:
767            self._precedences[node] = precedence
768
769    def get_raw_docstring(self, node):
770        """If a docstring node is found in the body of the *node* parameter,
771        return that docstring node, None otherwise.
772
773        Logic mirrored from ``_PyAST_GetDocString``."""
774        if not isinstance(
775            node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
776        ) or len(node.body) < 1:
777            return None
778        node = node.body[0]
779        if not isinstance(node, Expr):
780            return None
781        node = node.value
782        if isinstance(node, Constant) and isinstance(node.value, str):
783            return node
784
785    def get_type_comment(self, node):
786        comment = self._type_ignores.get(node.lineno) or node.type_comment
787        if comment is not None:
788            return f" # type: {comment}"
789
790    def traverse(self, node):
791        if isinstance(node, list):
792            for item in node:
793                self.traverse(item)
794        else:
795            super().visit(node)
796
797    def visit(self, node):
798        """Outputs a source code string that, if converted back to an ast
799        (using ast.parse) will generate an AST equivalent to *node*"""
800        self._source = []
801        self.traverse(node)
802        return "".join(self._source)
803
804    def _write_docstring_and_traverse_body(self, node):
805        if (docstring := self.get_raw_docstring(node)):
806            self._write_docstring(docstring)
807            self.traverse(node.body[1:])
808        else:
809            self.traverse(node.body)
810
811    def visit_Module(self, node):
812        self._type_ignores = {
813            ignore.lineno: f"ignore{ignore.tag}"
814            for ignore in node.type_ignores
815        }
816        self._write_docstring_and_traverse_body(node)
817        self._type_ignores.clear()
818
819    def visit_FunctionType(self, node):
820        with self.delimit("(", ")"):
821            self.interleave(
822                lambda: self.write(", "), self.traverse, node.argtypes
823            )
824
825        self.write(" -> ")
826        self.traverse(node.returns)
827
828    def visit_Expr(self, node):
829        self.fill()
830        self.set_precedence(_Precedence.YIELD, node.value)
831        self.traverse(node.value)
832
833    def visit_NamedExpr(self, node):
834        with self.require_parens(_Precedence.TUPLE, node):
835            self.set_precedence(_Precedence.ATOM, node.target, node.value)
836            self.traverse(node.target)
837            self.write(" := ")
838            self.traverse(node.value)
839
840    def visit_Import(self, node):
841        self.fill("import ")
842        self.interleave(lambda: self.write(", "), self.traverse, node.names)
843
844    def visit_ImportFrom(self, node):
845        self.fill("from ")
846        self.write("." * node.level)
847        if node.module:
848            self.write(node.module)
849        self.write(" import ")
850        self.interleave(lambda: self.write(", "), self.traverse, node.names)
851
852    def visit_Assign(self, node):
853        self.fill()
854        for target in node.targets:
855            self.traverse(target)
856            self.write(" = ")
857        self.traverse(node.value)
858        if type_comment := self.get_type_comment(node):
859            self.write(type_comment)
860
861    def visit_AugAssign(self, node):
862        self.fill()
863        self.traverse(node.target)
864        self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
865        self.traverse(node.value)
866
867    def visit_AnnAssign(self, node):
868        self.fill()
869        with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)):
870            self.traverse(node.target)
871        self.write(": ")
872        self.traverse(node.annotation)
873        if node.value:
874            self.write(" = ")
875            self.traverse(node.value)
876
877    def visit_Return(self, node):
878        self.fill("return")
879        if node.value:
880            self.write(" ")
881            self.traverse(node.value)
882
883    def visit_Pass(self, node):
884        self.fill("pass")
885
886    def visit_Break(self, node):
887        self.fill("break")
888
889    def visit_Continue(self, node):
890        self.fill("continue")
891
892    def visit_Delete(self, node):
893        self.fill("del ")
894        self.interleave(lambda: self.write(", "), self.traverse, node.targets)
895
896    def visit_Assert(self, node):
897        self.fill("assert ")
898        self.traverse(node.test)
899        if node.msg:
900            self.write(", ")
901            self.traverse(node.msg)
902
903    def visit_Global(self, node):
904        self.fill("global ")
905        self.interleave(lambda: self.write(", "), self.write, node.names)
906
907    def visit_Nonlocal(self, node):
908        self.fill("nonlocal ")
909        self.interleave(lambda: self.write(", "), self.write, node.names)
910
911    def visit_Await(self, node):
912        with self.require_parens(_Precedence.AWAIT, node):
913            self.write("await")
914            if node.value:
915                self.write(" ")
916                self.set_precedence(_Precedence.ATOM, node.value)
917                self.traverse(node.value)
918
919    def visit_Yield(self, node):
920        with self.require_parens(_Precedence.YIELD, node):
921            self.write("yield")
922            if node.value:
923                self.write(" ")
924                self.set_precedence(_Precedence.ATOM, node.value)
925                self.traverse(node.value)
926
927    def visit_YieldFrom(self, node):
928        with self.require_parens(_Precedence.YIELD, node):
929            self.write("yield from ")
930            if not node.value:
931                raise ValueError("Node can't be used without a value attribute.")
932            self.set_precedence(_Precedence.ATOM, node.value)
933            self.traverse(node.value)
934
935    def visit_Raise(self, node):
936        self.fill("raise")
937        if not node.exc:
938            if node.cause:
939                raise ValueError(f"Node can't use cause without an exception.")
940            return
941        self.write(" ")
942        self.traverse(node.exc)
943        if node.cause:
944            self.write(" from ")
945            self.traverse(node.cause)
946
947    def visit_Try(self, node):
948        self.fill("try")
949        with self.block():
950            self.traverse(node.body)
951        for ex in node.handlers:
952            self.traverse(ex)
953        if node.orelse:
954            self.fill("else")
955            with self.block():
956                self.traverse(node.orelse)
957        if node.finalbody:
958            self.fill("finally")
959            with self.block():
960                self.traverse(node.finalbody)
961
962    def visit_ExceptHandler(self, node):
963        self.fill("except")
964        if node.type:
965            self.write(" ")
966            self.traverse(node.type)
967        if node.name:
968            self.write(" as ")
969            self.write(node.name)
970        with self.block():
971            self.traverse(node.body)
972
973    def visit_ClassDef(self, node):
974        self.maybe_newline()
975        for deco in node.decorator_list:
976            self.fill("@")
977            self.traverse(deco)
978        self.fill("class " + node.name)
979        with self.delimit_if("(", ")", condition = node.bases or node.keywords):
980            comma = False
981            for e in node.bases:
982                if comma:
983                    self.write(", ")
984                else:
985                    comma = True
986                self.traverse(e)
987            for e in node.keywords:
988                if comma:
989                    self.write(", ")
990                else:
991                    comma = True
992                self.traverse(e)
993
994        with self.block():
995            self._write_docstring_and_traverse_body(node)
996
997    def visit_FunctionDef(self, node):
998        self._function_helper(node, "def")
999
1000    def visit_AsyncFunctionDef(self, node):
1001        self._function_helper(node, "async def")
1002
1003    def _function_helper(self, node, fill_suffix):
1004        self.maybe_newline()
1005        for deco in node.decorator_list:
1006            self.fill("@")
1007            self.traverse(deco)
1008        def_str = fill_suffix + " " + node.name
1009        self.fill(def_str)
1010        with self.delimit("(", ")"):
1011            self.traverse(node.args)
1012        if node.returns:
1013            self.write(" -> ")
1014            self.traverse(node.returns)
1015        with self.block(extra=self.get_type_comment(node)):
1016            self._write_docstring_and_traverse_body(node)
1017
1018    def visit_For(self, node):
1019        self._for_helper("for ", node)
1020
1021    def visit_AsyncFor(self, node):
1022        self._for_helper("async for ", node)
1023
1024    def _for_helper(self, fill, node):
1025        self.fill(fill)
1026        self.traverse(node.target)
1027        self.write(" in ")
1028        self.traverse(node.iter)
1029        with self.block(extra=self.get_type_comment(node)):
1030            self.traverse(node.body)
1031        if node.orelse:
1032            self.fill("else")
1033            with self.block():
1034                self.traverse(node.orelse)
1035
1036    def visit_If(self, node):
1037        self.fill("if ")
1038        self.traverse(node.test)
1039        with self.block():
1040            self.traverse(node.body)
1041        # collapse nested ifs into equivalent elifs.
1042        while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
1043            node = node.orelse[0]
1044            self.fill("elif ")
1045            self.traverse(node.test)
1046            with self.block():
1047                self.traverse(node.body)
1048        # final else
1049        if node.orelse:
1050            self.fill("else")
1051            with self.block():
1052                self.traverse(node.orelse)
1053
1054    def visit_While(self, node):
1055        self.fill("while ")
1056        self.traverse(node.test)
1057        with self.block():
1058            self.traverse(node.body)
1059        if node.orelse:
1060            self.fill("else")
1061            with self.block():
1062                self.traverse(node.orelse)
1063
1064    def visit_With(self, node):
1065        self.fill("with ")
1066        self.interleave(lambda: self.write(", "), self.traverse, node.items)
1067        with self.block(extra=self.get_type_comment(node)):
1068            self.traverse(node.body)
1069
1070    def visit_AsyncWith(self, node):
1071        self.fill("async with ")
1072        self.interleave(lambda: self.write(", "), self.traverse, node.items)
1073        with self.block(extra=self.get_type_comment(node)):
1074            self.traverse(node.body)
1075
1076    def _str_literal_helper(
1077        self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False
1078    ):
1079        """Helper for writing string literals, minimizing escapes.
1080        Returns the tuple (string literal to write, possible quote types).
1081        """
1082        def escape_char(c):
1083            # \n and \t are non-printable, but we only escape them if
1084            # escape_special_whitespace is True
1085            if not escape_special_whitespace and c in "\n\t":
1086                return c
1087            # Always escape backslashes and other non-printable characters
1088            if c == "\\" or not c.isprintable():
1089                return c.encode("unicode_escape").decode("ascii")
1090            return c
1091
1092        escaped_string = "".join(map(escape_char, string))
1093        possible_quotes = quote_types
1094        if "\n" in escaped_string:
1095            possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
1096        possible_quotes = [q for q in possible_quotes if q not in escaped_string]
1097        if not possible_quotes:
1098            # If there aren't any possible_quotes, fallback to using repr
1099            # on the original string. Try to use a quote from quote_types,
1100            # e.g., so that we use triple quotes for docstrings.
1101            string = repr(string)
1102            quote = next((q for q in quote_types if string[0] in q), string[0])
1103            return string[1:-1], [quote]
1104        if escaped_string:
1105            # Sort so that we prefer '''"''' over """\""""
1106            possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
1107            # If we're using triple quotes and we'd need to escape a final
1108            # quote, escape it
1109            if possible_quotes[0][0] == escaped_string[-1]:
1110                assert len(possible_quotes[0]) == 3
1111                escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
1112        return escaped_string, possible_quotes
1113
1114    def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
1115        """Write string literal value with a best effort attempt to avoid backslashes."""
1116        string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
1117        quote_type = quote_types[0]
1118        self.write(f"{quote_type}{string}{quote_type}")
1119
1120    def visit_JoinedStr(self, node):
1121        self.write("f")
1122        if self._avoid_backslashes:
1123            self._fstring_JoinedStr(node, self.buffer_writer)
1124            self._write_str_avoiding_backslashes(self.buffer)
1125            return
1126
1127        # If we don't need to avoid backslashes globally (i.e., we only need
1128        # to avoid them inside FormattedValues), it's cosmetically preferred
1129        # to use escaped whitespace. That is, it's preferred to use backslashes
1130        # for cases like: f"{x}\n". To accomplish this, we keep track of what
1131        # in our buffer corresponds to FormattedValues and what corresponds to
1132        # Constant parts of the f-string, and allow escapes accordingly.
1133        buffer = []
1134        for value in node.values:
1135            meth = getattr(self, "_fstring_" + type(value).__name__)
1136            meth(value, self.buffer_writer)
1137            buffer.append((self.buffer, isinstance(value, Constant)))
1138        new_buffer = []
1139        quote_types = _ALL_QUOTES
1140        for value, is_constant in buffer:
1141            # Repeatedly narrow down the list of possible quote_types
1142            value, quote_types = self._str_literal_helper(
1143                value, quote_types=quote_types,
1144                escape_special_whitespace=is_constant
1145            )
1146            new_buffer.append(value)
1147        value = "".join(new_buffer)
1148        quote_type = quote_types[0]
1149        self.write(f"{quote_type}{value}{quote_type}")
1150
1151    def visit_FormattedValue(self, node):
1152        self.write("f")
1153        self._fstring_FormattedValue(node, self.buffer_writer)
1154        self._write_str_avoiding_backslashes(self.buffer)
1155
1156    def _fstring_JoinedStr(self, node, write):
1157        for value in node.values:
1158            meth = getattr(self, "_fstring_" + type(value).__name__)
1159            meth(value, write)
1160
1161    def _fstring_Constant(self, node, write):
1162        if not isinstance(node.value, str):
1163            raise ValueError("Constants inside JoinedStr should be a string.")
1164        value = node.value.replace("{", "{{").replace("}", "}}")
1165        write(value)
1166
1167    def _fstring_FormattedValue(self, node, write):
1168        write("{")
1169        unparser = type(self)(_avoid_backslashes=True)
1170        unparser.set_precedence(_Precedence.TEST.next(), node.value)
1171        expr = unparser.visit(node.value)
1172        if expr.startswith("{"):
1173            write(" ")  # Separate pair of opening brackets as "{ {"
1174        if "\\" in expr:
1175            raise ValueError("Unable to avoid backslash in f-string expression part")
1176        write(expr)
1177        if node.conversion != -1:
1178            conversion = chr(node.conversion)
1179            if conversion not in "sra":
1180                raise ValueError("Unknown f-string conversion.")
1181            write(f"!{conversion}")
1182        if node.format_spec:
1183            write(":")
1184            meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
1185            meth(node.format_spec, write)
1186        write("}")
1187
1188    def visit_Name(self, node):
1189        self.write(node.id)
1190
1191    def _write_docstring(self, node):
1192        self.fill()
1193        if node.kind == "u":
1194            self.write("u")
1195        self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
1196
1197    def _write_constant(self, value):
1198        if isinstance(value, (float, complex)):
1199            # Substitute overflowing decimal literal for AST infinities.
1200            self.write(repr(value).replace("inf", _INFSTR))
1201        elif self._avoid_backslashes and isinstance(value, str):
1202            self._write_str_avoiding_backslashes(value)
1203        else:
1204            self.write(repr(value))
1205
1206    def visit_Constant(self, node):
1207        value = node.value
1208        if isinstance(value, tuple):
1209            with self.delimit("(", ")"):
1210                self.items_view(self._write_constant, value)
1211        elif value is ...:
1212            self.write("...")
1213        else:
1214            if node.kind == "u":
1215                self.write("u")
1216            self._write_constant(node.value)
1217
1218    def visit_List(self, node):
1219        with self.delimit("[", "]"):
1220            self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1221
1222    def visit_ListComp(self, node):
1223        with self.delimit("[", "]"):
1224            self.traverse(node.elt)
1225            for gen in node.generators:
1226                self.traverse(gen)
1227
1228    def visit_GeneratorExp(self, node):
1229        with self.delimit("(", ")"):
1230            self.traverse(node.elt)
1231            for gen in node.generators:
1232                self.traverse(gen)
1233
1234    def visit_SetComp(self, node):
1235        with self.delimit("{", "}"):
1236            self.traverse(node.elt)
1237            for gen in node.generators:
1238                self.traverse(gen)
1239
1240    def visit_DictComp(self, node):
1241        with self.delimit("{", "}"):
1242            self.traverse(node.key)
1243            self.write(": ")
1244            self.traverse(node.value)
1245            for gen in node.generators:
1246                self.traverse(gen)
1247
1248    def visit_comprehension(self, node):
1249        if node.is_async:
1250            self.write(" async for ")
1251        else:
1252            self.write(" for ")
1253        self.set_precedence(_Precedence.TUPLE, node.target)
1254        self.traverse(node.target)
1255        self.write(" in ")
1256        self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
1257        self.traverse(node.iter)
1258        for if_clause in node.ifs:
1259            self.write(" if ")
1260            self.traverse(if_clause)
1261
1262    def visit_IfExp(self, node):
1263        with self.require_parens(_Precedence.TEST, node):
1264            self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
1265            self.traverse(node.body)
1266            self.write(" if ")
1267            self.traverse(node.test)
1268            self.write(" else ")
1269            self.set_precedence(_Precedence.TEST, node.orelse)
1270            self.traverse(node.orelse)
1271
1272    def visit_Set(self, node):
1273        if not node.elts:
1274            raise ValueError("Set node should have at least one item")
1275        with self.delimit("{", "}"):
1276            self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1277
1278    def visit_Dict(self, node):
1279        def write_key_value_pair(k, v):
1280            self.traverse(k)
1281            self.write(": ")
1282            self.traverse(v)
1283
1284        def write_item(item):
1285            k, v = item
1286            if k is None:
1287                # for dictionary unpacking operator in dicts {**{'y': 2}}
1288                # see PEP 448 for details
1289                self.write("**")
1290                self.set_precedence(_Precedence.EXPR, v)
1291                self.traverse(v)
1292            else:
1293                write_key_value_pair(k, v)
1294
1295        with self.delimit("{", "}"):
1296            self.interleave(
1297                lambda: self.write(", "), write_item, zip(node.keys, node.values)
1298            )
1299
1300    def visit_Tuple(self, node):
1301        with self.delimit("(", ")"):
1302            self.items_view(self.traverse, node.elts)
1303
1304    unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1305    unop_precedence = {
1306        "not": _Precedence.NOT,
1307        "~": _Precedence.FACTOR,
1308        "+": _Precedence.FACTOR,
1309        "-": _Precedence.FACTOR,
1310    }
1311
1312    def visit_UnaryOp(self, node):
1313        operator = self.unop[node.op.__class__.__name__]
1314        operator_precedence = self.unop_precedence[operator]
1315        with self.require_parens(operator_precedence, node):
1316            self.write(operator)
1317            # factor prefixes (+, -, ~) shouldn't be seperated
1318            # from the value they belong, (e.g: +1 instead of + 1)
1319            if operator_precedence is not _Precedence.FACTOR:
1320                self.write(" ")
1321            self.set_precedence(operator_precedence, node.operand)
1322            self.traverse(node.operand)
1323
1324    binop = {
1325        "Add": "+",
1326        "Sub": "-",
1327        "Mult": "*",
1328        "MatMult": "@",
1329        "Div": "/",
1330        "Mod": "%",
1331        "LShift": "<<",
1332        "RShift": ">>",
1333        "BitOr": "|",
1334        "BitXor": "^",
1335        "BitAnd": "&",
1336        "FloorDiv": "//",
1337        "Pow": "**",
1338    }
1339
1340    binop_precedence = {
1341        "+": _Precedence.ARITH,
1342        "-": _Precedence.ARITH,
1343        "*": _Precedence.TERM,
1344        "@": _Precedence.TERM,
1345        "/": _Precedence.TERM,
1346        "%": _Precedence.TERM,
1347        "<<": _Precedence.SHIFT,
1348        ">>": _Precedence.SHIFT,
1349        "|": _Precedence.BOR,
1350        "^": _Precedence.BXOR,
1351        "&": _Precedence.BAND,
1352        "//": _Precedence.TERM,
1353        "**": _Precedence.POWER,
1354    }
1355
1356    binop_rassoc = frozenset(("**",))
1357    def visit_BinOp(self, node):
1358        operator = self.binop[node.op.__class__.__name__]
1359        operator_precedence = self.binop_precedence[operator]
1360        with self.require_parens(operator_precedence, node):
1361            if operator in self.binop_rassoc:
1362                left_precedence = operator_precedence.next()
1363                right_precedence = operator_precedence
1364            else:
1365                left_precedence = operator_precedence
1366                right_precedence = operator_precedence.next()
1367
1368            self.set_precedence(left_precedence, node.left)
1369            self.traverse(node.left)
1370            self.write(f" {operator} ")
1371            self.set_precedence(right_precedence, node.right)
1372            self.traverse(node.right)
1373
1374    cmpops = {
1375        "Eq": "==",
1376        "NotEq": "!=",
1377        "Lt": "<",
1378        "LtE": "<=",
1379        "Gt": ">",
1380        "GtE": ">=",
1381        "Is": "is",
1382        "IsNot": "is not",
1383        "In": "in",
1384        "NotIn": "not in",
1385    }
1386
1387    def visit_Compare(self, node):
1388        with self.require_parens(_Precedence.CMP, node):
1389            self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
1390            self.traverse(node.left)
1391            for o, e in zip(node.ops, node.comparators):
1392                self.write(" " + self.cmpops[o.__class__.__name__] + " ")
1393                self.traverse(e)
1394
1395    boolops = {"And": "and", "Or": "or"}
1396    boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}
1397
1398    def visit_BoolOp(self, node):
1399        operator = self.boolops[node.op.__class__.__name__]
1400        operator_precedence = self.boolop_precedence[operator]
1401
1402        def increasing_level_traverse(node):
1403            nonlocal operator_precedence
1404            operator_precedence = operator_precedence.next()
1405            self.set_precedence(operator_precedence, node)
1406            self.traverse(node)
1407
1408        with self.require_parens(operator_precedence, node):
1409            s = f" {operator} "
1410            self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)
1411
1412    def visit_Attribute(self, node):
1413        self.set_precedence(_Precedence.ATOM, node.value)
1414        self.traverse(node.value)
1415        # Special case: 3.__abs__() is a syntax error, so if node.value
1416        # is an integer literal then we need to either parenthesize
1417        # it or add an extra space to get 3 .__abs__().
1418        if isinstance(node.value, Constant) and isinstance(node.value.value, int):
1419            self.write(" ")
1420        self.write(".")
1421        self.write(node.attr)
1422
1423    def visit_Call(self, node):
1424        self.set_precedence(_Precedence.ATOM, node.func)
1425        self.traverse(node.func)
1426        with self.delimit("(", ")"):
1427            comma = False
1428            for e in node.args:
1429                if comma:
1430                    self.write(", ")
1431                else:
1432                    comma = True
1433                self.traverse(e)
1434            for e in node.keywords:
1435                if comma:
1436                    self.write(", ")
1437                else:
1438                    comma = True
1439                self.traverse(e)
1440
1441    def visit_Subscript(self, node):
1442        def is_simple_tuple(slice_value):
1443            # when unparsing a non-empty tuple, the parantheses can be safely
1444            # omitted if there aren't any elements that explicitly requires
1445            # parantheses (such as starred expressions).
1446            return (
1447                isinstance(slice_value, Tuple)
1448                and slice_value.elts
1449                and not any(isinstance(elt, Starred) for elt in slice_value.elts)
1450            )
1451
1452        self.set_precedence(_Precedence.ATOM, node.value)
1453        self.traverse(node.value)
1454        with self.delimit("[", "]"):
1455            if is_simple_tuple(node.slice):
1456                self.items_view(self.traverse, node.slice.elts)
1457            else:
1458                self.traverse(node.slice)
1459
1460    def visit_Starred(self, node):
1461        self.write("*")
1462        self.set_precedence(_Precedence.EXPR, node.value)
1463        self.traverse(node.value)
1464
1465    def visit_Ellipsis(self, node):
1466        self.write("...")
1467
1468    def visit_Slice(self, node):
1469        if node.lower:
1470            self.traverse(node.lower)
1471        self.write(":")
1472        if node.upper:
1473            self.traverse(node.upper)
1474        if node.step:
1475            self.write(":")
1476            self.traverse(node.step)
1477
1478    def visit_arg(self, node):
1479        self.write(node.arg)
1480        if node.annotation:
1481            self.write(": ")
1482            self.traverse(node.annotation)
1483
1484    def visit_arguments(self, node):
1485        first = True
1486        # normal arguments
1487        all_args = node.posonlyargs + node.args
1488        defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
1489        for index, elements in enumerate(zip(all_args, defaults), 1):
1490            a, d = elements
1491            if first:
1492                first = False
1493            else:
1494                self.write(", ")
1495            self.traverse(a)
1496            if d:
1497                self.write("=")
1498                self.traverse(d)
1499            if index == len(node.posonlyargs):
1500                self.write(", /")
1501
1502        # varargs, or bare '*' if no varargs but keyword-only arguments present
1503        if node.vararg or node.kwonlyargs:
1504            if first:
1505                first = False
1506            else:
1507                self.write(", ")
1508            self.write("*")
1509            if node.vararg:
1510                self.write(node.vararg.arg)
1511                if node.vararg.annotation:
1512                    self.write(": ")
1513                    self.traverse(node.vararg.annotation)
1514
1515        # keyword-only arguments
1516        if node.kwonlyargs:
1517            for a, d in zip(node.kwonlyargs, node.kw_defaults):
1518                self.write(", ")
1519                self.traverse(a)
1520                if d:
1521                    self.write("=")
1522                    self.traverse(d)
1523
1524        # kwargs
1525        if node.kwarg:
1526            if first:
1527                first = False
1528            else:
1529                self.write(", ")
1530            self.write("**" + node.kwarg.arg)
1531            if node.kwarg.annotation:
1532                self.write(": ")
1533                self.traverse(node.kwarg.annotation)
1534
1535    def visit_keyword(self, node):
1536        if node.arg is None:
1537            self.write("**")
1538        else:
1539            self.write(node.arg)
1540            self.write("=")
1541        self.traverse(node.value)
1542
1543    def visit_Lambda(self, node):
1544        with self.require_parens(_Precedence.TEST, node):
1545            self.write("lambda ")
1546            self.traverse(node.args)
1547            self.write(": ")
1548            self.set_precedence(_Precedence.TEST, node.body)
1549            self.traverse(node.body)
1550
1551    def visit_alias(self, node):
1552        self.write(node.name)
1553        if node.asname:
1554            self.write(" as " + node.asname)
1555
1556    def visit_withitem(self, node):
1557        self.traverse(node.context_expr)
1558        if node.optional_vars:
1559            self.write(" as ")
1560            self.traverse(node.optional_vars)
1561
1562def unparse(ast_obj):
1563    unparser = _Unparser()
1564    return unparser.visit(ast_obj)
1565
1566
1567def main():
1568    import argparse
1569
1570    parser = argparse.ArgumentParser(prog='python -m ast')
1571    parser.add_argument('infile', type=argparse.FileType(mode='rb'), nargs='?',
1572                        default='-',
1573                        help='the file to parse; defaults to stdin')
1574    parser.add_argument('-m', '--mode', default='exec',
1575                        choices=('exec', 'single', 'eval', 'func_type'),
1576                        help='specify what kind of code must be parsed')
1577    parser.add_argument('--no-type-comments', default=True, action='store_false',
1578                        help="don't add information about type comments")
1579    parser.add_argument('-a', '--include-attributes', action='store_true',
1580                        help='include attributes such as line numbers and '
1581                             'column offsets')
1582    parser.add_argument('-i', '--indent', type=int, default=3,
1583                        help='indentation of nodes (number of spaces)')
1584    args = parser.parse_args()
1585
1586    with args.infile as infile:
1587        source = infile.read()
1588    tree = parse(source, args.infile.name, args.mode, type_comments=args.no_type_comments)
1589    print(dump(tree, include_attributes=args.include_attributes, indent=args.indent))
1590
1591if __name__ == '__main__':
1592    main()
1593