• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2010-2017 Benjamin Peterson
2#
3# Permission is hereby granted, free of charge, to any person obtaining a copy
4# of this software and associated documentation files (the "Software"), to deal
5# in the Software without restriction, including without limitation the rights
6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7# copies of the Software, and to permit persons to whom the Software is
8# furnished to do so, subject to the following conditions:
9#
10# The above copyright notice and this permission notice shall be included in all
11# copies or substantial portions of the Software.
12#
13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19# SOFTWARE.
20
21import operator
22import sys
23import types
24import unittest
25
26import py
27
28import six
29
30
31def test_add_doc():
32    def f():
33        """Icky doc"""
34        pass
35    six._add_doc(f, """New doc""")
36    assert f.__doc__ == "New doc"
37
38
39def test_import_module():
40    from logging import handlers
41    m = six._import_module("logging.handlers")
42    assert m is handlers
43
44
45def test_integer_types():
46    assert isinstance(1, six.integer_types)
47    assert isinstance(-1, six.integer_types)
48    assert isinstance(six.MAXSIZE + 23, six.integer_types)
49    assert not isinstance(.1, six.integer_types)
50
51
52def test_string_types():
53    assert isinstance("hi", six.string_types)
54    assert isinstance(six.u("hi"), six.string_types)
55    assert issubclass(six.text_type, six.string_types)
56
57
58def test_class_types():
59    class X:
60        pass
61    class Y(object):
62        pass
63    assert isinstance(X, six.class_types)
64    assert isinstance(Y, six.class_types)
65    assert not isinstance(X(), six.class_types)
66
67
68def test_text_type():
69    assert type(six.u("hi")) is six.text_type
70
71
72def test_binary_type():
73    assert type(six.b("hi")) is six.binary_type
74
75
76def test_MAXSIZE():
77    try:
78        # This shouldn't raise an overflow error.
79        six.MAXSIZE.__index__()
80    except AttributeError:
81        # Before Python 2.6.
82        pass
83    py.test.raises(
84        (ValueError, OverflowError),
85        operator.mul, [None], six.MAXSIZE + 1)
86
87
88def test_lazy():
89    if six.PY3:
90        html_name = "html.parser"
91    else:
92        html_name = "HTMLParser"
93    assert html_name not in sys.modules
94    mod = six.moves.html_parser
95    assert sys.modules[html_name] is mod
96    assert "htmlparser" not in six._MovedItems.__dict__
97
98
99try:
100    import _tkinter
101except ImportError:
102    have_tkinter = False
103else:
104    have_tkinter = True
105
106have_gdbm = True
107try:
108    import gdbm
109except ImportError:
110    try:
111        import dbm.gnu
112    except ImportError:
113        have_gdbm = False
114
115@py.test.mark.parametrize("item_name",
116                          [item.name for item in six._moved_attributes])
117def test_move_items(item_name):
118    """Ensure that everything loads correctly."""
119    try:
120        item = getattr(six.moves, item_name)
121        if isinstance(item, types.ModuleType):
122            __import__("six.moves." + item_name)
123    except AttributeError:
124        if item_name == "zip_longest" and sys.version_info < (2, 6):
125            py.test.skip("zip_longest only available on 2.6+")
126    except ImportError:
127        if item_name == "winreg" and not sys.platform.startswith("win"):
128            py.test.skip("Windows only module")
129        if item_name.startswith("tkinter"):
130            if not have_tkinter:
131                py.test.skip("requires tkinter")
132            if item_name == "tkinter_ttk" and sys.version_info[:2] <= (2, 6):
133                py.test.skip("ttk only available on 2.7+")
134        if item_name.startswith("dbm_gnu") and not have_gdbm:
135            py.test.skip("requires gdbm")
136        raise
137    if sys.version_info[:2] >= (2, 6):
138        assert item_name in dir(six.moves)
139
140
141@py.test.mark.parametrize("item_name",
142                          [item.name for item in six._urllib_parse_moved_attributes])
143def test_move_items_urllib_parse(item_name):
144    """Ensure that everything loads correctly."""
145    if item_name == "ParseResult" and sys.version_info < (2, 5):
146        py.test.skip("ParseResult is only found on 2.5+")
147    if item_name in ("parse_qs", "parse_qsl") and sys.version_info < (2, 6):
148        py.test.skip("parse_qs[l] is new in 2.6")
149    if sys.version_info[:2] >= (2, 6):
150        assert item_name in dir(six.moves.urllib.parse)
151    getattr(six.moves.urllib.parse, item_name)
152
153
154@py.test.mark.parametrize("item_name",
155                          [item.name for item in six._urllib_error_moved_attributes])
156def test_move_items_urllib_error(item_name):
157    """Ensure that everything loads correctly."""
158    if sys.version_info[:2] >= (2, 6):
159        assert item_name in dir(six.moves.urllib.error)
160    getattr(six.moves.urllib.error, item_name)
161
162
163@py.test.mark.parametrize("item_name",
164                          [item.name for item in six._urllib_request_moved_attributes])
165def test_move_items_urllib_request(item_name):
166    """Ensure that everything loads correctly."""
167    if sys.version_info[:2] >= (2, 6):
168        assert item_name in dir(six.moves.urllib.request)
169    getattr(six.moves.urllib.request, item_name)
170
171
172@py.test.mark.parametrize("item_name",
173                          [item.name for item in six._urllib_response_moved_attributes])
174def test_move_items_urllib_response(item_name):
175    """Ensure that everything loads correctly."""
176    if sys.version_info[:2] >= (2, 6):
177        assert item_name in dir(six.moves.urllib.response)
178    getattr(six.moves.urllib.response, item_name)
179
180
181@py.test.mark.parametrize("item_name",
182                          [item.name for item in six._urllib_robotparser_moved_attributes])
183def test_move_items_urllib_robotparser(item_name):
184    """Ensure that everything loads correctly."""
185    if sys.version_info[:2] >= (2, 6):
186        assert item_name in dir(six.moves.urllib.robotparser)
187    getattr(six.moves.urllib.robotparser, item_name)
188
189
190def test_import_moves_error_1():
191    from six.moves.urllib.parse import urljoin
192    from six import moves
193    # In 1.4.1: AttributeError: 'Module_six_moves_urllib_parse' object has no attribute 'urljoin'
194    assert moves.urllib.parse.urljoin
195
196
197def test_import_moves_error_2():
198    from six import moves
199    assert moves.urllib.parse.urljoin
200    # In 1.4.1: ImportError: cannot import name urljoin
201    from six.moves.urllib.parse import urljoin
202
203
204def test_import_moves_error_3():
205    from six.moves.urllib.parse import urljoin
206    # In 1.4.1: ImportError: cannot import name urljoin
207    from six.moves.urllib_parse import urljoin
208
209
210def test_from_imports():
211    from six.moves.queue import Queue
212    assert isinstance(Queue, six.class_types)
213    from six.moves.configparser import ConfigParser
214    assert isinstance(ConfigParser, six.class_types)
215
216
217def test_filter():
218    from six.moves import filter
219    f = filter(lambda x: x % 2, range(10))
220    assert six.advance_iterator(f) == 1
221
222
223def test_filter_false():
224    from six.moves import filterfalse
225    f = filterfalse(lambda x: x % 3, range(10))
226    assert six.advance_iterator(f) == 0
227    assert six.advance_iterator(f) == 3
228    assert six.advance_iterator(f) == 6
229
230def test_map():
231    from six.moves import map
232    assert six.advance_iterator(map(lambda x: x + 1, range(2))) == 1
233
234
235def test_getoutput():
236    from six.moves import getoutput
237    output = getoutput('echo "foo"')
238    assert output == 'foo'
239
240
241def test_zip():
242    from six.moves import zip
243    assert six.advance_iterator(zip(range(2), range(2))) == (0, 0)
244
245
246@py.test.mark.skipif("sys.version_info < (2, 6)")
247def test_zip_longest():
248    from six.moves import zip_longest
249    it = zip_longest(range(2), range(1))
250
251    assert six.advance_iterator(it) == (0, 0)
252    assert six.advance_iterator(it) == (1, None)
253
254
255class TestCustomizedMoves:
256
257    def teardown_method(self, meth):
258        try:
259            del six._MovedItems.spam
260        except AttributeError:
261            pass
262        try:
263            del six.moves.__dict__["spam"]
264        except KeyError:
265            pass
266
267
268    def test_moved_attribute(self):
269        attr = six.MovedAttribute("spam", "foo", "bar")
270        if six.PY3:
271            assert attr.mod == "bar"
272        else:
273            assert attr.mod == "foo"
274        assert attr.attr == "spam"
275        attr = six.MovedAttribute("spam", "foo", "bar", "lemma")
276        assert attr.attr == "lemma"
277        attr = six.MovedAttribute("spam", "foo", "bar", "lemma", "theorm")
278        if six.PY3:
279            assert attr.attr == "theorm"
280        else:
281            assert attr.attr == "lemma"
282
283
284    def test_moved_module(self):
285        attr = six.MovedModule("spam", "foo")
286        if six.PY3:
287            assert attr.mod == "spam"
288        else:
289            assert attr.mod == "foo"
290        attr = six.MovedModule("spam", "foo", "bar")
291        if six.PY3:
292            assert attr.mod == "bar"
293        else:
294            assert attr.mod == "foo"
295
296
297    def test_custom_move_module(self):
298        attr = six.MovedModule("spam", "six", "six")
299        six.add_move(attr)
300        six.remove_move("spam")
301        assert not hasattr(six.moves, "spam")
302        attr = six.MovedModule("spam", "six", "six")
303        six.add_move(attr)
304        from six.moves import spam
305        assert spam is six
306        six.remove_move("spam")
307        assert not hasattr(six.moves, "spam")
308
309
310    def test_custom_move_attribute(self):
311        attr = six.MovedAttribute("spam", "six", "six", "u", "u")
312        six.add_move(attr)
313        six.remove_move("spam")
314        assert not hasattr(six.moves, "spam")
315        attr = six.MovedAttribute("spam", "six", "six", "u", "u")
316        six.add_move(attr)
317        from six.moves import spam
318        assert spam is six.u
319        six.remove_move("spam")
320        assert not hasattr(six.moves, "spam")
321
322
323    def test_empty_remove(self):
324        py.test.raises(AttributeError, six.remove_move, "eggs")
325
326
327def test_get_unbound_function():
328    class X(object):
329        def m(self):
330            pass
331    assert six.get_unbound_function(X.m) is X.__dict__["m"]
332
333
334def test_get_method_self():
335    class X(object):
336        def m(self):
337            pass
338    x = X()
339    assert six.get_method_self(x.m) is x
340    py.test.raises(AttributeError, six.get_method_self, 42)
341
342
343def test_get_method_function():
344    class X(object):
345        def m(self):
346            pass
347    x = X()
348    assert six.get_method_function(x.m) is X.__dict__["m"]
349    py.test.raises(AttributeError, six.get_method_function, hasattr)
350
351
352def test_get_function_closure():
353    def f():
354        x = 42
355        def g():
356            return x
357        return g
358    cell = six.get_function_closure(f())[0]
359    assert type(cell).__name__ == "cell"
360
361
362def test_get_function_code():
363    def f():
364        pass
365    assert isinstance(six.get_function_code(f), types.CodeType)
366    if not hasattr(sys, "pypy_version_info"):
367        py.test.raises(AttributeError, six.get_function_code, hasattr)
368
369
370def test_get_function_defaults():
371    def f(x, y=3, b=4):
372        pass
373    assert six.get_function_defaults(f) == (3, 4)
374
375
376def test_get_function_globals():
377    def f():
378        pass
379    assert six.get_function_globals(f) is globals()
380
381
382def test_dictionary_iterators(monkeypatch):
383    def stock_method_name(iterwhat):
384        """Given a method suffix like "lists" or "values", return the name
385        of the dict method that delivers those on the version of Python
386        we're running in."""
387        if six.PY3:
388            return iterwhat
389        return 'iter' + iterwhat
390
391    class MyDict(dict):
392        if not six.PY3:
393            def lists(self, **kw):
394                return [1, 2, 3]
395        def iterlists(self, **kw):
396            return iter([1, 2, 3])
397    f = MyDict.iterlists
398    del MyDict.iterlists
399    setattr(MyDict, stock_method_name('lists'), f)
400
401    d = MyDict(zip(range(10), reversed(range(10))))
402    for name in "keys", "values", "items", "lists":
403        meth = getattr(six, "iter" + name)
404        it = meth(d)
405        assert not isinstance(it, list)
406        assert list(it) == list(getattr(d, name)())
407        py.test.raises(StopIteration, six.advance_iterator, it)
408        record = []
409        def with_kw(*args, **kw):
410            record.append(kw["kw"])
411            return old(*args)
412        old = getattr(MyDict, stock_method_name(name))
413        monkeypatch.setattr(MyDict, stock_method_name(name), with_kw)
414        meth(d, kw=42)
415        assert record == [42]
416        monkeypatch.undo()
417
418
419@py.test.mark.skipif("sys.version_info[:2] < (2, 7)",
420                reason="view methods on dictionaries only available on 2.7+")
421def test_dictionary_views():
422    def stock_method_name(viewwhat):
423        """Given a method suffix like "keys" or "values", return the name
424        of the dict method that delivers those on the version of Python
425        we're running in."""
426        if six.PY3:
427            return viewwhat
428        return 'view' + viewwhat
429
430    d = dict(zip(range(10), (range(11, 20))))
431    for name in "keys", "values", "items":
432        meth = getattr(six, "view" + name)
433        view = meth(d)
434        assert set(view) == set(getattr(d, name)())
435
436
437def test_advance_iterator():
438    assert six.next is six.advance_iterator
439    l = [1, 2]
440    it = iter(l)
441    assert six.next(it) == 1
442    assert six.next(it) == 2
443    py.test.raises(StopIteration, six.next, it)
444    py.test.raises(StopIteration, six.next, it)
445
446
447def test_iterator():
448    class myiter(six.Iterator):
449        def __next__(self):
450            return 13
451    assert six.advance_iterator(myiter()) == 13
452    class myitersub(myiter):
453        def __next__(self):
454            return 14
455    assert six.advance_iterator(myitersub()) == 14
456
457
458def test_callable():
459    class X:
460        def __call__(self):
461            pass
462        def method(self):
463            pass
464    assert six.callable(X)
465    assert six.callable(X())
466    assert six.callable(test_callable)
467    assert six.callable(hasattr)
468    assert six.callable(X.method)
469    assert six.callable(X().method)
470    assert not six.callable(4)
471    assert not six.callable("string")
472
473
474def test_create_bound_method():
475    class X(object):
476        pass
477    def f(self):
478        return self
479    x = X()
480    b = six.create_bound_method(f, x)
481    assert isinstance(b, types.MethodType)
482    assert b() is x
483
484
485def test_create_unbound_method():
486    class X(object):
487        pass
488
489    def f(self):
490        return self
491    u = six.create_unbound_method(f, X)
492    py.test.raises(TypeError, u)
493    if six.PY2:
494        assert isinstance(u, types.MethodType)
495    x = X()
496    assert f(x) is x
497
498
499if six.PY3:
500
501    def test_b():
502        data = six.b("\xff")
503        assert isinstance(data, bytes)
504        assert len(data) == 1
505        assert data == bytes([255])
506
507
508    def test_u():
509        s = six.u("hi \u0439 \U00000439 \\ \\\\ \n")
510        assert isinstance(s, str)
511        assert s == "hi \u0439 \U00000439 \\ \\\\ \n"
512
513else:
514
515    def test_b():
516        data = six.b("\xff")
517        assert isinstance(data, str)
518        assert len(data) == 1
519        assert data == "\xff"
520
521
522    def test_u():
523        s = six.u("hi \u0439 \U00000439 \\ \\\\ \n")
524        assert isinstance(s, unicode)
525        assert s == "hi \xd0\xb9 \xd0\xb9 \\ \\\\ \n".decode("utf8")
526
527
528def test_u_escapes():
529    s = six.u("\u1234")
530    assert len(s) == 1
531
532
533def test_unichr():
534    assert six.u("\u1234") == six.unichr(0x1234)
535    assert type(six.u("\u1234")) is type(six.unichr(0x1234))
536
537
538def test_int2byte():
539    assert six.int2byte(3) == six.b("\x03")
540    py.test.raises(Exception, six.int2byte, 256)
541
542
543def test_byte2int():
544    assert six.byte2int(six.b("\x03")) == 3
545    assert six.byte2int(six.b("\x03\x04")) == 3
546    py.test.raises(IndexError, six.byte2int, six.b(""))
547
548
549def test_bytesindex():
550    assert six.indexbytes(six.b("hello"), 3) == ord("l")
551
552
553def test_bytesiter():
554    it = six.iterbytes(six.b("hi"))
555    assert six.next(it) == ord("h")
556    assert six.next(it) == ord("i")
557    py.test.raises(StopIteration, six.next, it)
558
559
560def test_StringIO():
561    fp = six.StringIO()
562    fp.write(six.u("hello"))
563    assert fp.getvalue() == six.u("hello")
564
565
566def test_BytesIO():
567    fp = six.BytesIO()
568    fp.write(six.b("hello"))
569    assert fp.getvalue() == six.b("hello")
570
571
572def test_exec_():
573    def f():
574        l = []
575        six.exec_("l.append(1)")
576        assert l == [1]
577    f()
578    ns = {}
579    six.exec_("x = 42", ns)
580    assert ns["x"] == 42
581    glob = {}
582    loc = {}
583    six.exec_("global y; y = 42; x = 12", glob, loc)
584    assert glob["y"] == 42
585    assert "x" not in glob
586    assert loc["x"] == 12
587    assert "y" not in loc
588
589
590def test_reraise():
591    def get_next(tb):
592        if six.PY3:
593            return tb.tb_next.tb_next
594        else:
595            return tb.tb_next
596    e = Exception("blah")
597    try:
598        raise e
599    except Exception:
600        tp, val, tb = sys.exc_info()
601    try:
602        six.reraise(tp, val, tb)
603    except Exception:
604        tp2, value2, tb2 = sys.exc_info()
605        assert tp2 is Exception
606        assert value2 is e
607        assert tb is get_next(tb2)
608    try:
609        six.reraise(tp, val)
610    except Exception:
611        tp2, value2, tb2 = sys.exc_info()
612        assert tp2 is Exception
613        assert value2 is e
614        assert tb2 is not tb
615    try:
616        six.reraise(tp, val, tb2)
617    except Exception:
618        tp2, value2, tb3 = sys.exc_info()
619        assert tp2 is Exception
620        assert value2 is e
621        assert get_next(tb3) is tb2
622    try:
623        six.reraise(tp, None, tb)
624    except Exception:
625        tp2, value2, tb2 = sys.exc_info()
626        assert tp2 is Exception
627        assert value2 is not val
628        assert isinstance(value2, Exception)
629        assert tb is get_next(tb2)
630
631
632def test_raise_from():
633    try:
634        try:
635            raise Exception("blah")
636        except Exception:
637            ctx = sys.exc_info()[1]
638            f = Exception("foo")
639            six.raise_from(f, None)
640    except Exception:
641        tp, val, tb = sys.exc_info()
642    if sys.version_info[:2] > (3, 0):
643        # We should have done a raise f from None equivalent.
644        assert val.__cause__ is None
645        assert val.__context__ is ctx
646    if sys.version_info[:2] >= (3, 3):
647        # And that should suppress the context on the exception.
648        assert val.__suppress_context__
649    # For all versions the outer exception should have raised successfully.
650    assert str(val) == "foo"
651
652
653def test_print_():
654    save = sys.stdout
655    out = sys.stdout = six.moves.StringIO()
656    try:
657        six.print_("Hello,", "person!")
658    finally:
659        sys.stdout = save
660    assert out.getvalue() == "Hello, person!\n"
661    out = six.StringIO()
662    six.print_("Hello,", "person!", file=out)
663    assert out.getvalue() == "Hello, person!\n"
664    out = six.StringIO()
665    six.print_("Hello,", "person!", file=out, end="")
666    assert out.getvalue() == "Hello, person!"
667    out = six.StringIO()
668    six.print_("Hello,", "person!", file=out, sep="X")
669    assert out.getvalue() == "Hello,Xperson!\n"
670    out = six.StringIO()
671    six.print_(six.u("Hello,"), six.u("person!"), file=out)
672    result = out.getvalue()
673    assert isinstance(result, six.text_type)
674    assert result == six.u("Hello, person!\n")
675    six.print_("Hello", file=None) # This works.
676    out = six.StringIO()
677    six.print_(None, file=out)
678    assert out.getvalue() == "None\n"
679    class FlushableStringIO(six.StringIO):
680        def __init__(self):
681            six.StringIO.__init__(self)
682            self.flushed = False
683        def flush(self):
684            self.flushed = True
685    out = FlushableStringIO()
686    six.print_("Hello", file=out)
687    assert not out.flushed
688    six.print_("Hello", file=out, flush=True)
689    assert out.flushed
690
691
692@py.test.mark.skipif("sys.version_info[:2] >= (2, 6)")
693def test_print_encoding(monkeypatch):
694    # Fool the type checking in print_.
695    monkeypatch.setattr(six, "file", six.BytesIO, raising=False)
696    out = six.BytesIO()
697    out.encoding = "utf-8"
698    out.errors = None
699    six.print_(six.u("\u053c"), end="", file=out)
700    assert out.getvalue() == six.b("\xd4\xbc")
701    out = six.BytesIO()
702    out.encoding = "ascii"
703    out.errors = "strict"
704    py.test.raises(UnicodeEncodeError, six.print_, six.u("\u053c"), file=out)
705    out.errors = "backslashreplace"
706    six.print_(six.u("\u053c"), end="", file=out)
707    assert out.getvalue() == six.b("\\u053c")
708
709
710def test_print_exceptions():
711    py.test.raises(TypeError, six.print_, x=3)
712    py.test.raises(TypeError, six.print_, end=3)
713    py.test.raises(TypeError, six.print_, sep=42)
714
715
716def test_with_metaclass():
717    class Meta(type):
718        pass
719    class X(six.with_metaclass(Meta)):
720        pass
721    assert type(X) is Meta
722    assert issubclass(X, object)
723    class Base(object):
724        pass
725    class X(six.with_metaclass(Meta, Base)):
726        pass
727    assert type(X) is Meta
728    assert issubclass(X, Base)
729    class Base2(object):
730        pass
731    class X(six.with_metaclass(Meta, Base, Base2)):
732        pass
733    assert type(X) is Meta
734    assert issubclass(X, Base)
735    assert issubclass(X, Base2)
736    assert X.__mro__ == (X, Base, Base2, object)
737    class X(six.with_metaclass(Meta)):
738        pass
739    class MetaSub(Meta):
740        pass
741    class Y(six.with_metaclass(MetaSub, X)):
742        pass
743    assert type(Y) is MetaSub
744    assert Y.__mro__ == (Y, X, object)
745
746
747@py.test.mark.skipif("sys.version_info[:2] < (3, 0)")
748def test_with_metaclass_prepare():
749    """Test that with_metaclass causes Meta.__prepare__ to be called with the correct arguments."""
750
751    class MyDict(dict):
752        pass
753
754    class Meta(type):
755
756        @classmethod
757        def __prepare__(cls, name, bases):
758            namespace = MyDict(super().__prepare__(name, bases), cls=cls, bases=bases)
759            namespace['namespace'] = namespace
760            return namespace
761
762    class Base(object):
763        pass
764
765    bases = (Base,)
766
767    class X(six.with_metaclass(Meta, *bases)):
768        pass
769
770    assert getattr(X, 'cls', type) is Meta
771    assert getattr(X, 'bases', ()) == bases
772    assert isinstance(getattr(X, 'namespace', {}), MyDict)
773
774
775def test_wraps():
776    def f(g):
777        @six.wraps(g)
778        def w():
779            return 42
780        return w
781    def k():
782        pass
783    original_k = k
784    k = f(f(k))
785    assert hasattr(k, '__wrapped__')
786    k = k.__wrapped__
787    assert hasattr(k, '__wrapped__')
788    k = k.__wrapped__
789    assert k is original_k
790    assert not hasattr(k, '__wrapped__')
791
792    def f(g, assign, update):
793        def w():
794            return 42
795        w.glue = {"foo" : "bar"}
796        return six.wraps(g, assign, update)(w)
797    k.glue = {"melon" : "egg"}
798    k.turnip = 43
799    k = f(k, ["turnip"], ["glue"])
800    assert k.__name__ == "w"
801    assert k.turnip == 43
802    assert k.glue == {"melon" : "egg", "foo" : "bar"}
803
804
805def test_add_metaclass():
806    class Meta(type):
807        pass
808    class X:
809        "success"
810    X = six.add_metaclass(Meta)(X)
811    assert type(X) is Meta
812    assert issubclass(X, object)
813    assert X.__module__ == __name__
814    assert X.__doc__ == "success"
815    class Base(object):
816        pass
817    class X(Base):
818        pass
819    X = six.add_metaclass(Meta)(X)
820    assert type(X) is Meta
821    assert issubclass(X, Base)
822    class Base2(object):
823        pass
824    class X(Base, Base2):
825        pass
826    X = six.add_metaclass(Meta)(X)
827    assert type(X) is Meta
828    assert issubclass(X, Base)
829    assert issubclass(X, Base2)
830
831    # Test a second-generation subclass of a type.
832    class Meta1(type):
833        m1 = "m1"
834    class Meta2(Meta1):
835        m2 = "m2"
836    class Base:
837        b = "b"
838    Base = six.add_metaclass(Meta1)(Base)
839    class X(Base):
840        x = "x"
841    X = six.add_metaclass(Meta2)(X)
842    assert type(X) is Meta2
843    assert issubclass(X, Base)
844    assert type(Base) is Meta1
845    assert "__dict__" not in vars(X)
846    instance = X()
847    instance.attr = "test"
848    assert vars(instance) == {"attr": "test"}
849    assert instance.b == Base.b
850    assert instance.x == X.x
851
852    # Test a class with slots.
853    class MySlots(object):
854        __slots__ = ["a", "b"]
855    MySlots = six.add_metaclass(Meta1)(MySlots)
856
857    assert MySlots.__slots__ == ["a", "b"]
858    instance = MySlots()
859    instance.a = "foo"
860    py.test.raises(AttributeError, setattr, instance, "c", "baz")
861
862    # Test a class with string for slots.
863    class MyStringSlots(object):
864        __slots__ = "ab"
865    MyStringSlots = six.add_metaclass(Meta1)(MyStringSlots)
866    assert MyStringSlots.__slots__ == "ab"
867    instance = MyStringSlots()
868    instance.ab = "foo"
869    py.test.raises(AttributeError, setattr, instance, "a", "baz")
870    py.test.raises(AttributeError, setattr, instance, "b", "baz")
871
872    class MySlotsWeakref(object):
873        __slots__ = "__weakref__",
874    MySlotsWeakref = six.add_metaclass(Meta)(MySlotsWeakref)
875    assert type(MySlotsWeakref) is Meta
876
877
878@py.test.mark.skipif("sys.version_info[:2] < (2, 7) or sys.version_info[:2] in ((3, 0), (3, 1))")
879def test_assertCountEqual():
880    class TestAssertCountEqual(unittest.TestCase):
881        def test(self):
882            with self.assertRaises(AssertionError):
883                six.assertCountEqual(self, (1, 2), [3, 4, 5])
884
885            six.assertCountEqual(self, (1, 2), [2, 1])
886
887    TestAssertCountEqual('test').test()
888
889
890@py.test.mark.skipif("sys.version_info[:2] < (2, 7)")
891def test_assertRegex():
892    class TestAssertRegex(unittest.TestCase):
893        def test(self):
894            with self.assertRaises(AssertionError):
895                six.assertRegex(self, 'test', r'^a')
896
897            six.assertRegex(self, 'test', r'^t')
898
899    TestAssertRegex('test').test()
900
901
902@py.test.mark.skipif("sys.version_info[:2] < (2, 7)")
903def test_assertRaisesRegex():
904    class TestAssertRaisesRegex(unittest.TestCase):
905        def test(self):
906            with six.assertRaisesRegex(self, AssertionError, '^Foo'):
907                raise AssertionError('Foo')
908
909            with self.assertRaises(AssertionError):
910                with six.assertRaisesRegex(self, AssertionError, r'^Foo'):
911                    raise AssertionError('Bar')
912
913    TestAssertRaisesRegex('test').test()
914
915
916def test_python_2_unicode_compatible():
917    @six.python_2_unicode_compatible
918    class MyTest(object):
919        def __str__(self):
920            return six.u('hello')
921
922        def __bytes__(self):
923            return six.b('hello')
924
925    my_test = MyTest()
926
927    if six.PY2:
928        assert str(my_test) == six.b("hello")
929        assert unicode(my_test) == six.u("hello")
930    elif six.PY3:
931        assert bytes(my_test) == six.b("hello")
932        assert str(my_test) == six.u("hello")
933
934    assert getattr(six.moves.builtins, 'bytes', str)(my_test) == six.b("hello")
935