• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import builtins
2import locale
3import os
4import sys
5import threading
6
7from test import support
8from test.support import os_helper
9
10from .utils import print_warning
11
12
13class SkipTestEnvironment(Exception):
14    pass
15
16
17# Unit tests are supposed to leave the execution environment unchanged
18# once they complete.  But sometimes tests have bugs, especially when
19# tests fail, and the changes to environment go on to mess up other
20# tests.  This can cause issues with buildbot stability, since tests
21# are run in random order and so problems may appear to come and go.
22# There are a few things we can save and restore to mitigate this, and
23# the following context manager handles this task.
24
25class saved_test_environment:
26    """Save bits of the test environment and restore them at block exit.
27
28        with saved_test_environment(test_name, verbose, quiet):
29            #stuff
30
31    Unless quiet is True, a warning is printed to stderr if any of
32    the saved items was changed by the test. The support.environment_altered
33    attribute is set to True if a change is detected.
34
35    If verbose is more than 1, the before and after state of changed
36    items is also printed.
37    """
38
39    def __init__(self, test_name, verbose, quiet, *, pgo):
40        self.test_name = test_name
41        self.verbose = verbose
42        self.quiet = quiet
43        self.pgo = pgo
44
45    # To add things to save and restore, add a name XXX to the resources list
46    # and add corresponding get_XXX/restore_XXX functions.  get_XXX should
47    # return the value to be saved and compared against a second call to the
48    # get function when test execution completes.  restore_XXX should accept
49    # the saved value and restore the resource using it.  It will be called if
50    # and only if a change in the value is detected.
51    #
52    # Note: XXX will have any '.' replaced with '_' characters when determining
53    # the corresponding method names.
54
55    resources = ('sys.argv', 'cwd', 'sys.stdin', 'sys.stdout', 'sys.stderr',
56                 'os.environ', 'sys.path', 'sys.path_hooks', '__import__',
57                 'warnings.filters', 'asyncore.socket_map',
58                 'logging._handlers', 'logging._handlerList', 'sys.gettrace',
59                 'sys.warnoptions',
60                 # multiprocessing.process._cleanup() may release ref
61                 # to a thread, so check processes first.
62                 'multiprocessing.process._dangling', 'threading._dangling',
63                 'sysconfig._CONFIG_VARS', 'sysconfig._INSTALL_SCHEMES',
64                 'files', 'locale', 'warnings.showwarning',
65                 'shutil_archive_formats', 'shutil_unpack_formats',
66                 'asyncio.events._event_loop_policy',
67                 'urllib.requests._url_tempfiles', 'urllib.requests._opener',
68                )
69
70    def get_module(self, name):
71        # function for restore() methods
72        return sys.modules[name]
73
74    def try_get_module(self, name):
75        # function for get() methods
76        try:
77            return self.get_module(name)
78        except KeyError:
79            raise SkipTestEnvironment
80
81    def get_urllib_requests__url_tempfiles(self):
82        urllib_request = self.try_get_module('urllib.request')
83        return list(urllib_request._url_tempfiles)
84    def restore_urllib_requests__url_tempfiles(self, tempfiles):
85        for filename in tempfiles:
86            os_helper.unlink(filename)
87
88    def get_urllib_requests__opener(self):
89        urllib_request = self.try_get_module('urllib.request')
90        return urllib_request._opener
91    def restore_urllib_requests__opener(self, opener):
92        urllib_request = self.get_module('urllib.request')
93        urllib_request._opener = opener
94
95    def get_asyncio_events__event_loop_policy(self):
96        self.try_get_module('asyncio')
97        return support.maybe_get_event_loop_policy()
98    def restore_asyncio_events__event_loop_policy(self, policy):
99        asyncio = self.get_module('asyncio')
100        asyncio.set_event_loop_policy(policy)
101
102    def get_sys_argv(self):
103        return id(sys.argv), sys.argv, sys.argv[:]
104    def restore_sys_argv(self, saved_argv):
105        sys.argv = saved_argv[1]
106        sys.argv[:] = saved_argv[2]
107
108    def get_cwd(self):
109        return os.getcwd()
110    def restore_cwd(self, saved_cwd):
111        os.chdir(saved_cwd)
112
113    def get_sys_stdout(self):
114        return sys.stdout
115    def restore_sys_stdout(self, saved_stdout):
116        sys.stdout = saved_stdout
117
118    def get_sys_stderr(self):
119        return sys.stderr
120    def restore_sys_stderr(self, saved_stderr):
121        sys.stderr = saved_stderr
122
123    def get_sys_stdin(self):
124        return sys.stdin
125    def restore_sys_stdin(self, saved_stdin):
126        sys.stdin = saved_stdin
127
128    def get_os_environ(self):
129        return id(os.environ), os.environ, dict(os.environ)
130    def restore_os_environ(self, saved_environ):
131        os.environ = saved_environ[1]
132        os.environ.clear()
133        os.environ.update(saved_environ[2])
134
135    def get_sys_path(self):
136        return id(sys.path), sys.path, sys.path[:]
137    def restore_sys_path(self, saved_path):
138        sys.path = saved_path[1]
139        sys.path[:] = saved_path[2]
140
141    def get_sys_path_hooks(self):
142        return id(sys.path_hooks), sys.path_hooks, sys.path_hooks[:]
143    def restore_sys_path_hooks(self, saved_hooks):
144        sys.path_hooks = saved_hooks[1]
145        sys.path_hooks[:] = saved_hooks[2]
146
147    def get_sys_gettrace(self):
148        return sys.gettrace()
149    def restore_sys_gettrace(self, trace_fxn):
150        sys.settrace(trace_fxn)
151
152    def get___import__(self):
153        return builtins.__import__
154    def restore___import__(self, import_):
155        builtins.__import__ = import_
156
157    def get_warnings_filters(self):
158        warnings = self.try_get_module('warnings')
159        return id(warnings.filters), warnings.filters, warnings.filters[:]
160    def restore_warnings_filters(self, saved_filters):
161        warnings = self.get_module('warnings')
162        warnings.filters = saved_filters[1]
163        warnings.filters[:] = saved_filters[2]
164
165    def get_asyncore_socket_map(self):
166        asyncore = sys.modules.get('test.support.asyncore')
167        # XXX Making a copy keeps objects alive until __exit__ gets called.
168        return asyncore and asyncore.socket_map.copy() or {}
169    def restore_asyncore_socket_map(self, saved_map):
170        asyncore = sys.modules.get('test.support.asyncore')
171        if asyncore is not None:
172            asyncore.close_all(ignore_all=True)
173            asyncore.socket_map.update(saved_map)
174
175    def get_shutil_archive_formats(self):
176        shutil = self.try_get_module('shutil')
177        # we could call get_archives_formats() but that only returns the
178        # registry keys; we want to check the values too (the functions that
179        # are registered)
180        return shutil._ARCHIVE_FORMATS, shutil._ARCHIVE_FORMATS.copy()
181    def restore_shutil_archive_formats(self, saved):
182        shutil = self.get_module('shutil')
183        shutil._ARCHIVE_FORMATS = saved[0]
184        shutil._ARCHIVE_FORMATS.clear()
185        shutil._ARCHIVE_FORMATS.update(saved[1])
186
187    def get_shutil_unpack_formats(self):
188        shutil = self.try_get_module('shutil')
189        return shutil._UNPACK_FORMATS, shutil._UNPACK_FORMATS.copy()
190    def restore_shutil_unpack_formats(self, saved):
191        shutil = self.get_module('shutil')
192        shutil._UNPACK_FORMATS = saved[0]
193        shutil._UNPACK_FORMATS.clear()
194        shutil._UNPACK_FORMATS.update(saved[1])
195
196    def get_logging__handlers(self):
197        logging = self.try_get_module('logging')
198        # _handlers is a WeakValueDictionary
199        return id(logging._handlers), logging._handlers, logging._handlers.copy()
200    def restore_logging__handlers(self, saved_handlers):
201        # Can't easily revert the logging state
202        pass
203
204    def get_logging__handlerList(self):
205        logging = self.try_get_module('logging')
206        # _handlerList is a list of weakrefs to handlers
207        return id(logging._handlerList), logging._handlerList, logging._handlerList[:]
208    def restore_logging__handlerList(self, saved_handlerList):
209        # Can't easily revert the logging state
210        pass
211
212    def get_sys_warnoptions(self):
213        return id(sys.warnoptions), sys.warnoptions, sys.warnoptions[:]
214    def restore_sys_warnoptions(self, saved_options):
215        sys.warnoptions = saved_options[1]
216        sys.warnoptions[:] = saved_options[2]
217
218    # Controlling dangling references to Thread objects can make it easier
219    # to track reference leaks.
220    def get_threading__dangling(self):
221        # This copies the weakrefs without making any strong reference
222        return threading._dangling.copy()
223    def restore_threading__dangling(self, saved):
224        threading._dangling.clear()
225        threading._dangling.update(saved)
226
227    # Same for Process objects
228    def get_multiprocessing_process__dangling(self):
229        multiprocessing_process = self.try_get_module('multiprocessing.process')
230        # Unjoined process objects can survive after process exits
231        multiprocessing_process._cleanup()
232        # This copies the weakrefs without making any strong reference
233        return multiprocessing_process._dangling.copy()
234    def restore_multiprocessing_process__dangling(self, saved):
235        multiprocessing_process = self.get_module('multiprocessing.process')
236        multiprocessing_process._dangling.clear()
237        multiprocessing_process._dangling.update(saved)
238
239    def get_sysconfig__CONFIG_VARS(self):
240        # make sure the dict is initialized
241        sysconfig = self.try_get_module('sysconfig')
242        sysconfig.get_config_var('prefix')
243        return (id(sysconfig._CONFIG_VARS), sysconfig._CONFIG_VARS,
244                dict(sysconfig._CONFIG_VARS))
245    def restore_sysconfig__CONFIG_VARS(self, saved):
246        sysconfig = self.get_module('sysconfig')
247        sysconfig._CONFIG_VARS = saved[1]
248        sysconfig._CONFIG_VARS.clear()
249        sysconfig._CONFIG_VARS.update(saved[2])
250
251    def get_sysconfig__INSTALL_SCHEMES(self):
252        sysconfig = self.try_get_module('sysconfig')
253        return (id(sysconfig._INSTALL_SCHEMES), sysconfig._INSTALL_SCHEMES,
254                sysconfig._INSTALL_SCHEMES.copy())
255    def restore_sysconfig__INSTALL_SCHEMES(self, saved):
256        sysconfig = self.get_module('sysconfig')
257        sysconfig._INSTALL_SCHEMES = saved[1]
258        sysconfig._INSTALL_SCHEMES.clear()
259        sysconfig._INSTALL_SCHEMES.update(saved[2])
260
261    def get_files(self):
262        # XXX: Maybe add an allow-list here?
263        return sorted(fn + ('/' if os.path.isdir(fn) else '')
264                      for fn in os.listdir()
265                      if not fn.startswith(".hypothesis"))
266    def restore_files(self, saved_value):
267        fn = os_helper.TESTFN
268        if fn not in saved_value and (fn + '/') not in saved_value:
269            if os.path.isfile(fn):
270                os_helper.unlink(fn)
271            elif os.path.isdir(fn):
272                os_helper.rmtree(fn)
273
274    _lc = [getattr(locale, lc) for lc in dir(locale)
275           if lc.startswith('LC_')]
276    def get_locale(self):
277        pairings = []
278        for lc in self._lc:
279            try:
280                pairings.append((lc, locale.setlocale(lc, None)))
281            except (TypeError, ValueError):
282                continue
283        return pairings
284    def restore_locale(self, saved):
285        for lc, setting in saved:
286            locale.setlocale(lc, setting)
287
288    def get_warnings_showwarning(self):
289        warnings = self.try_get_module('warnings')
290        return warnings.showwarning
291    def restore_warnings_showwarning(self, fxn):
292        warnings = self.get_module('warnings')
293        warnings.showwarning = fxn
294
295    def resource_info(self):
296        for name in self.resources:
297            method_suffix = name.replace('.', '_')
298            get_name = 'get_' + method_suffix
299            restore_name = 'restore_' + method_suffix
300            yield name, getattr(self, get_name), getattr(self, restore_name)
301
302    def __enter__(self):
303        self.saved_values = []
304        for name, get, restore in self.resource_info():
305            try:
306                original = get()
307            except SkipTestEnvironment:
308                continue
309
310            self.saved_values.append((name, get, restore, original))
311        return self
312
313    def __exit__(self, exc_type, exc_val, exc_tb):
314        saved_values = self.saved_values
315        self.saved_values = None
316
317        # Some resources use weak references
318        support.gc_collect()
319
320        for name, get, restore, original in saved_values:
321            current = get()
322            # Check for changes to the resource's value
323            if current != original:
324                support.environment_altered = True
325                restore(original)
326                if not self.quiet and not self.pgo:
327                    print_warning(
328                        f"{name} was modified by {self.test_name}\n"
329                        f"  Before: {original}\n"
330                        f"  After:  {current} ")
331        return False
332