1#! /usr/bin/env python 2"""Generate C code from an ASDL description.""" 3 4import os 5import sys 6 7from argparse import ArgumentParser 8from pathlib import Path 9 10import asdl 11 12TABSIZE = 4 13MAX_COL = 80 14AUTOGEN_MESSAGE = "/* File automatically generated by {}. */\n\n" 15 16def get_c_type(name): 17 """Return a string for the C name of the type. 18 19 This function special cases the default types provided by asdl. 20 """ 21 if name in asdl.builtin_types: 22 return name 23 else: 24 return "%s_ty" % name 25 26def reflow_lines(s, depth): 27 """Reflow the line s indented depth tabs. 28 29 Return a sequence of lines where no line extends beyond MAX_COL 30 when properly indented. The first line is properly indented based 31 exclusively on depth * TABSIZE. All following lines -- these are 32 the reflowed lines generated by this function -- start at the same 33 column as the first character beyond the opening { in the first 34 line. 35 """ 36 size = MAX_COL - depth * TABSIZE 37 if len(s) < size: 38 return [s] 39 40 lines = [] 41 cur = s 42 padding = "" 43 while len(cur) > size: 44 i = cur.rfind(' ', 0, size) 45 # XXX this should be fixed for real 46 if i == -1 and 'GeneratorExp' in cur: 47 i = size + 3 48 assert i != -1, "Impossible line %d to reflow: %r" % (size, s) 49 lines.append(padding + cur[:i]) 50 if len(lines) == 1: 51 # find new size based on brace 52 j = cur.find('{', 0, i) 53 if j >= 0: 54 j += 2 # account for the brace and the space after it 55 size -= j 56 padding = " " * j 57 else: 58 j = cur.find('(', 0, i) 59 if j >= 0: 60 j += 1 # account for the paren (no space after it) 61 size -= j 62 padding = " " * j 63 cur = cur[i+1:] 64 else: 65 lines.append(padding + cur) 66 return lines 67 68def reflow_c_string(s, depth): 69 return '"%s"' % s.replace('\n', '\\n"\n%s"' % (' ' * depth * TABSIZE)) 70 71def is_simple(sum): 72 """Return True if a sum is a simple. 73 74 A sum is simple if its types have no fields, e.g. 75 unaryop = Invert | Not | UAdd | USub 76 """ 77 for t in sum.types: 78 if t.fields: 79 return False 80 return True 81 82def asdl_of(name, obj): 83 if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor): 84 fields = ", ".join(map(str, obj.fields)) 85 if fields: 86 fields = "({})".format(fields) 87 return "{}{}".format(name, fields) 88 else: 89 if is_simple(obj): 90 types = " | ".join(type.name for type in obj.types) 91 else: 92 sep = "\n{}| ".format(" " * (len(name) + 1)) 93 types = sep.join( 94 asdl_of(type.name, type) for type in obj.types 95 ) 96 return "{} = {}".format(name, types) 97 98class EmitVisitor(asdl.VisitorBase): 99 """Visit that emits lines""" 100 101 def __init__(self, file): 102 self.file = file 103 self.identifiers = set() 104 self.singletons = set() 105 self.types = set() 106 super(EmitVisitor, self).__init__() 107 108 def emit_identifier(self, name): 109 self.identifiers.add(str(name)) 110 111 def emit_singleton(self, name): 112 self.singletons.add(str(name)) 113 114 def emit_type(self, name): 115 self.types.add(str(name)) 116 117 def emit(self, s, depth, reflow=True): 118 # XXX reflow long lines? 119 if reflow: 120 lines = reflow_lines(s, depth) 121 else: 122 lines = [s] 123 for line in lines: 124 if line: 125 line = (" " * TABSIZE * depth) + line 126 self.file.write(line + "\n") 127 128 129class TypeDefVisitor(EmitVisitor): 130 def visitModule(self, mod): 131 for dfn in mod.dfns: 132 self.visit(dfn) 133 134 def visitType(self, type, depth=0): 135 self.visit(type.value, type.name, depth) 136 137 def visitSum(self, sum, name, depth): 138 if is_simple(sum): 139 self.simple_sum(sum, name, depth) 140 else: 141 self.sum_with_constructors(sum, name, depth) 142 143 def simple_sum(self, sum, name, depth): 144 enum = [] 145 for i in range(len(sum.types)): 146 type = sum.types[i] 147 enum.append("%s=%d" % (type.name, i + 1)) 148 enums = ", ".join(enum) 149 ctype = get_c_type(name) 150 s = "typedef enum _%s { %s } %s;" % (name, enums, ctype) 151 self.emit(s, depth) 152 self.emit("", depth) 153 154 def sum_with_constructors(self, sum, name, depth): 155 ctype = get_c_type(name) 156 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 157 self.emit(s, depth) 158 self.emit("", depth) 159 160 def visitProduct(self, product, name, depth): 161 ctype = get_c_type(name) 162 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 163 self.emit(s, depth) 164 self.emit("", depth) 165 166 167class StructVisitor(EmitVisitor): 168 """Visitor to generate typedefs for AST.""" 169 170 def visitModule(self, mod): 171 for dfn in mod.dfns: 172 self.visit(dfn) 173 174 def visitType(self, type, depth=0): 175 self.visit(type.value, type.name, depth) 176 177 def visitSum(self, sum, name, depth): 178 if not is_simple(sum): 179 self.sum_with_constructors(sum, name, depth) 180 181 def sum_with_constructors(self, sum, name, depth): 182 def emit(s, depth=depth): 183 self.emit(s % sys._getframe(1).f_locals, depth) 184 enum = [] 185 for i in range(len(sum.types)): 186 type = sum.types[i] 187 enum.append("%s_kind=%d" % (type.name, i + 1)) 188 189 emit("enum _%(name)s_kind {" + ", ".join(enum) + "};") 190 191 emit("struct _%(name)s {") 192 emit("enum _%(name)s_kind kind;", depth + 1) 193 emit("union {", depth + 1) 194 for t in sum.types: 195 self.visit(t, depth + 2) 196 emit("} v;", depth + 1) 197 for field in sum.attributes: 198 # rudimentary attribute handling 199 type = str(field.type) 200 assert type in asdl.builtin_types, type 201 emit("%s %s;" % (type, field.name), depth + 1); 202 emit("};") 203 emit("") 204 205 def visitConstructor(self, cons, depth): 206 if cons.fields: 207 self.emit("struct {", depth) 208 for f in cons.fields: 209 self.visit(f, depth + 1) 210 self.emit("} %s;" % cons.name, depth) 211 self.emit("", depth) 212 213 def visitField(self, field, depth): 214 # XXX need to lookup field.type, because it might be something 215 # like a builtin... 216 ctype = get_c_type(field.type) 217 name = field.name 218 if field.seq: 219 if field.type == 'cmpop': 220 self.emit("asdl_int_seq *%(name)s;" % locals(), depth) 221 else: 222 self.emit("asdl_seq *%(name)s;" % locals(), depth) 223 else: 224 self.emit("%(ctype)s %(name)s;" % locals(), depth) 225 226 def visitProduct(self, product, name, depth): 227 self.emit("struct _%(name)s {" % locals(), depth) 228 for f in product.fields: 229 self.visit(f, depth + 1) 230 for field in product.attributes: 231 # rudimentary attribute handling 232 type = str(field.type) 233 assert type in asdl.builtin_types, type 234 self.emit("%s %s;" % (type, field.name), depth + 1); 235 self.emit("};", depth) 236 self.emit("", depth) 237 238 239class PrototypeVisitor(EmitVisitor): 240 """Generate function prototypes for the .h file""" 241 242 def visitModule(self, mod): 243 for dfn in mod.dfns: 244 self.visit(dfn) 245 246 def visitType(self, type): 247 self.visit(type.value, type.name) 248 249 def visitSum(self, sum, name): 250 if is_simple(sum): 251 pass # XXX 252 else: 253 for t in sum.types: 254 self.visit(t, name, sum.attributes) 255 256 def get_args(self, fields): 257 """Return list of C argument into, one for each field. 258 259 Argument info is 3-tuple of a C type, variable name, and flag 260 that is true if type can be NULL. 261 """ 262 args = [] 263 unnamed = {} 264 for f in fields: 265 if f.name is None: 266 name = f.type 267 c = unnamed[name] = unnamed.get(name, 0) + 1 268 if c > 1: 269 name = "name%d" % (c - 1) 270 else: 271 name = f.name 272 # XXX should extend get_c_type() to handle this 273 if f.seq: 274 if f.type == 'cmpop': 275 ctype = "asdl_int_seq *" 276 else: 277 ctype = "asdl_seq *" 278 else: 279 ctype = get_c_type(f.type) 280 args.append((ctype, name, f.opt or f.seq)) 281 return args 282 283 def visitConstructor(self, cons, type, attrs): 284 args = self.get_args(cons.fields) 285 attrs = self.get_args(attrs) 286 ctype = get_c_type(type) 287 self.emit_function(cons.name, ctype, args, attrs) 288 289 def emit_function(self, name, ctype, args, attrs, union=True): 290 args = args + attrs 291 if args: 292 argstr = ", ".join(["%s %s" % (atype, aname) 293 for atype, aname, opt in args]) 294 argstr += ", PyArena *arena" 295 else: 296 argstr = "PyArena *arena" 297 margs = "a0" 298 for i in range(1, len(args)+1): 299 margs += ", a%d" % i 300 self.emit("#define %s(%s) _Py_%s(%s)" % (name, margs, name, margs), 0, 301 reflow=False) 302 self.emit("%s _Py_%s(%s);" % (ctype, name, argstr), False) 303 304 def visitProduct(self, prod, name): 305 self.emit_function(name, get_c_type(name), 306 self.get_args(prod.fields), 307 self.get_args(prod.attributes), 308 union=False) 309 310 311class FunctionVisitor(PrototypeVisitor): 312 """Visitor to generate constructor functions for AST.""" 313 314 def emit_function(self, name, ctype, args, attrs, union=True): 315 def emit(s, depth=0, reflow=True): 316 self.emit(s, depth, reflow) 317 argstr = ", ".join(["%s %s" % (atype, aname) 318 for atype, aname, opt in args + attrs]) 319 if argstr: 320 argstr += ", PyArena *arena" 321 else: 322 argstr = "PyArena *arena" 323 self.emit("%s" % ctype, 0) 324 emit("%s(%s)" % (name, argstr)) 325 emit("{") 326 emit("%s p;" % ctype, 1) 327 for argtype, argname, opt in args: 328 if not opt and argtype != "int": 329 emit("if (!%s) {" % argname, 1) 330 emit("PyErr_SetString(PyExc_ValueError,", 2) 331 msg = "field '%s' is required for %s" % (argname, name) 332 emit(' "%s");' % msg, 333 2, reflow=False) 334 emit('return NULL;', 2) 335 emit('}', 1) 336 337 emit("p = (%s)PyArena_Malloc(arena, sizeof(*p));" % ctype, 1); 338 emit("if (!p)", 1) 339 emit("return NULL;", 2) 340 if union: 341 self.emit_body_union(name, args, attrs) 342 else: 343 self.emit_body_struct(name, args, attrs) 344 emit("return p;", 1) 345 emit("}") 346 emit("") 347 348 def emit_body_union(self, name, args, attrs): 349 def emit(s, depth=0, reflow=True): 350 self.emit(s, depth, reflow) 351 emit("p->kind = %s_kind;" % name, 1) 352 for argtype, argname, opt in args: 353 emit("p->v.%s.%s = %s;" % (name, argname, argname), 1) 354 for argtype, argname, opt in attrs: 355 emit("p->%s = %s;" % (argname, argname), 1) 356 357 def emit_body_struct(self, name, args, attrs): 358 def emit(s, depth=0, reflow=True): 359 self.emit(s, depth, reflow) 360 for argtype, argname, opt in args: 361 emit("p->%s = %s;" % (argname, argname), 1) 362 for argtype, argname, opt in attrs: 363 emit("p->%s = %s;" % (argname, argname), 1) 364 365 366class PickleVisitor(EmitVisitor): 367 368 def visitModule(self, mod): 369 for dfn in mod.dfns: 370 self.visit(dfn) 371 372 def visitType(self, type): 373 self.visit(type.value, type.name) 374 375 def visitSum(self, sum, name): 376 pass 377 378 def visitProduct(self, sum, name): 379 pass 380 381 def visitConstructor(self, cons, name): 382 pass 383 384 def visitField(self, sum): 385 pass 386 387 388class Obj2ModPrototypeVisitor(PickleVisitor): 389 def visitProduct(self, prod, name): 390 code = "static int obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena);" 391 self.emit(code % (name, get_c_type(name)), 0) 392 393 visitSum = visitProduct 394 395 396class Obj2ModVisitor(PickleVisitor): 397 def funcHeader(self, name): 398 ctype = get_c_type(name) 399 self.emit("int", 0) 400 self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 401 self.emit("{", 0) 402 self.emit("int isinstance;", 1) 403 self.emit("", 0) 404 405 def sumTrailer(self, name, add_label=False): 406 self.emit("", 0) 407 # there's really nothing more we can do if this fails ... 408 error = "expected some sort of %s, but got %%R" % name 409 format = "PyErr_Format(PyExc_TypeError, \"%s\", obj);" 410 self.emit(format % error, 1, reflow=False) 411 if add_label: 412 self.emit("failed:", 1) 413 self.emit("Py_XDECREF(tmp);", 1) 414 self.emit("return 1;", 1) 415 self.emit("}", 0) 416 self.emit("", 0) 417 418 def simpleSum(self, sum, name): 419 self.funcHeader(name) 420 for t in sum.types: 421 line = ("isinstance = PyObject_IsInstance(obj, " 422 "state->%s_type);") 423 self.emit(line % (t.name,), 1) 424 self.emit("if (isinstance == -1) {", 1) 425 self.emit("return 1;", 2) 426 self.emit("}", 1) 427 self.emit("if (isinstance) {", 1) 428 self.emit("*out = %s;" % t.name, 2) 429 self.emit("return 0;", 2) 430 self.emit("}", 1) 431 self.sumTrailer(name) 432 433 def buildArgs(self, fields): 434 return ", ".join(fields + ["arena"]) 435 436 def complexSum(self, sum, name): 437 self.funcHeader(name) 438 self.emit("PyObject *tmp = NULL;", 1) 439 self.emit("PyObject *tp;", 1) 440 for a in sum.attributes: 441 self.visitAttributeDeclaration(a, name, sum=sum) 442 self.emit("", 0) 443 # XXX: should we only do this for 'expr'? 444 self.emit("if (obj == Py_None) {", 1) 445 self.emit("*out = NULL;", 2) 446 self.emit("return 0;", 2) 447 self.emit("}", 1) 448 for a in sum.attributes: 449 self.visitField(a, name, sum=sum, depth=1) 450 for t in sum.types: 451 self.emit("tp = state->%s_type;" % (t.name,), 1) 452 self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1) 453 self.emit("if (isinstance == -1) {", 1) 454 self.emit("return 1;", 2) 455 self.emit("}", 1) 456 self.emit("if (isinstance) {", 1) 457 for f in t.fields: 458 self.visitFieldDeclaration(f, t.name, sum=sum, depth=2) 459 self.emit("", 0) 460 for f in t.fields: 461 self.visitField(f, t.name, sum=sum, depth=2) 462 args = [f.name for f in t.fields] + [a.name for a in sum.attributes] 463 self.emit("*out = %s(%s);" % (t.name, self.buildArgs(args)), 2) 464 self.emit("if (*out == NULL) goto failed;", 2) 465 self.emit("return 0;", 2) 466 self.emit("}", 1) 467 self.sumTrailer(name, True) 468 469 def visitAttributeDeclaration(self, a, name, sum=sum): 470 ctype = get_c_type(a.type) 471 self.emit("%s %s;" % (ctype, a.name), 1) 472 473 def visitSum(self, sum, name): 474 if is_simple(sum): 475 self.simpleSum(sum, name) 476 else: 477 self.complexSum(sum, name) 478 479 def visitProduct(self, prod, name): 480 ctype = get_c_type(name) 481 self.emit("int", 0) 482 self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 483 self.emit("{", 0) 484 self.emit("PyObject* tmp = NULL;", 1) 485 for f in prod.fields: 486 self.visitFieldDeclaration(f, name, prod=prod, depth=1) 487 for a in prod.attributes: 488 self.visitFieldDeclaration(a, name, prod=prod, depth=1) 489 self.emit("", 0) 490 for f in prod.fields: 491 self.visitField(f, name, prod=prod, depth=1) 492 for a in prod.attributes: 493 self.visitField(a, name, prod=prod, depth=1) 494 args = [f.name for f in prod.fields] 495 args.extend([a.name for a in prod.attributes]) 496 self.emit("*out = %s(%s);" % (name, self.buildArgs(args)), 1) 497 self.emit("return 0;", 1) 498 self.emit("failed:", 0) 499 self.emit("Py_XDECREF(tmp);", 1) 500 self.emit("return 1;", 1) 501 self.emit("}", 0) 502 self.emit("", 0) 503 504 def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0): 505 ctype = get_c_type(field.type) 506 if field.seq: 507 if self.isSimpleType(field): 508 self.emit("asdl_int_seq* %s;" % field.name, depth) 509 else: 510 self.emit("asdl_seq* %s;" % field.name, depth) 511 else: 512 ctype = get_c_type(field.type) 513 self.emit("%s %s;" % (ctype, field.name), depth) 514 515 def isSimpleSum(self, field): 516 # XXX can the members of this list be determined automatically? 517 return field.type in ('expr_context', 'boolop', 'operator', 518 'unaryop', 'cmpop') 519 520 def isNumeric(self, field): 521 return get_c_type(field.type) in ("int", "bool") 522 523 def isSimpleType(self, field): 524 return self.isSimpleSum(field) or self.isNumeric(field) 525 526 def visitField(self, field, name, sum=None, prod=None, depth=0): 527 ctype = get_c_type(field.type) 528 line = "if (_PyObject_LookupAttr(obj, state->%s, &tmp) < 0) {" 529 self.emit(line % field.name, depth) 530 self.emit("return 1;", depth+1) 531 self.emit("}", depth) 532 if not field.opt: 533 self.emit("if (tmp == NULL) {", depth) 534 message = "required field \\\"%s\\\" missing from %s" % (field.name, name) 535 format = "PyErr_SetString(PyExc_TypeError, \"%s\");" 536 self.emit(format % message, depth+1, reflow=False) 537 self.emit("return 1;", depth+1) 538 else: 539 self.emit("if (tmp == NULL || tmp == Py_None) {", depth) 540 self.emit("Py_CLEAR(tmp);", depth+1) 541 if self.isNumeric(field): 542 self.emit("%s = 0;" % field.name, depth+1) 543 elif not self.isSimpleType(field): 544 self.emit("%s = NULL;" % field.name, depth+1) 545 else: 546 raise TypeError("could not determine the default value for %s" % field.name) 547 self.emit("}", depth) 548 self.emit("else {", depth) 549 550 self.emit("int res;", depth+1) 551 if field.seq: 552 self.emit("Py_ssize_t len;", depth+1) 553 self.emit("Py_ssize_t i;", depth+1) 554 self.emit("if (!PyList_Check(tmp)) {", depth+1) 555 self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must " 556 "be a list, not a %%.200s\", _PyType_Name(Py_TYPE(tmp)));" % 557 (name, field.name), 558 depth+2, reflow=False) 559 self.emit("goto failed;", depth+2) 560 self.emit("}", depth+1) 561 self.emit("len = PyList_GET_SIZE(tmp);", depth+1) 562 if self.isSimpleType(field): 563 self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1) 564 else: 565 self.emit("%s = _Py_asdl_seq_new(len, arena);" % field.name, depth+1) 566 self.emit("if (%s == NULL) goto failed;" % field.name, depth+1) 567 self.emit("for (i = 0; i < len; i++) {", depth+1) 568 self.emit("%s val;" % ctype, depth+2) 569 self.emit("PyObject *tmp2 = PyList_GET_ITEM(tmp, i);", depth+2) 570 self.emit("Py_INCREF(tmp2);", depth+2) 571 self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" % 572 field.type, depth+2, reflow=False) 573 self.emit("Py_DECREF(tmp2);", depth+2) 574 self.emit("if (res != 0) goto failed;", depth+2) 575 self.emit("if (len != PyList_GET_SIZE(tmp)) {", depth+2) 576 self.emit("PyErr_SetString(PyExc_RuntimeError, \"%s field \\\"%s\\\" " 577 "changed size during iteration\");" % 578 (name, field.name), 579 depth+3, reflow=False) 580 self.emit("goto failed;", depth+3) 581 self.emit("}", depth+2) 582 self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2) 583 self.emit("}", depth+1) 584 else: 585 self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" % 586 (field.type, field.name), depth+1) 587 self.emit("if (res != 0) goto failed;", depth+1) 588 589 self.emit("Py_CLEAR(tmp);", depth+1) 590 self.emit("}", depth) 591 592 593class MarshalPrototypeVisitor(PickleVisitor): 594 595 def prototype(self, sum, name): 596 ctype = get_c_type(name) 597 self.emit("static int marshal_write_%s(PyObject **, int *, %s);" 598 % (name, ctype), 0) 599 600 visitProduct = visitSum = prototype 601 602 603class PyTypesDeclareVisitor(PickleVisitor): 604 605 def visitProduct(self, prod, name): 606 self.emit_type("%s_type" % name) 607 self.emit("static PyObject* ast2obj_%s(astmodulestate *state, void*);" % name, 0) 608 if prod.attributes: 609 for a in prod.attributes: 610 self.emit_identifier(a.name) 611 self.emit("static const char * const %s_attributes[] = {" % name, 0) 612 for a in prod.attributes: 613 self.emit('"%s",' % a.name, 1) 614 self.emit("};", 0) 615 if prod.fields: 616 for f in prod.fields: 617 self.emit_identifier(f.name) 618 self.emit("static const char * const %s_fields[]={" % name,0) 619 for f in prod.fields: 620 self.emit('"%s",' % f.name, 1) 621 self.emit("};", 0) 622 623 def visitSum(self, sum, name): 624 self.emit_type("%s_type" % name) 625 if sum.attributes: 626 for a in sum.attributes: 627 self.emit_identifier(a.name) 628 self.emit("static const char * const %s_attributes[] = {" % name, 0) 629 for a in sum.attributes: 630 self.emit('"%s",' % a.name, 1) 631 self.emit("};", 0) 632 ptype = "void*" 633 if is_simple(sum): 634 ptype = get_c_type(name) 635 for t in sum.types: 636 self.emit_singleton("%s_singleton" % t.name) 637 self.emit("static PyObject* ast2obj_%s(astmodulestate *state, %s);" % (name, ptype), 0) 638 for t in sum.types: 639 self.visitConstructor(t, name) 640 641 def visitConstructor(self, cons, name): 642 if cons.fields: 643 for t in cons.fields: 644 self.emit_identifier(t.name) 645 self.emit("static const char * const %s_fields[]={" % cons.name, 0) 646 for t in cons.fields: 647 self.emit('"%s",' % t.name, 1) 648 self.emit("};",0) 649 650class PyTypesVisitor(PickleVisitor): 651 652 def visitModule(self, mod): 653 self.emit(""" 654 655typedef struct { 656 PyObject_HEAD 657 PyObject *dict; 658} AST_object; 659 660static void 661ast_dealloc(AST_object *self) 662{ 663 /* bpo-31095: UnTrack is needed before calling any callbacks */ 664 PyTypeObject *tp = Py_TYPE(self); 665 PyObject_GC_UnTrack(self); 666 Py_CLEAR(self->dict); 667 freefunc free_func = PyType_GetSlot(tp, Py_tp_free); 668 assert(free_func != NULL); 669 free_func(self); 670 Py_DECREF(tp); 671} 672 673static int 674ast_traverse(AST_object *self, visitproc visit, void *arg) 675{ 676 Py_VISIT(Py_TYPE(self)); 677 Py_VISIT(self->dict); 678 return 0; 679} 680 681static int 682ast_clear(AST_object *self) 683{ 684 Py_CLEAR(self->dict); 685 return 0; 686} 687 688static int 689ast_type_init(PyObject *self, PyObject *args, PyObject *kw) 690{ 691 astmodulestate *state = get_global_ast_state(); 692 if (state == NULL) { 693 return -1; 694 } 695 696 Py_ssize_t i, numfields = 0; 697 int res = -1; 698 PyObject *key, *value, *fields; 699 if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { 700 goto cleanup; 701 } 702 if (fields) { 703 numfields = PySequence_Size(fields); 704 if (numfields == -1) { 705 goto cleanup; 706 } 707 } 708 709 res = 0; /* if no error occurs, this stays 0 to the end */ 710 if (numfields < PyTuple_GET_SIZE(args)) { 711 PyErr_Format(PyExc_TypeError, "%.400s constructor takes at most " 712 "%zd positional argument%s", 713 _PyType_Name(Py_TYPE(self)), 714 numfields, numfields == 1 ? "" : "s"); 715 res = -1; 716 goto cleanup; 717 } 718 for (i = 0; i < PyTuple_GET_SIZE(args); i++) { 719 /* cannot be reached when fields is NULL */ 720 PyObject *name = PySequence_GetItem(fields, i); 721 if (!name) { 722 res = -1; 723 goto cleanup; 724 } 725 res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i)); 726 Py_DECREF(name); 727 if (res < 0) { 728 goto cleanup; 729 } 730 } 731 if (kw) { 732 i = 0; /* needed by PyDict_Next */ 733 while (PyDict_Next(kw, &i, &key, &value)) { 734 int contains = PySequence_Contains(fields, key); 735 if (contains == -1) { 736 res = -1; 737 goto cleanup; 738 } else if (contains == 1) { 739 Py_ssize_t p = PySequence_Index(fields, key); 740 if (p == -1) { 741 res = -1; 742 goto cleanup; 743 } 744 if (p < PyTuple_GET_SIZE(args)) { 745 PyErr_Format(PyExc_TypeError, 746 "%.400s got multiple values for argument '%U'", 747 Py_TYPE(self)->tp_name, key); 748 res = -1; 749 goto cleanup; 750 } 751 } 752 res = PyObject_SetAttr(self, key, value); 753 if (res < 0) { 754 goto cleanup; 755 } 756 } 757 } 758 cleanup: 759 Py_XDECREF(fields); 760 return res; 761} 762 763/* Pickling support */ 764static PyObject * 765ast_type_reduce(PyObject *self, PyObject *unused) 766{ 767 astmodulestate *state = get_global_ast_state(); 768 if (state == NULL) { 769 return NULL; 770 } 771 772 PyObject *dict; 773 if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) { 774 return NULL; 775 } 776 if (dict) { 777 return Py_BuildValue("O()N", Py_TYPE(self), dict); 778 } 779 return Py_BuildValue("O()", Py_TYPE(self)); 780} 781 782static PyMemberDef ast_type_members[] = { 783 {"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY}, 784 {NULL} /* Sentinel */ 785}; 786 787static PyMethodDef ast_type_methods[] = { 788 {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, 789 {NULL} 790}; 791 792static PyGetSetDef ast_type_getsets[] = { 793 {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, 794 {NULL} 795}; 796 797static PyType_Slot AST_type_slots[] = { 798 {Py_tp_dealloc, ast_dealloc}, 799 {Py_tp_getattro, PyObject_GenericGetAttr}, 800 {Py_tp_setattro, PyObject_GenericSetAttr}, 801 {Py_tp_traverse, ast_traverse}, 802 {Py_tp_clear, ast_clear}, 803 {Py_tp_members, ast_type_members}, 804 {Py_tp_methods, ast_type_methods}, 805 {Py_tp_getset, ast_type_getsets}, 806 {Py_tp_init, ast_type_init}, 807 {Py_tp_alloc, PyType_GenericAlloc}, 808 {Py_tp_new, PyType_GenericNew}, 809 {Py_tp_free, PyObject_GC_Del}, 810 {0, 0}, 811}; 812 813static PyType_Spec AST_type_spec = { 814 "ast.AST", 815 sizeof(AST_object), 816 0, 817 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, 818 AST_type_slots 819}; 820 821static PyObject * 822make_type(astmodulestate *state, const char *type, PyObject* base, 823 const char* const* fields, int num_fields, const char *doc) 824{ 825 PyObject *fnames, *result; 826 int i; 827 fnames = PyTuple_New(num_fields); 828 if (!fnames) return NULL; 829 for (i = 0; i < num_fields; i++) { 830 PyObject *field = PyUnicode_InternFromString(fields[i]); 831 if (!field) { 832 Py_DECREF(fnames); 833 return NULL; 834 } 835 PyTuple_SET_ITEM(fnames, i, field); 836 } 837 result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOs}", 838 type, base, 839 state->_fields, fnames, 840 state->__module__, 841 state->ast, 842 state->__doc__, doc); 843 Py_DECREF(fnames); 844 return result; 845} 846 847static int 848add_attributes(astmodulestate *state, PyObject *type, const char * const *attrs, int num_fields) 849{ 850 int i, result; 851 PyObject *s, *l = PyTuple_New(num_fields); 852 if (!l) 853 return 0; 854 for (i = 0; i < num_fields; i++) { 855 s = PyUnicode_InternFromString(attrs[i]); 856 if (!s) { 857 Py_DECREF(l); 858 return 0; 859 } 860 PyTuple_SET_ITEM(l, i, s); 861 } 862 result = PyObject_SetAttr(type, state->_attributes, l) >= 0; 863 Py_DECREF(l); 864 return result; 865} 866 867/* Conversion AST -> Python */ 868 869static PyObject* ast2obj_list(astmodulestate *state, asdl_seq *seq, PyObject* (*func)(astmodulestate *state, void*)) 870{ 871 Py_ssize_t i, n = asdl_seq_LEN(seq); 872 PyObject *result = PyList_New(n); 873 PyObject *value; 874 if (!result) 875 return NULL; 876 for (i = 0; i < n; i++) { 877 value = func(state, asdl_seq_GET(seq, i)); 878 if (!value) { 879 Py_DECREF(result); 880 return NULL; 881 } 882 PyList_SET_ITEM(result, i, value); 883 } 884 return result; 885} 886 887static PyObject* ast2obj_object(astmodulestate *Py_UNUSED(state), void *o) 888{ 889 if (!o) 890 o = Py_None; 891 Py_INCREF((PyObject*)o); 892 return (PyObject*)o; 893} 894#define ast2obj_constant ast2obj_object 895#define ast2obj_identifier ast2obj_object 896#define ast2obj_string ast2obj_object 897 898static PyObject* ast2obj_int(astmodulestate *Py_UNUSED(state), long b) 899{ 900 return PyLong_FromLong(b); 901} 902 903/* Conversion Python -> AST */ 904 905static int obj2ast_object(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 906{ 907 if (obj == Py_None) 908 obj = NULL; 909 if (obj) { 910 if (PyArena_AddPyObject(arena, obj) < 0) { 911 *out = NULL; 912 return -1; 913 } 914 Py_INCREF(obj); 915 } 916 *out = obj; 917 return 0; 918} 919 920static int obj2ast_constant(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 921{ 922 if (PyArena_AddPyObject(arena, obj) < 0) { 923 *out = NULL; 924 return -1; 925 } 926 Py_INCREF(obj); 927 *out = obj; 928 return 0; 929} 930 931static int obj2ast_identifier(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena) 932{ 933 if (!PyUnicode_CheckExact(obj) && obj != Py_None) { 934 PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str"); 935 return 1; 936 } 937 return obj2ast_object(state, obj, out, arena); 938} 939 940static int obj2ast_string(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena) 941{ 942 if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) { 943 PyErr_SetString(PyExc_TypeError, "AST string must be of type str"); 944 return 1; 945 } 946 return obj2ast_object(state, obj, out, arena); 947} 948 949static int obj2ast_int(astmodulestate* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena) 950{ 951 int i; 952 if (!PyLong_Check(obj)) { 953 PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj); 954 return 1; 955 } 956 957 i = _PyLong_AsInt(obj); 958 if (i == -1 && PyErr_Occurred()) 959 return 1; 960 *out = i; 961 return 0; 962} 963 964static int add_ast_fields(astmodulestate *state) 965{ 966 PyObject *empty_tuple; 967 empty_tuple = PyTuple_New(0); 968 if (!empty_tuple || 969 PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 || 970 PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) { 971 Py_XDECREF(empty_tuple); 972 return -1; 973 } 974 Py_DECREF(empty_tuple); 975 return 0; 976} 977 978""", 0, reflow=False) 979 980 self.emit("static int init_types(astmodulestate *state)",0) 981 self.emit("{", 0) 982 self.emit("if (state->initialized) return 1;", 1) 983 self.emit("if (init_identifiers(state) < 0) return 0;", 1) 984 self.emit("state->AST_type = PyType_FromSpec(&AST_type_spec);", 1) 985 self.emit("if (!state->AST_type) return 0;", 1) 986 self.emit("if (add_ast_fields(state) < 0) return 0;", 1) 987 for dfn in mod.dfns: 988 self.visit(dfn) 989 self.emit("state->initialized = 1;", 1) 990 self.emit("return 1;", 1); 991 self.emit("}", 0) 992 993 def visitProduct(self, prod, name): 994 if prod.fields: 995 fields = name+"_fields" 996 else: 997 fields = "NULL" 998 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' % 999 (name, name, fields, len(prod.fields)), 1) 1000 self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False) 1001 self.emit("if (!state->%s_type) return 0;" % name, 1) 1002 self.emit_type("AST_type") 1003 self.emit_type("%s_type" % name) 1004 if prod.attributes: 1005 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1006 (name, name, len(prod.attributes)), 1) 1007 else: 1008 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1009 self.emit_defaults(name, prod.fields, 1) 1010 self.emit_defaults(name, prod.attributes, 1) 1011 1012 def visitSum(self, sum, name): 1013 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' % 1014 (name, name), 1) 1015 self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False) 1016 self.emit_type("%s_type" % name) 1017 self.emit("if (!state->%s_type) return 0;" % name, 1) 1018 if sum.attributes: 1019 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1020 (name, name, len(sum.attributes)), 1) 1021 else: 1022 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1023 self.emit_defaults(name, sum.attributes, 1) 1024 simple = is_simple(sum) 1025 for t in sum.types: 1026 self.visitConstructor(t, name, simple) 1027 1028 def visitConstructor(self, cons, name, simple): 1029 if cons.fields: 1030 fields = cons.name+"_fields" 1031 else: 1032 fields = "NULL" 1033 self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' % 1034 (cons.name, cons.name, name, fields, len(cons.fields)), 1) 1035 self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False) 1036 self.emit("if (!state->%s_type) return 0;" % cons.name, 1) 1037 self.emit_type("%s_type" % cons.name) 1038 self.emit_defaults(cons.name, cons.fields, 1) 1039 if simple: 1040 self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)" 1041 "state->%s_type, NULL, NULL);" % 1042 (cons.name, cons.name), 1) 1043 self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1) 1044 1045 def emit_defaults(self, name, fields, depth): 1046 for field in fields: 1047 if field.opt: 1048 self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' % 1049 (name, field.name), depth) 1050 self.emit("return 0;", depth+1) 1051 1052 1053class ASTModuleVisitor(PickleVisitor): 1054 1055 def visitModule(self, mod): 1056 self.emit("static int", 0) 1057 self.emit("astmodule_exec(PyObject *m)", 0) 1058 self.emit("{", 0) 1059 self.emit('astmodulestate *state = get_ast_state(m);', 1) 1060 self.emit("", 0) 1061 1062 self.emit("if (!init_types(state)) {", 1) 1063 self.emit("return -1;", 2) 1064 self.emit("}", 1) 1065 self.emit('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {', 1) 1066 self.emit('return -1;', 2) 1067 self.emit('}', 1) 1068 self.emit('Py_INCREF(state->AST_type);', 1) 1069 self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1) 1070 self.emit("return -1;", 2) 1071 self.emit('}', 1) 1072 self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1) 1073 self.emit("return -1;", 2) 1074 self.emit('}', 1) 1075 self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1) 1076 self.emit("return -1;", 2) 1077 self.emit('}', 1) 1078 for dfn in mod.dfns: 1079 self.visit(dfn) 1080 self.emit("return 0;", 1) 1081 self.emit("}", 0) 1082 self.emit("", 0) 1083 self.emit(""" 1084static PyModuleDef_Slot astmodule_slots[] = { 1085 {Py_mod_exec, astmodule_exec}, 1086 {0, NULL} 1087}; 1088 1089static struct PyModuleDef _astmodule = { 1090 PyModuleDef_HEAD_INIT, 1091 .m_name = "_ast", 1092 // The _ast module uses a global state (global_ast_state). 1093 .m_size = 0, 1094 .m_slots = astmodule_slots, 1095}; 1096 1097PyMODINIT_FUNC 1098PyInit__ast(void) 1099{ 1100 return PyModuleDef_Init(&_astmodule); 1101} 1102""".strip(), 0, reflow=False) 1103 1104 def visitProduct(self, prod, name): 1105 self.addObj(name) 1106 1107 def visitSum(self, sum, name): 1108 self.addObj(name) 1109 for t in sum.types: 1110 self.visitConstructor(t, name) 1111 1112 def visitConstructor(self, cons, name): 1113 self.addObj(cons.name) 1114 1115 def addObj(self, name): 1116 self.emit("if (PyModule_AddObject(m, \"%s\", " 1117 "state->%s_type) < 0) {" % (name, name), 1) 1118 self.emit("return -1;", 2) 1119 self.emit('}', 1) 1120 self.emit("Py_INCREF(state->%s_type);" % name, 1) 1121 1122 1123_SPECIALIZED_SEQUENCES = ('stmt', 'expr') 1124 1125def find_sequence(fields, doing_specialization): 1126 """Return True if any field uses a sequence.""" 1127 for f in fields: 1128 if f.seq: 1129 if not doing_specialization: 1130 return True 1131 if str(f.type) not in _SPECIALIZED_SEQUENCES: 1132 return True 1133 return False 1134 1135def has_sequence(types, doing_specialization): 1136 for t in types: 1137 if find_sequence(t.fields, doing_specialization): 1138 return True 1139 return False 1140 1141 1142class StaticVisitor(PickleVisitor): 1143 CODE = '''Very simple, always emit this static code. Override CODE''' 1144 1145 def visit(self, object): 1146 self.emit(self.CODE, 0, reflow=False) 1147 1148 1149class ObjVisitor(PickleVisitor): 1150 1151 def func_begin(self, name): 1152 ctype = get_c_type(name) 1153 self.emit("PyObject*", 0) 1154 self.emit("ast2obj_%s(astmodulestate *state, void* _o)" % (name), 0) 1155 self.emit("{", 0) 1156 self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) 1157 self.emit("PyObject *result = NULL, *value = NULL;", 1) 1158 self.emit("PyTypeObject *tp;", 1) 1159 self.emit('if (!o) {', 1) 1160 self.emit("Py_RETURN_NONE;", 2) 1161 self.emit("}", 1) 1162 1163 def func_end(self): 1164 self.emit("return result;", 1) 1165 self.emit("failed:", 0) 1166 self.emit("Py_XDECREF(value);", 1) 1167 self.emit("Py_XDECREF(result);", 1) 1168 self.emit("return NULL;", 1) 1169 self.emit("}", 0) 1170 self.emit("", 0) 1171 1172 def visitSum(self, sum, name): 1173 if is_simple(sum): 1174 self.simpleSum(sum, name) 1175 return 1176 self.func_begin(name) 1177 self.emit("switch (o->kind) {", 1) 1178 for i in range(len(sum.types)): 1179 t = sum.types[i] 1180 self.visitConstructor(t, i + 1, name) 1181 self.emit("}", 1) 1182 for a in sum.attributes: 1183 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1184 self.emit("if (!value) goto failed;", 1) 1185 self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) 1186 self.emit('goto failed;', 2) 1187 self.emit('Py_DECREF(value);', 1) 1188 self.func_end() 1189 1190 def simpleSum(self, sum, name): 1191 self.emit("PyObject* ast2obj_%s(astmodulestate *state, %s_ty o)" % (name, name), 0) 1192 self.emit("{", 0) 1193 self.emit("switch(o) {", 1) 1194 for t in sum.types: 1195 self.emit("case %s:" % t.name, 2) 1196 self.emit("Py_INCREF(state->%s_singleton);" % t.name, 3) 1197 self.emit("return state->%s_singleton;" % t.name, 3) 1198 self.emit("}", 1) 1199 self.emit("Py_UNREACHABLE();", 1); 1200 self.emit("}", 0) 1201 1202 def visitProduct(self, prod, name): 1203 self.func_begin(name) 1204 self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1) 1205 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1); 1206 self.emit("if (!result) return NULL;", 1) 1207 for field in prod.fields: 1208 self.visitField(field, name, 1, True) 1209 for a in prod.attributes: 1210 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1211 self.emit("if (!value) goto failed;", 1) 1212 self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) 1213 self.emit('goto failed;', 2) 1214 self.emit('Py_DECREF(value);', 1) 1215 self.func_end() 1216 1217 def visitConstructor(self, cons, enum, name): 1218 self.emit("case %s_kind:" % cons.name, 1) 1219 self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2) 1220 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2); 1221 self.emit("if (!result) goto failed;", 2) 1222 for f in cons.fields: 1223 self.visitField(f, cons.name, 2, False) 1224 self.emit("break;", 2) 1225 1226 def visitField(self, field, name, depth, product): 1227 def emit(s, d): 1228 self.emit(s, depth + d) 1229 if product: 1230 value = "o->%s" % field.name 1231 else: 1232 value = "o->v.%s.%s" % (name, field.name) 1233 self.set(field, value, depth) 1234 emit("if (!value) goto failed;", 0) 1235 emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0) 1236 emit("goto failed;", 1) 1237 emit("Py_DECREF(value);", 0) 1238 1239 def emitSeq(self, field, value, depth, emit): 1240 emit("seq = %s;" % value, 0) 1241 emit("n = asdl_seq_LEN(seq);", 0) 1242 emit("value = PyList_New(n);", 0) 1243 emit("if (!value) goto failed;", 0) 1244 emit("for (i = 0; i < n; i++) {", 0) 1245 self.set("value", field, "asdl_seq_GET(seq, i)", depth + 1) 1246 emit("if (!value1) goto failed;", 1) 1247 emit("PyList_SET_ITEM(value, i, value1);", 1) 1248 emit("value1 = NULL;", 1) 1249 emit("}", 0) 1250 1251 def set(self, field, value, depth): 1252 if field.seq: 1253 # XXX should really check for is_simple, but that requires a symbol table 1254 if field.type == "cmpop": 1255 # While the sequence elements are stored as void*, 1256 # ast2obj_cmpop expects an enum 1257 self.emit("{", depth) 1258 self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1) 1259 self.emit("value = PyList_New(n);", depth+1) 1260 self.emit("if (!value) goto failed;", depth+1) 1261 self.emit("for(i = 0; i < n; i++)", depth+1) 1262 # This cannot fail, so no need for error handling 1263 self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop(state, (cmpop_ty)asdl_seq_GET(%s, i)));" % value, 1264 depth+2, reflow=False) 1265 self.emit("}", depth) 1266 else: 1267 self.emit("value = ast2obj_list(state, %s, ast2obj_%s);" % (value, field.type), depth) 1268 else: 1269 ctype = get_c_type(field.type) 1270 self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) 1271 1272 1273class PartingShots(StaticVisitor): 1274 1275 CODE = """ 1276PyObject* PyAST_mod2obj(mod_ty t) 1277{ 1278 astmodulestate *state = get_global_ast_state(); 1279 if (state == NULL) { 1280 return NULL; 1281 } 1282 return ast2obj_mod(state, t); 1283} 1284 1285/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */ 1286mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode) 1287{ 1288 const char * const req_name[] = {"Module", "Expression", "Interactive"}; 1289 int isinstance; 1290 1291 if (PySys_Audit("compile", "OO", ast, Py_None) < 0) { 1292 return NULL; 1293 } 1294 1295 astmodulestate *state = get_global_ast_state(); 1296 PyObject *req_type[3]; 1297 req_type[0] = state->Module_type; 1298 req_type[1] = state->Expression_type; 1299 req_type[2] = state->Interactive_type; 1300 1301 assert(0 <= mode && mode <= 2); 1302 1303 isinstance = PyObject_IsInstance(ast, req_type[mode]); 1304 if (isinstance == -1) 1305 return NULL; 1306 if (!isinstance) { 1307 PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s", 1308 req_name[mode], _PyType_Name(Py_TYPE(ast))); 1309 return NULL; 1310 } 1311 1312 mod_ty res = NULL; 1313 if (obj2ast_mod(state, ast, &res, arena) != 0) 1314 return NULL; 1315 else 1316 return res; 1317} 1318 1319int PyAST_Check(PyObject* obj) 1320{ 1321 astmodulestate *state = get_global_ast_state(); 1322 if (state == NULL) { 1323 return -1; 1324 } 1325 return PyObject_IsInstance(obj, state->AST_type); 1326} 1327""" 1328 1329class ChainOfVisitors: 1330 def __init__(self, *visitors): 1331 self.visitors = visitors 1332 1333 def visit(self, object): 1334 for v in self.visitors: 1335 v.visit(object) 1336 v.emit("", 0) 1337 1338 1339def generate_module_def(f, mod): 1340 # Gather all the data needed for ModuleSpec 1341 visitor_list = set() 1342 with open(os.devnull, "w") as devnull: 1343 visitor = PyTypesDeclareVisitor(devnull) 1344 visitor.visit(mod) 1345 visitor_list.add(visitor) 1346 visitor = PyTypesVisitor(devnull) 1347 visitor.visit(mod) 1348 visitor_list.add(visitor) 1349 1350 state_strings = { 1351 "ast", 1352 "_fields", 1353 "__doc__", 1354 "__dict__", 1355 "__module__", 1356 "_attributes", 1357 } 1358 module_state = state_strings.copy() 1359 for visitor in visitor_list: 1360 for identifier in visitor.identifiers: 1361 module_state.add(identifier) 1362 state_strings.add(identifier) 1363 for singleton in visitor.singletons: 1364 module_state.add(singleton) 1365 for tp in visitor.types: 1366 module_state.add(tp) 1367 state_strings = sorted(state_strings) 1368 module_state = sorted(module_state) 1369 f.write('typedef struct {\n') 1370 f.write(' int initialized;\n') 1371 for s in module_state: 1372 f.write(' PyObject *' + s + ';\n') 1373 f.write('} astmodulestate;\n\n') 1374 f.write(""" 1375// Forward declaration 1376static int init_types(astmodulestate *state); 1377 1378// bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state. 1379static astmodulestate global_ast_state = {0}; 1380 1381static astmodulestate* 1382get_global_ast_state(void) 1383{ 1384 astmodulestate* state = &global_ast_state; 1385 if (!init_types(state)) { 1386 return NULL; 1387 } 1388 return state; 1389} 1390 1391static astmodulestate* 1392get_ast_state(PyObject* Py_UNUSED(module)) 1393{ 1394 astmodulestate* state = get_global_ast_state(); 1395 // get_ast_state() must only be called after _ast module is imported, 1396 // and astmodule_exec() calls init_types() 1397 assert(state != NULL); 1398 return state; 1399} 1400 1401void _PyAST_Fini() 1402{ 1403 astmodulestate* state = &global_ast_state; 1404""") 1405 for s in module_state: 1406 f.write(" Py_CLEAR(state->" + s + ');\n') 1407 f.write(""" 1408 state->initialized = 0; 1409} 1410 1411""") 1412 f.write('static int init_identifiers(astmodulestate *state)\n') 1413 f.write('{\n') 1414 for identifier in state_strings: 1415 f.write(' if ((state->' + identifier) 1416 f.write(' = PyUnicode_InternFromString("') 1417 f.write(identifier + '")) == NULL) return 0;\n') 1418 f.write(' return 1;\n') 1419 f.write('};\n\n') 1420 1421def write_header(f, mod): 1422 f.write('#ifndef Py_PYTHON_AST_H\n') 1423 f.write('#define Py_PYTHON_AST_H\n') 1424 f.write('#ifdef __cplusplus\n') 1425 f.write('extern "C" {\n') 1426 f.write('#endif\n') 1427 f.write('\n') 1428 f.write('#ifndef Py_LIMITED_API\n') 1429 f.write('#include "asdl.h"\n') 1430 f.write('\n') 1431 f.write('#undef Yield /* undefine macro conflicting with <winbase.h> */\n') 1432 f.write('\n') 1433 c = ChainOfVisitors(TypeDefVisitor(f), 1434 StructVisitor(f)) 1435 c.visit(mod) 1436 f.write("// Note: these macros affect function definitions, not only call sites.\n") 1437 PrototypeVisitor(f).visit(mod) 1438 f.write("\n") 1439 f.write("PyObject* PyAST_mod2obj(mod_ty t);\n") 1440 f.write("mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode);\n") 1441 f.write("int PyAST_Check(PyObject* obj);\n") 1442 f.write("#endif /* !Py_LIMITED_API */\n") 1443 f.write('\n') 1444 f.write('#ifdef __cplusplus\n') 1445 f.write('}\n') 1446 f.write('#endif\n') 1447 f.write('#endif /* !Py_PYTHON_AST_H */\n') 1448 1449def write_source(f, mod): 1450 f.write('#include <stddef.h>\n') 1451 f.write('\n') 1452 f.write('#include "Python.h"\n') 1453 f.write('#include "%s-ast.h"\n' % mod.name) 1454 f.write('#include "structmember.h" // PyMemberDef\n') 1455 f.write('\n') 1456 1457 generate_module_def(f, mod) 1458 1459 v = ChainOfVisitors( 1460 PyTypesDeclareVisitor(f), 1461 PyTypesVisitor(f), 1462 Obj2ModPrototypeVisitor(f), 1463 FunctionVisitor(f), 1464 ObjVisitor(f), 1465 Obj2ModVisitor(f), 1466 ASTModuleVisitor(f), 1467 PartingShots(f), 1468 ) 1469 v.visit(mod) 1470 1471def main(input_file, c_file, h_file, dump_module=False): 1472 auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) 1473 mod = asdl.parse(input_file) 1474 if dump_module: 1475 print('Parsed Module:') 1476 print(mod) 1477 if not asdl.check(mod): 1478 sys.exit(1) 1479 for file, writer in (c_file, write_source), (h_file, write_header): 1480 if file is not None: 1481 with file.open("w") as f: 1482 f.write(auto_gen_msg) 1483 writer(f, mod) 1484 print(file, "regenerated.") 1485 1486if __name__ == "__main__": 1487 parser = ArgumentParser() 1488 parser.add_argument("input_file", type=Path) 1489 parser.add_argument("-C", "--c-file", type=Path, default=None) 1490 parser.add_argument("-H", "--h-file", type=Path, default=None) 1491 parser.add_argument("-d", "--dump-module", action="store_true") 1492 1493 options = parser.parse_args() 1494 main(**vars(options)) 1495