1# Copyright 2012, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30 31"""This file provides the opening handshake processor for the WebSocket 32protocol (RFC 6455). 33 34Specification: 35http://tools.ietf.org/html/rfc6455 36""" 37 38 39# Note: request.connection.write is used in this module, even though mod_python 40# document says that it should be used only in connection handlers. 41# Unfortunately, we have no other options. For example, request.write is not 42# suitable because it doesn't allow direct raw bytes writing. 43 44 45import base64 46import logging 47import os 48import re 49 50from mod_pywebsocket import common 51from mod_pywebsocket.extensions import get_extension_processor 52from mod_pywebsocket.handshake._base import check_request_line 53from mod_pywebsocket.handshake._base import format_header 54from mod_pywebsocket.handshake._base import get_mandatory_header 55from mod_pywebsocket.handshake._base import HandshakeException 56from mod_pywebsocket.handshake._base import parse_token_list 57from mod_pywebsocket.handshake._base import validate_mandatory_header 58from mod_pywebsocket.handshake._base import validate_subprotocol 59from mod_pywebsocket.handshake._base import VersionException 60from mod_pywebsocket.stream import Stream 61from mod_pywebsocket.stream import StreamOptions 62from mod_pywebsocket import util 63 64 65# Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648 66# disallows non-zero padding, so the character right before == must be any of 67# A, Q, g and w. 68_SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$') 69 70# Defining aliases for values used frequently. 71_VERSION_HYBI08 = common.VERSION_HYBI08 72_VERSION_HYBI08_STRING = str(_VERSION_HYBI08) 73_VERSION_LATEST = common.VERSION_HYBI_LATEST 74_VERSION_LATEST_STRING = str(_VERSION_LATEST) 75_SUPPORTED_VERSIONS = [ 76 _VERSION_LATEST, 77 _VERSION_HYBI08, 78] 79 80 81def compute_accept(key): 82 """Computes value for the Sec-WebSocket-Accept header from value of the 83 Sec-WebSocket-Key header. 84 """ 85 86 accept_binary = util.sha1_hash( 87 key + common.WEBSOCKET_ACCEPT_UUID).digest() 88 accept = base64.b64encode(accept_binary) 89 90 return (accept, accept_binary) 91 92 93class Handshaker(object): 94 """Opening handshake processor for the WebSocket protocol (RFC 6455).""" 95 96 def __init__(self, request, dispatcher): 97 """Construct an instance. 98 99 Args: 100 request: mod_python request. 101 dispatcher: Dispatcher (dispatch.Dispatcher). 102 103 Handshaker will add attributes such as ws_resource during handshake. 104 """ 105 106 self._logger = util.get_class_logger(self) 107 108 self._request = request 109 self._dispatcher = dispatcher 110 111 def _validate_connection_header(self): 112 connection = get_mandatory_header( 113 self._request, common.CONNECTION_HEADER) 114 115 try: 116 connection_tokens = parse_token_list(connection) 117 except HandshakeException, e: 118 raise HandshakeException( 119 'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e)) 120 121 connection_is_valid = False 122 for token in connection_tokens: 123 if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower(): 124 connection_is_valid = True 125 break 126 if not connection_is_valid: 127 raise HandshakeException( 128 '%s header doesn\'t contain "%s"' % 129 (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) 130 131 def do_handshake(self): 132 self._request.ws_close_code = None 133 self._request.ws_close_reason = None 134 135 # Parsing. 136 137 check_request_line(self._request) 138 139 validate_mandatory_header( 140 self._request, 141 common.UPGRADE_HEADER, 142 common.WEBSOCKET_UPGRADE_TYPE) 143 144 self._validate_connection_header() 145 146 self._request.ws_resource = self._request.uri 147 148 unused_host = get_mandatory_header(self._request, common.HOST_HEADER) 149 150 self._request.ws_version = self._check_version() 151 152 # This handshake must be based on latest hybi. We are responsible to 153 # fallback to HTTP on handshake failure as latest hybi handshake 154 # specifies. 155 try: 156 self._get_origin() 157 self._set_protocol() 158 self._parse_extensions() 159 160 # Key validation, response generation. 161 162 key = self._get_key() 163 (accept, accept_binary) = compute_accept(key) 164 self._logger.debug( 165 '%s: %r (%s)', 166 common.SEC_WEBSOCKET_ACCEPT_HEADER, 167 accept, 168 util.hexify(accept_binary)) 169 170 self._logger.debug('Protocol version is RFC 6455') 171 172 # Setup extension processors. 173 174 processors = [] 175 if self._request.ws_requested_extensions is not None: 176 for extension_request in self._request.ws_requested_extensions: 177 processor = get_extension_processor(extension_request) 178 # Unknown extension requests are just ignored. 179 if processor is not None: 180 processors.append(processor) 181 self._request.ws_extension_processors = processors 182 183 # Extra handshake handler may modify/remove processors. 184 self._dispatcher.do_extra_handshake(self._request) 185 processors = filter(lambda processor: processor is not None, 186 self._request.ws_extension_processors) 187 188 accepted_extensions = [] 189 190 # We need to take care of mux extension here. Extensions that 191 # are placed before mux should be applied to logical channels. 192 mux_index = -1 193 for i, processor in enumerate(processors): 194 if processor.name() == common.MUX_EXTENSION: 195 mux_index = i 196 break 197 if mux_index >= 0: 198 mux_processor = processors[mux_index] 199 logical_channel_processors = processors[:mux_index] 200 processors = processors[mux_index+1:] 201 202 for processor in logical_channel_processors: 203 extension_response = processor.get_extension_response() 204 if extension_response is None: 205 # Rejected. 206 continue 207 accepted_extensions.append(extension_response) 208 # Pass a shallow copy of accepted_extensions as extensions for 209 # logical channels. 210 mux_response = mux_processor.get_extension_response( 211 self._request, accepted_extensions[:]) 212 if mux_response is not None: 213 accepted_extensions.append(mux_response) 214 215 stream_options = StreamOptions() 216 217 # When there is mux extension, here, |processors| contain only 218 # prosessors for extensions placed after mux. 219 for processor in processors: 220 221 extension_response = processor.get_extension_response() 222 if extension_response is None: 223 # Rejected. 224 continue 225 226 accepted_extensions.append(extension_response) 227 228 processor.setup_stream_options(stream_options) 229 230 if len(accepted_extensions) > 0: 231 self._request.ws_extensions = accepted_extensions 232 self._logger.debug( 233 'Extensions accepted: %r', 234 map(common.ExtensionParameter.name, accepted_extensions)) 235 else: 236 self._request.ws_extensions = None 237 238 self._request.ws_stream = self._create_stream(stream_options) 239 240 if self._request.ws_requested_protocols is not None: 241 if self._request.ws_protocol is None: 242 raise HandshakeException( 243 'do_extra_handshake must choose one subprotocol from ' 244 'ws_requested_protocols and set it to ws_protocol') 245 validate_subprotocol(self._request.ws_protocol, hixie=False) 246 247 self._logger.debug( 248 'Subprotocol accepted: %r', 249 self._request.ws_protocol) 250 else: 251 if self._request.ws_protocol is not None: 252 raise HandshakeException( 253 'ws_protocol must be None when the client didn\'t ' 254 'request any subprotocol') 255 256 self._send_handshake(accept) 257 except HandshakeException, e: 258 if not e.status: 259 # Fallback to 400 bad request by default. 260 e.status = common.HTTP_STATUS_BAD_REQUEST 261 raise e 262 263 def _get_origin(self): 264 if self._request.ws_version is _VERSION_HYBI08: 265 origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER 266 else: 267 origin_header = common.ORIGIN_HEADER 268 origin = self._request.headers_in.get(origin_header) 269 if origin is None: 270 self._logger.debug('Client request does not have origin header') 271 self._request.ws_origin = origin 272 273 def _check_version(self): 274 version = get_mandatory_header(self._request, 275 common.SEC_WEBSOCKET_VERSION_HEADER) 276 if version == _VERSION_HYBI08_STRING: 277 return _VERSION_HYBI08 278 if version == _VERSION_LATEST_STRING: 279 return _VERSION_LATEST 280 281 if version.find(',') >= 0: 282 raise HandshakeException( 283 'Multiple versions (%r) are not allowed for header %s' % 284 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 285 status=common.HTTP_STATUS_BAD_REQUEST) 286 raise VersionException( 287 'Unsupported version %r for header %s' % 288 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 289 supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS))) 290 291 def _set_protocol(self): 292 self._request.ws_protocol = None 293 294 protocol_header = self._request.headers_in.get( 295 common.SEC_WEBSOCKET_PROTOCOL_HEADER) 296 297 if not protocol_header: 298 self._request.ws_requested_protocols = None 299 return 300 301 self._request.ws_requested_protocols = parse_token_list( 302 protocol_header) 303 self._logger.debug('Subprotocols requested: %r', 304 self._request.ws_requested_protocols) 305 306 def _parse_extensions(self): 307 extensions_header = self._request.headers_in.get( 308 common.SEC_WEBSOCKET_EXTENSIONS_HEADER) 309 if not extensions_header: 310 self._request.ws_requested_extensions = None 311 return 312 313 if self._request.ws_version is common.VERSION_HYBI08: 314 allow_quoted_string=False 315 else: 316 allow_quoted_string=True 317 try: 318 self._request.ws_requested_extensions = common.parse_extensions( 319 extensions_header, allow_quoted_string=allow_quoted_string) 320 except common.ExtensionParsingException, e: 321 raise HandshakeException( 322 'Failed to parse Sec-WebSocket-Extensions header: %r' % e) 323 324 self._logger.debug( 325 'Extensions requested: %r', 326 map(common.ExtensionParameter.name, 327 self._request.ws_requested_extensions)) 328 329 def _validate_key(self, key): 330 if key.find(',') >= 0: 331 raise HandshakeException('Request has multiple %s header lines or ' 332 'contains illegal character \',\': %r' % 333 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 334 335 # Validate 336 key_is_valid = False 337 try: 338 # Validate key by quick regex match before parsing by base64 339 # module. Because base64 module skips invalid characters, we have 340 # to do this in advance to make this server strictly reject illegal 341 # keys. 342 if _SEC_WEBSOCKET_KEY_REGEX.match(key): 343 decoded_key = base64.b64decode(key) 344 if len(decoded_key) == 16: 345 key_is_valid = True 346 except TypeError, e: 347 pass 348 349 if not key_is_valid: 350 raise HandshakeException( 351 'Illegal value for header %s: %r' % 352 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 353 354 return decoded_key 355 356 def _get_key(self): 357 key = get_mandatory_header( 358 self._request, common.SEC_WEBSOCKET_KEY_HEADER) 359 360 decoded_key = self._validate_key(key) 361 362 self._logger.debug( 363 '%s: %r (%s)', 364 common.SEC_WEBSOCKET_KEY_HEADER, 365 key, 366 util.hexify(decoded_key)) 367 368 return key 369 370 def _create_stream(self, stream_options): 371 return Stream(self._request, stream_options) 372 373 def _create_handshake_response(self, accept): 374 response = [] 375 376 response.append('HTTP/1.1 101 Switching Protocols\r\n') 377 378 response.append(format_header( 379 common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE)) 380 response.append(format_header( 381 common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) 382 response.append(format_header( 383 common.SEC_WEBSOCKET_ACCEPT_HEADER, accept)) 384 if self._request.ws_protocol is not None: 385 response.append(format_header( 386 common.SEC_WEBSOCKET_PROTOCOL_HEADER, 387 self._request.ws_protocol)) 388 if (self._request.ws_extensions is not None and 389 len(self._request.ws_extensions) != 0): 390 response.append(format_header( 391 common.SEC_WEBSOCKET_EXTENSIONS_HEADER, 392 common.format_extensions(self._request.ws_extensions))) 393 response.append('\r\n') 394 395 return ''.join(response) 396 397 def _send_handshake(self, accept): 398 raw_response = self._create_handshake_response(accept) 399 self._request.connection.write(raw_response) 400 self._logger.debug('Sent server\'s opening handshake: %r', 401 raw_response) 402 403 404# vi:sts=4 sw=4 et 405