• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The ChromiumOS Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""A BPF compiler for the Minijail policy file."""
6
7import enum
8
9
10try:
11    import parser  # pylint: disable=wrong-import-order
12
13    import bpf
14except ImportError:
15    from minijail import bpf
16    from minijail import parser  # pylint: disable=wrong-import-order
17
18
19class OptimizationStrategy(enum.Enum):
20    """The available optimization strategies."""
21
22    # Generate a linear chain of syscall number checks. Works best for policies
23    # with very few syscalls.
24    LINEAR = "linear"
25
26    # Generate a binary search tree for the syscalls. Works best for policies
27    # with a lot of syscalls, where no one syscall dominates.
28    BST = "bst"
29
30    def __str__(self):
31        return self.value
32
33
34class SyscallPolicyEntry:
35    """The parsed version of a seccomp policy line."""
36
37    def __init__(self, name, number, frequency):
38        self.name = name
39        self.number = number
40        self.frequency = frequency
41        self.accumulated = 0
42        self.filter = None
43
44    def __repr__(self):
45        return (
46            "SyscallPolicyEntry<name: %s, number: %d, "
47            "frequency: %d, filter: %r>"
48        ) % (
49            self.name,
50            self.number,
51            self.frequency,
52            self.filter.instructions if self.filter else None,
53        )
54
55    def simulate(self, arch, syscall_number, *args):
56        """Simulate the policy with the given arguments."""
57        if not self.filter:
58            return (0, "ALLOW")
59        return bpf.simulate(
60            self.filter.instructions, arch, syscall_number, *args
61        )
62
63
64class SyscallPolicyRange:
65    """A contiguous range of SyscallPolicyEntries that have the same action."""
66
67    def __init__(self, *entries):
68        self.numbers = (entries[0].number, entries[-1].number + 1)
69        self.frequency = sum(e.frequency for e in entries)
70        self.accumulated = 0
71        self.filter = entries[0].filter
72
73    def __repr__(self):
74        return "SyscallPolicyRange<numbers: %r, frequency: %d, filter: %r>" % (
75            self.numbers,
76            self.frequency,
77            self.filter.instructions if self.filter else None,
78        )
79
80    def simulate(self, arch, syscall_number, *args):
81        """Simulate the policy with the given arguments."""
82        if not self.filter:
83            return (0, "ALLOW")
84        return self.filter.simulate(arch, syscall_number, *args)
85
86
87def _convert_to_ranges(entries):
88    entries = list(sorted(entries, key=lambda r: r.number))
89    lower = 0
90    while lower < len(entries):
91        upper = lower + 1
92        while upper < len(entries):
93            if entries[upper - 1].filter != entries[upper].filter:
94                break
95            if entries[upper - 1].number + 1 != entries[upper].number:
96                break
97            upper += 1
98        yield SyscallPolicyRange(*entries[lower:upper])
99        lower = upper
100
101
102def _compile_single_range(
103    entry, accept_action, reject_action, lower_bound=0, upper_bound=1e99
104):
105    action = accept_action
106    if entry.filter:
107        action = entry.filter
108    if entry.numbers[1] - entry.numbers[0] == 1:
109        # Single syscall.
110        # Accept if |X == nr|.
111        return (
112            1,
113            bpf.SyscallEntry(
114                entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ
115            ),
116        )
117    elif entry.numbers[0] == lower_bound:
118        # Syscall range aligned with the lower bound.
119        # Accept if |X < nr[1]|.
120        return (
121            1,
122            bpf.SyscallEntry(
123                entry.numbers[1], reject_action, action, op=bpf.BPF_JGE
124            ),
125        )
126    elif entry.numbers[1] == upper_bound:
127        # Syscall range aligned with the upper bound.
128        # Accept if |X >= nr[0]|.
129        return (
130            1,
131            bpf.SyscallEntry(
132                entry.numbers[0], action, reject_action, op=bpf.BPF_JGE
133            ),
134        )
135    # Syscall range in the middle.
136    # Accept if |nr[0] <= X < nr[1]|.
137    upper_entry = bpf.SyscallEntry(
138        entry.numbers[1], reject_action, action, op=bpf.BPF_JGE
139    )
140    return (
141        2,
142        bpf.SyscallEntry(
143            entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE
144        ),
145    )
146
147
148def _compile_ranges_linear(ranges, accept_action, reject_action):
149    # Compiles the list of ranges into a simple linear list of comparisons. In
150    # order to make the generated code a bit more efficient, we sort the
151    # ranges by frequency, so that the most frequently-called syscalls appear
152    # earlier in the chain.
153    cost = 0
154    accumulated_frequencies = 0
155    next_action = reject_action
156    for entry in sorted(ranges, key=lambda r: r.frequency):
157        current_cost, next_action = _compile_single_range(
158            entry, accept_action, next_action
159        )
160        accumulated_frequencies += entry.frequency
161        cost += accumulated_frequencies * current_cost
162    return (cost, next_action)
163
164
165def _compile_entries_linear(entries, accept_action, reject_action):
166    return _compile_ranges_linear(
167        _convert_to_ranges(entries), accept_action, reject_action
168    )[1]
169
170
171def _compile_entries_bst(entries, accept_action, reject_action):
172    # Instead of generating a linear list of comparisons, this method generates
173    # a binary search tree, where some of the leaves can be linear chains of
174    # comparisons.
175    #
176    # Even though we are going to perform a binary search over the syscall
177    # number, we would still like to rotate some of the internal nodes of the
178    # binary search tree so that more frequently-used syscalls can be accessed
179    # more cheaply (i.e. fewer internal nodes need to be traversed to reach
180    # them).
181    #
182    # This uses Dynamic Programming to generate all possible BSTs efficiently
183    # (in O(n^3)) so that we can get the absolute minimum-cost tree that matches
184    # all syscall entries. It does so by considering all of the O(n^2) possible
185    # sub-intervals, and for each one of those try all of the O(n) partitions of
186    # that sub-interval. At each step, it considers putting the remaining
187    # entries in a linear comparison chain as well as another BST, and chooses
188    # the option that minimizes the total overall cost.
189    #
190    # Between every pair of non-contiguous allowed syscalls, there are two
191    # locally optimal options as to where to set the partition for the
192    # subsequent ranges: aligned to the end of the left subrange or to the
193    # beginning of the right subrange. The fact that these two options have
194    # slightly different costs, combined with the possibility of a subtree to
195    # use the linear chain strategy (which has a completely different cost
196    # model), causes the target cost function that we are trying to optimize to
197    # not be unimodal / convex. This unfortunately means that more clever
198    # techniques like using ternary search (which would reduce the overall
199    # complexity to O(n^2 log n)) do not work in all cases.
200    ranges = list(_convert_to_ranges(entries))
201
202    accumulated = 0
203    for entry in ranges:
204        accumulated += entry.frequency
205        entry.accumulated = accumulated
206
207    # Memoization cache to build the DP table top-down, which is easier to
208    # understand.
209    memoized_costs = {}
210
211    def _generate_syscall_bst(ranges, indices, bounds=(0, 2**64 - 1)):
212        assert bounds[0] <= ranges[indices[0]].numbers[0], (indices, bounds)
213        assert ranges[indices[1] - 1].numbers[1] <= bounds[1], (indices, bounds)
214
215        if bounds in memoized_costs:
216            return memoized_costs[bounds]
217        if indices[1] - indices[0] == 1:
218            if bounds == ranges[indices[0]].numbers:
219                # If bounds are tight around the syscall, it costs nothing.
220                memoized_costs[bounds] = (
221                    0,
222                    ranges[indices[0]].filter or accept_action,
223                )
224                return memoized_costs[bounds]
225            result = _compile_single_range(
226                ranges[indices[0]], accept_action, reject_action
227            )
228            memoized_costs[bounds] = (
229                result[0] * ranges[indices[0]].frequency,
230                result[1],
231            )
232            return memoized_costs[bounds]
233
234        # Try the linear model first and use that as the best estimate so far.
235        best_cost = _compile_ranges_linear(
236            ranges[slice(*indices)], accept_action, reject_action
237        )
238
239        # Now recursively go through all possible partitions of the interval
240        # currently being considered.
241        previous_accumulated = (
242            ranges[indices[0]].accumulated - ranges[indices[0]].frequency
243        )
244        bst_comparison_cost = (
245            ranges[indices[1] - 1].accumulated - previous_accumulated
246        )
247        for i, entry in enumerate(ranges[slice(*indices)]):
248            candidates = [entry.numbers[0]]
249            if i:
250                candidates.append(ranges[i - 1 + indices[0]].numbers[1])
251            for cutoff_bound in candidates:
252                if not bounds[0] < cutoff_bound < bounds[1]:
253                    continue
254                if not indices[0] < i + indices[0] < indices[1]:
255                    continue
256                left_subtree = _generate_syscall_bst(
257                    ranges,
258                    (indices[0], i + indices[0]),
259                    (bounds[0], cutoff_bound),
260                )
261                right_subtree = _generate_syscall_bst(
262                    ranges,
263                    (i + indices[0], indices[1]),
264                    (cutoff_bound, bounds[1]),
265                )
266                best_cost = min(
267                    best_cost,
268                    (
269                        bst_comparison_cost
270                        + left_subtree[0]
271                        + right_subtree[0],
272                        bpf.SyscallEntry(
273                            cutoff_bound,
274                            right_subtree[1],
275                            left_subtree[1],
276                            op=bpf.BPF_JGE,
277                        ),
278                    ),
279                )
280
281        memoized_costs[bounds] = best_cost
282        return memoized_costs[bounds]
283
284    return _generate_syscall_bst(ranges, (0, len(ranges)))[1]
285
286
287class PolicyCompiler:
288    """A parser for the Minijail seccomp policy file format."""
289
290    def __init__(self, arch):
291        self._arch = arch
292
293    def compile_file(
294        self,
295        policy_filename,
296        *,
297        optimization_strategy,
298        kill_action,
299        include_depth_limit=10,
300        override_default_action=None,
301        denylist=False,
302        ret_log=False,
303    ):
304        """Return a compiled BPF program from the provided policy file."""
305        policy_parser = parser.PolicyParser(
306            self._arch,
307            kill_action=kill_action,
308            include_depth_limit=include_depth_limit,
309            override_default_action=override_default_action,
310            denylist=denylist,
311            ret_log=ret_log,
312        )
313        parsed_policy = policy_parser.parse_file(policy_filename)
314        entries = [
315            self.compile_filter_statement(
316                filter_statement, kill_action=kill_action, denylist=denylist
317            )
318            for filter_statement in parsed_policy.filter_statements
319        ]
320
321        visitor = bpf.FlatteningVisitor(
322            arch=self._arch, kill_action=kill_action
323        )
324        if denylist:
325            accept_action = kill_action
326            reject_action = bpf.Allow()
327        else:
328            accept_action = bpf.Allow()
329            reject_action = parsed_policy.default_action
330        if entries:
331            if optimization_strategy == OptimizationStrategy.BST:
332                next_action = _compile_entries_bst(
333                    entries, accept_action, reject_action
334                )
335            else:
336                next_action = _compile_entries_linear(
337                    entries, accept_action, reject_action
338                )
339            next_action.accept(bpf.ArgFilterForwardingVisitor(visitor))
340            reject_action.accept(visitor)
341            accept_action.accept(visitor)
342            bpf.ValidateArch(next_action).accept(visitor)
343        else:
344            reject_action.accept(visitor)
345            bpf.ValidateArch(reject_action).accept(visitor)
346        return visitor.result
347
348    def compile_filter_statement(
349        self, filter_statement, *, kill_action, denylist=False
350    ):
351        """Compile one parser.FilterStatement into BPF."""
352        policy_entry = SyscallPolicyEntry(
353            filter_statement.syscall.name,
354            filter_statement.syscall.number,
355            filter_statement.frequency,
356        )
357        # In each step of the way, the false action is the one that is taken if
358        # the immediate boolean condition does not match. This means that the
359        # false action taken here is the one that applies if the whole
360        # expression fails to match.
361        false_action = filter_statement.filters[-1].action
362        if not denylist and false_action == bpf.Allow():
363            return policy_entry
364        # We then traverse the list of filters backwards since we want
365        # the root of the DAG to be the very first boolean operation in
366        # the filter chain.
367        for filt in filter_statement.filters[:-1][::-1]:
368            for disjunction in filt.expression:
369                # This is the jump target of the very last comparison in the
370                # conjunction. Given that any conjunction that succeeds should
371                # make the whole expression succeed, make the very last
372                # comparison jump to the accept action if it succeeds.
373                true_action = filt.action
374                for atom in disjunction:
375                    block = bpf.Atom(
376                        atom.argument_index,
377                        atom.op,
378                        atom.value,
379                        true_action,
380                        false_action,
381                    )
382                    true_action = block
383                false_action = true_action
384        policy_filter = false_action
385
386        # Lower all Atoms into WideAtoms.
387        lowering_visitor = bpf.LoweringVisitor(arch=self._arch)
388        policy_filter = lowering_visitor.process(policy_filter)
389
390        # Flatten the IR DAG into a single BasicBlock.
391        flattening_visitor = bpf.FlatteningVisitor(
392            arch=self._arch, kill_action=kill_action
393        )
394        policy_filter.accept(flattening_visitor)
395        policy_entry.filter = flattening_visitor.result
396        return policy_entry
397