• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""Tools to interact with BPF programs."""
18
19import abc
20import collections
21import struct
22
23# This comes from syscall(2). Most architectures only support passing 6 args to
24# syscalls, but ARM supports passing 7.
25MAX_SYSCALL_ARGUMENTS = 7
26
27# The following fields were copied from <linux/bpf_common.h>:
28
29# Instruction classes
30BPF_LD = 0x00
31BPF_LDX = 0x01
32BPF_ST = 0x02
33BPF_STX = 0x03
34BPF_ALU = 0x04
35BPF_JMP = 0x05
36BPF_RET = 0x06
37BPF_MISC = 0x07
38
39# LD/LDX fields.
40# Size
41BPF_W = 0x00
42BPF_H = 0x08
43BPF_B = 0x10
44# Mode
45BPF_IMM = 0x00
46BPF_ABS = 0x20
47BPF_IND = 0x40
48BPF_MEM = 0x60
49BPF_LEN = 0x80
50BPF_MSH = 0xa0
51
52# JMP fields.
53BPF_JA = 0x00
54BPF_JEQ = 0x10
55BPF_JGT = 0x20
56BPF_JGE = 0x30
57BPF_JSET = 0x40
58
59# Source
60BPF_K = 0x00
61BPF_X = 0x08
62
63BPF_MAXINSNS = 4096
64
65# The following fields were copied from <linux/seccomp.h>:
66
67SECCOMP_RET_KILL_PROCESS = 0x80000000
68SECCOMP_RET_KILL_THREAD = 0x00000000
69SECCOMP_RET_TRAP = 0x00030000
70SECCOMP_RET_ERRNO = 0x00050000
71SECCOMP_RET_TRACE = 0x7ff00000
72SECCOMP_RET_LOG = 0x7ffc0000
73SECCOMP_RET_ALLOW = 0x7fff0000
74
75SECCOMP_RET_ACTION_FULL = 0xffff0000
76SECCOMP_RET_DATA = 0x0000ffff
77
78
79def arg_offset(arg_index, hi=False):
80    """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset."""
81    offsetof_args = 4 + 4 + 8
82    arg_width = 8
83    return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi
84
85
86def simulate(instructions, arch, syscall_number, *args):
87    """Simulate a BPF program with the given arguments."""
88    args = ((args + (0, ) *
89             (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS])
90    input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS,
91                               syscall_number, arch, 0, *args)
92
93    register = 0
94    program_counter = 0
95    cost = 0
96    while program_counter < len(instructions):
97        ins = instructions[program_counter]
98        program_counter += 1
99        cost += 1
100        if ins.code == BPF_LD | BPF_W | BPF_ABS:
101            register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0]
102        elif ins.code == BPF_JMP | BPF_JA | BPF_K:
103            program_counter += ins.k
104        elif ins.code == BPF_JMP | BPF_JEQ | BPF_K:
105            if register == ins.k:
106                program_counter += ins.jt
107            else:
108                program_counter += ins.jf
109        elif ins.code == BPF_JMP | BPF_JGT | BPF_K:
110            if register > ins.k:
111                program_counter += ins.jt
112            else:
113                program_counter += ins.jf
114        elif ins.code == BPF_JMP | BPF_JGE | BPF_K:
115            if register >= ins.k:
116                program_counter += ins.jt
117            else:
118                program_counter += ins.jf
119        elif ins.code == BPF_JMP | BPF_JSET | BPF_K:
120            if register & ins.k != 0:
121                program_counter += ins.jt
122            else:
123                program_counter += ins.jf
124        elif ins.code == BPF_RET:
125            if ins.k == SECCOMP_RET_KILL_PROCESS:
126                return (cost, 'KILL_PROCESS')
127            if ins.k == SECCOMP_RET_KILL_THREAD:
128                return (cost, 'KILL_THREAD')
129            if ins.k == SECCOMP_RET_TRAP:
130                return (cost, 'TRAP')
131            if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO:
132                return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA)
133            if ins.k == SECCOMP_RET_TRACE:
134                return (cost, 'TRACE')
135            if ins.k == SECCOMP_RET_LOG:
136                return (cost, 'LOG')
137            if ins.k == SECCOMP_RET_ALLOW:
138                return (cost, 'ALLOW')
139            raise Exception('unknown return %#x' % ins.k)
140        else:
141            raise Exception('unknown instruction %r' % (ins, ))
142    raise Exception('out-of-bounds')
143
144
145class SockFilter(
146        collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])):
147    """A representation of struct sock_filter."""
148
149    __slots__ = ()
150
151    def encode(self):
152        """Return an encoded version of the SockFilter."""
153        return struct.pack('HBBI', self.code, self.jt, self.jf, self.k)
154
155
156class AbstractBlock(abc.ABC):
157    """A class that implements the visitor pattern."""
158
159    def __init__(self):
160        super().__init__()
161
162    @abc.abstractmethod
163    def accept(self, visitor):
164        pass
165
166
167class BasicBlock(AbstractBlock):
168    """A concrete implementation of AbstractBlock that has been compiled."""
169
170    def __init__(self, instructions):
171        super().__init__()
172        self._instructions = instructions
173
174    def accept(self, visitor):
175        if visitor.visited(self):
176            return
177        visitor.visit(self)
178
179    @property
180    def instructions(self):
181        return self._instructions
182
183    @property
184    def opcodes(self):
185        return b''.join(i.encode() for i in self._instructions)
186
187    def __eq__(self, o):
188        if not isinstance(o, BasicBlock):
189            return False
190        return self._instructions == o._instructions
191
192
193class KillProcess(BasicBlock):
194    """A BasicBlock that unconditionally returns KILL_PROCESS."""
195
196    def __init__(self):
197        super().__init__(
198            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)])
199
200
201class KillThread(BasicBlock):
202    """A BasicBlock that unconditionally returns KILL_THREAD."""
203
204    def __init__(self):
205        super().__init__(
206            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)])
207
208
209class Trap(BasicBlock):
210    """A BasicBlock that unconditionally returns TRAP."""
211
212    def __init__(self):
213        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)])
214
215
216class Trace(BasicBlock):
217    """A BasicBlock that unconditionally returns TRACE."""
218
219    def __init__(self):
220        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)])
221
222
223class Log(BasicBlock):
224    """A BasicBlock that unconditionally returns LOG."""
225
226    def __init__(self):
227        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)])
228
229
230class ReturnErrno(BasicBlock):
231    """A BasicBlock that unconditionally returns the specified errno."""
232
233    def __init__(self, errno):
234        super().__init__([
235            SockFilter(BPF_RET, 0x00, 0x00,
236                       SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA))
237        ])
238        self.errno = errno
239
240
241class Allow(BasicBlock):
242    """A BasicBlock that unconditionally returns ALLOW."""
243
244    def __init__(self):
245        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)])
246
247
248class ValidateArch(AbstractBlock):
249    """An AbstractBlock that validates the architecture."""
250
251    def __init__(self, next_block):
252        super().__init__()
253        self.next_block = next_block
254
255    def accept(self, visitor):
256        if visitor.visited(self):
257            return
258        self.next_block.accept(visitor)
259        visitor.visit(self)
260
261
262class SyscallEntry(AbstractBlock):
263    """An abstract block that represents a syscall comparison in a DAG."""
264
265    def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ):
266        super().__init__()
267        self.op = op
268        self.syscall_number = syscall_number
269        self.jt = jt
270        self.jf = jf
271
272    def __lt__(self, o):
273        # Defined because we want to compare tuples that contain SyscallEntries.
274        return False
275
276    def __gt__(self, o):
277        # Defined because we want to compare tuples that contain SyscallEntries.
278        return False
279
280    def accept(self, visitor):
281        if visitor.visited(self):
282            return
283        self.jt.accept(visitor)
284        self.jf.accept(visitor)
285        visitor.visit(self)
286
287    def __lt__(self, o):
288        # Defined because we want to compare tuples that contain SyscallEntries.
289        return False
290
291    def __gt__(self, o):
292        # Defined because we want to compare tuples that contain SyscallEntries.
293        return False
294
295
296class WideAtom(AbstractBlock):
297    """A BasicBlock that represents a 32-bit wide atom."""
298
299    def __init__(self, arg_offset, op, value, jt, jf):
300        super().__init__()
301        self.arg_offset = arg_offset
302        self.op = op
303        self.value = value
304        self.jt = jt
305        self.jf = jf
306
307    def accept(self, visitor):
308        if visitor.visited(self):
309            return
310        self.jt.accept(visitor)
311        self.jf.accept(visitor)
312        visitor.visit(self)
313
314
315class Atom(AbstractBlock):
316    """A BasicBlock that represents an atom (a simple comparison operation)."""
317
318    def __init__(self, arg_index, op, value, jt, jf):
319        super().__init__()
320        if op == '==':
321            op = BPF_JEQ
322        elif op == '!=':
323            op = BPF_JEQ
324            jt, jf = jf, jt
325        elif op == '>':
326            op = BPF_JGT
327        elif op == '<=':
328            op = BPF_JGT
329            jt, jf = jf, jt
330        elif op == '>=':
331            op = BPF_JGE
332        elif op == '<':
333            op = BPF_JGE
334            jt, jf = jf, jt
335        elif op == '&':
336            op = BPF_JSET
337        elif op == 'in':
338            op = BPF_JSET
339            # The mask is negated, so the comparison will be true when the
340            # argument includes a flag that wasn't listed in the original
341            # (non-negated) mask. This would be the failure case, so we switch
342            # |jt| and |jf|.
343            value = (~value) & ((1 << 64) - 1)
344            jt, jf = jf, jt
345        else:
346            raise Exception('Unknown operator %s' % op)
347
348        self.arg_index = arg_index
349        self.op = op
350        self.jt = jt
351        self.jf = jf
352        self.value = value
353
354    def accept(self, visitor):
355        if visitor.visited(self):
356            return
357        self.jt.accept(visitor)
358        self.jf.accept(visitor)
359        visitor.visit(self)
360
361
362class AbstractVisitor(abc.ABC):
363    """An abstract visitor."""
364
365    def __init__(self):
366        self._visited = set()
367
368    def visited(self, block):
369        if id(block) in self._visited:
370            return True
371        self._visited.add(id(block))
372        return False
373
374    def process(self, block):
375        block.accept(self)
376        return block
377
378    def visit(self, block):
379        if isinstance(block, KillProcess):
380            self.visitKillProcess(block)
381        elif isinstance(block, KillThread):
382            self.visitKillThread(block)
383        elif isinstance(block, Trap):
384            self.visitTrap(block)
385        elif isinstance(block, ReturnErrno):
386            self.visitReturnErrno(block)
387        elif isinstance(block, Trace):
388            self.visitTrace(block)
389        elif isinstance(block, Log):
390            self.visitLog(block)
391        elif isinstance(block, Allow):
392            self.visitAllow(block)
393        elif isinstance(block, BasicBlock):
394            self.visitBasicBlock(block)
395        elif isinstance(block, ValidateArch):
396            self.visitValidateArch(block)
397        elif isinstance(block, SyscallEntry):
398            self.visitSyscallEntry(block)
399        elif isinstance(block, WideAtom):
400            self.visitWideAtom(block)
401        elif isinstance(block, Atom):
402            self.visitAtom(block)
403        else:
404            raise Exception('Unknown block type: %r' % block)
405
406    @abc.abstractmethod
407    def visitKillProcess(self, block):
408        pass
409
410    @abc.abstractmethod
411    def visitKillThread(self, block):
412        pass
413
414    @abc.abstractmethod
415    def visitTrap(self, block):
416        pass
417
418    @abc.abstractmethod
419    def visitReturnErrno(self, block):
420        pass
421
422    @abc.abstractmethod
423    def visitTrace(self, block):
424        pass
425
426    @abc.abstractmethod
427    def visitLog(self, block):
428        pass
429
430    @abc.abstractmethod
431    def visitAllow(self, block):
432        pass
433
434    @abc.abstractmethod
435    def visitBasicBlock(self, block):
436        pass
437
438    @abc.abstractmethod
439    def visitValidateArch(self, block):
440        pass
441
442    @abc.abstractmethod
443    def visitSyscallEntry(self, block):
444        pass
445
446    @abc.abstractmethod
447    def visitWideAtom(self, block):
448        pass
449
450    @abc.abstractmethod
451    def visitAtom(self, block):
452        pass
453
454
455class CopyingVisitor(AbstractVisitor):
456    """A visitor that copies Blocks."""
457
458    def __init__(self):
459        super().__init__()
460        self._mapping = {}
461
462    def process(self, block):
463        self._mapping = {}
464        block.accept(self)
465        return self._mapping[id(block)]
466
467    def visitKillProcess(self, block):
468        assert id(block) not in self._mapping
469        self._mapping[id(block)] = KillProcess()
470
471    def visitKillThread(self, block):
472        assert id(block) not in self._mapping
473        self._mapping[id(block)] = KillThread()
474
475    def visitTrap(self, block):
476        assert id(block) not in self._mapping
477        self._mapping[id(block)] = Trap()
478
479    def visitReturnErrno(self, block):
480        assert id(block) not in self._mapping
481        self._mapping[id(block)] = ReturnErrno(block.errno)
482
483    def visitTrace(self, block):
484        assert id(block) not in self._mapping
485        self._mapping[id(block)] = Trace()
486
487    def visitLog(self, block):
488        assert id(block) not in self._mapping
489        self._mapping[id(block)] = Log()
490
491    def visitAllow(self, block):
492        assert id(block) not in self._mapping
493        self._mapping[id(block)] = Allow()
494
495    def visitBasicBlock(self, block):
496        assert id(block) not in self._mapping
497        self._mapping[id(block)] = BasicBlock(block.instructions)
498
499    def visitValidateArch(self, block):
500        assert id(block) not in self._mapping
501        self._mapping[id(block)] = ValidateArch(
502            block.arch, self._mapping[id(block.next_block)])
503
504    def visitSyscallEntry(self, block):
505        assert id(block) not in self._mapping
506        self._mapping[id(block)] = SyscallEntry(
507            block.syscall_number,
508            self._mapping[id(block.jt)],
509            self._mapping[id(block.jf)],
510            op=block.op)
511
512    def visitWideAtom(self, block):
513        assert id(block) not in self._mapping
514        self._mapping[id(block)] = WideAtom(
515            block.arg_offset, block.op, block.value, self._mapping[id(
516                block.jt)], self._mapping[id(block.jf)])
517
518    def visitAtom(self, block):
519        assert id(block) not in self._mapping
520        self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value,
521                                        self._mapping[id(block.jt)],
522                                        self._mapping[id(block.jf)])
523
524
525class LoweringVisitor(CopyingVisitor):
526    """A visitor that lowers Atoms into WideAtoms."""
527
528    def __init__(self, *, arch):
529        super().__init__()
530        self._bits = arch.bits
531
532    def visitAtom(self, block):
533        assert id(block) not in self._mapping
534
535        lo = block.value & 0xFFFFFFFF
536        hi = (block.value >> 32) & 0xFFFFFFFF
537
538        lo_block = WideAtom(
539            arg_offset(block.arg_index, False), block.op, lo,
540            self._mapping[id(block.jt)], self._mapping[id(block.jf)])
541
542        if self._bits == 32:
543            self._mapping[id(block)] = lo_block
544            return
545
546        if block.op in (BPF_JGE, BPF_JGT):
547            # hi_1,lo_1 <op> hi_2,lo_2
548            #
549            # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2
550            if hi == 0:
551                # Special case: it's not needed to check whether |hi_1 == hi_2|,
552                # because it's true iff the JGT test fails.
553                self._mapping[id(block)] = WideAtom(
554                    arg_offset(block.arg_index, True), BPF_JGT, hi,
555                    self._mapping[id(block.jt)], lo_block)
556                return
557            hi_eq_block = WideAtom(
558                arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block,
559                self._mapping[id(block.jf)])
560            self._mapping[id(block)] = WideAtom(
561                arg_offset(block.arg_index, True), BPF_JGT, hi,
562                self._mapping[id(block.jt)], hi_eq_block)
563            return
564        if block.op == BPF_JSET:
565            # hi_1,lo_1 & hi_2,lo_2
566            #
567            # hi_1 & hi_2 || lo_1 & lo_2
568            if hi == 0:
569                # Special case: |hi_1 & hi_2| will never be True, so jump
570                # directly into the |lo_1 & lo_2| case.
571                self._mapping[id(block)] = lo_block
572                return
573            self._mapping[id(block)] = WideAtom(
574                arg_offset(block.arg_index, True), block.op, hi,
575                self._mapping[id(block.jt)], lo_block)
576            return
577
578        assert block.op == BPF_JEQ, block.op
579
580        # hi_1,lo_1 == hi_2,lo_2
581        #
582        # hi_1 == hi_2 && lo_1 == lo_2
583        self._mapping[id(block)] = WideAtom(
584            arg_offset(block.arg_index, True), block.op, hi, lo_block,
585            self._mapping[id(block.jf)])
586
587
588class FlatteningVisitor:
589    """A visitor that flattens a DAG of Block objects."""
590
591    def __init__(self, *, arch, kill_action):
592        self._visited = set()
593        self._kill_action = kill_action
594        self._instructions = []
595        self._arch = arch
596        self._offsets = {}
597
598    @property
599    def result(self):
600        return BasicBlock(self._instructions)
601
602    def _distance(self, block):
603        distance = self._offsets[id(block)] + len(self._instructions)
604        assert distance >= 0
605        return distance
606
607    def _emit_load_arg(self, offset):
608        return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)]
609
610    def _emit_jmp(self, op, value, jt_distance, jf_distance):
611        if jt_distance < 0x100 and jf_distance < 0x100:
612            return [
613                SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance,
614                           value),
615            ]
616        if jt_distance + 1 < 0x100:
617            return [
618                SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value),
619                SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
620            ]
621        if jf_distance + 1 < 0x100:
622            return [
623                SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value),
624                SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance),
625            ]
626        return [
627            SockFilter(BPF_JMP | op | BPF_K, 0, 1, value),
628            SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1),
629            SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
630        ]
631
632    def visited(self, block):
633        if id(block) in self._visited:
634            return True
635        self._visited.add(id(block))
636        return False
637
638    def visit(self, block):
639        assert id(block) not in self._offsets
640
641        if isinstance(block, BasicBlock):
642            instructions = block.instructions
643        elif isinstance(block, ValidateArch):
644            instructions = [
645                SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4),
646                SockFilter(BPF_JMP | BPF_JEQ | BPF_K,
647                           self._distance(block.next_block) + 1, 0,
648                           self._arch.arch_nr),
649            ] + self._kill_action.instructions + [
650                SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0),
651            ]
652        elif isinstance(block, SyscallEntry):
653            instructions = self._emit_jmp(block.op, block.syscall_number,
654                                          self._distance(block.jt),
655                                          self._distance(block.jf))
656        elif isinstance(block, WideAtom):
657            instructions = (
658                self._emit_load_arg(block.arg_offset) + self._emit_jmp(
659                    block.op, block.value, self._distance(block.jt),
660                    self._distance(block.jf)))
661        else:
662            raise Exception('Unknown block type: %r' % block)
663
664        self._instructions = instructions + self._instructions
665        self._offsets[id(block)] = -len(self._instructions)
666        return
667
668
669class ArgFilterForwardingVisitor:
670    """A visitor that forwards visitation to all arg filters."""
671
672    def __init__(self, visitor):
673        self._visited = set()
674        self.visitor = visitor
675
676    def visited(self, block):
677        if id(block) in self._visited:
678            return True
679        self._visited.add(id(block))
680        return False
681
682    def visit(self, block):
683        # All arg filters are BasicBlocks.
684        if not isinstance(block, BasicBlock):
685            return
686        # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't
687        # want to visit them just yet.
688        if (isinstance(block, KillProcess) or isinstance(block, KillThread)
689                or isinstance(block, Trap) or isinstance(block, ReturnErrno)
690                or isinstance(block, Trace) or isinstance(block, Log)
691                or isinstance(block, Allow)):
692            return
693        block.accept(self.visitor)
694