• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19from dataclasses import dataclass
20import logging
21import enum
22import struct
23
24from abc import ABC, abstractmethod
25from pyee import EventEmitter
26from typing import Optional, Callable, TYPE_CHECKING
27from typing_extensions import override
28
29from bumble import l2cap, device
30from bumble.colors import color
31from bumble.core import InvalidStateError, ProtocolError
32from .hci import Address
33
34
35# -----------------------------------------------------------------------------
36# Logging
37# -----------------------------------------------------------------------------
38logger = logging.getLogger(__name__)
39
40
41# -----------------------------------------------------------------------------
42# Constants
43# -----------------------------------------------------------------------------
44# fmt: on
45HID_CONTROL_PSM = 0x0011
46HID_INTERRUPT_PSM = 0x0013
47
48
49class Message:
50    message_type: MessageType
51
52    # Report types
53    class ReportType(enum.IntEnum):
54        OTHER_REPORT = 0x00
55        INPUT_REPORT = 0x01
56        OUTPUT_REPORT = 0x02
57        FEATURE_REPORT = 0x03
58
59    # Handshake parameters
60    class Handshake(enum.IntEnum):
61        SUCCESSFUL = 0x00
62        NOT_READY = 0x01
63        ERR_INVALID_REPORT_ID = 0x02
64        ERR_UNSUPPORTED_REQUEST = 0x03
65        ERR_INVALID_PARAMETER = 0x04
66        ERR_UNKNOWN = 0x0E
67        ERR_FATAL = 0x0F
68
69    # Message Type
70    class MessageType(enum.IntEnum):
71        HANDSHAKE = 0x00
72        CONTROL = 0x01
73        GET_REPORT = 0x04
74        SET_REPORT = 0x05
75        GET_PROTOCOL = 0x06
76        SET_PROTOCOL = 0x07
77        DATA = 0x0A
78
79    # Protocol modes
80    class ProtocolMode(enum.IntEnum):
81        BOOT_PROTOCOL = 0x00
82        REPORT_PROTOCOL = 0x01
83
84    # Control Operations
85    class ControlCommand(enum.IntEnum):
86        SUSPEND = 0x03
87        EXIT_SUSPEND = 0x04
88        VIRTUAL_CABLE_UNPLUG = 0x05
89
90    # Class Method to derive header
91    @classmethod
92    def header(cls, lower_bits: int = 0x00) -> bytes:
93        return bytes([(cls.message_type << 4) | lower_bits])
94
95
96# HIDP messages
97@dataclass
98class GetReportMessage(Message):
99    report_type: int
100    report_id: int
101    buffer_size: int
102    message_type = Message.MessageType.GET_REPORT
103
104    def __bytes__(self) -> bytes:
105        packet_bytes = bytearray()
106        packet_bytes.append(self.report_id)
107        if self.buffer_size == 0:
108            return self.header(self.report_type) + packet_bytes
109        else:
110            return (
111                self.header(0x08 | self.report_type)
112                + packet_bytes
113                + struct.pack("<H", self.buffer_size)
114            )
115
116
117@dataclass
118class SetReportMessage(Message):
119    report_type: int
120    data: bytes
121    message_type = Message.MessageType.SET_REPORT
122
123    def __bytes__(self) -> bytes:
124        return self.header(self.report_type) + self.data
125
126
127@dataclass
128class SendControlData(Message):
129    report_type: int
130    data: bytes
131    message_type = Message.MessageType.DATA
132
133    def __bytes__(self) -> bytes:
134        return self.header(self.report_type) + self.data
135
136
137@dataclass
138class GetProtocolMessage(Message):
139    message_type = Message.MessageType.GET_PROTOCOL
140
141    def __bytes__(self) -> bytes:
142        return self.header()
143
144
145@dataclass
146class SetProtocolMessage(Message):
147    protocol_mode: int
148    message_type = Message.MessageType.SET_PROTOCOL
149
150    def __bytes__(self) -> bytes:
151        return self.header(self.protocol_mode)
152
153
154@dataclass
155class Suspend(Message):
156    message_type = Message.MessageType.CONTROL
157
158    def __bytes__(self) -> bytes:
159        return self.header(Message.ControlCommand.SUSPEND)
160
161
162@dataclass
163class ExitSuspend(Message):
164    message_type = Message.MessageType.CONTROL
165
166    def __bytes__(self) -> bytes:
167        return self.header(Message.ControlCommand.EXIT_SUSPEND)
168
169
170@dataclass
171class VirtualCableUnplug(Message):
172    message_type = Message.MessageType.CONTROL
173
174    def __bytes__(self) -> bytes:
175        return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
176
177
178# Device sends input report, host sends output report.
179@dataclass
180class SendData(Message):
181    data: bytes
182    report_type: int
183    message_type = Message.MessageType.DATA
184
185    def __bytes__(self) -> bytes:
186        return self.header(self.report_type) + self.data
187
188
189@dataclass
190class SendHandshakeMessage(Message):
191    result_code: int
192    message_type = Message.MessageType.HANDSHAKE
193
194    def __bytes__(self) -> bytes:
195        return self.header(self.result_code)
196
197
198# -----------------------------------------------------------------------------
199class HID(ABC, EventEmitter):
200    l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
201    l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
202    connection: Optional[device.Connection] = None
203
204    class Role(enum.IntEnum):
205        HOST = 0x00
206        DEVICE = 0x01
207
208    def __init__(self, device: device.Device, role: Role) -> None:
209        super().__init__()
210        self.remote_device_bd_address: Optional[Address] = None
211        self.device = device
212        self.role = role
213
214        # Register ourselves with the L2CAP channel manager
215        device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
216        device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
217
218        device.on('connection', self.on_device_connection)
219
220    async def connect_control_channel(self) -> None:
221        # Create a new L2CAP connection - control channel
222        try:
223            self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
224                self.connection, HID_CONTROL_PSM
225            )
226        except ProtocolError:
227            logging.exception(f'L2CAP connection failed.')
228            raise
229
230        assert self.l2cap_ctrl_channel is not None
231        # Become a sink for the L2CAP channel
232        self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
233
234    async def connect_interrupt_channel(self) -> None:
235        # Create a new L2CAP connection - interrupt channel
236        try:
237            self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
238                self.connection, HID_INTERRUPT_PSM
239            )
240        except ProtocolError:
241            logging.exception(f'L2CAP connection failed.')
242            raise
243
244        assert self.l2cap_intr_channel is not None
245        # Become a sink for the L2CAP channel
246        self.l2cap_intr_channel.sink = self.on_intr_pdu
247
248    async def disconnect_interrupt_channel(self) -> None:
249        if self.l2cap_intr_channel is None:
250            raise InvalidStateError('invalid state')
251        channel = self.l2cap_intr_channel
252        self.l2cap_intr_channel = None
253        await channel.disconnect()
254
255    async def disconnect_control_channel(self) -> None:
256        if self.l2cap_ctrl_channel is None:
257            raise InvalidStateError('invalid state')
258        channel = self.l2cap_ctrl_channel
259        self.l2cap_ctrl_channel = None
260        await channel.disconnect()
261
262    def on_device_connection(self, connection: device.Connection) -> None:
263        self.connection = connection
264        self.remote_device_bd_address = connection.peer_address
265        connection.on('disconnection', self.on_device_disconnection)
266
267    def on_device_disconnection(self, reason: int) -> None:
268        self.connection = None
269
270    def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
271        logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
272        l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
273        l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
274
275    def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
276        if l2cap_channel.psm == HID_CONTROL_PSM:
277            self.l2cap_ctrl_channel = l2cap_channel
278            self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
279        else:
280            self.l2cap_intr_channel = l2cap_channel
281            self.l2cap_intr_channel.sink = self.on_intr_pdu
282        logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
283
284    def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
285        if l2cap_channel.psm == HID_CONTROL_PSM:
286            self.l2cap_ctrl_channel = None
287        else:
288            self.l2cap_intr_channel = None
289        logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
290
291    @abstractmethod
292    def on_ctrl_pdu(self, pdu: bytes) -> None:
293        pass
294
295    def on_intr_pdu(self, pdu: bytes) -> None:
296        logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
297        self.emit("interrupt_data", pdu)
298
299    def send_pdu_on_ctrl(self, msg: bytes) -> None:
300        assert self.l2cap_ctrl_channel
301        self.l2cap_ctrl_channel.send_pdu(msg)
302
303    def send_pdu_on_intr(self, msg: bytes) -> None:
304        assert self.l2cap_intr_channel
305        self.l2cap_intr_channel.send_pdu(msg)
306
307    def send_data(self, data: bytes) -> None:
308        if self.role == HID.Role.HOST:
309            report_type = Message.ReportType.OUTPUT_REPORT
310        else:
311            report_type = Message.ReportType.INPUT_REPORT
312        msg = SendData(data, report_type)
313        hid_message = bytes(msg)
314        if self.l2cap_intr_channel is not None:
315            logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
316            self.send_pdu_on_intr(hid_message)
317
318    def virtual_cable_unplug(self) -> None:
319        msg = VirtualCableUnplug()
320        hid_message = bytes(msg)
321        logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
322        self.send_pdu_on_ctrl(hid_message)
323
324
325# -----------------------------------------------------------------------------
326
327
328class Device(HID):
329    class GetSetReturn(enum.IntEnum):
330        FAILURE = 0x00
331        REPORT_ID_NOT_FOUND = 0x01
332        ERR_UNSUPPORTED_REQUEST = 0x02
333        ERR_UNKNOWN = 0x03
334        ERR_INVALID_PARAMETER = 0x04
335        SUCCESS = 0xFF
336
337    class GetSetStatus:
338        def __init__(self) -> None:
339            self.data = bytearray()
340            self.status = 0
341
342    def __init__(self, device: device.Device) -> None:
343        super().__init__(device, HID.Role.DEVICE)
344        get_report_cb: Optional[Callable[[int, int, int], None]] = None
345        set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
346        get_protocol_cb: Optional[Callable[[], None]] = None
347        set_protocol_cb: Optional[Callable[[int], None]] = None
348
349    @override
350    def on_ctrl_pdu(self, pdu: bytes) -> None:
351        logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
352        param = pdu[0] & 0x0F
353        message_type = pdu[0] >> 4
354
355        if message_type == Message.MessageType.GET_REPORT:
356            logger.debug('<<< HID GET REPORT')
357            self.handle_get_report(pdu)
358        elif message_type == Message.MessageType.SET_REPORT:
359            logger.debug('<<< HID SET REPORT')
360            self.handle_set_report(pdu)
361        elif message_type == Message.MessageType.GET_PROTOCOL:
362            logger.debug('<<< HID GET PROTOCOL')
363            self.handle_get_protocol(pdu)
364        elif message_type == Message.MessageType.SET_PROTOCOL:
365            logger.debug('<<< HID SET PROTOCOL')
366            self.handle_set_protocol(pdu)
367        elif message_type == Message.MessageType.DATA:
368            logger.debug('<<< HID CONTROL DATA')
369            self.emit('control_data', pdu)
370        elif message_type == Message.MessageType.CONTROL:
371            if param == Message.ControlCommand.SUSPEND:
372                logger.debug('<<< HID SUSPEND')
373                self.emit('suspend')
374            elif param == Message.ControlCommand.EXIT_SUSPEND:
375                logger.debug('<<< HID EXIT SUSPEND')
376                self.emit('exit_suspend')
377            elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
378                logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
379                self.emit('virtual_cable_unplug')
380            else:
381                logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
382        else:
383            logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
384            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
385
386    def send_handshake_message(self, result_code: int) -> None:
387        msg = SendHandshakeMessage(result_code)
388        hid_message = bytes(msg)
389        logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
390        self.send_pdu_on_ctrl(hid_message)
391
392    def send_control_data(self, report_type: int, data: bytes):
393        msg = SendControlData(report_type=report_type, data=data)
394        hid_message = bytes(msg)
395        logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
396        self.send_pdu_on_ctrl(hid_message)
397
398    def handle_get_report(self, pdu: bytes):
399        if self.get_report_cb is None:
400            logger.debug("GetReport callback not registered !!")
401            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
402            return
403        report_type = pdu[0] & 0x03
404        buffer_flag = (pdu[0] & 0x08) >> 3
405        report_id = pdu[1]
406        logger.debug(f"buffer_flag: {buffer_flag}")
407        if buffer_flag == 1:
408            buffer_size = (pdu[3] << 8) | pdu[2]
409        else:
410            buffer_size = 0
411
412        ret = self.get_report_cb(report_id, report_type, buffer_size)
413        assert ret is not None
414        if ret.status == self.GetSetReturn.FAILURE:
415            self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
416        elif ret.status == self.GetSetReturn.SUCCESS:
417            data = bytearray()
418            data.append(report_id)
419            data.extend(ret.data)
420            if len(data) < self.l2cap_ctrl_channel.peer_mtu:  # type: ignore[union-attr]
421                self.send_control_data(report_type=report_type, data=data)
422            else:
423                self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
424        elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
425            self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
426        elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
427            self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
428        elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
429            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
430
431    def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
432        self.get_report_cb = cb
433        logger.debug("GetReport callback registered successfully")
434
435    def handle_set_report(self, pdu: bytes):
436        if self.set_report_cb is None:
437            logger.debug("SetReport callback not registered !!")
438            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
439            return
440        report_type = pdu[0] & 0x03
441        report_id = pdu[1]
442        report_data = pdu[2:]
443        report_size = len(report_data) + 1
444        ret = self.set_report_cb(report_id, report_type, report_size, report_data)
445        assert ret is not None
446        if ret.status == self.GetSetReturn.SUCCESS:
447            self.send_handshake_message(Message.Handshake.SUCCESSFUL)
448        elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
449            self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
450        elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
451            self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
452        else:
453            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
454
455    def register_set_report_cb(
456        self, cb: Callable[[int, int, int, bytes], None]
457    ) -> None:
458        self.set_report_cb = cb
459        logger.debug("SetReport callback registered successfully")
460
461    def handle_get_protocol(self, pdu: bytes):
462        if self.get_protocol_cb is None:
463            logger.debug("GetProtocol callback not registered !!")
464            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
465            return
466        ret = self.get_protocol_cb()
467        assert ret is not None
468        if ret.status == self.GetSetReturn.SUCCESS:
469            self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
470        else:
471            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
472
473    def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
474        self.get_protocol_cb = cb
475        logger.debug("GetProtocol callback registered successfully")
476
477    def handle_set_protocol(self, pdu: bytes):
478        if self.set_protocol_cb is None:
479            logger.debug("SetProtocol callback not registered !!")
480            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
481            return
482        ret = self.set_protocol_cb(pdu[0] & 0x01)
483        assert ret is not None
484        if ret.status == self.GetSetReturn.SUCCESS:
485            self.send_handshake_message(Message.Handshake.SUCCESSFUL)
486        else:
487            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
488
489    def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
490        self.set_protocol_cb = cb
491        logger.debug("SetProtocol callback registered successfully")
492
493
494# -----------------------------------------------------------------------------
495class Host(HID):
496    def __init__(self, device: device.Device) -> None:
497        super().__init__(device, HID.Role.HOST)
498
499    def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
500        msg = GetReportMessage(
501            report_type=report_type, report_id=report_id, buffer_size=buffer_size
502        )
503        hid_message = bytes(msg)
504        logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
505        self.send_pdu_on_ctrl(hid_message)
506
507    def set_report(self, report_type: int, data: bytes) -> None:
508        msg = SetReportMessage(report_type=report_type, data=data)
509        hid_message = bytes(msg)
510        logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
511        self.send_pdu_on_ctrl(hid_message)
512
513    def get_protocol(self) -> None:
514        msg = GetProtocolMessage()
515        hid_message = bytes(msg)
516        logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
517        self.send_pdu_on_ctrl(hid_message)
518
519    def set_protocol(self, protocol_mode: int) -> None:
520        msg = SetProtocolMessage(protocol_mode=protocol_mode)
521        hid_message = bytes(msg)
522        logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
523        self.send_pdu_on_ctrl(hid_message)
524
525    def suspend(self) -> None:
526        msg = Suspend()
527        hid_message = bytes(msg)
528        logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
529        self.send_pdu_on_ctrl(hid_message)
530
531    def exit_suspend(self) -> None:
532        msg = ExitSuspend()
533        hid_message = bytes(msg)
534        logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
535        self.send_pdu_on_ctrl(hid_message)
536
537    @override
538    def on_ctrl_pdu(self, pdu: bytes) -> None:
539        logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
540        param = pdu[0] & 0x0F
541        message_type = pdu[0] >> 4
542        if message_type == Message.MessageType.HANDSHAKE:
543            logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
544            self.emit('handshake', Message.Handshake(param))
545        elif message_type == Message.MessageType.DATA:
546            logger.debug('<<< HID CONTROL DATA')
547            self.emit('control_data', pdu)
548        elif message_type == Message.MessageType.CONTROL:
549            if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
550                logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
551                self.emit('virtual_cable_unplug')
552            else:
553                logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
554        else:
555            logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
556