1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19from dataclasses import dataclass 20import logging 21import enum 22import struct 23 24from abc import ABC, abstractmethod 25from pyee import EventEmitter 26from typing import Optional, Callable, TYPE_CHECKING 27from typing_extensions import override 28 29from bumble import l2cap, device 30from bumble.colors import color 31from bumble.core import InvalidStateError, ProtocolError 32from .hci import Address 33 34 35# ----------------------------------------------------------------------------- 36# Logging 37# ----------------------------------------------------------------------------- 38logger = logging.getLogger(__name__) 39 40 41# ----------------------------------------------------------------------------- 42# Constants 43# ----------------------------------------------------------------------------- 44# fmt: on 45HID_CONTROL_PSM = 0x0011 46HID_INTERRUPT_PSM = 0x0013 47 48 49class Message: 50 message_type: MessageType 51 52 # Report types 53 class ReportType(enum.IntEnum): 54 OTHER_REPORT = 0x00 55 INPUT_REPORT = 0x01 56 OUTPUT_REPORT = 0x02 57 FEATURE_REPORT = 0x03 58 59 # Handshake parameters 60 class Handshake(enum.IntEnum): 61 SUCCESSFUL = 0x00 62 NOT_READY = 0x01 63 ERR_INVALID_REPORT_ID = 0x02 64 ERR_UNSUPPORTED_REQUEST = 0x03 65 ERR_INVALID_PARAMETER = 0x04 66 ERR_UNKNOWN = 0x0E 67 ERR_FATAL = 0x0F 68 69 # Message Type 70 class MessageType(enum.IntEnum): 71 HANDSHAKE = 0x00 72 CONTROL = 0x01 73 GET_REPORT = 0x04 74 SET_REPORT = 0x05 75 GET_PROTOCOL = 0x06 76 SET_PROTOCOL = 0x07 77 DATA = 0x0A 78 79 # Protocol modes 80 class ProtocolMode(enum.IntEnum): 81 BOOT_PROTOCOL = 0x00 82 REPORT_PROTOCOL = 0x01 83 84 # Control Operations 85 class ControlCommand(enum.IntEnum): 86 SUSPEND = 0x03 87 EXIT_SUSPEND = 0x04 88 VIRTUAL_CABLE_UNPLUG = 0x05 89 90 # Class Method to derive header 91 @classmethod 92 def header(cls, lower_bits: int = 0x00) -> bytes: 93 return bytes([(cls.message_type << 4) | lower_bits]) 94 95 96# HIDP messages 97@dataclass 98class GetReportMessage(Message): 99 report_type: int 100 report_id: int 101 buffer_size: int 102 message_type = Message.MessageType.GET_REPORT 103 104 def __bytes__(self) -> bytes: 105 packet_bytes = bytearray() 106 packet_bytes.append(self.report_id) 107 if self.buffer_size == 0: 108 return self.header(self.report_type) + packet_bytes 109 else: 110 return ( 111 self.header(0x08 | self.report_type) 112 + packet_bytes 113 + struct.pack("<H", self.buffer_size) 114 ) 115 116 117@dataclass 118class SetReportMessage(Message): 119 report_type: int 120 data: bytes 121 message_type = Message.MessageType.SET_REPORT 122 123 def __bytes__(self) -> bytes: 124 return self.header(self.report_type) + self.data 125 126 127@dataclass 128class SendControlData(Message): 129 report_type: int 130 data: bytes 131 message_type = Message.MessageType.DATA 132 133 def __bytes__(self) -> bytes: 134 return self.header(self.report_type) + self.data 135 136 137@dataclass 138class GetProtocolMessage(Message): 139 message_type = Message.MessageType.GET_PROTOCOL 140 141 def __bytes__(self) -> bytes: 142 return self.header() 143 144 145@dataclass 146class SetProtocolMessage(Message): 147 protocol_mode: int 148 message_type = Message.MessageType.SET_PROTOCOL 149 150 def __bytes__(self) -> bytes: 151 return self.header(self.protocol_mode) 152 153 154@dataclass 155class Suspend(Message): 156 message_type = Message.MessageType.CONTROL 157 158 def __bytes__(self) -> bytes: 159 return self.header(Message.ControlCommand.SUSPEND) 160 161 162@dataclass 163class ExitSuspend(Message): 164 message_type = Message.MessageType.CONTROL 165 166 def __bytes__(self) -> bytes: 167 return self.header(Message.ControlCommand.EXIT_SUSPEND) 168 169 170@dataclass 171class VirtualCableUnplug(Message): 172 message_type = Message.MessageType.CONTROL 173 174 def __bytes__(self) -> bytes: 175 return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG) 176 177 178# Device sends input report, host sends output report. 179@dataclass 180class SendData(Message): 181 data: bytes 182 report_type: int 183 message_type = Message.MessageType.DATA 184 185 def __bytes__(self) -> bytes: 186 return self.header(self.report_type) + self.data 187 188 189@dataclass 190class SendHandshakeMessage(Message): 191 result_code: int 192 message_type = Message.MessageType.HANDSHAKE 193 194 def __bytes__(self) -> bytes: 195 return self.header(self.result_code) 196 197 198# ----------------------------------------------------------------------------- 199class HID(ABC, EventEmitter): 200 l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None 201 l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None 202 connection: Optional[device.Connection] = None 203 204 class Role(enum.IntEnum): 205 HOST = 0x00 206 DEVICE = 0x01 207 208 def __init__(self, device: device.Device, role: Role) -> None: 209 super().__init__() 210 self.remote_device_bd_address: Optional[Address] = None 211 self.device = device 212 self.role = role 213 214 # Register ourselves with the L2CAP channel manager 215 device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection) 216 device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection) 217 218 device.on('connection', self.on_device_connection) 219 220 async def connect_control_channel(self) -> None: 221 # Create a new L2CAP connection - control channel 222 try: 223 self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect( 224 self.connection, HID_CONTROL_PSM 225 ) 226 except ProtocolError: 227 logging.exception(f'L2CAP connection failed.') 228 raise 229 230 assert self.l2cap_ctrl_channel is not None 231 # Become a sink for the L2CAP channel 232 self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu 233 234 async def connect_interrupt_channel(self) -> None: 235 # Create a new L2CAP connection - interrupt channel 236 try: 237 self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect( 238 self.connection, HID_INTERRUPT_PSM 239 ) 240 except ProtocolError: 241 logging.exception(f'L2CAP connection failed.') 242 raise 243 244 assert self.l2cap_intr_channel is not None 245 # Become a sink for the L2CAP channel 246 self.l2cap_intr_channel.sink = self.on_intr_pdu 247 248 async def disconnect_interrupt_channel(self) -> None: 249 if self.l2cap_intr_channel is None: 250 raise InvalidStateError('invalid state') 251 channel = self.l2cap_intr_channel 252 self.l2cap_intr_channel = None 253 await channel.disconnect() 254 255 async def disconnect_control_channel(self) -> None: 256 if self.l2cap_ctrl_channel is None: 257 raise InvalidStateError('invalid state') 258 channel = self.l2cap_ctrl_channel 259 self.l2cap_ctrl_channel = None 260 await channel.disconnect() 261 262 def on_device_connection(self, connection: device.Connection) -> None: 263 self.connection = connection 264 self.remote_device_bd_address = connection.peer_address 265 connection.on('disconnection', self.on_device_disconnection) 266 267 def on_device_disconnection(self, reason: int) -> None: 268 self.connection = None 269 270 def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: 271 logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') 272 l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) 273 l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel)) 274 275 def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: 276 if l2cap_channel.psm == HID_CONTROL_PSM: 277 self.l2cap_ctrl_channel = l2cap_channel 278 self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu 279 else: 280 self.l2cap_intr_channel = l2cap_channel 281 self.l2cap_intr_channel.sink = self.on_intr_pdu 282 logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') 283 284 def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None: 285 if l2cap_channel.psm == HID_CONTROL_PSM: 286 self.l2cap_ctrl_channel = None 287 else: 288 self.l2cap_intr_channel = None 289 logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}') 290 291 @abstractmethod 292 def on_ctrl_pdu(self, pdu: bytes) -> None: 293 pass 294 295 def on_intr_pdu(self, pdu: bytes) -> None: 296 logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') 297 self.emit("interrupt_data", pdu) 298 299 def send_pdu_on_ctrl(self, msg: bytes) -> None: 300 assert self.l2cap_ctrl_channel 301 self.l2cap_ctrl_channel.send_pdu(msg) 302 303 def send_pdu_on_intr(self, msg: bytes) -> None: 304 assert self.l2cap_intr_channel 305 self.l2cap_intr_channel.send_pdu(msg) 306 307 def send_data(self, data: bytes) -> None: 308 if self.role == HID.Role.HOST: 309 report_type = Message.ReportType.OUTPUT_REPORT 310 else: 311 report_type = Message.ReportType.INPUT_REPORT 312 msg = SendData(data, report_type) 313 hid_message = bytes(msg) 314 if self.l2cap_intr_channel is not None: 315 logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}') 316 self.send_pdu_on_intr(hid_message) 317 318 def virtual_cable_unplug(self) -> None: 319 msg = VirtualCableUnplug() 320 hid_message = bytes(msg) 321 logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}') 322 self.send_pdu_on_ctrl(hid_message) 323 324 325# ----------------------------------------------------------------------------- 326 327 328class Device(HID): 329 class GetSetReturn(enum.IntEnum): 330 FAILURE = 0x00 331 REPORT_ID_NOT_FOUND = 0x01 332 ERR_UNSUPPORTED_REQUEST = 0x02 333 ERR_UNKNOWN = 0x03 334 ERR_INVALID_PARAMETER = 0x04 335 SUCCESS = 0xFF 336 337 class GetSetStatus: 338 def __init__(self) -> None: 339 self.data = bytearray() 340 self.status = 0 341 342 def __init__(self, device: device.Device) -> None: 343 super().__init__(device, HID.Role.DEVICE) 344 get_report_cb: Optional[Callable[[int, int, int], None]] = None 345 set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None 346 get_protocol_cb: Optional[Callable[[], None]] = None 347 set_protocol_cb: Optional[Callable[[int], None]] = None 348 349 @override 350 def on_ctrl_pdu(self, pdu: bytes) -> None: 351 logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') 352 param = pdu[0] & 0x0F 353 message_type = pdu[0] >> 4 354 355 if message_type == Message.MessageType.GET_REPORT: 356 logger.debug('<<< HID GET REPORT') 357 self.handle_get_report(pdu) 358 elif message_type == Message.MessageType.SET_REPORT: 359 logger.debug('<<< HID SET REPORT') 360 self.handle_set_report(pdu) 361 elif message_type == Message.MessageType.GET_PROTOCOL: 362 logger.debug('<<< HID GET PROTOCOL') 363 self.handle_get_protocol(pdu) 364 elif message_type == Message.MessageType.SET_PROTOCOL: 365 logger.debug('<<< HID SET PROTOCOL') 366 self.handle_set_protocol(pdu) 367 elif message_type == Message.MessageType.DATA: 368 logger.debug('<<< HID CONTROL DATA') 369 self.emit('control_data', pdu) 370 elif message_type == Message.MessageType.CONTROL: 371 if param == Message.ControlCommand.SUSPEND: 372 logger.debug('<<< HID SUSPEND') 373 self.emit('suspend') 374 elif param == Message.ControlCommand.EXIT_SUSPEND: 375 logger.debug('<<< HID EXIT SUSPEND') 376 self.emit('exit_suspend') 377 elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: 378 logger.debug('<<< HID VIRTUAL CABLE UNPLUG') 379 self.emit('virtual_cable_unplug') 380 else: 381 logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') 382 else: 383 logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') 384 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 385 386 def send_handshake_message(self, result_code: int) -> None: 387 msg = SendHandshakeMessage(result_code) 388 hid_message = bytes(msg) 389 logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}') 390 self.send_pdu_on_ctrl(hid_message) 391 392 def send_control_data(self, report_type: int, data: bytes): 393 msg = SendControlData(report_type=report_type, data=data) 394 hid_message = bytes(msg) 395 logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}') 396 self.send_pdu_on_ctrl(hid_message) 397 398 def handle_get_report(self, pdu: bytes): 399 if self.get_report_cb is None: 400 logger.debug("GetReport callback not registered !!") 401 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 402 return 403 report_type = pdu[0] & 0x03 404 buffer_flag = (pdu[0] & 0x08) >> 3 405 report_id = pdu[1] 406 logger.debug(f"buffer_flag: {buffer_flag}") 407 if buffer_flag == 1: 408 buffer_size = (pdu[3] << 8) | pdu[2] 409 else: 410 buffer_size = 0 411 412 ret = self.get_report_cb(report_id, report_type, buffer_size) 413 assert ret is not None 414 if ret.status == self.GetSetReturn.FAILURE: 415 self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) 416 elif ret.status == self.GetSetReturn.SUCCESS: 417 data = bytearray() 418 data.append(report_id) 419 data.extend(ret.data) 420 if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr] 421 self.send_control_data(report_type=report_type, data=data) 422 else: 423 self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) 424 elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: 425 self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) 426 elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: 427 self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) 428 elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: 429 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 430 431 def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None: 432 self.get_report_cb = cb 433 logger.debug("GetReport callback registered successfully") 434 435 def handle_set_report(self, pdu: bytes): 436 if self.set_report_cb is None: 437 logger.debug("SetReport callback not registered !!") 438 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 439 return 440 report_type = pdu[0] & 0x03 441 report_id = pdu[1] 442 report_data = pdu[2:] 443 report_size = len(report_data) + 1 444 ret = self.set_report_cb(report_id, report_type, report_size, report_data) 445 assert ret is not None 446 if ret.status == self.GetSetReturn.SUCCESS: 447 self.send_handshake_message(Message.Handshake.SUCCESSFUL) 448 elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: 449 self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) 450 elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: 451 self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) 452 else: 453 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 454 455 def register_set_report_cb( 456 self, cb: Callable[[int, int, int, bytes], None] 457 ) -> None: 458 self.set_report_cb = cb 459 logger.debug("SetReport callback registered successfully") 460 461 def handle_get_protocol(self, pdu: bytes): 462 if self.get_protocol_cb is None: 463 logger.debug("GetProtocol callback not registered !!") 464 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 465 return 466 ret = self.get_protocol_cb() 467 assert ret is not None 468 if ret.status == self.GetSetReturn.SUCCESS: 469 self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) 470 else: 471 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 472 473 def register_get_protocol_cb(self, cb: Callable[[], None]) -> None: 474 self.get_protocol_cb = cb 475 logger.debug("GetProtocol callback registered successfully") 476 477 def handle_set_protocol(self, pdu: bytes): 478 if self.set_protocol_cb is None: 479 logger.debug("SetProtocol callback not registered !!") 480 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 481 return 482 ret = self.set_protocol_cb(pdu[0] & 0x01) 483 assert ret is not None 484 if ret.status == self.GetSetReturn.SUCCESS: 485 self.send_handshake_message(Message.Handshake.SUCCESSFUL) 486 else: 487 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 488 489 def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None: 490 self.set_protocol_cb = cb 491 logger.debug("SetProtocol callback registered successfully") 492 493 494# ----------------------------------------------------------------------------- 495class Host(HID): 496 def __init__(self, device: device.Device) -> None: 497 super().__init__(device, HID.Role.HOST) 498 499 def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: 500 msg = GetReportMessage( 501 report_type=report_type, report_id=report_id, buffer_size=buffer_size 502 ) 503 hid_message = bytes(msg) 504 logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}') 505 self.send_pdu_on_ctrl(hid_message) 506 507 def set_report(self, report_type: int, data: bytes) -> None: 508 msg = SetReportMessage(report_type=report_type, data=data) 509 hid_message = bytes(msg) 510 logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}') 511 self.send_pdu_on_ctrl(hid_message) 512 513 def get_protocol(self) -> None: 514 msg = GetProtocolMessage() 515 hid_message = bytes(msg) 516 logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}') 517 self.send_pdu_on_ctrl(hid_message) 518 519 def set_protocol(self, protocol_mode: int) -> None: 520 msg = SetProtocolMessage(protocol_mode=protocol_mode) 521 hid_message = bytes(msg) 522 logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}') 523 self.send_pdu_on_ctrl(hid_message) 524 525 def suspend(self) -> None: 526 msg = Suspend() 527 hid_message = bytes(msg) 528 logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}') 529 self.send_pdu_on_ctrl(hid_message) 530 531 def exit_suspend(self) -> None: 532 msg = ExitSuspend() 533 hid_message = bytes(msg) 534 logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}') 535 self.send_pdu_on_ctrl(hid_message) 536 537 @override 538 def on_ctrl_pdu(self, pdu: bytes) -> None: 539 logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') 540 param = pdu[0] & 0x0F 541 message_type = pdu[0] >> 4 542 if message_type == Message.MessageType.HANDSHAKE: 543 logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}') 544 self.emit('handshake', Message.Handshake(param)) 545 elif message_type == Message.MessageType.DATA: 546 logger.debug('<<< HID CONTROL DATA') 547 self.emit('control_data', pdu) 548 elif message_type == Message.MessageType.CONTROL: 549 if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: 550 logger.debug('<<< HID VIRTUAL CABLE UNPLUG') 551 self.emit('virtual_cable_unplug') 552 else: 553 logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') 554 else: 555 logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') 556