• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""bytecode_helper - support tools for testing correct bytecode generation"""
2
3import unittest
4import dis
5import io
6
7_UNSPECIFIED = object()
8
9class BytecodeTestCase(unittest.TestCase):
10    """Custom assertion methods for inspecting bytecode."""
11
12    def get_disassembly_as_string(self, co):
13        s = io.StringIO()
14        dis.dis(co, file=s)
15        return s.getvalue()
16
17    def assertInBytecode(self, x, opname, argval=_UNSPECIFIED):
18        """Returns instr if opname is found, otherwise throws AssertionError"""
19        for instr in dis.get_instructions(x):
20            if instr.opname == opname:
21                if argval is _UNSPECIFIED or instr.argval == argval:
22                    return instr
23        disassembly = self.get_disassembly_as_string(x)
24        if argval is _UNSPECIFIED:
25            msg = '%s not found in bytecode:\n%s' % (opname, disassembly)
26        else:
27            msg = '(%s,%r) not found in bytecode:\n%s'
28            msg = msg % (opname, argval, disassembly)
29        self.fail(msg)
30
31    def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED):
32        """Throws AssertionError if opname is found"""
33        for instr in dis.get_instructions(x):
34            if instr.opname == opname:
35                disassembly = self.get_disassembly_as_string(x)
36                if argval is _UNSPECIFIED:
37                    msg = '%s occurs in bytecode:\n%s' % (opname, disassembly)
38                    self.fail(msg)
39                elif instr.argval == argval:
40                    msg = '(%s,%r) occurs in bytecode:\n%s'
41                    msg = msg % (opname, argval, disassembly)
42                    self.fail(msg)
43