1"""A flow graph representation for Python bytecode""" 2 3import dis 4import types 5import sys 6 7from compiler import misc 8from compiler.consts \ 9 import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS 10 11class FlowGraph: 12 def __init__(self): 13 self.current = self.entry = Block() 14 self.exit = Block("exit") 15 self.blocks = misc.Set() 16 self.blocks.add(self.entry) 17 self.blocks.add(self.exit) 18 19 def startBlock(self, block): 20 if self._debug: 21 if self.current: 22 print "end", repr(self.current) 23 print " next", self.current.next 24 print " prev", self.current.prev 25 print " ", self.current.get_children() 26 print repr(block) 27 self.current = block 28 29 def nextBlock(self, block=None): 30 # XXX think we need to specify when there is implicit transfer 31 # from one block to the next. might be better to represent this 32 # with explicit JUMP_ABSOLUTE instructions that are optimized 33 # out when they are unnecessary. 34 # 35 # I think this strategy works: each block has a child 36 # designated as "next" which is returned as the last of the 37 # children. because the nodes in a graph are emitted in 38 # reverse post order, the "next" block will always be emitted 39 # immediately after its parent. 40 # Worry: maintaining this invariant could be tricky 41 if block is None: 42 block = self.newBlock() 43 44 # Note: If the current block ends with an unconditional control 45 # transfer, then it is techically incorrect to add an implicit 46 # transfer to the block graph. Doing so results in code generation 47 # for unreachable blocks. That doesn't appear to be very common 48 # with Python code and since the built-in compiler doesn't optimize 49 # it out we don't either. 50 self.current.addNext(block) 51 self.startBlock(block) 52 53 def newBlock(self): 54 b = Block() 55 self.blocks.add(b) 56 return b 57 58 def startExitBlock(self): 59 self.startBlock(self.exit) 60 61 _debug = 0 62 63 def _enable_debug(self): 64 self._debug = 1 65 66 def _disable_debug(self): 67 self._debug = 0 68 69 def emit(self, *inst): 70 if self._debug: 71 print "\t", inst 72 if len(inst) == 2 and isinstance(inst[1], Block): 73 self.current.addOutEdge(inst[1]) 74 self.current.emit(inst) 75 76 def getBlocksInOrder(self): 77 """Return the blocks in reverse postorder 78 79 i.e. each node appears before all of its successors 80 """ 81 order = order_blocks(self.entry, self.exit) 82 return order 83 84 def getBlocks(self): 85 return self.blocks.elements() 86 87 def getRoot(self): 88 """Return nodes appropriate for use with dominator""" 89 return self.entry 90 91 def getContainedGraphs(self): 92 l = [] 93 for b in self.getBlocks(): 94 l.extend(b.getContainedGraphs()) 95 return l 96 97 98def order_blocks(start_block, exit_block): 99 """Order blocks so that they are emitted in the right order""" 100 # Rules: 101 # - when a block has a next block, the next block must be emitted just after 102 # - when a block has followers (relative jumps), it must be emitted before 103 # them 104 # - all reachable blocks must be emitted 105 order = [] 106 107 # Find all the blocks to be emitted. 108 remaining = set() 109 todo = [start_block] 110 while todo: 111 b = todo.pop() 112 if b in remaining: 113 continue 114 remaining.add(b) 115 for c in b.get_children(): 116 if c not in remaining: 117 todo.append(c) 118 119 # A block is dominated by another block if that block must be emitted 120 # before it. 121 dominators = {} 122 for b in remaining: 123 if __debug__ and b.next: 124 assert b is b.next[0].prev[0], (b, b.next) 125 # Make sure every block appears in dominators, even if no 126 # other block must precede it. 127 dominators.setdefault(b, set()) 128 # preceding blocks dominate following blocks 129 for c in b.get_followers(): 130 while 1: 131 dominators.setdefault(c, set()).add(b) 132 # Any block that has a next pointer leading to c is also 133 # dominated because the whole chain will be emitted at once. 134 # Walk backwards and add them all. 135 if c.prev and c.prev[0] is not b: 136 c = c.prev[0] 137 else: 138 break 139 140 def find_next(): 141 # Find a block that can be emitted next. 142 for b in remaining: 143 for c in dominators[b]: 144 if c in remaining: 145 break # can't emit yet, dominated by a remaining block 146 else: 147 return b 148 assert 0, 'circular dependency, cannot find next block' 149 150 b = start_block 151 while 1: 152 order.append(b) 153 remaining.discard(b) 154 if b.next: 155 b = b.next[0] 156 continue 157 elif b is not exit_block and not b.has_unconditional_transfer(): 158 order.append(exit_block) 159 if not remaining: 160 break 161 b = find_next() 162 return order 163 164 165class Block: 166 _count = 0 167 168 def __init__(self, label=''): 169 self.insts = [] 170 self.outEdges = set() 171 self.label = label 172 self.bid = Block._count 173 self.next = [] 174 self.prev = [] 175 Block._count = Block._count + 1 176 177 def __repr__(self): 178 if self.label: 179 return "<block %s id=%d>" % (self.label, self.bid) 180 else: 181 return "<block id=%d>" % (self.bid) 182 183 def __str__(self): 184 insts = map(str, self.insts) 185 return "<block %s %d:\n%s>" % (self.label, self.bid, 186 '\n'.join(insts)) 187 188 def emit(self, inst): 189 op = inst[0] 190 self.insts.append(inst) 191 192 def getInstructions(self): 193 return self.insts 194 195 def addOutEdge(self, block): 196 self.outEdges.add(block) 197 198 def addNext(self, block): 199 self.next.append(block) 200 assert len(self.next) == 1, map(str, self.next) 201 block.prev.append(self) 202 assert len(block.prev) == 1, map(str, block.prev) 203 204 _uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS', 205 'JUMP_ABSOLUTE', 'JUMP_FORWARD', 'CONTINUE_LOOP', 206 ) 207 208 def has_unconditional_transfer(self): 209 """Returns True if there is an unconditional transfer to an other block 210 at the end of this block. This means there is no risk for the bytecode 211 executer to go past this block's bytecode.""" 212 try: 213 op, arg = self.insts[-1] 214 except (IndexError, ValueError): 215 return 216 return op in self._uncond_transfer 217 218 def get_children(self): 219 return list(self.outEdges) + self.next 220 221 def get_followers(self): 222 """Get the whole list of followers, including the next block.""" 223 followers = set(self.next) 224 # Blocks that must be emitted *after* this one, because of 225 # bytecode offsets (e.g. relative jumps) pointing to them. 226 for inst in self.insts: 227 if inst[0] in PyFlowGraph.hasjrel: 228 followers.add(inst[1]) 229 return followers 230 231 def getContainedGraphs(self): 232 """Return all graphs contained within this block. 233 234 For example, a MAKE_FUNCTION block will contain a reference to 235 the graph for the function body. 236 """ 237 contained = [] 238 for inst in self.insts: 239 if len(inst) == 1: 240 continue 241 op = inst[1] 242 if hasattr(op, 'graph'): 243 contained.append(op.graph) 244 return contained 245 246# flags for code objects 247 248# the FlowGraph is transformed in place; it exists in one of these states 249RAW = "RAW" 250FLAT = "FLAT" 251CONV = "CONV" 252DONE = "DONE" 253 254class PyFlowGraph(FlowGraph): 255 super_init = FlowGraph.__init__ 256 257 def __init__(self, name, filename, args=(), optimized=0, klass=None): 258 self.super_init() 259 self.name = name 260 self.filename = filename 261 self.docstring = None 262 self.args = args # XXX 263 self.argcount = getArgCount(args) 264 self.klass = klass 265 if optimized: 266 self.flags = CO_OPTIMIZED | CO_NEWLOCALS 267 else: 268 self.flags = 0 269 self.consts = [] 270 self.names = [] 271 # Free variables found by the symbol table scan, including 272 # variables used only in nested scopes, are included here. 273 self.freevars = [] 274 self.cellvars = [] 275 # The closure list is used to track the order of cell 276 # variables and free variables in the resulting code object. 277 # The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both 278 # kinds of variables. 279 self.closure = [] 280 self.varnames = list(args) or [] 281 for i in range(len(self.varnames)): 282 var = self.varnames[i] 283 if isinstance(var, TupleArg): 284 self.varnames[i] = var.getName() 285 self.stage = RAW 286 287 def setDocstring(self, doc): 288 self.docstring = doc 289 290 def setFlag(self, flag): 291 self.flags = self.flags | flag 292 if flag == CO_VARARGS: 293 self.argcount = self.argcount - 1 294 295 def checkFlag(self, flag): 296 if self.flags & flag: 297 return 1 298 299 def setFreeVars(self, names): 300 self.freevars = list(names) 301 302 def setCellVars(self, names): 303 self.cellvars = names 304 305 def getCode(self): 306 """Get a Python code object""" 307 assert self.stage == RAW 308 self.computeStackDepth() 309 self.flattenGraph() 310 assert self.stage == FLAT 311 self.convertArgs() 312 assert self.stage == CONV 313 self.makeByteCode() 314 assert self.stage == DONE 315 return self.newCodeObject() 316 317 def dump(self, io=None): 318 if io: 319 save = sys.stdout 320 sys.stdout = io 321 pc = 0 322 for t in self.insts: 323 opname = t[0] 324 if opname == "SET_LINENO": 325 print 326 if len(t) == 1: 327 print "\t", "%3d" % pc, opname 328 pc = pc + 1 329 else: 330 print "\t", "%3d" % pc, opname, t[1] 331 pc = pc + 3 332 if io: 333 sys.stdout = save 334 335 def computeStackDepth(self): 336 """Compute the max stack depth. 337 338 Approach is to compute the stack effect of each basic block. 339 Then find the path through the code with the largest total 340 effect. 341 """ 342 depth = {} 343 exit = None 344 for b in self.getBlocks(): 345 depth[b] = findDepth(b.getInstructions()) 346 347 seen = {} 348 349 def max_depth(b, d): 350 if b in seen: 351 return d 352 seen[b] = 1 353 d = d + depth[b] 354 children = b.get_children() 355 if children: 356 return max([max_depth(c, d) for c in children]) 357 else: 358 if not b.label == "exit": 359 return max_depth(self.exit, d) 360 else: 361 return d 362 363 self.stacksize = max_depth(self.entry, 0) 364 365 def flattenGraph(self): 366 """Arrange the blocks in order and resolve jumps""" 367 assert self.stage == RAW 368 self.insts = insts = [] 369 pc = 0 370 begin = {} 371 end = {} 372 for b in self.getBlocksInOrder(): 373 begin[b] = pc 374 for inst in b.getInstructions(): 375 insts.append(inst) 376 if len(inst) == 1: 377 pc = pc + 1 378 elif inst[0] != "SET_LINENO": 379 # arg takes 2 bytes 380 pc = pc + 3 381 end[b] = pc 382 pc = 0 383 for i in range(len(insts)): 384 inst = insts[i] 385 if len(inst) == 1: 386 pc = pc + 1 387 elif inst[0] != "SET_LINENO": 388 pc = pc + 3 389 opname = inst[0] 390 if opname in self.hasjrel: 391 oparg = inst[1] 392 offset = begin[oparg] - pc 393 insts[i] = opname, offset 394 elif opname in self.hasjabs: 395 insts[i] = opname, begin[inst[1]] 396 self.stage = FLAT 397 398 hasjrel = set() 399 for i in dis.hasjrel: 400 hasjrel.add(dis.opname[i]) 401 hasjabs = set() 402 for i in dis.hasjabs: 403 hasjabs.add(dis.opname[i]) 404 405 def convertArgs(self): 406 """Convert arguments from symbolic to concrete form""" 407 assert self.stage == FLAT 408 self.consts.insert(0, self.docstring) 409 self.sort_cellvars() 410 for i in range(len(self.insts)): 411 t = self.insts[i] 412 if len(t) == 2: 413 opname, oparg = t 414 conv = self._converters.get(opname, None) 415 if conv: 416 self.insts[i] = opname, conv(self, oparg) 417 self.stage = CONV 418 419 def sort_cellvars(self): 420 """Sort cellvars in the order of varnames and prune from freevars. 421 """ 422 cells = {} 423 for name in self.cellvars: 424 cells[name] = 1 425 self.cellvars = [name for name in self.varnames 426 if name in cells] 427 for name in self.cellvars: 428 del cells[name] 429 self.cellvars = self.cellvars + cells.keys() 430 self.closure = self.cellvars + self.freevars 431 432 def _lookupName(self, name, list): 433 """Return index of name in list, appending if necessary 434 435 This routine uses a list instead of a dictionary, because a 436 dictionary can't store two different keys if the keys have the 437 same value but different types, e.g. 2 and 2L. The compiler 438 must treat these two separately, so it does an explicit type 439 comparison before comparing the values. 440 """ 441 t = type(name) 442 for i in range(len(list)): 443 if t == type(list[i]) and list[i] == name: 444 return i 445 end = len(list) 446 list.append(name) 447 return end 448 449 _converters = {} 450 def _convert_LOAD_CONST(self, arg): 451 if hasattr(arg, 'getCode'): 452 arg = arg.getCode() 453 return self._lookupName(arg, self.consts) 454 455 def _convert_LOAD_FAST(self, arg): 456 self._lookupName(arg, self.names) 457 return self._lookupName(arg, self.varnames) 458 _convert_STORE_FAST = _convert_LOAD_FAST 459 _convert_DELETE_FAST = _convert_LOAD_FAST 460 461 def _convert_LOAD_NAME(self, arg): 462 if self.klass is None: 463 self._lookupName(arg, self.varnames) 464 return self._lookupName(arg, self.names) 465 466 def _convert_NAME(self, arg): 467 if self.klass is None: 468 self._lookupName(arg, self.varnames) 469 return self._lookupName(arg, self.names) 470 _convert_STORE_NAME = _convert_NAME 471 _convert_DELETE_NAME = _convert_NAME 472 _convert_IMPORT_NAME = _convert_NAME 473 _convert_IMPORT_FROM = _convert_NAME 474 _convert_STORE_ATTR = _convert_NAME 475 _convert_LOAD_ATTR = _convert_NAME 476 _convert_DELETE_ATTR = _convert_NAME 477 _convert_LOAD_GLOBAL = _convert_NAME 478 _convert_STORE_GLOBAL = _convert_NAME 479 _convert_DELETE_GLOBAL = _convert_NAME 480 481 def _convert_DEREF(self, arg): 482 self._lookupName(arg, self.names) 483 self._lookupName(arg, self.varnames) 484 return self._lookupName(arg, self.closure) 485 _convert_LOAD_DEREF = _convert_DEREF 486 _convert_STORE_DEREF = _convert_DEREF 487 488 def _convert_LOAD_CLOSURE(self, arg): 489 self._lookupName(arg, self.varnames) 490 return self._lookupName(arg, self.closure) 491 492 _cmp = list(dis.cmp_op) 493 def _convert_COMPARE_OP(self, arg): 494 return self._cmp.index(arg) 495 496 # similarly for other opcodes... 497 498 for name, obj in locals().items(): 499 if name[:9] == "_convert_": 500 opname = name[9:] 501 _converters[opname] = obj 502 del name, obj, opname 503 504 def makeByteCode(self): 505 assert self.stage == CONV 506 self.lnotab = lnotab = LineAddrTable() 507 for t in self.insts: 508 opname = t[0] 509 if len(t) == 1: 510 lnotab.addCode(self.opnum[opname]) 511 else: 512 oparg = t[1] 513 if opname == "SET_LINENO": 514 lnotab.nextLine(oparg) 515 continue 516 hi, lo = twobyte(oparg) 517 try: 518 lnotab.addCode(self.opnum[opname], lo, hi) 519 except ValueError: 520 print opname, oparg 521 print self.opnum[opname], lo, hi 522 raise 523 self.stage = DONE 524 525 opnum = {} 526 for num in range(len(dis.opname)): 527 opnum[dis.opname[num]] = num 528 del num 529 530 def newCodeObject(self): 531 assert self.stage == DONE 532 if (self.flags & CO_NEWLOCALS) == 0: 533 nlocals = 0 534 else: 535 nlocals = len(self.varnames) 536 argcount = self.argcount 537 if self.flags & CO_VARKEYWORDS: 538 argcount = argcount - 1 539 return types.CodeType(argcount, nlocals, self.stacksize, self.flags, 540 self.lnotab.getCode(), self.getConsts(), 541 tuple(self.names), tuple(self.varnames), 542 self.filename, self.name, self.lnotab.firstline, 543 self.lnotab.getTable(), tuple(self.freevars), 544 tuple(self.cellvars)) 545 546 def getConsts(self): 547 """Return a tuple for the const slot of the code object 548 549 Must convert references to code (MAKE_FUNCTION) to code 550 objects recursively. 551 """ 552 l = [] 553 for elt in self.consts: 554 if isinstance(elt, PyFlowGraph): 555 elt = elt.getCode() 556 l.append(elt) 557 return tuple(l) 558 559def isJump(opname): 560 if opname[:4] == 'JUMP': 561 return 1 562 563class TupleArg: 564 """Helper for marking func defs with nested tuples in arglist""" 565 def __init__(self, count, names): 566 self.count = count 567 self.names = names 568 def __repr__(self): 569 return "TupleArg(%s, %s)" % (self.count, self.names) 570 def getName(self): 571 return ".%d" % self.count 572 573def getArgCount(args): 574 argcount = len(args) 575 if args: 576 for arg in args: 577 if isinstance(arg, TupleArg): 578 numNames = len(misc.flatten(arg.names)) 579 argcount = argcount - numNames 580 return argcount 581 582def twobyte(val): 583 """Convert an int argument into high and low bytes""" 584 assert isinstance(val, (int, long)) 585 return divmod(val, 256) 586 587class LineAddrTable: 588 """lnotab 589 590 This class builds the lnotab, which is documented in compile.c. 591 Here's a brief recap: 592 593 For each SET_LINENO instruction after the first one, two bytes are 594 added to lnotab. (In some cases, multiple two-byte entries are 595 added.) The first byte is the distance in bytes between the 596 instruction for the last SET_LINENO and the current SET_LINENO. 597 The second byte is offset in line numbers. If either offset is 598 greater than 255, multiple two-byte entries are added -- see 599 compile.c for the delicate details. 600 """ 601 602 def __init__(self): 603 self.code = [] 604 self.codeOffset = 0 605 self.firstline = 0 606 self.lastline = 0 607 self.lastoff = 0 608 self.lnotab = [] 609 610 def addCode(self, *args): 611 for arg in args: 612 self.code.append(chr(arg)) 613 self.codeOffset = self.codeOffset + len(args) 614 615 def nextLine(self, lineno): 616 if self.firstline == 0: 617 self.firstline = lineno 618 self.lastline = lineno 619 else: 620 # compute deltas 621 addr = self.codeOffset - self.lastoff 622 line = lineno - self.lastline 623 # Python assumes that lineno always increases with 624 # increasing bytecode address (lnotab is unsigned char). 625 # Depending on when SET_LINENO instructions are emitted 626 # this is not always true. Consider the code: 627 # a = (1, 628 # b) 629 # In the bytecode stream, the assignment to "a" occurs 630 # after the loading of "b". This works with the C Python 631 # compiler because it only generates a SET_LINENO instruction 632 # for the assignment. 633 if line >= 0: 634 push = self.lnotab.append 635 while addr > 255: 636 push(255); push(0) 637 addr -= 255 638 while line > 255: 639 push(addr); push(255) 640 line -= 255 641 addr = 0 642 if addr > 0 or line > 0: 643 push(addr); push(line) 644 self.lastline = lineno 645 self.lastoff = self.codeOffset 646 647 def getCode(self): 648 return ''.join(self.code) 649 650 def getTable(self): 651 return ''.join(map(chr, self.lnotab)) 652 653class StackDepthTracker: 654 # XXX 1. need to keep track of stack depth on jumps 655 # XXX 2. at least partly as a result, this code is broken 656 657 def findDepth(self, insts, debug=0): 658 depth = 0 659 maxDepth = 0 660 for i in insts: 661 opname = i[0] 662 if debug: 663 print i, 664 delta = self.effect.get(opname, None) 665 if delta is not None: 666 depth = depth + delta 667 else: 668 # now check patterns 669 for pat, pat_delta in self.patterns: 670 if opname[:len(pat)] == pat: 671 delta = pat_delta 672 depth = depth + delta 673 break 674 # if we still haven't found a match 675 if delta is None: 676 meth = getattr(self, opname, None) 677 if meth is not None: 678 depth = depth + meth(i[1]) 679 if depth > maxDepth: 680 maxDepth = depth 681 if debug: 682 print depth, maxDepth 683 return maxDepth 684 685 effect = { 686 'POP_TOP': -1, 687 'DUP_TOP': 1, 688 'LIST_APPEND': -1, 689 'SET_ADD': -1, 690 'MAP_ADD': -2, 691 'SLICE+1': -1, 692 'SLICE+2': -1, 693 'SLICE+3': -2, 694 'STORE_SLICE+0': -1, 695 'STORE_SLICE+1': -2, 696 'STORE_SLICE+2': -2, 697 'STORE_SLICE+3': -3, 698 'DELETE_SLICE+0': -1, 699 'DELETE_SLICE+1': -2, 700 'DELETE_SLICE+2': -2, 701 'DELETE_SLICE+3': -3, 702 'STORE_SUBSCR': -3, 703 'DELETE_SUBSCR': -2, 704 # PRINT_EXPR? 705 'PRINT_ITEM': -1, 706 'RETURN_VALUE': -1, 707 'YIELD_VALUE': -1, 708 'EXEC_STMT': -3, 709 'BUILD_CLASS': -2, 710 'STORE_NAME': -1, 711 'STORE_ATTR': -2, 712 'DELETE_ATTR': -1, 713 'STORE_GLOBAL': -1, 714 'BUILD_MAP': 1, 715 'COMPARE_OP': -1, 716 'STORE_FAST': -1, 717 'IMPORT_STAR': -1, 718 'IMPORT_NAME': -1, 719 'IMPORT_FROM': 1, 720 'LOAD_ATTR': 0, # unlike other loads 721 # close enough... 722 'SETUP_EXCEPT': 3, 723 'SETUP_FINALLY': 3, 724 'FOR_ITER': 1, 725 'WITH_CLEANUP': -1, 726 } 727 # use pattern match 728 patterns = [ 729 ('BINARY_', -1), 730 ('LOAD_', 1), 731 ] 732 733 def UNPACK_SEQUENCE(self, count): 734 return count-1 735 def BUILD_TUPLE(self, count): 736 return -count+1 737 def BUILD_LIST(self, count): 738 return -count+1 739 def BUILD_SET(self, count): 740 return -count+1 741 def CALL_FUNCTION(self, argc): 742 hi, lo = divmod(argc, 256) 743 return -(lo + hi * 2) 744 def CALL_FUNCTION_VAR(self, argc): 745 return self.CALL_FUNCTION(argc)-1 746 def CALL_FUNCTION_KW(self, argc): 747 return self.CALL_FUNCTION(argc)-1 748 def CALL_FUNCTION_VAR_KW(self, argc): 749 return self.CALL_FUNCTION(argc)-2 750 def MAKE_FUNCTION(self, argc): 751 return -argc 752 def MAKE_CLOSURE(self, argc): 753 # XXX need to account for free variables too! 754 return -argc 755 def BUILD_SLICE(self, argc): 756 if argc == 2: 757 return -1 758 elif argc == 3: 759 return -2 760 def DUP_TOPX(self, argc): 761 return argc 762 763findDepth = StackDepthTracker().findDepth 764