• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Create portable serialized representations of Python objects.
2
3See module cPickle for a (much) faster implementation.
4See module copy_reg for a mechanism for registering custom picklers.
5See module pickletools source for extensive comments.
6
7Classes:
8
9    Pickler
10    Unpickler
11
12Functions:
13
14    dump(object, file)
15    dumps(object) -> string
16    load(file) -> object
17    loads(string) -> object
18
19Misc variables:
20
21    __version__
22    format_version
23    compatible_formats
24
25"""
26
27__version__ = "$Revision: 72223 $"       # Code version
28
29from types import *
30from copy_reg import dispatch_table
31from copy_reg import _extension_registry, _inverted_registry, _extension_cache
32import marshal
33import sys
34import struct
35import re
36
37__all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler",
38           "Unpickler", "dump", "dumps", "load", "loads"]
39
40# These are purely informational; no code uses these.
41format_version = "2.0"                  # File format version we write
42compatible_formats = ["1.0",            # Original protocol 0
43                      "1.1",            # Protocol 0 with INST added
44                      "1.2",            # Original protocol 1
45                      "1.3",            # Protocol 1 with BINFLOAT added
46                      "2.0",            # Protocol 2
47                      ]                 # Old format versions we can read
48
49# Keep in synch with cPickle.  This is the highest protocol number we
50# know how to read.
51HIGHEST_PROTOCOL = 2
52
53# Why use struct.pack() for pickling but marshal.loads() for
54# unpickling?  struct.pack() is 40% faster than marshal.dumps(), but
55# marshal.loads() is twice as fast as struct.unpack()!
56mloads = marshal.loads
57
58class PickleError(Exception):
59    """A common base class for the other pickling exceptions."""
60    pass
61
62class PicklingError(PickleError):
63    """This exception is raised when an unpicklable object is passed to the
64    dump() method.
65
66    """
67    pass
68
69class UnpicklingError(PickleError):
70    """This exception is raised when there is a problem unpickling an object,
71    such as a security violation.
72
73    Note that other exceptions may also be raised during unpickling, including
74    (but not necessarily limited to) AttributeError, EOFError, ImportError,
75    and IndexError.
76
77    """
78    pass
79
80# An instance of _Stop is raised by Unpickler.load_stop() in response to
81# the STOP opcode, passing the object that is the result of unpickling.
82class _Stop(Exception):
83    def __init__(self, value):
84        self.value = value
85
86# Jython has PyStringMap; it's a dict subclass with string keys
87try:
88    from org.python.core import PyStringMap
89except ImportError:
90    PyStringMap = None
91
92# UnicodeType may or may not be exported (normally imported from types)
93try:
94    UnicodeType
95except NameError:
96    UnicodeType = None
97
98# Pickle opcodes.  See pickletools.py for extensive docs.  The listing
99# here is in kind-of alphabetical order of 1-character pickle code.
100# pickletools groups them by purpose.
101
102MARK            = '('   # push special markobject on stack
103STOP            = '.'   # every pickle ends with STOP
104POP             = '0'   # discard topmost stack item
105POP_MARK        = '1'   # discard stack top through topmost markobject
106DUP             = '2'   # duplicate top stack item
107FLOAT           = 'F'   # push float object; decimal string argument
108INT             = 'I'   # push integer or bool; decimal string argument
109BININT          = 'J'   # push four-byte signed int
110BININT1         = 'K'   # push 1-byte unsigned int
111LONG            = 'L'   # push long; decimal string argument
112BININT2         = 'M'   # push 2-byte unsigned int
113NONE            = 'N'   # push None
114PERSID          = 'P'   # push persistent object; id is taken from string arg
115BINPERSID       = 'Q'   #  "       "         "  ;  "  "   "     "  stack
116REDUCE          = 'R'   # apply callable to argtuple, both on stack
117STRING          = 'S'   # push string; NL-terminated string argument
118BINSTRING       = 'T'   # push string; counted binary string argument
119SHORT_BINSTRING = 'U'   #  "     "   ;    "      "       "      " < 256 bytes
120UNICODE         = 'V'   # push Unicode string; raw-unicode-escaped'd argument
121BINUNICODE      = 'X'   #   "     "       "  ; counted UTF-8 string argument
122APPEND          = 'a'   # append stack top to list below it
123BUILD           = 'b'   # call __setstate__ or __dict__.update()
124GLOBAL          = 'c'   # push self.find_class(modname, name); 2 string args
125DICT            = 'd'   # build a dict from stack items
126EMPTY_DICT      = '}'   # push empty dict
127APPENDS         = 'e'   # extend list on stack by topmost stack slice
128GET             = 'g'   # push item from memo on stack; index is string arg
129BINGET          = 'h'   #   "    "    "    "   "   "  ;   "    " 1-byte arg
130INST            = 'i'   # build & push class instance
131LONG_BINGET     = 'j'   # push item from memo on stack; index is 4-byte arg
132LIST            = 'l'   # build list from topmost stack items
133EMPTY_LIST      = ']'   # push empty list
134OBJ             = 'o'   # build & push class instance
135PUT             = 'p'   # store stack top in memo; index is string arg
136BINPUT          = 'q'   #   "     "    "   "   " ;   "    " 1-byte arg
137LONG_BINPUT     = 'r'   #   "     "    "   "   " ;   "    " 4-byte arg
138SETITEM         = 's'   # add key+value pair to dict
139TUPLE           = 't'   # build tuple from topmost stack items
140EMPTY_TUPLE     = ')'   # push empty tuple
141SETITEMS        = 'u'   # modify dict by adding topmost key+value pairs
142BINFLOAT        = 'G'   # push float; arg is 8-byte float encoding
143
144TRUE            = 'I01\n'  # not an opcode; see INT docs in pickletools.py
145FALSE           = 'I00\n'  # not an opcode; see INT docs in pickletools.py
146
147# Protocol 2
148
149PROTO           = '\x80'  # identify pickle protocol
150NEWOBJ          = '\x81'  # build object by applying cls.__new__ to argtuple
151EXT1            = '\x82'  # push object from extension registry; 1-byte index
152EXT2            = '\x83'  # ditto, but 2-byte index
153EXT4            = '\x84'  # ditto, but 4-byte index
154TUPLE1          = '\x85'  # build 1-tuple from stack top
155TUPLE2          = '\x86'  # build 2-tuple from two topmost stack items
156TUPLE3          = '\x87'  # build 3-tuple from three topmost stack items
157NEWTRUE         = '\x88'  # push True
158NEWFALSE        = '\x89'  # push False
159LONG1           = '\x8a'  # push long from < 256 bytes
160LONG4           = '\x8b'  # push really big long
161
162_tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3]
163
164
165__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)])
166del x
167
168
169# Pickling machinery
170
171class Pickler:
172
173    def __init__(self, file, protocol=None):
174        """This takes a file-like object for writing a pickle data stream.
175
176        The optional protocol argument tells the pickler to use the
177        given protocol; supported protocols are 0, 1, 2.  The default
178        protocol is 0, to be backwards compatible.  (Protocol 0 is the
179        only protocol that can be written to a file opened in text
180        mode and read back successfully.  When using a protocol higher
181        than 0, make sure the file is opened in binary mode, both when
182        pickling and unpickling.)
183
184        Protocol 1 is more efficient than protocol 0; protocol 2 is
185        more efficient than protocol 1.
186
187        Specifying a negative protocol version selects the highest
188        protocol version supported.  The higher the protocol used, the
189        more recent the version of Python needed to read the pickle
190        produced.
191
192        The file parameter must have a write() method that accepts a single
193        string argument.  It can thus be an open file object, a StringIO
194        object, or any other custom object that meets this interface.
195
196        """
197        if protocol is None:
198            protocol = 0
199        if protocol < 0:
200            protocol = HIGHEST_PROTOCOL
201        elif not 0 <= protocol <= HIGHEST_PROTOCOL:
202            raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
203        self.write = file.write
204        self.memo = {}
205        self.proto = int(protocol)
206        self.bin = protocol >= 1
207        self.fast = 0
208
209    def clear_memo(self):
210        """Clears the pickler's "memo".
211
212        The memo is the data structure that remembers which objects the
213        pickler has already seen, so that shared or recursive objects are
214        pickled by reference and not by value.  This method is useful when
215        re-using picklers.
216
217        """
218        self.memo.clear()
219
220    def dump(self, obj):
221        """Write a pickled representation of obj to the open file."""
222        if self.proto >= 2:
223            self.write(PROTO + chr(self.proto))
224        self.save(obj)
225        self.write(STOP)
226
227    def memoize(self, obj):
228        """Store an object in the memo."""
229
230        # The Pickler memo is a dictionary mapping object ids to 2-tuples
231        # that contain the Unpickler memo key and the object being memoized.
232        # The memo key is written to the pickle and will become
233        # the key in the Unpickler's memo.  The object is stored in the
234        # Pickler memo so that transient objects are kept alive during
235        # pickling.
236
237        # The use of the Unpickler memo length as the memo key is just a
238        # convention.  The only requirement is that the memo values be unique.
239        # But there appears no advantage to any other scheme, and this
240        # scheme allows the Unpickler memo to be implemented as a plain (but
241        # growable) array, indexed by memo key.
242        if self.fast:
243            return
244        assert id(obj) not in self.memo
245        memo_len = len(self.memo)
246        self.write(self.put(memo_len))
247        self.memo[id(obj)] = memo_len, obj
248
249    # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
250    def put(self, i, pack=struct.pack):
251        if self.bin:
252            if i < 256:
253                return BINPUT + chr(i)
254            else:
255                return LONG_BINPUT + pack("<i", i)
256
257        return PUT + repr(i) + '\n'
258
259    # Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
260    def get(self, i, pack=struct.pack):
261        if self.bin:
262            if i < 256:
263                return BINGET + chr(i)
264            else:
265                return LONG_BINGET + pack("<i", i)
266
267        return GET + repr(i) + '\n'
268
269    def save(self, obj):
270        # Check for persistent id (defined by a subclass)
271        pid = self.persistent_id(obj)
272        if pid is not None:
273            self.save_pers(pid)
274            return
275
276        # Check the memo
277        x = self.memo.get(id(obj))
278        if x:
279            self.write(self.get(x[0]))
280            return
281
282        # Check the type dispatch table
283        t = type(obj)
284        f = self.dispatch.get(t)
285        if f:
286            f(self, obj) # Call unbound method with explicit self
287            return
288
289        # Check copy_reg.dispatch_table
290        reduce = dispatch_table.get(t)
291        if reduce:
292            rv = reduce(obj)
293        else:
294            # Check for a class with a custom metaclass; treat as regular class
295            try:
296                issc = issubclass(t, TypeType)
297            except TypeError: # t is not a class (old Boost; see SF #502085)
298                issc = 0
299            if issc:
300                self.save_global(obj)
301                return
302
303            # Check for a __reduce_ex__ method, fall back to __reduce__
304            reduce = getattr(obj, "__reduce_ex__", None)
305            if reduce:
306                rv = reduce(self.proto)
307            else:
308                reduce = getattr(obj, "__reduce__", None)
309                if reduce:
310                    rv = reduce()
311                else:
312                    raise PicklingError("Can't pickle %r object: %r" %
313                                        (t.__name__, obj))
314
315        # Check for string returned by reduce(), meaning "save as global"
316        if type(rv) is StringType:
317            self.save_global(obj, rv)
318            return
319
320        # Assert that reduce() returned a tuple
321        if type(rv) is not TupleType:
322            raise PicklingError("%s must return string or tuple" % reduce)
323
324        # Assert that it returned an appropriately sized tuple
325        l = len(rv)
326        if not (2 <= l <= 5):
327            raise PicklingError("Tuple returned by %s must have "
328                                "two to five elements" % reduce)
329
330        # Save the reduce() output and finally memoize the object
331        self.save_reduce(obj=obj, *rv)
332
333    def persistent_id(self, obj):
334        # This exists so a subclass can override it
335        return None
336
337    def save_pers(self, pid):
338        # Save a persistent id reference
339        if self.bin:
340            self.save(pid)
341            self.write(BINPERSID)
342        else:
343            self.write(PERSID + str(pid) + '\n')
344
345    def save_reduce(self, func, args, state=None,
346                    listitems=None, dictitems=None, obj=None):
347        # This API is called by some subclasses
348
349        # Assert that args is a tuple or None
350        if not isinstance(args, TupleType):
351            raise PicklingError("args from reduce() should be a tuple")
352
353        # Assert that func is callable
354        if not hasattr(func, '__call__'):
355            raise PicklingError("func from reduce should be callable")
356
357        save = self.save
358        write = self.write
359
360        # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
361        if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
362            # A __reduce__ implementation can direct protocol 2 to
363            # use the more efficient NEWOBJ opcode, while still
364            # allowing protocol 0 and 1 to work normally.  For this to
365            # work, the function returned by __reduce__ should be
366            # called __newobj__, and its first argument should be a
367            # new-style class.  The implementation for __newobj__
368            # should be as follows, although pickle has no way to
369            # verify this:
370            #
371            # def __newobj__(cls, *args):
372            #     return cls.__new__(cls, *args)
373            #
374            # Protocols 0 and 1 will pickle a reference to __newobj__,
375            # while protocol 2 (and above) will pickle a reference to
376            # cls, the remaining args tuple, and the NEWOBJ code,
377            # which calls cls.__new__(cls, *args) at unpickling time
378            # (see load_newobj below).  If __reduce__ returns a
379            # three-tuple, the state from the third tuple item will be
380            # pickled regardless of the protocol, calling __setstate__
381            # at unpickling time (see load_build below).
382            #
383            # Note that no standard __newobj__ implementation exists;
384            # you have to provide your own.  This is to enforce
385            # compatibility with Python 2.2 (pickles written using
386            # protocol 0 or 1 in Python 2.3 should be unpicklable by
387            # Python 2.2).
388            cls = args[0]
389            if not hasattr(cls, "__new__"):
390                raise PicklingError(
391                    "args[0] from __newobj__ args has no __new__")
392            if obj is not None and cls is not obj.__class__:
393                raise PicklingError(
394                    "args[0] from __newobj__ args has the wrong class")
395            args = args[1:]
396            save(cls)
397            save(args)
398            write(NEWOBJ)
399        else:
400            save(func)
401            save(args)
402            write(REDUCE)
403
404        if obj is not None:
405            # If the object is already in the memo, this means it is
406            # recursive. In this case, throw away everything we put on the
407            # stack, and fetch the object back from the memo.
408            if id(obj) in self.memo:
409                write(POP + self.get(self.memo[id(obj)][0]))
410            else:
411                self.memoize(obj)
412
413        # More new special cases (that work with older protocols as
414        # well): when __reduce__ returns a tuple with 4 or 5 items,
415        # the 4th and 5th item should be iterators that provide list
416        # items and dict items (as (key, value) tuples), or None.
417
418        if listitems is not None:
419            self._batch_appends(listitems)
420
421        if dictitems is not None:
422            self._batch_setitems(dictitems)
423
424        if state is not None:
425            save(state)
426            write(BUILD)
427
428    # Methods below this point are dispatched through the dispatch table
429
430    dispatch = {}
431
432    def save_none(self, obj):
433        self.write(NONE)
434    dispatch[NoneType] = save_none
435
436    def save_bool(self, obj):
437        if self.proto >= 2:
438            self.write(obj and NEWTRUE or NEWFALSE)
439        else:
440            self.write(obj and TRUE or FALSE)
441    dispatch[bool] = save_bool
442
443    def save_int(self, obj, pack=struct.pack):
444        if self.bin:
445            # If the int is small enough to fit in a signed 4-byte 2's-comp
446            # format, we can store it more efficiently than the general
447            # case.
448            # First one- and two-byte unsigned ints:
449            if obj >= 0:
450                if obj <= 0xff:
451                    self.write(BININT1 + chr(obj))
452                    return
453                if obj <= 0xffff:
454                    self.write("%c%c%c" % (BININT2, obj&0xff, obj>>8))
455                    return
456            # Next check for 4-byte signed ints:
457            high_bits = obj >> 31  # note that Python shift sign-extends
458            if high_bits == 0 or high_bits == -1:
459                # All high bits are copies of bit 2**31, so the value
460                # fits in a 4-byte signed int.
461                self.write(BININT + pack("<i", obj))
462                return
463        # Text pickle, or int too big to fit in signed 4-byte format.
464        self.write(INT + repr(obj) + '\n')
465    dispatch[IntType] = save_int
466
467    def save_long(self, obj, pack=struct.pack):
468        if self.proto >= 2:
469            bytes = encode_long(obj)
470            n = len(bytes)
471            if n < 256:
472                self.write(LONG1 + chr(n) + bytes)
473            else:
474                self.write(LONG4 + pack("<i", n) + bytes)
475            return
476        self.write(LONG + repr(obj) + '\n')
477    dispatch[LongType] = save_long
478
479    def save_float(self, obj, pack=struct.pack):
480        if self.bin:
481            self.write(BINFLOAT + pack('>d', obj))
482        else:
483            self.write(FLOAT + repr(obj) + '\n')
484    dispatch[FloatType] = save_float
485
486    def save_string(self, obj, pack=struct.pack):
487        if self.bin:
488            n = len(obj)
489            if n < 256:
490                self.write(SHORT_BINSTRING + chr(n) + obj)
491            else:
492                self.write(BINSTRING + pack("<i", n) + obj)
493        else:
494            self.write(STRING + repr(obj) + '\n')
495        self.memoize(obj)
496    dispatch[StringType] = save_string
497
498    def save_unicode(self, obj, pack=struct.pack):
499        if self.bin:
500            encoding = obj.encode('utf-8')
501            n = len(encoding)
502            self.write(BINUNICODE + pack("<i", n) + encoding)
503        else:
504            obj = obj.replace("\\", "\\u005c")
505            obj = obj.replace("\n", "\\u000a")
506            self.write(UNICODE + obj.encode('raw-unicode-escape') + '\n')
507        self.memoize(obj)
508    dispatch[UnicodeType] = save_unicode
509
510    if StringType is UnicodeType:
511        # This is true for Jython
512        def save_string(self, obj, pack=struct.pack):
513            unicode = obj.isunicode()
514
515            if self.bin:
516                if unicode:
517                    obj = obj.encode("utf-8")
518                l = len(obj)
519                if l < 256 and not unicode:
520                    self.write(SHORT_BINSTRING + chr(l) + obj)
521                else:
522                    s = pack("<i", l)
523                    if unicode:
524                        self.write(BINUNICODE + s + obj)
525                    else:
526                        self.write(BINSTRING + s + obj)
527            else:
528                if unicode:
529                    obj = obj.replace("\\", "\\u005c")
530                    obj = obj.replace("\n", "\\u000a")
531                    obj = obj.encode('raw-unicode-escape')
532                    self.write(UNICODE + obj + '\n')
533                else:
534                    self.write(STRING + repr(obj) + '\n')
535            self.memoize(obj)
536        dispatch[StringType] = save_string
537
538    def save_tuple(self, obj):
539        write = self.write
540        proto = self.proto
541
542        n = len(obj)
543        if n == 0:
544            if proto:
545                write(EMPTY_TUPLE)
546            else:
547                write(MARK + TUPLE)
548            return
549
550        save = self.save
551        memo = self.memo
552        if n <= 3 and proto >= 2:
553            for element in obj:
554                save(element)
555            # Subtle.  Same as in the big comment below.
556            if id(obj) in memo:
557                get = self.get(memo[id(obj)][0])
558                write(POP * n + get)
559            else:
560                write(_tuplesize2code[n])
561                self.memoize(obj)
562            return
563
564        # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple
565        # has more than 3 elements.
566        write(MARK)
567        for element in obj:
568            save(element)
569
570        if id(obj) in memo:
571            # Subtle.  d was not in memo when we entered save_tuple(), so
572            # the process of saving the tuple's elements must have saved
573            # the tuple itself:  the tuple is recursive.  The proper action
574            # now is to throw away everything we put on the stack, and
575            # simply GET the tuple (it's already constructed).  This check
576            # could have been done in the "for element" loop instead, but
577            # recursive tuples are a rare thing.
578            get = self.get(memo[id(obj)][0])
579            if proto:
580                write(POP_MARK + get)
581            else:   # proto 0 -- POP_MARK not available
582                write(POP * (n+1) + get)
583            return
584
585        # No recursion.
586        self.write(TUPLE)
587        self.memoize(obj)
588
589    dispatch[TupleType] = save_tuple
590
591    # save_empty_tuple() isn't used by anything in Python 2.3.  However, I
592    # found a Pickler subclass in Zope3 that calls it, so it's not harmless
593    # to remove it.
594    def save_empty_tuple(self, obj):
595        self.write(EMPTY_TUPLE)
596
597    def save_list(self, obj):
598        write = self.write
599
600        if self.bin:
601            write(EMPTY_LIST)
602        else:   # proto 0 -- can't use EMPTY_LIST
603            write(MARK + LIST)
604
605        self.memoize(obj)
606        self._batch_appends(iter(obj))
607
608    dispatch[ListType] = save_list
609
610    # Keep in synch with cPickle's BATCHSIZE.  Nothing will break if it gets
611    # out of synch, though.
612    _BATCHSIZE = 1000
613
614    def _batch_appends(self, items):
615        # Helper to batch up APPENDS sequences
616        save = self.save
617        write = self.write
618
619        if not self.bin:
620            for x in items:
621                save(x)
622                write(APPEND)
623            return
624
625        r = xrange(self._BATCHSIZE)
626        while items is not None:
627            tmp = []
628            for i in r:
629                try:
630                    x = items.next()
631                    tmp.append(x)
632                except StopIteration:
633                    items = None
634                    break
635            n = len(tmp)
636            if n > 1:
637                write(MARK)
638                for x in tmp:
639                    save(x)
640                write(APPENDS)
641            elif n:
642                save(tmp[0])
643                write(APPEND)
644            # else tmp is empty, and we're done
645
646    def save_dict(self, obj):
647        write = self.write
648
649        if self.bin:
650            write(EMPTY_DICT)
651        else:   # proto 0 -- can't use EMPTY_DICT
652            write(MARK + DICT)
653
654        self.memoize(obj)
655        self._batch_setitems(obj.iteritems())
656
657    dispatch[DictionaryType] = save_dict
658    if not PyStringMap is None:
659        dispatch[PyStringMap] = save_dict
660
661    def _batch_setitems(self, items):
662        # Helper to batch up SETITEMS sequences; proto >= 1 only
663        save = self.save
664        write = self.write
665
666        if not self.bin:
667            for k, v in items:
668                save(k)
669                save(v)
670                write(SETITEM)
671            return
672
673        r = xrange(self._BATCHSIZE)
674        while items is not None:
675            tmp = []
676            for i in r:
677                try:
678                    tmp.append(items.next())
679                except StopIteration:
680                    items = None
681                    break
682            n = len(tmp)
683            if n > 1:
684                write(MARK)
685                for k, v in tmp:
686                    save(k)
687                    save(v)
688                write(SETITEMS)
689            elif n:
690                k, v = tmp[0]
691                save(k)
692                save(v)
693                write(SETITEM)
694            # else tmp is empty, and we're done
695
696    def save_inst(self, obj):
697        cls = obj.__class__
698
699        memo  = self.memo
700        write = self.write
701        save  = self.save
702
703        if hasattr(obj, '__getinitargs__'):
704            args = obj.__getinitargs__()
705            len(args) # XXX Assert it's a sequence
706            _keep_alive(args, memo)
707        else:
708            args = ()
709
710        write(MARK)
711
712        if self.bin:
713            save(cls)
714            for arg in args:
715                save(arg)
716            write(OBJ)
717        else:
718            for arg in args:
719                save(arg)
720            write(INST + cls.__module__ + '\n' + cls.__name__ + '\n')
721
722        self.memoize(obj)
723
724        try:
725            getstate = obj.__getstate__
726        except AttributeError:
727            stuff = obj.__dict__
728        else:
729            stuff = getstate()
730            _keep_alive(stuff, memo)
731        save(stuff)
732        write(BUILD)
733
734    dispatch[InstanceType] = save_inst
735
736    def save_global(self, obj, name=None, pack=struct.pack):
737        write = self.write
738        memo = self.memo
739
740        if name is None:
741            name = obj.__name__
742
743        module = getattr(obj, "__module__", None)
744        if module is None:
745            module = whichmodule(obj, name)
746
747        try:
748            __import__(module)
749            mod = sys.modules[module]
750            klass = getattr(mod, name)
751        except (ImportError, KeyError, AttributeError):
752            raise PicklingError(
753                "Can't pickle %r: it's not found as %s.%s" %
754                (obj, module, name))
755        else:
756            if klass is not obj:
757                raise PicklingError(
758                    "Can't pickle %r: it's not the same object as %s.%s" %
759                    (obj, module, name))
760
761        if self.proto >= 2:
762            code = _extension_registry.get((module, name))
763            if code:
764                assert code > 0
765                if code <= 0xff:
766                    write(EXT1 + chr(code))
767                elif code <= 0xffff:
768                    write("%c%c%c" % (EXT2, code&0xff, code>>8))
769                else:
770                    write(EXT4 + pack("<i", code))
771                return
772
773        write(GLOBAL + module + '\n' + name + '\n')
774        self.memoize(obj)
775
776    dispatch[ClassType] = save_global
777    dispatch[FunctionType] = save_global
778    dispatch[BuiltinFunctionType] = save_global
779    dispatch[TypeType] = save_global
780
781# Pickling helpers
782
783def _keep_alive(x, memo):
784    """Keeps a reference to the object x in the memo.
785
786    Because we remember objects by their id, we have
787    to assure that possibly temporary objects are kept
788    alive by referencing them.
789    We store a reference at the id of the memo, which should
790    normally not be used unless someone tries to deepcopy
791    the memo itself...
792    """
793    try:
794        memo[id(memo)].append(x)
795    except KeyError:
796        # aha, this is the first one :-)
797        memo[id(memo)]=[x]
798
799
800# A cache for whichmodule(), mapping a function object to the name of
801# the module in which the function was found.
802
803classmap = {} # called classmap for backwards compatibility
804
805def whichmodule(func, funcname):
806    """Figure out the module in which a function occurs.
807
808    Search sys.modules for the module.
809    Cache in classmap.
810    Return a module name.
811    If the function cannot be found, return "__main__".
812    """
813    # Python functions should always get an __module__ from their globals.
814    mod = getattr(func, "__module__", None)
815    if mod is not None:
816        return mod
817    if func in classmap:
818        return classmap[func]
819
820    for name, module in sys.modules.items():
821        if module is None:
822            continue # skip dummy package entries
823        if name != '__main__' and getattr(module, funcname, None) is func:
824            break
825    else:
826        name = '__main__'
827    classmap[func] = name
828    return name
829
830
831# Unpickling machinery
832
833class Unpickler:
834
835    def __init__(self, file):
836        """This takes a file-like object for reading a pickle data stream.
837
838        The protocol version of the pickle is detected automatically, so no
839        proto argument is needed.
840
841        The file-like object must have two methods, a read() method that
842        takes an integer argument, and a readline() method that requires no
843        arguments.  Both methods should return a string.  Thus file-like
844        object can be a file object opened for reading, a StringIO object,
845        or any other custom object that meets this interface.
846        """
847        self.readline = file.readline
848        self.read = file.read
849        self.memo = {}
850
851    def load(self):
852        """Read a pickled object representation from the open file.
853
854        Return the reconstituted object hierarchy specified in the file.
855        """
856        self.mark = object() # any new unique object
857        self.stack = []
858        self.append = self.stack.append
859        read = self.read
860        dispatch = self.dispatch
861        try:
862            while 1:
863                key = read(1)
864                dispatch[key](self)
865        except _Stop, stopinst:
866            return stopinst.value
867
868    # Return largest index k such that self.stack[k] is self.mark.
869    # If the stack doesn't contain a mark, eventually raises IndexError.
870    # This could be sped by maintaining another stack, of indices at which
871    # the mark appears.  For that matter, the latter stack would suffice,
872    # and we wouldn't need to push mark objects on self.stack at all.
873    # Doing so is probably a good thing, though, since if the pickle is
874    # corrupt (or hostile) we may get a clue from finding self.mark embedded
875    # in unpickled objects.
876    def marker(self):
877        stack = self.stack
878        mark = self.mark
879        k = len(stack)-1
880        while stack[k] is not mark: k = k-1
881        return k
882
883    dispatch = {}
884
885    def load_eof(self):
886        raise EOFError
887    dispatch[''] = load_eof
888
889    def load_proto(self):
890        proto = ord(self.read(1))
891        if not 0 <= proto <= 2:
892            raise ValueError, "unsupported pickle protocol: %d" % proto
893    dispatch[PROTO] = load_proto
894
895    def load_persid(self):
896        pid = self.readline()[:-1]
897        self.append(self.persistent_load(pid))
898    dispatch[PERSID] = load_persid
899
900    def load_binpersid(self):
901        pid = self.stack.pop()
902        self.append(self.persistent_load(pid))
903    dispatch[BINPERSID] = load_binpersid
904
905    def load_none(self):
906        self.append(None)
907    dispatch[NONE] = load_none
908
909    def load_false(self):
910        self.append(False)
911    dispatch[NEWFALSE] = load_false
912
913    def load_true(self):
914        self.append(True)
915    dispatch[NEWTRUE] = load_true
916
917    def load_int(self):
918        data = self.readline()
919        if data == FALSE[1:]:
920            val = False
921        elif data == TRUE[1:]:
922            val = True
923        else:
924            try:
925                val = int(data)
926            except ValueError:
927                val = long(data)
928        self.append(val)
929    dispatch[INT] = load_int
930
931    def load_binint(self):
932        self.append(mloads('i' + self.read(4)))
933    dispatch[BININT] = load_binint
934
935    def load_binint1(self):
936        self.append(ord(self.read(1)))
937    dispatch[BININT1] = load_binint1
938
939    def load_binint2(self):
940        self.append(mloads('i' + self.read(2) + '\000\000'))
941    dispatch[BININT2] = load_binint2
942
943    def load_long(self):
944        self.append(long(self.readline()[:-1], 0))
945    dispatch[LONG] = load_long
946
947    def load_long1(self):
948        n = ord(self.read(1))
949        bytes = self.read(n)
950        self.append(decode_long(bytes))
951    dispatch[LONG1] = load_long1
952
953    def load_long4(self):
954        n = mloads('i' + self.read(4))
955        bytes = self.read(n)
956        self.append(decode_long(bytes))
957    dispatch[LONG4] = load_long4
958
959    def load_float(self):
960        self.append(float(self.readline()[:-1]))
961    dispatch[FLOAT] = load_float
962
963    def load_binfloat(self, unpack=struct.unpack):
964        self.append(unpack('>d', self.read(8))[0])
965    dispatch[BINFLOAT] = load_binfloat
966
967    def load_string(self):
968        rep = self.readline()[:-1]
969        for q in "\"'": # double or single quote
970            if rep.startswith(q):
971                if len(rep) < 2 or not rep.endswith(q):
972                    raise ValueError, "insecure string pickle"
973                rep = rep[len(q):-len(q)]
974                break
975        else:
976            raise ValueError, "insecure string pickle"
977        self.append(rep.decode("string-escape"))
978    dispatch[STRING] = load_string
979
980    def load_binstring(self):
981        len = mloads('i' + self.read(4))
982        self.append(self.read(len))
983    dispatch[BINSTRING] = load_binstring
984
985    def load_unicode(self):
986        self.append(unicode(self.readline()[:-1],'raw-unicode-escape'))
987    dispatch[UNICODE] = load_unicode
988
989    def load_binunicode(self):
990        len = mloads('i' + self.read(4))
991        self.append(unicode(self.read(len),'utf-8'))
992    dispatch[BINUNICODE] = load_binunicode
993
994    def load_short_binstring(self):
995        len = ord(self.read(1))
996        self.append(self.read(len))
997    dispatch[SHORT_BINSTRING] = load_short_binstring
998
999    def load_tuple(self):
1000        k = self.marker()
1001        self.stack[k:] = [tuple(self.stack[k+1:])]
1002    dispatch[TUPLE] = load_tuple
1003
1004    def load_empty_tuple(self):
1005        self.stack.append(())
1006    dispatch[EMPTY_TUPLE] = load_empty_tuple
1007
1008    def load_tuple1(self):
1009        self.stack[-1] = (self.stack[-1],)
1010    dispatch[TUPLE1] = load_tuple1
1011
1012    def load_tuple2(self):
1013        self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
1014    dispatch[TUPLE2] = load_tuple2
1015
1016    def load_tuple3(self):
1017        self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
1018    dispatch[TUPLE3] = load_tuple3
1019
1020    def load_empty_list(self):
1021        self.stack.append([])
1022    dispatch[EMPTY_LIST] = load_empty_list
1023
1024    def load_empty_dictionary(self):
1025        self.stack.append({})
1026    dispatch[EMPTY_DICT] = load_empty_dictionary
1027
1028    def load_list(self):
1029        k = self.marker()
1030        self.stack[k:] = [self.stack[k+1:]]
1031    dispatch[LIST] = load_list
1032
1033    def load_dict(self):
1034        k = self.marker()
1035        d = {}
1036        items = self.stack[k+1:]
1037        for i in range(0, len(items), 2):
1038            key = items[i]
1039            value = items[i+1]
1040            d[key] = value
1041        self.stack[k:] = [d]
1042    dispatch[DICT] = load_dict
1043
1044    # INST and OBJ differ only in how they get a class object.  It's not
1045    # only sensible to do the rest in a common routine, the two routines
1046    # previously diverged and grew different bugs.
1047    # klass is the class to instantiate, and k points to the topmost mark
1048    # object, following which are the arguments for klass.__init__.
1049    def _instantiate(self, klass, k):
1050        args = tuple(self.stack[k+1:])
1051        del self.stack[k:]
1052        instantiated = 0
1053        if (not args and
1054                type(klass) is ClassType and
1055                not hasattr(klass, "__getinitargs__")):
1056            try:
1057                value = _EmptyClass()
1058                value.__class__ = klass
1059                instantiated = 1
1060            except RuntimeError:
1061                # In restricted execution, assignment to inst.__class__ is
1062                # prohibited
1063                pass
1064        if not instantiated:
1065            try:
1066                value = klass(*args)
1067            except TypeError, err:
1068                raise TypeError, "in constructor for %s: %s" % (
1069                    klass.__name__, str(err)), sys.exc_info()[2]
1070        self.append(value)
1071
1072    def load_inst(self):
1073        module = self.readline()[:-1]
1074        name = self.readline()[:-1]
1075        klass = self.find_class(module, name)
1076        self._instantiate(klass, self.marker())
1077    dispatch[INST] = load_inst
1078
1079    def load_obj(self):
1080        # Stack is ... markobject classobject arg1 arg2 ...
1081        k = self.marker()
1082        klass = self.stack.pop(k+1)
1083        self._instantiate(klass, k)
1084    dispatch[OBJ] = load_obj
1085
1086    def load_newobj(self):
1087        args = self.stack.pop()
1088        cls = self.stack[-1]
1089        obj = cls.__new__(cls, *args)
1090        self.stack[-1] = obj
1091    dispatch[NEWOBJ] = load_newobj
1092
1093    def load_global(self):
1094        module = self.readline()[:-1]
1095        name = self.readline()[:-1]
1096        klass = self.find_class(module, name)
1097        self.append(klass)
1098    dispatch[GLOBAL] = load_global
1099
1100    def load_ext1(self):
1101        code = ord(self.read(1))
1102        self.get_extension(code)
1103    dispatch[EXT1] = load_ext1
1104
1105    def load_ext2(self):
1106        code = mloads('i' + self.read(2) + '\000\000')
1107        self.get_extension(code)
1108    dispatch[EXT2] = load_ext2
1109
1110    def load_ext4(self):
1111        code = mloads('i' + self.read(4))
1112        self.get_extension(code)
1113    dispatch[EXT4] = load_ext4
1114
1115    def get_extension(self, code):
1116        nil = []
1117        obj = _extension_cache.get(code, nil)
1118        if obj is not nil:
1119            self.append(obj)
1120            return
1121        key = _inverted_registry.get(code)
1122        if not key:
1123            raise ValueError("unregistered extension code %d" % code)
1124        obj = self.find_class(*key)
1125        _extension_cache[code] = obj
1126        self.append(obj)
1127
1128    def find_class(self, module, name):
1129        # Subclasses may override this
1130        __import__(module)
1131        mod = sys.modules[module]
1132        klass = getattr(mod, name)
1133        return klass
1134
1135    def load_reduce(self):
1136        stack = self.stack
1137        args = stack.pop()
1138        func = stack[-1]
1139        value = func(*args)
1140        stack[-1] = value
1141    dispatch[REDUCE] = load_reduce
1142
1143    def load_pop(self):
1144        del self.stack[-1]
1145    dispatch[POP] = load_pop
1146
1147    def load_pop_mark(self):
1148        k = self.marker()
1149        del self.stack[k:]
1150    dispatch[POP_MARK] = load_pop_mark
1151
1152    def load_dup(self):
1153        self.append(self.stack[-1])
1154    dispatch[DUP] = load_dup
1155
1156    def load_get(self):
1157        self.append(self.memo[self.readline()[:-1]])
1158    dispatch[GET] = load_get
1159
1160    def load_binget(self):
1161        i = ord(self.read(1))
1162        self.append(self.memo[repr(i)])
1163    dispatch[BINGET] = load_binget
1164
1165    def load_long_binget(self):
1166        i = mloads('i' + self.read(4))
1167        self.append(self.memo[repr(i)])
1168    dispatch[LONG_BINGET] = load_long_binget
1169
1170    def load_put(self):
1171        self.memo[self.readline()[:-1]] = self.stack[-1]
1172    dispatch[PUT] = load_put
1173
1174    def load_binput(self):
1175        i = ord(self.read(1))
1176        self.memo[repr(i)] = self.stack[-1]
1177    dispatch[BINPUT] = load_binput
1178
1179    def load_long_binput(self):
1180        i = mloads('i' + self.read(4))
1181        self.memo[repr(i)] = self.stack[-1]
1182    dispatch[LONG_BINPUT] = load_long_binput
1183
1184    def load_append(self):
1185        stack = self.stack
1186        value = stack.pop()
1187        list = stack[-1]
1188        list.append(value)
1189    dispatch[APPEND] = load_append
1190
1191    def load_appends(self):
1192        stack = self.stack
1193        mark = self.marker()
1194        list = stack[mark - 1]
1195        list.extend(stack[mark + 1:])
1196        del stack[mark:]
1197    dispatch[APPENDS] = load_appends
1198
1199    def load_setitem(self):
1200        stack = self.stack
1201        value = stack.pop()
1202        key = stack.pop()
1203        dict = stack[-1]
1204        dict[key] = value
1205    dispatch[SETITEM] = load_setitem
1206
1207    def load_setitems(self):
1208        stack = self.stack
1209        mark = self.marker()
1210        dict = stack[mark - 1]
1211        for i in range(mark + 1, len(stack), 2):
1212            dict[stack[i]] = stack[i + 1]
1213
1214        del stack[mark:]
1215    dispatch[SETITEMS] = load_setitems
1216
1217    def load_build(self):
1218        stack = self.stack
1219        state = stack.pop()
1220        inst = stack[-1]
1221        setstate = getattr(inst, "__setstate__", None)
1222        if setstate:
1223            setstate(state)
1224            return
1225        slotstate = None
1226        if isinstance(state, tuple) and len(state) == 2:
1227            state, slotstate = state
1228        if state:
1229            try:
1230                d = inst.__dict__
1231                try:
1232                    for k, v in state.iteritems():
1233                        d[intern(k)] = v
1234                # keys in state don't have to be strings
1235                # don't blow up, but don't go out of our way
1236                except TypeError:
1237                    d.update(state)
1238
1239            except RuntimeError:
1240                # XXX In restricted execution, the instance's __dict__
1241                # is not accessible.  Use the old way of unpickling
1242                # the instance variables.  This is a semantic
1243                # difference when unpickling in restricted
1244                # vs. unrestricted modes.
1245                # Note, however, that cPickle has never tried to do the
1246                # .update() business, and always uses
1247                #     PyObject_SetItem(inst.__dict__, key, value) in a
1248                # loop over state.items().
1249                for k, v in state.items():
1250                    setattr(inst, k, v)
1251        if slotstate:
1252            for k, v in slotstate.items():
1253                setattr(inst, k, v)
1254    dispatch[BUILD] = load_build
1255
1256    def load_mark(self):
1257        self.append(self.mark)
1258    dispatch[MARK] = load_mark
1259
1260    def load_stop(self):
1261        value = self.stack.pop()
1262        raise _Stop(value)
1263    dispatch[STOP] = load_stop
1264
1265# Helper class for load_inst/load_obj
1266
1267class _EmptyClass:
1268    pass
1269
1270# Encode/decode longs in linear time.
1271
1272import binascii as _binascii
1273
1274def encode_long(x):
1275    r"""Encode a long to a two's complement little-endian binary string.
1276    Note that 0L is a special case, returning an empty string, to save a
1277    byte in the LONG1 pickling context.
1278
1279    >>> encode_long(0L)
1280    ''
1281    >>> encode_long(255L)
1282    '\xff\x00'
1283    >>> encode_long(32767L)
1284    '\xff\x7f'
1285    >>> encode_long(-256L)
1286    '\x00\xff'
1287    >>> encode_long(-32768L)
1288    '\x00\x80'
1289    >>> encode_long(-128L)
1290    '\x80'
1291    >>> encode_long(127L)
1292    '\x7f'
1293    >>>
1294    """
1295
1296    if x == 0:
1297        return ''
1298    if x > 0:
1299        ashex = hex(x)
1300        assert ashex.startswith("0x")
1301        njunkchars = 2 + ashex.endswith('L')
1302        nibbles = len(ashex) - njunkchars
1303        if nibbles & 1:
1304            # need an even # of nibbles for unhexlify
1305            ashex = "0x0" + ashex[2:]
1306        elif int(ashex[2], 16) >= 8:
1307            # "looks negative", so need a byte of sign bits
1308            ashex = "0x00" + ashex[2:]
1309    else:
1310        # Build the 256's-complement:  (1L << nbytes) + x.  The trick is
1311        # to find the number of bytes in linear time (although that should
1312        # really be a constant-time task).
1313        ashex = hex(-x)
1314        assert ashex.startswith("0x")
1315        njunkchars = 2 + ashex.endswith('L')
1316        nibbles = len(ashex) - njunkchars
1317        if nibbles & 1:
1318            # Extend to a full byte.
1319            nibbles += 1
1320        nbits = nibbles * 4
1321        x += 1L << nbits
1322        assert x > 0
1323        ashex = hex(x)
1324        njunkchars = 2 + ashex.endswith('L')
1325        newnibbles = len(ashex) - njunkchars
1326        if newnibbles < nibbles:
1327            ashex = "0x" + "0" * (nibbles - newnibbles) + ashex[2:]
1328        if int(ashex[2], 16) < 8:
1329            # "looks positive", so need a byte of sign bits
1330            ashex = "0xff" + ashex[2:]
1331
1332    if ashex.endswith('L'):
1333        ashex = ashex[2:-1]
1334    else:
1335        ashex = ashex[2:]
1336    assert len(ashex) & 1 == 0, (x, ashex)
1337    binary = _binascii.unhexlify(ashex)
1338    return binary[::-1]
1339
1340def decode_long(data):
1341    r"""Decode a long from a two's complement little-endian binary string.
1342
1343    >>> decode_long('')
1344    0L
1345    >>> decode_long("\xff\x00")
1346    255L
1347    >>> decode_long("\xff\x7f")
1348    32767L
1349    >>> decode_long("\x00\xff")
1350    -256L
1351    >>> decode_long("\x00\x80")
1352    -32768L
1353    >>> decode_long("\x80")
1354    -128L
1355    >>> decode_long("\x7f")
1356    127L
1357    """
1358
1359    nbytes = len(data)
1360    if nbytes == 0:
1361        return 0L
1362    ashex = _binascii.hexlify(data[::-1])
1363    n = long(ashex, 16) # quadratic time before Python 2.3; linear now
1364    if data[-1] >= '\x80':
1365        n -= 1L << (nbytes * 8)
1366    return n
1367
1368# Shorthands
1369
1370try:
1371    from cStringIO import StringIO
1372except ImportError:
1373    from StringIO import StringIO
1374
1375def dump(obj, file, protocol=None):
1376    Pickler(file, protocol).dump(obj)
1377
1378def dumps(obj, protocol=None):
1379    file = StringIO()
1380    Pickler(file, protocol).dump(obj)
1381    return file.getvalue()
1382
1383def load(file):
1384    return Unpickler(file).load()
1385
1386def loads(str):
1387    file = StringIO(str)
1388    return Unpickler(file).load()
1389
1390# Doctest
1391
1392def _test():
1393    import doctest
1394    return doctest.testmod()
1395
1396if __name__ == "__main__":
1397    _test()
1398