1 /* AST Optimizer */
2 #include "Python.h"
3 #include "pycore_ast.h" // _PyAST_GetDocString()
4 #include "pycore_format.h" // F_LJUST
5 #include "pycore_long.h" // _PyLong
6 #include "pycore_pystate.h" // _PyThreadState_GET()
7 #include "pycore_setobject.h" // _PySet_NextEntry()
8
9
10 typedef struct {
11 int optimize;
12 int ff_features;
13
14 int recursion_depth; /* current recursion depth */
15 int recursion_limit; /* recursion limit */
16 } _PyASTOptimizeState;
17
18
19 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)20 make_const(expr_ty node, PyObject *val, PyArena *arena)
21 {
22 // Even if no new value was calculated, make_const may still
23 // need to clear an error (e.g. for division by zero)
24 if (val == NULL) {
25 if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
26 return 0;
27 }
28 PyErr_Clear();
29 return 1;
30 }
31 if (_PyArena_AddPyObject(arena, val) < 0) {
32 Py_DECREF(val);
33 return 0;
34 }
35 node->kind = Constant_kind;
36 node->v.Constant.kind = NULL;
37 node->v.Constant.value = val;
38 return 1;
39 }
40
41 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
42
43 static int
has_starred(asdl_expr_seq * elts)44 has_starred(asdl_expr_seq *elts)
45 {
46 Py_ssize_t n = asdl_seq_LEN(elts);
47 for (Py_ssize_t i = 0; i < n; i++) {
48 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
49 if (e->kind == Starred_kind) {
50 return 1;
51 }
52 }
53 return 0;
54 }
55
56
57 static PyObject*
unary_not(PyObject * v)58 unary_not(PyObject *v)
59 {
60 int r = PyObject_IsTrue(v);
61 if (r < 0)
62 return NULL;
63 return PyBool_FromLong(!r);
64 }
65
66 static int
fold_unaryop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)67 fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
68 {
69 expr_ty arg = node->v.UnaryOp.operand;
70
71 if (arg->kind != Constant_kind) {
72 /* Fold not into comparison */
73 if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
74 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
75 /* Eq and NotEq are often implemented in terms of one another, so
76 folding not (self == other) into self != other breaks implementation
77 of !=. Detecting such cases doesn't seem worthwhile.
78 Python uses </> for 'is subset'/'is superset' operations on sets.
79 They don't satisfy not folding laws. */
80 cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
81 switch (op) {
82 case Is:
83 op = IsNot;
84 break;
85 case IsNot:
86 op = Is;
87 break;
88 case In:
89 op = NotIn;
90 break;
91 case NotIn:
92 op = In;
93 break;
94 // The remaining comparison operators can't be safely inverted
95 case Eq:
96 case NotEq:
97 case Lt:
98 case LtE:
99 case Gt:
100 case GtE:
101 op = 0; // The AST enums leave "0" free as an "unused" marker
102 break;
103 // No default case, so the compiler will emit a warning if new
104 // comparison operators are added without being handled here
105 }
106 if (op) {
107 asdl_seq_SET(arg->v.Compare.ops, 0, op);
108 COPY_NODE(node, arg);
109 return 1;
110 }
111 }
112 return 1;
113 }
114
115 typedef PyObject *(*unary_op)(PyObject*);
116 static const unary_op ops[] = {
117 [Invert] = PyNumber_Invert,
118 [Not] = unary_not,
119 [UAdd] = PyNumber_Positive,
120 [USub] = PyNumber_Negative,
121 };
122 PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
123 return make_const(node, newval, arena);
124 }
125
126 /* Check whether a collection doesn't containing too much items (including
127 subcollections). This protects from creating a constant that needs
128 too much time for calculating a hash.
129 "limit" is the maximal number of items.
130 Returns the negative number if the total number of items exceeds the
131 limit. Otherwise returns the limit minus the total number of items.
132 */
133
134 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)135 check_complexity(PyObject *obj, Py_ssize_t limit)
136 {
137 if (PyTuple_Check(obj)) {
138 Py_ssize_t i;
139 limit -= PyTuple_GET_SIZE(obj);
140 for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
141 limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
142 }
143 return limit;
144 }
145 return limit;
146 }
147
148 #define MAX_INT_SIZE 128 /* bits */
149 #define MAX_COLLECTION_SIZE 256 /* items */
150 #define MAX_STR_SIZE 4096 /* characters */
151 #define MAX_TOTAL_ITEMS 1024 /* including nested collections */
152
153 static PyObject *
safe_multiply(PyObject * v,PyObject * w)154 safe_multiply(PyObject *v, PyObject *w)
155 {
156 if (PyLong_Check(v) && PyLong_Check(w) &&
157 !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
158 ) {
159 size_t vbits = _PyLong_NumBits(v);
160 size_t wbits = _PyLong_NumBits(w);
161 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
162 return NULL;
163 }
164 if (vbits + wbits > MAX_INT_SIZE) {
165 return NULL;
166 }
167 }
168 else if (PyLong_Check(v) && PyTuple_Check(w)) {
169 Py_ssize_t size = PyTuple_GET_SIZE(w);
170 if (size) {
171 long n = PyLong_AsLong(v);
172 if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
173 return NULL;
174 }
175 if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
176 return NULL;
177 }
178 }
179 }
180 else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
181 Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
182 PyBytes_GET_SIZE(w);
183 if (size) {
184 long n = PyLong_AsLong(v);
185 if (n < 0 || n > MAX_STR_SIZE / size) {
186 return NULL;
187 }
188 }
189 }
190 else if (PyLong_Check(w) &&
191 (PyTuple_Check(v) || PyUnicode_Check(v) || PyBytes_Check(v)))
192 {
193 return safe_multiply(w, v);
194 }
195
196 return PyNumber_Multiply(v, w);
197 }
198
199 static PyObject *
safe_power(PyObject * v,PyObject * w)200 safe_power(PyObject *v, PyObject *w)
201 {
202 if (PyLong_Check(v) && PyLong_Check(w) &&
203 !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
204 ) {
205 size_t vbits = _PyLong_NumBits(v);
206 size_t wbits = PyLong_AsSize_t(w);
207 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
208 return NULL;
209 }
210 if (vbits > MAX_INT_SIZE / wbits) {
211 return NULL;
212 }
213 }
214
215 return PyNumber_Power(v, w, Py_None);
216 }
217
218 static PyObject *
safe_lshift(PyObject * v,PyObject * w)219 safe_lshift(PyObject *v, PyObject *w)
220 {
221 if (PyLong_Check(v) && PyLong_Check(w) &&
222 !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
223 ) {
224 size_t vbits = _PyLong_NumBits(v);
225 size_t wbits = PyLong_AsSize_t(w);
226 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
227 return NULL;
228 }
229 if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
230 return NULL;
231 }
232 }
233
234 return PyNumber_Lshift(v, w);
235 }
236
237 static PyObject *
safe_mod(PyObject * v,PyObject * w)238 safe_mod(PyObject *v, PyObject *w)
239 {
240 if (PyUnicode_Check(v) || PyBytes_Check(v)) {
241 return NULL;
242 }
243
244 return PyNumber_Remainder(v, w);
245 }
246
247
248 static expr_ty
parse_literal(PyObject * fmt,Py_ssize_t * ppos,PyArena * arena)249 parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
250 {
251 const void *data = PyUnicode_DATA(fmt);
252 int kind = PyUnicode_KIND(fmt);
253 Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
254 Py_ssize_t start, pos;
255 int has_percents = 0;
256 start = pos = *ppos;
257 while (pos < size) {
258 if (PyUnicode_READ(kind, data, pos) != '%') {
259 pos++;
260 }
261 else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
262 has_percents = 1;
263 pos += 2;
264 }
265 else {
266 break;
267 }
268 }
269 *ppos = pos;
270 if (pos == start) {
271 return NULL;
272 }
273 PyObject *str = PyUnicode_Substring(fmt, start, pos);
274 /* str = str.replace('%%', '%') */
275 if (str && has_percents) {
276 _Py_DECLARE_STR(dbl_percent, "%%");
277 Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
278 _Py_LATIN1_CHR('%'), -1));
279 }
280 if (!str) {
281 return NULL;
282 }
283
284 if (_PyArena_AddPyObject(arena, str) < 0) {
285 Py_DECREF(str);
286 return NULL;
287 }
288 return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
289 }
290
291 #define MAXDIGITS 3
292
293 static int
simple_format_arg_parse(PyObject * fmt,Py_ssize_t * ppos,int * spec,int * flags,int * width,int * prec)294 simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
295 int *spec, int *flags, int *width, int *prec)
296 {
297 Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
298 Py_UCS4 ch;
299
300 #define NEXTC do { \
301 if (pos >= len) { \
302 return 0; \
303 } \
304 ch = PyUnicode_READ_CHAR(fmt, pos); \
305 pos++; \
306 } while (0)
307
308 *flags = 0;
309 while (1) {
310 NEXTC;
311 switch (ch) {
312 case '-': *flags |= F_LJUST; continue;
313 case '+': *flags |= F_SIGN; continue;
314 case ' ': *flags |= F_BLANK; continue;
315 case '#': *flags |= F_ALT; continue;
316 case '0': *flags |= F_ZERO; continue;
317 }
318 break;
319 }
320 if ('0' <= ch && ch <= '9') {
321 *width = 0;
322 int digits = 0;
323 while ('0' <= ch && ch <= '9') {
324 *width = *width * 10 + (ch - '0');
325 NEXTC;
326 if (++digits >= MAXDIGITS) {
327 return 0;
328 }
329 }
330 }
331
332 if (ch == '.') {
333 NEXTC;
334 *prec = 0;
335 if ('0' <= ch && ch <= '9') {
336 int digits = 0;
337 while ('0' <= ch && ch <= '9') {
338 *prec = *prec * 10 + (ch - '0');
339 NEXTC;
340 if (++digits >= MAXDIGITS) {
341 return 0;
342 }
343 }
344 }
345 }
346 *spec = ch;
347 *ppos = pos;
348 return 1;
349
350 #undef NEXTC
351 }
352
353 static expr_ty
parse_format(PyObject * fmt,Py_ssize_t * ppos,expr_ty arg,PyArena * arena)354 parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
355 {
356 int spec, flags, width = -1, prec = -1;
357 if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
358 // Unsupported format.
359 return NULL;
360 }
361 if (spec == 's' || spec == 'r' || spec == 'a') {
362 char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
363 if (!(flags & F_LJUST) && width > 0) {
364 *p++ = '>';
365 }
366 if (width >= 0) {
367 p += snprintf(p, MAXDIGITS + 1, "%d", width);
368 }
369 if (prec >= 0) {
370 p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
371 }
372 expr_ty format_spec = NULL;
373 if (p != buf) {
374 PyObject *str = PyUnicode_FromString(buf);
375 if (str == NULL) {
376 return NULL;
377 }
378 if (_PyArena_AddPyObject(arena, str) < 0) {
379 Py_DECREF(str);
380 return NULL;
381 }
382 format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
383 if (format_spec == NULL) {
384 return NULL;
385 }
386 }
387 return _PyAST_FormattedValue(arg, spec, format_spec,
388 arg->lineno, arg->col_offset,
389 arg->end_lineno, arg->end_col_offset,
390 arena);
391 }
392 // Unsupported format.
393 return NULL;
394 }
395
396 static int
optimize_format(expr_ty node,PyObject * fmt,asdl_expr_seq * elts,PyArena * arena)397 optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
398 {
399 Py_ssize_t pos = 0;
400 Py_ssize_t cnt = 0;
401 asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
402 if (!seq) {
403 return 0;
404 }
405 seq->size = 0;
406
407 while (1) {
408 expr_ty lit = parse_literal(fmt, &pos, arena);
409 if (lit) {
410 asdl_seq_SET(seq, seq->size++, lit);
411 }
412 else if (PyErr_Occurred()) {
413 return 0;
414 }
415
416 if (pos >= PyUnicode_GET_LENGTH(fmt)) {
417 break;
418 }
419 if (cnt >= asdl_seq_LEN(elts)) {
420 // More format units than items.
421 return 1;
422 }
423 assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
424 pos++;
425 expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
426 cnt++;
427 if (!expr) {
428 return !PyErr_Occurred();
429 }
430 asdl_seq_SET(seq, seq->size++, expr);
431 }
432 if (cnt < asdl_seq_LEN(elts)) {
433 // More items than format units.
434 return 1;
435 }
436 expr_ty res = _PyAST_JoinedStr(seq,
437 node->lineno, node->col_offset,
438 node->end_lineno, node->end_col_offset,
439 arena);
440 if (!res) {
441 return 0;
442 }
443 COPY_NODE(node, res);
444 // PySys_FormatStderr("format = %R\n", fmt);
445 return 1;
446 }
447
448 static int
fold_binop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)449 fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
450 {
451 expr_ty lhs, rhs;
452 lhs = node->v.BinOp.left;
453 rhs = node->v.BinOp.right;
454 if (lhs->kind != Constant_kind) {
455 return 1;
456 }
457 PyObject *lv = lhs->v.Constant.value;
458
459 if (node->v.BinOp.op == Mod &&
460 rhs->kind == Tuple_kind &&
461 PyUnicode_Check(lv) &&
462 !has_starred(rhs->v.Tuple.elts))
463 {
464 return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
465 }
466
467 if (rhs->kind != Constant_kind) {
468 return 1;
469 }
470
471 PyObject *rv = rhs->v.Constant.value;
472 PyObject *newval = NULL;
473
474 switch (node->v.BinOp.op) {
475 case Add:
476 newval = PyNumber_Add(lv, rv);
477 break;
478 case Sub:
479 newval = PyNumber_Subtract(lv, rv);
480 break;
481 case Mult:
482 newval = safe_multiply(lv, rv);
483 break;
484 case Div:
485 newval = PyNumber_TrueDivide(lv, rv);
486 break;
487 case FloorDiv:
488 newval = PyNumber_FloorDivide(lv, rv);
489 break;
490 case Mod:
491 newval = safe_mod(lv, rv);
492 break;
493 case Pow:
494 newval = safe_power(lv, rv);
495 break;
496 case LShift:
497 newval = safe_lshift(lv, rv);
498 break;
499 case RShift:
500 newval = PyNumber_Rshift(lv, rv);
501 break;
502 case BitOr:
503 newval = PyNumber_Or(lv, rv);
504 break;
505 case BitXor:
506 newval = PyNumber_Xor(lv, rv);
507 break;
508 case BitAnd:
509 newval = PyNumber_And(lv, rv);
510 break;
511 // No builtin constants implement the following operators
512 case MatMult:
513 return 1;
514 // No default case, so the compiler will emit a warning if new binary
515 // operators are added without being handled here
516 }
517
518 return make_const(node, newval, arena);
519 }
520
521 static PyObject*
make_const_tuple(asdl_expr_seq * elts)522 make_const_tuple(asdl_expr_seq *elts)
523 {
524 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
525 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
526 if (e->kind != Constant_kind) {
527 return NULL;
528 }
529 }
530
531 PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
532 if (newval == NULL) {
533 return NULL;
534 }
535
536 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
537 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
538 PyObject *v = e->v.Constant.value;
539 PyTuple_SET_ITEM(newval, i, Py_NewRef(v));
540 }
541 return newval;
542 }
543
544 static int
fold_tuple(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)545 fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
546 {
547 PyObject *newval;
548
549 if (node->v.Tuple.ctx != Load)
550 return 1;
551
552 newval = make_const_tuple(node->v.Tuple.elts);
553 return make_const(node, newval, arena);
554 }
555
556 static int
fold_subscr(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)557 fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
558 {
559 PyObject *newval;
560 expr_ty arg, idx;
561
562 arg = node->v.Subscript.value;
563 idx = node->v.Subscript.slice;
564 if (node->v.Subscript.ctx != Load ||
565 arg->kind != Constant_kind ||
566 idx->kind != Constant_kind)
567 {
568 return 1;
569 }
570
571 newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
572 return make_const(node, newval, arena);
573 }
574
575 /* Change literal list or set of constants into constant
576 tuple or frozenset respectively. Change literal list of
577 non-constants into tuple.
578 Used for right operand of "in" and "not in" tests and for iterable
579 in "for" loop and comprehensions.
580 */
581 static int
fold_iter(expr_ty arg,PyArena * arena,_PyASTOptimizeState * state)582 fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
583 {
584 PyObject *newval;
585 if (arg->kind == List_kind) {
586 /* First change a list into tuple. */
587 asdl_expr_seq *elts = arg->v.List.elts;
588 if (has_starred(elts)) {
589 return 1;
590 }
591 expr_context_ty ctx = arg->v.List.ctx;
592 arg->kind = Tuple_kind;
593 arg->v.Tuple.elts = elts;
594 arg->v.Tuple.ctx = ctx;
595 /* Try to create a constant tuple. */
596 newval = make_const_tuple(elts);
597 }
598 else if (arg->kind == Set_kind) {
599 newval = make_const_tuple(arg->v.Set.elts);
600 if (newval) {
601 Py_SETREF(newval, PyFrozenSet_New(newval));
602 }
603 }
604 else {
605 return 1;
606 }
607 return make_const(arg, newval, arena);
608 }
609
610 static int
fold_compare(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)611 fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
612 {
613 asdl_int_seq *ops;
614 asdl_expr_seq *args;
615 Py_ssize_t i;
616
617 ops = node->v.Compare.ops;
618 args = node->v.Compare.comparators;
619 /* Change literal list or set in 'in' or 'not in' into
620 tuple or frozenset respectively. */
621 i = asdl_seq_LEN(ops) - 1;
622 int op = asdl_seq_GET(ops, i);
623 if (op == In || op == NotIn) {
624 if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
625 return 0;
626 }
627 }
628 return 1;
629 }
630
631 static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
632 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
633 static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
634 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
635 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
636 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
637 static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
638 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
639 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
640 static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
641 static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
642 static int astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
643
644 #define CALL(FUNC, TYPE, ARG) \
645 if (!FUNC((ARG), ctx_, state)) \
646 return 0;
647
648 #define CALL_OPT(FUNC, TYPE, ARG) \
649 if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
650 return 0;
651
652 #define CALL_SEQ(FUNC, TYPE, ARG) { \
653 int i; \
654 asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
655 for (i = 0; i < asdl_seq_LEN(seq); i++) { \
656 TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
657 if (elt != NULL && !FUNC(elt, ctx_, state)) \
658 return 0; \
659 } \
660 }
661
662
663 static int
astfold_body(asdl_stmt_seq * stmts,PyArena * ctx_,_PyASTOptimizeState * state)664 astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
665 {
666 int docstring = _PyAST_GetDocString(stmts) != NULL;
667 CALL_SEQ(astfold_stmt, stmt, stmts);
668 if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
669 stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
670 asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
671 if (!values) {
672 return 0;
673 }
674 asdl_seq_SET(values, 0, st->v.Expr.value);
675 expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
676 st->end_lineno, st->end_col_offset,
677 ctx_);
678 if (!expr) {
679 return 0;
680 }
681 st->v.Expr.value = expr;
682 }
683 return 1;
684 }
685
686 static int
astfold_mod(mod_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)687 astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
688 {
689 switch (node_->kind) {
690 case Module_kind:
691 CALL(astfold_body, asdl_seq, node_->v.Module.body);
692 break;
693 case Interactive_kind:
694 CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
695 break;
696 case Expression_kind:
697 CALL(astfold_expr, expr_ty, node_->v.Expression.body);
698 break;
699 // The following top level nodes don't participate in constant folding
700 case FunctionType_kind:
701 break;
702 // No default case, so the compiler will emit a warning if new top level
703 // compilation nodes are added without being handled here
704 }
705 return 1;
706 }
707
708 static int
astfold_expr(expr_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)709 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
710 {
711 if (++state->recursion_depth > state->recursion_limit) {
712 PyErr_SetString(PyExc_RecursionError,
713 "maximum recursion depth exceeded during compilation");
714 return 0;
715 }
716 switch (node_->kind) {
717 case BoolOp_kind:
718 CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
719 break;
720 case BinOp_kind:
721 CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
722 CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
723 CALL(fold_binop, expr_ty, node_);
724 break;
725 case UnaryOp_kind:
726 CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
727 CALL(fold_unaryop, expr_ty, node_);
728 break;
729 case Lambda_kind:
730 CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
731 CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
732 break;
733 case IfExp_kind:
734 CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
735 CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
736 CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
737 break;
738 case Dict_kind:
739 CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
740 CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
741 break;
742 case Set_kind:
743 CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
744 break;
745 case ListComp_kind:
746 CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
747 CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
748 break;
749 case SetComp_kind:
750 CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
751 CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
752 break;
753 case DictComp_kind:
754 CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
755 CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
756 CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
757 break;
758 case GeneratorExp_kind:
759 CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
760 CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
761 break;
762 case Await_kind:
763 CALL(astfold_expr, expr_ty, node_->v.Await.value);
764 break;
765 case Yield_kind:
766 CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
767 break;
768 case YieldFrom_kind:
769 CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
770 break;
771 case Compare_kind:
772 CALL(astfold_expr, expr_ty, node_->v.Compare.left);
773 CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
774 CALL(fold_compare, expr_ty, node_);
775 break;
776 case Call_kind:
777 CALL(astfold_expr, expr_ty, node_->v.Call.func);
778 CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
779 CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
780 break;
781 case FormattedValue_kind:
782 CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
783 CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
784 break;
785 case JoinedStr_kind:
786 CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
787 break;
788 case Attribute_kind:
789 CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
790 break;
791 case Subscript_kind:
792 CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
793 CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
794 CALL(fold_subscr, expr_ty, node_);
795 break;
796 case Starred_kind:
797 CALL(astfold_expr, expr_ty, node_->v.Starred.value);
798 break;
799 case Slice_kind:
800 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
801 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
802 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
803 break;
804 case List_kind:
805 CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
806 break;
807 case Tuple_kind:
808 CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
809 CALL(fold_tuple, expr_ty, node_);
810 break;
811 case Name_kind:
812 if (node_->v.Name.ctx == Load &&
813 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
814 state->recursion_depth--;
815 return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
816 }
817 break;
818 case NamedExpr_kind:
819 CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
820 break;
821 case Constant_kind:
822 // Already a constant, nothing further to do
823 break;
824 // No default case, so the compiler will emit a warning if new expression
825 // kinds are added without being handled here
826 }
827 state->recursion_depth--;
828 return 1;
829 }
830
831 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)832 astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
833 {
834 CALL(astfold_expr, expr_ty, node_->value);
835 return 1;
836 }
837
838 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)839 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
840 {
841 CALL(astfold_expr, expr_ty, node_->target);
842 CALL(astfold_expr, expr_ty, node_->iter);
843 CALL_SEQ(astfold_expr, expr, node_->ifs);
844
845 CALL(fold_iter, expr_ty, node_->iter);
846 return 1;
847 }
848
849 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)850 astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
851 {
852 CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
853 CALL_SEQ(astfold_arg, arg, node_->args);
854 CALL_OPT(astfold_arg, arg_ty, node_->vararg);
855 CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
856 CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
857 CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
858 CALL_SEQ(astfold_expr, expr, node_->defaults);
859 return 1;
860 }
861
862 static int
astfold_arg(arg_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)863 astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
864 {
865 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
866 CALL_OPT(astfold_expr, expr_ty, node_->annotation);
867 }
868 return 1;
869 }
870
871 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)872 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
873 {
874 if (++state->recursion_depth > state->recursion_limit) {
875 PyErr_SetString(PyExc_RecursionError,
876 "maximum recursion depth exceeded during compilation");
877 return 0;
878 }
879 switch (node_->kind) {
880 case FunctionDef_kind:
881 CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
882 CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
883 CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
884 CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
885 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
886 CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
887 }
888 break;
889 case AsyncFunctionDef_kind:
890 CALL_SEQ(astfold_type_param, type_param, node_->v.AsyncFunctionDef.type_params);
891 CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
892 CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
893 CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
894 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
895 CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
896 }
897 break;
898 case ClassDef_kind:
899 CALL_SEQ(astfold_type_param, type_param, node_->v.ClassDef.type_params);
900 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
901 CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
902 CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
903 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
904 break;
905 case Return_kind:
906 CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
907 break;
908 case Delete_kind:
909 CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
910 break;
911 case Assign_kind:
912 CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
913 CALL(astfold_expr, expr_ty, node_->v.Assign.value);
914 break;
915 case AugAssign_kind:
916 CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
917 CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
918 break;
919 case AnnAssign_kind:
920 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
921 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
922 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
923 }
924 CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
925 break;
926 case TypeAlias_kind:
927 CALL(astfold_expr, expr_ty, node_->v.TypeAlias.name);
928 CALL_SEQ(astfold_type_param, type_param, node_->v.TypeAlias.type_params);
929 CALL(astfold_expr, expr_ty, node_->v.TypeAlias.value);
930 break;
931 case For_kind:
932 CALL(astfold_expr, expr_ty, node_->v.For.target);
933 CALL(astfold_expr, expr_ty, node_->v.For.iter);
934 CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
935 CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
936
937 CALL(fold_iter, expr_ty, node_->v.For.iter);
938 break;
939 case AsyncFor_kind:
940 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
941 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
942 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
943 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
944 break;
945 case While_kind:
946 CALL(astfold_expr, expr_ty, node_->v.While.test);
947 CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
948 CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
949 break;
950 case If_kind:
951 CALL(astfold_expr, expr_ty, node_->v.If.test);
952 CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
953 CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
954 break;
955 case With_kind:
956 CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
957 CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
958 break;
959 case AsyncWith_kind:
960 CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
961 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
962 break;
963 case Raise_kind:
964 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
965 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
966 break;
967 case Try_kind:
968 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
969 CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
970 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
971 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
972 break;
973 case TryStar_kind:
974 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
975 CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
976 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
977 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
978 break;
979 case Assert_kind:
980 CALL(astfold_expr, expr_ty, node_->v.Assert.test);
981 CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
982 break;
983 case Expr_kind:
984 CALL(astfold_expr, expr_ty, node_->v.Expr.value);
985 break;
986 case Match_kind:
987 CALL(astfold_expr, expr_ty, node_->v.Match.subject);
988 CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
989 break;
990 // The following statements don't contain any subexpressions to be folded
991 case Import_kind:
992 case ImportFrom_kind:
993 case Global_kind:
994 case Nonlocal_kind:
995 case Pass_kind:
996 case Break_kind:
997 case Continue_kind:
998 break;
999 // No default case, so the compiler will emit a warning if new statement
1000 // kinds are added without being handled here
1001 }
1002 state->recursion_depth--;
1003 return 1;
1004 }
1005
1006 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1007 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1008 {
1009 switch (node_->kind) {
1010 case ExceptHandler_kind:
1011 CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
1012 CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
1013 break;
1014 // No default case, so the compiler will emit a warning if new handler
1015 // kinds are added without being handled here
1016 }
1017 return 1;
1018 }
1019
1020 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1021 astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1022 {
1023 CALL(astfold_expr, expr_ty, node_->context_expr);
1024 CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
1025 return 1;
1026 }
1027
1028 static int
astfold_pattern(pattern_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1029 astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1030 {
1031 // Currently, this is really only used to form complex/negative numeric
1032 // constants in MatchValue and MatchMapping nodes
1033 // We still recurse into all subexpressions and subpatterns anyway
1034 if (++state->recursion_depth > state->recursion_limit) {
1035 PyErr_SetString(PyExc_RecursionError,
1036 "maximum recursion depth exceeded during compilation");
1037 return 0;
1038 }
1039 switch (node_->kind) {
1040 case MatchValue_kind:
1041 CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
1042 break;
1043 case MatchSingleton_kind:
1044 break;
1045 case MatchSequence_kind:
1046 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
1047 break;
1048 case MatchMapping_kind:
1049 CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
1050 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
1051 break;
1052 case MatchClass_kind:
1053 CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
1054 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
1055 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
1056 break;
1057 case MatchStar_kind:
1058 break;
1059 case MatchAs_kind:
1060 if (node_->v.MatchAs.pattern) {
1061 CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
1062 }
1063 break;
1064 case MatchOr_kind:
1065 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
1066 break;
1067 // No default case, so the compiler will emit a warning if new pattern
1068 // kinds are added without being handled here
1069 }
1070 state->recursion_depth--;
1071 return 1;
1072 }
1073
1074 static int
astfold_match_case(match_case_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1075 astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1076 {
1077 CALL(astfold_pattern, expr_ty, node_->pattern);
1078 CALL_OPT(astfold_expr, expr_ty, node_->guard);
1079 CALL_SEQ(astfold_stmt, stmt, node_->body);
1080 return 1;
1081 }
1082
1083 static int
astfold_type_param(type_param_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1084 astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1085 {
1086 switch (node_->kind) {
1087 case TypeVar_kind:
1088 CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
1089 CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.default_value);
1090 break;
1091 case ParamSpec_kind:
1092 CALL_OPT(astfold_expr, expr_ty, node_->v.ParamSpec.default_value);
1093 break;
1094 case TypeVarTuple_kind:
1095 CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVarTuple.default_value);
1096 break;
1097 }
1098 return 1;
1099 }
1100
1101 #undef CALL
1102 #undef CALL_OPT
1103 #undef CALL_SEQ
1104
1105 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,int optimize,int ff_features)1106 _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize, int ff_features)
1107 {
1108 PyThreadState *tstate;
1109 int starting_recursion_depth;
1110
1111 _PyASTOptimizeState state;
1112 state.optimize = optimize;
1113 state.ff_features = ff_features;
1114
1115 /* Setup recursion depth check counters */
1116 tstate = _PyThreadState_GET();
1117 if (!tstate) {
1118 return 0;
1119 }
1120 /* Be careful here to prevent overflow. */
1121 int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
1122 starting_recursion_depth = recursion_depth;
1123 state.recursion_depth = starting_recursion_depth;
1124 state.recursion_limit = Py_C_RECURSION_LIMIT;
1125
1126 int ret = astfold_mod(mod, arena, &state);
1127 assert(ret || PyErr_Occurred());
1128
1129 /* Check that the recursion depth counting balanced correctly */
1130 if (ret && state.recursion_depth != starting_recursion_depth) {
1131 PyErr_Format(PyExc_SystemError,
1132 "AST optimizer recursion depth mismatch (before=%d, after=%d)",
1133 starting_recursion_depth, state.recursion_depth);
1134 return 0;
1135 }
1136
1137 return ret;
1138 }
1139