1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import websockets 21 22from .common import Transport, ParserSource, PumpedPacketSink 23 24# ----------------------------------------------------------------------------- 25# Logging 26# ----------------------------------------------------------------------------- 27logger = logging.getLogger(__name__) 28 29 30# ----------------------------------------------------------------------------- 31async def open_ws_server_transport(spec): 32 ''' 33 Open a WebSocket server transport. 34 The parameter string has this syntax: 35 <local-host>:<local-port> 36 Where <local-host> may be the address of a local network interface, or '_' 37 to accept connections on all local network interfaces. 38 39 Example: _:9001 40 ''' 41 42 class WsServerTransport(Transport): 43 def __init__(self): 44 source = ParserSource() 45 sink = PumpedPacketSink(self.send_packet) 46 self.connection = asyncio.get_running_loop().create_future() 47 48 super().__init__(source, sink) 49 50 async def serve(self, local_host, local_port): 51 self.sink.start() 52 self.server = await websockets.serve( 53 ws_handler = self.on_connection, 54 host = local_host if local_host != '_' else None, 55 port = int(local_port) 56 ) 57 logger.debug(f'websocket server ready on port {local_port}') 58 59 async def on_connection(self, connection): 60 logger.debug(f'new connection on {connection.local_address} from {connection.remote_address}') 61 self.connection.set_result(connection) 62 try: 63 async for packet in connection: 64 if type(packet) is bytes: 65 self.source.parser.feed_data(packet) 66 else: 67 logger.warn('discarding packet: not a BINARY frame') 68 except websockets.WebSocketException as error: 69 logger.debug(f'exception while receiving packet: {error}') 70 71 # Wait for a new connection 72 self.connection = asyncio.get_running_loop().create_future() 73 74 async def send_packet(self, packet): 75 connection = await self.connection 76 return await connection.send(packet) 77 78 local_host, local_port = spec.split(':') 79 transport = WsServerTransport() 80 await transport.serve(local_host, local_port) 81 return transport 82