1"""Loading unittests.""" 2 3import os 4import re 5import sys 6import traceback 7import types 8import unittest 9 10from fnmatch import fnmatch 11 12from unittest2 import case, suite 13 14try: 15 from os.path import relpath 16except ImportError: 17 from unittest2.compatibility import relpath 18 19__unittest = True 20 21 22def _CmpToKey(mycmp): 23 'Convert a cmp= function into a key= function' 24 class K(object): 25 def __init__(self, obj): 26 self.obj = obj 27 def __lt__(self, other): 28 return mycmp(self.obj, other.obj) == -1 29 return K 30 31 32# what about .pyc or .pyo (etc) 33# we would need to avoid loading the same tests multiple times 34# from '.py', '.pyc' *and* '.pyo' 35VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 36 37 38def _make_failed_import_test(name, suiteClass): 39 message = 'Failed to import test module: %s' % name 40 if hasattr(traceback, 'format_exc'): 41 # Python 2.3 compatibility 42 # format_exc returns two frames of discover.py as well 43 message += '\n%s' % traceback.format_exc() 44 return _make_failed_test('ModuleImportFailure', name, ImportError(message), 45 suiteClass) 46 47def _make_failed_load_tests(name, exception, suiteClass): 48 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass) 49 50def _make_failed_test(classname, methodname, exception, suiteClass): 51 def testFailure(self): 52 raise exception 53 attrs = {methodname: testFailure} 54 TestClass = type(classname, (case.TestCase,), attrs) 55 return suiteClass((TestClass(methodname),)) 56 57 58class TestLoader(unittest.TestLoader): 59 """ 60 This class is responsible for loading tests according to various criteria 61 and returning them wrapped in a TestSuite 62 """ 63 testMethodPrefix = 'test' 64 sortTestMethodsUsing = cmp 65 suiteClass = suite.TestSuite 66 _top_level_dir = None 67 68 def loadTestsFromTestCase(self, testCaseClass): 69 """Return a suite of all tests cases contained in testCaseClass""" 70 if issubclass(testCaseClass, suite.TestSuite): 71 raise TypeError("Test cases should not be derived from TestSuite." 72 " Maybe you meant to derive from TestCase?") 73 testCaseNames = self.getTestCaseNames(testCaseClass) 74 if not testCaseNames and hasattr(testCaseClass, 'runTest'): 75 testCaseNames = ['runTest'] 76 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 77 return loaded_suite 78 79 def loadTestsFromModule(self, module, use_load_tests=True): 80 """Return a suite of all tests cases contained in the given module""" 81 tests = [] 82 for name in dir(module): 83 obj = getattr(module, name) 84 if isinstance(obj, type) and issubclass(obj, unittest.TestCase): 85 tests.append(self.loadTestsFromTestCase(obj)) 86 87 load_tests = getattr(module, 'load_tests', None) 88 tests = self.suiteClass(tests) 89 if use_load_tests and load_tests is not None: 90 try: 91 return load_tests(self, tests, None) 92 except Exception, e: 93 return _make_failed_load_tests(module.__name__, e, 94 self.suiteClass) 95 return tests 96 97 def loadTestsFromName(self, name, module=None): 98 """Return a suite of all tests cases given a string specifier. 99 100 The name may resolve either to a module, a test case class, a 101 test method within a test case class, or a callable object which 102 returns a TestCase or TestSuite instance. 103 104 The method optionally resolves the names relative to a given module. 105 """ 106 parts = name.split('.') 107 if module is None: 108 parts_copy = parts[:] 109 while parts_copy: 110 try: 111 module = __import__('.'.join(parts_copy)) 112 break 113 except ImportError: 114 del parts_copy[-1] 115 if not parts_copy: 116 raise 117 parts = parts[1:] 118 obj = module 119 for part in parts: 120 parent, obj = obj, getattr(obj, part) 121 122 if isinstance(obj, types.ModuleType): 123 return self.loadTestsFromModule(obj) 124 elif isinstance(obj, type) and issubclass(obj, unittest.TestCase): 125 return self.loadTestsFromTestCase(obj) 126 elif (isinstance(obj, types.UnboundMethodType) and 127 isinstance(parent, type) and 128 issubclass(parent, case.TestCase)): 129 return self.suiteClass([parent(obj.__name__)]) 130 elif isinstance(obj, unittest.TestSuite): 131 return obj 132 elif hasattr(obj, '__call__'): 133 test = obj() 134 if isinstance(test, unittest.TestSuite): 135 return test 136 elif isinstance(test, unittest.TestCase): 137 return self.suiteClass([test]) 138 else: 139 raise TypeError("calling %s returned %s, not a test" % 140 (obj, test)) 141 else: 142 raise TypeError("don't know how to make test from: %s" % obj) 143 144 def loadTestsFromNames(self, names, module=None): 145 """Return a suite of all tests cases found using the given sequence 146 of string specifiers. See 'loadTestsFromName()'. 147 """ 148 suites = [self.loadTestsFromName(name, module) for name in names] 149 return self.suiteClass(suites) 150 151 def getTestCaseNames(self, testCaseClass): 152 """Return a sorted sequence of method names found within testCaseClass 153 """ 154 def isTestMethod(attrname, testCaseClass=testCaseClass, 155 prefix=self.testMethodPrefix): 156 return attrname.startswith(prefix) and \ 157 hasattr(getattr(testCaseClass, attrname), '__call__') 158 testFnNames = filter(isTestMethod, dir(testCaseClass)) 159 if self.sortTestMethodsUsing: 160 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing)) 161 return testFnNames 162 163 def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 164 """Find and return all test modules from the specified start 165 directory, recursing into subdirectories to find them. Only test files 166 that match the pattern will be loaded. (Using shell style pattern 167 matching.) 168 169 All test modules must be importable from the top level of the project. 170 If the start directory is not the top level directory then the top 171 level directory must be specified separately. 172 173 If a test package name (directory with '__init__.py') matches the 174 pattern then the package will be checked for a 'load_tests' function. If 175 this exists then it will be called with loader, tests, pattern. 176 177 If load_tests exists then discovery does *not* recurse into the package, 178 load_tests is responsible for loading all tests in the package. 179 180 The pattern is deliberately not stored as a loader attribute so that 181 packages can continue discovery themselves. top_level_dir is stored so 182 load_tests does not need to pass this argument in to loader.discover(). 183 """ 184 set_implicit_top = False 185 if top_level_dir is None and self._top_level_dir is not None: 186 # make top_level_dir optional if called from load_tests in a package 187 top_level_dir = self._top_level_dir 188 elif top_level_dir is None: 189 set_implicit_top = True 190 top_level_dir = start_dir 191 192 top_level_dir = os.path.abspath(top_level_dir) 193 194 if not top_level_dir in sys.path: 195 # all test modules must be importable from the top level directory 196 # should we *unconditionally* put the start directory in first 197 # in sys.path to minimise likelihood of conflicts between installed 198 # modules and development versions? 199 sys.path.insert(0, top_level_dir) 200 self._top_level_dir = top_level_dir 201 202 is_not_importable = False 203 if os.path.isdir(os.path.abspath(start_dir)): 204 start_dir = os.path.abspath(start_dir) 205 if start_dir != top_level_dir: 206 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) 207 else: 208 # support for discovery from dotted module names 209 try: 210 __import__(start_dir) 211 except ImportError: 212 is_not_importable = True 213 else: 214 the_module = sys.modules[start_dir] 215 top_part = start_dir.split('.')[0] 216 start_dir = os.path.abspath(os.path.dirname((the_module.__file__))) 217 if set_implicit_top: 218 self._top_level_dir = os.path.abspath(os.path.dirname(os.path.dirname(sys.modules[top_part].__file__))) 219 sys.path.remove(top_level_dir) 220 221 if is_not_importable: 222 raise ImportError('Start directory is not importable: %r' % start_dir) 223 224 tests = list(self._find_tests(start_dir, pattern)) 225 return self.suiteClass(tests) 226 227 def _get_name_from_path(self, path): 228 path = os.path.splitext(os.path.normpath(path))[0] 229 230 _relpath = relpath(path, self._top_level_dir) 231 assert not os.path.isabs(_relpath), "Path must be within the project" 232 assert not _relpath.startswith('..'), "Path must be within the project" 233 234 name = _relpath.replace(os.path.sep, '.') 235 return name 236 237 def _get_module_from_name(self, name): 238 __import__(name) 239 return sys.modules[name] 240 241 def _match_path(self, path, full_path, pattern): 242 # override this method to use alternative matching strategy 243 return fnmatch(path, pattern) 244 245 def _find_tests(self, start_dir, pattern): 246 """Used by discovery. Yields test suites it loads.""" 247 paths = os.listdir(start_dir) 248 249 for path in paths: 250 full_path = os.path.join(start_dir, path) 251 if os.path.isfile(full_path): 252 if not VALID_MODULE_NAME.match(path): 253 # valid Python identifiers only 254 continue 255 if not self._match_path(path, full_path, pattern): 256 continue 257 # if the test file matches, load it 258 name = self._get_name_from_path(full_path) 259 try: 260 module = self._get_module_from_name(name) 261 except: 262 yield _make_failed_import_test(name, self.suiteClass) 263 else: 264 mod_file = os.path.abspath(getattr(module, '__file__', full_path)) 265 realpath = os.path.splitext(mod_file)[0] 266 fullpath_noext = os.path.splitext(full_path)[0] 267 if realpath.lower() != fullpath_noext.lower(): 268 module_dir = os.path.dirname(realpath) 269 mod_name = os.path.splitext(os.path.basename(full_path))[0] 270 expected_dir = os.path.dirname(full_path) 271 msg = ("%r module incorrectly imported from %r. Expected %r. " 272 "Is this module globally installed?") 273 raise ImportError(msg % (mod_name, module_dir, expected_dir)) 274 yield self.loadTestsFromModule(module) 275 elif os.path.isdir(full_path): 276 if not os.path.isfile(os.path.join(full_path, '__init__.py')): 277 continue 278 279 load_tests = None 280 tests = None 281 if fnmatch(path, pattern): 282 # only check load_tests if the package directory itself matches the filter 283 name = self._get_name_from_path(full_path) 284 package = self._get_module_from_name(name) 285 load_tests = getattr(package, 'load_tests', None) 286 tests = self.loadTestsFromModule(package, use_load_tests=False) 287 288 if load_tests is None: 289 if tests is not None: 290 # tests loaded from package file 291 yield tests 292 # recurse into the package 293 for test in self._find_tests(full_path, pattern): 294 yield test 295 else: 296 try: 297 yield load_tests(self, tests, pattern) 298 except Exception, e: 299 yield _make_failed_load_tests(package.__name__, e, 300 self.suiteClass) 301 302defaultTestLoader = TestLoader() 303 304 305def _makeLoader(prefix, sortUsing, suiteClass=None): 306 loader = TestLoader() 307 loader.sortTestMethodsUsing = sortUsing 308 loader.testMethodPrefix = prefix 309 if suiteClass: 310 loader.suiteClass = suiteClass 311 return loader 312 313def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp): 314 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) 315 316def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, 317 suiteClass=suite.TestSuite): 318 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass) 319 320def findTestCases(module, prefix='test', sortUsing=cmp, 321 suiteClass=suite.TestSuite): 322 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module) 323