1#! /usr/bin/env python 2"""Generate C code from an ASDL description.""" 3 4import sys 5import textwrap 6import types 7 8from argparse import ArgumentParser 9from contextlib import contextmanager 10from pathlib import Path 11 12import asdl 13 14TABSIZE = 4 15MAX_COL = 80 16AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n" 17 18builtin_type_to_c_type = { 19 "identifier": "PyUnicode_Type", 20 "string": "PyUnicode_Type", 21 "int": "PyLong_Type", 22 "constant": "PyBaseObject_Type", 23} 24 25def get_c_type(name): 26 """Return a string for the C name of the type. 27 28 This function special cases the default types provided by asdl. 29 """ 30 if name in asdl.builtin_types: 31 return name 32 else: 33 return "%s_ty" % name 34 35def reflow_lines(s, depth): 36 """Reflow the line s indented depth tabs. 37 38 Return a sequence of lines where no line extends beyond MAX_COL 39 when properly indented. The first line is properly indented based 40 exclusively on depth * TABSIZE. All following lines -- these are 41 the reflowed lines generated by this function -- start at the same 42 column as the first character beyond the opening { in the first 43 line. 44 """ 45 size = MAX_COL - depth * TABSIZE 46 if len(s) < size: 47 return [s] 48 49 lines = [] 50 cur = s 51 padding = "" 52 while len(cur) > size: 53 i = cur.rfind(' ', 0, size) 54 # XXX this should be fixed for real 55 if i == -1 and 'GeneratorExp' in cur: 56 i = size + 3 57 assert i != -1, "Impossible line %d to reflow: %r" % (size, s) 58 lines.append(padding + cur[:i]) 59 if len(lines) == 1: 60 # find new size based on brace 61 j = cur.find('{', 0, i) 62 if j >= 0: 63 j += 2 # account for the brace and the space after it 64 size -= j 65 padding = " " * j 66 else: 67 j = cur.find('(', 0, i) 68 if j >= 0: 69 j += 1 # account for the paren (no space after it) 70 size -= j 71 padding = " " * j 72 cur = cur[i+1:] 73 else: 74 lines.append(padding + cur) 75 return lines 76 77def reflow_c_string(s, depth): 78 return '"%s"' % s.replace('\n', '\\n"\n%s"' % (' ' * depth * TABSIZE)) 79 80def is_simple(sum_type): 81 """Return True if a sum is a simple. 82 83 A sum is simple if its types have no fields and itself 84 doesn't have any attributes. Instances of these types are 85 cached at C level, and they act like singletons when propagating 86 parser generated nodes into Python level, e.g. 87 unaryop = Invert | Not | UAdd | USub 88 """ 89 90 return not ( 91 sum_type.attributes or 92 any(constructor.fields for constructor in sum_type.types) 93 ) 94 95def asdl_of(name, obj): 96 if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor): 97 fields = ", ".join(map(str, obj.fields)) 98 if fields: 99 fields = "({})".format(fields) 100 return "{}{}".format(name, fields) 101 else: 102 if is_simple(obj): 103 types = " | ".join(type.name for type in obj.types) 104 else: 105 sep = "\n{}| ".format(" " * (len(name) + 1)) 106 types = sep.join( 107 asdl_of(type.name, type) for type in obj.types 108 ) 109 return "{} = {}".format(name, types) 110 111class EmitVisitor(asdl.VisitorBase): 112 """Visit that emits lines""" 113 114 def __init__(self, file, metadata = None): 115 self.file = file 116 self._metadata = metadata 117 super(EmitVisitor, self).__init__() 118 119 def emit(self, s, depth, reflow=True): 120 # XXX reflow long lines? 121 if reflow: 122 lines = reflow_lines(s, depth) 123 else: 124 lines = [s] 125 for line in lines: 126 if line: 127 line = (" " * TABSIZE * depth) + line 128 self.file.write(line + "\n") 129 130 @property 131 def metadata(self): 132 if self._metadata is None: 133 raise ValueError( 134 "%s was expecting to be annnotated with metadata" 135 % type(self).__name__ 136 ) 137 return self._metadata 138 139 @metadata.setter 140 def metadata(self, value): 141 self._metadata = value 142 143class MetadataVisitor(asdl.VisitorBase): 144 ROOT_TYPE = "AST" 145 146 def __init__(self, *args, **kwargs): 147 super().__init__(*args, **kwargs) 148 149 # Metadata: 150 # - simple_sums: Tracks the list of compound type 151 # names where all the constructors 152 # belonging to that type lack of any 153 # fields. 154 # - identifiers: All identifiers used in the AST declarations 155 # - singletons: List of all constructors that originates from 156 # simple sums. 157 # - types: List of all top level type names 158 # 159 self.metadata = types.SimpleNamespace( 160 simple_sums=set(), 161 identifiers=set(), 162 singletons=set(), 163 types={self.ROOT_TYPE}, 164 ) 165 166 def visitModule(self, mod): 167 for dfn in mod.dfns: 168 self.visit(dfn) 169 170 def visitType(self, type): 171 self.visit(type.value, type.name) 172 173 def visitSum(self, sum, name): 174 self.metadata.types.add(name) 175 176 simple_sum = is_simple(sum) 177 if simple_sum: 178 self.metadata.simple_sums.add(name) 179 180 for constructor in sum.types: 181 if simple_sum: 182 self.metadata.singletons.add(constructor.name) 183 self.visitConstructor(constructor) 184 self.visitFields(sum.attributes) 185 186 def visitConstructor(self, constructor): 187 self.metadata.types.add(constructor.name) 188 self.visitFields(constructor.fields) 189 190 def visitProduct(self, product, name): 191 self.metadata.types.add(name) 192 self.visitFields(product.attributes) 193 self.visitFields(product.fields) 194 195 def visitFields(self, fields): 196 for field in fields: 197 self.visitField(field) 198 199 def visitField(self, field): 200 self.metadata.identifiers.add(field.name) 201 202 203class TypeDefVisitor(EmitVisitor): 204 def visitModule(self, mod): 205 for dfn in mod.dfns: 206 self.visit(dfn) 207 208 def visitType(self, type, depth=0): 209 self.visit(type.value, type.name, depth) 210 211 def visitSum(self, sum, name, depth): 212 if is_simple(sum): 213 self.simple_sum(sum, name, depth) 214 else: 215 self.sum_with_constructors(sum, name, depth) 216 217 def simple_sum(self, sum, name, depth): 218 enum = [] 219 for i in range(len(sum.types)): 220 type = sum.types[i] 221 enum.append("%s=%d" % (type.name, i + 1)) 222 enums = ", ".join(enum) 223 ctype = get_c_type(name) 224 s = "typedef enum _%s { %s } %s;" % (name, enums, ctype) 225 self.emit(s, depth) 226 self.emit("", depth) 227 228 def sum_with_constructors(self, sum, name, depth): 229 ctype = get_c_type(name) 230 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 231 self.emit(s, depth) 232 self.emit("", depth) 233 234 def visitProduct(self, product, name, depth): 235 ctype = get_c_type(name) 236 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 237 self.emit(s, depth) 238 self.emit("", depth) 239 240class SequenceDefVisitor(EmitVisitor): 241 def visitModule(self, mod): 242 for dfn in mod.dfns: 243 self.visit(dfn) 244 245 def visitType(self, type, depth=0): 246 self.visit(type.value, type.name, depth) 247 248 def visitSum(self, sum, name, depth): 249 if is_simple(sum): 250 return 251 self.emit_sequence_constructor(name, depth) 252 253 def emit_sequence_constructor(self, name,depth): 254 ctype = get_c_type(name) 255 self.emit("""\ 256typedef struct { 257 _ASDL_SEQ_HEAD 258 %(ctype)s typed_elements[1]; 259} asdl_%(name)s_seq;""" % locals(), reflow=False, depth=depth) 260 self.emit("", depth) 261 self.emit("asdl_%(name)s_seq *_Py_asdl_%(name)s_seq_new(Py_ssize_t size, PyArena *arena);" % locals(), depth) 262 self.emit("", depth) 263 264 def visitProduct(self, product, name, depth): 265 self.emit_sequence_constructor(name, depth) 266 267class StructVisitor(EmitVisitor): 268 """Visitor to generate typedefs for AST.""" 269 270 def visitModule(self, mod): 271 for dfn in mod.dfns: 272 self.visit(dfn) 273 274 def visitType(self, type, depth=0): 275 self.visit(type.value, type.name, depth) 276 277 def visitSum(self, sum, name, depth): 278 if not is_simple(sum): 279 self.sum_with_constructors(sum, name, depth) 280 281 def sum_with_constructors(self, sum, name, depth): 282 def emit(s, depth=depth): 283 self.emit(s % sys._getframe(1).f_locals, depth) 284 enum = [] 285 for i in range(len(sum.types)): 286 type = sum.types[i] 287 enum.append("%s_kind=%d" % (type.name, i + 1)) 288 289 emit("enum _%(name)s_kind {" + ", ".join(enum) + "};") 290 291 emit("struct _%(name)s {") 292 emit("enum _%(name)s_kind kind;", depth + 1) 293 emit("union {", depth + 1) 294 for t in sum.types: 295 self.visit(t, depth + 2) 296 emit("} v;", depth + 1) 297 for field in sum.attributes: 298 # rudimentary attribute handling 299 type = str(field.type) 300 assert type in asdl.builtin_types, type 301 emit("%s %s;" % (type, field.name), depth + 1); 302 emit("};") 303 emit("") 304 305 def visitConstructor(self, cons, depth): 306 if cons.fields: 307 self.emit("struct {", depth) 308 for f in cons.fields: 309 self.visit(f, depth + 1) 310 self.emit("} %s;" % cons.name, depth) 311 self.emit("", depth) 312 313 def visitField(self, field, depth): 314 # XXX need to lookup field.type, because it might be something 315 # like a builtin... 316 ctype = get_c_type(field.type) 317 name = field.name 318 if field.seq: 319 if field.type in self.metadata.simple_sums: 320 self.emit("asdl_int_seq *%(name)s;" % locals(), depth) 321 else: 322 _type = field.type 323 self.emit("asdl_%(_type)s_seq *%(name)s;" % locals(), depth) 324 else: 325 self.emit("%(ctype)s %(name)s;" % locals(), depth) 326 327 def visitProduct(self, product, name, depth): 328 self.emit("struct _%(name)s {" % locals(), depth) 329 for f in product.fields: 330 self.visit(f, depth + 1) 331 for field in product.attributes: 332 # rudimentary attribute handling 333 type = str(field.type) 334 assert type in asdl.builtin_types, type 335 self.emit("%s %s;" % (type, field.name), depth + 1); 336 self.emit("};", depth) 337 self.emit("", depth) 338 339 340def ast_func_name(name): 341 return f"_PyAST_{name}" 342 343 344class PrototypeVisitor(EmitVisitor): 345 """Generate function prototypes for the .h file""" 346 347 def visitModule(self, mod): 348 for dfn in mod.dfns: 349 self.visit(dfn) 350 351 def visitType(self, type): 352 self.visit(type.value, type.name) 353 354 def visitSum(self, sum, name): 355 if is_simple(sum): 356 pass # XXX 357 else: 358 for t in sum.types: 359 self.visit(t, name, sum.attributes) 360 361 def get_args(self, fields): 362 """Return list of C argument info, one for each field. 363 364 Argument info is 3-tuple of a C type, variable name, and flag 365 that is true if type can be NULL. 366 """ 367 args = [] 368 unnamed = {} 369 for f in fields: 370 if f.name is None: 371 name = f.type 372 c = unnamed[name] = unnamed.get(name, 0) + 1 373 if c > 1: 374 name = "name%d" % (c - 1) 375 else: 376 name = f.name 377 # XXX should extend get_c_type() to handle this 378 if f.seq: 379 if f.type in self.metadata.simple_sums: 380 ctype = "asdl_int_seq *" 381 else: 382 ctype = f"asdl_{f.type}_seq *" 383 else: 384 ctype = get_c_type(f.type) 385 args.append((ctype, name, f.opt or f.seq)) 386 return args 387 388 def visitConstructor(self, cons, type, attrs): 389 args = self.get_args(cons.fields) 390 attrs = self.get_args(attrs) 391 ctype = get_c_type(type) 392 self.emit_function(cons.name, ctype, args, attrs) 393 394 def emit_function(self, name, ctype, args, attrs, union=True): 395 args = args + attrs 396 if args: 397 argstr = ", ".join(["%s %s" % (atype, aname) 398 for atype, aname, opt in args]) 399 argstr += ", PyArena *arena" 400 else: 401 argstr = "PyArena *arena" 402 self.emit("%s %s(%s);" % (ctype, ast_func_name(name), argstr), False) 403 404 def visitProduct(self, prod, name): 405 self.emit_function(name, get_c_type(name), 406 self.get_args(prod.fields), 407 self.get_args(prod.attributes), 408 union=False) 409 410 411class FunctionVisitor(PrototypeVisitor): 412 """Visitor to generate constructor functions for AST.""" 413 414 def emit_function(self, name, ctype, args, attrs, union=True): 415 def emit(s, depth=0, reflow=True): 416 self.emit(s, depth, reflow) 417 argstr = ", ".join(["%s %s" % (atype, aname) 418 for atype, aname, opt in args + attrs]) 419 if argstr: 420 argstr += ", PyArena *arena" 421 else: 422 argstr = "PyArena *arena" 423 self.emit("%s" % ctype, 0) 424 emit("%s(%s)" % (ast_func_name(name), argstr)) 425 emit("{") 426 emit("%s p;" % ctype, 1) 427 for argtype, argname, opt in args: 428 if not opt and argtype != "int": 429 emit("if (!%s) {" % argname, 1) 430 emit("PyErr_SetString(PyExc_ValueError,", 2) 431 msg = "field '%s' is required for %s" % (argname, name) 432 emit(' "%s");' % msg, 433 2, reflow=False) 434 emit('return NULL;', 2) 435 emit('}', 1) 436 437 emit("p = (%s)_PyArena_Malloc(arena, sizeof(*p));" % ctype, 1); 438 emit("if (!p)", 1) 439 emit("return NULL;", 2) 440 if union: 441 self.emit_body_union(name, args, attrs) 442 else: 443 self.emit_body_struct(name, args, attrs) 444 emit("return p;", 1) 445 emit("}") 446 emit("") 447 448 def emit_body_union(self, name, args, attrs): 449 def emit(s, depth=0, reflow=True): 450 self.emit(s, depth, reflow) 451 emit("p->kind = %s_kind;" % name, 1) 452 for argtype, argname, opt in args: 453 emit("p->v.%s.%s = %s;" % (name, argname, argname), 1) 454 for argtype, argname, opt in attrs: 455 emit("p->%s = %s;" % (argname, argname), 1) 456 457 def emit_body_struct(self, name, args, attrs): 458 def emit(s, depth=0, reflow=True): 459 self.emit(s, depth, reflow) 460 for argtype, argname, opt in args: 461 emit("p->%s = %s;" % (argname, argname), 1) 462 for argtype, argname, opt in attrs: 463 emit("p->%s = %s;" % (argname, argname), 1) 464 465 466class PickleVisitor(EmitVisitor): 467 468 def visitModule(self, mod): 469 for dfn in mod.dfns: 470 self.visit(dfn) 471 472 def visitType(self, type): 473 self.visit(type.value, type.name) 474 475 def visitSum(self, sum, name): 476 pass 477 478 def visitProduct(self, sum, name): 479 pass 480 481 def visitConstructor(self, cons, name): 482 pass 483 484 def visitField(self, sum): 485 pass 486 487 488class Obj2ModPrototypeVisitor(PickleVisitor): 489 def visitProduct(self, prod, name): 490 code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena);" 491 self.emit(code % (name, get_c_type(name)), 0) 492 493 visitSum = visitProduct 494 495 496class Obj2ModVisitor(PickleVisitor): 497 498 attribute_special_defaults = { 499 "end_lineno": "lineno", 500 "end_col_offset": "col_offset", 501 } 502 503 @contextmanager 504 def recursive_call(self, node, level): 505 self.emit('if (_Py_EnterRecursiveCall(" while traversing \'%s\' node")) {' % node, level, reflow=False) 506 self.emit('goto failed;', level + 1) 507 self.emit('}', level) 508 yield 509 self.emit('_Py_LeaveRecursiveCall();', level) 510 511 def funcHeader(self, name): 512 ctype = get_c_type(name) 513 self.emit("int", 0) 514 self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 515 self.emit("{", 0) 516 self.emit("int isinstance;", 1) 517 self.emit("", 0) 518 519 def sumTrailer(self, name, add_label=False): 520 self.emit("", 0) 521 # there's really nothing more we can do if this fails ... 522 error = "expected some sort of %s, but got %%R" % name 523 format = "PyErr_Format(PyExc_TypeError, \"%s\", obj);" 524 self.emit(format % error, 1, reflow=False) 525 if add_label: 526 self.emit("failed:", 1) 527 self.emit("Py_XDECREF(tmp);", 1) 528 self.emit("return -1;", 1) 529 self.emit("}", 0) 530 self.emit("", 0) 531 532 def simpleSum(self, sum, name): 533 self.funcHeader(name) 534 for t in sum.types: 535 line = ("isinstance = PyObject_IsInstance(obj, " 536 "state->%s_type);") 537 self.emit(line % (t.name,), 1) 538 self.emit("if (isinstance == -1) {", 1) 539 self.emit("return -1;", 2) 540 self.emit("}", 1) 541 self.emit("if (isinstance) {", 1) 542 self.emit("*out = %s;" % t.name, 2) 543 self.emit("return 0;", 2) 544 self.emit("}", 1) 545 self.sumTrailer(name) 546 547 def buildArgs(self, fields): 548 return ", ".join(fields + ["arena"]) 549 550 def complexSum(self, sum, name): 551 self.funcHeader(name) 552 self.emit("PyObject *tmp = NULL;", 1) 553 self.emit("PyObject *tp;", 1) 554 for a in sum.attributes: 555 self.visitAttributeDeclaration(a, name, sum=sum) 556 self.emit("", 0) 557 # XXX: should we only do this for 'expr'? 558 self.emit("if (obj == Py_None) {", 1) 559 self.emit("*out = NULL;", 2) 560 self.emit("return 0;", 2) 561 self.emit("}", 1) 562 for a in sum.attributes: 563 self.visitField(a, name, sum=sum, depth=1) 564 for t in sum.types: 565 self.emit("tp = state->%s_type;" % (t.name,), 1) 566 self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1) 567 self.emit("if (isinstance == -1) {", 1) 568 self.emit("return -1;", 2) 569 self.emit("}", 1) 570 self.emit("if (isinstance) {", 1) 571 for f in t.fields: 572 self.visitFieldDeclaration(f, t.name, sum=sum, depth=2) 573 self.emit("", 0) 574 for f in t.fields: 575 self.visitField(f, t.name, sum=sum, depth=2) 576 args = [f.name for f in t.fields] + [a.name for a in sum.attributes] 577 self.emit("*out = %s(%s);" % (ast_func_name(t.name), self.buildArgs(args)), 2) 578 self.emit("if (*out == NULL) goto failed;", 2) 579 self.emit("return 0;", 2) 580 self.emit("}", 1) 581 self.sumTrailer(name, True) 582 583 def visitAttributeDeclaration(self, a, name, sum=sum): 584 ctype = get_c_type(a.type) 585 self.emit("%s %s;" % (ctype, a.name), 1) 586 587 def visitSum(self, sum, name): 588 if is_simple(sum): 589 self.simpleSum(sum, name) 590 else: 591 self.complexSum(sum, name) 592 593 def visitProduct(self, prod, name): 594 ctype = get_c_type(name) 595 self.emit("int", 0) 596 self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 597 self.emit("{", 0) 598 self.emit("PyObject* tmp = NULL;", 1) 599 for f in prod.fields: 600 self.visitFieldDeclaration(f, name, prod=prod, depth=1) 601 for a in prod.attributes: 602 self.visitFieldDeclaration(a, name, prod=prod, depth=1) 603 self.emit("", 0) 604 for f in prod.fields: 605 self.visitField(f, name, prod=prod, depth=1) 606 for a in prod.attributes: 607 self.visitField(a, name, prod=prod, depth=1) 608 args = [f.name for f in prod.fields] 609 args.extend([a.name for a in prod.attributes]) 610 self.emit("*out = %s(%s);" % (ast_func_name(name), self.buildArgs(args)), 1) 611 self.emit("if (*out == NULL) goto failed;", 1) 612 self.emit("return 0;", 1) 613 self.emit("failed:", 0) 614 self.emit("Py_XDECREF(tmp);", 1) 615 self.emit("return -1;", 1) 616 self.emit("}", 0) 617 self.emit("", 0) 618 619 def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0): 620 ctype = get_c_type(field.type) 621 if field.seq: 622 if self.isSimpleType(field): 623 self.emit("asdl_int_seq* %s;" % field.name, depth) 624 else: 625 _type = field.type 626 self.emit(f"asdl_{field.type}_seq* {field.name};", depth) 627 else: 628 ctype = get_c_type(field.type) 629 self.emit("%s %s;" % (ctype, field.name), depth) 630 631 def isNumeric(self, field): 632 return get_c_type(field.type) in ("int", "bool") 633 634 def isSimpleType(self, field): 635 return field.type in self.metadata.simple_sums or self.isNumeric(field) 636 637 def visitField(self, field, name, sum=None, prod=None, depth=0): 638 ctype = get_c_type(field.type) 639 line = "if (PyObject_GetOptionalAttr(obj, state->%s, &tmp) < 0) {" 640 self.emit(line % field.name, depth) 641 self.emit("return -1;", depth+1) 642 self.emit("}", depth) 643 if field.seq: 644 self.emit("if (tmp == NULL) {", depth) 645 self.emit("tmp = PyList_New(0);", depth+1) 646 self.emit("if (tmp == NULL) {", depth+1) 647 self.emit("return -1;", depth+2) 648 self.emit("}", depth+1) 649 self.emit("}", depth) 650 self.emit("{", depth) 651 else: 652 if not field.opt: 653 self.emit("if (tmp == NULL) {", depth) 654 message = "required field \\\"%s\\\" missing from %s" % (field.name, name) 655 format = "PyErr_SetString(PyExc_TypeError, \"%s\");" 656 self.emit(format % message, depth+1, reflow=False) 657 self.emit("return -1;", depth+1) 658 else: 659 self.emit("if (tmp == NULL || tmp == Py_None) {", depth) 660 self.emit("Py_CLEAR(tmp);", depth+1) 661 if self.isNumeric(field): 662 if field.name in self.attribute_special_defaults: 663 self.emit( 664 "%s = %s;" % (field.name, self.attribute_special_defaults[field.name]), 665 depth+1, 666 ) 667 else: 668 self.emit("%s = 0;" % field.name, depth+1) 669 elif not self.isSimpleType(field): 670 self.emit("%s = NULL;" % field.name, depth+1) 671 else: 672 raise TypeError("could not determine the default value for %s" % field.name) 673 self.emit("}", depth) 674 self.emit("else {", depth) 675 676 self.emit("int res;", depth+1) 677 if field.seq: 678 self.emit("Py_ssize_t len;", depth+1) 679 self.emit("Py_ssize_t i;", depth+1) 680 self.emit("if (!PyList_Check(tmp)) {", depth+1) 681 self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must " 682 "be a list, not a %%.200s\", _PyType_Name(Py_TYPE(tmp)));" % 683 (name, field.name), 684 depth+2, reflow=False) 685 self.emit("goto failed;", depth+2) 686 self.emit("}", depth+1) 687 self.emit("len = PyList_GET_SIZE(tmp);", depth+1) 688 if self.isSimpleType(field): 689 self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1) 690 else: 691 self.emit("%s = _Py_asdl_%s_seq_new(len, arena);" % (field.name, field.type), depth+1) 692 self.emit("if (%s == NULL) goto failed;" % field.name, depth+1) 693 self.emit("for (i = 0; i < len; i++) {", depth+1) 694 self.emit("%s val;" % ctype, depth+2) 695 self.emit("PyObject *tmp2 = Py_NewRef(PyList_GET_ITEM(tmp, i));", depth+2) 696 with self.recursive_call(name, depth+2): 697 self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" % 698 field.type, depth+2, reflow=False) 699 self.emit("Py_DECREF(tmp2);", depth+2) 700 self.emit("if (res != 0) goto failed;", depth+2) 701 self.emit("if (len != PyList_GET_SIZE(tmp)) {", depth+2) 702 self.emit("PyErr_SetString(PyExc_RuntimeError, \"%s field \\\"%s\\\" " 703 "changed size during iteration\");" % 704 (name, field.name), 705 depth+3, reflow=False) 706 self.emit("goto failed;", depth+3) 707 self.emit("}", depth+2) 708 self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2) 709 self.emit("}", depth+1) 710 else: 711 with self.recursive_call(name, depth+1): 712 self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" % 713 (field.type, field.name), depth+1) 714 self.emit("if (res != 0) goto failed;", depth+1) 715 716 self.emit("Py_CLEAR(tmp);", depth+1) 717 self.emit("}", depth) 718 719 720class SequenceConstructorVisitor(EmitVisitor): 721 def visitModule(self, mod): 722 for dfn in mod.dfns: 723 self.visit(dfn) 724 725 def visitType(self, type): 726 self.visit(type.value, type.name) 727 728 def visitProduct(self, prod, name): 729 self.emit_sequence_constructor(name, get_c_type(name)) 730 731 def visitSum(self, sum, name): 732 if not is_simple(sum): 733 self.emit_sequence_constructor(name, get_c_type(name)) 734 735 def emit_sequence_constructor(self, name, type): 736 self.emit(f"GENERATE_ASDL_SEQ_CONSTRUCTOR({name}, {type})", depth=0) 737 738class PyTypesDeclareVisitor(PickleVisitor): 739 740 def visitProduct(self, prod, name): 741 self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0) 742 if prod.attributes: 743 self.emit("static const char * const %s_attributes[] = {" % name, 0) 744 for a in prod.attributes: 745 self.emit('"%s",' % a.name, 1) 746 self.emit("};", 0) 747 if prod.fields: 748 self.emit("static const char * const %s_fields[]={" % name,0) 749 for f in prod.fields: 750 self.emit('"%s",' % f.name, 1) 751 self.emit("};", 0) 752 753 def visitSum(self, sum, name): 754 if sum.attributes: 755 self.emit("static const char * const %s_attributes[] = {" % name, 0) 756 for a in sum.attributes: 757 self.emit('"%s",' % a.name, 1) 758 self.emit("};", 0) 759 ptype = "void*" 760 if is_simple(sum): 761 ptype = get_c_type(name) 762 self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0) 763 for t in sum.types: 764 self.visitConstructor(t, name) 765 766 def visitConstructor(self, cons, name): 767 if cons.fields: 768 self.emit("static const char * const %s_fields[]={" % cons.name, 0) 769 for t in cons.fields: 770 self.emit('"%s",' % t.name, 1) 771 self.emit("};",0) 772 773 774class AnnotationsVisitor(PickleVisitor): 775 def visitModule(self, mod): 776 self.file.write(textwrap.dedent(''' 777 static int 778 add_ast_annotations(struct ast_state *state) 779 { 780 bool cond; 781 ''')) 782 for dfn in mod.dfns: 783 self.visit(dfn) 784 self.file.write(textwrap.dedent(''' 785 return 1; 786 } 787 ''')) 788 789 def visitProduct(self, prod, name): 790 self.emit_annotations(name, prod.fields) 791 792 def visitSum(self, sum, name): 793 for t in sum.types: 794 self.visitConstructor(t, name) 795 796 def visitConstructor(self, cons, name): 797 self.emit_annotations(cons.name, cons.fields) 798 799 def emit_annotations(self, name, fields): 800 self.emit(f"PyObject *{name}_annotations = PyDict_New();", 1) 801 self.emit(f"if (!{name}_annotations) return 0;", 1) 802 for field in fields: 803 self.emit("{", 1) 804 if field.type in builtin_type_to_c_type: 805 self.emit(f"PyObject *type = (PyObject *)&{builtin_type_to_c_type[field.type]};", 2) 806 else: 807 self.emit(f"PyObject *type = state->{field.type}_type;", 2) 808 if field.opt: 809 self.emit("type = _Py_union_type_or(type, Py_None);", 2) 810 self.emit("cond = type != NULL;", 2) 811 self.emit_annotations_error(name, 2) 812 elif field.seq: 813 self.emit("type = Py_GenericAlias((PyObject *)&PyList_Type, type);", 2) 814 self.emit("cond = type != NULL;", 2) 815 self.emit_annotations_error(name, 2) 816 else: 817 self.emit("Py_INCREF(type);", 2) 818 self.emit(f"cond = PyDict_SetItemString({name}_annotations, \"{field.name}\", type) == 0;", 2) 819 self.emit("Py_DECREF(type);", 2) 820 self.emit_annotations_error(name, 2) 821 self.emit("}", 1) 822 self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "_field_types", {name}_annotations) == 0;', 1) 823 self.emit_annotations_error(name, 1) 824 self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "__annotations__", {name}_annotations) == 0;', 1) 825 self.emit_annotations_error(name, 1) 826 self.emit(f"Py_DECREF({name}_annotations);", 1) 827 828 def emit_annotations_error(self, name, depth): 829 self.emit("if (!cond) {", depth) 830 self.emit(f"Py_DECREF({name}_annotations);", depth + 1) 831 self.emit("return 0;", depth + 1) 832 self.emit("}", depth) 833 834 835class PyTypesVisitor(PickleVisitor): 836 837 def visitModule(self, mod): 838 self.emit(""" 839 840typedef struct { 841 PyObject_HEAD 842 PyObject *dict; 843} AST_object; 844 845static void 846ast_dealloc(AST_object *self) 847{ 848 /* bpo-31095: UnTrack is needed before calling any callbacks */ 849 PyTypeObject *tp = Py_TYPE(self); 850 PyObject_GC_UnTrack(self); 851 Py_CLEAR(self->dict); 852 freefunc free_func = PyType_GetSlot(tp, Py_tp_free); 853 assert(free_func != NULL); 854 free_func(self); 855 Py_DECREF(tp); 856} 857 858static int 859ast_traverse(AST_object *self, visitproc visit, void *arg) 860{ 861 Py_VISIT(Py_TYPE(self)); 862 Py_VISIT(self->dict); 863 return 0; 864} 865 866static int 867ast_clear(AST_object *self) 868{ 869 Py_CLEAR(self->dict); 870 return 0; 871} 872 873static int 874ast_type_init(PyObject *self, PyObject *args, PyObject *kw) 875{ 876 struct ast_state *state = get_ast_state(); 877 if (state == NULL) { 878 return -1; 879 } 880 881 Py_ssize_t i, numfields = 0; 882 int res = -1; 883 PyObject *key, *value, *fields, *attributes = NULL, *remaining_fields = NULL; 884 885 fields = PyObject_GetAttr((PyObject*)Py_TYPE(self), state->_fields); 886 if (fields == NULL) { 887 goto cleanup; 888 } 889 890 numfields = PySequence_Size(fields); 891 if (numfields == -1) { 892 goto cleanup; 893 } 894 remaining_fields = PySet_New(fields); 895 if (remaining_fields == NULL) { 896 goto cleanup; 897 } 898 899 res = 0; /* if no error occurs, this stays 0 to the end */ 900 if (numfields < PyTuple_GET_SIZE(args)) { 901 PyErr_Format(PyExc_TypeError, "%.400s constructor takes at most " 902 "%zd positional argument%s", 903 _PyType_Name(Py_TYPE(self)), 904 numfields, numfields == 1 ? "" : "s"); 905 res = -1; 906 goto cleanup; 907 } 908 for (i = 0; i < PyTuple_GET_SIZE(args); i++) { 909 /* cannot be reached when fields is NULL */ 910 PyObject *name = PySequence_GetItem(fields, i); 911 if (!name) { 912 res = -1; 913 goto cleanup; 914 } 915 res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i)); 916 if (PySet_Discard(remaining_fields, name) < 0) { 917 res = -1; 918 Py_DECREF(name); 919 goto cleanup; 920 } 921 Py_DECREF(name); 922 if (res < 0) { 923 goto cleanup; 924 } 925 } 926 if (kw) { 927 i = 0; /* needed by PyDict_Next */ 928 while (PyDict_Next(kw, &i, &key, &value)) { 929 int contains = PySequence_Contains(fields, key); 930 if (contains == -1) { 931 res = -1; 932 goto cleanup; 933 } 934 else if (contains == 1) { 935 int p = PySet_Discard(remaining_fields, key); 936 if (p == -1) { 937 res = -1; 938 goto cleanup; 939 } 940 if (p == 0) { 941 PyErr_Format(PyExc_TypeError, 942 "%.400s got multiple values for argument '%U'", 943 Py_TYPE(self)->tp_name, key); 944 res = -1; 945 goto cleanup; 946 } 947 } 948 else { 949 // Lazily initialize "attributes" 950 if (attributes == NULL) { 951 attributes = PyObject_GetAttr((PyObject*)Py_TYPE(self), state->_attributes); 952 if (attributes == NULL) { 953 res = -1; 954 goto cleanup; 955 } 956 } 957 int contains = PySequence_Contains(attributes, key); 958 if (contains == -1) { 959 res = -1; 960 goto cleanup; 961 } 962 else if (contains == 0) { 963 if (PyErr_WarnFormat( 964 PyExc_DeprecationWarning, 1, 965 "%.400s.__init__ got an unexpected keyword argument '%U'. " 966 "Support for arbitrary keyword arguments is deprecated " 967 "and will be removed in Python 3.15.", 968 Py_TYPE(self)->tp_name, key 969 ) < 0) { 970 res = -1; 971 goto cleanup; 972 } 973 } 974 } 975 res = PyObject_SetAttr(self, key, value); 976 if (res < 0) { 977 goto cleanup; 978 } 979 } 980 } 981 Py_ssize_t size = PySet_Size(remaining_fields); 982 PyObject *field_types = NULL, *remaining_list = NULL; 983 if (size > 0) { 984 if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types), 985 &field_types) < 0) { 986 res = -1; 987 goto cleanup; 988 } 989 if (field_types == NULL) { 990 // Probably a user-defined subclass of AST that lacks _field_types. 991 // This will continue to work as it did before 3.13; i.e., attributes 992 // that are not passed in simply do not exist on the instance. 993 goto cleanup; 994 } 995 remaining_list = PySequence_List(remaining_fields); 996 if (!remaining_list) { 997 goto set_remaining_cleanup; 998 } 999 for (Py_ssize_t i = 0; i < size; i++) { 1000 PyObject *name = PyList_GET_ITEM(remaining_list, i); 1001 PyObject *type = PyDict_GetItemWithError(field_types, name); 1002 if (!type) { 1003 if (PyErr_Occurred()) { 1004 goto set_remaining_cleanup; 1005 } 1006 else { 1007 if (PyErr_WarnFormat( 1008 PyExc_DeprecationWarning, 1, 1009 "Field '%U' is missing from %.400s._field_types. " 1010 "This will become an error in Python 3.15.", 1011 name, Py_TYPE(self)->tp_name 1012 ) < 0) { 1013 goto set_remaining_cleanup; 1014 } 1015 } 1016 } 1017 else if (_PyUnion_Check(type)) { 1018 // optional field 1019 // do nothing, we'll have set a None default on the class 1020 } 1021 else if (Py_IS_TYPE(type, &Py_GenericAliasType)) { 1022 // list field 1023 PyObject *empty = PyList_New(0); 1024 if (!empty) { 1025 goto set_remaining_cleanup; 1026 } 1027 res = PyObject_SetAttr(self, name, empty); 1028 Py_DECREF(empty); 1029 if (res < 0) { 1030 goto set_remaining_cleanup; 1031 } 1032 } 1033 else if (type == state->expr_context_type) { 1034 // special case for expr_context: default to Load() 1035 res = PyObject_SetAttr(self, name, state->Load_singleton); 1036 if (res < 0) { 1037 goto set_remaining_cleanup; 1038 } 1039 } 1040 else { 1041 // simple field (e.g., identifier) 1042 if (PyErr_WarnFormat( 1043 PyExc_DeprecationWarning, 1, 1044 "%.400s.__init__ missing 1 required positional argument: '%U'. " 1045 "This will become an error in Python 3.15.", 1046 Py_TYPE(self)->tp_name, name 1047 ) < 0) { 1048 goto set_remaining_cleanup; 1049 } 1050 } 1051 } 1052 Py_DECREF(remaining_list); 1053 Py_DECREF(field_types); 1054 } 1055 cleanup: 1056 Py_XDECREF(attributes); 1057 Py_XDECREF(fields); 1058 Py_XDECREF(remaining_fields); 1059 return res; 1060 set_remaining_cleanup: 1061 Py_XDECREF(remaining_list); 1062 Py_XDECREF(field_types); 1063 res = -1; 1064 goto cleanup; 1065} 1066 1067/* Pickling support */ 1068static PyObject * 1069ast_type_reduce(PyObject *self, PyObject *unused) 1070{ 1071 struct ast_state *state = get_ast_state(); 1072 if (state == NULL) { 1073 return NULL; 1074 } 1075 1076 PyObject *dict = NULL, *fields = NULL, *positional_args = NULL; 1077 if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { 1078 return NULL; 1079 } 1080 PyObject *result = NULL; 1081 if (dict) { 1082 // Unpickling (or copying) works as follows: 1083 // - Construct the object with only positional arguments 1084 // - Set the fields from the dict 1085 // We have two constraints: 1086 // - We must set all the required fields in the initial constructor call, 1087 // or the unpickling or deepcopying of the object will trigger DeprecationWarnings. 1088 // - We must not include child nodes in the positional args, because 1089 // that may trigger runaway recursion during copying (gh-120108). 1090 // To satisfy both constraints, we set all the fields to None in the 1091 // initial list of positional args, and then set the fields from the dict. 1092 if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { 1093 goto cleanup; 1094 } 1095 if (fields) { 1096 Py_ssize_t numfields = PySequence_Size(fields); 1097 if (numfields == -1) { 1098 Py_DECREF(dict); 1099 goto cleanup; 1100 } 1101 positional_args = PyList_New(0); 1102 if (!positional_args) { 1103 goto cleanup; 1104 } 1105 for (Py_ssize_t i = 0; i < numfields; i++) { 1106 PyObject *name = PySequence_GetItem(fields, i); 1107 if (!name) { 1108 goto cleanup; 1109 } 1110 PyObject *value; 1111 int rc = PyDict_GetItemRef(dict, name, &value); 1112 Py_DECREF(name); 1113 if (rc < 0) { 1114 goto cleanup; 1115 } 1116 if (!value) { 1117 break; 1118 } 1119 rc = PyList_Append(positional_args, Py_None); 1120 Py_DECREF(value); 1121 if (rc < 0) { 1122 goto cleanup; 1123 } 1124 } 1125 PyObject *args_tuple = PyList_AsTuple(positional_args); 1126 if (!args_tuple) { 1127 goto cleanup; 1128 } 1129 result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict); 1130 } 1131 else { 1132 result = Py_BuildValue("O()N", Py_TYPE(self), dict); 1133 } 1134 } 1135 else { 1136 result = Py_BuildValue("O()", Py_TYPE(self)); 1137 } 1138cleanup: 1139 Py_XDECREF(fields); 1140 Py_XDECREF(positional_args); 1141 return result; 1142} 1143 1144static PyMemberDef ast_type_members[] = { 1145 {"__dictoffset__", Py_T_PYSSIZET, offsetof(AST_object, dict), Py_READONLY}, 1146 {NULL} /* Sentinel */ 1147}; 1148 1149static PyMethodDef ast_type_methods[] = { 1150 {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, 1151 {NULL} 1152}; 1153 1154static PyGetSetDef ast_type_getsets[] = { 1155 {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, 1156 {NULL} 1157}; 1158 1159static PyType_Slot AST_type_slots[] = { 1160 {Py_tp_dealloc, ast_dealloc}, 1161 {Py_tp_getattro, PyObject_GenericGetAttr}, 1162 {Py_tp_setattro, PyObject_GenericSetAttr}, 1163 {Py_tp_traverse, ast_traverse}, 1164 {Py_tp_clear, ast_clear}, 1165 {Py_tp_members, ast_type_members}, 1166 {Py_tp_methods, ast_type_methods}, 1167 {Py_tp_getset, ast_type_getsets}, 1168 {Py_tp_init, ast_type_init}, 1169 {Py_tp_alloc, PyType_GenericAlloc}, 1170 {Py_tp_new, PyType_GenericNew}, 1171 {Py_tp_free, PyObject_GC_Del}, 1172 {0, 0}, 1173}; 1174 1175static PyType_Spec AST_type_spec = { 1176 "ast.AST", 1177 sizeof(AST_object), 1178 0, 1179 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, 1180 AST_type_slots 1181}; 1182 1183static PyObject * 1184make_type(struct ast_state *state, const char *type, PyObject* base, 1185 const char* const* fields, int num_fields, const char *doc) 1186{ 1187 PyObject *fnames, *result; 1188 int i; 1189 fnames = PyTuple_New(num_fields); 1190 if (!fnames) return NULL; 1191 for (i = 0; i < num_fields; i++) { 1192 PyObject *field = PyUnicode_InternFromString(fields[i]); 1193 if (!field) { 1194 Py_DECREF(fnames); 1195 return NULL; 1196 } 1197 PyTuple_SET_ITEM(fnames, i, field); 1198 } 1199 result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOOOs}", 1200 type, base, 1201 state->_fields, fnames, 1202 state->__match_args__, fnames, 1203 state->__module__, 1204 state->ast, 1205 state->__doc__, doc); 1206 Py_DECREF(fnames); 1207 return result; 1208} 1209 1210static int 1211add_attributes(struct ast_state *state, PyObject *type, const char * const *attrs, int num_fields) 1212{ 1213 int i, result; 1214 PyObject *s, *l = PyTuple_New(num_fields); 1215 if (!l) 1216 return -1; 1217 for (i = 0; i < num_fields; i++) { 1218 s = PyUnicode_InternFromString(attrs[i]); 1219 if (!s) { 1220 Py_DECREF(l); 1221 return -1; 1222 } 1223 PyTuple_SET_ITEM(l, i, s); 1224 } 1225 result = PyObject_SetAttr(type, state->_attributes, l); 1226 Py_DECREF(l); 1227 return result; 1228} 1229 1230/* Conversion AST -> Python */ 1231 1232static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq, 1233 PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*)) 1234{ 1235 Py_ssize_t i, n = asdl_seq_LEN(seq); 1236 PyObject *result = PyList_New(n); 1237 PyObject *value; 1238 if (!result) 1239 return NULL; 1240 for (i = 0; i < n; i++) { 1241 value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i)); 1242 if (!value) { 1243 Py_DECREF(result); 1244 return NULL; 1245 } 1246 PyList_SET_ITEM(result, i, value); 1247 } 1248 return result; 1249} 1250 1251static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o) 1252{ 1253 PyObject *op = (PyObject*)o; 1254 if (!op) { 1255 op = Py_None; 1256 } 1257 return Py_NewRef(op); 1258} 1259#define ast2obj_constant ast2obj_object 1260#define ast2obj_identifier ast2obj_object 1261#define ast2obj_string ast2obj_object 1262 1263static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b) 1264{ 1265 return PyLong_FromLong(b); 1266} 1267 1268/* Conversion Python -> AST */ 1269 1270static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 1271{ 1272 if (obj == Py_None) 1273 obj = NULL; 1274 if (obj) { 1275 if (_PyArena_AddPyObject(arena, obj) < 0) { 1276 *out = NULL; 1277 return -1; 1278 } 1279 *out = Py_NewRef(obj); 1280 } 1281 else { 1282 *out = NULL; 1283 } 1284 return 0; 1285} 1286 1287static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 1288{ 1289 if (_PyArena_AddPyObject(arena, obj) < 0) { 1290 *out = NULL; 1291 return -1; 1292 } 1293 *out = Py_NewRef(obj); 1294 return 0; 1295} 1296 1297static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena) 1298{ 1299 if (!PyUnicode_CheckExact(obj) && obj != Py_None) { 1300 PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str"); 1301 return -1; 1302 } 1303 return obj2ast_object(state, obj, out, arena); 1304} 1305 1306static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena) 1307{ 1308 if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) { 1309 PyErr_SetString(PyExc_TypeError, "AST string must be of type str"); 1310 return -1; 1311 } 1312 return obj2ast_object(state, obj, out, arena); 1313} 1314 1315static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena) 1316{ 1317 int i; 1318 if (!PyLong_Check(obj)) { 1319 PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj); 1320 return -1; 1321 } 1322 1323 i = PyLong_AsInt(obj); 1324 if (i == -1 && PyErr_Occurred()) 1325 return -1; 1326 *out = i; 1327 return 0; 1328} 1329 1330static int add_ast_fields(struct ast_state *state) 1331{ 1332 PyObject *empty_tuple; 1333 empty_tuple = PyTuple_New(0); 1334 if (!empty_tuple || 1335 PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 || 1336 PyObject_SetAttrString(state->AST_type, "__match_args__", empty_tuple) < 0 || 1337 PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) { 1338 Py_XDECREF(empty_tuple); 1339 return -1; 1340 } 1341 Py_DECREF(empty_tuple); 1342 return 0; 1343} 1344 1345""", 0, reflow=False) 1346 1347 self.file.write(textwrap.dedent(''' 1348 static int 1349 init_types(struct ast_state *state) 1350 { 1351 if (init_identifiers(state) < 0) { 1352 return -1; 1353 } 1354 state->AST_type = PyType_FromSpec(&AST_type_spec); 1355 if (!state->AST_type) { 1356 return -1; 1357 } 1358 if (add_ast_fields(state) < 0) { 1359 return -1; 1360 } 1361 ''')) 1362 for dfn in mod.dfns: 1363 self.visit(dfn) 1364 self.file.write(textwrap.dedent(''' 1365 if (!add_ast_annotations(state)) { 1366 return -1; 1367 } 1368 return 0; 1369 } 1370 ''')) 1371 1372 def visitProduct(self, prod, name): 1373 if prod.fields: 1374 fields = name+"_fields" 1375 else: 1376 fields = "NULL" 1377 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' % 1378 (name, name, fields, len(prod.fields)), 1) 1379 self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False) 1380 self.emit("if (!state->%s_type) return -1;" % name, 1) 1381 if prod.attributes: 1382 self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" % 1383 (name, name, len(prod.attributes)), 1) 1384 else: 1385 self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1) 1386 self.emit_defaults(name, prod.fields, 1) 1387 self.emit_defaults(name, prod.attributes, 1) 1388 1389 def visitSum(self, sum, name): 1390 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' % 1391 (name, name), 1) 1392 self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False) 1393 self.emit("if (!state->%s_type) return -1;" % name, 1) 1394 if sum.attributes: 1395 self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" % 1396 (name, name, len(sum.attributes)), 1) 1397 else: 1398 self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1) 1399 self.emit_defaults(name, sum.attributes, 1) 1400 simple = is_simple(sum) 1401 for t in sum.types: 1402 self.visitConstructor(t, name, simple) 1403 1404 def visitConstructor(self, cons, name, simple): 1405 if cons.fields: 1406 fields = cons.name+"_fields" 1407 else: 1408 fields = "NULL" 1409 self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' % 1410 (cons.name, cons.name, name, fields, len(cons.fields)), 1) 1411 self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False) 1412 self.emit("if (!state->%s_type) return -1;" % cons.name, 1) 1413 self.emit_defaults(cons.name, cons.fields, 1) 1414 if simple: 1415 self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)" 1416 "state->%s_type, NULL, NULL);" % 1417 (cons.name, cons.name), 1) 1418 self.emit("if (!state->%s_singleton) return -1;" % cons.name, 1) 1419 1420 def emit_defaults(self, name, fields, depth): 1421 for field in fields: 1422 if field.opt: 1423 self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' % 1424 (name, field.name), depth) 1425 self.emit("return -1;", depth+1) 1426 1427 1428class ASTModuleVisitor(PickleVisitor): 1429 1430 def visitModule(self, mod): 1431 self.emit("static int", 0) 1432 self.emit("astmodule_exec(PyObject *m)", 0) 1433 self.emit("{", 0) 1434 self.emit('struct ast_state *state = get_ast_state();', 1) 1435 self.emit('if (state == NULL) {', 1) 1436 self.emit('return -1;', 2) 1437 self.emit('}', 1) 1438 self.emit('if (PyModule_AddObjectRef(m, "AST", state->AST_type) < 0) {', 1) 1439 self.emit('return -1;', 2) 1440 self.emit('}', 1) 1441 self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1) 1442 self.emit("return -1;", 2) 1443 self.emit('}', 1) 1444 self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1) 1445 self.emit("return -1;", 2) 1446 self.emit('}', 1) 1447 self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1) 1448 self.emit("return -1;", 2) 1449 self.emit('}', 1) 1450 self.emit('if (PyModule_AddIntMacro(m, PyCF_OPTIMIZED_AST) < 0) {', 1) 1451 self.emit("return -1;", 2) 1452 self.emit('}', 1) 1453 for dfn in mod.dfns: 1454 self.visit(dfn) 1455 self.emit("return 0;", 1) 1456 self.emit("}", 0) 1457 self.emit("", 0) 1458 self.emit(""" 1459static PyModuleDef_Slot astmodule_slots[] = { 1460 {Py_mod_exec, astmodule_exec}, 1461 {Py_mod_multiple_interpreters, Py_MOD_PER_INTERPRETER_GIL_SUPPORTED}, 1462 {Py_mod_gil, Py_MOD_GIL_NOT_USED}, 1463 {0, NULL} 1464}; 1465 1466static struct PyModuleDef _astmodule = { 1467 PyModuleDef_HEAD_INIT, 1468 .m_name = "_ast", 1469 // The _ast module uses a per-interpreter state (PyInterpreterState.ast) 1470 .m_size = 0, 1471 .m_slots = astmodule_slots, 1472}; 1473 1474PyMODINIT_FUNC 1475PyInit__ast(void) 1476{ 1477 return PyModuleDef_Init(&_astmodule); 1478} 1479""".strip(), 0, reflow=False) 1480 1481 def visitProduct(self, prod, name): 1482 self.addObj(name) 1483 1484 def visitSum(self, sum, name): 1485 self.addObj(name) 1486 for t in sum.types: 1487 self.visitConstructor(t, name) 1488 1489 def visitConstructor(self, cons, name): 1490 self.addObj(cons.name) 1491 1492 def addObj(self, name): 1493 self.emit("if (PyModule_AddObjectRef(m, \"%s\", " 1494 "state->%s_type) < 0) {" % (name, name), 1) 1495 self.emit("return -1;", 2) 1496 self.emit('}', 1) 1497 1498 1499class StaticVisitor(PickleVisitor): 1500 CODE = '''Very simple, always emit this static code. Override CODE''' 1501 1502 def visit(self, object): 1503 self.emit(self.CODE, 0, reflow=False) 1504 1505 1506class ObjVisitor(PickleVisitor): 1507 1508 def func_begin(self, name): 1509 ctype = get_c_type(name) 1510 self.emit("PyObject*", 0) 1511 self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0) 1512 self.emit("{", 0) 1513 self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) 1514 self.emit("PyObject *result = NULL, *value = NULL;", 1) 1515 self.emit("PyTypeObject *tp;", 1) 1516 self.emit('if (!o) {', 1) 1517 self.emit("Py_RETURN_NONE;", 2) 1518 self.emit("}", 1) 1519 self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1) 1520 self.emit("PyErr_SetString(PyExc_RecursionError,", 2) 1521 self.emit('"maximum recursion depth exceeded during ast construction");', 3) 1522 self.emit("return NULL;", 2) 1523 self.emit("}", 1) 1524 1525 def func_end(self): 1526 self.emit("vstate->recursion_depth--;", 1) 1527 self.emit("return result;", 1) 1528 self.emit("failed:", 0) 1529 self.emit("vstate->recursion_depth--;", 1) 1530 self.emit("Py_XDECREF(value);", 1) 1531 self.emit("Py_XDECREF(result);", 1) 1532 self.emit("return NULL;", 1) 1533 self.emit("}", 0) 1534 self.emit("", 0) 1535 1536 def visitSum(self, sum, name): 1537 if is_simple(sum): 1538 self.simpleSum(sum, name) 1539 return 1540 self.func_begin(name) 1541 self.emit("switch (o->kind) {", 1) 1542 for i in range(len(sum.types)): 1543 t = sum.types[i] 1544 self.visitConstructor(t, i + 1, name) 1545 self.emit("}", 1) 1546 for a in sum.attributes: 1547 self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1) 1548 self.emit("if (!value) goto failed;", 1) 1549 self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) 1550 self.emit('goto failed;', 2) 1551 self.emit('Py_DECREF(value);', 1) 1552 self.func_end() 1553 1554 def simpleSum(self, sum, name): 1555 self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0) 1556 self.emit("{", 0) 1557 self.emit("switch(o) {", 1) 1558 for t in sum.types: 1559 self.emit("case %s:" % t.name, 2) 1560 self.emit("return Py_NewRef(state->%s_singleton);" % t.name, 3) 1561 self.emit("}", 1) 1562 self.emit("Py_UNREACHABLE();", 1); 1563 self.emit("}", 0) 1564 1565 def visitProduct(self, prod, name): 1566 self.func_begin(name) 1567 self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1) 1568 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1); 1569 self.emit("if (!result) return NULL;", 1) 1570 for field in prod.fields: 1571 self.visitField(field, name, 1, True) 1572 for a in prod.attributes: 1573 self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1) 1574 self.emit("if (!value) goto failed;", 1) 1575 self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) 1576 self.emit('goto failed;', 2) 1577 self.emit('Py_DECREF(value);', 1) 1578 self.func_end() 1579 1580 def visitConstructor(self, cons, enum, name): 1581 self.emit("case %s_kind:" % cons.name, 1) 1582 self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2) 1583 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2); 1584 self.emit("if (!result) goto failed;", 2) 1585 for f in cons.fields: 1586 self.visitField(f, cons.name, 2, False) 1587 self.emit("break;", 2) 1588 1589 def visitField(self, field, name, depth, product): 1590 def emit(s, d): 1591 self.emit(s, depth + d) 1592 if product: 1593 value = "o->%s" % field.name 1594 else: 1595 value = "o->v.%s.%s" % (name, field.name) 1596 self.set(field, value, depth) 1597 emit("if (!value) goto failed;", 0) 1598 emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0) 1599 emit("goto failed;", 1) 1600 emit("Py_DECREF(value);", 0) 1601 1602 def set(self, field, value, depth): 1603 if field.seq: 1604 if field.type in self.metadata.simple_sums: 1605 # While the sequence elements are stored as void*, 1606 # simple sums expects an enum 1607 self.emit("{", depth) 1608 self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1) 1609 self.emit("value = PyList_New(n);", depth+1) 1610 self.emit("if (!value) goto failed;", depth+1) 1611 self.emit("for(i = 0; i < n; i++)", depth+1) 1612 # This cannot fail, so no need for error handling 1613 self.emit( 1614 "PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format( 1615 field.type, 1616 value 1617 ), 1618 depth + 2, 1619 reflow=False, 1620 ) 1621 self.emit("}", depth) 1622 else: 1623 self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) 1624 else: 1625 self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False) 1626 1627 1628class PartingShots(StaticVisitor): 1629 1630 CODE = """ 1631PyObject* PyAST_mod2obj(mod_ty t) 1632{ 1633 struct ast_state *state = get_ast_state(); 1634 if (state == NULL) { 1635 return NULL; 1636 } 1637 1638 int starting_recursion_depth; 1639 /* Be careful here to prevent overflow. */ 1640 PyThreadState *tstate = _PyThreadState_GET(); 1641 if (!tstate) { 1642 return NULL; 1643 } 1644 struct validator vstate; 1645 vstate.recursion_limit = Py_C_RECURSION_LIMIT; 1646 int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining; 1647 starting_recursion_depth = recursion_depth; 1648 vstate.recursion_depth = starting_recursion_depth; 1649 1650 PyObject *result = ast2obj_mod(state, &vstate, t); 1651 1652 /* Check that the recursion depth counting balanced correctly */ 1653 if (result && vstate.recursion_depth != starting_recursion_depth) { 1654 PyErr_Format(PyExc_SystemError, 1655 "AST constructor recursion depth mismatch (before=%d, after=%d)", 1656 starting_recursion_depth, vstate.recursion_depth); 1657 return NULL; 1658 } 1659 return result; 1660} 1661 1662/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */ 1663mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode) 1664{ 1665 const char * const req_name[] = {"Module", "Expression", "Interactive"}; 1666 int isinstance; 1667 1668 if (PySys_Audit("compile", "OO", ast, Py_None) < 0) { 1669 return NULL; 1670 } 1671 1672 struct ast_state *state = get_ast_state(); 1673 if (state == NULL) { 1674 return NULL; 1675 } 1676 1677 PyObject *req_type[3]; 1678 req_type[0] = state->Module_type; 1679 req_type[1] = state->Expression_type; 1680 req_type[2] = state->Interactive_type; 1681 1682 assert(0 <= mode && mode <= 2); 1683 1684 isinstance = PyObject_IsInstance(ast, req_type[mode]); 1685 if (isinstance == -1) 1686 return NULL; 1687 if (!isinstance) { 1688 PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s", 1689 req_name[mode], _PyType_Name(Py_TYPE(ast))); 1690 return NULL; 1691 } 1692 1693 mod_ty res = NULL; 1694 if (obj2ast_mod(state, ast, &res, arena) != 0) 1695 return NULL; 1696 else 1697 return res; 1698} 1699 1700int PyAST_Check(PyObject* obj) 1701{ 1702 struct ast_state *state = get_ast_state(); 1703 if (state == NULL) { 1704 return -1; 1705 } 1706 return PyObject_IsInstance(obj, state->AST_type); 1707} 1708""" 1709 1710class ChainOfVisitors: 1711 def __init__(self, *visitors, metadata = None): 1712 self.visitors = visitors 1713 self.metadata = metadata 1714 1715 def visit(self, object): 1716 for v in self.visitors: 1717 v.metadata = self.metadata 1718 v.visit(object) 1719 v.emit("", 0) 1720 1721 1722def generate_ast_state(module_state, f): 1723 f.write('struct ast_state {\n') 1724 f.write(' _PyOnceFlag once;\n') 1725 f.write(' int finalized;\n') 1726 for s in module_state: 1727 f.write(' PyObject *' + s + ';\n') 1728 f.write('};') 1729 1730 1731def generate_ast_fini(module_state, f): 1732 f.write(textwrap.dedent(""" 1733 void _PyAST_Fini(PyInterpreterState *interp) 1734 { 1735 struct ast_state *state = &interp->ast; 1736 1737 """)) 1738 for s in module_state: 1739 f.write(" Py_CLEAR(state->" + s + ');\n') 1740 f.write(textwrap.dedent(""" 1741 state->finalized = 1; 1742 state->once = (_PyOnceFlag){0}; 1743 } 1744 1745 """)) 1746 1747 1748def generate_module_def(mod, metadata, f, internal_h): 1749 # Gather all the data needed for ModuleSpec 1750 state_strings = { 1751 "ast", 1752 "_fields", 1753 "__match_args__", 1754 "__doc__", 1755 "__dict__", 1756 "__module__", 1757 "_attributes", 1758 *metadata.identifiers 1759 } 1760 1761 module_state = state_strings.copy() 1762 module_state.update( 1763 "%s_singleton" % singleton 1764 for singleton in metadata.singletons 1765 ) 1766 module_state.update( 1767 "%s_type" % type 1768 for type in metadata.types 1769 ) 1770 1771 state_strings = sorted(state_strings) 1772 module_state = sorted(module_state) 1773 1774 generate_ast_state(module_state, internal_h) 1775 1776 print(textwrap.dedent(""" 1777 #include "Python.h" 1778 #include "pycore_ast.h" 1779 #include "pycore_ast_state.h" // struct ast_state 1780 #include "pycore_ceval.h" // _Py_EnterRecursiveCall 1781 #include "pycore_lock.h" // _PyOnceFlag 1782 #include "pycore_interp.h" // _PyInterpreterState.ast 1783 #include "pycore_pystate.h" // _PyInterpreterState_GET() 1784 #include "pycore_unionobject.h" // _Py_union_type_or 1785 #include "structmember.h" 1786 #include <stddef.h> 1787 1788 struct validator { 1789 int recursion_depth; /* current recursion depth */ 1790 int recursion_limit; /* recursion limit */ 1791 }; 1792 1793 // Forward declaration 1794 static int init_types(struct ast_state *state); 1795 1796 static struct ast_state* 1797 get_ast_state(void) 1798 { 1799 PyInterpreterState *interp = _PyInterpreterState_GET(); 1800 struct ast_state *state = &interp->ast; 1801 assert(!state->finalized); 1802 if (_PyOnceFlag_CallOnce(&state->once, (_Py_once_fn_t *)&init_types, state) < 0) { 1803 return NULL; 1804 } 1805 return state; 1806 } 1807 """).strip(), file=f) 1808 1809 generate_ast_fini(module_state, f) 1810 1811 f.write('static int init_identifiers(struct ast_state *state)\n') 1812 f.write('{\n') 1813 for identifier in state_strings: 1814 f.write(' if ((state->' + identifier) 1815 f.write(' = PyUnicode_InternFromString("') 1816 f.write(identifier + '")) == NULL) return -1;\n') 1817 f.write(' return 0;\n') 1818 f.write('};\n\n') 1819 1820def write_header(mod, metadata, f): 1821 f.write(textwrap.dedent(""" 1822 #ifndef Py_INTERNAL_AST_H 1823 #define Py_INTERNAL_AST_H 1824 #ifdef __cplusplus 1825 extern "C" { 1826 #endif 1827 1828 #ifndef Py_BUILD_CORE 1829 # error "this header requires Py_BUILD_CORE define" 1830 #endif 1831 1832 #include "pycore_asdl.h" // _ASDL_SEQ_HEAD 1833 1834 """).lstrip()) 1835 1836 c = ChainOfVisitors( 1837 TypeDefVisitor(f), 1838 SequenceDefVisitor(f), 1839 StructVisitor(f), 1840 metadata=metadata 1841 ) 1842 c.visit(mod) 1843 1844 f.write("// Note: these macros affect function definitions, not only call sites.\n") 1845 prototype_visitor = PrototypeVisitor(f, metadata=metadata) 1846 prototype_visitor.visit(mod) 1847 1848 f.write(textwrap.dedent(""" 1849 1850 PyObject* PyAST_mod2obj(mod_ty t); 1851 mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode); 1852 int PyAST_Check(PyObject* obj); 1853 1854 extern int _PyAST_Validate(mod_ty); 1855 1856 /* _PyAST_ExprAsUnicode is defined in ast_unparse.c */ 1857 extern PyObject* _PyAST_ExprAsUnicode(expr_ty); 1858 1859 /* Return the borrowed reference to the first literal string in the 1860 sequence of statements or NULL if it doesn't start from a literal string. 1861 Doesn't set exception. */ 1862 extern PyObject* _PyAST_GetDocString(asdl_stmt_seq *); 1863 1864 #ifdef __cplusplus 1865 } 1866 #endif 1867 #endif /* !Py_INTERNAL_AST_H */ 1868 """)) 1869 1870 1871def write_internal_h_header(mod, f): 1872 print(textwrap.dedent(""" 1873 #ifndef Py_INTERNAL_AST_STATE_H 1874 #define Py_INTERNAL_AST_STATE_H 1875 1876 #include "pycore_lock.h" // _PyOnceFlag 1877 1878 #ifdef __cplusplus 1879 extern "C" { 1880 #endif 1881 1882 #ifndef Py_BUILD_CORE 1883 # error "this header requires Py_BUILD_CORE define" 1884 #endif 1885 """).lstrip(), file=f) 1886 1887 1888def write_internal_h_footer(mod, f): 1889 print(textwrap.dedent(""" 1890 1891 #ifdef __cplusplus 1892 } 1893 #endif 1894 #endif /* !Py_INTERNAL_AST_STATE_H */ 1895 """), file=f) 1896 1897def write_source(mod, metadata, f, internal_h_file): 1898 generate_module_def(mod, metadata, f, internal_h_file) 1899 1900 v = ChainOfVisitors( 1901 SequenceConstructorVisitor(f), 1902 PyTypesDeclareVisitor(f), 1903 AnnotationsVisitor(f), 1904 PyTypesVisitor(f), 1905 Obj2ModPrototypeVisitor(f), 1906 FunctionVisitor(f), 1907 ObjVisitor(f), 1908 Obj2ModVisitor(f), 1909 ASTModuleVisitor(f), 1910 PartingShots(f), 1911 metadata=metadata 1912 ) 1913 v.visit(mod) 1914 1915def main(input_filename, c_filename, h_filename, internal_h_filename, dump_module=False): 1916 auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) 1917 mod = asdl.parse(input_filename) 1918 if dump_module: 1919 print('Parsed Module:') 1920 print(mod) 1921 if not asdl.check(mod): 1922 sys.exit(1) 1923 1924 metadata_visitor = MetadataVisitor() 1925 metadata_visitor.visit(mod) 1926 metadata = metadata_visitor.metadata 1927 1928 with c_filename.open("w") as c_file, \ 1929 h_filename.open("w") as h_file, \ 1930 internal_h_filename.open("w") as internal_h_file: 1931 c_file.write(auto_gen_msg) 1932 h_file.write(auto_gen_msg) 1933 internal_h_file.write(auto_gen_msg) 1934 1935 write_internal_h_header(mod, internal_h_file) 1936 write_source(mod, metadata, c_file, internal_h_file) 1937 write_header(mod, metadata, h_file) 1938 write_internal_h_footer(mod, internal_h_file) 1939 1940 print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.") 1941 1942if __name__ == "__main__": 1943 parser = ArgumentParser() 1944 parser.add_argument("input_file", type=Path) 1945 parser.add_argument("-C", "--c-file", type=Path, required=True) 1946 parser.add_argument("-H", "--h-file", type=Path, required=True) 1947 parser.add_argument("-I", "--internal-h-file", type=Path, required=True) 1948 parser.add_argument("-d", "--dump-module", action="store_true") 1949 1950 args = parser.parse_args() 1951 main(args.input_file, args.c_file, args.h_file, 1952 args.internal_h_file, args.dump_module) 1953