1"""Loading unittests.""" 2 3import os 4import re 5import sys 6import traceback 7import types 8import functools 9import warnings 10 11from fnmatch import fnmatch 12 13from . import case, suite, util 14 15__unittest = True 16 17# what about .pyc (etc) 18# we would need to avoid loading the same tests multiple times 19# from '.py', *and* '.pyc' 20VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 21 22 23class _FailedTest(case.TestCase): 24 _testMethodName = None 25 26 def __init__(self, method_name, exception): 27 self._exception = exception 28 super(_FailedTest, self).__init__(method_name) 29 30 def __getattr__(self, name): 31 if name != self._testMethodName: 32 return super(_FailedTest, self).__getattr__(name) 33 def testFailure(): 34 raise self._exception 35 return testFailure 36 37 38def _make_failed_import_test(name, suiteClass): 39 message = 'Failed to import test module: %s\n%s' % ( 40 name, traceback.format_exc()) 41 return _make_failed_test(name, ImportError(message), suiteClass, message) 42 43def _make_failed_load_tests(name, exception, suiteClass): 44 message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),) 45 return _make_failed_test( 46 name, exception, suiteClass, message) 47 48def _make_failed_test(methodname, exception, suiteClass, message): 49 test = _FailedTest(methodname, exception) 50 return suiteClass((test,)), message 51 52def _make_skipped_test(methodname, exception, suiteClass): 53 @case.skip(str(exception)) 54 def testSkipped(self): 55 pass 56 attrs = {methodname: testSkipped} 57 TestClass = type("ModuleSkipped", (case.TestCase,), attrs) 58 return suiteClass((TestClass(methodname),)) 59 60def _jython_aware_splitext(path): 61 if path.lower().endswith('$py.class'): 62 return path[:-9] 63 return os.path.splitext(path)[0] 64 65 66class TestLoader(object): 67 """ 68 This class is responsible for loading tests according to various criteria 69 and returning them wrapped in a TestSuite 70 """ 71 testMethodPrefix = 'test' 72 sortTestMethodsUsing = staticmethod(util.three_way_cmp) 73 suiteClass = suite.TestSuite 74 _top_level_dir = None 75 76 def __init__(self): 77 super(TestLoader, self).__init__() 78 self.errors = [] 79 # Tracks packages which we have called into via load_tests, to 80 # avoid infinite re-entrancy. 81 self._loading_packages = set() 82 83 def loadTestsFromTestCase(self, testCaseClass): 84 """Return a suite of all test cases contained in testCaseClass""" 85 if issubclass(testCaseClass, suite.TestSuite): 86 raise TypeError("Test cases should not be derived from " 87 "TestSuite. Maybe you meant to derive from " 88 "TestCase?") 89 testCaseNames = self.getTestCaseNames(testCaseClass) 90 if not testCaseNames and hasattr(testCaseClass, 'runTest'): 91 testCaseNames = ['runTest'] 92 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 93 return loaded_suite 94 95 # XXX After Python 3.5, remove backward compatibility hacks for 96 # use_load_tests deprecation via *args and **kws. See issue 16662. 97 def loadTestsFromModule(self, module, *args, pattern=None, **kws): 98 """Return a suite of all test cases contained in the given module""" 99 # This method used to take an undocumented and unofficial 100 # use_load_tests argument. For backward compatibility, we still 101 # accept the argument (which can also be the first position) but we 102 # ignore it and issue a deprecation warning if it's present. 103 if len(args) > 0 or 'use_load_tests' in kws: 104 warnings.warn('use_load_tests is deprecated and ignored', 105 DeprecationWarning) 106 kws.pop('use_load_tests', None) 107 if len(args) > 1: 108 # Complain about the number of arguments, but don't forget the 109 # required `module` argument. 110 complaint = len(args) + 1 111 raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint)) 112 if len(kws) != 0: 113 # Since the keyword arguments are unsorted (see PEP 468), just 114 # pick the alphabetically sorted first argument to complain about, 115 # if multiple were given. At least the error message will be 116 # predictable. 117 complaint = sorted(kws)[0] 118 raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint)) 119 tests = [] 120 for name in dir(module): 121 obj = getattr(module, name) 122 if isinstance(obj, type) and issubclass(obj, case.TestCase): 123 tests.append(self.loadTestsFromTestCase(obj)) 124 125 load_tests = getattr(module, 'load_tests', None) 126 tests = self.suiteClass(tests) 127 if load_tests is not None: 128 try: 129 return load_tests(self, tests, pattern) 130 except Exception as e: 131 error_case, error_message = _make_failed_load_tests( 132 module.__name__, e, self.suiteClass) 133 self.errors.append(error_message) 134 return error_case 135 return tests 136 137 def loadTestsFromName(self, name, module=None): 138 """Return a suite of all test cases given a string specifier. 139 140 The name may resolve either to a module, a test case class, a 141 test method within a test case class, or a callable object which 142 returns a TestCase or TestSuite instance. 143 144 The method optionally resolves the names relative to a given module. 145 """ 146 parts = name.split('.') 147 error_case, error_message = None, None 148 if module is None: 149 parts_copy = parts[:] 150 while parts_copy: 151 try: 152 module_name = '.'.join(parts_copy) 153 module = __import__(module_name) 154 break 155 except ImportError: 156 next_attribute = parts_copy.pop() 157 # Last error so we can give it to the user if needed. 158 error_case, error_message = _make_failed_import_test( 159 next_attribute, self.suiteClass) 160 if not parts_copy: 161 # Even the top level import failed: report that error. 162 self.errors.append(error_message) 163 return error_case 164 parts = parts[1:] 165 obj = module 166 for part in parts: 167 try: 168 parent, obj = obj, getattr(obj, part) 169 except AttributeError as e: 170 # We can't traverse some part of the name. 171 if (getattr(obj, '__path__', None) is not None 172 and error_case is not None): 173 # This is a package (no __path__ per importlib docs), and we 174 # encountered an error importing something. We cannot tell 175 # the difference between package.WrongNameTestClass and 176 # package.wrong_module_name so we just report the 177 # ImportError - it is more informative. 178 self.errors.append(error_message) 179 return error_case 180 else: 181 # Otherwise, we signal that an AttributeError has occurred. 182 error_case, error_message = _make_failed_test( 183 part, e, self.suiteClass, 184 'Failed to access attribute:\n%s' % ( 185 traceback.format_exc(),)) 186 self.errors.append(error_message) 187 return error_case 188 189 if isinstance(obj, types.ModuleType): 190 return self.loadTestsFromModule(obj) 191 elif isinstance(obj, type) and issubclass(obj, case.TestCase): 192 return self.loadTestsFromTestCase(obj) 193 elif (isinstance(obj, types.FunctionType) and 194 isinstance(parent, type) and 195 issubclass(parent, case.TestCase)): 196 name = parts[-1] 197 inst = parent(name) 198 # static methods follow a different path 199 if not isinstance(getattr(inst, name), types.FunctionType): 200 return self.suiteClass([inst]) 201 elif isinstance(obj, suite.TestSuite): 202 return obj 203 if callable(obj): 204 test = obj() 205 if isinstance(test, suite.TestSuite): 206 return test 207 elif isinstance(test, case.TestCase): 208 return self.suiteClass([test]) 209 else: 210 raise TypeError("calling %s returned %s, not a test" % 211 (obj, test)) 212 else: 213 raise TypeError("don't know how to make test from: %s" % obj) 214 215 def loadTestsFromNames(self, names, module=None): 216 """Return a suite of all test cases found using the given sequence 217 of string specifiers. See 'loadTestsFromName()'. 218 """ 219 suites = [self.loadTestsFromName(name, module) for name in names] 220 return self.suiteClass(suites) 221 222 def getTestCaseNames(self, testCaseClass): 223 """Return a sorted sequence of method names found within testCaseClass 224 """ 225 def isTestMethod(attrname, testCaseClass=testCaseClass, 226 prefix=self.testMethodPrefix): 227 return attrname.startswith(prefix) and \ 228 callable(getattr(testCaseClass, attrname)) 229 testFnNames = list(filter(isTestMethod, dir(testCaseClass))) 230 if self.sortTestMethodsUsing: 231 testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing)) 232 return testFnNames 233 234 def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 235 """Find and return all test modules from the specified start 236 directory, recursing into subdirectories to find them and return all 237 tests found within them. Only test files that match the pattern will 238 be loaded. (Using shell style pattern matching.) 239 240 All test modules must be importable from the top level of the project. 241 If the start directory is not the top level directory then the top 242 level directory must be specified separately. 243 244 If a test package name (directory with '__init__.py') matches the 245 pattern then the package will be checked for a 'load_tests' function. If 246 this exists then it will be called with (loader, tests, pattern) unless 247 the package has already had load_tests called from the same discovery 248 invocation, in which case the package module object is not scanned for 249 tests - this ensures that when a package uses discover to further 250 discover child tests that infinite recursion does not happen. 251 252 If load_tests exists then discovery does *not* recurse into the package, 253 load_tests is responsible for loading all tests in the package. 254 255 The pattern is deliberately not stored as a loader attribute so that 256 packages can continue discovery themselves. top_level_dir is stored so 257 load_tests does not need to pass this argument in to loader.discover(). 258 259 Paths are sorted before being imported to ensure reproducible execution 260 order even on filesystems with non-alphabetical ordering like ext3/4. 261 """ 262 set_implicit_top = False 263 if top_level_dir is None and self._top_level_dir is not None: 264 # make top_level_dir optional if called from load_tests in a package 265 top_level_dir = self._top_level_dir 266 elif top_level_dir is None: 267 set_implicit_top = True 268 top_level_dir = start_dir 269 270 top_level_dir = os.path.abspath(top_level_dir) 271 272 if not top_level_dir in sys.path: 273 # all test modules must be importable from the top level directory 274 # should we *unconditionally* put the start directory in first 275 # in sys.path to minimise likelihood of conflicts between installed 276 # modules and development versions? 277 sys.path.insert(0, top_level_dir) 278 self._top_level_dir = top_level_dir 279 280 is_not_importable = False 281 is_namespace = False 282 tests = [] 283 if os.path.isdir(os.path.abspath(start_dir)): 284 start_dir = os.path.abspath(start_dir) 285 if start_dir != top_level_dir: 286 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) 287 else: 288 # support for discovery from dotted module names 289 try: 290 __import__(start_dir) 291 except ImportError: 292 is_not_importable = True 293 else: 294 the_module = sys.modules[start_dir] 295 top_part = start_dir.split('.')[0] 296 try: 297 start_dir = os.path.abspath( 298 os.path.dirname((the_module.__file__))) 299 except AttributeError: 300 # look for namespace packages 301 try: 302 spec = the_module.__spec__ 303 except AttributeError: 304 spec = None 305 306 if spec and spec.loader is None: 307 if spec.submodule_search_locations is not None: 308 is_namespace = True 309 310 for path in the_module.__path__: 311 if (not set_implicit_top and 312 not path.startswith(top_level_dir)): 313 continue 314 self._top_level_dir = \ 315 (path.split(the_module.__name__ 316 .replace(".", os.path.sep))[0]) 317 tests.extend(self._find_tests(path, 318 pattern, 319 namespace=True)) 320 elif the_module.__name__ in sys.builtin_module_names: 321 # builtin module 322 raise TypeError('Can not use builtin modules ' 323 'as dotted module names') from None 324 else: 325 raise TypeError( 326 'don\'t know how to discover from {!r}' 327 .format(the_module)) from None 328 329 if set_implicit_top: 330 if not is_namespace: 331 self._top_level_dir = \ 332 self._get_directory_containing_module(top_part) 333 sys.path.remove(top_level_dir) 334 else: 335 sys.path.remove(top_level_dir) 336 337 if is_not_importable: 338 raise ImportError('Start directory is not importable: %r' % start_dir) 339 340 if not is_namespace: 341 tests = list(self._find_tests(start_dir, pattern)) 342 return self.suiteClass(tests) 343 344 def _get_directory_containing_module(self, module_name): 345 module = sys.modules[module_name] 346 full_path = os.path.abspath(module.__file__) 347 348 if os.path.basename(full_path).lower().startswith('__init__.py'): 349 return os.path.dirname(os.path.dirname(full_path)) 350 else: 351 # here we have been given a module rather than a package - so 352 # all we can do is search the *same* directory the module is in 353 # should an exception be raised instead 354 return os.path.dirname(full_path) 355 356 def _get_name_from_path(self, path): 357 if path == self._top_level_dir: 358 return '.' 359 path = _jython_aware_splitext(os.path.normpath(path)) 360 361 _relpath = os.path.relpath(path, self._top_level_dir) 362 assert not os.path.isabs(_relpath), "Path must be within the project" 363 assert not _relpath.startswith('..'), "Path must be within the project" 364 365 name = _relpath.replace(os.path.sep, '.') 366 return name 367 368 def _get_module_from_name(self, name): 369 __import__(name) 370 return sys.modules[name] 371 372 def _match_path(self, path, full_path, pattern): 373 # override this method to use alternative matching strategy 374 return fnmatch(path, pattern) 375 376 def _find_tests(self, start_dir, pattern, namespace=False): 377 """Used by discovery. Yields test suites it loads.""" 378 # Handle the __init__ in this package 379 name = self._get_name_from_path(start_dir) 380 # name is '.' when start_dir == top_level_dir (and top_level_dir is by 381 # definition not a package). 382 if name != '.' and name not in self._loading_packages: 383 # name is in self._loading_packages while we have called into 384 # loadTestsFromModule with name. 385 tests, should_recurse = self._find_test_path( 386 start_dir, pattern, namespace) 387 if tests is not None: 388 yield tests 389 if not should_recurse: 390 # Either an error occurred, or load_tests was used by the 391 # package. 392 return 393 # Handle the contents. 394 paths = sorted(os.listdir(start_dir)) 395 for path in paths: 396 full_path = os.path.join(start_dir, path) 397 tests, should_recurse = self._find_test_path( 398 full_path, pattern, namespace) 399 if tests is not None: 400 yield tests 401 if should_recurse: 402 # we found a package that didn't use load_tests. 403 name = self._get_name_from_path(full_path) 404 self._loading_packages.add(name) 405 try: 406 yield from self._find_tests(full_path, pattern, namespace) 407 finally: 408 self._loading_packages.discard(name) 409 410 def _find_test_path(self, full_path, pattern, namespace=False): 411 """Used by discovery. 412 413 Loads tests from a single file, or a directories' __init__.py when 414 passed the directory. 415 416 Returns a tuple (None_or_tests_from_file, should_recurse). 417 """ 418 basename = os.path.basename(full_path) 419 if os.path.isfile(full_path): 420 if not VALID_MODULE_NAME.match(basename): 421 # valid Python identifiers only 422 return None, False 423 if not self._match_path(basename, full_path, pattern): 424 return None, False 425 # if the test file matches, load it 426 name = self._get_name_from_path(full_path) 427 try: 428 module = self._get_module_from_name(name) 429 except case.SkipTest as e: 430 return _make_skipped_test(name, e, self.suiteClass), False 431 except: 432 error_case, error_message = \ 433 _make_failed_import_test(name, self.suiteClass) 434 self.errors.append(error_message) 435 return error_case, False 436 else: 437 mod_file = os.path.abspath( 438 getattr(module, '__file__', full_path)) 439 realpath = _jython_aware_splitext( 440 os.path.realpath(mod_file)) 441 fullpath_noext = _jython_aware_splitext( 442 os.path.realpath(full_path)) 443 if realpath.lower() != fullpath_noext.lower(): 444 module_dir = os.path.dirname(realpath) 445 mod_name = _jython_aware_splitext( 446 os.path.basename(full_path)) 447 expected_dir = os.path.dirname(full_path) 448 msg = ("%r module incorrectly imported from %r. Expected " 449 "%r. Is this module globally installed?") 450 raise ImportError( 451 msg % (mod_name, module_dir, expected_dir)) 452 return self.loadTestsFromModule(module, pattern=pattern), False 453 elif os.path.isdir(full_path): 454 if (not namespace and 455 not os.path.isfile(os.path.join(full_path, '__init__.py'))): 456 return None, False 457 458 load_tests = None 459 tests = None 460 name = self._get_name_from_path(full_path) 461 try: 462 package = self._get_module_from_name(name) 463 except case.SkipTest as e: 464 return _make_skipped_test(name, e, self.suiteClass), False 465 except: 466 error_case, error_message = \ 467 _make_failed_import_test(name, self.suiteClass) 468 self.errors.append(error_message) 469 return error_case, False 470 else: 471 load_tests = getattr(package, 'load_tests', None) 472 # Mark this package as being in load_tests (possibly ;)) 473 self._loading_packages.add(name) 474 try: 475 tests = self.loadTestsFromModule(package, pattern=pattern) 476 if load_tests is not None: 477 # loadTestsFromModule(package) has loaded tests for us. 478 return tests, False 479 return tests, True 480 finally: 481 self._loading_packages.discard(name) 482 else: 483 return None, False 484 485 486defaultTestLoader = TestLoader() 487 488 489def _makeLoader(prefix, sortUsing, suiteClass=None): 490 loader = TestLoader() 491 loader.sortTestMethodsUsing = sortUsing 492 loader.testMethodPrefix = prefix 493 if suiteClass: 494 loader.suiteClass = suiteClass 495 return loader 496 497def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp): 498 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) 499 500def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp, 501 suiteClass=suite.TestSuite): 502 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase( 503 testCaseClass) 504 505def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp, 506 suiteClass=suite.TestSuite): 507 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\ 508 module) 509