• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os, sys, unittest, getopt, time
2
3use_resources = []
4
5import ctypes
6ctypes_symbols = dir(ctypes)
7
8def need_symbol(name):
9    return unittest.skipUnless(name in ctypes_symbols,
10                               '{!r} is required'.format(name))
11
12
13class ResourceDenied(unittest.SkipTest):
14    """Test skipped because it requested a disallowed resource.
15
16    This is raised when a test calls requires() for a resource that
17    has not be enabled.  Resources are defined by test modules.
18    """
19
20def is_resource_enabled(resource):
21    """Test whether a resource is enabled.
22
23    If the caller's module is __main__ then automatically return True."""
24    if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
25        return True
26    result = use_resources is not None and \
27           (resource in use_resources or "*" in use_resources)
28    if not result:
29        _unavail[resource] = None
30    return result
31
32_unavail = {}
33def requires(resource, msg=None):
34    """Raise ResourceDenied if the specified resource is not available.
35
36    If the caller's module is __main__ then automatically return True."""
37    # see if the caller's module is __main__ - if so, treat as if
38    # the resource was set
39    if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
40        return
41    if not is_resource_enabled(resource):
42        if msg is None:
43            msg = "Use of the `%s' resource not enabled" % resource
44        raise ResourceDenied(msg)
45
46def find_package_modules(package, mask):
47    import fnmatch
48    if (hasattr(package, "__loader__") and
49            hasattr(package.__loader__, '_files')):
50        path = package.__name__.replace(".", os.path.sep)
51        mask = os.path.join(path, mask)
52        for fnm in package.__loader__._files.iterkeys():
53            if fnmatch.fnmatchcase(fnm, mask):
54                yield os.path.splitext(fnm)[0].replace(os.path.sep, ".")
55    else:
56        path = package.__path__[0]
57        for fnm in os.listdir(path):
58            if fnmatch.fnmatchcase(fnm, mask):
59                yield "%s.%s" % (package.__name__, os.path.splitext(fnm)[0])
60
61def get_tests(package, mask, verbosity, exclude=()):
62    """Return a list of skipped test modules, and a list of test cases."""
63    tests = []
64    skipped = []
65    for modname in find_package_modules(package, mask):
66        if modname.split(".")[-1] in exclude:
67            skipped.append(modname)
68            if verbosity > 1:
69                print >> sys.stderr, "Skipped %s: excluded" % modname
70            continue
71        try:
72            mod = __import__(modname, globals(), locals(), ['*'])
73        except (ResourceDenied, unittest.SkipTest) as detail:
74            skipped.append(modname)
75            if verbosity > 1:
76                print >> sys.stderr, "Skipped %s: %s" % (modname, detail)
77            continue
78        for name in dir(mod):
79            if name.startswith("_"):
80                continue
81            o = getattr(mod, name)
82            if type(o) is type(unittest.TestCase) and issubclass(o, unittest.TestCase):
83                tests.append(o)
84    return skipped, tests
85
86def usage():
87    print __doc__
88    return 1
89
90def test_with_refcounts(runner, verbosity, testcase):
91    """Run testcase several times, tracking reference counts."""
92    import gc
93    import ctypes
94    ptc = ctypes._pointer_type_cache.copy()
95    cfc = ctypes._c_functype_cache.copy()
96    wfc = ctypes._win_functype_cache.copy()
97
98    # when searching for refcount leaks, we have to manually reset any
99    # caches that ctypes has.
100    def cleanup():
101        ctypes._pointer_type_cache = ptc.copy()
102        ctypes._c_functype_cache = cfc.copy()
103        ctypes._win_functype_cache = wfc.copy()
104        gc.collect()
105
106    test = unittest.makeSuite(testcase)
107    for i in range(5):
108        rc = sys.gettotalrefcount()
109        runner.run(test)
110        cleanup()
111    COUNT = 5
112    refcounts = [None] * COUNT
113    for i in range(COUNT):
114        rc = sys.gettotalrefcount()
115        runner.run(test)
116        cleanup()
117        refcounts[i] = sys.gettotalrefcount() - rc
118    if filter(None, refcounts):
119        print "%s leaks:\n\t" % testcase, refcounts
120    elif verbosity:
121        print "%s: ok." % testcase
122
123class TestRunner(unittest.TextTestRunner):
124    def run(self, test, skipped):
125        "Run the given test case or test suite."
126        # Same as unittest.TextTestRunner.run, except that it reports
127        # skipped tests.
128        result = self._makeResult()
129        startTime = time.time()
130        test(result)
131        stopTime = time.time()
132        timeTaken = stopTime - startTime
133        result.printErrors()
134        self.stream.writeln(result.separator2)
135        run = result.testsRun
136        if _unavail: #skipped:
137            requested = _unavail.keys()
138            requested.sort()
139            self.stream.writeln("Ran %d test%s in %.3fs (%s module%s skipped)" %
140                                (run, run != 1 and "s" or "", timeTaken,
141                                 len(skipped),
142                                 len(skipped) != 1 and "s" or ""))
143            self.stream.writeln("Unavailable resources: %s" % ", ".join(requested))
144        else:
145            self.stream.writeln("Ran %d test%s in %.3fs" %
146                                (run, run != 1 and "s" or "", timeTaken))
147        self.stream.writeln()
148        if not result.wasSuccessful():
149            self.stream.write("FAILED (")
150            failed, errored = map(len, (result.failures, result.errors))
151            if failed:
152                self.stream.write("failures=%d" % failed)
153            if errored:
154                if failed: self.stream.write(", ")
155                self.stream.write("errors=%d" % errored)
156            self.stream.writeln(")")
157        else:
158            self.stream.writeln("OK")
159        return result
160
161
162def main(*packages):
163    try:
164        opts, args = getopt.getopt(sys.argv[1:], "rqvu:x:")
165    except getopt.error:
166        return usage()
167
168    verbosity = 1
169    search_leaks = False
170    exclude = []
171    for flag, value in opts:
172        if flag == "-q":
173            verbosity -= 1
174        elif flag == "-v":
175            verbosity += 1
176        elif flag == "-r":
177            try:
178                sys.gettotalrefcount
179            except AttributeError:
180                print >> sys.stderr, "-r flag requires Python debug build"
181                return -1
182            search_leaks = True
183        elif flag == "-u":
184            use_resources.extend(value.split(","))
185        elif flag == "-x":
186            exclude.extend(value.split(","))
187
188    mask = "test_*.py"
189    if args:
190        mask = args[0]
191
192    for package in packages:
193        run_tests(package, mask, verbosity, search_leaks, exclude)
194
195
196def run_tests(package, mask, verbosity, search_leaks, exclude):
197    skipped, testcases = get_tests(package, mask, verbosity, exclude)
198    runner = TestRunner(verbosity=verbosity)
199
200    suites = [unittest.makeSuite(o) for o in testcases]
201    suite = unittest.TestSuite(suites)
202    result = runner.run(suite, skipped)
203
204    if search_leaks:
205        # hunt for refcount leaks
206        runner = BasicTestRunner()
207        for t in testcases:
208            test_with_refcounts(runner, verbosity, t)
209
210    return bool(result.errors)
211
212class BasicTestRunner:
213    def run(self, test):
214        result = unittest.TestResult()
215        test(result)
216        return result
217