• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import unicode_literals
2import os
3import time
4import subprocess
5import warnings
6import tempfile
7import pickle
8
9
10class WarningTestMixin(object):
11    # Based on https://stackoverflow.com/a/12935176/467366
12    class _AssertWarnsContext(warnings.catch_warnings):
13        def __init__(self, expected_warnings, parent, **kwargs):
14            super(WarningTestMixin._AssertWarnsContext, self).__init__(**kwargs)
15
16            self.parent = parent
17            try:
18                self.expected_warnings = list(expected_warnings)
19            except TypeError:
20                self.expected_warnings = [expected_warnings]
21
22            self._warning_log = []
23
24        def __enter__(self, *args, **kwargs):
25            rv = super(WarningTestMixin._AssertWarnsContext, self).__enter__(*args, **kwargs)
26
27            if self._showwarning is not self._module.showwarning:
28                super_showwarning = self._module.showwarning
29            else:
30                super_showwarning = None
31
32            def showwarning(*args, **kwargs):
33                if super_showwarning is not None:
34                    super_showwarning(*args, **kwargs)
35
36                self._warning_log.append(warnings.WarningMessage(*args, **kwargs))
37
38            self._module.showwarning = showwarning
39            return rv
40
41        def __exit__(self, *args, **kwargs):
42            super(WarningTestMixin._AssertWarnsContext, self).__exit__(self, *args, **kwargs)
43
44            self.parent.assertTrue(any(issubclass(item.category, warning)
45                                       for warning in self.expected_warnings
46                                       for item in self._warning_log))
47
48    def assertWarns(self, warning, callable=None, *args, **kwargs):
49        warnings.simplefilter('always')
50        context = self.__class__._AssertWarnsContext(warning, self)
51        if callable is None:
52            return context
53        else:
54            with context:
55                callable(*args, **kwargs)
56
57
58class PicklableMixin(object):
59    def _get_nobj_bytes(self, obj, dump_kwargs, load_kwargs):
60        """
61        Pickle and unpickle an object using ``pickle.dumps`` / ``pickle.loads``
62        """
63        pkl = pickle.dumps(obj, **dump_kwargs)
64        return pickle.loads(pkl, **load_kwargs)
65
66    def _get_nobj_file(self, obj, dump_kwargs, load_kwargs):
67        """
68        Pickle and unpickle an object using ``pickle.dump`` / ``pickle.load`` on
69        a temporary file.
70        """
71        with tempfile.TemporaryFile('w+b') as pkl:
72            pickle.dump(obj, pkl, **dump_kwargs)
73            pkl.seek(0)         # Reset the file to the beginning to read it
74            nobj = pickle.load(pkl, **load_kwargs)
75
76        return nobj
77
78    def assertPicklable(self, obj, singleton=False, asfile=False,
79                        dump_kwargs=None, load_kwargs=None):
80        """
81        Assert that an object can be pickled and unpickled. This assertion
82        assumes that the desired behavior is that the unpickled object compares
83        equal to the original object, but is not the same object.
84        """
85        get_nobj = self._get_nobj_file if asfile else self._get_nobj_bytes
86        dump_kwargs = dump_kwargs or {}
87        load_kwargs = load_kwargs or {}
88
89        nobj = get_nobj(obj, dump_kwargs, load_kwargs)
90        if not singleton:
91            self.assertIsNot(obj, nobj)
92        self.assertEqual(obj, nobj)
93
94
95class TZContextBase(object):
96    """
97    Base class for a context manager which allows changing of time zones.
98
99    Subclasses may define a guard variable to either block or or allow time
100    zone changes by redefining ``_guard_var_name`` and ``_guard_allows_change``.
101    The default is that the guard variable must be affirmatively set.
102
103    Subclasses must define ``get_current_tz`` and ``set_current_tz``.
104    """
105    _guard_var_name = "DATEUTIL_MAY_CHANGE_TZ"
106    _guard_allows_change = True
107
108    def __init__(self, tzval):
109        self.tzval = tzval
110        self._old_tz = None
111
112    @classmethod
113    def tz_change_allowed(cls):
114        """
115        Class method used to query whether or not this class allows time zone
116        changes.
117        """
118        guard = bool(os.environ.get(cls._guard_var_name, False))
119
120        # _guard_allows_change gives the "default" behavior - if True, the
121        # guard is overcoming a block. If false, the guard is causing a block.
122        # Whether tz_change is allowed is therefore the XNOR of the two.
123        return guard == cls._guard_allows_change
124
125    @classmethod
126    def tz_change_disallowed_message(cls):
127        """ Generate instructions on how to allow tz changes """
128        msg = ('Changing time zone not allowed. Set {envar} to {gval} '
129               'if you would like to allow this behavior')
130
131        return msg.format(envar=cls._guard_var_name,
132                          gval=cls._guard_allows_change)
133
134    def __enter__(self):
135        if not self.tz_change_allowed():
136            raise ValueError(self.tz_change_disallowed_message())
137
138        self._old_tz = self.get_current_tz()
139        self.set_current_tz(self.tzval)
140
141    def __exit__(self, type, value, traceback):
142        if self._old_tz is not None:
143            self.set_current_tz(self._old_tz)
144
145        self._old_tz = None
146
147    def get_current_tz(self):
148        raise NotImplementedError
149
150    def set_current_tz(self):
151        raise NotImplementedError
152
153
154class TZEnvContext(TZContextBase):
155    """
156    Context manager that temporarily sets the `TZ` variable (for use on
157    *nix-like systems). Because the effect is local to the shell anyway, this
158    will apply *unless* a guard is set.
159
160    If you do not want the TZ environment variable set, you may set the
161    ``DATEUTIL_MAY_NOT_CHANGE_TZ_VAR`` variable to a truthy value.
162    """
163    _guard_var_name = "DATEUTIL_MAY_NOT_CHANGE_TZ_VAR"
164    _guard_allows_change = False
165
166    def get_current_tz(self):
167        return os.environ.get('TZ', UnsetTz)
168
169    def set_current_tz(self, tzval):
170        if tzval is UnsetTz and 'TZ' in os.environ:
171            del os.environ['TZ']
172        else:
173            os.environ['TZ'] = tzval
174
175        time.tzset()
176
177
178class TZWinContext(TZContextBase):
179    """
180    Context manager for changing local time zone on Windows.
181
182    Because the effect of this is system-wide and global, it may have
183    unintended side effect. Set the ``DATEUTIL_MAY_CHANGE_TZ`` environment
184    variable to a truthy value before using this context manager.
185    """
186    def get_current_tz(self):
187        p = subprocess.Popen(['tzutil', '/g'], stdout=subprocess.PIPE)
188
189        ctzname, err = p.communicate()
190        ctzname = ctzname.decode()     # Popen returns
191
192        if p.returncode:
193            raise OSError('Failed to get current time zone: ' + err)
194
195        return ctzname
196
197    def set_current_tz(self, tzname):
198        p = subprocess.Popen('tzutil /s "' + tzname + '"')
199
200        out, err = p.communicate()
201
202        if p.returncode:
203            raise OSError('Failed to set current time zone: ' +
204                          (err or 'Unknown error.'))
205
206
207###
208# Utility classes
209class NotAValueClass(object):
210    """
211    A class analogous to NaN that has operations defined for any type.
212    """
213    def _op(self, other):
214        return self             # Operation with NotAValue returns NotAValue
215
216    def _cmp(self, other):
217        return False
218
219    __add__ = __radd__ = _op
220    __sub__ = __rsub__ = _op
221    __mul__ = __rmul__ = _op
222    __div__ = __rdiv__ = _op
223    __truediv__ = __rtruediv__ = _op
224    __floordiv__ = __rfloordiv__ = _op
225
226    __lt__ = __rlt__ = _op
227    __gt__ = __rgt__ = _op
228    __eq__ = __req__ = _op
229    __le__ = __rle__ = _op
230    __ge__ = __rge__ = _op
231
232
233NotAValue = NotAValueClass()
234
235
236class ComparesEqualClass(object):
237    """
238    A class that is always equal to whatever you compare it to.
239    """
240
241    def __eq__(self, other):
242        return True
243
244    def __ne__(self, other):
245        return False
246
247    def __le__(self, other):
248        return True
249
250    def __ge__(self, other):
251        return True
252
253    def __lt__(self, other):
254        return False
255
256    def __gt__(self, other):
257        return False
258
259    __req__ = __eq__
260    __rne__ = __ne__
261    __rle__ = __le__
262    __rge__ = __ge__
263    __rlt__ = __lt__
264    __rgt__ = __gt__
265
266
267ComparesEqual = ComparesEqualClass()
268
269
270class UnsetTzClass(object):
271    """ Sentinel class for unset time zone variable """
272    pass
273
274
275UnsetTz = UnsetTzClass()
276