• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Core data structures for compiled code templates."""
2
3import dataclasses
4import enum
5import sys
6import typing
7
8import _schema
9
10
11@enum.unique
12class HoleValue(enum.Enum):
13    """
14    Different "base" values that can be patched into holes (usually combined with the
15    address of a symbol and/or an addend).
16    """
17
18    # The base address of the machine code for the current uop (exposed as _JIT_ENTRY):
19    CODE = enum.auto()
20    # The base address of the machine code for the next uop (exposed as _JIT_CONTINUE):
21    CONTINUE = enum.auto()
22    # The base address of the read-only data for this uop:
23    DATA = enum.auto()
24    # The address of the current executor (exposed as _JIT_EXECUTOR):
25    EXECUTOR = enum.auto()
26    # The base address of the "global" offset table located in the read-only data.
27    # Shouldn't be present in the final stencils, since these are all replaced with
28    # equivalent DATA values:
29    GOT = enum.auto()
30    # The current uop's oparg (exposed as _JIT_OPARG):
31    OPARG = enum.auto()
32    # The current uop's operand on 64-bit platforms (exposed as _JIT_OPERAND):
33    OPERAND = enum.auto()
34    # The current uop's operand on 32-bit platforms (exposed as _JIT_OPERAND_HI/LO):
35    OPERAND_HI = enum.auto()
36    OPERAND_LO = enum.auto()
37    # The current uop's target (exposed as _JIT_TARGET):
38    TARGET = enum.auto()
39    # The base address of the machine code for the jump target (exposed as _JIT_JUMP_TARGET):
40    JUMP_TARGET = enum.auto()
41    # The base address of the machine code for the error jump target (exposed as _JIT_ERROR_TARGET):
42    ERROR_TARGET = enum.auto()
43    # The index of the exit to be jumped through (exposed as _JIT_EXIT_INDEX):
44    EXIT_INDEX = enum.auto()
45    # The base address of the machine code for the first uop (exposed as _JIT_TOP):
46    TOP = enum.auto()
47    # A hardcoded value of zero (used for symbol lookups):
48    ZERO = enum.auto()
49
50
51# Map relocation types to our JIT's patch functions. "r" suffixes indicate that
52# the patch function is relative. "x" suffixes indicate that they are "relaxing"
53# (see comments in jit.c for more info):
54_PATCH_FUNCS = {
55    # aarch64-apple-darwin:
56    "ARM64_RELOC_BRANCH26": "patch_aarch64_26r",
57    "ARM64_RELOC_GOT_LOAD_PAGE21": "patch_aarch64_21rx",
58    "ARM64_RELOC_GOT_LOAD_PAGEOFF12": "patch_aarch64_12x",
59    "ARM64_RELOC_PAGE21": "patch_aarch64_21r",
60    "ARM64_RELOC_PAGEOFF12": "patch_aarch64_12",
61    "ARM64_RELOC_UNSIGNED": "patch_64",
62    # x86_64-pc-windows-msvc:
63    "IMAGE_REL_AMD64_REL32": "patch_x86_64_32rx",
64    # aarch64-pc-windows-msvc:
65    "IMAGE_REL_ARM64_BRANCH26": "patch_aarch64_26r",
66    "IMAGE_REL_ARM64_PAGEBASE_REL21": "patch_aarch64_21rx",
67    "IMAGE_REL_ARM64_PAGEOFFSET_12A": "patch_aarch64_12",
68    "IMAGE_REL_ARM64_PAGEOFFSET_12L": "patch_aarch64_12x",
69    # i686-pc-windows-msvc:
70    "IMAGE_REL_I386_DIR32": "patch_32",
71    "IMAGE_REL_I386_REL32": "patch_x86_64_32rx",
72    # aarch64-unknown-linux-gnu:
73    "R_AARCH64_ABS64": "patch_64",
74    "R_AARCH64_ADD_ABS_LO12_NC": "patch_aarch64_12",
75    "R_AARCH64_ADR_GOT_PAGE": "patch_aarch64_21rx",
76    "R_AARCH64_ADR_PREL_PG_HI21": "patch_aarch64_21r",
77    "R_AARCH64_CALL26": "patch_aarch64_26r",
78    "R_AARCH64_JUMP26": "patch_aarch64_26r",
79    "R_AARCH64_LD64_GOT_LO12_NC": "patch_aarch64_12x",
80    "R_AARCH64_MOVW_UABS_G0_NC": "patch_aarch64_16a",
81    "R_AARCH64_MOVW_UABS_G1_NC": "patch_aarch64_16b",
82    "R_AARCH64_MOVW_UABS_G2_NC": "patch_aarch64_16c",
83    "R_AARCH64_MOVW_UABS_G3": "patch_aarch64_16d",
84    # x86_64-unknown-linux-gnu:
85    "R_X86_64_64": "patch_64",
86    "R_X86_64_GOTPCREL": "patch_32r",
87    "R_X86_64_GOTPCRELX": "patch_x86_64_32rx",
88    "R_X86_64_PC32": "patch_32r",
89    "R_X86_64_REX_GOTPCRELX": "patch_x86_64_32rx",
90    # x86_64-apple-darwin:
91    "X86_64_RELOC_BRANCH": "patch_32r",
92    "X86_64_RELOC_GOT": "patch_x86_64_32rx",
93    "X86_64_RELOC_GOT_LOAD": "patch_x86_64_32rx",
94    "X86_64_RELOC_SIGNED": "patch_32r",
95    "X86_64_RELOC_UNSIGNED": "patch_64",
96}
97# Translate HoleValues to C expressions:
98_HOLE_EXPRS = {
99    HoleValue.CODE: "(uintptr_t)code",
100    HoleValue.CONTINUE: "(uintptr_t)code + sizeof(code_body)",
101    HoleValue.DATA: "(uintptr_t)data",
102    HoleValue.EXECUTOR: "(uintptr_t)executor",
103    # These should all have been turned into DATA values by process_relocations:
104    # HoleValue.GOT: "",
105    HoleValue.OPARG: "instruction->oparg",
106    HoleValue.OPERAND: "instruction->operand",
107    HoleValue.OPERAND_HI: "(instruction->operand >> 32)",
108    HoleValue.OPERAND_LO: "(instruction->operand & UINT32_MAX)",
109    HoleValue.TARGET: "instruction->target",
110    HoleValue.JUMP_TARGET: "instruction_starts[instruction->jump_target]",
111    HoleValue.ERROR_TARGET: "instruction_starts[instruction->error_target]",
112    HoleValue.EXIT_INDEX: "instruction->exit_index",
113    HoleValue.TOP: "instruction_starts[1]",
114    HoleValue.ZERO: "",
115}
116
117
118@dataclasses.dataclass
119class Hole:
120    """
121    A "hole" in the stencil to be patched with a computed runtime value.
122
123    Analogous to relocation records in an object file.
124    """
125
126    offset: int
127    kind: _schema.HoleKind
128    # Patch with this base value:
129    value: HoleValue
130    # ...plus the address of this symbol:
131    symbol: str | None
132    # ...plus this addend:
133    addend: int
134    func: str = dataclasses.field(init=False)
135    # Convenience method:
136    replace = dataclasses.replace
137
138    def __post_init__(self) -> None:
139        self.func = _PATCH_FUNCS[self.kind]
140
141    def fold(self, other: typing.Self) -> typing.Self | None:
142        """Combine two holes into a single hole, if possible."""
143        if (
144            self.offset + 4 == other.offset
145            and self.value == other.value
146            and self.symbol == other.symbol
147            and self.addend == other.addend
148            and self.func == "patch_aarch64_21rx"
149            and other.func == "patch_aarch64_12x"
150        ):
151            # These can *only* be properly relaxed when they appear together and
152            # patch the same value:
153            folded = self.replace()
154            folded.func = "patch_aarch64_33rx"
155            return folded
156        return None
157
158    def as_c(self, where: str) -> str:
159        """Dump this hole as a call to a patch_* function."""
160        location = f"{where} + {self.offset:#x}"
161        value = _HOLE_EXPRS[self.value]
162        if self.symbol:
163            if value:
164                value += " + "
165            value += f"(uintptr_t)&{self.symbol}"
166        if _signed(self.addend):
167            if value:
168                value += " + "
169            value += f"{_signed(self.addend):#x}"
170        return f"{self.func}({location}, {value});"
171
172
173@dataclasses.dataclass
174class Stencil:
175    """
176    A contiguous block of machine code or data to be copied-and-patched.
177
178    Analogous to a section or segment in an object file.
179    """
180
181    body: bytearray = dataclasses.field(default_factory=bytearray, init=False)
182    holes: list[Hole] = dataclasses.field(default_factory=list, init=False)
183    disassembly: list[str] = dataclasses.field(default_factory=list, init=False)
184
185    def pad(self, alignment: int) -> None:
186        """Pad the stencil to the given alignment."""
187        offset = len(self.body)
188        padding = -offset % alignment
189        self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
190        self.body.extend([0] * padding)
191
192    def emit_aarch64_trampoline(self, hole: Hole) -> None:
193        """Even with the large code model, AArch64 Linux insists on 28-bit jumps."""
194        base = len(self.body)
195        where = slice(hole.offset, hole.offset + 4)
196        instruction = int.from_bytes(self.body[where], sys.byteorder)
197        instruction &= 0xFC000000
198        instruction |= ((base - hole.offset) >> 2) & 0x03FFFFFF
199        self.body[where] = instruction.to_bytes(4, sys.byteorder)
200        self.disassembly += [
201            f"{base + 4 * 0:x}: d2800008      mov     x8, #0x0",
202            f"{base + 4 * 0:016x}:  R_AARCH64_MOVW_UABS_G0_NC    {hole.symbol}",
203            f"{base + 4 * 1:x}: f2a00008      movk    x8, #0x0, lsl #16",
204            f"{base + 4 * 1:016x}:  R_AARCH64_MOVW_UABS_G1_NC    {hole.symbol}",
205            f"{base + 4 * 2:x}: f2c00008      movk    x8, #0x0, lsl #32",
206            f"{base + 4 * 2:016x}:  R_AARCH64_MOVW_UABS_G2_NC    {hole.symbol}",
207            f"{base + 4 * 3:x}: f2e00008      movk    x8, #0x0, lsl #48",
208            f"{base + 4 * 3:016x}:  R_AARCH64_MOVW_UABS_G3       {hole.symbol}",
209            f"{base + 4 * 4:x}: d61f0100      br      x8",
210        ]
211        for code in [
212            0xD2800008.to_bytes(4, sys.byteorder),
213            0xF2A00008.to_bytes(4, sys.byteorder),
214            0xF2C00008.to_bytes(4, sys.byteorder),
215            0xF2E00008.to_bytes(4, sys.byteorder),
216            0xD61F0100.to_bytes(4, sys.byteorder),
217        ]:
218            self.body.extend(code)
219        for i, kind in enumerate(
220            [
221                "R_AARCH64_MOVW_UABS_G0_NC",
222                "R_AARCH64_MOVW_UABS_G1_NC",
223                "R_AARCH64_MOVW_UABS_G2_NC",
224                "R_AARCH64_MOVW_UABS_G3",
225            ]
226        ):
227            self.holes.append(hole.replace(offset=base + 4 * i, kind=kind))
228
229    def remove_jump(self, *, alignment: int = 1) -> None:
230        """Remove a zero-length continuation jump, if it exists."""
231        hole = max(self.holes, key=lambda hole: hole.offset)
232        match hole:
233            case Hole(
234                offset=offset,
235                kind="IMAGE_REL_AMD64_REL32",
236                value=HoleValue.GOT,
237                symbol="_JIT_CONTINUE",
238                addend=-4,
239            ) as hole:
240                # jmp qword ptr [rip]
241                jump = b"\x48\xFF\x25\x00\x00\x00\x00"
242                offset -= 3
243            case Hole(
244                offset=offset,
245                kind="IMAGE_REL_I386_REL32" | "X86_64_RELOC_BRANCH",
246                value=HoleValue.CONTINUE,
247                symbol=None,
248                addend=-4,
249            ) as hole:
250                # jmp 5
251                jump = b"\xE9\x00\x00\x00\x00"
252                offset -= 1
253            case Hole(
254                offset=offset,
255                kind="R_AARCH64_JUMP26",
256                value=HoleValue.CONTINUE,
257                symbol=None,
258                addend=0,
259            ) as hole:
260                # b #4
261                jump = b"\x00\x00\x00\x14"
262            case Hole(
263                offset=offset,
264                kind="R_X86_64_GOTPCRELX",
265                value=HoleValue.GOT,
266                symbol="_JIT_CONTINUE",
267                addend=addend,
268            ) as hole:
269                assert _signed(addend) == -4
270                # jmp qword ptr [rip]
271                jump = b"\xFF\x25\x00\x00\x00\x00"
272                offset -= 2
273            case _:
274                return
275        if self.body[offset:] == jump and offset % alignment == 0:
276            self.body = self.body[:offset]
277            self.holes.remove(hole)
278
279
280@dataclasses.dataclass
281class StencilGroup:
282    """
283    Code and data corresponding to a given micro-opcode.
284
285    Analogous to an entire object file.
286    """
287
288    code: Stencil = dataclasses.field(default_factory=Stencil, init=False)
289    data: Stencil = dataclasses.field(default_factory=Stencil, init=False)
290    symbols: dict[int | str, tuple[HoleValue, int]] = dataclasses.field(
291        default_factory=dict, init=False
292    )
293    _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
294
295    def process_relocations(self, *, alignment: int = 1) -> None:
296        """Fix up all GOT and internal relocations for this stencil group."""
297        for hole in self.code.holes.copy():
298            if (
299                hole.kind
300                in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"}
301                and hole.value is HoleValue.ZERO
302            ):
303                self.code.pad(alignment)
304                self.code.emit_aarch64_trampoline(hole)
305                self.code.holes.remove(hole)
306        self.code.remove_jump(alignment=alignment)
307        self.code.pad(alignment)
308        self.data.pad(8)
309        for stencil in [self.code, self.data]:
310            for hole in stencil.holes:
311                if hole.value is HoleValue.GOT:
312                    assert hole.symbol is not None
313                    hole.value = HoleValue.DATA
314                    hole.addend += self._global_offset_table_lookup(hole.symbol)
315                    hole.symbol = None
316                elif hole.symbol in self.symbols:
317                    hole.value, addend = self.symbols[hole.symbol]
318                    hole.addend += addend
319                    hole.symbol = None
320                elif (
321                    hole.kind in {"IMAGE_REL_AMD64_REL32"}
322                    and hole.value is HoleValue.ZERO
323                ):
324                    raise ValueError(
325                        f"Add PyAPI_FUNC(...) or PyAPI_DATA(...) to declaration of {hole.symbol}!"
326                    )
327        self._emit_global_offset_table()
328        self.code.holes.sort(key=lambda hole: hole.offset)
329        self.data.holes.sort(key=lambda hole: hole.offset)
330
331    def _global_offset_table_lookup(self, symbol: str) -> int:
332        return len(self.data.body) + self._got.setdefault(symbol, 8 * len(self._got))
333
334    def _emit_global_offset_table(self) -> None:
335        got = len(self.data.body)
336        for s, offset in self._got.items():
337            if s in self.symbols:
338                value, addend = self.symbols[s]
339                symbol = None
340            else:
341                value, symbol = symbol_to_value(s)
342                addend = 0
343            self.data.holes.append(
344                Hole(got + offset, "R_X86_64_64", value, symbol, addend)
345            )
346            value_part = value.name if value is not HoleValue.ZERO else ""
347            if value_part and not symbol and not addend:
348                addend_part = ""
349            else:
350                signed = "+" if symbol is not None else ""
351                addend_part = f"&{symbol}" if symbol else ""
352                addend_part += f"{_signed(addend):{signed}#x}"
353                if value_part:
354                    value_part += "+"
355            self.data.disassembly.append(
356                f"{len(self.data.body):x}: {value_part}{addend_part}"
357            )
358            self.data.body.extend([0] * 8)
359
360    def as_c(self, opname: str) -> str:
361        """Dump this hole as a StencilGroup initializer."""
362        return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}"
363
364
365def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
366    """
367    Convert a symbol name to a HoleValue and a symbol name.
368
369    Some symbols (starting with "_JIT_") are special and are converted to their
370    own HoleValues.
371    """
372    if symbol.startswith("_JIT_"):
373        try:
374            return HoleValue[symbol.removeprefix("_JIT_")], None
375        except KeyError:
376            pass
377    return HoleValue.ZERO, symbol
378
379
380def _signed(value: int) -> int:
381    value %= 1 << 64
382    if value & (1 << 63):
383        value -= 1 << 64
384    return value
385