• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import print_function
2import filecmp
3import glob
4import itertools
5import os
6import sys
7import sysconfig
8import tempfile
9import unittest
10
11
12project_dir = os.path.abspath(os.path.join(__file__, '..', '..', '..'))
13
14PYTHON = sys.executable or 'python'
15
16BRO = os.path.join(project_dir, 'python', 'bro.py')
17
18# Get the platform/version-specific build folder.
19# By default, the distutils build base is in the same location as setup.py.
20platform_lib_name = 'lib.{platform}-{version[0]}.{version[1]}'.format(
21    platform=sysconfig.get_platform(), version=sys.version_info)
22build_dir = os.path.join(project_dir, 'bin', platform_lib_name)
23
24# Prepend the build folder to sys.path and the PYTHONPATH environment variable.
25if build_dir not in sys.path:
26    sys.path.insert(0, build_dir)
27TEST_ENV = os.environ.copy()
28if 'PYTHONPATH' not in TEST_ENV:
29    TEST_ENV['PYTHONPATH'] = build_dir
30else:
31    TEST_ENV['PYTHONPATH'] = build_dir + os.pathsep + TEST_ENV['PYTHONPATH']
32
33TESTDATA_DIR = os.path.join(project_dir, 'tests', 'testdata')
34
35TESTDATA_FILES = [
36    'empty',  # Empty file
37    '10x10y',  # Small text
38    'alice29.txt',  # Large text
39    'random_org_10k.bin',  # Small data
40    'mapsdatazrh',  # Large data
41]
42
43TESTDATA_PATHS = [os.path.join(TESTDATA_DIR, f) for f in TESTDATA_FILES]
44
45TESTDATA_PATHS_FOR_DECOMPRESSION = glob.glob(
46    os.path.join(TESTDATA_DIR, '*.compressed'))
47
48TEMP_DIR = tempfile.mkdtemp()
49
50
51def get_temp_compressed_name(filename):
52    return os.path.join(TEMP_DIR, os.path.basename(filename + '.bro'))
53
54
55def get_temp_uncompressed_name(filename):
56    return os.path.join(TEMP_DIR, os.path.basename(filename + '.unbro'))
57
58
59def bind_method_args(method, *args, **kwargs):
60    return lambda self: method(self, *args, **kwargs)
61
62
63def generate_test_methods(test_case_class,
64                          for_decompression=False,
65                          variants=None):
66    # Add test methods for each test data file.  This makes identifying problems
67    # with specific compression scenarios easier.
68    if for_decompression:
69        paths = TESTDATA_PATHS_FOR_DECOMPRESSION
70    else:
71        paths = TESTDATA_PATHS
72    opts = []
73    if variants:
74        opts_list = []
75        for k, v in variants.items():
76            opts_list.append([r for r in itertools.product([k], v)])
77        for o in itertools.product(*opts_list):
78            opts_name = '_'.join([str(i) for i in itertools.chain(*o)])
79            opts_dict = dict(o)
80            opts.append([opts_name, opts_dict])
81    else:
82        opts.append(['', {}])
83    for method in [m for m in dir(test_case_class) if m.startswith('_test')]:
84        for testdata in paths:
85            for (opts_name, opts_dict) in opts:
86                f = os.path.splitext(os.path.basename(testdata))[0]
87                name = 'test_{method}_{options}_{file}'.format(
88                    method=method, options=opts_name, file=f)
89                func = bind_method_args(
90                    getattr(test_case_class, method), testdata, **opts_dict)
91                setattr(test_case_class, name, func)
92
93
94class TestCase(unittest.TestCase):
95
96    def tearDown(self):
97        for f in TESTDATA_PATHS:
98            try:
99                os.unlink(get_temp_compressed_name(f))
100            except OSError:
101                pass
102            try:
103                os.unlink(get_temp_uncompressed_name(f))
104            except OSError:
105                pass
106
107    def assertFilesMatch(self, first, second):
108        self.assertTrue(
109            filecmp.cmp(first, second, shallow=False),
110            'File {} differs from {}'.format(first, second))
111