• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os
2import operator
3import sys
4import contextlib
5import itertools
6import unittest
7from distutils.errors import DistutilsError, DistutilsOptionError
8from distutils import log
9from unittest import TestLoader
10
11from pkg_resources import (
12    resource_listdir,
13    resource_exists,
14    normalize_path,
15    working_set,
16    evaluate_marker,
17    add_activation_listener,
18    require,
19)
20from .._importlib import metadata
21from setuptools import Command
22from setuptools.extern.more_itertools import unique_everseen
23from setuptools.extern.jaraco.functools import pass_none
24
25
26class ScanningLoader(TestLoader):
27    def __init__(self):
28        TestLoader.__init__(self)
29        self._visited = set()
30
31    def loadTestsFromModule(self, module, pattern=None):
32        """Return a suite of all tests cases contained in the given module
33
34        If the module is a package, load tests from all the modules in it.
35        If the module has an ``additional_tests`` function, call it and add
36        the return value to the tests.
37        """
38        if module in self._visited:
39            return None
40        self._visited.add(module)
41
42        tests = []
43        tests.append(TestLoader.loadTestsFromModule(self, module))
44
45        if hasattr(module, "additional_tests"):
46            tests.append(module.additional_tests())
47
48        if hasattr(module, '__path__'):
49            for file in resource_listdir(module.__name__, ''):
50                if file.endswith('.py') and file != '__init__.py':
51                    submodule = module.__name__ + '.' + file[:-3]
52                else:
53                    if resource_exists(module.__name__, file + '/__init__.py'):
54                        submodule = module.__name__ + '.' + file
55                    else:
56                        continue
57                tests.append(self.loadTestsFromName(submodule))
58
59        if len(tests) != 1:
60            return self.suiteClass(tests)
61        else:
62            return tests[0]  # don't create a nested suite for only one return
63
64
65# adapted from jaraco.classes.properties:NonDataProperty
66class NonDataProperty:
67    def __init__(self, fget):
68        self.fget = fget
69
70    def __get__(self, obj, objtype=None):
71        if obj is None:
72            return self
73        return self.fget(obj)
74
75
76class test(Command):
77    """Command to run unit tests after in-place build"""
78
79    description = "run unit tests after in-place build (deprecated)"
80
81    user_options = [
82        ('test-module=', 'm', "Run 'test_suite' in specified module"),
83        (
84            'test-suite=',
85            's',
86            "Run single test, case or suite (e.g. 'module.test_suite')",
87        ),
88        ('test-runner=', 'r', "Test runner to use"),
89    ]
90
91    def initialize_options(self):
92        self.test_suite = None
93        self.test_module = None
94        self.test_loader = None
95        self.test_runner = None
96
97    def finalize_options(self):
98
99        if self.test_suite and self.test_module:
100            msg = "You may specify a module or a suite, but not both"
101            raise DistutilsOptionError(msg)
102
103        if self.test_suite is None:
104            if self.test_module is None:
105                self.test_suite = self.distribution.test_suite
106            else:
107                self.test_suite = self.test_module + ".test_suite"
108
109        if self.test_loader is None:
110            self.test_loader = getattr(self.distribution, 'test_loader', None)
111        if self.test_loader is None:
112            self.test_loader = "setuptools.command.test:ScanningLoader"
113        if self.test_runner is None:
114            self.test_runner = getattr(self.distribution, 'test_runner', None)
115
116    @NonDataProperty
117    def test_args(self):
118        return list(self._test_args())
119
120    def _test_args(self):
121        if not self.test_suite and sys.version_info >= (2, 7):
122            yield 'discover'
123        if self.verbose:
124            yield '--verbose'
125        if self.test_suite:
126            yield self.test_suite
127
128    def with_project_on_sys_path(self, func):
129        """
130        Backward compatibility for project_on_sys_path context.
131        """
132        with self.project_on_sys_path():
133            func()
134
135    @contextlib.contextmanager
136    def project_on_sys_path(self, include_dists=[]):
137        self.run_command('egg_info')
138
139        # Build extensions in-place
140        self.reinitialize_command('build_ext', inplace=1)
141        self.run_command('build_ext')
142
143        ei_cmd = self.get_finalized_command("egg_info")
144
145        old_path = sys.path[:]
146        old_modules = sys.modules.copy()
147
148        try:
149            project_path = normalize_path(ei_cmd.egg_base)
150            sys.path.insert(0, project_path)
151            working_set.__init__()
152            add_activation_listener(lambda dist: dist.activate())
153            require('%s==%s' % (ei_cmd.egg_name, ei_cmd.egg_version))
154            with self.paths_on_pythonpath([project_path]):
155                yield
156        finally:
157            sys.path[:] = old_path
158            sys.modules.clear()
159            sys.modules.update(old_modules)
160            working_set.__init__()
161
162    @staticmethod
163    @contextlib.contextmanager
164    def paths_on_pythonpath(paths):
165        """
166        Add the indicated paths to the head of the PYTHONPATH environment
167        variable so that subprocesses will also see the packages at
168        these paths.
169
170        Do this in a context that restores the value on exit.
171        """
172        nothing = object()
173        orig_pythonpath = os.environ.get('PYTHONPATH', nothing)
174        current_pythonpath = os.environ.get('PYTHONPATH', '')
175        try:
176            prefix = os.pathsep.join(unique_everseen(paths))
177            to_join = filter(None, [prefix, current_pythonpath])
178            new_path = os.pathsep.join(to_join)
179            if new_path:
180                os.environ['PYTHONPATH'] = new_path
181            yield
182        finally:
183            if orig_pythonpath is nothing:
184                os.environ.pop('PYTHONPATH', None)
185            else:
186                os.environ['PYTHONPATH'] = orig_pythonpath
187
188    @staticmethod
189    def install_dists(dist):
190        """
191        Install the requirements indicated by self.distribution and
192        return an iterable of the dists that were built.
193        """
194        ir_d = dist.fetch_build_eggs(dist.install_requires)
195        tr_d = dist.fetch_build_eggs(dist.tests_require or [])
196        er_d = dist.fetch_build_eggs(
197            v
198            for k, v in dist.extras_require.items()
199            if k.startswith(':') and evaluate_marker(k[1:])
200        )
201        return itertools.chain(ir_d, tr_d, er_d)
202
203    def run(self):
204        self.announce(
205            "WARNING: Testing via this command is deprecated and will be "
206            "removed in a future version. Users looking for a generic test "
207            "entry point independent of test runner are encouraged to use "
208            "tox.",
209            log.WARN,
210        )
211
212        installed_dists = self.install_dists(self.distribution)
213
214        cmd = ' '.join(self._argv)
215        if self.dry_run:
216            self.announce('skipping "%s" (dry run)' % cmd)
217            return
218
219        self.announce('running "%s"' % cmd)
220
221        paths = map(operator.attrgetter('location'), installed_dists)
222        with self.paths_on_pythonpath(paths):
223            with self.project_on_sys_path():
224                self.run_tests()
225
226    def run_tests(self):
227        test = unittest.main(
228            None,
229            None,
230            self._argv,
231            testLoader=self._resolve_as_ep(self.test_loader),
232            testRunner=self._resolve_as_ep(self.test_runner),
233            exit=False,
234        )
235        if not test.result.wasSuccessful():
236            msg = 'Test failed: %s' % test.result
237            self.announce(msg, log.ERROR)
238            raise DistutilsError(msg)
239
240    @property
241    def _argv(self):
242        return ['unittest'] + self.test_args
243
244    @staticmethod
245    @pass_none
246    def _resolve_as_ep(val):
247        """
248        Load the indicated attribute value, called, as a as if it were
249        specified as an entry point.
250        """
251        return metadata.EntryPoint(value=val, name=None, group=None).load()()
252