1#! /usr/bin/env python 2"""Generate C code from an ASDL description.""" 3 4import os 5import sys 6import textwrap 7 8from argparse import ArgumentParser 9from contextlib import contextmanager 10from pathlib import Path 11 12import asdl 13 14TABSIZE = 4 15MAX_COL = 80 16AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n" 17 18def get_c_type(name): 19 """Return a string for the C name of the type. 20 21 This function special cases the default types provided by asdl. 22 """ 23 if name in asdl.builtin_types: 24 return name 25 else: 26 return "%s_ty" % name 27 28def reflow_lines(s, depth): 29 """Reflow the line s indented depth tabs. 30 31 Return a sequence of lines where no line extends beyond MAX_COL 32 when properly indented. The first line is properly indented based 33 exclusively on depth * TABSIZE. All following lines -- these are 34 the reflowed lines generated by this function -- start at the same 35 column as the first character beyond the opening { in the first 36 line. 37 """ 38 size = MAX_COL - depth * TABSIZE 39 if len(s) < size: 40 return [s] 41 42 lines = [] 43 cur = s 44 padding = "" 45 while len(cur) > size: 46 i = cur.rfind(' ', 0, size) 47 # XXX this should be fixed for real 48 if i == -1 and 'GeneratorExp' in cur: 49 i = size + 3 50 assert i != -1, "Impossible line %d to reflow: %r" % (size, s) 51 lines.append(padding + cur[:i]) 52 if len(lines) == 1: 53 # find new size based on brace 54 j = cur.find('{', 0, i) 55 if j >= 0: 56 j += 2 # account for the brace and the space after it 57 size -= j 58 padding = " " * j 59 else: 60 j = cur.find('(', 0, i) 61 if j >= 0: 62 j += 1 # account for the paren (no space after it) 63 size -= j 64 padding = " " * j 65 cur = cur[i+1:] 66 else: 67 lines.append(padding + cur) 68 return lines 69 70def reflow_c_string(s, depth): 71 return '"%s"' % s.replace('\n', '\\n"\n%s"' % (' ' * depth * TABSIZE)) 72 73def is_simple(sum): 74 """Return True if a sum is a simple. 75 76 A sum is simple if its types have no fields, e.g. 77 unaryop = Invert | Not | UAdd | USub 78 """ 79 for t in sum.types: 80 if t.fields: 81 return False 82 return True 83 84def asdl_of(name, obj): 85 if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor): 86 fields = ", ".join(map(str, obj.fields)) 87 if fields: 88 fields = "({})".format(fields) 89 return "{}{}".format(name, fields) 90 else: 91 if is_simple(obj): 92 types = " | ".join(type.name for type in obj.types) 93 else: 94 sep = "\n{}| ".format(" " * (len(name) + 1)) 95 types = sep.join( 96 asdl_of(type.name, type) for type in obj.types 97 ) 98 return "{} = {}".format(name, types) 99 100class EmitVisitor(asdl.VisitorBase): 101 """Visit that emits lines""" 102 103 def __init__(self, file): 104 self.file = file 105 self.identifiers = set() 106 self.singletons = set() 107 self.types = set() 108 super(EmitVisitor, self).__init__() 109 110 def emit_identifier(self, name): 111 self.identifiers.add(str(name)) 112 113 def emit_singleton(self, name): 114 self.singletons.add(str(name)) 115 116 def emit_type(self, name): 117 self.types.add(str(name)) 118 119 def emit(self, s, depth, reflow=True): 120 # XXX reflow long lines? 121 if reflow: 122 lines = reflow_lines(s, depth) 123 else: 124 lines = [s] 125 for line in lines: 126 if line: 127 line = (" " * TABSIZE * depth) + line 128 self.file.write(line + "\n") 129 130 131class TypeDefVisitor(EmitVisitor): 132 def visitModule(self, mod): 133 for dfn in mod.dfns: 134 self.visit(dfn) 135 136 def visitType(self, type, depth=0): 137 self.visit(type.value, type.name, depth) 138 139 def visitSum(self, sum, name, depth): 140 if is_simple(sum): 141 self.simple_sum(sum, name, depth) 142 else: 143 self.sum_with_constructors(sum, name, depth) 144 145 def simple_sum(self, sum, name, depth): 146 enum = [] 147 for i in range(len(sum.types)): 148 type = sum.types[i] 149 enum.append("%s=%d" % (type.name, i + 1)) 150 enums = ", ".join(enum) 151 ctype = get_c_type(name) 152 s = "typedef enum _%s { %s } %s;" % (name, enums, ctype) 153 self.emit(s, depth) 154 self.emit("", depth) 155 156 def sum_with_constructors(self, sum, name, depth): 157 ctype = get_c_type(name) 158 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 159 self.emit(s, depth) 160 self.emit("", depth) 161 162 def visitProduct(self, product, name, depth): 163 ctype = get_c_type(name) 164 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 165 self.emit(s, depth) 166 self.emit("", depth) 167 168class SequenceDefVisitor(EmitVisitor): 169 def visitModule(self, mod): 170 for dfn in mod.dfns: 171 self.visit(dfn) 172 173 def visitType(self, type, depth=0): 174 self.visit(type.value, type.name, depth) 175 176 def visitSum(self, sum, name, depth): 177 if is_simple(sum): 178 return 179 self.emit_sequence_constructor(name, depth) 180 181 def emit_sequence_constructor(self, name,depth): 182 ctype = get_c_type(name) 183 self.emit("""\ 184typedef struct { 185 _ASDL_SEQ_HEAD 186 %(ctype)s typed_elements[1]; 187} asdl_%(name)s_seq;""" % locals(), reflow=False, depth=depth) 188 self.emit("", depth) 189 self.emit("asdl_%(name)s_seq *_Py_asdl_%(name)s_seq_new(Py_ssize_t size, PyArena *arena);" % locals(), depth) 190 self.emit("", depth) 191 192 def visitProduct(self, product, name, depth): 193 self.emit_sequence_constructor(name, depth) 194 195class StructVisitor(EmitVisitor): 196 """Visitor to generate typedefs for AST.""" 197 198 def visitModule(self, mod): 199 for dfn in mod.dfns: 200 self.visit(dfn) 201 202 def visitType(self, type, depth=0): 203 self.visit(type.value, type.name, depth) 204 205 def visitSum(self, sum, name, depth): 206 if not is_simple(sum): 207 self.sum_with_constructors(sum, name, depth) 208 209 def sum_with_constructors(self, sum, name, depth): 210 def emit(s, depth=depth): 211 self.emit(s % sys._getframe(1).f_locals, depth) 212 enum = [] 213 for i in range(len(sum.types)): 214 type = sum.types[i] 215 enum.append("%s_kind=%d" % (type.name, i + 1)) 216 217 emit("enum _%(name)s_kind {" + ", ".join(enum) + "};") 218 219 emit("struct _%(name)s {") 220 emit("enum _%(name)s_kind kind;", depth + 1) 221 emit("union {", depth + 1) 222 for t in sum.types: 223 self.visit(t, depth + 2) 224 emit("} v;", depth + 1) 225 for field in sum.attributes: 226 # rudimentary attribute handling 227 type = str(field.type) 228 assert type in asdl.builtin_types, type 229 emit("%s %s;" % (type, field.name), depth + 1); 230 emit("};") 231 emit("") 232 233 def visitConstructor(self, cons, depth): 234 if cons.fields: 235 self.emit("struct {", depth) 236 for f in cons.fields: 237 self.visit(f, depth + 1) 238 self.emit("} %s;" % cons.name, depth) 239 self.emit("", depth) 240 241 def visitField(self, field, depth): 242 # XXX need to lookup field.type, because it might be something 243 # like a builtin... 244 ctype = get_c_type(field.type) 245 name = field.name 246 if field.seq: 247 if field.type == 'cmpop': 248 self.emit("asdl_int_seq *%(name)s;" % locals(), depth) 249 else: 250 _type = field.type 251 self.emit("asdl_%(_type)s_seq *%(name)s;" % locals(), depth) 252 else: 253 self.emit("%(ctype)s %(name)s;" % locals(), depth) 254 255 def visitProduct(self, product, name, depth): 256 self.emit("struct _%(name)s {" % locals(), depth) 257 for f in product.fields: 258 self.visit(f, depth + 1) 259 for field in product.attributes: 260 # rudimentary attribute handling 261 type = str(field.type) 262 assert type in asdl.builtin_types, type 263 self.emit("%s %s;" % (type, field.name), depth + 1); 264 self.emit("};", depth) 265 self.emit("", depth) 266 267 268def ast_func_name(name): 269 return f"_PyAST_{name}" 270 271 272class PrototypeVisitor(EmitVisitor): 273 """Generate function prototypes for the .h file""" 274 275 def visitModule(self, mod): 276 for dfn in mod.dfns: 277 self.visit(dfn) 278 279 def visitType(self, type): 280 self.visit(type.value, type.name) 281 282 def visitSum(self, sum, name): 283 if is_simple(sum): 284 pass # XXX 285 else: 286 for t in sum.types: 287 self.visit(t, name, sum.attributes) 288 289 def get_args(self, fields): 290 """Return list of C argument into, one for each field. 291 292 Argument info is 3-tuple of a C type, variable name, and flag 293 that is true if type can be NULL. 294 """ 295 args = [] 296 unnamed = {} 297 for f in fields: 298 if f.name is None: 299 name = f.type 300 c = unnamed[name] = unnamed.get(name, 0) + 1 301 if c > 1: 302 name = "name%d" % (c - 1) 303 else: 304 name = f.name 305 # XXX should extend get_c_type() to handle this 306 if f.seq: 307 if f.type == 'cmpop': 308 ctype = "asdl_int_seq *" 309 else: 310 ctype = f"asdl_{f.type}_seq *" 311 else: 312 ctype = get_c_type(f.type) 313 args.append((ctype, name, f.opt or f.seq)) 314 return args 315 316 def visitConstructor(self, cons, type, attrs): 317 args = self.get_args(cons.fields) 318 attrs = self.get_args(attrs) 319 ctype = get_c_type(type) 320 self.emit_function(cons.name, ctype, args, attrs) 321 322 def emit_function(self, name, ctype, args, attrs, union=True): 323 args = args + attrs 324 if args: 325 argstr = ", ".join(["%s %s" % (atype, aname) 326 for atype, aname, opt in args]) 327 argstr += ", PyArena *arena" 328 else: 329 argstr = "PyArena *arena" 330 self.emit("%s %s(%s);" % (ctype, ast_func_name(name), argstr), False) 331 332 def visitProduct(self, prod, name): 333 self.emit_function(name, get_c_type(name), 334 self.get_args(prod.fields), 335 self.get_args(prod.attributes), 336 union=False) 337 338 339class FunctionVisitor(PrototypeVisitor): 340 """Visitor to generate constructor functions for AST.""" 341 342 def emit_function(self, name, ctype, args, attrs, union=True): 343 def emit(s, depth=0, reflow=True): 344 self.emit(s, depth, reflow) 345 argstr = ", ".join(["%s %s" % (atype, aname) 346 for atype, aname, opt in args + attrs]) 347 if argstr: 348 argstr += ", PyArena *arena" 349 else: 350 argstr = "PyArena *arena" 351 self.emit("%s" % ctype, 0) 352 emit("%s(%s)" % (ast_func_name(name), argstr)) 353 emit("{") 354 emit("%s p;" % ctype, 1) 355 for argtype, argname, opt in args: 356 if not opt and argtype != "int": 357 emit("if (!%s) {" % argname, 1) 358 emit("PyErr_SetString(PyExc_ValueError,", 2) 359 msg = "field '%s' is required for %s" % (argname, name) 360 emit(' "%s");' % msg, 361 2, reflow=False) 362 emit('return NULL;', 2) 363 emit('}', 1) 364 365 emit("p = (%s)_PyArena_Malloc(arena, sizeof(*p));" % ctype, 1); 366 emit("if (!p)", 1) 367 emit("return NULL;", 2) 368 if union: 369 self.emit_body_union(name, args, attrs) 370 else: 371 self.emit_body_struct(name, args, attrs) 372 emit("return p;", 1) 373 emit("}") 374 emit("") 375 376 def emit_body_union(self, name, args, attrs): 377 def emit(s, depth=0, reflow=True): 378 self.emit(s, depth, reflow) 379 emit("p->kind = %s_kind;" % name, 1) 380 for argtype, argname, opt in args: 381 emit("p->v.%s.%s = %s;" % (name, argname, argname), 1) 382 for argtype, argname, opt in attrs: 383 emit("p->%s = %s;" % (argname, argname), 1) 384 385 def emit_body_struct(self, name, args, attrs): 386 def emit(s, depth=0, reflow=True): 387 self.emit(s, depth, reflow) 388 for argtype, argname, opt in args: 389 emit("p->%s = %s;" % (argname, argname), 1) 390 for argtype, argname, opt in attrs: 391 emit("p->%s = %s;" % (argname, argname), 1) 392 393 394class PickleVisitor(EmitVisitor): 395 396 def visitModule(self, mod): 397 for dfn in mod.dfns: 398 self.visit(dfn) 399 400 def visitType(self, type): 401 self.visit(type.value, type.name) 402 403 def visitSum(self, sum, name): 404 pass 405 406 def visitProduct(self, sum, name): 407 pass 408 409 def visitConstructor(self, cons, name): 410 pass 411 412 def visitField(self, sum): 413 pass 414 415 416class Obj2ModPrototypeVisitor(PickleVisitor): 417 def visitProduct(self, prod, name): 418 code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena);" 419 self.emit(code % (name, get_c_type(name)), 0) 420 421 visitSum = visitProduct 422 423 424class Obj2ModVisitor(PickleVisitor): 425 @contextmanager 426 def recursive_call(self, node, level): 427 self.emit('if (Py_EnterRecursiveCall(" while traversing \'%s\' node")) {' % node, level, reflow=False) 428 self.emit('goto failed;', level + 1) 429 self.emit('}', level) 430 yield 431 self.emit('Py_LeaveRecursiveCall();', level) 432 433 def funcHeader(self, name): 434 ctype = get_c_type(name) 435 self.emit("int", 0) 436 self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 437 self.emit("{", 0) 438 self.emit("int isinstance;", 1) 439 self.emit("", 0) 440 441 def sumTrailer(self, name, add_label=False): 442 self.emit("", 0) 443 # there's really nothing more we can do if this fails ... 444 error = "expected some sort of %s, but got %%R" % name 445 format = "PyErr_Format(PyExc_TypeError, \"%s\", obj);" 446 self.emit(format % error, 1, reflow=False) 447 if add_label: 448 self.emit("failed:", 1) 449 self.emit("Py_XDECREF(tmp);", 1) 450 self.emit("return 1;", 1) 451 self.emit("}", 0) 452 self.emit("", 0) 453 454 def simpleSum(self, sum, name): 455 self.funcHeader(name) 456 for t in sum.types: 457 line = ("isinstance = PyObject_IsInstance(obj, " 458 "state->%s_type);") 459 self.emit(line % (t.name,), 1) 460 self.emit("if (isinstance == -1) {", 1) 461 self.emit("return 1;", 2) 462 self.emit("}", 1) 463 self.emit("if (isinstance) {", 1) 464 self.emit("*out = %s;" % t.name, 2) 465 self.emit("return 0;", 2) 466 self.emit("}", 1) 467 self.sumTrailer(name) 468 469 def buildArgs(self, fields): 470 return ", ".join(fields + ["arena"]) 471 472 def complexSum(self, sum, name): 473 self.funcHeader(name) 474 self.emit("PyObject *tmp = NULL;", 1) 475 self.emit("PyObject *tp;", 1) 476 for a in sum.attributes: 477 self.visitAttributeDeclaration(a, name, sum=sum) 478 self.emit("", 0) 479 # XXX: should we only do this for 'expr'? 480 self.emit("if (obj == Py_None) {", 1) 481 self.emit("*out = NULL;", 2) 482 self.emit("return 0;", 2) 483 self.emit("}", 1) 484 for a in sum.attributes: 485 self.visitField(a, name, sum=sum, depth=1) 486 for t in sum.types: 487 self.emit("tp = state->%s_type;" % (t.name,), 1) 488 self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1) 489 self.emit("if (isinstance == -1) {", 1) 490 self.emit("return 1;", 2) 491 self.emit("}", 1) 492 self.emit("if (isinstance) {", 1) 493 for f in t.fields: 494 self.visitFieldDeclaration(f, t.name, sum=sum, depth=2) 495 self.emit("", 0) 496 for f in t.fields: 497 self.visitField(f, t.name, sum=sum, depth=2) 498 args = [f.name for f in t.fields] + [a.name for a in sum.attributes] 499 self.emit("*out = %s(%s);" % (ast_func_name(t.name), self.buildArgs(args)), 2) 500 self.emit("if (*out == NULL) goto failed;", 2) 501 self.emit("return 0;", 2) 502 self.emit("}", 1) 503 self.sumTrailer(name, True) 504 505 def visitAttributeDeclaration(self, a, name, sum=sum): 506 ctype = get_c_type(a.type) 507 self.emit("%s %s;" % (ctype, a.name), 1) 508 509 def visitSum(self, sum, name): 510 if is_simple(sum): 511 self.simpleSum(sum, name) 512 else: 513 self.complexSum(sum, name) 514 515 def visitProduct(self, prod, name): 516 ctype = get_c_type(name) 517 self.emit("int", 0) 518 self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 519 self.emit("{", 0) 520 self.emit("PyObject* tmp = NULL;", 1) 521 for f in prod.fields: 522 self.visitFieldDeclaration(f, name, prod=prod, depth=1) 523 for a in prod.attributes: 524 self.visitFieldDeclaration(a, name, prod=prod, depth=1) 525 self.emit("", 0) 526 for f in prod.fields: 527 self.visitField(f, name, prod=prod, depth=1) 528 for a in prod.attributes: 529 self.visitField(a, name, prod=prod, depth=1) 530 args = [f.name for f in prod.fields] 531 args.extend([a.name for a in prod.attributes]) 532 self.emit("*out = %s(%s);" % (ast_func_name(name), self.buildArgs(args)), 1) 533 self.emit("return 0;", 1) 534 self.emit("failed:", 0) 535 self.emit("Py_XDECREF(tmp);", 1) 536 self.emit("return 1;", 1) 537 self.emit("}", 0) 538 self.emit("", 0) 539 540 def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0): 541 ctype = get_c_type(field.type) 542 if field.seq: 543 if self.isSimpleType(field): 544 self.emit("asdl_int_seq* %s;" % field.name, depth) 545 else: 546 _type = field.type 547 self.emit(f"asdl_{field.type}_seq* {field.name};", depth) 548 else: 549 ctype = get_c_type(field.type) 550 self.emit("%s %s;" % (ctype, field.name), depth) 551 552 def isSimpleSum(self, field): 553 # XXX can the members of this list be determined automatically? 554 return field.type in ('expr_context', 'boolop', 'operator', 555 'unaryop', 'cmpop') 556 557 def isNumeric(self, field): 558 return get_c_type(field.type) in ("int", "bool") 559 560 def isSimpleType(self, field): 561 return self.isSimpleSum(field) or self.isNumeric(field) 562 563 def visitField(self, field, name, sum=None, prod=None, depth=0): 564 ctype = get_c_type(field.type) 565 line = "if (_PyObject_LookupAttr(obj, state->%s, &tmp) < 0) {" 566 self.emit(line % field.name, depth) 567 self.emit("return 1;", depth+1) 568 self.emit("}", depth) 569 if not field.opt: 570 self.emit("if (tmp == NULL) {", depth) 571 message = "required field \\\"%s\\\" missing from %s" % (field.name, name) 572 format = "PyErr_SetString(PyExc_TypeError, \"%s\");" 573 self.emit(format % message, depth+1, reflow=False) 574 self.emit("return 1;", depth+1) 575 else: 576 self.emit("if (tmp == NULL || tmp == Py_None) {", depth) 577 self.emit("Py_CLEAR(tmp);", depth+1) 578 if self.isNumeric(field): 579 self.emit("%s = 0;" % field.name, depth+1) 580 elif not self.isSimpleType(field): 581 self.emit("%s = NULL;" % field.name, depth+1) 582 else: 583 raise TypeError("could not determine the default value for %s" % field.name) 584 self.emit("}", depth) 585 self.emit("else {", depth) 586 587 self.emit("int res;", depth+1) 588 if field.seq: 589 self.emit("Py_ssize_t len;", depth+1) 590 self.emit("Py_ssize_t i;", depth+1) 591 self.emit("if (!PyList_Check(tmp)) {", depth+1) 592 self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must " 593 "be a list, not a %%.200s\", _PyType_Name(Py_TYPE(tmp)));" % 594 (name, field.name), 595 depth+2, reflow=False) 596 self.emit("goto failed;", depth+2) 597 self.emit("}", depth+1) 598 self.emit("len = PyList_GET_SIZE(tmp);", depth+1) 599 if self.isSimpleType(field): 600 self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1) 601 else: 602 self.emit("%s = _Py_asdl_%s_seq_new(len, arena);" % (field.name, field.type), depth+1) 603 self.emit("if (%s == NULL) goto failed;" % field.name, depth+1) 604 self.emit("for (i = 0; i < len; i++) {", depth+1) 605 self.emit("%s val;" % ctype, depth+2) 606 self.emit("PyObject *tmp2 = PyList_GET_ITEM(tmp, i);", depth+2) 607 self.emit("Py_INCREF(tmp2);", depth+2) 608 with self.recursive_call(name, depth+2): 609 self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" % 610 field.type, depth+2, reflow=False) 611 self.emit("Py_DECREF(tmp2);", depth+2) 612 self.emit("if (res != 0) goto failed;", depth+2) 613 self.emit("if (len != PyList_GET_SIZE(tmp)) {", depth+2) 614 self.emit("PyErr_SetString(PyExc_RuntimeError, \"%s field \\\"%s\\\" " 615 "changed size during iteration\");" % 616 (name, field.name), 617 depth+3, reflow=False) 618 self.emit("goto failed;", depth+3) 619 self.emit("}", depth+2) 620 self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2) 621 self.emit("}", depth+1) 622 else: 623 with self.recursive_call(name, depth+1): 624 self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" % 625 (field.type, field.name), depth+1) 626 self.emit("if (res != 0) goto failed;", depth+1) 627 628 self.emit("Py_CLEAR(tmp);", depth+1) 629 self.emit("}", depth) 630 631 632class SequenceConstructorVisitor(EmitVisitor): 633 def visitModule(self, mod): 634 for dfn in mod.dfns: 635 self.visit(dfn) 636 637 def visitType(self, type): 638 self.visit(type.value, type.name) 639 640 def visitProduct(self, prod, name): 641 self.emit_sequence_constructor(name, get_c_type(name)) 642 643 def visitSum(self, sum, name): 644 if not is_simple(sum): 645 self.emit_sequence_constructor(name, get_c_type(name)) 646 647 def emit_sequence_constructor(self, name, type): 648 self.emit(f"GENERATE_ASDL_SEQ_CONSTRUCTOR({name}, {type})", depth=0) 649 650class PyTypesDeclareVisitor(PickleVisitor): 651 652 def visitProduct(self, prod, name): 653 self.emit_type("%s_type" % name) 654 self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0) 655 if prod.attributes: 656 for a in prod.attributes: 657 self.emit_identifier(a.name) 658 self.emit("static const char * const %s_attributes[] = {" % name, 0) 659 for a in prod.attributes: 660 self.emit('"%s",' % a.name, 1) 661 self.emit("};", 0) 662 if prod.fields: 663 for f in prod.fields: 664 self.emit_identifier(f.name) 665 self.emit("static const char * const %s_fields[]={" % name,0) 666 for f in prod.fields: 667 self.emit('"%s",' % f.name, 1) 668 self.emit("};", 0) 669 670 def visitSum(self, sum, name): 671 self.emit_type("%s_type" % name) 672 if sum.attributes: 673 for a in sum.attributes: 674 self.emit_identifier(a.name) 675 self.emit("static const char * const %s_attributes[] = {" % name, 0) 676 for a in sum.attributes: 677 self.emit('"%s",' % a.name, 1) 678 self.emit("};", 0) 679 ptype = "void*" 680 if is_simple(sum): 681 ptype = get_c_type(name) 682 for t in sum.types: 683 self.emit_singleton("%s_singleton" % t.name) 684 self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0) 685 for t in sum.types: 686 self.visitConstructor(t, name) 687 688 def visitConstructor(self, cons, name): 689 if cons.fields: 690 for t in cons.fields: 691 self.emit_identifier(t.name) 692 self.emit("static const char * const %s_fields[]={" % cons.name, 0) 693 for t in cons.fields: 694 self.emit('"%s",' % t.name, 1) 695 self.emit("};",0) 696 697 698class PyTypesVisitor(PickleVisitor): 699 700 def visitModule(self, mod): 701 self.emit(""" 702 703typedef struct { 704 PyObject_HEAD 705 PyObject *dict; 706} AST_object; 707 708static void 709ast_dealloc(AST_object *self) 710{ 711 /* bpo-31095: UnTrack is needed before calling any callbacks */ 712 PyTypeObject *tp = Py_TYPE(self); 713 PyObject_GC_UnTrack(self); 714 Py_CLEAR(self->dict); 715 freefunc free_func = PyType_GetSlot(tp, Py_tp_free); 716 assert(free_func != NULL); 717 free_func(self); 718 Py_DECREF(tp); 719} 720 721static int 722ast_traverse(AST_object *self, visitproc visit, void *arg) 723{ 724 Py_VISIT(Py_TYPE(self)); 725 Py_VISIT(self->dict); 726 return 0; 727} 728 729static int 730ast_clear(AST_object *self) 731{ 732 Py_CLEAR(self->dict); 733 return 0; 734} 735 736static int 737ast_type_init(PyObject *self, PyObject *args, PyObject *kw) 738{ 739 struct ast_state *state = get_ast_state(); 740 if (state == NULL) { 741 return -1; 742 } 743 744 Py_ssize_t i, numfields = 0; 745 int res = -1; 746 PyObject *key, *value, *fields; 747 if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { 748 goto cleanup; 749 } 750 if (fields) { 751 numfields = PySequence_Size(fields); 752 if (numfields == -1) { 753 goto cleanup; 754 } 755 } 756 757 res = 0; /* if no error occurs, this stays 0 to the end */ 758 if (numfields < PyTuple_GET_SIZE(args)) { 759 PyErr_Format(PyExc_TypeError, "%.400s constructor takes at most " 760 "%zd positional argument%s", 761 _PyType_Name(Py_TYPE(self)), 762 numfields, numfields == 1 ? "" : "s"); 763 res = -1; 764 goto cleanup; 765 } 766 for (i = 0; i < PyTuple_GET_SIZE(args); i++) { 767 /* cannot be reached when fields is NULL */ 768 PyObject *name = PySequence_GetItem(fields, i); 769 if (!name) { 770 res = -1; 771 goto cleanup; 772 } 773 res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i)); 774 Py_DECREF(name); 775 if (res < 0) { 776 goto cleanup; 777 } 778 } 779 if (kw) { 780 i = 0; /* needed by PyDict_Next */ 781 while (PyDict_Next(kw, &i, &key, &value)) { 782 int contains = PySequence_Contains(fields, key); 783 if (contains == -1) { 784 res = -1; 785 goto cleanup; 786 } else if (contains == 1) { 787 Py_ssize_t p = PySequence_Index(fields, key); 788 if (p == -1) { 789 res = -1; 790 goto cleanup; 791 } 792 if (p < PyTuple_GET_SIZE(args)) { 793 PyErr_Format(PyExc_TypeError, 794 "%.400s got multiple values for argument '%U'", 795 Py_TYPE(self)->tp_name, key); 796 res = -1; 797 goto cleanup; 798 } 799 } 800 res = PyObject_SetAttr(self, key, value); 801 if (res < 0) { 802 goto cleanup; 803 } 804 } 805 } 806 cleanup: 807 Py_XDECREF(fields); 808 return res; 809} 810 811/* Pickling support */ 812static PyObject * 813ast_type_reduce(PyObject *self, PyObject *unused) 814{ 815 struct ast_state *state = get_ast_state(); 816 if (state == NULL) { 817 return NULL; 818 } 819 820 PyObject *dict; 821 if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) { 822 return NULL; 823 } 824 if (dict) { 825 return Py_BuildValue("O()N", Py_TYPE(self), dict); 826 } 827 return Py_BuildValue("O()", Py_TYPE(self)); 828} 829 830static PyMemberDef ast_type_members[] = { 831 {"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY}, 832 {NULL} /* Sentinel */ 833}; 834 835static PyMethodDef ast_type_methods[] = { 836 {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, 837 {NULL} 838}; 839 840static PyGetSetDef ast_type_getsets[] = { 841 {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, 842 {NULL} 843}; 844 845static PyType_Slot AST_type_slots[] = { 846 {Py_tp_dealloc, ast_dealloc}, 847 {Py_tp_getattro, PyObject_GenericGetAttr}, 848 {Py_tp_setattro, PyObject_GenericSetAttr}, 849 {Py_tp_traverse, ast_traverse}, 850 {Py_tp_clear, ast_clear}, 851 {Py_tp_members, ast_type_members}, 852 {Py_tp_methods, ast_type_methods}, 853 {Py_tp_getset, ast_type_getsets}, 854 {Py_tp_init, ast_type_init}, 855 {Py_tp_alloc, PyType_GenericAlloc}, 856 {Py_tp_new, PyType_GenericNew}, 857 {Py_tp_free, PyObject_GC_Del}, 858 {0, 0}, 859}; 860 861static PyType_Spec AST_type_spec = { 862 "ast.AST", 863 sizeof(AST_object), 864 0, 865 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, 866 AST_type_slots 867}; 868 869static PyObject * 870make_type(struct ast_state *state, const char *type, PyObject* base, 871 const char* const* fields, int num_fields, const char *doc) 872{ 873 PyObject *fnames, *result; 874 int i; 875 fnames = PyTuple_New(num_fields); 876 if (!fnames) return NULL; 877 for (i = 0; i < num_fields; i++) { 878 PyObject *field = PyUnicode_InternFromString(fields[i]); 879 if (!field) { 880 Py_DECREF(fnames); 881 return NULL; 882 } 883 PyTuple_SET_ITEM(fnames, i, field); 884 } 885 result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOOOs}", 886 type, base, 887 state->_fields, fnames, 888 state->__match_args__, fnames, 889 state->__module__, 890 state->ast, 891 state->__doc__, doc); 892 Py_DECREF(fnames); 893 return result; 894} 895 896static int 897add_attributes(struct ast_state *state, PyObject *type, const char * const *attrs, int num_fields) 898{ 899 int i, result; 900 PyObject *s, *l = PyTuple_New(num_fields); 901 if (!l) 902 return 0; 903 for (i = 0; i < num_fields; i++) { 904 s = PyUnicode_InternFromString(attrs[i]); 905 if (!s) { 906 Py_DECREF(l); 907 return 0; 908 } 909 PyTuple_SET_ITEM(l, i, s); 910 } 911 result = PyObject_SetAttr(type, state->_attributes, l) >= 0; 912 Py_DECREF(l); 913 return result; 914} 915 916/* Conversion AST -> Python */ 917 918static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*)) 919{ 920 Py_ssize_t i, n = asdl_seq_LEN(seq); 921 PyObject *result = PyList_New(n); 922 PyObject *value; 923 if (!result) 924 return NULL; 925 for (i = 0; i < n; i++) { 926 value = func(state, asdl_seq_GET_UNTYPED(seq, i)); 927 if (!value) { 928 Py_DECREF(result); 929 return NULL; 930 } 931 PyList_SET_ITEM(result, i, value); 932 } 933 return result; 934} 935 936static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o) 937{ 938 if (!o) 939 o = Py_None; 940 Py_INCREF((PyObject*)o); 941 return (PyObject*)o; 942} 943#define ast2obj_constant ast2obj_object 944#define ast2obj_identifier ast2obj_object 945#define ast2obj_string ast2obj_object 946 947static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b) 948{ 949 return PyLong_FromLong(b); 950} 951 952/* Conversion Python -> AST */ 953 954static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 955{ 956 if (obj == Py_None) 957 obj = NULL; 958 if (obj) { 959 if (_PyArena_AddPyObject(arena, obj) < 0) { 960 *out = NULL; 961 return -1; 962 } 963 Py_INCREF(obj); 964 } 965 *out = obj; 966 return 0; 967} 968 969static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 970{ 971 if (_PyArena_AddPyObject(arena, obj) < 0) { 972 *out = NULL; 973 return -1; 974 } 975 Py_INCREF(obj); 976 *out = obj; 977 return 0; 978} 979 980static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena) 981{ 982 if (!PyUnicode_CheckExact(obj) && obj != Py_None) { 983 PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str"); 984 return 1; 985 } 986 return obj2ast_object(state, obj, out, arena); 987} 988 989static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena) 990{ 991 if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) { 992 PyErr_SetString(PyExc_TypeError, "AST string must be of type str"); 993 return 1; 994 } 995 return obj2ast_object(state, obj, out, arena); 996} 997 998static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena) 999{ 1000 int i; 1001 if (!PyLong_Check(obj)) { 1002 PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj); 1003 return 1; 1004 } 1005 1006 i = _PyLong_AsInt(obj); 1007 if (i == -1 && PyErr_Occurred()) 1008 return 1; 1009 *out = i; 1010 return 0; 1011} 1012 1013static int add_ast_fields(struct ast_state *state) 1014{ 1015 PyObject *empty_tuple; 1016 empty_tuple = PyTuple_New(0); 1017 if (!empty_tuple || 1018 PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 || 1019 PyObject_SetAttrString(state->AST_type, "__match_args__", empty_tuple) < 0 || 1020 PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) { 1021 Py_XDECREF(empty_tuple); 1022 return -1; 1023 } 1024 Py_DECREF(empty_tuple); 1025 return 0; 1026} 1027 1028""", 0, reflow=False) 1029 1030 self.file.write(textwrap.dedent(''' 1031 static int 1032 init_types(struct ast_state *state) 1033 { 1034 // init_types() must not be called after _PyAST_Fini() 1035 // has been called 1036 assert(state->initialized >= 0); 1037 1038 if (state->initialized) { 1039 return 1; 1040 } 1041 if (init_identifiers(state) < 0) { 1042 return 0; 1043 } 1044 state->AST_type = PyType_FromSpec(&AST_type_spec); 1045 if (!state->AST_type) { 1046 return 0; 1047 } 1048 if (add_ast_fields(state) < 0) { 1049 return 0; 1050 } 1051 ''')) 1052 for dfn in mod.dfns: 1053 self.visit(dfn) 1054 self.file.write(textwrap.dedent(''' 1055 state->initialized = 1; 1056 return 1; 1057 } 1058 ''')) 1059 1060 def visitProduct(self, prod, name): 1061 if prod.fields: 1062 fields = name+"_fields" 1063 else: 1064 fields = "NULL" 1065 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' % 1066 (name, name, fields, len(prod.fields)), 1) 1067 self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False) 1068 self.emit("if (!state->%s_type) return 0;" % name, 1) 1069 self.emit_type("AST_type") 1070 self.emit_type("%s_type" % name) 1071 if prod.attributes: 1072 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1073 (name, name, len(prod.attributes)), 1) 1074 else: 1075 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1076 self.emit_defaults(name, prod.fields, 1) 1077 self.emit_defaults(name, prod.attributes, 1) 1078 1079 def visitSum(self, sum, name): 1080 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' % 1081 (name, name), 1) 1082 self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False) 1083 self.emit_type("%s_type" % name) 1084 self.emit("if (!state->%s_type) return 0;" % name, 1) 1085 if sum.attributes: 1086 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1087 (name, name, len(sum.attributes)), 1) 1088 else: 1089 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1090 self.emit_defaults(name, sum.attributes, 1) 1091 simple = is_simple(sum) 1092 for t in sum.types: 1093 self.visitConstructor(t, name, simple) 1094 1095 def visitConstructor(self, cons, name, simple): 1096 if cons.fields: 1097 fields = cons.name+"_fields" 1098 else: 1099 fields = "NULL" 1100 self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' % 1101 (cons.name, cons.name, name, fields, len(cons.fields)), 1) 1102 self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False) 1103 self.emit("if (!state->%s_type) return 0;" % cons.name, 1) 1104 self.emit_type("%s_type" % cons.name) 1105 self.emit_defaults(cons.name, cons.fields, 1) 1106 if simple: 1107 self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)" 1108 "state->%s_type, NULL, NULL);" % 1109 (cons.name, cons.name), 1) 1110 self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1) 1111 1112 def emit_defaults(self, name, fields, depth): 1113 for field in fields: 1114 if field.opt: 1115 self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' % 1116 (name, field.name), depth) 1117 self.emit("return 0;", depth+1) 1118 1119 1120class ASTModuleVisitor(PickleVisitor): 1121 1122 def visitModule(self, mod): 1123 self.emit("static int", 0) 1124 self.emit("astmodule_exec(PyObject *m)", 0) 1125 self.emit("{", 0) 1126 self.emit('struct ast_state *state = get_ast_state();', 1) 1127 self.emit('if (state == NULL) {', 1) 1128 self.emit('return -1;', 2) 1129 self.emit('}', 1) 1130 self.emit('if (PyModule_AddObjectRef(m, "AST", state->AST_type) < 0) {', 1) 1131 self.emit('return -1;', 2) 1132 self.emit('}', 1) 1133 self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1) 1134 self.emit("return -1;", 2) 1135 self.emit('}', 1) 1136 self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1) 1137 self.emit("return -1;", 2) 1138 self.emit('}', 1) 1139 self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1) 1140 self.emit("return -1;", 2) 1141 self.emit('}', 1) 1142 for dfn in mod.dfns: 1143 self.visit(dfn) 1144 self.emit("return 0;", 1) 1145 self.emit("}", 0) 1146 self.emit("", 0) 1147 self.emit(""" 1148static PyModuleDef_Slot astmodule_slots[] = { 1149 {Py_mod_exec, astmodule_exec}, 1150 {0, NULL} 1151}; 1152 1153static struct PyModuleDef _astmodule = { 1154 PyModuleDef_HEAD_INIT, 1155 .m_name = "_ast", 1156 // The _ast module uses a per-interpreter state (PyInterpreterState.ast) 1157 .m_size = 0, 1158 .m_slots = astmodule_slots, 1159}; 1160 1161PyMODINIT_FUNC 1162PyInit__ast(void) 1163{ 1164 return PyModuleDef_Init(&_astmodule); 1165} 1166""".strip(), 0, reflow=False) 1167 1168 def visitProduct(self, prod, name): 1169 self.addObj(name) 1170 1171 def visitSum(self, sum, name): 1172 self.addObj(name) 1173 for t in sum.types: 1174 self.visitConstructor(t, name) 1175 1176 def visitConstructor(self, cons, name): 1177 self.addObj(cons.name) 1178 1179 def addObj(self, name): 1180 self.emit("if (PyModule_AddObjectRef(m, \"%s\", " 1181 "state->%s_type) < 0) {" % (name, name), 1) 1182 self.emit("return -1;", 2) 1183 self.emit('}', 1) 1184 1185 1186class StaticVisitor(PickleVisitor): 1187 CODE = '''Very simple, always emit this static code. Override CODE''' 1188 1189 def visit(self, object): 1190 self.emit(self.CODE, 0, reflow=False) 1191 1192 1193class ObjVisitor(PickleVisitor): 1194 1195 def func_begin(self, name): 1196 ctype = get_c_type(name) 1197 self.emit("PyObject*", 0) 1198 self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0) 1199 self.emit("{", 0) 1200 self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) 1201 self.emit("PyObject *result = NULL, *value = NULL;", 1) 1202 self.emit("PyTypeObject *tp;", 1) 1203 self.emit('if (!o) {', 1) 1204 self.emit("Py_RETURN_NONE;", 2) 1205 self.emit("}", 1) 1206 1207 def func_end(self): 1208 self.emit("return result;", 1) 1209 self.emit("failed:", 0) 1210 self.emit("Py_XDECREF(value);", 1) 1211 self.emit("Py_XDECREF(result);", 1) 1212 self.emit("return NULL;", 1) 1213 self.emit("}", 0) 1214 self.emit("", 0) 1215 1216 def visitSum(self, sum, name): 1217 if is_simple(sum): 1218 self.simpleSum(sum, name) 1219 return 1220 self.func_begin(name) 1221 self.emit("switch (o->kind) {", 1) 1222 for i in range(len(sum.types)): 1223 t = sum.types[i] 1224 self.visitConstructor(t, i + 1, name) 1225 self.emit("}", 1) 1226 for a in sum.attributes: 1227 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1228 self.emit("if (!value) goto failed;", 1) 1229 self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) 1230 self.emit('goto failed;', 2) 1231 self.emit('Py_DECREF(value);', 1) 1232 self.func_end() 1233 1234 def simpleSum(self, sum, name): 1235 self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0) 1236 self.emit("{", 0) 1237 self.emit("switch(o) {", 1) 1238 for t in sum.types: 1239 self.emit("case %s:" % t.name, 2) 1240 self.emit("Py_INCREF(state->%s_singleton);" % t.name, 3) 1241 self.emit("return state->%s_singleton;" % t.name, 3) 1242 self.emit("}", 1) 1243 self.emit("Py_UNREACHABLE();", 1); 1244 self.emit("}", 0) 1245 1246 def visitProduct(self, prod, name): 1247 self.func_begin(name) 1248 self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1) 1249 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1); 1250 self.emit("if (!result) return NULL;", 1) 1251 for field in prod.fields: 1252 self.visitField(field, name, 1, True) 1253 for a in prod.attributes: 1254 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1255 self.emit("if (!value) goto failed;", 1) 1256 self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) 1257 self.emit('goto failed;', 2) 1258 self.emit('Py_DECREF(value);', 1) 1259 self.func_end() 1260 1261 def visitConstructor(self, cons, enum, name): 1262 self.emit("case %s_kind:" % cons.name, 1) 1263 self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2) 1264 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2); 1265 self.emit("if (!result) goto failed;", 2) 1266 for f in cons.fields: 1267 self.visitField(f, cons.name, 2, False) 1268 self.emit("break;", 2) 1269 1270 def visitField(self, field, name, depth, product): 1271 def emit(s, d): 1272 self.emit(s, depth + d) 1273 if product: 1274 value = "o->%s" % field.name 1275 else: 1276 value = "o->v.%s.%s" % (name, field.name) 1277 self.set(field, value, depth) 1278 emit("if (!value) goto failed;", 0) 1279 emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0) 1280 emit("goto failed;", 1) 1281 emit("Py_DECREF(value);", 0) 1282 1283 def set(self, field, value, depth): 1284 if field.seq: 1285 # XXX should really check for is_simple, but that requires a symbol table 1286 if field.type == "cmpop": 1287 # While the sequence elements are stored as void*, 1288 # ast2obj_cmpop expects an enum 1289 self.emit("{", depth) 1290 self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1) 1291 self.emit("value = PyList_New(n);", depth+1) 1292 self.emit("if (!value) goto failed;", depth+1) 1293 self.emit("for(i = 0; i < n; i++)", depth+1) 1294 # This cannot fail, so no need for error handling 1295 self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop(state, (cmpop_ty)asdl_seq_GET(%s, i)));" % value, 1296 depth+2, reflow=False) 1297 self.emit("}", depth) 1298 else: 1299 self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) 1300 else: 1301 self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) 1302 1303 1304class PartingShots(StaticVisitor): 1305 1306 CODE = """ 1307PyObject* PyAST_mod2obj(mod_ty t) 1308{ 1309 struct ast_state *state = get_ast_state(); 1310 if (state == NULL) { 1311 return NULL; 1312 } 1313 return ast2obj_mod(state, t); 1314} 1315 1316/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */ 1317mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode) 1318{ 1319 const char * const req_name[] = {"Module", "Expression", "Interactive"}; 1320 int isinstance; 1321 1322 if (PySys_Audit("compile", "OO", ast, Py_None) < 0) { 1323 return NULL; 1324 } 1325 1326 struct ast_state *state = get_ast_state(); 1327 if (state == NULL) { 1328 return NULL; 1329 } 1330 1331 PyObject *req_type[3]; 1332 req_type[0] = state->Module_type; 1333 req_type[1] = state->Expression_type; 1334 req_type[2] = state->Interactive_type; 1335 1336 assert(0 <= mode && mode <= 2); 1337 1338 isinstance = PyObject_IsInstance(ast, req_type[mode]); 1339 if (isinstance == -1) 1340 return NULL; 1341 if (!isinstance) { 1342 PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s", 1343 req_name[mode], _PyType_Name(Py_TYPE(ast))); 1344 return NULL; 1345 } 1346 1347 mod_ty res = NULL; 1348 if (obj2ast_mod(state, ast, &res, arena) != 0) 1349 return NULL; 1350 else 1351 return res; 1352} 1353 1354int PyAST_Check(PyObject* obj) 1355{ 1356 struct ast_state *state = get_ast_state(); 1357 if (state == NULL) { 1358 return -1; 1359 } 1360 return PyObject_IsInstance(obj, state->AST_type); 1361} 1362""" 1363 1364class ChainOfVisitors: 1365 def __init__(self, *visitors): 1366 self.visitors = visitors 1367 1368 def visit(self, object): 1369 for v in self.visitors: 1370 v.visit(object) 1371 v.emit("", 0) 1372 1373 1374def generate_ast_state(module_state, f): 1375 f.write('struct ast_state {\n') 1376 f.write(' int initialized;\n') 1377 for s in module_state: 1378 f.write(' PyObject *' + s + ';\n') 1379 f.write('};') 1380 1381 1382def generate_ast_fini(module_state, f): 1383 f.write(textwrap.dedent(""" 1384 void _PyAST_Fini(PyInterpreterState *interp) 1385 { 1386 struct ast_state *state = &interp->ast; 1387 1388 """)) 1389 for s in module_state: 1390 f.write(" Py_CLEAR(state->" + s + ');\n') 1391 f.write(textwrap.dedent(""" 1392 #if !defined(NDEBUG) 1393 state->initialized = -1; 1394 #else 1395 state->initialized = 0; 1396 #endif 1397 } 1398 1399 """)) 1400 1401 1402def generate_module_def(mod, f, internal_h): 1403 # Gather all the data needed for ModuleSpec 1404 visitor_list = set() 1405 with open(os.devnull, "w") as devnull: 1406 visitor = PyTypesDeclareVisitor(devnull) 1407 visitor.visit(mod) 1408 visitor_list.add(visitor) 1409 visitor = PyTypesVisitor(devnull) 1410 visitor.visit(mod) 1411 visitor_list.add(visitor) 1412 1413 state_strings = { 1414 "ast", 1415 "_fields", 1416 "__match_args__", 1417 "__doc__", 1418 "__dict__", 1419 "__module__", 1420 "_attributes", 1421 } 1422 module_state = state_strings.copy() 1423 for visitor in visitor_list: 1424 for identifier in visitor.identifiers: 1425 module_state.add(identifier) 1426 state_strings.add(identifier) 1427 for singleton in visitor.singletons: 1428 module_state.add(singleton) 1429 for tp in visitor.types: 1430 module_state.add(tp) 1431 state_strings = sorted(state_strings) 1432 module_state = sorted(module_state) 1433 1434 generate_ast_state(module_state, internal_h) 1435 1436 print(textwrap.dedent(""" 1437 #include "Python.h" 1438 #include "pycore_ast.h" 1439 #include "pycore_ast_state.h" // struct ast_state 1440 #include "pycore_interp.h" // _PyInterpreterState.ast 1441 #include "pycore_pystate.h" // _PyInterpreterState_GET() 1442 #include "structmember.h" 1443 #include <stddef.h> 1444 1445 // Forward declaration 1446 static int init_types(struct ast_state *state); 1447 1448 static struct ast_state* 1449 get_ast_state(void) 1450 { 1451 PyInterpreterState *interp = _PyInterpreterState_GET(); 1452 struct ast_state *state = &interp->ast; 1453 if (!init_types(state)) { 1454 return NULL; 1455 } 1456 return state; 1457 } 1458 """).strip(), file=f) 1459 1460 generate_ast_fini(module_state, f) 1461 1462 f.write('static int init_identifiers(struct ast_state *state)\n') 1463 f.write('{\n') 1464 for identifier in state_strings: 1465 f.write(' if ((state->' + identifier) 1466 f.write(' = PyUnicode_InternFromString("') 1467 f.write(identifier + '")) == NULL) return 0;\n') 1468 f.write(' return 1;\n') 1469 f.write('};\n\n') 1470 1471def write_header(mod, f): 1472 f.write(textwrap.dedent(""" 1473 #ifndef Py_INTERNAL_AST_H 1474 #define Py_INTERNAL_AST_H 1475 #ifdef __cplusplus 1476 extern "C" { 1477 #endif 1478 1479 #ifndef Py_BUILD_CORE 1480 # error "this header requires Py_BUILD_CORE define" 1481 #endif 1482 1483 #include "pycore_asdl.h" 1484 1485 """).lstrip()) 1486 c = ChainOfVisitors(TypeDefVisitor(f), 1487 SequenceDefVisitor(f), 1488 StructVisitor(f)) 1489 c.visit(mod) 1490 f.write("// Note: these macros affect function definitions, not only call sites.\n") 1491 PrototypeVisitor(f).visit(mod) 1492 f.write(textwrap.dedent(""" 1493 1494 PyObject* PyAST_mod2obj(mod_ty t); 1495 mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode); 1496 int PyAST_Check(PyObject* obj); 1497 1498 extern int _PyAST_Validate(mod_ty); 1499 1500 /* _PyAST_ExprAsUnicode is defined in ast_unparse.c */ 1501 extern PyObject* _PyAST_ExprAsUnicode(expr_ty); 1502 1503 /* Return the borrowed reference to the first literal string in the 1504 sequence of statements or NULL if it doesn't start from a literal string. 1505 Doesn't set exception. */ 1506 extern PyObject* _PyAST_GetDocString(asdl_stmt_seq *); 1507 1508 #ifdef __cplusplus 1509 } 1510 #endif 1511 #endif /* !Py_INTERNAL_AST_H */ 1512 """)) 1513 1514 1515def write_internal_h_header(mod, f): 1516 print(textwrap.dedent(""" 1517 #ifndef Py_INTERNAL_AST_STATE_H 1518 #define Py_INTERNAL_AST_STATE_H 1519 #ifdef __cplusplus 1520 extern "C" { 1521 #endif 1522 1523 #ifndef Py_BUILD_CORE 1524 # error "this header requires Py_BUILD_CORE define" 1525 #endif 1526 """).lstrip(), file=f) 1527 1528 1529def write_internal_h_footer(mod, f): 1530 print(textwrap.dedent(""" 1531 1532 #ifdef __cplusplus 1533 } 1534 #endif 1535 #endif /* !Py_INTERNAL_AST_STATE_H */ 1536 """), file=f) 1537 1538 1539def write_source(mod, f, internal_h_file): 1540 generate_module_def(mod, f, internal_h_file) 1541 1542 v = ChainOfVisitors( 1543 SequenceConstructorVisitor(f), 1544 PyTypesDeclareVisitor(f), 1545 PyTypesVisitor(f), 1546 Obj2ModPrototypeVisitor(f), 1547 FunctionVisitor(f), 1548 ObjVisitor(f), 1549 Obj2ModVisitor(f), 1550 ASTModuleVisitor(f), 1551 PartingShots(f), 1552 ) 1553 v.visit(mod) 1554 1555def main(input_filename, c_filename, h_filename, internal_h_filename, dump_module=False): 1556 auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) 1557 mod = asdl.parse(input_filename) 1558 if dump_module: 1559 print('Parsed Module:') 1560 print(mod) 1561 if not asdl.check(mod): 1562 sys.exit(1) 1563 1564 with c_filename.open("w") as c_file, \ 1565 h_filename.open("w") as h_file, \ 1566 internal_h_filename.open("w") as internal_h_file: 1567 c_file.write(auto_gen_msg) 1568 h_file.write(auto_gen_msg) 1569 internal_h_file.write(auto_gen_msg) 1570 1571 write_internal_h_header(mod, internal_h_file) 1572 write_source(mod, c_file, internal_h_file) 1573 write_header(mod, h_file) 1574 write_internal_h_footer(mod, internal_h_file) 1575 1576 print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.") 1577 1578if __name__ == "__main__": 1579 parser = ArgumentParser() 1580 parser.add_argument("input_file", type=Path) 1581 parser.add_argument("-C", "--c-file", type=Path, required=True) 1582 parser.add_argument("-H", "--h-file", type=Path, required=True) 1583 parser.add_argument("-I", "--internal-h-file", type=Path, required=True) 1584 parser.add_argument("-d", "--dump-module", action="store_true") 1585 1586 args = parser.parse_args() 1587 main(args.input_file, args.c_file, args.h_file, 1588 args.internal_h_file, args.dump_module) 1589