• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# A higher level module for using sockets (or Windows named pipes)
3#
4# multiprocessing/connection.py
5#
6# Copyright (c) 2006-2008, R Oudkerk
7# All rights reserved.
8#
9# Redistribution and use in source and binary forms, with or without
10# modification, are permitted provided that the following conditions
11# are met:
12#
13# 1. Redistributions of source code must retain the above copyright
14#    notice, this list of conditions and the following disclaimer.
15# 2. Redistributions in binary form must reproduce the above copyright
16#    notice, this list of conditions and the following disclaimer in the
17#    documentation and/or other materials provided with the distribution.
18# 3. Neither the name of author nor the names of any contributors may be
19#    used to endorse or promote products derived from this software
20#    without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
23# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
26# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
28# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
29# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
31# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32# SUCH DAMAGE.
33#
34
35__all__ = [ 'Client', 'Listener', 'Pipe' ]
36
37import os
38import sys
39import socket
40import errno
41import time
42import tempfile
43import itertools
44
45import _multiprocessing
46from multiprocessing import current_process, AuthenticationError
47from multiprocessing.util import get_temp_dir, Finalize, sub_debug, debug
48from multiprocessing.forking import duplicate, close
49
50
51#
52#
53#
54
55BUFSIZE = 8192
56# A very generous timeout when it comes to local connections...
57CONNECTION_TIMEOUT = 20.
58
59_mmap_counter = itertools.count()
60
61default_family = 'AF_INET'
62families = ['AF_INET']
63
64if hasattr(socket, 'AF_UNIX'):
65    default_family = 'AF_UNIX'
66    families += ['AF_UNIX']
67
68if sys.platform == 'win32':
69    default_family = 'AF_PIPE'
70    families += ['AF_PIPE']
71
72
73def _init_timeout(timeout=CONNECTION_TIMEOUT):
74    return time.time() + timeout
75
76def _check_timeout(t):
77    return time.time() > t
78
79#
80#
81#
82
83def arbitrary_address(family):
84    '''
85    Return an arbitrary free address for the given family
86    '''
87    if family == 'AF_INET':
88        return ('localhost', 0)
89    elif family == 'AF_UNIX':
90        return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
91    elif family == 'AF_PIPE':
92        return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
93                               (os.getpid(), _mmap_counter.next()), dir="")
94    else:
95        raise ValueError('unrecognized family')
96
97
98def address_type(address):
99    '''
100    Return the types of the address
101
102    This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
103    '''
104    if type(address) == tuple:
105        return 'AF_INET'
106    elif type(address) is str and address.startswith('\\\\'):
107        return 'AF_PIPE'
108    elif type(address) is str:
109        return 'AF_UNIX'
110    else:
111        raise ValueError('address type of %r unrecognized' % address)
112
113#
114# Public functions
115#
116
117class Listener(object):
118    '''
119    Returns a listener object.
120
121    This is a wrapper for a bound socket which is 'listening' for
122    connections, or for a Windows named pipe.
123    '''
124    def __init__(self, address=None, family=None, backlog=1, authkey=None):
125        family = family or (address and address_type(address)) \
126                 or default_family
127        address = address or arbitrary_address(family)
128
129        if family == 'AF_PIPE':
130            self._listener = PipeListener(address, backlog)
131        else:
132            self._listener = SocketListener(address, family, backlog)
133
134        if authkey is not None and not isinstance(authkey, bytes):
135            raise TypeError, 'authkey should be a byte string'
136
137        self._authkey = authkey
138
139    def accept(self):
140        '''
141        Accept a connection on the bound socket or named pipe of `self`.
142
143        Returns a `Connection` object.
144        '''
145        c = self._listener.accept()
146        if self._authkey:
147            deliver_challenge(c, self._authkey)
148            answer_challenge(c, self._authkey)
149        return c
150
151    def close(self):
152        '''
153        Close the bound socket or named pipe of `self`.
154        '''
155        return self._listener.close()
156
157    address = property(lambda self: self._listener._address)
158    last_accepted = property(lambda self: self._listener._last_accepted)
159
160
161def Client(address, family=None, authkey=None):
162    '''
163    Returns a connection to the address of a `Listener`
164    '''
165    family = family or address_type(address)
166    if family == 'AF_PIPE':
167        c = PipeClient(address)
168    else:
169        c = SocketClient(address)
170
171    if authkey is not None and not isinstance(authkey, bytes):
172        raise TypeError, 'authkey should be a byte string'
173
174    if authkey is not None:
175        answer_challenge(c, authkey)
176        deliver_challenge(c, authkey)
177
178    return c
179
180
181if sys.platform != 'win32':
182
183    def Pipe(duplex=True):
184        '''
185        Returns pair of connection objects at either end of a pipe
186        '''
187        if duplex:
188            s1, s2 = socket.socketpair()
189            s1.setblocking(True)
190            s2.setblocking(True)
191            c1 = _multiprocessing.Connection(os.dup(s1.fileno()))
192            c2 = _multiprocessing.Connection(os.dup(s2.fileno()))
193            s1.close()
194            s2.close()
195        else:
196            fd1, fd2 = os.pipe()
197            c1 = _multiprocessing.Connection(fd1, writable=False)
198            c2 = _multiprocessing.Connection(fd2, readable=False)
199
200        return c1, c2
201
202else:
203    from _multiprocessing import win32
204
205    def Pipe(duplex=True):
206        '''
207        Returns pair of connection objects at either end of a pipe
208        '''
209        address = arbitrary_address('AF_PIPE')
210        if duplex:
211            openmode = win32.PIPE_ACCESS_DUPLEX
212            access = win32.GENERIC_READ | win32.GENERIC_WRITE
213            obsize, ibsize = BUFSIZE, BUFSIZE
214        else:
215            openmode = win32.PIPE_ACCESS_INBOUND
216            access = win32.GENERIC_WRITE
217            obsize, ibsize = 0, BUFSIZE
218
219        h1 = win32.CreateNamedPipe(
220            address, openmode,
221            win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
222            win32.PIPE_WAIT,
223            1, obsize, ibsize, win32.NMPWAIT_WAIT_FOREVER, win32.NULL
224            )
225        h2 = win32.CreateFile(
226            address, access, 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
227            )
228        win32.SetNamedPipeHandleState(
229            h2, win32.PIPE_READMODE_MESSAGE, None, None
230            )
231
232        try:
233            win32.ConnectNamedPipe(h1, win32.NULL)
234        except WindowsError, e:
235            if e.args[0] != win32.ERROR_PIPE_CONNECTED:
236                raise
237
238        c1 = _multiprocessing.PipeConnection(h1, writable=duplex)
239        c2 = _multiprocessing.PipeConnection(h2, readable=duplex)
240
241        return c1, c2
242
243#
244# Definitions for connections based on sockets
245#
246
247class SocketListener(object):
248    '''
249    Representation of a socket which is bound to an address and listening
250    '''
251    def __init__(self, address, family, backlog=1):
252        self._socket = socket.socket(getattr(socket, family))
253        try:
254            self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
255            self._socket.setblocking(True)
256            self._socket.bind(address)
257            self._socket.listen(backlog)
258            self._address = self._socket.getsockname()
259        except socket.error:
260            self._socket.close()
261            raise
262        self._family = family
263        self._last_accepted = None
264
265        if family == 'AF_UNIX':
266            self._unlink = Finalize(
267                self, os.unlink, args=(address,), exitpriority=0
268                )
269        else:
270            self._unlink = None
271
272    def accept(self):
273        while True:
274            try:
275                s, self._last_accepted = self._socket.accept()
276            except socket.error as e:
277                if e.args[0] != errno.EINTR:
278                    raise
279            else:
280                break
281        s.setblocking(True)
282        fd = duplicate(s.fileno())
283        conn = _multiprocessing.Connection(fd)
284        s.close()
285        return conn
286
287    def close(self):
288        try:
289            self._socket.close()
290        finally:
291            unlink = self._unlink
292            if unlink is not None:
293                self._unlink = None
294                unlink()
295
296
297def SocketClient(address):
298    '''
299    Return a connection object connected to the socket given by `address`
300    '''
301    family = getattr(socket, address_type(address))
302    t = _init_timeout()
303
304    while 1:
305        s = socket.socket(family)
306        s.setblocking(True)
307        try:
308            s.connect(address)
309        except socket.error, e:
310            s.close()
311            if e.args[0] != errno.ECONNREFUSED or _check_timeout(t):
312                debug('failed to connect to address %s', address)
313                raise
314            time.sleep(0.01)
315        else:
316            break
317    else:
318        raise
319
320    fd = duplicate(s.fileno())
321    conn = _multiprocessing.Connection(fd)
322    s.close()
323    return conn
324
325#
326# Definitions for connections based on named pipes
327#
328
329if sys.platform == 'win32':
330
331    class PipeListener(object):
332        '''
333        Representation of a named pipe
334        '''
335        def __init__(self, address, backlog=None):
336            self._address = address
337            handle = win32.CreateNamedPipe(
338                address, win32.PIPE_ACCESS_DUPLEX,
339                win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
340                win32.PIPE_WAIT,
341                win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
342                win32.NMPWAIT_WAIT_FOREVER, win32.NULL
343                )
344            self._handle_queue = [handle]
345            self._last_accepted = None
346
347            sub_debug('listener created with address=%r', self._address)
348
349            self.close = Finalize(
350                self, PipeListener._finalize_pipe_listener,
351                args=(self._handle_queue, self._address), exitpriority=0
352                )
353
354        def accept(self):
355            newhandle = win32.CreateNamedPipe(
356                self._address, win32.PIPE_ACCESS_DUPLEX,
357                win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
358                win32.PIPE_WAIT,
359                win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
360                win32.NMPWAIT_WAIT_FOREVER, win32.NULL
361                )
362            self._handle_queue.append(newhandle)
363            handle = self._handle_queue.pop(0)
364            try:
365                win32.ConnectNamedPipe(handle, win32.NULL)
366            except WindowsError, e:
367                # ERROR_NO_DATA can occur if a client has already connected,
368                # written data and then disconnected -- see Issue 14725.
369                if e.args[0] not in (win32.ERROR_PIPE_CONNECTED,
370                                     win32.ERROR_NO_DATA):
371                    raise
372            return _multiprocessing.PipeConnection(handle)
373
374        @staticmethod
375        def _finalize_pipe_listener(queue, address):
376            sub_debug('closing listener with address=%r', address)
377            for handle in queue:
378                close(handle)
379
380    def PipeClient(address):
381        '''
382        Return a connection object connected to the pipe given by `address`
383        '''
384        t = _init_timeout()
385        while 1:
386            try:
387                win32.WaitNamedPipe(address, 1000)
388                h = win32.CreateFile(
389                    address, win32.GENERIC_READ | win32.GENERIC_WRITE,
390                    0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
391                    )
392            except WindowsError, e:
393                if e.args[0] not in (win32.ERROR_SEM_TIMEOUT,
394                                     win32.ERROR_PIPE_BUSY) or _check_timeout(t):
395                    raise
396            else:
397                break
398        else:
399            raise
400
401        win32.SetNamedPipeHandleState(
402            h, win32.PIPE_READMODE_MESSAGE, None, None
403            )
404        return _multiprocessing.PipeConnection(h)
405
406#
407# Authentication stuff
408#
409
410MESSAGE_LENGTH = 20
411
412CHALLENGE = b'#CHALLENGE#'
413WELCOME = b'#WELCOME#'
414FAILURE = b'#FAILURE#'
415
416def deliver_challenge(connection, authkey):
417    import hmac
418    assert isinstance(authkey, bytes)
419    message = os.urandom(MESSAGE_LENGTH)
420    connection.send_bytes(CHALLENGE + message)
421    digest = hmac.new(authkey, message).digest()
422    response = connection.recv_bytes(256)        # reject large message
423    if response == digest:
424        connection.send_bytes(WELCOME)
425    else:
426        connection.send_bytes(FAILURE)
427        raise AuthenticationError('digest received was wrong')
428
429def answer_challenge(connection, authkey):
430    import hmac
431    assert isinstance(authkey, bytes)
432    message = connection.recv_bytes(256)         # reject large message
433    assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
434    message = message[len(CHALLENGE):]
435    digest = hmac.new(authkey, message).digest()
436    connection.send_bytes(digest)
437    response = connection.recv_bytes(256)        # reject large message
438    if response != WELCOME:
439        raise AuthenticationError('digest sent was rejected')
440
441#
442# Support for using xmlrpclib for serialization
443#
444
445class ConnectionWrapper(object):
446    def __init__(self, conn, dumps, loads):
447        self._conn = conn
448        self._dumps = dumps
449        self._loads = loads
450        for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
451            obj = getattr(conn, attr)
452            setattr(self, attr, obj)
453    def send(self, obj):
454        s = self._dumps(obj)
455        self._conn.send_bytes(s)
456    def recv(self):
457        s = self._conn.recv_bytes()
458        return self._loads(s)
459
460def _xml_dumps(obj):
461    return xmlrpclib.dumps((obj,), None, None, None, 1)
462
463def _xml_loads(s):
464    (obj,), method = xmlrpclib.loads(s)
465    return obj
466
467class XmlListener(Listener):
468    def accept(self):
469        global xmlrpclib
470        import xmlrpclib
471        obj = Listener.accept(self)
472        return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
473
474def XmlClient(*args, **kwds):
475    global xmlrpclib
476    import xmlrpclib
477    return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
478