1"""Loading unittests.""" 2 3import os 4import re 5import sys 6import traceback 7import types 8 9from functools import cmp_to_key as _CmpToKey 10from fnmatch import fnmatch 11 12from . import case, suite 13 14__unittest = True 15 16# what about .pyc or .pyo (etc) 17# we would need to avoid loading the same tests multiple times 18# from '.py', '.pyc' *and* '.pyo' 19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 20 21 22def _make_failed_import_test(name, suiteClass): 23 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc()) 24 return _make_failed_test('ModuleImportFailure', name, ImportError(message), 25 suiteClass) 26 27def _make_failed_load_tests(name, exception, suiteClass): 28 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass) 29 30def _make_failed_test(classname, methodname, exception, suiteClass): 31 def testFailure(self): 32 raise exception 33 attrs = {methodname: testFailure} 34 TestClass = type(classname, (case.TestCase,), attrs) 35 return suiteClass((TestClass(methodname),)) 36 37 38class TestLoader(object): 39 """ 40 This class is responsible for loading tests according to various criteria 41 and returning them wrapped in a TestSuite 42 """ 43 testMethodPrefix = 'test' 44 sortTestMethodsUsing = cmp 45 suiteClass = suite.TestSuite 46 _top_level_dir = None 47 48 def loadTestsFromTestCase(self, testCaseClass): 49 """Return a suite of all tests cases contained in testCaseClass""" 50 if issubclass(testCaseClass, suite.TestSuite): 51 raise TypeError("Test cases should not be derived from TestSuite." \ 52 " Maybe you meant to derive from TestCase?") 53 testCaseNames = self.getTestCaseNames(testCaseClass) 54 if not testCaseNames and hasattr(testCaseClass, 'runTest'): 55 testCaseNames = ['runTest'] 56 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 57 return loaded_suite 58 59 def loadTestsFromModule(self, module, use_load_tests=True): 60 """Return a suite of all tests cases contained in the given module""" 61 tests = [] 62 for name in dir(module): 63 obj = getattr(module, name) 64 if isinstance(obj, type) and issubclass(obj, case.TestCase): 65 tests.append(self.loadTestsFromTestCase(obj)) 66 67 load_tests = getattr(module, 'load_tests', None) 68 tests = self.suiteClass(tests) 69 if use_load_tests and load_tests is not None: 70 try: 71 return load_tests(self, tests, None) 72 except Exception, e: 73 return _make_failed_load_tests(module.__name__, e, 74 self.suiteClass) 75 return tests 76 77 def loadTestsFromName(self, name, module=None): 78 """Return a suite of all tests cases given a string specifier. 79 80 The name may resolve either to a module, a test case class, a 81 test method within a test case class, or a callable object which 82 returns a TestCase or TestSuite instance. 83 84 The method optionally resolves the names relative to a given module. 85 """ 86 parts = name.split('.') 87 if module is None: 88 parts_copy = parts[:] 89 while parts_copy: 90 try: 91 module = __import__('.'.join(parts_copy)) 92 break 93 except ImportError: 94 del parts_copy[-1] 95 if not parts_copy: 96 raise 97 parts = parts[1:] 98 obj = module 99 for part in parts: 100 parent, obj = obj, getattr(obj, part) 101 102 if isinstance(obj, types.ModuleType): 103 return self.loadTestsFromModule(obj) 104 elif isinstance(obj, type) and issubclass(obj, case.TestCase): 105 return self.loadTestsFromTestCase(obj) 106 elif (isinstance(obj, types.UnboundMethodType) and 107 isinstance(parent, type) and 108 issubclass(parent, case.TestCase)): 109 name = parts[-1] 110 inst = parent(name) 111 return self.suiteClass([inst]) 112 elif isinstance(obj, suite.TestSuite): 113 return obj 114 elif hasattr(obj, '__call__'): 115 test = obj() 116 if isinstance(test, suite.TestSuite): 117 return test 118 elif isinstance(test, case.TestCase): 119 return self.suiteClass([test]) 120 else: 121 raise TypeError("calling %s returned %s, not a test" % 122 (obj, test)) 123 else: 124 raise TypeError("don't know how to make test from: %s" % obj) 125 126 def loadTestsFromNames(self, names, module=None): 127 """Return a suite of all tests cases found using the given sequence 128 of string specifiers. See 'loadTestsFromName()'. 129 """ 130 suites = [self.loadTestsFromName(name, module) for name in names] 131 return self.suiteClass(suites) 132 133 def getTestCaseNames(self, testCaseClass): 134 """Return a sorted sequence of method names found within testCaseClass 135 """ 136 def isTestMethod(attrname, testCaseClass=testCaseClass, 137 prefix=self.testMethodPrefix): 138 return attrname.startswith(prefix) and \ 139 hasattr(getattr(testCaseClass, attrname), '__call__') 140 testFnNames = filter(isTestMethod, dir(testCaseClass)) 141 if self.sortTestMethodsUsing: 142 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing)) 143 return testFnNames 144 145 def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 146 """Find and return all test modules from the specified start 147 directory, recursing into subdirectories to find them. Only test files 148 that match the pattern will be loaded. (Using shell style pattern 149 matching.) 150 151 All test modules must be importable from the top level of the project. 152 If the start directory is not the top level directory then the top 153 level directory must be specified separately. 154 155 If a test package name (directory with '__init__.py') matches the 156 pattern then the package will be checked for a 'load_tests' function. If 157 this exists then it will be called with loader, tests, pattern. 158 159 If load_tests exists then discovery does *not* recurse into the package, 160 load_tests is responsible for loading all tests in the package. 161 162 The pattern is deliberately not stored as a loader attribute so that 163 packages can continue discovery themselves. top_level_dir is stored so 164 load_tests does not need to pass this argument in to loader.discover(). 165 """ 166 set_implicit_top = False 167 if top_level_dir is None and self._top_level_dir is not None: 168 # make top_level_dir optional if called from load_tests in a package 169 top_level_dir = self._top_level_dir 170 elif top_level_dir is None: 171 set_implicit_top = True 172 top_level_dir = start_dir 173 174 top_level_dir = os.path.abspath(top_level_dir) 175 176 if not top_level_dir in sys.path: 177 # all test modules must be importable from the top level directory 178 # should we *unconditionally* put the start directory in first 179 # in sys.path to minimise likelihood of conflicts between installed 180 # modules and development versions? 181 sys.path.insert(0, top_level_dir) 182 self._top_level_dir = top_level_dir 183 184 is_not_importable = False 185 if os.path.isdir(os.path.abspath(start_dir)): 186 start_dir = os.path.abspath(start_dir) 187 if start_dir != top_level_dir: 188 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) 189 else: 190 # support for discovery from dotted module names 191 try: 192 __import__(start_dir) 193 except ImportError: 194 is_not_importable = True 195 else: 196 the_module = sys.modules[start_dir] 197 top_part = start_dir.split('.')[0] 198 start_dir = os.path.abspath(os.path.dirname((the_module.__file__))) 199 if set_implicit_top: 200 self._top_level_dir = self._get_directory_containing_module(top_part) 201 sys.path.remove(top_level_dir) 202 203 if is_not_importable: 204 raise ImportError('Start directory is not importable: %r' % start_dir) 205 206 tests = list(self._find_tests(start_dir, pattern)) 207 return self.suiteClass(tests) 208 209 def _get_directory_containing_module(self, module_name): 210 module = sys.modules[module_name] 211 full_path = os.path.abspath(module.__file__) 212 213 if os.path.basename(full_path).lower().startswith('__init__.py'): 214 return os.path.dirname(os.path.dirname(full_path)) 215 else: 216 # here we have been given a module rather than a package - so 217 # all we can do is search the *same* directory the module is in 218 # should an exception be raised instead 219 return os.path.dirname(full_path) 220 221 def _get_name_from_path(self, path): 222 path = os.path.splitext(os.path.normpath(path))[0] 223 224 _relpath = os.path.relpath(path, self._top_level_dir) 225 assert not os.path.isabs(_relpath), "Path must be within the project" 226 assert not _relpath.startswith('..'), "Path must be within the project" 227 228 name = _relpath.replace(os.path.sep, '.') 229 return name 230 231 def _get_module_from_name(self, name): 232 __import__(name) 233 return sys.modules[name] 234 235 def _match_path(self, path, full_path, pattern): 236 # override this method to use alternative matching strategy 237 return fnmatch(path, pattern) 238 239 def _find_tests(self, start_dir, pattern): 240 """Used by discovery. Yields test suites it loads.""" 241 paths = os.listdir(start_dir) 242 243 for path in paths: 244 full_path = os.path.join(start_dir, path) 245 if os.path.isfile(full_path): 246 if not VALID_MODULE_NAME.match(path): 247 # valid Python identifiers only 248 continue 249 if not self._match_path(path, full_path, pattern): 250 continue 251 # if the test file matches, load it 252 name = self._get_name_from_path(full_path) 253 try: 254 module = self._get_module_from_name(name) 255 except: 256 yield _make_failed_import_test(name, self.suiteClass) 257 else: 258 mod_file = os.path.abspath(getattr(module, '__file__', full_path)) 259 realpath = os.path.splitext(os.path.realpath(mod_file))[0] 260 fullpath_noext = os.path.splitext(os.path.realpath(full_path))[0] 261 if realpath.lower() != fullpath_noext.lower(): 262 module_dir = os.path.dirname(realpath) 263 mod_name = os.path.splitext(os.path.basename(full_path))[0] 264 expected_dir = os.path.dirname(full_path) 265 msg = ("%r module incorrectly imported from %r. Expected %r. " 266 "Is this module globally installed?") 267 raise ImportError(msg % (mod_name, module_dir, expected_dir)) 268 yield self.loadTestsFromModule(module) 269 elif os.path.isdir(full_path): 270 if not os.path.isfile(os.path.join(full_path, '__init__.py')): 271 continue 272 273 load_tests = None 274 tests = None 275 if fnmatch(path, pattern): 276 # only check load_tests if the package directory itself matches the filter 277 name = self._get_name_from_path(full_path) 278 package = self._get_module_from_name(name) 279 load_tests = getattr(package, 'load_tests', None) 280 tests = self.loadTestsFromModule(package, use_load_tests=False) 281 282 if load_tests is None: 283 if tests is not None: 284 # tests loaded from package file 285 yield tests 286 # recurse into the package 287 for test in self._find_tests(full_path, pattern): 288 yield test 289 else: 290 try: 291 yield load_tests(self, tests, pattern) 292 except Exception, e: 293 yield _make_failed_load_tests(package.__name__, e, 294 self.suiteClass) 295 296defaultTestLoader = TestLoader() 297 298 299def _makeLoader(prefix, sortUsing, suiteClass=None): 300 loader = TestLoader() 301 loader.sortTestMethodsUsing = sortUsing 302 loader.testMethodPrefix = prefix 303 if suiteClass: 304 loader.suiteClass = suiteClass 305 return loader 306 307def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp): 308 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) 309 310def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, 311 suiteClass=suite.TestSuite): 312 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass) 313 314def findTestCases(module, prefix='test', sortUsing=cmp, 315 suiteClass=suite.TestSuite): 316 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module) 317