1# Copyright 2012, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30 31from mod_pywebsocket import common 32from mod_pywebsocket import util 33from mod_pywebsocket.http_header_util import quote_if_necessary 34 35 36_available_processors = {} 37 38 39class ExtensionProcessorInterface(object): 40 41 def name(self): 42 return None 43 44 def get_extension_response(self): 45 return None 46 47 def setup_stream_options(self, stream_options): 48 pass 49 50 51class DeflateStreamExtensionProcessor(ExtensionProcessorInterface): 52 """WebSocket DEFLATE stream extension processor.""" 53 54 def __init__(self, request): 55 self._logger = util.get_class_logger(self) 56 57 self._request = request 58 59 def name(self): 60 return common.DEFLATE_STREAM_EXTENSION 61 62 def get_extension_response(self): 63 if len(self._request.get_parameter_names()) != 0: 64 return None 65 66 self._logger.debug( 67 'Enable %s extension', common.DEFLATE_STREAM_EXTENSION) 68 69 return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION) 70 71 def setup_stream_options(self, stream_options): 72 stream_options.deflate_stream = True 73 74 75_available_processors[common.DEFLATE_STREAM_EXTENSION] = ( 76 DeflateStreamExtensionProcessor) 77 78 79def _log_compression_ratio(logger, original_bytes, total_original_bytes, 80 filtered_bytes, total_filtered_bytes): 81 # Print inf when ratio is not available. 82 ratio = float('inf') 83 average_ratio = float('inf') 84 if original_bytes != 0: 85 ratio = float(filtered_bytes) / original_bytes 86 if total_original_bytes != 0: 87 average_ratio = ( 88 float(total_filtered_bytes) / total_original_bytes) 89 logger.debug('Outgoing compress ratio: %f (average: %f)' % 90 (ratio, average_ratio)) 91 92 93def _log_decompression_ratio(logger, received_bytes, total_received_bytes, 94 filtered_bytes, total_filtered_bytes): 95 # Print inf when ratio is not available. 96 ratio = float('inf') 97 average_ratio = float('inf') 98 if received_bytes != 0: 99 ratio = float(received_bytes) / filtered_bytes 100 if total_filtered_bytes != 0: 101 average_ratio = ( 102 float(total_received_bytes) / total_filtered_bytes) 103 logger.debug('Incoming compress ratio: %f (average: %f)' % 104 (ratio, average_ratio)) 105 106 107class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): 108 """WebSocket Per-frame DEFLATE extension processor.""" 109 110 _WINDOW_BITS_PARAM = 'max_window_bits' 111 _NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover' 112 113 def __init__(self, request): 114 self._logger = util.get_class_logger(self) 115 116 self._request = request 117 118 self._response_window_bits = None 119 self._response_no_context_takeover = False 120 121 # Counters for statistics. 122 123 # Total number of outgoing bytes supplied to this filter. 124 self._total_outgoing_payload_bytes = 0 125 # Total number of bytes sent to the network after applying this filter. 126 self._total_filtered_outgoing_payload_bytes = 0 127 128 # Total number of bytes received from the network. 129 self._total_incoming_payload_bytes = 0 130 # Total number of incoming bytes obtained after applying this filter. 131 self._total_filtered_incoming_payload_bytes = 0 132 133 def name(self): 134 return common.DEFLATE_FRAME_EXTENSION 135 136 def get_extension_response(self): 137 # Any unknown parameter will be just ignored. 138 139 window_bits = self._request.get_parameter_value( 140 self._WINDOW_BITS_PARAM) 141 no_context_takeover = self._request.has_parameter( 142 self._NO_CONTEXT_TAKEOVER_PARAM) 143 if (no_context_takeover and 144 self._request.get_parameter_value( 145 self._NO_CONTEXT_TAKEOVER_PARAM) is not None): 146 return None 147 148 if window_bits is not None: 149 try: 150 window_bits = int(window_bits) 151 except ValueError, e: 152 return None 153 if window_bits < 8 or window_bits > 15: 154 return None 155 156 self._deflater = util._RFC1979Deflater( 157 window_bits, no_context_takeover) 158 159 self._inflater = util._RFC1979Inflater() 160 161 self._compress_outgoing = True 162 163 response = common.ExtensionParameter(self._request.name()) 164 165 if self._response_window_bits is not None: 166 response.add_parameter( 167 self._WINDOW_BITS_PARAM, str(self._response_window_bits)) 168 if self._response_no_context_takeover: 169 response.add_parameter( 170 self._NO_CONTEXT_TAKEOVER_PARAM, None) 171 172 self._logger.debug( 173 'Enable %s extension (' 174 'request: window_bits=%s; no_context_takeover=%r, ' 175 'response: window_wbits=%s; no_context_takeover=%r)' % 176 (self._request.name(), 177 window_bits, 178 no_context_takeover, 179 self._response_window_bits, 180 self._response_no_context_takeover)) 181 182 return response 183 184 def setup_stream_options(self, stream_options): 185 186 class _OutgoingFilter(object): 187 188 def __init__(self, parent): 189 self._parent = parent 190 191 def filter(self, frame): 192 self._parent._outgoing_filter(frame) 193 194 class _IncomingFilter(object): 195 196 def __init__(self, parent): 197 self._parent = parent 198 199 def filter(self, frame): 200 self._parent._incoming_filter(frame) 201 202 stream_options.outgoing_frame_filters.append( 203 _OutgoingFilter(self)) 204 stream_options.incoming_frame_filters.insert( 205 0, _IncomingFilter(self)) 206 207 def set_response_window_bits(self, value): 208 self._response_window_bits = value 209 210 def set_response_no_context_takeover(self, value): 211 self._response_no_context_takeover = value 212 213 def enable_outgoing_compression(self): 214 self._compress_outgoing = True 215 216 def disable_outgoing_compression(self): 217 self._compress_outgoing = False 218 219 def _outgoing_filter(self, frame): 220 """Transform outgoing frames. This method is called only by 221 an _OutgoingFilter instance. 222 """ 223 224 original_payload_size = len(frame.payload) 225 self._total_outgoing_payload_bytes += original_payload_size 226 227 if (not self._compress_outgoing or 228 common.is_control_opcode(frame.opcode)): 229 self._total_filtered_outgoing_payload_bytes += ( 230 original_payload_size) 231 return 232 233 frame.payload = self._deflater.filter(frame.payload) 234 frame.rsv1 = 1 235 236 filtered_payload_size = len(frame.payload) 237 self._total_filtered_outgoing_payload_bytes += filtered_payload_size 238 239 _log_compression_ratio(self._logger, original_payload_size, 240 self._total_outgoing_payload_bytes, 241 filtered_payload_size, 242 self._total_filtered_outgoing_payload_bytes) 243 244 def _incoming_filter(self, frame): 245 """Transform incoming frames. This method is called only by 246 an _IncomingFilter instance. 247 """ 248 249 received_payload_size = len(frame.payload) 250 self._total_incoming_payload_bytes += received_payload_size 251 252 if frame.rsv1 != 1 or common.is_control_opcode(frame.opcode): 253 self._total_filtered_incoming_payload_bytes += ( 254 received_payload_size) 255 return 256 257 frame.payload = self._inflater.filter(frame.payload) 258 frame.rsv1 = 0 259 260 filtered_payload_size = len(frame.payload) 261 self._total_filtered_incoming_payload_bytes += filtered_payload_size 262 263 _log_decompression_ratio(self._logger, received_payload_size, 264 self._total_incoming_payload_bytes, 265 filtered_payload_size, 266 self._total_filtered_incoming_payload_bytes) 267 268 269_available_processors[common.DEFLATE_FRAME_EXTENSION] = ( 270 DeflateFrameExtensionProcessor) 271 272 273# Adding vendor-prefixed deflate-frame extension. 274# TODO(bashi): Remove this after WebKit stops using vender prefix. 275_available_processors[common.X_WEBKIT_DEFLATE_FRAME_EXTENSION] = ( 276 DeflateFrameExtensionProcessor) 277 278 279def _parse_compression_method(data): 280 """Parses the value of "method" extension parameter.""" 281 282 return common.parse_extensions(data, allow_quoted_string=True) 283 284 285def _create_accepted_method_desc(method_name, method_params): 286 """Creates accepted-method-desc from given method name and parameters""" 287 288 extension = common.ExtensionParameter(method_name) 289 for name, value in method_params: 290 extension.add_parameter(name, value) 291 return common.format_extension(extension) 292 293 294class CompressionExtensionProcessorBase(ExtensionProcessorInterface): 295 """Base class for Per-frame and Per-message compression extension.""" 296 297 _METHOD_PARAM = 'method' 298 299 def __init__(self, request): 300 self._logger = util.get_class_logger(self) 301 self._request = request 302 self._compression_method_name = None 303 self._compression_processor = None 304 305 def name(self): 306 return '' 307 308 def _lookup_compression_processor(self, method_desc): 309 return None 310 311 def _get_compression_processor_response(self): 312 """Looks up the compression processor based on the self._request and 313 returns the compression processor's response. 314 """ 315 316 method_list = self._request.get_parameter_value(self._METHOD_PARAM) 317 if method_list is None: 318 return None 319 methods = _parse_compression_method(method_list) 320 if methods is None: 321 return None 322 comression_processor = None 323 # The current implementation tries only the first method that matches 324 # supported algorithm. Following methods aren't tried even if the 325 # first one is rejected. 326 # TODO(bashi): Need to clarify this behavior. 327 for method_desc in methods: 328 compression_processor = self._lookup_compression_processor( 329 method_desc) 330 if compression_processor is not None: 331 self._compression_method_name = method_desc.name() 332 break 333 if compression_processor is None: 334 return None 335 processor_response = compression_processor.get_extension_response() 336 if processor_response is None: 337 return None 338 self._compression_processor = compression_processor 339 return processor_response 340 341 def get_extension_response(self): 342 processor_response = self._get_compression_processor_response() 343 if processor_response is None: 344 return None 345 346 response = common.ExtensionParameter(self._request.name()) 347 accepted_method_desc = _create_accepted_method_desc( 348 self._compression_method_name, 349 processor_response.get_parameters()) 350 response.add_parameter(self._METHOD_PARAM, accepted_method_desc) 351 self._logger.debug( 352 'Enable %s extension (method: %s)' % 353 (self._request.name(), self._compression_method_name)) 354 return response 355 356 def setup_stream_options(self, stream_options): 357 if self._compression_processor is None: 358 return 359 self._compression_processor.setup_stream_options(stream_options) 360 361 def get_compression_processor(self): 362 return self._compression_processor 363 364 365class PerFrameCompressionExtensionProcessor(CompressionExtensionProcessorBase): 366 """WebSocket Per-frame compression extension processor.""" 367 368 _DEFLATE_METHOD = 'deflate' 369 370 def __init__(self, request): 371 CompressionExtensionProcessorBase.__init__(self, request) 372 373 def name(self): 374 return common.PERFRAME_COMPRESSION_EXTENSION 375 376 def _lookup_compression_processor(self, method_desc): 377 if method_desc.name() == self._DEFLATE_METHOD: 378 return DeflateFrameExtensionProcessor(method_desc) 379 380 381_available_processors[common.PERFRAME_COMPRESSION_EXTENSION] = ( 382 PerFrameCompressionExtensionProcessor) 383 384 385class DeflateMessageProcessor(ExtensionProcessorInterface): 386 """Per-message deflate processor.""" 387 388 _S2C_MAX_WINDOW_BITS_PARAM = 's2c_max_window_bits' 389 _S2C_NO_CONTEXT_TAKEOVER_PARAM = 's2c_no_context_takeover' 390 _C2S_MAX_WINDOW_BITS_PARAM = 'c2s_max_window_bits' 391 _C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover' 392 393 def __init__(self, request): 394 self._request = request 395 self._logger = util.get_class_logger(self) 396 397 self._c2s_max_window_bits = None 398 self._c2s_no_context_takeover = False 399 400 self._compress_outgoing = False 401 402 # Counters for statistics. 403 404 # Total number of outgoing bytes supplied to this filter. 405 self._total_outgoing_payload_bytes = 0 406 # Total number of bytes sent to the network after applying this filter. 407 self._total_filtered_outgoing_payload_bytes = 0 408 409 # Total number of bytes received from the network. 410 self._total_incoming_payload_bytes = 0 411 # Total number of incoming bytes obtained after applying this filter. 412 self._total_filtered_incoming_payload_bytes = 0 413 414 def name(self): 415 return 'deflate' 416 417 def get_extension_response(self): 418 # Any unknown parameter will be just ignored. 419 420 s2c_max_window_bits = self._request.get_parameter_value( 421 self._S2C_MAX_WINDOW_BITS_PARAM) 422 if s2c_max_window_bits is not None: 423 try: 424 s2c_max_window_bits = int(s2c_max_window_bits) 425 except ValueError, e: 426 return None 427 if s2c_max_window_bits < 8 or s2c_max_window_bits > 15: 428 return None 429 430 s2c_no_context_takeover = self._request.has_parameter( 431 self._S2C_NO_CONTEXT_TAKEOVER_PARAM) 432 if (s2c_no_context_takeover and 433 self._request.get_parameter_value( 434 self._S2C_NO_CONTEXT_TAKEOVER_PARAM) is not None): 435 return None 436 437 self._deflater = util._RFC1979Deflater( 438 s2c_max_window_bits, s2c_no_context_takeover) 439 440 self._inflater = util._RFC1979Inflater() 441 442 self._compress_outgoing = True 443 444 response = common.ExtensionParameter(self._request.name()) 445 446 if s2c_max_window_bits is not None: 447 response.add_parameter( 448 self._S2C_MAX_WINDOW_BITS_PARAM, str(s2c_max_window_bits)) 449 450 if s2c_no_context_takeover is not None: 451 response.add_parameter( 452 self._S2C_NO_CONTEXT_TAKEOVER_PARAM, None) 453 454 if self._c2s_max_window_bits is not None: 455 response.add_parameter( 456 self._C2S_MAX_WINDOW_BITS_PARAM, 457 str(self._c2s_response_window_bits)) 458 if self._c2s_no_context_takeover: 459 response.add_parameter( 460 self._C2S_NO_CONTEXT_TAKEOVER_PARAM, None) 461 462 self._logger.debug( 463 'Enable %s extension (' 464 'request: s2c_max_window_bits=%s; s2c_no_context_takeover=%r, ' 465 'response: c2s_max_window_bits=%s; c2s_no_context_takeover=%r)' % 466 (self._request.name(), 467 s2c_max_window_bits, 468 s2c_no_context_takeover, 469 self._c2s_max_window_bits, 470 self._c2s_no_context_takeover)) 471 472 return response 473 474 def setup_stream_options(self, stream_options): 475 class _OutgoingMessageFilter(object): 476 477 def __init__(self, parent): 478 self._parent = parent 479 480 def filter(self, message, end=True, binary=False): 481 return self._parent._process_outgoing_message( 482 message, end, binary) 483 484 class _IncomingMessageFilter(object): 485 486 def __init__(self, parent): 487 self._parent = parent 488 self._decompress_next_message = False 489 490 def decompress_next_message(self): 491 self._decompress_next_message = True 492 493 def filter(self, message): 494 message = self._parent._process_incoming_message( 495 message, self._decompress_next_message) 496 self._decompress_next_message = False 497 return message 498 499 self._outgoing_message_filter = _OutgoingMessageFilter(self) 500 self._incoming_message_filter = _IncomingMessageFilter(self) 501 stream_options.outgoing_message_filters.append( 502 self._outgoing_message_filter) 503 stream_options.incoming_message_filters.append( 504 self._incoming_message_filter) 505 506 class _OutgoingFrameFilter(object): 507 508 def __init__(self, parent): 509 self._parent = parent 510 self._set_compression_bit = False 511 512 def set_compression_bit(self): 513 self._set_compression_bit = True 514 515 def filter(self, frame): 516 self._parent._process_outgoing_frame( 517 frame, self._set_compression_bit) 518 self._set_compression_bit = False 519 520 class _IncomingFrameFilter(object): 521 522 def __init__(self, parent): 523 self._parent = parent 524 525 def filter(self, frame): 526 self._parent._process_incoming_frame(frame) 527 528 self._outgoing_frame_filter = _OutgoingFrameFilter(self) 529 self._incoming_frame_filter = _IncomingFrameFilter(self) 530 stream_options.outgoing_frame_filters.append( 531 self._outgoing_frame_filter) 532 stream_options.incoming_frame_filters.append( 533 self._incoming_frame_filter) 534 535 stream_options.encode_text_message_to_utf8 = False 536 537 def set_c2s_window_bits(self, value): 538 self._c2s_max_window_bits = value 539 540 def set_c2s_no_context_takeover(self, value): 541 self._c2s_no_context_takeover = value 542 543 def enable_outgoing_compression(self): 544 self._compress_outgoing = True 545 546 def disable_outgoing_compression(self): 547 self._compress_outgoing = False 548 549 def _process_incoming_message(self, message, decompress): 550 if not decompress: 551 return message 552 553 received_payload_size = len(message) 554 self._total_incoming_payload_bytes += received_payload_size 555 556 message = self._inflater.filter(message) 557 558 filtered_payload_size = len(message) 559 self._total_filtered_incoming_payload_bytes += filtered_payload_size 560 561 _log_decompression_ratio(self._logger, received_payload_size, 562 self._total_incoming_payload_bytes, 563 filtered_payload_size, 564 self._total_filtered_incoming_payload_bytes) 565 566 return message 567 568 def _process_outgoing_message(self, message, end, binary): 569 if not binary: 570 message = message.encode('utf-8') 571 572 if not self._compress_outgoing: 573 return message 574 575 original_payload_size = len(message) 576 self._total_outgoing_payload_bytes += original_payload_size 577 578 message = self._deflater.filter(message) 579 580 filtered_payload_size = len(message) 581 self._total_filtered_outgoing_payload_bytes += filtered_payload_size 582 583 _log_compression_ratio(self._logger, original_payload_size, 584 self._total_outgoing_payload_bytes, 585 filtered_payload_size, 586 self._total_filtered_outgoing_payload_bytes) 587 588 self._outgoing_frame_filter.set_compression_bit() 589 return message 590 591 def _process_incoming_frame(self, frame): 592 if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode): 593 self._incoming_message_filter.decompress_next_message() 594 frame.rsv1 = 0 595 596 def _process_outgoing_frame(self, frame, compression_bit): 597 if (not compression_bit or 598 common.is_control_opcode(frame.opcode)): 599 return 600 601 frame.rsv1 = 1 602 603 604class PerMessageCompressionExtensionProcessor( 605 CompressionExtensionProcessorBase): 606 """WebSocket Per-message compression extension processor.""" 607 608 _DEFLATE_METHOD = 'deflate' 609 610 def __init__(self, request): 611 CompressionExtensionProcessorBase.__init__(self, request) 612 613 def name(self): 614 return common.PERMESSAGE_COMPRESSION_EXTENSION 615 616 def _lookup_compression_processor(self, method_desc): 617 if method_desc.name() == self._DEFLATE_METHOD: 618 return DeflateMessageProcessor(method_desc) 619 620 621_available_processors[common.PERMESSAGE_COMPRESSION_EXTENSION] = ( 622 PerFrameCompressionExtensionProcessor) 623 624 625class MuxExtensionProcessor(ExtensionProcessorInterface): 626 """WebSocket multiplexing extension processor.""" 627 628 _QUOTA_PARAM = 'quota' 629 630 def __init__(self, request): 631 self._request = request 632 633 def name(self): 634 return common.MUX_EXTENSION 635 636 def get_extension_response(self, ws_request, 637 logical_channel_extensions): 638 # Mux extension cannot be used after extensions that depend on 639 # frame boundary, extension data field, or any reserved bits 640 # which are attributed to each frame. 641 for extension in logical_channel_extensions: 642 name = extension.name() 643 if (name == common.PERFRAME_COMPRESSION_EXTENSION or 644 name == common.DEFLATE_FRAME_EXTENSION or 645 name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): 646 return None 647 648 quota = self._request.get_parameter_value(self._QUOTA_PARAM) 649 if quota is None: 650 ws_request.mux_quota = 0 651 else: 652 try: 653 quota = int(quota) 654 except ValueError, e: 655 return None 656 if quota < 0 or quota >= 2 ** 32: 657 return None 658 ws_request.mux_quota = quota 659 660 ws_request.mux = True 661 ws_request.mux_extensions = logical_channel_extensions 662 return common.ExtensionParameter(common.MUX_EXTENSION) 663 664 def setup_stream_options(self, stream_options): 665 pass 666 667 668_available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor 669 670 671def get_extension_processor(extension_request): 672 global _available_processors 673 processor_class = _available_processors.get(extension_request.name()) 674 if processor_class is None: 675 return None 676 return processor_class(extension_request) 677 678 679# vi:sts=4 sw=4 et 680