1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import os 21from types import LambdaType 22import pytest 23 24from bumble.core import BT_BR_EDR_TRANSPORT 25from bumble.device import Connection, Device 26from bumble.host import Host 27from bumble.hci import ( 28 HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 29 HCI_COMMAND_STATUS_PENDING, 30 HCI_CREATE_CONNECTION_COMMAND, 31 HCI_SUCCESS, 32 Address, 33 HCI_Command_Complete_Event, 34 HCI_Command_Status_Event, 35 HCI_Connection_Complete_Event, 36 HCI_Connection_Request_Event, 37 HCI_Packet, 38) 39from bumble.gatt import ( 40 GATT_GENERIC_ACCESS_SERVICE, 41 GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, 42 GATT_DEVICE_NAME_CHARACTERISTIC, 43 GATT_APPEARANCE_CHARACTERISTIC, 44) 45 46# ----------------------------------------------------------------------------- 47# Logging 48# ----------------------------------------------------------------------------- 49logger = logging.getLogger(__name__) 50 51 52# ----------------------------------------------------------------------------- 53class Sink: 54 def __init__(self, flow): 55 self.flow = flow 56 next(self.flow) 57 58 def on_packet(self, packet): 59 self.flow.send(packet) 60 61 62# ----------------------------------------------------------------------------- 63@pytest.mark.asyncio 64async def test_device_connect_parallel(): 65 d0 = Device(host=Host(None, None)) 66 d1 = Device(host=Host(None, None)) 67 d2 = Device(host=Host(None, None)) 68 69 # enable classic 70 d0.classic_enabled = True 71 d1.classic_enabled = True 72 d2.classic_enabled = True 73 74 # set public addresses 75 d0.public_address = Address( 76 'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS 77 ) 78 d1.public_address = Address( 79 'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS 80 ) 81 d2.public_address = Address( 82 'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS 83 ) 84 85 def d0_flow(): 86 packet = HCI_Packet.from_bytes((yield)) 87 assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' 88 assert packet.bd_addr == d1.public_address 89 90 d0.host.on_hci_packet( 91 HCI_Command_Status_Event( 92 status=HCI_COMMAND_STATUS_PENDING, 93 num_hci_command_packets=1, 94 command_opcode=HCI_CREATE_CONNECTION_COMMAND, 95 ) 96 ) 97 98 d1.host.on_hci_packet( 99 HCI_Connection_Request_Event( 100 bd_addr=d0.public_address, 101 class_of_device=0, 102 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 103 ) 104 ) 105 106 packet = HCI_Packet.from_bytes((yield)) 107 assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' 108 assert packet.bd_addr == d2.public_address 109 110 d0.host.on_hci_packet( 111 HCI_Command_Status_Event( 112 status=HCI_COMMAND_STATUS_PENDING, 113 num_hci_command_packets=1, 114 command_opcode=HCI_CREATE_CONNECTION_COMMAND, 115 ) 116 ) 117 118 d2.host.on_hci_packet( 119 HCI_Connection_Request_Event( 120 bd_addr=d0.public_address, 121 class_of_device=0, 122 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 123 ) 124 ) 125 126 assert (yield) == None 127 128 def d1_flow(): 129 packet = HCI_Packet.from_bytes((yield)) 130 assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' 131 132 d1.host.on_hci_packet( 133 HCI_Command_Complete_Event( 134 num_hci_command_packets=1, 135 command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 136 return_parameters=b"\x00", 137 ) 138 ) 139 140 d1.host.on_hci_packet( 141 HCI_Connection_Complete_Event( 142 status=HCI_SUCCESS, 143 connection_handle=0x100, 144 bd_addr=d0.public_address, 145 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 146 encryption_enabled=True, 147 ) 148 ) 149 150 d0.host.on_hci_packet( 151 HCI_Connection_Complete_Event( 152 status=HCI_SUCCESS, 153 connection_handle=0x100, 154 bd_addr=d1.public_address, 155 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 156 encryption_enabled=True, 157 ) 158 ) 159 160 assert (yield) == None 161 162 def d2_flow(): 163 packet = HCI_Packet.from_bytes((yield)) 164 assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' 165 166 d2.host.on_hci_packet( 167 HCI_Command_Complete_Event( 168 num_hci_command_packets=1, 169 command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 170 return_parameters=b"\x00", 171 ) 172 ) 173 174 d2.host.on_hci_packet( 175 HCI_Connection_Complete_Event( 176 status=HCI_SUCCESS, 177 connection_handle=0x101, 178 bd_addr=d0.public_address, 179 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 180 encryption_enabled=True, 181 ) 182 ) 183 184 d0.host.on_hci_packet( 185 HCI_Connection_Complete_Event( 186 status=HCI_SUCCESS, 187 connection_handle=0x101, 188 bd_addr=d2.public_address, 189 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 190 encryption_enabled=True, 191 ) 192 ) 193 194 assert (yield) == None 195 196 d0.host.set_packet_sink(Sink(d0_flow())) 197 d1.host.set_packet_sink(Sink(d1_flow())) 198 d2.host.set_packet_sink(Sink(d2_flow())) 199 200 [c01, c02, a10, a20] = await asyncio.gather( 201 *[ 202 asyncio.create_task( 203 d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT) 204 ), 205 asyncio.create_task( 206 d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT) 207 ), 208 asyncio.create_task(d1.accept(peer_address=d0.public_address)), 209 asyncio.create_task(d2.accept()), 210 ] 211 ) 212 213 assert type(c01) == Connection 214 assert type(c02) == Connection 215 assert type(a10) == Connection 216 assert type(a20) == Connection 217 218 assert c01.handle == a10.handle and c01.handle == 0x100 219 assert c02.handle == a20.handle and c02.handle == 0x101 220 221 222# ----------------------------------------------------------------------------- 223@pytest.mark.asyncio 224async def test_flush(): 225 d0 = Device(host=Host(None, None)) 226 task = d0.abort_on('flush', asyncio.sleep(10000)) 227 await d0.host.flush() 228 try: 229 await task 230 assert False 231 except asyncio.CancelledError: 232 pass 233 234 235# ----------------------------------------------------------------------------- 236def test_gatt_services_with_gas(): 237 device = Device(host=Host(None, None)) 238 239 # there should be one service and two chars, therefore 5 attributes 240 assert len(device.gatt_server.attributes) == 5 241 assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE 242 assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE 243 assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC 244 assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE 245 assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC 246 247 248# ----------------------------------------------------------------------------- 249def test_gatt_services_without_gas(): 250 device = Device(host=Host(None, None), generic_access_service=False) 251 252 # there should be no services 253 assert len(device.gatt_server.attributes) == 0 254 255 256# ----------------------------------------------------------------------------- 257async def run_test_device(): 258 await test_device_connect_parallel() 259 await test_flush() 260 await test_gatt_services_with_gas() 261 await test_gatt_services_without_gas() 262 263 264# ----------------------------------------------------------------------------- 265if __name__ == '__main__': 266 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 267 asyncio.run(run_test_device()) 268