• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2011, Google Inc.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32
33"""Simple WebSocket client named echo_client just because of historical reason.
34
35mod_pywebsocket directory must be in PYTHONPATH.
36
37Example Usage:
38
39# server setup
40 % cd $pywebsocket
41 % PYTHONPATH=$cwd/src python ./mod_pywebsocket/standalone.py -p 8880 \
42    -d $cwd/src/example
43
44# run client
45 % PYTHONPATH=$cwd/src python ./src/example/echo_client.py -p 8880 \
46     -s localhost \
47     -o http://localhost -r /echo -m test
48
49or
50
51# run echo client to test IETF HyBi 00 protocol
52 run with --protocol-version=hybi00
53
54or
55
56# server setup to test Hixie 75 protocol
57 run with --allow-draft75
58
59# run echo client to test Hixie 75 protocol
60 run with --protocol-version=hixie75
61"""
62
63
64import base64
65import codecs
66import logging
67from optparse import OptionParser
68import os
69import random
70import re
71import socket
72import struct
73import sys
74
75from mod_pywebsocket import common
76from mod_pywebsocket.extensions import DeflateFrameExtensionProcessor
77from mod_pywebsocket.stream import Stream
78from mod_pywebsocket.stream import StreamHixie75
79from mod_pywebsocket.stream import StreamOptions
80from mod_pywebsocket import util
81
82
83_TIMEOUT_SEC = 10
84_UNDEFINED_PORT = -1
85
86_UPGRADE_HEADER = 'Upgrade: websocket\r\n'
87_UPGRADE_HEADER_HIXIE75 = 'Upgrade: WebSocket\r\n'
88_CONNECTION_HEADER = 'Connection: Upgrade\r\n'
89
90# Special message that tells the echo server to start closing handshake
91_GOODBYE_MESSAGE = 'Goodbye'
92
93_PROTOCOL_VERSION_HYBI13 = 'hybi13'
94_PROTOCOL_VERSION_HYBI08 = 'hybi08'
95_PROTOCOL_VERSION_HYBI00 = 'hybi00'
96_PROTOCOL_VERSION_HIXIE75 = 'hixie75'
97
98
99class ClientHandshakeError(Exception):
100    pass
101
102
103def _build_method_line(resource):
104    return 'GET %s HTTP/1.1\r\n' % resource
105
106
107def _origin_header(header, origin):
108    # 4.1 13. concatenation of the string "Origin:", a U+0020 SPACE character,
109    # and the /origin/ value, converted to ASCII lowercase, to /fields/.
110    return '%s: %s\r\n' % (header, origin.lower())
111
112
113def _format_host_header(host, port, secure):
114    # 4.1 9. Let /hostport/ be an empty string.
115    # 4.1 10. Append the /host/ value, converted to ASCII lowercase, to
116    # /hostport/
117    hostport = host.lower()
118    # 4.1 11. If /secure/ is false, and /port/ is not 80, or if /secure/
119    # is true, and /port/ is not 443, then append a U+003A COLON character
120    # (:) followed by the value of /port/, expressed as a base-ten integer,
121    # to /hostport/
122    if ((not secure and port != common.DEFAULT_WEB_SOCKET_PORT) or
123        (secure and port != common.DEFAULT_WEB_SOCKET_SECURE_PORT)):
124        hostport += ':' + str(port)
125    # 4.1 12. concatenation of the string "Host:", a U+0020 SPACE
126    # character, and /hostport/, to /fields/.
127    return '%s: %s\r\n' % (common.HOST_HEADER, hostport)
128
129
130def _receive_bytes(socket, length):
131    bytes = []
132    remaining = length
133    while remaining > 0:
134        received_bytes = socket.recv(remaining)
135        if not received_bytes:
136            raise IOError(
137                'Connection closed before receiving requested length '
138                '(requested %d bytes but received only %d bytes)' %
139                (length, length - remaining))
140        bytes.append(received_bytes)
141        remaining -= len(received_bytes)
142    return ''.join(bytes)
143
144
145def _get_mandatory_header(fields, name):
146    """Gets the value of the header specified by name from fields.
147
148    This function expects that there's only one header with the specified name
149    in fields. Otherwise, raises an ClientHandshakeError.
150    """
151
152    values = fields.get(name.lower())
153    if values is None or len(values) == 0:
154        raise ClientHandshakeError(
155            '%s header not found: %r' % (name, values))
156    if len(values) > 1:
157        raise ClientHandshakeError(
158            'Multiple %s headers found: %r' % (name, values))
159    return values[0]
160
161
162def _validate_mandatory_header(fields, name,
163                               expected_value, case_sensitive=False):
164    """Gets and validates the value of the header specified by name from
165    fields.
166
167    If expected_value is specified, compares expected value and actual value
168    and raises an ClientHandshakeError on failure. You can specify case
169    sensitiveness in this comparison by case_sensitive parameter. This function
170    expects that there's only one header with the specified name in fields.
171    Otherwise, raises an ClientHandshakeError.
172    """
173
174    value = _get_mandatory_header(fields, name)
175
176    if ((case_sensitive and value != expected_value) or
177        (not case_sensitive and value.lower() != expected_value.lower())):
178        raise ClientHandshakeError(
179            'Illegal value for header %s: %r (expected) vs %r (actual)' %
180            (name, expected_value, value))
181
182
183class _TLSSocket(object):
184    """Wrapper for a TLS connection."""
185
186    def __init__(self, raw_socket):
187        self._ssl = socket.ssl(raw_socket)
188
189    def send(self, bytes):
190        return self._ssl.write(bytes)
191
192    def recv(self, size=-1):
193        return self._ssl.read(size)
194
195    def close(self):
196        # Nothing to do.
197        pass
198
199
200class ClientHandshakeBase(object):
201    """A base class for WebSocket opening handshake processors for each
202    protocol version.
203    """
204
205    def __init__(self):
206        self._logger = util.get_class_logger(self)
207
208    def _read_fields(self):
209        # 4.1 32. let /fields/ be a list of name-value pairs, initially empty.
210        fields = {}
211        while True:  # "Field"
212            # 4.1 33. let /name/ and /value/ be empty byte arrays
213            name = ''
214            value = ''
215            # 4.1 34. read /name/
216            name = self._read_name()
217            if name is None:
218                break
219            # 4.1 35. read spaces
220            # TODO(tyoshino): Skip only one space as described in the spec.
221            ch = self._skip_spaces()
222            # 4.1 36. read /value/
223            value = self._read_value(ch)
224            # 4.1 37. read a byte from the server
225            ch = _receive_bytes(self._socket, 1)
226            if ch != '\n':  # 0x0A
227                raise ClientHandshakeError(
228                    'Expected LF but found %r while reading value %r for '
229                    'header %r' % (ch, value, name))
230            self._logger.debug('Received %r header', name)
231            # 4.1 38. append an entry to the /fields/ list that has the name
232            # given by the string obtained by interpreting the /name/ byte
233            # array as a UTF-8 stream and the value given by the string
234            # obtained by interpreting the /value/ byte array as a UTF-8 byte
235            # stream.
236            fields.setdefault(name, []).append(value)
237            # 4.1 39. return to the "Field" step above
238        return fields
239
240    def _read_name(self):
241        # 4.1 33. let /name/ be empty byte arrays
242        name = ''
243        while True:
244            # 4.1 34. read a byte from the server
245            ch = _receive_bytes(self._socket, 1)
246            if ch == '\r':  # 0x0D
247                return None
248            elif ch == '\n':  # 0x0A
249                raise ClientHandshakeError(
250                    'Unexpected LF when reading header name %r' % name)
251            elif ch == ':':  # 0x3A
252                return name
253            elif ch >= 'A' and ch <= 'Z':  # Range 0x31 to 0x5A
254                ch = chr(ord(ch) + 0x20)
255                name += ch
256            else:
257                name += ch
258
259    def _skip_spaces(self):
260        # 4.1 35. read a byte from the server
261        while True:
262            ch = _receive_bytes(self._socket, 1)
263            if ch == ' ':  # 0x20
264                continue
265            return ch
266
267    def _read_value(self, ch):
268        # 4.1 33. let /value/ be empty byte arrays
269        value = ''
270        # 4.1 36. read a byte from server.
271        while True:
272            if ch == '\r':  # 0x0D
273                return value
274            elif ch == '\n':  # 0x0A
275                raise ClientHandshakeError(
276                    'Unexpected LF when reading header value %r' % value)
277            else:
278                value += ch
279            ch = _receive_bytes(self._socket, 1)
280
281
282class ClientHandshakeProcessor(ClientHandshakeBase):
283    """WebSocket opening handshake processor for
284    draft-ietf-hybi-thewebsocketprotocol-06 and later.
285    """
286
287    def __init__(self, socket, options):
288        super(ClientHandshakeProcessor, self).__init__()
289
290        self._socket = socket
291        self._options = options
292
293        self._logger = util.get_class_logger(self)
294
295    def handshake(self):
296        """Performs opening handshake on the specified socket.
297
298        Raises:
299            ClientHandshakeError: handshake failed.
300        """
301
302        request_line = _build_method_line(self._options.resource)
303        self._logger.debug('Client\'s opening handshake Request-Line: %r',
304                           request_line)
305        self._socket.sendall(request_line)
306
307        fields = []
308        fields.append(_format_host_header(
309            self._options.server_host,
310            self._options.server_port,
311            self._options.use_tls))
312        fields.append(_UPGRADE_HEADER)
313        fields.append(_CONNECTION_HEADER)
314        if self._options.origin is not None:
315            if self._options.protocol_version == _PROTOCOL_VERSION_HYBI08:
316                fields.append(_origin_header(
317                    common.SEC_WEBSOCKET_ORIGIN_HEADER,
318                    self._options.origin))
319            else:
320                fields.append(_origin_header(common.ORIGIN_HEADER,
321                                             self._options.origin))
322
323        original_key = os.urandom(16)
324        self._key = base64.b64encode(original_key)
325        self._logger.debug(
326            '%s: %r (%s)',
327            common.SEC_WEBSOCKET_KEY_HEADER,
328            self._key,
329            util.hexify(original_key))
330        fields.append(
331            '%s: %s\r\n' % (common.SEC_WEBSOCKET_KEY_HEADER, self._key))
332
333        if self._options.version_header > 0:
334            fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER,
335                                          self._options.version_header))
336        elif self._options.protocol_version == _PROTOCOL_VERSION_HYBI08:
337            fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER,
338                                          common.VERSION_HYBI08))
339        else:
340            fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER,
341                                          common.VERSION_HYBI_LATEST))
342
343        extensions_to_request = []
344
345        if self._options.deflate_stream:
346            extensions_to_request.append(
347                common.ExtensionParameter(
348                    common.DEFLATE_STREAM_EXTENSION))
349
350        if self._options.deflate_frame:
351            extensions_to_request.append(
352                common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION))
353
354        if len(extensions_to_request) != 0:
355            fields.append(
356                '%s: %s\r\n' %
357                (common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
358                 common.format_extensions(extensions_to_request)))
359
360        for field in fields:
361            self._socket.sendall(field)
362
363        self._socket.sendall('\r\n')
364
365        self._logger.debug('Sent client\'s opening handshake headers: %r',
366                           fields)
367        self._logger.debug('Start reading Status-Line')
368
369        status_line = ''
370        while True:
371            ch = _receive_bytes(self._socket, 1)
372            status_line += ch
373            if ch == '\n':
374                break
375
376        m = re.match('HTTP/\\d+\.\\d+ (\\d\\d\\d) .*\r\n', status_line)
377        if m is None:
378            raise ClientHandshakeError(
379                'Wrong status line format: %r' % status_line)
380        status_code = m.group(1)
381        if status_code != '101':
382            self._logger.debug('Unexpected status code %s with following '
383                               'headers: %r', status_code, self._read_fields())
384            raise ClientHandshakeError(
385                'Expected HTTP status code 101 but found %r' % status_code)
386
387        self._logger.debug('Received valid Status-Line')
388        self._logger.debug('Start reading headers until we see an empty line')
389
390        fields = self._read_fields()
391
392        ch = _receive_bytes(self._socket, 1)
393        if ch != '\n':  # 0x0A
394            raise ClientHandshakeError(
395                'Expected LF but found %r while reading value %r for header '
396                'name %r' % (ch, value, name))
397
398        self._logger.debug('Received an empty line')
399        self._logger.debug('Server\'s opening handshake headers: %r', fields)
400
401        _validate_mandatory_header(
402            fields,
403            common.UPGRADE_HEADER,
404            common.WEBSOCKET_UPGRADE_TYPE,
405            False)
406
407        _validate_mandatory_header(
408            fields,
409            common.CONNECTION_HEADER,
410            common.UPGRADE_CONNECTION_TYPE,
411            False)
412
413        accept = _get_mandatory_header(
414            fields, common.SEC_WEBSOCKET_ACCEPT_HEADER)
415
416        # Validate
417        try:
418            binary_accept = base64.b64decode(accept)
419        except TypeError, e:
420            raise HandshakeError(
421                'Illegal value for header %s: %r' %
422                (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept))
423
424        if len(binary_accept) != 20:
425            raise ClientHandshakeError(
426                'Decoded value of %s is not 20-byte long' %
427                common.SEC_WEBSOCKET_ACCEPT_HEADER)
428
429        self._logger.debug(
430            'Response for challenge : %r (%s)',
431            accept, util.hexify(binary_accept))
432
433        binary_expected_accept = util.sha1_hash(
434            self._key + common.WEBSOCKET_ACCEPT_UUID).digest()
435        expected_accept = base64.b64encode(binary_expected_accept)
436
437        self._logger.debug(
438            'Expected response for challenge: %r (%s)',
439            expected_accept, util.hexify(binary_expected_accept))
440
441        if accept != expected_accept:
442            raise ClientHandshakeError(
443                'Invalid %s header: %r (expected: %s)' %
444                (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept, expected_accept))
445
446        deflate_stream_accepted = False
447        deflate_frame_accepted = False
448
449        extensions_header = fields.get(
450            common.SEC_WEBSOCKET_EXTENSIONS_HEADER.lower())
451        accepted_extensions = []
452        if extensions_header is not None and len(extensions_header) != 0:
453            accepted_extensions = common.parse_extensions(extensions_header[0])
454        # TODO(bashi): Support the new style perframe compression extension.
455        for extension in accepted_extensions:
456            extension_name = extension.name()
457            if (extension_name == common.DEFLATE_STREAM_EXTENSION and
458                len(extension.get_parameter_names()) == 0 and
459                self._options.deflate_stream):
460                deflate_stream_accepted = True
461                continue
462
463            if (extension_name == common.DEFLATE_FRAME_EXTENSION and
464                self._options.deflate_frame):
465                deflate_frame_accepted = True
466                processor = DeflateFrameExtensionProcessor(extension)
467                unused_extension_response = processor.get_extension_response()
468                self._options.deflate_frame = processor
469                continue
470
471            raise ClientHandshakeError(
472                'Unexpected extension %r' % extension_name)
473
474        if (self._options.deflate_stream and not deflate_stream_accepted):
475            raise ClientHandshakeError(
476                'Requested %s, but the server rejected it' %
477                common.DEFLATE_STREAM_EXTENSION)
478
479        if (self._options.deflate_frame and not deflate_frame_accepted):
480            raise ClientHandshakeError(
481                'Requested %s, but the server rejected it' %
482                common.DEFLATE_FRAME_EXTENSION)
483
484        # TODO(tyoshino): Handle Sec-WebSocket-Protocol
485        # TODO(tyoshino): Handle Cookie, etc.
486
487
488class ClientHandshakeProcessorHybi00(ClientHandshakeBase):
489    """WebSocket opening handshake processor for
490    draft-ietf-hybi-thewebsocketprotocol-00 (equivalent to
491    draft-hixie-thewebsocketprotocol-76).
492    """
493
494    def __init__(self, socket, options):
495        super(ClientHandshakeProcessorHybi00, self).__init__()
496
497        self._socket = socket
498        self._options = options
499
500        self._logger = util.get_class_logger(self)
501
502    def handshake(self):
503        """Performs opening handshake on the specified socket.
504
505        Raises:
506            ClientHandshakeError: handshake failed.
507        """
508        # 4.1 5. send request line.
509        self._socket.sendall(_build_method_line(self._options.resource))
510        # 4.1 6. Let /fields/ be an empty list of strings.
511        fields = []
512        # 4.1 7. Add the string "Upgrade: WebSocket" to /fields/.
513        fields.append(_UPGRADE_HEADER_HIXIE75)
514        # 4.1 8. Add the string "Connection: Upgrade" to /fields/.
515        fields.append(_CONNECTION_HEADER)
516        # 4.1 9-12. Add Host: field to /fields/.
517        fields.append(_format_host_header(
518            self._options.server_host,
519            self._options.server_port,
520            self._options.use_tls))
521        # 4.1 13. Add Origin: field to /fields/.
522        if not self._options.origin:
523            raise ClientHandshakeError(
524                'Specify the origin of the connection by --origin flag')
525        fields.append(_origin_header(common.ORIGIN_HEADER,
526                                     self._options.origin))
527        # TODO: 4.1 14 Add Sec-WebSocket-Protocol: field to /fields/.
528        # TODO: 4.1 15 Add cookie headers to /fields/.
529
530        # 4.1 16-23. Add Sec-WebSocket-Key<n> to /fields/.
531        self._number1, key1 = self._generate_sec_websocket_key()
532        self._logger.debug('Number1: %d', self._number1)
533        fields.append('%s: %s\r\n' % (common.SEC_WEBSOCKET_KEY1_HEADER, key1))
534        self._number2, key2 = self._generate_sec_websocket_key()
535        self._logger.debug('Number2: %d', self._number2)
536        fields.append('%s: %s\r\n' % (common.SEC_WEBSOCKET_KEY2_HEADER, key2))
537
538        fields.append('%s: 0\r\n' % common.SEC_WEBSOCKET_DRAFT_HEADER)
539
540        # 4.1 24. For each string in /fields/, in a random order: send the
541        # string, encoded as UTF-8, followed by a UTF-8 encoded U+000D CARRIAGE
542        # RETURN U+000A LINE FEED character pair (CRLF).
543        random.shuffle(fields)
544        for field in fields:
545            self._socket.sendall(field)
546        # 4.1 25. send a UTF-8-encoded U+000D CARRIAGE RETURN U+000A LINE FEED
547        # character pair (CRLF).
548        self._socket.sendall('\r\n')
549        # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
550        # equivalently, a random 64 bit integer encoded in a big-endian order).
551        self._key3 = self._generate_key3()
552        # 4.1 27. send /key3/ to the server.
553        self._socket.sendall(self._key3)
554        self._logger.debug(
555            'Key3: %r (%s)', self._key3, util.hexify(self._key3))
556
557        self._logger.info('Sent handshake')
558
559        # 4.1 28. Read bytes from the server until either the connection
560        # closes, or a 0x0A byte is read. let /field/ be these bytes, including
561        # the 0x0A bytes.
562        field = ''
563        while True:
564            ch = _receive_bytes(self._socket, 1)
565            field += ch
566            if ch == '\n':
567                break
568        # if /field/ is not at least seven bytes long, or if the last
569        # two bytes aren't 0x0D and 0x0A respectively, or if it does not
570        # contain at least two 0x20 bytes, then fail the WebSocket connection
571        # and abort these steps.
572        if len(field) < 7 or not field.endswith('\r\n'):
573            raise ClientHandshakeError('Wrong status line: %r' % field)
574        m = re.match('[^ ]* ([^ ]*) .*', field)
575        if m is None:
576            raise ClientHandshakeError(
577                'No HTTP status code found in status line: %r' % field)
578        # 4.1 29. let /code/ be the substring of /field/ that starts from the
579        # byte after the first 0x20 byte, and ends with the byte before the
580        # second 0x20 byte.
581        code = m.group(1)
582        # 4.1 30. if /code/ is not three bytes long, or if any of the bytes in
583        # /code/ are not in the range 0x30 to 0x90, then fail the WebSocket
584        # connection and abort these steps.
585        if not re.match('[0-9][0-9][0-9]', code):
586            raise ClientHandshakeError(
587                'HTTP status code %r is not three digit in status line: %r' %
588                (code, field))
589        # 4.1 31. if /code/, interpreted as UTF-8, is "101", then move to the
590        # next step.
591        if code != '101':
592            raise ClientHandshakeError(
593                'Expected HTTP status code 101 but found %r in status line: '
594                '%r' % (code, field))
595        # 4.1 32-39. read fields into /fields/
596        fields = self._read_fields()
597        # 4.1 40. _Fields processing_
598        # read a byte from server
599        ch = _receive_bytes(self._socket, 1)
600        if ch != '\n':  # 0x0A
601            raise ClientHandshakeError('Expected LF but found %r' % ch)
602        # 4.1 41. check /fields/
603        # TODO(ukai): protocol
604        # if the entry's name is "upgrade"
605        #  if the value is not exactly equal to the string "WebSocket",
606        #  then fail the WebSocket connection and abort these steps.
607        _validate_mandatory_header(
608            fields,
609            common.UPGRADE_HEADER,
610            common.WEBSOCKET_UPGRADE_TYPE_HIXIE75,
611            True)
612        # if the entry's name is "connection"
613        #  if the value, converted to ASCII lowercase, is not exactly equal
614        #  to the string "upgrade", then fail the WebSocket connection and
615        #  abort these steps.
616        _validate_mandatory_header(
617            fields,
618            common.CONNECTION_HEADER,
619            common.UPGRADE_CONNECTION_TYPE,
620            False)
621
622        origin = _get_mandatory_header(
623            fields, common.SEC_WEBSOCKET_ORIGIN_HEADER)
624
625        location = _get_mandatory_header(
626            fields, common.SEC_WEBSOCKET_LOCATION_HEADER)
627
628        # TODO(ukai): check origin, location, cookie, ..
629
630        # 4.1 42. let /challenge/ be the concatenation of /number_1/,
631        # expressed as a big endian 32 bit integer, /number_2/, expressed
632        # as big endian 32 bit integer, and the eight bytes of /key_3/ in the
633        # order they were sent on the wire.
634        challenge = struct.pack('!I', self._number1)
635        challenge += struct.pack('!I', self._number2)
636        challenge += self._key3
637
638        self._logger.debug(
639            'Challenge: %r (%s)', challenge, util.hexify(challenge))
640
641        # 4.1 43. let /expected/ be the MD5 fingerprint of /challenge/ as a
642        # big-endian 128 bit string.
643        expected = util.md5_hash(challenge).digest()
644        self._logger.debug(
645            'Expected challenge response: %r (%s)',
646            expected, util.hexify(expected))
647
648        # 4.1 44. read sixteen bytes from the server.
649        # let /reply/ be those bytes.
650        reply = _receive_bytes(self._socket, 16)
651        self._logger.debug(
652            'Actual challenge response: %r (%s)', reply, util.hexify(reply))
653
654        # 4.1 45. if /reply/ does not exactly equal /expected/, then fail
655        # the WebSocket connection and abort these steps.
656        if expected != reply:
657            raise ClientHandshakeError(
658                'Bad challenge response: %r (expected) != %r (actual)' %
659                (expected, reply))
660        # 4.1 46. The *WebSocket connection is established*.
661
662    def _generate_sec_websocket_key(self):
663        # 4.1 16. let /spaces_n/ be a random integer from 1 to 12 inclusive.
664        spaces = random.randint(1, 12)
665        # 4.1 17. let /max_n/ be the largest integer not greater than
666        #  4,294,967,295 divided by /spaces_n/.
667        maxnum = 4294967295 / spaces
668        # 4.1 18. let /number_n/ be a random integer from 0 to /max_n/
669        # inclusive.
670        number = random.randint(0, maxnum)
671        # 4.1 19. let /product_n/ be the result of multiplying /number_n/ and
672        # /spaces_n/ together.
673        product = number * spaces
674        # 4.1 20. let /key_n/ be a string consisting of /product_n/, expressed
675        # in base ten using the numerals in the range U+0030 DIGIT ZERO (0) to
676        # U+0039 DIGIT NINE (9).
677        key = str(product)
678        # 4.1 21. insert between one and twelve random characters from the
679        # range U+0021 to U+002F and U+003A to U+007E into /key_n/ at random
680        # positions.
681        available_chars = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1)
682        n = random.randint(1, 12)
683        for _ in xrange(n):
684            ch = random.choice(available_chars)
685            pos = random.randint(0, len(key))
686            key = key[0:pos] + chr(ch) + key[pos:]
687        # 4.1 22. insert /spaces_n/ U+0020 SPACE characters into /key_n/ at
688        # random positions other than start or end of the string.
689        for _ in xrange(spaces):
690            pos = random.randint(1, len(key) - 1)
691            key = key[0:pos] + ' ' + key[pos:]
692        return number, key
693
694    def _generate_key3(self):
695        # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
696        # equivalently, a random 64 bit integer encoded in a big-endian order).
697        return ''.join([chr(random.randint(0, 255)) for _ in xrange(8)])
698
699
700class ClientHandshakeProcessorHixie75(object):
701    """WebSocket opening handshake processor for
702    draft-hixie-thewebsocketprotocol-75.
703    """
704
705    _EXPECTED_RESPONSE = (
706        'HTTP/1.1 101 Web Socket Protocol Handshake\r\n' +
707        _UPGRADE_HEADER_HIXIE75 +
708        _CONNECTION_HEADER)
709
710    def __init__(self, socket, options):
711        self._socket = socket
712        self._options = options
713
714        self._logger = util.get_class_logger(self)
715
716    def _skip_headers(self):
717        terminator = '\r\n\r\n'
718        pos = 0
719        while pos < len(terminator):
720            received = _receive_bytes(self._socket, 1)
721            if received == terminator[pos]:
722                pos += 1
723            elif received == terminator[0]:
724                pos = 1
725            else:
726                pos = 0
727
728    def handshake(self):
729        """Performs opening handshake on the specified socket.
730
731        Raises:
732            ClientHandshakeError: handshake failed.
733        """
734
735        self._socket.sendall(_build_method_line(self._options.resource))
736        self._socket.sendall(_UPGRADE_HEADER_HIXIE75)
737        self._socket.sendall(_CONNECTION_HEADER)
738        self._socket.sendall(_format_host_header(
739            self._options.server_host,
740            self._options.server_port,
741            self._options.use_tls))
742        if not self._options.origin:
743            raise ClientHandshakeError(
744                'Specify the origin of the connection by --origin flag')
745        self._socket.sendall(_origin_header(common.ORIGIN_HEADER,
746                                            self._options.origin))
747        self._socket.sendall('\r\n')
748
749        self._logger.info('Sent handshake')
750
751        for expected_char in (
752            ClientHandshakeProcessorHixie75._EXPECTED_RESPONSE):
753            received = _receive_bytes(self._socket, 1)
754            if expected_char != received:
755                raise ClientHandshakeError('Handshake failure')
756        # We cut corners and skip other headers.
757        self._skip_headers()
758
759
760class ClientConnection(object):
761    """A wrapper for socket object to provide the mp_conn interface.
762    mod_pywebsocket library is designed to be working on Apache mod_python's
763    mp_conn object.
764    """
765
766    def __init__(self, socket):
767        self._socket = socket
768
769    def write(self, data):
770        self._socket.sendall(data)
771
772    def read(self, n):
773        return self._socket.recv(n)
774
775    def get_remote_addr(self):
776        return self._socket.getpeername()
777    remote_addr = property(get_remote_addr)
778
779
780class ClientRequest(object):
781    """A wrapper class just to make it able to pass a socket object to
782    functions that expect a mp_request object.
783    """
784
785    def __init__(self, socket):
786        self._logger = util.get_class_logger(self)
787
788        self._socket = socket
789        self.connection = ClientConnection(socket)
790
791    def _drain_received_data(self):
792        """Drains unread data in the receive buffer."""
793
794        drained_data = util.drain_received_data(self._socket)
795
796        if drained_data:
797            self._logger.debug(
798                'Drained data following close frame: %r', drained_data)
799
800
801class EchoClient(object):
802    """WebSocket echo client."""
803
804    def __init__(self, options):
805        self._options = options
806        self._socket = None
807
808        self._logger = util.get_class_logger(self)
809
810    def run(self):
811        """Run the client.
812
813        Shake hands and then repeat sending message and receiving its echo.
814        """
815
816        self._socket = socket.socket()
817        self._socket.settimeout(self._options.socket_timeout)
818        try:
819            self._socket.connect((self._options.server_host,
820                                  self._options.server_port))
821            if self._options.use_tls:
822                self._socket = _TLSSocket(self._socket)
823
824            version = self._options.protocol_version
825
826            if (version == _PROTOCOL_VERSION_HYBI08 or
827                version == _PROTOCOL_VERSION_HYBI13):
828                self._handshake = ClientHandshakeProcessor(
829                    self._socket, self._options)
830            elif version == _PROTOCOL_VERSION_HYBI00:
831                self._handshake = ClientHandshakeProcessorHybi00(
832                    self._socket, self._options)
833            elif version == _PROTOCOL_VERSION_HIXIE75:
834                self._handshake = ClientHandshakeProcessorHixie75(
835                    self._socket, self._options)
836            else:
837                raise ValueError(
838                    'Invalid --protocol-version flag: %r' % version)
839
840            self._handshake.handshake()
841
842            self._logger.info('Connection established')
843
844            request = ClientRequest(self._socket)
845
846            version_map = {
847                _PROTOCOL_VERSION_HYBI08: common.VERSION_HYBI08,
848                _PROTOCOL_VERSION_HYBI13: common.VERSION_HYBI13,
849                _PROTOCOL_VERSION_HYBI00: common.VERSION_HYBI00,
850                _PROTOCOL_VERSION_HIXIE75: common.VERSION_HIXIE75}
851            request.ws_version = version_map[version]
852
853            if (version == _PROTOCOL_VERSION_HYBI08 or
854                version == _PROTOCOL_VERSION_HYBI13):
855                stream_option = StreamOptions()
856                stream_option.mask_send = True
857                stream_option.unmask_receive = False
858
859                if self._options.deflate_stream:
860                    stream_option.deflate_stream = True
861
862                if self._options.deflate_frame is not False:
863                    processor = self._options.deflate_frame
864                    processor.setup_stream_options(stream_option)
865
866                self._stream = Stream(request, stream_option)
867            elif version == _PROTOCOL_VERSION_HYBI00:
868                self._stream = StreamHixie75(request, True)
869            elif version == _PROTOCOL_VERSION_HIXIE75:
870                self._stream = StreamHixie75(request)
871
872            for line in self._options.message.split(','):
873                self._stream.send_message(line)
874                if self._options.verbose:
875                    print 'Send: %s' % line
876                try:
877                    received = self._stream.receive_message()
878
879                    if self._options.verbose:
880                        print 'Recv: %s' % received
881                except Exception, e:
882                    if self._options.verbose:
883                        print 'Error: %s' % e
884                    raise
885
886            if version != _PROTOCOL_VERSION_HIXIE75:
887                self._do_closing_handshake()
888        finally:
889            self._socket.close()
890
891    def _do_closing_handshake(self):
892        """Perform closing handshake using the specified closing frame."""
893
894        if self._options.message.split(',')[-1] == _GOODBYE_MESSAGE:
895            # requested server initiated closing handshake, so
896            # expecting closing handshake message from server.
897            self._logger.info('Wait for server-initiated closing handshake')
898            message = self._stream.receive_message()
899            if message is None:
900                print 'Recv close'
901                print 'Send ack'
902                self._logger.info(
903                    'Received closing handshake and sent ack')
904                return
905        print 'Send close'
906        self._stream.close_connection()
907        self._logger.info('Sent closing handshake')
908        print 'Recv ack'
909        self._logger.info('Received ack')
910
911
912def main():
913    sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
914
915    parser = OptionParser()
916    # We accept --command_line_flag style flags which is the same as Google
917    # gflags in addition to common --command-line-flag style flags.
918    parser.add_option('-s', '--server-host', '--server_host',
919                      dest='server_host', type='string',
920                      default='localhost', help='server host')
921    parser.add_option('-p', '--server-port', '--server_port',
922                      dest='server_port', type='int',
923                      default=_UNDEFINED_PORT, help='server port')
924    parser.add_option('-o', '--origin', dest='origin', type='string',
925                      default=None, help='origin')
926    parser.add_option('-r', '--resource', dest='resource', type='string',
927                      default='/echo', help='resource path')
928    parser.add_option('-m', '--message', dest='message', type='string',
929                      help=('comma-separated messages to send. '
930                           '%s will force close the connection from server.' %
931                            _GOODBYE_MESSAGE))
932    parser.add_option('-q', '--quiet', dest='verbose', action='store_false',
933                      default=True, help='suppress messages')
934    parser.add_option('-t', '--tls', dest='use_tls', action='store_true',
935                      default=False, help='use TLS (wss://)')
936    parser.add_option('-k', '--socket-timeout', '--socket_timeout',
937                      dest='socket_timeout', type='int', default=_TIMEOUT_SEC,
938                      help='Timeout(sec) for sockets')
939    parser.add_option('--draft75', dest='draft75',
940                       action='store_true', default=False,
941                      help='use the Hixie 75 protocol. This overrides '
942                      'protocol-version flag')
943    parser.add_option('--protocol-version', '--protocol_version',
944                      dest='protocol_version',
945                      type='string', default=_PROTOCOL_VERSION_HYBI13,
946                      help='WebSocket protocol version to use. One of \'' +
947                      _PROTOCOL_VERSION_HYBI13 + '\', \'' +
948                      _PROTOCOL_VERSION_HYBI08 + '\', \'' +
949                      _PROTOCOL_VERSION_HYBI00 + '\', \'' +
950                      _PROTOCOL_VERSION_HIXIE75 + '\'')
951    parser.add_option('--version-header', '--version_header',
952                      dest='version_header',
953                      type='int', default=-1,
954                      help='specify Sec-WebSocket-Version header value')
955    parser.add_option('--deflate-stream', '--deflate_stream',
956                      dest='deflate_stream',
957                      action='store_true', default=False,
958                      help='use deflate-stream extension. This value will be '
959                      'ignored if used with protocol version that doesn\'t '
960                      'support deflate-stream.')
961    parser.add_option('--deflate-frame', '--deflate_frame',
962                      dest='deflate_frame',
963                      action='store_true', default=False,
964                      help='use deflate-frame extension. This value will be '
965                      'ignored if used with protocol version that doesn\'t '
966                      'support deflate-frame.')
967    parser.add_option('--log-level', '--log_level', type='choice',
968                      dest='log_level', default='warn',
969                      choices=['debug', 'info', 'warn', 'error', 'critical'],
970                      help='Log level.')
971
972    (options, unused_args) = parser.parse_args()
973
974    logging.basicConfig(level=logging.getLevelName(options.log_level.upper()))
975
976    if options.draft75:
977        options.protocol_version = _PROTOCOL_VERSION_HIXIE75
978
979    # Default port number depends on whether TLS is used.
980    if options.server_port == _UNDEFINED_PORT:
981        if options.use_tls:
982            options.server_port = common.DEFAULT_WEB_SOCKET_SECURE_PORT
983        else:
984            options.server_port = common.DEFAULT_WEB_SOCKET_PORT
985
986    # optparse doesn't seem to handle non-ascii default values.
987    # Set default message here.
988    if not options.message:
989        options.message = u'Hello,\u65e5\u672c'   # "Japan" in Japanese
990
991    EchoClient(options).run()
992
993
994if __name__ == '__main__':
995    main()
996
997
998# vi:sts=4 sw=4 et
999