• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20from pyee import EventEmitter
21from colors import color
22
23from .hci import *
24from .l2cap import *
25from .att import *
26from .gatt import *
27from .smp import *
28from .core import ConnectionParameters
29
30# -----------------------------------------------------------------------------
31# Logging
32# -----------------------------------------------------------------------------
33logger = logging.getLogger(__name__)
34
35
36# -----------------------------------------------------------------------------
37# Constants
38# -----------------------------------------------------------------------------
39HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
40HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS     = 1
41HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH    = 27
42HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS        = 1
43
44
45# -----------------------------------------------------------------------------
46class Connection:
47    def __init__(self, host, handle, role, peer_address):
48        self.host                  = host
49        self.handle                = handle
50        self.role                  = role
51        self.peer_address          = peer_address
52        self.assembler             = HCI_AclDataPacketAssembler(self.on_acl_pdu)
53
54    def on_hci_acl_data_packet(self, packet):
55        self.assembler.feed_packet(packet)
56
57    def on_acl_pdu(self, pdu):
58        l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
59
60        if l2cap_pdu.cid == ATT_CID:
61            self.host.on_gatt_pdu(self, l2cap_pdu.payload)
62        elif l2cap_pdu.cid == SMP_CID:
63            self.host.on_smp_pdu(self, l2cap_pdu.payload)
64        else:
65            self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
66
67
68# -----------------------------------------------------------------------------
69class Host(EventEmitter):
70    def __init__(self, controller_source = None, controller_sink = None):
71        super().__init__()
72
73        self.hci_sink                         = None
74        self.ready                            = False  # True when we can accept incoming packets
75        self.connections                      = {}     # Connections, by connection handle
76        self.pending_command                  = None
77        self.pending_response                 = None
78        self.hc_le_acl_data_packet_length     = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
79        self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
80        self.hc_acl_data_packet_length        = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
81        self.hc_total_num_acl_data_packets    = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
82        self.acl_packet_queue                 = collections.deque()
83        self.acl_packets_in_flight            = 0
84        self.local_supported_commands         = bytes(64)
85        self.command_semaphore                = asyncio.Semaphore(1)
86        self.long_term_key_provider           = None
87        self.link_key_provider                = None
88        self.pairing_io_capability_provider   = None  # Classic only
89
90        # Connect to the source and sink if specified
91        if controller_source:
92            controller_source.set_packet_sink(self)
93        if controller_sink:
94            self.set_packet_sink(controller_sink)
95
96    async def reset(self):
97        await self.send_command(HCI_Reset_Command())
98        self.ready = True
99
100        response = await self.send_command(HCI_Read_Local_Supported_Commands_Command())
101        if response.return_parameters.status != HCI_SUCCESS:
102            raise ProtocolError(response.return_parameters.status, 'hci')
103        self.local_supported_commands = response.return_parameters.supported_commands
104
105        await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFFFF')))
106        await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = bytes.fromhex('FFFFF00000000000')))
107        await self.send_command(HCI_Read_Local_Version_Information_Command())
108        await self.send_command(HCI_Write_LE_Host_Support_Command(le_supported_host = 1, simultaneous_le_host = 0))
109
110        response = await self.send_command(HCI_LE_Read_Buffer_Size_Command())
111        if response.return_parameters.status == HCI_SUCCESS:
112            self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
113            self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets
114            logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={response.return_parameters.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={response.return_parameters.hc_total_num_le_acl_data_packets}')
115        else:
116            logger.warn(f'HCI_LE_Read_Buffer_Size_Command failed: {response.return_parameters.status}')
117        if response.return_parameters.hc_le_acl_data_packet_length == 0 or response.return_parameters.hc_total_num_le_acl_data_packets == 0:
118            # Read the non-LE-specific values
119            response = await self.send_command(HCI_Read_Buffer_Size_Command())
120            if response.return_parameters.status == HCI_SUCCESS:
121                self.hc_acl_data_packet_length        = response.return_parameters.hc_le_acl_data_packet_length
122                self.hc_le_acl_data_packet_length     = self.hc_le_acl_data_packet_length or self.hc_acl_data_packet_length
123                self.hc_total_num_acl_data_packets    = response.return_parameters.hc_total_num_le_acl_data_packets
124                self.hc_total_num_le_acl_data_packets = self.hc_total_num_le_acl_data_packets or self.hc_total_num_acl_data_packets
125                logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}')
126            else:
127                logger.warn(f'HCI_Read_Buffer_Size_Command failed: {response.return_parameters.status}')
128
129        self.reset_done = True
130
131    @property
132    def controller(self):
133        return self.hci_sink
134
135    @controller.setter
136    def controller(self, controller):
137        self.set_packet_sink(controller)
138        if controller:
139            controller.set_packet_sink(self)
140
141    def set_packet_sink(self, sink):
142        self.hci_sink = sink
143
144    def send_hci_packet(self, packet):
145        self.hci_sink.on_packet(packet.to_bytes())
146
147    async def send_command(self, command):
148        logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
149
150        # Wait until we can send (only one pending command at a time)
151        async with self.command_semaphore:
152            assert(self.pending_command is None)
153            assert(self.pending_response is None)
154
155            # Create a future value to hold the eventual response
156            self.pending_response = asyncio.get_running_loop().create_future()
157            self.pending_command = command
158
159            try:
160                self.send_hci_packet(command)
161                response = await self.pending_response
162                # TODO: check error values
163                return response
164            except Exception as error:
165                logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}')
166                # raise error
167            finally:
168                self.pending_command = None
169                self.pending_response = None
170
171    # Use this method to send a command from a task
172    def send_command_sync(self, command):
173        async def send_command(command):
174            await self.send_command(command)
175
176        asyncio.create_task(send_command(command))
177
178    def send_l2cap_pdu(self, connection_handle, cid, pdu):
179        l2cap_pdu = L2CAP_PDU(cid, pdu).to_bytes()
180
181        # Send the data to the controller via ACL packets
182        bytes_remaining = len(l2cap_pdu)
183        offset = 0
184        pb_flag = 0
185        while bytes_remaining:
186            data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
187            acl_packet = HCI_AclDataPacket(
188                connection_handle  = connection_handle,
189                pb_flag            = pb_flag,
190                bc_flag            = 0,
191                data_total_length  = data_total_length,
192                data               = l2cap_pdu[offset:offset + data_total_length]
193            )
194            logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}')
195            self.queue_acl_packet(acl_packet)
196            pb_flag = 1
197            offset += data_total_length
198            bytes_remaining -= data_total_length
199
200    def queue_acl_packet(self, acl_packet):
201        self.acl_packet_queue.appendleft(acl_packet)
202        self.check_acl_packet_queue()
203
204        if len(self.acl_packet_queue):
205            logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue')
206
207    def check_acl_packet_queue(self):
208        # Send all we can
209        while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets:
210            packet = self.acl_packet_queue.pop()
211            self.send_hci_packet(packet)
212            self.acl_packets_in_flight += 1
213
214    # Packet Sink protocol (packets coming from the controller via HCI)
215    def on_packet(self, packet):
216        hci_packet = HCI_Packet.from_bytes(packet)
217        if self.ready or (
218            hci_packet.hci_packet_type == HCI_EVENT_PACKET and
219            hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and
220            hci_packet.command_opcode == HCI_RESET_COMMAND
221        ):
222            self.on_hci_packet(hci_packet)
223        else:
224            logger.debug('reset not done, ignoring packet from controller')
225
226    def on_hci_packet(self, packet):
227        logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
228
229        # If the packet is a command, invoke the handler for this packet
230        if packet.hci_packet_type == HCI_COMMAND_PACKET:
231            self.on_hci_command_packet(packet)
232        elif packet.hci_packet_type == HCI_EVENT_PACKET:
233            self.on_hci_event_packet(packet)
234        elif packet.hci_packet_type == HCI_ACL_DATA_PACKET:
235            self.on_hci_acl_data_packet(packet)
236        else:
237            logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
238
239    def on_hci_command_packet(self, command):
240        logger.warning(f'!!! unexpected command packet: {command}')
241
242    def on_hci_event_packet(self, event):
243        handler_name = f'on_{event.name.lower()}'
244        handler = getattr(self, handler_name, self.on_hci_event)
245        handler(event)
246
247    def on_hci_acl_data_packet(self, packet):
248        # Look for the connection to which this data belongs
249        if connection := self.connections.get(packet.connection_handle):
250            connection.on_hci_acl_data_packet(packet)
251
252    def on_gatt_pdu(self, connection, pdu):
253        self.emit('gatt_pdu', connection.handle, pdu)
254
255    def on_smp_pdu(self, connection, pdu):
256        self.emit('smp_pdu', connection.handle, pdu)
257
258    def on_l2cap_pdu(self, connection, cid, pdu):
259        self.emit('l2cap_pdu', connection.handle, cid, pdu)
260
261    def on_command_processed(self, event):
262        if self.pending_response:
263            # Check that it is what we were expecting
264            if self.pending_command.op_code != event.command_opcode:
265                logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}')
266
267            self.pending_response.set_result(event)
268        else:
269            logger.warning('!!! no pending response future to set')
270
271    ############################################################
272    # HCI handlers
273    ############################################################
274    def on_hci_event(self, event):
275        logger.warning(f'{color(f"--- Ignoring event {event}", "red")}')
276
277    def on_hci_command_complete_event(self, event):
278        if event.command_opcode == 0:
279            # This is used just for the Num_HCI_Command_Packets field, not related to an actual command
280            logger.debug('no-command event')
281        else:
282            return self.on_command_processed(event)
283
284    def on_hci_command_status_event(self, event):
285        return self.on_command_processed(event)
286
287    def on_hci_number_of_completed_packets_event(self, event):
288        total_packets = sum(event.num_completed_packets)
289        if total_packets <= self.acl_packets_in_flight:
290            self.acl_packets_in_flight -= total_packets
291            self.check_acl_packet_queue()
292        else:
293            logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight'))
294            self.acl_packets_in_flight = 0
295
296    # Classic only
297    def on_hci_connection_request_event(self, event):
298        # For now, just accept everything
299        # TODO: delegate the decision
300        self.send_command_sync(
301            HCI_Accept_Connection_Request_Command(
302                bd_addr = event.bd_addr,
303                role    = 0x01  # Remain the peripheral
304            )
305        )
306
307    def on_hci_le_connection_complete_event(self, event):
308        # Check if this is a cancellation
309        if event.status == HCI_SUCCESS:
310            # Create/update the connection
311            logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}')
312
313            connection = self.connections.get(event.connection_handle)
314            if connection is None:
315                connection = Connection(self, event.connection_handle, event.role, event.peer_address)
316                self.connections[event.connection_handle] = connection
317
318            # Notify the client
319            connection_parameters = ConnectionParameters(
320                event.conn_interval,
321                event.conn_latency,
322                event.supervision_timeout
323            )
324            self.emit(
325                'connection',
326                event.connection_handle,
327                BT_LE_TRANSPORT,
328                event.peer_address,
329                None,
330                event.role,
331                connection_parameters
332            )
333        else:
334            logger.debug(f'### CONNECTION FAILED: {event.status}')
335
336            # Notify the listeners
337            self.emit('connection_failure', event.status)
338
339    def on_hci_le_enhanced_connection_complete_event(self, event):
340        # Just use the same implementation as for the non-enhanced event for now
341        self.on_hci_le_connection_complete_event(event)
342
343    def on_hci_connection_complete_event(self, event):
344        if event.status == HCI_SUCCESS:
345            # Create/update the connection
346            logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}')
347
348            connection = self.connections.get(event.connection_handle)
349            if connection is None:
350                connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr)
351                self.connections[event.connection_handle] = connection
352
353            # Notify the client
354            self.emit(
355                'connection',
356                event.connection_handle,
357                BT_BR_EDR_TRANSPORT,
358                event.bd_addr,
359                None,
360                BT_CENTRAL_ROLE,
361                None
362            )
363        else:
364            logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
365
366            # Notify the client
367            self.emit('connection_failure', event.connection_handle, event.status)
368
369    def on_hci_disconnection_complete_event(self, event):
370        # Find the connection
371        if (connection := self.connections.get(event.connection_handle)) is None:
372            logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
373            return
374
375        if event.status == HCI_SUCCESS:
376            logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}')
377            del self.connections[event.connection_handle]
378
379            # Notify the listeners
380            self.emit('disconnection', event.connection_handle, event.reason)
381        else:
382            logger.debug(f'### DISCONNECTION FAILED: {event.status}')
383
384            # Notify the listeners
385            self.emit('disconnection_failure', event.status)
386
387    def on_hci_le_connection_update_complete_event(self, event):
388        if (connection := self.connections.get(event.connection_handle)) is None:
389            logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle')
390            return
391
392        # Notify the client
393        if event.status == HCI_SUCCESS:
394            connection_parameters = ConnectionParameters(
395                event.conn_interval,
396                event.conn_latency,
397                event.supervision_timeout
398            )
399            self.emit('connection_parameters_update', connection.handle, connection_parameters)
400        else:
401            self.emit('connection_parameters_update_failure', connection.handle, event.status)
402
403    def on_hci_le_phy_update_complete_event(self, event):
404        if (connection := self.connections.get(event.connection_handle)) is None:
405            logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle')
406            return
407
408        # Notify the client
409        if event.status == HCI_SUCCESS:
410            connection_phy = ConnectionPHY(event.tx_phy, event.rx_phy)
411            self.emit('connection_phy_update', connection.handle, connection_phy)
412        else:
413            self.emit('connection_phy_update_failure', connection.handle, event.status)
414
415    def on_hci_le_advertising_report_event(self, event):
416        for report in event.reports:
417            self.emit(
418                'advertising_report',
419                report.address,
420                report.data,
421                report.rssi,
422                report.event_type
423            )
424
425    def on_hci_le_remote_connection_parameter_request_event(self, event):
426        if event.connection_handle not in self.connections:
427            logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle')
428            return
429
430        # For now, just accept everything
431        # TODO: delegate the decision
432        self.send_command_sync(
433            HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
434                connection_handle = event.connection_handle,
435                interval_min      = event.interval_min,
436                interval_max      = event.interval_max,
437                latency           = event.latency,
438                timeout           = event.timeout,
439                minimum_ce_length = 0,
440                maximum_ce_length = 0
441            )
442        )
443
444    def on_hci_le_long_term_key_request_event(self, event):
445        if (connection := self.connections.get(event.connection_handle)) is None:
446            logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle')
447            return
448
449        async def send_long_term_key():
450            if self.long_term_key_provider is None:
451                logger.debug('no long term key provider')
452                long_term_key = None
453            else:
454                long_term_key = await self.long_term_key_provider(
455                    connection.handle,
456                    event.random_number,
457                    event.encryption_diversifier
458                )
459            if long_term_key:
460                response = HCI_LE_Long_Term_Key_Request_Reply_Command(
461                    connection_handle = event.connection_handle,
462                    long_term_key     = long_term_key
463                )
464            else:
465                response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(
466                    connection_handle = event.connection_handle
467                )
468
469            await self.send_command(response)
470
471        asyncio.create_task(send_long_term_key())
472
473    def on_hci_synchronous_connection_complete_event(self, event):
474        pass
475
476    def on_hci_synchronous_connection_changed_event(self, event):
477        pass
478
479    def on_hci_role_change_event(self, event):
480        if event.status == HCI_SUCCESS:
481            logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}')
482            # TODO: lookup the connection and update the role
483        else:
484            logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}')
485
486    def on_hci_le_data_length_change_event(self, event):
487        self.emit(
488            'connection_data_length_change',
489            event.connection_handle,
490            event.max_tx_octets,
491            event.max_tx_time,
492            event.max_rx_octets,
493            event.max_rx_time
494        )
495
496    def on_hci_authentication_complete_event(self, event):
497        # Notify the client
498        if event.status == HCI_SUCCESS:
499            self.emit('connection_authentication', event.connection_handle)
500        else:
501            self.emit('connection_authentication_failure', event.connection_handle, event.status)
502
503    def on_hci_encryption_change_event(self, event):
504        # Notify the client
505        if event.status == HCI_SUCCESS:
506            self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled)
507        else:
508            self.emit('connection_encryption_failure', event.connection_handle, event.status)
509
510    def on_hci_encryption_key_refresh_complete_event(self, event):
511        # Notify the client
512        if event.status == HCI_SUCCESS:
513            self.emit('connection_encryption_key_refresh', event.connection_handle)
514        else:
515            self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status)
516
517    def on_hci_link_supervision_timeout_changed_event(self, event):
518        pass
519
520    def on_hci_max_slots_change_event(self, event):
521        pass
522
523    def on_hci_page_scan_repetition_mode_change_event(self, event):
524        pass
525
526    def on_hci_link_key_notification_event(self, event):
527        logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}')
528        self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
529
530    def on_hci_simple_pairing_complete_event(self, event):
531        logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}')
532
533    def on_hci_pin_code_request_event(self, event):
534        # For now, just refuse all requests
535        # TODO: delegate the decision
536        self.send_command_sync(
537            HCI_PIN_Code_Request_Negative_Reply_Command(
538                bd_addr = event.bd_addr
539            )
540        )
541
542    def on_hci_link_key_request_event(self, event):
543        async def send_link_key():
544            if self.link_key_provider is None:
545                logger.debug('no link key provider')
546                link_key = None
547            else:
548                link_key = await self.link_key_provider(event.bd_addr)
549            if link_key:
550                response = HCI_Link_Key_Request_Reply_Command(
551                    bd_addr  = event.bd_addr,
552                    link_key = link_key
553                )
554            else:
555                response = HCI_Link_Key_Request_Negative_Reply_Command(
556                    bd_addr = event.bd_addr
557                )
558
559            await self.send_command(response)
560
561        asyncio.create_task(send_link_key())
562
563    def on_hci_io_capability_request_event(self, event):
564        self.emit('authentication_io_capability_request', event.bd_addr)
565
566    def on_hci_io_capability_response_event(self, event):
567        pass
568
569    def on_hci_user_confirmation_request_event(self, event):
570        self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value)
571
572    def on_hci_user_passkey_request_event(self, event):
573        self.emit('authentication_user_passkey_request', event.bd_addr)
574
575    def on_hci_inquiry_complete_event(self, event):
576        self.emit('inquiry_complete')
577
578    def on_hci_inquiry_result_with_rssi_event(self, event):
579        for response in event.responses:
580            self.emit(
581                'inquiry_result',
582                response.bd_addr,
583                response.class_of_device,
584                b'',
585                response.rssi
586            )
587
588    def on_hci_extended_inquiry_result_event(self, event):
589        self.emit(
590            'inquiry_result',
591            event.bd_addr,
592            event.class_of_device,
593            event.extended_inquiry_response,
594            event.rssi
595        )
596
597    def on_hci_remote_name_request_complete_event(self, event):
598        if event.status != HCI_SUCCESS:
599            self.emit('remote_name_failure', event.bd_addr, event.status)
600        else:
601            self.emit('remote_name', event.bd_addr, event.remote_name)
602