• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# A higher level module for using sockets (or Windows named pipes)
3#
4# multiprocessing/connection.py
5#
6# Copyright (c) 2006-2008, R Oudkerk
7# Licensed to PSF under a Contributor Agreement.
8#
9
10__all__ = [ 'Client', 'Listener', 'Pipe', 'wait' ]
11
12import errno
13import io
14import os
15import sys
16import socket
17import struct
18import time
19import tempfile
20import itertools
21
22
23from . import util
24
25from . import AuthenticationError, BufferTooShort
26from .context import reduction
27_ForkingPickler = reduction.ForkingPickler
28
29try:
30    import _multiprocessing
31    import _winapi
32    from _winapi import WAIT_OBJECT_0, WAIT_ABANDONED_0, WAIT_TIMEOUT, INFINITE
33except ImportError:
34    if sys.platform == 'win32':
35        raise
36    _winapi = None
37
38#
39#
40#
41
42BUFSIZE = 8192
43# A very generous timeout when it comes to local connections...
44CONNECTION_TIMEOUT = 20.
45
46_mmap_counter = itertools.count()
47
48default_family = 'AF_INET'
49families = ['AF_INET']
50
51if hasattr(socket, 'AF_UNIX'):
52    default_family = 'AF_UNIX'
53    families += ['AF_UNIX']
54
55if sys.platform == 'win32':
56    default_family = 'AF_PIPE'
57    families += ['AF_PIPE']
58
59
60def _init_timeout(timeout=CONNECTION_TIMEOUT):
61    return time.monotonic() + timeout
62
63def _check_timeout(t):
64    return time.monotonic() > t
65
66#
67#
68#
69
70def arbitrary_address(family):
71    '''
72    Return an arbitrary free address for the given family
73    '''
74    if family == 'AF_INET':
75        return ('localhost', 0)
76    elif family == 'AF_UNIX':
77        return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir())
78    elif family == 'AF_PIPE':
79        return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
80                               (os.getpid(), next(_mmap_counter)), dir="")
81    else:
82        raise ValueError('unrecognized family')
83
84def _validate_family(family):
85    '''
86    Checks if the family is valid for the current environment.
87    '''
88    if sys.platform != 'win32' and family == 'AF_PIPE':
89        raise ValueError('Family %s is not recognized.' % family)
90
91    if sys.platform == 'win32' and family == 'AF_UNIX':
92        # double check
93        if not hasattr(socket, family):
94            raise ValueError('Family %s is not recognized.' % family)
95
96def address_type(address):
97    '''
98    Return the types of the address
99
100    This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
101    '''
102    if type(address) == tuple:
103        return 'AF_INET'
104    elif type(address) is str and address.startswith('\\\\'):
105        return 'AF_PIPE'
106    elif type(address) is str or util.is_abstract_socket_namespace(address):
107        return 'AF_UNIX'
108    else:
109        raise ValueError('address type of %r unrecognized' % address)
110
111#
112# Connection classes
113#
114
115class _ConnectionBase:
116    _handle = None
117
118    def __init__(self, handle, readable=True, writable=True):
119        handle = handle.__index__()
120        if handle < 0:
121            raise ValueError("invalid handle")
122        if not readable and not writable:
123            raise ValueError(
124                "at least one of `readable` and `writable` must be True")
125        self._handle = handle
126        self._readable = readable
127        self._writable = writable
128
129    # XXX should we use util.Finalize instead of a __del__?
130
131    def __del__(self):
132        if self._handle is not None:
133            self._close()
134
135    def _check_closed(self):
136        if self._handle is None:
137            raise OSError("handle is closed")
138
139    def _check_readable(self):
140        if not self._readable:
141            raise OSError("connection is write-only")
142
143    def _check_writable(self):
144        if not self._writable:
145            raise OSError("connection is read-only")
146
147    def _bad_message_length(self):
148        if self._writable:
149            self._readable = False
150        else:
151            self.close()
152        raise OSError("bad message length")
153
154    @property
155    def closed(self):
156        """True if the connection is closed"""
157        return self._handle is None
158
159    @property
160    def readable(self):
161        """True if the connection is readable"""
162        return self._readable
163
164    @property
165    def writable(self):
166        """True if the connection is writable"""
167        return self._writable
168
169    def fileno(self):
170        """File descriptor or handle of the connection"""
171        self._check_closed()
172        return self._handle
173
174    def close(self):
175        """Close the connection"""
176        if self._handle is not None:
177            try:
178                self._close()
179            finally:
180                self._handle = None
181
182    def send_bytes(self, buf, offset=0, size=None):
183        """Send the bytes data from a bytes-like object"""
184        self._check_closed()
185        self._check_writable()
186        m = memoryview(buf)
187        if m.itemsize > 1:
188            m = m.cast('B')
189        n = m.nbytes
190        if offset < 0:
191            raise ValueError("offset is negative")
192        if n < offset:
193            raise ValueError("buffer length < offset")
194        if size is None:
195            size = n - offset
196        elif size < 0:
197            raise ValueError("size is negative")
198        elif offset + size > n:
199            raise ValueError("buffer length < offset + size")
200        self._send_bytes(m[offset:offset + size])
201
202    def send(self, obj):
203        """Send a (picklable) object"""
204        self._check_closed()
205        self._check_writable()
206        self._send_bytes(_ForkingPickler.dumps(obj))
207
208    def recv_bytes(self, maxlength=None):
209        """
210        Receive bytes data as a bytes object.
211        """
212        self._check_closed()
213        self._check_readable()
214        if maxlength is not None and maxlength < 0:
215            raise ValueError("negative maxlength")
216        buf = self._recv_bytes(maxlength)
217        if buf is None:
218            self._bad_message_length()
219        return buf.getvalue()
220
221    def recv_bytes_into(self, buf, offset=0):
222        """
223        Receive bytes data into a writeable bytes-like object.
224        Return the number of bytes read.
225        """
226        self._check_closed()
227        self._check_readable()
228        with memoryview(buf) as m:
229            # Get bytesize of arbitrary buffer
230            itemsize = m.itemsize
231            bytesize = itemsize * len(m)
232            if offset < 0:
233                raise ValueError("negative offset")
234            elif offset > bytesize:
235                raise ValueError("offset too large")
236            result = self._recv_bytes()
237            size = result.tell()
238            if bytesize < offset + size:
239                raise BufferTooShort(result.getvalue())
240            # Message can fit in dest
241            result.seek(0)
242            result.readinto(m[offset // itemsize :
243                              (offset + size) // itemsize])
244            return size
245
246    def recv(self):
247        """Receive a (picklable) object"""
248        self._check_closed()
249        self._check_readable()
250        buf = self._recv_bytes()
251        return _ForkingPickler.loads(buf.getbuffer())
252
253    def poll(self, timeout=0.0):
254        """Whether there is any input available to be read"""
255        self._check_closed()
256        self._check_readable()
257        return self._poll(timeout)
258
259    def __enter__(self):
260        return self
261
262    def __exit__(self, exc_type, exc_value, exc_tb):
263        self.close()
264
265
266if _winapi:
267
268    class PipeConnection(_ConnectionBase):
269        """
270        Connection class based on a Windows named pipe.
271        Overlapped I/O is used, so the handles must have been created
272        with FILE_FLAG_OVERLAPPED.
273        """
274        _got_empty_message = False
275        _send_ov = None
276
277        def _close(self, _CloseHandle=_winapi.CloseHandle):
278            ov = self._send_ov
279            if ov is not None:
280                # Interrupt WaitForMultipleObjects() in _send_bytes()
281                ov.cancel()
282            _CloseHandle(self._handle)
283
284        def _send_bytes(self, buf):
285            if self._send_ov is not None:
286                # A connection should only be used by a single thread
287                raise ValueError("concurrent send_bytes() calls "
288                                 "are not supported")
289            ov, err = _winapi.WriteFile(self._handle, buf, overlapped=True)
290            self._send_ov = ov
291            try:
292                if err == _winapi.ERROR_IO_PENDING:
293                    waitres = _winapi.WaitForMultipleObjects(
294                        [ov.event], False, INFINITE)
295                    assert waitres == WAIT_OBJECT_0
296            except:
297                ov.cancel()
298                raise
299            finally:
300                self._send_ov = None
301                nwritten, err = ov.GetOverlappedResult(True)
302            if err == _winapi.ERROR_OPERATION_ABORTED:
303                # close() was called by another thread while
304                # WaitForMultipleObjects() was waiting for the overlapped
305                # operation.
306                raise OSError(errno.EPIPE, "handle is closed")
307            assert err == 0
308            assert nwritten == len(buf)
309
310        def _recv_bytes(self, maxsize=None):
311            if self._got_empty_message:
312                self._got_empty_message = False
313                return io.BytesIO()
314            else:
315                bsize = 128 if maxsize is None else min(maxsize, 128)
316                try:
317                    ov, err = _winapi.ReadFile(self._handle, bsize,
318                                                overlapped=True)
319                    try:
320                        if err == _winapi.ERROR_IO_PENDING:
321                            waitres = _winapi.WaitForMultipleObjects(
322                                [ov.event], False, INFINITE)
323                            assert waitres == WAIT_OBJECT_0
324                    except:
325                        ov.cancel()
326                        raise
327                    finally:
328                        nread, err = ov.GetOverlappedResult(True)
329                        if err == 0:
330                            f = io.BytesIO()
331                            f.write(ov.getbuffer())
332                            return f
333                        elif err == _winapi.ERROR_MORE_DATA:
334                            return self._get_more_data(ov, maxsize)
335                except OSError as e:
336                    if e.winerror == _winapi.ERROR_BROKEN_PIPE:
337                        raise EOFError
338                    else:
339                        raise
340            raise RuntimeError("shouldn't get here; expected KeyboardInterrupt")
341
342        def _poll(self, timeout):
343            if (self._got_empty_message or
344                        _winapi.PeekNamedPipe(self._handle)[0] != 0):
345                return True
346            return bool(wait([self], timeout))
347
348        def _get_more_data(self, ov, maxsize):
349            buf = ov.getbuffer()
350            f = io.BytesIO()
351            f.write(buf)
352            left = _winapi.PeekNamedPipe(self._handle)[1]
353            assert left > 0
354            if maxsize is not None and len(buf) + left > maxsize:
355                self._bad_message_length()
356            ov, err = _winapi.ReadFile(self._handle, left, overlapped=True)
357            rbytes, err = ov.GetOverlappedResult(True)
358            assert err == 0
359            assert rbytes == left
360            f.write(ov.getbuffer())
361            return f
362
363
364class Connection(_ConnectionBase):
365    """
366    Connection class based on an arbitrary file descriptor (Unix only), or
367    a socket handle (Windows).
368    """
369
370    if _winapi:
371        def _close(self, _close=_multiprocessing.closesocket):
372            _close(self._handle)
373        _write = _multiprocessing.send
374        _read = _multiprocessing.recv
375    else:
376        def _close(self, _close=os.close):
377            _close(self._handle)
378        _write = os.write
379        _read = os.read
380
381    def _send(self, buf, write=_write):
382        remaining = len(buf)
383        while True:
384            n = write(self._handle, buf)
385            remaining -= n
386            if remaining == 0:
387                break
388            buf = buf[n:]
389
390    def _recv(self, size, read=_read):
391        buf = io.BytesIO()
392        handle = self._handle
393        remaining = size
394        while remaining > 0:
395            chunk = read(handle, remaining)
396            n = len(chunk)
397            if n == 0:
398                if remaining == size:
399                    raise EOFError
400                else:
401                    raise OSError("got end of file during message")
402            buf.write(chunk)
403            remaining -= n
404        return buf
405
406    def _send_bytes(self, buf):
407        n = len(buf)
408        if n > 0x7fffffff:
409            pre_header = struct.pack("!i", -1)
410            header = struct.pack("!Q", n)
411            self._send(pre_header)
412            self._send(header)
413            self._send(buf)
414        else:
415            # For wire compatibility with 3.7 and lower
416            header = struct.pack("!i", n)
417            if n > 16384:
418                # The payload is large so Nagle's algorithm won't be triggered
419                # and we'd better avoid the cost of concatenation.
420                self._send(header)
421                self._send(buf)
422            else:
423                # Issue #20540: concatenate before sending, to avoid delays due
424                # to Nagle's algorithm on a TCP socket.
425                # Also note we want to avoid sending a 0-length buffer separately,
426                # to avoid "broken pipe" errors if the other end closed the pipe.
427                self._send(header + buf)
428
429    def _recv_bytes(self, maxsize=None):
430        buf = self._recv(4)
431        size, = struct.unpack("!i", buf.getvalue())
432        if size == -1:
433            buf = self._recv(8)
434            size, = struct.unpack("!Q", buf.getvalue())
435        if maxsize is not None and size > maxsize:
436            return None
437        return self._recv(size)
438
439    def _poll(self, timeout):
440        r = wait([self], timeout)
441        return bool(r)
442
443
444#
445# Public functions
446#
447
448class Listener(object):
449    '''
450    Returns a listener object.
451
452    This is a wrapper for a bound socket which is 'listening' for
453    connections, or for a Windows named pipe.
454    '''
455    def __init__(self, address=None, family=None, backlog=1, authkey=None):
456        family = family or (address and address_type(address)) \
457                 or default_family
458        address = address or arbitrary_address(family)
459
460        _validate_family(family)
461        if family == 'AF_PIPE':
462            self._listener = PipeListener(address, backlog)
463        else:
464            self._listener = SocketListener(address, family, backlog)
465
466        if authkey is not None and not isinstance(authkey, bytes):
467            raise TypeError('authkey should be a byte string')
468
469        self._authkey = authkey
470
471    def accept(self):
472        '''
473        Accept a connection on the bound socket or named pipe of `self`.
474
475        Returns a `Connection` object.
476        '''
477        if self._listener is None:
478            raise OSError('listener is closed')
479
480        c = self._listener.accept()
481        if self._authkey is not None:
482            deliver_challenge(c, self._authkey)
483            answer_challenge(c, self._authkey)
484        return c
485
486    def close(self):
487        '''
488        Close the bound socket or named pipe of `self`.
489        '''
490        listener = self._listener
491        if listener is not None:
492            self._listener = None
493            listener.close()
494
495    @property
496    def address(self):
497        return self._listener._address
498
499    @property
500    def last_accepted(self):
501        return self._listener._last_accepted
502
503    def __enter__(self):
504        return self
505
506    def __exit__(self, exc_type, exc_value, exc_tb):
507        self.close()
508
509
510def Client(address, family=None, authkey=None):
511    '''
512    Returns a connection to the address of a `Listener`
513    '''
514    family = family or address_type(address)
515    _validate_family(family)
516    if family == 'AF_PIPE':
517        c = PipeClient(address)
518    else:
519        c = SocketClient(address)
520
521    if authkey is not None and not isinstance(authkey, bytes):
522        raise TypeError('authkey should be a byte string')
523
524    if authkey is not None:
525        answer_challenge(c, authkey)
526        deliver_challenge(c, authkey)
527
528    return c
529
530
531if sys.platform != 'win32':
532
533    def Pipe(duplex=True):
534        '''
535        Returns pair of connection objects at either end of a pipe
536        '''
537        if duplex:
538            s1, s2 = socket.socketpair()
539            s1.setblocking(True)
540            s2.setblocking(True)
541            c1 = Connection(s1.detach())
542            c2 = Connection(s2.detach())
543        else:
544            fd1, fd2 = os.pipe()
545            c1 = Connection(fd1, writable=False)
546            c2 = Connection(fd2, readable=False)
547
548        return c1, c2
549
550else:
551
552    def Pipe(duplex=True):
553        '''
554        Returns pair of connection objects at either end of a pipe
555        '''
556        address = arbitrary_address('AF_PIPE')
557        if duplex:
558            openmode = _winapi.PIPE_ACCESS_DUPLEX
559            access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
560            obsize, ibsize = BUFSIZE, BUFSIZE
561        else:
562            openmode = _winapi.PIPE_ACCESS_INBOUND
563            access = _winapi.GENERIC_WRITE
564            obsize, ibsize = 0, BUFSIZE
565
566        h1 = _winapi.CreateNamedPipe(
567            address, openmode | _winapi.FILE_FLAG_OVERLAPPED |
568            _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE,
569            _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
570            _winapi.PIPE_WAIT,
571            1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER,
572            # default security descriptor: the handle cannot be inherited
573            _winapi.NULL
574            )
575        h2 = _winapi.CreateFile(
576            address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
577            _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL
578            )
579        _winapi.SetNamedPipeHandleState(
580            h2, _winapi.PIPE_READMODE_MESSAGE, None, None
581            )
582
583        overlapped = _winapi.ConnectNamedPipe(h1, overlapped=True)
584        _, err = overlapped.GetOverlappedResult(True)
585        assert err == 0
586
587        c1 = PipeConnection(h1, writable=duplex)
588        c2 = PipeConnection(h2, readable=duplex)
589
590        return c1, c2
591
592#
593# Definitions for connections based on sockets
594#
595
596class SocketListener(object):
597    '''
598    Representation of a socket which is bound to an address and listening
599    '''
600    def __init__(self, address, family, backlog=1):
601        self._socket = socket.socket(getattr(socket, family))
602        try:
603            # SO_REUSEADDR has different semantics on Windows (issue #2550).
604            if os.name == 'posix':
605                self._socket.setsockopt(socket.SOL_SOCKET,
606                                        socket.SO_REUSEADDR, 1)
607            self._socket.setblocking(True)
608            self._socket.bind(address)
609            self._socket.listen(backlog)
610            self._address = self._socket.getsockname()
611        except OSError:
612            self._socket.close()
613            raise
614        self._family = family
615        self._last_accepted = None
616
617        if family == 'AF_UNIX' and not util.is_abstract_socket_namespace(address):
618            # Linux abstract socket namespaces do not need to be explicitly unlinked
619            self._unlink = util.Finalize(
620                self, os.unlink, args=(address,), exitpriority=0
621                )
622        else:
623            self._unlink = None
624
625    def accept(self):
626        s, self._last_accepted = self._socket.accept()
627        s.setblocking(True)
628        return Connection(s.detach())
629
630    def close(self):
631        try:
632            self._socket.close()
633        finally:
634            unlink = self._unlink
635            if unlink is not None:
636                self._unlink = None
637                unlink()
638
639
640def SocketClient(address):
641    '''
642    Return a connection object connected to the socket given by `address`
643    '''
644    family = address_type(address)
645    with socket.socket( getattr(socket, family) ) as s:
646        s.setblocking(True)
647        s.connect(address)
648        return Connection(s.detach())
649
650#
651# Definitions for connections based on named pipes
652#
653
654if sys.platform == 'win32':
655
656    class PipeListener(object):
657        '''
658        Representation of a named pipe
659        '''
660        def __init__(self, address, backlog=None):
661            self._address = address
662            self._handle_queue = [self._new_handle(first=True)]
663
664            self._last_accepted = None
665            util.sub_debug('listener created with address=%r', self._address)
666            self.close = util.Finalize(
667                self, PipeListener._finalize_pipe_listener,
668                args=(self._handle_queue, self._address), exitpriority=0
669                )
670
671        def _new_handle(self, first=False):
672            flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED
673            if first:
674                flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
675            return _winapi.CreateNamedPipe(
676                self._address, flags,
677                _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
678                _winapi.PIPE_WAIT,
679                _winapi.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
680                _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL
681                )
682
683        def accept(self):
684            self._handle_queue.append(self._new_handle())
685            handle = self._handle_queue.pop(0)
686            try:
687                ov = _winapi.ConnectNamedPipe(handle, overlapped=True)
688            except OSError as e:
689                if e.winerror != _winapi.ERROR_NO_DATA:
690                    raise
691                # ERROR_NO_DATA can occur if a client has already connected,
692                # written data and then disconnected -- see Issue 14725.
693            else:
694                try:
695                    res = _winapi.WaitForMultipleObjects(
696                        [ov.event], False, INFINITE)
697                except:
698                    ov.cancel()
699                    _winapi.CloseHandle(handle)
700                    raise
701                finally:
702                    _, err = ov.GetOverlappedResult(True)
703                    assert err == 0
704            return PipeConnection(handle)
705
706        @staticmethod
707        def _finalize_pipe_listener(queue, address):
708            util.sub_debug('closing listener with address=%r', address)
709            for handle in queue:
710                _winapi.CloseHandle(handle)
711
712    def PipeClient(address):
713        '''
714        Return a connection object connected to the pipe given by `address`
715        '''
716        t = _init_timeout()
717        while 1:
718            try:
719                _winapi.WaitNamedPipe(address, 1000)
720                h = _winapi.CreateFile(
721                    address, _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
722                    0, _winapi.NULL, _winapi.OPEN_EXISTING,
723                    _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL
724                    )
725            except OSError as e:
726                if e.winerror not in (_winapi.ERROR_SEM_TIMEOUT,
727                                      _winapi.ERROR_PIPE_BUSY) or _check_timeout(t):
728                    raise
729            else:
730                break
731        else:
732            raise
733
734        _winapi.SetNamedPipeHandleState(
735            h, _winapi.PIPE_READMODE_MESSAGE, None, None
736            )
737        return PipeConnection(h)
738
739#
740# Authentication stuff
741#
742
743MESSAGE_LENGTH = 40  # MUST be > 20
744
745_CHALLENGE = b'#CHALLENGE#'
746_WELCOME = b'#WELCOME#'
747_FAILURE = b'#FAILURE#'
748
749# multiprocessing.connection Authentication Handshake Protocol Description
750# (as documented for reference after reading the existing code)
751# =============================================================================
752#
753# On Windows: native pipes with "overlapped IO" are used to send the bytes,
754# instead of the length prefix SIZE scheme described below. (ie: the OS deals
755# with message sizes for us)
756#
757# Protocol error behaviors:
758#
759# On POSIX, any failure to receive the length prefix into SIZE, for SIZE greater
760# than the requested maxsize to receive, or receiving fewer than SIZE bytes
761# results in the connection being closed and auth to fail.
762#
763# On Windows, receiving too few bytes is never a low level _recv_bytes read
764# error, receiving too many will trigger an error only if receive maxsize
765# value was larger than 128 OR the if the data arrived in smaller pieces.
766#
767#      Serving side                           Client side
768#     ------------------------------  ---------------------------------------
769# 0.                                  Open a connection on the pipe.
770# 1.  Accept connection.
771# 2.  Random 20+ bytes -> MESSAGE
772#     Modern servers always send
773#     more than 20 bytes and include
774#     a {digest} prefix on it with
775#     their preferred HMAC digest.
776#     Legacy ones send ==20 bytes.
777# 3.  send 4 byte length (net order)
778#     prefix followed by:
779#       b'#CHALLENGE#' + MESSAGE
780# 4.                                  Receive 4 bytes, parse as network byte
781#                                     order integer. If it is -1, receive an
782#                                     additional 8 bytes, parse that as network
783#                                     byte order. The result is the length of
784#                                     the data that follows -> SIZE.
785# 5.                                  Receive min(SIZE, 256) bytes -> M1
786# 6.                                  Assert that M1 starts with:
787#                                       b'#CHALLENGE#'
788# 7.                                  Strip that prefix from M1 into -> M2
789# 7.1.                                Parse M2: if it is exactly 20 bytes in
790#                                     length this indicates a legacy server
791#                                     supporting only HMAC-MD5. Otherwise the
792# 7.2.                                preferred digest is looked up from an
793#                                     expected "{digest}" prefix on M2. No prefix
794#                                     or unsupported digest? <- AuthenticationError
795# 7.3.                                Put divined algorithm name in -> D_NAME
796# 8.                                  Compute HMAC-D_NAME of AUTHKEY, M2 -> C_DIGEST
797# 9.                                  Send 4 byte length prefix (net order)
798#                                     followed by C_DIGEST bytes.
799# 10. Receive 4 or 4+8 byte length
800#     prefix (#4 dance) -> SIZE.
801# 11. Receive min(SIZE, 256) -> C_D.
802# 11.1. Parse C_D: legacy servers
803#     accept it as is, "md5" -> D_NAME
804# 11.2. modern servers check the length
805#     of C_D, IF it is 16 bytes?
806# 11.2.1. "md5" -> D_NAME
807#         and skip to step 12.
808# 11.3. longer? expect and parse a "{digest}"
809#     prefix into -> D_NAME.
810#     Strip the prefix and store remaining
811#     bytes in -> C_D.
812# 11.4. Don't like D_NAME? <- AuthenticationError
813# 12. Compute HMAC-D_NAME of AUTHKEY,
814#     MESSAGE into -> M_DIGEST.
815# 13. Compare M_DIGEST == C_D:
816# 14a: Match? Send length prefix &
817#       b'#WELCOME#'
818#    <- RETURN
819# 14b: Mismatch? Send len prefix &
820#       b'#FAILURE#'
821#    <- CLOSE & AuthenticationError
822# 15.                                 Receive 4 or 4+8 byte length prefix (net
823#                                     order) again as in #4 into -> SIZE.
824# 16.                                 Receive min(SIZE, 256) bytes -> M3.
825# 17.                                 Compare M3 == b'#WELCOME#':
826# 17a.                                Match? <- RETURN
827# 17b.                                Mismatch? <- CLOSE & AuthenticationError
828#
829# If this RETURNed, the connection remains open: it has been authenticated.
830#
831# Length prefixes are used consistently. Even on the legacy protocol, this
832# was good fortune and allowed us to evolve the protocol by using the length
833# of the opening challenge or length of the returned digest as a signal as
834# to which protocol the other end supports.
835
836_ALLOWED_DIGESTS = frozenset(
837        {b'md5', b'sha256', b'sha384', b'sha3_256', b'sha3_384'})
838_MAX_DIGEST_LEN = max(len(_) for _ in _ALLOWED_DIGESTS)
839
840# Old hmac-md5 only server versions from Python <=3.11 sent a message of this
841# length. It happens to not match the length of any supported digest so we can
842# use a message of this length to indicate that we should work in backwards
843# compatible md5-only mode without a {digest_name} prefix on our response.
844_MD5ONLY_MESSAGE_LENGTH = 20
845_MD5_DIGEST_LEN = 16
846_LEGACY_LENGTHS = (_MD5ONLY_MESSAGE_LENGTH, _MD5_DIGEST_LEN)
847
848
849def _get_digest_name_and_payload(message: bytes) -> (str, bytes):
850    """Returns a digest name and the payload for a response hash.
851
852    If a legacy protocol is detected based on the message length
853    or contents the digest name returned will be empty to indicate
854    legacy mode where MD5 and no digest prefix should be sent.
855    """
856    # modern message format: b"{digest}payload" longer than 20 bytes
857    # legacy message format: 16 or 20 byte b"payload"
858    if len(message) in _LEGACY_LENGTHS:
859        # Either this was a legacy server challenge, or we're processing
860        # a reply from a legacy client that sent an unprefixed 16-byte
861        # HMAC-MD5 response. All messages using the modern protocol will
862        # be longer than either of these lengths.
863        return '', message
864    if (message.startswith(b'{') and
865        (curly := message.find(b'}', 1, _MAX_DIGEST_LEN+2)) > 0):
866        digest = message[1:curly]
867        if digest in _ALLOWED_DIGESTS:
868            payload = message[curly+1:]
869            return digest.decode('ascii'), payload
870    raise AuthenticationError(
871            'unsupported message length, missing digest prefix, '
872            f'or unsupported digest: {message=}')
873
874
875def _create_response(authkey, message):
876    """Create a MAC based on authkey and message
877
878    The MAC algorithm defaults to HMAC-MD5, unless MD5 is not available or
879    the message has a '{digest_name}' prefix. For legacy HMAC-MD5, the response
880    is the raw MAC, otherwise the response is prefixed with '{digest_name}',
881    e.g. b'{sha256}abcdefg...'
882
883    Note: The MAC protects the entire message including the digest_name prefix.
884    """
885    import hmac
886    digest_name = _get_digest_name_and_payload(message)[0]
887    # The MAC protects the entire message: digest header and payload.
888    if not digest_name:
889        # Legacy server without a {digest} prefix on message.
890        # Generate a legacy non-prefixed HMAC-MD5 reply.
891        try:
892            return hmac.new(authkey, message, 'md5').digest()
893        except ValueError:
894            # HMAC-MD5 is not available (FIPS mode?), fall back to
895            # HMAC-SHA2-256 modern protocol. The legacy server probably
896            # doesn't support it and will reject us anyways. :shrug:
897            digest_name = 'sha256'
898    # Modern protocol, indicate the digest used in the reply.
899    response = hmac.new(authkey, message, digest_name).digest()
900    return b'{%s}%s' % (digest_name.encode('ascii'), response)
901
902
903def _verify_challenge(authkey, message, response):
904    """Verify MAC challenge
905
906    If our message did not include a digest_name prefix, the client is allowed
907    to select a stronger digest_name from _ALLOWED_DIGESTS.
908
909    In case our message is prefixed, a client cannot downgrade to a weaker
910    algorithm, because the MAC is calculated over the entire message
911    including the '{digest_name}' prefix.
912    """
913    import hmac
914    response_digest, response_mac = _get_digest_name_and_payload(response)
915    response_digest = response_digest or 'md5'
916    try:
917        expected = hmac.new(authkey, message, response_digest).digest()
918    except ValueError:
919        raise AuthenticationError(f'{response_digest=} unsupported')
920    if len(expected) != len(response_mac):
921        raise AuthenticationError(
922                f'expected {response_digest!r} of length {len(expected)} '
923                f'got {len(response_mac)}')
924    if not hmac.compare_digest(expected, response_mac):
925        raise AuthenticationError('digest received was wrong')
926
927
928def deliver_challenge(connection, authkey: bytes, digest_name='sha256'):
929    if not isinstance(authkey, bytes):
930        raise ValueError(
931            "Authkey must be bytes, not {0!s}".format(type(authkey)))
932    assert MESSAGE_LENGTH > _MD5ONLY_MESSAGE_LENGTH, "protocol constraint"
933    message = os.urandom(MESSAGE_LENGTH)
934    message = b'{%s}%s' % (digest_name.encode('ascii'), message)
935    # Even when sending a challenge to a legacy client that does not support
936    # digest prefixes, they'll take the entire thing as a challenge and
937    # respond to it with a raw HMAC-MD5.
938    connection.send_bytes(_CHALLENGE + message)
939    response = connection.recv_bytes(256)        # reject large message
940    try:
941        _verify_challenge(authkey, message, response)
942    except AuthenticationError:
943        connection.send_bytes(_FAILURE)
944        raise
945    else:
946        connection.send_bytes(_WELCOME)
947
948
949def answer_challenge(connection, authkey: bytes):
950    if not isinstance(authkey, bytes):
951        raise ValueError(
952            "Authkey must be bytes, not {0!s}".format(type(authkey)))
953    message = connection.recv_bytes(256)         # reject large message
954    if not message.startswith(_CHALLENGE):
955        raise AuthenticationError(
956                f'Protocol error, expected challenge: {message=}')
957    message = message[len(_CHALLENGE):]
958    if len(message) < _MD5ONLY_MESSAGE_LENGTH:
959        raise AuthenticationError(f'challenge too short: {len(message)} bytes')
960    digest = _create_response(authkey, message)
961    connection.send_bytes(digest)
962    response = connection.recv_bytes(256)        # reject large message
963    if response != _WELCOME:
964        raise AuthenticationError('digest sent was rejected')
965
966#
967# Support for using xmlrpclib for serialization
968#
969
970class ConnectionWrapper(object):
971    def __init__(self, conn, dumps, loads):
972        self._conn = conn
973        self._dumps = dumps
974        self._loads = loads
975        for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
976            obj = getattr(conn, attr)
977            setattr(self, attr, obj)
978    def send(self, obj):
979        s = self._dumps(obj)
980        self._conn.send_bytes(s)
981    def recv(self):
982        s = self._conn.recv_bytes()
983        return self._loads(s)
984
985def _xml_dumps(obj):
986    return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf-8')
987
988def _xml_loads(s):
989    (obj,), method = xmlrpclib.loads(s.decode('utf-8'))
990    return obj
991
992class XmlListener(Listener):
993    def accept(self):
994        global xmlrpclib
995        import xmlrpc.client as xmlrpclib
996        obj = Listener.accept(self)
997        return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
998
999def XmlClient(*args, **kwds):
1000    global xmlrpclib
1001    import xmlrpc.client as xmlrpclib
1002    return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
1003
1004#
1005# Wait
1006#
1007
1008if sys.platform == 'win32':
1009
1010    def _exhaustive_wait(handles, timeout):
1011        # Return ALL handles which are currently signalled.  (Only
1012        # returning the first signalled might create starvation issues.)
1013        L = list(handles)
1014        ready = []
1015        # Windows limits WaitForMultipleObjects at 64 handles, and we use a
1016        # few for synchronisation, so we switch to batched waits at 60.
1017        if len(L) > 60:
1018            try:
1019                res = _winapi.BatchedWaitForMultipleObjects(L, False, timeout)
1020            except TimeoutError:
1021                return []
1022            ready.extend(L[i] for i in res)
1023            if res:
1024                L = [h for i, h in enumerate(L) if i > res[0] & i not in res]
1025            timeout = 0
1026        while L:
1027            short_L = L[:60] if len(L) > 60 else L
1028            res = _winapi.WaitForMultipleObjects(short_L, False, timeout)
1029            if res == WAIT_TIMEOUT:
1030                break
1031            elif WAIT_OBJECT_0 <= res < WAIT_OBJECT_0 + len(L):
1032                res -= WAIT_OBJECT_0
1033            elif WAIT_ABANDONED_0 <= res < WAIT_ABANDONED_0 + len(L):
1034                res -= WAIT_ABANDONED_0
1035            else:
1036                raise RuntimeError('Should not get here')
1037            ready.append(L[res])
1038            L = L[res+1:]
1039            timeout = 0
1040        return ready
1041
1042    _ready_errors = {_winapi.ERROR_BROKEN_PIPE, _winapi.ERROR_NETNAME_DELETED}
1043
1044    def wait(object_list, timeout=None):
1045        '''
1046        Wait till an object in object_list is ready/readable.
1047
1048        Returns list of those objects in object_list which are ready/readable.
1049        '''
1050        if timeout is None:
1051            timeout = INFINITE
1052        elif timeout < 0:
1053            timeout = 0
1054        else:
1055            timeout = int(timeout * 1000 + 0.5)
1056
1057        object_list = list(object_list)
1058        waithandle_to_obj = {}
1059        ov_list = []
1060        ready_objects = set()
1061        ready_handles = set()
1062
1063        try:
1064            for o in object_list:
1065                try:
1066                    fileno = getattr(o, 'fileno')
1067                except AttributeError:
1068                    waithandle_to_obj[o.__index__()] = o
1069                else:
1070                    # start an overlapped read of length zero
1071                    try:
1072                        ov, err = _winapi.ReadFile(fileno(), 0, True)
1073                    except OSError as e:
1074                        ov, err = None, e.winerror
1075                        if err not in _ready_errors:
1076                            raise
1077                    if err == _winapi.ERROR_IO_PENDING:
1078                        ov_list.append(ov)
1079                        waithandle_to_obj[ov.event] = o
1080                    else:
1081                        # If o.fileno() is an overlapped pipe handle and
1082                        # err == 0 then there is a zero length message
1083                        # in the pipe, but it HAS NOT been consumed...
1084                        if ov and sys.getwindowsversion()[:2] >= (6, 2):
1085                            # ... except on Windows 8 and later, where
1086                            # the message HAS been consumed.
1087                            try:
1088                                _, err = ov.GetOverlappedResult(False)
1089                            except OSError as e:
1090                                err = e.winerror
1091                            if not err and hasattr(o, '_got_empty_message'):
1092                                o._got_empty_message = True
1093                        ready_objects.add(o)
1094                        timeout = 0
1095
1096            ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), timeout)
1097        finally:
1098            # request that overlapped reads stop
1099            for ov in ov_list:
1100                ov.cancel()
1101
1102            # wait for all overlapped reads to stop
1103            for ov in ov_list:
1104                try:
1105                    _, err = ov.GetOverlappedResult(True)
1106                except OSError as e:
1107                    err = e.winerror
1108                    if err not in _ready_errors:
1109                        raise
1110                if err != _winapi.ERROR_OPERATION_ABORTED:
1111                    o = waithandle_to_obj[ov.event]
1112                    ready_objects.add(o)
1113                    if err == 0:
1114                        # If o.fileno() is an overlapped pipe handle then
1115                        # a zero length message HAS been consumed.
1116                        if hasattr(o, '_got_empty_message'):
1117                            o._got_empty_message = True
1118
1119        ready_objects.update(waithandle_to_obj[h] for h in ready_handles)
1120        return [o for o in object_list if o in ready_objects]
1121
1122else:
1123
1124    import selectors
1125
1126    # poll/select have the advantage of not requiring any extra file
1127    # descriptor, contrarily to epoll/kqueue (also, they require a single
1128    # syscall).
1129    if hasattr(selectors, 'PollSelector'):
1130        _WaitSelector = selectors.PollSelector
1131    else:
1132        _WaitSelector = selectors.SelectSelector
1133
1134    def wait(object_list, timeout=None):
1135        '''
1136        Wait till an object in object_list is ready/readable.
1137
1138        Returns list of those objects in object_list which are ready/readable.
1139        '''
1140        with _WaitSelector() as selector:
1141            for obj in object_list:
1142                selector.register(obj, selectors.EVENT_READ)
1143
1144            if timeout is not None:
1145                deadline = time.monotonic() + timeout
1146
1147            while True:
1148                ready = selector.select(timeout)
1149                if ready:
1150                    return [key.fileobj for (key, events) in ready]
1151                else:
1152                    if timeout is not None:
1153                        timeout = deadline - time.monotonic()
1154                        if timeout < 0:
1155                            return ready
1156
1157#
1158# Make connection and socket objects shareable if possible
1159#
1160
1161if sys.platform == 'win32':
1162    def reduce_connection(conn):
1163        handle = conn.fileno()
1164        with socket.fromfd(handle, socket.AF_INET, socket.SOCK_STREAM) as s:
1165            from . import resource_sharer
1166            ds = resource_sharer.DupSocket(s)
1167            return rebuild_connection, (ds, conn.readable, conn.writable)
1168    def rebuild_connection(ds, readable, writable):
1169        sock = ds.detach()
1170        return Connection(sock.detach(), readable, writable)
1171    reduction.register(Connection, reduce_connection)
1172
1173    def reduce_pipe_connection(conn):
1174        access = ((_winapi.FILE_GENERIC_READ if conn.readable else 0) |
1175                  (_winapi.FILE_GENERIC_WRITE if conn.writable else 0))
1176        dh = reduction.DupHandle(conn.fileno(), access)
1177        return rebuild_pipe_connection, (dh, conn.readable, conn.writable)
1178    def rebuild_pipe_connection(dh, readable, writable):
1179        handle = dh.detach()
1180        return PipeConnection(handle, readable, writable)
1181    reduction.register(PipeConnection, reduce_pipe_connection)
1182
1183else:
1184    def reduce_connection(conn):
1185        df = reduction.DupFd(conn.fileno())
1186        return rebuild_connection, (df, conn.readable, conn.writable)
1187    def rebuild_connection(df, readable, writable):
1188        fd = df.detach()
1189        return Connection(fd, readable, writable)
1190    reduction.register(Connection, reduce_connection)
1191