• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* AST Optimizer */
2 #include "Python.h"
3 #include "Python-ast.h"
4 #include "ast.h"
5 
6 
7 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)8 make_const(expr_ty node, PyObject *val, PyArena *arena)
9 {
10     if (val == NULL) {
11         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
12             return 0;
13         }
14         PyErr_Clear();
15         return 1;
16     }
17     if (PyArena_AddPyObject(arena, val) < 0) {
18         Py_DECREF(val);
19         return 0;
20     }
21     node->kind = Constant_kind;
22     node->v.Constant.value = val;
23     return 1;
24 }
25 
26 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
27 
28 static PyObject*
unary_not(PyObject * v)29 unary_not(PyObject *v)
30 {
31     int r = PyObject_IsTrue(v);
32     if (r < 0)
33         return NULL;
34     return PyBool_FromLong(!r);
35 }
36 
37 static int
fold_unaryop(expr_ty node,PyArena * arena,int optimize)38 fold_unaryop(expr_ty node, PyArena *arena, int optimize)
39 {
40     expr_ty arg = node->v.UnaryOp.operand;
41 
42     if (arg->kind != Constant_kind) {
43         /* Fold not into comparison */
44         if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
45                 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
46             /* Eq and NotEq are often implemented in terms of one another, so
47                folding not (self == other) into self != other breaks implementation
48                of !=. Detecting such cases doesn't seem worthwhile.
49                Python uses </> for 'is subset'/'is superset' operations on sets.
50                They don't satisfy not folding laws. */
51             int op = asdl_seq_GET(arg->v.Compare.ops, 0);
52             switch (op) {
53             case Is:
54                 op = IsNot;
55                 break;
56             case IsNot:
57                 op = Is;
58                 break;
59             case In:
60                 op = NotIn;
61                 break;
62             case NotIn:
63                 op = In;
64                 break;
65             default:
66                 op = 0;
67             }
68             if (op) {
69                 asdl_seq_SET(arg->v.Compare.ops, 0, op);
70                 COPY_NODE(node, arg);
71                 return 1;
72             }
73         }
74         return 1;
75     }
76 
77     typedef PyObject *(*unary_op)(PyObject*);
78     static const unary_op ops[] = {
79         [Invert] = PyNumber_Invert,
80         [Not] = unary_not,
81         [UAdd] = PyNumber_Positive,
82         [USub] = PyNumber_Negative,
83     };
84     PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
85     return make_const(node, newval, arena);
86 }
87 
88 /* Check whether a collection doesn't containing too much items (including
89    subcollections).  This protects from creating a constant that needs
90    too much time for calculating a hash.
91    "limit" is the maximal number of items.
92    Returns the negative number if the total number of items exceeds the
93    limit.  Otherwise returns the limit minus the total number of items.
94 */
95 
96 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)97 check_complexity(PyObject *obj, Py_ssize_t limit)
98 {
99     if (PyTuple_Check(obj)) {
100         Py_ssize_t i;
101         limit -= PyTuple_GET_SIZE(obj);
102         for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
103             limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
104         }
105         return limit;
106     }
107     else if (PyFrozenSet_Check(obj)) {
108         Py_ssize_t i = 0;
109         PyObject *item;
110         Py_hash_t hash;
111         limit -= PySet_GET_SIZE(obj);
112         while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
113             limit = check_complexity(item, limit);
114         }
115     }
116     return limit;
117 }
118 
119 #define MAX_INT_SIZE           128  /* bits */
120 #define MAX_COLLECTION_SIZE    256  /* items */
121 #define MAX_STR_SIZE          4096  /* characters */
122 #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
123 
124 static PyObject *
safe_multiply(PyObject * v,PyObject * w)125 safe_multiply(PyObject *v, PyObject *w)
126 {
127     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
128         size_t vbits = _PyLong_NumBits(v);
129         size_t wbits = _PyLong_NumBits(w);
130         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
131             return NULL;
132         }
133         if (vbits + wbits > MAX_INT_SIZE) {
134             return NULL;
135         }
136     }
137     else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
138         Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
139                                              PySet_GET_SIZE(w);
140         if (size) {
141             long n = PyLong_AsLong(v);
142             if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
143                 return NULL;
144             }
145             if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
146                 return NULL;
147             }
148         }
149     }
150     else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
151         Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
152                                                PyBytes_GET_SIZE(w);
153         if (size) {
154             long n = PyLong_AsLong(v);
155             if (n < 0 || n > MAX_STR_SIZE / size) {
156                 return NULL;
157             }
158         }
159     }
160     else if (PyLong_Check(w) &&
161              (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
162               PyUnicode_Check(v) || PyBytes_Check(v)))
163     {
164         return safe_multiply(w, v);
165     }
166 
167     return PyNumber_Multiply(v, w);
168 }
169 
170 static PyObject *
safe_power(PyObject * v,PyObject * w)171 safe_power(PyObject *v, PyObject *w)
172 {
173     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
174         size_t vbits = _PyLong_NumBits(v);
175         size_t wbits = PyLong_AsSize_t(w);
176         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
177             return NULL;
178         }
179         if (vbits > MAX_INT_SIZE / wbits) {
180             return NULL;
181         }
182     }
183 
184     return PyNumber_Power(v, w, Py_None);
185 }
186 
187 static PyObject *
safe_lshift(PyObject * v,PyObject * w)188 safe_lshift(PyObject *v, PyObject *w)
189 {
190     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
191         size_t vbits = _PyLong_NumBits(v);
192         size_t wbits = PyLong_AsSize_t(w);
193         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
194             return NULL;
195         }
196         if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
197             return NULL;
198         }
199     }
200 
201     return PyNumber_Lshift(v, w);
202 }
203 
204 static PyObject *
safe_mod(PyObject * v,PyObject * w)205 safe_mod(PyObject *v, PyObject *w)
206 {
207     if (PyUnicode_Check(v) || PyBytes_Check(v)) {
208         return NULL;
209     }
210 
211     return PyNumber_Remainder(v, w);
212 }
213 
214 static int
fold_binop(expr_ty node,PyArena * arena,int optimize)215 fold_binop(expr_ty node, PyArena *arena, int optimize)
216 {
217     expr_ty lhs, rhs;
218     lhs = node->v.BinOp.left;
219     rhs = node->v.BinOp.right;
220     if (lhs->kind != Constant_kind || rhs->kind != Constant_kind) {
221         return 1;
222     }
223 
224     PyObject *lv = lhs->v.Constant.value;
225     PyObject *rv = rhs->v.Constant.value;
226     PyObject *newval;
227 
228     switch (node->v.BinOp.op) {
229     case Add:
230         newval = PyNumber_Add(lv, rv);
231         break;
232     case Sub:
233         newval = PyNumber_Subtract(lv, rv);
234         break;
235     case Mult:
236         newval = safe_multiply(lv, rv);
237         break;
238     case Div:
239         newval = PyNumber_TrueDivide(lv, rv);
240         break;
241     case FloorDiv:
242         newval = PyNumber_FloorDivide(lv, rv);
243         break;
244     case Mod:
245         newval = safe_mod(lv, rv);
246         break;
247     case Pow:
248         newval = safe_power(lv, rv);
249         break;
250     case LShift:
251         newval = safe_lshift(lv, rv);
252         break;
253     case RShift:
254         newval = PyNumber_Rshift(lv, rv);
255         break;
256     case BitOr:
257         newval = PyNumber_Or(lv, rv);
258         break;
259     case BitXor:
260         newval = PyNumber_Xor(lv, rv);
261         break;
262     case BitAnd:
263         newval = PyNumber_And(lv, rv);
264         break;
265     default: // Unknown operator
266         return 1;
267     }
268 
269     return make_const(node, newval, arena);
270 }
271 
272 static PyObject*
make_const_tuple(asdl_seq * elts)273 make_const_tuple(asdl_seq *elts)
274 {
275     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
276         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
277         if (e->kind != Constant_kind) {
278             return NULL;
279         }
280     }
281 
282     PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
283     if (newval == NULL) {
284         return NULL;
285     }
286 
287     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
288         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
289         PyObject *v = e->v.Constant.value;
290         Py_INCREF(v);
291         PyTuple_SET_ITEM(newval, i, v);
292     }
293     return newval;
294 }
295 
296 static int
fold_tuple(expr_ty node,PyArena * arena,int optimize)297 fold_tuple(expr_ty node, PyArena *arena, int optimize)
298 {
299     PyObject *newval;
300 
301     if (node->v.Tuple.ctx != Load)
302         return 1;
303 
304     newval = make_const_tuple(node->v.Tuple.elts);
305     return make_const(node, newval, arena);
306 }
307 
308 static int
fold_subscr(expr_ty node,PyArena * arena,int optimize)309 fold_subscr(expr_ty node, PyArena *arena, int optimize)
310 {
311     PyObject *newval;
312     expr_ty arg, idx;
313     slice_ty slice;
314 
315     arg = node->v.Subscript.value;
316     slice = node->v.Subscript.slice;
317     if (node->v.Subscript.ctx != Load ||
318             arg->kind != Constant_kind ||
319             /* TODO: handle other types of slices */
320             slice->kind != Index_kind ||
321             slice->v.Index.value->kind != Constant_kind)
322     {
323         return 1;
324     }
325 
326     idx = slice->v.Index.value;
327     newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
328     return make_const(node, newval, arena);
329 }
330 
331 /* Change literal list or set of constants into constant
332    tuple or frozenset respectively.  Change literal list of
333    non-constants into tuple.
334    Used for right operand of "in" and "not in" tests and for iterable
335    in "for" loop and comprehensions.
336 */
337 static int
fold_iter(expr_ty arg,PyArena * arena,int optimize)338 fold_iter(expr_ty arg, PyArena *arena, int optimize)
339 {
340     PyObject *newval;
341     if (arg->kind == List_kind) {
342         /* First change a list into tuple. */
343         asdl_seq *elts = arg->v.List.elts;
344         Py_ssize_t n = asdl_seq_LEN(elts);
345         for (Py_ssize_t i = 0; i < n; i++) {
346             expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
347             if (e->kind == Starred_kind) {
348                 return 1;
349             }
350         }
351         expr_context_ty ctx = arg->v.List.ctx;
352         arg->kind = Tuple_kind;
353         arg->v.Tuple.elts = elts;
354         arg->v.Tuple.ctx = ctx;
355         /* Try to create a constant tuple. */
356         newval = make_const_tuple(elts);
357     }
358     else if (arg->kind == Set_kind) {
359         newval = make_const_tuple(arg->v.Set.elts);
360         if (newval) {
361             Py_SETREF(newval, PyFrozenSet_New(newval));
362         }
363     }
364     else {
365         return 1;
366     }
367     return make_const(arg, newval, arena);
368 }
369 
370 static int
fold_compare(expr_ty node,PyArena * arena,int optimize)371 fold_compare(expr_ty node, PyArena *arena, int optimize)
372 {
373     asdl_int_seq *ops;
374     asdl_seq *args;
375     Py_ssize_t i;
376 
377     ops = node->v.Compare.ops;
378     args = node->v.Compare.comparators;
379     /* TODO: optimize cases with literal arguments. */
380     /* Change literal list or set in 'in' or 'not in' into
381        tuple or frozenset respectively. */
382     i = asdl_seq_LEN(ops) - 1;
383     int op = asdl_seq_GET(ops, i);
384     if (op == In || op == NotIn) {
385         if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, optimize)) {
386             return 0;
387         }
388     }
389     return 1;
390 }
391 
392 static int astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_);
393 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_);
394 static int astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_);
395 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_);
396 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_);
397 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_);
398 static int astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_);
399 static int astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_);
400 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_);
401 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_);
402 #define CALL(FUNC, TYPE, ARG) \
403     if (!FUNC((ARG), ctx_, optimize_)) \
404         return 0;
405 
406 #define CALL_OPT(FUNC, TYPE, ARG) \
407     if ((ARG) != NULL && !FUNC((ARG), ctx_, optimize_)) \
408         return 0;
409 
410 #define CALL_SEQ(FUNC, TYPE, ARG) { \
411     int i; \
412     asdl_seq *seq = (ARG); /* avoid variable capture */ \
413     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
414         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
415         if (elt != NULL && !FUNC(elt, ctx_, optimize_)) \
416             return 0; \
417     } \
418 }
419 
420 #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
421     int i; \
422     asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
423     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
424         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
425         if (!FUNC(elt, ctx_, optimize_)) \
426             return 0; \
427     } \
428 }
429 
430 static int
astfold_body(asdl_seq * stmts,PyArena * ctx_,int optimize_)431 astfold_body(asdl_seq *stmts, PyArena *ctx_, int optimize_)
432 {
433     int docstring = _PyAST_GetDocString(stmts) != NULL;
434     CALL_SEQ(astfold_stmt, stmt_ty, stmts);
435     if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
436         stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
437         asdl_seq *values = _Py_asdl_seq_new(1, ctx_);
438         if (!values) {
439             return 0;
440         }
441         asdl_seq_SET(values, 0, st->v.Expr.value);
442         expr_ty expr = JoinedStr(values, st->lineno, st->col_offset,
443                                  st->end_lineno, st->end_col_offset, ctx_);
444         if (!expr) {
445             return 0;
446         }
447         st->v.Expr.value = expr;
448     }
449     return 1;
450 }
451 
452 static int
astfold_mod(mod_ty node_,PyArena * ctx_,int optimize_)453 astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_)
454 {
455     switch (node_->kind) {
456     case Module_kind:
457         CALL(astfold_body, asdl_seq, node_->v.Module.body);
458         break;
459     case Interactive_kind:
460         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Interactive.body);
461         break;
462     case Expression_kind:
463         CALL(astfold_expr, expr_ty, node_->v.Expression.body);
464         break;
465     case Suite_kind:
466         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Suite.body);
467         break;
468     default:
469         break;
470     }
471     return 1;
472 }
473 
474 static int
astfold_expr(expr_ty node_,PyArena * ctx_,int optimize_)475 astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_)
476 {
477     switch (node_->kind) {
478     case BoolOp_kind:
479         CALL_SEQ(astfold_expr, expr_ty, node_->v.BoolOp.values);
480         break;
481     case BinOp_kind:
482         CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
483         CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
484         CALL(fold_binop, expr_ty, node_);
485         break;
486     case UnaryOp_kind:
487         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
488         CALL(fold_unaryop, expr_ty, node_);
489         break;
490     case Lambda_kind:
491         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
492         CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
493         break;
494     case IfExp_kind:
495         CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
496         CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
497         CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
498         break;
499     case Dict_kind:
500         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.keys);
501         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.values);
502         break;
503     case Set_kind:
504         CALL_SEQ(astfold_expr, expr_ty, node_->v.Set.elts);
505         break;
506     case ListComp_kind:
507         CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
508         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.ListComp.generators);
509         break;
510     case SetComp_kind:
511         CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
512         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.SetComp.generators);
513         break;
514     case DictComp_kind:
515         CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
516         CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
517         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.DictComp.generators);
518         break;
519     case GeneratorExp_kind:
520         CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
521         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.GeneratorExp.generators);
522         break;
523     case Await_kind:
524         CALL(astfold_expr, expr_ty, node_->v.Await.value);
525         break;
526     case Yield_kind:
527         CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
528         break;
529     case YieldFrom_kind:
530         CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
531         break;
532     case Compare_kind:
533         CALL(astfold_expr, expr_ty, node_->v.Compare.left);
534         CALL_SEQ(astfold_expr, expr_ty, node_->v.Compare.comparators);
535         CALL(fold_compare, expr_ty, node_);
536         break;
537     case Call_kind:
538         CALL(astfold_expr, expr_ty, node_->v.Call.func);
539         CALL_SEQ(astfold_expr, expr_ty, node_->v.Call.args);
540         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.Call.keywords);
541         break;
542     case FormattedValue_kind:
543         CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
544         CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
545         break;
546     case JoinedStr_kind:
547         CALL_SEQ(astfold_expr, expr_ty, node_->v.JoinedStr.values);
548         break;
549     case Attribute_kind:
550         CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
551         break;
552     case Subscript_kind:
553         CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
554         CALL(astfold_slice, slice_ty, node_->v.Subscript.slice);
555         CALL(fold_subscr, expr_ty, node_);
556         break;
557     case Starred_kind:
558         CALL(astfold_expr, expr_ty, node_->v.Starred.value);
559         break;
560     case List_kind:
561         CALL_SEQ(astfold_expr, expr_ty, node_->v.List.elts);
562         break;
563     case Tuple_kind:
564         CALL_SEQ(astfold_expr, expr_ty, node_->v.Tuple.elts);
565         CALL(fold_tuple, expr_ty, node_);
566         break;
567     case Name_kind:
568         if (_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
569             return make_const(node_, PyBool_FromLong(!optimize_), ctx_);
570         }
571         break;
572     default:
573         break;
574     }
575     return 1;
576 }
577 
578 static int
astfold_slice(slice_ty node_,PyArena * ctx_,int optimize_)579 astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_)
580 {
581     switch (node_->kind) {
582     case Slice_kind:
583         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
584         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
585         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
586         break;
587     case ExtSlice_kind:
588         CALL_SEQ(astfold_slice, slice_ty, node_->v.ExtSlice.dims);
589         break;
590     case Index_kind:
591         CALL(astfold_expr, expr_ty, node_->v.Index.value);
592         break;
593     default:
594         break;
595     }
596     return 1;
597 }
598 
599 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,int optimize_)600 astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_)
601 {
602     CALL(astfold_expr, expr_ty, node_->value);
603     return 1;
604 }
605 
606 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,int optimize_)607 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_)
608 {
609     CALL(astfold_expr, expr_ty, node_->target);
610     CALL(astfold_expr, expr_ty, node_->iter);
611     CALL_SEQ(astfold_expr, expr_ty, node_->ifs);
612 
613     CALL(fold_iter, expr_ty, node_->iter);
614     return 1;
615 }
616 
617 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,int optimize_)618 astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_)
619 {
620     CALL_SEQ(astfold_arg, arg_ty, node_->posonlyargs);
621     CALL_SEQ(astfold_arg, arg_ty, node_->args);
622     CALL_OPT(astfold_arg, arg_ty, node_->vararg);
623     CALL_SEQ(astfold_arg, arg_ty, node_->kwonlyargs);
624     CALL_SEQ(astfold_expr, expr_ty, node_->kw_defaults);
625     CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
626     CALL_SEQ(astfold_expr, expr_ty, node_->defaults);
627     return 1;
628 }
629 
630 static int
astfold_arg(arg_ty node_,PyArena * ctx_,int optimize_)631 astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_)
632 {
633     CALL_OPT(astfold_expr, expr_ty, node_->annotation);
634     return 1;
635 }
636 
637 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,int optimize_)638 astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_)
639 {
640     switch (node_->kind) {
641     case FunctionDef_kind:
642         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
643         CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
644         CALL_SEQ(astfold_expr, expr_ty, node_->v.FunctionDef.decorator_list);
645         CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
646         break;
647     case AsyncFunctionDef_kind:
648         CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
649         CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
650         CALL_SEQ(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.decorator_list);
651         CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
652         break;
653     case ClassDef_kind:
654         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.bases);
655         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.ClassDef.keywords);
656         CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
657         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.decorator_list);
658         break;
659     case Return_kind:
660         CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
661         break;
662     case Delete_kind:
663         CALL_SEQ(astfold_expr, expr_ty, node_->v.Delete.targets);
664         break;
665     case Assign_kind:
666         CALL_SEQ(astfold_expr, expr_ty, node_->v.Assign.targets);
667         CALL(astfold_expr, expr_ty, node_->v.Assign.value);
668         break;
669     case AugAssign_kind:
670         CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
671         CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
672         break;
673     case AnnAssign_kind:
674         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
675         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
676         CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
677         break;
678     case For_kind:
679         CALL(astfold_expr, expr_ty, node_->v.For.target);
680         CALL(astfold_expr, expr_ty, node_->v.For.iter);
681         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.body);
682         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.orelse);
683 
684         CALL(fold_iter, expr_ty, node_->v.For.iter);
685         break;
686     case AsyncFor_kind:
687         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
688         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
689         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.body);
690         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.orelse);
691         break;
692     case While_kind:
693         CALL(astfold_expr, expr_ty, node_->v.While.test);
694         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.body);
695         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.orelse);
696         break;
697     case If_kind:
698         CALL(astfold_expr, expr_ty, node_->v.If.test);
699         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.body);
700         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.orelse);
701         break;
702     case With_kind:
703         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.With.items);
704         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.With.body);
705         break;
706     case AsyncWith_kind:
707         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.AsyncWith.items);
708         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncWith.body);
709         break;
710     case Raise_kind:
711         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
712         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
713         break;
714     case Try_kind:
715         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.body);
716         CALL_SEQ(astfold_excepthandler, excepthandler_ty, node_->v.Try.handlers);
717         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.orelse);
718         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.finalbody);
719         break;
720     case Assert_kind:
721         CALL(astfold_expr, expr_ty, node_->v.Assert.test);
722         CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
723         break;
724     case Expr_kind:
725         CALL(astfold_expr, expr_ty, node_->v.Expr.value);
726         break;
727     default:
728         break;
729     }
730     return 1;
731 }
732 
733 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,int optimize_)734 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_)
735 {
736     switch (node_->kind) {
737     case ExceptHandler_kind:
738         CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
739         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.ExceptHandler.body);
740         break;
741     default:
742         break;
743     }
744     return 1;
745 }
746 
747 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,int optimize_)748 astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_)
749 {
750     CALL(astfold_expr, expr_ty, node_->context_expr);
751     CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
752     return 1;
753 }
754 
755 #undef CALL
756 #undef CALL_OPT
757 #undef CALL_SEQ
758 #undef CALL_INT_SEQ
759 
760 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,int optimize)761 _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize)
762 {
763     int ret = astfold_mod(mod, arena, optimize);
764     assert(ret || PyErr_Occurred());
765     return ret;
766 }
767