#!/usr/bin/env python # # Copyright 2012, Google Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following disclaimer # in the documentation and/or other materials provided with the # distribution. # * Neither the name of Google Inc. nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Tests for mux module.""" import Queue import logging import optparse import unittest import struct import sys import set_sys_path # Update sys.path to locate mod_pywebsocket module. from mod_pywebsocket import common from mod_pywebsocket import mux from mod_pywebsocket._stream_base import ConnectionTerminatedException from mod_pywebsocket._stream_hybi import Stream from mod_pywebsocket._stream_hybi import StreamOptions from mod_pywebsocket._stream_hybi import create_binary_frame from mod_pywebsocket._stream_hybi import parse_frame import mock class _OutgoingChannelData(object): def __init__(self): self.messages = [] self.control_messages = [] self.current_opcode = None self.pending_fragments = [] class _MockMuxConnection(mock.MockBlockingConn): """Mock class of mod_python connection for mux.""" def __init__(self): mock.MockBlockingConn.__init__(self) self._control_blocks = [] self._channel_data = {} self._current_opcode = None self._pending_fragments = [] def write(self, data): """Override MockBlockingConn.write.""" self._current_data = data self._position = 0 def _receive_bytes(length): if self._position + length > len(self._current_data): raise ConnectionTerminatedException( 'Failed to receive %d bytes from encapsulated ' 'frame' % length) data = self._current_data[self._position:self._position+length] self._position += length return data opcode, payload, fin, rsv1, rsv2, rsv3 = ( parse_frame(_receive_bytes, unmask_receive=False)) self._pending_fragments.append(payload) if self._current_opcode is None: if opcode == common.OPCODE_CONTINUATION: raise Exception('Sending invalid continuation opcode') self._current_opcode = opcode else: if opcode != common.OPCODE_CONTINUATION: raise Exception('Sending invalid opcode %d' % opcode) if not fin: return inner_frame_data = ''.join(self._pending_fragments) self._pending_fragments = [] self._current_opcode = None parser = mux._MuxFramePayloadParser(inner_frame_data) channel_id = parser.read_channel_id() if channel_id == mux._CONTROL_CHANNEL_ID: self._control_blocks.append(parser.remaining_data()) return if not channel_id in self._channel_data: self._channel_data[channel_id] = _OutgoingChannelData() channel_data = self._channel_data[channel_id] (inner_fin, inner_rsv1, inner_rsv2, inner_rsv3, inner_opcode, inner_payload) = parser.read_inner_frame() channel_data.pending_fragments.append(inner_payload) if channel_data.current_opcode is None: if inner_opcode == common.OPCODE_CONTINUATION: raise Exception('Sending invalid continuation opcode') channel_data.current_opcode = inner_opcode else: if inner_opcode != common.OPCODE_CONTINUATION: raise Exception('Sending invalid opcode %d' % inner_opcode) if not inner_fin: return message = ''.join(channel_data.pending_fragments) channel_data.pending_fragments = [] if (channel_data.current_opcode == common.OPCODE_TEXT or channel_data.current_opcode == common.OPCODE_BINARY): channel_data.messages.append(message) else: channel_data.control_messages.append( {'opcode': channel_data.current_opcode, 'message': message}) channel_data.current_opcode = None def get_written_control_blocks(self): return self._control_blocks def get_written_messages(self, channel_id): return self._channel_data[channel_id].messages def get_written_control_messages(self, channel_id): return self._channel_data[channel_id].control_messages class _ChannelEvent(object): """A structure that records channel events.""" def __init__(self): self.messages = [] self.exception = None self.client_initiated_closing = False class _MuxMockDispatcher(object): """Mock class of dispatch.Dispatcher for mux.""" def __init__(self): self.channel_events = {} def do_extra_handshake(self, request): pass def _do_echo(self, request, channel_events): while True: message = request.ws_stream.receive_message() if message == None: channel_events.client_initiated_closing = True return if message == 'Goodbye': return channel_events.messages.append(message) # echo back request.ws_stream.send_message(message) def _do_ping(self, request, channel_events): request.ws_stream.send_ping('Ping!') def transfer_data(self, request): self.channel_events[request.channel_id] = _ChannelEvent() try: # Note: more handler will be added. if request.uri.endswith('echo'): self._do_echo(request, self.channel_events[request.channel_id]) elif request.uri.endswith('ping'): self._do_ping(request, self.channel_events[request.channel_id]) else: raise ValueError('Cannot handle path %r' % request.path) if not request.server_terminated: request.ws_stream.close_connection() except ConnectionTerminatedException, e: self.channel_events[request.channel_id].exception = e except Exception, e: self.channel_events[request.channel_id].exception = e raise def _create_mock_request(): headers = {'Host': 'server.example.com', 'Upgrade': 'websocket', 'Connection': 'Upgrade', 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', 'Sec-WebSocket-Version': '13', 'Origin': 'http://example.com'} request = mock.MockRequest(uri='/echo', headers_in=headers, connection=_MockMuxConnection()) request.ws_stream = Stream(request, options=StreamOptions()) request.mux = True request.mux_extensions = [] request.mux_quota = 8 * 1024 return request def _create_add_channel_request_frame(channel_id, encoding, encoded_handshake): if encoding != 0 and encoding != 1: raise ValueError('Invalid encoding') block = mux._create_control_block_length_value( channel_id, mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, encoding, encoded_handshake) payload = mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + block return create_binary_frame(payload, mask=True) def _create_logical_frame(channel_id, message, opcode=common.OPCODE_BINARY, mask=True): bits = chr(0x80 | opcode) payload = mux._encode_channel_id(channel_id) + bits + message return create_binary_frame(payload, mask=mask) def _create_request_header(path='/echo'): return ( 'GET %s HTTP/1.1\r\n' 'Host: server.example.com\r\n' 'Upgrade: websocket\r\n' 'Connection: Upgrade\r\n' 'Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n' 'Sec-WebSocket-Version: 13\r\n' 'Origin: http://example.com\r\n' '\r\n') % path class MuxTest(unittest.TestCase): """A unittest for mux module.""" def test_channel_id_decode(self): data = '\x00\x01\xbf\xff\xdf\xff\xff\xff\xff\xff\xff' parser = mux._MuxFramePayloadParser(data) channel_id = parser.read_channel_id() self.assertEqual(0, channel_id) channel_id = parser.read_channel_id() self.assertEqual(1, channel_id) channel_id = parser.read_channel_id() self.assertEqual(2 ** 14 - 1, channel_id) channel_id = parser.read_channel_id() self.assertEqual(2 ** 21 - 1, channel_id) channel_id = parser.read_channel_id() self.assertEqual(2 ** 29 - 1, channel_id) self.assertEqual(len(data), parser._read_position) def test_channel_id_encode(self): encoded = mux._encode_channel_id(0) self.assertEqual('\x00', encoded) encoded = mux._encode_channel_id(2 ** 14 - 1) self.assertEqual('\xbf\xff', encoded) encoded = mux._encode_channel_id(2 ** 14) self.assertEqual('\xc0@\x00', encoded) encoded = mux._encode_channel_id(2 ** 21 - 1) self.assertEqual('\xdf\xff\xff', encoded) encoded = mux._encode_channel_id(2 ** 21) self.assertEqual('\xe0 \x00\x00', encoded) encoded = mux._encode_channel_id(2 ** 29 - 1) self.assertEqual('\xff\xff\xff\xff', encoded) # channel_id is too large self.assertRaises(ValueError, mux._encode_channel_id, 2 ** 29) def test_create_control_block_length_value(self): data = 'Hello, world!' block = mux._create_control_block_length_value( channel_id=1, opcode=mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, flags=0x7, value=data) expected = '\x1c\x01\x0dHello, world!' self.assertEqual(expected, block) data = 'a' * (2 ** 8) block = mux._create_control_block_length_value( channel_id=2, opcode=mux._MUX_OPCODE_ADD_CHANNEL_RESPONSE, flags=0x0, value=data) expected = '\x21\x02\x01\x00' + data self.assertEqual(expected, block) data = 'b' * (2 ** 16) block = mux._create_control_block_length_value( channel_id=3, opcode=mux._MUX_OPCODE_DROP_CHANNEL, flags=0x0, value=data) expected = '\x62\x03\x01\x00\x00' + data self.assertEqual(expected, block) def test_read_control_blocks(self): data = ('\x00\x01\00' '\x61\x02\x01\x00%s' '\x0a\x03\x01\x00\x00%s' '\x63\x04\x01\x00\x00\x00%s') % ( 'a' * 0x0100, 'b' * 0x010000, 'c' * 0x01000000) parser = mux._MuxFramePayloadParser(data) blocks = list(parser.read_control_blocks()) self.assertEqual(4, len(blocks)) self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[0].opcode) self.assertEqual(0, blocks[0].encoding) self.assertEqual(0, len(blocks[0].encoded_handshake)) self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[1].opcode) self.assertEqual(0, blocks[1].mux_error) self.assertEqual(0x0100, len(blocks[1].reason)) self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[2].opcode) self.assertEqual(2, blocks[2].encoding) self.assertEqual(0x010000, len(blocks[2].encoded_handshake)) self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[3].opcode) self.assertEqual(0, blocks[3].mux_error) self.assertEqual(0x01000000, len(blocks[3].reason)) self.assertEqual(len(data), parser._read_position) def test_create_add_channel_response(self): data = mux._create_add_channel_response(channel_id=1, encoded_handshake='FooBar', encoding=0, rejected=False) self.assertEqual('\x82\x0a\x00\x20\x01\x06FooBar', data) data = mux._create_add_channel_response(channel_id=2, encoded_handshake='Hello', encoding=1, rejected=True) self.assertEqual('\x82\x09\x00\x34\x02\x05Hello', data) def test_drop_channel(self): data = mux._create_drop_channel(channel_id=1, reason='', mux_error=False) self.assertEqual('\x82\x04\x00\x60\x01\x00', data) data = mux._create_drop_channel(channel_id=1, reason='error', mux_error=True) self.assertEqual('\x82\x09\x00\x70\x01\x05error', data) # reason must be empty if mux_error is False. self.assertRaises(ValueError, mux._create_drop_channel, 1, 'FooBar', False) def test_parse_request_text(self): request_text = _create_request_header() command, path, version, headers = mux._parse_request_text(request_text) self.assertEqual('GET', command) self.assertEqual('/echo', path) self.assertEqual('HTTP/1.1', version) self.assertEqual(6, len(headers)) self.assertEqual('server.example.com', headers['Host']) self.assertEqual('websocket', headers['Upgrade']) self.assertEqual('Upgrade', headers['Connection']) self.assertEqual('dGhlIHNhbXBsZSBub25jZQ==', headers['Sec-WebSocket-Key']) self.assertEqual('13', headers['Sec-WebSocket-Version']) self.assertEqual('http://example.com', headers['Origin']) class MuxHandlerTest(unittest.TestCase): def test_add_channel(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) flow_control = mux._create_flow_control(channel_id=2, replenished_quota=5, outer_frame_mask=True) request.connection.put_bytes(flow_control) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=3, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) flow_control = mux._create_flow_control(channel_id=3, replenished_quota=5, outer_frame_mask=True) request.connection.put_bytes(flow_control) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='Hello')) request.connection.put_bytes( _create_logical_frame(channel_id=3, message='World')) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='Goodbye')) request.connection.put_bytes( _create_logical_frame(channel_id=3, message='Goodbye')) mux_handler.wait_until_done(timeout=2) self.assertEqual([], dispatcher.channel_events[1].messages) self.assertEqual(['Hello'], dispatcher.channel_events[2].messages) self.assertEqual(['World'], dispatcher.channel_events[3].messages) # Channel 2 messages = request.connection.get_written_messages(2) self.assertEqual(1, len(messages)) self.assertEqual('Hello', messages[0]) # Channel 3 messages = request.connection.get_written_messages(3) self.assertEqual(1, len(messages)) self.assertEqual('World', messages[0]) control_blocks = request.connection.get_written_control_blocks() # There should be 8 control blocks: # - 1 NewChannelSlot # - 2 AddChannelResponses for channel id 2 and 3 # - 6 FlowControls for channel id 1 (initialize), 'Hello', 'World', # and 3 'Goodbye's self.assertEqual(9, len(control_blocks)) def test_add_channel_incomplete_handshake(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) incomplete_encoded_handshake = 'GET /echo HTTP/1.1' add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=incomplete_encoded_handshake) request.connection.put_bytes(add_channel_request) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) mux_handler.wait_until_done(timeout=2) self.assertTrue(1 in dispatcher.channel_events) self.assertTrue(not 2 in dispatcher.channel_events) def test_add_channel_invalid_version_handshake(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = ( 'GET /echo HTTP/1.1\r\n' 'Host: example.com\r\n' 'Connection: Upgrade\r\n' 'Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n' 'Sec-WebSocket-Protocol: sample\r\n' 'Upgrade: WebSocket\r\n' 'Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n' 'Origin: http://example.com\r\n' '\r\n' '^n:ds[4U') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) mux_handler.wait_until_done(timeout=2) self.assertTrue(1 in dispatcher.channel_events) self.assertTrue(not 2 in dispatcher.channel_events) def test_receive_drop_channel(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) drop_channel = mux._create_drop_channel(channel_id=2, outer_frame_mask=True) request.connection.put_bytes(drop_channel) # Terminate implicitly opened channel. request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) mux_handler.wait_until_done(timeout=2) exception = dispatcher.channel_events[2].exception self.assertTrue(exception.__class__ == ConnectionTerminatedException) def test_receive_ping_frame(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) flow_control = mux._create_flow_control(channel_id=2, replenished_quota=12, outer_frame_mask=True) request.connection.put_bytes(flow_control) ping_frame = _create_logical_frame(channel_id=2, message='Hello World!', opcode=common.OPCODE_PING) request.connection.put_bytes(ping_frame) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='Goodbye')) mux_handler.wait_until_done(timeout=2) messages = request.connection.get_written_control_messages(2) self.assertEqual(common.OPCODE_PONG, messages[0]['opcode']) self.assertEqual('Hello World!', messages[0]['message']) def test_send_ping(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/ping') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) flow_control = mux._create_flow_control(channel_id=2, replenished_quota=5, outer_frame_mask=True) request.connection.put_bytes(flow_control) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) mux_handler.wait_until_done(timeout=2) messages = request.connection.get_written_control_messages(2) self.assertEqual(common.OPCODE_PING, messages[0]['opcode']) self.assertEqual('Ping!', messages[0]['message']) def test_two_flow_control(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) # Replenish 5 bytes. flow_control = mux._create_flow_control(channel_id=2, replenished_quota=5, outer_frame_mask=True) request.connection.put_bytes(flow_control) # Send 10 bytes. The server will try echo back 10 bytes. request.connection.put_bytes( _create_logical_frame(channel_id=2, message='HelloWorld')) # Replenish 5 bytes again. flow_control = mux._create_flow_control(channel_id=2, replenished_quota=5, outer_frame_mask=True) request.connection.put_bytes(flow_control) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='Goodbye')) mux_handler.wait_until_done(timeout=2) messages = request.connection.get_written_messages(2) self.assertEqual(['HelloWorld'], messages) def test_no_send_quota_on_server(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='HelloWorld')) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) mux_handler.wait_until_done(timeout=1) # No message should be sent on channel 2. self.assertRaises(KeyError, request.connection.get_written_messages, 2) def test_quota_violation_by_client(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 0) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='HelloWorld')) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) mux_handler.wait_until_done(timeout=2) control_blocks = request.connection.get_written_control_blocks() # The first block is FlowControl for channel id 1. # The next two blocks are NewChannelSlot and AddChannelResponse. # The 4th block or the last block should be DropChannels for channel 2. # (The order can be mixed up) # The remaining block should be FlowControl for 'Goodbye'. self.assertEqual(5, len(control_blocks)) expected_opcode_and_flag = ((mux._MUX_OPCODE_DROP_CHANNEL << 5) | (1 << 4)) self.assertTrue((expected_opcode_and_flag == (ord(control_blocks[3][0]) & 0xf0)) or (expected_opcode_and_flag == (ord(control_blocks[4][0]) & 0xf0))) def test_fragmented_control_message(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/ping') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) # Replenish total 5 bytes in 3 FlowControls. flow_control = mux._create_flow_control(channel_id=2, replenished_quota=1, outer_frame_mask=True) request.connection.put_bytes(flow_control) flow_control = mux._create_flow_control(channel_id=2, replenished_quota=2, outer_frame_mask=True) request.connection.put_bytes(flow_control) flow_control = mux._create_flow_control(channel_id=2, replenished_quota=2, outer_frame_mask=True) request.connection.put_bytes(flow_control) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) mux_handler.wait_until_done(timeout=2) messages = request.connection.get_written_control_messages(2) self.assertEqual(common.OPCODE_PING, messages[0]['opcode']) self.assertEqual('Ping!', messages[0]['message']) def test_channel_slot_violation_by_client(self): request = _create_mock_request() dispatcher = _MuxMockDispatcher() mux_handler = mux._MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(slots=1, send_quota=mux._INITIAL_QUOTA_FOR_CLIENT) encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=2, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) flow_control = mux._create_flow_control(channel_id=2, replenished_quota=10, outer_frame_mask=True) request.connection.put_bytes(flow_control) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='Hello')) # This request should be rejected. encoded_handshake = _create_request_header(path='/echo') add_channel_request = _create_add_channel_request_frame( channel_id=3, encoding=0, encoded_handshake=encoded_handshake) request.connection.put_bytes(add_channel_request) flow_control = mux._create_flow_control(channel_id=3, replenished_quota=5, outer_frame_mask=True) request.connection.put_bytes(flow_control) request.connection.put_bytes( _create_logical_frame(channel_id=3, message='Hello')) request.connection.put_bytes( _create_logical_frame(channel_id=1, message='Goodbye')) request.connection.put_bytes( _create_logical_frame(channel_id=2, message='Goodbye')) mux_handler.wait_until_done(timeout=2) self.assertEqual(['Hello'], dispatcher.channel_events[2].messages) self.assertFalse(dispatcher.channel_events.has_key(3)) if __name__ == '__main__': unittest.main() # vi:sts=4 sw=4 et