• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# ----------------------------------------------------------------------------
16# Imports
17# ----------------------------------------------------------------------------
18import sys
19import logging
20import json
21import asyncio
22import argparse
23import uuid
24import os
25from urllib.parse import urlparse
26import websockets
27
28from bumble.colors import color
29
30# -----------------------------------------------------------------------------
31# Logging
32# -----------------------------------------------------------------------------
33logger = logging.getLogger(__name__)
34
35
36# ----------------------------------------------------------------------------
37# Constants
38# ----------------------------------------------------------------------------
39DEFAULT_RELAY_PORT = 10723
40
41
42# ----------------------------------------------------------------------------
43# Utils
44# ----------------------------------------------------------------------------
45def error_to_json(error):
46    return json.dumps({'error': error})
47
48
49def error_to_result(error):
50    return f'result:{error_to_json(error)}'
51
52
53async def broadcast_message(message, connections):
54    # Send to all the connections
55    tasks = [connection.send_message(message) for connection in connections]
56    if tasks:
57        await asyncio.gather(*tasks)
58
59
60# ----------------------------------------------------------------------------
61# Connection class
62# ----------------------------------------------------------------------------
63class Connection:
64    """
65    A Connection represents a client connected to the relay over a websocket
66    """
67
68    def __init__(self, room, websocket):
69        self.room = room
70        self.websocket = websocket
71        self.address = str(uuid.uuid4())
72
73    async def send_message(self, message):
74        try:
75            logger.debug(color(f'->{self.address}: {message}', 'yellow'))
76            return await self.websocket.send(message)
77        except websockets.exceptions.WebSocketException as error:
78            logger.info(f'! client "{self}" disconnected: {error}')
79            await self.cleanup()
80
81    async def send_error(self, error):
82        return await self.send_message(f'result:{error_to_json(error)}')
83
84    async def receive_message(self):
85        try:
86            message = await self.websocket.recv()
87            logger.debug(color(f'<-{self.address}: {message}', 'blue'))
88            return message
89        except websockets.exceptions.WebSocketException as error:
90            logger.info(color(f'! client "{self}" disconnected: {error}', 'red'))
91            await self.cleanup()
92
93    async def cleanup(self):
94        if self.room:
95            await self.room.remove_connection(self)
96
97    def set_address(self, address):
98        logger.info(f'Connection address changed: {self.address} -> {address}')
99        self.address = address
100
101    def __str__(self):
102        return (
103            f'Connection(address="{self.address}", '
104            f'client={self.websocket.remote_address[0]}:'
105            f'{self.websocket.remote_address[1]})'
106        )
107
108
109# ----------------------------------------------------------------------------
110# Room class
111# ----------------------------------------------------------------------------
112class Room:
113    """
114    A Room is a collection of bridged connections
115    """
116
117    def __init__(self, relay, name):
118        self.relay = relay
119        self.name = name
120        self.observers = []
121        self.connections = []
122
123    async def add_connection(self, connection):
124        logger.info(f'New participant in {self.name}: {connection}')
125        self.connections.append(connection)
126        await self.broadcast_message(connection, f'joined:{connection.address}')
127
128    async def remove_connection(self, connection):
129        if connection in self.connections:
130            self.connections.remove(connection)
131            await self.broadcast_message(connection, f'left:{connection.address}')
132
133    def find_connections_by_address(self, address):
134        return [c for c in self.connections if c.address == address]
135
136    async def bridge_connection(self, connection):
137        while True:
138            # Wait for a message
139            message = await connection.receive_message()
140
141            # Skip empty messages
142            if message is None:
143                return
144
145            # Parse the message to decide how to handle it
146            if message.startswith('@'):
147                # This is a targeted message
148                await self.on_targeted_message(connection, message)
149            elif message.startswith('/'):
150                # This is an RPC request
151                await self.on_rpc_request(connection, message)
152            else:
153                await connection.send_message(
154                    f'result:{error_to_json("error: invalid message")}'
155                )
156
157    async def broadcast_message(self, sender, message):
158        '''
159        Send to all connections in the room except back to the sender
160        '''
161        await broadcast_message(message, [c for c in self.connections if c != sender])
162
163    async def on_rpc_request(self, connection, message):
164        command, *params = message.split(' ', 1)
165        if handler := getattr(
166            self, f'on_{command[1:].lower().replace("-","_")}_command', None
167        ):
168            try:
169                result = await handler(connection, params)
170            except Exception as error:
171                result = error_to_result(error)
172        else:
173            result = error_to_result('unknown command')
174
175        await connection.send_message(result or 'result:{}')
176
177    async def on_targeted_message(self, connection, message):
178        target, *payload = message.split(' ', 1)
179        if not payload:
180            return error_to_json('missing arguments')
181        payload = payload[0]
182        target = target[1:]
183
184        # Determine what targets to send to
185        if target == '*':
186            # Send to all connections in the room except the connection from which the
187            # message was received
188            connections = [c for c in self.connections if c != connection]
189        else:
190            connections = self.find_connections_by_address(target)
191            if not connections:
192                # Unicast with no recipient, let the sender know
193                await connection.send_message(f'unreachable:{target}')
194
195        # Send to targets
196        await broadcast_message(f'message:{connection.address}/{payload}', connections)
197
198    async def on_set_address_command(self, connection, params):
199        if not params:
200            return error_to_result('missing address')
201
202        current_address = connection.address
203        new_address = params[0]
204        connection.set_address(new_address)
205        await self.broadcast_message(
206            connection, f'address-changed:from={current_address},to={new_address}'
207        )
208
209
210# ----------------------------------------------------------------------------
211class Relay:
212    """
213    A relay accepts connections with the following url: ws://<hostname>/<room>.
214    Participants in a room can communicate with each other
215    """
216
217    def __init__(self, port):
218        self.port = port
219        self.rooms = {}
220        self.observers = []
221
222    def start(self):
223        logger.info(f'Starting Relay on port {self.port}')
224
225        # pylint: disable-next=no-member
226        return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
227
228    async def serve_as_controller(self, connection):
229        pass
230
231    async def serve(self, websocket, path):
232        logger.debug(f'New connection with path {path}')
233
234        # Parse the path
235        parsed = urlparse(path)
236
237        # Check if this is a controller client
238        if parsed.path == '/':
239            return await self.serve_as_controller(Connection('', websocket))
240
241        # Find or create a room for this connection
242        room_name = parsed.path[1:].split('/')[0]
243        if room_name not in self.rooms:
244            self.rooms[room_name] = Room(self, room_name)
245        room = self.rooms[room_name]
246
247        # Add the connection to the room
248        connection = Connection(room, websocket)
249        await room.add_connection(connection)
250
251        # Bridge until the connection is closed
252        await room.bridge_connection(connection)
253
254
255# ----------------------------------------------------------------------------
256def main():
257    # Check the Python version
258    if sys.version_info < (3, 6, 1):
259        print('ERROR: Python 3.6.1 or higher is required')
260        sys.exit(1)
261
262    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
263
264    # Parse arguments
265    arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
266    arg_parser.add_argument('--log-level', default='INFO', help='logger level')
267    arg_parser.add_argument('--log-config', help='logger config file (YAML)')
268    arg_parser.add_argument(
269        '--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
270    )
271    args = arg_parser.parse_args()
272
273    # Setup logger
274    if args.log_config:
275        from logging import config  # pylint: disable=import-outside-toplevel
276
277        config.fileConfig(args.log_config)
278    else:
279        logging.basicConfig(level=getattr(logging, args.log_level.upper()))
280
281    # Start a relay
282    relay = Relay(args.port)
283    asyncio.get_event_loop().run_until_complete(relay.start())
284    asyncio.get_event_loop().run_forever()
285
286
287# ----------------------------------------------------------------------------
288if __name__ == '__main__':
289    main()
290