• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* AST Optimizer */
2 #include "Python.h"
3 #include "Python-ast.h"
4 #include "ast.h"
5 
6 
7 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)8 make_const(expr_ty node, PyObject *val, PyArena *arena)
9 {
10     if (val == NULL) {
11         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
12             return 0;
13         }
14         PyErr_Clear();
15         return 1;
16     }
17     if (PyArena_AddPyObject(arena, val) < 0) {
18         Py_DECREF(val);
19         return 0;
20     }
21     node->kind = Constant_kind;
22     node->v.Constant.value = val;
23     return 1;
24 }
25 
26 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
27 
28 static PyObject*
unary_not(PyObject * v)29 unary_not(PyObject *v)
30 {
31     int r = PyObject_IsTrue(v);
32     if (r < 0)
33         return NULL;
34     return PyBool_FromLong(!r);
35 }
36 
37 static int
fold_unaryop(expr_ty node,PyArena * arena,int optimize)38 fold_unaryop(expr_ty node, PyArena *arena, int optimize)
39 {
40     expr_ty arg = node->v.UnaryOp.operand;
41 
42     if (arg->kind != Constant_kind) {
43         /* Fold not into comparison */
44         if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
45                 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
46             /* Eq and NotEq are often implemented in terms of one another, so
47                folding not (self == other) into self != other breaks implementation
48                of !=. Detecting such cases doesn't seem worthwhile.
49                Python uses </> for 'is subset'/'is superset' operations on sets.
50                They don't satisfy not folding laws. */
51             int op = asdl_seq_GET(arg->v.Compare.ops, 0);
52             switch (op) {
53             case Is:
54                 op = IsNot;
55                 break;
56             case IsNot:
57                 op = Is;
58                 break;
59             case In:
60                 op = NotIn;
61                 break;
62             case NotIn:
63                 op = In;
64                 break;
65             default:
66                 op = 0;
67             }
68             if (op) {
69                 asdl_seq_SET(arg->v.Compare.ops, 0, op);
70                 COPY_NODE(node, arg);
71                 return 1;
72             }
73         }
74         return 1;
75     }
76 
77     typedef PyObject *(*unary_op)(PyObject*);
78     static const unary_op ops[] = {
79         [Invert] = PyNumber_Invert,
80         [Not] = unary_not,
81         [UAdd] = PyNumber_Positive,
82         [USub] = PyNumber_Negative,
83     };
84     PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
85     return make_const(node, newval, arena);
86 }
87 
88 /* Check whether a collection doesn't containing too much items (including
89    subcollections).  This protects from creating a constant that needs
90    too much time for calculating a hash.
91    "limit" is the maximal number of items.
92    Returns the negative number if the total number of items exceeds the
93    limit.  Otherwise returns the limit minus the total number of items.
94 */
95 
96 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)97 check_complexity(PyObject *obj, Py_ssize_t limit)
98 {
99     if (PyTuple_Check(obj)) {
100         Py_ssize_t i;
101         limit -= PyTuple_GET_SIZE(obj);
102         for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
103             limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
104         }
105         return limit;
106     }
107     else if (PyFrozenSet_Check(obj)) {
108         Py_ssize_t i = 0;
109         PyObject *item;
110         Py_hash_t hash;
111         limit -= PySet_GET_SIZE(obj);
112         while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
113             limit = check_complexity(item, limit);
114         }
115     }
116     return limit;
117 }
118 
119 #define MAX_INT_SIZE           128  /* bits */
120 #define MAX_COLLECTION_SIZE    256  /* items */
121 #define MAX_STR_SIZE          4096  /* characters */
122 #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
123 
124 static PyObject *
safe_multiply(PyObject * v,PyObject * w)125 safe_multiply(PyObject *v, PyObject *w)
126 {
127     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
128         size_t vbits = _PyLong_NumBits(v);
129         size_t wbits = _PyLong_NumBits(w);
130         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
131             return NULL;
132         }
133         if (vbits + wbits > MAX_INT_SIZE) {
134             return NULL;
135         }
136     }
137     else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
138         Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
139                                              PySet_GET_SIZE(w);
140         if (size) {
141             long n = PyLong_AsLong(v);
142             if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
143                 return NULL;
144             }
145             if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
146                 return NULL;
147             }
148         }
149     }
150     else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
151         Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
152                                                PyBytes_GET_SIZE(w);
153         if (size) {
154             long n = PyLong_AsLong(v);
155             if (n < 0 || n > MAX_STR_SIZE / size) {
156                 return NULL;
157             }
158         }
159     }
160     else if (PyLong_Check(w) &&
161              (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
162               PyUnicode_Check(v) || PyBytes_Check(v)))
163     {
164         return safe_multiply(w, v);
165     }
166 
167     return PyNumber_Multiply(v, w);
168 }
169 
170 static PyObject *
safe_power(PyObject * v,PyObject * w)171 safe_power(PyObject *v, PyObject *w)
172 {
173     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
174         size_t vbits = _PyLong_NumBits(v);
175         size_t wbits = PyLong_AsSize_t(w);
176         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
177             return NULL;
178         }
179         if (vbits > MAX_INT_SIZE / wbits) {
180             return NULL;
181         }
182     }
183 
184     return PyNumber_Power(v, w, Py_None);
185 }
186 
187 static PyObject *
safe_lshift(PyObject * v,PyObject * w)188 safe_lshift(PyObject *v, PyObject *w)
189 {
190     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
191         size_t vbits = _PyLong_NumBits(v);
192         size_t wbits = PyLong_AsSize_t(w);
193         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
194             return NULL;
195         }
196         if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
197             return NULL;
198         }
199     }
200 
201     return PyNumber_Lshift(v, w);
202 }
203 
204 static PyObject *
safe_mod(PyObject * v,PyObject * w)205 safe_mod(PyObject *v, PyObject *w)
206 {
207     if (PyUnicode_Check(v) || PyBytes_Check(v)) {
208         return NULL;
209     }
210 
211     return PyNumber_Remainder(v, w);
212 }
213 
214 static int
fold_binop(expr_ty node,PyArena * arena,int optimize)215 fold_binop(expr_ty node, PyArena *arena, int optimize)
216 {
217     expr_ty lhs, rhs;
218     lhs = node->v.BinOp.left;
219     rhs = node->v.BinOp.right;
220     if (lhs->kind != Constant_kind || rhs->kind != Constant_kind) {
221         return 1;
222     }
223 
224     PyObject *lv = lhs->v.Constant.value;
225     PyObject *rv = rhs->v.Constant.value;
226     PyObject *newval;
227 
228     switch (node->v.BinOp.op) {
229     case Add:
230         newval = PyNumber_Add(lv, rv);
231         break;
232     case Sub:
233         newval = PyNumber_Subtract(lv, rv);
234         break;
235     case Mult:
236         newval = safe_multiply(lv, rv);
237         break;
238     case Div:
239         newval = PyNumber_TrueDivide(lv, rv);
240         break;
241     case FloorDiv:
242         newval = PyNumber_FloorDivide(lv, rv);
243         break;
244     case Mod:
245         newval = safe_mod(lv, rv);
246         break;
247     case Pow:
248         newval = safe_power(lv, rv);
249         break;
250     case LShift:
251         newval = safe_lshift(lv, rv);
252         break;
253     case RShift:
254         newval = PyNumber_Rshift(lv, rv);
255         break;
256     case BitOr:
257         newval = PyNumber_Or(lv, rv);
258         break;
259     case BitXor:
260         newval = PyNumber_Xor(lv, rv);
261         break;
262     case BitAnd:
263         newval = PyNumber_And(lv, rv);
264         break;
265     default: // Unknown operator
266         return 1;
267     }
268 
269     return make_const(node, newval, arena);
270 }
271 
272 static PyObject*
make_const_tuple(asdl_seq * elts)273 make_const_tuple(asdl_seq *elts)
274 {
275     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
276         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
277         if (e->kind != Constant_kind) {
278             return NULL;
279         }
280     }
281 
282     PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
283     if (newval == NULL) {
284         return NULL;
285     }
286 
287     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
288         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
289         PyObject *v = e->v.Constant.value;
290         Py_INCREF(v);
291         PyTuple_SET_ITEM(newval, i, v);
292     }
293     return newval;
294 }
295 
296 static int
fold_tuple(expr_ty node,PyArena * arena,int optimize)297 fold_tuple(expr_ty node, PyArena *arena, int optimize)
298 {
299     PyObject *newval;
300 
301     if (node->v.Tuple.ctx != Load)
302         return 1;
303 
304     newval = make_const_tuple(node->v.Tuple.elts);
305     return make_const(node, newval, arena);
306 }
307 
308 static int
fold_subscr(expr_ty node,PyArena * arena,int optimize)309 fold_subscr(expr_ty node, PyArena *arena, int optimize)
310 {
311     PyObject *newval;
312     expr_ty arg, idx;
313     slice_ty slice;
314 
315     arg = node->v.Subscript.value;
316     slice = node->v.Subscript.slice;
317     if (node->v.Subscript.ctx != Load ||
318             arg->kind != Constant_kind ||
319             /* TODO: handle other types of slices */
320             slice->kind != Index_kind ||
321             slice->v.Index.value->kind != Constant_kind)
322     {
323         return 1;
324     }
325 
326     idx = slice->v.Index.value;
327     newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
328     return make_const(node, newval, arena);
329 }
330 
331 /* Change literal list or set of constants into constant
332    tuple or frozenset respectively.  Change literal list of
333    non-constants into tuple.
334    Used for right operand of "in" and "not in" tests and for iterable
335    in "for" loop and comprehensions.
336 */
337 static int
fold_iter(expr_ty arg,PyArena * arena,int optimize)338 fold_iter(expr_ty arg, PyArena *arena, int optimize)
339 {
340     PyObject *newval;
341     if (arg->kind == List_kind) {
342         /* First change a list into tuple. */
343         asdl_seq *elts = arg->v.List.elts;
344         Py_ssize_t n = asdl_seq_LEN(elts);
345         for (Py_ssize_t i = 0; i < n; i++) {
346             expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
347             if (e->kind == Starred_kind) {
348                 return 1;
349             }
350         }
351         expr_context_ty ctx = arg->v.List.ctx;
352         arg->kind = Tuple_kind;
353         arg->v.Tuple.elts = elts;
354         arg->v.Tuple.ctx = ctx;
355         /* Try to create a constant tuple. */
356         newval = make_const_tuple(elts);
357     }
358     else if (arg->kind == Set_kind) {
359         newval = make_const_tuple(arg->v.Set.elts);
360         if (newval) {
361             Py_SETREF(newval, PyFrozenSet_New(newval));
362         }
363     }
364     else {
365         return 1;
366     }
367     return make_const(arg, newval, arena);
368 }
369 
370 static int
fold_compare(expr_ty node,PyArena * arena,int optimize)371 fold_compare(expr_ty node, PyArena *arena, int optimize)
372 {
373     asdl_int_seq *ops;
374     asdl_seq *args;
375     Py_ssize_t i;
376 
377     ops = node->v.Compare.ops;
378     args = node->v.Compare.comparators;
379     /* TODO: optimize cases with literal arguments. */
380     /* Change literal list or set in 'in' or 'not in' into
381        tuple or frozenset respectively. */
382     i = asdl_seq_LEN(ops) - 1;
383     int op = asdl_seq_GET(ops, i);
384     if (op == In || op == NotIn) {
385         if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, optimize)) {
386             return 0;
387         }
388     }
389     return 1;
390 }
391 
392 static int astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_);
393 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_);
394 static int astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_);
395 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_);
396 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_);
397 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_);
398 static int astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_);
399 static int astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_);
400 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_);
401 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_);
402 #define CALL(FUNC, TYPE, ARG) \
403     if (!FUNC((ARG), ctx_, optimize_)) \
404         return 0;
405 
406 #define CALL_OPT(FUNC, TYPE, ARG) \
407     if ((ARG) != NULL && !FUNC((ARG), ctx_, optimize_)) \
408         return 0;
409 
410 #define CALL_SEQ(FUNC, TYPE, ARG) { \
411     int i; \
412     asdl_seq *seq = (ARG); /* avoid variable capture */ \
413     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
414         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
415         if (elt != NULL && !FUNC(elt, ctx_, optimize_)) \
416             return 0; \
417     } \
418 }
419 
420 #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
421     int i; \
422     asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
423     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
424         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
425         if (!FUNC(elt, ctx_, optimize_)) \
426             return 0; \
427     } \
428 }
429 
430 static int
astfold_body(asdl_seq * stmts,PyArena * ctx_,int optimize_)431 astfold_body(asdl_seq *stmts, PyArena *ctx_, int optimize_)
432 {
433     int docstring = _PyAST_GetDocString(stmts) != NULL;
434     CALL_SEQ(astfold_stmt, stmt_ty, stmts);
435     if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
436         stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
437         asdl_seq *values = _Py_asdl_seq_new(1, ctx_);
438         if (!values) {
439             return 0;
440         }
441         asdl_seq_SET(values, 0, st->v.Expr.value);
442         expr_ty expr = JoinedStr(values, st->lineno, st->col_offset,
443                                  st->end_lineno, st->end_col_offset, ctx_);
444         if (!expr) {
445             return 0;
446         }
447         st->v.Expr.value = expr;
448     }
449     return 1;
450 }
451 
452 static int
astfold_mod(mod_ty node_,PyArena * ctx_,int optimize_)453 astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_)
454 {
455     switch (node_->kind) {
456     case Module_kind:
457         CALL(astfold_body, asdl_seq, node_->v.Module.body);
458         break;
459     case Interactive_kind:
460         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Interactive.body);
461         break;
462     case Expression_kind:
463         CALL(astfold_expr, expr_ty, node_->v.Expression.body);
464         break;
465     case Suite_kind:
466         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Suite.body);
467         break;
468     default:
469         break;
470     }
471     return 1;
472 }
473 
474 static int
astfold_expr(expr_ty node_,PyArena * ctx_,int optimize_)475 astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_)
476 {
477     switch (node_->kind) {
478     case BoolOp_kind:
479         CALL_SEQ(astfold_expr, expr_ty, node_->v.BoolOp.values);
480         break;
481     case BinOp_kind:
482         CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
483         CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
484         CALL(fold_binop, expr_ty, node_);
485         break;
486     case UnaryOp_kind:
487         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
488         CALL(fold_unaryop, expr_ty, node_);
489         break;
490     case Lambda_kind:
491         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
492         CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
493         break;
494     case IfExp_kind:
495         CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
496         CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
497         CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
498         break;
499     case Dict_kind:
500         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.keys);
501         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.values);
502         break;
503     case Set_kind:
504         CALL_SEQ(astfold_expr, expr_ty, node_->v.Set.elts);
505         break;
506     case ListComp_kind:
507         CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
508         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.ListComp.generators);
509         break;
510     case SetComp_kind:
511         CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
512         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.SetComp.generators);
513         break;
514     case DictComp_kind:
515         CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
516         CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
517         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.DictComp.generators);
518         break;
519     case GeneratorExp_kind:
520         CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
521         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.GeneratorExp.generators);
522         break;
523     case Await_kind:
524         CALL(astfold_expr, expr_ty, node_->v.Await.value);
525         break;
526     case Yield_kind:
527         CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
528         break;
529     case YieldFrom_kind:
530         CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
531         break;
532     case Compare_kind:
533         CALL(astfold_expr, expr_ty, node_->v.Compare.left);
534         CALL_SEQ(astfold_expr, expr_ty, node_->v.Compare.comparators);
535         CALL(fold_compare, expr_ty, node_);
536         break;
537     case Call_kind:
538         CALL(astfold_expr, expr_ty, node_->v.Call.func);
539         CALL_SEQ(astfold_expr, expr_ty, node_->v.Call.args);
540         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.Call.keywords);
541         break;
542     case FormattedValue_kind:
543         CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
544         CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
545         break;
546     case JoinedStr_kind:
547         CALL_SEQ(astfold_expr, expr_ty, node_->v.JoinedStr.values);
548         break;
549     case Attribute_kind:
550         CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
551         break;
552     case Subscript_kind:
553         CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
554         CALL(astfold_slice, slice_ty, node_->v.Subscript.slice);
555         CALL(fold_subscr, expr_ty, node_);
556         break;
557     case Starred_kind:
558         CALL(astfold_expr, expr_ty, node_->v.Starred.value);
559         break;
560     case List_kind:
561         CALL_SEQ(astfold_expr, expr_ty, node_->v.List.elts);
562         break;
563     case Tuple_kind:
564         CALL_SEQ(astfold_expr, expr_ty, node_->v.Tuple.elts);
565         CALL(fold_tuple, expr_ty, node_);
566         break;
567     case Name_kind:
568         if (_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
569             return make_const(node_, PyBool_FromLong(!optimize_), ctx_);
570         }
571         break;
572     default:
573         break;
574     }
575     return 1;
576 }
577 
578 static int
astfold_slice(slice_ty node_,PyArena * ctx_,int optimize_)579 astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_)
580 {
581     switch (node_->kind) {
582     case Slice_kind:
583         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
584         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
585         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
586         break;
587     case ExtSlice_kind:
588         CALL_SEQ(astfold_slice, slice_ty, node_->v.ExtSlice.dims);
589         break;
590     case Index_kind:
591         CALL(astfold_expr, expr_ty, node_->v.Index.value);
592         break;
593     default:
594         break;
595     }
596     return 1;
597 }
598 
599 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,int optimize_)600 astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_)
601 {
602     CALL(astfold_expr, expr_ty, node_->value);
603     return 1;
604 }
605 
606 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,int optimize_)607 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_)
608 {
609     CALL(astfold_expr, expr_ty, node_->target);
610     CALL(astfold_expr, expr_ty, node_->iter);
611     CALL_SEQ(astfold_expr, expr_ty, node_->ifs);
612 
613     CALL(fold_iter, expr_ty, node_->iter);
614     return 1;
615 }
616 
617 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,int optimize_)618 astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_)
619 {
620     CALL_SEQ(astfold_arg, arg_ty, node_->args);
621     CALL_OPT(astfold_arg, arg_ty, node_->vararg);
622     CALL_SEQ(astfold_arg, arg_ty, node_->kwonlyargs);
623     CALL_SEQ(astfold_expr, expr_ty, node_->kw_defaults);
624     CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
625     CALL_SEQ(astfold_expr, expr_ty, node_->defaults);
626     return 1;
627 }
628 
629 static int
astfold_arg(arg_ty node_,PyArena * ctx_,int optimize_)630 astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_)
631 {
632     CALL_OPT(astfold_expr, expr_ty, node_->annotation);
633     return 1;
634 }
635 
636 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,int optimize_)637 astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_)
638 {
639     switch (node_->kind) {
640     case FunctionDef_kind:
641         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
642         CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
643         CALL_SEQ(astfold_expr, expr_ty, node_->v.FunctionDef.decorator_list);
644         CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
645         break;
646     case AsyncFunctionDef_kind:
647         CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
648         CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
649         CALL_SEQ(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.decorator_list);
650         CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
651         break;
652     case ClassDef_kind:
653         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.bases);
654         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.ClassDef.keywords);
655         CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
656         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.decorator_list);
657         break;
658     case Return_kind:
659         CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
660         break;
661     case Delete_kind:
662         CALL_SEQ(astfold_expr, expr_ty, node_->v.Delete.targets);
663         break;
664     case Assign_kind:
665         CALL_SEQ(astfold_expr, expr_ty, node_->v.Assign.targets);
666         CALL(astfold_expr, expr_ty, node_->v.Assign.value);
667         break;
668     case AugAssign_kind:
669         CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
670         CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
671         break;
672     case AnnAssign_kind:
673         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
674         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
675         CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
676         break;
677     case For_kind:
678         CALL(astfold_expr, expr_ty, node_->v.For.target);
679         CALL(astfold_expr, expr_ty, node_->v.For.iter);
680         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.body);
681         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.orelse);
682 
683         CALL(fold_iter, expr_ty, node_->v.For.iter);
684         break;
685     case AsyncFor_kind:
686         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
687         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
688         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.body);
689         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.orelse);
690         break;
691     case While_kind:
692         CALL(astfold_expr, expr_ty, node_->v.While.test);
693         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.body);
694         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.orelse);
695         break;
696     case If_kind:
697         CALL(astfold_expr, expr_ty, node_->v.If.test);
698         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.body);
699         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.orelse);
700         break;
701     case With_kind:
702         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.With.items);
703         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.With.body);
704         break;
705     case AsyncWith_kind:
706         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.AsyncWith.items);
707         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncWith.body);
708         break;
709     case Raise_kind:
710         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
711         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
712         break;
713     case Try_kind:
714         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.body);
715         CALL_SEQ(astfold_excepthandler, excepthandler_ty, node_->v.Try.handlers);
716         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.orelse);
717         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.finalbody);
718         break;
719     case Assert_kind:
720         CALL(astfold_expr, expr_ty, node_->v.Assert.test);
721         CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
722         break;
723     case Expr_kind:
724         CALL(astfold_expr, expr_ty, node_->v.Expr.value);
725         break;
726     default:
727         break;
728     }
729     return 1;
730 }
731 
732 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,int optimize_)733 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_)
734 {
735     switch (node_->kind) {
736     case ExceptHandler_kind:
737         CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
738         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.ExceptHandler.body);
739         break;
740     default:
741         break;
742     }
743     return 1;
744 }
745 
746 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,int optimize_)747 astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_)
748 {
749     CALL(astfold_expr, expr_ty, node_->context_expr);
750     CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
751     return 1;
752 }
753 
754 #undef CALL
755 #undef CALL_OPT
756 #undef CALL_SEQ
757 #undef CALL_INT_SEQ
758 
759 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,int optimize)760 _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize)
761 {
762     int ret = astfold_mod(mod, arena, optimize);
763     assert(ret || PyErr_Occurred());
764     return ret;
765 }
766