• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The ChromiumOS Authors
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5#
6# This script will take any number of trace files generated by strace(1)
7# and output a system call filtering policy suitable for use with Minijail.
8
9"""Tool to generate a minijail seccomp filter from strace or audit output."""
10
11import argparse
12import collections
13import datetime
14import os
15import re
16import sys
17
18
19# auparse may not be installed and is currently optional.
20try:
21    import auparse
22except ImportError:
23    auparse = None
24
25
26YEAR = datetime.datetime.now().year
27NOTICE = f"""# Copyright {YEAR} The ChromiumOS Authors
28# Use of this source code is governed by a BSD-style license that can be
29# found in the LICENSE file.
30"""
31
32ALLOW = "1"
33
34# This ignores any leading PID tag and trailing <unfinished ...>, and extracts
35# the syscall name and the argument list.
36LINE_RE = re.compile(r"^\s*(?:\[[^]]*\]|\d+)?\s*([a-zA-Z0-9_]+)\(([^)<]*)")
37
38SOCKETCALLS = {
39    "accept",
40    "bind",
41    "connect",
42    "getpeername",
43    "getsockname",
44    "getsockopt",
45    "listen",
46    "recv",
47    "recvfrom",
48    "recvmsg",
49    "send",
50    "sendmsg",
51    "sendto",
52    "setsockopt",
53    "shutdown",
54    "socket",
55    "socketpair",
56}
57
58# List of private ARM syscalls. These can be found in any ARM specific unistd.h
59# such as Linux's arch/arm/include/uapi/asm/unistd.h.
60PRIVATE_ARM_SYSCALLS = {
61    983041: "ARM_breakpoint",
62    983042: "ARM_cacheflush",
63    983043: "ARM_usr26",
64    983044: "ARM_usr32",
65    983045: "ARM_set_tls",
66}
67
68ArgInspectionEntry = collections.namedtuple(
69    "ArgInspectionEntry", ("arg_index", "value_set")
70)
71
72
73# pylint: disable=too-few-public-methods
74class BucketInputFiles(argparse.Action):
75    """Buckets input files using simple content based heuristics.
76
77    Attributes:
78        audit_logs: Mutually exclusive list of audit log filenames.
79        traces: Mutually exclusive list of strace log filenames.
80    """
81
82    def __call__(self, parser, namespace, values, option_string=None):
83        audit_logs = []
84        traces = []
85
86        strace_line_re = re.compile(r"[a-z]+[0-9]*\(.+\) += ")
87        audit_line_re = re.compile(r"type=(SYSCALL|SECCOMP)")
88
89        for filename in values:
90            if not os.path.exists(filename):
91                parser.error(f"Input file {filename} not found.")
92            with open(filename, mode="r", encoding="utf-8") as input_file:
93                for line in input_file.readlines():
94                    if strace_line_re.search(line):
95                        traces.append(filename)
96                        break
97                    if audit_line_re.search(line):
98                        audit_logs.append(filename)
99                        break
100                else:
101                    # Treat it as an strace log to retain legacy behaviour and
102                    # also just in case the strace regex is imperfect.
103                    traces.append(filename)
104
105        setattr(namespace, "audit_logs", audit_logs)
106        setattr(namespace, "traces", traces)
107
108
109# pylint: enable=too-few-public-methods
110
111
112def parse_args(argv):
113    """Returns the parsed CLI arguments for this tool."""
114    parser = argparse.ArgumentParser(description=__doc__)
115    parser.add_argument(
116        "--verbose",
117        action="store_true",
118        help="output informational messages to stderr",
119    )
120    parser.add_argument(
121        "--frequency", type=argparse.FileType("w"), help="frequency file"
122    )
123    parser.add_argument(
124        "--policy",
125        type=argparse.FileType("w"),
126        default=sys.stdout,
127        help="policy file",
128    )
129    parser.add_argument(
130        "input-logs",
131        action=BucketInputFiles,
132        help="strace and/or audit logs",
133        nargs="+",
134    )
135    parser.add_argument(
136        "--audit-comm",
137        type=str,
138        metavar="PROCESS_NAME",
139        help="relevant process name from the audit.log files",
140    )
141    opts = parser.parse_args(argv)
142
143    if opts.audit_logs and not auparse:
144        parser.error(
145            "Python bindings for the audit subsystem were not found.\n"
146            "Please install the python3-audit (sometimes python-audit)"
147            " package for your distro to process audit logs: "
148            f"{opts.audit_logs}"
149        )
150
151    if opts.audit_logs and not opts.audit_comm:
152        parser.error(
153            f"--audit-comm is required when using audit logs as input:"
154            f" {opts.audit_logs}"
155        )
156
157    if not opts.audit_logs and opts.audit_comm:
158        parser.error(
159            "--audit-comm was specified yet none of the input files "
160            "matched our hueristic for an audit log"
161        )
162
163    return opts
164
165
166def get_seccomp_bpf_filter(syscall, entry):
167    """Returns a minijail seccomp-bpf filter expression for the syscall."""
168    arg_index = entry.arg_index
169    arg_values = entry.value_set
170    atoms = []
171    if syscall in ("mmap", "mmap2", "mprotect") and arg_index == 2:
172        # See if there is at least one instance of any of these syscalls trying
173        # to map memory with both PROT_EXEC and PROT_WRITE. If there isn't, we
174        # can craft a concise expression to forbid this.
175        write_and_exec = set(("PROT_EXEC", "PROT_WRITE"))
176        for arg_value in arg_values:
177            if write_and_exec.issubset(
178                set(p.strip() for p in arg_value.split("|"))
179            ):
180                break
181        else:
182            atoms.extend(["arg2 in ~PROT_EXEC", "arg2 in ~PROT_WRITE"])
183            arg_values = set()
184    atoms.extend(f"arg{arg_index} == {arg_value}" for arg_value in arg_values)
185    return " || ".join(atoms)
186
187
188def parse_trace_file(trace_filename, syscalls, arg_inspection):
189    """Parses one file produced by strace."""
190    uses_socketcall = "i386" in trace_filename or (
191        "x86" in trace_filename and "64" not in trace_filename
192    )
193
194    with open(trace_filename, encoding="utf-8") as trace_file:
195        for line in trace_file:
196            matches = LINE_RE.match(line)
197            if not matches:
198                continue
199
200            syscall, args = matches.groups()
201            if uses_socketcall and syscall in SOCKETCALLS:
202                syscall = "socketcall"
203
204            # strace omits the 'ARM_' prefix on all private ARM syscalls. Add
205            # it manually here as a workaround. These syscalls are exclusive
206            # to ARM so we don't need to predicate this on a trace_filename
207            # based heuristic for the arch.
208            if f"ARM_{syscall}" in PRIVATE_ARM_SYSCALLS.values():
209                syscall = f"ARM_{syscall}"
210
211            syscalls[syscall] += 1
212
213            args = [arg.strip() for arg in args.split(",")]
214
215            if syscall in arg_inspection:
216                arg_value = args[arg_inspection[syscall].arg_index]
217                arg_inspection[syscall].value_set.add(arg_value)
218
219
220def parse_audit_log(audit_log, audit_comm, syscalls, arg_inspection):
221    """Parses one audit.log file generated by the Linux audit subsystem."""
222
223    unknown_syscall_re = re.compile(r"unknown-syscall\((?P<syscall_num>\d+)\)")
224
225    au = auparse.AuParser(auparse.AUSOURCE_FILE, audit_log)
226    # Quick validity check for whether this parses as a valid audit log. The
227    # first event should have at least one record.
228    if not au.first_record():
229        raise ValueError(f"Unable to parse audit log file {audit_log.name}")
230
231    # Iterate through events where _any_ contained record matches
232    # ((type == SECCOMP || type == SYSCALL) && comm == audit_comm).
233    au.search_add_item("type", "=", "SECCOMP", auparse.AUSEARCH_RULE_CLEAR)
234    au.search_add_item("type", "=", "SYSCALL", auparse.AUSEARCH_RULE_OR)
235    au.search_add_item(
236        "comm", "=", f'"{audit_comm}"', auparse.AUSEARCH_RULE_AND
237    )
238
239    # auparse_find_field(3) will ignore preceding fields in the record and
240    # at the same time happily cross record boundaries when looking for the
241    # field. This helper method always seeks the cursor back to the first
242    # field in the record and stops searching before crossing over to the
243    # next record; making the search far less error prone.
244    # Also implicitly seeks the internal 'cursor' to the matching field
245    # for any subsequent calls like auparse_interpret_field.
246    def _find_field_in_current_record(name):
247        au.first_field()
248        while True:
249            if au.get_field_name() == name:
250                return au.get_field_str()
251            if not au.next_field():
252                return None
253
254    while au.search_next_event():
255        # The event may have multiple records. Loop through all.
256        au.first_record()
257        for _ in range(au.get_num_records()):
258            event_type = _find_field_in_current_record("type")
259            comm = _find_field_in_current_record("comm")
260            # Some of the records in this event may not be relevant
261            # despite the event-specific search filter. Skip those.
262            if (
263                event_type not in ("SECCOMP", "SYSCALL")
264                or comm != f'"{audit_comm}"'
265            ):
266                au.next_record()
267                continue
268
269            if not _find_field_in_current_record("syscall"):
270                raise ValueError(
271                    f'Could not find field "syscall" in event of '
272                    f"type {event_type}"
273                )
274            # Intepret the syscall field that's under our 'cursor' following the
275            # find. Interpreting fields yields human friendly names instead
276            # of integers. E.g '16' -> 'ioctl'.
277            syscall = au.interpret_field()
278
279            # TODO(crbug/1172449): Add these syscalls to upstream
280            # audit-userspace and remove this workaround.
281            # This is redundant but safe for non-ARM architectures due to the
282            # disjoint set of private syscall numbers.
283            match = unknown_syscall_re.match(syscall)
284            if match:
285                syscall_num = int(match.group("syscall_num"))
286                syscall = PRIVATE_ARM_SYSCALLS.get(syscall_num, syscall)
287
288            if (syscall in arg_inspection and event_type == "SECCOMP") or (
289                syscall not in arg_inspection and event_type == "SYSCALL"
290            ):
291                # Skip SECCOMP records for syscalls that require argument
292                # inspection. Similarly, skip SYSCALL records for syscalls
293                # that do not require argument inspection. Technically such
294                # records wouldn't exist per our setup instructions but audit
295                # sometimes lets a few records slip through.
296                au.next_record()
297                continue
298            elif event_type == "SYSCALL":
299                arg_field_name = f"a{arg_inspection[syscall].arg_index}"
300                if not _find_field_in_current_record(arg_field_name):
301                    raise ValueError(
302                        f'Could not find field "{arg_field_name}"'
303                        f"in event of type {event_type}"
304                    )
305                # Intepret the arg field that's under our 'cursor' following the
306                # find. This may yield a more human friendly name.
307                # E.g '5401' -> 'TCGETS'.
308                arg_inspection[syscall].value_set.add(au.interpret_field())
309
310            syscalls[syscall] += 1
311            au.next_record()
312
313
314def main(argv=None):
315    """Main entrypoint."""
316
317    if argv is None:
318        argv = sys.argv[1:]
319
320    opts = parse_args(argv)
321
322    syscalls = collections.defaultdict(int)
323
324    arg_inspection = {
325        "socket": ArgInspectionEntry(0, set([])),  # int domain
326        "ioctl": ArgInspectionEntry(1, set([])),  # int request
327        "prctl": ArgInspectionEntry(0, set([])),  # int option
328        "mmap": ArgInspectionEntry(2, set([])),  # int prot
329        "mmap2": ArgInspectionEntry(2, set([])),  # int prot
330        "mprotect": ArgInspectionEntry(2, set([])),  # int prot
331    }
332
333    if opts.verbose:
334        # Print an informational message to stderr in case the filetype
335        # detection heuristics are wonky.
336        print(
337            "Generating a seccomp policy using these input files:",
338            file=sys.stderr,
339        )
340        print(f"Strace logs: {opts.traces}", file=sys.stderr)
341        print(f"Audit logs: {opts.audit_logs}", file=sys.stderr)
342
343    for trace_filename in opts.traces:
344        parse_trace_file(trace_filename, syscalls, arg_inspection)
345
346    for audit_log in opts.audit_logs:
347        parse_audit_log(audit_log, opts.audit_comm, syscalls, arg_inspection)
348
349    # Add the basic set if they are not yet present.
350    basic_set = [
351        "restart_syscall",
352        "exit",
353        "exit_group",
354        "rt_sigreturn",
355    ]
356    for basic_syscall in basic_set:
357        if basic_syscall not in syscalls:
358            syscalls[basic_syscall] = 1
359
360    # If a frequency file isn't used then sort the syscalls based on frequency
361    # to make the common case fast (by checking frequent calls earlier).
362    # Otherwise, sort alphabetically to make it easier for humans to see which
363    # calls are in use (and if necessary manually add a new syscall to the
364    # list).
365    if opts.frequency is None:
366        sorted_syscalls = list(
367            x[0]
368            for x in sorted(
369                syscalls.items(), key=lambda pair: pair[1], reverse=True
370            )
371        )
372    else:
373        sorted_syscalls = list(
374            x[0] for x in sorted(syscalls.items(), key=lambda pair: pair[0])
375        )
376
377    print(NOTICE, file=opts.policy)
378    if opts.frequency is not None:
379        print(NOTICE, file=opts.frequency)
380
381    for syscall in sorted_syscalls:
382        if syscall in arg_inspection:
383            arg_filter = get_seccomp_bpf_filter(
384                syscall, arg_inspection[syscall]
385            )
386        else:
387            arg_filter = ALLOW
388        print(f"{syscall}: {arg_filter}", file=opts.policy)
389        if opts.frequency is not None:
390            print(f"{syscall}: {syscalls[syscall]}", file=opts.frequency)
391
392
393if __name__ == "__main__":
394    sys.exit(main(sys.argv[1:]))
395