• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# (C) 2012-2013 by Pablo Neira Ayuso <pablo@netfilter.org>
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This software has been sponsored by Sophos Astaro <http://www.sophos.com>
11#
12
13import sys
14import os
15import subprocess
16import argparse
17
18IPTABLES = "iptables"
19IP6TABLES = "ip6tables"
20#IPTABLES = "xtables -4"
21#IP6TABLES = "xtables -6"
22
23IPTABLES_SAVE = "iptables-save"
24IP6TABLES_SAVE = "ip6tables-save"
25#IPTABLES_SAVE = ['xtables-save','-4']
26#IP6TABLES_SAVE = ['xtables-save','-6']
27
28EXTENSIONS_PATH = "extensions"
29LOGFILE="/tmp/iptables-test.log"
30log_file = None
31
32
33class Colors:
34    HEADER = '\033[95m'
35    BLUE = '\033[94m'
36    GREEN = '\033[92m'
37    YELLOW = '\033[93m'
38    RED = '\033[91m'
39    ENDC = '\033[0m'
40
41
42def print_error(reason, filename=None, lineno=None):
43    '''
44    Prints an error with nice colors, indicating file and line number.
45    '''
46    print (filename + ": " + Colors.RED + "ERROR" +
47        Colors.ENDC + ": line %d (%s)" % (lineno, reason))
48
49
50def delete_rule(iptables, rule, filename, lineno):
51    '''
52    Removes an iptables rule
53    '''
54    cmd = iptables + " -D " + rule
55    ret = execute_cmd(cmd, filename, lineno)
56    if ret == 1:
57        reason = "cannot delete: " + iptables + " -I " + rule
58        print_error(reason, filename, lineno)
59        return -1
60
61    return 0
62
63
64def run_test(iptables, rule, rule_save, res, filename, lineno):
65    '''
66    Executes an unit test. Returns the output of delete_rule().
67
68    Parameters:
69    :param  iptables: string with the iptables command to execute
70    :param rule: string with iptables arguments for the rule to test
71    :param rule_save: string to find the rule in the output of iptables -save
72    :param res: expected result of the rule. Valid values: "OK", "FAIL"
73    :param filename: name of the file tested (used for print_error purposes)
74    :param lineno: line number being tested (used for print_error purposes)
75    '''
76    ret = 0
77
78    cmd = iptables + " -A " + rule
79    ret = execute_cmd(cmd, filename, lineno)
80
81    #
82    # report failed test
83    #
84    if ret:
85        if res == "OK":
86            reason = "cannot load: " + cmd
87            print_error(reason, filename, lineno)
88            return -1
89        else:
90            # do not report this error
91            return 0
92    else:
93        if res == "FAIL":
94            reason = "should fail: " + cmd
95            print_error(reason, filename, lineno)
96            delete_rule(iptables, rule, filename, lineno)
97            return -1
98
99    matching = 0
100    splitted = iptables.split(" ")
101    if len(splitted) == 2:
102        if splitted[1] == '-4':
103            command = IPTABLES_SAVE
104        elif splitted[1] == '-6':
105            command = IP6TABLES_SAVE
106    elif len(splitted) == 1:
107        if splitted[0] == IPTABLES:
108            command = IPTABLES_SAVE
109        elif splitted[0] == IP6TABLES:
110            command = IP6TABLES_SAVE
111    args = splitted[1:]
112    proc = subprocess.Popen(command, stdin=subprocess.PIPE,
113                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
114    out, err = proc.communicate()
115
116    #
117    # check for segfaults
118    #
119    if proc.returncode == -11:
120        reason = "iptables-save segfaults: " + cmd
121        print_error(reason, filename, lineno)
122        delete_rule(iptables, rule, filename, lineno)
123        return -1
124
125    # find the rule
126    matching = out.find(rule_save)
127    if matching < 0:
128        reason = "cannot find: " + iptables + " -I " + rule
129        print_error(reason, filename, lineno)
130        delete_rule(iptables, rule, filename, lineno)
131        return -1
132
133    return delete_rule(iptables, rule, filename, lineno)
134
135
136def execute_cmd(cmd, filename, lineno):
137    '''
138    Executes a command, checking for segfaults and returning the command exit
139    code.
140
141    :param cmd: string with the command to be executed
142    :param filename: name of the file tested (used for print_error purposes)
143    :param lineno: line number being tested (used for print_error purposes)
144    '''
145    global log_file
146    print >> log_file, "command: %s" % cmd
147    ret = subprocess.call(cmd, shell=True, universal_newlines=True,
148        stderr=subprocess.STDOUT, stdout=log_file)
149    log_file.flush()
150
151    # generic check for segfaults
152    if ret  == -11:
153        reason = "command segfaults: " + cmd
154        print_error(reason, filename, lineno)
155    return ret
156
157
158def run_test_file(filename):
159    '''
160    Runs a test file
161
162    :param filename: name of the file with the test rules
163    '''
164    #
165    # if this is not a test file, skip.
166    #
167    if not filename.endswith(".t"):
168        return 0, 0
169
170    if "libipt_" in filename:
171        iptables = IPTABLES
172    elif "libip6t_" in filename:
173        iptables = IP6TABLES
174    elif "libxt_"  in filename:
175        iptables = IPTABLES
176    else:
177        # default to iptables if not known prefix
178        iptables = IPTABLES
179
180    f = open(filename)
181
182    tests = 0
183    passed = 0
184    table = ""
185    total_test_passed = True
186
187    for lineno, line in enumerate(f):
188        if line[0] == "#":
189            continue
190
191        if line[0] == ":":
192            chain_array = line.rstrip()[1:].split(",")
193            continue
194
195        # external non-iptables invocation, executed as is.
196        if line[0] == "@":
197            external_cmd = line.rstrip()[1:]
198            execute_cmd(external_cmd, filename, lineno)
199            continue
200
201        if line[0] == "*":
202            table = line.rstrip()[1:]
203            continue
204
205        if len(chain_array) == 0:
206            print "broken test, missing chain, leaving"
207            sys.exit()
208
209        test_passed = True
210        tests += 1
211
212        for chain in chain_array:
213            item = line.split(";")
214            if table == "":
215                rule = chain + " " + item[0]
216            else:
217                rule = chain + " -t " + table + " " + item[0]
218
219            if item[1] == "=":
220                rule_save = chain + " " + item[0]
221            else:
222                rule_save = chain + " " + item[1]
223
224            res = item[2].rstrip()
225
226            ret = run_test(iptables, rule, rule_save,
227                           res, filename, lineno + 1)
228            if ret < 0:
229                test_passed = False
230                total_test_passed = False
231                break
232
233        if test_passed:
234            passed += 1
235
236    if total_test_passed:
237        print filename + ": " + Colors.GREEN + "OK" + Colors.ENDC
238
239    f.close()
240    return tests, passed
241
242
243def show_missing():
244    '''
245    Show the list of missing test files
246    '''
247    file_list = os.listdir(EXTENSIONS_PATH)
248    testfiles = [i for i in file_list if i.endswith('.t')]
249    libfiles = [i for i in file_list
250                if i.startswith('lib') and i.endswith('.c')]
251
252    def test_name(x):
253        return x[0:-2] + '.t'
254    missing = [test_name(i) for i in libfiles
255               if not test_name(i) in testfiles]
256
257    print '\n'.join(missing)
258
259
260#
261# main
262#
263def main():
264    parser = argparse.ArgumentParser(description='Run iptables tests')
265    parser.add_argument('filename', nargs='?',
266                        metavar='path/to/file.t',
267                        help='Run only this test')
268    parser.add_argument('-m', '--missing', action='store_true',
269                        help='Check for missing tests')
270    args = parser.parse_args()
271
272    #
273    # show list of missing test files
274    #
275    if args.missing:
276        show_missing()
277        return
278
279    if os.getuid() != 0:
280        print "You need to be root to run this, sorry"
281        return
282
283    test_files = 0
284    tests = 0
285    passed = 0
286
287    # setup global var log file
288    global log_file
289    try:
290        log_file = open(LOGFILE, 'w')
291    except IOError:
292        print "Couldn't open log file %s" % LOGFILE
293        return
294
295    file_list = [os.path.join(EXTENSIONS_PATH, i)
296                 for i in os.listdir(EXTENSIONS_PATH)]
297    if args.filename:
298        file_list = [args.filename]
299    for filename in file_list:
300        file_tests, file_passed = run_test_file(filename)
301        if file_tests:
302            tests += file_tests
303            passed += file_passed
304            test_files += 1
305
306    print ("%d test files, %d unit tests, %d passed" %
307           (test_files, tests, passed))
308
309
310if __name__ == '__main__':
311    main()
312