• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2import pickle
3import cPickle
4import StringIO
5import cStringIO
6import pickletools
7import copy_reg
8
9from test.test_support import TestFailed, have_unicode, TESTFN
10
11# Tests that try a number of pickle protocols should have a
12#     for proto in protocols:
13# kind of outer loop.
14assert pickle.HIGHEST_PROTOCOL == cPickle.HIGHEST_PROTOCOL == 2
15protocols = range(pickle.HIGHEST_PROTOCOL + 1)
16
17# Copy of test.test_support.run_with_locale. This is needed to support Python
18# 2.4, which didn't include it. This is all to support test_xpickle, which
19# bounces pickled objects through older Python versions to test backwards
20# compatibility.
21def run_with_locale(catstr, *locales):
22    def decorator(func):
23        def inner(*args, **kwds):
24            try:
25                import locale
26                category = getattr(locale, catstr)
27                orig_locale = locale.setlocale(category)
28            except AttributeError:
29                # if the test author gives us an invalid category string
30                raise
31            except:
32                # cannot retrieve original locale, so do nothing
33                locale = orig_locale = None
34            else:
35                for loc in locales:
36                    try:
37                        locale.setlocale(category, loc)
38                        break
39                    except:
40                        pass
41
42            # now run the function, resetting the locale on exceptions
43            try:
44                return func(*args, **kwds)
45            finally:
46                if locale and orig_locale:
47                    locale.setlocale(category, orig_locale)
48        inner.func_name = func.func_name
49        inner.__doc__ = func.__doc__
50        return inner
51    return decorator
52
53
54# Return True if opcode code appears in the pickle, else False.
55def opcode_in_pickle(code, pickle):
56    for op, dummy, dummy in pickletools.genops(pickle):
57        if op.code == code:
58            return True
59    return False
60
61# Return the number of times opcode code appears in pickle.
62def count_opcode(code, pickle):
63    n = 0
64    for op, dummy, dummy in pickletools.genops(pickle):
65        if op.code == code:
66            n += 1
67    return n
68
69# We can't very well test the extension registry without putting known stuff
70# in it, but we have to be careful to restore its original state.  Code
71# should do this:
72#
73#     e = ExtensionSaver(extension_code)
74#     try:
75#         fiddle w/ the extension registry's stuff for extension_code
76#     finally:
77#         e.restore()
78
79class ExtensionSaver:
80    # Remember current registration for code (if any), and remove it (if
81    # there is one).
82    def __init__(self, code):
83        self.code = code
84        if code in copy_reg._inverted_registry:
85            self.pair = copy_reg._inverted_registry[code]
86            copy_reg.remove_extension(self.pair[0], self.pair[1], code)
87        else:
88            self.pair = None
89
90    # Restore previous registration for code.
91    def restore(self):
92        code = self.code
93        curpair = copy_reg._inverted_registry.get(code)
94        if curpair is not None:
95            copy_reg.remove_extension(curpair[0], curpair[1], code)
96        pair = self.pair
97        if pair is not None:
98            copy_reg.add_extension(pair[0], pair[1], code)
99
100class C:
101    def __cmp__(self, other):
102        return cmp(self.__dict__, other.__dict__)
103
104import __main__
105__main__.C = C
106C.__module__ = "__main__"
107
108class myint(int):
109    def __init__(self, x):
110        self.str = str(x)
111
112class initarg(C):
113
114    def __init__(self, a, b):
115        self.a = a
116        self.b = b
117
118    def __getinitargs__(self):
119        return self.a, self.b
120
121class metaclass(type):
122    pass
123
124class use_metaclass(object):
125    __metaclass__ = metaclass
126
127# DATA0 .. DATA2 are the pickles we expect under the various protocols, for
128# the object returned by create_data().
129
130# break into multiple strings to avoid confusing font-lock-mode
131DATA0 = """(lp1
132I0
133aL1L
134aF2
135ac__builtin__
136complex
137p2
138""" + \
139"""(F3
140F0
141tRp3
142aI1
143aI-1
144aI255
145aI-255
146aI-256
147aI65535
148aI-65535
149aI-65536
150aI2147483647
151aI-2147483647
152aI-2147483648
153a""" + \
154"""(S'abc'
155p4
156g4
157""" + \
158"""(i__main__
159C
160p5
161""" + \
162"""(dp6
163S'foo'
164p7
165I1
166sS'bar'
167p8
168I2
169sbg5
170tp9
171ag9
172aI5
173a.
174"""
175
176# Disassembly of DATA0.
177DATA0_DIS = """\
178    0: (    MARK
179    1: l        LIST       (MARK at 0)
180    2: p    PUT        1
181    5: I    INT        0
182    8: a    APPEND
183    9: L    LONG       1L
184   13: a    APPEND
185   14: F    FLOAT      2.0
186   17: a    APPEND
187   18: c    GLOBAL     '__builtin__ complex'
188   39: p    PUT        2
189   42: (    MARK
190   43: F        FLOAT      3.0
191   46: F        FLOAT      0.0
192   49: t        TUPLE      (MARK at 42)
193   50: R    REDUCE
194   51: p    PUT        3
195   54: a    APPEND
196   55: I    INT        1
197   58: a    APPEND
198   59: I    INT        -1
199   63: a    APPEND
200   64: I    INT        255
201   69: a    APPEND
202   70: I    INT        -255
203   76: a    APPEND
204   77: I    INT        -256
205   83: a    APPEND
206   84: I    INT        65535
207   91: a    APPEND
208   92: I    INT        -65535
209  100: a    APPEND
210  101: I    INT        -65536
211  109: a    APPEND
212  110: I    INT        2147483647
213  122: a    APPEND
214  123: I    INT        -2147483647
215  136: a    APPEND
216  137: I    INT        -2147483648
217  150: a    APPEND
218  151: (    MARK
219  152: S        STRING     'abc'
220  159: p        PUT        4
221  162: g        GET        4
222  165: (        MARK
223  166: i            INST       '__main__ C' (MARK at 165)
224  178: p        PUT        5
225  181: (        MARK
226  182: d            DICT       (MARK at 181)
227  183: p        PUT        6
228  186: S        STRING     'foo'
229  193: p        PUT        7
230  196: I        INT        1
231  199: s        SETITEM
232  200: S        STRING     'bar'
233  207: p        PUT        8
234  210: I        INT        2
235  213: s        SETITEM
236  214: b        BUILD
237  215: g        GET        5
238  218: t        TUPLE      (MARK at 151)
239  219: p    PUT        9
240  222: a    APPEND
241  223: g    GET        9
242  226: a    APPEND
243  227: I    INT        5
244  230: a    APPEND
245  231: .    STOP
246highest protocol among opcodes = 0
247"""
248
249DATA1 = (']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00'
250         'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00'
251         '\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff'
252         '\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff'
253         'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00'
254         '\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n'
255         'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh'
256         '\x06tq\nh\nK\x05e.'
257        )
258
259# Disassembly of DATA1.
260DATA1_DIS = """\
261    0: ]    EMPTY_LIST
262    1: q    BINPUT     1
263    3: (    MARK
264    4: K        BININT1    0
265    6: L        LONG       1L
266   10: G        BINFLOAT   2.0
267   19: c        GLOBAL     '__builtin__ complex'
268   40: q        BINPUT     2
269   42: (        MARK
270   43: G            BINFLOAT   3.0
271   52: G            BINFLOAT   0.0
272   61: t            TUPLE      (MARK at 42)
273   62: R        REDUCE
274   63: q        BINPUT     3
275   65: K        BININT1    1
276   67: J        BININT     -1
277   72: K        BININT1    255
278   74: J        BININT     -255
279   79: J        BININT     -256
280   84: M        BININT2    65535
281   87: J        BININT     -65535
282   92: J        BININT     -65536
283   97: J        BININT     2147483647
284  102: J        BININT     -2147483647
285  107: J        BININT     -2147483648
286  112: (        MARK
287  113: U            SHORT_BINSTRING 'abc'
288  118: q            BINPUT     4
289  120: h            BINGET     4
290  122: (            MARK
291  123: c                GLOBAL     '__main__ C'
292  135: q                BINPUT     5
293  137: o                OBJ        (MARK at 122)
294  138: q            BINPUT     6
295  140: }            EMPTY_DICT
296  141: q            BINPUT     7
297  143: (            MARK
298  144: U                SHORT_BINSTRING 'foo'
299  149: q                BINPUT     8
300  151: K                BININT1    1
301  153: U                SHORT_BINSTRING 'bar'
302  158: q                BINPUT     9
303  160: K                BININT1    2
304  162: u                SETITEMS   (MARK at 143)
305  163: b            BUILD
306  164: h            BINGET     6
307  166: t            TUPLE      (MARK at 112)
308  167: q        BINPUT     10
309  169: h        BINGET     10
310  171: K        BININT1    5
311  173: e        APPENDS    (MARK at 3)
312  174: .    STOP
313highest protocol among opcodes = 1
314"""
315
316DATA2 = ('\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00'
317         'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00'
318         '\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK'
319         '\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff'
320         'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00'
321         '\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo'
322         'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.')
323
324# Disassembly of DATA2.
325DATA2_DIS = """\
326    0: \x80 PROTO      2
327    2: ]    EMPTY_LIST
328    3: q    BINPUT     1
329    5: (    MARK
330    6: K        BININT1    0
331    8: \x8a     LONG1      1L
332   11: G        BINFLOAT   2.0
333   20: c        GLOBAL     '__builtin__ complex'
334   41: q        BINPUT     2
335   43: G        BINFLOAT   3.0
336   52: G        BINFLOAT   0.0
337   61: \x86     TUPLE2
338   62: R        REDUCE
339   63: q        BINPUT     3
340   65: K        BININT1    1
341   67: J        BININT     -1
342   72: K        BININT1    255
343   74: J        BININT     -255
344   79: J        BININT     -256
345   84: M        BININT2    65535
346   87: J        BININT     -65535
347   92: J        BININT     -65536
348   97: J        BININT     2147483647
349  102: J        BININT     -2147483647
350  107: J        BININT     -2147483648
351  112: (        MARK
352  113: U            SHORT_BINSTRING 'abc'
353  118: q            BINPUT     4
354  120: h            BINGET     4
355  122: (            MARK
356  123: c                GLOBAL     '__main__ C'
357  135: q                BINPUT     5
358  137: o                OBJ        (MARK at 122)
359  138: q            BINPUT     6
360  140: }            EMPTY_DICT
361  141: q            BINPUT     7
362  143: (            MARK
363  144: U                SHORT_BINSTRING 'foo'
364  149: q                BINPUT     8
365  151: K                BININT1    1
366  153: U                SHORT_BINSTRING 'bar'
367  158: q                BINPUT     9
368  160: K                BININT1    2
369  162: u                SETITEMS   (MARK at 143)
370  163: b            BUILD
371  164: h            BINGET     6
372  166: t            TUPLE      (MARK at 112)
373  167: q        BINPUT     10
374  169: h        BINGET     10
375  171: K        BININT1    5
376  173: e        APPENDS    (MARK at 5)
377  174: .    STOP
378highest protocol among opcodes = 2
379"""
380
381def create_data():
382    c = C()
383    c.foo = 1
384    c.bar = 2
385    x = [0, 1L, 2.0, 3.0+0j]
386    # Append some integer test cases at cPickle.c's internal size
387    # cutoffs.
388    uint1max = 0xff
389    uint2max = 0xffff
390    int4max = 0x7fffffff
391    x.extend([1, -1,
392              uint1max, -uint1max, -uint1max-1,
393              uint2max, -uint2max, -uint2max-1,
394               int4max,  -int4max,  -int4max-1])
395    y = ('abc', 'abc', c, c)
396    x.append(y)
397    x.append(y)
398    x.append(5)
399    return x
400
401class AbstractPickleTests(unittest.TestCase):
402    # Subclass must define self.dumps, self.loads, self.error.
403
404    _testdata = create_data()
405
406    def setUp(self):
407        pass
408
409    def test_misc(self):
410        # test various datatypes not tested by testdata
411        for proto in protocols:
412            x = myint(4)
413            s = self.dumps(x, proto)
414            y = self.loads(s)
415            self.assertEqual(x, y)
416
417            x = (1, ())
418            s = self.dumps(x, proto)
419            y = self.loads(s)
420            self.assertEqual(x, y)
421
422            x = initarg(1, x)
423            s = self.dumps(x, proto)
424            y = self.loads(s)
425            self.assertEqual(x, y)
426
427        # XXX test __reduce__ protocol?
428
429    def test_roundtrip_equality(self):
430        expected = self._testdata
431        for proto in protocols:
432            s = self.dumps(expected, proto)
433            got = self.loads(s)
434            self.assertEqual(expected, got)
435
436    def test_load_from_canned_string(self):
437        expected = self._testdata
438        for canned in DATA0, DATA1, DATA2:
439            got = self.loads(canned)
440            self.assertEqual(expected, got)
441
442    # There are gratuitous differences between pickles produced by
443    # pickle and cPickle, largely because cPickle starts PUT indices at
444    # 1 and pickle starts them at 0.  See XXX comment in cPickle's put2() --
445    # there's a comment with an exclamation point there whose meaning
446    # is a mystery.  cPickle also suppresses PUT for objects with a refcount
447    # of 1.
448    def dont_test_disassembly(self):
449        from pickletools import dis
450
451        for proto, expected in (0, DATA0_DIS), (1, DATA1_DIS):
452            s = self.dumps(self._testdata, proto)
453            filelike = cStringIO.StringIO()
454            dis(s, out=filelike)
455            got = filelike.getvalue()
456            self.assertEqual(expected, got)
457
458    def test_recursive_list(self):
459        l = []
460        l.append(l)
461        for proto in protocols:
462            s = self.dumps(l, proto)
463            x = self.loads(s)
464            self.assertEqual(len(x), 1)
465            self.assertTrue(x is x[0])
466
467    def test_recursive_tuple(self):
468        t = ([],)
469        t[0].append(t)
470        for proto in protocols:
471            s = self.dumps(t, proto)
472            x = self.loads(s)
473            self.assertEqual(len(x), 1)
474            self.assertEqual(len(x[0]), 1)
475            self.assertTrue(x is x[0][0])
476
477    def test_recursive_dict(self):
478        d = {}
479        d[1] = d
480        for proto in protocols:
481            s = self.dumps(d, proto)
482            x = self.loads(s)
483            self.assertEqual(x.keys(), [1])
484            self.assertTrue(x[1] is x)
485
486    def test_recursive_inst(self):
487        i = C()
488        i.attr = i
489        for proto in protocols:
490            s = self.dumps(i, 2)
491            x = self.loads(s)
492            self.assertEqual(dir(x), dir(i))
493            self.assertTrue(x.attr is x)
494
495    def test_recursive_multi(self):
496        l = []
497        d = {1:l}
498        i = C()
499        i.attr = d
500        l.append(i)
501        for proto in protocols:
502            s = self.dumps(l, proto)
503            x = self.loads(s)
504            self.assertEqual(len(x), 1)
505            self.assertEqual(dir(x[0]), dir(i))
506            self.assertEqual(x[0].attr.keys(), [1])
507            self.assertTrue(x[0].attr[1] is x)
508
509    def test_garyp(self):
510        self.assertRaises(self.error, self.loads, 'garyp')
511
512    def test_insecure_strings(self):
513        insecure = ["abc", "2 + 2", # not quoted
514                    #"'abc' + 'def'", # not a single quoted string
515                    "'abc", # quote is not closed
516                    "'abc\"", # open quote and close quote don't match
517                    "'abc'   ?", # junk after close quote
518                    "'\\'", # trailing backslash
519                    # some tests of the quoting rules
520                    #"'abc\"\''",
521                    #"'\\\\a\'\'\'\\\'\\\\\''",
522                    ]
523        for s in insecure:
524            buf = "S" + s + "\012p0\012."
525            self.assertRaises(ValueError, self.loads, buf)
526
527    if have_unicode:
528        def test_unicode(self):
529            endcases = [u'', u'<\\u>', u'<\\\u1234>', u'<\n>',
530                        u'<\\>', u'<\\\U00012345>']
531            for proto in protocols:
532                for u in endcases:
533                    p = self.dumps(u, proto)
534                    u2 = self.loads(p)
535                    self.assertEqual(u2, u)
536
537        def test_unicode_high_plane(self):
538            t = u'\U00012345'
539            for proto in protocols:
540                p = self.dumps(t, proto)
541                t2 = self.loads(p)
542                self.assertEqual(t2, t)
543
544    def test_ints(self):
545        import sys
546        for proto in protocols:
547            n = sys.maxint
548            while n:
549                for expected in (-n, n):
550                    s = self.dumps(expected, proto)
551                    n2 = self.loads(s)
552                    self.assertEqual(expected, n2)
553                n = n >> 1
554
555    def test_maxint64(self):
556        maxint64 = (1L << 63) - 1
557        data = 'I' + str(maxint64) + '\n.'
558        got = self.loads(data)
559        self.assertEqual(got, maxint64)
560
561        # Try too with a bogus literal.
562        data = 'I' + str(maxint64) + 'JUNK\n.'
563        self.assertRaises(ValueError, self.loads, data)
564
565    def test_long(self):
566        for proto in protocols:
567            # 256 bytes is where LONG4 begins.
568            for nbits in 1, 8, 8*254, 8*255, 8*256, 8*257:
569                nbase = 1L << nbits
570                for npos in nbase-1, nbase, nbase+1:
571                    for n in npos, -npos:
572                        pickle = self.dumps(n, proto)
573                        got = self.loads(pickle)
574                        self.assertEqual(n, got)
575        # Try a monster.  This is quadratic-time in protos 0 & 1, so don't
576        # bother with those.
577        nbase = long("deadbeeffeedface", 16)
578        nbase += nbase << 1000000
579        for n in nbase, -nbase:
580            p = self.dumps(n, 2)
581            got = self.loads(p)
582            self.assertEqual(n, got)
583
584    def test_float(self):
585        test_values = [0.0, 4.94e-324, 1e-310, 7e-308, 6.626e-34, 0.1, 0.5,
586                       3.14, 263.44582062374053, 6.022e23, 1e30]
587        test_values = test_values + [-x for x in test_values]
588        for proto in protocols:
589            for value in test_values:
590                pickle = self.dumps(value, proto)
591                got = self.loads(pickle)
592                self.assertEqual(value, got)
593
594    @run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
595    def test_float_format(self):
596        # make sure that floats are formatted locale independent
597        self.assertEqual(self.dumps(1.2)[0:3], 'F1.')
598
599    def test_reduce(self):
600        pass
601
602    def test_getinitargs(self):
603        pass
604
605    def test_metaclass(self):
606        a = use_metaclass()
607        for proto in protocols:
608            s = self.dumps(a, proto)
609            b = self.loads(s)
610            self.assertEqual(a.__class__, b.__class__)
611
612    def test_structseq(self):
613        import time
614        import os
615
616        t = time.localtime()
617        for proto in protocols:
618            s = self.dumps(t, proto)
619            u = self.loads(s)
620            self.assertEqual(t, u)
621            if hasattr(os, "stat"):
622                t = os.stat(os.curdir)
623                s = self.dumps(t, proto)
624                u = self.loads(s)
625                self.assertEqual(t, u)
626            if hasattr(os, "statvfs"):
627                t = os.statvfs(os.curdir)
628                s = self.dumps(t, proto)
629                u = self.loads(s)
630                self.assertEqual(t, u)
631
632    # Tests for protocol 2
633
634    def test_proto(self):
635        build_none = pickle.NONE + pickle.STOP
636        for proto in protocols:
637            expected = build_none
638            if proto >= 2:
639                expected = pickle.PROTO + chr(proto) + expected
640            p = self.dumps(None, proto)
641            self.assertEqual(p, expected)
642
643        oob = protocols[-1] + 1     # a future protocol
644        badpickle = pickle.PROTO + chr(oob) + build_none
645        try:
646            self.loads(badpickle)
647        except ValueError, detail:
648            self.assertTrue(str(detail).startswith(
649                                            "unsupported pickle protocol"))
650        else:
651            self.fail("expected bad protocol number to raise ValueError")
652
653    def test_long1(self):
654        x = 12345678910111213141516178920L
655        for proto in protocols:
656            s = self.dumps(x, proto)
657            y = self.loads(s)
658            self.assertEqual(x, y)
659            self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2)
660
661    def test_long4(self):
662        x = 12345678910111213141516178920L << (256*8)
663        for proto in protocols:
664            s = self.dumps(x, proto)
665            y = self.loads(s)
666            self.assertEqual(x, y)
667            self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2)
668
669    def test_short_tuples(self):
670        # Map (proto, len(tuple)) to expected opcode.
671        expected_opcode = {(0, 0): pickle.TUPLE,
672                           (0, 1): pickle.TUPLE,
673                           (0, 2): pickle.TUPLE,
674                           (0, 3): pickle.TUPLE,
675                           (0, 4): pickle.TUPLE,
676
677                           (1, 0): pickle.EMPTY_TUPLE,
678                           (1, 1): pickle.TUPLE,
679                           (1, 2): pickle.TUPLE,
680                           (1, 3): pickle.TUPLE,
681                           (1, 4): pickle.TUPLE,
682
683                           (2, 0): pickle.EMPTY_TUPLE,
684                           (2, 1): pickle.TUPLE1,
685                           (2, 2): pickle.TUPLE2,
686                           (2, 3): pickle.TUPLE3,
687                           (2, 4): pickle.TUPLE,
688                          }
689        a = ()
690        b = (1,)
691        c = (1, 2)
692        d = (1, 2, 3)
693        e = (1, 2, 3, 4)
694        for proto in protocols:
695            for x in a, b, c, d, e:
696                s = self.dumps(x, proto)
697                y = self.loads(s)
698                self.assertEqual(x, y, (proto, x, s, y))
699                expected = expected_opcode[proto, len(x)]
700                self.assertEqual(opcode_in_pickle(expected, s), True)
701
702    def test_singletons(self):
703        # Map (proto, singleton) to expected opcode.
704        expected_opcode = {(0, None): pickle.NONE,
705                           (1, None): pickle.NONE,
706                           (2, None): pickle.NONE,
707
708                           (0, True): pickle.INT,
709                           (1, True): pickle.INT,
710                           (2, True): pickle.NEWTRUE,
711
712                           (0, False): pickle.INT,
713                           (1, False): pickle.INT,
714                           (2, False): pickle.NEWFALSE,
715                          }
716        for proto in protocols:
717            for x in None, False, True:
718                s = self.dumps(x, proto)
719                y = self.loads(s)
720                self.assertTrue(x is y, (proto, x, s, y))
721                expected = expected_opcode[proto, x]
722                self.assertEqual(opcode_in_pickle(expected, s), True)
723
724    def test_newobj_tuple(self):
725        x = MyTuple([1, 2, 3])
726        x.foo = 42
727        x.bar = "hello"
728        for proto in protocols:
729            s = self.dumps(x, proto)
730            y = self.loads(s)
731            self.assertEqual(tuple(x), tuple(y))
732            self.assertEqual(x.__dict__, y.__dict__)
733
734    def test_newobj_list(self):
735        x = MyList([1, 2, 3])
736        x.foo = 42
737        x.bar = "hello"
738        for proto in protocols:
739            s = self.dumps(x, proto)
740            y = self.loads(s)
741            self.assertEqual(list(x), list(y))
742            self.assertEqual(x.__dict__, y.__dict__)
743
744    def test_newobj_generic(self):
745        for proto in protocols:
746            for C in myclasses:
747                B = C.__base__
748                x = C(C.sample)
749                x.foo = 42
750                s = self.dumps(x, proto)
751                y = self.loads(s)
752                detail = (proto, C, B, x, y, type(y))
753                self.assertEqual(B(x), B(y), detail)
754                self.assertEqual(x.__dict__, y.__dict__, detail)
755
756    # Register a type with copy_reg, with extension code extcode.  Pickle
757    # an object of that type.  Check that the resulting pickle uses opcode
758    # (EXT[124]) under proto 2, and not in proto 1.
759
760    def produce_global_ext(self, extcode, opcode):
761        e = ExtensionSaver(extcode)
762        try:
763            copy_reg.add_extension(__name__, "MyList", extcode)
764            x = MyList([1, 2, 3])
765            x.foo = 42
766            x.bar = "hello"
767
768            # Dump using protocol 1 for comparison.
769            s1 = self.dumps(x, 1)
770            self.assertIn(__name__, s1)
771            self.assertIn("MyList", s1)
772            self.assertEqual(opcode_in_pickle(opcode, s1), False)
773
774            y = self.loads(s1)
775            self.assertEqual(list(x), list(y))
776            self.assertEqual(x.__dict__, y.__dict__)
777
778            # Dump using protocol 2 for test.
779            s2 = self.dumps(x, 2)
780            self.assertNotIn(__name__, s2)
781            self.assertNotIn("MyList", s2)
782            self.assertEqual(opcode_in_pickle(opcode, s2), True)
783
784            y = self.loads(s2)
785            self.assertEqual(list(x), list(y))
786            self.assertEqual(x.__dict__, y.__dict__)
787
788        finally:
789            e.restore()
790
791    def test_global_ext1(self):
792        self.produce_global_ext(0x00000001, pickle.EXT1)  # smallest EXT1 code
793        self.produce_global_ext(0x000000ff, pickle.EXT1)  # largest EXT1 code
794
795    def test_global_ext2(self):
796        self.produce_global_ext(0x00000100, pickle.EXT2)  # smallest EXT2 code
797        self.produce_global_ext(0x0000ffff, pickle.EXT2)  # largest EXT2 code
798        self.produce_global_ext(0x0000abcd, pickle.EXT2)  # check endianness
799
800    def test_global_ext4(self):
801        self.produce_global_ext(0x00010000, pickle.EXT4)  # smallest EXT4 code
802        self.produce_global_ext(0x7fffffff, pickle.EXT4)  # largest EXT4 code
803        self.produce_global_ext(0x12abcdef, pickle.EXT4)  # check endianness
804
805    def test_list_chunking(self):
806        n = 10  # too small to chunk
807        x = range(n)
808        for proto in protocols:
809            s = self.dumps(x, proto)
810            y = self.loads(s)
811            self.assertEqual(x, y)
812            num_appends = count_opcode(pickle.APPENDS, s)
813            self.assertEqual(num_appends, proto > 0)
814
815        n = 2500  # expect at least two chunks when proto > 0
816        x = range(n)
817        for proto in protocols:
818            s = self.dumps(x, proto)
819            y = self.loads(s)
820            self.assertEqual(x, y)
821            num_appends = count_opcode(pickle.APPENDS, s)
822            if proto == 0:
823                self.assertEqual(num_appends, 0)
824            else:
825                self.assertTrue(num_appends >= 2)
826
827    def test_dict_chunking(self):
828        n = 10  # too small to chunk
829        x = dict.fromkeys(range(n))
830        for proto in protocols:
831            s = self.dumps(x, proto)
832            y = self.loads(s)
833            self.assertEqual(x, y)
834            num_setitems = count_opcode(pickle.SETITEMS, s)
835            self.assertEqual(num_setitems, proto > 0)
836
837        n = 2500  # expect at least two chunks when proto > 0
838        x = dict.fromkeys(range(n))
839        for proto in protocols:
840            s = self.dumps(x, proto)
841            y = self.loads(s)
842            self.assertEqual(x, y)
843            num_setitems = count_opcode(pickle.SETITEMS, s)
844            if proto == 0:
845                self.assertEqual(num_setitems, 0)
846            else:
847                self.assertTrue(num_setitems >= 2)
848
849    def test_simple_newobj(self):
850        x = object.__new__(SimpleNewObj)  # avoid __init__
851        x.abc = 666
852        for proto in protocols:
853            s = self.dumps(x, proto)
854            self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2)
855            y = self.loads(s)   # will raise TypeError if __init__ called
856            self.assertEqual(y.abc, 666)
857            self.assertEqual(x.__dict__, y.__dict__)
858
859    def test_newobj_list_slots(self):
860        x = SlotList([1, 2, 3])
861        x.foo = 42
862        x.bar = "hello"
863        s = self.dumps(x, 2)
864        y = self.loads(s)
865        self.assertEqual(list(x), list(y))
866        self.assertEqual(x.__dict__, y.__dict__)
867        self.assertEqual(x.foo, y.foo)
868        self.assertEqual(x.bar, y.bar)
869
870    def test_reduce_overrides_default_reduce_ex(self):
871        for proto in protocols:
872            x = REX_one()
873            self.assertEqual(x._reduce_called, 0)
874            s = self.dumps(x, proto)
875            self.assertEqual(x._reduce_called, 1)
876            y = self.loads(s)
877            self.assertEqual(y._reduce_called, 0)
878
879    def test_reduce_ex_called(self):
880        for proto in protocols:
881            x = REX_two()
882            self.assertEqual(x._proto, None)
883            s = self.dumps(x, proto)
884            self.assertEqual(x._proto, proto)
885            y = self.loads(s)
886            self.assertEqual(y._proto, None)
887
888    def test_reduce_ex_overrides_reduce(self):
889        for proto in protocols:
890            x = REX_three()
891            self.assertEqual(x._proto, None)
892            s = self.dumps(x, proto)
893            self.assertEqual(x._proto, proto)
894            y = self.loads(s)
895            self.assertEqual(y._proto, None)
896
897    def test_reduce_ex_calls_base(self):
898        for proto in protocols:
899            x = REX_four()
900            self.assertEqual(x._proto, None)
901            s = self.dumps(x, proto)
902            self.assertEqual(x._proto, proto)
903            y = self.loads(s)
904            self.assertEqual(y._proto, proto)
905
906    def test_reduce_calls_base(self):
907        for proto in protocols:
908            x = REX_five()
909            self.assertEqual(x._reduce_called, 0)
910            s = self.dumps(x, proto)
911            self.assertEqual(x._reduce_called, 1)
912            y = self.loads(s)
913            self.assertEqual(y._reduce_called, 1)
914
915    def test_reduce_bad_iterator(self):
916        # Issue4176: crash when 4th and 5th items of __reduce__()
917        # are not iterators
918        class C(object):
919            def __reduce__(self):
920                # 4th item is not an iterator
921                return list, (), None, [], None
922        class D(object):
923            def __reduce__(self):
924                # 5th item is not an iterator
925                return dict, (), None, None, []
926
927        # Protocol 0 is less strict and also accept iterables.
928        for proto in protocols:
929            try:
930                self.dumps(C(), proto)
931            except (AttributeError, pickle.PickleError, cPickle.PickleError):
932                pass
933            try:
934                self.dumps(D(), proto)
935            except (AttributeError, pickle.PickleError, cPickle.PickleError):
936                pass
937
938    def test_many_puts_and_gets(self):
939        # Test that internal data structures correctly deal with lots of
940        # puts/gets.
941        keys = ("aaa" + str(i) for i in xrange(100))
942        large_dict = dict((k, [4, 5, 6]) for k in keys)
943        obj = [dict(large_dict), dict(large_dict), dict(large_dict)]
944
945        for proto in protocols:
946            dumped = self.dumps(obj, proto)
947            loaded = self.loads(dumped)
948            self.assertEqual(loaded, obj,
949                             "Failed protocol %d: %r != %r"
950                             % (proto, obj, loaded))
951
952    def test_attribute_name_interning(self):
953        # Test that attribute names of pickled objects are interned when
954        # unpickling.
955        for proto in protocols:
956            x = C()
957            x.foo = 42
958            x.bar = "hello"
959            s = self.dumps(x, proto)
960            y = self.loads(s)
961            x_keys = sorted(x.__dict__)
962            y_keys = sorted(y.__dict__)
963            for x_key, y_key in zip(x_keys, y_keys):
964                self.assertIs(x_key, y_key)
965
966
967# Test classes for reduce_ex
968
969class REX_one(object):
970    _reduce_called = 0
971    def __reduce__(self):
972        self._reduce_called = 1
973        return REX_one, ()
974    # No __reduce_ex__ here, but inheriting it from object
975
976class REX_two(object):
977    _proto = None
978    def __reduce_ex__(self, proto):
979        self._proto = proto
980        return REX_two, ()
981    # No __reduce__ here, but inheriting it from object
982
983class REX_three(object):
984    _proto = None
985    def __reduce_ex__(self, proto):
986        self._proto = proto
987        return REX_two, ()
988    def __reduce__(self):
989        raise TestFailed, "This __reduce__ shouldn't be called"
990
991class REX_four(object):
992    _proto = None
993    def __reduce_ex__(self, proto):
994        self._proto = proto
995        return object.__reduce_ex__(self, proto)
996    # Calling base class method should succeed
997
998class REX_five(object):
999    _reduce_called = 0
1000    def __reduce__(self):
1001        self._reduce_called = 1
1002        return object.__reduce__(self)
1003    # This one used to fail with infinite recursion
1004
1005# Test classes for newobj
1006
1007class MyInt(int):
1008    sample = 1
1009
1010class MyLong(long):
1011    sample = 1L
1012
1013class MyFloat(float):
1014    sample = 1.0
1015
1016class MyComplex(complex):
1017    sample = 1.0 + 0.0j
1018
1019class MyStr(str):
1020    sample = "hello"
1021
1022class MyUnicode(unicode):
1023    sample = u"hello \u1234"
1024
1025class MyTuple(tuple):
1026    sample = (1, 2, 3)
1027
1028class MyList(list):
1029    sample = [1, 2, 3]
1030
1031class MyDict(dict):
1032    sample = {"a": 1, "b": 2}
1033
1034myclasses = [MyInt, MyLong, MyFloat,
1035             MyComplex,
1036             MyStr, MyUnicode,
1037             MyTuple, MyList, MyDict]
1038
1039
1040class SlotList(MyList):
1041    __slots__ = ["foo"]
1042
1043class SimpleNewObj(object):
1044    def __init__(self, a, b, c):
1045        # raise an error, to make sure this isn't called
1046        raise TypeError("SimpleNewObj.__init__() didn't expect to get called")
1047
1048class AbstractPickleModuleTests(unittest.TestCase):
1049
1050    def test_dump_closed_file(self):
1051        import os
1052        f = open(TESTFN, "w")
1053        try:
1054            f.close()
1055            self.assertRaises(ValueError, self.module.dump, 123, f)
1056        finally:
1057            os.remove(TESTFN)
1058
1059    def test_load_closed_file(self):
1060        import os
1061        f = open(TESTFN, "w")
1062        try:
1063            f.close()
1064            self.assertRaises(ValueError, self.module.dump, 123, f)
1065        finally:
1066            os.remove(TESTFN)
1067
1068    def test_load_from_and_dump_to_file(self):
1069        stream = cStringIO.StringIO()
1070        data = [123, {}, 124]
1071        self.module.dump(data, stream)
1072        stream.seek(0)
1073        unpickled = self.module.load(stream)
1074        self.assertEqual(unpickled, data)
1075
1076    def test_highest_protocol(self):
1077        # Of course this needs to be changed when HIGHEST_PROTOCOL changes.
1078        self.assertEqual(self.module.HIGHEST_PROTOCOL, 2)
1079
1080    def test_callapi(self):
1081        f = cStringIO.StringIO()
1082        # With and without keyword arguments
1083        self.module.dump(123, f, -1)
1084        self.module.dump(123, file=f, protocol=-1)
1085        self.module.dumps(123, -1)
1086        self.module.dumps(123, protocol=-1)
1087        self.module.Pickler(f, -1)
1088        self.module.Pickler(f, protocol=-1)
1089
1090    def test_incomplete_input(self):
1091        s = StringIO.StringIO("X''.")
1092        self.assertRaises(EOFError, self.module.load, s)
1093
1094    def test_restricted(self):
1095        # issue7128: cPickle failed in restricted mode
1096        builtins = {self.module.__name__: self.module,
1097                    '__import__': __import__}
1098        d = {}
1099        teststr = "def f(): {0}.dumps(0)".format(self.module.__name__)
1100        exec teststr in {'__builtins__': builtins}, d
1101        d['f']()
1102
1103    def test_bad_input(self):
1104        # Test issue4298
1105        s = '\x58\0\0\0\x54'
1106        self.assertRaises(EOFError, self.module.loads, s)
1107        # Test issue7455
1108        s = '0'
1109        # XXX Why doesn't pickle raise UnpicklingError?
1110        self.assertRaises((IndexError, cPickle.UnpicklingError),
1111                          self.module.loads, s)
1112
1113class AbstractPersistentPicklerTests(unittest.TestCase):
1114
1115    # This class defines persistent_id() and persistent_load()
1116    # functions that should be used by the pickler.  All even integers
1117    # are pickled using persistent ids.
1118
1119    def persistent_id(self, object):
1120        if isinstance(object, int) and object % 2 == 0:
1121            self.id_count += 1
1122            return str(object)
1123        else:
1124            return None
1125
1126    def persistent_load(self, oid):
1127        self.load_count += 1
1128        object = int(oid)
1129        assert object % 2 == 0
1130        return object
1131
1132    def test_persistence(self):
1133        self.id_count = 0
1134        self.load_count = 0
1135        L = range(10)
1136        self.assertEqual(self.loads(self.dumps(L)), L)
1137        self.assertEqual(self.id_count, 5)
1138        self.assertEqual(self.load_count, 5)
1139
1140    def test_bin_persistence(self):
1141        self.id_count = 0
1142        self.load_count = 0
1143        L = range(10)
1144        self.assertEqual(self.loads(self.dumps(L, 1)), L)
1145        self.assertEqual(self.id_count, 5)
1146        self.assertEqual(self.load_count, 5)
1147
1148class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
1149
1150    pickler_class = None
1151    unpickler_class = None
1152
1153    def setUp(self):
1154        assert self.pickler_class
1155        assert self.unpickler_class
1156
1157    def test_clear_pickler_memo(self):
1158        # To test whether clear_memo() has any effect, we pickle an object,
1159        # then pickle it again without clearing the memo; the two serialized
1160        # forms should be different. If we clear_memo() and then pickle the
1161        # object again, the third serialized form should be identical to the
1162        # first one we obtained.
1163        data = ["abcdefg", "abcdefg", 44]
1164        f = cStringIO.StringIO()
1165        pickler = self.pickler_class(f)
1166
1167        pickler.dump(data)
1168        first_pickled = f.getvalue()
1169
1170        # Reset StringIO object.
1171        f.seek(0)
1172        f.truncate()
1173
1174        pickler.dump(data)
1175        second_pickled = f.getvalue()
1176
1177        # Reset the Pickler and StringIO objects.
1178        pickler.clear_memo()
1179        f.seek(0)
1180        f.truncate()
1181
1182        pickler.dump(data)
1183        third_pickled = f.getvalue()
1184
1185        self.assertNotEqual(first_pickled, second_pickled)
1186        self.assertEqual(first_pickled, third_pickled)
1187
1188    def test_priming_pickler_memo(self):
1189        # Verify that we can set the Pickler's memo attribute.
1190        data = ["abcdefg", "abcdefg", 44]
1191        f = cStringIO.StringIO()
1192        pickler = self.pickler_class(f)
1193
1194        pickler.dump(data)
1195        first_pickled = f.getvalue()
1196
1197        f = cStringIO.StringIO()
1198        primed = self.pickler_class(f)
1199        primed.memo = pickler.memo
1200
1201        primed.dump(data)
1202        primed_pickled = f.getvalue()
1203
1204        self.assertNotEqual(first_pickled, primed_pickled)
1205
1206    def test_priming_unpickler_memo(self):
1207        # Verify that we can set the Unpickler's memo attribute.
1208        data = ["abcdefg", "abcdefg", 44]
1209        f = cStringIO.StringIO()
1210        pickler = self.pickler_class(f)
1211
1212        pickler.dump(data)
1213        first_pickled = f.getvalue()
1214
1215        f = cStringIO.StringIO()
1216        primed = self.pickler_class(f)
1217        primed.memo = pickler.memo
1218
1219        primed.dump(data)
1220        primed_pickled = f.getvalue()
1221
1222        unpickler = self.unpickler_class(cStringIO.StringIO(first_pickled))
1223        unpickled_data1 = unpickler.load()
1224
1225        self.assertEqual(unpickled_data1, data)
1226
1227        primed = self.unpickler_class(cStringIO.StringIO(primed_pickled))
1228        primed.memo = unpickler.memo
1229        unpickled_data2 = primed.load()
1230
1231        primed.memo.clear()
1232
1233        self.assertEqual(unpickled_data2, data)
1234        self.assertTrue(unpickled_data2 is unpickled_data1)
1235
1236    def test_reusing_unpickler_objects(self):
1237        data1 = ["abcdefg", "abcdefg", 44]
1238        f = cStringIO.StringIO()
1239        pickler = self.pickler_class(f)
1240        pickler.dump(data1)
1241        pickled1 = f.getvalue()
1242
1243        data2 = ["abcdefg", 44, 44]
1244        f = cStringIO.StringIO()
1245        pickler = self.pickler_class(f)
1246        pickler.dump(data2)
1247        pickled2 = f.getvalue()
1248
1249        f = cStringIO.StringIO()
1250        f.write(pickled1)
1251        f.seek(0)
1252        unpickler = self.unpickler_class(f)
1253        self.assertEqual(unpickler.load(), data1)
1254
1255        f.seek(0)
1256        f.truncate()
1257        f.write(pickled2)
1258        f.seek(0)
1259        self.assertEqual(unpickler.load(), data2)
1260