• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2012, Google Inc.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32
33"""WebSocket client utility for testing.
34
35This module contains helper methods for performing handshake, frame
36sending/receiving as a WebSocket client.
37
38This is code for testing mod_pywebsocket. Keep this code independent from
39mod_pywebsocket. Don't import e.g. Stream class for generating frame for
40testing. Using util.hexify, etc. that are not related to protocol processing
41is allowed.
42
43Note:
44This code is far from robust, e.g., we cut corners in handshake.
45"""
46
47
48import base64
49import errno
50import logging
51import os
52import random
53import re
54import socket
55import struct
56
57from mod_pywebsocket import util
58
59
60DEFAULT_PORT = 80
61DEFAULT_SECURE_PORT = 443
62
63# Opcodes introduced in IETF HyBi 01 for the new framing format
64OPCODE_CONTINUATION = 0x0
65OPCODE_CLOSE = 0x8
66OPCODE_PING = 0x9
67OPCODE_PONG = 0xa
68OPCODE_TEXT = 0x1
69OPCODE_BINARY = 0x2
70
71# Strings used for handshake
72_UPGRADE_HEADER = 'Upgrade: websocket\r\n'
73_UPGRADE_HEADER_HIXIE75 = 'Upgrade: WebSocket\r\n'
74_CONNECTION_HEADER = 'Connection: Upgrade\r\n'
75
76WEBSOCKET_ACCEPT_UUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
77
78# Status codes
79STATUS_NORMAL_CLOSURE = 1000
80STATUS_GOING_AWAY = 1001
81STATUS_PROTOCOL_ERROR = 1002
82STATUS_UNSUPPORTED_DATA = 1003
83STATUS_NO_STATUS_RECEIVED = 1005
84STATUS_ABNORMAL_CLOSURE = 1006
85STATUS_INVALID_FRAME_PAYLOAD_DATA = 1007
86STATUS_POLICY_VIOLATION = 1008
87STATUS_MESSAGE_TOO_BIG = 1009
88STATUS_MANDATORY_EXT = 1010
89STATUS_INTERNAL_SERVER_ERROR = 1011
90STATUS_TLS_HANDSHAKE = 1015
91
92# Extension tokens
93_DEFLATE_STREAM_EXTENSION = 'deflate-stream'
94_DEFLATE_FRAME_EXTENSION = 'deflate-frame'
95# TODO(bashi): Update after mux implementation finished.
96_MUX_EXTENSION = 'mux_DO_NOT_USE'
97
98def _method_line(resource):
99    return 'GET %s HTTP/1.1\r\n' % resource
100
101
102def _sec_origin_header(origin):
103    return 'Sec-WebSocket-Origin: %s\r\n' % origin.lower()
104
105
106def _origin_header(origin):
107    # 4.1 13. concatenation of the string "Origin:", a U+0020 SPACE character,
108    # and the /origin/ value, converted to ASCII lowercase, to /fields/.
109    return 'Origin: %s\r\n' % origin.lower()
110
111
112def _format_host_header(host, port, secure):
113    # 4.1 9. Let /hostport/ be an empty string.
114    # 4.1 10. Append the /host/ value, converted to ASCII lowercase, to
115    # /hostport/
116    hostport = host.lower()
117    # 4.1 11. If /secure/ is false, and /port/ is not 80, or if /secure/
118    # is true, and /port/ is not 443, then append a U+003A COLON character
119    # (:) followed by the value of /port/, expressed as a base-ten integer,
120    # to /hostport/
121    if ((not secure and port != DEFAULT_PORT) or
122        (secure and port != DEFAULT_SECURE_PORT)):
123        hostport += ':' + str(port)
124    # 4.1 12. concatenation of the string "Host:", a U+0020 SPACE
125    # character, and /hostport/, to /fields/.
126    return 'Host: %s\r\n' % hostport
127
128
129# TODO(tyoshino): Define a base class and move these shared methods to that.
130
131
132def receive_bytes(socket, length):
133    bytes = []
134    remaining = length
135    while remaining > 0:
136        received_bytes = socket.recv(remaining)
137        if not received_bytes:
138            raise Exception(
139                'Connection closed before receiving requested length '
140                '(requested %d bytes but received only %d bytes)' %
141                (length, length - remaining))
142        bytes.append(received_bytes)
143        remaining -= len(received_bytes)
144    return ''.join(bytes)
145
146
147# TODO(tyoshino): Now the WebSocketHandshake class diverts these methods. We
148# should move to HTTP parser as specified in RFC 6455. For HyBi 00 and
149# Hixie 75, pack these methods as some parser class.
150
151
152def _read_fields(socket):
153    # 4.1 32. let /fields/ be a list of name-value pairs, initially empty.
154    fields = {}
155    while True:
156        # 4.1 33. let /name/ and /value/ be empty byte arrays
157        name = ''
158        value = ''
159        # 4.1 34. read /name/
160        name = _read_name(socket)
161        if name is None:
162            break
163        # 4.1 35. read spaces
164        # TODO(tyoshino): Skip only one space as described in the spec.
165        ch = _skip_spaces(socket)
166        # 4.1 36. read /value/
167        value = _read_value(socket, ch)
168        # 4.1 37. read a byte from the server
169        ch = receive_bytes(socket, 1)
170        if ch != '\n':  # 0x0A
171            raise Exception(
172                'Expected LF but found %r while reading value %r for header '
173                '%r' % (ch, name, value))
174        # 4.1 38. append an entry to the /fields/ list that has the name
175        # given by the string obtained by interpreting the /name/ byte
176        # array as a UTF-8 stream and the value given by the string
177        # obtained by interpreting the /value/ byte array as a UTF-8 byte
178        # stream.
179        fields.setdefault(name, []).append(value)
180        # 4.1 39. return to the "Field" step above
181    return fields
182
183
184def _read_name(socket):
185    # 4.1 33. let /name/ be empty byte arrays
186    name = ''
187    while True:
188        # 4.1 34. read a byte from the server
189        ch = receive_bytes(socket, 1)
190        if ch == '\r':  # 0x0D
191            return None
192        elif ch == '\n':  # 0x0A
193            raise Exception(
194                'Unexpected LF when reading header name %r' % name)
195        elif ch == ':':  # 0x3A
196            return name
197        elif ch >= 'A' and ch <= 'Z':  # range 0x31 to 0x5A
198            ch = chr(ord(ch) + 0x20)
199            name += ch
200        else:
201            name += ch
202
203
204def _skip_spaces(socket):
205    # 4.1 35. read a byte from the server
206    while True:
207        ch = receive_bytes(socket, 1)
208        if ch == ' ':  # 0x20
209            continue
210        return ch
211
212
213def _read_value(socket, ch):
214    # 4.1 33. let /value/ be empty byte arrays
215    value = ''
216    # 4.1 36. read a byte from server.
217    while True:
218        if ch == '\r':  # 0x0D
219            return value
220        elif ch == '\n':  # 0x0A
221            raise Exception(
222                'Unexpected LF when reading header value %r' % value)
223        else:
224            value += ch
225        ch = receive_bytes(socket, 1)
226
227
228def read_frame_header(socket):
229    received = receive_bytes(socket, 2)
230
231    first_byte = ord(received[0])
232    fin = (first_byte >> 7) & 1
233    rsv1 = (first_byte >> 6) & 1
234    rsv2 = (first_byte >> 5) & 1
235    rsv3 = (first_byte >> 4) & 1
236    opcode = first_byte & 0xf
237
238    second_byte = ord(received[1])
239    mask = (second_byte >> 7) & 1
240    payload_length = second_byte & 0x7f
241
242    if mask != 0:
243        raise Exception(
244            'Mask bit must be 0 for frames coming from server')
245
246    if payload_length == 127:
247        extended_payload_length = receive_bytes(socket, 8)
248        payload_length = struct.unpack(
249            '!Q', extended_payload_length)[0]
250        if payload_length > 0x7FFFFFFFFFFFFFFF:
251            raise Exception('Extended payload length >= 2^63')
252    elif payload_length == 126:
253        extended_payload_length = receive_bytes(socket, 2)
254        payload_length = struct.unpack(
255            '!H', extended_payload_length)[0]
256
257    return fin, rsv1, rsv2, rsv3, opcode, payload_length
258
259
260class _TLSSocket(object):
261    """Wrapper for a TLS connection."""
262
263    def __init__(self, raw_socket):
264        self._ssl = socket.ssl(raw_socket)
265
266    def send(self, bytes):
267        return self._ssl.write(bytes)
268
269    def recv(self, size=-1):
270        return self._ssl.read(size)
271
272    def close(self):
273        # Nothing to do.
274        pass
275
276
277class HttpStatusException(Exception):
278    """This exception will be raised when unexpected http status code was
279    received as a result of handshake.
280    """
281
282    def __init__(self, name, status):
283        super(HttpStatusException, self).__init__(name)
284        self.status = status
285
286
287class WebSocketHandshake(object):
288    """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
289
290    def __init__(self, options):
291        self._logger = util.get_class_logger(self)
292
293        self._options = options
294
295    def handshake(self, socket):
296        """Handshake WebSocket.
297
298        Raises:
299            Exception: handshake failed.
300        """
301
302        self._socket = socket
303
304        request_line = _method_line(self._options.resource)
305        self._logger.debug('Opening handshake Request-Line: %r', request_line)
306        self._socket.sendall(request_line)
307
308        fields = []
309        fields.append(_UPGRADE_HEADER)
310        fields.append(_CONNECTION_HEADER)
311
312        fields.append(_format_host_header(
313            self._options.server_host,
314            self._options.server_port,
315            self._options.use_tls))
316
317        if self._options.version is 8:
318            fields.append(_sec_origin_header(self._options.origin))
319        else:
320            fields.append(_origin_header(self._options.origin))
321
322        original_key = os.urandom(16)
323        key = base64.b64encode(original_key)
324        self._logger.debug(
325            'Sec-WebSocket-Key: %s (%s)', key, util.hexify(original_key))
326        fields.append('Sec-WebSocket-Key: %s\r\n' % key)
327
328        fields.append('Sec-WebSocket-Version: %d\r\n' % self._options.version)
329
330        # Setting up extensions.
331        if len(self._options.extensions) > 0:
332            fields.append('Sec-WebSocket-Extensions: %s\r\n' %
333                          ', '.join(self._options.extensions))
334
335        self._logger.debug('Opening handshake request headers: %r', fields)
336
337        for field in fields:
338            self._socket.sendall(field)
339        self._socket.sendall('\r\n')
340
341        self._logger.info('Sent opening handshake request')
342
343        field = ''
344        while True:
345            ch = receive_bytes(self._socket, 1)
346            field += ch
347            if ch == '\n':
348                break
349
350        self._logger.debug('Opening handshake Response-Line: %r', field)
351
352        if len(field) < 7 or not field.endswith('\r\n'):
353            raise Exception('Wrong status line: %r' % field)
354        m = re.match('[^ ]* ([^ ]*) .*', field)
355        if m is None:
356            raise Exception(
357                'No HTTP status code found in status line: %r' % field)
358        code = m.group(1)
359        if not re.match('[0-9][0-9][0-9]', code):
360            raise Exception(
361                'HTTP status code %r is not three digit in status line: %r' %
362                (code, field))
363        if code != '101':
364            raise HttpStatusException(
365                'Expected HTTP status code 101 but found %r in status line: '
366                '%r' % (code, field), int(code))
367        fields = _read_fields(self._socket)
368        ch = receive_bytes(self._socket, 1)
369        if ch != '\n':  # 0x0A
370            raise Exception('Expected LF but found: %r' % ch)
371
372        self._logger.debug('Opening handshake response headers: %r', fields)
373
374        # Check /fields/
375        if len(fields['upgrade']) != 1:
376            raise Exception(
377                'Multiple Upgrade headers found: %s' % fields['upgrade'])
378        if len(fields['connection']) != 1:
379            raise Exception(
380                'Multiple Connection headers found: %s' % fields['connection'])
381        if fields['upgrade'][0] != 'websocket':
382            raise Exception(
383                'Unexpected Upgrade header value: %s' % fields['upgrade'][0])
384        if fields['connection'][0].lower() != 'upgrade':
385            raise Exception(
386                'Unexpected Connection header value: %s' %
387                fields['connection'][0])
388
389        if len(fields['sec-websocket-accept']) != 1:
390            raise Exception(
391                'Multiple Sec-WebSocket-Accept headers found: %s' %
392                fields['sec-websocket-accept'])
393
394        accept = fields['sec-websocket-accept'][0]
395
396        # Validate
397        try:
398            decoded_accept = base64.b64decode(accept)
399        except TypeError, e:
400            raise HandshakeException(
401                'Illegal value for header Sec-WebSocket-Accept: ' + accept)
402
403        if len(decoded_accept) != 20:
404            raise HandshakeException(
405                'Decoded value of Sec-WebSocket-Accept is not 20-byte long')
406
407        self._logger.debug('Actual Sec-WebSocket-Accept: %r (%s)',
408                           accept, util.hexify(decoded_accept))
409
410        original_expected_accept = util.sha1_hash(
411            key + WEBSOCKET_ACCEPT_UUID).digest()
412        expected_accept = base64.b64encode(original_expected_accept)
413
414        self._logger.debug('Expected Sec-WebSocket-Accept: %r (%s)',
415                           expected_accept,
416                           util.hexify(original_expected_accept))
417
418        if accept != expected_accept:
419            raise Exception(
420                'Invalid Sec-WebSocket-Accept header: %r (expected) != %r '
421                '(actual)' % (accept, expected_accept))
422
423        server_extensions_header = fields.get('sec-websocket-extensions')
424        if (server_extensions_header is None or
425            len(server_extensions_header) != 1):
426            accepted_extensions = []
427        else:
428            accepted_extensions = server_extensions_header[0].split(',')
429            # TODO(tyoshino): Follow the ABNF in the spec.
430            accepted_extensions = [s.strip() for s in accepted_extensions]
431
432        # Scan accepted extension list to check if there is any unrecognized
433        # extensions or extensions we didn't request in it. Then, for
434        # extensions we request, parse them and store parameters. They will be
435        # used later by each extension.
436        deflate_stream_accepted = False
437        deflate_frame_accepted = False
438        mux_accepted = False
439        for extension in accepted_extensions:
440            if extension == '':
441                continue
442            if extension == _DEFLATE_STREAM_EXTENSION:
443                if self._options.use_deflate_stream:
444                    deflate_stream_accepted = True
445                    continue
446            if extension == _DEFLATE_FRAME_EXTENSION:
447                if self._options.use_deflate_frame:
448                    deflate_frame_accepted = True
449                    continue
450            if extension == _MUX_EXTENSION:
451                if self._options.use_mux:
452                    mux_accepted = True
453                    continue
454
455            raise Exception(
456                'Received unrecognized extension: %s' % extension)
457
458        # Let all extensions check the response for extension request.
459
460        if self._options.use_deflate_stream and not deflate_stream_accepted:
461            raise Exception('%s extension not accepted' %
462                            _DEFLATE_STREAM_EXTENSION)
463
464        if (self._options.use_deflate_frame and
465            not deflate_frame_accepted):
466            raise Exception('%s extension not accepted' %
467                            _DEFLATE_FRAME_EXTENSION)
468
469        if self._options.use_mux and not mux_accepted:
470            raise Exception('%s extension not accepted' % _MUX_EXTENSION)
471
472
473class WebSocketHybi00Handshake(object):
474    """Opening handshake processor for the WebSocket protocol version HyBi 00.
475    """
476
477    def __init__(self, options, draft_field):
478        self._logger = util.get_class_logger(self)
479
480        self._options = options
481        self._draft_field = draft_field
482
483    def handshake(self, socket):
484        """Handshake WebSocket.
485
486        Raises:
487            Exception: handshake failed.
488        """
489
490        self._socket = socket
491
492        # 4.1 5. send request line.
493        request_line = _method_line(self._options.resource)
494        self._logger.debug('Opening handshake Request-Line: %r', request_line)
495        self._socket.sendall(request_line)
496        # 4.1 6. Let /fields/ be an empty list of strings.
497        fields = []
498        # 4.1 7. Add the string "Upgrade: WebSocket" to /fields/.
499        fields.append(_UPGRADE_HEADER_HIXIE75)
500        # 4.1 8. Add the string "Connection: Upgrade" to /fields/.
501        fields.append(_CONNECTION_HEADER)
502        # 4.1 9-12. Add Host: field to /fields/.
503        fields.append(_format_host_header(
504            self._options.server_host,
505            self._options.server_port,
506            self._options.use_tls))
507        # 4.1 13. Add Origin: field to /fields/.
508        fields.append(_origin_header(self._options.origin))
509        # TODO: 4.1 14 Add Sec-WebSocket-Protocol: field to /fields/.
510        # TODO: 4.1 15 Add cookie headers to /fields/.
511
512        # 4.1 16-23. Add Sec-WebSocket-Key<n> to /fields/.
513        self._number1, key1 = self._generate_sec_websocket_key()
514        self._logger.debug('Number1: %d', self._number1)
515        fields.append('Sec-WebSocket-Key1: %s\r\n' % key1)
516        self._number2, key2 = self._generate_sec_websocket_key()
517        self._logger.debug('Number2: %d', self._number1)
518        fields.append('Sec-WebSocket-Key2: %s\r\n' % key2)
519
520        fields.append('Sec-WebSocket-Draft: %s\r\n' % self._draft_field)
521
522        # 4.1 24. For each string in /fields/, in a random order: send the
523        # string, encoded as UTF-8, followed by a UTF-8 encoded U+000D CARRIAGE
524        # RETURN U+000A LINE FEED character pair (CRLF).
525        random.shuffle(fields)
526
527        self._logger.debug('Opening handshake request headers: %r', fields)
528        for field in fields:
529            self._socket.sendall(field)
530
531        # 4.1 25. send a UTF-8-encoded U+000D CARRIAGE RETURN U+000A LINE FEED
532        # character pair (CRLF).
533        self._socket.sendall('\r\n')
534        # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
535        # equivalently, a random 64 bit integer encoded in a big-endian order).
536        self._key3 = self._generate_key3()
537        # 4.1 27. send /key3/ to the server.
538        self._socket.sendall(self._key3)
539        self._logger.debug(
540            'Key3: %r (%s)', self._key3, util.hexify(self._key3))
541
542        self._logger.info('Sent opening handshake request')
543
544        # 4.1 28. Read bytes from the server until either the connection
545        # closes, or a 0x0A byte is read. let /field/ be these bytes, including
546        # the 0x0A bytes.
547        field = ''
548        while True:
549            ch = receive_bytes(self._socket, 1)
550            field += ch
551            if ch == '\n':
552                break
553
554        self._logger.debug('Opening handshake Response-Line: %r', field)
555
556        # if /field/ is not at least seven bytes long, or if the last
557        # two bytes aren't 0x0D and 0x0A respectively, or if it does not
558        # contain at least two 0x20 bytes, then fail the WebSocket connection
559        # and abort these steps.
560        if len(field) < 7 or not field.endswith('\r\n'):
561            raise Exception('Wrong status line: %r' % field)
562        m = re.match('[^ ]* ([^ ]*) .*', field)
563        if m is None:
564            raise Exception('No code found in status line: %r' % field)
565        # 4.1 29. let /code/ be the substring of /field/ that starts from the
566        # byte after the first 0x20 byte, and ends with the byte before the
567        # second 0x20 byte.
568        code = m.group(1)
569        # 4.1 30. if /code/ is not three bytes long, or if any of the bytes in
570        # /code/ are not in the range 0x30 to 0x90, then fail the WebSocket
571        # connection and abort these steps.
572        if not re.match('[0-9][0-9][0-9]', code):
573            raise Exception(
574                'HTTP status code %r is not three digit in status line: %r' %
575                (code, field))
576        # 4.1 31. if /code/, interpreted as UTF-8, is "101", then move to the
577        # next step.
578        if code != '101':
579            raise HttpStatusException(
580                'Expected HTTP status code 101 but found %r in status line: '
581                '%r' % (code, field), int(code))
582        # 4.1 32-39. read fields into /fields/
583        fields = _read_fields(self._socket)
584
585        self._logger.debug('Opening handshake response headers: %r', fields)
586
587        # 4.1 40. _Fields processing_
588        # read a byte from server
589        ch = receive_bytes(self._socket, 1)
590        if ch != '\n':  # 0x0A
591            raise Exception('Expected LF but found %r' % ch)
592        # 4.1 41. check /fields/
593        if len(fields['upgrade']) != 1:
594            raise Exception(
595                'Multiple Upgrade headers found: %s' % fields['upgrade'])
596        if len(fields['connection']) != 1:
597            raise Exception(
598                'Multiple Connection headers found: %s' % fields['connection'])
599        if len(fields['sec-websocket-origin']) != 1:
600            raise Exception(
601                'Multiple Sec-WebSocket-Origin headers found: %s' %
602                fields['sec-sebsocket-origin'])
603        if len(fields['sec-websocket-location']) != 1:
604            raise Exception(
605                'Multiple Sec-WebSocket-Location headers found: %s' %
606                fields['sec-sebsocket-location'])
607        # TODO(ukai): protocol
608        # if the entry's name is "upgrade"
609        #  if the value is not exactly equal to the string "WebSocket",
610        #  then fail the WebSocket connection and abort these steps.
611        if fields['upgrade'][0] != 'WebSocket':
612            raise Exception(
613                'Unexpected Upgrade header value: %s' % fields['upgrade'][0])
614        # if the entry's name is "connection"
615        #  if the value, converted to ASCII lowercase, is not exactly equal
616        #  to the string "upgrade", then fail the WebSocket connection and
617        #  abort these steps.
618        if fields['connection'][0].lower() != 'upgrade':
619            raise Exception(
620                'Unexpected Connection header value: %s' %
621                fields['connection'][0])
622        # TODO(ukai): check origin, location, cookie, ..
623
624        # 4.1 42. let /challenge/ be the concatenation of /number_1/,
625        # expressed as a big endian 32 bit integer, /number_2/, expressed
626        # as big endian 32 bit integer, and the eight bytes of /key_3/ in the
627        # order they were sent on the wire.
628        challenge = struct.pack('!I', self._number1)
629        challenge += struct.pack('!I', self._number2)
630        challenge += self._key3
631
632        self._logger.debug(
633            'Challenge: %r (%s)', challenge, util.hexify(challenge))
634
635        # 4.1 43. let /expected/ be the MD5 fingerprint of /challenge/ as a
636        # big-endian 128 bit string.
637        expected = util.md5_hash(challenge).digest()
638        self._logger.debug(
639            'Expected challenge response: %r (%s)',
640            expected, util.hexify(expected))
641
642        # 4.1 44. read sixteen bytes from the server.
643        # let /reply/ be those bytes.
644        reply = receive_bytes(self._socket, 16)
645        self._logger.debug(
646            'Actual challenge response: %r (%s)', reply, util.hexify(reply))
647
648        # 4.1 45. if /reply/ does not exactly equal /expected/, then fail
649        # the WebSocket connection and abort these steps.
650        if expected != reply:
651            raise Exception(
652                'Bad challenge response: %r (expected) != %r (actual)' %
653                (expected, reply))
654        # 4.1 46. The *WebSocket connection is established*.
655
656    def _generate_sec_websocket_key(self):
657        # 4.1 16. let /spaces_n/ be a random integer from 1 to 12 inclusive.
658        spaces = random.randint(1, 12)
659        # 4.1 17. let /max_n/ be the largest integer not greater than
660        #  4,294,967,295 divided by /spaces_n/.
661        maxnum = 4294967295 / spaces
662        # 4.1 18. let /number_n/ be a random integer from 0 to /max_n/
663        # inclusive.
664        number = random.randint(0, maxnum)
665        # 4.1 19. let /product_n/ be the result of multiplying /number_n/ and
666        # /spaces_n/ together.
667        product = number * spaces
668        # 4.1 20. let /key_n/ be a string consisting of /product_n/, expressed
669        # in base ten using the numerals in the range U+0030 DIGIT ZERO (0) to
670        # U+0039 DIGIT NINE (9).
671        key = str(product)
672        # 4.1 21. insert between one and twelve random characters from the
673        # range U+0021 to U+002F and U+003A to U+007E into /key_n/ at random
674        # positions.
675        available_chars = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1)
676        n = random.randint(1, 12)
677        for _ in xrange(n):
678            ch = random.choice(available_chars)
679            pos = random.randint(0, len(key))
680            key = key[0:pos] + chr(ch) + key[pos:]
681        # 4.1 22. insert /spaces_n/ U+0020 SPACE characters into /key_n/ at
682        # random positions other than start or end of the string.
683        for _ in xrange(spaces):
684            pos = random.randint(1, len(key) - 1)
685            key = key[0:pos] + ' ' + key[pos:]
686        return number, key
687
688    def _generate_key3(self):
689        # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
690        # equivalently, a random 64 bit integer encoded in a big-endian order).
691        return ''.join([chr(random.randint(0, 255)) for _ in xrange(8)])
692
693
694class WebSocketHixie75Handshake(object):
695    """WebSocket handshake processor for IETF Hixie 75."""
696
697    _EXPECTED_RESPONSE = (
698        'HTTP/1.1 101 Web Socket Protocol Handshake\r\n' +
699        _UPGRADE_HEADER_HIXIE75 +
700        _CONNECTION_HEADER)
701
702    def __init__(self, options):
703        self._logger = util.get_class_logger(self)
704
705        self._options = options
706
707    def _skip_headers(self):
708        terminator = '\r\n\r\n'
709        pos = 0
710        while pos < len(terminator):
711            received = receive_bytes(self._socket, 1)
712            if received == terminator[pos]:
713                pos += 1
714            elif received == terminator[0]:
715                pos = 1
716            else:
717                pos = 0
718
719    def handshake(self, socket):
720        self._socket = socket
721
722        request_line = _method_line(self._options.resource)
723        self._logger.debug('Opening handshake Request-Line: %r', request_line)
724        self._socket.sendall(request_line)
725
726        headers = _UPGRADE_HEADER_HIXIE75 + _CONNECTION_HEADER
727        headers += _format_host_header(
728            self._options.server_host,
729            self._options.server_port,
730            self._options.use_tls)
731        headers += _origin_header(self._options.origin)
732        self._logger.debug('Opening handshake request headers: %r', headers)
733        self._socket.sendall(headers)
734
735        self._socket.sendall('\r\n')
736
737        self._logger.info('Sent opening handshake request')
738
739        for expected_char in WebSocketHixie75Handshake._EXPECTED_RESPONSE:
740            received = receive_bytes(self._socket, 1)
741            if expected_char != received:
742                raise Exception('Handshake failure')
743        # We cut corners and skip other headers.
744        self._skip_headers()
745
746
747class WebSocketStream(object):
748    """Frame processor for the WebSocket protocol (RFC 6455)."""
749
750    def __init__(self, socket, handshake):
751        self._handshake = handshake
752        if self._handshake._options.use_deflate_stream:
753            self._socket = util.DeflateSocket(socket)
754        else:
755            self._socket = socket
756
757        # Filters applied to application data part of data frames.
758        self._outgoing_frame_filter = None
759        self._incoming_frame_filter = None
760
761        if self._handshake._options.use_deflate_frame:
762            self._outgoing_frame_filter = (
763                util._RFC1979Deflater(None, False))
764            self._incoming_frame_filter = util._RFC1979Inflater()
765
766        self._fragmented = False
767
768    def _mask_hybi(self, s):
769        # TODO(tyoshino): os.urandom does open/read/close for every call. If
770        # performance matters, change this to some library call that generates
771        # cryptographically secure pseudo random number sequence.
772        masking_nonce = os.urandom(4)
773        result = [masking_nonce]
774        count = 0
775        for c in s:
776            result.append(chr(ord(c) ^ ord(masking_nonce[count])))
777            count = (count + 1) % len(masking_nonce)
778        return ''.join(result)
779
780    def send_frame_of_arbitrary_bytes(self, header, body):
781        self._socket.sendall(header + self._mask_hybi(body))
782
783    def send_data(self, payload, frame_type, end=True, mask=True):
784        if self._outgoing_frame_filter is not None:
785            payload = self._outgoing_frame_filter.filter(payload)
786
787        if self._fragmented:
788            opcode = OPCODE_CONTINUATION
789        else:
790            opcode = frame_type
791
792        if end:
793            self._fragmented = False
794            fin = 1
795        else:
796            self._fragmented = True
797            fin = 0
798
799        rsv1 = 0
800        if self._handshake._options.use_deflate_frame:
801            rsv1 = 1
802
803        if mask:
804            mask_bit = 1 << 7
805        else:
806            mask_bit = 0
807
808        header = chr(fin << 7 | rsv1 << 6 | opcode)
809        payload_length = len(payload)
810        if payload_length <= 125:
811            header += chr(mask_bit | payload_length)
812        elif payload_length < 1 << 16:
813            header += chr(mask_bit | 126) + struct.pack('!H', payload_length)
814        elif payload_length < 1 << 63:
815            header += chr(mask_bit | 127) + struct.pack('!Q', payload_length)
816        else:
817            raise Exception('Too long payload (%d byte)' % payload_length)
818        if mask:
819            payload = self._mask_hybi(payload)
820        self._socket.sendall(header + payload)
821
822    def send_binary(self, payload, end=True, mask=True):
823        self.send_data(payload, OPCODE_BINARY, end, mask)
824
825    def send_text(self, payload, end=True, mask=True):
826        self.send_data(payload.encode('utf-8'), OPCODE_TEXT, end, mask)
827
828    def _assert_receive_data(self, payload, opcode, fin, rsv1, rsv2, rsv3):
829        (actual_fin, actual_rsv1, actual_rsv2, actual_rsv3, actual_opcode,
830         payload_length) = read_frame_header(self._socket)
831
832        if actual_opcode != opcode:
833            raise Exception(
834                'Unexpected opcode: %d (expected) vs %d (actual)' %
835                (opcode, actual_opcode))
836
837        if actual_fin != fin:
838            raise Exception(
839                'Unexpected fin: %d (expected) vs %d (actual)' %
840                (fin, actual_fin))
841
842        if rsv1 is None:
843            rsv1 = 0
844            if self._handshake._options.use_deflate_frame:
845                rsv1 = 1
846
847        if rsv2 is None:
848            rsv2 = 0
849
850        if rsv3 is None:
851            rsv3 = 0
852
853        if actual_rsv1 != rsv1:
854            raise Exception(
855                'Unexpected rsv1: %r (expected) vs %r (actual)' %
856                (rsv1, actual_rsv1))
857
858        if actual_rsv2 != rsv2:
859            raise Exception(
860                'Unexpected rsv2: %r (expected) vs %r (actual)' %
861                (rsv2, actual_rsv2))
862
863        if actual_rsv3 != rsv3:
864            raise Exception(
865                'Unexpected rsv3: %r (expected) vs %r (actual)' %
866                (rsv3, actual_rsv3))
867
868        received = receive_bytes(self._socket, payload_length)
869
870        if self._incoming_frame_filter is not None:
871            received = self._incoming_frame_filter.filter(received)
872
873        if len(received) != len(payload):
874            raise Exception(
875                'Unexpected payload length: %d (expected) vs %d (actual)' %
876                (len(payload), len(received)))
877
878        if payload != received:
879            raise Exception(
880                'Unexpected payload: %r (expected) vs %r (actual)' %
881                (payload, received))
882
883    def assert_receive_binary(self, payload, opcode=OPCODE_BINARY, fin=1,
884                              rsv1=None, rsv2=None, rsv3=None):
885        self._assert_receive_data(payload, opcode, fin, rsv1, rsv2, rsv3)
886
887    def assert_receive_text(self, payload, opcode=OPCODE_TEXT, fin=1,
888                            rsv1=None, rsv2=None, rsv3=None):
889        self._assert_receive_data(payload.encode('utf-8'), opcode, fin, rsv1,
890                                  rsv2, rsv3)
891
892    def _build_close_frame(self, code, reason, mask):
893        frame = chr(1 << 7 | OPCODE_CLOSE)
894
895        if code is not None:
896            body = struct.pack('!H', code) + reason.encode('utf-8')
897        else:
898            body = ''
899        if mask:
900            frame += chr(1 << 7 | len(body)) + self._mask_hybi(body)
901        else:
902            frame += chr(len(body)) + body
903        return frame
904
905    def send_close(self, code, reason):
906        self._socket.sendall(
907            self._build_close_frame(code, reason, True))
908
909    def assert_receive_close(self, code, reason):
910        expected_frame = self._build_close_frame(code, reason, False)
911        actual_frame = receive_bytes(self._socket, len(expected_frame))
912        if actual_frame != expected_frame:
913            raise Exception(
914                'Unexpected close frame: %r (expected) vs %r (actual)' %
915                (expected_frame, actual_frame))
916
917
918class WebSocketStreamHixie75(object):
919    """Frame processor for the WebSocket protocol version Hixie 75 and HyBi 00.
920    """
921
922    _CLOSE_FRAME = '\xff\x00'
923
924    def __init__(self, socket, unused_handshake):
925        self._socket = socket
926
927    def send_frame_of_arbitrary_bytes(self, header, body):
928        self._socket.sendall(header + body)
929
930    def send_data(self, payload, unused_frame_typem, unused_end, unused_mask):
931        frame = ''.join(['\x00', payload, '\xff'])
932        self._socket.sendall(frame)
933
934    def send_binary(self, unused_payload, unused_end, unused_mask):
935        pass
936
937    def send_text(self, payload, unused_end, unused_mask):
938        encoded_payload = payload.encode('utf-8')
939        frame = ''.join(['\x00', encoded_payload, '\xff'])
940        self._socket.sendall(frame)
941
942    def assert_receive_binary(self, payload, opcode=OPCODE_BINARY, fin=1,
943                              rsv1=0, rsv2=0, rsv3=0):
944        raise Exception('Binary frame is not supported in hixie75')
945
946    def assert_receive_text(self, payload):
947        received = receive_bytes(self._socket, 1)
948
949        if received != '\x00':
950            raise Exception(
951                'Unexpected frame type: %d (expected) vs %d (actual)' %
952                (0, ord(received)))
953
954        received = receive_bytes(self._socket, len(payload) + 1)
955        if received[-1] != '\xff':
956            raise Exception(
957                'Termination expected: 0xff (expected) vs %r (actual)' %
958                received)
959
960        if received[0:-1] != payload:
961            raise Exception(
962                'Unexpected payload: %r (expected) vs %r (actual)' %
963                (payload, received[0:-1]))
964
965    def send_close(self, code, reason):
966        self._socket.sendall(self._CLOSE_FRAME)
967
968    def assert_receive_close(self, unused_code, unused_reason):
969        closing = receive_bytes(self._socket, len(self._CLOSE_FRAME))
970        if closing != self._CLOSE_FRAME:
971            raise Exception('Didn\'t receive closing handshake')
972
973
974class ClientOptions(object):
975    """Holds option values to configure the Client object."""
976
977    def __init__(self):
978        self.version = 13
979        self.server_host = ''
980        self.origin = ''
981        self.resource = ''
982        self.server_port = -1
983        self.socket_timeout = 1000
984        self.use_tls = False
985        self.extensions = []
986        # Enable deflate-stream.
987        self.use_deflate_stream = False
988        # Enable deflate-application-data.
989        self.use_deflate_frame = False
990        # Enable mux
991        self.use_mux = False
992
993    def enable_deflate_stream(self):
994        self.use_deflate_stream = True
995        self.extensions.append(_DEFLATE_STREAM_EXTENSION)
996
997    def enable_deflate_frame(self):
998        self.use_deflate_frame = True
999        self.extensions.append(_DEFLATE_FRAME_EXTENSION)
1000
1001    def enable_mux(self):
1002        self.use_mux = True
1003        self.extensions.append(_MUX_EXTENSION)
1004
1005
1006class Client(object):
1007    """WebSocket client."""
1008
1009    def __init__(self, options, handshake, stream_class):
1010        self._logger = util.get_class_logger(self)
1011
1012        self._options = options
1013        self._socket = None
1014
1015        self._handshake = handshake
1016        self._stream_class = stream_class
1017
1018    def connect(self):
1019        self._socket = socket.socket()
1020        self._socket.settimeout(self._options.socket_timeout)
1021
1022        self._socket.connect((self._options.server_host,
1023                              self._options.server_port))
1024        if self._options.use_tls:
1025            self._socket = _TLSSocket(self._socket)
1026
1027        self._handshake.handshake(self._socket)
1028
1029        self._stream = self._stream_class(self._socket, self._handshake)
1030
1031        self._logger.info('Connection established')
1032
1033    def send_frame_of_arbitrary_bytes(self, header, body):
1034        self._stream.send_frame_of_arbitrary_bytes(header, body)
1035
1036    def send_message(self, message, end=True, binary=False, raw=False,
1037                     mask=True):
1038        if binary:
1039            self._stream.send_binary(message, end, mask)
1040        elif raw:
1041            self._stream.send_data(message, OPCODE_TEXT, end, mask)
1042        else:
1043            self._stream.send_text(message, end, mask)
1044
1045    def assert_receive(self, payload, binary=False):
1046        if binary:
1047            self._stream.assert_receive_binary(payload)
1048        else:
1049            self._stream.assert_receive_text(payload)
1050
1051    def send_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
1052        self._stream.send_close(code, reason)
1053
1054    def assert_receive_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
1055        self._stream.assert_receive_close(code, reason)
1056
1057    def close_socket(self):
1058        self._socket.close()
1059
1060    def assert_connection_closed(self):
1061        try:
1062            read_data = receive_bytes(self._socket, 1)
1063        except Exception, e:
1064            if str(e).find(
1065                'Connection closed before receiving requested length ') == 0:
1066                return
1067            try:
1068                error_number, message = e
1069                for error_name in ['ECONNRESET', 'WSAECONNRESET']:
1070                    if (error_name in dir(errno) and
1071                        error_number == getattr(errno, error_name)):
1072                        return
1073            except:
1074                raise e
1075            raise e
1076
1077        raise Exception('Connection is not closed (Read: %r)' % read_data)
1078
1079
1080def create_client(options):
1081    return Client(
1082        options, WebSocketHandshake(options), WebSocketStream)
1083
1084
1085def create_client_hybi00(options):
1086    return Client(
1087        options,
1088        WebSocketHybi00Handshake(options, '0'),
1089        WebSocketStreamHixie75)
1090
1091
1092def create_client_hixie75(options):
1093    return Client(
1094        options, WebSocketHixie75Handshake(options), WebSocketStreamHixie75)
1095
1096
1097# vi:sts=4 sw=4 et
1098