• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19
20import logging
21import asyncio
22import collections
23import dataclasses
24import enum
25from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
26from typing_extensions import Self
27
28from pyee import EventEmitter
29
30from bumble import core
31from bumble import l2cap
32from bumble import sdp
33from .colors import color
34from .core import (
35    UUID,
36    BT_RFCOMM_PROTOCOL_ID,
37    BT_BR_EDR_TRANSPORT,
38    BT_L2CAP_PROTOCOL_ID,
39    InvalidStateError,
40    ProtocolError,
41)
42
43if TYPE_CHECKING:
44    from bumble.device import Device, Connection
45
46# -----------------------------------------------------------------------------
47# Logging
48# -----------------------------------------------------------------------------
49logger = logging.getLogger(__name__)
50
51
52# -----------------------------------------------------------------------------
53# Constants
54# -----------------------------------------------------------------------------
55# fmt: off
56
57RFCOMM_PSM = 0x0003
58DEFAULT_RX_QUEUE_SIZE = 32
59
60class FrameType(enum.IntEnum):
61    SABM = 0x2F  # Control field [1,1,1,1,_,1,0,0] LSB-first
62    UA   = 0x63  # Control field [0,1,1,0,_,0,1,1] LSB-first
63    DM   = 0x0F  # Control field [1,1,1,1,_,0,0,0] LSB-first
64    DISC = 0x43  # Control field [0,1,0,_,0,0,1,1] LSB-first
65    UIH  = 0xEF  # Control field [1,1,1,_,1,1,1,1] LSB-first
66    UI   = 0x03  # Control field [0,0,0,_,0,0,1,1] LSB-first
67
68class MccType(enum.IntEnum):
69    PN  = 0x20
70    MSC = 0x38
71
72
73# FCS CRC
74CRC_TABLE = bytes([
75    0X00, 0X91, 0XE3, 0X72, 0X07, 0X96, 0XE4, 0X75,
76    0X0E, 0X9F, 0XED, 0X7C, 0X09, 0X98, 0XEA, 0X7B,
77    0X1C, 0X8D, 0XFF, 0X6E, 0X1B, 0X8A, 0XF8, 0X69,
78    0X12, 0X83, 0XF1, 0X60, 0X15, 0X84, 0XF6, 0X67,
79    0X38, 0XA9, 0XDB, 0X4A, 0X3F, 0XAE, 0XDC, 0X4D,
80    0X36, 0XA7, 0XD5, 0X44, 0X31, 0XA0, 0XD2, 0X43,
81    0X24, 0XB5, 0XC7, 0X56, 0X23, 0XB2, 0XC0, 0X51,
82    0X2A, 0XBB, 0XC9, 0X58, 0X2D, 0XBC, 0XCE, 0X5F,
83    0X70, 0XE1, 0X93, 0X02, 0X77, 0XE6, 0X94, 0X05,
84    0X7E, 0XEF, 0X9D, 0X0C, 0X79, 0XE8, 0X9A, 0X0B,
85    0X6C, 0XFD, 0X8F, 0X1E, 0X6B, 0XFA, 0X88, 0X19,
86    0X62, 0XF3, 0X81, 0X10, 0X65, 0XF4, 0X86, 0X17,
87    0X48, 0XD9, 0XAB, 0X3A, 0X4F, 0XDE, 0XAC, 0X3D,
88    0X46, 0XD7, 0XA5, 0X34, 0X41, 0XD0, 0XA2, 0X33,
89    0X54, 0XC5, 0XB7, 0X26, 0X53, 0XC2, 0XB0, 0X21,
90    0X5A, 0XCB, 0XB9, 0X28, 0X5D, 0XCC, 0XBE, 0X2F,
91    0XE0, 0X71, 0X03, 0X92, 0XE7, 0X76, 0X04, 0X95,
92    0XEE, 0X7F, 0X0D, 0X9C, 0XE9, 0X78, 0X0A, 0X9B,
93    0XFC, 0X6D, 0X1F, 0X8E, 0XFB, 0X6A, 0X18, 0X89,
94    0XF2, 0X63, 0X11, 0X80, 0XF5, 0X64, 0X16, 0X87,
95    0XD8, 0X49, 0X3B, 0XAA, 0XDF, 0X4E, 0X3C, 0XAD,
96    0XD6, 0X47, 0X35, 0XA4, 0XD1, 0X40, 0X32, 0XA3,
97    0XC4, 0X55, 0X27, 0XB6, 0XC3, 0X52, 0X20, 0XB1,
98    0XCA, 0X5B, 0X29, 0XB8, 0XCD, 0X5C, 0X2E, 0XBF,
99    0X90, 0X01, 0X73, 0XE2, 0X97, 0X06, 0X74, 0XE5,
100    0X9E, 0X0F, 0X7D, 0XEC, 0X99, 0X08, 0X7A, 0XEB,
101    0X8C, 0X1D, 0X6F, 0XFE, 0X8B, 0X1A, 0X68, 0XF9,
102    0X82, 0X13, 0X61, 0XF0, 0X85, 0X14, 0X66, 0XF7,
103    0XA8, 0X39, 0X4B, 0XDA, 0XAF, 0X3E, 0X4C, 0XDD,
104    0XA6, 0X37, 0X45, 0XD4, 0XA1, 0X30, 0X42, 0XD3,
105    0XB4, 0X25, 0X57, 0XC6, 0XB3, 0X22, 0X50, 0XC1,
106    0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
107])
108
109RFCOMM_DEFAULT_L2CAP_MTU        = 2048
110RFCOMM_DEFAULT_INITIAL_CREDITS  = 7
111RFCOMM_DEFAULT_MAX_CREDITS      = 32
112RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
113RFCOMM_DEFAULT_MAX_FRAME_SIZE   = 2000
114
115RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
116RFCOMM_DYNAMIC_CHANNEL_NUMBER_END   = 30
117
118# fmt: on
119
120
121# -----------------------------------------------------------------------------
122def make_service_sdp_records(
123    service_record_handle: int, channel: int, uuid: Optional[UUID] = None
124) -> List[sdp.ServiceAttribute]:
125    """
126    Create SDP records for an RFComm service given a channel number and an
127    optional UUID. A Service Class Attribute is included only if the UUID is not None.
128    """
129    records = [
130        sdp.ServiceAttribute(
131            sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
132            sdp.DataElement.unsigned_integer_32(service_record_handle),
133        ),
134        sdp.ServiceAttribute(
135            sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
136            sdp.DataElement.sequence(
137                [sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
138            ),
139        ),
140        sdp.ServiceAttribute(
141            sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
142            sdp.DataElement.sequence(
143                [
144                    sdp.DataElement.sequence(
145                        [sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
146                    ),
147                    sdp.DataElement.sequence(
148                        [
149                            sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
150                            sdp.DataElement.unsigned_integer_8(channel),
151                        ]
152                    ),
153                ]
154            ),
155        ),
156    ]
157
158    if uuid:
159        records.append(
160            sdp.ServiceAttribute(
161                sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
162                sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
163            )
164        )
165
166    return records
167
168
169# -----------------------------------------------------------------------------
170async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
171    """Searches all RFCOMM channels and their associated UUID from SDP service records.
172
173    Args:
174        connection: ACL connection to make SDP search.
175
176    Returns:
177        Dictionary mapping from channel number to service class UUID list.
178    """
179    results = {}
180    async with sdp.Client(connection) as sdp_client:
181        search_result = await sdp_client.search_attributes(
182            uuids=[core.BT_RFCOMM_PROTOCOL_ID],
183            attribute_ids=[
184                sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
185                sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
186            ],
187        )
188        for attribute_lists in search_result:
189            service_classes: List[UUID] = []
190            channel: Optional[int] = None
191            for attribute in attribute_lists:
192                # The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
193                if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
194                    protocol_descriptor_list = attribute.value.value
195                    channel = protocol_descriptor_list[1].value[1].value
196                elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
197                    service_class_id_list = attribute.value.value
198                    service_classes = [
199                        service_class.value for service_class in service_class_id_list
200                    ]
201            if not service_classes or not channel:
202                logger.warning(f"Bad result {attribute_lists}.")
203            else:
204                results[channel] = service_classes
205    return results
206
207
208# -----------------------------------------------------------------------------
209async def find_rfcomm_channel_with_uuid(
210    connection: Connection, uuid: str | UUID
211) -> Optional[int]:
212    """Searches an RFCOMM channel associated with given UUID from service records.
213
214    Args:
215        connection: ACL connection to make SDP search.
216        uuid: UUID of service record to search for.
217
218    Returns:
219        RFCOMM channel number if found, otherwise None.
220    """
221    if isinstance(uuid, str):
222        uuid = UUID(uuid)
223    return next(
224        (
225            channel
226            for channel, class_id_list in (
227                await find_rfcomm_channels(connection)
228            ).items()
229            if uuid in class_id_list
230        ),
231        None,
232    )
233
234
235# -----------------------------------------------------------------------------
236def compute_fcs(buffer: bytes) -> int:
237    result = 0xFF
238    for byte in buffer:
239        result = CRC_TABLE[result ^ byte]
240    return 0xFF - result
241
242
243# -----------------------------------------------------------------------------
244class RFCOMM_Frame:
245    def __init__(
246        self,
247        frame_type: FrameType,
248        c_r: int,
249        dlci: int,
250        p_f: int,
251        information: bytes = b'',
252        with_credits: bool = False,
253    ) -> None:
254        self.type = frame_type
255        self.c_r = c_r
256        self.dlci = dlci
257        self.p_f = p_f
258        self.information = information
259        length = len(information)
260        if with_credits:
261            length -= 1
262        if length > 0x7F:
263            # 2-byte length indicator
264            self.length = bytes([(length & 0x7F) << 1, (length >> 7) & 0xFF])
265        else:
266            # 1-byte length indicator
267            self.length = bytes([(length << 1) | 1])
268        self.address = (dlci << 2) | (c_r << 1) | 1
269        self.control = frame_type | (p_f << 4)
270        if frame_type == FrameType.UIH:
271            self.fcs = compute_fcs(bytes([self.address, self.control]))
272        else:
273            self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
274
275    @staticmethod
276    def parse_mcc(data) -> Tuple[int, bool, bytes]:
277        mcc_type = data[0] >> 2
278        c_r = bool((data[0] >> 1) & 1)
279        length = data[1]
280        if data[1] & 1:
281            length >>= 1
282            value = data[2:]
283        else:
284            length = (data[3] << 7) & (length >> 1)
285            value = data[3 : 3 + length]
286
287        return (mcc_type, c_r, value)
288
289    @staticmethod
290    def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes:
291        return (
292            bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
293            + data
294        )
295
296    @staticmethod
297    def sabm(c_r: int, dlci: int):
298        return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)
299
300    @staticmethod
301    def ua(c_r: int, dlci: int):
302        return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)
303
304    @staticmethod
305    def dm(c_r: int, dlci: int):
306        return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)
307
308    @staticmethod
309    def disc(c_r: int, dlci: int):
310        return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)
311
312    @staticmethod
313    def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
314        return RFCOMM_Frame(
315            FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
316        )
317
318    @staticmethod
319    def from_bytes(data: bytes) -> RFCOMM_Frame:
320        # Extract fields
321        dlci = (data[0] >> 2) & 0x3F
322        c_r = (data[0] >> 1) & 0x01
323        frame_type = FrameType(data[1] & 0xEF)
324        p_f = (data[1] >> 4) & 0x01
325        length = data[2]
326        if length & 0x01:
327            length >>= 1
328            information = data[3:-1]
329        else:
330            length = (data[3] << 7) & (length >> 1)
331            information = data[4:-1]
332        fcs = data[-1]
333
334        # Construct the frame and check the CRC
335        frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
336        if frame.fcs != fcs:
337            logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
338            raise ValueError('fcs mismatch')
339
340        return frame
341
342    def __bytes__(self) -> bytes:
343        return (
344            bytes([self.address, self.control])
345            + self.length
346            + self.information
347            + bytes([self.fcs])
348        )
349
350    def __str__(self) -> str:
351        return (
352            f'{color(self.type.name, "yellow")}'
353            f'(c/r={self.c_r},'
354            f'dlci={self.dlci},'
355            f'p/f={self.p_f},'
356            f'length={len(self.information)},'
357            f'fcs=0x{self.fcs:02X})'
358        )
359
360
361# -----------------------------------------------------------------------------
362@dataclasses.dataclass
363class RFCOMM_MCC_PN:
364    dlci: int
365    cl: int
366    priority: int
367    ack_timer: int
368    max_frame_size: int
369    max_retransmissions: int
370    initial_credits: int
371
372    def __post_init__(self) -> None:
373        if self.initial_credits < 1 or self.initial_credits > 7:
374            logger.warning(
375                f'Initial credits {self.initial_credits} is out of range [1, 7].'
376            )
377
378    @staticmethod
379    def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
380        return RFCOMM_MCC_PN(
381            dlci=data[0],
382            cl=data[1],
383            priority=data[2],
384            ack_timer=data[3],
385            max_frame_size=data[4] | data[5] << 8,
386            max_retransmissions=data[6],
387            initial_credits=data[7] & 0x07,
388        )
389
390    def __bytes__(self) -> bytes:
391        return bytes(
392            [
393                self.dlci & 0xFF,
394                self.cl & 0xFF,
395                self.priority & 0xFF,
396                self.ack_timer & 0xFF,
397                self.max_frame_size & 0xFF,
398                (self.max_frame_size >> 8) & 0xFF,
399                self.max_retransmissions & 0xFF,
400                # Only 3 bits are meaningful.
401                self.initial_credits & 0x07,
402            ]
403        )
404
405
406# -----------------------------------------------------------------------------
407@dataclasses.dataclass
408class RFCOMM_MCC_MSC:
409    dlci: int
410    fc: int
411    rtc: int
412    rtr: int
413    ic: int
414    dv: int
415
416    @staticmethod
417    def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
418        return RFCOMM_MCC_MSC(
419            dlci=data[0] >> 2,
420            fc=data[1] >> 1 & 1,
421            rtc=data[1] >> 2 & 1,
422            rtr=data[1] >> 3 & 1,
423            ic=data[1] >> 6 & 1,
424            dv=data[1] >> 7 & 1,
425        )
426
427    def __bytes__(self) -> bytes:
428        return bytes(
429            [
430                (self.dlci << 2) | 3,
431                1
432                | self.fc << 1
433                | self.rtc << 2
434                | self.rtr << 3
435                | self.ic << 6
436                | self.dv << 7,
437            ]
438        )
439
440
441# -----------------------------------------------------------------------------
442class DLC(EventEmitter):
443    class State(enum.IntEnum):
444        INIT = 0x00
445        CONNECTING = 0x01
446        CONNECTED = 0x02
447        DISCONNECTING = 0x03
448        DISCONNECTED = 0x04
449        RESET = 0x05
450
451    def __init__(
452        self,
453        multiplexer: Multiplexer,
454        dlci: int,
455        tx_max_frame_size: int,
456        tx_initial_credits: int,
457        rx_max_frame_size: int,
458        rx_initial_credits: int,
459    ) -> None:
460        super().__init__()
461        self.multiplexer = multiplexer
462        self.dlci = dlci
463        self.rx_max_frame_size = rx_max_frame_size
464        self.rx_initial_credits = rx_initial_credits
465        self.rx_max_credits = RFCOMM_DEFAULT_MAX_CREDITS
466        self.rx_credits = rx_initial_credits
467        self.rx_credits_threshold = RFCOMM_DEFAULT_CREDIT_THRESHOLD
468        self.tx_max_frame_size = tx_max_frame_size
469        self.tx_credits = tx_initial_credits
470        self.tx_buffer = b''
471        self.state = DLC.State.INIT
472        self.role = multiplexer.role
473        self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
474        self.connection_result: Optional[asyncio.Future] = None
475        self.disconnection_result: Optional[asyncio.Future] = None
476        self.drained = asyncio.Event()
477        self.drained.set()
478        # Queued packets when sink is not set.
479        self._enqueued_rx_packets: collections.deque[bytes] = collections.deque(
480            maxlen=DEFAULT_RX_QUEUE_SIZE
481        )
482        self._sink: Optional[Callable[[bytes], None]] = None
483
484        # Compute the MTU
485        max_overhead = 4 + 1  # header with 2-byte length + fcs
486        self.mtu = min(
487            tx_max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
488        )
489
490    @property
491    def sink(self) -> Optional[Callable[[bytes], None]]:
492        return self._sink
493
494    @sink.setter
495    def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
496        self._sink = sink
497        # Dump queued packets to sink
498        if sink:
499            for packet in self._enqueued_rx_packets:
500                sink(packet)  # pylint: disable=not-callable
501            self._enqueued_rx_packets.clear()
502
503    def change_state(self, new_state: State) -> None:
504        logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
505        self.state = new_state
506
507    def send_frame(self, frame: RFCOMM_Frame) -> None:
508        self.multiplexer.send_frame(frame)
509
510    def on_frame(self, frame: RFCOMM_Frame) -> None:
511        handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
512        handler(frame)
513
514    def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
515        if self.state != DLC.State.CONNECTING:
516            logger.warning(
517                color('!!! received SABM when not in CONNECTING state', 'red')
518            )
519            return
520
521        self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
522
523        # Exchange the modem status with the peer
524        msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
525        mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
526        logger.debug(f'>>> MCC MSC Command: {msc}')
527        self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
528
529        self.change_state(DLC.State.CONNECTED)
530        self.emit('open')
531
532    def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
533        if self.state == DLC.State.CONNECTING:
534            # Exchange the modem status with the peer
535            msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
536            mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
537            logger.debug(f'>>> MCC MSC Command: {msc}')
538            self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
539
540            self.change_state(DLC.State.CONNECTED)
541            if self.connection_result:
542                self.connection_result.set_result(None)
543                self.connection_result = None
544            self.multiplexer.on_dlc_open_complete(self)
545        elif self.state == DLC.State.DISCONNECTING:
546            self.change_state(DLC.State.DISCONNECTED)
547            if self.disconnection_result:
548                self.disconnection_result.set_result(None)
549                self.disconnection_result = None
550            self.multiplexer.on_dlc_disconnection(self)
551            self.emit('close')
552        else:
553            logger.warning(
554                color(
555                    (
556                        '!!! received UA frame when not in '
557                        'CONNECTING or DISCONNECTING state'
558                    ),
559                    'red',
560                )
561            )
562
563    def on_dm_frame(self, frame: RFCOMM_Frame) -> None:
564        # TODO: handle all states
565        pass
566
567    def on_disc_frame(self, _frame: RFCOMM_Frame) -> None:
568        # TODO: handle all states
569        self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
570
571    def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
572        data = frame.information
573        if frame.p_f == 1:
574            # With credits
575            received_credits = frame.information[0]
576            self.tx_credits += received_credits
577
578            logger.debug(
579                f'<<< Credits [{self.dlci}]: '
580                f'received {received_credits}, total={self.tx_credits}'
581            )
582            data = data[1:]
583
584        logger.debug(
585            f'{color("<<< Data", "yellow")} '
586            f'[{self.dlci}] {len(data)} bytes, '
587            f'rx_credits={self.rx_credits}: {data.hex()}'
588        )
589        if data:
590            if self._sink:
591                self._sink(data)  # pylint: disable=not-callable
592            else:
593                self._enqueued_rx_packets.append(data)
594            if (
595                self._enqueued_rx_packets.maxlen
596                and len(self._enqueued_rx_packets) >= self._enqueued_rx_packets.maxlen
597            ):
598                logger.warning(f'DLC [{self.dlci}] received packet queue is full')
599
600            # Update the credits
601            if self.rx_credits > 0:
602                self.rx_credits -= 1
603            else:
604                logger.warning(color('!!! received frame with no rx credits', 'red'))
605
606        # Check if there's anything to send (including credits)
607        self.process_tx()
608
609    def on_ui_frame(self, frame: RFCOMM_Frame) -> None:
610        pass
611
612    def on_mcc_msc(self, c_r: bool, msc: RFCOMM_MCC_MSC) -> None:
613        if c_r:
614            # Command
615            logger.debug(f'<<< MCC MSC Command: {msc}')
616            msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
617            mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
618            logger.debug(f'>>> MCC MSC Response: {msc}')
619            self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
620        else:
621            # Response
622            logger.debug(f'<<< MCC MSC Response: {msc}')
623
624    def connect(self) -> None:
625        if self.state != DLC.State.INIT:
626            raise InvalidStateError('invalid state')
627
628        self.change_state(DLC.State.CONNECTING)
629        self.connection_result = asyncio.get_running_loop().create_future()
630        self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
631
632    async def disconnect(self) -> None:
633        if self.state != DLC.State.CONNECTED:
634            raise InvalidStateError('invalid state')
635
636        self.disconnection_result = asyncio.get_running_loop().create_future()
637        self.change_state(DLC.State.DISCONNECTING)
638        self.send_frame(
639            RFCOMM_Frame.disc(
640                c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=self.dlci
641            )
642        )
643        await self.disconnection_result
644
645    def accept(self) -> None:
646        if self.state != DLC.State.INIT:
647            raise InvalidStateError('invalid state')
648
649        pn = RFCOMM_MCC_PN(
650            dlci=self.dlci,
651            cl=0xE0,
652            priority=7,
653            ack_timer=0,
654            max_frame_size=self.rx_max_frame_size,
655            max_retransmissions=0,
656            initial_credits=self.rx_initial_credits,
657        )
658        mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
659        logger.debug(f'>>> PN Response: {pn}')
660        self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
661        self.change_state(DLC.State.CONNECTING)
662
663    def rx_credits_needed(self) -> int:
664        if self.rx_credits <= self.rx_credits_threshold:
665            return self.rx_max_credits - self.rx_credits
666
667        return 0
668
669    def process_tx(self) -> None:
670        # Send anything we can (or an empty frame if we need to send rx credits)
671        rx_credits_needed = self.rx_credits_needed()
672        while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
673            # Get the next chunk, up to MTU size
674            if rx_credits_needed > 0:
675                chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
676                self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
677                self.rx_credits += rx_credits_needed
678                tx_credit_spent = len(chunk) > 1
679            else:
680                chunk = self.tx_buffer[: self.mtu]
681                self.tx_buffer = self.tx_buffer[len(chunk) :]
682                tx_credit_spent = True
683
684            # Update the tx credits
685            # (no tx credit spent for empty frames that only contain rx credits)
686            if tx_credit_spent:
687                self.tx_credits -= 1
688
689            # Send the frame
690            logger.debug(
691                f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, '
692                f'rx_credits={self.rx_credits}, '
693                f'tx_credits={self.tx_credits}'
694            )
695            self.send_frame(
696                RFCOMM_Frame.uih(
697                    c_r=self.c_r,
698                    dlci=self.dlci,
699                    information=chunk,
700                    p_f=1 if rx_credits_needed > 0 else 0,
701                )
702            )
703
704            rx_credits_needed = 0
705            if not self.tx_buffer:
706                self.drained.set()
707
708    # Stream protocol
709    def write(self, data: Union[bytes, str]) -> None:
710        # We can only send bytes
711        if not isinstance(data, bytes):
712            if isinstance(data, str):
713                # Automatically convert strings to bytes using UTF-8
714                data = data.encode('utf-8')
715            else:
716                raise ValueError('write only accept bytes or strings')
717
718        self.tx_buffer += data
719        self.drained.clear()
720        self.process_tx()
721
722    async def drain(self) -> None:
723        await self.drained.wait()
724
725    def abort(self) -> None:
726        logger.debug(f'aborting DLC: {self}')
727        if self.connection_result:
728            self.connection_result.cancel()
729            self.connection_result = None
730        if self.disconnection_result:
731            self.disconnection_result.cancel()
732            self.disconnection_result = None
733        self.change_state(DLC.State.RESET)
734        self.emit('close')
735
736    def __str__(self) -> str:
737        return f'DLC(dlci={self.dlci},state={self.state.name})'
738
739
740# -----------------------------------------------------------------------------
741class Multiplexer(EventEmitter):
742    class Role(enum.IntEnum):
743        INITIATOR = 0x00
744        RESPONDER = 0x01
745
746    class State(enum.IntEnum):
747        INIT = 0x00
748        CONNECTING = 0x01
749        CONNECTED = 0x02
750        OPENING = 0x03
751        DISCONNECTING = 0x04
752        DISCONNECTED = 0x05
753        RESET = 0x06
754
755    connection_result: Optional[asyncio.Future]
756    disconnection_result: Optional[asyncio.Future]
757    open_result: Optional[asyncio.Future]
758    acceptor: Optional[Callable[[int], Optional[Tuple[int, int]]]]
759    dlcs: Dict[int, DLC]
760
761    def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
762        super().__init__()
763        self.role = role
764        self.l2cap_channel = l2cap_channel
765        self.state = Multiplexer.State.INIT
766        self.dlcs = {}  # DLCs, by DLCI
767        self.connection_result = None
768        self.disconnection_result = None
769        self.open_result = None
770        self.open_pn: Optional[RFCOMM_MCC_PN] = None
771        self.open_rx_max_credits = 0
772        self.acceptor = None
773
774        # Become a sink for the L2CAP channel
775        l2cap_channel.sink = self.on_pdu
776
777        l2cap_channel.on('close', self.on_l2cap_channel_close)
778
779    def change_state(self, new_state: State) -> None:
780        logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
781        self.state = new_state
782
783    def send_frame(self, frame: RFCOMM_Frame) -> None:
784        logger.debug(f'>>> Multiplexer sending {frame}')
785        self.l2cap_channel.send_pdu(frame)
786
787    def on_pdu(self, pdu: bytes) -> None:
788        frame = RFCOMM_Frame.from_bytes(pdu)
789        logger.debug(f'<<< Multiplexer received {frame}')
790
791        # Dispatch to this multiplexer or to a dlc, depending on the address
792        if frame.dlci == 0:
793            self.on_frame(frame)
794        else:
795            if frame.type == FrameType.DM:
796                # DM responses are for a DLCI, but since we only create the dlc when we
797                # receive a PN response (because we need the parameters), we handle DM
798                # frames at the Multiplexer level
799                self.on_dm_frame(frame)
800            else:
801                dlc = self.dlcs.get(frame.dlci)
802                if dlc is None:
803                    logger.warning(f'no dlc for DLCI {frame.dlci}')
804                    return
805                dlc.on_frame(frame)
806
807    def on_frame(self, frame: RFCOMM_Frame) -> None:
808        handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
809        handler(frame)
810
811    def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
812        if self.state != Multiplexer.State.INIT:
813            logger.debug('not in INIT state, ignoring SABM')
814            return
815        self.change_state(Multiplexer.State.CONNECTED)
816        self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
817
818    def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
819        if self.state == Multiplexer.State.CONNECTING:
820            self.change_state(Multiplexer.State.CONNECTED)
821            if self.connection_result:
822                self.connection_result.set_result(0)
823                self.connection_result = None
824        elif self.state == Multiplexer.State.DISCONNECTING:
825            self.change_state(Multiplexer.State.DISCONNECTED)
826            if self.disconnection_result:
827                self.disconnection_result.set_result(None)
828                self.disconnection_result = None
829
830    def on_dm_frame(self, _frame: RFCOMM_Frame) -> None:
831        if self.state == Multiplexer.State.OPENING:
832            self.change_state(Multiplexer.State.CONNECTED)
833            if self.open_result:
834                self.open_result.set_exception(
835                    core.ConnectionError(
836                        core.ConnectionError.CONNECTION_REFUSED,
837                        BT_BR_EDR_TRANSPORT,
838                        self.l2cap_channel.connection.peer_address,
839                        'rfcomm',
840                    )
841                )
842                self.open_result = None
843        else:
844            logger.warning(f'unexpected state for DM: {self}')
845
846    def on_disc_frame(self, _frame: RFCOMM_Frame) -> None:
847        self.change_state(Multiplexer.State.DISCONNECTED)
848        self.send_frame(
849            RFCOMM_Frame.ua(
850                c_r=0 if self.role == Multiplexer.Role.INITIATOR else 1, dlci=0
851            )
852        )
853
854    def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
855        (mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
856
857        if mcc_type == MccType.PN:
858            pn = RFCOMM_MCC_PN.from_bytes(value)
859            self.on_mcc_pn(c_r, pn)
860        elif mcc_type == MccType.MSC:
861            mcs = RFCOMM_MCC_MSC.from_bytes(value)
862            self.on_mcc_msc(c_r, mcs)
863
864    def on_ui_frame(self, frame: RFCOMM_Frame) -> None:
865        pass
866
867    def on_mcc_pn(self, c_r: bool, pn: RFCOMM_MCC_PN) -> None:
868        if c_r:
869            # Command
870            logger.debug(f'<<< PN Command: {pn}')
871
872            # Check with the multiplexer if there's an acceptor for this channel
873            if pn.dlci & 1:
874                # Not expected, this is an initiator-side number
875                # TODO: error out
876                logger.warning(f'invalid DLCI: {pn.dlci}')
877            else:
878                if self.acceptor:
879                    channel_number = pn.dlci >> 1
880                    if dlc_params := self.acceptor(channel_number):
881                        # Create a new DLC
882                        dlc = DLC(
883                            self,
884                            dlci=pn.dlci,
885                            tx_max_frame_size=pn.max_frame_size,
886                            tx_initial_credits=pn.initial_credits,
887                            rx_max_frame_size=dlc_params[0],
888                            rx_initial_credits=dlc_params[1],
889                        )
890                        self.dlcs[pn.dlci] = dlc
891
892                        # Re-emit the handshake completion event
893                        dlc.on('open', lambda: self.emit('dlc', dlc))
894
895                        # Respond to complete the handshake
896                        dlc.accept()
897                    else:
898                        # No acceptor, we're in Disconnected Mode
899                        self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci))
900                else:
901                    # No acceptor?? shouldn't happen
902                    logger.warning(color('!!! no acceptor registered', 'red'))
903        else:
904            # Response
905            logger.debug(f'>>> PN Response: {pn}')
906            if self.state == Multiplexer.State.OPENING:
907                assert self.open_pn
908                dlc = DLC(
909                    self,
910                    dlci=pn.dlci,
911                    tx_max_frame_size=pn.max_frame_size,
912                    tx_initial_credits=pn.initial_credits,
913                    rx_max_frame_size=self.open_pn.max_frame_size,
914                    rx_initial_credits=self.open_pn.initial_credits,
915                )
916                self.dlcs[pn.dlci] = dlc
917                self.open_pn = None
918                dlc.connect()
919            else:
920                logger.warning('ignoring PN response')
921
922    def on_mcc_msc(self, c_r: bool, msc: RFCOMM_MCC_MSC) -> None:
923        dlc = self.dlcs.get(msc.dlci)
924        if dlc is None:
925            logger.warning(f'no dlc for DLCI {msc.dlci}')
926            return
927        dlc.on_mcc_msc(c_r, msc)
928
929    async def connect(self) -> None:
930        if self.state != Multiplexer.State.INIT:
931            raise InvalidStateError('invalid state')
932
933        self.change_state(Multiplexer.State.CONNECTING)
934        self.connection_result = asyncio.get_running_loop().create_future()
935        self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
936        return await self.connection_result
937
938    async def disconnect(self) -> None:
939        if self.state != Multiplexer.State.CONNECTED:
940            return
941
942        self.disconnection_result = asyncio.get_running_loop().create_future()
943        self.change_state(Multiplexer.State.DISCONNECTING)
944        self.send_frame(
945            RFCOMM_Frame.disc(
946                c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=0
947            )
948        )
949        await self.disconnection_result
950
951    async def open_dlc(
952        self,
953        channel: int,
954        max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
955        initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
956    ) -> DLC:
957        if self.state != Multiplexer.State.CONNECTED:
958            if self.state == Multiplexer.State.OPENING:
959                raise InvalidStateError('open already in progress')
960
961            raise InvalidStateError('not connected')
962
963        self.open_pn = RFCOMM_MCC_PN(
964            dlci=channel << 1,
965            cl=0xF0,
966            priority=7,
967            ack_timer=0,
968            max_frame_size=max_frame_size,
969            max_retransmissions=0,
970            initial_credits=initial_credits,
971        )
972        mcc = RFCOMM_Frame.make_mcc(
973            mcc_type=MccType.PN, c_r=1, data=bytes(self.open_pn)
974        )
975        logger.debug(f'>>> Sending MCC: {self.open_pn}')
976        self.open_result = asyncio.get_running_loop().create_future()
977        self.change_state(Multiplexer.State.OPENING)
978        self.send_frame(
979            RFCOMM_Frame.uih(
980                c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0,
981                dlci=0,
982                information=mcc,
983            )
984        )
985        return await self.open_result
986
987    def on_dlc_open_complete(self, dlc: DLC) -> None:
988        logger.debug(f'DLC [{dlc.dlci}] open complete')
989
990        self.change_state(Multiplexer.State.CONNECTED)
991
992        if self.open_result:
993            self.open_result.set_result(dlc)
994            self.open_result = None
995
996    def on_dlc_disconnection(self, dlc: DLC) -> None:
997        logger.debug(f'DLC [{dlc.dlci}] disconnection')
998        self.dlcs.pop(dlc.dlci, None)
999
1000    def on_l2cap_channel_close(self) -> None:
1001        logger.debug('L2CAP channel closed, cleaning up')
1002        if self.open_result:
1003            self.open_result.cancel()
1004            self.open_result = None
1005        if self.disconnection_result:
1006            self.disconnection_result.cancel()
1007            self.disconnection_result = None
1008        for dlc in self.dlcs.values():
1009            dlc.abort()
1010
1011    def __str__(self) -> str:
1012        return f'Multiplexer(state={self.state.name})'
1013
1014
1015# -----------------------------------------------------------------------------
1016class Client:
1017    multiplexer: Optional[Multiplexer]
1018    l2cap_channel: Optional[l2cap.ClassicChannel]
1019
1020    def __init__(
1021        self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
1022    ) -> None:
1023        self.connection = connection
1024        self.l2cap_mtu = l2cap_mtu
1025        self.l2cap_channel = None
1026        self.multiplexer = None
1027
1028    async def start(self) -> Multiplexer:
1029        # Create a new L2CAP connection
1030        try:
1031            self.l2cap_channel = await self.connection.create_l2cap_channel(
1032                spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu)
1033            )
1034        except ProtocolError as error:
1035            logger.warning(f'L2CAP connection failed: {error}')
1036            raise
1037
1038        assert self.l2cap_channel is not None
1039        # Create a multiplexer to manage DLCs with the server
1040        self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.Role.INITIATOR)
1041
1042        # Connect the multiplexer
1043        await self.multiplexer.connect()
1044
1045        return self.multiplexer
1046
1047    async def shutdown(self) -> None:
1048        if self.multiplexer is None:
1049            return
1050        # Disconnect the multiplexer
1051        await self.multiplexer.disconnect()
1052        self.multiplexer = None
1053
1054        # Close the L2CAP channel
1055        if self.l2cap_channel:
1056            await self.l2cap_channel.disconnect()
1057            self.l2cap_channel = None
1058
1059    async def __aenter__(self) -> Multiplexer:
1060        return await self.start()
1061
1062    async def __aexit__(self, *args) -> None:
1063        await self.shutdown()
1064
1065
1066# -----------------------------------------------------------------------------
1067class Server(EventEmitter):
1068    def __init__(
1069        self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
1070    ) -> None:
1071        super().__init__()
1072        self.device = device
1073        self.acceptors: Dict[int, Callable[[DLC], None]] = {}
1074        self.dlc_configs: Dict[int, Tuple[int, int]] = {}
1075
1076        # Register ourselves with the L2CAP channel manager
1077        self.l2cap_server = device.create_l2cap_server(
1078            spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu),
1079            handler=self.on_connection,
1080        )
1081
1082    def listen(
1083        self,
1084        acceptor: Callable[[DLC], None],
1085        channel: int = 0,
1086        max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
1087        initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
1088    ) -> int:
1089        if channel:
1090            if channel in self.acceptors:
1091                # Busy
1092                return 0
1093        else:
1094            # Find a free channel number
1095            for candidate in range(
1096                RFCOMM_DYNAMIC_CHANNEL_NUMBER_START,
1097                RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1,
1098            ):
1099                if candidate not in self.acceptors:
1100                    channel = candidate
1101                    break
1102
1103            if channel == 0:
1104                # All channels used...
1105                return 0
1106
1107        self.acceptors[channel] = acceptor
1108        self.dlc_configs[channel] = (max_frame_size, initial_credits)
1109
1110        return channel
1111
1112    def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
1113        logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
1114        l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
1115
1116    def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
1117        logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
1118
1119        # Create a new multiplexer for the channel
1120        multiplexer = Multiplexer(l2cap_channel, Multiplexer.Role.RESPONDER)
1121        multiplexer.acceptor = self.accept_dlc
1122        multiplexer.on('dlc', self.on_dlc)
1123
1124        # Notify
1125        self.emit('start', multiplexer)
1126
1127    def accept_dlc(self, channel_number: int) -> Optional[Tuple[int, int]]:
1128        return self.dlc_configs.get(channel_number)
1129
1130    def on_dlc(self, dlc: DLC) -> None:
1131        logger.debug(f'@@@ new DLC connected: {dlc}')
1132
1133        # Let the acceptor know
1134        if acceptor := self.acceptors.get(dlc.dlci >> 1):
1135            acceptor(dlc)
1136
1137    def __enter__(self) -> Self:
1138        return self
1139
1140    def __exit__(self, *args) -> None:
1141        self.l2cap_server.close()
1142