1"""Loading unittests.""" 2 3import os 4import re 5import sys 6import traceback 7import types 8 9from functools import cmp_to_key as _CmpToKey 10from fnmatch import fnmatch 11 12from . import case, suite 13 14__unittest = True 15 16# what about .pyc or .pyo (etc) 17# we would need to avoid loading the same tests multiple times 18# from '.py', '.pyc' *and* '.pyo' 19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 20 21 22def _make_failed_import_test(name, suiteClass): 23 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc()) 24 return _make_failed_test('ModuleImportFailure', name, ImportError(message), 25 suiteClass) 26 27def _make_failed_load_tests(name, exception, suiteClass): 28 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass) 29 30def _make_failed_test(classname, methodname, exception, suiteClass): 31 def testFailure(self): 32 raise exception 33 attrs = {methodname: testFailure} 34 TestClass = type(classname, (case.TestCase,), attrs) 35 return suiteClass((TestClass(methodname),)) 36 37 38class TestLoader(object): 39 """ 40 This class is responsible for loading tests according to various criteria 41 and returning them wrapped in a TestSuite 42 """ 43 testMethodPrefix = 'test' 44 sortTestMethodsUsing = cmp 45 suiteClass = suite.TestSuite 46 _top_level_dir = None 47 48 def loadTestsFromTestCase(self, testCaseClass): 49 """Return a suite of all tests cases contained in testCaseClass""" 50 if issubclass(testCaseClass, suite.TestSuite): 51 raise TypeError("Test cases should not be derived from TestSuite." \ 52 " Maybe you meant to derive from TestCase?") 53 testCaseNames = self.getTestCaseNames(testCaseClass) 54 if not testCaseNames and hasattr(testCaseClass, 'runTest'): 55 testCaseNames = ['runTest'] 56 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 57 return loaded_suite 58 59 def loadTestsFromModule(self, module, use_load_tests=True): 60 """Return a suite of all tests cases contained in the given module""" 61 tests = [] 62 for name in dir(module): 63 obj = getattr(module, name) 64 if isinstance(obj, type) and issubclass(obj, case.TestCase): 65 tests.append(self.loadTestsFromTestCase(obj)) 66 67 load_tests = getattr(module, 'load_tests', None) 68 tests = self.suiteClass(tests) 69 if use_load_tests and load_tests is not None: 70 try: 71 return load_tests(self, tests, None) 72 except Exception, e: 73 return _make_failed_load_tests(module.__name__, e, 74 self.suiteClass) 75 return tests 76 77 def loadTestsFromName(self, name, module=None): 78 """Return a suite of all tests cases given a string specifier. 79 80 The name may resolve either to a module, a test case class, a 81 test method within a test case class, or a callable object which 82 returns a TestCase or TestSuite instance. 83 84 The method optionally resolves the names relative to a given module. 85 """ 86 parts = name.split('.') 87 if module is None: 88 parts_copy = parts[:] 89 while parts_copy: 90 try: 91 module = __import__('.'.join(parts_copy)) 92 break 93 except ImportError: 94 del parts_copy[-1] 95 if not parts_copy: 96 raise 97 parts = parts[1:] 98 obj = module 99 for part in parts: 100 parent, obj = obj, getattr(obj, part) 101 102 if isinstance(obj, types.ModuleType): 103 return self.loadTestsFromModule(obj) 104 elif isinstance(obj, type) and issubclass(obj, case.TestCase): 105 return self.loadTestsFromTestCase(obj) 106 elif (isinstance(obj, types.UnboundMethodType) and 107 isinstance(parent, type) and 108 issubclass(parent, case.TestCase)): 109 return self.suiteClass([parent(obj.__name__)]) 110 elif isinstance(obj, suite.TestSuite): 111 return obj 112 elif hasattr(obj, '__call__'): 113 test = obj() 114 if isinstance(test, suite.TestSuite): 115 return test 116 elif isinstance(test, case.TestCase): 117 return self.suiteClass([test]) 118 else: 119 raise TypeError("calling %s returned %s, not a test" % 120 (obj, test)) 121 else: 122 raise TypeError("don't know how to make test from: %s" % obj) 123 124 def loadTestsFromNames(self, names, module=None): 125 """Return a suite of all tests cases found using the given sequence 126 of string specifiers. See 'loadTestsFromName()'. 127 """ 128 suites = [self.loadTestsFromName(name, module) for name in names] 129 return self.suiteClass(suites) 130 131 def getTestCaseNames(self, testCaseClass): 132 """Return a sorted sequence of method names found within testCaseClass 133 """ 134 def isTestMethod(attrname, testCaseClass=testCaseClass, 135 prefix=self.testMethodPrefix): 136 return attrname.startswith(prefix) and \ 137 hasattr(getattr(testCaseClass, attrname), '__call__') 138 testFnNames = filter(isTestMethod, dir(testCaseClass)) 139 if self.sortTestMethodsUsing: 140 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing)) 141 return testFnNames 142 143 def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 144 """Find and return all test modules from the specified start 145 directory, recursing into subdirectories to find them. Only test files 146 that match the pattern will be loaded. (Using shell style pattern 147 matching.) 148 149 All test modules must be importable from the top level of the project. 150 If the start directory is not the top level directory then the top 151 level directory must be specified separately. 152 153 If a test package name (directory with '__init__.py') matches the 154 pattern then the package will be checked for a 'load_tests' function. If 155 this exists then it will be called with loader, tests, pattern. 156 157 If load_tests exists then discovery does *not* recurse into the package, 158 load_tests is responsible for loading all tests in the package. 159 160 The pattern is deliberately not stored as a loader attribute so that 161 packages can continue discovery themselves. top_level_dir is stored so 162 load_tests does not need to pass this argument in to loader.discover(). 163 """ 164 set_implicit_top = False 165 if top_level_dir is None and self._top_level_dir is not None: 166 # make top_level_dir optional if called from load_tests in a package 167 top_level_dir = self._top_level_dir 168 elif top_level_dir is None: 169 set_implicit_top = True 170 top_level_dir = start_dir 171 172 top_level_dir = os.path.abspath(top_level_dir) 173 174 if not top_level_dir in sys.path: 175 # all test modules must be importable from the top level directory 176 # should we *unconditionally* put the start directory in first 177 # in sys.path to minimise likelihood of conflicts between installed 178 # modules and development versions? 179 sys.path.insert(0, top_level_dir) 180 self._top_level_dir = top_level_dir 181 182 is_not_importable = False 183 if os.path.isdir(os.path.abspath(start_dir)): 184 start_dir = os.path.abspath(start_dir) 185 if start_dir != top_level_dir: 186 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) 187 else: 188 # support for discovery from dotted module names 189 try: 190 __import__(start_dir) 191 except ImportError: 192 is_not_importable = True 193 else: 194 the_module = sys.modules[start_dir] 195 top_part = start_dir.split('.')[0] 196 start_dir = os.path.abspath(os.path.dirname((the_module.__file__))) 197 if set_implicit_top: 198 self._top_level_dir = self._get_directory_containing_module(top_part) 199 sys.path.remove(top_level_dir) 200 201 if is_not_importable: 202 raise ImportError('Start directory is not importable: %r' % start_dir) 203 204 tests = list(self._find_tests(start_dir, pattern)) 205 return self.suiteClass(tests) 206 207 def _get_directory_containing_module(self, module_name): 208 module = sys.modules[module_name] 209 full_path = os.path.abspath(module.__file__) 210 211 if os.path.basename(full_path).lower().startswith('__init__.py'): 212 return os.path.dirname(os.path.dirname(full_path)) 213 else: 214 # here we have been given a module rather than a package - so 215 # all we can do is search the *same* directory the module is in 216 # should an exception be raised instead 217 return os.path.dirname(full_path) 218 219 def _get_name_from_path(self, path): 220 path = os.path.splitext(os.path.normpath(path))[0] 221 222 _relpath = os.path.relpath(path, self._top_level_dir) 223 assert not os.path.isabs(_relpath), "Path must be within the project" 224 assert not _relpath.startswith('..'), "Path must be within the project" 225 226 name = _relpath.replace(os.path.sep, '.') 227 return name 228 229 def _get_module_from_name(self, name): 230 __import__(name) 231 return sys.modules[name] 232 233 def _match_path(self, path, full_path, pattern): 234 # override this method to use alternative matching strategy 235 return fnmatch(path, pattern) 236 237 def _find_tests(self, start_dir, pattern): 238 """Used by discovery. Yields test suites it loads.""" 239 paths = os.listdir(start_dir) 240 241 for path in paths: 242 full_path = os.path.join(start_dir, path) 243 if os.path.isfile(full_path): 244 if not VALID_MODULE_NAME.match(path): 245 # valid Python identifiers only 246 continue 247 if not self._match_path(path, full_path, pattern): 248 continue 249 # if the test file matches, load it 250 name = self._get_name_from_path(full_path) 251 try: 252 module = self._get_module_from_name(name) 253 except: 254 yield _make_failed_import_test(name, self.suiteClass) 255 else: 256 mod_file = os.path.abspath(getattr(module, '__file__', full_path)) 257 realpath = os.path.splitext(mod_file)[0] 258 fullpath_noext = os.path.splitext(full_path)[0] 259 if realpath.lower() != fullpath_noext.lower(): 260 module_dir = os.path.dirname(realpath) 261 mod_name = os.path.splitext(os.path.basename(full_path))[0] 262 expected_dir = os.path.dirname(full_path) 263 msg = ("%r module incorrectly imported from %r. Expected %r. " 264 "Is this module globally installed?") 265 raise ImportError(msg % (mod_name, module_dir, expected_dir)) 266 yield self.loadTestsFromModule(module) 267 elif os.path.isdir(full_path): 268 if not os.path.isfile(os.path.join(full_path, '__init__.py')): 269 continue 270 271 load_tests = None 272 tests = None 273 if fnmatch(path, pattern): 274 # only check load_tests if the package directory itself matches the filter 275 name = self._get_name_from_path(full_path) 276 package = self._get_module_from_name(name) 277 load_tests = getattr(package, 'load_tests', None) 278 tests = self.loadTestsFromModule(package, use_load_tests=False) 279 280 if load_tests is None: 281 if tests is not None: 282 # tests loaded from package file 283 yield tests 284 # recurse into the package 285 for test in self._find_tests(full_path, pattern): 286 yield test 287 else: 288 try: 289 yield load_tests(self, tests, pattern) 290 except Exception, e: 291 yield _make_failed_load_tests(package.__name__, e, 292 self.suiteClass) 293 294defaultTestLoader = TestLoader() 295 296 297def _makeLoader(prefix, sortUsing, suiteClass=None): 298 loader = TestLoader() 299 loader.sortTestMethodsUsing = sortUsing 300 loader.testMethodPrefix = prefix 301 if suiteClass: 302 loader.suiteClass = suiteClass 303 return loader 304 305def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp): 306 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) 307 308def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, 309 suiteClass=suite.TestSuite): 310 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass) 311 312def findTestCases(module, prefix='test', sortUsing=cmp, 313 suiteClass=suite.TestSuite): 314 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module) 315