1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import functools 20import logging 21import os 22from types import LambdaType 23import pytest 24from unittest import mock 25 26from bumble.core import ( 27 BT_BR_EDR_TRANSPORT, 28 BT_LE_TRANSPORT, 29 BT_PERIPHERAL_ROLE, 30 ConnectionParameters, 31) 32from bumble.device import AdvertisingParameters, Connection, Device 33from bumble.host import AclPacketQueue, Host 34from bumble.hci import ( 35 HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 36 HCI_COMMAND_STATUS_PENDING, 37 HCI_CREATE_CONNECTION_COMMAND, 38 HCI_SUCCESS, 39 HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, 40 Address, 41 OwnAddressType, 42 HCI_Command_Complete_Event, 43 HCI_Command_Status_Event, 44 HCI_Connection_Complete_Event, 45 HCI_Connection_Request_Event, 46 HCI_Error, 47 HCI_Packet, 48) 49from bumble.gatt import ( 50 GATT_GENERIC_ACCESS_SERVICE, 51 GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, 52 GATT_DEVICE_NAME_CHARACTERISTIC, 53 GATT_APPEARANCE_CHARACTERISTIC, 54) 55 56from .test_utils import TwoDevices, async_barrier 57 58# ----------------------------------------------------------------------------- 59# Constants 60# ----------------------------------------------------------------------------- 61_TIMEOUT = 0.1 62 63# ----------------------------------------------------------------------------- 64# Logging 65# ----------------------------------------------------------------------------- 66logger = logging.getLogger(__name__) 67 68 69# ----------------------------------------------------------------------------- 70class Sink: 71 def __init__(self, flow): 72 self.flow = flow 73 next(self.flow) 74 75 def on_packet(self, packet): 76 self.flow.send(packet) 77 78 79# ----------------------------------------------------------------------------- 80@pytest.mark.asyncio 81async def test_device_connect_parallel(): 82 d0 = Device(host=Host(None, None)) 83 d1 = Device(host=Host(None, None)) 84 d2 = Device(host=Host(None, None)) 85 86 def _send(packet): 87 pass 88 89 d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send) 90 d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send) 91 d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send) 92 93 # enable classic 94 d0.classic_enabled = True 95 d1.classic_enabled = True 96 d2.classic_enabled = True 97 98 # set public addresses 99 d0.public_address = Address( 100 'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS 101 ) 102 d1.public_address = Address( 103 'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS 104 ) 105 d2.public_address = Address( 106 'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS 107 ) 108 109 def d0_flow(): 110 packet = HCI_Packet.from_bytes((yield)) 111 assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' 112 assert packet.bd_addr == d1.public_address 113 114 d0.host.on_hci_packet( 115 HCI_Command_Status_Event( 116 status=HCI_COMMAND_STATUS_PENDING, 117 num_hci_command_packets=1, 118 command_opcode=HCI_CREATE_CONNECTION_COMMAND, 119 ) 120 ) 121 122 d1.host.on_hci_packet( 123 HCI_Connection_Request_Event( 124 bd_addr=d0.public_address, 125 class_of_device=0, 126 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 127 ) 128 ) 129 130 packet = HCI_Packet.from_bytes((yield)) 131 assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' 132 assert packet.bd_addr == d2.public_address 133 134 d0.host.on_hci_packet( 135 HCI_Command_Status_Event( 136 status=HCI_COMMAND_STATUS_PENDING, 137 num_hci_command_packets=1, 138 command_opcode=HCI_CREATE_CONNECTION_COMMAND, 139 ) 140 ) 141 142 d2.host.on_hci_packet( 143 HCI_Connection_Request_Event( 144 bd_addr=d0.public_address, 145 class_of_device=0, 146 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 147 ) 148 ) 149 150 assert (yield) == None 151 152 def d1_flow(): 153 packet = HCI_Packet.from_bytes((yield)) 154 assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' 155 156 d1.host.on_hci_packet( 157 HCI_Command_Complete_Event( 158 num_hci_command_packets=1, 159 command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 160 return_parameters=b"\x00", 161 ) 162 ) 163 164 d1.host.on_hci_packet( 165 HCI_Connection_Complete_Event( 166 status=HCI_SUCCESS, 167 connection_handle=0x100, 168 bd_addr=d0.public_address, 169 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 170 encryption_enabled=True, 171 ) 172 ) 173 174 d0.host.on_hci_packet( 175 HCI_Connection_Complete_Event( 176 status=HCI_SUCCESS, 177 connection_handle=0x100, 178 bd_addr=d1.public_address, 179 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 180 encryption_enabled=True, 181 ) 182 ) 183 184 assert (yield) == None 185 186 def d2_flow(): 187 packet = HCI_Packet.from_bytes((yield)) 188 assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' 189 190 d2.host.on_hci_packet( 191 HCI_Command_Complete_Event( 192 num_hci_command_packets=1, 193 command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 194 return_parameters=b"\x00", 195 ) 196 ) 197 198 d2.host.on_hci_packet( 199 HCI_Connection_Complete_Event( 200 status=HCI_SUCCESS, 201 connection_handle=0x101, 202 bd_addr=d0.public_address, 203 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 204 encryption_enabled=True, 205 ) 206 ) 207 208 d0.host.on_hci_packet( 209 HCI_Connection_Complete_Event( 210 status=HCI_SUCCESS, 211 connection_handle=0x101, 212 bd_addr=d2.public_address, 213 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 214 encryption_enabled=True, 215 ) 216 ) 217 218 assert (yield) == None 219 220 d0.host.set_packet_sink(Sink(d0_flow())) 221 d1.host.set_packet_sink(Sink(d1_flow())) 222 d2.host.set_packet_sink(Sink(d2_flow())) 223 224 d1_accept_task = asyncio.create_task(d1.accept(peer_address=d0.public_address)) 225 d2_accept_task = asyncio.create_task(d2.accept()) 226 227 # Ensure that the accept tasks have started. 228 await async_barrier() 229 230 [c01, c02, a10, a20] = await asyncio.gather( 231 *[ 232 asyncio.create_task( 233 d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT) 234 ), 235 asyncio.create_task( 236 d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT) 237 ), 238 d1_accept_task, 239 d2_accept_task, 240 ] 241 ) 242 243 assert type(c01) == Connection 244 assert type(c02) == Connection 245 assert type(a10) == Connection 246 assert type(a20) == Connection 247 248 assert c01.handle == a10.handle and c01.handle == 0x100 249 assert c02.handle == a20.handle and c02.handle == 0x101 250 251 252# ----------------------------------------------------------------------------- 253@pytest.mark.asyncio 254async def test_flush(): 255 d0 = Device(host=Host(None, None)) 256 task = d0.abort_on('flush', asyncio.sleep(10000)) 257 await d0.host.flush() 258 try: 259 await task 260 assert False 261 except asyncio.CancelledError: 262 pass 263 264 265# ----------------------------------------------------------------------------- 266@pytest.mark.asyncio 267async def test_legacy_advertising(): 268 device = Device(host=mock.AsyncMock(Host)) 269 270 # Start advertising 271 await device.start_advertising() 272 assert device.is_advertising 273 274 # Stop advertising 275 await device.stop_advertising() 276 assert not device.is_advertising 277 278 279# ----------------------------------------------------------------------------- 280@pytest.mark.parametrize( 281 'own_address_type,', 282 (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), 283) 284@pytest.mark.asyncio 285async def test_legacy_advertising_connection(own_address_type): 286 device = Device(host=mock.AsyncMock(Host)) 287 peer_address = Address('F0:F1:F2:F3:F4:F5') 288 289 # Start advertising 290 await device.start_advertising() 291 device.on_connection( 292 0x0001, 293 BT_LE_TRANSPORT, 294 peer_address, 295 BT_PERIPHERAL_ROLE, 296 ConnectionParameters(0, 0, 0), 297 ) 298 299 if own_address_type == OwnAddressType.PUBLIC: 300 assert device.lookup_connection(0x0001).self_address == device.public_address 301 else: 302 assert device.lookup_connection(0x0001).self_address == device.random_address 303 304 # For unknown reason, read_phy() in on_connection() would be killed at the end of 305 # test, so we force scheduling here to avoid an warning. 306 await asyncio.sleep(0.0001) 307 308 309# ----------------------------------------------------------------------------- 310@pytest.mark.parametrize( 311 'auto_restart,', 312 (True, False), 313) 314@pytest.mark.asyncio 315async def test_legacy_advertising_disconnection(auto_restart): 316 device = Device(host=mock.AsyncMock(spec=Host)) 317 peer_address = Address('F0:F1:F2:F3:F4:F5') 318 await device.start_advertising(auto_restart=auto_restart) 319 device.on_connection( 320 0x0001, 321 BT_LE_TRANSPORT, 322 peer_address, 323 BT_PERIPHERAL_ROLE, 324 ConnectionParameters(0, 0, 0), 325 ) 326 327 device.on_advertising_set_termination( 328 HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0 329 ) 330 331 device.on_disconnection(0x0001, 0) 332 await async_barrier() 333 await async_barrier() 334 335 if auto_restart: 336 assert device.is_advertising 337 else: 338 assert not device.is_advertising 339 340 341# ----------------------------------------------------------------------------- 342@pytest.mark.asyncio 343async def test_extended_advertising(): 344 device = Device(host=mock.AsyncMock(Host)) 345 346 # Start advertising 347 advertising_set = await device.create_advertising_set() 348 assert device.extended_advertising_sets 349 assert advertising_set.enabled 350 351 # Stop advertising 352 await advertising_set.stop() 353 assert not advertising_set.enabled 354 355 356# ----------------------------------------------------------------------------- 357@pytest.mark.parametrize( 358 'own_address_type,', 359 (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), 360) 361@pytest.mark.asyncio 362async def test_extended_advertising_connection(own_address_type): 363 device = Device(host=mock.AsyncMock(spec=Host)) 364 peer_address = Address('F0:F1:F2:F3:F4:F5') 365 advertising_set = await device.create_advertising_set( 366 advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) 367 ) 368 device.on_connection( 369 0x0001, 370 BT_LE_TRANSPORT, 371 peer_address, 372 BT_PERIPHERAL_ROLE, 373 ConnectionParameters(0, 0, 0), 374 ) 375 device.on_advertising_set_termination( 376 HCI_SUCCESS, 377 advertising_set.advertising_handle, 378 0x0001, 379 0, 380 ) 381 382 if own_address_type == OwnAddressType.PUBLIC: 383 assert device.lookup_connection(0x0001).self_address == device.public_address 384 else: 385 assert device.lookup_connection(0x0001).self_address == device.random_address 386 387 # For unknown reason, read_phy() in on_connection() would be killed at the end of 388 # test, so we force scheduling here to avoid an warning. 389 await asyncio.sleep(0.0001) 390 391 392# ----------------------------------------------------------------------------- 393@pytest.mark.asyncio 394async def test_get_remote_le_features(): 395 devices = TwoDevices() 396 await devices.setup_connection() 397 398 assert (await devices.connections[0].get_remote_le_features()) is not None 399 400 401# ----------------------------------------------------------------------------- 402@pytest.mark.asyncio 403async def test_get_remote_le_features_failed(): 404 devices = TwoDevices() 405 await devices.setup_connection() 406 407 def on_hci_le_read_remote_features_complete_event(event): 408 devices[0].host.emit( 409 'le_remote_features_failure', 410 event.connection_handle, 411 HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, 412 ) 413 414 devices[0].host.on_hci_le_read_remote_features_complete_event = ( 415 on_hci_le_read_remote_features_complete_event 416 ) 417 418 with pytest.raises(HCI_Error): 419 await asyncio.wait_for( 420 devices.connections[0].get_remote_le_features(), _TIMEOUT 421 ) 422 423 424# ----------------------------------------------------------------------------- 425@pytest.mark.asyncio 426async def test_cis(): 427 devices = TwoDevices() 428 await devices.setup_connection() 429 430 peripheral_cis_futures = {} 431 432 def on_cis_request( 433 acl_connection: Connection, 434 cis_handle: int, 435 _cig_id: int, 436 _cis_id: int, 437 ): 438 acl_connection.abort_on( 439 'disconnection', devices[1].accept_cis_request(cis_handle) 440 ) 441 peripheral_cis_futures[cis_handle] = asyncio.get_running_loop().create_future() 442 443 devices[1].on('cis_request', on_cis_request) 444 devices[1].on( 445 'cis_establishment', 446 lambda cis_link: peripheral_cis_futures[cis_link.handle].set_result(None), 447 ) 448 449 cis_handles = await devices[0].setup_cig( 450 cig_id=1, 451 cis_id=[2, 3], 452 sdu_interval=(0, 0), 453 framing=0, 454 max_sdu=(0, 0), 455 retransmission_number=0, 456 max_transport_latency=(0, 0), 457 ) 458 assert len(cis_handles) == 2 459 cis_links = await devices[0].create_cis( 460 [ 461 (cis_handles[0], devices.connections[0].handle), 462 (cis_handles[1], devices.connections[0].handle), 463 ] 464 ) 465 await asyncio.gather(*peripheral_cis_futures.values()) 466 assert len(cis_links) == 2 467 468 await cis_links[0].disconnect() 469 await cis_links[1].disconnect() 470 471 472# ----------------------------------------------------------------------------- 473@pytest.mark.asyncio 474async def test_cis_setup_failure(): 475 devices = TwoDevices() 476 await devices.setup_connection() 477 478 cis_requests = asyncio.Queue() 479 480 def on_cis_request( 481 acl_connection: Connection, 482 cis_handle: int, 483 cig_id: int, 484 cis_id: int, 485 ): 486 del acl_connection, cig_id, cis_id 487 cis_requests.put_nowait(cis_handle) 488 489 devices[1].on('cis_request', on_cis_request) 490 491 cis_handles = await devices[0].setup_cig( 492 cig_id=1, 493 cis_id=[2], 494 sdu_interval=(0, 0), 495 framing=0, 496 max_sdu=(0, 0), 497 retransmission_number=0, 498 max_transport_latency=(0, 0), 499 ) 500 assert len(cis_handles) == 1 501 502 cis_create_task = asyncio.create_task( 503 devices[0].create_cis( 504 [ 505 (cis_handles[0], devices.connections[0].handle), 506 ] 507 ) 508 ) 509 510 def on_hci_le_cis_established_event(host, event): 511 host.emit( 512 'cis_establishment_failure', 513 event.connection_handle, 514 HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, 515 ) 516 517 for device in devices: 518 device.host.on_hci_le_cis_established_event = functools.partial( 519 on_hci_le_cis_established_event, device.host 520 ) 521 522 cis_request = await asyncio.wait_for(cis_requests.get(), _TIMEOUT) 523 524 with pytest.raises(HCI_Error): 525 await asyncio.wait_for(devices[1].accept_cis_request(cis_request), _TIMEOUT) 526 527 with pytest.raises(HCI_Error): 528 await asyncio.wait_for(cis_create_task, _TIMEOUT) 529 530 531# ----------------------------------------------------------------------------- 532def test_gatt_services_with_gas(): 533 device = Device(host=Host(None, None)) 534 535 # there should be one service and two chars, therefore 5 attributes 536 assert len(device.gatt_server.attributes) == 5 537 assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE 538 assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE 539 assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC 540 assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE 541 assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC 542 543 544# ----------------------------------------------------------------------------- 545def test_gatt_services_without_gas(): 546 device = Device(host=Host(None, None), generic_access_service=False) 547 548 # there should be no services 549 assert len(device.gatt_server.attributes) == 0 550 551 552# ----------------------------------------------------------------------------- 553async def run_test_device(): 554 await test_device_connect_parallel() 555 await test_flush() 556 await test_gatt_services_with_gas() 557 await test_gatt_services_without_gas() 558 559 560# ----------------------------------------------------------------------------- 561if __name__ == '__main__': 562 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 563 asyncio.run(run_test_device()) 564