1"""Loading unittests.""" 2 3import os 4import re 5import sys 6import traceback 7import types 8import functools 9 10from fnmatch import fnmatch, fnmatchcase 11 12from . import case, suite, util 13 14__unittest = True 15 16# what about .pyc (etc) 17# we would need to avoid loading the same tests multiple times 18# from '.py', *and* '.pyc' 19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 20 21 22class _FailedTest(case.TestCase): 23 _testMethodName = None 24 25 def __init__(self, method_name, exception): 26 self._exception = exception 27 super(_FailedTest, self).__init__(method_name) 28 29 def __getattr__(self, name): 30 if name != self._testMethodName: 31 return super(_FailedTest, self).__getattr__(name) 32 def testFailure(): 33 raise self._exception 34 return testFailure 35 36 37def _make_failed_import_test(name, suiteClass): 38 message = 'Failed to import test module: %s\n%s' % ( 39 name, traceback.format_exc()) 40 return _make_failed_test(name, ImportError(message), suiteClass, message) 41 42def _make_failed_load_tests(name, exception, suiteClass): 43 message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),) 44 return _make_failed_test( 45 name, exception, suiteClass, message) 46 47def _make_failed_test(methodname, exception, suiteClass, message): 48 test = _FailedTest(methodname, exception) 49 return suiteClass((test,)), message 50 51def _make_skipped_test(methodname, exception, suiteClass): 52 @case.skip(str(exception)) 53 def testSkipped(self): 54 pass 55 attrs = {methodname: testSkipped} 56 TestClass = type("ModuleSkipped", (case.TestCase,), attrs) 57 return suiteClass((TestClass(methodname),)) 58 59def _splitext(path): 60 return os.path.splitext(path)[0] 61 62 63class TestLoader(object): 64 """ 65 This class is responsible for loading tests according to various criteria 66 and returning them wrapped in a TestSuite 67 """ 68 testMethodPrefix = 'test' 69 sortTestMethodsUsing = staticmethod(util.three_way_cmp) 70 testNamePatterns = None 71 suiteClass = suite.TestSuite 72 _top_level_dir = None 73 74 def __init__(self): 75 super(TestLoader, self).__init__() 76 self.errors = [] 77 # Tracks packages which we have called into via load_tests, to 78 # avoid infinite re-entrancy. 79 self._loading_packages = set() 80 81 def loadTestsFromTestCase(self, testCaseClass): 82 """Return a suite of all test cases contained in testCaseClass""" 83 if issubclass(testCaseClass, suite.TestSuite): 84 raise TypeError("Test cases should not be derived from " 85 "TestSuite. Maybe you meant to derive from " 86 "TestCase?") 87 if testCaseClass in (case.TestCase, case.FunctionTestCase): 88 # We don't load any tests from base types that should not be loaded. 89 testCaseNames = [] 90 else: 91 testCaseNames = self.getTestCaseNames(testCaseClass) 92 if not testCaseNames and hasattr(testCaseClass, 'runTest'): 93 testCaseNames = ['runTest'] 94 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 95 return loaded_suite 96 97 def loadTestsFromModule(self, module, *, pattern=None): 98 """Return a suite of all test cases contained in the given module""" 99 tests = [] 100 for name in dir(module): 101 obj = getattr(module, name) 102 if ( 103 isinstance(obj, type) 104 and issubclass(obj, case.TestCase) 105 and obj not in (case.TestCase, case.FunctionTestCase) 106 ): 107 tests.append(self.loadTestsFromTestCase(obj)) 108 109 load_tests = getattr(module, 'load_tests', None) 110 tests = self.suiteClass(tests) 111 if load_tests is not None: 112 try: 113 return load_tests(self, tests, pattern) 114 except Exception as e: 115 error_case, error_message = _make_failed_load_tests( 116 module.__name__, e, self.suiteClass) 117 self.errors.append(error_message) 118 return error_case 119 return tests 120 121 def loadTestsFromName(self, name, module=None): 122 """Return a suite of all test cases given a string specifier. 123 124 The name may resolve either to a module, a test case class, a 125 test method within a test case class, or a callable object which 126 returns a TestCase or TestSuite instance. 127 128 The method optionally resolves the names relative to a given module. 129 """ 130 parts = name.split('.') 131 error_case, error_message = None, None 132 if module is None: 133 parts_copy = parts[:] 134 while parts_copy: 135 try: 136 module_name = '.'.join(parts_copy) 137 module = __import__(module_name) 138 break 139 except ImportError: 140 next_attribute = parts_copy.pop() 141 # Last error so we can give it to the user if needed. 142 error_case, error_message = _make_failed_import_test( 143 next_attribute, self.suiteClass) 144 if not parts_copy: 145 # Even the top level import failed: report that error. 146 self.errors.append(error_message) 147 return error_case 148 parts = parts[1:] 149 obj = module 150 for part in parts: 151 try: 152 parent, obj = obj, getattr(obj, part) 153 except AttributeError as e: 154 # We can't traverse some part of the name. 155 if (getattr(obj, '__path__', None) is not None 156 and error_case is not None): 157 # This is a package (no __path__ per importlib docs), and we 158 # encountered an error importing something. We cannot tell 159 # the difference between package.WrongNameTestClass and 160 # package.wrong_module_name so we just report the 161 # ImportError - it is more informative. 162 self.errors.append(error_message) 163 return error_case 164 else: 165 # Otherwise, we signal that an AttributeError has occurred. 166 error_case, error_message = _make_failed_test( 167 part, e, self.suiteClass, 168 'Failed to access attribute:\n%s' % ( 169 traceback.format_exc(),)) 170 self.errors.append(error_message) 171 return error_case 172 173 if isinstance(obj, types.ModuleType): 174 return self.loadTestsFromModule(obj) 175 elif ( 176 isinstance(obj, type) 177 and issubclass(obj, case.TestCase) 178 and obj not in (case.TestCase, case.FunctionTestCase) 179 ): 180 return self.loadTestsFromTestCase(obj) 181 elif (isinstance(obj, types.FunctionType) and 182 isinstance(parent, type) and 183 issubclass(parent, case.TestCase)): 184 name = parts[-1] 185 inst = parent(name) 186 # static methods follow a different path 187 if not isinstance(getattr(inst, name), types.FunctionType): 188 return self.suiteClass([inst]) 189 elif isinstance(obj, suite.TestSuite): 190 return obj 191 if callable(obj): 192 test = obj() 193 if isinstance(test, suite.TestSuite): 194 return test 195 elif isinstance(test, case.TestCase): 196 return self.suiteClass([test]) 197 else: 198 raise TypeError("calling %s returned %s, not a test" % 199 (obj, test)) 200 else: 201 raise TypeError("don't know how to make test from: %s" % obj) 202 203 def loadTestsFromNames(self, names, module=None): 204 """Return a suite of all test cases found using the given sequence 205 of string specifiers. See 'loadTestsFromName()'. 206 """ 207 suites = [self.loadTestsFromName(name, module) for name in names] 208 return self.suiteClass(suites) 209 210 def getTestCaseNames(self, testCaseClass): 211 """Return a sorted sequence of method names found within testCaseClass 212 """ 213 def shouldIncludeMethod(attrname): 214 if not attrname.startswith(self.testMethodPrefix): 215 return False 216 testFunc = getattr(testCaseClass, attrname) 217 if not callable(testFunc): 218 return False 219 fullName = f'%s.%s.%s' % ( 220 testCaseClass.__module__, testCaseClass.__qualname__, attrname 221 ) 222 return self.testNamePatterns is None or \ 223 any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns) 224 testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass))) 225 if self.sortTestMethodsUsing: 226 testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing)) 227 return testFnNames 228 229 def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 230 """Find and return all test modules from the specified start 231 directory, recursing into subdirectories to find them and return all 232 tests found within them. Only test files that match the pattern will 233 be loaded. (Using shell style pattern matching.) 234 235 All test modules must be importable from the top level of the project. 236 If the start directory is not the top level directory then the top 237 level directory must be specified separately. 238 239 If a test package name (directory with '__init__.py') matches the 240 pattern then the package will be checked for a 'load_tests' function. If 241 this exists then it will be called with (loader, tests, pattern) unless 242 the package has already had load_tests called from the same discovery 243 invocation, in which case the package module object is not scanned for 244 tests - this ensures that when a package uses discover to further 245 discover child tests that infinite recursion does not happen. 246 247 If load_tests exists then discovery does *not* recurse into the package, 248 load_tests is responsible for loading all tests in the package. 249 250 The pattern is deliberately not stored as a loader attribute so that 251 packages can continue discovery themselves. top_level_dir is stored so 252 load_tests does not need to pass this argument in to loader.discover(). 253 254 Paths are sorted before being imported to ensure reproducible execution 255 order even on filesystems with non-alphabetical ordering like ext3/4. 256 """ 257 original_top_level_dir = self._top_level_dir 258 set_implicit_top = False 259 if top_level_dir is None and self._top_level_dir is not None: 260 # make top_level_dir optional if called from load_tests in a package 261 top_level_dir = self._top_level_dir 262 elif top_level_dir is None: 263 set_implicit_top = True 264 top_level_dir = start_dir 265 266 top_level_dir = os.path.abspath(top_level_dir) 267 268 if not top_level_dir in sys.path: 269 # all test modules must be importable from the top level directory 270 # should we *unconditionally* put the start directory in first 271 # in sys.path to minimise likelihood of conflicts between installed 272 # modules and development versions? 273 sys.path.insert(0, top_level_dir) 274 self._top_level_dir = top_level_dir 275 276 is_not_importable = False 277 if os.path.isdir(os.path.abspath(start_dir)): 278 start_dir = os.path.abspath(start_dir) 279 if start_dir != top_level_dir: 280 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) 281 else: 282 # support for discovery from dotted module names 283 try: 284 __import__(start_dir) 285 except ImportError: 286 is_not_importable = True 287 else: 288 the_module = sys.modules[start_dir] 289 top_part = start_dir.split('.')[0] 290 try: 291 start_dir = os.path.abspath( 292 os.path.dirname((the_module.__file__))) 293 except AttributeError: 294 if the_module.__name__ in sys.builtin_module_names: 295 # builtin module 296 raise TypeError('Can not use builtin modules ' 297 'as dotted module names') from None 298 else: 299 raise TypeError( 300 f"don't know how to discover from {the_module!r}" 301 ) from None 302 303 if set_implicit_top: 304 self._top_level_dir = self._get_directory_containing_module(top_part) 305 sys.path.remove(top_level_dir) 306 307 if is_not_importable: 308 raise ImportError('Start directory is not importable: %r' % start_dir) 309 310 tests = list(self._find_tests(start_dir, pattern)) 311 self._top_level_dir = original_top_level_dir 312 return self.suiteClass(tests) 313 314 def _get_directory_containing_module(self, module_name): 315 module = sys.modules[module_name] 316 full_path = os.path.abspath(module.__file__) 317 318 if os.path.basename(full_path).lower().startswith('__init__.py'): 319 return os.path.dirname(os.path.dirname(full_path)) 320 else: 321 # here we have been given a module rather than a package - so 322 # all we can do is search the *same* directory the module is in 323 # should an exception be raised instead 324 return os.path.dirname(full_path) 325 326 def _get_name_from_path(self, path): 327 if path == self._top_level_dir: 328 return '.' 329 path = _splitext(os.path.normpath(path)) 330 331 _relpath = os.path.relpath(path, self._top_level_dir) 332 assert not os.path.isabs(_relpath), "Path must be within the project" 333 assert not _relpath.startswith('..'), "Path must be within the project" 334 335 name = _relpath.replace(os.path.sep, '.') 336 return name 337 338 def _get_module_from_name(self, name): 339 __import__(name) 340 return sys.modules[name] 341 342 def _match_path(self, path, full_path, pattern): 343 # override this method to use alternative matching strategy 344 return fnmatch(path, pattern) 345 346 def _find_tests(self, start_dir, pattern): 347 """Used by discovery. Yields test suites it loads.""" 348 # Handle the __init__ in this package 349 name = self._get_name_from_path(start_dir) 350 # name is '.' when start_dir == top_level_dir (and top_level_dir is by 351 # definition not a package). 352 if name != '.' and name not in self._loading_packages: 353 # name is in self._loading_packages while we have called into 354 # loadTestsFromModule with name. 355 tests, should_recurse = self._find_test_path(start_dir, pattern) 356 if tests is not None: 357 yield tests 358 if not should_recurse: 359 # Either an error occurred, or load_tests was used by the 360 # package. 361 return 362 # Handle the contents. 363 paths = sorted(os.listdir(start_dir)) 364 for path in paths: 365 full_path = os.path.join(start_dir, path) 366 tests, should_recurse = self._find_test_path(full_path, pattern) 367 if tests is not None: 368 yield tests 369 if should_recurse: 370 # we found a package that didn't use load_tests. 371 name = self._get_name_from_path(full_path) 372 self._loading_packages.add(name) 373 try: 374 yield from self._find_tests(full_path, pattern) 375 finally: 376 self._loading_packages.discard(name) 377 378 def _find_test_path(self, full_path, pattern): 379 """Used by discovery. 380 381 Loads tests from a single file, or a directories' __init__.py when 382 passed the directory. 383 384 Returns a tuple (None_or_tests_from_file, should_recurse). 385 """ 386 basename = os.path.basename(full_path) 387 if os.path.isfile(full_path): 388 if not VALID_MODULE_NAME.match(basename): 389 # valid Python identifiers only 390 return None, False 391 if not self._match_path(basename, full_path, pattern): 392 return None, False 393 # if the test file matches, load it 394 name = self._get_name_from_path(full_path) 395 try: 396 module = self._get_module_from_name(name) 397 except case.SkipTest as e: 398 return _make_skipped_test(name, e, self.suiteClass), False 399 except: 400 error_case, error_message = \ 401 _make_failed_import_test(name, self.suiteClass) 402 self.errors.append(error_message) 403 return error_case, False 404 else: 405 mod_file = os.path.abspath( 406 getattr(module, '__file__', full_path)) 407 realpath = _splitext( 408 os.path.realpath(mod_file)) 409 fullpath_noext = _splitext( 410 os.path.realpath(full_path)) 411 if realpath.lower() != fullpath_noext.lower(): 412 module_dir = os.path.dirname(realpath) 413 mod_name = _splitext( 414 os.path.basename(full_path)) 415 expected_dir = os.path.dirname(full_path) 416 msg = ("%r module incorrectly imported from %r. Expected " 417 "%r. Is this module globally installed?") 418 raise ImportError( 419 msg % (mod_name, module_dir, expected_dir)) 420 return self.loadTestsFromModule(module, pattern=pattern), False 421 elif os.path.isdir(full_path): 422 if not os.path.isfile(os.path.join(full_path, '__init__.py')): 423 return None, False 424 425 load_tests = None 426 tests = None 427 name = self._get_name_from_path(full_path) 428 try: 429 package = self._get_module_from_name(name) 430 except case.SkipTest as e: 431 return _make_skipped_test(name, e, self.suiteClass), False 432 except: 433 error_case, error_message = \ 434 _make_failed_import_test(name, self.suiteClass) 435 self.errors.append(error_message) 436 return error_case, False 437 else: 438 load_tests = getattr(package, 'load_tests', None) 439 # Mark this package as being in load_tests (possibly ;)) 440 self._loading_packages.add(name) 441 try: 442 tests = self.loadTestsFromModule(package, pattern=pattern) 443 if load_tests is not None: 444 # loadTestsFromModule(package) has loaded tests for us. 445 return tests, False 446 return tests, True 447 finally: 448 self._loading_packages.discard(name) 449 else: 450 return None, False 451 452 453defaultTestLoader = TestLoader() 454