1# Copyright 2011, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30 31"""WebSocket utilities. 32""" 33 34 35import array 36import errno 37 38# Import hash classes from a module available and recommended for each Python 39# version and re-export those symbol. Use sha and md5 module in Python 2.4, and 40# hashlib module in Python 2.6. 41try: 42 import hashlib 43 md5_hash = hashlib.md5 44 sha1_hash = hashlib.sha1 45except ImportError: 46 import md5 47 import sha 48 md5_hash = md5.md5 49 sha1_hash = sha.sha 50 51import StringIO 52import logging 53import os 54import re 55import socket 56import traceback 57import zlib 58 59 60def get_stack_trace(): 61 """Get the current stack trace as string. 62 63 This is needed to support Python 2.3. 64 TODO: Remove this when we only support Python 2.4 and above. 65 Use traceback.format_exc instead. 66 """ 67 68 out = StringIO.StringIO() 69 traceback.print_exc(file=out) 70 return out.getvalue() 71 72 73def prepend_message_to_exception(message, exc): 74 """Prepend message to the exception.""" 75 76 exc.args = (message + str(exc),) 77 return 78 79 80def __translate_interp(interp, cygwin_path): 81 """Translate interp program path for Win32 python to run cygwin program 82 (e.g. perl). Note that it doesn't support path that contains space, 83 which is typically true for Unix, where #!-script is written. 84 For Win32 python, cygwin_path is a directory of cygwin binaries. 85 86 Args: 87 interp: interp command line 88 cygwin_path: directory name of cygwin binary, or None 89 Returns: 90 translated interp command line. 91 """ 92 if not cygwin_path: 93 return interp 94 m = re.match('^[^ ]*/([^ ]+)( .*)?', interp) 95 if m: 96 cmd = os.path.join(cygwin_path, m.group(1)) 97 return cmd + m.group(2) 98 return interp 99 100 101def get_script_interp(script_path, cygwin_path=None): 102 """Gets #!-interpreter command line from the script. 103 104 It also fixes command path. When Cygwin Python is used, e.g. in WebKit, 105 it could run "/usr/bin/perl -wT hello.pl". 106 When Win32 Python is used, e.g. in Chromium, it couldn't. So, fix 107 "/usr/bin/perl" to "<cygwin_path>\perl.exe". 108 109 Args: 110 script_path: pathname of the script 111 cygwin_path: directory name of cygwin binary, or None 112 Returns: 113 #!-interpreter command line, or None if it is not #!-script. 114 """ 115 fp = open(script_path) 116 line = fp.readline() 117 fp.close() 118 m = re.match('^#!(.*)', line) 119 if m: 120 return __translate_interp(m.group(1), cygwin_path) 121 return None 122 123 124def wrap_popen3_for_win(cygwin_path): 125 """Wrap popen3 to support #!-script on Windows. 126 127 Args: 128 cygwin_path: path for cygwin binary if command path is needed to be 129 translated. None if no translation required. 130 """ 131 132 __orig_popen3 = os.popen3 133 134 def __wrap_popen3(cmd, mode='t', bufsize=-1): 135 cmdline = cmd.split(' ') 136 interp = get_script_interp(cmdline[0], cygwin_path) 137 if interp: 138 cmd = interp + ' ' + cmd 139 return __orig_popen3(cmd, mode, bufsize) 140 141 os.popen3 = __wrap_popen3 142 143 144def hexify(s): 145 return ' '.join(map(lambda x: '%02x' % ord(x), s)) 146 147 148def get_class_logger(o): 149 return logging.getLogger( 150 '%s.%s' % (o.__class__.__module__, o.__class__.__name__)) 151 152 153class NoopMasker(object): 154 """A masking object that has the same interface as RepeatedXorMasker but 155 just returns the string passed in without making any change. 156 """ 157 158 def __init__(self): 159 pass 160 161 def mask(self, s): 162 return s 163 164 165class RepeatedXorMasker(object): 166 """A masking object that applies XOR on the string given to mask method 167 with the masking bytes given to the constructor repeatedly. This object 168 remembers the position in the masking bytes the last mask method call 169 ended and resumes from that point on the next mask method call. 170 """ 171 172 def __init__(self, mask): 173 self._mask = map(ord, mask) 174 self._mask_size = len(self._mask) 175 self._count = 0 176 177 def mask(self, s): 178 result = array.array('B') 179 result.fromstring(s) 180 # Use temporary local variables to eliminate the cost to access 181 # attributes 182 count = self._count 183 mask = self._mask 184 mask_size = self._mask_size 185 for i in xrange(len(result)): 186 result[i] ^= mask[count] 187 count = (count + 1) % mask_size 188 self._count = count 189 190 return result.tostring() 191 192 193class DeflateRequest(object): 194 """A wrapper class for request object to intercept send and recv to perform 195 deflate compression and decompression transparently. 196 """ 197 198 def __init__(self, request): 199 self._request = request 200 self.connection = DeflateConnection(request.connection) 201 202 def __getattribute__(self, name): 203 if name in ('_request', 'connection'): 204 return object.__getattribute__(self, name) 205 return self._request.__getattribute__(name) 206 207 def __setattr__(self, name, value): 208 if name in ('_request', 'connection'): 209 return object.__setattr__(self, name, value) 210 return self._request.__setattr__(name, value) 211 212 213# By making wbits option negative, we can suppress CMF/FLG (2 octet) and 214# ADLER32 (4 octet) fields of zlib so that we can use zlib module just as 215# deflate library. DICTID won't be added as far as we don't set dictionary. 216# LZ77 window of 32K will be used for both compression and decompression. 217# For decompression, we can just use 32K to cover any windows size. For 218# compression, we use 32K so receivers must use 32K. 219# 220# Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level 221# to decode. 222# 223# See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of 224# Python. See also RFC1950 (ZLIB 3.3). 225 226 227class _Deflater(object): 228 229 def __init__(self, window_bits): 230 self._logger = get_class_logger(self) 231 232 self._compress = zlib.compressobj( 233 zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits) 234 235 def compress_and_flush(self, bytes): 236 compressed_bytes = self._compress.compress(bytes) 237 compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH) 238 self._logger.debug('Compress input %r', bytes) 239 self._logger.debug('Compress result %r', compressed_bytes) 240 return compressed_bytes 241 242 243class _Inflater(object): 244 245 def __init__(self): 246 self._logger = get_class_logger(self) 247 248 self._unconsumed = '' 249 250 self.reset() 251 252 def decompress(self, size): 253 if not (size == -1 or size > 0): 254 raise Exception('size must be -1 or positive') 255 256 data = '' 257 258 while True: 259 if size == -1: 260 data += self._decompress.decompress(self._unconsumed) 261 # See Python bug http://bugs.python.org/issue12050 to 262 # understand why the same code cannot be used for updating 263 # self._unconsumed for here and else block. 264 self._unconsumed = '' 265 else: 266 data += self._decompress.decompress( 267 self._unconsumed, size - len(data)) 268 self._unconsumed = self._decompress.unconsumed_tail 269 if self._decompress.unused_data: 270 # Encountered a last block (i.e. a block with BFINAL = 1) and 271 # found a new stream (unused_data). We cannot use the same 272 # zlib.Decompress object for the new stream. Create a new 273 # Decompress object to decompress the new one. 274 # 275 # It's fine to ignore unconsumed_tail if unused_data is not 276 # empty. 277 self._unconsumed = self._decompress.unused_data 278 self.reset() 279 if size >= 0 and len(data) == size: 280 # data is filled. Don't call decompress again. 281 break 282 else: 283 # Re-invoke Decompress.decompress to try to decompress all 284 # available bytes before invoking read which blocks until 285 # any new byte is available. 286 continue 287 else: 288 # Here, since unused_data is empty, even if unconsumed_tail is 289 # not empty, bytes of requested length are already in data. We 290 # don't have to "continue" here. 291 break 292 293 if data: 294 self._logger.debug('Decompressed %r', data) 295 return data 296 297 def append(self, data): 298 self._logger.debug('Appended %r', data) 299 self._unconsumed += data 300 301 def reset(self): 302 self._logger.debug('Reset') 303 self._decompress = zlib.decompressobj(-zlib.MAX_WBITS) 304 305 306# Compresses/decompresses given octets using the method introduced in RFC1979. 307 308 309class _RFC1979Deflater(object): 310 """A compressor class that applies DEFLATE to given byte sequence and 311 flushes using the algorithm described in the RFC1979 section 2.1. 312 """ 313 314 def __init__(self, window_bits, no_context_takeover): 315 self._deflater = None 316 if window_bits is None: 317 window_bits = zlib.MAX_WBITS 318 self._window_bits = window_bits 319 self._no_context_takeover = no_context_takeover 320 321 def filter(self, bytes): 322 if self._deflater is None or self._no_context_takeover: 323 self._deflater = _Deflater(self._window_bits) 324 325 # Strip last 4 octets which is LEN and NLEN field of a non-compressed 326 # block added for Z_SYNC_FLUSH. 327 return self._deflater.compress_and_flush(bytes)[:-4] 328 329 330class _RFC1979Inflater(object): 331 """A decompressor class for byte sequence compressed and flushed following 332 the algorithm described in the RFC1979 section 2.1. 333 """ 334 335 def __init__(self): 336 self._inflater = _Inflater() 337 338 def filter(self, bytes): 339 # Restore stripped LEN and NLEN field of a non-compressed block added 340 # for Z_SYNC_FLUSH. 341 self._inflater.append(bytes + '\x00\x00\xff\xff') 342 return self._inflater.decompress(-1) 343 344 345class DeflateSocket(object): 346 """A wrapper class for socket object to intercept send and recv to perform 347 deflate compression and decompression transparently. 348 """ 349 350 # Size of the buffer passed to recv to receive compressed data. 351 _RECV_SIZE = 4096 352 353 def __init__(self, socket): 354 self._socket = socket 355 356 self._logger = get_class_logger(self) 357 358 self._deflater = _Deflater(zlib.MAX_WBITS) 359 self._inflater = _Inflater() 360 361 def recv(self, size): 362 """Receives data from the socket specified on the construction up 363 to the specified size. Once any data is available, returns it even 364 if it's smaller than the specified size. 365 """ 366 367 # TODO(tyoshino): Allow call with size=0. It should block until any 368 # decompressed data is available. 369 if size <= 0: 370 raise Exception('Non-positive size passed') 371 while True: 372 data = self._inflater.decompress(size) 373 if len(data) != 0: 374 return data 375 376 read_data = self._socket.recv(DeflateSocket._RECV_SIZE) 377 if not read_data: 378 return '' 379 self._inflater.append(read_data) 380 381 def sendall(self, bytes): 382 self.send(bytes) 383 384 def send(self, bytes): 385 self._socket.sendall(self._deflater.compress_and_flush(bytes)) 386 return len(bytes) 387 388 389class DeflateConnection(object): 390 """A wrapper class for request object to intercept write and read to 391 perform deflate compression and decompression transparently. 392 """ 393 394 def __init__(self, connection): 395 self._connection = connection 396 397 self._logger = get_class_logger(self) 398 399 self._deflater = _Deflater(zlib.MAX_WBITS) 400 self._inflater = _Inflater() 401 402 def get_remote_addr(self): 403 return self._connection.remote_addr 404 remote_addr = property(get_remote_addr) 405 406 def put_bytes(self, bytes): 407 self.write(bytes) 408 409 def read(self, size=-1): 410 """Reads at most size bytes. Blocks until there's at least one byte 411 available. 412 """ 413 414 # TODO(tyoshino): Allow call with size=0. 415 if not (size == -1 or size > 0): 416 raise Exception('size must be -1 or positive') 417 418 data = '' 419 while True: 420 if size == -1: 421 data += self._inflater.decompress(-1) 422 else: 423 data += self._inflater.decompress(size - len(data)) 424 425 if size >= 0 and len(data) != 0: 426 break 427 428 # TODO(tyoshino): Make this read efficient by some workaround. 429 # 430 # In 3.0.3 and prior of mod_python, read blocks until length bytes 431 # was read. We don't know the exact size to read while using 432 # deflate, so read byte-by-byte. 433 # 434 # _StandaloneRequest.read that ultimately performs 435 # socket._fileobject.read also blocks until length bytes was read 436 read_data = self._connection.read(1) 437 if not read_data: 438 break 439 self._inflater.append(read_data) 440 return data 441 442 def write(self, bytes): 443 self._connection.write(self._deflater.compress_and_flush(bytes)) 444 445 446def _is_ewouldblock_errno(error_number): 447 """Returns True iff error_number indicates that receive operation would 448 block. To make this portable, we check availability of errno and then 449 compare them. 450 """ 451 452 for error_name in ['WSAEWOULDBLOCK', 'EWOULDBLOCK', 'EAGAIN']: 453 if (error_name in dir(errno) and 454 error_number == getattr(errno, error_name)): 455 return True 456 return False 457 458 459def drain_received_data(raw_socket): 460 # Set the socket non-blocking. 461 original_timeout = raw_socket.gettimeout() 462 raw_socket.settimeout(0.0) 463 464 drained_data = [] 465 466 # Drain until the socket is closed or no data is immediately 467 # available for read. 468 while True: 469 try: 470 data = raw_socket.recv(1) 471 if not data: 472 break 473 drained_data.append(data) 474 except socket.error, e: 475 # e can be either a pair (errno, string) or just a string (or 476 # something else) telling what went wrong. We suppress only 477 # the errors that indicates that the socket blocks. Those 478 # exceptions can be parsed as a pair (errno, string). 479 try: 480 error_number, message = e 481 except: 482 # Failed to parse socket.error. 483 raise e 484 485 if _is_ewouldblock_errno(error_number): 486 break 487 else: 488 raise e 489 490 # Rollback timeout value. 491 raw_socket.settimeout(original_timeout) 492 493 return ''.join(drained_data) 494 495 496# vi:sts=4 sw=4 et 497