• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Loading unittests."""
2
3import os
4import re
5import sys
6import traceback
7import types
8import functools
9
10from fnmatch import fnmatch, fnmatchcase
11
12from . import case, suite, util
13
14__unittest = True
15
16# what about .pyc (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', *and* '.pyc'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22class _FailedTest(case.TestCase):
23    _testMethodName = None
24
25    def __init__(self, method_name, exception):
26        self._exception = exception
27        super(_FailedTest, self).__init__(method_name)
28
29    def __getattr__(self, name):
30        if name != self._testMethodName:
31            return super(_FailedTest, self).__getattr__(name)
32        def testFailure():
33            raise self._exception
34        return testFailure
35
36
37def _make_failed_import_test(name, suiteClass):
38    message = 'Failed to import test module: %s\n%s' % (
39        name, traceback.format_exc())
40    return _make_failed_test(name, ImportError(message), suiteClass, message)
41
42def _make_failed_load_tests(name, exception, suiteClass):
43    message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
44    return _make_failed_test(
45        name, exception, suiteClass, message)
46
47def _make_failed_test(methodname, exception, suiteClass, message):
48    test = _FailedTest(methodname, exception)
49    return suiteClass((test,)), message
50
51def _make_skipped_test(methodname, exception, suiteClass):
52    @case.skip(str(exception))
53    def testSkipped(self):
54        pass
55    attrs = {methodname: testSkipped}
56    TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
57    return suiteClass((TestClass(methodname),))
58
59def _splitext(path):
60    return os.path.splitext(path)[0]
61
62
63class TestLoader(object):
64    """
65    This class is responsible for loading tests according to various criteria
66    and returning them wrapped in a TestSuite
67    """
68    testMethodPrefix = 'test'
69    sortTestMethodsUsing = staticmethod(util.three_way_cmp)
70    testNamePatterns = None
71    suiteClass = suite.TestSuite
72    _top_level_dir = None
73
74    def __init__(self):
75        super(TestLoader, self).__init__()
76        self.errors = []
77        # Tracks packages which we have called into via load_tests, to
78        # avoid infinite re-entrancy.
79        self._loading_packages = set()
80
81    def loadTestsFromTestCase(self, testCaseClass):
82        """Return a suite of all test cases contained in testCaseClass"""
83        if issubclass(testCaseClass, suite.TestSuite):
84            raise TypeError("Test cases should not be derived from "
85                            "TestSuite. Maybe you meant to derive from "
86                            "TestCase?")
87        if testCaseClass in (case.TestCase, case.FunctionTestCase):
88            # We don't load any tests from base types that should not be loaded.
89            testCaseNames = []
90        else:
91            testCaseNames = self.getTestCaseNames(testCaseClass)
92            if not testCaseNames and hasattr(testCaseClass, 'runTest'):
93                testCaseNames = ['runTest']
94        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
95        return loaded_suite
96
97    def loadTestsFromModule(self, module, *, pattern=None):
98        """Return a suite of all test cases contained in the given module"""
99        tests = []
100        for name in dir(module):
101            obj = getattr(module, name)
102            if (
103                isinstance(obj, type)
104                and issubclass(obj, case.TestCase)
105                and obj not in (case.TestCase, case.FunctionTestCase)
106            ):
107                tests.append(self.loadTestsFromTestCase(obj))
108
109        load_tests = getattr(module, 'load_tests', None)
110        tests = self.suiteClass(tests)
111        if load_tests is not None:
112            try:
113                return load_tests(self, tests, pattern)
114            except Exception as e:
115                error_case, error_message = _make_failed_load_tests(
116                    module.__name__, e, self.suiteClass)
117                self.errors.append(error_message)
118                return error_case
119        return tests
120
121    def loadTestsFromName(self, name, module=None):
122        """Return a suite of all test cases given a string specifier.
123
124        The name may resolve either to a module, a test case class, a
125        test method within a test case class, or a callable object which
126        returns a TestCase or TestSuite instance.
127
128        The method optionally resolves the names relative to a given module.
129        """
130        parts = name.split('.')
131        error_case, error_message = None, None
132        if module is None:
133            parts_copy = parts[:]
134            while parts_copy:
135                try:
136                    module_name = '.'.join(parts_copy)
137                    module = __import__(module_name)
138                    break
139                except ImportError:
140                    next_attribute = parts_copy.pop()
141                    # Last error so we can give it to the user if needed.
142                    error_case, error_message = _make_failed_import_test(
143                        next_attribute, self.suiteClass)
144                    if not parts_copy:
145                        # Even the top level import failed: report that error.
146                        self.errors.append(error_message)
147                        return error_case
148            parts = parts[1:]
149        obj = module
150        for part in parts:
151            try:
152                parent, obj = obj, getattr(obj, part)
153            except AttributeError as e:
154                # We can't traverse some part of the name.
155                if (getattr(obj, '__path__', None) is not None
156                    and error_case is not None):
157                    # This is a package (no __path__ per importlib docs), and we
158                    # encountered an error importing something. We cannot tell
159                    # the difference between package.WrongNameTestClass and
160                    # package.wrong_module_name so we just report the
161                    # ImportError - it is more informative.
162                    self.errors.append(error_message)
163                    return error_case
164                else:
165                    # Otherwise, we signal that an AttributeError has occurred.
166                    error_case, error_message = _make_failed_test(
167                        part, e, self.suiteClass,
168                        'Failed to access attribute:\n%s' % (
169                            traceback.format_exc(),))
170                    self.errors.append(error_message)
171                    return error_case
172
173        if isinstance(obj, types.ModuleType):
174            return self.loadTestsFromModule(obj)
175        elif (
176            isinstance(obj, type)
177            and issubclass(obj, case.TestCase)
178            and obj not in (case.TestCase, case.FunctionTestCase)
179        ):
180            return self.loadTestsFromTestCase(obj)
181        elif (isinstance(obj, types.FunctionType) and
182              isinstance(parent, type) and
183              issubclass(parent, case.TestCase)):
184            name = parts[-1]
185            inst = parent(name)
186            # static methods follow a different path
187            if not isinstance(getattr(inst, name), types.FunctionType):
188                return self.suiteClass([inst])
189        elif isinstance(obj, suite.TestSuite):
190            return obj
191        if callable(obj):
192            test = obj()
193            if isinstance(test, suite.TestSuite):
194                return test
195            elif isinstance(test, case.TestCase):
196                return self.suiteClass([test])
197            else:
198                raise TypeError("calling %s returned %s, not a test" %
199                                (obj, test))
200        else:
201            raise TypeError("don't know how to make test from: %s" % obj)
202
203    def loadTestsFromNames(self, names, module=None):
204        """Return a suite of all test cases found using the given sequence
205        of string specifiers. See 'loadTestsFromName()'.
206        """
207        suites = [self.loadTestsFromName(name, module) for name in names]
208        return self.suiteClass(suites)
209
210    def getTestCaseNames(self, testCaseClass):
211        """Return a sorted sequence of method names found within testCaseClass
212        """
213        def shouldIncludeMethod(attrname):
214            if not attrname.startswith(self.testMethodPrefix):
215                return False
216            testFunc = getattr(testCaseClass, attrname)
217            if not callable(testFunc):
218                return False
219            fullName = f'%s.%s.%s' % (
220                testCaseClass.__module__, testCaseClass.__qualname__, attrname
221            )
222            return self.testNamePatterns is None or \
223                any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns)
224        testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass)))
225        if self.sortTestMethodsUsing:
226            testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
227        return testFnNames
228
229    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
230        """Find and return all test modules from the specified start
231        directory, recursing into subdirectories to find them and return all
232        tests found within them. Only test files that match the pattern will
233        be loaded. (Using shell style pattern matching.)
234
235        All test modules must be importable from the top level of the project.
236        If the start directory is not the top level directory then the top
237        level directory must be specified separately.
238
239        If a test package name (directory with '__init__.py') matches the
240        pattern then the package will be checked for a 'load_tests' function. If
241        this exists then it will be called with (loader, tests, pattern) unless
242        the package has already had load_tests called from the same discovery
243        invocation, in which case the package module object is not scanned for
244        tests - this ensures that when a package uses discover to further
245        discover child tests that infinite recursion does not happen.
246
247        If load_tests exists then discovery does *not* recurse into the package,
248        load_tests is responsible for loading all tests in the package.
249
250        The pattern is deliberately not stored as a loader attribute so that
251        packages can continue discovery themselves. top_level_dir is stored so
252        load_tests does not need to pass this argument in to loader.discover().
253
254        Paths are sorted before being imported to ensure reproducible execution
255        order even on filesystems with non-alphabetical ordering like ext3/4.
256        """
257        original_top_level_dir = self._top_level_dir
258        set_implicit_top = False
259        if top_level_dir is None and self._top_level_dir is not None:
260            # make top_level_dir optional if called from load_tests in a package
261            top_level_dir = self._top_level_dir
262        elif top_level_dir is None:
263            set_implicit_top = True
264            top_level_dir = start_dir
265
266        top_level_dir = os.path.abspath(top_level_dir)
267
268        if not top_level_dir in sys.path:
269            # all test modules must be importable from the top level directory
270            # should we *unconditionally* put the start directory in first
271            # in sys.path to minimise likelihood of conflicts between installed
272            # modules and development versions?
273            sys.path.insert(0, top_level_dir)
274        self._top_level_dir = top_level_dir
275
276        is_not_importable = False
277        if os.path.isdir(os.path.abspath(start_dir)):
278            start_dir = os.path.abspath(start_dir)
279            if start_dir != top_level_dir:
280                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
281        else:
282            # support for discovery from dotted module names
283            try:
284                __import__(start_dir)
285            except ImportError:
286                is_not_importable = True
287            else:
288                the_module = sys.modules[start_dir]
289                top_part = start_dir.split('.')[0]
290                try:
291                    start_dir = os.path.abspath(
292                        os.path.dirname((the_module.__file__)))
293                except AttributeError:
294                    if the_module.__name__ in sys.builtin_module_names:
295                        # builtin module
296                        raise TypeError('Can not use builtin modules '
297                                        'as dotted module names') from None
298                    else:
299                        raise TypeError(
300                            f"don't know how to discover from {the_module!r}"
301                            ) from None
302
303                if set_implicit_top:
304                    self._top_level_dir = self._get_directory_containing_module(top_part)
305                    sys.path.remove(top_level_dir)
306
307        if is_not_importable:
308            raise ImportError('Start directory is not importable: %r' % start_dir)
309
310        tests = list(self._find_tests(start_dir, pattern))
311        self._top_level_dir = original_top_level_dir
312        return self.suiteClass(tests)
313
314    def _get_directory_containing_module(self, module_name):
315        module = sys.modules[module_name]
316        full_path = os.path.abspath(module.__file__)
317
318        if os.path.basename(full_path).lower().startswith('__init__.py'):
319            return os.path.dirname(os.path.dirname(full_path))
320        else:
321            # here we have been given a module rather than a package - so
322            # all we can do is search the *same* directory the module is in
323            # should an exception be raised instead
324            return os.path.dirname(full_path)
325
326    def _get_name_from_path(self, path):
327        if path == self._top_level_dir:
328            return '.'
329        path = _splitext(os.path.normpath(path))
330
331        _relpath = os.path.relpath(path, self._top_level_dir)
332        assert not os.path.isabs(_relpath), "Path must be within the project"
333        assert not _relpath.startswith('..'), "Path must be within the project"
334
335        name = _relpath.replace(os.path.sep, '.')
336        return name
337
338    def _get_module_from_name(self, name):
339        __import__(name)
340        return sys.modules[name]
341
342    def _match_path(self, path, full_path, pattern):
343        # override this method to use alternative matching strategy
344        return fnmatch(path, pattern)
345
346    def _find_tests(self, start_dir, pattern):
347        """Used by discovery. Yields test suites it loads."""
348        # Handle the __init__ in this package
349        name = self._get_name_from_path(start_dir)
350        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
351        # definition not a package).
352        if name != '.' and name not in self._loading_packages:
353            # name is in self._loading_packages while we have called into
354            # loadTestsFromModule with name.
355            tests, should_recurse = self._find_test_path(start_dir, pattern)
356            if tests is not None:
357                yield tests
358            if not should_recurse:
359                # Either an error occurred, or load_tests was used by the
360                # package.
361                return
362        # Handle the contents.
363        paths = sorted(os.listdir(start_dir))
364        for path in paths:
365            full_path = os.path.join(start_dir, path)
366            tests, should_recurse = self._find_test_path(full_path, pattern)
367            if tests is not None:
368                yield tests
369            if should_recurse:
370                # we found a package that didn't use load_tests.
371                name = self._get_name_from_path(full_path)
372                self._loading_packages.add(name)
373                try:
374                    yield from self._find_tests(full_path, pattern)
375                finally:
376                    self._loading_packages.discard(name)
377
378    def _find_test_path(self, full_path, pattern):
379        """Used by discovery.
380
381        Loads tests from a single file, or a directories' __init__.py when
382        passed the directory.
383
384        Returns a tuple (None_or_tests_from_file, should_recurse).
385        """
386        basename = os.path.basename(full_path)
387        if os.path.isfile(full_path):
388            if not VALID_MODULE_NAME.match(basename):
389                # valid Python identifiers only
390                return None, False
391            if not self._match_path(basename, full_path, pattern):
392                return None, False
393            # if the test file matches, load it
394            name = self._get_name_from_path(full_path)
395            try:
396                module = self._get_module_from_name(name)
397            except case.SkipTest as e:
398                return _make_skipped_test(name, e, self.suiteClass), False
399            except:
400                error_case, error_message = \
401                    _make_failed_import_test(name, self.suiteClass)
402                self.errors.append(error_message)
403                return error_case, False
404            else:
405                mod_file = os.path.abspath(
406                    getattr(module, '__file__', full_path))
407                realpath = _splitext(
408                    os.path.realpath(mod_file))
409                fullpath_noext = _splitext(
410                    os.path.realpath(full_path))
411                if realpath.lower() != fullpath_noext.lower():
412                    module_dir = os.path.dirname(realpath)
413                    mod_name = _splitext(
414                        os.path.basename(full_path))
415                    expected_dir = os.path.dirname(full_path)
416                    msg = ("%r module incorrectly imported from %r. Expected "
417                           "%r. Is this module globally installed?")
418                    raise ImportError(
419                        msg % (mod_name, module_dir, expected_dir))
420                return self.loadTestsFromModule(module, pattern=pattern), False
421        elif os.path.isdir(full_path):
422            if not os.path.isfile(os.path.join(full_path, '__init__.py')):
423                return None, False
424
425            load_tests = None
426            tests = None
427            name = self._get_name_from_path(full_path)
428            try:
429                package = self._get_module_from_name(name)
430            except case.SkipTest as e:
431                return _make_skipped_test(name, e, self.suiteClass), False
432            except:
433                error_case, error_message = \
434                    _make_failed_import_test(name, self.suiteClass)
435                self.errors.append(error_message)
436                return error_case, False
437            else:
438                load_tests = getattr(package, 'load_tests', None)
439                # Mark this package as being in load_tests (possibly ;))
440                self._loading_packages.add(name)
441                try:
442                    tests = self.loadTestsFromModule(package, pattern=pattern)
443                    if load_tests is not None:
444                        # loadTestsFromModule(package) has loaded tests for us.
445                        return tests, False
446                    return tests, True
447                finally:
448                    self._loading_packages.discard(name)
449        else:
450            return None, False
451
452
453defaultTestLoader = TestLoader()
454