1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import os 21import random 22import pytest 23 24from bumble.controller import Controller 25from bumble.link import LocalLink 26from bumble.device import Device 27from bumble.host import Host 28from bumble.transport import AsyncPipeSink 29from bumble.core import ProtocolError 30from bumble.l2cap import L2CAP_Connection_Request 31 32 33# ----------------------------------------------------------------------------- 34# Logging 35# ----------------------------------------------------------------------------- 36logger = logging.getLogger(__name__) 37 38 39# ----------------------------------------------------------------------------- 40class TwoDevices: 41 def __init__(self): 42 self.connections = [None, None] 43 44 self.link = LocalLink() 45 self.controllers = [ 46 Controller('C1', link=self.link), 47 Controller('C2', link=self.link), 48 ] 49 self.devices = [ 50 Device( 51 address='F0:F1:F2:F3:F4:F5', 52 host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), 53 ), 54 Device( 55 address='F5:F4:F3:F2:F1:F0', 56 host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), 57 ), 58 ] 59 60 self.paired = [None, None] 61 62 def on_connection(self, which, connection): 63 self.connections[which] = connection 64 65 def on_paired(self, which, keys): 66 self.paired[which] = keys 67 68 69# ----------------------------------------------------------------------------- 70async def setup_connection(): 71 # Create two devices, each with a controller, attached to the same link 72 two_devices = TwoDevices() 73 74 # Attach listeners 75 two_devices.devices[0].on( 76 'connection', lambda connection: two_devices.on_connection(0, connection) 77 ) 78 two_devices.devices[1].on( 79 'connection', lambda connection: two_devices.on_connection(1, connection) 80 ) 81 82 # Start 83 await two_devices.devices[0].power_on() 84 await two_devices.devices[1].power_on() 85 86 # Connect the two devices 87 await two_devices.devices[0].connect(two_devices.devices[1].random_address) 88 89 # Check the post conditions 90 assert two_devices.connections[0] is not None 91 assert two_devices.connections[1] is not None 92 93 return two_devices 94 95 96# ----------------------------------------------------------------------------- 97def test_helpers(): 98 psm = L2CAP_Connection_Request.serialize_psm(0x01) 99 assert psm == bytes([0x01, 0x00]) 100 101 psm = L2CAP_Connection_Request.serialize_psm(0x1023) 102 assert psm == bytes([0x23, 0x10]) 103 104 psm = L2CAP_Connection_Request.serialize_psm(0x242311) 105 assert psm == bytes([0x11, 0x23, 0x24]) 106 107 (offset, psm) = L2CAP_Connection_Request.parse_psm( 108 bytes([0x00, 0x01, 0x00, 0x44]), 1 109 ) 110 assert offset == 3 111 assert psm == 0x01 112 113 (offset, psm) = L2CAP_Connection_Request.parse_psm( 114 bytes([0x00, 0x23, 0x10, 0x44]), 1 115 ) 116 assert offset == 3 117 assert psm == 0x1023 118 119 (offset, psm) = L2CAP_Connection_Request.parse_psm( 120 bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1 121 ) 122 assert offset == 4 123 assert psm == 0x242311 124 125 rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44) 126 brq = bytes(rq) 127 srq = L2CAP_Connection_Request.from_bytes(brq) 128 assert srq.psm == rq.psm 129 assert srq.source_cid == rq.source_cid 130 131 132# ----------------------------------------------------------------------------- 133@pytest.mark.asyncio 134async def test_basic_connection(): 135 devices = await setup_connection() 136 psm = 1234 137 138 # Check that if there's no one listening, we can't connect 139 with pytest.raises(ProtocolError): 140 l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) 141 142 # Now add a listener 143 incoming_channel = None 144 received = [] 145 146 def on_coc(channel): 147 nonlocal incoming_channel 148 incoming_channel = channel 149 150 def on_data(data): 151 received.append(data) 152 153 channel.sink = on_data 154 155 devices.devices[1].register_l2cap_channel_server(psm, on_coc) 156 l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) 157 158 messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000)) 159 for message in messages: 160 l2cap_channel.write(message) 161 await asyncio.sleep(0) 162 163 await l2cap_channel.drain() 164 165 # Test closing 166 closed = [False, False] 167 closed_event = asyncio.Event() 168 169 def on_close(which, event): 170 closed[which] = True 171 if event: 172 event.set() 173 174 l2cap_channel.on('close', lambda: on_close(0, None)) 175 incoming_channel.on('close', lambda: on_close(1, closed_event)) 176 await l2cap_channel.disconnect() 177 assert closed == [True, True] 178 await closed_event.wait() 179 180 sent_bytes = b''.join(messages) 181 received_bytes = b''.join(received) 182 assert sent_bytes == received_bytes 183 184 185# ----------------------------------------------------------------------------- 186async def transfer_payload(max_credits, mtu, mps): 187 devices = await setup_connection() 188 189 received = [] 190 191 def on_coc(channel): 192 def on_data(data): 193 received.append(data) 194 195 channel.sink = on_data 196 197 psm = devices.devices[1].register_l2cap_channel_server( 198 psm=0, server=on_coc, max_credits=max_credits, mtu=mtu, mps=mps 199 ) 200 l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) 201 202 messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)] 203 for message in messages: 204 l2cap_channel.write(message) 205 await asyncio.sleep(0) 206 if random.randint(0, 5) == 1: 207 await l2cap_channel.drain() 208 209 await l2cap_channel.drain() 210 await l2cap_channel.disconnect() 211 212 sent_bytes = b''.join(messages) 213 received_bytes = b''.join(received) 214 assert sent_bytes == received_bytes 215 216 217@pytest.mark.asyncio 218async def test_transfer(): 219 for max_credits in (1, 10, 100, 10000): 220 for mtu in (50, 255, 256, 1000): 221 for mps in (50, 255, 256, 1000): 222 # print(max_credits, mtu, mps) 223 await transfer_payload(max_credits, mtu, mps) 224 225 226# ----------------------------------------------------------------------------- 227@pytest.mark.asyncio 228async def test_bidirectional_transfer(): 229 devices = await setup_connection() 230 231 client_received = [] 232 server_received = [] 233 server_channel = None 234 235 def on_server_coc(channel): 236 nonlocal server_channel 237 server_channel = channel 238 239 def on_server_data(data): 240 server_received.append(data) 241 242 channel.sink = on_server_data 243 244 def on_client_data(data): 245 client_received.append(data) 246 247 psm = devices.devices[1].register_l2cap_channel_server(psm=0, server=on_server_coc) 248 client_channel = await devices.connections[0].open_l2cap_channel(psm) 249 client_channel.sink = on_client_data 250 251 messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)] 252 for message in messages: 253 client_channel.write(message) 254 await client_channel.drain() 255 await asyncio.sleep(0) 256 server_channel.write(message) 257 await server_channel.drain() 258 259 await client_channel.disconnect() 260 261 message_bytes = b''.join(messages) 262 client_received_bytes = b''.join(client_received) 263 server_received_bytes = b''.join(server_received) 264 assert client_received_bytes == message_bytes 265 assert server_received_bytes == message_bytes 266 267 268# ----------------------------------------------------------------------------- 269async def run(): 270 test_helpers() 271 await test_basic_connection() 272 await test_transfer() 273 await test_bidirectional_transfer() 274 275 276# ----------------------------------------------------------------------------- 277if __name__ == '__main__': 278 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 279 asyncio.run(run()) 280