• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2
3import copy
4import re
5import sys
6import SELinuxNeverallowTestFrame
7
8usage = "Usage: ./SELinuxNeverallowTestGen.py <input policy file> <output cts java source>"
9
10
11class NeverallowRule:
12    statement = ''
13    depths = None
14
15    def __init__(self, statement, depths):
16        self.statement = statement
17        self.depths = copy.deepcopy(depths)
18
19# full-Treble only tests are inside sections delimited by BEGIN_{section} and
20# END_{section} comments.
21sections = [
22    "TREBLE_ONLY",
23    "COMPATIBLE_PROPERTY_ONLY",
24    "LAUNCHING_WITH_R_ONLY",
25    "LAUNCHING_WITH_S_ONLY",
26]
27
28# extract_neverallow_rules - takes an intermediate policy file and pulls out the
29# neverallow rules by taking all of the non-commented text between the 'neverallow'
30# keyword and a terminating ';'
31# returns: a list of rules
32def extract_neverallow_rules(policy_file):
33    with open(policy_file, 'r') as in_file:
34        policy_str = in_file.read()
35
36        # uncomment section delimiter lines
37        section_headers = '|'.join(['BEGIN_%s|END_%s' % (s, s) for s in sections])
38        remaining = re.sub(
39            r'^\s*#\s*(' + section_headers + ')',
40            r'\1',
41            policy_str,
42            flags = re.M)
43        # remove comments
44        remaining = re.sub(r'#.+?$', r'', remaining, flags = re.M)
45        # match neverallow rules
46        lines = re.findall(
47            r'^\s*(neverallow\s.+?;|' + section_headers + ')',
48            remaining,
49            flags = re.M |re.S)
50
51        # extract neverallow rules from the remaining lines
52        rules = list()
53        depths = dict()
54        for section in sections:
55            depths[section] = 0
56        for line in lines:
57            is_header = False
58            for section in sections:
59                if line.startswith("BEGIN_%s" % section):
60                    depths[section] += 1
61                    is_header = True
62                    break
63                elif line.startswith("END_%s" % section):
64                    if depths[section] < 1:
65                        exit("ERROR: END_%s outside of %s section" % (section, section))
66                    depths[section] -= 1
67                    is_header = True
68                    break
69            if not is_header:
70                rule = NeverallowRule(line, depths)
71                rules.append(rule)
72
73        for section in sections:
74            if depths[section] != 0:
75                exit("ERROR: end of input while inside %s section" % section)
76
77        return rules
78
79# neverallow_rule_to_test - takes a neverallow statement and transforms it into
80# the output necessary to form a cts unit test in a java source file.
81# returns: a string representing a generic test method based on this rule.
82def neverallow_rule_to_test(rule, test_num):
83    squashed_neverallow = rule.statement.replace("\n", " ")
84    method  = SELinuxNeverallowTestFrame.src_method
85    method = method.replace("testNeverallowRules()",
86        "testNeverallowRules" + str(test_num) + "()")
87    method = method.replace("$NEVERALLOW_RULE_HERE$", squashed_neverallow)
88    for section in sections:
89        method = method.replace(
90            "$%s_BOOL_HERE$" % section,
91            "true" if rule.depths[section] else "false")
92    return method
93
94if __name__ == "__main__":
95    # check usage
96    if len(sys.argv) != 3:
97        print (usage)
98        exit(1)
99    input_file = sys.argv[1]
100    output_file = sys.argv[2]
101
102    src_header = SELinuxNeverallowTestFrame.src_header
103    src_body = SELinuxNeverallowTestFrame.src_body
104    src_footer = SELinuxNeverallowTestFrame.src_footer
105
106    # grab the neverallow rules from the policy file and transform into tests
107    neverallow_rules = extract_neverallow_rules(input_file)
108    i = 0
109    for rule in neverallow_rules:
110        src_body += neverallow_rule_to_test(rule, i)
111        i += 1
112
113    with open(output_file, 'w') as out_file:
114        out_file.write(src_header)
115        out_file.write(src_body)
116        out_file.write(src_footer)
117