1# Copyright 2020 The ChromiumOS Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5"""Tools to interact with BPF programs.""" 6 7import abc 8import collections 9import struct 10 11 12# This comes from syscall(2). Most architectures only support passing 6 args to 13# syscalls, but ARM supports passing 7. 14MAX_SYSCALL_ARGUMENTS = 7 15 16# The following fields were copied from <linux/bpf_common.h>: 17 18# Instruction classes 19BPF_LD = 0x00 20BPF_LDX = 0x01 21BPF_ST = 0x02 22BPF_STX = 0x03 23BPF_ALU = 0x04 24BPF_JMP = 0x05 25BPF_RET = 0x06 26BPF_MISC = 0x07 27 28# LD/LDX fields. 29# Size 30BPF_W = 0x00 31BPF_H = 0x08 32BPF_B = 0x10 33# Mode 34BPF_IMM = 0x00 35BPF_ABS = 0x20 36BPF_IND = 0x40 37BPF_MEM = 0x60 38BPF_LEN = 0x80 39BPF_MSH = 0xA0 40 41# JMP fields. 42BPF_JA = 0x00 43BPF_JEQ = 0x10 44BPF_JGT = 0x20 45BPF_JGE = 0x30 46BPF_JSET = 0x40 47 48# Source 49BPF_K = 0x00 50BPF_X = 0x08 51 52BPF_MAXINSNS = 4096 53 54# The following fields were copied from <linux/seccomp.h>: 55 56SECCOMP_RET_KILL_PROCESS = 0x80000000 57SECCOMP_RET_KILL_THREAD = 0x00000000 58SECCOMP_RET_TRAP = 0x00030000 59SECCOMP_RET_ERRNO = 0x00050000 60SECCOMP_RET_TRACE = 0x7FF00000 61SECCOMP_RET_USER_NOTIF = 0x7FC00000 62SECCOMP_RET_LOG = 0x7FFC0000 63SECCOMP_RET_ALLOW = 0x7FFF0000 64 65SECCOMP_RET_ACTION_FULL = 0xFFFF0000 66SECCOMP_RET_DATA = 0x0000FFFF 67 68 69def arg_offset(arg_index, hi=False): 70 """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset.""" 71 offsetof_args = 4 + 4 + 8 72 arg_width = 8 73 return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi 74 75 76def simulate(instructions, arch, syscall_number, *args): 77 """Simulate a BPF program with the given arguments.""" 78 args = (args + (0,) * (MAX_SYSCALL_ARGUMENTS - len(args)))[ 79 :MAX_SYSCALL_ARGUMENTS 80 ] 81 input_memory = struct.pack( 82 "IIQ" + "Q" * MAX_SYSCALL_ARGUMENTS, syscall_number, arch, 0, *args 83 ) 84 85 register = 0 86 program_counter = 0 87 cost = 0 88 while program_counter < len(instructions): 89 ins = instructions[program_counter] 90 program_counter += 1 91 cost += 1 92 if ins.code == BPF_LD | BPF_W | BPF_ABS: 93 register = struct.unpack("I", input_memory[ins.k : ins.k + 4])[0] 94 elif ins.code == BPF_JMP | BPF_JA | BPF_K: 95 program_counter += ins.k 96 elif ins.code == BPF_JMP | BPF_JEQ | BPF_K: 97 if register == ins.k: 98 program_counter += ins.jt 99 else: 100 program_counter += ins.jf 101 elif ins.code == BPF_JMP | BPF_JGT | BPF_K: 102 if register > ins.k: 103 program_counter += ins.jt 104 else: 105 program_counter += ins.jf 106 elif ins.code == BPF_JMP | BPF_JGE | BPF_K: 107 if register >= ins.k: 108 program_counter += ins.jt 109 else: 110 program_counter += ins.jf 111 elif ins.code == BPF_JMP | BPF_JSET | BPF_K: 112 if register & ins.k != 0: 113 program_counter += ins.jt 114 else: 115 program_counter += ins.jf 116 elif ins.code == BPF_RET: 117 if ins.k == SECCOMP_RET_KILL_PROCESS: 118 return (cost, "KILL_PROCESS") 119 if ins.k == SECCOMP_RET_KILL_THREAD: 120 return (cost, "KILL_THREAD") 121 if ins.k == SECCOMP_RET_TRAP: 122 return (cost, "TRAP") 123 if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO: 124 return (cost, "ERRNO", ins.k & SECCOMP_RET_DATA) 125 if ins.k == SECCOMP_RET_TRACE: 126 return (cost, "TRACE") 127 if ins.k == SECCOMP_RET_USER_NOTIF: 128 return (cost, "USER_NOTIF") 129 if ins.k == SECCOMP_RET_LOG: 130 return (cost, "LOG") 131 if ins.k == SECCOMP_RET_ALLOW: 132 return (cost, "ALLOW") 133 raise Exception("unknown return %#x" % ins.k) 134 else: 135 raise Exception("unknown instruction %r" % (ins,)) 136 raise Exception("out-of-bounds") 137 138 139class SockFilter( 140 collections.namedtuple("SockFilter", ["code", "jt", "jf", "k"]) 141): 142 """A representation of struct sock_filter.""" 143 144 __slots__ = () 145 146 def encode(self): 147 """Return an encoded version of the SockFilter.""" 148 return struct.pack("HBBI", self.code, self.jt, self.jf, self.k) 149 150 151class AbstractBlock(abc.ABC): 152 """A class that implements the visitor pattern.""" 153 154 @abc.abstractmethod 155 def accept(self, visitor): 156 pass 157 158 159class BasicBlock(AbstractBlock): 160 """A concrete implementation of AbstractBlock that has been compiled.""" 161 162 def __init__(self, instructions): 163 super().__init__() 164 self._instructions = instructions 165 166 def accept(self, visitor): 167 if visitor.visited(self): 168 return 169 visitor.visit(self) 170 171 @property 172 def instructions(self): 173 return self._instructions 174 175 @property 176 def opcodes(self): 177 return b"".join(i.encode() for i in self._instructions) 178 179 def __eq__(self, o): 180 if not isinstance(o, BasicBlock): 181 return False 182 return self._instructions == o._instructions 183 184 185class KillProcess(BasicBlock): 186 """A BasicBlock that unconditionally returns KILL_PROCESS.""" 187 188 def __init__(self): 189 super().__init__( 190 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)] 191 ) 192 193 194class KillThread(BasicBlock): 195 """A BasicBlock that unconditionally returns KILL_THREAD.""" 196 197 def __init__(self): 198 super().__init__( 199 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)] 200 ) 201 202 203class Trap(BasicBlock): 204 """A BasicBlock that unconditionally returns TRAP.""" 205 206 def __init__(self): 207 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)]) 208 209 210class Trace(BasicBlock): 211 """A BasicBlock that unconditionally returns TRACE.""" 212 213 def __init__(self): 214 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)]) 215 216 217class UserNotify(BasicBlock): 218 """A BasicBlock that unconditionally returns USER_NOTIF.""" 219 220 def __init__(self): 221 super().__init__( 222 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_USER_NOTIF)] 223 ) 224 225 226class Log(BasicBlock): 227 """A BasicBlock that unconditionally returns LOG.""" 228 229 def __init__(self): 230 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)]) 231 232 233class ReturnErrno(BasicBlock): 234 """A BasicBlock that unconditionally returns the specified errno.""" 235 236 def __init__(self, errno): 237 super().__init__( 238 [ 239 SockFilter( 240 BPF_RET, 241 0x00, 242 0x00, 243 SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA), 244 ) 245 ] 246 ) 247 self.errno = errno 248 249 250class Allow(BasicBlock): 251 """A BasicBlock that unconditionally returns ALLOW.""" 252 253 def __init__(self): 254 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)]) 255 256 257class ValidateArch(AbstractBlock): 258 """An AbstractBlock that validates the architecture.""" 259 260 def __init__(self, next_block): 261 super().__init__() 262 self.next_block = next_block 263 264 def accept(self, visitor): 265 if visitor.visited(self): 266 return 267 self.next_block.accept(visitor) 268 visitor.visit(self) 269 270 271class SyscallEntry(AbstractBlock): 272 """An abstract block that represents a syscall comparison in a DAG.""" 273 274 def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ): 275 super().__init__() 276 self.op = op 277 self.syscall_number = syscall_number 278 self.jt = jt 279 self.jf = jf 280 281 def __lt__(self, o): 282 # Defined because we want to compare tuples that contain SyscallEntries. 283 return False 284 285 def __gt__(self, o): 286 # Defined because we want to compare tuples that contain SyscallEntries. 287 return False 288 289 def accept(self, visitor): 290 if visitor.visited(self): 291 return 292 self.jt.accept(visitor) 293 self.jf.accept(visitor) 294 visitor.visit(self) 295 296 def __lt__(self, o): 297 # Defined because we want to compare tuples that contain SyscallEntries. 298 return False 299 300 def __gt__(self, o): 301 # Defined because we want to compare tuples that contain SyscallEntries. 302 return False 303 304 305class WideAtom(AbstractBlock): 306 """A BasicBlock that represents a 32-bit wide atom.""" 307 308 def __init__( 309 self, arg_offset, op, value, jt, jf 310 ): # pylint: disable=redefined-outer-name 311 super().__init__() 312 self.arg_offset = arg_offset 313 self.op = op 314 self.value = value 315 self.jt = jt 316 self.jf = jf 317 318 def accept(self, visitor): 319 if visitor.visited(self): 320 return 321 self.jt.accept(visitor) 322 self.jf.accept(visitor) 323 visitor.visit(self) 324 325 326class Atom(AbstractBlock): 327 """A BasicBlock that represents an atom (a simple comparison operation).""" 328 329 def __init__(self, arg_index, op, value, jt, jf): 330 super().__init__() 331 if op == "==": 332 op = BPF_JEQ 333 elif op == "!=": 334 op = BPF_JEQ 335 jt, jf = jf, jt 336 elif op == ">": 337 op = BPF_JGT 338 elif op == "<=": 339 op = BPF_JGT 340 jt, jf = jf, jt 341 elif op == ">=": 342 op = BPF_JGE 343 elif op == "<": 344 op = BPF_JGE 345 jt, jf = jf, jt 346 elif op == "&": 347 op = BPF_JSET 348 elif op == "in": 349 op = BPF_JSET 350 # The mask is negated, so the comparison will be true when the 351 # argument includes a flag that wasn't listed in the original 352 # (non-negated) mask. This would be the failure case, so we switch 353 # |jt| and |jf|. 354 value = (~value) & ((1 << 64) - 1) 355 jt, jf = jf, jt 356 else: 357 raise Exception("Unknown operator %s" % op) 358 359 self.arg_index = arg_index 360 self.op = op 361 self.jt = jt 362 self.jf = jf 363 self.value = value 364 365 def accept(self, visitor): 366 if visitor.visited(self): 367 return 368 self.jt.accept(visitor) 369 self.jf.accept(visitor) 370 visitor.visit(self) 371 372 373class AbstractVisitor(abc.ABC): 374 """An abstract visitor.""" 375 376 def __init__(self): 377 self._visited = set() 378 379 def visited(self, block): 380 if id(block) in self._visited: 381 return True 382 self._visited.add(id(block)) 383 return False 384 385 def process(self, block): 386 block.accept(self) 387 return block 388 389 def visit(self, block): 390 if isinstance(block, KillProcess): 391 self.visitKillProcess(block) 392 elif isinstance(block, KillThread): 393 self.visitKillThread(block) 394 elif isinstance(block, Trap): 395 self.visitTrap(block) 396 elif isinstance(block, ReturnErrno): 397 self.visitReturnErrno(block) 398 elif isinstance(block, Trace): 399 self.visitTrace(block) 400 elif isinstance(block, UserNotify): 401 self.visitUserNotify(block) 402 elif isinstance(block, Log): 403 self.visitLog(block) 404 elif isinstance(block, Allow): 405 self.visitAllow(block) 406 elif isinstance(block, BasicBlock): 407 self.visitBasicBlock(block) 408 elif isinstance(block, ValidateArch): 409 self.visitValidateArch(block) 410 elif isinstance(block, SyscallEntry): 411 self.visitSyscallEntry(block) 412 elif isinstance(block, WideAtom): 413 self.visitWideAtom(block) 414 elif isinstance(block, Atom): 415 self.visitAtom(block) 416 else: 417 raise Exception("Unknown block type: %r" % block) 418 419 @abc.abstractmethod 420 def visitKillProcess(self, block): 421 pass 422 423 @abc.abstractmethod 424 def visitKillThread(self, block): 425 pass 426 427 @abc.abstractmethod 428 def visitTrap(self, block): 429 pass 430 431 @abc.abstractmethod 432 def visitReturnErrno(self, block): 433 pass 434 435 @abc.abstractmethod 436 def visitTrace(self, block): 437 pass 438 439 @abc.abstractmethod 440 def visitUserNotify(self, block): 441 pass 442 443 @abc.abstractmethod 444 def visitLog(self, block): 445 pass 446 447 @abc.abstractmethod 448 def visitAllow(self, block): 449 pass 450 451 @abc.abstractmethod 452 def visitBasicBlock(self, block): 453 pass 454 455 @abc.abstractmethod 456 def visitValidateArch(self, block): 457 pass 458 459 @abc.abstractmethod 460 def visitSyscallEntry(self, block): 461 pass 462 463 @abc.abstractmethod 464 def visitWideAtom(self, block): 465 pass 466 467 @abc.abstractmethod 468 def visitAtom(self, block): 469 pass 470 471 472class CopyingVisitor(AbstractVisitor): 473 """A visitor that copies Blocks.""" 474 475 def __init__(self): 476 super().__init__() 477 self._mapping = {} 478 479 def process(self, block): 480 self._mapping = {} 481 block.accept(self) 482 return self._mapping[id(block)] 483 484 def visitKillProcess(self, block): 485 assert id(block) not in self._mapping 486 self._mapping[id(block)] = KillProcess() 487 488 def visitKillThread(self, block): 489 assert id(block) not in self._mapping 490 self._mapping[id(block)] = KillThread() 491 492 def visitTrap(self, block): 493 assert id(block) not in self._mapping 494 self._mapping[id(block)] = Trap() 495 496 def visitReturnErrno(self, block): 497 assert id(block) not in self._mapping 498 self._mapping[id(block)] = ReturnErrno(block.errno) 499 500 def visitTrace(self, block): 501 assert id(block) not in self._mapping 502 self._mapping[id(block)] = Trace() 503 504 def visitUserNotify(self, block): 505 assert id(block) not in self._mapping 506 self._mapping[id(block)] = UserNotify() 507 508 def visitLog(self, block): 509 assert id(block) not in self._mapping 510 self._mapping[id(block)] = Log() 511 512 def visitAllow(self, block): 513 assert id(block) not in self._mapping 514 self._mapping[id(block)] = Allow() 515 516 def visitBasicBlock(self, block): 517 assert id(block) not in self._mapping 518 self._mapping[id(block)] = BasicBlock(block.instructions) 519 520 def visitValidateArch(self, block): 521 assert id(block) not in self._mapping 522 self._mapping[id(block)] = ValidateArch( 523 self._mapping[id(block.next_block)] 524 ) 525 526 def visitSyscallEntry(self, block): 527 assert id(block) not in self._mapping 528 self._mapping[id(block)] = SyscallEntry( 529 block.syscall_number, 530 self._mapping[id(block.jt)], 531 self._mapping[id(block.jf)], 532 op=block.op, 533 ) 534 535 def visitWideAtom(self, block): 536 assert id(block) not in self._mapping 537 self._mapping[id(block)] = WideAtom( 538 block.arg_offset, 539 block.op, 540 block.value, 541 self._mapping[id(block.jt)], 542 self._mapping[id(block.jf)], 543 ) 544 545 def visitAtom(self, block): 546 assert id(block) not in self._mapping 547 self._mapping[id(block)] = Atom( 548 block.arg_index, 549 block.op, 550 block.value, 551 self._mapping[id(block.jt)], 552 self._mapping[id(block.jf)], 553 ) 554 555 556class LoweringVisitor(CopyingVisitor): 557 """A visitor that lowers Atoms into WideAtoms.""" 558 559 def __init__(self, *, arch): 560 super().__init__() 561 self._bits = arch.bits 562 563 def visitAtom(self, block): 564 assert id(block) not in self._mapping 565 566 lo = block.value & 0xFFFFFFFF 567 hi = (block.value >> 32) & 0xFFFFFFFF 568 569 lo_block = WideAtom( 570 arg_offset(block.arg_index, False), 571 block.op, 572 lo, 573 self._mapping[id(block.jt)], 574 self._mapping[id(block.jf)], 575 ) 576 577 if self._bits == 32: 578 self._mapping[id(block)] = lo_block 579 return 580 581 if block.op in (BPF_JGE, BPF_JGT): 582 # hi_1,lo_1 <op> hi_2,lo_2 583 # 584 # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2 585 if hi == 0: 586 # Special case: it's not needed to check whether |hi_1 == hi_2|, 587 # because it's true iff the JGT test fails. 588 self._mapping[id(block)] = WideAtom( 589 arg_offset(block.arg_index, True), 590 BPF_JGT, 591 hi, 592 self._mapping[id(block.jt)], 593 lo_block, 594 ) 595 return 596 hi_eq_block = WideAtom( 597 arg_offset(block.arg_index, True), 598 BPF_JEQ, 599 hi, 600 lo_block, 601 self._mapping[id(block.jf)], 602 ) 603 self._mapping[id(block)] = WideAtom( 604 arg_offset(block.arg_index, True), 605 BPF_JGT, 606 hi, 607 self._mapping[id(block.jt)], 608 hi_eq_block, 609 ) 610 return 611 if block.op == BPF_JSET: 612 # hi_1,lo_1 & hi_2,lo_2 613 # 614 # hi_1 & hi_2 || lo_1 & lo_2 615 if hi == 0: 616 # Special case: |hi_1 & hi_2| will never be True, so jump 617 # directly into the |lo_1 & lo_2| case. 618 self._mapping[id(block)] = lo_block 619 return 620 self._mapping[id(block)] = WideAtom( 621 arg_offset(block.arg_index, True), 622 block.op, 623 hi, 624 self._mapping[id(block.jt)], 625 lo_block, 626 ) 627 return 628 629 assert block.op == BPF_JEQ, block.op 630 631 # hi_1,lo_1 == hi_2,lo_2 632 # 633 # hi_1 == hi_2 && lo_1 == lo_2 634 self._mapping[id(block)] = WideAtom( 635 arg_offset(block.arg_index, True), 636 block.op, 637 hi, 638 lo_block, 639 self._mapping[id(block.jf)], 640 ) 641 642 643class FlatteningVisitor: 644 """A visitor that flattens a DAG of Block objects.""" 645 646 def __init__(self, *, arch, kill_action): 647 self._visited = set() 648 self._kill_action = kill_action 649 self._instructions = [] 650 self._arch = arch 651 self._offsets = {} 652 653 @property 654 def result(self): 655 return BasicBlock(self._instructions) 656 657 def _distance(self, block): 658 distance = self._offsets[id(block)] + len(self._instructions) 659 assert distance >= 0 660 return distance 661 662 def _emit_load_arg(self, offset): 663 return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)] 664 665 def _emit_jmp(self, op, value, jt_distance, jf_distance): 666 if jt_distance < 0x100 and jf_distance < 0x100: 667 return [ 668 SockFilter( 669 BPF_JMP | op | BPF_K, jt_distance, jf_distance, value 670 ), 671 ] 672 if jt_distance + 1 < 0x100: 673 return [ 674 SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value), 675 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 676 ] 677 if jf_distance + 1 < 0x100: 678 return [ 679 SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value), 680 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance), 681 ] 682 return [ 683 SockFilter(BPF_JMP | op | BPF_K, 0, 1, value), 684 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1), 685 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 686 ] 687 688 def visited(self, block): 689 if id(block) in self._visited: 690 return True 691 self._visited.add(id(block)) 692 return False 693 694 def visit(self, block): 695 assert id(block) not in self._offsets 696 697 if isinstance(block, BasicBlock): 698 instructions = block.instructions 699 elif isinstance(block, ValidateArch): 700 instructions = ( 701 [ 702 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4), 703 SockFilter( 704 BPF_JMP | BPF_JEQ | BPF_K, 705 self._distance(block.next_block) + 1, 706 0, 707 self._arch.arch_nr, 708 ), 709 ] 710 + self._kill_action.instructions 711 + [ 712 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0), 713 ] 714 ) 715 elif isinstance(block, SyscallEntry): 716 instructions = self._emit_jmp( 717 block.op, 718 block.syscall_number, 719 self._distance(block.jt), 720 self._distance(block.jf), 721 ) 722 elif isinstance(block, WideAtom): 723 instructions = self._emit_load_arg( 724 block.arg_offset 725 ) + self._emit_jmp( 726 block.op, 727 block.value, 728 self._distance(block.jt), 729 self._distance(block.jf), 730 ) 731 else: 732 raise Exception("Unknown block type: %r" % block) 733 734 self._instructions = instructions + self._instructions 735 self._offsets[id(block)] = -len(self._instructions) 736 737 738class ArgFilterForwardingVisitor: 739 """A visitor that forwards visitation to all arg filters.""" 740 741 def __init__(self, visitor): 742 self._visited = set() 743 self.visitor = visitor 744 745 def visited(self, block): 746 if id(block) in self._visited: 747 return True 748 self._visited.add(id(block)) 749 return False 750 751 def visit(self, block): 752 # All arg filters are BasicBlocks. 753 if not isinstance(block, BasicBlock): 754 return 755 # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't 756 # want to visit them just yet. 757 if isinstance( 758 block, 759 ( 760 KillProcess, 761 KillThread, 762 Trap, 763 ReturnErrno, 764 Trace, 765 UserNotify, 766 Log, 767 Allow, 768 ), 769 ): 770 return 771 block.accept(self.visitor) 772