• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import print_function
2import filecmp
3import glob
4import itertools
5import os
6import sys
7import sysconfig
8import tempfile
9import unittest
10
11
12project_dir = os.path.abspath(os.path.join(__file__, '..', '..', '..'))
13test_dir = os.getenv("BROTLI_TESTS_PATH")
14BRO_ARGS = [os.getenv("BROTLI_WRAPPER")]
15
16# Fallbacks
17if test_dir is None:
18  test_dir = os.path.join(project_dir, 'tests')
19if BRO_ARGS[0] is None:
20  python_exe = sys.executable or 'python'
21  bro_path = os.path.join(project_dir, 'python', 'bro.py')
22  BRO_ARGS = [python_exe, bro_path]
23
24# Get the platform/version-specific build folder.
25# By default, the distutils build base is in the same location as setup.py.
26platform_lib_name = 'lib.{platform}-{version[0]}.{version[1]}'.format(
27    platform=sysconfig.get_platform(), version=sys.version_info)
28build_dir = os.path.join(project_dir, 'bin', platform_lib_name)
29
30# Prepend the build folder to sys.path and the PYTHONPATH environment variable.
31if build_dir not in sys.path:
32    sys.path.insert(0, build_dir)
33TEST_ENV = os.environ.copy()
34if 'PYTHONPATH' not in TEST_ENV:
35    TEST_ENV['PYTHONPATH'] = build_dir
36else:
37    TEST_ENV['PYTHONPATH'] = build_dir + os.pathsep + TEST_ENV['PYTHONPATH']
38
39TESTDATA_DIR = os.path.join(test_dir, 'testdata')
40
41TESTDATA_FILES = [
42    'empty',  # Empty file
43    '10x10y',  # Small text
44    'alice29.txt',  # Large text
45    'random_org_10k.bin',  # Small data
46    'mapsdatazrh',  # Large data
47    'ukkonooa',  # Poem
48    'cp1251-utf16le',  # Codepage 1251 table saved in UTF16-LE encoding
49    'cp852-utf8',  # Codepage 852 table saved in UTF8 encoding
50]
51
52# Some files might be missing in a lightweight sources pack.
53TESTDATA_PATH_CANDIDATES = [
54    os.path.join(TESTDATA_DIR, f) for f in TESTDATA_FILES
55]
56
57TESTDATA_PATHS = [
58    path for path in TESTDATA_PATH_CANDIDATES if os.path.isfile(path)
59]
60
61TESTDATA_PATHS_FOR_DECOMPRESSION = glob.glob(
62    os.path.join(TESTDATA_DIR, '*.compressed'))
63
64TEMP_DIR = tempfile.mkdtemp()
65
66
67def get_temp_compressed_name(filename):
68    return os.path.join(TEMP_DIR, os.path.basename(filename + '.bro'))
69
70
71def get_temp_uncompressed_name(filename):
72    return os.path.join(TEMP_DIR, os.path.basename(filename + '.unbro'))
73
74
75def bind_method_args(method, *args, **kwargs):
76    return lambda self: method(self, *args, **kwargs)
77
78
79def generate_test_methods(test_case_class,
80                          for_decompression=False,
81                          variants=None):
82    # Add test methods for each test data file.  This makes identifying problems
83    # with specific compression scenarios easier.
84    if for_decompression:
85        paths = TESTDATA_PATHS_FOR_DECOMPRESSION
86    else:
87        paths = TESTDATA_PATHS
88    opts = []
89    if variants:
90        opts_list = []
91        for k, v in variants.items():
92            opts_list.append([r for r in itertools.product([k], v)])
93        for o in itertools.product(*opts_list):
94            opts_name = '_'.join([str(i) for i in itertools.chain(*o)])
95            opts_dict = dict(o)
96            opts.append([opts_name, opts_dict])
97    else:
98        opts.append(['', {}])
99    for method in [m for m in dir(test_case_class) if m.startswith('_test')]:
100        for testdata in paths:
101            for (opts_name, opts_dict) in opts:
102                f = os.path.splitext(os.path.basename(testdata))[0]
103                name = 'test_{method}_{options}_{file}'.format(
104                    method=method, options=opts_name, file=f)
105                func = bind_method_args(
106                    getattr(test_case_class, method), testdata, **opts_dict)
107                setattr(test_case_class, name, func)
108
109
110class TestCase(unittest.TestCase):
111
112    def tearDown(self):
113        for f in TESTDATA_PATHS:
114            try:
115                os.unlink(get_temp_compressed_name(f))
116            except OSError:
117                pass
118            try:
119                os.unlink(get_temp_uncompressed_name(f))
120            except OSError:
121                pass
122
123    def assertFilesMatch(self, first, second):
124        self.assertTrue(
125            filecmp.cmp(first, second, shallow=False),
126            'File {} differs from {}'.format(first, second))
127