• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* AST Optimizer */
2 #include "Python.h"
3 #include "pycore_ast.h"           // _PyAST_GetDocString()
4 #include "pycore_format.h"        // F_LJUST
5 #include "pycore_long.h"          // _PyLong
6 #include "pycore_pystate.h"       // _PyThreadState_GET()
7 #include "pycore_setobject.h"     // _PySet_NextEntry()
8 
9 
10 typedef struct {
11     int optimize;
12     int ff_features;
13 
14     int recursion_depth;            /* current recursion depth */
15     int recursion_limit;            /* recursion limit */
16 } _PyASTOptimizeState;
17 
18 
19 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)20 make_const(expr_ty node, PyObject *val, PyArena *arena)
21 {
22     // Even if no new value was calculated, make_const may still
23     // need to clear an error (e.g. for division by zero)
24     if (val == NULL) {
25         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
26             return 0;
27         }
28         PyErr_Clear();
29         return 1;
30     }
31     if (_PyArena_AddPyObject(arena, val) < 0) {
32         Py_DECREF(val);
33         return 0;
34     }
35     node->kind = Constant_kind;
36     node->v.Constant.kind = NULL;
37     node->v.Constant.value = val;
38     return 1;
39 }
40 
41 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
42 
43 static int
has_starred(asdl_expr_seq * elts)44 has_starred(asdl_expr_seq *elts)
45 {
46     Py_ssize_t n = asdl_seq_LEN(elts);
47     for (Py_ssize_t i = 0; i < n; i++) {
48         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
49         if (e->kind == Starred_kind) {
50             return 1;
51         }
52     }
53     return 0;
54 }
55 
56 
57 static PyObject*
unary_not(PyObject * v)58 unary_not(PyObject *v)
59 {
60     int r = PyObject_IsTrue(v);
61     if (r < 0)
62         return NULL;
63     return PyBool_FromLong(!r);
64 }
65 
66 static int
fold_unaryop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)67 fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
68 {
69     expr_ty arg = node->v.UnaryOp.operand;
70 
71     if (arg->kind != Constant_kind) {
72         /* Fold not into comparison */
73         if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
74                 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
75             /* Eq and NotEq are often implemented in terms of one another, so
76                folding not (self == other) into self != other breaks implementation
77                of !=. Detecting such cases doesn't seem worthwhile.
78                Python uses </> for 'is subset'/'is superset' operations on sets.
79                They don't satisfy not folding laws. */
80             cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
81             switch (op) {
82             case Is:
83                 op = IsNot;
84                 break;
85             case IsNot:
86                 op = Is;
87                 break;
88             case In:
89                 op = NotIn;
90                 break;
91             case NotIn:
92                 op = In;
93                 break;
94             // The remaining comparison operators can't be safely inverted
95             case Eq:
96             case NotEq:
97             case Lt:
98             case LtE:
99             case Gt:
100             case GtE:
101                 op = 0; // The AST enums leave "0" free as an "unused" marker
102                 break;
103             // No default case, so the compiler will emit a warning if new
104             // comparison operators are added without being handled here
105             }
106             if (op) {
107                 asdl_seq_SET(arg->v.Compare.ops, 0, op);
108                 COPY_NODE(node, arg);
109                 return 1;
110             }
111         }
112         return 1;
113     }
114 
115     typedef PyObject *(*unary_op)(PyObject*);
116     static const unary_op ops[] = {
117         [Invert] = PyNumber_Invert,
118         [Not] = unary_not,
119         [UAdd] = PyNumber_Positive,
120         [USub] = PyNumber_Negative,
121     };
122     PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
123     return make_const(node, newval, arena);
124 }
125 
126 /* Check whether a collection doesn't containing too much items (including
127    subcollections).  This protects from creating a constant that needs
128    too much time for calculating a hash.
129    "limit" is the maximal number of items.
130    Returns the negative number if the total number of items exceeds the
131    limit.  Otherwise returns the limit minus the total number of items.
132 */
133 
134 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)135 check_complexity(PyObject *obj, Py_ssize_t limit)
136 {
137     if (PyTuple_Check(obj)) {
138         Py_ssize_t i;
139         limit -= PyTuple_GET_SIZE(obj);
140         for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
141             limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
142         }
143         return limit;
144     }
145     return limit;
146 }
147 
148 #define MAX_INT_SIZE           128  /* bits */
149 #define MAX_COLLECTION_SIZE    256  /* items */
150 #define MAX_STR_SIZE          4096  /* characters */
151 #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
152 
153 static PyObject *
safe_multiply(PyObject * v,PyObject * w)154 safe_multiply(PyObject *v, PyObject *w)
155 {
156     if (PyLong_Check(v) && PyLong_Check(w) &&
157         !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
158     ) {
159         size_t vbits = _PyLong_NumBits(v);
160         size_t wbits = _PyLong_NumBits(w);
161         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
162             return NULL;
163         }
164         if (vbits + wbits > MAX_INT_SIZE) {
165             return NULL;
166         }
167     }
168     else if (PyLong_Check(v) && PyTuple_Check(w)) {
169         Py_ssize_t size = PyTuple_GET_SIZE(w);
170         if (size) {
171             long n = PyLong_AsLong(v);
172             if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
173                 return NULL;
174             }
175             if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
176                 return NULL;
177             }
178         }
179     }
180     else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
181         Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
182                                                PyBytes_GET_SIZE(w);
183         if (size) {
184             long n = PyLong_AsLong(v);
185             if (n < 0 || n > MAX_STR_SIZE / size) {
186                 return NULL;
187             }
188         }
189     }
190     else if (PyLong_Check(w) &&
191              (PyTuple_Check(v) || PyUnicode_Check(v) || PyBytes_Check(v)))
192     {
193         return safe_multiply(w, v);
194     }
195 
196     return PyNumber_Multiply(v, w);
197 }
198 
199 static PyObject *
safe_power(PyObject * v,PyObject * w)200 safe_power(PyObject *v, PyObject *w)
201 {
202     if (PyLong_Check(v) && PyLong_Check(w) &&
203         !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
204     ) {
205         size_t vbits = _PyLong_NumBits(v);
206         size_t wbits = PyLong_AsSize_t(w);
207         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
208             return NULL;
209         }
210         if (vbits > MAX_INT_SIZE / wbits) {
211             return NULL;
212         }
213     }
214 
215     return PyNumber_Power(v, w, Py_None);
216 }
217 
218 static PyObject *
safe_lshift(PyObject * v,PyObject * w)219 safe_lshift(PyObject *v, PyObject *w)
220 {
221     if (PyLong_Check(v) && PyLong_Check(w) &&
222         !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
223     ) {
224         size_t vbits = _PyLong_NumBits(v);
225         size_t wbits = PyLong_AsSize_t(w);
226         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
227             return NULL;
228         }
229         if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
230             return NULL;
231         }
232     }
233 
234     return PyNumber_Lshift(v, w);
235 }
236 
237 static PyObject *
safe_mod(PyObject * v,PyObject * w)238 safe_mod(PyObject *v, PyObject *w)
239 {
240     if (PyUnicode_Check(v) || PyBytes_Check(v)) {
241         return NULL;
242     }
243 
244     return PyNumber_Remainder(v, w);
245 }
246 
247 
248 static expr_ty
parse_literal(PyObject * fmt,Py_ssize_t * ppos,PyArena * arena)249 parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
250 {
251     const void *data = PyUnicode_DATA(fmt);
252     int kind = PyUnicode_KIND(fmt);
253     Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
254     Py_ssize_t start, pos;
255     int has_percents = 0;
256     start = pos = *ppos;
257     while (pos < size) {
258         if (PyUnicode_READ(kind, data, pos) != '%') {
259             pos++;
260         }
261         else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
262             has_percents = 1;
263             pos += 2;
264         }
265         else {
266             break;
267         }
268     }
269     *ppos = pos;
270     if (pos == start) {
271         return NULL;
272     }
273     PyObject *str = PyUnicode_Substring(fmt, start, pos);
274     /* str = str.replace('%%', '%') */
275     if (str && has_percents) {
276         _Py_DECLARE_STR(dbl_percent, "%%");
277         Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
278                                          _Py_LATIN1_CHR('%'), -1));
279     }
280     if (!str) {
281         return NULL;
282     }
283 
284     if (_PyArena_AddPyObject(arena, str) < 0) {
285         Py_DECREF(str);
286         return NULL;
287     }
288     return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
289 }
290 
291 #define MAXDIGITS 3
292 
293 static int
simple_format_arg_parse(PyObject * fmt,Py_ssize_t * ppos,int * spec,int * flags,int * width,int * prec)294 simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
295                         int *spec, int *flags, int *width, int *prec)
296 {
297     Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
298     Py_UCS4 ch;
299 
300 #define NEXTC do {                      \
301     if (pos >= len) {                   \
302         return 0;                       \
303     }                                   \
304     ch = PyUnicode_READ_CHAR(fmt, pos); \
305     pos++;                              \
306 } while (0)
307 
308     *flags = 0;
309     while (1) {
310         NEXTC;
311         switch (ch) {
312             case '-': *flags |= F_LJUST; continue;
313             case '+': *flags |= F_SIGN; continue;
314             case ' ': *flags |= F_BLANK; continue;
315             case '#': *flags |= F_ALT; continue;
316             case '0': *flags |= F_ZERO; continue;
317         }
318         break;
319     }
320     if ('0' <= ch && ch <= '9') {
321         *width = 0;
322         int digits = 0;
323         while ('0' <= ch && ch <= '9') {
324             *width = *width * 10 + (ch - '0');
325             NEXTC;
326             if (++digits >= MAXDIGITS) {
327                 return 0;
328             }
329         }
330     }
331 
332     if (ch == '.') {
333         NEXTC;
334         *prec = 0;
335         if ('0' <= ch && ch <= '9') {
336             int digits = 0;
337             while ('0' <= ch && ch <= '9') {
338                 *prec = *prec * 10 + (ch - '0');
339                 NEXTC;
340                 if (++digits >= MAXDIGITS) {
341                     return 0;
342                 }
343             }
344         }
345     }
346     *spec = ch;
347     *ppos = pos;
348     return 1;
349 
350 #undef NEXTC
351 }
352 
353 static expr_ty
parse_format(PyObject * fmt,Py_ssize_t * ppos,expr_ty arg,PyArena * arena)354 parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
355 {
356     int spec, flags, width = -1, prec = -1;
357     if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
358         // Unsupported format.
359         return NULL;
360     }
361     if (spec == 's' || spec == 'r' || spec == 'a') {
362         char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
363         if (!(flags & F_LJUST) && width > 0) {
364             *p++ = '>';
365         }
366         if (width >= 0) {
367             p += snprintf(p, MAXDIGITS + 1, "%d", width);
368         }
369         if (prec >= 0) {
370             p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
371         }
372         expr_ty format_spec = NULL;
373         if (p != buf) {
374             PyObject *str = PyUnicode_FromString(buf);
375             if (str == NULL) {
376                 return NULL;
377             }
378             if (_PyArena_AddPyObject(arena, str) < 0) {
379                 Py_DECREF(str);
380                 return NULL;
381             }
382             format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
383             if (format_spec == NULL) {
384                 return NULL;
385             }
386         }
387         return _PyAST_FormattedValue(arg, spec, format_spec,
388                                      arg->lineno, arg->col_offset,
389                                      arg->end_lineno, arg->end_col_offset,
390                                      arena);
391     }
392     // Unsupported format.
393     return NULL;
394 }
395 
396 static int
optimize_format(expr_ty node,PyObject * fmt,asdl_expr_seq * elts,PyArena * arena)397 optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
398 {
399     Py_ssize_t pos = 0;
400     Py_ssize_t cnt = 0;
401     asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
402     if (!seq) {
403         return 0;
404     }
405     seq->size = 0;
406 
407     while (1) {
408         expr_ty lit = parse_literal(fmt, &pos, arena);
409         if (lit) {
410             asdl_seq_SET(seq, seq->size++, lit);
411         }
412         else if (PyErr_Occurred()) {
413             return 0;
414         }
415 
416         if (pos >= PyUnicode_GET_LENGTH(fmt)) {
417             break;
418         }
419         if (cnt >= asdl_seq_LEN(elts)) {
420             // More format units than items.
421             return 1;
422         }
423         assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
424         pos++;
425         expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
426         cnt++;
427         if (!expr) {
428             return !PyErr_Occurred();
429         }
430         asdl_seq_SET(seq, seq->size++, expr);
431     }
432     if (cnt < asdl_seq_LEN(elts)) {
433         // More items than format units.
434         return 1;
435     }
436     expr_ty res = _PyAST_JoinedStr(seq,
437                                    node->lineno, node->col_offset,
438                                    node->end_lineno, node->end_col_offset,
439                                    arena);
440     if (!res) {
441         return 0;
442     }
443     COPY_NODE(node, res);
444 //     PySys_FormatStderr("format = %R\n", fmt);
445     return 1;
446 }
447 
448 static int
fold_binop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)449 fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
450 {
451     expr_ty lhs, rhs;
452     lhs = node->v.BinOp.left;
453     rhs = node->v.BinOp.right;
454     if (lhs->kind != Constant_kind) {
455         return 1;
456     }
457     PyObject *lv = lhs->v.Constant.value;
458 
459     if (node->v.BinOp.op == Mod &&
460         rhs->kind == Tuple_kind &&
461         PyUnicode_Check(lv) &&
462         !has_starred(rhs->v.Tuple.elts))
463     {
464         return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
465     }
466 
467     if (rhs->kind != Constant_kind) {
468         return 1;
469     }
470 
471     PyObject *rv = rhs->v.Constant.value;
472     PyObject *newval = NULL;
473 
474     switch (node->v.BinOp.op) {
475     case Add:
476         newval = PyNumber_Add(lv, rv);
477         break;
478     case Sub:
479         newval = PyNumber_Subtract(lv, rv);
480         break;
481     case Mult:
482         newval = safe_multiply(lv, rv);
483         break;
484     case Div:
485         newval = PyNumber_TrueDivide(lv, rv);
486         break;
487     case FloorDiv:
488         newval = PyNumber_FloorDivide(lv, rv);
489         break;
490     case Mod:
491         newval = safe_mod(lv, rv);
492         break;
493     case Pow:
494         newval = safe_power(lv, rv);
495         break;
496     case LShift:
497         newval = safe_lshift(lv, rv);
498         break;
499     case RShift:
500         newval = PyNumber_Rshift(lv, rv);
501         break;
502     case BitOr:
503         newval = PyNumber_Or(lv, rv);
504         break;
505     case BitXor:
506         newval = PyNumber_Xor(lv, rv);
507         break;
508     case BitAnd:
509         newval = PyNumber_And(lv, rv);
510         break;
511     // No builtin constants implement the following operators
512     case MatMult:
513         return 1;
514     // No default case, so the compiler will emit a warning if new binary
515     // operators are added without being handled here
516     }
517 
518     return make_const(node, newval, arena);
519 }
520 
521 static PyObject*
make_const_tuple(asdl_expr_seq * elts)522 make_const_tuple(asdl_expr_seq *elts)
523 {
524     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
525         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
526         if (e->kind != Constant_kind) {
527             return NULL;
528         }
529     }
530 
531     PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
532     if (newval == NULL) {
533         return NULL;
534     }
535 
536     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
537         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
538         PyObject *v = e->v.Constant.value;
539         PyTuple_SET_ITEM(newval, i, Py_NewRef(v));
540     }
541     return newval;
542 }
543 
544 static int
fold_tuple(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)545 fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
546 {
547     PyObject *newval;
548 
549     if (node->v.Tuple.ctx != Load)
550         return 1;
551 
552     newval = make_const_tuple(node->v.Tuple.elts);
553     return make_const(node, newval, arena);
554 }
555 
556 static int
fold_subscr(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)557 fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
558 {
559     PyObject *newval;
560     expr_ty arg, idx;
561 
562     arg = node->v.Subscript.value;
563     idx = node->v.Subscript.slice;
564     if (node->v.Subscript.ctx != Load ||
565             arg->kind != Constant_kind ||
566             idx->kind != Constant_kind)
567     {
568         return 1;
569     }
570 
571     newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
572     return make_const(node, newval, arena);
573 }
574 
575 /* Change literal list or set of constants into constant
576    tuple or frozenset respectively.  Change literal list of
577    non-constants into tuple.
578    Used for right operand of "in" and "not in" tests and for iterable
579    in "for" loop and comprehensions.
580 */
581 static int
fold_iter(expr_ty arg,PyArena * arena,_PyASTOptimizeState * state)582 fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
583 {
584     PyObject *newval;
585     if (arg->kind == List_kind) {
586         /* First change a list into tuple. */
587         asdl_expr_seq *elts = arg->v.List.elts;
588         if (has_starred(elts)) {
589             return 1;
590         }
591         expr_context_ty ctx = arg->v.List.ctx;
592         arg->kind = Tuple_kind;
593         arg->v.Tuple.elts = elts;
594         arg->v.Tuple.ctx = ctx;
595         /* Try to create a constant tuple. */
596         newval = make_const_tuple(elts);
597     }
598     else if (arg->kind == Set_kind) {
599         newval = make_const_tuple(arg->v.Set.elts);
600         if (newval) {
601             Py_SETREF(newval, PyFrozenSet_New(newval));
602         }
603     }
604     else {
605         return 1;
606     }
607     return make_const(arg, newval, arena);
608 }
609 
610 static int
fold_compare(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)611 fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
612 {
613     asdl_int_seq *ops;
614     asdl_expr_seq *args;
615     Py_ssize_t i;
616 
617     ops = node->v.Compare.ops;
618     args = node->v.Compare.comparators;
619     /* Change literal list or set in 'in' or 'not in' into
620        tuple or frozenset respectively. */
621     i = asdl_seq_LEN(ops) - 1;
622     int op = asdl_seq_GET(ops, i);
623     if (op == In || op == NotIn) {
624         if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
625             return 0;
626         }
627     }
628     return 1;
629 }
630 
631 static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
632 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
633 static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
634 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
635 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
636 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
637 static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
638 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
639 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
640 static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
641 static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
642 static int astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
643 
644 #define CALL(FUNC, TYPE, ARG) \
645     if (!FUNC((ARG), ctx_, state)) \
646         return 0;
647 
648 #define CALL_OPT(FUNC, TYPE, ARG) \
649     if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
650         return 0;
651 
652 #define CALL_SEQ(FUNC, TYPE, ARG) { \
653     int i; \
654     asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
655     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
656         TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
657         if (elt != NULL && !FUNC(elt, ctx_, state)) \
658             return 0; \
659     } \
660 }
661 
662 
663 static int
astfold_body(asdl_stmt_seq * stmts,PyArena * ctx_,_PyASTOptimizeState * state)664 astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
665 {
666     int docstring = _PyAST_GetDocString(stmts) != NULL;
667     CALL_SEQ(astfold_stmt, stmt, stmts);
668     if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
669         stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
670         asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
671         if (!values) {
672             return 0;
673         }
674         asdl_seq_SET(values, 0, st->v.Expr.value);
675         expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
676                                         st->end_lineno, st->end_col_offset,
677                                         ctx_);
678         if (!expr) {
679             return 0;
680         }
681         st->v.Expr.value = expr;
682     }
683     return 1;
684 }
685 
686 static int
astfold_mod(mod_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)687 astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
688 {
689     switch (node_->kind) {
690     case Module_kind:
691         CALL(astfold_body, asdl_seq, node_->v.Module.body);
692         break;
693     case Interactive_kind:
694         CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
695         break;
696     case Expression_kind:
697         CALL(astfold_expr, expr_ty, node_->v.Expression.body);
698         break;
699     // The following top level nodes don't participate in constant folding
700     case FunctionType_kind:
701         break;
702     // No default case, so the compiler will emit a warning if new top level
703     // compilation nodes are added without being handled here
704     }
705     return 1;
706 }
707 
708 static int
astfold_expr(expr_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)709 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
710 {
711     if (++state->recursion_depth > state->recursion_limit) {
712         PyErr_SetString(PyExc_RecursionError,
713                         "maximum recursion depth exceeded during compilation");
714         return 0;
715     }
716     switch (node_->kind) {
717     case BoolOp_kind:
718         CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
719         break;
720     case BinOp_kind:
721         CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
722         CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
723         CALL(fold_binop, expr_ty, node_);
724         break;
725     case UnaryOp_kind:
726         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
727         CALL(fold_unaryop, expr_ty, node_);
728         break;
729     case Lambda_kind:
730         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
731         CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
732         break;
733     case IfExp_kind:
734         CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
735         CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
736         CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
737         break;
738     case Dict_kind:
739         CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
740         CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
741         break;
742     case Set_kind:
743         CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
744         break;
745     case ListComp_kind:
746         CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
747         CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
748         break;
749     case SetComp_kind:
750         CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
751         CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
752         break;
753     case DictComp_kind:
754         CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
755         CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
756         CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
757         break;
758     case GeneratorExp_kind:
759         CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
760         CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
761         break;
762     case Await_kind:
763         CALL(astfold_expr, expr_ty, node_->v.Await.value);
764         break;
765     case Yield_kind:
766         CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
767         break;
768     case YieldFrom_kind:
769         CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
770         break;
771     case Compare_kind:
772         CALL(astfold_expr, expr_ty, node_->v.Compare.left);
773         CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
774         CALL(fold_compare, expr_ty, node_);
775         break;
776     case Call_kind:
777         CALL(astfold_expr, expr_ty, node_->v.Call.func);
778         CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
779         CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
780         break;
781     case FormattedValue_kind:
782         CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
783         CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
784         break;
785     case JoinedStr_kind:
786         CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
787         break;
788     case Attribute_kind:
789         CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
790         break;
791     case Subscript_kind:
792         CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
793         CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
794         CALL(fold_subscr, expr_ty, node_);
795         break;
796     case Starred_kind:
797         CALL(astfold_expr, expr_ty, node_->v.Starred.value);
798         break;
799     case Slice_kind:
800         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
801         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
802         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
803         break;
804     case List_kind:
805         CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
806         break;
807     case Tuple_kind:
808         CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
809         CALL(fold_tuple, expr_ty, node_);
810         break;
811     case Name_kind:
812         if (node_->v.Name.ctx == Load &&
813                 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
814             state->recursion_depth--;
815             return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
816         }
817         break;
818     case NamedExpr_kind:
819         CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
820         break;
821     case Constant_kind:
822         // Already a constant, nothing further to do
823         break;
824     // No default case, so the compiler will emit a warning if new expression
825     // kinds are added without being handled here
826     }
827     state->recursion_depth--;
828     return 1;
829 }
830 
831 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)832 astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
833 {
834     CALL(astfold_expr, expr_ty, node_->value);
835     return 1;
836 }
837 
838 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)839 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
840 {
841     CALL(astfold_expr, expr_ty, node_->target);
842     CALL(astfold_expr, expr_ty, node_->iter);
843     CALL_SEQ(astfold_expr, expr, node_->ifs);
844 
845     CALL(fold_iter, expr_ty, node_->iter);
846     return 1;
847 }
848 
849 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)850 astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
851 {
852     CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
853     CALL_SEQ(astfold_arg, arg, node_->args);
854     CALL_OPT(astfold_arg, arg_ty, node_->vararg);
855     CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
856     CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
857     CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
858     CALL_SEQ(astfold_expr, expr, node_->defaults);
859     return 1;
860 }
861 
862 static int
astfold_arg(arg_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)863 astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
864 {
865     if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
866         CALL_OPT(astfold_expr, expr_ty, node_->annotation);
867     }
868     return 1;
869 }
870 
871 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)872 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
873 {
874     if (++state->recursion_depth > state->recursion_limit) {
875         PyErr_SetString(PyExc_RecursionError,
876                         "maximum recursion depth exceeded during compilation");
877         return 0;
878     }
879     switch (node_->kind) {
880     case FunctionDef_kind:
881         CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
882         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
883         CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
884         CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
885         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
886             CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
887         }
888         break;
889     case AsyncFunctionDef_kind:
890         CALL_SEQ(astfold_type_param, type_param, node_->v.AsyncFunctionDef.type_params);
891         CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
892         CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
893         CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
894         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
895             CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
896         }
897         break;
898     case ClassDef_kind:
899         CALL_SEQ(astfold_type_param, type_param, node_->v.ClassDef.type_params);
900         CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
901         CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
902         CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
903         CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
904         break;
905     case Return_kind:
906         CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
907         break;
908     case Delete_kind:
909         CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
910         break;
911     case Assign_kind:
912         CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
913         CALL(astfold_expr, expr_ty, node_->v.Assign.value);
914         break;
915     case AugAssign_kind:
916         CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
917         CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
918         break;
919     case AnnAssign_kind:
920         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
921         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
922             CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
923         }
924         CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
925         break;
926     case TypeAlias_kind:
927         CALL(astfold_expr, expr_ty, node_->v.TypeAlias.name);
928         CALL_SEQ(astfold_type_param, type_param, node_->v.TypeAlias.type_params);
929         CALL(astfold_expr, expr_ty, node_->v.TypeAlias.value);
930         break;
931     case For_kind:
932         CALL(astfold_expr, expr_ty, node_->v.For.target);
933         CALL(astfold_expr, expr_ty, node_->v.For.iter);
934         CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
935         CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
936 
937         CALL(fold_iter, expr_ty, node_->v.For.iter);
938         break;
939     case AsyncFor_kind:
940         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
941         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
942         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
943         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
944         break;
945     case While_kind:
946         CALL(astfold_expr, expr_ty, node_->v.While.test);
947         CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
948         CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
949         break;
950     case If_kind:
951         CALL(astfold_expr, expr_ty, node_->v.If.test);
952         CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
953         CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
954         break;
955     case With_kind:
956         CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
957         CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
958         break;
959     case AsyncWith_kind:
960         CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
961         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
962         break;
963     case Raise_kind:
964         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
965         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
966         break;
967     case Try_kind:
968         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
969         CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
970         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
971         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
972         break;
973     case TryStar_kind:
974         CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
975         CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
976         CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
977         CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
978         break;
979     case Assert_kind:
980         CALL(astfold_expr, expr_ty, node_->v.Assert.test);
981         CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
982         break;
983     case Expr_kind:
984         CALL(astfold_expr, expr_ty, node_->v.Expr.value);
985         break;
986     case Match_kind:
987         CALL(astfold_expr, expr_ty, node_->v.Match.subject);
988         CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
989         break;
990     // The following statements don't contain any subexpressions to be folded
991     case Import_kind:
992     case ImportFrom_kind:
993     case Global_kind:
994     case Nonlocal_kind:
995     case Pass_kind:
996     case Break_kind:
997     case Continue_kind:
998         break;
999     // No default case, so the compiler will emit a warning if new statement
1000     // kinds are added without being handled here
1001     }
1002     state->recursion_depth--;
1003     return 1;
1004 }
1005 
1006 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1007 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1008 {
1009     switch (node_->kind) {
1010     case ExceptHandler_kind:
1011         CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
1012         CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
1013         break;
1014     // No default case, so the compiler will emit a warning if new handler
1015     // kinds are added without being handled here
1016     }
1017     return 1;
1018 }
1019 
1020 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1021 astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1022 {
1023     CALL(astfold_expr, expr_ty, node_->context_expr);
1024     CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
1025     return 1;
1026 }
1027 
1028 static int
astfold_pattern(pattern_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1029 astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1030 {
1031     // Currently, this is really only used to form complex/negative numeric
1032     // constants in MatchValue and MatchMapping nodes
1033     // We still recurse into all subexpressions and subpatterns anyway
1034     if (++state->recursion_depth > state->recursion_limit) {
1035         PyErr_SetString(PyExc_RecursionError,
1036                         "maximum recursion depth exceeded during compilation");
1037         return 0;
1038     }
1039     switch (node_->kind) {
1040         case MatchValue_kind:
1041             CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
1042             break;
1043         case MatchSingleton_kind:
1044             break;
1045         case MatchSequence_kind:
1046             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
1047             break;
1048         case MatchMapping_kind:
1049             CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
1050             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
1051             break;
1052         case MatchClass_kind:
1053             CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
1054             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
1055             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
1056             break;
1057         case MatchStar_kind:
1058             break;
1059         case MatchAs_kind:
1060             if (node_->v.MatchAs.pattern) {
1061                 CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
1062             }
1063             break;
1064         case MatchOr_kind:
1065             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
1066             break;
1067     // No default case, so the compiler will emit a warning if new pattern
1068     // kinds are added without being handled here
1069     }
1070     state->recursion_depth--;
1071     return 1;
1072 }
1073 
1074 static int
astfold_match_case(match_case_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1075 astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1076 {
1077     CALL(astfold_pattern, expr_ty, node_->pattern);
1078     CALL_OPT(astfold_expr, expr_ty, node_->guard);
1079     CALL_SEQ(astfold_stmt, stmt, node_->body);
1080     return 1;
1081 }
1082 
1083 static int
astfold_type_param(type_param_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1084 astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1085 {
1086     switch (node_->kind) {
1087         case TypeVar_kind:
1088             CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
1089             CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.default_value);
1090             break;
1091         case ParamSpec_kind:
1092             CALL_OPT(astfold_expr, expr_ty, node_->v.ParamSpec.default_value);
1093             break;
1094         case TypeVarTuple_kind:
1095             CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVarTuple.default_value);
1096             break;
1097     }
1098     return 1;
1099 }
1100 
1101 #undef CALL
1102 #undef CALL_OPT
1103 #undef CALL_SEQ
1104 
1105 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,int optimize,int ff_features)1106 _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize, int ff_features)
1107 {
1108     PyThreadState *tstate;
1109     int starting_recursion_depth;
1110 
1111     _PyASTOptimizeState state;
1112     state.optimize = optimize;
1113     state.ff_features = ff_features;
1114 
1115     /* Setup recursion depth check counters */
1116     tstate = _PyThreadState_GET();
1117     if (!tstate) {
1118         return 0;
1119     }
1120     /* Be careful here to prevent overflow. */
1121     int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
1122     starting_recursion_depth = recursion_depth;
1123     state.recursion_depth = starting_recursion_depth;
1124     state.recursion_limit = Py_C_RECURSION_LIMIT;
1125 
1126     int ret = astfold_mod(mod, arena, &state);
1127     assert(ret || PyErr_Occurred());
1128 
1129     /* Check that the recursion depth counting balanced correctly */
1130     if (ret && state.recursion_depth != starting_recursion_depth) {
1131         PyErr_Format(PyExc_SystemError,
1132             "AST optimizer recursion depth mismatch (before=%d, after=%d)",
1133             starting_recursion_depth, state.recursion_depth);
1134         return 0;
1135     }
1136 
1137     return ret;
1138 }
1139