• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2012, Google Inc.
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met:
7#
8#     * Redistributions of source code must retain the above copyright
9# notice, this list of conditions and the following disclaimer.
10#     * Redistributions in binary form must reproduce the above
11# copyright notice, this list of conditions and the following disclaimer
12# in the documentation and/or other materials provided with the
13# distribution.
14#     * Neither the name of Google Inc. nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30
31from mod_pywebsocket import common
32from mod_pywebsocket import util
33from mod_pywebsocket.http_header_util import quote_if_necessary
34
35
36_available_processors = {}
37
38
39class ExtensionProcessorInterface(object):
40
41    def name(self):
42        return None
43
44    def get_extension_response(self):
45        return None
46
47    def setup_stream_options(self, stream_options):
48        pass
49
50
51class DeflateStreamExtensionProcessor(ExtensionProcessorInterface):
52    """WebSocket DEFLATE stream extension processor."""
53
54    def __init__(self, request):
55        self._logger = util.get_class_logger(self)
56
57        self._request = request
58
59    def name(self):
60        return common.DEFLATE_STREAM_EXTENSION
61
62    def get_extension_response(self):
63        if len(self._request.get_parameter_names()) != 0:
64            return None
65
66        self._logger.debug(
67            'Enable %s extension', common.DEFLATE_STREAM_EXTENSION)
68
69        return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION)
70
71    def setup_stream_options(self, stream_options):
72        stream_options.deflate_stream = True
73
74
75_available_processors[common.DEFLATE_STREAM_EXTENSION] = (
76    DeflateStreamExtensionProcessor)
77
78
79def _log_compression_ratio(logger, original_bytes, total_original_bytes,
80                           filtered_bytes, total_filtered_bytes):
81    # Print inf when ratio is not available.
82    ratio = float('inf')
83    average_ratio = float('inf')
84    if original_bytes != 0:
85        ratio = float(filtered_bytes) / original_bytes
86    if total_original_bytes != 0:
87        average_ratio = (
88            float(total_filtered_bytes) / total_original_bytes)
89    logger.debug('Outgoing compress ratio: %f (average: %f)' %
90        (ratio, average_ratio))
91
92
93def _log_decompression_ratio(logger, received_bytes, total_received_bytes,
94                             filtered_bytes, total_filtered_bytes):
95    # Print inf when ratio is not available.
96    ratio = float('inf')
97    average_ratio = float('inf')
98    if received_bytes != 0:
99        ratio = float(received_bytes) / filtered_bytes
100    if total_filtered_bytes != 0:
101        average_ratio = (
102            float(total_received_bytes) / total_filtered_bytes)
103    logger.debug('Incoming compress ratio: %f (average: %f)' %
104        (ratio, average_ratio))
105
106
107class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
108    """WebSocket Per-frame DEFLATE extension processor."""
109
110    _WINDOW_BITS_PARAM = 'max_window_bits'
111    _NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover'
112
113    def __init__(self, request):
114        self._logger = util.get_class_logger(self)
115
116        self._request = request
117
118        self._response_window_bits = None
119        self._response_no_context_takeover = False
120
121        # Counters for statistics.
122
123        # Total number of outgoing bytes supplied to this filter.
124        self._total_outgoing_payload_bytes = 0
125        # Total number of bytes sent to the network after applying this filter.
126        self._total_filtered_outgoing_payload_bytes = 0
127
128        # Total number of bytes received from the network.
129        self._total_incoming_payload_bytes = 0
130        # Total number of incoming bytes obtained after applying this filter.
131        self._total_filtered_incoming_payload_bytes = 0
132
133    def name(self):
134        return common.DEFLATE_FRAME_EXTENSION
135
136    def get_extension_response(self):
137        # Any unknown parameter will be just ignored.
138
139        window_bits = self._request.get_parameter_value(
140            self._WINDOW_BITS_PARAM)
141        no_context_takeover = self._request.has_parameter(
142            self._NO_CONTEXT_TAKEOVER_PARAM)
143        if (no_context_takeover and
144            self._request.get_parameter_value(
145                self._NO_CONTEXT_TAKEOVER_PARAM) is not None):
146            return None
147
148        if window_bits is not None:
149            try:
150                window_bits = int(window_bits)
151            except ValueError, e:
152                return None
153            if window_bits < 8 or window_bits > 15:
154                return None
155
156        self._deflater = util._RFC1979Deflater(
157            window_bits, no_context_takeover)
158
159        self._inflater = util._RFC1979Inflater()
160
161        self._compress_outgoing = True
162
163        response = common.ExtensionParameter(self._request.name())
164
165        if self._response_window_bits is not None:
166            response.add_parameter(
167                self._WINDOW_BITS_PARAM, str(self._response_window_bits))
168        if self._response_no_context_takeover:
169            response.add_parameter(
170                self._NO_CONTEXT_TAKEOVER_PARAM, None)
171
172        self._logger.debug(
173            'Enable %s extension ('
174            'request: window_bits=%s; no_context_takeover=%r, '
175            'response: window_wbits=%s; no_context_takeover=%r)' %
176            (self._request.name(),
177             window_bits,
178             no_context_takeover,
179             self._response_window_bits,
180             self._response_no_context_takeover))
181
182        return response
183
184    def setup_stream_options(self, stream_options):
185
186        class _OutgoingFilter(object):
187
188            def __init__(self, parent):
189                self._parent = parent
190
191            def filter(self, frame):
192                self._parent._outgoing_filter(frame)
193
194        class _IncomingFilter(object):
195
196            def __init__(self, parent):
197                self._parent = parent
198
199            def filter(self, frame):
200                self._parent._incoming_filter(frame)
201
202        stream_options.outgoing_frame_filters.append(
203            _OutgoingFilter(self))
204        stream_options.incoming_frame_filters.insert(
205            0, _IncomingFilter(self))
206
207    def set_response_window_bits(self, value):
208        self._response_window_bits = value
209
210    def set_response_no_context_takeover(self, value):
211        self._response_no_context_takeover = value
212
213    def enable_outgoing_compression(self):
214        self._compress_outgoing = True
215
216    def disable_outgoing_compression(self):
217        self._compress_outgoing = False
218
219    def _outgoing_filter(self, frame):
220        """Transform outgoing frames. This method is called only by
221        an _OutgoingFilter instance.
222        """
223
224        original_payload_size = len(frame.payload)
225        self._total_outgoing_payload_bytes += original_payload_size
226
227        if (not self._compress_outgoing or
228            common.is_control_opcode(frame.opcode)):
229            self._total_filtered_outgoing_payload_bytes += (
230                original_payload_size)
231            return
232
233        frame.payload = self._deflater.filter(frame.payload)
234        frame.rsv1 = 1
235
236        filtered_payload_size = len(frame.payload)
237        self._total_filtered_outgoing_payload_bytes += filtered_payload_size
238
239        _log_compression_ratio(self._logger, original_payload_size,
240                               self._total_outgoing_payload_bytes,
241                               filtered_payload_size,
242                               self._total_filtered_outgoing_payload_bytes)
243
244    def _incoming_filter(self, frame):
245        """Transform incoming frames. This method is called only by
246        an _IncomingFilter instance.
247        """
248
249        received_payload_size = len(frame.payload)
250        self._total_incoming_payload_bytes += received_payload_size
251
252        if frame.rsv1 != 1 or common.is_control_opcode(frame.opcode):
253            self._total_filtered_incoming_payload_bytes += (
254                received_payload_size)
255            return
256
257        frame.payload = self._inflater.filter(frame.payload)
258        frame.rsv1 = 0
259
260        filtered_payload_size = len(frame.payload)
261        self._total_filtered_incoming_payload_bytes += filtered_payload_size
262
263        _log_decompression_ratio(self._logger, received_payload_size,
264                                 self._total_incoming_payload_bytes,
265                                 filtered_payload_size,
266                                 self._total_filtered_incoming_payload_bytes)
267
268
269_available_processors[common.DEFLATE_FRAME_EXTENSION] = (
270    DeflateFrameExtensionProcessor)
271
272
273# Adding vendor-prefixed deflate-frame extension.
274# TODO(bashi): Remove this after WebKit stops using vender prefix.
275_available_processors[common.X_WEBKIT_DEFLATE_FRAME_EXTENSION] = (
276    DeflateFrameExtensionProcessor)
277
278
279def _parse_compression_method(data):
280    """Parses the value of "method" extension parameter."""
281
282    return common.parse_extensions(data, allow_quoted_string=True)
283
284
285def _create_accepted_method_desc(method_name, method_params):
286    """Creates accepted-method-desc from given method name and parameters"""
287
288    extension = common.ExtensionParameter(method_name)
289    for name, value in method_params:
290        extension.add_parameter(name, value)
291    return common.format_extension(extension)
292
293
294class CompressionExtensionProcessorBase(ExtensionProcessorInterface):
295    """Base class for Per-frame and Per-message compression extension."""
296
297    _METHOD_PARAM = 'method'
298
299    def __init__(self, request):
300        self._logger = util.get_class_logger(self)
301        self._request = request
302        self._compression_method_name = None
303        self._compression_processor = None
304
305    def name(self):
306        return ''
307
308    def _lookup_compression_processor(self, method_desc):
309        return None
310
311    def _get_compression_processor_response(self):
312        """Looks up the compression processor based on the self._request and
313           returns the compression processor's response.
314        """
315
316        method_list = self._request.get_parameter_value(self._METHOD_PARAM)
317        if method_list is None:
318            return None
319        methods = _parse_compression_method(method_list)
320        if methods is None:
321            return None
322        comression_processor = None
323        # The current implementation tries only the first method that matches
324        # supported algorithm. Following methods aren't tried even if the
325        # first one is rejected.
326        # TODO(bashi): Need to clarify this behavior.
327        for method_desc in methods:
328            compression_processor = self._lookup_compression_processor(
329                method_desc)
330            if compression_processor is not None:
331                self._compression_method_name = method_desc.name()
332                break
333        if compression_processor is None:
334            return None
335        processor_response = compression_processor.get_extension_response()
336        if processor_response is None:
337            return None
338        self._compression_processor = compression_processor
339        return processor_response
340
341    def get_extension_response(self):
342        processor_response = self._get_compression_processor_response()
343        if processor_response is None:
344            return None
345
346        response = common.ExtensionParameter(self._request.name())
347        accepted_method_desc = _create_accepted_method_desc(
348                                   self._compression_method_name,
349                                   processor_response.get_parameters())
350        response.add_parameter(self._METHOD_PARAM, accepted_method_desc)
351        self._logger.debug(
352            'Enable %s extension (method: %s)' %
353            (self._request.name(), self._compression_method_name))
354        return response
355
356    def setup_stream_options(self, stream_options):
357        if self._compression_processor is None:
358            return
359        self._compression_processor.setup_stream_options(stream_options)
360
361    def get_compression_processor(self):
362        return self._compression_processor
363
364
365class PerFrameCompressionExtensionProcessor(CompressionExtensionProcessorBase):
366    """WebSocket Per-frame compression extension processor."""
367
368    _DEFLATE_METHOD = 'deflate'
369
370    def __init__(self, request):
371        CompressionExtensionProcessorBase.__init__(self, request)
372
373    def name(self):
374        return common.PERFRAME_COMPRESSION_EXTENSION
375
376    def _lookup_compression_processor(self, method_desc):
377        if method_desc.name() == self._DEFLATE_METHOD:
378            return DeflateFrameExtensionProcessor(method_desc)
379
380
381_available_processors[common.PERFRAME_COMPRESSION_EXTENSION] = (
382    PerFrameCompressionExtensionProcessor)
383
384
385class DeflateMessageProcessor(ExtensionProcessorInterface):
386    """Per-message deflate processor."""
387
388    _S2C_MAX_WINDOW_BITS_PARAM = 's2c_max_window_bits'
389    _S2C_NO_CONTEXT_TAKEOVER_PARAM = 's2c_no_context_takeover'
390    _C2S_MAX_WINDOW_BITS_PARAM = 'c2s_max_window_bits'
391    _C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover'
392
393    def __init__(self, request):
394        self._request = request
395        self._logger = util.get_class_logger(self)
396
397        self._c2s_max_window_bits = None
398        self._c2s_no_context_takeover = False
399
400        self._compress_outgoing = False
401
402        # Counters for statistics.
403
404        # Total number of outgoing bytes supplied to this filter.
405        self._total_outgoing_payload_bytes = 0
406        # Total number of bytes sent to the network after applying this filter.
407        self._total_filtered_outgoing_payload_bytes = 0
408
409        # Total number of bytes received from the network.
410        self._total_incoming_payload_bytes = 0
411        # Total number of incoming bytes obtained after applying this filter.
412        self._total_filtered_incoming_payload_bytes = 0
413
414    def name(self):
415        return 'deflate'
416
417    def get_extension_response(self):
418        # Any unknown parameter will be just ignored.
419
420        s2c_max_window_bits = self._request.get_parameter_value(
421            self._S2C_MAX_WINDOW_BITS_PARAM)
422        if s2c_max_window_bits is not None:
423            try:
424                s2c_max_window_bits = int(s2c_max_window_bits)
425            except ValueError, e:
426                return None
427            if s2c_max_window_bits < 8 or s2c_max_window_bits > 15:
428                return None
429
430        s2c_no_context_takeover = self._request.has_parameter(
431            self._S2C_NO_CONTEXT_TAKEOVER_PARAM)
432        if (s2c_no_context_takeover and
433            self._request.get_parameter_value(
434                self._S2C_NO_CONTEXT_TAKEOVER_PARAM) is not None):
435            return None
436
437        self._deflater = util._RFC1979Deflater(
438            s2c_max_window_bits, s2c_no_context_takeover)
439
440        self._inflater = util._RFC1979Inflater()
441
442        self._compress_outgoing = True
443
444        response = common.ExtensionParameter(self._request.name())
445
446        if s2c_max_window_bits is not None:
447            response.add_parameter(
448                self._S2C_MAX_WINDOW_BITS_PARAM, str(s2c_max_window_bits))
449
450        if s2c_no_context_takeover is not None:
451            response.add_parameter(
452                self._S2C_NO_CONTEXT_TAKEOVER_PARAM, None)
453
454        if self._c2s_max_window_bits is not None:
455            response.add_parameter(
456                self._C2S_MAX_WINDOW_BITS_PARAM,
457                str(self._c2s_response_window_bits))
458        if self._c2s_no_context_takeover:
459            response.add_parameter(
460                self._C2S_NO_CONTEXT_TAKEOVER_PARAM, None)
461
462        self._logger.debug(
463            'Enable %s extension ('
464            'request: s2c_max_window_bits=%s; s2c_no_context_takeover=%r, '
465            'response: c2s_max_window_bits=%s; c2s_no_context_takeover=%r)' %
466            (self._request.name(),
467             s2c_max_window_bits,
468             s2c_no_context_takeover,
469             self._c2s_max_window_bits,
470             self._c2s_no_context_takeover))
471
472        return response
473
474    def setup_stream_options(self, stream_options):
475        class _OutgoingMessageFilter(object):
476
477            def __init__(self, parent):
478                self._parent = parent
479
480            def filter(self, message, end=True, binary=False):
481                return self._parent._process_outgoing_message(
482                    message, end, binary)
483
484        class _IncomingMessageFilter(object):
485
486            def __init__(self, parent):
487                self._parent = parent
488                self._decompress_next_message = False
489
490            def decompress_next_message(self):
491                self._decompress_next_message = True
492
493            def filter(self, message):
494                message = self._parent._process_incoming_message(
495                    message, self._decompress_next_message)
496                self._decompress_next_message = False
497                return message
498
499        self._outgoing_message_filter = _OutgoingMessageFilter(self)
500        self._incoming_message_filter = _IncomingMessageFilter(self)
501        stream_options.outgoing_message_filters.append(
502            self._outgoing_message_filter)
503        stream_options.incoming_message_filters.append(
504            self._incoming_message_filter)
505
506        class _OutgoingFrameFilter(object):
507
508            def __init__(self, parent):
509                self._parent = parent
510                self._set_compression_bit = False
511
512            def set_compression_bit(self):
513                self._set_compression_bit = True
514
515            def filter(self, frame):
516                self._parent._process_outgoing_frame(
517                    frame, self._set_compression_bit)
518                self._set_compression_bit = False
519
520        class _IncomingFrameFilter(object):
521
522            def __init__(self, parent):
523                self._parent = parent
524
525            def filter(self, frame):
526                self._parent._process_incoming_frame(frame)
527
528        self._outgoing_frame_filter = _OutgoingFrameFilter(self)
529        self._incoming_frame_filter = _IncomingFrameFilter(self)
530        stream_options.outgoing_frame_filters.append(
531            self._outgoing_frame_filter)
532        stream_options.incoming_frame_filters.append(
533            self._incoming_frame_filter)
534
535        stream_options.encode_text_message_to_utf8 = False
536
537    def set_c2s_window_bits(self, value):
538        self._c2s_max_window_bits = value
539
540    def set_c2s_no_context_takeover(self, value):
541        self._c2s_no_context_takeover = value
542
543    def enable_outgoing_compression(self):
544        self._compress_outgoing = True
545
546    def disable_outgoing_compression(self):
547        self._compress_outgoing = False
548
549    def _process_incoming_message(self, message, decompress):
550        if not decompress:
551            return message
552
553        received_payload_size = len(message)
554        self._total_incoming_payload_bytes += received_payload_size
555
556        message = self._inflater.filter(message)
557
558        filtered_payload_size = len(message)
559        self._total_filtered_incoming_payload_bytes += filtered_payload_size
560
561        _log_decompression_ratio(self._logger, received_payload_size,
562                                 self._total_incoming_payload_bytes,
563                                 filtered_payload_size,
564                                 self._total_filtered_incoming_payload_bytes)
565
566        return message
567
568    def _process_outgoing_message(self, message, end, binary):
569        if not binary:
570            message = message.encode('utf-8')
571
572        if not self._compress_outgoing:
573            return message
574
575        original_payload_size = len(message)
576        self._total_outgoing_payload_bytes += original_payload_size
577
578        message = self._deflater.filter(message)
579
580        filtered_payload_size = len(message)
581        self._total_filtered_outgoing_payload_bytes += filtered_payload_size
582
583        _log_compression_ratio(self._logger, original_payload_size,
584                               self._total_outgoing_payload_bytes,
585                               filtered_payload_size,
586                               self._total_filtered_outgoing_payload_bytes)
587
588        self._outgoing_frame_filter.set_compression_bit()
589        return message
590
591    def _process_incoming_frame(self, frame):
592        if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode):
593            self._incoming_message_filter.decompress_next_message()
594            frame.rsv1 = 0
595
596    def _process_outgoing_frame(self, frame, compression_bit):
597        if (not compression_bit or
598            common.is_control_opcode(frame.opcode)):
599            return
600
601        frame.rsv1 = 1
602
603
604class PerMessageCompressionExtensionProcessor(
605    CompressionExtensionProcessorBase):
606    """WebSocket Per-message compression extension processor."""
607
608    _DEFLATE_METHOD = 'deflate'
609
610    def __init__(self, request):
611        CompressionExtensionProcessorBase.__init__(self, request)
612
613    def name(self):
614        return common.PERMESSAGE_COMPRESSION_EXTENSION
615
616    def _lookup_compression_processor(self, method_desc):
617        if method_desc.name() == self._DEFLATE_METHOD:
618            return DeflateMessageProcessor(method_desc)
619
620
621_available_processors[common.PERMESSAGE_COMPRESSION_EXTENSION] = (
622    PerFrameCompressionExtensionProcessor)
623
624
625class MuxExtensionProcessor(ExtensionProcessorInterface):
626    """WebSocket multiplexing extension processor."""
627
628    _QUOTA_PARAM = 'quota'
629
630    def __init__(self, request):
631        self._request = request
632
633    def name(self):
634        return common.MUX_EXTENSION
635
636    def get_extension_response(self, ws_request,
637                               logical_channel_extensions):
638        # Mux extension cannot be used after extensions that depend on
639        # frame boundary, extension data field, or any reserved bits
640        # which are attributed to each frame.
641        for extension in logical_channel_extensions:
642            name = extension.name()
643            if (name == common.PERFRAME_COMPRESSION_EXTENSION or
644                name == common.DEFLATE_FRAME_EXTENSION or
645                name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION):
646                return None
647
648        quota = self._request.get_parameter_value(self._QUOTA_PARAM)
649        if quota is None:
650            ws_request.mux_quota = 0
651        else:
652            try:
653                quota = int(quota)
654            except ValueError, e:
655                return None
656            if quota < 0 or quota >= 2 ** 32:
657                return None
658            ws_request.mux_quota = quota
659
660        ws_request.mux = True
661        ws_request.mux_extensions = logical_channel_extensions
662        return common.ExtensionParameter(common.MUX_EXTENSION)
663
664    def setup_stream_options(self, stream_options):
665        pass
666
667
668_available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor
669
670
671def get_extension_processor(extension_request):
672    global _available_processors
673    processor_class = _available_processors.get(extension_request.name())
674    if processor_class is None:
675        return None
676    return processor_class(extension_request)
677
678
679# vi:sts=4 sw=4 et
680