1#!/usr/bin/env python3 2# 3# (C) 2012-2013 by Pablo Neira Ayuso <pablo@netfilter.org> 4# 5# This program is free software; you can redistribute it and/or modify 6# it under the terms of the GNU General Public License as published by 7# the Free Software Foundation; either version 2 of the License, or 8# (at your option) any later version. 9# 10# This software has been sponsored by Sophos Astaro <http://www.sophos.com> 11# 12 13from __future__ import print_function 14import sys 15import os 16import subprocess 17import argparse 18from difflib import unified_diff 19 20IPTABLES = "iptables" 21IP6TABLES = "ip6tables" 22ARPTABLES = "arptables" 23EBTABLES = "ebtables" 24 25IPTABLES_SAVE = "iptables-save" 26IP6TABLES_SAVE = "ip6tables-save" 27ARPTABLES_SAVE = "arptables-save" 28EBTABLES_SAVE = "ebtables-save" 29#IPTABLES_SAVE = ['xtables-save','-4'] 30#IP6TABLES_SAVE = ['xtables-save','-6'] 31 32EXTENSIONS_PATH = "extensions" 33TESTS_PATH = os.path.join(os.path.dirname(sys.argv[0]), "extensions") 34LOGFILE="/tmp/iptables-test.log" 35log_file = None 36 37STDOUT_IS_TTY = sys.stdout.isatty() 38STDERR_IS_TTY = sys.stderr.isatty() 39 40def maybe_colored(color, text, isatty): 41 terminal_sequences = { 42 'green': '\033[92m', 43 'red': '\033[91m', 44 } 45 46 return ( 47 terminal_sequences[color] + text + '\033[0m' if isatty else text 48 ) 49 50 51def print_error(reason, filename=None, lineno=None, log_file=sys.stderr): 52 ''' 53 Prints an error with nice colors, indicating file and line number. 54 ''' 55 print(filename + ": " + maybe_colored('red', "ERROR", log_file.isatty()) + 56 ": line %d (%s)" % (lineno, reason), file=log_file) 57 58 59def delete_rule(iptables, rule, filename, lineno, netns = None): 60 ''' 61 Removes an iptables rule 62 63 Remove any --set-counters arguments, --delete rejects them. 64 ''' 65 delrule = rule.split() 66 for i in range(len(delrule)): 67 if delrule[i] in ['-c', '--set-counters']: 68 delrule.pop(i) 69 if ',' in delrule.pop(i): 70 break 71 if len(delrule) > i and delrule[i].isnumeric(): 72 delrule.pop(i) 73 break 74 rule = " ".join(delrule) 75 76 cmd = iptables + " -D " + rule 77 ret = execute_cmd(cmd, filename, lineno, netns) 78 if ret != 0: 79 reason = "cannot delete: " + iptables + " -I " + rule 80 print_error(reason, filename, lineno) 81 return -1 82 83 return 0 84 85 86def run_test(iptables, rule, rule_save, res, filename, lineno, netns, stderr=sys.stderr): 87 ''' 88 Executes an unit test. Returns the output of delete_rule(). 89 90 Parameters: 91 :param iptables: string with the iptables command to execute 92 :param rule: string with iptables arguments for the rule to test 93 :param rule_save: string to find the rule in the output of iptables-save 94 :param res: expected result of the rule. Valid values: "OK", "FAIL" 95 :param filename: name of the file tested (used for print_error purposes) 96 :param lineno: line number being tested (used for print_error purposes) 97 :param netns: network namespace to call commands in (or None) 98 ''' 99 ret = 0 100 101 cmd = iptables + " -A " + rule 102 ret = execute_cmd(cmd, filename, lineno, netns) 103 104 # 105 # report failed test 106 # 107 if ret: 108 if res != "FAIL": 109 reason = "cannot load: " + cmd 110 print_error(reason, filename, lineno, stderr) 111 return -1 112 else: 113 # do not report this error 114 return 0 115 else: 116 if res == "FAIL": 117 reason = "should fail: " + cmd 118 print_error(reason, filename, lineno, stderr) 119 delete_rule(iptables, rule, filename, lineno, netns) 120 return -1 121 122 matching = 0 123 tokens = iptables.split(" ") 124 if len(tokens) == 2: 125 if tokens[1] == '-4': 126 command = IPTABLES_SAVE 127 elif tokens[1] == '-6': 128 command = IP6TABLES_SAVE 129 elif len(tokens) == 1: 130 if tokens[0] == IPTABLES: 131 command = IPTABLES_SAVE 132 elif tokens[0] == IP6TABLES: 133 command = IP6TABLES_SAVE 134 elif tokens[0] == ARPTABLES: 135 command = ARPTABLES_SAVE 136 elif tokens[0] == EBTABLES: 137 command = EBTABLES_SAVE 138 139 command = EXECUTABLE + " " + command 140 141 if netns: 142 command = "ip netns exec " + netns + " " + command 143 144 args = tokens[1:] 145 proc = subprocess.Popen(command, shell=True, 146 stdin=subprocess.PIPE, 147 stdout=subprocess.PIPE, stderr=subprocess.PIPE) 148 out, err = proc.communicate() 149 if len(err): 150 print(err, file=log_file) 151 152 # 153 # check for segfaults 154 # 155 if proc.returncode == -11: 156 reason = command + " segfaults!" 157 print_error(reason, filename, lineno, stderr) 158 delete_rule(iptables, rule, filename, lineno, netns) 159 return -1 160 161 # find the rule 162 matching = out.find("\n-A {}\n".format(rule_save).encode('utf-8')) 163 164 if matching < 0: 165 if res == "OK": 166 reason = "cannot find: " + iptables + " -I " + rule 167 print_error(reason, filename, lineno, stderr) 168 delete_rule(iptables, rule, filename, lineno, netns) 169 return -1 170 else: 171 # do not report this error 172 return 0 173 else: 174 if res != "OK": 175 reason = "should not match: " + cmd 176 print_error(reason, filename, lineno, stderr) 177 delete_rule(iptables, rule, filename, lineno, netns) 178 return -1 179 180 # Test "ip netns del NETNS" path with rules in place 181 if netns: 182 return 0 183 184 return delete_rule(iptables, rule, filename, lineno) 185 186def execute_cmd(cmd, filename, lineno = 0, netns = None): 187 ''' 188 Executes a command, checking for segfaults and returning the command exit 189 code. 190 191 :param cmd: string with the command to be executed 192 :param filename: name of the file tested (used for print_error purposes) 193 :param lineno: line number being tested (used for print_error purposes) 194 :param netns: network namespace to run command in 195 ''' 196 global log_file 197 if cmd.startswith('iptables ') or cmd.startswith('ip6tables ') or cmd.startswith('ebtables ') or cmd.startswith('arptables '): 198 cmd = EXECUTABLE + " " + cmd 199 200 if netns: 201 cmd = "ip netns exec " + netns + " " + cmd 202 203 print("command: {}".format(cmd), file=log_file) 204 ret = subprocess.call(cmd, shell=True, universal_newlines=True, 205 stderr=subprocess.STDOUT, stdout=log_file) 206 log_file.flush() 207 208 # generic check for segfaults 209 if ret == -11: 210 reason = "command segfaults: " + cmd 211 print_error(reason, filename, lineno) 212 return ret 213 214 215def variant_res(res, variant, alt_res=None): 216 ''' 217 Adjust expected result with given variant 218 219 If expected result is scoped to a variant, the other one yields a different 220 result. Therefore map @res to itself if given variant is current, use the 221 alternate result, @alt_res, if specified, invert @res otherwise. 222 223 :param res: expected result from test spec ("OK", "FAIL" or "NOMATCH") 224 :param variant: variant @res is scoped to by test spec ("NFT" or "LEGACY") 225 :param alt_res: optional expected result for the alternate variant. 226 ''' 227 variant_executable = { 228 "NFT": "xtables-nft-multi", 229 "LEGACY": "xtables-legacy-multi" 230 } 231 res_inverse = { 232 "OK": "FAIL", 233 "FAIL": "OK", 234 "NOMATCH": "OK" 235 } 236 237 if variant_executable[variant] == EXECUTABLE: 238 return res 239 if alt_res is not None: 240 return alt_res 241 return res_inverse[res] 242 243def fast_run_possible(filename): 244 ''' 245 Return true if fast test run is possible. 246 247 To keep things simple, run only for simple test files: 248 - no external commands 249 - no multiple tables 250 - no variant-specific results 251 252 :param filename: test file to inspect 253 ''' 254 table = None 255 rulecount = 0 256 for line in open(filename): 257 if line[0] in ["#", ":"] or len(line.strip()) == 0: 258 continue 259 if line[0] == "*": 260 if table or rulecount > 0: 261 return False 262 table = line.rstrip()[1:] 263 if line[0] in ["@", "%"]: 264 return False 265 if len(line.split(";")) > 3: 266 return False 267 rulecount += 1 268 269 return True 270 271def run_test_file_fast(iptables, filename, netns): 272 ''' 273 Run a test file, but fast 274 275 Add all non-failing rules at once by use of iptables-restore, then check 276 all rules' listing at once by use of iptables-save. 277 278 :param filename: name of the file with the test rules 279 :param netns: network namespace to perform test run in 280 ''' 281 282 f = open(filename) 283 284 rules = {} 285 table = "filter" 286 chain_array = [] 287 tests = 0 288 289 for lineno, line in enumerate(f): 290 if line[0] == "#" or len(line.strip()) == 0: 291 continue 292 293 if line[0] == "*": 294 table = line.rstrip()[1:] 295 continue 296 297 if line[0] == ":": 298 chain_array = line.rstrip()[1:].split(",") 299 continue 300 301 if len(chain_array) == 0: 302 return -1 303 304 tests += 1 305 306 for chain in chain_array: 307 item = line.split(";") 308 rule = chain + " " + item[0] 309 310 if item[1] == "=": 311 rule_save = chain + " " + item[0] 312 else: 313 rule_save = chain + " " + item[1] 314 315 if iptables == EBTABLES and rule_save.find('-j') < 0: 316 rule_save += " -j CONTINUE" 317 318 res = item[2].rstrip() 319 if res != "OK": 320 rule = chain + " -t " + table + " " + item[0] 321 ret = run_test(iptables, rule, rule_save, 322 res, filename, lineno + 1, netns, log_file) 323 324 if ret < 0: 325 return -1 326 continue 327 328 if not chain in rules.keys(): 329 rules[chain] = [] 330 rules[chain].append((rule, rule_save)) 331 332 restore_data = ["*" + table] 333 out_expect = [] 334 for chain in ["PREROUTING", "INPUT", "FORWARD", "OUTPUT", "POSTROUTING"]: 335 if not chain in rules.keys(): 336 continue 337 for rule in rules[chain]: 338 restore_data.append("-A " + rule[0]) 339 out_expect.append("-A " + rule[1]) 340 restore_data.append("COMMIT") 341 342 out_expect = "\n".join(out_expect) 343 344 # load all rules via iptables_restore 345 346 command = EXECUTABLE + " " + iptables + "-restore" 347 if netns: 348 command = "ip netns exec " + netns + " " + command 349 350 for line in restore_data: 351 print(iptables + "-restore: " + line, file=log_file) 352 353 proc = subprocess.Popen(command, shell = True, text = True, 354 stdin = subprocess.PIPE, 355 stdout = subprocess.PIPE, 356 stderr = subprocess.PIPE) 357 restore_data = "\n".join(restore_data) + "\n" 358 out, err = proc.communicate(input = restore_data) 359 if len(err): 360 print(err, file=log_file) 361 362 if proc.returncode == -11: 363 reason = iptables + "-restore segfaults!" 364 print_error(reason, filename, lineno) 365 msg = [iptables + "-restore segfault from:"] 366 msg.extend(["input: " + l for l in restore_data.split("\n")]) 367 print("\n".join(msg), file=log_file) 368 return -1 369 370 if proc.returncode != 0: 371 print("%s-restore returned %d: %s" % (iptables, proc.returncode, err), 372 file=log_file) 373 return -1 374 375 # find all rules in iptables_save output 376 377 command = EXECUTABLE + " " + iptables + "-save" 378 if netns: 379 command = "ip netns exec " + netns + " " + command 380 381 proc = subprocess.Popen(command, shell = True, 382 stdin = subprocess.PIPE, 383 stdout = subprocess.PIPE, 384 stderr = subprocess.PIPE) 385 out, err = proc.communicate() 386 if len(err): 387 print(err, file=log_file) 388 389 if proc.returncode == -11: 390 reason = iptables + "-save segfaults!" 391 print_error(reason, filename, lineno) 392 return -1 393 394 cmd = iptables + " -F -t " + table 395 execute_cmd(cmd, filename, 0, netns) 396 397 out = out.decode('utf-8').rstrip() 398 if out.find(out_expect) < 0: 399 print("dumps differ!", file=log_file) 400 out_clean = [ l for l in out.split("\n") 401 if not l[0] in ['*', ':', '#']] 402 diff = unified_diff(out_expect.split("\n"), out_clean, 403 fromfile="expect", tofile="got", lineterm='') 404 print("\n".join(diff), file=log_file) 405 return -1 406 407 return tests 408 409def _run_test_file(iptables, filename, netns, suffix): 410 ''' 411 Runs a test file 412 413 :param iptables: string with the iptables command to execute 414 :param filename: name of the file with the test rules 415 :param netns: network namespace to perform test run in 416 ''' 417 418 fast_failed = False 419 if fast_run_possible(filename): 420 tests = run_test_file_fast(iptables, filename, netns) 421 if tests > 0: 422 print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY) + suffix) 423 return tests, tests 424 fast_failed = True 425 426 f = open(filename) 427 428 tests = 0 429 passed = 0 430 table = "" 431 chain_array = [] 432 total_test_passed = True 433 434 if netns: 435 execute_cmd("ip netns add " + netns, filename) 436 437 for lineno, line in enumerate(f): 438 if line[0] == "#" or len(line.strip()) == 0: 439 continue 440 441 if line[0] == ":": 442 chain_array = line.rstrip()[1:].split(",") 443 continue 444 445 # external command invocation, executed as is. 446 # detects iptables commands to prefix with EXECUTABLE automatically 447 if line[0] in ["@", "%"]: 448 external_cmd = line.rstrip()[1:] 449 execute_cmd(external_cmd, filename, lineno, netns) 450 continue 451 452 if line[0] == "*": 453 table = line.rstrip()[1:] 454 continue 455 456 if len(chain_array) == 0: 457 print_error("broken test, missing chain", 458 filename = filename, lineno = lineno) 459 total_test_passed = False 460 break 461 462 test_passed = True 463 tests += 1 464 465 for chain in chain_array: 466 item = line.split(";") 467 if table == "": 468 rule = chain + " " + item[0] 469 else: 470 rule = chain + " -t " + table + " " + item[0] 471 472 if item[1] == "=": 473 rule_save = chain + " " + item[0] 474 else: 475 rule_save = chain + " " + item[1] 476 477 if iptables == EBTABLES and rule_save.find('-j') < 0: 478 rule_save += " -j CONTINUE" 479 480 res = item[2].rstrip() 481 if len(item) > 3: 482 variant = item[3].rstrip() 483 if len(item) > 4: 484 alt_res = item[4].rstrip() 485 else: 486 alt_res = None 487 res = variant_res(res, variant, alt_res) 488 489 ret = run_test(iptables, rule, rule_save, 490 res, filename, lineno + 1, netns) 491 492 if ret < 0: 493 test_passed = False 494 total_test_passed = False 495 break 496 497 if test_passed: 498 passed += 1 499 500 if netns: 501 execute_cmd("ip netns del " + netns, filename) 502 if total_test_passed: 503 if fast_failed: 504 suffix += maybe_colored('red', " but fast mode failed!", STDOUT_IS_TTY) 505 print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY) + suffix) 506 507 f.close() 508 return tests, passed 509 510def run_test_file(filename, netns): 511 ''' 512 Runs a test file 513 514 :param filename: name of the file with the test rules 515 :param netns: network namespace to perform test run in 516 ''' 517 # 518 # if this is not a test file, skip. 519 # 520 if not filename.endswith(".t"): 521 return 0, 0 522 523 if "libipt_" in filename: 524 xtables = [ IPTABLES ] 525 elif "libip6t_" in filename: 526 xtables = [ IP6TABLES ] 527 elif "libxt_" in filename: 528 xtables = [ IPTABLES, IP6TABLES ] 529 elif "libarpt_" in filename: 530 # only supported with nf_tables backend 531 if EXECUTABLE != "xtables-nft-multi": 532 return 0, 0 533 xtables = [ ARPTABLES ] 534 elif "libebt_" in filename: 535 # only supported with nf_tables backend 536 if EXECUTABLE != "xtables-nft-multi": 537 return 0, 0 538 xtables = [ EBTABLES ] 539 else: 540 # default to iptables if not known prefix 541 xtables = [ IPTABLES ] 542 543 tests = 0 544 passed = 0 545 print_result = False 546 suffix = "" 547 for iptables in xtables: 548 if len(xtables) > 1: 549 suffix = "({})".format(iptables) 550 551 file_tests, file_passed = _run_test_file(iptables, filename, netns, suffix) 552 if file_tests: 553 tests += file_tests 554 passed += file_passed 555 556 return tests, passed 557 558def show_missing(): 559 ''' 560 Show the list of missing test files 561 ''' 562 file_list = os.listdir(TESTS_PATH) 563 testfiles = [i for i in file_list if i.endswith('.t')] 564 libfiles = [i for i in file_list 565 if i.startswith('lib') and i.endswith('.c')] 566 567 def test_name(x): 568 return x[0:-2] + '.t' 569 missing = [test_name(i) for i in libfiles 570 if not test_name(i) in testfiles] 571 572 print('\n'.join(missing)) 573 574def spawn_netns(): 575 # prefer unshare module 576 try: 577 import unshare 578 unshare.unshare(unshare.CLONE_NEWNET) 579 return True 580 except: 581 pass 582 583 # sledgehammer style: 584 # - call ourselves prefixed by 'unshare -n' if found 585 # - pass extra --no-netns parameter to avoid another recursion 586 try: 587 import shutil 588 589 unshare = shutil.which("unshare") 590 if unshare is None: 591 return False 592 593 sys.argv.append("--no-netns") 594 os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv) 595 except: 596 pass 597 598 return False 599 600# 601# main 602# 603def main(): 604 parser = argparse.ArgumentParser(description='Run iptables tests') 605 parser.add_argument('filename', nargs='*', 606 metavar='path/to/file.t', 607 help='Run only this test') 608 parser.add_argument('-H', '--host', action='store_true', 609 help='Run tests against installed binaries') 610 parser.add_argument('-l', '--legacy', action='store_true', 611 help='Test iptables-legacy') 612 parser.add_argument('-m', '--missing', action='store_true', 613 help='Check for missing tests') 614 parser.add_argument('-n', '--nftables', action='store_true', 615 help='Test iptables-over-nftables') 616 parser.add_argument('-N', '--netns', action='store_const', 617 const='____iptables-container-test', 618 help='Test netnamespace path') 619 parser.add_argument('--no-netns', action='store_true', 620 help='Do not run testsuite in own network namespace') 621 args = parser.parse_args() 622 623 # 624 # show list of missing test files 625 # 626 if args.missing: 627 show_missing() 628 return 629 630 variants = [] 631 if args.legacy: 632 variants.append("legacy") 633 if args.nftables: 634 variants.append("nft") 635 if len(variants) == 0: 636 variants = [ "legacy", "nft" ] 637 638 if os.getuid() != 0: 639 print("You need to be root to run this, sorry", file=sys.stderr) 640 return 77 641 642 if not args.netns and not args.no_netns and not spawn_netns(): 643 print("Cannot run in own namespace, connectivity might break", 644 file=sys.stderr) 645 646 if not args.host: 647 os.putenv("XTABLES_LIBDIR", os.path.abspath(EXTENSIONS_PATH)) 648 os.putenv("PATH", "%s/iptables:%s" % (os.path.abspath(os.path.curdir), 649 os.getenv("PATH"))) 650 651 total_test_files = 0 652 total_passed = 0 653 total_tests = 0 654 for variant in variants: 655 global EXECUTABLE 656 EXECUTABLE = "xtables-" + variant + "-multi" 657 658 test_files = 0 659 tests = 0 660 passed = 0 661 662 # setup global var log file 663 global log_file 664 try: 665 log_file = open(LOGFILE, 'w') 666 except IOError: 667 print("Couldn't open log file %s" % LOGFILE, file=sys.stderr) 668 return 669 670 if args.filename: 671 file_list = args.filename 672 else: 673 file_list = [os.path.join(TESTS_PATH, i) 674 for i in os.listdir(TESTS_PATH) 675 if i.endswith('.t')] 676 file_list.sort() 677 678 for filename in file_list: 679 file_tests, file_passed = run_test_file(filename, args.netns) 680 if file_tests: 681 tests += file_tests 682 passed += file_passed 683 test_files += 1 684 685 print("%s: %d test files, %d unit tests, %d passed" 686 % (variant, test_files, tests, passed)) 687 688 total_passed += passed 689 total_tests += tests 690 total_test_files = max(total_test_files, test_files) 691 692 if len(variants) > 1: 693 print("total: %d test files, %d unit tests, %d passed" 694 % (total_test_files, total_tests, total_passed)) 695 return total_passed - total_tests 696 697if __name__ == '__main__': 698 sys.exit(main()) 699