• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* AST Optimizer */
2 #include "Python.h"
3 #include "pycore_ast.h"           // _PyAST_GetDocString()
4 #include "pycore_compile.h"       // _PyASTOptimizeState
5 #include "pycore_pystate.h"       // _PyThreadState_GET()
6 
7 
8 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)9 make_const(expr_ty node, PyObject *val, PyArena *arena)
10 {
11     // Even if no new value was calculated, make_const may still
12     // need to clear an error (e.g. for division by zero)
13     if (val == NULL) {
14         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
15             return 0;
16         }
17         PyErr_Clear();
18         return 1;
19     }
20     if (_PyArena_AddPyObject(arena, val) < 0) {
21         Py_DECREF(val);
22         return 0;
23     }
24     node->kind = Constant_kind;
25     node->v.Constant.kind = NULL;
26     node->v.Constant.value = val;
27     return 1;
28 }
29 
30 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
31 
32 static PyObject*
unary_not(PyObject * v)33 unary_not(PyObject *v)
34 {
35     int r = PyObject_IsTrue(v);
36     if (r < 0)
37         return NULL;
38     return PyBool_FromLong(!r);
39 }
40 
41 static int
fold_unaryop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)42 fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
43 {
44     expr_ty arg = node->v.UnaryOp.operand;
45 
46     if (arg->kind != Constant_kind) {
47         /* Fold not into comparison */
48         if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
49                 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
50             /* Eq and NotEq are often implemented in terms of one another, so
51                folding not (self == other) into self != other breaks implementation
52                of !=. Detecting such cases doesn't seem worthwhile.
53                Python uses </> for 'is subset'/'is superset' operations on sets.
54                They don't satisfy not folding laws. */
55             cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
56             switch (op) {
57             case Is:
58                 op = IsNot;
59                 break;
60             case IsNot:
61                 op = Is;
62                 break;
63             case In:
64                 op = NotIn;
65                 break;
66             case NotIn:
67                 op = In;
68                 break;
69             // The remaining comparison operators can't be safely inverted
70             case Eq:
71             case NotEq:
72             case Lt:
73             case LtE:
74             case Gt:
75             case GtE:
76                 op = 0; // The AST enums leave "0" free as an "unused" marker
77                 break;
78             // No default case, so the compiler will emit a warning if new
79             // comparison operators are added without being handled here
80             }
81             if (op) {
82                 asdl_seq_SET(arg->v.Compare.ops, 0, op);
83                 COPY_NODE(node, arg);
84                 return 1;
85             }
86         }
87         return 1;
88     }
89 
90     typedef PyObject *(*unary_op)(PyObject*);
91     static const unary_op ops[] = {
92         [Invert] = PyNumber_Invert,
93         [Not] = unary_not,
94         [UAdd] = PyNumber_Positive,
95         [USub] = PyNumber_Negative,
96     };
97     PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
98     return make_const(node, newval, arena);
99 }
100 
101 /* Check whether a collection doesn't containing too much items (including
102    subcollections).  This protects from creating a constant that needs
103    too much time for calculating a hash.
104    "limit" is the maximal number of items.
105    Returns the negative number if the total number of items exceeds the
106    limit.  Otherwise returns the limit minus the total number of items.
107 */
108 
109 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)110 check_complexity(PyObject *obj, Py_ssize_t limit)
111 {
112     if (PyTuple_Check(obj)) {
113         Py_ssize_t i;
114         limit -= PyTuple_GET_SIZE(obj);
115         for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
116             limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
117         }
118         return limit;
119     }
120     else if (PyFrozenSet_Check(obj)) {
121         Py_ssize_t i = 0;
122         PyObject *item;
123         Py_hash_t hash;
124         limit -= PySet_GET_SIZE(obj);
125         while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
126             limit = check_complexity(item, limit);
127         }
128     }
129     return limit;
130 }
131 
132 #define MAX_INT_SIZE           128  /* bits */
133 #define MAX_COLLECTION_SIZE    256  /* items */
134 #define MAX_STR_SIZE          4096  /* characters */
135 #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
136 
137 static PyObject *
safe_multiply(PyObject * v,PyObject * w)138 safe_multiply(PyObject *v, PyObject *w)
139 {
140     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
141         size_t vbits = _PyLong_NumBits(v);
142         size_t wbits = _PyLong_NumBits(w);
143         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
144             return NULL;
145         }
146         if (vbits + wbits > MAX_INT_SIZE) {
147             return NULL;
148         }
149     }
150     else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
151         Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
152                                              PySet_GET_SIZE(w);
153         if (size) {
154             long n = PyLong_AsLong(v);
155             if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
156                 return NULL;
157             }
158             if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
159                 return NULL;
160             }
161         }
162     }
163     else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
164         Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
165                                                PyBytes_GET_SIZE(w);
166         if (size) {
167             long n = PyLong_AsLong(v);
168             if (n < 0 || n > MAX_STR_SIZE / size) {
169                 return NULL;
170             }
171         }
172     }
173     else if (PyLong_Check(w) &&
174              (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
175               PyUnicode_Check(v) || PyBytes_Check(v)))
176     {
177         return safe_multiply(w, v);
178     }
179 
180     return PyNumber_Multiply(v, w);
181 }
182 
183 static PyObject *
safe_power(PyObject * v,PyObject * w)184 safe_power(PyObject *v, PyObject *w)
185 {
186     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
187         size_t vbits = _PyLong_NumBits(v);
188         size_t wbits = PyLong_AsSize_t(w);
189         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
190             return NULL;
191         }
192         if (vbits > MAX_INT_SIZE / wbits) {
193             return NULL;
194         }
195     }
196 
197     return PyNumber_Power(v, w, Py_None);
198 }
199 
200 static PyObject *
safe_lshift(PyObject * v,PyObject * w)201 safe_lshift(PyObject *v, PyObject *w)
202 {
203     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
204         size_t vbits = _PyLong_NumBits(v);
205         size_t wbits = PyLong_AsSize_t(w);
206         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
207             return NULL;
208         }
209         if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
210             return NULL;
211         }
212     }
213 
214     return PyNumber_Lshift(v, w);
215 }
216 
217 static PyObject *
safe_mod(PyObject * v,PyObject * w)218 safe_mod(PyObject *v, PyObject *w)
219 {
220     if (PyUnicode_Check(v) || PyBytes_Check(v)) {
221         return NULL;
222     }
223 
224     return PyNumber_Remainder(v, w);
225 }
226 
227 static int
fold_binop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)228 fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
229 {
230     expr_ty lhs, rhs;
231     lhs = node->v.BinOp.left;
232     rhs = node->v.BinOp.right;
233     if (lhs->kind != Constant_kind || rhs->kind != Constant_kind) {
234         return 1;
235     }
236 
237     PyObject *lv = lhs->v.Constant.value;
238     PyObject *rv = rhs->v.Constant.value;
239     PyObject *newval = NULL;
240 
241     switch (node->v.BinOp.op) {
242     case Add:
243         newval = PyNumber_Add(lv, rv);
244         break;
245     case Sub:
246         newval = PyNumber_Subtract(lv, rv);
247         break;
248     case Mult:
249         newval = safe_multiply(lv, rv);
250         break;
251     case Div:
252         newval = PyNumber_TrueDivide(lv, rv);
253         break;
254     case FloorDiv:
255         newval = PyNumber_FloorDivide(lv, rv);
256         break;
257     case Mod:
258         newval = safe_mod(lv, rv);
259         break;
260     case Pow:
261         newval = safe_power(lv, rv);
262         break;
263     case LShift:
264         newval = safe_lshift(lv, rv);
265         break;
266     case RShift:
267         newval = PyNumber_Rshift(lv, rv);
268         break;
269     case BitOr:
270         newval = PyNumber_Or(lv, rv);
271         break;
272     case BitXor:
273         newval = PyNumber_Xor(lv, rv);
274         break;
275     case BitAnd:
276         newval = PyNumber_And(lv, rv);
277         break;
278     // No builtin constants implement the following operators
279     case MatMult:
280         return 1;
281     // No default case, so the compiler will emit a warning if new binary
282     // operators are added without being handled here
283     }
284 
285     return make_const(node, newval, arena);
286 }
287 
288 static PyObject*
make_const_tuple(asdl_expr_seq * elts)289 make_const_tuple(asdl_expr_seq *elts)
290 {
291     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
292         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
293         if (e->kind != Constant_kind) {
294             return NULL;
295         }
296     }
297 
298     PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
299     if (newval == NULL) {
300         return NULL;
301     }
302 
303     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
304         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
305         PyObject *v = e->v.Constant.value;
306         Py_INCREF(v);
307         PyTuple_SET_ITEM(newval, i, v);
308     }
309     return newval;
310 }
311 
312 static int
fold_tuple(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)313 fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
314 {
315     PyObject *newval;
316 
317     if (node->v.Tuple.ctx != Load)
318         return 1;
319 
320     newval = make_const_tuple(node->v.Tuple.elts);
321     return make_const(node, newval, arena);
322 }
323 
324 static int
fold_subscr(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)325 fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
326 {
327     PyObject *newval;
328     expr_ty arg, idx;
329 
330     arg = node->v.Subscript.value;
331     idx = node->v.Subscript.slice;
332     if (node->v.Subscript.ctx != Load ||
333             arg->kind != Constant_kind ||
334             idx->kind != Constant_kind)
335     {
336         return 1;
337     }
338 
339     newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
340     return make_const(node, newval, arena);
341 }
342 
343 /* Change literal list or set of constants into constant
344    tuple or frozenset respectively.  Change literal list of
345    non-constants into tuple.
346    Used for right operand of "in" and "not in" tests and for iterable
347    in "for" loop and comprehensions.
348 */
349 static int
fold_iter(expr_ty arg,PyArena * arena,_PyASTOptimizeState * state)350 fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
351 {
352     PyObject *newval;
353     if (arg->kind == List_kind) {
354         /* First change a list into tuple. */
355         asdl_expr_seq *elts = arg->v.List.elts;
356         Py_ssize_t n = asdl_seq_LEN(elts);
357         for (Py_ssize_t i = 0; i < n; i++) {
358             expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
359             if (e->kind == Starred_kind) {
360                 return 1;
361             }
362         }
363         expr_context_ty ctx = arg->v.List.ctx;
364         arg->kind = Tuple_kind;
365         arg->v.Tuple.elts = elts;
366         arg->v.Tuple.ctx = ctx;
367         /* Try to create a constant tuple. */
368         newval = make_const_tuple(elts);
369     }
370     else if (arg->kind == Set_kind) {
371         newval = make_const_tuple(arg->v.Set.elts);
372         if (newval) {
373             Py_SETREF(newval, PyFrozenSet_New(newval));
374         }
375     }
376     else {
377         return 1;
378     }
379     return make_const(arg, newval, arena);
380 }
381 
382 static int
fold_compare(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)383 fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
384 {
385     asdl_int_seq *ops;
386     asdl_expr_seq *args;
387     Py_ssize_t i;
388 
389     ops = node->v.Compare.ops;
390     args = node->v.Compare.comparators;
391     /* TODO: optimize cases with literal arguments. */
392     /* Change literal list or set in 'in' or 'not in' into
393        tuple or frozenset respectively. */
394     i = asdl_seq_LEN(ops) - 1;
395     int op = asdl_seq_GET(ops, i);
396     if (op == In || op == NotIn) {
397         if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
398             return 0;
399         }
400     }
401     return 1;
402 }
403 
404 static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
405 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
406 static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
407 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
408 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
409 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
410 static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
411 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
412 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
413 static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
414 static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
415 
416 #define CALL(FUNC, TYPE, ARG) \
417     if (!FUNC((ARG), ctx_, state)) \
418         return 0;
419 
420 #define CALL_OPT(FUNC, TYPE, ARG) \
421     if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
422         return 0;
423 
424 #define CALL_SEQ(FUNC, TYPE, ARG) { \
425     int i; \
426     asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
427     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
428         TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
429         if (elt != NULL && !FUNC(elt, ctx_, state)) \
430             return 0; \
431     } \
432 }
433 
434 #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
435     int i; \
436     asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
437     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
438         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
439         if (!FUNC(elt, ctx_, state)) \
440             return 0; \
441     } \
442 }
443 
444 static int
astfold_body(asdl_stmt_seq * stmts,PyArena * ctx_,_PyASTOptimizeState * state)445 astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
446 {
447     int docstring = _PyAST_GetDocString(stmts) != NULL;
448     CALL_SEQ(astfold_stmt, stmt, stmts);
449     if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
450         stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
451         asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
452         if (!values) {
453             return 0;
454         }
455         asdl_seq_SET(values, 0, st->v.Expr.value);
456         expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
457                                         st->end_lineno, st->end_col_offset,
458                                         ctx_);
459         if (!expr) {
460             return 0;
461         }
462         st->v.Expr.value = expr;
463     }
464     return 1;
465 }
466 
467 static int
astfold_mod(mod_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)468 astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
469 {
470     switch (node_->kind) {
471     case Module_kind:
472         CALL(astfold_body, asdl_seq, node_->v.Module.body);
473         break;
474     case Interactive_kind:
475         CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
476         break;
477     case Expression_kind:
478         CALL(astfold_expr, expr_ty, node_->v.Expression.body);
479         break;
480     // The following top level nodes don't participate in constant folding
481     case FunctionType_kind:
482         break;
483     // No default case, so the compiler will emit a warning if new top level
484     // compilation nodes are added without being handled here
485     }
486     return 1;
487 }
488 
489 static int
astfold_expr(expr_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)490 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
491 {
492     if (++state->recursion_depth > state->recursion_limit) {
493         PyErr_SetString(PyExc_RecursionError,
494                         "maximum recursion depth exceeded during compilation");
495         return 0;
496     }
497     switch (node_->kind) {
498     case BoolOp_kind:
499         CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
500         break;
501     case BinOp_kind:
502         CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
503         CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
504         CALL(fold_binop, expr_ty, node_);
505         break;
506     case UnaryOp_kind:
507         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
508         CALL(fold_unaryop, expr_ty, node_);
509         break;
510     case Lambda_kind:
511         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
512         CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
513         break;
514     case IfExp_kind:
515         CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
516         CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
517         CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
518         break;
519     case Dict_kind:
520         CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
521         CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
522         break;
523     case Set_kind:
524         CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
525         break;
526     case ListComp_kind:
527         CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
528         CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
529         break;
530     case SetComp_kind:
531         CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
532         CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
533         break;
534     case DictComp_kind:
535         CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
536         CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
537         CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
538         break;
539     case GeneratorExp_kind:
540         CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
541         CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
542         break;
543     case Await_kind:
544         CALL(astfold_expr, expr_ty, node_->v.Await.value);
545         break;
546     case Yield_kind:
547         CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
548         break;
549     case YieldFrom_kind:
550         CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
551         break;
552     case Compare_kind:
553         CALL(astfold_expr, expr_ty, node_->v.Compare.left);
554         CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
555         CALL(fold_compare, expr_ty, node_);
556         break;
557     case Call_kind:
558         CALL(astfold_expr, expr_ty, node_->v.Call.func);
559         CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
560         CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
561         break;
562     case FormattedValue_kind:
563         CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
564         CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
565         break;
566     case JoinedStr_kind:
567         CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
568         break;
569     case Attribute_kind:
570         CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
571         break;
572     case Subscript_kind:
573         CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
574         CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
575         CALL(fold_subscr, expr_ty, node_);
576         break;
577     case Starred_kind:
578         CALL(astfold_expr, expr_ty, node_->v.Starred.value);
579         break;
580     case Slice_kind:
581         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
582         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
583         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
584         break;
585     case List_kind:
586         CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
587         break;
588     case Tuple_kind:
589         CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
590         CALL(fold_tuple, expr_ty, node_);
591         break;
592     case Name_kind:
593         if (node_->v.Name.ctx == Load &&
594                 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
595             state->recursion_depth--;
596             return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
597         }
598         break;
599     case NamedExpr_kind:
600         CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
601         break;
602     case Constant_kind:
603         // Already a constant, nothing further to do
604         break;
605     // No default case, so the compiler will emit a warning if new expression
606     // kinds are added without being handled here
607     }
608     state->recursion_depth--;
609     return 1;
610 }
611 
612 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)613 astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
614 {
615     CALL(astfold_expr, expr_ty, node_->value);
616     return 1;
617 }
618 
619 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)620 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
621 {
622     CALL(astfold_expr, expr_ty, node_->target);
623     CALL(astfold_expr, expr_ty, node_->iter);
624     CALL_SEQ(astfold_expr, expr, node_->ifs);
625 
626     CALL(fold_iter, expr_ty, node_->iter);
627     return 1;
628 }
629 
630 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)631 astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
632 {
633     CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
634     CALL_SEQ(astfold_arg, arg, node_->args);
635     CALL_OPT(astfold_arg, arg_ty, node_->vararg);
636     CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
637     CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
638     CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
639     CALL_SEQ(astfold_expr, expr, node_->defaults);
640     return 1;
641 }
642 
643 static int
astfold_arg(arg_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)644 astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
645 {
646     if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
647         CALL_OPT(astfold_expr, expr_ty, node_->annotation);
648     }
649     return 1;
650 }
651 
652 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)653 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
654 {
655     if (++state->recursion_depth > state->recursion_limit) {
656         PyErr_SetString(PyExc_RecursionError,
657                         "maximum recursion depth exceeded during compilation");
658         return 0;
659     }
660     switch (node_->kind) {
661     case FunctionDef_kind:
662         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
663         CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
664         CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
665         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
666             CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
667         }
668         break;
669     case AsyncFunctionDef_kind:
670         CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
671         CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
672         CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
673         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
674             CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
675         }
676         break;
677     case ClassDef_kind:
678         CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
679         CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
680         CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
681         CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
682         break;
683     case Return_kind:
684         CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
685         break;
686     case Delete_kind:
687         CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
688         break;
689     case Assign_kind:
690         CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
691         CALL(astfold_expr, expr_ty, node_->v.Assign.value);
692         break;
693     case AugAssign_kind:
694         CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
695         CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
696         break;
697     case AnnAssign_kind:
698         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
699         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
700             CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
701         }
702         CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
703         break;
704     case For_kind:
705         CALL(astfold_expr, expr_ty, node_->v.For.target);
706         CALL(astfold_expr, expr_ty, node_->v.For.iter);
707         CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
708         CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
709 
710         CALL(fold_iter, expr_ty, node_->v.For.iter);
711         break;
712     case AsyncFor_kind:
713         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
714         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
715         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
716         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
717         break;
718     case While_kind:
719         CALL(astfold_expr, expr_ty, node_->v.While.test);
720         CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
721         CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
722         break;
723     case If_kind:
724         CALL(astfold_expr, expr_ty, node_->v.If.test);
725         CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
726         CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
727         break;
728     case With_kind:
729         CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
730         CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
731         break;
732     case AsyncWith_kind:
733         CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
734         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
735         break;
736     case Raise_kind:
737         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
738         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
739         break;
740     case Try_kind:
741         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
742         CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
743         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
744         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
745         break;
746     case Assert_kind:
747         CALL(astfold_expr, expr_ty, node_->v.Assert.test);
748         CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
749         break;
750     case Expr_kind:
751         CALL(astfold_expr, expr_ty, node_->v.Expr.value);
752         break;
753     case Match_kind:
754         CALL(astfold_expr, expr_ty, node_->v.Match.subject);
755         CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
756         break;
757     // The following statements don't contain any subexpressions to be folded
758     case Import_kind:
759     case ImportFrom_kind:
760     case Global_kind:
761     case Nonlocal_kind:
762     case Pass_kind:
763     case Break_kind:
764     case Continue_kind:
765         break;
766     // No default case, so the compiler will emit a warning if new statement
767     // kinds are added without being handled here
768     }
769     state->recursion_depth--;
770     return 1;
771 }
772 
773 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)774 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
775 {
776     switch (node_->kind) {
777     case ExceptHandler_kind:
778         CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
779         CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
780         break;
781     // No default case, so the compiler will emit a warning if new handler
782     // kinds are added without being handled here
783     }
784     return 1;
785 }
786 
787 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)788 astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
789 {
790     CALL(astfold_expr, expr_ty, node_->context_expr);
791     CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
792     return 1;
793 }
794 
795 static int
astfold_pattern(pattern_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)796 astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
797 {
798     // Currently, this is really only used to form complex/negative numeric
799     // constants in MatchValue and MatchMapping nodes
800     // We still recurse into all subexpressions and subpatterns anyway
801     if (++state->recursion_depth > state->recursion_limit) {
802         PyErr_SetString(PyExc_RecursionError,
803                         "maximum recursion depth exceeded during compilation");
804         return 0;
805     }
806     switch (node_->kind) {
807         case MatchValue_kind:
808             CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
809             break;
810         case MatchSingleton_kind:
811             break;
812         case MatchSequence_kind:
813             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
814             break;
815         case MatchMapping_kind:
816             CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
817             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
818             break;
819         case MatchClass_kind:
820             CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
821             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
822             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
823             break;
824         case MatchStar_kind:
825             break;
826         case MatchAs_kind:
827             if (node_->v.MatchAs.pattern) {
828                 CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
829             }
830             break;
831         case MatchOr_kind:
832             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
833             break;
834     // No default case, so the compiler will emit a warning if new pattern
835     // kinds are added without being handled here
836     }
837     state->recursion_depth--;
838     return 1;
839 }
840 
841 static int
astfold_match_case(match_case_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)842 astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
843 {
844     CALL(astfold_pattern, expr_ty, node_->pattern);
845     CALL_OPT(astfold_expr, expr_ty, node_->guard);
846     CALL_SEQ(astfold_stmt, stmt, node_->body);
847     return 1;
848 }
849 
850 #undef CALL
851 #undef CALL_OPT
852 #undef CALL_SEQ
853 #undef CALL_INT_SEQ
854 
855 /* See comments in symtable.c. */
856 #define COMPILER_STACK_FRAME_SCALE 3
857 
858 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,_PyASTOptimizeState * state)859 _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
860 {
861     PyThreadState *tstate;
862     int recursion_limit = Py_GetRecursionLimit();
863     int starting_recursion_depth;
864 
865     /* Setup recursion depth check counters */
866     tstate = _PyThreadState_GET();
867     if (!tstate) {
868         return 0;
869     }
870     /* Be careful here to prevent overflow. */
871     starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
872         tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth;
873     state->recursion_depth = starting_recursion_depth;
874     state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
875         recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
876 
877     int ret = astfold_mod(mod, arena, state);
878     assert(ret || PyErr_Occurred());
879 
880     /* Check that the recursion depth counting balanced correctly */
881     if (ret && state->recursion_depth != starting_recursion_depth) {
882         PyErr_Format(PyExc_SystemError,
883             "AST optimizer recursion depth mismatch (before=%d, after=%d)",
884             starting_recursion_depth, state->recursion_depth);
885         return 0;
886     }
887 
888     return ret;
889 }
890