1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (C) 2018 The Android Open Source Project 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17"""Tools to interact with BPF programs.""" 18 19import abc 20import collections 21import struct 22 23# This comes from syscall(2). Most architectures only support passing 6 args to 24# syscalls, but ARM supports passing 7. 25MAX_SYSCALL_ARGUMENTS = 7 26 27# The following fields were copied from <linux/bpf_common.h>: 28 29# Instruction classes 30BPF_LD = 0x00 31BPF_LDX = 0x01 32BPF_ST = 0x02 33BPF_STX = 0x03 34BPF_ALU = 0x04 35BPF_JMP = 0x05 36BPF_RET = 0x06 37BPF_MISC = 0x07 38 39# LD/LDX fields. 40# Size 41BPF_W = 0x00 42BPF_H = 0x08 43BPF_B = 0x10 44# Mode 45BPF_IMM = 0x00 46BPF_ABS = 0x20 47BPF_IND = 0x40 48BPF_MEM = 0x60 49BPF_LEN = 0x80 50BPF_MSH = 0xa0 51 52# JMP fields. 53BPF_JA = 0x00 54BPF_JEQ = 0x10 55BPF_JGT = 0x20 56BPF_JGE = 0x30 57BPF_JSET = 0x40 58 59# Source 60BPF_K = 0x00 61BPF_X = 0x08 62 63BPF_MAXINSNS = 4096 64 65# The following fields were copied from <linux/seccomp.h>: 66 67SECCOMP_RET_KILL_PROCESS = 0x80000000 68SECCOMP_RET_KILL_THREAD = 0x00000000 69SECCOMP_RET_TRAP = 0x00030000 70SECCOMP_RET_ERRNO = 0x00050000 71SECCOMP_RET_TRACE = 0x7ff00000 72SECCOMP_RET_LOG = 0x7ffc0000 73SECCOMP_RET_ALLOW = 0x7fff0000 74 75SECCOMP_RET_ACTION_FULL = 0xffff0000 76SECCOMP_RET_DATA = 0x0000ffff 77 78 79def arg_offset(arg_index, hi=False): 80 """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset.""" 81 offsetof_args = 4 + 4 + 8 82 arg_width = 8 83 return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi 84 85 86def simulate(instructions, arch, syscall_number, *args): 87 """Simulate a BPF program with the given arguments.""" 88 args = ((args + (0, ) * 89 (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS]) 90 input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS, 91 syscall_number, arch, 0, *args) 92 93 register = 0 94 program_counter = 0 95 cost = 0 96 while program_counter < len(instructions): 97 ins = instructions[program_counter] 98 program_counter += 1 99 cost += 1 100 if ins.code == BPF_LD | BPF_W | BPF_ABS: 101 register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0] 102 elif ins.code == BPF_JMP | BPF_JA | BPF_K: 103 program_counter += ins.k 104 elif ins.code == BPF_JMP | BPF_JEQ | BPF_K: 105 if register == ins.k: 106 program_counter += ins.jt 107 else: 108 program_counter += ins.jf 109 elif ins.code == BPF_JMP | BPF_JGT | BPF_K: 110 if register > ins.k: 111 program_counter += ins.jt 112 else: 113 program_counter += ins.jf 114 elif ins.code == BPF_JMP | BPF_JGE | BPF_K: 115 if register >= ins.k: 116 program_counter += ins.jt 117 else: 118 program_counter += ins.jf 119 elif ins.code == BPF_JMP | BPF_JSET | BPF_K: 120 if register & ins.k != 0: 121 program_counter += ins.jt 122 else: 123 program_counter += ins.jf 124 elif ins.code == BPF_RET: 125 if ins.k == SECCOMP_RET_KILL_PROCESS: 126 return (cost, 'KILL_PROCESS') 127 if ins.k == SECCOMP_RET_KILL_THREAD: 128 return (cost, 'KILL_THREAD') 129 if ins.k == SECCOMP_RET_TRAP: 130 return (cost, 'TRAP') 131 if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO: 132 return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA) 133 if ins.k == SECCOMP_RET_TRACE: 134 return (cost, 'TRACE') 135 if ins.k == SECCOMP_RET_LOG: 136 return (cost, 'LOG') 137 if ins.k == SECCOMP_RET_ALLOW: 138 return (cost, 'ALLOW') 139 raise Exception('unknown return %#x' % ins.k) 140 else: 141 raise Exception('unknown instruction %r' % (ins, )) 142 raise Exception('out-of-bounds') 143 144 145class SockFilter( 146 collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])): 147 """A representation of struct sock_filter.""" 148 149 __slots__ = () 150 151 def encode(self): 152 """Return an encoded version of the SockFilter.""" 153 return struct.pack('HBBI', self.code, self.jt, self.jf, self.k) 154 155 156class AbstractBlock(abc.ABC): 157 """A class that implements the visitor pattern.""" 158 159 def __init__(self): 160 super().__init__() 161 162 @abc.abstractmethod 163 def accept(self, visitor): 164 pass 165 166 167class BasicBlock(AbstractBlock): 168 """A concrete implementation of AbstractBlock that has been compiled.""" 169 170 def __init__(self, instructions): 171 super().__init__() 172 self._instructions = instructions 173 174 def accept(self, visitor): 175 if visitor.visited(self): 176 return 177 visitor.visit(self) 178 179 @property 180 def instructions(self): 181 return self._instructions 182 183 @property 184 def opcodes(self): 185 return b''.join(i.encode() for i in self._instructions) 186 187 def __eq__(self, o): 188 if not isinstance(o, BasicBlock): 189 return False 190 return self._instructions == o._instructions 191 192 193class KillProcess(BasicBlock): 194 """A BasicBlock that unconditionally returns KILL_PROCESS.""" 195 196 def __init__(self): 197 super().__init__( 198 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)]) 199 200 201class KillThread(BasicBlock): 202 """A BasicBlock that unconditionally returns KILL_THREAD.""" 203 204 def __init__(self): 205 super().__init__( 206 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)]) 207 208 209class Trap(BasicBlock): 210 """A BasicBlock that unconditionally returns TRAP.""" 211 212 def __init__(self): 213 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)]) 214 215 216class Trace(BasicBlock): 217 """A BasicBlock that unconditionally returns TRACE.""" 218 219 def __init__(self): 220 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)]) 221 222 223class Log(BasicBlock): 224 """A BasicBlock that unconditionally returns LOG.""" 225 226 def __init__(self): 227 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)]) 228 229 230class ReturnErrno(BasicBlock): 231 """A BasicBlock that unconditionally returns the specified errno.""" 232 233 def __init__(self, errno): 234 super().__init__([ 235 SockFilter(BPF_RET, 0x00, 0x00, 236 SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA)) 237 ]) 238 self.errno = errno 239 240 241class Allow(BasicBlock): 242 """A BasicBlock that unconditionally returns ALLOW.""" 243 244 def __init__(self): 245 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)]) 246 247 248class ValidateArch(AbstractBlock): 249 """An AbstractBlock that validates the architecture.""" 250 251 def __init__(self, next_block): 252 super().__init__() 253 self.next_block = next_block 254 255 def accept(self, visitor): 256 if visitor.visited(self): 257 return 258 self.next_block.accept(visitor) 259 visitor.visit(self) 260 261 262class SyscallEntry(AbstractBlock): 263 """An abstract block that represents a syscall comparison in a DAG.""" 264 265 def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ): 266 super().__init__() 267 self.op = op 268 self.syscall_number = syscall_number 269 self.jt = jt 270 self.jf = jf 271 272 def __lt__(self, o): 273 # Defined because we want to compare tuples that contain SyscallEntries. 274 return False 275 276 def __gt__(self, o): 277 # Defined because we want to compare tuples that contain SyscallEntries. 278 return False 279 280 def accept(self, visitor): 281 if visitor.visited(self): 282 return 283 self.jt.accept(visitor) 284 self.jf.accept(visitor) 285 visitor.visit(self) 286 287 def __lt__(self, o): 288 # Defined because we want to compare tuples that contain SyscallEntries. 289 return False 290 291 def __gt__(self, o): 292 # Defined because we want to compare tuples that contain SyscallEntries. 293 return False 294 295 296class WideAtom(AbstractBlock): 297 """A BasicBlock that represents a 32-bit wide atom.""" 298 299 def __init__(self, arg_offset, op, value, jt, jf): 300 super().__init__() 301 self.arg_offset = arg_offset 302 self.op = op 303 self.value = value 304 self.jt = jt 305 self.jf = jf 306 307 def accept(self, visitor): 308 if visitor.visited(self): 309 return 310 self.jt.accept(visitor) 311 self.jf.accept(visitor) 312 visitor.visit(self) 313 314 315class Atom(AbstractBlock): 316 """A BasicBlock that represents an atom (a simple comparison operation).""" 317 318 def __init__(self, arg_index, op, value, jt, jf): 319 super().__init__() 320 if op == '==': 321 op = BPF_JEQ 322 elif op == '!=': 323 op = BPF_JEQ 324 jt, jf = jf, jt 325 elif op == '>': 326 op = BPF_JGT 327 elif op == '<=': 328 op = BPF_JGT 329 jt, jf = jf, jt 330 elif op == '>=': 331 op = BPF_JGE 332 elif op == '<': 333 op = BPF_JGE 334 jt, jf = jf, jt 335 elif op == '&': 336 op = BPF_JSET 337 elif op == 'in': 338 op = BPF_JSET 339 # The mask is negated, so the comparison will be true when the 340 # argument includes a flag that wasn't listed in the original 341 # (non-negated) mask. This would be the failure case, so we switch 342 # |jt| and |jf|. 343 value = (~value) & ((1 << 64) - 1) 344 jt, jf = jf, jt 345 else: 346 raise Exception('Unknown operator %s' % op) 347 348 self.arg_index = arg_index 349 self.op = op 350 self.jt = jt 351 self.jf = jf 352 self.value = value 353 354 def accept(self, visitor): 355 if visitor.visited(self): 356 return 357 self.jt.accept(visitor) 358 self.jf.accept(visitor) 359 visitor.visit(self) 360 361 362class AbstractVisitor(abc.ABC): 363 """An abstract visitor.""" 364 365 def __init__(self): 366 self._visited = set() 367 368 def visited(self, block): 369 if id(block) in self._visited: 370 return True 371 self._visited.add(id(block)) 372 return False 373 374 def process(self, block): 375 block.accept(self) 376 return block 377 378 def visit(self, block): 379 if isinstance(block, KillProcess): 380 self.visitKillProcess(block) 381 elif isinstance(block, KillThread): 382 self.visitKillThread(block) 383 elif isinstance(block, Trap): 384 self.visitTrap(block) 385 elif isinstance(block, ReturnErrno): 386 self.visitReturnErrno(block) 387 elif isinstance(block, Trace): 388 self.visitTrace(block) 389 elif isinstance(block, Log): 390 self.visitLog(block) 391 elif isinstance(block, Allow): 392 self.visitAllow(block) 393 elif isinstance(block, BasicBlock): 394 self.visitBasicBlock(block) 395 elif isinstance(block, ValidateArch): 396 self.visitValidateArch(block) 397 elif isinstance(block, SyscallEntry): 398 self.visitSyscallEntry(block) 399 elif isinstance(block, WideAtom): 400 self.visitWideAtom(block) 401 elif isinstance(block, Atom): 402 self.visitAtom(block) 403 else: 404 raise Exception('Unknown block type: %r' % block) 405 406 @abc.abstractmethod 407 def visitKillProcess(self, block): 408 pass 409 410 @abc.abstractmethod 411 def visitKillThread(self, block): 412 pass 413 414 @abc.abstractmethod 415 def visitTrap(self, block): 416 pass 417 418 @abc.abstractmethod 419 def visitReturnErrno(self, block): 420 pass 421 422 @abc.abstractmethod 423 def visitTrace(self, block): 424 pass 425 426 @abc.abstractmethod 427 def visitLog(self, block): 428 pass 429 430 @abc.abstractmethod 431 def visitAllow(self, block): 432 pass 433 434 @abc.abstractmethod 435 def visitBasicBlock(self, block): 436 pass 437 438 @abc.abstractmethod 439 def visitValidateArch(self, block): 440 pass 441 442 @abc.abstractmethod 443 def visitSyscallEntry(self, block): 444 pass 445 446 @abc.abstractmethod 447 def visitWideAtom(self, block): 448 pass 449 450 @abc.abstractmethod 451 def visitAtom(self, block): 452 pass 453 454 455class CopyingVisitor(AbstractVisitor): 456 """A visitor that copies Blocks.""" 457 458 def __init__(self): 459 super().__init__() 460 self._mapping = {} 461 462 def process(self, block): 463 self._mapping = {} 464 block.accept(self) 465 return self._mapping[id(block)] 466 467 def visitKillProcess(self, block): 468 assert id(block) not in self._mapping 469 self._mapping[id(block)] = KillProcess() 470 471 def visitKillThread(self, block): 472 assert id(block) not in self._mapping 473 self._mapping[id(block)] = KillThread() 474 475 def visitTrap(self, block): 476 assert id(block) not in self._mapping 477 self._mapping[id(block)] = Trap() 478 479 def visitReturnErrno(self, block): 480 assert id(block) not in self._mapping 481 self._mapping[id(block)] = ReturnErrno(block.errno) 482 483 def visitTrace(self, block): 484 assert id(block) not in self._mapping 485 self._mapping[id(block)] = Trace() 486 487 def visitLog(self, block): 488 assert id(block) not in self._mapping 489 self._mapping[id(block)] = Log() 490 491 def visitAllow(self, block): 492 assert id(block) not in self._mapping 493 self._mapping[id(block)] = Allow() 494 495 def visitBasicBlock(self, block): 496 assert id(block) not in self._mapping 497 self._mapping[id(block)] = BasicBlock(block.instructions) 498 499 def visitValidateArch(self, block): 500 assert id(block) not in self._mapping 501 self._mapping[id(block)] = ValidateArch( 502 block.arch, self._mapping[id(block.next_block)]) 503 504 def visitSyscallEntry(self, block): 505 assert id(block) not in self._mapping 506 self._mapping[id(block)] = SyscallEntry( 507 block.syscall_number, 508 self._mapping[id(block.jt)], 509 self._mapping[id(block.jf)], 510 op=block.op) 511 512 def visitWideAtom(self, block): 513 assert id(block) not in self._mapping 514 self._mapping[id(block)] = WideAtom( 515 block.arg_offset, block.op, block.value, self._mapping[id( 516 block.jt)], self._mapping[id(block.jf)]) 517 518 def visitAtom(self, block): 519 assert id(block) not in self._mapping 520 self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value, 521 self._mapping[id(block.jt)], 522 self._mapping[id(block.jf)]) 523 524 525class LoweringVisitor(CopyingVisitor): 526 """A visitor that lowers Atoms into WideAtoms.""" 527 528 def __init__(self, *, arch): 529 super().__init__() 530 self._bits = arch.bits 531 532 def visitAtom(self, block): 533 assert id(block) not in self._mapping 534 535 lo = block.value & 0xFFFFFFFF 536 hi = (block.value >> 32) & 0xFFFFFFFF 537 538 lo_block = WideAtom( 539 arg_offset(block.arg_index, False), block.op, lo, 540 self._mapping[id(block.jt)], self._mapping[id(block.jf)]) 541 542 if self._bits == 32: 543 self._mapping[id(block)] = lo_block 544 return 545 546 if block.op in (BPF_JGE, BPF_JGT): 547 # hi_1,lo_1 <op> hi_2,lo_2 548 # 549 # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2 550 if hi == 0: 551 # Special case: it's not needed to check whether |hi_1 == hi_2|, 552 # because it's true iff the JGT test fails. 553 self._mapping[id(block)] = WideAtom( 554 arg_offset(block.arg_index, True), BPF_JGT, hi, 555 self._mapping[id(block.jt)], lo_block) 556 return 557 hi_eq_block = WideAtom( 558 arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block, 559 self._mapping[id(block.jf)]) 560 self._mapping[id(block)] = WideAtom( 561 arg_offset(block.arg_index, True), BPF_JGT, hi, 562 self._mapping[id(block.jt)], hi_eq_block) 563 return 564 if block.op == BPF_JSET: 565 # hi_1,lo_1 & hi_2,lo_2 566 # 567 # hi_1 & hi_2 || lo_1 & lo_2 568 if hi == 0: 569 # Special case: |hi_1 & hi_2| will never be True, so jump 570 # directly into the |lo_1 & lo_2| case. 571 self._mapping[id(block)] = lo_block 572 return 573 self._mapping[id(block)] = WideAtom( 574 arg_offset(block.arg_index, True), block.op, hi, 575 self._mapping[id(block.jt)], lo_block) 576 return 577 578 assert block.op == BPF_JEQ, block.op 579 580 # hi_1,lo_1 == hi_2,lo_2 581 # 582 # hi_1 == hi_2 && lo_1 == lo_2 583 self._mapping[id(block)] = WideAtom( 584 arg_offset(block.arg_index, True), block.op, hi, lo_block, 585 self._mapping[id(block.jf)]) 586 587 588class FlatteningVisitor: 589 """A visitor that flattens a DAG of Block objects.""" 590 591 def __init__(self, *, arch, kill_action): 592 self._visited = set() 593 self._kill_action = kill_action 594 self._instructions = [] 595 self._arch = arch 596 self._offsets = {} 597 598 @property 599 def result(self): 600 return BasicBlock(self._instructions) 601 602 def _distance(self, block): 603 distance = self._offsets[id(block)] + len(self._instructions) 604 assert distance >= 0 605 return distance 606 607 def _emit_load_arg(self, offset): 608 return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)] 609 610 def _emit_jmp(self, op, value, jt_distance, jf_distance): 611 if jt_distance < 0x100 and jf_distance < 0x100: 612 return [ 613 SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance, 614 value), 615 ] 616 if jt_distance + 1 < 0x100: 617 return [ 618 SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value), 619 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 620 ] 621 if jf_distance + 1 < 0x100: 622 return [ 623 SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value), 624 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance), 625 ] 626 return [ 627 SockFilter(BPF_JMP | op | BPF_K, 0, 1, value), 628 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1), 629 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 630 ] 631 632 def visited(self, block): 633 if id(block) in self._visited: 634 return True 635 self._visited.add(id(block)) 636 return False 637 638 def visit(self, block): 639 assert id(block) not in self._offsets 640 641 if isinstance(block, BasicBlock): 642 instructions = block.instructions 643 elif isinstance(block, ValidateArch): 644 instructions = [ 645 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4), 646 SockFilter(BPF_JMP | BPF_JEQ | BPF_K, 647 self._distance(block.next_block) + 1, 0, 648 self._arch.arch_nr), 649 ] + self._kill_action.instructions + [ 650 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0), 651 ] 652 elif isinstance(block, SyscallEntry): 653 instructions = self._emit_jmp(block.op, block.syscall_number, 654 self._distance(block.jt), 655 self._distance(block.jf)) 656 elif isinstance(block, WideAtom): 657 instructions = ( 658 self._emit_load_arg(block.arg_offset) + self._emit_jmp( 659 block.op, block.value, self._distance(block.jt), 660 self._distance(block.jf))) 661 else: 662 raise Exception('Unknown block type: %r' % block) 663 664 self._instructions = instructions + self._instructions 665 self._offsets[id(block)] = -len(self._instructions) 666 return 667 668 669class ArgFilterForwardingVisitor: 670 """A visitor that forwards visitation to all arg filters.""" 671 672 def __init__(self, visitor): 673 self._visited = set() 674 self.visitor = visitor 675 676 def visited(self, block): 677 if id(block) in self._visited: 678 return True 679 self._visited.add(id(block)) 680 return False 681 682 def visit(self, block): 683 # All arg filters are BasicBlocks. 684 if not isinstance(block, BasicBlock): 685 return 686 # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't 687 # want to visit them just yet. 688 if (isinstance(block, KillProcess) or isinstance(block, KillThread) 689 or isinstance(block, Trap) or isinstance(block, ReturnErrno) 690 or isinstance(block, Trace) or isinstance(block, Log) 691 or isinstance(block, Allow)): 692 return 693 block.accept(self.visitor) 694