1 /*
2 * This file exposes PyAST_Validate interface to check the integrity
3 * of the given abstract syntax tree (potentially constructed manually).
4 */
5 #include "Python.h"
6 #include "pycore_ast.h" // asdl_stmt_seq
7 #include "pycore_pystate.h" // _PyThreadState_GET()
8
9 #include <assert.h>
10 #include <stdbool.h>
11
12 struct validator {
13 int recursion_depth; /* current recursion depth */
14 int recursion_limit; /* recursion limit */
15 };
16
17 static int validate_stmts(struct validator *, asdl_stmt_seq *);
18 static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
19 static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
20 static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
21 static int validate_stmt(struct validator *, stmt_ty);
22 static int validate_expr(struct validator *, expr_ty, expr_context_ty);
23 static int validate_pattern(struct validator *, pattern_ty, int);
24
25 static int
validate_name(PyObject * name)26 validate_name(PyObject *name)
27 {
28 assert(PyUnicode_Check(name));
29 static const char * const forbidden[] = {
30 "None",
31 "True",
32 "False",
33 NULL
34 };
35 for (int i = 0; forbidden[i] != NULL; i++) {
36 if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
37 PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
38 return 0;
39 }
40 }
41 return 1;
42 }
43
44 static int
validate_comprehension(struct validator * state,asdl_comprehension_seq * gens)45 validate_comprehension(struct validator *state, asdl_comprehension_seq *gens)
46 {
47 Py_ssize_t i;
48 if (!asdl_seq_LEN(gens)) {
49 PyErr_SetString(PyExc_ValueError, "comprehension with no generators");
50 return 0;
51 }
52 for (i = 0; i < asdl_seq_LEN(gens); i++) {
53 comprehension_ty comp = asdl_seq_GET(gens, i);
54 if (!validate_expr(state, comp->target, Store) ||
55 !validate_expr(state, comp->iter, Load) ||
56 !validate_exprs(state, comp->ifs, Load, 0))
57 return 0;
58 }
59 return 1;
60 }
61
62 static int
validate_keywords(struct validator * state,asdl_keyword_seq * keywords)63 validate_keywords(struct validator *state, asdl_keyword_seq *keywords)
64 {
65 Py_ssize_t i;
66 for (i = 0; i < asdl_seq_LEN(keywords); i++)
67 if (!validate_expr(state, (asdl_seq_GET(keywords, i))->value, Load))
68 return 0;
69 return 1;
70 }
71
72 static int
validate_args(struct validator * state,asdl_arg_seq * args)73 validate_args(struct validator *state, asdl_arg_seq *args)
74 {
75 Py_ssize_t i;
76 for (i = 0; i < asdl_seq_LEN(args); i++) {
77 arg_ty arg = asdl_seq_GET(args, i);
78 if (arg->annotation && !validate_expr(state, arg->annotation, Load))
79 return 0;
80 }
81 return 1;
82 }
83
84 static const char *
expr_context_name(expr_context_ty ctx)85 expr_context_name(expr_context_ty ctx)
86 {
87 switch (ctx) {
88 case Load:
89 return "Load";
90 case Store:
91 return "Store";
92 case Del:
93 return "Del";
94 // No default case so compiler emits warning for unhandled cases
95 }
96 Py_UNREACHABLE();
97 }
98
99 static int
validate_arguments(struct validator * state,arguments_ty args)100 validate_arguments(struct validator *state, arguments_ty args)
101 {
102 if (!validate_args(state, args->posonlyargs) || !validate_args(state, args->args)) {
103 return 0;
104 }
105 if (args->vararg && args->vararg->annotation
106 && !validate_expr(state, args->vararg->annotation, Load)) {
107 return 0;
108 }
109 if (!validate_args(state, args->kwonlyargs))
110 return 0;
111 if (args->kwarg && args->kwarg->annotation
112 && !validate_expr(state, args->kwarg->annotation, Load)) {
113 return 0;
114 }
115 if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) {
116 PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments");
117 return 0;
118 }
119 if (asdl_seq_LEN(args->kw_defaults) != asdl_seq_LEN(args->kwonlyargs)) {
120 PyErr_SetString(PyExc_ValueError, "length of kwonlyargs is not the same as "
121 "kw_defaults on arguments");
122 return 0;
123 }
124 return validate_exprs(state, args->defaults, Load, 0) && validate_exprs(state, args->kw_defaults, Load, 1);
125 }
126
127 static int
validate_constant(struct validator * state,PyObject * value)128 validate_constant(struct validator *state, PyObject *value)
129 {
130 if (value == Py_None || value == Py_Ellipsis)
131 return 1;
132
133 if (PyLong_CheckExact(value)
134 || PyFloat_CheckExact(value)
135 || PyComplex_CheckExact(value)
136 || PyBool_Check(value)
137 || PyUnicode_CheckExact(value)
138 || PyBytes_CheckExact(value))
139 return 1;
140
141 if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
142 if (++state->recursion_depth > state->recursion_limit) {
143 PyErr_SetString(PyExc_RecursionError,
144 "maximum recursion depth exceeded during compilation");
145 return 0;
146 }
147
148 PyObject *it = PyObject_GetIter(value);
149 if (it == NULL)
150 return 0;
151
152 while (1) {
153 PyObject *item = PyIter_Next(it);
154 if (item == NULL) {
155 if (PyErr_Occurred()) {
156 Py_DECREF(it);
157 return 0;
158 }
159 break;
160 }
161
162 if (!validate_constant(state, item)) {
163 Py_DECREF(it);
164 Py_DECREF(item);
165 return 0;
166 }
167 Py_DECREF(item);
168 }
169
170 Py_DECREF(it);
171 --state->recursion_depth;
172 return 1;
173 }
174
175 if (!PyErr_Occurred()) {
176 PyErr_Format(PyExc_TypeError,
177 "got an invalid type in Constant: %s",
178 _PyType_Name(Py_TYPE(value)));
179 }
180 return 0;
181 }
182
183 static int
validate_expr(struct validator * state,expr_ty exp,expr_context_ty ctx)184 validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
185 {
186 int ret = -1;
187 if (++state->recursion_depth > state->recursion_limit) {
188 PyErr_SetString(PyExc_RecursionError,
189 "maximum recursion depth exceeded during compilation");
190 return 0;
191 }
192 int check_ctx = 1;
193 expr_context_ty actual_ctx;
194
195 /* First check expression context. */
196 switch (exp->kind) {
197 case Attribute_kind:
198 actual_ctx = exp->v.Attribute.ctx;
199 break;
200 case Subscript_kind:
201 actual_ctx = exp->v.Subscript.ctx;
202 break;
203 case Starred_kind:
204 actual_ctx = exp->v.Starred.ctx;
205 break;
206 case Name_kind:
207 if (!validate_name(exp->v.Name.id)) {
208 return 0;
209 }
210 actual_ctx = exp->v.Name.ctx;
211 break;
212 case List_kind:
213 actual_ctx = exp->v.List.ctx;
214 break;
215 case Tuple_kind:
216 actual_ctx = exp->v.Tuple.ctx;
217 break;
218 default:
219 if (ctx != Load) {
220 PyErr_Format(PyExc_ValueError, "expression which can't be "
221 "assigned to in %s context", expr_context_name(ctx));
222 return 0;
223 }
224 check_ctx = 0;
225 /* set actual_ctx to prevent gcc warning */
226 actual_ctx = 0;
227 }
228 if (check_ctx && actual_ctx != ctx) {
229 PyErr_Format(PyExc_ValueError, "expression must have %s context but has %s instead",
230 expr_context_name(ctx), expr_context_name(actual_ctx));
231 return 0;
232 }
233
234 /* Now validate expression. */
235 switch (exp->kind) {
236 case BoolOp_kind:
237 if (asdl_seq_LEN(exp->v.BoolOp.values) < 2) {
238 PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values");
239 return 0;
240 }
241 ret = validate_exprs(state, exp->v.BoolOp.values, Load, 0);
242 break;
243 case BinOp_kind:
244 ret = validate_expr(state, exp->v.BinOp.left, Load) &&
245 validate_expr(state, exp->v.BinOp.right, Load);
246 break;
247 case UnaryOp_kind:
248 ret = validate_expr(state, exp->v.UnaryOp.operand, Load);
249 break;
250 case Lambda_kind:
251 ret = validate_arguments(state, exp->v.Lambda.args) &&
252 validate_expr(state, exp->v.Lambda.body, Load);
253 break;
254 case IfExp_kind:
255 ret = validate_expr(state, exp->v.IfExp.test, Load) &&
256 validate_expr(state, exp->v.IfExp.body, Load) &&
257 validate_expr(state, exp->v.IfExp.orelse, Load);
258 break;
259 case Dict_kind:
260 if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) {
261 PyErr_SetString(PyExc_ValueError,
262 "Dict doesn't have the same number of keys as values");
263 return 0;
264 }
265 /* null_ok=1 for keys expressions to allow dict unpacking to work in
266 dict literals, i.e. ``{**{a:b}}`` */
267 ret = validate_exprs(state, exp->v.Dict.keys, Load, /*null_ok=*/ 1) &&
268 validate_exprs(state, exp->v.Dict.values, Load, /*null_ok=*/ 0);
269 break;
270 case Set_kind:
271 ret = validate_exprs(state, exp->v.Set.elts, Load, 0);
272 break;
273 #define COMP(NAME) \
274 case NAME ## _kind: \
275 ret = validate_comprehension(state, exp->v.NAME.generators) && \
276 validate_expr(state, exp->v.NAME.elt, Load); \
277 break;
278 COMP(ListComp)
279 COMP(SetComp)
280 COMP(GeneratorExp)
281 #undef COMP
282 case DictComp_kind:
283 ret = validate_comprehension(state, exp->v.DictComp.generators) &&
284 validate_expr(state, exp->v.DictComp.key, Load) &&
285 validate_expr(state, exp->v.DictComp.value, Load);
286 break;
287 case Yield_kind:
288 ret = !exp->v.Yield.value || validate_expr(state, exp->v.Yield.value, Load);
289 break;
290 case YieldFrom_kind:
291 ret = validate_expr(state, exp->v.YieldFrom.value, Load);
292 break;
293 case Await_kind:
294 ret = validate_expr(state, exp->v.Await.value, Load);
295 break;
296 case Compare_kind:
297 if (!asdl_seq_LEN(exp->v.Compare.comparators)) {
298 PyErr_SetString(PyExc_ValueError, "Compare with no comparators");
299 return 0;
300 }
301 if (asdl_seq_LEN(exp->v.Compare.comparators) !=
302 asdl_seq_LEN(exp->v.Compare.ops)) {
303 PyErr_SetString(PyExc_ValueError, "Compare has a different number "
304 "of comparators and operands");
305 return 0;
306 }
307 ret = validate_exprs(state, exp->v.Compare.comparators, Load, 0) &&
308 validate_expr(state, exp->v.Compare.left, Load);
309 break;
310 case Call_kind:
311 ret = validate_expr(state, exp->v.Call.func, Load) &&
312 validate_exprs(state, exp->v.Call.args, Load, 0) &&
313 validate_keywords(state, exp->v.Call.keywords);
314 break;
315 case Constant_kind:
316 if (!validate_constant(state, exp->v.Constant.value)) {
317 return 0;
318 }
319 ret = 1;
320 break;
321 case JoinedStr_kind:
322 ret = validate_exprs(state, exp->v.JoinedStr.values, Load, 0);
323 break;
324 case FormattedValue_kind:
325 if (validate_expr(state, exp->v.FormattedValue.value, Load) == 0)
326 return 0;
327 if (exp->v.FormattedValue.format_spec) {
328 ret = validate_expr(state, exp->v.FormattedValue.format_spec, Load);
329 break;
330 }
331 ret = 1;
332 break;
333 case Attribute_kind:
334 ret = validate_expr(state, exp->v.Attribute.value, Load);
335 break;
336 case Subscript_kind:
337 ret = validate_expr(state, exp->v.Subscript.slice, Load) &&
338 validate_expr(state, exp->v.Subscript.value, Load);
339 break;
340 case Starred_kind:
341 ret = validate_expr(state, exp->v.Starred.value, ctx);
342 break;
343 case Slice_kind:
344 ret = (!exp->v.Slice.lower || validate_expr(state, exp->v.Slice.lower, Load)) &&
345 (!exp->v.Slice.upper || validate_expr(state, exp->v.Slice.upper, Load)) &&
346 (!exp->v.Slice.step || validate_expr(state, exp->v.Slice.step, Load));
347 break;
348 case List_kind:
349 ret = validate_exprs(state, exp->v.List.elts, ctx, 0);
350 break;
351 case Tuple_kind:
352 ret = validate_exprs(state, exp->v.Tuple.elts, ctx, 0);
353 break;
354 case NamedExpr_kind:
355 ret = validate_expr(state, exp->v.NamedExpr.value, Load);
356 break;
357 /* This last case doesn't have any checking. */
358 case Name_kind:
359 ret = 1;
360 break;
361 // No default case so compiler emits warning for unhandled cases
362 }
363 if (ret < 0) {
364 PyErr_SetString(PyExc_SystemError, "unexpected expression");
365 ret = 0;
366 }
367 state->recursion_depth--;
368 return ret;
369 }
370
371
372 // Note: the ensure_literal_* functions are only used to validate a restricted
373 // set of non-recursive literals that have already been checked with
374 // validate_expr, so they don't accept the validator state
375 static int
ensure_literal_number(expr_ty exp,bool allow_real,bool allow_imaginary)376 ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary)
377 {
378 assert(exp->kind == Constant_kind);
379 PyObject *value = exp->v.Constant.value;
380 return (allow_real && PyFloat_CheckExact(value)) ||
381 (allow_real && PyLong_CheckExact(value)) ||
382 (allow_imaginary && PyComplex_CheckExact(value));
383 }
384
385 static int
ensure_literal_negative(expr_ty exp,bool allow_real,bool allow_imaginary)386 ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary)
387 {
388 assert(exp->kind == UnaryOp_kind);
389 // Must be negation ...
390 if (exp->v.UnaryOp.op != USub) {
391 return 0;
392 }
393 // ... of a constant ...
394 expr_ty operand = exp->v.UnaryOp.operand;
395 if (operand->kind != Constant_kind) {
396 return 0;
397 }
398 // ... number
399 return ensure_literal_number(operand, allow_real, allow_imaginary);
400 }
401
402 static int
ensure_literal_complex(expr_ty exp)403 ensure_literal_complex(expr_ty exp)
404 {
405 assert(exp->kind == BinOp_kind);
406 expr_ty left = exp->v.BinOp.left;
407 expr_ty right = exp->v.BinOp.right;
408 // Ensure op is addition or subtraction
409 if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) {
410 return 0;
411 }
412 // Check LHS is a real number (potentially signed)
413 switch (left->kind)
414 {
415 case Constant_kind:
416 if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) {
417 return 0;
418 }
419 break;
420 case UnaryOp_kind:
421 if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) {
422 return 0;
423 }
424 break;
425 default:
426 return 0;
427 }
428 // Check RHS is an imaginary number (no separate sign allowed)
429 switch (right->kind)
430 {
431 case Constant_kind:
432 if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) {
433 return 0;
434 }
435 break;
436 default:
437 return 0;
438 }
439 return 1;
440 }
441
442 static int
validate_pattern_match_value(struct validator * state,expr_ty exp)443 validate_pattern_match_value(struct validator *state, expr_ty exp)
444 {
445 if (!validate_expr(state, exp, Load)) {
446 return 0;
447 }
448
449 switch (exp->kind)
450 {
451 case Constant_kind:
452 /* Ellipsis and immutable sequences are not allowed.
453 For True, False and None, MatchSingleton() should
454 be used */
455 if (!validate_expr(state, exp, Load)) {
456 return 0;
457 }
458 PyObject *literal = exp->v.Constant.value;
459 if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
460 PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
461 PyUnicode_CheckExact(literal)) {
462 return 1;
463 }
464 PyErr_SetString(PyExc_ValueError,
465 "unexpected constant inside of a literal pattern");
466 return 0;
467 case Attribute_kind:
468 // Constants and attribute lookups are always permitted
469 return 1;
470 case UnaryOp_kind:
471 // Negated numbers are permitted (whether real or imaginary)
472 // Compiler will complain if AST folding doesn't create a constant
473 if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) {
474 return 1;
475 }
476 break;
477 case BinOp_kind:
478 // Complex literals are permitted
479 // Compiler will complain if AST folding doesn't create a constant
480 if (ensure_literal_complex(exp)) {
481 return 1;
482 }
483 break;
484 case JoinedStr_kind:
485 // Handled in the later stages
486 return 1;
487 default:
488 break;
489 }
490 PyErr_SetString(PyExc_ValueError,
491 "patterns may only match literals and attribute lookups");
492 return 0;
493 }
494
495 static int
validate_capture(PyObject * name)496 validate_capture(PyObject *name)
497 {
498 if (_PyUnicode_EqualToASCIIString(name, "_")) {
499 PyErr_Format(PyExc_ValueError, "can't capture name '_' in patterns");
500 return 0;
501 }
502 return validate_name(name);
503 }
504
505 static int
validate_pattern(struct validator * state,pattern_ty p,int star_ok)506 validate_pattern(struct validator *state, pattern_ty p, int star_ok)
507 {
508 int ret = -1;
509 if (++state->recursion_depth > state->recursion_limit) {
510 PyErr_SetString(PyExc_RecursionError,
511 "maximum recursion depth exceeded during compilation");
512 return 0;
513 }
514 switch (p->kind) {
515 case MatchValue_kind:
516 ret = validate_pattern_match_value(state, p->v.MatchValue.value);
517 break;
518 case MatchSingleton_kind:
519 ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
520 if (!ret) {
521 PyErr_SetString(PyExc_ValueError,
522 "MatchSingleton can only contain True, False and None");
523 }
524 break;
525 case MatchSequence_kind:
526 ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
527 break;
528 case MatchMapping_kind:
529 if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
530 PyErr_SetString(PyExc_ValueError,
531 "MatchMapping doesn't have the same number of keys as patterns");
532 ret = 0;
533 break;
534 }
535
536 if (p->v.MatchMapping.rest && !validate_capture(p->v.MatchMapping.rest)) {
537 ret = 0;
538 break;
539 }
540
541 asdl_expr_seq *keys = p->v.MatchMapping.keys;
542 for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
543 expr_ty key = asdl_seq_GET(keys, i);
544 if (key->kind == Constant_kind) {
545 PyObject *literal = key->v.Constant.value;
546 if (literal == Py_None || PyBool_Check(literal)) {
547 /* validate_pattern_match_value will ensure the key
548 doesn't contain True, False and None but it is
549 syntactically valid, so we will pass those on in
550 a special case. */
551 continue;
552 }
553 }
554 if (!validate_pattern_match_value(state, key)) {
555 ret = 0;
556 break;
557 }
558 }
559
560 ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
561 break;
562 case MatchClass_kind:
563 if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
564 PyErr_SetString(PyExc_ValueError,
565 "MatchClass doesn't have the same number of keyword attributes as patterns");
566 ret = 0;
567 break;
568 }
569 if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
570 ret = 0;
571 break;
572 }
573
574 expr_ty cls = p->v.MatchClass.cls;
575 while (1) {
576 if (cls->kind == Name_kind) {
577 break;
578 }
579 else if (cls->kind == Attribute_kind) {
580 cls = cls->v.Attribute.value;
581 continue;
582 }
583 else {
584 PyErr_SetString(PyExc_ValueError,
585 "MatchClass cls field can only contain Name or Attribute nodes.");
586 ret = 0;
587 break;
588 }
589 }
590
591 for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
592 PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
593 if (!validate_name(identifier)) {
594 ret = 0;
595 break;
596 }
597 }
598
599 if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
600 ret = 0;
601 break;
602 }
603
604 ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
605 break;
606 case MatchStar_kind:
607 if (!star_ok) {
608 PyErr_SetString(PyExc_ValueError, "can't use MatchStar here");
609 ret = 0;
610 break;
611 }
612 ret = p->v.MatchStar.name == NULL || validate_capture(p->v.MatchStar.name);
613 break;
614 case MatchAs_kind:
615 if (p->v.MatchAs.name && !validate_capture(p->v.MatchAs.name)) {
616 ret = 0;
617 break;
618 }
619 if (p->v.MatchAs.pattern == NULL) {
620 ret = 1;
621 }
622 else if (p->v.MatchAs.name == NULL) {
623 PyErr_SetString(PyExc_ValueError,
624 "MatchAs must specify a target name if a pattern is given");
625 ret = 0;
626 }
627 else {
628 ret = validate_pattern(state, p->v.MatchAs.pattern, /*star_ok=*/0);
629 }
630 break;
631 case MatchOr_kind:
632 if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
633 PyErr_SetString(PyExc_ValueError,
634 "MatchOr requires at least 2 patterns");
635 ret = 0;
636 break;
637 }
638 ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
639 break;
640 // No default case, so the compiler will emit a warning if new pattern
641 // kinds are added without being handled here
642 }
643 if (ret < 0) {
644 PyErr_SetString(PyExc_SystemError, "unexpected pattern");
645 ret = 0;
646 }
647 state->recursion_depth--;
648 return ret;
649 }
650
651 static int
_validate_nonempty_seq(asdl_seq * seq,const char * what,const char * owner)652 _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
653 {
654 if (asdl_seq_LEN(seq))
655 return 1;
656 PyErr_Format(PyExc_ValueError, "empty %s on %s", what, owner);
657 return 0;
658 }
659 #define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner)
660
661 static int
validate_assignlist(struct validator * state,asdl_expr_seq * targets,expr_context_ty ctx)662 validate_assignlist(struct validator *state, asdl_expr_seq *targets, expr_context_ty ctx)
663 {
664 return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") &&
665 validate_exprs(state, targets, ctx, 0);
666 }
667
668 static int
validate_body(struct validator * state,asdl_stmt_seq * body,const char * owner)669 validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner)
670 {
671 return validate_nonempty_seq(body, "body", owner) && validate_stmts(state, body);
672 }
673
674 static int
validate_stmt(struct validator * state,stmt_ty stmt)675 validate_stmt(struct validator *state, stmt_ty stmt)
676 {
677 int ret = -1;
678 Py_ssize_t i;
679 if (++state->recursion_depth > state->recursion_limit) {
680 PyErr_SetString(PyExc_RecursionError,
681 "maximum recursion depth exceeded during compilation");
682 return 0;
683 }
684 switch (stmt->kind) {
685 case FunctionDef_kind:
686 ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") &&
687 validate_arguments(state, stmt->v.FunctionDef.args) &&
688 validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) &&
689 (!stmt->v.FunctionDef.returns ||
690 validate_expr(state, stmt->v.FunctionDef.returns, Load));
691 break;
692 case ClassDef_kind:
693 ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") &&
694 validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) &&
695 validate_keywords(state, stmt->v.ClassDef.keywords) &&
696 validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0);
697 break;
698 case Return_kind:
699 ret = !stmt->v.Return.value || validate_expr(state, stmt->v.Return.value, Load);
700 break;
701 case Delete_kind:
702 ret = validate_assignlist(state, stmt->v.Delete.targets, Del);
703 break;
704 case Assign_kind:
705 ret = validate_assignlist(state, stmt->v.Assign.targets, Store) &&
706 validate_expr(state, stmt->v.Assign.value, Load);
707 break;
708 case AugAssign_kind:
709 ret = validate_expr(state, stmt->v.AugAssign.target, Store) &&
710 validate_expr(state, stmt->v.AugAssign.value, Load);
711 break;
712 case AnnAssign_kind:
713 if (stmt->v.AnnAssign.target->kind != Name_kind &&
714 stmt->v.AnnAssign.simple) {
715 PyErr_SetString(PyExc_TypeError,
716 "AnnAssign with simple non-Name target");
717 return 0;
718 }
719 ret = validate_expr(state, stmt->v.AnnAssign.target, Store) &&
720 (!stmt->v.AnnAssign.value ||
721 validate_expr(state, stmt->v.AnnAssign.value, Load)) &&
722 validate_expr(state, stmt->v.AnnAssign.annotation, Load);
723 break;
724 case For_kind:
725 ret = validate_expr(state, stmt->v.For.target, Store) &&
726 validate_expr(state, stmt->v.For.iter, Load) &&
727 validate_body(state, stmt->v.For.body, "For") &&
728 validate_stmts(state, stmt->v.For.orelse);
729 break;
730 case AsyncFor_kind:
731 ret = validate_expr(state, stmt->v.AsyncFor.target, Store) &&
732 validate_expr(state, stmt->v.AsyncFor.iter, Load) &&
733 validate_body(state, stmt->v.AsyncFor.body, "AsyncFor") &&
734 validate_stmts(state, stmt->v.AsyncFor.orelse);
735 break;
736 case While_kind:
737 ret = validate_expr(state, stmt->v.While.test, Load) &&
738 validate_body(state, stmt->v.While.body, "While") &&
739 validate_stmts(state, stmt->v.While.orelse);
740 break;
741 case If_kind:
742 ret = validate_expr(state, stmt->v.If.test, Load) &&
743 validate_body(state, stmt->v.If.body, "If") &&
744 validate_stmts(state, stmt->v.If.orelse);
745 break;
746 case With_kind:
747 if (!validate_nonempty_seq(stmt->v.With.items, "items", "With"))
748 return 0;
749 for (i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) {
750 withitem_ty item = asdl_seq_GET(stmt->v.With.items, i);
751 if (!validate_expr(state, item->context_expr, Load) ||
752 (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
753 return 0;
754 }
755 ret = validate_body(state, stmt->v.With.body, "With");
756 break;
757 case AsyncWith_kind:
758 if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith"))
759 return 0;
760 for (i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) {
761 withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i);
762 if (!validate_expr(state, item->context_expr, Load) ||
763 (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
764 return 0;
765 }
766 ret = validate_body(state, stmt->v.AsyncWith.body, "AsyncWith");
767 break;
768 case Match_kind:
769 if (!validate_expr(state, stmt->v.Match.subject, Load)
770 || !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) {
771 return 0;
772 }
773 for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
774 match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
775 if (!validate_pattern(state, m->pattern, /*star_ok=*/0)
776 || (m->guard && !validate_expr(state, m->guard, Load))
777 || !validate_body(state, m->body, "match_case")) {
778 return 0;
779 }
780 }
781 ret = 1;
782 break;
783 case Raise_kind:
784 if (stmt->v.Raise.exc) {
785 ret = validate_expr(state, stmt->v.Raise.exc, Load) &&
786 (!stmt->v.Raise.cause || validate_expr(state, stmt->v.Raise.cause, Load));
787 break;
788 }
789 if (stmt->v.Raise.cause) {
790 PyErr_SetString(PyExc_ValueError, "Raise with cause but no exception");
791 return 0;
792 }
793 ret = 1;
794 break;
795 case Try_kind:
796 if (!validate_body(state, stmt->v.Try.body, "Try"))
797 return 0;
798 if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
799 !asdl_seq_LEN(stmt->v.Try.finalbody)) {
800 PyErr_SetString(PyExc_ValueError, "Try has neither except handlers nor finalbody");
801 return 0;
802 }
803 if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
804 asdl_seq_LEN(stmt->v.Try.orelse)) {
805 PyErr_SetString(PyExc_ValueError, "Try has orelse but no except handlers");
806 return 0;
807 }
808 for (i = 0; i < asdl_seq_LEN(stmt->v.Try.handlers); i++) {
809 excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i);
810 if ((handler->v.ExceptHandler.type &&
811 !validate_expr(state, handler->v.ExceptHandler.type, Load)) ||
812 !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler"))
813 return 0;
814 }
815 ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) ||
816 validate_stmts(state, stmt->v.Try.finalbody)) &&
817 (!asdl_seq_LEN(stmt->v.Try.orelse) ||
818 validate_stmts(state, stmt->v.Try.orelse));
819 break;
820 case Assert_kind:
821 ret = validate_expr(state, stmt->v.Assert.test, Load) &&
822 (!stmt->v.Assert.msg || validate_expr(state, stmt->v.Assert.msg, Load));
823 break;
824 case Import_kind:
825 ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import");
826 break;
827 case ImportFrom_kind:
828 if (stmt->v.ImportFrom.level < 0) {
829 PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level");
830 return 0;
831 }
832 ret = validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom");
833 break;
834 case Global_kind:
835 ret = validate_nonempty_seq(stmt->v.Global.names, "names", "Global");
836 break;
837 case Nonlocal_kind:
838 ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal");
839 break;
840 case Expr_kind:
841 ret = validate_expr(state, stmt->v.Expr.value, Load);
842 break;
843 case AsyncFunctionDef_kind:
844 ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") &&
845 validate_arguments(state, stmt->v.AsyncFunctionDef.args) &&
846 validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) &&
847 (!stmt->v.AsyncFunctionDef.returns ||
848 validate_expr(state, stmt->v.AsyncFunctionDef.returns, Load));
849 break;
850 case Pass_kind:
851 case Break_kind:
852 case Continue_kind:
853 ret = 1;
854 break;
855 // No default case so compiler emits warning for unhandled cases
856 }
857 if (ret < 0) {
858 PyErr_SetString(PyExc_SystemError, "unexpected statement");
859 ret = 0;
860 }
861 state->recursion_depth--;
862 return ret;
863 }
864
865 static int
validate_stmts(struct validator * state,asdl_stmt_seq * seq)866 validate_stmts(struct validator *state, asdl_stmt_seq *seq)
867 {
868 Py_ssize_t i;
869 for (i = 0; i < asdl_seq_LEN(seq); i++) {
870 stmt_ty stmt = asdl_seq_GET(seq, i);
871 if (stmt) {
872 if (!validate_stmt(state, stmt))
873 return 0;
874 }
875 else {
876 PyErr_SetString(PyExc_ValueError,
877 "None disallowed in statement list");
878 return 0;
879 }
880 }
881 return 1;
882 }
883
884 static int
validate_exprs(struct validator * state,asdl_expr_seq * exprs,expr_context_ty ctx,int null_ok)885 validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok)
886 {
887 Py_ssize_t i;
888 for (i = 0; i < asdl_seq_LEN(exprs); i++) {
889 expr_ty expr = asdl_seq_GET(exprs, i);
890 if (expr) {
891 if (!validate_expr(state, expr, ctx))
892 return 0;
893 }
894 else if (!null_ok) {
895 PyErr_SetString(PyExc_ValueError,
896 "None disallowed in expression list");
897 return 0;
898 }
899
900 }
901 return 1;
902 }
903
904 static int
validate_patterns(struct validator * state,asdl_pattern_seq * patterns,int star_ok)905 validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
906 {
907 Py_ssize_t i;
908 for (i = 0; i < asdl_seq_LEN(patterns); i++) {
909 pattern_ty pattern = asdl_seq_GET(patterns, i);
910 if (!validate_pattern(state, pattern, star_ok)) {
911 return 0;
912 }
913 }
914 return 1;
915 }
916
917
918 /* See comments in symtable.c. */
919 #define COMPILER_STACK_FRAME_SCALE 3
920
921 int
_PyAST_Validate(mod_ty mod)922 _PyAST_Validate(mod_ty mod)
923 {
924 int res = -1;
925 struct validator state;
926 PyThreadState *tstate;
927 int recursion_limit = Py_GetRecursionLimit();
928 int starting_recursion_depth;
929
930 /* Setup recursion depth check counters */
931 tstate = _PyThreadState_GET();
932 if (!tstate) {
933 return 0;
934 }
935 /* Be careful here to prevent overflow. */
936 starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
937 tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth;
938 state.recursion_depth = starting_recursion_depth;
939 state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
940 recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
941
942 switch (mod->kind) {
943 case Module_kind:
944 res = validate_stmts(&state, mod->v.Module.body);
945 break;
946 case Interactive_kind:
947 res = validate_stmts(&state, mod->v.Interactive.body);
948 break;
949 case Expression_kind:
950 res = validate_expr(&state, mod->v.Expression.body, Load);
951 break;
952 case FunctionType_kind:
953 res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) &&
954 validate_expr(&state, mod->v.FunctionType.returns, Load);
955 break;
956 // No default case so compiler emits warning for unhandled cases
957 }
958
959 if (res < 0) {
960 PyErr_SetString(PyExc_SystemError, "impossible module node");
961 return 0;
962 }
963
964 /* Check that the recursion depth counting balanced correctly */
965 if (res && state.recursion_depth != starting_recursion_depth) {
966 PyErr_Format(PyExc_SystemError,
967 "AST validator recursion depth mismatch (before=%d, after=%d)",
968 starting_recursion_depth, state.recursion_depth);
969 return 0;
970 }
971 return res;
972 }
973
974 PyObject *
_PyAST_GetDocString(asdl_stmt_seq * body)975 _PyAST_GetDocString(asdl_stmt_seq *body)
976 {
977 if (!asdl_seq_LEN(body)) {
978 return NULL;
979 }
980 stmt_ty st = asdl_seq_GET(body, 0);
981 if (st->kind != Expr_kind) {
982 return NULL;
983 }
984 expr_ty e = st->v.Expr.value;
985 if (e->kind == Constant_kind && PyUnicode_CheckExact(e->v.Constant.value)) {
986 return e->v.Constant.value;
987 }
988 return NULL;
989 }
990