• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import builtins
2import contextlib
3import errno
4import functools
5from importlib import machinery, util, invalidate_caches
6import marshal
7import os
8import os.path
9from test import support
10from test.support import import_helper
11from test.support import is_apple_mobile
12from test.support import os_helper
13import unittest
14import sys
15import tempfile
16import types
17
18_testsinglephase = import_helper.import_module("_testsinglephase")
19
20
21BUILTINS = types.SimpleNamespace()
22BUILTINS.good_name = None
23BUILTINS.bad_name = None
24if 'errno' in sys.builtin_module_names:
25    BUILTINS.good_name = 'errno'
26if 'importlib' not in sys.builtin_module_names:
27    BUILTINS.bad_name = 'importlib'
28
29if support.is_wasi:
30    # dlopen() is a shim for WASI as of WASI SDK which fails by default.
31    # We don't provide an implementation, so tests will fail.
32    # But we also don't want to turn off dynamic loading for those that provide
33    # a working implementation.
34    def _extension_details():
35        global EXTENSIONS
36        EXTENSIONS = None
37else:
38    EXTENSIONS = types.SimpleNamespace()
39    EXTENSIONS.path = None
40    EXTENSIONS.ext = None
41    EXTENSIONS.filename = None
42    EXTENSIONS.file_path = None
43    EXTENSIONS.name = '_testsinglephase'
44
45    def _extension_details():
46        global EXTENSIONS
47        for path in sys.path:
48            for ext in machinery.EXTENSION_SUFFIXES:
49                # Apple mobile platforms mechanically load .so files,
50                # but the findable files are labelled .fwork
51                if is_apple_mobile:
52                    ext = ext.replace(".so", ".fwork")
53
54                filename = EXTENSIONS.name + ext
55                file_path = os.path.join(path, filename)
56                if os.path.exists(file_path):
57                    EXTENSIONS.path = path
58                    EXTENSIONS.ext = ext
59                    EXTENSIONS.filename = filename
60                    EXTENSIONS.file_path = file_path
61                    return
62
63_extension_details()
64
65
66def import_importlib(module_name):
67    """Import a module from importlib both w/ and w/o _frozen_importlib."""
68    fresh = ('importlib',) if '.' in module_name else ()
69    frozen = import_helper.import_fresh_module(module_name)
70    source = import_helper.import_fresh_module(module_name, fresh=fresh,
71                                         blocked=('_frozen_importlib', '_frozen_importlib_external'))
72    return {'Frozen': frozen, 'Source': source}
73
74
75def specialize_class(cls, kind, base=None, **kwargs):
76    # XXX Support passing in submodule names--load (and cache) them?
77    # That would clean up the test modules a bit more.
78    if base is None:
79        base = unittest.TestCase
80    elif not isinstance(base, type):
81        base = base[kind]
82    name = '{}_{}'.format(kind, cls.__name__)
83    bases = (cls, base)
84    specialized = types.new_class(name, bases)
85    specialized.__module__ = cls.__module__
86    specialized._NAME = cls.__name__
87    specialized._KIND = kind
88    for attr, values in kwargs.items():
89        value = values[kind]
90        setattr(specialized, attr, value)
91    return specialized
92
93
94def split_frozen(cls, base=None, **kwargs):
95    frozen = specialize_class(cls, 'Frozen', base, **kwargs)
96    source = specialize_class(cls, 'Source', base, **kwargs)
97    return frozen, source
98
99
100def test_both(test_class, base=None, **kwargs):
101    return split_frozen(test_class, base, **kwargs)
102
103
104CASE_INSENSITIVE_FS = True
105# Windows is the only OS that is *always* case-insensitive
106# (OS X *can* be case-sensitive).
107if sys.platform not in ('win32', 'cygwin'):
108    changed_name = __file__.upper()
109    if changed_name == __file__:
110        changed_name = __file__.lower()
111    if not os.path.exists(changed_name):
112        CASE_INSENSITIVE_FS = False
113
114source_importlib = import_importlib('importlib')['Source']
115__import__ = {'Frozen': staticmethod(builtins.__import__),
116              'Source': staticmethod(source_importlib.__import__)}
117
118
119def case_insensitive_tests(test):
120    """Class decorator that nullifies tests requiring a case-insensitive
121    file system."""
122    return unittest.skipIf(not CASE_INSENSITIVE_FS,
123                            "requires a case-insensitive filesystem")(test)
124
125
126def submodule(parent, name, pkg_dir, content=''):
127    path = os.path.join(pkg_dir, name + '.py')
128    with open(path, 'w', encoding='utf-8') as subfile:
129        subfile.write(content)
130    return '{}.{}'.format(parent, name), path
131
132
133def get_code_from_pyc(pyc_path):
134    """Reads a pyc file and returns the unmarshalled code object within.
135
136    No header validation is performed.
137    """
138    with open(pyc_path, 'rb') as pyc_f:
139        pyc_f.seek(16)
140        return marshal.load(pyc_f)
141
142
143@contextlib.contextmanager
144def uncache(*names):
145    """Uncache a module from sys.modules.
146
147    A basic sanity check is performed to prevent uncaching modules that either
148    cannot/shouldn't be uncached.
149
150    """
151    for name in names:
152        if name in ('sys', 'marshal'):
153            raise ValueError("cannot uncache {}".format(name))
154        try:
155            del sys.modules[name]
156        except KeyError:
157            pass
158    try:
159        yield
160    finally:
161        for name in names:
162            try:
163                del sys.modules[name]
164            except KeyError:
165                pass
166
167
168@contextlib.contextmanager
169def temp_module(name, content='', *, pkg=False):
170    conflicts = [n for n in sys.modules if n.partition('.')[0] == name]
171    with os_helper.temp_cwd(None) as cwd:
172        with uncache(name, *conflicts):
173            with import_helper.DirsOnSysPath(cwd):
174                invalidate_caches()
175
176                location = os.path.join(cwd, name)
177                if pkg:
178                    modpath = os.path.join(location, '__init__.py')
179                    os.mkdir(name)
180                else:
181                    modpath = location + '.py'
182                    if content is None:
183                        # Make sure the module file gets created.
184                        content = ''
185                if content is not None:
186                    # not a namespace package
187                    with open(modpath, 'w', encoding='utf-8') as modfile:
188                        modfile.write(content)
189                yield location
190
191
192@contextlib.contextmanager
193def import_state(**kwargs):
194    """Context manager to manage the various importers and stored state in the
195    sys module.
196
197    The 'modules' attribute is not supported as the interpreter state stores a
198    pointer to the dict that the interpreter uses internally;
199    reassigning to sys.modules does not have the desired effect.
200
201    """
202    originals = {}
203    try:
204        for attr, default in (('meta_path', []), ('path', []),
205                              ('path_hooks', []),
206                              ('path_importer_cache', {})):
207            originals[attr] = getattr(sys, attr)
208            if attr in kwargs:
209                new_value = kwargs[attr]
210                del kwargs[attr]
211            else:
212                new_value = default
213            setattr(sys, attr, new_value)
214        if len(kwargs):
215            raise ValueError('unrecognized arguments: {}'.format(kwargs))
216        yield
217    finally:
218        for attr, value in originals.items():
219            setattr(sys, attr, value)
220
221
222class _ImporterMock:
223
224    """Base class to help with creating importer mocks."""
225
226    def __init__(self, *names, module_code={}):
227        self.modules = {}
228        self.module_code = {}
229        for name in names:
230            if not name.endswith('.__init__'):
231                import_name = name
232            else:
233                import_name = name[:-len('.__init__')]
234            if '.' not in name:
235                package = None
236            elif import_name == name:
237                package = name.rsplit('.', 1)[0]
238            else:
239                package = import_name
240            module = types.ModuleType(import_name)
241            module.__loader__ = self
242            module.__file__ = '<mock __file__>'
243            module.__package__ = package
244            module.attr = name
245            if import_name != name:
246                module.__path__ = ['<mock __path__>']
247            self.modules[import_name] = module
248            if import_name in module_code:
249                self.module_code[import_name] = module_code[import_name]
250
251    def __getitem__(self, name):
252        return self.modules[name]
253
254    def __enter__(self):
255        self._uncache = uncache(*self.modules.keys())
256        self._uncache.__enter__()
257        return self
258
259    def __exit__(self, *exc_info):
260        self._uncache.__exit__(None, None, None)
261
262
263class mock_spec(_ImporterMock):
264
265    """Importer mock using PEP 451 APIs."""
266
267    def find_spec(self, fullname, path=None, parent=None):
268        try:
269            module = self.modules[fullname]
270        except KeyError:
271            return None
272        spec = util.spec_from_file_location(
273                fullname, module.__file__, loader=self,
274                submodule_search_locations=getattr(module, '__path__', None))
275        return spec
276
277    def create_module(self, spec):
278        if spec.name not in self.modules:
279            raise ImportError
280        return self.modules[spec.name]
281
282    def exec_module(self, module):
283        try:
284            self.module_code[module.__spec__.name]()
285        except KeyError:
286            pass
287
288
289def writes_bytecode_files(fxn):
290    """Decorator to protect sys.dont_write_bytecode from mutation and to skip
291    tests that require it to be set to False."""
292    if sys.dont_write_bytecode:
293        return unittest.skip("relies on writing bytecode")(fxn)
294    @functools.wraps(fxn)
295    def wrapper(*args, **kwargs):
296        original = sys.dont_write_bytecode
297        sys.dont_write_bytecode = False
298        try:
299            to_return = fxn(*args, **kwargs)
300        finally:
301            sys.dont_write_bytecode = original
302        return to_return
303    return wrapper
304
305
306def ensure_bytecode_path(bytecode_path):
307    """Ensure that the __pycache__ directory for PEP 3147 pyc file exists.
308
309    :param bytecode_path: File system path to PEP 3147 pyc file.
310    """
311    try:
312        os.mkdir(os.path.dirname(bytecode_path))
313    except OSError as error:
314        if error.errno != errno.EEXIST:
315            raise
316
317
318@contextlib.contextmanager
319def temporary_pycache_prefix(prefix):
320    """Adjust and restore sys.pycache_prefix."""
321    _orig_prefix = sys.pycache_prefix
322    sys.pycache_prefix = prefix
323    try:
324        yield
325    finally:
326        sys.pycache_prefix = _orig_prefix
327
328
329@contextlib.contextmanager
330def create_modules(*names):
331    """Temporarily create each named module with an attribute (named 'attr')
332    that contains the name passed into the context manager that caused the
333    creation of the module.
334
335    All files are created in a temporary directory returned by
336    tempfile.mkdtemp(). This directory is inserted at the beginning of
337    sys.path. When the context manager exits all created files (source and
338    bytecode) are explicitly deleted.
339
340    No magic is performed when creating packages! This means that if you create
341    a module within a package you must also create the package's __init__ as
342    well.
343
344    """
345    source = 'attr = {0!r}'
346    created_paths = []
347    mapping = {}
348    state_manager = None
349    uncache_manager = None
350    try:
351        temp_dir = tempfile.mkdtemp()
352        mapping['.root'] = temp_dir
353        import_names = set()
354        for name in names:
355            if not name.endswith('__init__'):
356                import_name = name
357            else:
358                import_name = name[:-len('.__init__')]
359            import_names.add(import_name)
360            if import_name in sys.modules:
361                del sys.modules[import_name]
362            name_parts = name.split('.')
363            file_path = temp_dir
364            for directory in name_parts[:-1]:
365                file_path = os.path.join(file_path, directory)
366                if not os.path.exists(file_path):
367                    os.mkdir(file_path)
368                    created_paths.append(file_path)
369            file_path = os.path.join(file_path, name_parts[-1] + '.py')
370            with open(file_path, 'w', encoding='utf-8') as file:
371                file.write(source.format(name))
372            created_paths.append(file_path)
373            mapping[name] = file_path
374        uncache_manager = uncache(*import_names)
375        uncache_manager.__enter__()
376        state_manager = import_state(path=[temp_dir])
377        state_manager.__enter__()
378        yield mapping
379    finally:
380        if state_manager is not None:
381            state_manager.__exit__(None, None, None)
382        if uncache_manager is not None:
383            uncache_manager.__exit__(None, None, None)
384        os_helper.rmtree(temp_dir)
385
386
387def mock_path_hook(*entries, importer):
388    """A mock sys.path_hooks entry."""
389    def hook(entry):
390        if entry not in entries:
391            raise ImportError
392        return importer
393    return hook
394
395
396class CASEOKTestBase:
397
398    def caseok_env_changed(self, *, should_exist):
399        possibilities = b'PYTHONCASEOK', 'PYTHONCASEOK'
400        if any(x in self.importlib._bootstrap_external._os.environ
401                    for x in possibilities) != should_exist:
402            self.skipTest('os.environ changes not reflected in _os.environ')
403