1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ---------------------------------------------------------------------------- 16# Imports 17# ---------------------------------------------------------------------------- 18import sys 19import logging 20import json 21import asyncio 22import argparse 23import uuid 24import os 25from urllib.parse import urlparse 26import websockets 27 28from bumble.colors import color 29 30# ----------------------------------------------------------------------------- 31# Logging 32# ----------------------------------------------------------------------------- 33logger = logging.getLogger(__name__) 34 35 36# ---------------------------------------------------------------------------- 37# Constants 38# ---------------------------------------------------------------------------- 39DEFAULT_RELAY_PORT = 10723 40 41 42# ---------------------------------------------------------------------------- 43# Utils 44# ---------------------------------------------------------------------------- 45def error_to_json(error): 46 return json.dumps({'error': error}) 47 48 49def error_to_result(error): 50 return f'result:{error_to_json(error)}' 51 52 53async def broadcast_message(message, connections): 54 # Send to all the connections 55 tasks = [connection.send_message(message) for connection in connections] 56 if tasks: 57 await asyncio.gather(*tasks) 58 59 60# ---------------------------------------------------------------------------- 61# Connection class 62# ---------------------------------------------------------------------------- 63class Connection: 64 """ 65 A Connection represents a client connected to the relay over a websocket 66 """ 67 68 def __init__(self, room, websocket): 69 self.room = room 70 self.websocket = websocket 71 self.address = str(uuid.uuid4()) 72 73 async def send_message(self, message): 74 try: 75 logger.debug(color(f'->{self.address}: {message}', 'yellow')) 76 return await self.websocket.send(message) 77 except websockets.exceptions.WebSocketException as error: 78 logger.info(f'! client "{self}" disconnected: {error}') 79 await self.cleanup() 80 81 async def send_error(self, error): 82 return await self.send_message(f'result:{error_to_json(error)}') 83 84 async def receive_message(self): 85 try: 86 message = await self.websocket.recv() 87 logger.debug(color(f'<-{self.address}: {message}', 'blue')) 88 return message 89 except websockets.exceptions.WebSocketException as error: 90 logger.info(color(f'! client "{self}" disconnected: {error}', 'red')) 91 await self.cleanup() 92 93 async def cleanup(self): 94 if self.room: 95 await self.room.remove_connection(self) 96 97 def set_address(self, address): 98 logger.info(f'Connection address changed: {self.address} -> {address}') 99 self.address = address 100 101 def __str__(self): 102 return ( 103 f'Connection(address="{self.address}", ' 104 f'client={self.websocket.remote_address[0]}:' 105 f'{self.websocket.remote_address[1]})' 106 ) 107 108 109# ---------------------------------------------------------------------------- 110# Room class 111# ---------------------------------------------------------------------------- 112class Room: 113 """ 114 A Room is a collection of bridged connections 115 """ 116 117 def __init__(self, relay, name): 118 self.relay = relay 119 self.name = name 120 self.observers = [] 121 self.connections = [] 122 123 async def add_connection(self, connection): 124 logger.info(f'New participant in {self.name}: {connection}') 125 self.connections.append(connection) 126 await self.broadcast_message(connection, f'joined:{connection.address}') 127 128 async def remove_connection(self, connection): 129 if connection in self.connections: 130 self.connections.remove(connection) 131 await self.broadcast_message(connection, f'left:{connection.address}') 132 133 def find_connections_by_address(self, address): 134 return [c for c in self.connections if c.address == address] 135 136 async def bridge_connection(self, connection): 137 while True: 138 # Wait for a message 139 message = await connection.receive_message() 140 141 # Skip empty messages 142 if message is None: 143 return 144 145 # Parse the message to decide how to handle it 146 if message.startswith('@'): 147 # This is a targeted message 148 await self.on_targeted_message(connection, message) 149 elif message.startswith('/'): 150 # This is an RPC request 151 await self.on_rpc_request(connection, message) 152 else: 153 await connection.send_message( 154 f'result:{error_to_json("error: invalid message")}' 155 ) 156 157 async def broadcast_message(self, sender, message): 158 ''' 159 Send to all connections in the room except back to the sender 160 ''' 161 await broadcast_message(message, [c for c in self.connections if c != sender]) 162 163 async def on_rpc_request(self, connection, message): 164 command, *params = message.split(' ', 1) 165 if handler := getattr( 166 self, f'on_{command[1:].lower().replace("-","_")}_command', None 167 ): 168 try: 169 result = await handler(connection, params) 170 except Exception as error: 171 result = error_to_result(error) 172 else: 173 result = error_to_result('unknown command') 174 175 await connection.send_message(result or 'result:{}') 176 177 async def on_targeted_message(self, connection, message): 178 target, *payload = message.split(' ', 1) 179 if not payload: 180 return error_to_json('missing arguments') 181 payload = payload[0] 182 target = target[1:] 183 184 # Determine what targets to send to 185 if target == '*': 186 # Send to all connections in the room except the connection from which the 187 # message was received 188 connections = [c for c in self.connections if c != connection] 189 else: 190 connections = self.find_connections_by_address(target) 191 if not connections: 192 # Unicast with no recipient, let the sender know 193 await connection.send_message(f'unreachable:{target}') 194 195 # Send to targets 196 await broadcast_message(f'message:{connection.address}/{payload}', connections) 197 198 async def on_set_address_command(self, connection, params): 199 if not params: 200 return error_to_result('missing address') 201 202 current_address = connection.address 203 new_address = params[0] 204 connection.set_address(new_address) 205 await self.broadcast_message( 206 connection, f'address-changed:from={current_address},to={new_address}' 207 ) 208 209 210# ---------------------------------------------------------------------------- 211class Relay: 212 """ 213 A relay accepts connections with the following url: ws://<hostname>/<room>. 214 Participants in a room can communicate with each other 215 """ 216 217 def __init__(self, port): 218 self.port = port 219 self.rooms = {} 220 self.observers = [] 221 222 def start(self): 223 logger.info(f'Starting Relay on port {self.port}') 224 225 # pylint: disable-next=no-member 226 return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None) 227 228 async def serve_as_controller(self, connection): 229 pass 230 231 async def serve(self, websocket, path): 232 logger.debug(f'New connection with path {path}') 233 234 # Parse the path 235 parsed = urlparse(path) 236 237 # Check if this is a controller client 238 if parsed.path == '/': 239 return await self.serve_as_controller(Connection('', websocket)) 240 241 # Find or create a room for this connection 242 room_name = parsed.path[1:].split('/')[0] 243 if room_name not in self.rooms: 244 self.rooms[room_name] = Room(self, room_name) 245 room = self.rooms[room_name] 246 247 # Add the connection to the room 248 connection = Connection(room, websocket) 249 await room.add_connection(connection) 250 251 # Bridge until the connection is closed 252 await room.bridge_connection(connection) 253 254 255# ---------------------------------------------------------------------------- 256def main(): 257 # Check the Python version 258 if sys.version_info < (3, 6, 1): 259 print('ERROR: Python 3.6.1 or higher is required') 260 sys.exit(1) 261 262 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 263 264 # Parse arguments 265 arg_parser = argparse.ArgumentParser(description='Bumble Link Relay') 266 arg_parser.add_argument('--log-level', default='INFO', help='logger level') 267 arg_parser.add_argument('--log-config', help='logger config file (YAML)') 268 arg_parser.add_argument( 269 '--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on' 270 ) 271 args = arg_parser.parse_args() 272 273 # Setup logger 274 if args.log_config: 275 from logging import config # pylint: disable=import-outside-toplevel 276 277 config.fileConfig(args.log_config) 278 else: 279 logging.basicConfig(level=getattr(logging, args.log_level.upper())) 280 281 # Start a relay 282 relay = Relay(args.port) 283 asyncio.get_event_loop().run_until_complete(relay.start()) 284 asyncio.get_event_loop().run_forever() 285 286 287# ---------------------------------------------------------------------------- 288if __name__ == '__main__': 289 main() 290