• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3# (C) 2012-2013 by Pablo Neira Ayuso <pablo@netfilter.org>
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This software has been sponsored by Sophos Astaro <http://www.sophos.com>
11#
12
13from __future__ import print_function
14import sys
15import os
16import subprocess
17import argparse
18from difflib import unified_diff
19
20IPTABLES = "iptables"
21IP6TABLES = "ip6tables"
22ARPTABLES = "arptables"
23EBTABLES = "ebtables"
24
25IPTABLES_SAVE = "iptables-save"
26IP6TABLES_SAVE = "ip6tables-save"
27ARPTABLES_SAVE = "arptables-save"
28EBTABLES_SAVE = "ebtables-save"
29#IPTABLES_SAVE = ['xtables-save','-4']
30#IP6TABLES_SAVE = ['xtables-save','-6']
31
32EXTENSIONS_PATH = "extensions"
33TESTS_PATH = os.path.join(os.path.dirname(sys.argv[0]), "extensions")
34LOGFILE="/tmp/iptables-test.log"
35log_file = None
36
37STDOUT_IS_TTY = sys.stdout.isatty()
38STDERR_IS_TTY = sys.stderr.isatty()
39
40def maybe_colored(color, text, isatty):
41    terminal_sequences = {
42        'green': '\033[92m',
43        'red': '\033[91m',
44    }
45
46    return (
47        terminal_sequences[color] + text + '\033[0m' if isatty else text
48    )
49
50
51def print_error(reason, filename=None, lineno=None, log_file=sys.stderr):
52    '''
53    Prints an error with nice colors, indicating file and line number.
54    '''
55    print(filename + ": " + maybe_colored('red', "ERROR", log_file.isatty()) +
56        ": line %d (%s)" % (lineno, reason), file=log_file)
57
58
59def delete_rule(iptables, rule, filename, lineno, netns = None):
60    '''
61    Removes an iptables rule
62
63    Remove any --set-counters arguments, --delete rejects them.
64    '''
65    delrule = rule.split()
66    for i in range(len(delrule)):
67        if delrule[i] in ['-c', '--set-counters']:
68            delrule.pop(i)
69            if ',' in delrule.pop(i):
70                break
71            if len(delrule) > i and delrule[i].isnumeric():
72                delrule.pop(i)
73            break
74    rule = " ".join(delrule)
75
76    cmd = iptables + " -D " + rule
77    ret = execute_cmd(cmd, filename, lineno, netns)
78    if ret != 0:
79        reason = "cannot delete: " + iptables + " -I " + rule
80        print_error(reason, filename, lineno)
81        return -1
82
83    return 0
84
85
86def run_test(iptables, rule, rule_save, res, filename, lineno, netns, stderr=sys.stderr):
87    '''
88    Executes an unit test. Returns the output of delete_rule().
89
90    Parameters:
91    :param iptables: string with the iptables command to execute
92    :param rule: string with iptables arguments for the rule to test
93    :param rule_save: string to find the rule in the output of iptables-save
94    :param res: expected result of the rule. Valid values: "OK", "FAIL"
95    :param filename: name of the file tested (used for print_error purposes)
96    :param lineno: line number being tested (used for print_error purposes)
97    :param netns: network namespace to call commands in (or None)
98    '''
99    ret = 0
100
101    cmd = iptables + " -A " + rule
102    ret = execute_cmd(cmd, filename, lineno, netns)
103
104    #
105    # report failed test
106    #
107    if ret:
108        if res != "FAIL":
109            reason = "cannot load: " + cmd
110            print_error(reason, filename, lineno, stderr)
111            return -1
112        else:
113            # do not report this error
114            return 0
115    else:
116        if res == "FAIL":
117            reason = "should fail: " + cmd
118            print_error(reason, filename, lineno, stderr)
119            delete_rule(iptables, rule, filename, lineno, netns)
120            return -1
121
122    matching = 0
123    tokens = iptables.split(" ")
124    if len(tokens) == 2:
125        if tokens[1] == '-4':
126            command = IPTABLES_SAVE
127        elif tokens[1] == '-6':
128            command = IP6TABLES_SAVE
129    elif len(tokens) == 1:
130        if tokens[0] == IPTABLES:
131            command = IPTABLES_SAVE
132        elif tokens[0] == IP6TABLES:
133            command = IP6TABLES_SAVE
134        elif tokens[0] == ARPTABLES:
135            command = ARPTABLES_SAVE
136        elif tokens[0] == EBTABLES:
137            command = EBTABLES_SAVE
138
139    command = EXECUTABLE + " " + command
140
141    if netns:
142            command = "ip netns exec " + netns + " " + command
143
144    args = tokens[1:]
145    proc = subprocess.Popen(command, shell=True,
146                            stdin=subprocess.PIPE,
147                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
148    out, err = proc.communicate()
149    if len(err):
150        print(err, file=log_file)
151
152    #
153    # check for segfaults
154    #
155    if proc.returncode == -11:
156        reason = command + " segfaults!"
157        print_error(reason, filename, lineno, stderr)
158        delete_rule(iptables, rule, filename, lineno, netns)
159        return -1
160
161    # find the rule
162    matching = out.find("\n-A {}\n".format(rule_save).encode('utf-8'))
163
164    if matching < 0:
165        if res == "OK":
166            reason = "cannot find: " + iptables + " -I " + rule
167            print_error(reason, filename, lineno, stderr)
168            delete_rule(iptables, rule, filename, lineno, netns)
169            return -1
170        else:
171            # do not report this error
172            return 0
173    else:
174        if res != "OK":
175            reason = "should not match: " + cmd
176            print_error(reason, filename, lineno, stderr)
177            delete_rule(iptables, rule, filename, lineno, netns)
178            return -1
179
180    # Test "ip netns del NETNS" path with rules in place
181    if netns:
182        return 0
183
184    return delete_rule(iptables, rule, filename, lineno)
185
186def execute_cmd(cmd, filename, lineno = 0, netns = None):
187    '''
188    Executes a command, checking for segfaults and returning the command exit
189    code.
190
191    :param cmd: string with the command to be executed
192    :param filename: name of the file tested (used for print_error purposes)
193    :param lineno: line number being tested (used for print_error purposes)
194    :param netns: network namespace to run command in
195    '''
196    global log_file
197    if cmd.startswith('iptables ') or cmd.startswith('ip6tables ') or cmd.startswith('ebtables ') or cmd.startswith('arptables '):
198        cmd = EXECUTABLE + " " + cmd
199
200    if netns:
201        cmd = "ip netns exec " + netns + " " + cmd
202
203    print("command: {}".format(cmd), file=log_file)
204    ret = subprocess.call(cmd, shell=True, universal_newlines=True,
205        stderr=subprocess.STDOUT, stdout=log_file)
206    log_file.flush()
207
208    # generic check for segfaults
209    if ret == -11:
210        reason = "command segfaults: " + cmd
211        print_error(reason, filename, lineno)
212    return ret
213
214
215def variant_res(res, variant, alt_res=None):
216    '''
217    Adjust expected result with given variant
218
219    If expected result is scoped to a variant, the other one yields a different
220    result. Therefore map @res to itself if given variant is current, use the
221    alternate result, @alt_res, if specified, invert @res otherwise.
222
223    :param res: expected result from test spec ("OK", "FAIL" or "NOMATCH")
224    :param variant: variant @res is scoped to by test spec ("NFT" or "LEGACY")
225    :param alt_res: optional expected result for the alternate variant.
226    '''
227    variant_executable = {
228        "NFT": "xtables-nft-multi",
229        "LEGACY": "xtables-legacy-multi"
230    }
231    res_inverse = {
232        "OK": "FAIL",
233        "FAIL": "OK",
234        "NOMATCH": "OK"
235    }
236
237    if variant_executable[variant] == EXECUTABLE:
238        return res
239    if alt_res is not None:
240        return alt_res
241    return res_inverse[res]
242
243def fast_run_possible(filename):
244    '''
245    Return true if fast test run is possible.
246
247    To keep things simple, run only for simple test files:
248    - no external commands
249    - no multiple tables
250    - no variant-specific results
251
252    :param filename: test file to inspect
253    '''
254    table = None
255    rulecount = 0
256    for line in open(filename):
257        if line[0] in ["#", ":"] or len(line.strip()) == 0:
258            continue
259        if line[0] == "*":
260            if table or rulecount > 0:
261                return False
262            table = line.rstrip()[1:]
263        if line[0] in ["@", "%"]:
264            return False
265        if len(line.split(";")) > 3:
266            return False
267        rulecount += 1
268
269    return True
270
271def run_test_file_fast(iptables, filename, netns):
272    '''
273    Run a test file, but fast
274
275    Add all non-failing rules at once by use of iptables-restore, then check
276    all rules' listing at once by use of iptables-save.
277
278    :param filename: name of the file with the test rules
279    :param netns: network namespace to perform test run in
280    '''
281
282    f = open(filename)
283
284    rules = {}
285    table = "filter"
286    chain_array = []
287    tests = 0
288
289    for lineno, line in enumerate(f):
290        if line[0] == "#" or len(line.strip()) == 0:
291            continue
292
293        if line[0] == "*":
294            table = line.rstrip()[1:]
295            continue
296
297        if line[0] == ":":
298            chain_array = line.rstrip()[1:].split(",")
299            continue
300
301        if len(chain_array) == 0:
302            return -1
303
304        tests += 1
305
306        for chain in chain_array:
307            item = line.split(";")
308            rule = chain + " " + item[0]
309
310            if item[1] == "=":
311                rule_save = chain + " " + item[0]
312            else:
313                rule_save = chain + " " + item[1]
314
315            if iptables == EBTABLES and rule_save.find('-j') < 0:
316                rule_save += " -j CONTINUE"
317
318            res = item[2].rstrip()
319            if res != "OK":
320                rule = chain + " -t " + table + " " + item[0]
321                ret = run_test(iptables, rule, rule_save,
322                               res, filename, lineno + 1, netns, log_file)
323
324                if ret < 0:
325                    return -1
326                continue
327
328            if not chain in rules.keys():
329                rules[chain] = []
330            rules[chain].append((rule, rule_save))
331
332    restore_data = ["*" + table]
333    out_expect = []
334    for chain in ["PREROUTING", "INPUT", "FORWARD", "OUTPUT", "POSTROUTING"]:
335        if not chain in rules.keys():
336            continue
337        for rule in rules[chain]:
338            restore_data.append("-A " + rule[0])
339            out_expect.append("-A " + rule[1])
340    restore_data.append("COMMIT")
341
342    out_expect = "\n".join(out_expect)
343
344    # load all rules via iptables_restore
345
346    command = EXECUTABLE + " " + iptables + "-restore"
347    if netns:
348        command = "ip netns exec " + netns + " " + command
349
350    for line in restore_data:
351        print(iptables + "-restore: " + line, file=log_file)
352
353    proc = subprocess.Popen(command, shell = True, text = True,
354                            stdin = subprocess.PIPE,
355                            stdout = subprocess.PIPE,
356                            stderr = subprocess.PIPE)
357    restore_data = "\n".join(restore_data) + "\n"
358    out, err = proc.communicate(input = restore_data)
359    if len(err):
360        print(err, file=log_file)
361
362    if proc.returncode == -11:
363        reason = iptables + "-restore segfaults!"
364        print_error(reason, filename, lineno)
365        msg = [iptables + "-restore segfault from:"]
366        msg.extend(["input: " + l for l in restore_data.split("\n")])
367        print("\n".join(msg), file=log_file)
368        return -1
369
370    if proc.returncode != 0:
371        print("%s-restore returned %d: %s" % (iptables, proc.returncode, err),
372              file=log_file)
373        return -1
374
375    # find all rules in iptables_save output
376
377    command = EXECUTABLE + " " + iptables + "-save"
378    if netns:
379        command = "ip netns exec " + netns + " " + command
380
381    proc = subprocess.Popen(command, shell = True,
382                            stdin = subprocess.PIPE,
383                            stdout = subprocess.PIPE,
384                            stderr = subprocess.PIPE)
385    out, err = proc.communicate()
386    if len(err):
387        print(err, file=log_file)
388
389    if proc.returncode == -11:
390        reason = iptables + "-save segfaults!"
391        print_error(reason, filename, lineno)
392        return -1
393
394    cmd = iptables + " -F -t " + table
395    execute_cmd(cmd, filename, 0, netns)
396
397    out = out.decode('utf-8').rstrip()
398    if out.find(out_expect) < 0:
399        print("dumps differ!", file=log_file)
400        out_clean = [ l for l in out.split("\n")
401                        if not l[0] in ['*', ':', '#']]
402        diff = unified_diff(out_expect.split("\n"), out_clean,
403                            fromfile="expect", tofile="got", lineterm='')
404        print("\n".join(diff), file=log_file)
405        return -1
406
407    return tests
408
409def _run_test_file(iptables, filename, netns, suffix):
410    '''
411    Runs a test file
412
413    :param iptables: string with the iptables command to execute
414    :param filename: name of the file with the test rules
415    :param netns: network namespace to perform test run in
416    '''
417
418    fast_failed = False
419    if fast_run_possible(filename):
420        tests = run_test_file_fast(iptables, filename, netns)
421        if tests > 0:
422            print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY) + suffix)
423            return tests, tests
424        fast_failed = True
425
426    f = open(filename)
427
428    tests = 0
429    passed = 0
430    table = ""
431    chain_array = []
432    total_test_passed = True
433
434    if netns:
435        execute_cmd("ip netns add " + netns, filename)
436
437    for lineno, line in enumerate(f):
438        if line[0] == "#" or len(line.strip()) == 0:
439            continue
440
441        if line[0] == ":":
442            chain_array = line.rstrip()[1:].split(",")
443            continue
444
445        # external command invocation, executed as is.
446        # detects iptables commands to prefix with EXECUTABLE automatically
447        if line[0] in ["@", "%"]:
448            external_cmd = line.rstrip()[1:]
449            execute_cmd(external_cmd, filename, lineno, netns)
450            continue
451
452        if line[0] == "*":
453            table = line.rstrip()[1:]
454            continue
455
456        if len(chain_array) == 0:
457            print_error("broken test, missing chain",
458                        filename = filename, lineno = lineno)
459            total_test_passed = False
460            break
461
462        test_passed = True
463        tests += 1
464
465        for chain in chain_array:
466            item = line.split(";")
467            if table == "":
468                rule = chain + " " + item[0]
469            else:
470                rule = chain + " -t " + table + " " + item[0]
471
472            if item[1] == "=":
473                rule_save = chain + " " + item[0]
474            else:
475                rule_save = chain + " " + item[1]
476
477            if iptables == EBTABLES and rule_save.find('-j') < 0:
478                rule_save += " -j CONTINUE"
479
480            res = item[2].rstrip()
481            if len(item) > 3:
482                variant = item[3].rstrip()
483                if len(item) > 4:
484                    alt_res = item[4].rstrip()
485                else:
486                    alt_res = None
487                res = variant_res(res, variant, alt_res)
488
489            ret = run_test(iptables, rule, rule_save,
490                           res, filename, lineno + 1, netns)
491
492            if ret < 0:
493                test_passed = False
494                total_test_passed = False
495                break
496
497        if test_passed:
498            passed += 1
499
500    if netns:
501        execute_cmd("ip netns del " + netns, filename)
502    if total_test_passed:
503        if fast_failed:
504            suffix += maybe_colored('red', " but fast mode failed!", STDOUT_IS_TTY)
505        print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY) + suffix)
506
507    f.close()
508    return tests, passed
509
510def run_test_file(filename, netns):
511    '''
512    Runs a test file
513
514    :param filename: name of the file with the test rules
515    :param netns: network namespace to perform test run in
516    '''
517    #
518    # if this is not a test file, skip.
519    #
520    if not filename.endswith(".t"):
521        return 0, 0
522
523    if "libipt_" in filename:
524        xtables = [ IPTABLES ]
525    elif "libip6t_" in filename:
526        xtables = [ IP6TABLES ]
527    elif "libxt_"  in filename:
528        xtables = [ IPTABLES, IP6TABLES ]
529    elif "libarpt_" in filename:
530        # only supported with nf_tables backend
531        if EXECUTABLE != "xtables-nft-multi":
532           return 0, 0
533        xtables = [ ARPTABLES ]
534    elif "libebt_" in filename:
535        # only supported with nf_tables backend
536        if EXECUTABLE != "xtables-nft-multi":
537           return 0, 0
538        xtables = [ EBTABLES ]
539    else:
540        # default to iptables if not known prefix
541        xtables = [ IPTABLES ]
542
543    tests = 0
544    passed = 0
545    print_result = False
546    suffix = ""
547    for iptables in xtables:
548        if len(xtables) > 1:
549            suffix = "({})".format(iptables)
550
551        file_tests, file_passed = _run_test_file(iptables, filename, netns, suffix)
552        if file_tests:
553            tests += file_tests
554            passed += file_passed
555
556    return tests, passed
557
558def show_missing():
559    '''
560    Show the list of missing test files
561    '''
562    file_list = os.listdir(TESTS_PATH)
563    testfiles = [i for i in file_list if i.endswith('.t')]
564    libfiles = [i for i in file_list
565                if i.startswith('lib') and i.endswith('.c')]
566
567    def test_name(x):
568        return x[0:-2] + '.t'
569    missing = [test_name(i) for i in libfiles
570               if not test_name(i) in testfiles]
571
572    print('\n'.join(missing))
573
574def spawn_netns():
575    # prefer unshare module
576    try:
577        import unshare
578        unshare.unshare(unshare.CLONE_NEWNET)
579        return True
580    except:
581        pass
582
583    # sledgehammer style:
584    # - call ourselves prefixed by 'unshare -n' if found
585    # - pass extra --no-netns parameter to avoid another recursion
586    try:
587        import shutil
588
589        unshare = shutil.which("unshare")
590        if unshare is None:
591            return False
592
593        sys.argv.append("--no-netns")
594        os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv)
595    except:
596        pass
597
598    return False
599
600#
601# main
602#
603def main():
604    parser = argparse.ArgumentParser(description='Run iptables tests')
605    parser.add_argument('filename', nargs='*',
606                        metavar='path/to/file.t',
607                        help='Run only this test')
608    parser.add_argument('-H', '--host', action='store_true',
609                        help='Run tests against installed binaries')
610    parser.add_argument('-l', '--legacy', action='store_true',
611                        help='Test iptables-legacy')
612    parser.add_argument('-m', '--missing', action='store_true',
613                        help='Check for missing tests')
614    parser.add_argument('-n', '--nftables', action='store_true',
615                        help='Test iptables-over-nftables')
616    parser.add_argument('-N', '--netns', action='store_const',
617                        const='____iptables-container-test',
618                        help='Test netnamespace path')
619    parser.add_argument('--no-netns', action='store_true',
620                        help='Do not run testsuite in own network namespace')
621    args = parser.parse_args()
622
623    #
624    # show list of missing test files
625    #
626    if args.missing:
627        show_missing()
628        return
629
630    variants = []
631    if args.legacy:
632        variants.append("legacy")
633    if args.nftables:
634        variants.append("nft")
635    if len(variants) == 0:
636        variants = [ "legacy", "nft" ]
637
638    if os.getuid() != 0:
639        print("You need to be root to run this, sorry", file=sys.stderr)
640        return 77
641
642    if not args.netns and not args.no_netns and not spawn_netns():
643        print("Cannot run in own namespace, connectivity might break",
644              file=sys.stderr)
645
646    if not args.host:
647        os.putenv("XTABLES_LIBDIR", os.path.abspath(EXTENSIONS_PATH))
648        os.putenv("PATH", "%s/iptables:%s" % (os.path.abspath(os.path.curdir),
649                                              os.getenv("PATH")))
650
651    total_test_files = 0
652    total_passed = 0
653    total_tests = 0
654    for variant in variants:
655        global EXECUTABLE
656        EXECUTABLE = "xtables-" + variant + "-multi"
657
658        test_files = 0
659        tests = 0
660        passed = 0
661
662        # setup global var log file
663        global log_file
664        try:
665            log_file = open(LOGFILE, 'w')
666        except IOError:
667            print("Couldn't open log file %s" % LOGFILE, file=sys.stderr)
668            return
669
670        if args.filename:
671            file_list = args.filename
672        else:
673            file_list = [os.path.join(TESTS_PATH, i)
674                         for i in os.listdir(TESTS_PATH)
675                         if i.endswith('.t')]
676            file_list.sort()
677
678        for filename in file_list:
679            file_tests, file_passed = run_test_file(filename, args.netns)
680            if file_tests:
681                tests += file_tests
682                passed += file_passed
683                test_files += 1
684
685        print("%s: %d test files, %d unit tests, %d passed"
686              % (variant, test_files, tests, passed))
687
688        total_passed += passed
689        total_tests += tests
690        total_test_files = max(total_test_files, test_files)
691
692    if len(variants) > 1:
693        print("total: %d test files, %d unit tests, %d passed"
694              % (total_test_files, total_tests, total_passed))
695    return total_passed - total_tests
696
697if __name__ == '__main__':
698    sys.exit(main())
699