• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Create portable serialized representations of Python objects.
2
3See module cPickle for a (much) faster implementation.
4See module copy_reg for a mechanism for registering custom picklers.
5See module pickletools source for extensive comments.
6
7Classes:
8
9    Pickler
10    Unpickler
11
12Functions:
13
14    dump(object, file)
15    dumps(object) -> string
16    load(file) -> object
17    loads(string) -> object
18
19Misc variables:
20
21    __version__
22    format_version
23    compatible_formats
24
25"""
26
27__version__ = "$Revision$"       # Code version
28
29from types import *
30from copy_reg import dispatch_table
31from copy_reg import _extension_registry, _inverted_registry, _extension_cache
32import marshal
33import sys
34import struct
35import re
36
37__all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler",
38           "Unpickler", "dump", "dumps", "load", "loads"]
39
40# These are purely informational; no code uses these.
41format_version = "2.0"                  # File format version we write
42compatible_formats = ["1.0",            # Original protocol 0
43                      "1.1",            # Protocol 0 with INST added
44                      "1.2",            # Original protocol 1
45                      "1.3",            # Protocol 1 with BINFLOAT added
46                      "2.0",            # Protocol 2
47                      ]                 # Old format versions we can read
48
49# Keep in synch with cPickle.  This is the highest protocol number we
50# know how to read.
51HIGHEST_PROTOCOL = 2
52
53# Why use struct.pack() for pickling but marshal.loads() for
54# unpickling?  struct.pack() is 40% faster than marshal.dumps(), but
55# marshal.loads() is twice as fast as struct.unpack()!
56mloads = marshal.loads
57
58class PickleError(Exception):
59    """A common base class for the other pickling exceptions."""
60    pass
61
62class PicklingError(PickleError):
63    """This exception is raised when an unpicklable object is passed to the
64    dump() method.
65
66    """
67    pass
68
69class UnpicklingError(PickleError):
70    """This exception is raised when there is a problem unpickling an object,
71    such as a security violation.
72
73    Note that other exceptions may also be raised during unpickling, including
74    (but not necessarily limited to) AttributeError, EOFError, ImportError,
75    and IndexError.
76
77    """
78    pass
79
80# An instance of _Stop is raised by Unpickler.load_stop() in response to
81# the STOP opcode, passing the object that is the result of unpickling.
82class _Stop(Exception):
83    def __init__(self, value):
84        self.value = value
85
86# Jython has PyStringMap; it's a dict subclass with string keys
87try:
88    from org.python.core import PyStringMap
89except ImportError:
90    PyStringMap = None
91
92# UnicodeType may or may not be exported (normally imported from types)
93try:
94    UnicodeType
95except NameError:
96    UnicodeType = None
97
98# Pickle opcodes.  See pickletools.py for extensive docs.  The listing
99# here is in kind-of alphabetical order of 1-character pickle code.
100# pickletools groups them by purpose.
101
102MARK            = '('   # push special markobject on stack
103STOP            = '.'   # every pickle ends with STOP
104POP             = '0'   # discard topmost stack item
105POP_MARK        = '1'   # discard stack top through topmost markobject
106DUP             = '2'   # duplicate top stack item
107FLOAT           = 'F'   # push float object; decimal string argument
108INT             = 'I'   # push integer or bool; decimal string argument
109BININT          = 'J'   # push four-byte signed int
110BININT1         = 'K'   # push 1-byte unsigned int
111LONG            = 'L'   # push long; decimal string argument
112BININT2         = 'M'   # push 2-byte unsigned int
113NONE            = 'N'   # push None
114PERSID          = 'P'   # push persistent object; id is taken from string arg
115BINPERSID       = 'Q'   #  "       "         "  ;  "  "   "     "  stack
116REDUCE          = 'R'   # apply callable to argtuple, both on stack
117STRING          = 'S'   # push string; NL-terminated string argument
118BINSTRING       = 'T'   # push string; counted binary string argument
119SHORT_BINSTRING = 'U'   #  "     "   ;    "      "       "      " < 256 bytes
120UNICODE         = 'V'   # push Unicode string; raw-unicode-escaped'd argument
121BINUNICODE      = 'X'   #   "     "       "  ; counted UTF-8 string argument
122APPEND          = 'a'   # append stack top to list below it
123BUILD           = 'b'   # call __setstate__ or __dict__.update()
124GLOBAL          = 'c'   # push self.find_class(modname, name); 2 string args
125DICT            = 'd'   # build a dict from stack items
126EMPTY_DICT      = '}'   # push empty dict
127APPENDS         = 'e'   # extend list on stack by topmost stack slice
128GET             = 'g'   # push item from memo on stack; index is string arg
129BINGET          = 'h'   #   "    "    "    "   "   "  ;   "    " 1-byte arg
130INST            = 'i'   # build & push class instance
131LONG_BINGET     = 'j'   # push item from memo on stack; index is 4-byte arg
132LIST            = 'l'   # build list from topmost stack items
133EMPTY_LIST      = ']'   # push empty list
134OBJ             = 'o'   # build & push class instance
135PUT             = 'p'   # store stack top in memo; index is string arg
136BINPUT          = 'q'   #   "     "    "   "   " ;   "    " 1-byte arg
137LONG_BINPUT     = 'r'   #   "     "    "   "   " ;   "    " 4-byte arg
138SETITEM         = 's'   # add key+value pair to dict
139TUPLE           = 't'   # build tuple from topmost stack items
140EMPTY_TUPLE     = ')'   # push empty tuple
141SETITEMS        = 'u'   # modify dict by adding topmost key+value pairs
142BINFLOAT        = 'G'   # push float; arg is 8-byte float encoding
143
144TRUE            = 'I01\n'  # not an opcode; see INT docs in pickletools.py
145FALSE           = 'I00\n'  # not an opcode; see INT docs in pickletools.py
146
147# Protocol 2
148
149PROTO           = '\x80'  # identify pickle protocol
150NEWOBJ          = '\x81'  # build object by applying cls.__new__ to argtuple
151EXT1            = '\x82'  # push object from extension registry; 1-byte index
152EXT2            = '\x83'  # ditto, but 2-byte index
153EXT4            = '\x84'  # ditto, but 4-byte index
154TUPLE1          = '\x85'  # build 1-tuple from stack top
155TUPLE2          = '\x86'  # build 2-tuple from two topmost stack items
156TUPLE3          = '\x87'  # build 3-tuple from three topmost stack items
157NEWTRUE         = '\x88'  # push True
158NEWFALSE        = '\x89'  # push False
159LONG1           = '\x8a'  # push long from < 256 bytes
160LONG4           = '\x8b'  # push really big long
161
162_tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3]
163
164
165__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)])
166del x
167
168
169# Pickling machinery
170
171class Pickler:
172
173    def __init__(self, file, protocol=None):
174        """This takes a file-like object for writing a pickle data stream.
175
176        The optional protocol argument tells the pickler to use the
177        given protocol; supported protocols are 0, 1, 2.  The default
178        protocol is 0, to be backwards compatible.  (Protocol 0 is the
179        only protocol that can be written to a file opened in text
180        mode and read back successfully.  When using a protocol higher
181        than 0, make sure the file is opened in binary mode, both when
182        pickling and unpickling.)
183
184        Protocol 1 is more efficient than protocol 0; protocol 2 is
185        more efficient than protocol 1.
186
187        Specifying a negative protocol version selects the highest
188        protocol version supported.  The higher the protocol used, the
189        more recent the version of Python needed to read the pickle
190        produced.
191
192        The file parameter must have a write() method that accepts a single
193        string argument.  It can thus be an open file object, a StringIO
194        object, or any other custom object that meets this interface.
195
196        """
197        if protocol is None:
198            protocol = 0
199        if protocol < 0:
200            protocol = HIGHEST_PROTOCOL
201        elif not 0 <= protocol <= HIGHEST_PROTOCOL:
202            raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
203        self.write = file.write
204        self.memo = {}
205        self.proto = int(protocol)
206        self.bin = protocol >= 1
207        self.fast = 0
208
209    def clear_memo(self):
210        """Clears the pickler's "memo".
211
212        The memo is the data structure that remembers which objects the
213        pickler has already seen, so that shared or recursive objects are
214        pickled by reference and not by value.  This method is useful when
215        re-using picklers.
216
217        """
218        self.memo.clear()
219
220    def dump(self, obj):
221        """Write a pickled representation of obj to the open file."""
222        if self.proto >= 2:
223            self.write(PROTO + chr(self.proto))
224        self.save(obj)
225        self.write(STOP)
226
227    def memoize(self, obj):
228        """Store an object in the memo."""
229
230        # The Pickler memo is a dictionary mapping object ids to 2-tuples
231        # that contain the Unpickler memo key and the object being memoized.
232        # The memo key is written to the pickle and will become
233        # the key in the Unpickler's memo.  The object is stored in the
234        # Pickler memo so that transient objects are kept alive during
235        # pickling.
236
237        # The use of the Unpickler memo length as the memo key is just a
238        # convention.  The only requirement is that the memo values be unique.
239        # But there appears no advantage to any other scheme, and this
240        # scheme allows the Unpickler memo to be implemented as a plain (but
241        # growable) array, indexed by memo key.
242        if self.fast:
243            return
244        assert id(obj) not in self.memo
245        memo_len = len(self.memo)
246        self.write(self.put(memo_len))
247        self.memo[id(obj)] = memo_len, obj
248
249    # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
250    def put(self, i, pack=struct.pack):
251        if self.bin:
252            if i < 256:
253                return BINPUT + chr(i)
254            else:
255                return LONG_BINPUT + pack("<i", i)
256
257        return PUT + repr(i) + '\n'
258
259    # Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
260    def get(self, i, pack=struct.pack):
261        if self.bin:
262            if i < 256:
263                return BINGET + chr(i)
264            else:
265                return LONG_BINGET + pack("<i", i)
266
267        return GET + repr(i) + '\n'
268
269    def save(self, obj):
270        # Check for persistent id (defined by a subclass)
271        pid = self.persistent_id(obj)
272        if pid:
273            self.save_pers(pid)
274            return
275
276        # Check the memo
277        x = self.memo.get(id(obj))
278        if x:
279            self.write(self.get(x[0]))
280            return
281
282        # Check the type dispatch table
283        t = type(obj)
284        f = self.dispatch.get(t)
285        if f:
286            f(self, obj) # Call unbound method with explicit self
287            return
288
289        # Check for a class with a custom metaclass; treat as regular class
290        try:
291            issc = issubclass(t, TypeType)
292        except TypeError: # t is not a class (old Boost; see SF #502085)
293            issc = 0
294        if issc:
295            self.save_global(obj)
296            return
297
298        # Check copy_reg.dispatch_table
299        reduce = dispatch_table.get(t)
300        if reduce:
301            rv = reduce(obj)
302        else:
303            # Check for a __reduce_ex__ method, fall back to __reduce__
304            reduce = getattr(obj, "__reduce_ex__", None)
305            if reduce:
306                rv = reduce(self.proto)
307            else:
308                reduce = getattr(obj, "__reduce__", None)
309                if reduce:
310                    rv = reduce()
311                else:
312                    raise PicklingError("Can't pickle %r object: %r" %
313                                        (t.__name__, obj))
314
315        # Check for string returned by reduce(), meaning "save as global"
316        if type(rv) is StringType:
317            self.save_global(obj, rv)
318            return
319
320        # Assert that reduce() returned a tuple
321        if type(rv) is not TupleType:
322            raise PicklingError("%s must return string or tuple" % reduce)
323
324        # Assert that it returned an appropriately sized tuple
325        l = len(rv)
326        if not (2 <= l <= 5):
327            raise PicklingError("Tuple returned by %s must have "
328                                "two to five elements" % reduce)
329
330        # Save the reduce() output and finally memoize the object
331        self.save_reduce(obj=obj, *rv)
332
333    def persistent_id(self, obj):
334        # This exists so a subclass can override it
335        return None
336
337    def save_pers(self, pid):
338        # Save a persistent id reference
339        if self.bin:
340            self.save(pid)
341            self.write(BINPERSID)
342        else:
343            self.write(PERSID + str(pid) + '\n')
344
345    def save_reduce(self, func, args, state=None,
346                    listitems=None, dictitems=None, obj=None):
347        # This API is called by some subclasses
348
349        # Assert that args is a tuple or None
350        if not isinstance(args, TupleType):
351            raise PicklingError("args from reduce() should be a tuple")
352
353        # Assert that func is callable
354        if not hasattr(func, '__call__'):
355            raise PicklingError("func from reduce should be callable")
356
357        save = self.save
358        write = self.write
359
360        # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
361        if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
362            # A __reduce__ implementation can direct protocol 2 to
363            # use the more efficient NEWOBJ opcode, while still
364            # allowing protocol 0 and 1 to work normally.  For this to
365            # work, the function returned by __reduce__ should be
366            # called __newobj__, and its first argument should be a
367            # new-style class.  The implementation for __newobj__
368            # should be as follows, although pickle has no way to
369            # verify this:
370            #
371            # def __newobj__(cls, *args):
372            #     return cls.__new__(cls, *args)
373            #
374            # Protocols 0 and 1 will pickle a reference to __newobj__,
375            # while protocol 2 (and above) will pickle a reference to
376            # cls, the remaining args tuple, and the NEWOBJ code,
377            # which calls cls.__new__(cls, *args) at unpickling time
378            # (see load_newobj below).  If __reduce__ returns a
379            # three-tuple, the state from the third tuple item will be
380            # pickled regardless of the protocol, calling __setstate__
381            # at unpickling time (see load_build below).
382            #
383            # Note that no standard __newobj__ implementation exists;
384            # you have to provide your own.  This is to enforce
385            # compatibility with Python 2.2 (pickles written using
386            # protocol 0 or 1 in Python 2.3 should be unpicklable by
387            # Python 2.2).
388            cls = args[0]
389            if not hasattr(cls, "__new__"):
390                raise PicklingError(
391                    "args[0] from __newobj__ args has no __new__")
392            if obj is not None and cls is not obj.__class__:
393                raise PicklingError(
394                    "args[0] from __newobj__ args has the wrong class")
395            args = args[1:]
396            save(cls)
397            save(args)
398            write(NEWOBJ)
399        else:
400            save(func)
401            save(args)
402            write(REDUCE)
403
404        if obj is not None:
405            self.memoize(obj)
406
407        # More new special cases (that work with older protocols as
408        # well): when __reduce__ returns a tuple with 4 or 5 items,
409        # the 4th and 5th item should be iterators that provide list
410        # items and dict items (as (key, value) tuples), or None.
411
412        if listitems is not None:
413            self._batch_appends(listitems)
414
415        if dictitems is not None:
416            self._batch_setitems(dictitems)
417
418        if state is not None:
419            save(state)
420            write(BUILD)
421
422    # Methods below this point are dispatched through the dispatch table
423
424    dispatch = {}
425
426    def save_none(self, obj):
427        self.write(NONE)
428    dispatch[NoneType] = save_none
429
430    def save_bool(self, obj):
431        if self.proto >= 2:
432            self.write(obj and NEWTRUE or NEWFALSE)
433        else:
434            self.write(obj and TRUE or FALSE)
435    dispatch[bool] = save_bool
436
437    def save_int(self, obj, pack=struct.pack):
438        if self.bin:
439            # If the int is small enough to fit in a signed 4-byte 2's-comp
440            # format, we can store it more efficiently than the general
441            # case.
442            # First one- and two-byte unsigned ints:
443            if obj >= 0:
444                if obj <= 0xff:
445                    self.write(BININT1 + chr(obj))
446                    return
447                if obj <= 0xffff:
448                    self.write("%c%c%c" % (BININT2, obj&0xff, obj>>8))
449                    return
450            # Next check for 4-byte signed ints:
451            high_bits = obj >> 31  # note that Python shift sign-extends
452            if high_bits == 0 or high_bits == -1:
453                # All high bits are copies of bit 2**31, so the value
454                # fits in a 4-byte signed int.
455                self.write(BININT + pack("<i", obj))
456                return
457        # Text pickle, or int too big to fit in signed 4-byte format.
458        self.write(INT + repr(obj) + '\n')
459    dispatch[IntType] = save_int
460
461    def save_long(self, obj, pack=struct.pack):
462        if self.proto >= 2:
463            bytes = encode_long(obj)
464            n = len(bytes)
465            if n < 256:
466                self.write(LONG1 + chr(n) + bytes)
467            else:
468                self.write(LONG4 + pack("<i", n) + bytes)
469            return
470        self.write(LONG + repr(obj) + '\n')
471    dispatch[LongType] = save_long
472
473    def save_float(self, obj, pack=struct.pack):
474        if self.bin:
475            self.write(BINFLOAT + pack('>d', obj))
476        else:
477            self.write(FLOAT + repr(obj) + '\n')
478    dispatch[FloatType] = save_float
479
480    def save_string(self, obj, pack=struct.pack):
481        if self.bin:
482            n = len(obj)
483            if n < 256:
484                self.write(SHORT_BINSTRING + chr(n) + obj)
485            else:
486                self.write(BINSTRING + pack("<i", n) + obj)
487        else:
488            self.write(STRING + repr(obj) + '\n')
489        self.memoize(obj)
490    dispatch[StringType] = save_string
491
492    def save_unicode(self, obj, pack=struct.pack):
493        if self.bin:
494            encoding = obj.encode('utf-8')
495            n = len(encoding)
496            self.write(BINUNICODE + pack("<i", n) + encoding)
497        else:
498            obj = obj.replace("\\", "\\u005c")
499            obj = obj.replace("\n", "\\u000a")
500            self.write(UNICODE + obj.encode('raw-unicode-escape') + '\n')
501        self.memoize(obj)
502    dispatch[UnicodeType] = save_unicode
503
504    if StringType is UnicodeType:
505        # This is true for Jython
506        def save_string(self, obj, pack=struct.pack):
507            unicode = obj.isunicode()
508
509            if self.bin:
510                if unicode:
511                    obj = obj.encode("utf-8")
512                l = len(obj)
513                if l < 256 and not unicode:
514                    self.write(SHORT_BINSTRING + chr(l) + obj)
515                else:
516                    s = pack("<i", l)
517                    if unicode:
518                        self.write(BINUNICODE + s + obj)
519                    else:
520                        self.write(BINSTRING + s + obj)
521            else:
522                if unicode:
523                    obj = obj.replace("\\", "\\u005c")
524                    obj = obj.replace("\n", "\\u000a")
525                    obj = obj.encode('raw-unicode-escape')
526                    self.write(UNICODE + obj + '\n')
527                else:
528                    self.write(STRING + repr(obj) + '\n')
529            self.memoize(obj)
530        dispatch[StringType] = save_string
531
532    def save_tuple(self, obj):
533        write = self.write
534        proto = self.proto
535
536        n = len(obj)
537        if n == 0:
538            if proto:
539                write(EMPTY_TUPLE)
540            else:
541                write(MARK + TUPLE)
542            return
543
544        save = self.save
545        memo = self.memo
546        if n <= 3 and proto >= 2:
547            for element in obj:
548                save(element)
549            # Subtle.  Same as in the big comment below.
550            if id(obj) in memo:
551                get = self.get(memo[id(obj)][0])
552                write(POP * n + get)
553            else:
554                write(_tuplesize2code[n])
555                self.memoize(obj)
556            return
557
558        # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple
559        # has more than 3 elements.
560        write(MARK)
561        for element in obj:
562            save(element)
563
564        if id(obj) in memo:
565            # Subtle.  d was not in memo when we entered save_tuple(), so
566            # the process of saving the tuple's elements must have saved
567            # the tuple itself:  the tuple is recursive.  The proper action
568            # now is to throw away everything we put on the stack, and
569            # simply GET the tuple (it's already constructed).  This check
570            # could have been done in the "for element" loop instead, but
571            # recursive tuples are a rare thing.
572            get = self.get(memo[id(obj)][0])
573            if proto:
574                write(POP_MARK + get)
575            else:   # proto 0 -- POP_MARK not available
576                write(POP * (n+1) + get)
577            return
578
579        # No recursion.
580        self.write(TUPLE)
581        self.memoize(obj)
582
583    dispatch[TupleType] = save_tuple
584
585    # save_empty_tuple() isn't used by anything in Python 2.3.  However, I
586    # found a Pickler subclass in Zope3 that calls it, so it's not harmless
587    # to remove it.
588    def save_empty_tuple(self, obj):
589        self.write(EMPTY_TUPLE)
590
591    def save_list(self, obj):
592        write = self.write
593
594        if self.bin:
595            write(EMPTY_LIST)
596        else:   # proto 0 -- can't use EMPTY_LIST
597            write(MARK + LIST)
598
599        self.memoize(obj)
600        self._batch_appends(iter(obj))
601
602    dispatch[ListType] = save_list
603
604    # Keep in synch with cPickle's BATCHSIZE.  Nothing will break if it gets
605    # out of synch, though.
606    _BATCHSIZE = 1000
607
608    def _batch_appends(self, items):
609        # Helper to batch up APPENDS sequences
610        save = self.save
611        write = self.write
612
613        if not self.bin:
614            for x in items:
615                save(x)
616                write(APPEND)
617            return
618
619        r = xrange(self._BATCHSIZE)
620        while items is not None:
621            tmp = []
622            for i in r:
623                try:
624                    x = items.next()
625                    tmp.append(x)
626                except StopIteration:
627                    items = None
628                    break
629            n = len(tmp)
630            if n > 1:
631                write(MARK)
632                for x in tmp:
633                    save(x)
634                write(APPENDS)
635            elif n:
636                save(tmp[0])
637                write(APPEND)
638            # else tmp is empty, and we're done
639
640    def save_dict(self, obj):
641        write = self.write
642
643        if self.bin:
644            write(EMPTY_DICT)
645        else:   # proto 0 -- can't use EMPTY_DICT
646            write(MARK + DICT)
647
648        self.memoize(obj)
649        self._batch_setitems(obj.iteritems())
650
651    dispatch[DictionaryType] = save_dict
652    if not PyStringMap is None:
653        dispatch[PyStringMap] = save_dict
654
655    def _batch_setitems(self, items):
656        # Helper to batch up SETITEMS sequences; proto >= 1 only
657        save = self.save
658        write = self.write
659
660        if not self.bin:
661            for k, v in items:
662                save(k)
663                save(v)
664                write(SETITEM)
665            return
666
667        r = xrange(self._BATCHSIZE)
668        while items is not None:
669            tmp = []
670            for i in r:
671                try:
672                    tmp.append(items.next())
673                except StopIteration:
674                    items = None
675                    break
676            n = len(tmp)
677            if n > 1:
678                write(MARK)
679                for k, v in tmp:
680                    save(k)
681                    save(v)
682                write(SETITEMS)
683            elif n:
684                k, v = tmp[0]
685                save(k)
686                save(v)
687                write(SETITEM)
688            # else tmp is empty, and we're done
689
690    def save_inst(self, obj):
691        cls = obj.__class__
692
693        memo  = self.memo
694        write = self.write
695        save  = self.save
696
697        if hasattr(obj, '__getinitargs__'):
698            args = obj.__getinitargs__()
699            len(args) # XXX Assert it's a sequence
700            _keep_alive(args, memo)
701        else:
702            args = ()
703
704        write(MARK)
705
706        if self.bin:
707            save(cls)
708            for arg in args:
709                save(arg)
710            write(OBJ)
711        else:
712            for arg in args:
713                save(arg)
714            write(INST + cls.__module__ + '\n' + cls.__name__ + '\n')
715
716        self.memoize(obj)
717
718        try:
719            getstate = obj.__getstate__
720        except AttributeError:
721            stuff = obj.__dict__
722        else:
723            stuff = getstate()
724            _keep_alive(stuff, memo)
725        save(stuff)
726        write(BUILD)
727
728    dispatch[InstanceType] = save_inst
729
730    def save_global(self, obj, name=None, pack=struct.pack):
731        write = self.write
732        memo = self.memo
733
734        if name is None:
735            name = obj.__name__
736
737        module = getattr(obj, "__module__", None)
738        if module is None:
739            module = whichmodule(obj, name)
740
741        try:
742            __import__(module)
743            mod = sys.modules[module]
744            klass = getattr(mod, name)
745        except (ImportError, KeyError, AttributeError):
746            raise PicklingError(
747                "Can't pickle %r: it's not found as %s.%s" %
748                (obj, module, name))
749        else:
750            if klass is not obj:
751                raise PicklingError(
752                    "Can't pickle %r: it's not the same object as %s.%s" %
753                    (obj, module, name))
754
755        if self.proto >= 2:
756            code = _extension_registry.get((module, name))
757            if code:
758                assert code > 0
759                if code <= 0xff:
760                    write(EXT1 + chr(code))
761                elif code <= 0xffff:
762                    write("%c%c%c" % (EXT2, code&0xff, code>>8))
763                else:
764                    write(EXT4 + pack("<i", code))
765                return
766
767        write(GLOBAL + module + '\n' + name + '\n')
768        self.memoize(obj)
769
770    dispatch[ClassType] = save_global
771    dispatch[FunctionType] = save_global
772    dispatch[BuiltinFunctionType] = save_global
773    dispatch[TypeType] = save_global
774
775# Pickling helpers
776
777def _keep_alive(x, memo):
778    """Keeps a reference to the object x in the memo.
779
780    Because we remember objects by their id, we have
781    to assure that possibly temporary objects are kept
782    alive by referencing them.
783    We store a reference at the id of the memo, which should
784    normally not be used unless someone tries to deepcopy
785    the memo itself...
786    """
787    try:
788        memo[id(memo)].append(x)
789    except KeyError:
790        # aha, this is the first one :-)
791        memo[id(memo)]=[x]
792
793
794# A cache for whichmodule(), mapping a function object to the name of
795# the module in which the function was found.
796
797classmap = {} # called classmap for backwards compatibility
798
799def whichmodule(func, funcname):
800    """Figure out the module in which a function occurs.
801
802    Search sys.modules for the module.
803    Cache in classmap.
804    Return a module name.
805    If the function cannot be found, return "__main__".
806    """
807    # Python functions should always get an __module__ from their globals.
808    mod = getattr(func, "__module__", None)
809    if mod is not None:
810        return mod
811    if func in classmap:
812        return classmap[func]
813
814    for name, module in sys.modules.items():
815        if module is None:
816            continue # skip dummy package entries
817        if name != '__main__' and getattr(module, funcname, None) is func:
818            break
819    else:
820        name = '__main__'
821    classmap[func] = name
822    return name
823
824
825# Unpickling machinery
826
827class Unpickler:
828
829    def __init__(self, file):
830        """This takes a file-like object for reading a pickle data stream.
831
832        The protocol version of the pickle is detected automatically, so no
833        proto argument is needed.
834
835        The file-like object must have two methods, a read() method that
836        takes an integer argument, and a readline() method that requires no
837        arguments.  Both methods should return a string.  Thus file-like
838        object can be a file object opened for reading, a StringIO object,
839        or any other custom object that meets this interface.
840        """
841        self.readline = file.readline
842        self.read = file.read
843        self.memo = {}
844
845    def load(self):
846        """Read a pickled object representation from the open file.
847
848        Return the reconstituted object hierarchy specified in the file.
849        """
850        self.mark = object() # any new unique object
851        self.stack = []
852        self.append = self.stack.append
853        read = self.read
854        dispatch = self.dispatch
855        try:
856            while 1:
857                key = read(1)
858                dispatch[key](self)
859        except _Stop, stopinst:
860            return stopinst.value
861
862    # Return largest index k such that self.stack[k] is self.mark.
863    # If the stack doesn't contain a mark, eventually raises IndexError.
864    # This could be sped by maintaining another stack, of indices at which
865    # the mark appears.  For that matter, the latter stack would suffice,
866    # and we wouldn't need to push mark objects on self.stack at all.
867    # Doing so is probably a good thing, though, since if the pickle is
868    # corrupt (or hostile) we may get a clue from finding self.mark embedded
869    # in unpickled objects.
870    def marker(self):
871        stack = self.stack
872        mark = self.mark
873        k = len(stack)-1
874        while stack[k] is not mark: k = k-1
875        return k
876
877    dispatch = {}
878
879    def load_eof(self):
880        raise EOFError
881    dispatch[''] = load_eof
882
883    def load_proto(self):
884        proto = ord(self.read(1))
885        if not 0 <= proto <= 2:
886            raise ValueError, "unsupported pickle protocol: %d" % proto
887    dispatch[PROTO] = load_proto
888
889    def load_persid(self):
890        pid = self.readline()[:-1]
891        self.append(self.persistent_load(pid))
892    dispatch[PERSID] = load_persid
893
894    def load_binpersid(self):
895        pid = self.stack.pop()
896        self.append(self.persistent_load(pid))
897    dispatch[BINPERSID] = load_binpersid
898
899    def load_none(self):
900        self.append(None)
901    dispatch[NONE] = load_none
902
903    def load_false(self):
904        self.append(False)
905    dispatch[NEWFALSE] = load_false
906
907    def load_true(self):
908        self.append(True)
909    dispatch[NEWTRUE] = load_true
910
911    def load_int(self):
912        data = self.readline()
913        if data == FALSE[1:]:
914            val = False
915        elif data == TRUE[1:]:
916            val = True
917        else:
918            try:
919                val = int(data)
920            except ValueError:
921                val = long(data)
922        self.append(val)
923    dispatch[INT] = load_int
924
925    def load_binint(self):
926        self.append(mloads('i' + self.read(4)))
927    dispatch[BININT] = load_binint
928
929    def load_binint1(self):
930        self.append(ord(self.read(1)))
931    dispatch[BININT1] = load_binint1
932
933    def load_binint2(self):
934        self.append(mloads('i' + self.read(2) + '\000\000'))
935    dispatch[BININT2] = load_binint2
936
937    def load_long(self):
938        self.append(long(self.readline()[:-1], 0))
939    dispatch[LONG] = load_long
940
941    def load_long1(self):
942        n = ord(self.read(1))
943        bytes = self.read(n)
944        self.append(decode_long(bytes))
945    dispatch[LONG1] = load_long1
946
947    def load_long4(self):
948        n = mloads('i' + self.read(4))
949        bytes = self.read(n)
950        self.append(decode_long(bytes))
951    dispatch[LONG4] = load_long4
952
953    def load_float(self):
954        self.append(float(self.readline()[:-1]))
955    dispatch[FLOAT] = load_float
956
957    def load_binfloat(self, unpack=struct.unpack):
958        self.append(unpack('>d', self.read(8))[0])
959    dispatch[BINFLOAT] = load_binfloat
960
961    def load_string(self):
962        rep = self.readline()[:-1]
963        for q in "\"'": # double or single quote
964            if rep.startswith(q):
965                if not rep.endswith(q):
966                    raise ValueError, "insecure string pickle"
967                rep = rep[len(q):-len(q)]
968                break
969        else:
970            raise ValueError, "insecure string pickle"
971        self.append(rep.decode("string-escape"))
972    dispatch[STRING] = load_string
973
974    def load_binstring(self):
975        len = mloads('i' + self.read(4))
976        self.append(self.read(len))
977    dispatch[BINSTRING] = load_binstring
978
979    def load_unicode(self):
980        self.append(unicode(self.readline()[:-1],'raw-unicode-escape'))
981    dispatch[UNICODE] = load_unicode
982
983    def load_binunicode(self):
984        len = mloads('i' + self.read(4))
985        self.append(unicode(self.read(len),'utf-8'))
986    dispatch[BINUNICODE] = load_binunicode
987
988    def load_short_binstring(self):
989        len = ord(self.read(1))
990        self.append(self.read(len))
991    dispatch[SHORT_BINSTRING] = load_short_binstring
992
993    def load_tuple(self):
994        k = self.marker()
995        self.stack[k:] = [tuple(self.stack[k+1:])]
996    dispatch[TUPLE] = load_tuple
997
998    def load_empty_tuple(self):
999        self.stack.append(())
1000    dispatch[EMPTY_TUPLE] = load_empty_tuple
1001
1002    def load_tuple1(self):
1003        self.stack[-1] = (self.stack[-1],)
1004    dispatch[TUPLE1] = load_tuple1
1005
1006    def load_tuple2(self):
1007        self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
1008    dispatch[TUPLE2] = load_tuple2
1009
1010    def load_tuple3(self):
1011        self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
1012    dispatch[TUPLE3] = load_tuple3
1013
1014    def load_empty_list(self):
1015        self.stack.append([])
1016    dispatch[EMPTY_LIST] = load_empty_list
1017
1018    def load_empty_dictionary(self):
1019        self.stack.append({})
1020    dispatch[EMPTY_DICT] = load_empty_dictionary
1021
1022    def load_list(self):
1023        k = self.marker()
1024        self.stack[k:] = [self.stack[k+1:]]
1025    dispatch[LIST] = load_list
1026
1027    def load_dict(self):
1028        k = self.marker()
1029        d = {}
1030        items = self.stack[k+1:]
1031        for i in range(0, len(items), 2):
1032            key = items[i]
1033            value = items[i+1]
1034            d[key] = value
1035        self.stack[k:] = [d]
1036    dispatch[DICT] = load_dict
1037
1038    # INST and OBJ differ only in how they get a class object.  It's not
1039    # only sensible to do the rest in a common routine, the two routines
1040    # previously diverged and grew different bugs.
1041    # klass is the class to instantiate, and k points to the topmost mark
1042    # object, following which are the arguments for klass.__init__.
1043    def _instantiate(self, klass, k):
1044        args = tuple(self.stack[k+1:])
1045        del self.stack[k:]
1046        instantiated = 0
1047        if (not args and
1048                type(klass) is ClassType and
1049                not hasattr(klass, "__getinitargs__")):
1050            try:
1051                value = _EmptyClass()
1052                value.__class__ = klass
1053                instantiated = 1
1054            except RuntimeError:
1055                # In restricted execution, assignment to inst.__class__ is
1056                # prohibited
1057                pass
1058        if not instantiated:
1059            try:
1060                value = klass(*args)
1061            except TypeError, err:
1062                raise TypeError, "in constructor for %s: %s" % (
1063                    klass.__name__, str(err)), sys.exc_info()[2]
1064        self.append(value)
1065
1066    def load_inst(self):
1067        module = self.readline()[:-1]
1068        name = self.readline()[:-1]
1069        klass = self.find_class(module, name)
1070        self._instantiate(klass, self.marker())
1071    dispatch[INST] = load_inst
1072
1073    def load_obj(self):
1074        # Stack is ... markobject classobject arg1 arg2 ...
1075        k = self.marker()
1076        klass = self.stack.pop(k+1)
1077        self._instantiate(klass, k)
1078    dispatch[OBJ] = load_obj
1079
1080    def load_newobj(self):
1081        args = self.stack.pop()
1082        cls = self.stack[-1]
1083        obj = cls.__new__(cls, *args)
1084        self.stack[-1] = obj
1085    dispatch[NEWOBJ] = load_newobj
1086
1087    def load_global(self):
1088        module = self.readline()[:-1]
1089        name = self.readline()[:-1]
1090        klass = self.find_class(module, name)
1091        self.append(klass)
1092    dispatch[GLOBAL] = load_global
1093
1094    def load_ext1(self):
1095        code = ord(self.read(1))
1096        self.get_extension(code)
1097    dispatch[EXT1] = load_ext1
1098
1099    def load_ext2(self):
1100        code = mloads('i' + self.read(2) + '\000\000')
1101        self.get_extension(code)
1102    dispatch[EXT2] = load_ext2
1103
1104    def load_ext4(self):
1105        code = mloads('i' + self.read(4))
1106        self.get_extension(code)
1107    dispatch[EXT4] = load_ext4
1108
1109    def get_extension(self, code):
1110        nil = []
1111        obj = _extension_cache.get(code, nil)
1112        if obj is not nil:
1113            self.append(obj)
1114            return
1115        key = _inverted_registry.get(code)
1116        if not key:
1117            raise ValueError("unregistered extension code %d" % code)
1118        obj = self.find_class(*key)
1119        _extension_cache[code] = obj
1120        self.append(obj)
1121
1122    def find_class(self, module, name):
1123        # Subclasses may override this
1124        __import__(module)
1125        mod = sys.modules[module]
1126        klass = getattr(mod, name)
1127        return klass
1128
1129    def load_reduce(self):
1130        stack = self.stack
1131        args = stack.pop()
1132        func = stack[-1]
1133        value = func(*args)
1134        stack[-1] = value
1135    dispatch[REDUCE] = load_reduce
1136
1137    def load_pop(self):
1138        del self.stack[-1]
1139    dispatch[POP] = load_pop
1140
1141    def load_pop_mark(self):
1142        k = self.marker()
1143        del self.stack[k:]
1144    dispatch[POP_MARK] = load_pop_mark
1145
1146    def load_dup(self):
1147        self.append(self.stack[-1])
1148    dispatch[DUP] = load_dup
1149
1150    def load_get(self):
1151        self.append(self.memo[self.readline()[:-1]])
1152    dispatch[GET] = load_get
1153
1154    def load_binget(self):
1155        i = ord(self.read(1))
1156        self.append(self.memo[repr(i)])
1157    dispatch[BINGET] = load_binget
1158
1159    def load_long_binget(self):
1160        i = mloads('i' + self.read(4))
1161        self.append(self.memo[repr(i)])
1162    dispatch[LONG_BINGET] = load_long_binget
1163
1164    def load_put(self):
1165        self.memo[self.readline()[:-1]] = self.stack[-1]
1166    dispatch[PUT] = load_put
1167
1168    def load_binput(self):
1169        i = ord(self.read(1))
1170        self.memo[repr(i)] = self.stack[-1]
1171    dispatch[BINPUT] = load_binput
1172
1173    def load_long_binput(self):
1174        i = mloads('i' + self.read(4))
1175        self.memo[repr(i)] = self.stack[-1]
1176    dispatch[LONG_BINPUT] = load_long_binput
1177
1178    def load_append(self):
1179        stack = self.stack
1180        value = stack.pop()
1181        list = stack[-1]
1182        list.append(value)
1183    dispatch[APPEND] = load_append
1184
1185    def load_appends(self):
1186        stack = self.stack
1187        mark = self.marker()
1188        list = stack[mark - 1]
1189        list.extend(stack[mark + 1:])
1190        del stack[mark:]
1191    dispatch[APPENDS] = load_appends
1192
1193    def load_setitem(self):
1194        stack = self.stack
1195        value = stack.pop()
1196        key = stack.pop()
1197        dict = stack[-1]
1198        dict[key] = value
1199    dispatch[SETITEM] = load_setitem
1200
1201    def load_setitems(self):
1202        stack = self.stack
1203        mark = self.marker()
1204        dict = stack[mark - 1]
1205        for i in range(mark + 1, len(stack), 2):
1206            dict[stack[i]] = stack[i + 1]
1207
1208        del stack[mark:]
1209    dispatch[SETITEMS] = load_setitems
1210
1211    def load_build(self):
1212        stack = self.stack
1213        state = stack.pop()
1214        inst = stack[-1]
1215        setstate = getattr(inst, "__setstate__", None)
1216        if setstate:
1217            setstate(state)
1218            return
1219        slotstate = None
1220        if isinstance(state, tuple) and len(state) == 2:
1221            state, slotstate = state
1222        if state:
1223            try:
1224                d = inst.__dict__
1225                try:
1226                    for k, v in state.iteritems():
1227                        d[intern(k)] = v
1228                # keys in state don't have to be strings
1229                # don't blow up, but don't go out of our way
1230                except TypeError:
1231                    d.update(state)
1232
1233            except RuntimeError:
1234                # XXX In restricted execution, the instance's __dict__
1235                # is not accessible.  Use the old way of unpickling
1236                # the instance variables.  This is a semantic
1237                # difference when unpickling in restricted
1238                # vs. unrestricted modes.
1239                # Note, however, that cPickle has never tried to do the
1240                # .update() business, and always uses
1241                #     PyObject_SetItem(inst.__dict__, key, value) in a
1242                # loop over state.items().
1243                for k, v in state.items():
1244                    setattr(inst, k, v)
1245        if slotstate:
1246            for k, v in slotstate.items():
1247                setattr(inst, k, v)
1248    dispatch[BUILD] = load_build
1249
1250    def load_mark(self):
1251        self.append(self.mark)
1252    dispatch[MARK] = load_mark
1253
1254    def load_stop(self):
1255        value = self.stack.pop()
1256        raise _Stop(value)
1257    dispatch[STOP] = load_stop
1258
1259# Helper class for load_inst/load_obj
1260
1261class _EmptyClass:
1262    pass
1263
1264# Encode/decode longs in linear time.
1265
1266import binascii as _binascii
1267
1268def encode_long(x):
1269    r"""Encode a long to a two's complement little-endian binary string.
1270    Note that 0L is a special case, returning an empty string, to save a
1271    byte in the LONG1 pickling context.
1272
1273    >>> encode_long(0L)
1274    ''
1275    >>> encode_long(255L)
1276    '\xff\x00'
1277    >>> encode_long(32767L)
1278    '\xff\x7f'
1279    >>> encode_long(-256L)
1280    '\x00\xff'
1281    >>> encode_long(-32768L)
1282    '\x00\x80'
1283    >>> encode_long(-128L)
1284    '\x80'
1285    >>> encode_long(127L)
1286    '\x7f'
1287    >>>
1288    """
1289
1290    if x == 0:
1291        return ''
1292    if x > 0:
1293        ashex = hex(x)
1294        assert ashex.startswith("0x")
1295        njunkchars = 2 + ashex.endswith('L')
1296        nibbles = len(ashex) - njunkchars
1297        if nibbles & 1:
1298            # need an even # of nibbles for unhexlify
1299            ashex = "0x0" + ashex[2:]
1300        elif int(ashex[2], 16) >= 8:
1301            # "looks negative", so need a byte of sign bits
1302            ashex = "0x00" + ashex[2:]
1303    else:
1304        # Build the 256's-complement:  (1L << nbytes) + x.  The trick is
1305        # to find the number of bytes in linear time (although that should
1306        # really be a constant-time task).
1307        ashex = hex(-x)
1308        assert ashex.startswith("0x")
1309        njunkchars = 2 + ashex.endswith('L')
1310        nibbles = len(ashex) - njunkchars
1311        if nibbles & 1:
1312            # Extend to a full byte.
1313            nibbles += 1
1314        nbits = nibbles * 4
1315        x += 1L << nbits
1316        assert x > 0
1317        ashex = hex(x)
1318        njunkchars = 2 + ashex.endswith('L')
1319        newnibbles = len(ashex) - njunkchars
1320        if newnibbles < nibbles:
1321            ashex = "0x" + "0" * (nibbles - newnibbles) + ashex[2:]
1322        if int(ashex[2], 16) < 8:
1323            # "looks positive", so need a byte of sign bits
1324            ashex = "0xff" + ashex[2:]
1325
1326    if ashex.endswith('L'):
1327        ashex = ashex[2:-1]
1328    else:
1329        ashex = ashex[2:]
1330    assert len(ashex) & 1 == 0, (x, ashex)
1331    binary = _binascii.unhexlify(ashex)
1332    return binary[::-1]
1333
1334def decode_long(data):
1335    r"""Decode a long from a two's complement little-endian binary string.
1336
1337    >>> decode_long('')
1338    0L
1339    >>> decode_long("\xff\x00")
1340    255L
1341    >>> decode_long("\xff\x7f")
1342    32767L
1343    >>> decode_long("\x00\xff")
1344    -256L
1345    >>> decode_long("\x00\x80")
1346    -32768L
1347    >>> decode_long("\x80")
1348    -128L
1349    >>> decode_long("\x7f")
1350    127L
1351    """
1352
1353    nbytes = len(data)
1354    if nbytes == 0:
1355        return 0L
1356    ashex = _binascii.hexlify(data[::-1])
1357    n = long(ashex, 16) # quadratic time before Python 2.3; linear now
1358    if data[-1] >= '\x80':
1359        n -= 1L << (nbytes * 8)
1360    return n
1361
1362# Shorthands
1363
1364try:
1365    from cStringIO import StringIO
1366except ImportError:
1367    from StringIO import StringIO
1368
1369def dump(obj, file, protocol=None):
1370    Pickler(file, protocol).dump(obj)
1371
1372def dumps(obj, protocol=None):
1373    file = StringIO()
1374    Pickler(file, protocol).dump(obj)
1375    return file.getvalue()
1376
1377def load(file):
1378    return Unpickler(file).load()
1379
1380def loads(str):
1381    file = StringIO(str)
1382    return Unpickler(file).load()
1383
1384# Doctest
1385
1386def _test():
1387    import doctest
1388    return doctest.testmod()
1389
1390if __name__ == "__main__":
1391    _test()
1392