1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (C) 2018 The Android Open Source Project 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17"""Tools to interact with BPF programs.""" 18 19import abc 20import collections 21import struct 22 23# This comes from syscall(2). Most architectures only support passing 6 args to 24# syscalls, but ARM supports passing 7. 25MAX_SYSCALL_ARGUMENTS = 7 26 27# The following fields were copied from <linux/bpf_common.h>: 28 29# Instruction classes 30BPF_LD = 0x00 31BPF_LDX = 0x01 32BPF_ST = 0x02 33BPF_STX = 0x03 34BPF_ALU = 0x04 35BPF_JMP = 0x05 36BPF_RET = 0x06 37BPF_MISC = 0x07 38 39# LD/LDX fields. 40# Size 41BPF_W = 0x00 42BPF_H = 0x08 43BPF_B = 0x10 44# Mode 45BPF_IMM = 0x00 46BPF_ABS = 0x20 47BPF_IND = 0x40 48BPF_MEM = 0x60 49BPF_LEN = 0x80 50BPF_MSH = 0xa0 51 52# JMP fields. 53BPF_JA = 0x00 54BPF_JEQ = 0x10 55BPF_JGT = 0x20 56BPF_JGE = 0x30 57BPF_JSET = 0x40 58 59# Source 60BPF_K = 0x00 61BPF_X = 0x08 62 63BPF_MAXINSNS = 4096 64 65# The following fields were copied from <linux/seccomp.h>: 66 67SECCOMP_RET_KILL_PROCESS = 0x80000000 68SECCOMP_RET_KILL_THREAD = 0x00000000 69SECCOMP_RET_TRAP = 0x00030000 70SECCOMP_RET_ERRNO = 0x00050000 71SECCOMP_RET_TRACE = 0x7ff00000 72SECCOMP_RET_USER_NOTIF = 0x7fc00000 73SECCOMP_RET_LOG = 0x7ffc0000 74SECCOMP_RET_ALLOW = 0x7fff0000 75 76SECCOMP_RET_ACTION_FULL = 0xffff0000 77SECCOMP_RET_DATA = 0x0000ffff 78 79 80def arg_offset(arg_index, hi=False): 81 """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset.""" 82 offsetof_args = 4 + 4 + 8 83 arg_width = 8 84 return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi 85 86 87def simulate(instructions, arch, syscall_number, *args): 88 """Simulate a BPF program with the given arguments.""" 89 args = ((args + (0, ) * 90 (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS]) 91 input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS, 92 syscall_number, arch, 0, *args) 93 94 register = 0 95 program_counter = 0 96 cost = 0 97 while program_counter < len(instructions): 98 ins = instructions[program_counter] 99 program_counter += 1 100 cost += 1 101 if ins.code == BPF_LD | BPF_W | BPF_ABS: 102 register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0] 103 elif ins.code == BPF_JMP | BPF_JA | BPF_K: 104 program_counter += ins.k 105 elif ins.code == BPF_JMP | BPF_JEQ | BPF_K: 106 if register == ins.k: 107 program_counter += ins.jt 108 else: 109 program_counter += ins.jf 110 elif ins.code == BPF_JMP | BPF_JGT | BPF_K: 111 if register > ins.k: 112 program_counter += ins.jt 113 else: 114 program_counter += ins.jf 115 elif ins.code == BPF_JMP | BPF_JGE | BPF_K: 116 if register >= ins.k: 117 program_counter += ins.jt 118 else: 119 program_counter += ins.jf 120 elif ins.code == BPF_JMP | BPF_JSET | BPF_K: 121 if register & ins.k != 0: 122 program_counter += ins.jt 123 else: 124 program_counter += ins.jf 125 elif ins.code == BPF_RET: 126 if ins.k == SECCOMP_RET_KILL_PROCESS: 127 return (cost, 'KILL_PROCESS') 128 if ins.k == SECCOMP_RET_KILL_THREAD: 129 return (cost, 'KILL_THREAD') 130 if ins.k == SECCOMP_RET_TRAP: 131 return (cost, 'TRAP') 132 if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO: 133 return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA) 134 if ins.k == SECCOMP_RET_TRACE: 135 return (cost, 'TRACE') 136 if ins.k == SECCOMP_RET_USER_NOTIF: 137 return (cost, 'USER_NOTIF') 138 if ins.k == SECCOMP_RET_LOG: 139 return (cost, 'LOG') 140 if ins.k == SECCOMP_RET_ALLOW: 141 return (cost, 'ALLOW') 142 raise Exception('unknown return %#x' % ins.k) 143 else: 144 raise Exception('unknown instruction %r' % (ins, )) 145 raise Exception('out-of-bounds') 146 147 148class SockFilter( 149 collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])): 150 """A representation of struct sock_filter.""" 151 152 __slots__ = () 153 154 def encode(self): 155 """Return an encoded version of the SockFilter.""" 156 return struct.pack('HBBI', self.code, self.jt, self.jf, self.k) 157 158 159class AbstractBlock(abc.ABC): 160 """A class that implements the visitor pattern.""" 161 162 def __init__(self): 163 super().__init__() 164 165 @abc.abstractmethod 166 def accept(self, visitor): 167 pass 168 169 170class BasicBlock(AbstractBlock): 171 """A concrete implementation of AbstractBlock that has been compiled.""" 172 173 def __init__(self, instructions): 174 super().__init__() 175 self._instructions = instructions 176 177 def accept(self, visitor): 178 if visitor.visited(self): 179 return 180 visitor.visit(self) 181 182 @property 183 def instructions(self): 184 return self._instructions 185 186 @property 187 def opcodes(self): 188 return b''.join(i.encode() for i in self._instructions) 189 190 def __eq__(self, o): 191 if not isinstance(o, BasicBlock): 192 return False 193 return self._instructions == o._instructions 194 195 196class KillProcess(BasicBlock): 197 """A BasicBlock that unconditionally returns KILL_PROCESS.""" 198 199 def __init__(self): 200 super().__init__( 201 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)]) 202 203 204class KillThread(BasicBlock): 205 """A BasicBlock that unconditionally returns KILL_THREAD.""" 206 207 def __init__(self): 208 super().__init__( 209 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)]) 210 211 212class Trap(BasicBlock): 213 """A BasicBlock that unconditionally returns TRAP.""" 214 215 def __init__(self): 216 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)]) 217 218 219class Trace(BasicBlock): 220 """A BasicBlock that unconditionally returns TRACE.""" 221 222 def __init__(self): 223 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)]) 224 225 226class UserNotify(BasicBlock): 227 """A BasicBlock that unconditionally returns USER_NOTIF.""" 228 229 def __init__(self): 230 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_USER_NOTIF)]) 231 232 233class Log(BasicBlock): 234 """A BasicBlock that unconditionally returns LOG.""" 235 236 def __init__(self): 237 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)]) 238 239 240class ReturnErrno(BasicBlock): 241 """A BasicBlock that unconditionally returns the specified errno.""" 242 243 def __init__(self, errno): 244 super().__init__([ 245 SockFilter(BPF_RET, 0x00, 0x00, 246 SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA)) 247 ]) 248 self.errno = errno 249 250 251class Allow(BasicBlock): 252 """A BasicBlock that unconditionally returns ALLOW.""" 253 254 def __init__(self): 255 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)]) 256 257 258class ValidateArch(AbstractBlock): 259 """An AbstractBlock that validates the architecture.""" 260 261 def __init__(self, next_block): 262 super().__init__() 263 self.next_block = next_block 264 265 def accept(self, visitor): 266 if visitor.visited(self): 267 return 268 self.next_block.accept(visitor) 269 visitor.visit(self) 270 271 272class SyscallEntry(AbstractBlock): 273 """An abstract block that represents a syscall comparison in a DAG.""" 274 275 def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ): 276 super().__init__() 277 self.op = op 278 self.syscall_number = syscall_number 279 self.jt = jt 280 self.jf = jf 281 282 def __lt__(self, o): 283 # Defined because we want to compare tuples that contain SyscallEntries. 284 return False 285 286 def __gt__(self, o): 287 # Defined because we want to compare tuples that contain SyscallEntries. 288 return False 289 290 def accept(self, visitor): 291 if visitor.visited(self): 292 return 293 self.jt.accept(visitor) 294 self.jf.accept(visitor) 295 visitor.visit(self) 296 297 def __lt__(self, o): 298 # Defined because we want to compare tuples that contain SyscallEntries. 299 return False 300 301 def __gt__(self, o): 302 # Defined because we want to compare tuples that contain SyscallEntries. 303 return False 304 305 306class WideAtom(AbstractBlock): 307 """A BasicBlock that represents a 32-bit wide atom.""" 308 309 def __init__(self, arg_offset, op, value, jt, jf): 310 super().__init__() 311 self.arg_offset = arg_offset 312 self.op = op 313 self.value = value 314 self.jt = jt 315 self.jf = jf 316 317 def accept(self, visitor): 318 if visitor.visited(self): 319 return 320 self.jt.accept(visitor) 321 self.jf.accept(visitor) 322 visitor.visit(self) 323 324 325class Atom(AbstractBlock): 326 """A BasicBlock that represents an atom (a simple comparison operation).""" 327 328 def __init__(self, arg_index, op, value, jt, jf): 329 super().__init__() 330 if op == '==': 331 op = BPF_JEQ 332 elif op == '!=': 333 op = BPF_JEQ 334 jt, jf = jf, jt 335 elif op == '>': 336 op = BPF_JGT 337 elif op == '<=': 338 op = BPF_JGT 339 jt, jf = jf, jt 340 elif op == '>=': 341 op = BPF_JGE 342 elif op == '<': 343 op = BPF_JGE 344 jt, jf = jf, jt 345 elif op == '&': 346 op = BPF_JSET 347 elif op == 'in': 348 op = BPF_JSET 349 # The mask is negated, so the comparison will be true when the 350 # argument includes a flag that wasn't listed in the original 351 # (non-negated) mask. This would be the failure case, so we switch 352 # |jt| and |jf|. 353 value = (~value) & ((1 << 64) - 1) 354 jt, jf = jf, jt 355 else: 356 raise Exception('Unknown operator %s' % op) 357 358 self.arg_index = arg_index 359 self.op = op 360 self.jt = jt 361 self.jf = jf 362 self.value = value 363 364 def accept(self, visitor): 365 if visitor.visited(self): 366 return 367 self.jt.accept(visitor) 368 self.jf.accept(visitor) 369 visitor.visit(self) 370 371 372class AbstractVisitor(abc.ABC): 373 """An abstract visitor.""" 374 375 def __init__(self): 376 self._visited = set() 377 378 def visited(self, block): 379 if id(block) in self._visited: 380 return True 381 self._visited.add(id(block)) 382 return False 383 384 def process(self, block): 385 block.accept(self) 386 return block 387 388 def visit(self, block): 389 if isinstance(block, KillProcess): 390 self.visitKillProcess(block) 391 elif isinstance(block, KillThread): 392 self.visitKillThread(block) 393 elif isinstance(block, Trap): 394 self.visitTrap(block) 395 elif isinstance(block, ReturnErrno): 396 self.visitReturnErrno(block) 397 elif isinstance(block, Trace): 398 self.visitTrace(block) 399 elif isinstance(block, UserNotify): 400 self.visitUserNotify(block) 401 elif isinstance(block, Log): 402 self.visitLog(block) 403 elif isinstance(block, Allow): 404 self.visitAllow(block) 405 elif isinstance(block, BasicBlock): 406 self.visitBasicBlock(block) 407 elif isinstance(block, ValidateArch): 408 self.visitValidateArch(block) 409 elif isinstance(block, SyscallEntry): 410 self.visitSyscallEntry(block) 411 elif isinstance(block, WideAtom): 412 self.visitWideAtom(block) 413 elif isinstance(block, Atom): 414 self.visitAtom(block) 415 else: 416 raise Exception('Unknown block type: %r' % block) 417 418 @abc.abstractmethod 419 def visitKillProcess(self, block): 420 pass 421 422 @abc.abstractmethod 423 def visitKillThread(self, block): 424 pass 425 426 @abc.abstractmethod 427 def visitTrap(self, block): 428 pass 429 430 @abc.abstractmethod 431 def visitReturnErrno(self, block): 432 pass 433 434 @abc.abstractmethod 435 def visitTrace(self, block): 436 pass 437 438 @abc.abstractmethod 439 def visitUserNotify(self, block): 440 pass 441 442 @abc.abstractmethod 443 def visitLog(self, block): 444 pass 445 446 @abc.abstractmethod 447 def visitAllow(self, block): 448 pass 449 450 @abc.abstractmethod 451 def visitBasicBlock(self, block): 452 pass 453 454 @abc.abstractmethod 455 def visitValidateArch(self, block): 456 pass 457 458 @abc.abstractmethod 459 def visitSyscallEntry(self, block): 460 pass 461 462 @abc.abstractmethod 463 def visitWideAtom(self, block): 464 pass 465 466 @abc.abstractmethod 467 def visitAtom(self, block): 468 pass 469 470 471class CopyingVisitor(AbstractVisitor): 472 """A visitor that copies Blocks.""" 473 474 def __init__(self): 475 super().__init__() 476 self._mapping = {} 477 478 def process(self, block): 479 self._mapping = {} 480 block.accept(self) 481 return self._mapping[id(block)] 482 483 def visitKillProcess(self, block): 484 assert id(block) not in self._mapping 485 self._mapping[id(block)] = KillProcess() 486 487 def visitKillThread(self, block): 488 assert id(block) not in self._mapping 489 self._mapping[id(block)] = KillThread() 490 491 def visitTrap(self, block): 492 assert id(block) not in self._mapping 493 self._mapping[id(block)] = Trap() 494 495 def visitReturnErrno(self, block): 496 assert id(block) not in self._mapping 497 self._mapping[id(block)] = ReturnErrno(block.errno) 498 499 def visitTrace(self, block): 500 assert id(block) not in self._mapping 501 self._mapping[id(block)] = Trace() 502 503 def visitUserNotify(self, block): 504 assert id(block) not in self._mapping 505 self._mapping[id(block)] = UserNotify() 506 507 def visitLog(self, block): 508 assert id(block) not in self._mapping 509 self._mapping[id(block)] = Log() 510 511 def visitAllow(self, block): 512 assert id(block) not in self._mapping 513 self._mapping[id(block)] = Allow() 514 515 def visitBasicBlock(self, block): 516 assert id(block) not in self._mapping 517 self._mapping[id(block)] = BasicBlock(block.instructions) 518 519 def visitValidateArch(self, block): 520 assert id(block) not in self._mapping 521 self._mapping[id(block)] = ValidateArch( 522 block.arch, self._mapping[id(block.next_block)]) 523 524 def visitSyscallEntry(self, block): 525 assert id(block) not in self._mapping 526 self._mapping[id(block)] = SyscallEntry( 527 block.syscall_number, 528 self._mapping[id(block.jt)], 529 self._mapping[id(block.jf)], 530 op=block.op) 531 532 def visitWideAtom(self, block): 533 assert id(block) not in self._mapping 534 self._mapping[id(block)] = WideAtom( 535 block.arg_offset, block.op, block.value, self._mapping[id( 536 block.jt)], self._mapping[id(block.jf)]) 537 538 def visitAtom(self, block): 539 assert id(block) not in self._mapping 540 self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value, 541 self._mapping[id(block.jt)], 542 self._mapping[id(block.jf)]) 543 544 545class LoweringVisitor(CopyingVisitor): 546 """A visitor that lowers Atoms into WideAtoms.""" 547 548 def __init__(self, *, arch): 549 super().__init__() 550 self._bits = arch.bits 551 552 def visitAtom(self, block): 553 assert id(block) not in self._mapping 554 555 lo = block.value & 0xFFFFFFFF 556 hi = (block.value >> 32) & 0xFFFFFFFF 557 558 lo_block = WideAtom( 559 arg_offset(block.arg_index, False), block.op, lo, 560 self._mapping[id(block.jt)], self._mapping[id(block.jf)]) 561 562 if self._bits == 32: 563 self._mapping[id(block)] = lo_block 564 return 565 566 if block.op in (BPF_JGE, BPF_JGT): 567 # hi_1,lo_1 <op> hi_2,lo_2 568 # 569 # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2 570 if hi == 0: 571 # Special case: it's not needed to check whether |hi_1 == hi_2|, 572 # because it's true iff the JGT test fails. 573 self._mapping[id(block)] = WideAtom( 574 arg_offset(block.arg_index, True), BPF_JGT, hi, 575 self._mapping[id(block.jt)], lo_block) 576 return 577 hi_eq_block = WideAtom( 578 arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block, 579 self._mapping[id(block.jf)]) 580 self._mapping[id(block)] = WideAtom( 581 arg_offset(block.arg_index, True), BPF_JGT, hi, 582 self._mapping[id(block.jt)], hi_eq_block) 583 return 584 if block.op == BPF_JSET: 585 # hi_1,lo_1 & hi_2,lo_2 586 # 587 # hi_1 & hi_2 || lo_1 & lo_2 588 if hi == 0: 589 # Special case: |hi_1 & hi_2| will never be True, so jump 590 # directly into the |lo_1 & lo_2| case. 591 self._mapping[id(block)] = lo_block 592 return 593 self._mapping[id(block)] = WideAtom( 594 arg_offset(block.arg_index, True), block.op, hi, 595 self._mapping[id(block.jt)], lo_block) 596 return 597 598 assert block.op == BPF_JEQ, block.op 599 600 # hi_1,lo_1 == hi_2,lo_2 601 # 602 # hi_1 == hi_2 && lo_1 == lo_2 603 self._mapping[id(block)] = WideAtom( 604 arg_offset(block.arg_index, True), block.op, hi, lo_block, 605 self._mapping[id(block.jf)]) 606 607 608class FlatteningVisitor: 609 """A visitor that flattens a DAG of Block objects.""" 610 611 def __init__(self, *, arch, kill_action): 612 self._visited = set() 613 self._kill_action = kill_action 614 self._instructions = [] 615 self._arch = arch 616 self._offsets = {} 617 618 @property 619 def result(self): 620 return BasicBlock(self._instructions) 621 622 def _distance(self, block): 623 distance = self._offsets[id(block)] + len(self._instructions) 624 assert distance >= 0 625 return distance 626 627 def _emit_load_arg(self, offset): 628 return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)] 629 630 def _emit_jmp(self, op, value, jt_distance, jf_distance): 631 if jt_distance < 0x100 and jf_distance < 0x100: 632 return [ 633 SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance, 634 value), 635 ] 636 if jt_distance + 1 < 0x100: 637 return [ 638 SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value), 639 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 640 ] 641 if jf_distance + 1 < 0x100: 642 return [ 643 SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value), 644 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance), 645 ] 646 return [ 647 SockFilter(BPF_JMP | op | BPF_K, 0, 1, value), 648 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1), 649 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 650 ] 651 652 def visited(self, block): 653 if id(block) in self._visited: 654 return True 655 self._visited.add(id(block)) 656 return False 657 658 def visit(self, block): 659 assert id(block) not in self._offsets 660 661 if isinstance(block, BasicBlock): 662 instructions = block.instructions 663 elif isinstance(block, ValidateArch): 664 instructions = [ 665 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4), 666 SockFilter(BPF_JMP | BPF_JEQ | BPF_K, 667 self._distance(block.next_block) + 1, 0, 668 self._arch.arch_nr), 669 ] + self._kill_action.instructions + [ 670 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0), 671 ] 672 elif isinstance(block, SyscallEntry): 673 instructions = self._emit_jmp(block.op, block.syscall_number, 674 self._distance(block.jt), 675 self._distance(block.jf)) 676 elif isinstance(block, WideAtom): 677 instructions = ( 678 self._emit_load_arg(block.arg_offset) + self._emit_jmp( 679 block.op, block.value, self._distance(block.jt), 680 self._distance(block.jf))) 681 else: 682 raise Exception('Unknown block type: %r' % block) 683 684 self._instructions = instructions + self._instructions 685 self._offsets[id(block)] = -len(self._instructions) 686 return 687 688 689class ArgFilterForwardingVisitor: 690 """A visitor that forwards visitation to all arg filters.""" 691 692 def __init__(self, visitor): 693 self._visited = set() 694 self.visitor = visitor 695 696 def visited(self, block): 697 if id(block) in self._visited: 698 return True 699 self._visited.add(id(block)) 700 return False 701 702 def visit(self, block): 703 # All arg filters are BasicBlocks. 704 if not isinstance(block, BasicBlock): 705 return 706 # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't 707 # want to visit them just yet. 708 if (isinstance(block, KillProcess) or isinstance(block, KillThread) 709 or isinstance(block, Trap) or isinstance(block, ReturnErrno) 710 or isinstance(block, Trace) or isinstance(block, UserNotify) 711 or isinstance(block, Log) or isinstance(block, Allow)): 712 return 713 block.accept(self.visitor) 714