1 /* AST Optimizer */
2 #include "Python.h"
3 #include "Python-ast.h"
4
5
6 /* TODO: is_const and get_const_value are copied from Python/compile.c.
7 It should be deduped in the future. Maybe, we can include this file
8 from compile.c?
9 */
10 static int
is_const(expr_ty e)11 is_const(expr_ty e)
12 {
13 switch (e->kind) {
14 case Constant_kind:
15 case Num_kind:
16 case Str_kind:
17 case Bytes_kind:
18 case Ellipsis_kind:
19 case NameConstant_kind:
20 return 1;
21 default:
22 return 0;
23 }
24 }
25
26 static PyObject *
get_const_value(expr_ty e)27 get_const_value(expr_ty e)
28 {
29 switch (e->kind) {
30 case Constant_kind:
31 return e->v.Constant.value;
32 case Num_kind:
33 return e->v.Num.n;
34 case Str_kind:
35 return e->v.Str.s;
36 case Bytes_kind:
37 return e->v.Bytes.s;
38 case Ellipsis_kind:
39 return Py_Ellipsis;
40 case NameConstant_kind:
41 return e->v.NameConstant.value;
42 default:
43 Py_UNREACHABLE();
44 }
45 }
46
47 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)48 make_const(expr_ty node, PyObject *val, PyArena *arena)
49 {
50 if (val == NULL) {
51 if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
52 return 0;
53 }
54 PyErr_Clear();
55 return 1;
56 }
57 if (PyArena_AddPyObject(arena, val) < 0) {
58 Py_DECREF(val);
59 return 0;
60 }
61 node->kind = Constant_kind;
62 node->v.Constant.value = val;
63 return 1;
64 }
65
66 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
67
68 static PyObject*
unary_not(PyObject * v)69 unary_not(PyObject *v)
70 {
71 int r = PyObject_IsTrue(v);
72 if (r < 0)
73 return NULL;
74 return PyBool_FromLong(!r);
75 }
76
77 static int
fold_unaryop(expr_ty node,PyArena * arena,int optimize)78 fold_unaryop(expr_ty node, PyArena *arena, int optimize)
79 {
80 expr_ty arg = node->v.UnaryOp.operand;
81
82 if (!is_const(arg)) {
83 /* Fold not into comparison */
84 if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
85 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
86 /* Eq and NotEq are often implemented in terms of one another, so
87 folding not (self == other) into self != other breaks implementation
88 of !=. Detecting such cases doesn't seem worthwhile.
89 Python uses </> for 'is subset'/'is superset' operations on sets.
90 They don't satisfy not folding laws. */
91 int op = asdl_seq_GET(arg->v.Compare.ops, 0);
92 switch (op) {
93 case Is:
94 op = IsNot;
95 break;
96 case IsNot:
97 op = Is;
98 break;
99 case In:
100 op = NotIn;
101 break;
102 case NotIn:
103 op = In;
104 break;
105 default:
106 op = 0;
107 }
108 if (op) {
109 asdl_seq_SET(arg->v.Compare.ops, 0, op);
110 COPY_NODE(node, arg);
111 return 1;
112 }
113 }
114 return 1;
115 }
116
117 typedef PyObject *(*unary_op)(PyObject*);
118 static const unary_op ops[] = {
119 [Invert] = PyNumber_Invert,
120 [Not] = unary_not,
121 [UAdd] = PyNumber_Positive,
122 [USub] = PyNumber_Negative,
123 };
124 PyObject *newval = ops[node->v.UnaryOp.op](get_const_value(arg));
125 return make_const(node, newval, arena);
126 }
127
128 /* Check whether a collection doesn't containing too much items (including
129 subcollections). This protects from creating a constant that needs
130 too much time for calculating a hash.
131 "limit" is the maximal number of items.
132 Returns the negative number if the total number of items exceeds the
133 limit. Otherwise returns the limit minus the total number of items.
134 */
135
136 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)137 check_complexity(PyObject *obj, Py_ssize_t limit)
138 {
139 if (PyTuple_Check(obj)) {
140 Py_ssize_t i;
141 limit -= PyTuple_GET_SIZE(obj);
142 for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
143 limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
144 }
145 return limit;
146 }
147 else if (PyFrozenSet_Check(obj)) {
148 Py_ssize_t i = 0;
149 PyObject *item;
150 Py_hash_t hash;
151 limit -= PySet_GET_SIZE(obj);
152 while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
153 limit = check_complexity(item, limit);
154 }
155 }
156 return limit;
157 }
158
159 #define MAX_INT_SIZE 128 /* bits */
160 #define MAX_COLLECTION_SIZE 256 /* items */
161 #define MAX_STR_SIZE 4096 /* characters */
162 #define MAX_TOTAL_ITEMS 1024 /* including nested collections */
163
164 static PyObject *
safe_multiply(PyObject * v,PyObject * w)165 safe_multiply(PyObject *v, PyObject *w)
166 {
167 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
168 size_t vbits = _PyLong_NumBits(v);
169 size_t wbits = _PyLong_NumBits(w);
170 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
171 return NULL;
172 }
173 if (vbits + wbits > MAX_INT_SIZE) {
174 return NULL;
175 }
176 }
177 else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
178 Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
179 PySet_GET_SIZE(w);
180 if (size) {
181 long n = PyLong_AsLong(v);
182 if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
183 return NULL;
184 }
185 if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
186 return NULL;
187 }
188 }
189 }
190 else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
191 Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
192 PyBytes_GET_SIZE(w);
193 if (size) {
194 long n = PyLong_AsLong(v);
195 if (n < 0 || n > MAX_STR_SIZE / size) {
196 return NULL;
197 }
198 }
199 }
200 else if (PyLong_Check(w) &&
201 (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
202 PyUnicode_Check(v) || PyBytes_Check(v)))
203 {
204 return safe_multiply(w, v);
205 }
206
207 return PyNumber_Multiply(v, w);
208 }
209
210 static PyObject *
safe_power(PyObject * v,PyObject * w)211 safe_power(PyObject *v, PyObject *w)
212 {
213 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
214 size_t vbits = _PyLong_NumBits(v);
215 size_t wbits = PyLong_AsSize_t(w);
216 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
217 return NULL;
218 }
219 if (vbits > MAX_INT_SIZE / wbits) {
220 return NULL;
221 }
222 }
223
224 return PyNumber_Power(v, w, Py_None);
225 }
226
227 static PyObject *
safe_lshift(PyObject * v,PyObject * w)228 safe_lshift(PyObject *v, PyObject *w)
229 {
230 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
231 size_t vbits = _PyLong_NumBits(v);
232 size_t wbits = PyLong_AsSize_t(w);
233 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
234 return NULL;
235 }
236 if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
237 return NULL;
238 }
239 }
240
241 return PyNumber_Lshift(v, w);
242 }
243
244 static PyObject *
safe_mod(PyObject * v,PyObject * w)245 safe_mod(PyObject *v, PyObject *w)
246 {
247 if (PyUnicode_Check(v) || PyBytes_Check(v)) {
248 return NULL;
249 }
250
251 return PyNumber_Remainder(v, w);
252 }
253
254 static int
fold_binop(expr_ty node,PyArena * arena,int optimize)255 fold_binop(expr_ty node, PyArena *arena, int optimize)
256 {
257 expr_ty lhs, rhs;
258 lhs = node->v.BinOp.left;
259 rhs = node->v.BinOp.right;
260 if (!is_const(lhs) || !is_const(rhs)) {
261 return 1;
262 }
263
264 PyObject *lv = get_const_value(lhs);
265 PyObject *rv = get_const_value(rhs);
266 PyObject *newval;
267
268 switch (node->v.BinOp.op) {
269 case Add:
270 newval = PyNumber_Add(lv, rv);
271 break;
272 case Sub:
273 newval = PyNumber_Subtract(lv, rv);
274 break;
275 case Mult:
276 newval = safe_multiply(lv, rv);
277 break;
278 case Div:
279 newval = PyNumber_TrueDivide(lv, rv);
280 break;
281 case FloorDiv:
282 newval = PyNumber_FloorDivide(lv, rv);
283 break;
284 case Mod:
285 newval = safe_mod(lv, rv);
286 break;
287 case Pow:
288 newval = safe_power(lv, rv);
289 break;
290 case LShift:
291 newval = safe_lshift(lv, rv);
292 break;
293 case RShift:
294 newval = PyNumber_Rshift(lv, rv);
295 break;
296 case BitOr:
297 newval = PyNumber_Or(lv, rv);
298 break;
299 case BitXor:
300 newval = PyNumber_Xor(lv, rv);
301 break;
302 case BitAnd:
303 newval = PyNumber_And(lv, rv);
304 break;
305 default: // Unknown operator
306 return 1;
307 }
308
309 return make_const(node, newval, arena);
310 }
311
312 static PyObject*
make_const_tuple(asdl_seq * elts)313 make_const_tuple(asdl_seq *elts)
314 {
315 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
316 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
317 if (!is_const(e)) {
318 return NULL;
319 }
320 }
321
322 PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
323 if (newval == NULL) {
324 return NULL;
325 }
326
327 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
328 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
329 PyObject *v = get_const_value(e);
330 Py_INCREF(v);
331 PyTuple_SET_ITEM(newval, i, v);
332 }
333 return newval;
334 }
335
336 static int
fold_tuple(expr_ty node,PyArena * arena,int optimize)337 fold_tuple(expr_ty node, PyArena *arena, int optimize)
338 {
339 PyObject *newval;
340
341 if (node->v.Tuple.ctx != Load)
342 return 1;
343
344 newval = make_const_tuple(node->v.Tuple.elts);
345 return make_const(node, newval, arena);
346 }
347
348 static int
fold_subscr(expr_ty node,PyArena * arena,int optimize)349 fold_subscr(expr_ty node, PyArena *arena, int optimize)
350 {
351 PyObject *newval;
352 expr_ty arg, idx;
353 slice_ty slice;
354
355 arg = node->v.Subscript.value;
356 slice = node->v.Subscript.slice;
357 if (node->v.Subscript.ctx != Load ||
358 !is_const(arg) ||
359 /* TODO: handle other types of slices */
360 slice->kind != Index_kind ||
361 !is_const(slice->v.Index.value))
362 {
363 return 1;
364 }
365
366 idx = slice->v.Index.value;
367 newval = PyObject_GetItem(get_const_value(arg), get_const_value(idx));
368 return make_const(node, newval, arena);
369 }
370
371 /* Change literal list or set of constants into constant
372 tuple or frozenset respectively.
373 Used for right operand of "in" and "not in" tests and for iterable
374 in "for" loop and comprehensions.
375 */
376 static int
fold_iter(expr_ty arg,PyArena * arena,int optimize)377 fold_iter(expr_ty arg, PyArena *arena, int optimize)
378 {
379 PyObject *newval;
380 if (arg->kind == List_kind) {
381 newval = make_const_tuple(arg->v.List.elts);
382 }
383 else if (arg->kind == Set_kind) {
384 newval = make_const_tuple(arg->v.Set.elts);
385 if (newval) {
386 Py_SETREF(newval, PyFrozenSet_New(newval));
387 }
388 }
389 else {
390 return 1;
391 }
392 return make_const(arg, newval, arena);
393 }
394
395 static int
fold_compare(expr_ty node,PyArena * arena,int optimize)396 fold_compare(expr_ty node, PyArena *arena, int optimize)
397 {
398 asdl_int_seq *ops;
399 asdl_seq *args;
400 Py_ssize_t i;
401
402 ops = node->v.Compare.ops;
403 args = node->v.Compare.comparators;
404 /* TODO: optimize cases with literal arguments. */
405 /* Change literal list or set in 'in' or 'not in' into
406 tuple or frozenset respectively. */
407 i = asdl_seq_LEN(ops) - 1;
408 int op = asdl_seq_GET(ops, i);
409 if (op == In || op == NotIn) {
410 if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, optimize)) {
411 return 0;
412 }
413 }
414 return 1;
415 }
416
417 static int astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_);
418 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_);
419 static int astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_);
420 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_);
421 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_);
422 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_);
423 static int astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_);
424 static int astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_);
425 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_);
426 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_);
427 #define CALL(FUNC, TYPE, ARG) \
428 if (!FUNC((ARG), ctx_, optimize_)) \
429 return 0;
430
431 #define CALL_OPT(FUNC, TYPE, ARG) \
432 if ((ARG) != NULL && !FUNC((ARG), ctx_, optimize_)) \
433 return 0;
434
435 #define CALL_SEQ(FUNC, TYPE, ARG) { \
436 int i; \
437 asdl_seq *seq = (ARG); /* avoid variable capture */ \
438 for (i = 0; i < asdl_seq_LEN(seq); i++) { \
439 TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
440 if (elt != NULL && !FUNC(elt, ctx_, optimize_)) \
441 return 0; \
442 } \
443 }
444
445 #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
446 int i; \
447 asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
448 for (i = 0; i < asdl_seq_LEN(seq); i++) { \
449 TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
450 if (!FUNC(elt, ctx_, optimize_)) \
451 return 0; \
452 } \
453 }
454
455 static int
isdocstring(stmt_ty s)456 isdocstring(stmt_ty s)
457 {
458 if (s->kind != Expr_kind)
459 return 0;
460 if (s->v.Expr.value->kind == Str_kind)
461 return 1;
462 if (s->v.Expr.value->kind == Constant_kind)
463 return PyUnicode_CheckExact(s->v.Expr.value->v.Constant.value);
464 return 0;
465 }
466
467 static int
astfold_body(asdl_seq * stmts,PyArena * ctx_,int optimize_)468 astfold_body(asdl_seq *stmts, PyArena *ctx_, int optimize_)
469 {
470 if (!asdl_seq_LEN(stmts)) {
471 return 1;
472 }
473 int docstring = isdocstring((stmt_ty)asdl_seq_GET(stmts, 0));
474 CALL_SEQ(astfold_stmt, stmt_ty, stmts);
475 if (docstring) {
476 return 1;
477 }
478 stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
479 if (isdocstring(st)) {
480 asdl_seq *values = _Py_asdl_seq_new(1, ctx_);
481 if (!values) {
482 return 0;
483 }
484 asdl_seq_SET(values, 0, st->v.Expr.value);
485 expr_ty expr = _Py_JoinedStr(values, st->lineno, st->col_offset, ctx_);
486 if (!expr) {
487 return 0;
488 }
489 st->v.Expr.value = expr;
490 }
491 return 1;
492 }
493
494 static int
astfold_mod(mod_ty node_,PyArena * ctx_,int optimize_)495 astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_)
496 {
497 switch (node_->kind) {
498 case Module_kind:
499 CALL(astfold_body, asdl_seq, node_->v.Module.body);
500 break;
501 case Interactive_kind:
502 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Interactive.body);
503 break;
504 case Expression_kind:
505 CALL(astfold_expr, expr_ty, node_->v.Expression.body);
506 break;
507 case Suite_kind:
508 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Suite.body);
509 break;
510 default:
511 break;
512 }
513 return 1;
514 }
515
516 static int
astfold_expr(expr_ty node_,PyArena * ctx_,int optimize_)517 astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_)
518 {
519 switch (node_->kind) {
520 case BoolOp_kind:
521 CALL_SEQ(astfold_expr, expr_ty, node_->v.BoolOp.values);
522 break;
523 case BinOp_kind:
524 CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
525 CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
526 CALL(fold_binop, expr_ty, node_);
527 break;
528 case UnaryOp_kind:
529 CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
530 CALL(fold_unaryop, expr_ty, node_);
531 break;
532 case Lambda_kind:
533 CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
534 CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
535 break;
536 case IfExp_kind:
537 CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
538 CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
539 CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
540 break;
541 case Dict_kind:
542 CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.keys);
543 CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.values);
544 break;
545 case Set_kind:
546 CALL_SEQ(astfold_expr, expr_ty, node_->v.Set.elts);
547 break;
548 case ListComp_kind:
549 CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
550 CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.ListComp.generators);
551 break;
552 case SetComp_kind:
553 CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
554 CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.SetComp.generators);
555 break;
556 case DictComp_kind:
557 CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
558 CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
559 CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.DictComp.generators);
560 break;
561 case GeneratorExp_kind:
562 CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
563 CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.GeneratorExp.generators);
564 break;
565 case Await_kind:
566 CALL(astfold_expr, expr_ty, node_->v.Await.value);
567 break;
568 case Yield_kind:
569 CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
570 break;
571 case YieldFrom_kind:
572 CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
573 break;
574 case Compare_kind:
575 CALL(astfold_expr, expr_ty, node_->v.Compare.left);
576 CALL_SEQ(astfold_expr, expr_ty, node_->v.Compare.comparators);
577 CALL(fold_compare, expr_ty, node_);
578 break;
579 case Call_kind:
580 CALL(astfold_expr, expr_ty, node_->v.Call.func);
581 CALL_SEQ(astfold_expr, expr_ty, node_->v.Call.args);
582 CALL_SEQ(astfold_keyword, keyword_ty, node_->v.Call.keywords);
583 break;
584 case FormattedValue_kind:
585 CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
586 CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
587 break;
588 case JoinedStr_kind:
589 CALL_SEQ(astfold_expr, expr_ty, node_->v.JoinedStr.values);
590 break;
591 case Attribute_kind:
592 CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
593 break;
594 case Subscript_kind:
595 CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
596 CALL(astfold_slice, slice_ty, node_->v.Subscript.slice);
597 CALL(fold_subscr, expr_ty, node_);
598 break;
599 case Starred_kind:
600 CALL(astfold_expr, expr_ty, node_->v.Starred.value);
601 break;
602 case List_kind:
603 CALL_SEQ(astfold_expr, expr_ty, node_->v.List.elts);
604 break;
605 case Tuple_kind:
606 CALL_SEQ(astfold_expr, expr_ty, node_->v.Tuple.elts);
607 CALL(fold_tuple, expr_ty, node_);
608 break;
609 case Name_kind:
610 if (_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
611 return make_const(node_, PyBool_FromLong(!optimize_), ctx_);
612 }
613 break;
614 default:
615 break;
616 }
617 return 1;
618 }
619
620 static int
astfold_slice(slice_ty node_,PyArena * ctx_,int optimize_)621 astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_)
622 {
623 switch (node_->kind) {
624 case Slice_kind:
625 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
626 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
627 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
628 break;
629 case ExtSlice_kind:
630 CALL_SEQ(astfold_slice, slice_ty, node_->v.ExtSlice.dims);
631 break;
632 case Index_kind:
633 CALL(astfold_expr, expr_ty, node_->v.Index.value);
634 break;
635 default:
636 break;
637 }
638 return 1;
639 }
640
641 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,int optimize_)642 astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_)
643 {
644 CALL(astfold_expr, expr_ty, node_->value);
645 return 1;
646 }
647
648 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,int optimize_)649 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_)
650 {
651 CALL(astfold_expr, expr_ty, node_->target);
652 CALL(astfold_expr, expr_ty, node_->iter);
653 CALL_SEQ(astfold_expr, expr_ty, node_->ifs);
654
655 CALL(fold_iter, expr_ty, node_->iter);
656 return 1;
657 }
658
659 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,int optimize_)660 astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_)
661 {
662 CALL_SEQ(astfold_arg, arg_ty, node_->args);
663 CALL_OPT(astfold_arg, arg_ty, node_->vararg);
664 CALL_SEQ(astfold_arg, arg_ty, node_->kwonlyargs);
665 CALL_SEQ(astfold_expr, expr_ty, node_->kw_defaults);
666 CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
667 CALL_SEQ(astfold_expr, expr_ty, node_->defaults);
668 return 1;
669 }
670
671 static int
astfold_arg(arg_ty node_,PyArena * ctx_,int optimize_)672 astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_)
673 {
674 CALL_OPT(astfold_expr, expr_ty, node_->annotation);
675 return 1;
676 }
677
678 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,int optimize_)679 astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_)
680 {
681 switch (node_->kind) {
682 case FunctionDef_kind:
683 CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
684 CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
685 CALL_SEQ(astfold_expr, expr_ty, node_->v.FunctionDef.decorator_list);
686 CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
687 break;
688 case AsyncFunctionDef_kind:
689 CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
690 CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
691 CALL_SEQ(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.decorator_list);
692 CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
693 break;
694 case ClassDef_kind:
695 CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.bases);
696 CALL_SEQ(astfold_keyword, keyword_ty, node_->v.ClassDef.keywords);
697 CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
698 CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.decorator_list);
699 break;
700 case Return_kind:
701 CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
702 break;
703 case Delete_kind:
704 CALL_SEQ(astfold_expr, expr_ty, node_->v.Delete.targets);
705 break;
706 case Assign_kind:
707 CALL_SEQ(astfold_expr, expr_ty, node_->v.Assign.targets);
708 CALL(astfold_expr, expr_ty, node_->v.Assign.value);
709 break;
710 case AugAssign_kind:
711 CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
712 CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
713 break;
714 case AnnAssign_kind:
715 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
716 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
717 CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
718 break;
719 case For_kind:
720 CALL(astfold_expr, expr_ty, node_->v.For.target);
721 CALL(astfold_expr, expr_ty, node_->v.For.iter);
722 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.body);
723 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.orelse);
724
725 CALL(fold_iter, expr_ty, node_->v.For.iter);
726 break;
727 case AsyncFor_kind:
728 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
729 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
730 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.body);
731 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.orelse);
732 break;
733 case While_kind:
734 CALL(astfold_expr, expr_ty, node_->v.While.test);
735 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.body);
736 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.orelse);
737 break;
738 case If_kind:
739 CALL(astfold_expr, expr_ty, node_->v.If.test);
740 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.body);
741 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.orelse);
742 break;
743 case With_kind:
744 CALL_SEQ(astfold_withitem, withitem_ty, node_->v.With.items);
745 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.With.body);
746 break;
747 case AsyncWith_kind:
748 CALL_SEQ(astfold_withitem, withitem_ty, node_->v.AsyncWith.items);
749 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncWith.body);
750 break;
751 case Raise_kind:
752 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
753 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
754 break;
755 case Try_kind:
756 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.body);
757 CALL_SEQ(astfold_excepthandler, excepthandler_ty, node_->v.Try.handlers);
758 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.orelse);
759 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.finalbody);
760 break;
761 case Assert_kind:
762 CALL(astfold_expr, expr_ty, node_->v.Assert.test);
763 CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
764 break;
765 case Expr_kind:
766 CALL(astfold_expr, expr_ty, node_->v.Expr.value);
767 break;
768 default:
769 break;
770 }
771 return 1;
772 }
773
774 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,int optimize_)775 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_)
776 {
777 switch (node_->kind) {
778 case ExceptHandler_kind:
779 CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
780 CALL_SEQ(astfold_stmt, stmt_ty, node_->v.ExceptHandler.body);
781 break;
782 default:
783 break;
784 }
785 return 1;
786 }
787
788 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,int optimize_)789 astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_)
790 {
791 CALL(astfold_expr, expr_ty, node_->context_expr);
792 CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
793 return 1;
794 }
795
796 #undef CALL
797 #undef CALL_OPT
798 #undef CALL_SEQ
799 #undef CALL_INT_SEQ
800
801 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,int optimize)802 _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize)
803 {
804 int ret = astfold_mod(mod, arena, optimize);
805 assert(ret || PyErr_Occurred());
806 return ret;
807 }
808