1# Copyright 2009, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30 31"""Web Socket handshaking. 32 33Note: request.connection.write/read are used in this module, even though 34mod_python document says that they should be used only in connection handlers. 35Unfortunately, we have no other options. For example, request.write/read are 36not suitable because they don't allow direct raw bytes writing/reading. 37""" 38 39 40import re 41 42import util 43 44 45_DEFAULT_WEB_SOCKET_PORT = 80 46_DEFAULT_WEB_SOCKET_SECURE_PORT = 443 47_WEB_SOCKET_SCHEME = 'ws' 48_WEB_SOCKET_SECURE_SCHEME = 'wss' 49 50_MANDATORY_HEADERS = [ 51 # key, expected value or None 52 ['Upgrade', 'WebSocket'], 53 ['Connection', 'Upgrade'], 54 ['Host', None], 55 ['Origin', None], 56] 57 58_FIRST_FIVE_LINES = map(re.compile, [ 59 r'^GET /[\S]* HTTP/1.1\r\n$', 60 r'^Upgrade: WebSocket\r\n$', 61 r'^Connection: Upgrade\r\n$', 62 r'^Host: [\S]+\r\n$', 63 r'^Origin: [\S]+\r\n$', 64]) 65 66_SIXTH_AND_LATER = re.compile( 67 r'^' 68 r'(WebSocket-Protocol: [\x20-\x7e]+\r\n)?' 69 r'(Cookie: [^\r]*\r\n)*' 70 r'(Cookie2: [^\r]*\r\n)?' 71 r'(Cookie: [^\r]*\r\n)*' 72 r'\r\n') 73 74 75def _default_port(is_secure): 76 if is_secure: 77 return _DEFAULT_WEB_SOCKET_SECURE_PORT 78 else: 79 return _DEFAULT_WEB_SOCKET_PORT 80 81 82class HandshakeError(Exception): 83 """Exception in Web Socket Handshake.""" 84 85 pass 86 87 88def _validate_protocol(protocol): 89 """Validate WebSocket-Protocol string.""" 90 91 if not protocol: 92 raise HandshakeError('Invalid WebSocket-Protocol: empty') 93 for c in protocol: 94 if not 0x20 <= ord(c) <= 0x7e: 95 raise HandshakeError('Illegal character in protocol: %r' % c) 96 97 98class Handshaker(object): 99 """This class performs Web Socket handshake.""" 100 101 def __init__(self, request, dispatcher, strict=False): 102 """Construct an instance. 103 104 Args: 105 request: mod_python request. 106 dispatcher: Dispatcher (dispatch.Dispatcher). 107 strict: Strictly check handshake request. Default: False. 108 If True, request.connection must provide get_memorized_lines 109 method. 110 111 Handshaker will add attributes such as ws_resource in performing 112 handshake. 113 """ 114 115 self._request = request 116 self._dispatcher = dispatcher 117 self._strict = strict 118 119 def do_handshake(self): 120 """Perform Web Socket Handshake.""" 121 122 self._check_header_lines() 123 self._set_resource() 124 self._set_origin() 125 self._set_location() 126 self._set_protocol() 127 self._dispatcher.do_extra_handshake(self._request) 128 self._send_handshake() 129 130 def _set_resource(self): 131 self._request.ws_resource = self._request.uri 132 133 def _set_origin(self): 134 self._request.ws_origin = self._request.headers_in['Origin'] 135 136 def _set_location(self): 137 location_parts = [] 138 if self._request.is_https(): 139 location_parts.append(_WEB_SOCKET_SECURE_SCHEME) 140 else: 141 location_parts.append(_WEB_SOCKET_SCHEME) 142 location_parts.append('://') 143 host, port = self._parse_host_header() 144 connection_port = self._request.connection.local_addr[1] 145 if port != connection_port: 146 raise HandshakeError('Header/connection port mismatch: %d/%d' % 147 (port, connection_port)) 148 location_parts.append(host) 149 if (port != _default_port(self._request.is_https())): 150 location_parts.append(':') 151 location_parts.append(str(port)) 152 location_parts.append(self._request.uri) 153 self._request.ws_location = ''.join(location_parts) 154 155 def _parse_host_header(self): 156 fields = self._request.headers_in['Host'].split(':', 1) 157 if len(fields) == 1: 158 return fields[0], _default_port(self._request.is_https()) 159 try: 160 return fields[0], int(fields[1]) 161 except ValueError, e: 162 raise HandshakeError('Invalid port number format: %r' % e) 163 164 def _set_protocol(self): 165 protocol = self._request.headers_in.get('WebSocket-Protocol') 166 if protocol is not None: 167 _validate_protocol(protocol) 168 self._request.ws_protocol = protocol 169 170 def _send_handshake(self): 171 self._request.connection.write( 172 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n') 173 self._request.connection.write('Upgrade: WebSocket\r\n') 174 self._request.connection.write('Connection: Upgrade\r\n') 175 self._request.connection.write('WebSocket-Origin: ') 176 self._request.connection.write(self._request.ws_origin) 177 self._request.connection.write('\r\n') 178 self._request.connection.write('WebSocket-Location: ') 179 self._request.connection.write(self._request.ws_location) 180 self._request.connection.write('\r\n') 181 if self._request.ws_protocol: 182 self._request.connection.write('WebSocket-Protocol: ') 183 self._request.connection.write(self._request.ws_protocol) 184 self._request.connection.write('\r\n') 185 self._request.connection.write('\r\n') 186 187 def _check_header_lines(self): 188 for key, expected_value in _MANDATORY_HEADERS: 189 actual_value = self._request.headers_in.get(key) 190 if not actual_value: 191 raise HandshakeError('Header %s is not defined' % key) 192 if expected_value: 193 if actual_value != expected_value: 194 raise HandshakeError('Illegal value for header %s: %s' % 195 (key, actual_value)) 196 if self._strict: 197 try: 198 lines = self._request.connection.get_memorized_lines() 199 except AttributeError, e: 200 util.prepend_message_to_exception( 201 'Strict handshake is specified but the connection ' 202 'doesn\'t provide get_memorized_lines()', e) 203 raise 204 self._check_first_lines(lines) 205 206 def _check_first_lines(self, lines): 207 if len(lines) < len(_FIRST_FIVE_LINES): 208 raise HandshakeError('Too few header lines: %d' % len(lines)) 209 for line, regexp in zip(lines, _FIRST_FIVE_LINES): 210 if not regexp.search(line): 211 raise HandshakeError('Unexpected header: %r doesn\'t match %r' 212 % (line, regexp.pattern)) 213 sixth_and_later = ''.join(lines[5:]) 214 if not _SIXTH_AND_LATER.search(sixth_and_later): 215 raise HandshakeError('Unexpected header: %r doesn\'t match %r' 216 % (sixth_and_later, 217 _SIXTH_AND_LATER.pattern)) 218 219 220# vi:sts=4 sw=4 et 221