• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2    ast
3    ~~~
4
5    The `ast` module helps Python applications to process trees of the Python
6    abstract syntax grammar.  The abstract syntax itself might change with
7    each Python release; this module helps to find out programmatically what
8    the current grammar looks like and allows modifications of it.
9
10    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11    a flag to the `compile()` builtin function or by using the `parse()`
12    function from this module.  The result will be a tree of objects whose
13    classes all inherit from `ast.AST`.
14
15    A modified abstract syntax tree can be compiled into a Python code object
16    using the built-in `compile()` function.
17
18    Additionally various helper functions are provided that make working with
19    the trees simpler.  The main intention of the helper functions and this
20    module in general is to provide an easy to use interface for libraries
21    that work tightly with the python syntax (template engines for example).
22
23
24    :copyright: Copyright 2008 by Armin Ronacher.
25    :license: Python License.
26"""
27import sys
28import re
29from _ast import *
30from contextlib import contextmanager, nullcontext
31from enum import IntEnum, auto, _simple_enum
32
33
34def parse(source, filename='<unknown>', mode='exec', *,
35          type_comments=False, feature_version=None, optimize=-1):
36    """
37    Parse the source into an AST node.
38    Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
39    Pass type_comments=True to get back type comments where the syntax allows.
40    """
41    flags = PyCF_ONLY_AST
42    if optimize > 0:
43        flags |= PyCF_OPTIMIZED_AST
44    if type_comments:
45        flags |= PyCF_TYPE_COMMENTS
46    if feature_version is None:
47        feature_version = -1
48    elif isinstance(feature_version, tuple):
49        major, minor = feature_version  # Should be a 2-tuple.
50        if major != 3:
51            raise ValueError(f"Unsupported major version: {major}")
52        feature_version = minor
53    # Else it should be an int giving the minor version for 3.x.
54    return compile(source, filename, mode, flags,
55                   _feature_version=feature_version, optimize=optimize)
56
57
58def literal_eval(node_or_string):
59    """
60    Evaluate an expression node or a string containing only a Python
61    expression.  The string or node provided may only consist of the following
62    Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
63    sets, booleans, and None.
64
65    Caution: A complex expression can overflow the C stack and cause a crash.
66    """
67    if isinstance(node_or_string, str):
68        node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')
69    if isinstance(node_or_string, Expression):
70        node_or_string = node_or_string.body
71    def _raise_malformed_node(node):
72        msg = "malformed node or string"
73        if lno := getattr(node, 'lineno', None):
74            msg += f' on line {lno}'
75        raise ValueError(msg + f': {node!r}')
76    def _convert_num(node):
77        if not isinstance(node, Constant) or type(node.value) not in (int, float, complex):
78            _raise_malformed_node(node)
79        return node.value
80    def _convert_signed_num(node):
81        if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
82            operand = _convert_num(node.operand)
83            if isinstance(node.op, UAdd):
84                return + operand
85            else:
86                return - operand
87        return _convert_num(node)
88    def _convert(node):
89        if isinstance(node, Constant):
90            return node.value
91        elif isinstance(node, Tuple):
92            return tuple(map(_convert, node.elts))
93        elif isinstance(node, List):
94            return list(map(_convert, node.elts))
95        elif isinstance(node, Set):
96            return set(map(_convert, node.elts))
97        elif (isinstance(node, Call) and isinstance(node.func, Name) and
98              node.func.id == 'set' and node.args == node.keywords == []):
99            return set()
100        elif isinstance(node, Dict):
101            if len(node.keys) != len(node.values):
102                _raise_malformed_node(node)
103            return dict(zip(map(_convert, node.keys),
104                            map(_convert, node.values)))
105        elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
106            left = _convert_signed_num(node.left)
107            right = _convert_num(node.right)
108            if isinstance(left, (int, float)) and isinstance(right, complex):
109                if isinstance(node.op, Add):
110                    return left + right
111                else:
112                    return left - right
113        return _convert_signed_num(node)
114    return _convert(node_or_string)
115
116
117def dump(
118    node, annotate_fields=True, include_attributes=False,
119    *,
120    indent=None, show_empty=False,
121):
122    """
123    Return a formatted dump of the tree in node.  This is mainly useful for
124    debugging purposes.  If annotate_fields is true (by default),
125    the returned string will show the names and the values for fields.
126    If annotate_fields is false, the result string will be more compact by
127    omitting unambiguous field names.  Attributes such as line
128    numbers and column offsets are not dumped by default.  If this is wanted,
129    include_attributes can be set to true.  If indent is a non-negative
130    integer or string, then the tree will be pretty-printed with that indent
131    level. None (the default) selects the single line representation.
132    If show_empty is False, then empty lists and fields that are None
133    will be omitted from the output for better readability.
134    """
135    def _format(node, level=0):
136        if indent is not None:
137            level += 1
138            prefix = '\n' + indent * level
139            sep = ',\n' + indent * level
140        else:
141            prefix = ''
142            sep = ', '
143        if isinstance(node, AST):
144            cls = type(node)
145            args = []
146            args_buffer = []
147            allsimple = True
148            keywords = annotate_fields
149            for name in node._fields:
150                try:
151                    value = getattr(node, name)
152                except AttributeError:
153                    keywords = True
154                    continue
155                if value is None and getattr(cls, name, ...) is None:
156                    keywords = True
157                    continue
158                if (
159                    not show_empty
160                    and (value is None or value == [])
161                    # Special cases:
162                    # `Constant(value=None)` and `MatchSingleton(value=None)`
163                    and not isinstance(node, (Constant, MatchSingleton))
164                ):
165                    args_buffer.append(repr(value))
166                    continue
167                elif not keywords:
168                    args.extend(args_buffer)
169                    args_buffer = []
170                value, simple = _format(value, level)
171                allsimple = allsimple and simple
172                if keywords:
173                    args.append('%s=%s' % (name, value))
174                else:
175                    args.append(value)
176            if include_attributes and node._attributes:
177                for name in node._attributes:
178                    try:
179                        value = getattr(node, name)
180                    except AttributeError:
181                        continue
182                    if value is None and getattr(cls, name, ...) is None:
183                        continue
184                    value, simple = _format(value, level)
185                    allsimple = allsimple and simple
186                    args.append('%s=%s' % (name, value))
187            if allsimple and len(args) <= 3:
188                return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
189            return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
190        elif isinstance(node, list):
191            if not node:
192                return '[]', True
193            return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
194        return repr(node), True
195
196    if not isinstance(node, AST):
197        raise TypeError('expected AST, got %r' % node.__class__.__name__)
198    if indent is not None and not isinstance(indent, str):
199        indent = ' ' * indent
200    return _format(node)[0]
201
202
203def copy_location(new_node, old_node):
204    """
205    Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
206    attributes) from *old_node* to *new_node* if possible, and return *new_node*.
207    """
208    for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
209        if attr in old_node._attributes and attr in new_node._attributes:
210            value = getattr(old_node, attr, None)
211            # end_lineno and end_col_offset are optional attributes, and they
212            # should be copied whether the value is None or not.
213            if value is not None or (
214                hasattr(old_node, attr) and attr.startswith("end_")
215            ):
216                setattr(new_node, attr, value)
217    return new_node
218
219
220def fix_missing_locations(node):
221    """
222    When you compile a node tree with compile(), the compiler expects lineno and
223    col_offset attributes for every node that supports them.  This is rather
224    tedious to fill in for generated nodes, so this helper adds these attributes
225    recursively where not already set, by setting them to the values of the
226    parent node.  It works recursively starting at *node*.
227    """
228    def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
229        if 'lineno' in node._attributes:
230            if not hasattr(node, 'lineno'):
231                node.lineno = lineno
232            else:
233                lineno = node.lineno
234        if 'end_lineno' in node._attributes:
235            if getattr(node, 'end_lineno', None) is None:
236                node.end_lineno = end_lineno
237            else:
238                end_lineno = node.end_lineno
239        if 'col_offset' in node._attributes:
240            if not hasattr(node, 'col_offset'):
241                node.col_offset = col_offset
242            else:
243                col_offset = node.col_offset
244        if 'end_col_offset' in node._attributes:
245            if getattr(node, 'end_col_offset', None) is None:
246                node.end_col_offset = end_col_offset
247            else:
248                end_col_offset = node.end_col_offset
249        for child in iter_child_nodes(node):
250            _fix(child, lineno, col_offset, end_lineno, end_col_offset)
251    _fix(node, 1, 0, 1, 0)
252    return node
253
254
255def increment_lineno(node, n=1):
256    """
257    Increment the line number and end line number of each node in the tree
258    starting at *node* by *n*. This is useful to "move code" to a different
259    location in a file.
260    """
261    for child in walk(node):
262        # TypeIgnore is a special case where lineno is not an attribute
263        # but rather a field of the node itself.
264        if isinstance(child, TypeIgnore):
265            child.lineno = getattr(child, 'lineno', 0) + n
266            continue
267
268        if 'lineno' in child._attributes:
269            child.lineno = getattr(child, 'lineno', 0) + n
270        if (
271            "end_lineno" in child._attributes
272            and (end_lineno := getattr(child, "end_lineno", 0)) is not None
273        ):
274            child.end_lineno = end_lineno + n
275    return node
276
277
278def iter_fields(node):
279    """
280    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
281    that is present on *node*.
282    """
283    for field in node._fields:
284        try:
285            yield field, getattr(node, field)
286        except AttributeError:
287            pass
288
289
290def iter_child_nodes(node):
291    """
292    Yield all direct child nodes of *node*, that is, all fields that are nodes
293    and all items of fields that are lists of nodes.
294    """
295    for name, field in iter_fields(node):
296        if isinstance(field, AST):
297            yield field
298        elif isinstance(field, list):
299            for item in field:
300                if isinstance(item, AST):
301                    yield item
302
303
304def get_docstring(node, clean=True):
305    """
306    Return the docstring for the given node or None if no docstring can
307    be found.  If the node provided does not have docstrings a TypeError
308    will be raised.
309
310    If *clean* is `True`, all tabs are expanded to spaces and any whitespace
311    that can be uniformly removed from the second line onwards is removed.
312    """
313    if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
314        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
315    if not(node.body and isinstance(node.body[0], Expr)):
316        return None
317    node = node.body[0].value
318    if isinstance(node, Constant) and isinstance(node.value, str):
319        text = node.value
320    else:
321        return None
322    if clean:
323        import inspect
324        text = inspect.cleandoc(text)
325    return text
326
327
328_line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))")
329def _splitlines_no_ff(source, maxlines=None):
330    """Split a string into lines ignoring form feed and other chars.
331
332    This mimics how the Python parser splits source code.
333    """
334    lines = []
335    for lineno, match in enumerate(_line_pattern.finditer(source), 1):
336        if maxlines is not None and lineno > maxlines:
337            break
338        lines.append(match[0])
339    return lines
340
341
342def _pad_whitespace(source):
343    r"""Replace all chars except '\f\t' in a line with spaces."""
344    result = ''
345    for c in source:
346        if c in '\f\t':
347            result += c
348        else:
349            result += ' '
350    return result
351
352
353def get_source_segment(source, node, *, padded=False):
354    """Get source code segment of the *source* that generated *node*.
355
356    If some location information (`lineno`, `end_lineno`, `col_offset`,
357    or `end_col_offset`) is missing, return None.
358
359    If *padded* is `True`, the first line of a multi-line statement will
360    be padded with spaces to match its original position.
361    """
362    try:
363        if node.end_lineno is None or node.end_col_offset is None:
364            return None
365        lineno = node.lineno - 1
366        end_lineno = node.end_lineno - 1
367        col_offset = node.col_offset
368        end_col_offset = node.end_col_offset
369    except AttributeError:
370        return None
371
372    lines = _splitlines_no_ff(source, maxlines=end_lineno+1)
373    if end_lineno == lineno:
374        return lines[lineno].encode()[col_offset:end_col_offset].decode()
375
376    if padded:
377        padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
378    else:
379        padding = ''
380
381    first = padding + lines[lineno].encode()[col_offset:].decode()
382    last = lines[end_lineno].encode()[:end_col_offset].decode()
383    lines = lines[lineno+1:end_lineno]
384
385    lines.insert(0, first)
386    lines.append(last)
387    return ''.join(lines)
388
389
390def walk(node):
391    """
392    Recursively yield all descendant nodes in the tree starting at *node*
393    (including *node* itself), in no specified order.  This is useful if you
394    only want to modify nodes in place and don't care about the context.
395    """
396    from collections import deque
397    todo = deque([node])
398    while todo:
399        node = todo.popleft()
400        todo.extend(iter_child_nodes(node))
401        yield node
402
403
404class NodeVisitor(object):
405    """
406    A node visitor base class that walks the abstract syntax tree and calls a
407    visitor function for every node found.  This function may return a value
408    which is forwarded by the `visit` method.
409
410    This class is meant to be subclassed, with the subclass adding visitor
411    methods.
412
413    Per default the visitor functions for the nodes are ``'visit_'`` +
414    class name of the node.  So a `TryFinally` node visit function would
415    be `visit_TryFinally`.  This behavior can be changed by overriding
416    the `visit` method.  If no visitor function exists for a node
417    (return value `None`) the `generic_visit` visitor is used instead.
418
419    Don't use the `NodeVisitor` if you want to apply changes to nodes during
420    traversing.  For this a special visitor exists (`NodeTransformer`) that
421    allows modifications.
422    """
423
424    def visit(self, node):
425        """Visit a node."""
426        method = 'visit_' + node.__class__.__name__
427        visitor = getattr(self, method, self.generic_visit)
428        return visitor(node)
429
430    def generic_visit(self, node):
431        """Called if no explicit visitor function exists for a node."""
432        for field, value in iter_fields(node):
433            if isinstance(value, list):
434                for item in value:
435                    if isinstance(item, AST):
436                        self.visit(item)
437            elif isinstance(value, AST):
438                self.visit(value)
439
440    def visit_Constant(self, node):
441        value = node.value
442        type_name = _const_node_type_names.get(type(value))
443        if type_name is None:
444            for cls, name in _const_node_type_names.items():
445                if isinstance(value, cls):
446                    type_name = name
447                    break
448        if type_name is not None:
449            method = 'visit_' + type_name
450            try:
451                visitor = getattr(self, method)
452            except AttributeError:
453                pass
454            else:
455                import warnings
456                warnings.warn(f"{method} is deprecated; add visit_Constant",
457                              DeprecationWarning, 2)
458                return visitor(node)
459        return self.generic_visit(node)
460
461
462class NodeTransformer(NodeVisitor):
463    """
464    A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
465    allows modification of nodes.
466
467    The `NodeTransformer` will walk the AST and use the return value of the
468    visitor methods to replace or remove the old node.  If the return value of
469    the visitor method is ``None``, the node will be removed from its location,
470    otherwise it is replaced with the return value.  The return value may be the
471    original node in which case no replacement takes place.
472
473    Here is an example transformer that rewrites all occurrences of name lookups
474    (``foo``) to ``data['foo']``::
475
476       class RewriteName(NodeTransformer):
477
478           def visit_Name(self, node):
479               return Subscript(
480                   value=Name(id='data', ctx=Load()),
481                   slice=Constant(value=node.id),
482                   ctx=node.ctx
483               )
484
485    Keep in mind that if the node you're operating on has child nodes you must
486    either transform the child nodes yourself or call the :meth:`generic_visit`
487    method for the node first.
488
489    For nodes that were part of a collection of statements (that applies to all
490    statement nodes), the visitor may also return a list of nodes rather than
491    just a single node.
492
493    Usually you use the transformer like this::
494
495       node = YourTransformer().visit(node)
496    """
497
498    def generic_visit(self, node):
499        for field, old_value in iter_fields(node):
500            if isinstance(old_value, list):
501                new_values = []
502                for value in old_value:
503                    if isinstance(value, AST):
504                        value = self.visit(value)
505                        if value is None:
506                            continue
507                        elif not isinstance(value, AST):
508                            new_values.extend(value)
509                            continue
510                    new_values.append(value)
511                old_value[:] = new_values
512            elif isinstance(old_value, AST):
513                new_node = self.visit(old_value)
514                if new_node is None:
515                    delattr(node, field)
516                else:
517                    setattr(node, field, new_node)
518        return node
519
520
521_DEPRECATED_VALUE_ALIAS_MESSAGE = (
522    "{name} is deprecated and will be removed in Python {remove}; use value instead"
523)
524_DEPRECATED_CLASS_MESSAGE = (
525    "{name} is deprecated and will be removed in Python {remove}; "
526    "use ast.Constant instead"
527)
528
529
530# If the ast module is loaded more than once, only add deprecated methods once
531if not hasattr(Constant, 'n'):
532    # The following code is for backward compatibility.
533    # It will be removed in future.
534
535    def _n_getter(self):
536        """Deprecated. Use value instead."""
537        import warnings
538        warnings._deprecated(
539            "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
540        )
541        return self.value
542
543    def _n_setter(self, value):
544        import warnings
545        warnings._deprecated(
546            "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
547        )
548        self.value = value
549
550    def _s_getter(self):
551        """Deprecated. Use value instead."""
552        import warnings
553        warnings._deprecated(
554            "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
555        )
556        return self.value
557
558    def _s_setter(self, value):
559        import warnings
560        warnings._deprecated(
561            "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
562        )
563        self.value = value
564
565    Constant.n = property(_n_getter, _n_setter)
566    Constant.s = property(_s_getter, _s_setter)
567
568class _ABC(type):
569
570    def __init__(cls, *args):
571        cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead"""
572
573    def __instancecheck__(cls, inst):
574        if cls in _const_types:
575            import warnings
576            warnings._deprecated(
577                f"ast.{cls.__qualname__}",
578                message=_DEPRECATED_CLASS_MESSAGE,
579                remove=(3, 14)
580            )
581        if not isinstance(inst, Constant):
582            return False
583        if cls in _const_types:
584            try:
585                value = inst.value
586            except AttributeError:
587                return False
588            else:
589                return (
590                    isinstance(value, _const_types[cls]) and
591                    not isinstance(value, _const_types_not.get(cls, ()))
592                )
593        return type.__instancecheck__(cls, inst)
594
595def _new(cls, *args, **kwargs):
596    for key in kwargs:
597        if key not in cls._fields:
598            # arbitrary keyword arguments are accepted
599            continue
600        pos = cls._fields.index(key)
601        if pos < len(args):
602            raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}")
603    if cls in _const_types:
604        import warnings
605        warnings._deprecated(
606            f"ast.{cls.__qualname__}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14)
607        )
608        return Constant(*args, **kwargs)
609    return Constant.__new__(cls, *args, **kwargs)
610
611class Num(Constant, metaclass=_ABC):
612    _fields = ('n',)
613    __new__ = _new
614
615class Str(Constant, metaclass=_ABC):
616    _fields = ('s',)
617    __new__ = _new
618
619class Bytes(Constant, metaclass=_ABC):
620    _fields = ('s',)
621    __new__ = _new
622
623class NameConstant(Constant, metaclass=_ABC):
624    __new__ = _new
625
626class Ellipsis(Constant, metaclass=_ABC):
627    _fields = ()
628
629    def __new__(cls, *args, **kwargs):
630        if cls is _ast_Ellipsis:
631            import warnings
632            warnings._deprecated(
633                "ast.Ellipsis", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14)
634            )
635            return Constant(..., *args, **kwargs)
636        return Constant.__new__(cls, *args, **kwargs)
637
638# Keep another reference to Ellipsis in the global namespace
639# so it can be referenced in Ellipsis.__new__
640# (The original "Ellipsis" name is removed from the global namespace later on)
641_ast_Ellipsis = Ellipsis
642
643_const_types = {
644    Num: (int, float, complex),
645    Str: (str,),
646    Bytes: (bytes,),
647    NameConstant: (type(None), bool),
648    Ellipsis: (type(...),),
649}
650_const_types_not = {
651    Num: (bool,),
652}
653
654_const_node_type_names = {
655    bool: 'NameConstant',  # should be before int
656    type(None): 'NameConstant',
657    int: 'Num',
658    float: 'Num',
659    complex: 'Num',
660    str: 'Str',
661    bytes: 'Bytes',
662    type(...): 'Ellipsis',
663}
664
665class slice(AST):
666    """Deprecated AST node class."""
667
668class Index(slice):
669    """Deprecated AST node class. Use the index value directly instead."""
670    def __new__(cls, value, **kwargs):
671        return value
672
673class ExtSlice(slice):
674    """Deprecated AST node class. Use ast.Tuple instead."""
675    def __new__(cls, dims=(), **kwargs):
676        return Tuple(list(dims), Load(), **kwargs)
677
678# If the ast module is loaded more than once, only add deprecated methods once
679if not hasattr(Tuple, 'dims'):
680    # The following code is for backward compatibility.
681    # It will be removed in future.
682
683    def _dims_getter(self):
684        """Deprecated. Use elts instead."""
685        return self.elts
686
687    def _dims_setter(self, value):
688        self.elts = value
689
690    Tuple.dims = property(_dims_getter, _dims_setter)
691
692class Suite(mod):
693    """Deprecated AST node class.  Unused in Python 3."""
694
695class AugLoad(expr_context):
696    """Deprecated AST node class.  Unused in Python 3."""
697
698class AugStore(expr_context):
699    """Deprecated AST node class.  Unused in Python 3."""
700
701class Param(expr_context):
702    """Deprecated AST node class.  Unused in Python 3."""
703
704
705# Large float and imaginary literals get turned into infinities in the AST.
706# We unparse those infinities to INFSTR.
707_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
708
709@_simple_enum(IntEnum)
710class _Precedence:
711    """Precedence table that originated from python grammar."""
712
713    NAMED_EXPR = auto()      # <target> := <expr1>
714    TUPLE = auto()           # <expr1>, <expr2>
715    YIELD = auto()           # 'yield', 'yield from'
716    TEST = auto()            # 'if'-'else', 'lambda'
717    OR = auto()              # 'or'
718    AND = auto()             # 'and'
719    NOT = auto()             # 'not'
720    CMP = auto()             # '<', '>', '==', '>=', '<=', '!=',
721                             # 'in', 'not in', 'is', 'is not'
722    EXPR = auto()
723    BOR = EXPR               # '|'
724    BXOR = auto()            # '^'
725    BAND = auto()            # '&'
726    SHIFT = auto()           # '<<', '>>'
727    ARITH = auto()           # '+', '-'
728    TERM = auto()            # '*', '@', '/', '%', '//'
729    FACTOR = auto()          # unary '+', '-', '~'
730    POWER = auto()           # '**'
731    AWAIT = auto()           # 'await'
732    ATOM = auto()
733
734    def next(self):
735        try:
736            return self.__class__(self + 1)
737        except ValueError:
738            return self
739
740
741_SINGLE_QUOTES = ("'", '"')
742_MULTI_QUOTES = ('"""', "'''")
743_ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES)
744
745class _Unparser(NodeVisitor):
746    """Methods in this class recursively traverse an AST and
747    output source code for the abstract syntax; original formatting
748    is disregarded."""
749
750    def __init__(self):
751        self._source = []
752        self._precedences = {}
753        self._type_ignores = {}
754        self._indent = 0
755        self._in_try_star = False
756
757    def interleave(self, inter, f, seq):
758        """Call f on each item in seq, calling inter() in between."""
759        seq = iter(seq)
760        try:
761            f(next(seq))
762        except StopIteration:
763            pass
764        else:
765            for x in seq:
766                inter()
767                f(x)
768
769    def items_view(self, traverser, items):
770        """Traverse and separate the given *items* with a comma and append it to
771        the buffer. If *items* is a single item sequence, a trailing comma
772        will be added."""
773        if len(items) == 1:
774            traverser(items[0])
775            self.write(",")
776        else:
777            self.interleave(lambda: self.write(", "), traverser, items)
778
779    def maybe_newline(self):
780        """Adds a newline if it isn't the start of generated source"""
781        if self._source:
782            self.write("\n")
783
784    def fill(self, text=""):
785        """Indent a piece of text and append it, according to the current
786        indentation level"""
787        self.maybe_newline()
788        self.write("    " * self._indent + text)
789
790    def write(self, *text):
791        """Add new source parts"""
792        self._source.extend(text)
793
794    @contextmanager
795    def buffered(self, buffer = None):
796        if buffer is None:
797            buffer = []
798
799        original_source = self._source
800        self._source = buffer
801        yield buffer
802        self._source = original_source
803
804    @contextmanager
805    def block(self, *, extra = None):
806        """A context manager for preparing the source for blocks. It adds
807        the character':', increases the indentation on enter and decreases
808        the indentation on exit. If *extra* is given, it will be directly
809        appended after the colon character.
810        """
811        self.write(":")
812        if extra:
813            self.write(extra)
814        self._indent += 1
815        yield
816        self._indent -= 1
817
818    @contextmanager
819    def delimit(self, start, end):
820        """A context manager for preparing the source for expressions. It adds
821        *start* to the buffer and enters, after exit it adds *end*."""
822
823        self.write(start)
824        yield
825        self.write(end)
826
827    def delimit_if(self, start, end, condition):
828        if condition:
829            return self.delimit(start, end)
830        else:
831            return nullcontext()
832
833    def require_parens(self, precedence, node):
834        """Shortcut to adding precedence related parens"""
835        return self.delimit_if("(", ")", self.get_precedence(node) > precedence)
836
837    def get_precedence(self, node):
838        return self._precedences.get(node, _Precedence.TEST)
839
840    def set_precedence(self, precedence, *nodes):
841        for node in nodes:
842            self._precedences[node] = precedence
843
844    def get_raw_docstring(self, node):
845        """If a docstring node is found in the body of the *node* parameter,
846        return that docstring node, None otherwise.
847
848        Logic mirrored from ``_PyAST_GetDocString``."""
849        if not isinstance(
850            node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
851        ) or len(node.body) < 1:
852            return None
853        node = node.body[0]
854        if not isinstance(node, Expr):
855            return None
856        node = node.value
857        if isinstance(node, Constant) and isinstance(node.value, str):
858            return node
859
860    def get_type_comment(self, node):
861        comment = self._type_ignores.get(node.lineno) or node.type_comment
862        if comment is not None:
863            return f" # type: {comment}"
864
865    def traverse(self, node):
866        if isinstance(node, list):
867            for item in node:
868                self.traverse(item)
869        else:
870            super().visit(node)
871
872    # Note: as visit() resets the output text, do NOT rely on
873    # NodeVisitor.generic_visit to handle any nodes (as it calls back in to
874    # the subclass visit() method, which resets self._source to an empty list)
875    def visit(self, node):
876        """Outputs a source code string that, if converted back to an ast
877        (using ast.parse) will generate an AST equivalent to *node*"""
878        self._source = []
879        self.traverse(node)
880        return "".join(self._source)
881
882    def _write_docstring_and_traverse_body(self, node):
883        if (docstring := self.get_raw_docstring(node)):
884            self._write_docstring(docstring)
885            self.traverse(node.body[1:])
886        else:
887            self.traverse(node.body)
888
889    def visit_Module(self, node):
890        self._type_ignores = {
891            ignore.lineno: f"ignore{ignore.tag}"
892            for ignore in node.type_ignores
893        }
894        self._write_docstring_and_traverse_body(node)
895        self._type_ignores.clear()
896
897    def visit_FunctionType(self, node):
898        with self.delimit("(", ")"):
899            self.interleave(
900                lambda: self.write(", "), self.traverse, node.argtypes
901            )
902
903        self.write(" -> ")
904        self.traverse(node.returns)
905
906    def visit_Expr(self, node):
907        self.fill()
908        self.set_precedence(_Precedence.YIELD, node.value)
909        self.traverse(node.value)
910
911    def visit_NamedExpr(self, node):
912        with self.require_parens(_Precedence.NAMED_EXPR, node):
913            self.set_precedence(_Precedence.ATOM, node.target, node.value)
914            self.traverse(node.target)
915            self.write(" := ")
916            self.traverse(node.value)
917
918    def visit_Import(self, node):
919        self.fill("import ")
920        self.interleave(lambda: self.write(", "), self.traverse, node.names)
921
922    def visit_ImportFrom(self, node):
923        self.fill("from ")
924        self.write("." * (node.level or 0))
925        if node.module:
926            self.write(node.module)
927        self.write(" import ")
928        self.interleave(lambda: self.write(", "), self.traverse, node.names)
929
930    def visit_Assign(self, node):
931        self.fill()
932        for target in node.targets:
933            self.set_precedence(_Precedence.TUPLE, target)
934            self.traverse(target)
935            self.write(" = ")
936        self.traverse(node.value)
937        if type_comment := self.get_type_comment(node):
938            self.write(type_comment)
939
940    def visit_AugAssign(self, node):
941        self.fill()
942        self.traverse(node.target)
943        self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
944        self.traverse(node.value)
945
946    def visit_AnnAssign(self, node):
947        self.fill()
948        with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)):
949            self.traverse(node.target)
950        self.write(": ")
951        self.traverse(node.annotation)
952        if node.value:
953            self.write(" = ")
954            self.traverse(node.value)
955
956    def visit_Return(self, node):
957        self.fill("return")
958        if node.value:
959            self.write(" ")
960            self.traverse(node.value)
961
962    def visit_Pass(self, node):
963        self.fill("pass")
964
965    def visit_Break(self, node):
966        self.fill("break")
967
968    def visit_Continue(self, node):
969        self.fill("continue")
970
971    def visit_Delete(self, node):
972        self.fill("del ")
973        self.interleave(lambda: self.write(", "), self.traverse, node.targets)
974
975    def visit_Assert(self, node):
976        self.fill("assert ")
977        self.traverse(node.test)
978        if node.msg:
979            self.write(", ")
980            self.traverse(node.msg)
981
982    def visit_Global(self, node):
983        self.fill("global ")
984        self.interleave(lambda: self.write(", "), self.write, node.names)
985
986    def visit_Nonlocal(self, node):
987        self.fill("nonlocal ")
988        self.interleave(lambda: self.write(", "), self.write, node.names)
989
990    def visit_Await(self, node):
991        with self.require_parens(_Precedence.AWAIT, node):
992            self.write("await")
993            if node.value:
994                self.write(" ")
995                self.set_precedence(_Precedence.ATOM, node.value)
996                self.traverse(node.value)
997
998    def visit_Yield(self, node):
999        with self.require_parens(_Precedence.YIELD, node):
1000            self.write("yield")
1001            if node.value:
1002                self.write(" ")
1003                self.set_precedence(_Precedence.ATOM, node.value)
1004                self.traverse(node.value)
1005
1006    def visit_YieldFrom(self, node):
1007        with self.require_parens(_Precedence.YIELD, node):
1008            self.write("yield from ")
1009            if not node.value:
1010                raise ValueError("Node can't be used without a value attribute.")
1011            self.set_precedence(_Precedence.ATOM, node.value)
1012            self.traverse(node.value)
1013
1014    def visit_Raise(self, node):
1015        self.fill("raise")
1016        if not node.exc:
1017            if node.cause:
1018                raise ValueError(f"Node can't use cause without an exception.")
1019            return
1020        self.write(" ")
1021        self.traverse(node.exc)
1022        if node.cause:
1023            self.write(" from ")
1024            self.traverse(node.cause)
1025
1026    def do_visit_try(self, node):
1027        self.fill("try")
1028        with self.block():
1029            self.traverse(node.body)
1030        for ex in node.handlers:
1031            self.traverse(ex)
1032        if node.orelse:
1033            self.fill("else")
1034            with self.block():
1035                self.traverse(node.orelse)
1036        if node.finalbody:
1037            self.fill("finally")
1038            with self.block():
1039                self.traverse(node.finalbody)
1040
1041    def visit_Try(self, node):
1042        prev_in_try_star = self._in_try_star
1043        try:
1044            self._in_try_star = False
1045            self.do_visit_try(node)
1046        finally:
1047            self._in_try_star = prev_in_try_star
1048
1049    def visit_TryStar(self, node):
1050        prev_in_try_star = self._in_try_star
1051        try:
1052            self._in_try_star = True
1053            self.do_visit_try(node)
1054        finally:
1055            self._in_try_star = prev_in_try_star
1056
1057    def visit_ExceptHandler(self, node):
1058        self.fill("except*" if self._in_try_star else "except")
1059        if node.type:
1060            self.write(" ")
1061            self.traverse(node.type)
1062        if node.name:
1063            self.write(" as ")
1064            self.write(node.name)
1065        with self.block():
1066            self.traverse(node.body)
1067
1068    def visit_ClassDef(self, node):
1069        self.maybe_newline()
1070        for deco in node.decorator_list:
1071            self.fill("@")
1072            self.traverse(deco)
1073        self.fill("class " + node.name)
1074        if hasattr(node, "type_params"):
1075            self._type_params_helper(node.type_params)
1076        with self.delimit_if("(", ")", condition = node.bases or node.keywords):
1077            comma = False
1078            for e in node.bases:
1079                if comma:
1080                    self.write(", ")
1081                else:
1082                    comma = True
1083                self.traverse(e)
1084            for e in node.keywords:
1085                if comma:
1086                    self.write(", ")
1087                else:
1088                    comma = True
1089                self.traverse(e)
1090
1091        with self.block():
1092            self._write_docstring_and_traverse_body(node)
1093
1094    def visit_FunctionDef(self, node):
1095        self._function_helper(node, "def")
1096
1097    def visit_AsyncFunctionDef(self, node):
1098        self._function_helper(node, "async def")
1099
1100    def _function_helper(self, node, fill_suffix):
1101        self.maybe_newline()
1102        for deco in node.decorator_list:
1103            self.fill("@")
1104            self.traverse(deco)
1105        def_str = fill_suffix + " " + node.name
1106        self.fill(def_str)
1107        if hasattr(node, "type_params"):
1108            self._type_params_helper(node.type_params)
1109        with self.delimit("(", ")"):
1110            self.traverse(node.args)
1111        if node.returns:
1112            self.write(" -> ")
1113            self.traverse(node.returns)
1114        with self.block(extra=self.get_type_comment(node)):
1115            self._write_docstring_and_traverse_body(node)
1116
1117    def _type_params_helper(self, type_params):
1118        if type_params is not None and len(type_params) > 0:
1119            with self.delimit("[", "]"):
1120                self.interleave(lambda: self.write(", "), self.traverse, type_params)
1121
1122    def visit_TypeVar(self, node):
1123        self.write(node.name)
1124        if node.bound:
1125            self.write(": ")
1126            self.traverse(node.bound)
1127        if node.default_value:
1128            self.write(" = ")
1129            self.traverse(node.default_value)
1130
1131    def visit_TypeVarTuple(self, node):
1132        self.write("*" + node.name)
1133        if node.default_value:
1134            self.write(" = ")
1135            self.traverse(node.default_value)
1136
1137    def visit_ParamSpec(self, node):
1138        self.write("**" + node.name)
1139        if node.default_value:
1140            self.write(" = ")
1141            self.traverse(node.default_value)
1142
1143    def visit_TypeAlias(self, node):
1144        self.fill("type ")
1145        self.traverse(node.name)
1146        self._type_params_helper(node.type_params)
1147        self.write(" = ")
1148        self.traverse(node.value)
1149
1150    def visit_For(self, node):
1151        self._for_helper("for ", node)
1152
1153    def visit_AsyncFor(self, node):
1154        self._for_helper("async for ", node)
1155
1156    def _for_helper(self, fill, node):
1157        self.fill(fill)
1158        self.set_precedence(_Precedence.TUPLE, node.target)
1159        self.traverse(node.target)
1160        self.write(" in ")
1161        self.traverse(node.iter)
1162        with self.block(extra=self.get_type_comment(node)):
1163            self.traverse(node.body)
1164        if node.orelse:
1165            self.fill("else")
1166            with self.block():
1167                self.traverse(node.orelse)
1168
1169    def visit_If(self, node):
1170        self.fill("if ")
1171        self.traverse(node.test)
1172        with self.block():
1173            self.traverse(node.body)
1174        # collapse nested ifs into equivalent elifs.
1175        while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
1176            node = node.orelse[0]
1177            self.fill("elif ")
1178            self.traverse(node.test)
1179            with self.block():
1180                self.traverse(node.body)
1181        # final else
1182        if node.orelse:
1183            self.fill("else")
1184            with self.block():
1185                self.traverse(node.orelse)
1186
1187    def visit_While(self, node):
1188        self.fill("while ")
1189        self.traverse(node.test)
1190        with self.block():
1191            self.traverse(node.body)
1192        if node.orelse:
1193            self.fill("else")
1194            with self.block():
1195                self.traverse(node.orelse)
1196
1197    def visit_With(self, node):
1198        self.fill("with ")
1199        self.interleave(lambda: self.write(", "), self.traverse, node.items)
1200        with self.block(extra=self.get_type_comment(node)):
1201            self.traverse(node.body)
1202
1203    def visit_AsyncWith(self, node):
1204        self.fill("async with ")
1205        self.interleave(lambda: self.write(", "), self.traverse, node.items)
1206        with self.block(extra=self.get_type_comment(node)):
1207            self.traverse(node.body)
1208
1209    def _str_literal_helper(
1210        self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False
1211    ):
1212        """Helper for writing string literals, minimizing escapes.
1213        Returns the tuple (string literal to write, possible quote types).
1214        """
1215        def escape_char(c):
1216            # \n and \t are non-printable, but we only escape them if
1217            # escape_special_whitespace is True
1218            if not escape_special_whitespace and c in "\n\t":
1219                return c
1220            # Always escape backslashes and other non-printable characters
1221            if c == "\\" or not c.isprintable():
1222                return c.encode("unicode_escape").decode("ascii")
1223            return c
1224
1225        escaped_string = "".join(map(escape_char, string))
1226        possible_quotes = quote_types
1227        if "\n" in escaped_string:
1228            possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
1229        possible_quotes = [q for q in possible_quotes if q not in escaped_string]
1230        if not possible_quotes:
1231            # If there aren't any possible_quotes, fallback to using repr
1232            # on the original string. Try to use a quote from quote_types,
1233            # e.g., so that we use triple quotes for docstrings.
1234            string = repr(string)
1235            quote = next((q for q in quote_types if string[0] in q), string[0])
1236            return string[1:-1], [quote]
1237        if escaped_string:
1238            # Sort so that we prefer '''"''' over """\""""
1239            possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
1240            # If we're using triple quotes and we'd need to escape a final
1241            # quote, escape it
1242            if possible_quotes[0][0] == escaped_string[-1]:
1243                assert len(possible_quotes[0]) == 3
1244                escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
1245        return escaped_string, possible_quotes
1246
1247    def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
1248        """Write string literal value with a best effort attempt to avoid backslashes."""
1249        string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
1250        quote_type = quote_types[0]
1251        self.write(f"{quote_type}{string}{quote_type}")
1252
1253    def visit_JoinedStr(self, node):
1254        self.write("f")
1255
1256        fstring_parts = []
1257        for value in node.values:
1258            with self.buffered() as buffer:
1259                self._write_fstring_inner(value)
1260            fstring_parts.append(
1261                ("".join(buffer), isinstance(value, Constant))
1262            )
1263
1264        new_fstring_parts = []
1265        quote_types = list(_ALL_QUOTES)
1266        fallback_to_repr = False
1267        for value, is_constant in fstring_parts:
1268            if is_constant:
1269                value, new_quote_types = self._str_literal_helper(
1270                    value,
1271                    quote_types=quote_types,
1272                    escape_special_whitespace=True,
1273                )
1274                if set(new_quote_types).isdisjoint(quote_types):
1275                    fallback_to_repr = True
1276                    break
1277                quote_types = new_quote_types
1278            elif "\n" in value:
1279                quote_types = [q for q in quote_types if q in _MULTI_QUOTES]
1280                assert quote_types
1281            new_fstring_parts.append(value)
1282
1283        if fallback_to_repr:
1284            # If we weren't able to find a quote type that works for all parts
1285            # of the JoinedStr, fallback to using repr and triple single quotes.
1286            quote_types = ["'''"]
1287            new_fstring_parts.clear()
1288            for value, is_constant in fstring_parts:
1289                if is_constant:
1290                    value = repr('"' + value)  # force repr to use single quotes
1291                    expected_prefix = "'\""
1292                    assert value.startswith(expected_prefix), repr(value)
1293                    value = value[len(expected_prefix):-1]
1294                new_fstring_parts.append(value)
1295
1296        value = "".join(new_fstring_parts)
1297        quote_type = quote_types[0]
1298        self.write(f"{quote_type}{value}{quote_type}")
1299
1300    def _write_fstring_inner(self, node, is_format_spec=False):
1301        if isinstance(node, JoinedStr):
1302            # for both the f-string itself, and format_spec
1303            for value in node.values:
1304                self._write_fstring_inner(value, is_format_spec=is_format_spec)
1305        elif isinstance(node, Constant) and isinstance(node.value, str):
1306            value = node.value.replace("{", "{{").replace("}", "}}")
1307
1308            if is_format_spec:
1309                value = value.replace("\\", "\\\\")
1310                value = value.replace("'", "\\'")
1311                value = value.replace('"', '\\"')
1312                value = value.replace("\n", "\\n")
1313            self.write(value)
1314        elif isinstance(node, FormattedValue):
1315            self.visit_FormattedValue(node)
1316        else:
1317            raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
1318
1319    def visit_FormattedValue(self, node):
1320        def unparse_inner(inner):
1321            unparser = type(self)()
1322            unparser.set_precedence(_Precedence.TEST.next(), inner)
1323            return unparser.visit(inner)
1324
1325        with self.delimit("{", "}"):
1326            expr = unparse_inner(node.value)
1327            if expr.startswith("{"):
1328                # Separate pair of opening brackets as "{ {"
1329                self.write(" ")
1330            self.write(expr)
1331            if node.conversion != -1:
1332                self.write(f"!{chr(node.conversion)}")
1333            if node.format_spec:
1334                self.write(":")
1335                self._write_fstring_inner(node.format_spec, is_format_spec=True)
1336
1337    def visit_Name(self, node):
1338        self.write(node.id)
1339
1340    def _write_docstring(self, node):
1341        self.fill()
1342        if node.kind == "u":
1343            self.write("u")
1344        self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
1345
1346    def _write_constant(self, value):
1347        if isinstance(value, (float, complex)):
1348            # Substitute overflowing decimal literal for AST infinities,
1349            # and inf - inf for NaNs.
1350            self.write(
1351                repr(value)
1352                .replace("inf", _INFSTR)
1353                .replace("nan", f"({_INFSTR}-{_INFSTR})")
1354            )
1355        else:
1356            self.write(repr(value))
1357
1358    def visit_Constant(self, node):
1359        value = node.value
1360        if isinstance(value, tuple):
1361            with self.delimit("(", ")"):
1362                self.items_view(self._write_constant, value)
1363        elif value is ...:
1364            self.write("...")
1365        else:
1366            if node.kind == "u":
1367                self.write("u")
1368            self._write_constant(node.value)
1369
1370    def visit_List(self, node):
1371        with self.delimit("[", "]"):
1372            self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1373
1374    def visit_ListComp(self, node):
1375        with self.delimit("[", "]"):
1376            self.traverse(node.elt)
1377            for gen in node.generators:
1378                self.traverse(gen)
1379
1380    def visit_GeneratorExp(self, node):
1381        with self.delimit("(", ")"):
1382            self.traverse(node.elt)
1383            for gen in node.generators:
1384                self.traverse(gen)
1385
1386    def visit_SetComp(self, node):
1387        with self.delimit("{", "}"):
1388            self.traverse(node.elt)
1389            for gen in node.generators:
1390                self.traverse(gen)
1391
1392    def visit_DictComp(self, node):
1393        with self.delimit("{", "}"):
1394            self.traverse(node.key)
1395            self.write(": ")
1396            self.traverse(node.value)
1397            for gen in node.generators:
1398                self.traverse(gen)
1399
1400    def visit_comprehension(self, node):
1401        if node.is_async:
1402            self.write(" async for ")
1403        else:
1404            self.write(" for ")
1405        self.set_precedence(_Precedence.TUPLE, node.target)
1406        self.traverse(node.target)
1407        self.write(" in ")
1408        self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
1409        self.traverse(node.iter)
1410        for if_clause in node.ifs:
1411            self.write(" if ")
1412            self.traverse(if_clause)
1413
1414    def visit_IfExp(self, node):
1415        with self.require_parens(_Precedence.TEST, node):
1416            self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
1417            self.traverse(node.body)
1418            self.write(" if ")
1419            self.traverse(node.test)
1420            self.write(" else ")
1421            self.set_precedence(_Precedence.TEST, node.orelse)
1422            self.traverse(node.orelse)
1423
1424    def visit_Set(self, node):
1425        if node.elts:
1426            with self.delimit("{", "}"):
1427                self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1428        else:
1429            # `{}` would be interpreted as a dictionary literal, and
1430            # `set` might be shadowed. Thus:
1431            self.write('{*()}')
1432
1433    def visit_Dict(self, node):
1434        def write_key_value_pair(k, v):
1435            self.traverse(k)
1436            self.write(": ")
1437            self.traverse(v)
1438
1439        def write_item(item):
1440            k, v = item
1441            if k is None:
1442                # for dictionary unpacking operator in dicts {**{'y': 2}}
1443                # see PEP 448 for details
1444                self.write("**")
1445                self.set_precedence(_Precedence.EXPR, v)
1446                self.traverse(v)
1447            else:
1448                write_key_value_pair(k, v)
1449
1450        with self.delimit("{", "}"):
1451            self.interleave(
1452                lambda: self.write(", "), write_item, zip(node.keys, node.values)
1453            )
1454
1455    def visit_Tuple(self, node):
1456        with self.delimit_if(
1457            "(",
1458            ")",
1459            len(node.elts) == 0 or self.get_precedence(node) > _Precedence.TUPLE
1460        ):
1461            self.items_view(self.traverse, node.elts)
1462
1463    unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1464    unop_precedence = {
1465        "not": _Precedence.NOT,
1466        "~": _Precedence.FACTOR,
1467        "+": _Precedence.FACTOR,
1468        "-": _Precedence.FACTOR,
1469    }
1470
1471    def visit_UnaryOp(self, node):
1472        operator = self.unop[node.op.__class__.__name__]
1473        operator_precedence = self.unop_precedence[operator]
1474        with self.require_parens(operator_precedence, node):
1475            self.write(operator)
1476            # factor prefixes (+, -, ~) shouldn't be separated
1477            # from the value they belong, (e.g: +1 instead of + 1)
1478            if operator_precedence is not _Precedence.FACTOR:
1479                self.write(" ")
1480            self.set_precedence(operator_precedence, node.operand)
1481            self.traverse(node.operand)
1482
1483    binop = {
1484        "Add": "+",
1485        "Sub": "-",
1486        "Mult": "*",
1487        "MatMult": "@",
1488        "Div": "/",
1489        "Mod": "%",
1490        "LShift": "<<",
1491        "RShift": ">>",
1492        "BitOr": "|",
1493        "BitXor": "^",
1494        "BitAnd": "&",
1495        "FloorDiv": "//",
1496        "Pow": "**",
1497    }
1498
1499    binop_precedence = {
1500        "+": _Precedence.ARITH,
1501        "-": _Precedence.ARITH,
1502        "*": _Precedence.TERM,
1503        "@": _Precedence.TERM,
1504        "/": _Precedence.TERM,
1505        "%": _Precedence.TERM,
1506        "<<": _Precedence.SHIFT,
1507        ">>": _Precedence.SHIFT,
1508        "|": _Precedence.BOR,
1509        "^": _Precedence.BXOR,
1510        "&": _Precedence.BAND,
1511        "//": _Precedence.TERM,
1512        "**": _Precedence.POWER,
1513    }
1514
1515    binop_rassoc = frozenset(("**",))
1516    def visit_BinOp(self, node):
1517        operator = self.binop[node.op.__class__.__name__]
1518        operator_precedence = self.binop_precedence[operator]
1519        with self.require_parens(operator_precedence, node):
1520            if operator in self.binop_rassoc:
1521                left_precedence = operator_precedence.next()
1522                right_precedence = operator_precedence
1523            else:
1524                left_precedence = operator_precedence
1525                right_precedence = operator_precedence.next()
1526
1527            self.set_precedence(left_precedence, node.left)
1528            self.traverse(node.left)
1529            self.write(f" {operator} ")
1530            self.set_precedence(right_precedence, node.right)
1531            self.traverse(node.right)
1532
1533    cmpops = {
1534        "Eq": "==",
1535        "NotEq": "!=",
1536        "Lt": "<",
1537        "LtE": "<=",
1538        "Gt": ">",
1539        "GtE": ">=",
1540        "Is": "is",
1541        "IsNot": "is not",
1542        "In": "in",
1543        "NotIn": "not in",
1544    }
1545
1546    def visit_Compare(self, node):
1547        with self.require_parens(_Precedence.CMP, node):
1548            self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
1549            self.traverse(node.left)
1550            for o, e in zip(node.ops, node.comparators):
1551                self.write(" " + self.cmpops[o.__class__.__name__] + " ")
1552                self.traverse(e)
1553
1554    boolops = {"And": "and", "Or": "or"}
1555    boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}
1556
1557    def visit_BoolOp(self, node):
1558        operator = self.boolops[node.op.__class__.__name__]
1559        operator_precedence = self.boolop_precedence[operator]
1560
1561        def increasing_level_traverse(node):
1562            nonlocal operator_precedence
1563            operator_precedence = operator_precedence.next()
1564            self.set_precedence(operator_precedence, node)
1565            self.traverse(node)
1566
1567        with self.require_parens(operator_precedence, node):
1568            s = f" {operator} "
1569            self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)
1570
1571    def visit_Attribute(self, node):
1572        self.set_precedence(_Precedence.ATOM, node.value)
1573        self.traverse(node.value)
1574        # Special case: 3.__abs__() is a syntax error, so if node.value
1575        # is an integer literal then we need to either parenthesize
1576        # it or add an extra space to get 3 .__abs__().
1577        if isinstance(node.value, Constant) and isinstance(node.value.value, int):
1578            self.write(" ")
1579        self.write(".")
1580        self.write(node.attr)
1581
1582    def visit_Call(self, node):
1583        self.set_precedence(_Precedence.ATOM, node.func)
1584        self.traverse(node.func)
1585        with self.delimit("(", ")"):
1586            comma = False
1587            for e in node.args:
1588                if comma:
1589                    self.write(", ")
1590                else:
1591                    comma = True
1592                self.traverse(e)
1593            for e in node.keywords:
1594                if comma:
1595                    self.write(", ")
1596                else:
1597                    comma = True
1598                self.traverse(e)
1599
1600    def visit_Subscript(self, node):
1601        def is_non_empty_tuple(slice_value):
1602            return (
1603                isinstance(slice_value, Tuple)
1604                and slice_value.elts
1605            )
1606
1607        self.set_precedence(_Precedence.ATOM, node.value)
1608        self.traverse(node.value)
1609        with self.delimit("[", "]"):
1610            if is_non_empty_tuple(node.slice):
1611                # parentheses can be omitted if the tuple isn't empty
1612                self.items_view(self.traverse, node.slice.elts)
1613            else:
1614                self.traverse(node.slice)
1615
1616    def visit_Starred(self, node):
1617        self.write("*")
1618        self.set_precedence(_Precedence.EXPR, node.value)
1619        self.traverse(node.value)
1620
1621    def visit_Ellipsis(self, node):
1622        self.write("...")
1623
1624    def visit_Slice(self, node):
1625        if node.lower:
1626            self.traverse(node.lower)
1627        self.write(":")
1628        if node.upper:
1629            self.traverse(node.upper)
1630        if node.step:
1631            self.write(":")
1632            self.traverse(node.step)
1633
1634    def visit_Match(self, node):
1635        self.fill("match ")
1636        self.traverse(node.subject)
1637        with self.block():
1638            for case in node.cases:
1639                self.traverse(case)
1640
1641    def visit_arg(self, node):
1642        self.write(node.arg)
1643        if node.annotation:
1644            self.write(": ")
1645            self.traverse(node.annotation)
1646
1647    def visit_arguments(self, node):
1648        first = True
1649        # normal arguments
1650        all_args = node.posonlyargs + node.args
1651        defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
1652        for index, elements in enumerate(zip(all_args, defaults), 1):
1653            a, d = elements
1654            if first:
1655                first = False
1656            else:
1657                self.write(", ")
1658            self.traverse(a)
1659            if d:
1660                self.write("=")
1661                self.traverse(d)
1662            if index == len(node.posonlyargs):
1663                self.write(", /")
1664
1665        # varargs, or bare '*' if no varargs but keyword-only arguments present
1666        if node.vararg or node.kwonlyargs:
1667            if first:
1668                first = False
1669            else:
1670                self.write(", ")
1671            self.write("*")
1672            if node.vararg:
1673                self.write(node.vararg.arg)
1674                if node.vararg.annotation:
1675                    self.write(": ")
1676                    self.traverse(node.vararg.annotation)
1677
1678        # keyword-only arguments
1679        if node.kwonlyargs:
1680            for a, d in zip(node.kwonlyargs, node.kw_defaults):
1681                self.write(", ")
1682                self.traverse(a)
1683                if d:
1684                    self.write("=")
1685                    self.traverse(d)
1686
1687        # kwargs
1688        if node.kwarg:
1689            if first:
1690                first = False
1691            else:
1692                self.write(", ")
1693            self.write("**" + node.kwarg.arg)
1694            if node.kwarg.annotation:
1695                self.write(": ")
1696                self.traverse(node.kwarg.annotation)
1697
1698    def visit_keyword(self, node):
1699        if node.arg is None:
1700            self.write("**")
1701        else:
1702            self.write(node.arg)
1703            self.write("=")
1704        self.traverse(node.value)
1705
1706    def visit_Lambda(self, node):
1707        with self.require_parens(_Precedence.TEST, node):
1708            self.write("lambda")
1709            with self.buffered() as buffer:
1710                self.traverse(node.args)
1711            if buffer:
1712                self.write(" ", *buffer)
1713            self.write(": ")
1714            self.set_precedence(_Precedence.TEST, node.body)
1715            self.traverse(node.body)
1716
1717    def visit_alias(self, node):
1718        self.write(node.name)
1719        if node.asname:
1720            self.write(" as " + node.asname)
1721
1722    def visit_withitem(self, node):
1723        self.traverse(node.context_expr)
1724        if node.optional_vars:
1725            self.write(" as ")
1726            self.traverse(node.optional_vars)
1727
1728    def visit_match_case(self, node):
1729        self.fill("case ")
1730        self.traverse(node.pattern)
1731        if node.guard:
1732            self.write(" if ")
1733            self.traverse(node.guard)
1734        with self.block():
1735            self.traverse(node.body)
1736
1737    def visit_MatchValue(self, node):
1738        self.traverse(node.value)
1739
1740    def visit_MatchSingleton(self, node):
1741        self._write_constant(node.value)
1742
1743    def visit_MatchSequence(self, node):
1744        with self.delimit("[", "]"):
1745            self.interleave(
1746                lambda: self.write(", "), self.traverse, node.patterns
1747            )
1748
1749    def visit_MatchStar(self, node):
1750        name = node.name
1751        if name is None:
1752            name = "_"
1753        self.write(f"*{name}")
1754
1755    def visit_MatchMapping(self, node):
1756        def write_key_pattern_pair(pair):
1757            k, p = pair
1758            self.traverse(k)
1759            self.write(": ")
1760            self.traverse(p)
1761
1762        with self.delimit("{", "}"):
1763            keys = node.keys
1764            self.interleave(
1765                lambda: self.write(", "),
1766                write_key_pattern_pair,
1767                zip(keys, node.patterns, strict=True),
1768            )
1769            rest = node.rest
1770            if rest is not None:
1771                if keys:
1772                    self.write(", ")
1773                self.write(f"**{rest}")
1774
1775    def visit_MatchClass(self, node):
1776        self.set_precedence(_Precedence.ATOM, node.cls)
1777        self.traverse(node.cls)
1778        with self.delimit("(", ")"):
1779            patterns = node.patterns
1780            self.interleave(
1781                lambda: self.write(", "), self.traverse, patterns
1782            )
1783            attrs = node.kwd_attrs
1784            if attrs:
1785                def write_attr_pattern(pair):
1786                    attr, pattern = pair
1787                    self.write(f"{attr}=")
1788                    self.traverse(pattern)
1789
1790                if patterns:
1791                    self.write(", ")
1792                self.interleave(
1793                    lambda: self.write(", "),
1794                    write_attr_pattern,
1795                    zip(attrs, node.kwd_patterns, strict=True),
1796                )
1797
1798    def visit_MatchAs(self, node):
1799        name = node.name
1800        pattern = node.pattern
1801        if name is None:
1802            self.write("_")
1803        elif pattern is None:
1804            self.write(node.name)
1805        else:
1806            with self.require_parens(_Precedence.TEST, node):
1807                self.set_precedence(_Precedence.BOR, node.pattern)
1808                self.traverse(node.pattern)
1809                self.write(f" as {node.name}")
1810
1811    def visit_MatchOr(self, node):
1812        with self.require_parens(_Precedence.BOR, node):
1813            self.set_precedence(_Precedence.BOR.next(), *node.patterns)
1814            self.interleave(lambda: self.write(" | "), self.traverse, node.patterns)
1815
1816def unparse(ast_obj):
1817    unparser = _Unparser()
1818    return unparser.visit(ast_obj)
1819
1820
1821_deprecated_globals = {
1822    name: globals().pop(name)
1823    for name in ('Num', 'Str', 'Bytes', 'NameConstant', 'Ellipsis')
1824}
1825
1826def __getattr__(name):
1827    if name in _deprecated_globals:
1828        globals()[name] = value = _deprecated_globals[name]
1829        import warnings
1830        warnings._deprecated(
1831            f"ast.{name}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14)
1832        )
1833        return value
1834    raise AttributeError(f"module 'ast' has no attribute '{name}'")
1835
1836
1837def main():
1838    import argparse
1839
1840    parser = argparse.ArgumentParser(prog='python -m ast')
1841    parser.add_argument('infile', nargs='?', default='-',
1842                        help='the file to parse; defaults to stdin')
1843    parser.add_argument('-m', '--mode', default='exec',
1844                        choices=('exec', 'single', 'eval', 'func_type'),
1845                        help='specify what kind of code must be parsed')
1846    parser.add_argument('--no-type-comments', default=True, action='store_false',
1847                        help="don't add information about type comments")
1848    parser.add_argument('-a', '--include-attributes', action='store_true',
1849                        help='include attributes such as line numbers and '
1850                             'column offsets')
1851    parser.add_argument('-i', '--indent', type=int, default=3,
1852                        help='indentation of nodes (number of spaces)')
1853    args = parser.parse_args()
1854
1855    if args.infile == '-':
1856        name = '<stdin>'
1857        source = sys.stdin.buffer.read()
1858    else:
1859        name = args.infile
1860        with open(args.infile, 'rb') as infile:
1861            source = infile.read()
1862    tree = parse(source, name, args.mode, type_comments=args.no_type_comments)
1863    print(dump(tree, include_attributes=args.include_attributes, indent=args.indent))
1864
1865if __name__ == '__main__':
1866    main()
1867