#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Copyright (C) 2018 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools to interact with BPF programs.""" import abc import collections import struct # This comes from syscall(2). Most architectures only support passing 6 args to # syscalls, but ARM supports passing 7. MAX_SYSCALL_ARGUMENTS = 7 # The following fields were copied from : # Instruction classes BPF_LD = 0x00 BPF_LDX = 0x01 BPF_ST = 0x02 BPF_STX = 0x03 BPF_ALU = 0x04 BPF_JMP = 0x05 BPF_RET = 0x06 BPF_MISC = 0x07 # LD/LDX fields. # Size BPF_W = 0x00 BPF_H = 0x08 BPF_B = 0x10 # Mode BPF_IMM = 0x00 BPF_ABS = 0x20 BPF_IND = 0x40 BPF_MEM = 0x60 BPF_LEN = 0x80 BPF_MSH = 0xa0 # JMP fields. BPF_JA = 0x00 BPF_JEQ = 0x10 BPF_JGT = 0x20 BPF_JGE = 0x30 BPF_JSET = 0x40 # Source BPF_K = 0x00 BPF_X = 0x08 BPF_MAXINSNS = 4096 # The following fields were copied from : SECCOMP_RET_KILL_PROCESS = 0x80000000 SECCOMP_RET_KILL_THREAD = 0x00000000 SECCOMP_RET_TRAP = 0x00030000 SECCOMP_RET_ERRNO = 0x00050000 SECCOMP_RET_TRACE = 0x7ff00000 SECCOMP_RET_USER_NOTIF = 0x7fc00000 SECCOMP_RET_LOG = 0x7ffc0000 SECCOMP_RET_ALLOW = 0x7fff0000 SECCOMP_RET_ACTION_FULL = 0xffff0000 SECCOMP_RET_DATA = 0x0000ffff def arg_offset(arg_index, hi=False): """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset.""" offsetof_args = 4 + 4 + 8 arg_width = 8 return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi def simulate(instructions, arch, syscall_number, *args): """Simulate a BPF program with the given arguments.""" args = ((args + (0, ) * (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS]) input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS, syscall_number, arch, 0, *args) register = 0 program_counter = 0 cost = 0 while program_counter < len(instructions): ins = instructions[program_counter] program_counter += 1 cost += 1 if ins.code == BPF_LD | BPF_W | BPF_ABS: register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0] elif ins.code == BPF_JMP | BPF_JA | BPF_K: program_counter += ins.k elif ins.code == BPF_JMP | BPF_JEQ | BPF_K: if register == ins.k: program_counter += ins.jt else: program_counter += ins.jf elif ins.code == BPF_JMP | BPF_JGT | BPF_K: if register > ins.k: program_counter += ins.jt else: program_counter += ins.jf elif ins.code == BPF_JMP | BPF_JGE | BPF_K: if register >= ins.k: program_counter += ins.jt else: program_counter += ins.jf elif ins.code == BPF_JMP | BPF_JSET | BPF_K: if register & ins.k != 0: program_counter += ins.jt else: program_counter += ins.jf elif ins.code == BPF_RET: if ins.k == SECCOMP_RET_KILL_PROCESS: return (cost, 'KILL_PROCESS') if ins.k == SECCOMP_RET_KILL_THREAD: return (cost, 'KILL_THREAD') if ins.k == SECCOMP_RET_TRAP: return (cost, 'TRAP') if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO: return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA) if ins.k == SECCOMP_RET_TRACE: return (cost, 'TRACE') if ins.k == SECCOMP_RET_USER_NOTIF: return (cost, 'USER_NOTIF') if ins.k == SECCOMP_RET_LOG: return (cost, 'LOG') if ins.k == SECCOMP_RET_ALLOW: return (cost, 'ALLOW') raise Exception('unknown return %#x' % ins.k) else: raise Exception('unknown instruction %r' % (ins, )) raise Exception('out-of-bounds') class SockFilter( collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])): """A representation of struct sock_filter.""" __slots__ = () def encode(self): """Return an encoded version of the SockFilter.""" return struct.pack('HBBI', self.code, self.jt, self.jf, self.k) class AbstractBlock(abc.ABC): """A class that implements the visitor pattern.""" def __init__(self): super().__init__() @abc.abstractmethod def accept(self, visitor): pass class BasicBlock(AbstractBlock): """A concrete implementation of AbstractBlock that has been compiled.""" def __init__(self, instructions): super().__init__() self._instructions = instructions def accept(self, visitor): if visitor.visited(self): return visitor.visit(self) @property def instructions(self): return self._instructions @property def opcodes(self): return b''.join(i.encode() for i in self._instructions) def __eq__(self, o): if not isinstance(o, BasicBlock): return False return self._instructions == o._instructions class KillProcess(BasicBlock): """A BasicBlock that unconditionally returns KILL_PROCESS.""" def __init__(self): super().__init__( [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)]) class KillThread(BasicBlock): """A BasicBlock that unconditionally returns KILL_THREAD.""" def __init__(self): super().__init__( [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)]) class Trap(BasicBlock): """A BasicBlock that unconditionally returns TRAP.""" def __init__(self): super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)]) class Trace(BasicBlock): """A BasicBlock that unconditionally returns TRACE.""" def __init__(self): super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)]) class UserNotify(BasicBlock): """A BasicBlock that unconditionally returns USER_NOTIF.""" def __init__(self): super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_USER_NOTIF)]) class Log(BasicBlock): """A BasicBlock that unconditionally returns LOG.""" def __init__(self): super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)]) class ReturnErrno(BasicBlock): """A BasicBlock that unconditionally returns the specified errno.""" def __init__(self, errno): super().__init__([ SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA)) ]) self.errno = errno class Allow(BasicBlock): """A BasicBlock that unconditionally returns ALLOW.""" def __init__(self): super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)]) class ValidateArch(AbstractBlock): """An AbstractBlock that validates the architecture.""" def __init__(self, next_block): super().__init__() self.next_block = next_block def accept(self, visitor): if visitor.visited(self): return self.next_block.accept(visitor) visitor.visit(self) class SyscallEntry(AbstractBlock): """An abstract block that represents a syscall comparison in a DAG.""" def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ): super().__init__() self.op = op self.syscall_number = syscall_number self.jt = jt self.jf = jf def __lt__(self, o): # Defined because we want to compare tuples that contain SyscallEntries. return False def __gt__(self, o): # Defined because we want to compare tuples that contain SyscallEntries. return False def accept(self, visitor): if visitor.visited(self): return self.jt.accept(visitor) self.jf.accept(visitor) visitor.visit(self) def __lt__(self, o): # Defined because we want to compare tuples that contain SyscallEntries. return False def __gt__(self, o): # Defined because we want to compare tuples that contain SyscallEntries. return False class WideAtom(AbstractBlock): """A BasicBlock that represents a 32-bit wide atom.""" def __init__(self, arg_offset, op, value, jt, jf): super().__init__() self.arg_offset = arg_offset self.op = op self.value = value self.jt = jt self.jf = jf def accept(self, visitor): if visitor.visited(self): return self.jt.accept(visitor) self.jf.accept(visitor) visitor.visit(self) class Atom(AbstractBlock): """A BasicBlock that represents an atom (a simple comparison operation).""" def __init__(self, arg_index, op, value, jt, jf): super().__init__() if op == '==': op = BPF_JEQ elif op == '!=': op = BPF_JEQ jt, jf = jf, jt elif op == '>': op = BPF_JGT elif op == '<=': op = BPF_JGT jt, jf = jf, jt elif op == '>=': op = BPF_JGE elif op == '<': op = BPF_JGE jt, jf = jf, jt elif op == '&': op = BPF_JSET elif op == 'in': op = BPF_JSET # The mask is negated, so the comparison will be true when the # argument includes a flag that wasn't listed in the original # (non-negated) mask. This would be the failure case, so we switch # |jt| and |jf|. value = (~value) & ((1 << 64) - 1) jt, jf = jf, jt else: raise Exception('Unknown operator %s' % op) self.arg_index = arg_index self.op = op self.jt = jt self.jf = jf self.value = value def accept(self, visitor): if visitor.visited(self): return self.jt.accept(visitor) self.jf.accept(visitor) visitor.visit(self) class AbstractVisitor(abc.ABC): """An abstract visitor.""" def __init__(self): self._visited = set() def visited(self, block): if id(block) in self._visited: return True self._visited.add(id(block)) return False def process(self, block): block.accept(self) return block def visit(self, block): if isinstance(block, KillProcess): self.visitKillProcess(block) elif isinstance(block, KillThread): self.visitKillThread(block) elif isinstance(block, Trap): self.visitTrap(block) elif isinstance(block, ReturnErrno): self.visitReturnErrno(block) elif isinstance(block, Trace): self.visitTrace(block) elif isinstance(block, UserNotify): self.visitUserNotify(block) elif isinstance(block, Log): self.visitLog(block) elif isinstance(block, Allow): self.visitAllow(block) elif isinstance(block, BasicBlock): self.visitBasicBlock(block) elif isinstance(block, ValidateArch): self.visitValidateArch(block) elif isinstance(block, SyscallEntry): self.visitSyscallEntry(block) elif isinstance(block, WideAtom): self.visitWideAtom(block) elif isinstance(block, Atom): self.visitAtom(block) else: raise Exception('Unknown block type: %r' % block) @abc.abstractmethod def visitKillProcess(self, block): pass @abc.abstractmethod def visitKillThread(self, block): pass @abc.abstractmethod def visitTrap(self, block): pass @abc.abstractmethod def visitReturnErrno(self, block): pass @abc.abstractmethod def visitTrace(self, block): pass @abc.abstractmethod def visitUserNotify(self, block): pass @abc.abstractmethod def visitLog(self, block): pass @abc.abstractmethod def visitAllow(self, block): pass @abc.abstractmethod def visitBasicBlock(self, block): pass @abc.abstractmethod def visitValidateArch(self, block): pass @abc.abstractmethod def visitSyscallEntry(self, block): pass @abc.abstractmethod def visitWideAtom(self, block): pass @abc.abstractmethod def visitAtom(self, block): pass class CopyingVisitor(AbstractVisitor): """A visitor that copies Blocks.""" def __init__(self): super().__init__() self._mapping = {} def process(self, block): self._mapping = {} block.accept(self) return self._mapping[id(block)] def visitKillProcess(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = KillProcess() def visitKillThread(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = KillThread() def visitTrap(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = Trap() def visitReturnErrno(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = ReturnErrno(block.errno) def visitTrace(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = Trace() def visitUserNotify(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = UserNotify() def visitLog(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = Log() def visitAllow(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = Allow() def visitBasicBlock(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = BasicBlock(block.instructions) def visitValidateArch(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = ValidateArch( block.arch, self._mapping[id(block.next_block)]) def visitSyscallEntry(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = SyscallEntry( block.syscall_number, self._mapping[id(block.jt)], self._mapping[id(block.jf)], op=block.op) def visitWideAtom(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = WideAtom( block.arg_offset, block.op, block.value, self._mapping[id( block.jt)], self._mapping[id(block.jf)]) def visitAtom(self, block): assert id(block) not in self._mapping self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value, self._mapping[id(block.jt)], self._mapping[id(block.jf)]) class LoweringVisitor(CopyingVisitor): """A visitor that lowers Atoms into WideAtoms.""" def __init__(self, *, arch): super().__init__() self._bits = arch.bits def visitAtom(self, block): assert id(block) not in self._mapping lo = block.value & 0xFFFFFFFF hi = (block.value >> 32) & 0xFFFFFFFF lo_block = WideAtom( arg_offset(block.arg_index, False), block.op, lo, self._mapping[id(block.jt)], self._mapping[id(block.jf)]) if self._bits == 32: self._mapping[id(block)] = lo_block return if block.op in (BPF_JGE, BPF_JGT): # hi_1,lo_1 hi_2,lo_2 # # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 lo_2 if hi == 0: # Special case: it's not needed to check whether |hi_1 == hi_2|, # because it's true iff the JGT test fails. self._mapping[id(block)] = WideAtom( arg_offset(block.arg_index, True), BPF_JGT, hi, self._mapping[id(block.jt)], lo_block) return hi_eq_block = WideAtom( arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block, self._mapping[id(block.jf)]) self._mapping[id(block)] = WideAtom( arg_offset(block.arg_index, True), BPF_JGT, hi, self._mapping[id(block.jt)], hi_eq_block) return if block.op == BPF_JSET: # hi_1,lo_1 & hi_2,lo_2 # # hi_1 & hi_2 || lo_1 & lo_2 if hi == 0: # Special case: |hi_1 & hi_2| will never be True, so jump # directly into the |lo_1 & lo_2| case. self._mapping[id(block)] = lo_block return self._mapping[id(block)] = WideAtom( arg_offset(block.arg_index, True), block.op, hi, self._mapping[id(block.jt)], lo_block) return assert block.op == BPF_JEQ, block.op # hi_1,lo_1 == hi_2,lo_2 # # hi_1 == hi_2 && lo_1 == lo_2 self._mapping[id(block)] = WideAtom( arg_offset(block.arg_index, True), block.op, hi, lo_block, self._mapping[id(block.jf)]) class FlatteningVisitor: """A visitor that flattens a DAG of Block objects.""" def __init__(self, *, arch, kill_action): self._visited = set() self._kill_action = kill_action self._instructions = [] self._arch = arch self._offsets = {} @property def result(self): return BasicBlock(self._instructions) def _distance(self, block): distance = self._offsets[id(block)] + len(self._instructions) assert distance >= 0 return distance def _emit_load_arg(self, offset): return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)] def _emit_jmp(self, op, value, jt_distance, jf_distance): if jt_distance < 0x100 and jf_distance < 0x100: return [ SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance, value), ] if jt_distance + 1 < 0x100: return [ SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value), SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), ] if jf_distance + 1 < 0x100: return [ SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value), SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance), ] return [ SockFilter(BPF_JMP | op | BPF_K, 0, 1, value), SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1), SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), ] def visited(self, block): if id(block) in self._visited: return True self._visited.add(id(block)) return False def visit(self, block): assert id(block) not in self._offsets if isinstance(block, BasicBlock): instructions = block.instructions elif isinstance(block, ValidateArch): instructions = [ SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4), SockFilter(BPF_JMP | BPF_JEQ | BPF_K, self._distance(block.next_block) + 1, 0, self._arch.arch_nr), ] + self._kill_action.instructions + [ SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0), ] elif isinstance(block, SyscallEntry): instructions = self._emit_jmp(block.op, block.syscall_number, self._distance(block.jt), self._distance(block.jf)) elif isinstance(block, WideAtom): instructions = ( self._emit_load_arg(block.arg_offset) + self._emit_jmp( block.op, block.value, self._distance(block.jt), self._distance(block.jf))) else: raise Exception('Unknown block type: %r' % block) self._instructions = instructions + self._instructions self._offsets[id(block)] = -len(self._instructions) return class ArgFilterForwardingVisitor: """A visitor that forwards visitation to all arg filters.""" def __init__(self, visitor): self._visited = set() self.visitor = visitor def visited(self, block): if id(block) in self._visited: return True self._visited.add(id(block)) return False def visit(self, block): # All arg filters are BasicBlocks. if not isinstance(block, BasicBlock): return # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't # want to visit them just yet. if (isinstance(block, KillProcess) or isinstance(block, KillThread) or isinstance(block, Trap) or isinstance(block, ReturnErrno) or isinstance(block, Trace) or isinstance(block, UserNotify) or isinstance(block, Log) or isinstance(block, Allow)): return block.accept(self.visitor)