• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2012, Google Inc.
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met:
7#
8#     * Redistributions of source code must retain the above copyright
9# notice, this list of conditions and the following disclaimer.
10#     * Redistributions in binary form must reproduce the above
11# copyright notice, this list of conditions and the following disclaimer
12# in the documentation and/or other materials provided with the
13# distribution.
14#     * Neither the name of Google Inc. nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30
31"""This file provides the opening handshake processor for the WebSocket
32protocol (RFC 6455).
33
34Specification:
35http://tools.ietf.org/html/rfc6455
36"""
37
38
39# Note: request.connection.write is used in this module, even though mod_python
40# document says that it should be used only in connection handlers.
41# Unfortunately, we have no other options. For example, request.write is not
42# suitable because it doesn't allow direct raw bytes writing.
43
44
45import base64
46import logging
47import os
48import re
49
50from mod_pywebsocket import common
51from mod_pywebsocket.extensions import get_extension_processor
52from mod_pywebsocket.handshake._base import check_request_line
53from mod_pywebsocket.handshake._base import format_header
54from mod_pywebsocket.handshake._base import get_mandatory_header
55from mod_pywebsocket.handshake._base import HandshakeException
56from mod_pywebsocket.handshake._base import parse_token_list
57from mod_pywebsocket.handshake._base import validate_mandatory_header
58from mod_pywebsocket.handshake._base import validate_subprotocol
59from mod_pywebsocket.handshake._base import VersionException
60from mod_pywebsocket.stream import Stream
61from mod_pywebsocket.stream import StreamOptions
62from mod_pywebsocket import util
63
64
65# Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648
66# disallows non-zero padding, so the character right before == must be any of
67# A, Q, g and w.
68_SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$')
69
70# Defining aliases for values used frequently.
71_VERSION_HYBI08 = common.VERSION_HYBI08
72_VERSION_HYBI08_STRING = str(_VERSION_HYBI08)
73_VERSION_LATEST = common.VERSION_HYBI_LATEST
74_VERSION_LATEST_STRING = str(_VERSION_LATEST)
75_SUPPORTED_VERSIONS = [
76    _VERSION_LATEST,
77    _VERSION_HYBI08,
78]
79
80
81def compute_accept(key):
82    """Computes value for the Sec-WebSocket-Accept header from value of the
83    Sec-WebSocket-Key header.
84    """
85
86    accept_binary = util.sha1_hash(
87        key + common.WEBSOCKET_ACCEPT_UUID).digest()
88    accept = base64.b64encode(accept_binary)
89
90    return (accept, accept_binary)
91
92
93class Handshaker(object):
94    """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
95
96    def __init__(self, request, dispatcher):
97        """Construct an instance.
98
99        Args:
100            request: mod_python request.
101            dispatcher: Dispatcher (dispatch.Dispatcher).
102
103        Handshaker will add attributes such as ws_resource during handshake.
104        """
105
106        self._logger = util.get_class_logger(self)
107
108        self._request = request
109        self._dispatcher = dispatcher
110
111    def _validate_connection_header(self):
112        connection = get_mandatory_header(
113            self._request, common.CONNECTION_HEADER)
114
115        try:
116            connection_tokens = parse_token_list(connection)
117        except HandshakeException, e:
118            raise HandshakeException(
119                'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e))
120
121        connection_is_valid = False
122        for token in connection_tokens:
123            if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower():
124                connection_is_valid = True
125                break
126        if not connection_is_valid:
127            raise HandshakeException(
128                '%s header doesn\'t contain "%s"' %
129                (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
130
131    def do_handshake(self):
132        self._request.ws_close_code = None
133        self._request.ws_close_reason = None
134
135        # Parsing.
136
137        check_request_line(self._request)
138
139        validate_mandatory_header(
140            self._request,
141            common.UPGRADE_HEADER,
142            common.WEBSOCKET_UPGRADE_TYPE)
143
144        self._validate_connection_header()
145
146        self._request.ws_resource = self._request.uri
147
148        unused_host = get_mandatory_header(self._request, common.HOST_HEADER)
149
150        self._request.ws_version = self._check_version()
151
152        # This handshake must be based on latest hybi. We are responsible to
153        # fallback to HTTP on handshake failure as latest hybi handshake
154        # specifies.
155        try:
156            self._get_origin()
157            self._set_protocol()
158            self._parse_extensions()
159
160            # Key validation, response generation.
161
162            key = self._get_key()
163            (accept, accept_binary) = compute_accept(key)
164            self._logger.debug(
165                '%s: %r (%s)',
166                common.SEC_WEBSOCKET_ACCEPT_HEADER,
167                accept,
168                util.hexify(accept_binary))
169
170            self._logger.debug('Protocol version is RFC 6455')
171
172            # Setup extension processors.
173
174            processors = []
175            if self._request.ws_requested_extensions is not None:
176                for extension_request in self._request.ws_requested_extensions:
177                    processor = get_extension_processor(extension_request)
178                    # Unknown extension requests are just ignored.
179                    if processor is not None:
180                        processors.append(processor)
181            self._request.ws_extension_processors = processors
182
183            # Extra handshake handler may modify/remove processors.
184            self._dispatcher.do_extra_handshake(self._request)
185            processors = filter(lambda processor: processor is not None,
186                                self._request.ws_extension_processors)
187
188            accepted_extensions = []
189
190            # We need to take care of mux extension here. Extensions that
191            # are placed before mux should be applied to logical channels.
192            mux_index = -1
193            for i, processor in enumerate(processors):
194                if processor.name() == common.MUX_EXTENSION:
195                    mux_index = i
196                    break
197            if mux_index >= 0:
198                mux_processor = processors[mux_index]
199                logical_channel_processors = processors[:mux_index]
200                processors = processors[mux_index+1:]
201
202                for processor in logical_channel_processors:
203                    extension_response = processor.get_extension_response()
204                    if extension_response is None:
205                        # Rejected.
206                        continue
207                    accepted_extensions.append(extension_response)
208                # Pass a shallow copy of accepted_extensions as extensions for
209                # logical channels.
210                mux_response = mux_processor.get_extension_response(
211                    self._request, accepted_extensions[:])
212                if mux_response is not None:
213                    accepted_extensions.append(mux_response)
214
215            stream_options = StreamOptions()
216
217            # When there is mux extension, here, |processors| contain only
218            # prosessors for extensions placed after mux.
219            for processor in processors:
220
221                extension_response = processor.get_extension_response()
222                if extension_response is None:
223                    # Rejected.
224                    continue
225
226                accepted_extensions.append(extension_response)
227
228                processor.setup_stream_options(stream_options)
229
230            if len(accepted_extensions) > 0:
231                self._request.ws_extensions = accepted_extensions
232                self._logger.debug(
233                    'Extensions accepted: %r',
234                    map(common.ExtensionParameter.name, accepted_extensions))
235            else:
236                self._request.ws_extensions = None
237
238            self._request.ws_stream = self._create_stream(stream_options)
239
240            if self._request.ws_requested_protocols is not None:
241                if self._request.ws_protocol is None:
242                    raise HandshakeException(
243                        'do_extra_handshake must choose one subprotocol from '
244                        'ws_requested_protocols and set it to ws_protocol')
245                validate_subprotocol(self._request.ws_protocol, hixie=False)
246
247                self._logger.debug(
248                    'Subprotocol accepted: %r',
249                    self._request.ws_protocol)
250            else:
251                if self._request.ws_protocol is not None:
252                    raise HandshakeException(
253                        'ws_protocol must be None when the client didn\'t '
254                        'request any subprotocol')
255
256            self._send_handshake(accept)
257        except HandshakeException, e:
258            if not e.status:
259                # Fallback to 400 bad request by default.
260                e.status = common.HTTP_STATUS_BAD_REQUEST
261            raise e
262
263    def _get_origin(self):
264        if self._request.ws_version is _VERSION_HYBI08:
265            origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER
266        else:
267            origin_header = common.ORIGIN_HEADER
268        origin = self._request.headers_in.get(origin_header)
269        if origin is None:
270            self._logger.debug('Client request does not have origin header')
271        self._request.ws_origin = origin
272
273    def _check_version(self):
274        version = get_mandatory_header(self._request,
275                                       common.SEC_WEBSOCKET_VERSION_HEADER)
276        if version == _VERSION_HYBI08_STRING:
277            return _VERSION_HYBI08
278        if version == _VERSION_LATEST_STRING:
279            return _VERSION_LATEST
280
281        if version.find(',') >= 0:
282            raise HandshakeException(
283                'Multiple versions (%r) are not allowed for header %s' %
284                (version, common.SEC_WEBSOCKET_VERSION_HEADER),
285                status=common.HTTP_STATUS_BAD_REQUEST)
286        raise VersionException(
287            'Unsupported version %r for header %s' %
288            (version, common.SEC_WEBSOCKET_VERSION_HEADER),
289            supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS)))
290
291    def _set_protocol(self):
292        self._request.ws_protocol = None
293
294        protocol_header = self._request.headers_in.get(
295            common.SEC_WEBSOCKET_PROTOCOL_HEADER)
296
297        if not protocol_header:
298            self._request.ws_requested_protocols = None
299            return
300
301        self._request.ws_requested_protocols = parse_token_list(
302            protocol_header)
303        self._logger.debug('Subprotocols requested: %r',
304                           self._request.ws_requested_protocols)
305
306    def _parse_extensions(self):
307        extensions_header = self._request.headers_in.get(
308            common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
309        if not extensions_header:
310            self._request.ws_requested_extensions = None
311            return
312
313        if self._request.ws_version is common.VERSION_HYBI08:
314            allow_quoted_string=False
315        else:
316            allow_quoted_string=True
317        try:
318            self._request.ws_requested_extensions = common.parse_extensions(
319                extensions_header, allow_quoted_string=allow_quoted_string)
320        except common.ExtensionParsingException, e:
321            raise HandshakeException(
322                'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
323
324        self._logger.debug(
325            'Extensions requested: %r',
326            map(common.ExtensionParameter.name,
327                self._request.ws_requested_extensions))
328
329    def _validate_key(self, key):
330        if key.find(',') >= 0:
331            raise HandshakeException('Request has multiple %s header lines or '
332                                     'contains illegal character \',\': %r' %
333                                     (common.SEC_WEBSOCKET_KEY_HEADER, key))
334
335        # Validate
336        key_is_valid = False
337        try:
338            # Validate key by quick regex match before parsing by base64
339            # module. Because base64 module skips invalid characters, we have
340            # to do this in advance to make this server strictly reject illegal
341            # keys.
342            if _SEC_WEBSOCKET_KEY_REGEX.match(key):
343                decoded_key = base64.b64decode(key)
344                if len(decoded_key) == 16:
345                    key_is_valid = True
346        except TypeError, e:
347            pass
348
349        if not key_is_valid:
350            raise HandshakeException(
351                'Illegal value for header %s: %r' %
352                (common.SEC_WEBSOCKET_KEY_HEADER, key))
353
354        return decoded_key
355
356    def _get_key(self):
357        key = get_mandatory_header(
358            self._request, common.SEC_WEBSOCKET_KEY_HEADER)
359
360        decoded_key = self._validate_key(key)
361
362        self._logger.debug(
363            '%s: %r (%s)',
364            common.SEC_WEBSOCKET_KEY_HEADER,
365            key,
366            util.hexify(decoded_key))
367
368        return key
369
370    def _create_stream(self, stream_options):
371        return Stream(self._request, stream_options)
372
373    def _create_handshake_response(self, accept):
374        response = []
375
376        response.append('HTTP/1.1 101 Switching Protocols\r\n')
377
378        response.append(format_header(
379            common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE))
380        response.append(format_header(
381            common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
382        response.append(format_header(
383            common.SEC_WEBSOCKET_ACCEPT_HEADER, accept))
384        if self._request.ws_protocol is not None:
385            response.append(format_header(
386                common.SEC_WEBSOCKET_PROTOCOL_HEADER,
387                self._request.ws_protocol))
388        if (self._request.ws_extensions is not None and
389            len(self._request.ws_extensions) != 0):
390            response.append(format_header(
391                common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
392                common.format_extensions(self._request.ws_extensions)))
393        response.append('\r\n')
394
395        return ''.join(response)
396
397    def _send_handshake(self, accept):
398        raw_response = self._create_handshake_response(accept)
399        self._request.connection.write(raw_response)
400        self._logger.debug('Sent server\'s opening handshake: %r',
401                           raw_response)
402
403
404# vi:sts=4 sw=4 et
405