• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The ChromiumOS Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Tools to interact with BPF programs."""
6
7import abc
8import collections
9import struct
10
11
12# This comes from syscall(2). Most architectures only support passing 6 args to
13# syscalls, but ARM supports passing 7.
14MAX_SYSCALL_ARGUMENTS = 7
15
16# The following fields were copied from <linux/bpf_common.h>:
17
18# Instruction classes
19BPF_LD = 0x00
20BPF_LDX = 0x01
21BPF_ST = 0x02
22BPF_STX = 0x03
23BPF_ALU = 0x04
24BPF_JMP = 0x05
25BPF_RET = 0x06
26BPF_MISC = 0x07
27
28# LD/LDX fields.
29# Size
30BPF_W = 0x00
31BPF_H = 0x08
32BPF_B = 0x10
33# Mode
34BPF_IMM = 0x00
35BPF_ABS = 0x20
36BPF_IND = 0x40
37BPF_MEM = 0x60
38BPF_LEN = 0x80
39BPF_MSH = 0xA0
40
41# JMP fields.
42BPF_JA = 0x00
43BPF_JEQ = 0x10
44BPF_JGT = 0x20
45BPF_JGE = 0x30
46BPF_JSET = 0x40
47
48# Source
49BPF_K = 0x00
50BPF_X = 0x08
51
52BPF_MAXINSNS = 4096
53
54# The following fields were copied from <linux/seccomp.h>:
55
56SECCOMP_RET_KILL_PROCESS = 0x80000000
57SECCOMP_RET_KILL_THREAD = 0x00000000
58SECCOMP_RET_TRAP = 0x00030000
59SECCOMP_RET_ERRNO = 0x00050000
60SECCOMP_RET_TRACE = 0x7FF00000
61SECCOMP_RET_USER_NOTIF = 0x7FC00000
62SECCOMP_RET_LOG = 0x7FFC0000
63SECCOMP_RET_ALLOW = 0x7FFF0000
64
65SECCOMP_RET_ACTION_FULL = 0xFFFF0000
66SECCOMP_RET_DATA = 0x0000FFFF
67
68
69def arg_offset(arg_index, hi=False):
70    """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset."""
71    offsetof_args = 4 + 4 + 8
72    arg_width = 8
73    return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi
74
75
76def simulate(instructions, arch, syscall_number, *args):
77    """Simulate a BPF program with the given arguments."""
78    args = (args + (0,) * (MAX_SYSCALL_ARGUMENTS - len(args)))[
79        :MAX_SYSCALL_ARGUMENTS
80    ]
81    input_memory = struct.pack(
82        "IIQ" + "Q" * MAX_SYSCALL_ARGUMENTS, syscall_number, arch, 0, *args
83    )
84
85    register = 0
86    program_counter = 0
87    cost = 0
88    while program_counter < len(instructions):
89        ins = instructions[program_counter]
90        program_counter += 1
91        cost += 1
92        if ins.code == BPF_LD | BPF_W | BPF_ABS:
93            register = struct.unpack("I", input_memory[ins.k : ins.k + 4])[0]
94        elif ins.code == BPF_JMP | BPF_JA | BPF_K:
95            program_counter += ins.k
96        elif ins.code == BPF_JMP | BPF_JEQ | BPF_K:
97            if register == ins.k:
98                program_counter += ins.jt
99            else:
100                program_counter += ins.jf
101        elif ins.code == BPF_JMP | BPF_JGT | BPF_K:
102            if register > ins.k:
103                program_counter += ins.jt
104            else:
105                program_counter += ins.jf
106        elif ins.code == BPF_JMP | BPF_JGE | BPF_K:
107            if register >= ins.k:
108                program_counter += ins.jt
109            else:
110                program_counter += ins.jf
111        elif ins.code == BPF_JMP | BPF_JSET | BPF_K:
112            if register & ins.k != 0:
113                program_counter += ins.jt
114            else:
115                program_counter += ins.jf
116        elif ins.code == BPF_RET:
117            if ins.k == SECCOMP_RET_KILL_PROCESS:
118                return (cost, "KILL_PROCESS")
119            if ins.k == SECCOMP_RET_KILL_THREAD:
120                return (cost, "KILL_THREAD")
121            if ins.k == SECCOMP_RET_TRAP:
122                return (cost, "TRAP")
123            if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO:
124                return (cost, "ERRNO", ins.k & SECCOMP_RET_DATA)
125            if ins.k == SECCOMP_RET_TRACE:
126                return (cost, "TRACE")
127            if ins.k == SECCOMP_RET_USER_NOTIF:
128                return (cost, "USER_NOTIF")
129            if ins.k == SECCOMP_RET_LOG:
130                return (cost, "LOG")
131            if ins.k == SECCOMP_RET_ALLOW:
132                return (cost, "ALLOW")
133            raise Exception("unknown return %#x" % ins.k)
134        else:
135            raise Exception("unknown instruction %r" % (ins,))
136    raise Exception("out-of-bounds")
137
138
139class SockFilter(
140    collections.namedtuple("SockFilter", ["code", "jt", "jf", "k"])
141):
142    """A representation of struct sock_filter."""
143
144    __slots__ = ()
145
146    def encode(self):
147        """Return an encoded version of the SockFilter."""
148        return struct.pack("HBBI", self.code, self.jt, self.jf, self.k)
149
150
151class AbstractBlock(abc.ABC):
152    """A class that implements the visitor pattern."""
153
154    @abc.abstractmethod
155    def accept(self, visitor):
156        pass
157
158
159class BasicBlock(AbstractBlock):
160    """A concrete implementation of AbstractBlock that has been compiled."""
161
162    def __init__(self, instructions):
163        super().__init__()
164        self._instructions = instructions
165
166    def accept(self, visitor):
167        if visitor.visited(self):
168            return
169        visitor.visit(self)
170
171    @property
172    def instructions(self):
173        return self._instructions
174
175    @property
176    def opcodes(self):
177        return b"".join(i.encode() for i in self._instructions)
178
179    def __eq__(self, o):
180        if not isinstance(o, BasicBlock):
181            return False
182        return self._instructions == o._instructions
183
184
185class KillProcess(BasicBlock):
186    """A BasicBlock that unconditionally returns KILL_PROCESS."""
187
188    def __init__(self):
189        super().__init__(
190            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)]
191        )
192
193
194class KillThread(BasicBlock):
195    """A BasicBlock that unconditionally returns KILL_THREAD."""
196
197    def __init__(self):
198        super().__init__(
199            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)]
200        )
201
202
203class Trap(BasicBlock):
204    """A BasicBlock that unconditionally returns TRAP."""
205
206    def __init__(self):
207        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)])
208
209
210class Trace(BasicBlock):
211    """A BasicBlock that unconditionally returns TRACE."""
212
213    def __init__(self):
214        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)])
215
216
217class UserNotify(BasicBlock):
218    """A BasicBlock that unconditionally returns USER_NOTIF."""
219
220    def __init__(self):
221        super().__init__(
222            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_USER_NOTIF)]
223        )
224
225
226class Log(BasicBlock):
227    """A BasicBlock that unconditionally returns LOG."""
228
229    def __init__(self):
230        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)])
231
232
233class ReturnErrno(BasicBlock):
234    """A BasicBlock that unconditionally returns the specified errno."""
235
236    def __init__(self, errno):
237        super().__init__(
238            [
239                SockFilter(
240                    BPF_RET,
241                    0x00,
242                    0x00,
243                    SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA),
244                )
245            ]
246        )
247        self.errno = errno
248
249
250class Allow(BasicBlock):
251    """A BasicBlock that unconditionally returns ALLOW."""
252
253    def __init__(self):
254        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)])
255
256
257class ValidateArch(AbstractBlock):
258    """An AbstractBlock that validates the architecture."""
259
260    def __init__(self, next_block):
261        super().__init__()
262        self.next_block = next_block
263
264    def accept(self, visitor):
265        if visitor.visited(self):
266            return
267        self.next_block.accept(visitor)
268        visitor.visit(self)
269
270
271class SyscallEntry(AbstractBlock):
272    """An abstract block that represents a syscall comparison in a DAG."""
273
274    def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ):
275        super().__init__()
276        self.op = op
277        self.syscall_number = syscall_number
278        self.jt = jt
279        self.jf = jf
280
281    def __lt__(self, o):
282        # Defined because we want to compare tuples that contain SyscallEntries.
283        return False
284
285    def __gt__(self, o):
286        # Defined because we want to compare tuples that contain SyscallEntries.
287        return False
288
289    def accept(self, visitor):
290        if visitor.visited(self):
291            return
292        self.jt.accept(visitor)
293        self.jf.accept(visitor)
294        visitor.visit(self)
295
296    def __lt__(self, o):
297        # Defined because we want to compare tuples that contain SyscallEntries.
298        return False
299
300    def __gt__(self, o):
301        # Defined because we want to compare tuples that contain SyscallEntries.
302        return False
303
304
305class WideAtom(AbstractBlock):
306    """A BasicBlock that represents a 32-bit wide atom."""
307
308    def __init__(
309        self, arg_offset, op, value, jt, jf
310    ):  # pylint: disable=redefined-outer-name
311        super().__init__()
312        self.arg_offset = arg_offset
313        self.op = op
314        self.value = value
315        self.jt = jt
316        self.jf = jf
317
318    def accept(self, visitor):
319        if visitor.visited(self):
320            return
321        self.jt.accept(visitor)
322        self.jf.accept(visitor)
323        visitor.visit(self)
324
325
326class Atom(AbstractBlock):
327    """A BasicBlock that represents an atom (a simple comparison operation)."""
328
329    def __init__(self, arg_index, op, value, jt, jf):
330        super().__init__()
331        if op == "==":
332            op = BPF_JEQ
333        elif op == "!=":
334            op = BPF_JEQ
335            jt, jf = jf, jt
336        elif op == ">":
337            op = BPF_JGT
338        elif op == "<=":
339            op = BPF_JGT
340            jt, jf = jf, jt
341        elif op == ">=":
342            op = BPF_JGE
343        elif op == "<":
344            op = BPF_JGE
345            jt, jf = jf, jt
346        elif op == "&":
347            op = BPF_JSET
348        elif op == "in":
349            op = BPF_JSET
350            # The mask is negated, so the comparison will be true when the
351            # argument includes a flag that wasn't listed in the original
352            # (non-negated) mask. This would be the failure case, so we switch
353            # |jt| and |jf|.
354            value = (~value) & ((1 << 64) - 1)
355            jt, jf = jf, jt
356        else:
357            raise Exception("Unknown operator %s" % op)
358
359        self.arg_index = arg_index
360        self.op = op
361        self.jt = jt
362        self.jf = jf
363        self.value = value
364
365    def accept(self, visitor):
366        if visitor.visited(self):
367            return
368        self.jt.accept(visitor)
369        self.jf.accept(visitor)
370        visitor.visit(self)
371
372
373class AbstractVisitor(abc.ABC):
374    """An abstract visitor."""
375
376    def __init__(self):
377        self._visited = set()
378
379    def visited(self, block):
380        if id(block) in self._visited:
381            return True
382        self._visited.add(id(block))
383        return False
384
385    def process(self, block):
386        block.accept(self)
387        return block
388
389    def visit(self, block):
390        if isinstance(block, KillProcess):
391            self.visitKillProcess(block)
392        elif isinstance(block, KillThread):
393            self.visitKillThread(block)
394        elif isinstance(block, Trap):
395            self.visitTrap(block)
396        elif isinstance(block, ReturnErrno):
397            self.visitReturnErrno(block)
398        elif isinstance(block, Trace):
399            self.visitTrace(block)
400        elif isinstance(block, UserNotify):
401            self.visitUserNotify(block)
402        elif isinstance(block, Log):
403            self.visitLog(block)
404        elif isinstance(block, Allow):
405            self.visitAllow(block)
406        elif isinstance(block, BasicBlock):
407            self.visitBasicBlock(block)
408        elif isinstance(block, ValidateArch):
409            self.visitValidateArch(block)
410        elif isinstance(block, SyscallEntry):
411            self.visitSyscallEntry(block)
412        elif isinstance(block, WideAtom):
413            self.visitWideAtom(block)
414        elif isinstance(block, Atom):
415            self.visitAtom(block)
416        else:
417            raise Exception("Unknown block type: %r" % block)
418
419    @abc.abstractmethod
420    def visitKillProcess(self, block):
421        pass
422
423    @abc.abstractmethod
424    def visitKillThread(self, block):
425        pass
426
427    @abc.abstractmethod
428    def visitTrap(self, block):
429        pass
430
431    @abc.abstractmethod
432    def visitReturnErrno(self, block):
433        pass
434
435    @abc.abstractmethod
436    def visitTrace(self, block):
437        pass
438
439    @abc.abstractmethod
440    def visitUserNotify(self, block):
441        pass
442
443    @abc.abstractmethod
444    def visitLog(self, block):
445        pass
446
447    @abc.abstractmethod
448    def visitAllow(self, block):
449        pass
450
451    @abc.abstractmethod
452    def visitBasicBlock(self, block):
453        pass
454
455    @abc.abstractmethod
456    def visitValidateArch(self, block):
457        pass
458
459    @abc.abstractmethod
460    def visitSyscallEntry(self, block):
461        pass
462
463    @abc.abstractmethod
464    def visitWideAtom(self, block):
465        pass
466
467    @abc.abstractmethod
468    def visitAtom(self, block):
469        pass
470
471
472class CopyingVisitor(AbstractVisitor):
473    """A visitor that copies Blocks."""
474
475    def __init__(self):
476        super().__init__()
477        self._mapping = {}
478
479    def process(self, block):
480        self._mapping = {}
481        block.accept(self)
482        return self._mapping[id(block)]
483
484    def visitKillProcess(self, block):
485        assert id(block) not in self._mapping
486        self._mapping[id(block)] = KillProcess()
487
488    def visitKillThread(self, block):
489        assert id(block) not in self._mapping
490        self._mapping[id(block)] = KillThread()
491
492    def visitTrap(self, block):
493        assert id(block) not in self._mapping
494        self._mapping[id(block)] = Trap()
495
496    def visitReturnErrno(self, block):
497        assert id(block) not in self._mapping
498        self._mapping[id(block)] = ReturnErrno(block.errno)
499
500    def visitTrace(self, block):
501        assert id(block) not in self._mapping
502        self._mapping[id(block)] = Trace()
503
504    def visitUserNotify(self, block):
505        assert id(block) not in self._mapping
506        self._mapping[id(block)] = UserNotify()
507
508    def visitLog(self, block):
509        assert id(block) not in self._mapping
510        self._mapping[id(block)] = Log()
511
512    def visitAllow(self, block):
513        assert id(block) not in self._mapping
514        self._mapping[id(block)] = Allow()
515
516    def visitBasicBlock(self, block):
517        assert id(block) not in self._mapping
518        self._mapping[id(block)] = BasicBlock(block.instructions)
519
520    def visitValidateArch(self, block):
521        assert id(block) not in self._mapping
522        self._mapping[id(block)] = ValidateArch(
523            self._mapping[id(block.next_block)]
524        )
525
526    def visitSyscallEntry(self, block):
527        assert id(block) not in self._mapping
528        self._mapping[id(block)] = SyscallEntry(
529            block.syscall_number,
530            self._mapping[id(block.jt)],
531            self._mapping[id(block.jf)],
532            op=block.op,
533        )
534
535    def visitWideAtom(self, block):
536        assert id(block) not in self._mapping
537        self._mapping[id(block)] = WideAtom(
538            block.arg_offset,
539            block.op,
540            block.value,
541            self._mapping[id(block.jt)],
542            self._mapping[id(block.jf)],
543        )
544
545    def visitAtom(self, block):
546        assert id(block) not in self._mapping
547        self._mapping[id(block)] = Atom(
548            block.arg_index,
549            block.op,
550            block.value,
551            self._mapping[id(block.jt)],
552            self._mapping[id(block.jf)],
553        )
554
555
556class LoweringVisitor(CopyingVisitor):
557    """A visitor that lowers Atoms into WideAtoms."""
558
559    def __init__(self, *, arch):
560        super().__init__()
561        self._bits = arch.bits
562
563    def visitAtom(self, block):
564        assert id(block) not in self._mapping
565
566        lo = block.value & 0xFFFFFFFF
567        hi = (block.value >> 32) & 0xFFFFFFFF
568
569        lo_block = WideAtom(
570            arg_offset(block.arg_index, False),
571            block.op,
572            lo,
573            self._mapping[id(block.jt)],
574            self._mapping[id(block.jf)],
575        )
576
577        if self._bits == 32:
578            self._mapping[id(block)] = lo_block
579            return
580
581        if block.op in (BPF_JGE, BPF_JGT):
582            # hi_1,lo_1 <op> hi_2,lo_2
583            #
584            # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2
585            if hi == 0:
586                # Special case: it's not needed to check whether |hi_1 == hi_2|,
587                # because it's true iff the JGT test fails.
588                self._mapping[id(block)] = WideAtom(
589                    arg_offset(block.arg_index, True),
590                    BPF_JGT,
591                    hi,
592                    self._mapping[id(block.jt)],
593                    lo_block,
594                )
595                return
596            hi_eq_block = WideAtom(
597                arg_offset(block.arg_index, True),
598                BPF_JEQ,
599                hi,
600                lo_block,
601                self._mapping[id(block.jf)],
602            )
603            self._mapping[id(block)] = WideAtom(
604                arg_offset(block.arg_index, True),
605                BPF_JGT,
606                hi,
607                self._mapping[id(block.jt)],
608                hi_eq_block,
609            )
610            return
611        if block.op == BPF_JSET:
612            # hi_1,lo_1 & hi_2,lo_2
613            #
614            # hi_1 & hi_2 || lo_1 & lo_2
615            if hi == 0:
616                # Special case: |hi_1 & hi_2| will never be True, so jump
617                # directly into the |lo_1 & lo_2| case.
618                self._mapping[id(block)] = lo_block
619                return
620            self._mapping[id(block)] = WideAtom(
621                arg_offset(block.arg_index, True),
622                block.op,
623                hi,
624                self._mapping[id(block.jt)],
625                lo_block,
626            )
627            return
628
629        assert block.op == BPF_JEQ, block.op
630
631        # hi_1,lo_1 == hi_2,lo_2
632        #
633        # hi_1 == hi_2 && lo_1 == lo_2
634        self._mapping[id(block)] = WideAtom(
635            arg_offset(block.arg_index, True),
636            block.op,
637            hi,
638            lo_block,
639            self._mapping[id(block.jf)],
640        )
641
642
643class FlatteningVisitor:
644    """A visitor that flattens a DAG of Block objects."""
645
646    def __init__(self, *, arch, kill_action):
647        self._visited = set()
648        self._kill_action = kill_action
649        self._instructions = []
650        self._arch = arch
651        self._offsets = {}
652
653    @property
654    def result(self):
655        return BasicBlock(self._instructions)
656
657    def _distance(self, block):
658        distance = self._offsets[id(block)] + len(self._instructions)
659        assert distance >= 0
660        return distance
661
662    def _emit_load_arg(self, offset):
663        return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)]
664
665    def _emit_jmp(self, op, value, jt_distance, jf_distance):
666        if jt_distance < 0x100 and jf_distance < 0x100:
667            return [
668                SockFilter(
669                    BPF_JMP | op | BPF_K, jt_distance, jf_distance, value
670                ),
671            ]
672        if jt_distance + 1 < 0x100:
673            return [
674                SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value),
675                SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
676            ]
677        if jf_distance + 1 < 0x100:
678            return [
679                SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value),
680                SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance),
681            ]
682        return [
683            SockFilter(BPF_JMP | op | BPF_K, 0, 1, value),
684            SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1),
685            SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
686        ]
687
688    def visited(self, block):
689        if id(block) in self._visited:
690            return True
691        self._visited.add(id(block))
692        return False
693
694    def visit(self, block):
695        assert id(block) not in self._offsets
696
697        if isinstance(block, BasicBlock):
698            instructions = block.instructions
699        elif isinstance(block, ValidateArch):
700            instructions = (
701                [
702                    SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4),
703                    SockFilter(
704                        BPF_JMP | BPF_JEQ | BPF_K,
705                        self._distance(block.next_block) + 1,
706                        0,
707                        self._arch.arch_nr,
708                    ),
709                ]
710                + self._kill_action.instructions
711                + [
712                    SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0),
713                ]
714            )
715        elif isinstance(block, SyscallEntry):
716            instructions = self._emit_jmp(
717                block.op,
718                block.syscall_number,
719                self._distance(block.jt),
720                self._distance(block.jf),
721            )
722        elif isinstance(block, WideAtom):
723            instructions = self._emit_load_arg(
724                block.arg_offset
725            ) + self._emit_jmp(
726                block.op,
727                block.value,
728                self._distance(block.jt),
729                self._distance(block.jf),
730            )
731        else:
732            raise Exception("Unknown block type: %r" % block)
733
734        self._instructions = instructions + self._instructions
735        self._offsets[id(block)] = -len(self._instructions)
736
737
738class ArgFilterForwardingVisitor:
739    """A visitor that forwards visitation to all arg filters."""
740
741    def __init__(self, visitor):
742        self._visited = set()
743        self.visitor = visitor
744
745    def visited(self, block):
746        if id(block) in self._visited:
747            return True
748        self._visited.add(id(block))
749        return False
750
751    def visit(self, block):
752        # All arg filters are BasicBlocks.
753        if not isinstance(block, BasicBlock):
754            return
755        # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't
756        # want to visit them just yet.
757        if isinstance(
758            block,
759            (
760                KillProcess,
761                KillThread,
762                Trap,
763                ReturnErrno,
764                Trace,
765                UserNotify,
766                Log,
767                Allow,
768            ),
769        ):
770            return
771        block.accept(self.visitor)
772