• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4#
5# Copyright (c) 2022 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import sys
20import argparse
21import textwrap
22import re
23import os
24import stat
25
26supported_parse_item = ['arch', 'labelName', 'priority', 'allowList', 'blockList', 'priorityWithArgs',\
27    'allowListWithArgs', 'headFiles', 'selfDefineSyscall', 'returnValue', 'mode']
28
29function_name_nr_table_dict = {}
30
31BPF_JGE = 'BPF_JUMP(BPF_JMP|BPF_JGE|BPF_K, {}, {}, {}),'
32BPF_JGT = 'BPF_JUMP(BPF_JMP|BPF_JGT|BPF_K, {}, {}, {}),'
33BPF_JEQ = 'BPF_JUMP(BPF_JMP|BPF_JEQ|BPF_K, {}, {}, {}),'
34BPF_JSET = 'BPF_JUMP(BPF_JMP|BPF_JSET|BPF_K, {}, {}, {}),'
35BPF_JA = 'BPF_JUMP(BPF_JMP|BPF_JA, {}, 0, 0),'
36BPF_LOAD = 'BPF_STMT(BPF_LD|BPF_W|BPF_ABS, {}),'
37BPF_LOAD_MEM = 'BPF_STMT(BPF_LD|BPF_MEM, {}),'
38BPF_ST = 'BPF_STMT(BPF_ST, {}),'
39BPF_AND = 'BPF_STMT(BPF_ALU|BPF_AND|BPF_K, {}),'
40BPF_RET_VALUE = 'BPF_STMT(BPF_RET|BPF_K, {}),'
41
42operation = ['<', '<=', '!=', '==', '>', '>=', '&']
43
44ret_str_to_bpf = {
45    'KILL_PROCESS': 'SECCOMP_RET_KILL_PROCESS',
46    'KILL_THREAD': 'SECCOMP_RET_KILL_THREAD',
47    'LOG' : 'SECCOMP_RET_LOG',
48    'ALLOW': 'SECCOMP_RET_ALLOW'
49}
50
51mode_str = {
52    'DEFAULT': 0,
53    'ONLY_CHECK_ARGS': 1
54}
55
56
57def is_hex_digit(s):
58    try:
59        int(s, 16)
60        return True
61
62    except ValueError:
63        return False
64
65
66def str_convert_to_int(s):
67    number = -1
68    digit_flag = False
69
70    if s.isdigit() :
71        number = int(s)
72        digit_flag = True
73
74    elif is_hex_digit(s):
75        number = int(s, 16)
76        digit_flag = True
77
78    return number, digit_flag
79
80
81class SeccompPolicyParam:
82    def __init__(self, arch):
83        self.arch = arch
84        self.priority = set()
85        self.allow_list = set()
86        self.blocklist = set()
87        self.priority_with_args = set()
88        self.allow_list_with_args = set()
89        self.head_files = set()
90        self.self_define_syscall = set()
91        self.final_allow_list = set()
92        self.final_priority = set()
93        self.final_priority_with_args = set()
94        self.final_allow_list_with_args = set()
95        self.return_value = ''
96        self.mode = 'DEFAULT'
97        self.function_name_nr_table = function_name_nr_table_dict.get(arch)
98        self.value_function = {
99            'priority': self.update_priority,
100            'allowList': self.update_allow_list,
101            'blockList': self.update_blocklist,
102            'allowListWithArgs': self.update_allow_list_with_args,
103            'priorityWithArgs': self.update_priority_with_args,
104            'headFiles': self.update_head_files,
105            'selfDefineSyscall': self.update_self_define_syscall,
106            'returnValue': self.update_return_value,
107            'mode': self.update_mode
108        }
109
110    def clear_list(self):
111        self.priority.clear()
112        self.allow_list.clear()
113        self.blocklist.clear()
114        self.allow_list_with_args.clear()
115        self.priority_with_args.clear()
116        if self.mode == 'ONLY_CHECK_ARGS':
117            self.final_allow_list.clear()
118            self.final_priority.clear()
119
120    def function_name_to_nr(self, function_name_list):
121        return set(self.function_name_nr_table[function_name] for function_name \
122        in function_name_list if function_name in self.function_name_nr_table)
123
124    def is_function_name_exist(self, function_name):
125        if function_name in self.function_name_nr_table:
126            return True
127        else:
128            print('[ERROR] {} not exsit in {} function_name_nr_table Table'.format(function_name, self.arch))
129            return False
130
131    def update_priority(self, function_name):
132        if self.is_function_name_exist(function_name):
133            self.priority.add(function_name)
134            return True
135        return False
136
137    def update_allow_list(self, function_name):
138        if self.is_function_name_exist(function_name):
139            self.allow_list.add(function_name)
140            return True
141        return False
142
143    def update_blocklist(self, function_name):
144        if self.is_function_name_exist(function_name):
145            self.blocklist.add(function_name)
146            return True
147        return False
148
149    def update_priority_with_args(self, function_name_with_args):
150        function_name = function_name_with_args[:function_name_with_args.find(':')]
151        function_name = function_name.strip()
152        if self.is_function_name_exist(function_name):
153            self.priority_with_args.add(function_name_with_args)
154            return True
155        return False
156
157    def update_allow_list_with_args(self, function_name_with_args):
158        function_name = function_name_with_args[:function_name_with_args.find(':')]
159        function_name = function_name.strip()
160        if self.is_function_name_exist(function_name):
161            self.allow_list_with_args.add(function_name_with_args)
162            return True
163        return False
164
165    def update_head_files(self, head_files):
166        if len(head_files) > 2 and (head_files[0] == '\"' and head_files[-1] == '\"') or \
167            (head_files[0] == '<' and head_files[-1] == '>'):
168            self.head_files.add(head_files)
169            return True
170
171        print('[ERROR] {} is not legal by headFiles format'.format(head_files))
172        return False
173
174    def update_self_define_syscall(self, self_define_syscall):
175        nr, digit_flag = str_convert_to_int(self_define_syscall)
176        if digit_flag and nr not in self.function_name_nr_table.values():
177            self.self_define_syscall.add(nr)
178            return True
179
180        print("[ERROR] {} is not a number or {} is already used by ohter \
181            syscall".format(self_define_syscall, self_define_syscall))
182        return False
183
184    def update_return_value(self, return_str):
185        if return_str in ret_str_to_bpf:
186            self.return_value = return_str
187            return True
188
189        print('[ERROR] {} not in [KILL_RPOCESS, KILL_THREAD, LOG]'.format(return_str))
190        return False
191
192    def update_mode(self, mode):
193        if mode in mode_str.keys():
194            self.mode = mode
195            return True
196        print('[ERROR] {} not in [DEFAULT, ONLY_CHECK_ARGS]'.format(mode_str))
197        return False
198
199    def update_final_list(self):
200        #remove duplicate function_name
201        self.priority_with_args = set(item
202                                      for item in self.priority_with_args
203                                      if item[:item.find(':')] not in self.blocklist)
204        priority_function_name_list_with_args = set(item[:item.find(':')] for item in self.priority_with_args)
205        self.allow_list_with_args = set(item
206                                        for item in self.allow_list_with_args
207                                        if item[:item.find(':')] not in self.blocklist and
208                                        item[:item.find(':')] not in priority_function_name_list_with_args)
209        function_name_list_with_args = set(item[:item.find(':')] for item in self.allow_list_with_args)
210
211        self.final_allow_list |= self.allow_list - self.blocklist - self.priority - function_name_list_with_args - \
212                                    priority_function_name_list_with_args
213        self.final_priority |= self.priority - self.blocklist - function_name_list_with_args - \
214                                priority_function_name_list_with_args
215        self.final_allow_list_with_args |= self.allow_list_with_args
216        self.final_priority_with_args |= self.priority_with_args
217        block_nr_list = self.function_name_to_nr(self.blocklist)
218        self.self_define_syscall = self.self_define_syscall - block_nr_list
219        self.clear_list()
220
221
222class GenBpfPolicy:
223    def __init__(self):
224        self.arch = ''
225        self.syscall_nr_range = []
226        self.bpf_policy = []
227        self.syscall_nr_policy_list = []
228        self.function_name_nr_table = {}
229        self.gen_mode = 0
230        self.flag = True
231        self.return_value = ''
232        self.operate_func_table = {
233            '<' : self.gen_bpf_lt,
234            '<=': self.gen_bpf_le,
235            '==': self.gen_bpf_eq,
236            '!=': self.gen_bpf_ne,
237            '>' : self.gen_bpf_gt,
238            '>=': self.gen_bpf_ge,
239            '&' : self.gen_bpf_set,
240        }
241
242    def update_arch(self, arch):
243        self.arch = arch
244        self.function_name_nr_table = function_name_nr_table_dict.get(arch)
245        self.syscall_nr_range = []
246        self.syscall_nr_policy_list = []
247
248    def clear_bpf_policy(self):
249        self.bpf_policy.clear()
250
251    def get_gen_flag(self):
252        return self.flag
253
254    def set_gen_flag(self, flag):
255        if flag:
256            self.flag = True
257        else:
258            self.flag = False
259
260    def set_gen_mode(self, mode):
261        self.gen_mode = mode_str.get(mode)
262
263    def set_return_value(self, return_value):
264        if return_value not in ret_str_to_bpf:
265            self.set_gen_mode(False)
266            return
267
268        self.return_value = return_value
269
270    @staticmethod
271    def gen_bpf_eq32(const_str, jt, jf):
272        bpf_policy = []
273        bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
274        return bpf_policy
275
276    @staticmethod
277    def gen_bpf_eq64(const_str, jt, jf):
278        bpf_policy = []
279        bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
280        bpf_policy.append(BPF_LOAD_MEM.format(0))
281        bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
282        return bpf_policy
283
284    def gen_bpf_eq(self, const_str, jt, jf):
285        if self.arch == 'arm':
286            return self.gen_bpf_eq32(const_str, jt, jf)
287        elif self.arch == 'arm64':
288            return self.gen_bpf_eq64(const_str, jt, jf)
289        return []
290
291    def gen_bpf_ne(self, const_str, jt, jf):
292        return self.gen_bpf_eq(const_str, jf, jt)
293
294    @staticmethod
295    def gen_bpf_gt32(const_str, jt, jf):
296        bpf_policy = []
297        bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
298        return bpf_policy
299
300    @staticmethod
301    def gen_bpf_gt64(const_str, jt, jf):
302        bpf_policy = []
303        number, digit_flag = str_convert_to_int(const_str)
304
305        hight = int(number / (2**32))
306        low = number & 0xffffffff
307
308        if digit_flag and hight == 0:
309            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
310        else:
311            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
312            bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
313
314        bpf_policy.append(BPF_LOAD_MEM.format(0))
315        bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
316
317        return bpf_policy
318
319    def gen_bpf_gt(self, const_str, jt, jf):
320        if self.arch == 'arm':
321            return self.gen_bpf_gt32(const_str, jt, jf)
322        elif self.arch == 'arm64':
323            return self.gen_bpf_gt64(const_str, jt, jf)
324        return []
325
326    def gen_bpf_le(self, const_str, jt, jf):
327        return self.gen_bpf_gt(const_str, jf, jt)
328
329    @staticmethod
330    def gen_bpf_ge32(const_str, jt, jf):
331        bpf_policy = []
332        bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
333        return bpf_policy
334
335    @staticmethod
336    def gen_bpf_ge64(const_str, jt, jf):
337        bpf_policy = []
338        number, digit_flag = str_convert_to_int(const_str)
339
340        hight = int(number / (2**32))
341        low = number & 0xffffffff
342
343        if digit_flag and hight == 0:
344            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
345        else:
346            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
347            bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
348        bpf_policy.append(BPF_LOAD_MEM.format(0))
349        bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
350        return bpf_policy
351
352    def gen_bpf_ge(self, const_str, jt, jf):
353        if self.arch == 'arm':
354            return self.gen_bpf_ge32(const_str, jt, jf)
355        elif self.arch == 'arm64':
356            return self.gen_bpf_ge64(const_str, jt, jf)
357        return []
358
359    def gen_bpf_lt(self, const_str, jt, jf):
360        return self.gen_bpf_ge(const_str, jf, jt)
361
362    @staticmethod
363    def gen_bpf_set32(const_str, jt, jf):
364        bpf_policy = []
365        bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
366        return bpf_policy
367
368    @staticmethod
369    def gen_bpf_set64(const_str, jt, jf):
370        bpf_policy = []
371        bpf_policy.append(BPF_JSET.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
372        bpf_policy.append(BPF_LOAD_MEM.format(0))
373        bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
374        return bpf_policy
375
376    def gen_bpf_set(self, const_str, jt, jf):
377        if self.arch == 'arm':
378            return self.gen_bpf_set32(const_str, jt, jf)
379        elif self.arch == 'arm64':
380            return self.gen_bpf_set64(const_str, jt, jf)
381        return []
382
383    @staticmethod
384    def gen_bpf_valid_syscall_nr(syscall_nr, cur_size):
385        bpf_policy = []
386        bpf_policy.append(BPF_LOAD.format(0))
387        bpf_policy.append(BPF_JEQ.format(syscall_nr, 0, cur_size))
388        return bpf_policy
389
390    def gen_range_list(self, syscall_nr_list):
391        if len(syscall_nr_list) == 0:
392            return
393        self.syscall_nr_range.clear()
394
395        syscall_nr_list_order = sorted(list(syscall_nr_list))
396        range_temp = [syscall_nr_list_order[0], syscall_nr_list_order[0]]
397
398        for i in range(len(syscall_nr_list_order) - 1):
399            if syscall_nr_list_order[i + 1] != syscall_nr_list_order[i] + 1:
400                range_temp[1] = syscall_nr_list_order[i]
401                self.syscall_nr_range.append(range_temp)
402                range_temp = [syscall_nr_list_order[i + 1], syscall_nr_list_order[i + 1]]
403
404        range_temp[1] = syscall_nr_list_order[-1]
405        self.syscall_nr_range.append(range_temp)
406
407    def gen_policy_syscall_nr(self, min_index, max_index, cur_syscall_nr_range):
408        middle_index = (int)((min_index + max_index + 1) / 2)
409
410        if middle_index == min_index:
411            self.syscall_nr_policy_list.append(cur_syscall_nr_range[middle_index][1] + 1)
412            return
413        else:
414            self.syscall_nr_policy_list.append(cur_syscall_nr_range[middle_index][0])
415
416        self.gen_policy_syscall_nr(min_index, middle_index - 1, cur_syscall_nr_range)
417        self.gen_policy_syscall_nr(middle_index, max_index, cur_syscall_nr_range)
418
419    def gen_policy_syscall_nr_list(self, cur_syscall_nr_range):
420        if not cur_syscall_nr_range:
421            return
422        self.syscall_nr_policy_list.clear()
423        self.syscall_nr_policy_list.append(cur_syscall_nr_range[0][0])
424        self.gen_policy_syscall_nr(0, len(cur_syscall_nr_range) - 1, cur_syscall_nr_range)
425
426    def calculate_step(self, index):
427        for i in range(index + 1, len(self.syscall_nr_policy_list)):
428            if self.syscall_nr_policy_list[index] < self.syscall_nr_policy_list[i]:
429                step = i - index
430                break
431        return step - 1
432
433    def nr_range_to_bpf_policy(self, cur_syscall_nr_range):
434        self.gen_policy_syscall_nr_list(cur_syscall_nr_range)
435        syscall_list_len  = len(self.syscall_nr_policy_list)
436
437        if syscall_list_len == 0:
438            return
439
440        self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[0], 0, syscall_list_len))
441
442        range_max_list = [k[1] for k in cur_syscall_nr_range]
443
444        for i in range(1, syscall_list_len):
445            if self.syscall_nr_policy_list[i] - 1 in range_max_list:
446                self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[i], \
447                                        syscall_list_len - i, syscall_list_len - i - 1))
448            else:
449                step = self.calculate_step(i)
450                self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[i], step, 0))
451
452        if self.syscall_nr_policy_list:
453            self.bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_ALLOW'))
454
455    def count_alone_range(self):
456        cnt = 0
457        for item in self.syscall_nr_range:
458            if item[0] == item[1]:
459                cnt = cnt + 1
460        return cnt
461
462    def gen_transverse_bpf_policy(self):
463        if not self.syscall_nr_range:
464            return
465        cnt = self.count_alone_range()
466        total_instruction_num = cnt + (len(self.syscall_nr_range) - cnt) * 2
467        i = 0
468        for item in self.syscall_nr_range:
469            if item[0] == item[1]:
470                if i == total_instruction_num - 1:
471                    self.bpf_policy.append(BPF_JEQ.format(item[0], total_instruction_num - i - 1, 1))
472                else:
473                    self.bpf_policy.append(BPF_JEQ.format(item[0], total_instruction_num - i - 1, 0))
474                i += 1
475            else:
476                self.bpf_policy.append(BPF_JGE.format(item[0], 0, total_instruction_num - i))
477                i += 1
478                if i == total_instruction_num - 1:
479                    self.bpf_policy.append(BPF_JGE.format(item[1] + 1, 1, total_instruction_num - i - 1))
480                else:
481                    self.bpf_policy.append(BPF_JGE.format(item[1] + 1, 0, total_instruction_num - i - 1))
482                i += 1
483
484        self.bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_ALLOW'))
485
486    def gen_bpf_policy(self, syscall_nr_list):
487        self.gen_range_list(syscall_nr_list)
488        range_size = (int)((len(self.syscall_nr_range) - 1) / 127) + 1
489        alone_range_cnt = self.count_alone_range()
490        if alone_range_cnt == len(self.syscall_nr_range):
491            #Scattered distribution
492            self.gen_transverse_bpf_policy()
493            return
494
495        if range_size == 1:
496            self.nr_range_to_bpf_policy(self.syscall_nr_range)
497        else:
498            for i in range(0, range_size):
499                if i == 0:
500                    self.nr_range_to_bpf_policy(self.syscall_nr_range[-127 * (i + 1):])
501                elif i == range_size - 1:
502                    self.nr_range_to_bpf_policy(self.syscall_nr_range[:-127 * i])
503                else:
504                    self.nr_range_to_bpf_policy(self.syscall_nr_range[-127 * (i + 1): -127 * i])
505
506    def load_arg(self, arg_id):
507        # little endian
508        bpf_policy = []
509        if self.arch == 'arm':
510            bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
511        elif self.arch == 'arm64':
512            #low 4 bytes
513            bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
514            bpf_policy.append(BPF_ST.format(0))
515            #high 4 bytes
516            bpf_policy.append(BPF_LOAD.format(20 + arg_id * 8))
517            bpf_policy.append(BPF_ST.format(1))
518
519        return bpf_policy
520
521    def compile_atom(self, atom, cur_size):
522        bpf_policy = []
523        if len(atom) < 6:
524            print('[ERROR] {} format ERROR '.format(atom))
525            self.flag = False
526            return bpf_policy
527
528        if atom[0] == '(':
529            bpf_policy += self.compile_mask_equal_atom(atom, cur_size)
530        else:
531            bpf_policy += self.compile_single_operation_atom(atom, cur_size)
532
533        return bpf_policy
534
535    def check_arg_str(self, arg_atom):
536        arg_str = arg_atom[0:3]
537        if arg_str != 'arg':
538            print('[ERROR] format ERROR, {} is not equal to arg'.format(arg_atom))
539            self.flag = False
540            return -1, False
541
542        arg_id = int(arg_atom[3])
543        if arg_id not in range(6):
544            print('[ERROR] arg num out of the scope 0~5')
545            self.flag = False
546            return -1, False
547
548        return arg_id, True
549
550    def check_operation_str(self, operation_atom):
551        operation_str = operation_atom
552        if operation_str not in operation:
553            operation_str = operation_atom[0]
554            if operation_str not in operation:
555                print('[ERROR] operation not in [<, <=, !=, ==, >, >=, &]')
556                self.flag = False
557                return '', False
558        return operation_str, True
559
560    #gen bpf (argn & mask) == value
561    @staticmethod
562    def gen_mask_equal_bpf(arg_id, mask, value, cur_size):
563        bpf_policy = []
564        #high 4 bytes
565        bpf_policy.append(BPF_LOAD.format(20 + arg_id * 8))
566        bpf_policy.append(BPF_AND.format('((uint64_t)' + mask + ') >> 32'))
567        bpf_policy.append(BPF_JEQ.format('((uint64_t)' + value + ') >> 32', 0, cur_size + 3))
568
569        #low 4 bytes
570        bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
571        bpf_policy.append(BPF_AND.format(mask))
572        bpf_policy.append(BPF_JEQ.format(value, cur_size, cur_size + 1))
573
574        return bpf_policy
575
576    def compile_mask_equal_atom(self, atom, cur_size):
577        bpf_policy = []
578        left_brace_pos = atom.find('(')
579        right_brace_pos = atom.rfind(')')
580        inside_brace_content = atom[left_brace_pos + 1: right_brace_pos]
581        outside_brace_content = atom[right_brace_pos + 1:]
582
583        arg_res = self.check_arg_str(inside_brace_content[0:4])
584        if not arg_res[1]:
585            return bpf_policy
586
587        operation_res_inside = self.check_operation_str(inside_brace_content[4:6])
588        if operation_res_inside[0] != '&' or not operation_res_inside[1]:
589            return bpf_policy
590
591        mask = inside_brace_content[4 + len(operation_res_inside[0]):]
592
593        operation_res_outside = self.check_operation_str(outside_brace_content[0:2])
594        if operation_res_outside[0] != '==' or not operation_res_outside[1]:
595            return bpf_policy
596
597        value = outside_brace_content[len(operation_res_outside[0]):]
598
599        return self.gen_mask_equal_bpf(arg_res[0], mask, value, cur_size)
600
601    def compile_single_operation_atom(self, atom, cur_size):
602        bpf_policy = []
603        arg_res = self.check_arg_str(atom[0:4])
604        if not arg_res[1]:
605            return bpf_policy
606
607        operation_res = self.check_operation_str(atom[4:6])
608        if not operation_res[1]:
609            return bpf_policy
610
611        const_str = atom[4 + len(operation_res[0]):]
612
613        if not const_str:
614            return bpf_policy
615
616        bpf_policy += self.load_arg(arg_res[0])
617        bpf_policy += self.operate_func_table.get(operation_res[0])(const_str, 0, cur_size + 1)
618
619        return bpf_policy
620
621    def parse_args_with_condition(self, group):
622        #the priority of && higher than ||
623        atoms = group.split('&&')
624        bpf_policy = []
625        for atom in reversed(atoms):
626            bpf_policy = self.compile_atom(atom, len(bpf_policy)) + bpf_policy
627        return bpf_policy
628
629    def parse_sub_group(self, group):
630        bpf_policy = []
631        group_info = group.split(';')
632        operation_part = group_info[0]
633        return_part = group_info[1]
634        if not return_part.startswith('return'):
635            self.set_gen_flag(False)
636            return bpf_policy
637        self.set_return_value(return_part[len('return'):])
638        and_cond_groups = operation_part.split('||')
639        for and_condition_group in and_cond_groups:
640            bpf_policy += self.parse_args_with_condition(and_condition_group)
641            bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(self.return_value)))
642        return bpf_policy
643
644    def parse_else_part(self, else_part):
645        return_value = else_part.split(';')[0][else_part.find('return') + len('return'):]
646        self.set_return_value(return_value)
647
648    def parse_args(self, function_name, line, skip):
649        bpf_policy = []
650        group_info  = line.split('else')
651        else_part = group_info[-1]
652        group = group_info[0].split('elif')
653        for sub_group in group:
654            bpf_policy += self.parse_sub_group(sub_group)
655        self.parse_else_part(else_part)
656        bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(self.return_value)))
657        syscall_nr = self.function_name_nr_table.get(function_name)
658        #load syscall nr
659        bpf_policy = self.gen_bpf_valid_syscall_nr(syscall_nr, len(bpf_policy) - skip)  + bpf_policy
660        return bpf_policy
661
662    def gen_bpf_policy_with_args(self, allow_list_with_args, mode, return_value):
663        self.set_gen_mode(mode)
664        skip = 0
665        for line in allow_list_with_args:
666            if self.gen_mode == 1 and line == list(allow_list_with_args)[-1]:
667                skip = 2
668            line = line.replace(' ', '')
669            pos = line.find(':')
670            function_name = line[:pos]
671
672            left_line = line[pos + 1:]
673            if not left_line.startswith('if'):
674                continue
675
676            self.bpf_policy += self.parse_args(function_name, left_line[2:], skip)
677
678    def add_load_syscall_nr(self):
679        self.bpf_policy.append(BPF_LOAD.format(0))
680
681    def add_return_value(self, return_value):
682        self.bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(return_value)))
683
684    def add_validate_arch(self, arches, skip_step):
685        if not self.bpf_policy or not self.flag:
686            return
687        bpf_policy = []
688        #load arch
689        bpf_policy.append(BPF_LOAD.format(4))
690        if len(arches) == 2:
691            bpf_policy.append(BPF_JEQ.format('AUDIT_ARCH_AARCH64', 3, 0))
692            bpf_policy.append(BPF_JEQ.format('AUDIT_ARCH_ARM', 0, 1))
693            bpf_policy.append(BPF_JA.format(skip_step))
694            bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_KILL_PROCESS'))
695        elif 'arm' in arches:
696            bpf_policy.append(BPF_JEQ.format('AUDIT_ARCH_ARM', 1, 0))
697            bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_KILL_PROCESS'))
698        elif 'arm64' in arches:
699            bpf_policy.append(BPF_JEQ.format('AUDIT_ARCH_AARCH64', 1, 0))
700            bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_KILL_PROCESS'))
701        else:
702            self.bpf_policy = []
703
704        self.bpf_policy = bpf_policy + self.bpf_policy
705
706
707class SeccompPolicyParser:
708    def __init__(self):
709        self.cur_parse_item = ''
710        self.cur_arch = ''
711        self.arches = set()
712        self.bpf_generator = GenBpfPolicy()
713        self.seccomp_policy_param_arm = None
714        self.seccomp_policy_param_arm64 = None
715        self.cur_policy_param = None
716
717    def update_arch(self, arch):
718        if arch in ['arm', 'arm64'] :
719            self.cur_arch = arch
720            print("[INFO] start deal with {} scope".format(self.cur_arch))
721            self.arches.add(arch)
722            if self.cur_arch == 'arm':
723                self.cur_policy_param = self.seccomp_policy_param_arm
724            elif self.cur_arch == 'arm64':
725                self.cur_policy_param = self.seccomp_policy_param_arm64
726        else:
727            print('[ERROR] {} not in [arm arm64]'.format(arch))
728            self.bpf_generator.set_gen_flag(False)
729
730    def update_parse_item(self, line):
731        item = line[1:]
732        if item in supported_parse_item:
733            self.cur_parse_item = item
734            print('[INFO] start deal with {}'.format(self.cur_parse_item))
735
736    def clear_file_syscall_list(self):
737        self.seccomp_policy_param_arm.update_final_list()
738        self.seccomp_policy_param_arm64.update_final_list()
739        self.cur_parse_item = ''
740        self.cur_arch = ''
741
742    def parse_line(self, line):
743        if not self.cur_parse_item :
744            return
745        if self.cur_parse_item == 'arch':
746            self.update_arch(line)
747        elif not self.cur_arch :
748            return
749        else:
750            if not self.cur_policy_param.value_function.get(self.cur_parse_item)(line):
751                self.bpf_generator.set_gen_flag(False)
752
753    def parse_open_file(self, fp):
754        for line in fp:
755            line = line.strip()
756            if not line:
757                continue
758            if line[0] == '#':
759                continue
760            if line[0] == '@':
761                self.update_parse_item(line)
762                continue
763            if line[0] != '@' and self.cur_arch == '' and self.cur_parse_item == '':
764                continue
765            self.parse_line(line)
766        self.clear_file_syscall_list()
767
768    def parse_file(self, file_path):
769        with open(file_path) as fp:
770            self.parse_open_file(fp)
771
772    def gen_seccomp_policy_of_arch(self):
773        if not self.cur_policy_param.return_value:
774            print('[ERROR] return value not defined')
775            self.bpf_generator.set_gen_flag(False)
776            return
777
778        #get final allow_list
779        syscall_nr_allow_list = self.cur_policy_param.function_name_to_nr(self.cur_policy_param.final_allow_list) | \
780            self.cur_policy_param.self_define_syscall
781        syscall_nr_priority = self.cur_policy_param.function_name_to_nr(self.cur_policy_param.final_priority)
782        self.bpf_generator.update_arch(self.cur_arch)
783
784        #load syscall nr
785        if syscall_nr_allow_list or syscall_nr_priority:
786            self.bpf_generator.add_load_syscall_nr()
787        self.bpf_generator.gen_bpf_policy(syscall_nr_priority)
788        self.bpf_generator.gen_bpf_policy_with_args(self.cur_policy_param.final_priority_with_args, \
789            self.cur_policy_param.mode, self.cur_policy_param.return_value)
790        self.bpf_generator.gen_bpf_policy(syscall_nr_allow_list)
791        self.bpf_generator.gen_bpf_policy_with_args(self.cur_policy_param.final_allow_list_with_args, \
792            self.cur_policy_param.mode, self.cur_policy_param.return_value)
793
794        self.bpf_generator.add_return_value(self.cur_policy_param.return_value)
795
796    def gen_seccomp_policy(self):
797        if 'arm64' in self.arches:
798            self.update_arch('arm64')
799            self.gen_seccomp_policy_of_arch()
800
801        skip_step = len(self.bpf_generator.bpf_policy) + 1
802
803        if 'arm' in self.arches:
804            self.update_arch('arm')
805            self.gen_seccomp_policy_of_arch()
806        self.bpf_generator.add_validate_arch(self.arches, skip_step)
807
808    def gen_output_file(self, args):
809        if not self.bpf_generator.bpf_policy:
810            print("[ERROR] bpf_policy is empty!")
811            return
812
813        header = textwrap.dedent('''\
814
815            #include <linux/filter.h>
816            #include <stddef.h>
817            #include <linux/seccomp.h>
818            #include <linux/audit.h>
819            ''')
820        extra_header = set()
821        extra_header = self.seccomp_policy_param_arm.head_files | self.seccomp_policy_param_arm64.head_files
822        extra_header_list =  ['#include ' + i for i in extra_header]
823
824        array_name = textwrap.dedent('''
825
826            const struct sock_filter {}[] = {{
827            ''').format(args.bpfArrayName)
828
829        footer = textwrap.dedent('''\
830
831            }};
832
833            const size_t {} = sizeof({}) / sizeof(struct sock_filter);
834            ''').format(args.bpfArrayName + 'Size', args.bpfArrayName)
835
836        content = header + '\n'.join(extra_header_list) + array_name + \
837            '    ' + '\n    '.join(self.bpf_generator.bpf_policy) + footer
838
839
840        flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
841        modes = stat.S_IWUSR | stat.S_IRUSR | stat.S_IWGRP | stat.S_IRGRP
842        with os.fdopen(os.open(args.dstfile, flags, modes), 'w') as output_file:
843            output_file.write(content)
844
845    @staticmethod
846    def filter_syscalls_nr(name_to_nr):
847        syscalls = {}
848        for syscall_name, nr in name_to_nr.items():
849            if not syscall_name.startswith("__NR_") and not syscall_name.startswith("__ARM_NR_"):
850                continue
851
852            if syscall_name.startswith("__NR_arm_"):
853                syscall_name = syscall_name[len("__NR_arm_"):]
854            elif syscall_name.startswith("__NR_"):
855                syscall_name = syscall_name[len("__NR_"):]
856            elif syscall_name.startswith("__ARM_NR_"):
857                syscall_name = syscall_name[len("__ARM_NR_"):]
858
859            syscalls[syscall_name] = nr
860
861        return syscalls
862
863    def parse_syscall_file(self, file_name):
864        const_pattern = re.compile(
865            r'^\s*#define\s+([A-Za-z_][A-Za-z0-9_]+)\s+(.+)\s*$')
866        mark_pattern = re.compile(r'\b[A-Za-z_][A-Za-z0-9_]+\b')
867        name_to_nr = {}
868        with open(file_name) as f:
869            for line in f:
870                k = const_pattern.match(line)
871                if k is None:
872                    continue
873                try:
874                    name = k.group(1)
875                    nr = eval(mark_pattern.sub(lambda x: str(name_to_nr.get(x.group(0))),
876                                            k.group(2)))
877
878                    name_to_nr[name] = nr
879                except(KeyError, SyntaxError, NameError, TypeError):
880                    continue
881
882        return self.filter_syscalls_nr(name_to_nr)
883
884    def gen_syscall_nr_table(self, file_name):
885        s = re.search(r"libsyscall_to_nr_([^/]+)", file_name)
886        function_name_nr_table_dict[str(s.group(1))] = self.parse_syscall_file(file_name)
887        if str(s.group(1)) not in function_name_nr_table_dict.keys():
888            return False
889        return True
890
891    def gen_seccomp_policy_code(self, args):
892        for file_name in args.srcfiles:
893            file_name_tmp = file_name.split('/')[-1]
894            if not file_name_tmp.lower().startswith('libsyscall_to_nr_'):
895                continue
896            if not self.gen_syscall_nr_table(file_name):
897                return
898        self.seccomp_policy_param_arm = SeccompPolicyParam('arm')
899        self.seccomp_policy_param_arm64 = SeccompPolicyParam('arm64')
900        for file_name in args.srcfiles:
901            if file_name.lower().endswith('.policy'):
902                self.parse_file(file_name)
903        self.gen_seccomp_policy()
904        if self.bpf_generator.get_gen_flag():
905            self.gen_output_file(args)
906
907
908def main():
909    parser = argparse.ArgumentParser(
910      description='Generates a seccomp-bpf policy')
911    parser.add_argument('--srcfiles', type=str, action='append',
912                        help=('The input files\n'))
913    parser.add_argument('--dstfile',
914                        help='The output path for the policy files')
915
916    parser.add_argument('--bpfArrayName',  type=str,
917                        help='Name of seccomp bpf array generated by this script')
918
919    args = parser.parse_args()
920
921    generator = SeccompPolicyParser()
922    generator.gen_seccomp_policy_code(args)
923
924
925if __name__ == '__main__':
926    sys.exit(main())
927