• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Module which deals with pickling of objects.
3#
4# multiprocessing/reduction.py
5#
6# Copyright (c) 2006-2008, R Oudkerk
7# Licensed to PSF under a Contributor Agreement.
8#
9
10from abc import ABCMeta
11import copyreg
12import functools
13import io
14import os
15import pickle
16import socket
17import sys
18
19from . import context
20
21__all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump']
22
23
24HAVE_SEND_HANDLE = (sys.platform == 'win32' or
25                    (hasattr(socket, 'CMSG_LEN') and
26                     hasattr(socket, 'SCM_RIGHTS') and
27                     hasattr(socket.socket, 'sendmsg')))
28
29#
30# Pickler subclass
31#
32
33class ForkingPickler(pickle.Pickler):
34    '''Pickler subclass used by multiprocessing.'''
35    _extra_reducers = {}
36    _copyreg_dispatch_table = copyreg.dispatch_table
37
38    def __init__(self, *args):
39        super().__init__(*args)
40        self.dispatch_table = self._copyreg_dispatch_table.copy()
41        self.dispatch_table.update(self._extra_reducers)
42
43    @classmethod
44    def register(cls, type, reduce):
45        '''Register a reduce function for a type.'''
46        cls._extra_reducers[type] = reduce
47
48    @classmethod
49    def dumps(cls, obj, protocol=None):
50        buf = io.BytesIO()
51        cls(buf, protocol).dump(obj)
52        return buf.getbuffer()
53
54    loads = pickle.loads
55
56register = ForkingPickler.register
57
58def dump(obj, file, protocol=None):
59    '''Replacement for pickle.dump() using ForkingPickler.'''
60    ForkingPickler(file, protocol).dump(obj)
61
62#
63# Platform specific definitions
64#
65
66if sys.platform == 'win32':
67    # Windows
68    __all__ += ['DupHandle', 'duplicate', 'steal_handle']
69    import _winapi
70
71    def duplicate(handle, target_process=None, inheritable=False,
72                  *, source_process=None):
73        '''Duplicate a handle.  (target_process is a handle not a pid!)'''
74        current_process = _winapi.GetCurrentProcess()
75        if source_process is None:
76            source_process = current_process
77        if target_process is None:
78            target_process = current_process
79        return _winapi.DuplicateHandle(
80            source_process, handle, target_process,
81            0, inheritable, _winapi.DUPLICATE_SAME_ACCESS)
82
83    def steal_handle(source_pid, handle):
84        '''Steal a handle from process identified by source_pid.'''
85        source_process_handle = _winapi.OpenProcess(
86            _winapi.PROCESS_DUP_HANDLE, False, source_pid)
87        try:
88            return _winapi.DuplicateHandle(
89                source_process_handle, handle,
90                _winapi.GetCurrentProcess(), 0, False,
91                _winapi.DUPLICATE_SAME_ACCESS | _winapi.DUPLICATE_CLOSE_SOURCE)
92        finally:
93            _winapi.CloseHandle(source_process_handle)
94
95    def send_handle(conn, handle, destination_pid):
96        '''Send a handle over a local connection.'''
97        dh = DupHandle(handle, _winapi.DUPLICATE_SAME_ACCESS, destination_pid)
98        conn.send(dh)
99
100    def recv_handle(conn):
101        '''Receive a handle over a local connection.'''
102        return conn.recv().detach()
103
104    class DupHandle(object):
105        '''Picklable wrapper for a handle.'''
106        def __init__(self, handle, access, pid=None):
107            if pid is None:
108                # We just duplicate the handle in the current process and
109                # let the receiving process steal the handle.
110                pid = os.getpid()
111            proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, pid)
112            try:
113                self._handle = _winapi.DuplicateHandle(
114                    _winapi.GetCurrentProcess(),
115                    handle, proc, access, False, 0)
116            finally:
117                _winapi.CloseHandle(proc)
118            self._access = access
119            self._pid = pid
120
121        def detach(self):
122            '''Get the handle.  This should only be called once.'''
123            # retrieve handle from process which currently owns it
124            if self._pid == os.getpid():
125                # The handle has already been duplicated for this process.
126                return self._handle
127            # We must steal the handle from the process whose pid is self._pid.
128            proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False,
129                                       self._pid)
130            try:
131                return _winapi.DuplicateHandle(
132                    proc, self._handle, _winapi.GetCurrentProcess(),
133                    self._access, False, _winapi.DUPLICATE_CLOSE_SOURCE)
134            finally:
135                _winapi.CloseHandle(proc)
136
137else:
138    # Unix
139    __all__ += ['DupFd', 'sendfds', 'recvfds']
140    import array
141
142    # On MacOSX we should acknowledge receipt of fds -- see Issue14669
143    ACKNOWLEDGE = sys.platform == 'darwin'
144
145    def sendfds(sock, fds):
146        '''Send an array of fds over an AF_UNIX socket.'''
147        fds = array.array('i', fds)
148        msg = bytes([len(fds) % 256])
149        sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)])
150        if ACKNOWLEDGE and sock.recv(1) != b'A':
151            raise RuntimeError('did not receive acknowledgement of fd')
152
153    def recvfds(sock, size):
154        '''Receive an array of fds over an AF_UNIX socket.'''
155        a = array.array('i')
156        bytes_size = a.itemsize * size
157        msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_SPACE(bytes_size))
158        if not msg and not ancdata:
159            raise EOFError
160        try:
161            if ACKNOWLEDGE:
162                sock.send(b'A')
163            if len(ancdata) != 1:
164                raise RuntimeError('received %d items of ancdata' %
165                                   len(ancdata))
166            cmsg_level, cmsg_type, cmsg_data = ancdata[0]
167            if (cmsg_level == socket.SOL_SOCKET and
168                cmsg_type == socket.SCM_RIGHTS):
169                if len(cmsg_data) % a.itemsize != 0:
170                    raise ValueError
171                a.frombytes(cmsg_data)
172                if len(a) % 256 != msg[0]:
173                    raise AssertionError(
174                        "Len is {0:n} but msg[0] is {1!r}".format(
175                            len(a), msg[0]))
176                return list(a)
177        except (ValueError, IndexError):
178            pass
179        raise RuntimeError('Invalid data received')
180
181    def send_handle(conn, handle, destination_pid):
182        '''Send a handle over a local connection.'''
183        with socket.fromfd(conn.fileno(), socket.AF_UNIX, socket.SOCK_STREAM) as s:
184            sendfds(s, [handle])
185
186    def recv_handle(conn):
187        '''Receive a handle over a local connection.'''
188        with socket.fromfd(conn.fileno(), socket.AF_UNIX, socket.SOCK_STREAM) as s:
189            return recvfds(s, 1)[0]
190
191    def DupFd(fd):
192        '''Return a wrapper for an fd.'''
193        popen_obj = context.get_spawning_popen()
194        if popen_obj is not None:
195            return popen_obj.DupFd(popen_obj.duplicate_for_child(fd))
196        elif HAVE_SEND_HANDLE:
197            from . import resource_sharer
198            return resource_sharer.DupFd(fd)
199        else:
200            raise ValueError('SCM_RIGHTS appears not to be available')
201
202#
203# Try making some callable types picklable
204#
205
206def _reduce_method(m):
207    if m.__self__ is None:
208        return getattr, (m.__class__, m.__func__.__name__)
209    else:
210        return getattr, (m.__self__, m.__func__.__name__)
211class _C:
212    def f(self):
213        pass
214register(type(_C().f), _reduce_method)
215
216
217def _reduce_method_descriptor(m):
218    return getattr, (m.__objclass__, m.__name__)
219register(type(list.append), _reduce_method_descriptor)
220register(type(int.__add__), _reduce_method_descriptor)
221
222
223def _reduce_partial(p):
224    return _rebuild_partial, (p.func, p.args, p.keywords or {})
225def _rebuild_partial(func, args, keywords):
226    return functools.partial(func, *args, **keywords)
227register(functools.partial, _reduce_partial)
228
229#
230# Make sockets picklable
231#
232
233if sys.platform == 'win32':
234    def _reduce_socket(s):
235        from .resource_sharer import DupSocket
236        return _rebuild_socket, (DupSocket(s),)
237    def _rebuild_socket(ds):
238        return ds.detach()
239    register(socket.socket, _reduce_socket)
240
241else:
242    def _reduce_socket(s):
243        df = DupFd(s.fileno())
244        return _rebuild_socket, (df, s.family, s.type, s.proto)
245    def _rebuild_socket(df, family, type, proto):
246        fd = df.detach()
247        return socket.socket(family, type, proto, fileno=fd)
248    register(socket.socket, _reduce_socket)
249
250
251class AbstractReducer(metaclass=ABCMeta):
252    '''Abstract base class for use in implementing a Reduction class
253    suitable for use in replacing the standard reduction mechanism
254    used in multiprocessing.'''
255    ForkingPickler = ForkingPickler
256    register = register
257    dump = dump
258    send_handle = send_handle
259    recv_handle = recv_handle
260
261    if sys.platform == 'win32':
262        steal_handle = steal_handle
263        duplicate = duplicate
264        DupHandle = DupHandle
265    else:
266        sendfds = sendfds
267        recvfds = recvfds
268        DupFd = DupFd
269
270    _reduce_method = _reduce_method
271    _reduce_method_descriptor = _reduce_method_descriptor
272    _rebuild_partial = _rebuild_partial
273    _reduce_socket = _reduce_socket
274    _rebuild_socket = _rebuild_socket
275
276    def __init__(self, *args):
277        register(type(_C().f), _reduce_method)
278        register(type(list.append), _reduce_method_descriptor)
279        register(type(int.__add__), _reduce_method_descriptor)
280        register(functools.partial, _reduce_partial)
281        register(socket.socket, _reduce_socket)
282