• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4#
5# Copyright (c) 2022 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import sys
20import argparse
21import textwrap
22import re
23import os
24import stat
25
26supported_parse_item = ['labelName', 'priority', 'allowList', 'blockList', 'priorityWithArgs', \
27                        'allowListWithArgs', 'headFiles', 'selfDefineSyscall', 'returnValue', \
28                        'mode', 'privilegedProcessName', 'allowBlockList']
29
30supported_architecture = ['arm', 'arm64']
31
32BPF_JGE = 'BPF_JUMP(BPF_JMP|BPF_JGE|BPF_K, {}, {}, {}),'
33BPF_JGT = 'BPF_JUMP(BPF_JMP|BPF_JGT|BPF_K, {}, {}, {}),'
34BPF_JEQ = 'BPF_JUMP(BPF_JMP|BPF_JEQ|BPF_K, {}, {}, {}),'
35BPF_JSET = 'BPF_JUMP(BPF_JMP|BPF_JSET|BPF_K, {}, {}, {}),'
36BPF_JA = 'BPF_JUMP(BPF_JMP|BPF_JA, {}, 0, 0),'
37BPF_LOAD = 'BPF_STMT(BPF_LD|BPF_W|BPF_ABS, {}),'
38BPF_LOAD_MEM = 'BPF_STMT(BPF_LD|BPF_MEM, {}),'
39BPF_ST = 'BPF_STMT(BPF_ST, {}),'
40BPF_AND = 'BPF_STMT(BPF_ALU|BPF_AND|BPF_K, {}),'
41BPF_RET_VALUE = 'BPF_STMT(BPF_RET|BPF_K, {}),'
42
43operation = ['<', '<=', '!=', '==', '>', '>=', '&']
44
45ret_str_to_bpf = {
46    'KILL_PROCESS': 'SECCOMP_RET_KILL_PROCESS',
47    'KILL_THREAD': 'SECCOMP_RET_KILL_THREAD',
48    'TRAP': 'SECCOMP_RET_TRAP',
49    'ERRNO': 'SECCOMP_RET_ERRNO',
50    'USER_NOTIF': 'SECCOMP_RET_USER_NOTIF',
51    'TRACE': 'SECCOMP_RET_TRACE',
52    'LOG' : 'SECCOMP_RET_LOG',
53    'ALLOW': 'SECCOMP_RET_ALLOW'
54}
55
56mode_str = {
57    'DEFAULT': 0,
58    'ONLY_CHECK_ARGS': 1
59}
60
61architecture_to_number = {
62    'arm': 'AUDIT_ARCH_ARM',
63    'arm64': 'AUDIT_ARCH_AARCH64'
64}
65
66
67class ValidateError(Exception):
68    def __init__(self, msg):
69        super().__init__(msg)
70
71
72def print_info(info):
73    print("[INFO] %s" % info)
74
75
76def is_hex_digit(s):
77    try:
78        int(s, 16)
79        return True
80
81    except ValueError:
82        return False
83
84
85def str_convert_to_int(s):
86    number = -1
87    digit_flag = False
88
89    if s.isdigit() :
90        number = int(s)
91        digit_flag = True
92
93    elif is_hex_digit(s):
94        number = int(s, 16)
95        digit_flag = True
96
97    return number, digit_flag
98
99
100def is_function_name_exist(arch, function_name, func_name_nr_table):
101    if function_name in func_name_nr_table:
102        return True
103    else:
104        raise ValidateError('{} not exsit in {} function_name_nr_table Table'.format(function_name, arch))
105
106
107def is_errno_in_valid_range(errno):
108    if int(errno) > 0 and int(errno) <= 255 and errno.isdigit():
109        return True
110    else:
111        raise ValidateError('{} not within the legal range of errno values.'.format(errno))
112
113
114def is_return_errno(return_str):
115    if return_str[0:len('ERRNO')] == 'ERRNO':
116        errno_no = return_str[return_str.find('(') + 1 : return_str.find(')')]
117        return_string = return_str[0:len('ERRNO')]
118        return_string += ' | '
119        if is_errno_in_valid_range(errno_no):
120            return_string += errno_no
121            return True, return_string
122    return False, 'not_return_errno'
123
124
125def function_name_to_nr(function_name_list, func_name_nr_table):
126    return set(func_name_nr_table[function_name] for function_name \
127    in function_name_list if function_name in func_name_nr_table)
128
129
130def filter_syscalls_nr(name_to_nr):
131    syscalls = {}
132    for syscall_name, nr in name_to_nr.items():
133        if not syscall_name.startswith("__NR_") and not syscall_name.startswith("__ARM_NR_"):
134            continue
135
136        if syscall_name.startswith("__NR_arm_"):
137            syscall_name = syscall_name[len("__NR_arm_"):]
138        elif syscall_name.startswith("__NR_"):
139            syscall_name = syscall_name[len("__NR_"):]
140        elif syscall_name.startswith("__ARM_NR_"):
141            syscall_name = syscall_name[len("__ARM_NR_"):]
142
143        syscalls[syscall_name] = nr
144
145    return syscalls
146
147
148def parse_syscall_file(file_name):
149    const_pattern = re.compile(
150        r'^\s*#define\s+([A-Za-z_][A-Za-z0-9_]+)\s+(.+)\s*$')
151    mark_pattern = re.compile(r'\b[A-Za-z_][A-Za-z0-9_]+\b')
152    name_to_nr = {}
153    with open(file_name) as f:
154        for line in f:
155            k = const_pattern.match(line)
156            if k is None:
157                continue
158            try:
159                name = k.group(1)
160                nr = eval(mark_pattern.sub(lambda x: str(name_to_nr.get(x.group(0))),
161                                        k.group(2)))
162
163                name_to_nr[name] = nr
164            except(KeyError, SyntaxError, NameError, TypeError):
165                continue
166
167    return filter_syscalls_nr(name_to_nr)
168
169
170def gen_syscall_nr_table(file_name, func_name_nr_table):
171    s = re.search(r"libsyscall_to_nr_([^/]+)", file_name)
172    func_name_nr_table[str(s.group(1))] = parse_syscall_file(file_name)
173    if str(s.group(1)) not in func_name_nr_table.keys():
174        raise ValidateError("parse syscall file failed")
175    return func_name_nr_table
176
177
178class SeccompPolicyParam:
179    def __init__(self, arch, function_name_nr_table, is_debug):
180        self.arch = arch
181        self.priority = set()
182        self.allow_list = set()
183        self.blocklist = set()
184        self.priority_with_args = set()
185        self.allow_list_with_args = set()
186        self.head_files = set()
187        self.self_define_syscall = set()
188        self.final_allow_list = set()
189        self.final_priority = set()
190        self.final_priority_with_args = set()
191        self.final_allow_list_with_args = set()
192        self.return_value = ''
193        self.mode = 'DEFAULT'
194        self.is_debug = is_debug
195        self.function_name_nr_table = function_name_nr_table
196        self.value_function = {
197            'priority': self.update_priority,
198            'allowList': self.update_allow_list,
199            'blockList': self.update_blocklist,
200            'allowListWithArgs': self.update_allow_list_with_args,
201            'priorityWithArgs': self.update_priority_with_args,
202            'headFiles': self.update_head_files,
203            'selfDefineSyscall': self.update_self_define_syscall,
204            'returnValue': self.update_return_value,
205            'mode': self.update_mode
206        }
207
208    def clear_list(self):
209        self.priority.clear()
210        self.allow_list.clear()
211        self.allow_list_with_args.clear()
212        self.priority_with_args.clear()
213        if self.mode == 'ONLY_CHECK_ARGS':
214            self.final_allow_list.clear()
215            self.final_priority.clear()
216
217    def update_list(self, function_name, to_update_list):
218        if is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
219            to_update_list.add(function_name)
220            return True
221        return False
222
223    def update_priority(self, function_name):
224        return self.update_list(function_name, self.priority)
225
226    def update_allow_list(self, function_name):
227        return self.update_list(function_name, self.allow_list)
228
229    def update_blocklist(self, function_name):
230        return self.update_list(function_name, self.blocklist)
231
232    def update_priority_with_args(self, function_name_with_args):
233        function_name = function_name_with_args[:function_name_with_args.find(':')]
234        function_name = function_name.strip()
235        if is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
236            self.priority_with_args.add(function_name_with_args)
237            return True
238        return False
239
240    def update_allow_list_with_args(self, function_name_with_args):
241        function_name = function_name_with_args[:function_name_with_args.find(':')]
242        function_name = function_name.strip()
243        if is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
244            self.allow_list_with_args.add(function_name_with_args)
245            return True
246        return False
247
248    def update_head_files(self, head_files):
249        if len(head_files) > 2 and (head_files[0] == '\"' and head_files[-1] == '\"') or \
250            (head_files[0] == '<' and head_files[-1] == '>'):
251            self.head_files.add(head_files)
252            return True
253
254        raise ValidateError('{} is not legal by headFiles format'.format(head_files))
255
256    def update_self_define_syscall(self, self_define_syscall):
257        nr, digit_flag = str_convert_to_int(self_define_syscall)
258        if digit_flag and nr not in self.function_name_nr_table.values():
259            self.self_define_syscall.add(nr)
260            return True
261
262        raise ValidateError('{} is not a number or {} is already used by ohter \
263            syscall'.format(self_define_syscall, self_define_syscall))
264
265    def update_return_value(self, return_str):
266        is_ret_errno, return_string = is_return_errno(return_str)
267        if is_ret_errno == True:
268            self.return_value = return_string
269            return True
270        if return_str in ret_str_to_bpf:
271            if self.is_debug == 'false' and return_str == 'LOG':
272                raise ValidateError("LOG return value is not allowed in user mode")
273            self.return_value = return_str
274            return True
275
276        raise ValidateError('{} not in {}'.format(return_str, ret_str_to_bpf.keys()))
277
278    def update_mode(self, mode):
279        if mode in mode_str.keys():
280            self.mode = mode
281            return True
282        raise ValidateError('{} not in [DEFAULT, ONLY_CHECK_ARGS]'.format(mode_str))
283
284    def check_allow_list(self, allow_list):
285        for item in allow_list:
286            pos = item.find(':')
287            syscall = item
288            if pos != -1:
289                syscall = item[:pos]
290            if syscall in self.blocklist:
291                raise ValidateError('{} of allow list  is in block list'.format(syscall))
292        return True
293
294    def check_all_allow_list(self):
295        flag = self.check_allow_list(self.final_allow_list) \
296               and self.check_allow_list(self.final_priority) \
297               and self.check_allow_list(self.final_priority_with_args) \
298               and self.check_allow_list(self.final_allow_list_with_args)
299        block_nr_list = function_name_to_nr(self.blocklist, self.function_name_nr_table)
300        for nr in self.self_define_syscall:
301            if nr in block_nr_list:
302                return False
303        return flag
304
305    def update_final_list(self):
306        #remove duplicate function_name
307        self.final_allow_list |= self.allow_list
308        self.final_priority |= self.priority
309        self.final_allow_list_with_args |= self.allow_list_with_args
310        self.final_priority_with_args |= self.priority_with_args
311        final_priority_function_name_list_with_args = set(item[:item.find(':')]
312                                                            for item in self.final_priority_with_args)
313        final_function_name_list_with_args = set(item[:item.find(':')]
314                                                    for item in self.final_allow_list_with_args)
315        self.final_allow_list = self.final_allow_list - self.final_priority - \
316                                    final_priority_function_name_list_with_args - final_function_name_list_with_args
317        self.final_priority = self.final_priority - final_priority_function_name_list_with_args - \
318                                final_function_name_list_with_args
319        self.clear_list()
320
321
322class GenBpfPolicy:
323    def __init__(self):
324        self.arch = ''
325        self.syscall_nr_range = []
326        self.bpf_policy = []
327        self.syscall_nr_policy_list = []
328        self.function_name_nr_table_dict = {}
329        self.gen_mode = 0
330        self.flag = True
331        self.return_value = ''
332        self.operate_func_table = {
333            '<' : self.gen_bpf_lt,
334            '<=': self.gen_bpf_le,
335            '==': self.gen_bpf_eq,
336            '!=': self.gen_bpf_ne,
337            '>' : self.gen_bpf_gt,
338            '>=': self.gen_bpf_ge,
339            '&' : self.gen_bpf_set,
340        }
341
342    def update_arch(self, arch):
343        self.arch = arch
344        self.syscall_nr_range = []
345        self.syscall_nr_policy_list = []
346
347    def update_function_name_nr_table(self, func_name_nr_table):
348        self.function_name_nr_table_dict = func_name_nr_table
349
350    def clear_bpf_policy(self):
351        self.bpf_policy.clear()
352
353    def get_gen_flag(self):
354        return self.flag
355
356    def set_gen_flag(self, flag):
357        if flag:
358            self.flag = True
359        else:
360            self.flag = False
361
362    def set_gen_mode(self, mode):
363        self.gen_mode = mode_str.get(mode)
364
365    def set_return_value(self, return_value):
366        is_ret_errno, return_string = is_return_errno(return_value)
367        if is_ret_errno == True:
368            self.return_value = return_string
369            return
370        if return_value not in ret_str_to_bpf:
371            self.set_gen_mode(False)
372            return
373
374        self.return_value = return_value
375
376    @staticmethod
377    def gen_bpf_eq32(const_str, jt, jf):
378        bpf_policy = []
379        bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
380        return bpf_policy
381
382    @staticmethod
383    def gen_bpf_eq64(const_str, jt, jf):
384        bpf_policy = []
385        bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
386        bpf_policy.append(BPF_LOAD_MEM.format(0))
387        bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
388        return bpf_policy
389
390    def gen_bpf_eq(self, const_str, jt, jf):
391        if self.arch == 'arm':
392            return self.gen_bpf_eq32(const_str, jt, jf)
393        elif self.arch == 'arm64':
394            return self.gen_bpf_eq64(const_str, jt, jf)
395        return []
396
397    def gen_bpf_ne(self, const_str, jt, jf):
398        return self.gen_bpf_eq(const_str, jf, jt)
399
400    @staticmethod
401    def gen_bpf_gt32(const_str, jt, jf):
402        bpf_policy = []
403        bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
404        return bpf_policy
405
406    @staticmethod
407    def gen_bpf_gt64(const_str, jt, jf):
408        bpf_policy = []
409        number, digit_flag = str_convert_to_int(const_str)
410
411        hight = int(number / (2**32))
412        low = number & 0xffffffff
413
414        if digit_flag and hight == 0:
415            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
416        else:
417            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
418            bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
419
420        bpf_policy.append(BPF_LOAD_MEM.format(0))
421        bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
422
423        return bpf_policy
424
425    def gen_bpf_gt(self, const_str, jt, jf):
426        if self.arch == 'arm':
427            return self.gen_bpf_gt32(const_str, jt, jf)
428        elif self.arch == 'arm64':
429            return self.gen_bpf_gt64(const_str, jt, jf)
430        return []
431
432    def gen_bpf_le(self, const_str, jt, jf):
433        return self.gen_bpf_gt(const_str, jf, jt)
434
435    @staticmethod
436    def gen_bpf_ge32(const_str, jt, jf):
437        bpf_policy = []
438        bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
439        return bpf_policy
440
441    @staticmethod
442    def gen_bpf_ge64(const_str, jt, jf):
443        bpf_policy = []
444        number, digit_flag = str_convert_to_int(const_str)
445
446        hight = int(number / (2**32))
447        low = number & 0xffffffff
448
449        if digit_flag and hight == 0:
450            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
451        else:
452            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
453            bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
454        bpf_policy.append(BPF_LOAD_MEM.format(0))
455        bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
456        return bpf_policy
457
458    def gen_bpf_ge(self, const_str, jt, jf):
459        if self.arch == 'arm':
460            return self.gen_bpf_ge32(const_str, jt, jf)
461        elif self.arch == 'arm64':
462            return self.gen_bpf_ge64(const_str, jt, jf)
463        return []
464
465    def gen_bpf_lt(self, const_str, jt, jf):
466        return self.gen_bpf_ge(const_str, jf, jt)
467
468    @staticmethod
469    def gen_bpf_set32(const_str, jt, jf):
470        bpf_policy = []
471        bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
472        return bpf_policy
473
474    @staticmethod
475    def gen_bpf_set64(const_str, jt, jf):
476        bpf_policy = []
477        bpf_policy.append(BPF_JSET.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
478        bpf_policy.append(BPF_LOAD_MEM.format(0))
479        bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
480        return bpf_policy
481
482    def gen_bpf_set(self, const_str, jt, jf):
483        if self.arch == 'arm':
484            return self.gen_bpf_set32(const_str, jt, jf)
485        elif self.arch == 'arm64':
486            return self.gen_bpf_set64(const_str, jt, jf)
487        return []
488
489    @staticmethod
490    def gen_bpf_valid_syscall_nr(syscall_nr, cur_size):
491        bpf_policy = []
492        bpf_policy.append(BPF_LOAD.format(0))
493        bpf_policy.append(BPF_JEQ.format(syscall_nr, 0, cur_size))
494        return bpf_policy
495
496    def gen_range_list(self, syscall_nr_list):
497        if len(syscall_nr_list) == 0:
498            return
499        self.syscall_nr_range.clear()
500
501        syscall_nr_list_order = sorted(list(syscall_nr_list))
502        range_temp = [syscall_nr_list_order[0], syscall_nr_list_order[0]]
503
504        for i in range(len(syscall_nr_list_order) - 1):
505            if syscall_nr_list_order[i + 1] != syscall_nr_list_order[i] + 1:
506                range_temp[1] = syscall_nr_list_order[i]
507                self.syscall_nr_range.append(range_temp)
508                range_temp = [syscall_nr_list_order[i + 1], syscall_nr_list_order[i + 1]]
509
510        range_temp[1] = syscall_nr_list_order[-1]
511        self.syscall_nr_range.append(range_temp)
512
513    def gen_policy_syscall_nr(self, min_index, max_index, cur_syscall_nr_range):
514        middle_index = (int)((min_index + max_index + 1) / 2)
515
516        if middle_index == min_index:
517            self.syscall_nr_policy_list.append(cur_syscall_nr_range[middle_index][1] + 1)
518            return
519        else:
520            self.syscall_nr_policy_list.append(cur_syscall_nr_range[middle_index][0])
521
522        self.gen_policy_syscall_nr(min_index, middle_index - 1, cur_syscall_nr_range)
523        self.gen_policy_syscall_nr(middle_index, max_index, cur_syscall_nr_range)
524
525    def gen_policy_syscall_nr_list(self, cur_syscall_nr_range):
526        if not cur_syscall_nr_range:
527            return
528        self.syscall_nr_policy_list.clear()
529        self.syscall_nr_policy_list.append(cur_syscall_nr_range[0][0])
530        self.gen_policy_syscall_nr(0, len(cur_syscall_nr_range) - 1, cur_syscall_nr_range)
531
532    def calculate_step(self, index):
533        for i in range(index + 1, len(self.syscall_nr_policy_list)):
534            if self.syscall_nr_policy_list[index] < self.syscall_nr_policy_list[i]:
535                step = i - index
536                break
537        return step - 1
538
539    def nr_range_to_bpf_policy(self, cur_syscall_nr_range):
540        self.gen_policy_syscall_nr_list(cur_syscall_nr_range)
541        syscall_list_len  = len(self.syscall_nr_policy_list)
542
543        if syscall_list_len == 0:
544            return
545
546        self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[0], 0, syscall_list_len))
547
548        range_max_list = [k[1] for k in cur_syscall_nr_range]
549
550        for i in range(1, syscall_list_len):
551            if self.syscall_nr_policy_list[i] - 1 in range_max_list:
552                self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[i], \
553                                        syscall_list_len - i, syscall_list_len - i - 1))
554            else:
555                step = self.calculate_step(i)
556                self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[i], step, 0))
557
558        if self.syscall_nr_policy_list:
559            self.bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_ALLOW'))
560
561    def count_alone_range(self):
562        cnt = 0
563        for item in self.syscall_nr_range:
564            if item[0] == item[1]:
565                cnt = cnt + 1
566        return cnt
567
568    def gen_transverse_bpf_policy(self):
569        if not self.syscall_nr_range:
570            return
571        cnt = self.count_alone_range()
572        total_instruction_num = cnt + (len(self.syscall_nr_range) - cnt) * 2
573        i = 0
574        for item in self.syscall_nr_range:
575            if item[0] == item[1]:
576                if i == total_instruction_num - 1:
577                    self.bpf_policy.append(BPF_JEQ.format(item[0], total_instruction_num - i - 1, 1))
578                else:
579                    self.bpf_policy.append(BPF_JEQ.format(item[0], total_instruction_num - i - 1, 0))
580                i += 1
581            else:
582                self.bpf_policy.append(BPF_JGE.format(item[0], 0, total_instruction_num - i))
583                i += 1
584                if i == total_instruction_num - 1:
585                    self.bpf_policy.append(BPF_JGE.format(item[1] + 1, 1, total_instruction_num - i - 1))
586                else:
587                    self.bpf_policy.append(BPF_JGE.format(item[1] + 1, 0, total_instruction_num - i - 1))
588                i += 1
589
590        self.bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_ALLOW'))
591
592    def gen_bpf_policy(self, syscall_nr_list):
593        self.gen_range_list(syscall_nr_list)
594        range_size = (int)((len(self.syscall_nr_range) - 1) / 127) + 1
595        alone_range_cnt = self.count_alone_range()
596        if alone_range_cnt == len(self.syscall_nr_range):
597            #Scattered distribution
598            self.gen_transverse_bpf_policy()
599            return
600
601        if range_size == 1:
602            self.nr_range_to_bpf_policy(self.syscall_nr_range)
603        else:
604            for i in range(0, range_size):
605                if i == 0:
606                    self.nr_range_to_bpf_policy(self.syscall_nr_range[-127 * (i + 1):])
607                elif i == range_size - 1:
608                    self.nr_range_to_bpf_policy(self.syscall_nr_range[:-127 * i])
609                else:
610                    self.nr_range_to_bpf_policy(self.syscall_nr_range[-127 * (i + 1): -127 * i])
611
612    def load_arg(self, arg_id):
613        # little endian
614        bpf_policy = []
615        if self.arch == 'arm':
616            bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
617        elif self.arch == 'arm64':
618            #low 4 bytes
619            bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
620            bpf_policy.append(BPF_ST.format(0))
621            #high 4 bytes
622            bpf_policy.append(BPF_LOAD.format(20 + arg_id * 8))
623            bpf_policy.append(BPF_ST.format(1))
624
625        return bpf_policy
626
627    def compile_atom(self, atom, cur_size):
628        bpf_policy = []
629        if len(atom) < 6:
630            raise ValidateError('{} format ERROR '.format(atom))
631
632        if atom[0] == '(':
633            bpf_policy += self.compile_mask_equal_atom(atom, cur_size)
634        else:
635            bpf_policy += self.compile_single_operation_atom(atom, cur_size)
636
637        return bpf_policy
638
639    @staticmethod
640    def check_arg_str(arg_atom):
641        arg_str = arg_atom[0:3]
642        if arg_str != 'arg':
643            raise ValidateError('format ERROR, {} is not equal to arg'.format(arg_atom))
644
645        arg_id = int(arg_atom[3])
646        if arg_id not in range(6):
647            raise ValidateError('arg num out of the scope 0~5')
648
649        return arg_id, True
650
651    @staticmethod
652    def check_operation_str(operation_atom):
653        operation_str = operation_atom
654        if operation_str not in operation:
655            operation_str = operation_atom[0]
656            if operation_str not in operation:
657                raise ValidateError('operation not in [<, <=, !=, ==, >, >=, &]')
658        return operation_str, True
659
660    #gen bpf (argn & mask) == value
661    @staticmethod
662    def gen_mask_equal_bpf(arg_id, mask, value, cur_size):
663        bpf_policy = []
664        #high 4 bytes
665        bpf_policy.append(BPF_LOAD.format(20 + arg_id * 8))
666        bpf_policy.append(BPF_AND.format('((uint64_t)' + mask + ') >> 32'))
667        bpf_policy.append(BPF_JEQ.format('((uint64_t)' + value + ') >> 32', 0, cur_size + 3))
668
669        #low 4 bytes
670        bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
671        bpf_policy.append(BPF_AND.format(mask))
672        bpf_policy.append(BPF_JEQ.format(value, cur_size, cur_size + 1))
673
674        return bpf_policy
675
676    def compile_mask_equal_atom(self, atom, cur_size):
677        bpf_policy = []
678        left_brace_pos = atom.find('(')
679        right_brace_pos = atom.rfind(')')
680        inside_brace_content = atom[left_brace_pos + 1: right_brace_pos]
681        outside_brace_content = atom[right_brace_pos + 1:]
682
683        arg_res = self.check_arg_str(inside_brace_content[0:4])
684        if not arg_res[1]:
685            return bpf_policy
686
687        operation_res_inside = self.check_operation_str(inside_brace_content[4:6])
688        if operation_res_inside[0] != '&' or not operation_res_inside[1]:
689            return bpf_policy
690
691        mask = inside_brace_content[4 + len(operation_res_inside[0]):]
692
693        operation_res_outside = self.check_operation_str(outside_brace_content[0:2])
694        if operation_res_outside[0] != '==' or not operation_res_outside[1]:
695            return bpf_policy
696
697        value = outside_brace_content[len(operation_res_outside[0]):]
698
699        return self.gen_mask_equal_bpf(arg_res[0], mask, value, cur_size)
700
701    def compile_single_operation_atom(self, atom, cur_size):
702        bpf_policy = []
703        arg_res = self.check_arg_str(atom[0:4])
704        if not arg_res[1]:
705            return bpf_policy
706
707        operation_res = self.check_operation_str(atom[4:6])
708        if not operation_res[1]:
709            return bpf_policy
710
711        const_str = atom[4 + len(operation_res[0]):]
712
713        if not const_str:
714            return bpf_policy
715
716        bpf_policy += self.load_arg(arg_res[0])
717        bpf_policy += self.operate_func_table.get(operation_res[0])(const_str, 0, cur_size + 1)
718
719        return bpf_policy
720
721    def parse_args_with_condition(self, group):
722        #the priority of && higher than ||
723        atoms = group.split('&&')
724        bpf_policy = []
725        for atom in reversed(atoms):
726            bpf_policy = self.compile_atom(atom, len(bpf_policy)) + bpf_policy
727        return bpf_policy
728
729    def parse_sub_group(self, group):
730        bpf_policy = []
731        group_info = group.split(';')
732        operation_part = group_info[0]
733        return_part = group_info[1]
734        if not return_part.startswith('return'):
735            raise ValidateError('allow list with args do not have return part')
736
737        self.set_return_value(return_part[len('return'):])
738        and_cond_groups = operation_part.split('||')
739        for and_condition_group in and_cond_groups:
740            bpf_policy += self.parse_args_with_condition(and_condition_group)
741            bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(self.return_value)))
742        return bpf_policy
743
744    def parse_else_part(self, else_part):
745        return_value = else_part.split(';')[0][else_part.find('return') + len('return'):]
746        self.set_return_value(return_value)
747
748    def parse_args(self, function_name, line, skip):
749        bpf_policy = []
750        group_info  = line.split('else')
751        else_part = group_info[-1]
752        group = group_info[0].split('elif')
753        for sub_group in group:
754            bpf_policy += self.parse_sub_group(sub_group)
755        self.parse_else_part(else_part)
756        if self.return_value[0:len('ERRNO')] == 'ERRNO':
757            bpf_policy.append(BPF_RET_VALUE.format(self.return_value.replace('ERRNO', ret_str_to_bpf.get('ERRNO'))))
758        else:
759            bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(self.return_value)))
760        syscall_nr = self.function_name_nr_table_dict.get(self.arch).get(function_name)
761        #load syscall nr
762        bpf_policy = self.gen_bpf_valid_syscall_nr(syscall_nr, len(bpf_policy) - skip)  + bpf_policy
763        return bpf_policy
764
765    def gen_bpf_policy_with_args(self, allow_list_with_args, mode, return_value):
766        self.set_gen_mode(mode)
767        skip = 0
768        for line in allow_list_with_args:
769            if self.gen_mode == 1 and line == list(allow_list_with_args)[-1]:
770                skip = 2
771            line = line.replace(' ', '')
772            pos = line.find(':')
773            function_name = line[:pos]
774
775            left_line = line[pos + 1:]
776            if not left_line.startswith('if'):
777                continue
778
779            self.bpf_policy += self.parse_args(function_name, left_line[2:], skip)
780
781    def add_load_syscall_nr(self):
782        self.bpf_policy.append(BPF_LOAD.format(0))
783
784    def add_return_value(self, return_value):
785        if return_value[0:len('ERRNO')] == 'ERRNO':
786            self.bpf_policy.append(BPF_RET_VALUE.format(return_value.replace('ERRNO', ret_str_to_bpf.get('ERRNO'))))
787        else:
788            self.bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(return_value)))
789
790    def add_validate_arch(self, arches, skip_step):
791        if not self.bpf_policy or not self.flag:
792            return
793        bpf_policy = []
794        #load arch
795        bpf_policy.append(BPF_LOAD.format(4))
796        if len(arches) == 2:
797            bpf_policy.append(BPF_JEQ.format(architecture_to_number.get(arches[0]), 3, 0))
798            bpf_policy.append(BPF_JEQ.format(architecture_to_number.get(arches[1]), 0, 1))
799            bpf_policy.append(BPF_JA.format(skip_step))
800            bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_TRAP'))
801        elif len(arches) == 1:
802            bpf_policy.append(BPF_JEQ.format(architecture_to_number.get(arches[0]), 1, 0))
803            bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_TRAP'))
804        else:
805            self.bpf_policy = []
806
807        self.bpf_policy = bpf_policy + self.bpf_policy
808
809
810class AllowBlockList:
811    def __init__(self, filter_name, arch, function_name_nr_table):
812        self.is_valid = False
813        self.arch = arch
814        self.filter_name = filter_name
815        self.reduced_block_list = set()
816        self.function_name_nr_table = function_name_nr_table
817        self.value_function = {
818            'privilegedProcessName': self.update_flag,
819            'allowBlockList': self.update_reduced_block_list,
820        }
821
822    def update_flag(self, name):
823        if self.filter_name == name:
824            self.is_valid = True
825        else:
826            self.is_valid = False
827
828    def update_reduced_block_list(self, function_name):
829        if self.is_valid and is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
830            self.reduced_block_list.add(function_name)
831            return True
832        return False
833
834
835class SeccompPolicyParser:
836    def __init__(self):
837        self.cur_parse_item = ''
838        self.arches = set()
839        self.bpf_generator = GenBpfPolicy()
840        self.seccomp_policy_param = dict()
841        self.reduced_block_list_parm = dict()
842        self.key_process_flag = False
843        self.is_debug = False
844
845    def update_is_debug(self, is_debug):
846        if is_debug == 'false':
847            self.is_debug = False
848        else:
849            self.is_debug = True
850
851    def update_arch(self, target_cpu):
852        if target_cpu == "arm":
853            self.arches.add(target_cpu)
854        elif target_cpu == "arm64":
855            self.arches.add("arm")
856            self.arches.add(target_cpu)
857
858    def update_block_list(self):
859        for arch in supported_architecture:
860            self.seccomp_policy_param.get(arch).blocklist -= self.reduced_block_list_parm.get(arch).reduced_block_list
861
862    def update_parse_item(self, line):
863        item = line[1:]
864        if item in supported_parse_item:
865            self.cur_parse_item = item
866            print_info('start deal with {}'.format(self.cur_parse_item))
867
868    def check_allow_list(self):
869        for arch in self.arches:
870            if not self.seccomp_policy_param.get(arch).check_all_allow_list():
871                self.bpf_generator.set_gen_flag(False)
872
873    def clear_file_syscall_list(self):
874        for arch in self.arches:
875            self.seccomp_policy_param.get(arch).update_final_list()
876        self.cur_parse_item = ''
877        self.cur_arch = ''
878
879    def parse_line(self, line):
880        if not self.cur_parse_item :
881            return
882        line = line.replace(' ', '')
883        pos = line.rfind(';')
884        if pos < 0:
885            for arch in self.arches:
886                if self.key_process_flag:
887                    self.reduced_block_list_parm.get(arch).value_function.get(self.cur_parse_item)(line)
888                else:
889                    self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line)
890        else:
891            arches = line[pos + 1:].split(',')
892            if arches[0] == 'all':
893                arches = supported_architecture
894            for arch in arches:
895                if self.key_process_flag:
896                    self.reduced_block_list_parm.get(arch).value_function.get(self.cur_parse_item)(line[:pos])
897                else:
898                    self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line[:pos])
899
900    def parse_open_file(self, fp):
901        for line in fp:
902            line = line.strip()
903            if not line:
904                continue
905            if line[0] == '#':
906                continue
907            if line[0] == '@':
908                self.update_parse_item(line)
909                continue
910            if line[0] != '@' and self.cur_parse_item == '':
911                continue
912            self.parse_line(line)
913        self.clear_file_syscall_list()
914        self.check_allow_list()
915
916    def parse_file(self, file_path):
917        with open(file_path) as fp:
918            self.parse_open_file(fp)
919
920    def gen_seccomp_policy_of_arch(self, arch):
921        cur_policy_param = self.seccomp_policy_param.get(arch)
922
923        if not cur_policy_param.return_value:
924            raise ValidateError('return value not defined')
925
926        #get final allow_list
927        syscall_nr_allow_list = function_name_to_nr(cur_policy_param.final_allow_list, \
928                                                    cur_policy_param.function_name_nr_table) \
929                                                    | cur_policy_param.self_define_syscall
930        syscall_nr_priority = function_name_to_nr(cur_policy_param.final_priority, \
931                                                  cur_policy_param.function_name_nr_table)
932        self.bpf_generator.update_arch(arch)
933
934        #load syscall nr
935        if syscall_nr_allow_list or syscall_nr_priority:
936            self.bpf_generator.add_load_syscall_nr()
937        self.bpf_generator.gen_bpf_policy(syscall_nr_priority)
938        self.bpf_generator.gen_bpf_policy_with_args(sorted(list(cur_policy_param.final_priority_with_args)), \
939            cur_policy_param.mode, cur_policy_param.return_value)
940        self.bpf_generator.gen_bpf_policy(syscall_nr_allow_list)
941        self.bpf_generator.gen_bpf_policy_with_args(sorted(list(cur_policy_param.final_allow_list_with_args)), \
942            cur_policy_param.mode, cur_policy_param.return_value)
943
944        self.bpf_generator.add_return_value(cur_policy_param.return_value)
945        for line in self.bpf_generator.bpf_policy:
946            if 'SECCOMP_RET_LOG' in line and self.is_debug == False:
947                raise ValidateError("LOG return value is not allowed in user mode")
948
949    def gen_seccomp_policy(self):
950        arches = sorted(list(self.arches))
951        if not arches:
952            return
953        self.gen_seccomp_policy_of_arch(arches[0])
954        skip_step = len(self.bpf_generator.bpf_policy) + 1
955        if len(arches) == 2:
956            self.gen_seccomp_policy_of_arch(arches[1])
957
958        self.bpf_generator.add_validate_arch(arches, skip_step)
959
960    def gen_output_file(self, args):
961        if not self.bpf_generator.bpf_policy:
962            raise ValidateError("bpf_policy is empty!")
963
964        header = textwrap.dedent('''\
965
966            #include <linux/filter.h>
967            #include <stddef.h>
968            #include <linux/seccomp.h>
969            #include <linux/audit.h>
970            ''')
971        extra_header = set()
972        for arch in self.arches:
973            extra_header |= self.seccomp_policy_param.get(arch).head_files
974        extra_header_list =  ['#include ' + i for i in sorted(list(extra_header))]
975        filter_name = 'g_' + args.filter_name + 'SeccompFilter'
976
977        array_name = textwrap.dedent('''
978
979            const struct sock_filter {}[] = {{
980            ''').format(filter_name)
981
982        footer = textwrap.dedent('''\
983
984            }};
985
986            const size_t {} = sizeof({}) / sizeof(struct sock_filter);
987            ''').format(filter_name + 'Size', filter_name)
988
989        content = header + '\n'.join(extra_header_list) + array_name + \
990            '    ' + '\n    '.join(self.bpf_generator.bpf_policy) + footer
991
992        flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
993        modes = stat.S_IWUSR | stat.S_IRUSR | stat.S_IWGRP | stat.S_IRGRP
994        with os.fdopen(os.open(args.dst_file, flags, modes), 'w') as output_file:
995            output_file.write(content)
996
997    def gen_seccomp_policy_code(self, args):
998        if args.target_cpu not in supported_architecture:
999            raise ValidateError('target cpu not supported')
1000        function_name_nr_table_dict = {}
1001        for file_name in args.src_files:
1002            file_name_tmp = file_name.split('/')[-1]
1003            if not file_name_tmp.lower().startswith('libsyscall_to_nr_'):
1004                continue
1005            function_name_nr_table_dict = gen_syscall_nr_table(file_name, function_name_nr_table_dict)
1006
1007
1008        for arch in supported_architecture:
1009            self.seccomp_policy_param.update(
1010                {arch: SeccompPolicyParam(arch, function_name_nr_table_dict.get(arch), args.is_debug)})
1011            self.reduced_block_list_parm.update(
1012                {arch: AllowBlockList(args.filter_name, arch, function_name_nr_table_dict.get(arch))})
1013
1014        self.bpf_generator.update_function_name_nr_table(function_name_nr_table_dict)
1015
1016        self.update_arch(args.target_cpu)
1017        self.update_is_debug(args.is_debug)
1018
1019        for file_name in args.blocklist_file:
1020            if file_name.lower().endswith('blocklist.seccomp.policy'):
1021                self.parse_file(file_name)
1022
1023        for file_name in args.keyprocess_file:
1024            if file_name.lower().endswith('privileged_process.seccomp.policy'):
1025                self.key_process_flag = True
1026                self.parse_file(file_name)
1027                self.key_process_flag = False
1028
1029        self.update_block_list()
1030
1031        for file_name in args.src_files:
1032            if file_name.lower().endswith('.policy'):
1033                self.parse_file(file_name)
1034
1035        if self.bpf_generator.get_gen_flag():
1036            self.gen_seccomp_policy()
1037
1038        if self.bpf_generator.get_gen_flag():
1039            self.gen_output_file(args)
1040
1041
1042def main():
1043    parser = argparse.ArgumentParser(
1044      description='Generates a seccomp-bpf policy')
1045    parser.add_argument('--src-files', type=str, action='append',
1046                        help=('The input files\n'))
1047
1048    parser.add_argument('--blocklist-file', type=str, action='append',
1049                        help=('input basic blocklist file(s)\n'))
1050
1051    parser.add_argument('--keyprocess-file', type=str, action='append',
1052                        help=('input key process file(s)\n'))
1053
1054    parser.add_argument('--dst-file',
1055                        help='The output path for the policy files')
1056
1057    parser.add_argument('--filter-name',  type=str,
1058                        help='Name of seccomp bpf array generated by this script')
1059
1060    parser.add_argument('--target-cpu', type=str,
1061                        help=('please input target cpu arm or arm64\n'))
1062
1063    parser.add_argument('--is-debug', type=str,
1064                        help=('please input is_debug true or false\n'))
1065
1066    args = parser.parse_args()
1067
1068    generator = SeccompPolicyParser()
1069    generator.gen_seccomp_policy_code(args)
1070
1071
1072if __name__ == '__main__':
1073    sys.exit(main())
1074