1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20from pyee import EventEmitter 21from colors import color 22 23from .hci import * 24from .l2cap import * 25from .att import * 26from .gatt import * 27from .smp import * 28from .core import ConnectionParameters 29 30# ----------------------------------------------------------------------------- 31# Logging 32# ----------------------------------------------------------------------------- 33logger = logging.getLogger(__name__) 34 35 36# ----------------------------------------------------------------------------- 37# Constants 38# ----------------------------------------------------------------------------- 39HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27 40HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1 41HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27 42HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 43 44 45# ----------------------------------------------------------------------------- 46class Connection: 47 def __init__(self, host, handle, role, peer_address): 48 self.host = host 49 self.handle = handle 50 self.role = role 51 self.peer_address = peer_address 52 self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) 53 54 def on_hci_acl_data_packet(self, packet): 55 self.assembler.feed_packet(packet) 56 57 def on_acl_pdu(self, pdu): 58 l2cap_pdu = L2CAP_PDU.from_bytes(pdu) 59 60 if l2cap_pdu.cid == ATT_CID: 61 self.host.on_gatt_pdu(self, l2cap_pdu.payload) 62 elif l2cap_pdu.cid == SMP_CID: 63 self.host.on_smp_pdu(self, l2cap_pdu.payload) 64 else: 65 self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) 66 67 68# ----------------------------------------------------------------------------- 69class Host(EventEmitter): 70 def __init__(self, controller_source = None, controller_sink = None): 71 super().__init__() 72 73 self.hci_sink = None 74 self.ready = False # True when we can accept incoming packets 75 self.connections = {} # Connections, by connection handle 76 self.pending_command = None 77 self.pending_response = None 78 self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH 79 self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS 80 self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH 81 self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS 82 self.acl_packet_queue = collections.deque() 83 self.acl_packets_in_flight = 0 84 self.local_supported_commands = bytes(64) 85 self.command_semaphore = asyncio.Semaphore(1) 86 self.long_term_key_provider = None 87 self.link_key_provider = None 88 self.pairing_io_capability_provider = None # Classic only 89 90 # Connect to the source and sink if specified 91 if controller_source: 92 controller_source.set_packet_sink(self) 93 if controller_sink: 94 self.set_packet_sink(controller_sink) 95 96 async def reset(self): 97 await self.send_command(HCI_Reset_Command()) 98 self.ready = True 99 100 response = await self.send_command(HCI_Read_Local_Supported_Commands_Command()) 101 if response.return_parameters.status != HCI_SUCCESS: 102 raise ProtocolError(response.return_parameters.status, 'hci') 103 self.local_supported_commands = response.return_parameters.supported_commands 104 105 await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFFFF'))) 106 await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = bytes.fromhex('FFFFF00000000000'))) 107 await self.send_command(HCI_Read_Local_Version_Information_Command()) 108 await self.send_command(HCI_Write_LE_Host_Support_Command(le_supported_host = 1, simultaneous_le_host = 0)) 109 110 response = await self.send_command(HCI_LE_Read_Buffer_Size_Command()) 111 if response.return_parameters.status == HCI_SUCCESS: 112 self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length 113 self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets 114 logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={response.return_parameters.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={response.return_parameters.hc_total_num_le_acl_data_packets}') 115 else: 116 logger.warn(f'HCI_LE_Read_Buffer_Size_Command failed: {response.return_parameters.status}') 117 if response.return_parameters.hc_le_acl_data_packet_length == 0 or response.return_parameters.hc_total_num_le_acl_data_packets == 0: 118 # Read the non-LE-specific values 119 response = await self.send_command(HCI_Read_Buffer_Size_Command()) 120 if response.return_parameters.status == HCI_SUCCESS: 121 self.hc_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length 122 self.hc_le_acl_data_packet_length = self.hc_le_acl_data_packet_length or self.hc_acl_data_packet_length 123 self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets 124 self.hc_total_num_le_acl_data_packets = self.hc_total_num_le_acl_data_packets or self.hc_total_num_acl_data_packets 125 logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}') 126 else: 127 logger.warn(f'HCI_Read_Buffer_Size_Command failed: {response.return_parameters.status}') 128 129 self.reset_done = True 130 131 @property 132 def controller(self): 133 return self.hci_sink 134 135 @controller.setter 136 def controller(self, controller): 137 self.set_packet_sink(controller) 138 if controller: 139 controller.set_packet_sink(self) 140 141 def set_packet_sink(self, sink): 142 self.hci_sink = sink 143 144 def send_hci_packet(self, packet): 145 self.hci_sink.on_packet(packet.to_bytes()) 146 147 async def send_command(self, command): 148 logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') 149 150 # Wait until we can send (only one pending command at a time) 151 async with self.command_semaphore: 152 assert(self.pending_command is None) 153 assert(self.pending_response is None) 154 155 # Create a future value to hold the eventual response 156 self.pending_response = asyncio.get_running_loop().create_future() 157 self.pending_command = command 158 159 try: 160 self.send_hci_packet(command) 161 response = await self.pending_response 162 # TODO: check error values 163 return response 164 except Exception as error: 165 logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}') 166 # raise error 167 finally: 168 self.pending_command = None 169 self.pending_response = None 170 171 # Use this method to send a command from a task 172 def send_command_sync(self, command): 173 async def send_command(command): 174 await self.send_command(command) 175 176 asyncio.create_task(send_command(command)) 177 178 def send_l2cap_pdu(self, connection_handle, cid, pdu): 179 l2cap_pdu = L2CAP_PDU(cid, pdu).to_bytes() 180 181 # Send the data to the controller via ACL packets 182 bytes_remaining = len(l2cap_pdu) 183 offset = 0 184 pb_flag = 0 185 while bytes_remaining: 186 data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length) 187 acl_packet = HCI_AclDataPacket( 188 connection_handle = connection_handle, 189 pb_flag = pb_flag, 190 bc_flag = 0, 191 data_total_length = data_total_length, 192 data = l2cap_pdu[offset:offset + data_total_length] 193 ) 194 logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}') 195 self.queue_acl_packet(acl_packet) 196 pb_flag = 1 197 offset += data_total_length 198 bytes_remaining -= data_total_length 199 200 def queue_acl_packet(self, acl_packet): 201 self.acl_packet_queue.appendleft(acl_packet) 202 self.check_acl_packet_queue() 203 204 if len(self.acl_packet_queue): 205 logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue') 206 207 def check_acl_packet_queue(self): 208 # Send all we can 209 while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets: 210 packet = self.acl_packet_queue.pop() 211 self.send_hci_packet(packet) 212 self.acl_packets_in_flight += 1 213 214 # Packet Sink protocol (packets coming from the controller via HCI) 215 def on_packet(self, packet): 216 hci_packet = HCI_Packet.from_bytes(packet) 217 if self.ready or ( 218 hci_packet.hci_packet_type == HCI_EVENT_PACKET and 219 hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and 220 hci_packet.command_opcode == HCI_RESET_COMMAND 221 ): 222 self.on_hci_packet(hci_packet) 223 else: 224 logger.debug('reset not done, ignoring packet from controller') 225 226 def on_hci_packet(self, packet): 227 logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') 228 229 # If the packet is a command, invoke the handler for this packet 230 if packet.hci_packet_type == HCI_COMMAND_PACKET: 231 self.on_hci_command_packet(packet) 232 elif packet.hci_packet_type == HCI_EVENT_PACKET: 233 self.on_hci_event_packet(packet) 234 elif packet.hci_packet_type == HCI_ACL_DATA_PACKET: 235 self.on_hci_acl_data_packet(packet) 236 else: 237 logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') 238 239 def on_hci_command_packet(self, command): 240 logger.warning(f'!!! unexpected command packet: {command}') 241 242 def on_hci_event_packet(self, event): 243 handler_name = f'on_{event.name.lower()}' 244 handler = getattr(self, handler_name, self.on_hci_event) 245 handler(event) 246 247 def on_hci_acl_data_packet(self, packet): 248 # Look for the connection to which this data belongs 249 if connection := self.connections.get(packet.connection_handle): 250 connection.on_hci_acl_data_packet(packet) 251 252 def on_gatt_pdu(self, connection, pdu): 253 self.emit('gatt_pdu', connection.handle, pdu) 254 255 def on_smp_pdu(self, connection, pdu): 256 self.emit('smp_pdu', connection.handle, pdu) 257 258 def on_l2cap_pdu(self, connection, cid, pdu): 259 self.emit('l2cap_pdu', connection.handle, cid, pdu) 260 261 def on_command_processed(self, event): 262 if self.pending_response: 263 # Check that it is what we were expecting 264 if self.pending_command.op_code != event.command_opcode: 265 logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}') 266 267 self.pending_response.set_result(event) 268 else: 269 logger.warning('!!! no pending response future to set') 270 271 ############################################################ 272 # HCI handlers 273 ############################################################ 274 def on_hci_event(self, event): 275 logger.warning(f'{color(f"--- Ignoring event {event}", "red")}') 276 277 def on_hci_command_complete_event(self, event): 278 if event.command_opcode == 0: 279 # This is used just for the Num_HCI_Command_Packets field, not related to an actual command 280 logger.debug('no-command event') 281 else: 282 return self.on_command_processed(event) 283 284 def on_hci_command_status_event(self, event): 285 return self.on_command_processed(event) 286 287 def on_hci_number_of_completed_packets_event(self, event): 288 total_packets = sum(event.num_completed_packets) 289 if total_packets <= self.acl_packets_in_flight: 290 self.acl_packets_in_flight -= total_packets 291 self.check_acl_packet_queue() 292 else: 293 logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight')) 294 self.acl_packets_in_flight = 0 295 296 # Classic only 297 def on_hci_connection_request_event(self, event): 298 # For now, just accept everything 299 # TODO: delegate the decision 300 self.send_command_sync( 301 HCI_Accept_Connection_Request_Command( 302 bd_addr = event.bd_addr, 303 role = 0x01 # Remain the peripheral 304 ) 305 ) 306 307 def on_hci_le_connection_complete_event(self, event): 308 # Check if this is a cancellation 309 if event.status == HCI_SUCCESS: 310 # Create/update the connection 311 logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}') 312 313 connection = self.connections.get(event.connection_handle) 314 if connection is None: 315 connection = Connection(self, event.connection_handle, event.role, event.peer_address) 316 self.connections[event.connection_handle] = connection 317 318 # Notify the client 319 connection_parameters = ConnectionParameters( 320 event.conn_interval, 321 event.conn_latency, 322 event.supervision_timeout 323 ) 324 self.emit( 325 'connection', 326 event.connection_handle, 327 BT_LE_TRANSPORT, 328 event.peer_address, 329 None, 330 event.role, 331 connection_parameters 332 ) 333 else: 334 logger.debug(f'### CONNECTION FAILED: {event.status}') 335 336 # Notify the listeners 337 self.emit('connection_failure', event.status) 338 339 def on_hci_le_enhanced_connection_complete_event(self, event): 340 # Just use the same implementation as for the non-enhanced event for now 341 self.on_hci_le_connection_complete_event(event) 342 343 def on_hci_connection_complete_event(self, event): 344 if event.status == HCI_SUCCESS: 345 # Create/update the connection 346 logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}') 347 348 connection = self.connections.get(event.connection_handle) 349 if connection is None: 350 connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr) 351 self.connections[event.connection_handle] = connection 352 353 # Notify the client 354 self.emit( 355 'connection', 356 event.connection_handle, 357 BT_BR_EDR_TRANSPORT, 358 event.bd_addr, 359 None, 360 BT_CENTRAL_ROLE, 361 None 362 ) 363 else: 364 logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') 365 366 # Notify the client 367 self.emit('connection_failure', event.connection_handle, event.status) 368 369 def on_hci_disconnection_complete_event(self, event): 370 # Find the connection 371 if (connection := self.connections.get(event.connection_handle)) is None: 372 logger.warning('!!! DISCONNECTION COMPLETE: unknown handle') 373 return 374 375 if event.status == HCI_SUCCESS: 376 logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}') 377 del self.connections[event.connection_handle] 378 379 # Notify the listeners 380 self.emit('disconnection', event.connection_handle, event.reason) 381 else: 382 logger.debug(f'### DISCONNECTION FAILED: {event.status}') 383 384 # Notify the listeners 385 self.emit('disconnection_failure', event.status) 386 387 def on_hci_le_connection_update_complete_event(self, event): 388 if (connection := self.connections.get(event.connection_handle)) is None: 389 logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle') 390 return 391 392 # Notify the client 393 if event.status == HCI_SUCCESS: 394 connection_parameters = ConnectionParameters( 395 event.conn_interval, 396 event.conn_latency, 397 event.supervision_timeout 398 ) 399 self.emit('connection_parameters_update', connection.handle, connection_parameters) 400 else: 401 self.emit('connection_parameters_update_failure', connection.handle, event.status) 402 403 def on_hci_le_phy_update_complete_event(self, event): 404 if (connection := self.connections.get(event.connection_handle)) is None: 405 logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle') 406 return 407 408 # Notify the client 409 if event.status == HCI_SUCCESS: 410 connection_phy = ConnectionPHY(event.tx_phy, event.rx_phy) 411 self.emit('connection_phy_update', connection.handle, connection_phy) 412 else: 413 self.emit('connection_phy_update_failure', connection.handle, event.status) 414 415 def on_hci_le_advertising_report_event(self, event): 416 for report in event.reports: 417 self.emit( 418 'advertising_report', 419 report.address, 420 report.data, 421 report.rssi, 422 report.event_type 423 ) 424 425 def on_hci_le_remote_connection_parameter_request_event(self, event): 426 if event.connection_handle not in self.connections: 427 logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle') 428 return 429 430 # For now, just accept everything 431 # TODO: delegate the decision 432 self.send_command_sync( 433 HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( 434 connection_handle = event.connection_handle, 435 interval_min = event.interval_min, 436 interval_max = event.interval_max, 437 latency = event.latency, 438 timeout = event.timeout, 439 minimum_ce_length = 0, 440 maximum_ce_length = 0 441 ) 442 ) 443 444 def on_hci_le_long_term_key_request_event(self, event): 445 if (connection := self.connections.get(event.connection_handle)) is None: 446 logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle') 447 return 448 449 async def send_long_term_key(): 450 if self.long_term_key_provider is None: 451 logger.debug('no long term key provider') 452 long_term_key = None 453 else: 454 long_term_key = await self.long_term_key_provider( 455 connection.handle, 456 event.random_number, 457 event.encryption_diversifier 458 ) 459 if long_term_key: 460 response = HCI_LE_Long_Term_Key_Request_Reply_Command( 461 connection_handle = event.connection_handle, 462 long_term_key = long_term_key 463 ) 464 else: 465 response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command( 466 connection_handle = event.connection_handle 467 ) 468 469 await self.send_command(response) 470 471 asyncio.create_task(send_long_term_key()) 472 473 def on_hci_synchronous_connection_complete_event(self, event): 474 pass 475 476 def on_hci_synchronous_connection_changed_event(self, event): 477 pass 478 479 def on_hci_role_change_event(self, event): 480 if event.status == HCI_SUCCESS: 481 logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}') 482 # TODO: lookup the connection and update the role 483 else: 484 logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}') 485 486 def on_hci_le_data_length_change_event(self, event): 487 self.emit( 488 'connection_data_length_change', 489 event.connection_handle, 490 event.max_tx_octets, 491 event.max_tx_time, 492 event.max_rx_octets, 493 event.max_rx_time 494 ) 495 496 def on_hci_authentication_complete_event(self, event): 497 # Notify the client 498 if event.status == HCI_SUCCESS: 499 self.emit('connection_authentication', event.connection_handle) 500 else: 501 self.emit('connection_authentication_failure', event.connection_handle, event.status) 502 503 def on_hci_encryption_change_event(self, event): 504 # Notify the client 505 if event.status == HCI_SUCCESS: 506 self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled) 507 else: 508 self.emit('connection_encryption_failure', event.connection_handle, event.status) 509 510 def on_hci_encryption_key_refresh_complete_event(self, event): 511 # Notify the client 512 if event.status == HCI_SUCCESS: 513 self.emit('connection_encryption_key_refresh', event.connection_handle) 514 else: 515 self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status) 516 517 def on_hci_link_supervision_timeout_changed_event(self, event): 518 pass 519 520 def on_hci_max_slots_change_event(self, event): 521 pass 522 523 def on_hci_page_scan_repetition_mode_change_event(self, event): 524 pass 525 526 def on_hci_link_key_notification_event(self, event): 527 logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}') 528 self.emit('link_key', event.bd_addr, event.link_key, event.key_type) 529 530 def on_hci_simple_pairing_complete_event(self, event): 531 logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}') 532 533 def on_hci_pin_code_request_event(self, event): 534 # For now, just refuse all requests 535 # TODO: delegate the decision 536 self.send_command_sync( 537 HCI_PIN_Code_Request_Negative_Reply_Command( 538 bd_addr = event.bd_addr 539 ) 540 ) 541 542 def on_hci_link_key_request_event(self, event): 543 async def send_link_key(): 544 if self.link_key_provider is None: 545 logger.debug('no link key provider') 546 link_key = None 547 else: 548 link_key = await self.link_key_provider(event.bd_addr) 549 if link_key: 550 response = HCI_Link_Key_Request_Reply_Command( 551 bd_addr = event.bd_addr, 552 link_key = link_key 553 ) 554 else: 555 response = HCI_Link_Key_Request_Negative_Reply_Command( 556 bd_addr = event.bd_addr 557 ) 558 559 await self.send_command(response) 560 561 asyncio.create_task(send_link_key()) 562 563 def on_hci_io_capability_request_event(self, event): 564 self.emit('authentication_io_capability_request', event.bd_addr) 565 566 def on_hci_io_capability_response_event(self, event): 567 pass 568 569 def on_hci_user_confirmation_request_event(self, event): 570 self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value) 571 572 def on_hci_user_passkey_request_event(self, event): 573 self.emit('authentication_user_passkey_request', event.bd_addr) 574 575 def on_hci_inquiry_complete_event(self, event): 576 self.emit('inquiry_complete') 577 578 def on_hci_inquiry_result_with_rssi_event(self, event): 579 for response in event.responses: 580 self.emit( 581 'inquiry_result', 582 response.bd_addr, 583 response.class_of_device, 584 b'', 585 response.rssi 586 ) 587 588 def on_hci_extended_inquiry_result_event(self, event): 589 self.emit( 590 'inquiry_result', 591 event.bd_addr, 592 event.class_of_device, 593 event.extended_inquiry_response, 594 event.rssi 595 ) 596 597 def on_hci_remote_name_request_complete_event(self, event): 598 if event.status != HCI_SUCCESS: 599 self.emit('remote_name_failure', event.bd_addr, event.status) 600 else: 601 self.emit('remote_name', event.bd_addr, event.remote_name) 602