1# Copyright 2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19 20from bumble.device import Connection 21try: 22 from packets import avdtp as av 23 from packets.avdtp import * 24except ImportError: 25 from .packets import avdtp as av 26 from .packets.avdtp import * 27from pyee import EventEmitter 28from typing import List, Literal, Union 29 30import asyncio 31import bumble.avdtp as avdtp 32import bumble.l2cap as l2cap 33import logging 34 35# ----------------------------------------------------------------------------- 36# Logging 37# ----------------------------------------------------------------------------- 38logger = logging.getLogger(__name__) 39 40av.print = lambda *args, **kwargs: logger.debug(" ".join(map(str, args))) 41 42 43class Any: 44 """Helper class that will match all other values. 45 Use an element of this class in expected packets to match any value 46 returned by the AVDTP signaling.""" 47 48 def __eq__(self, other) -> bool: 49 return True 50 51 def __format__(self, format_spec: str) -> str: 52 return "_" 53 54 def __len__(self) -> int: 55 return 1 56 57 def show(self, prefix: str = "") -> str: 58 return prefix + "_" 59 60 61RoleType = Optional[Literal["acceptor", "initiator"]] 62 63 64class SignalingChannel(EventEmitter): 65 connection: Connection 66 signaling_channel: Optional[l2cap.ClassicChannel] = None 67 transport_channel: Optional[l2cap.ClassicChannel] = None 68 avdtp_server: Optional[l2cap.ClassicChannelServer] = None 69 role: RoleType = None 70 any: Any = Any() 71 acp_seid: int = 0 72 int_seid: int = 0 73 74 def __init__(self, connection: Connection): 75 super().__init__() 76 self.connection = connection 77 self.signaling_queue = asyncio.Queue() 78 self.transport_queue = asyncio.Queue() 79 80 @classmethod 81 async def initiate(cls, connection: Connection) -> SignalingChannel: 82 channel = cls(connection) 83 await channel._initiate_signaling_channel() 84 return channel 85 86 @classmethod 87 def accept(cls, connection: Connection) -> SignalingChannel: 88 channel = cls(connection) 89 channel._accept_signaling_channel() 90 return channel 91 92 async def disconnect(self): 93 if not self.signaling_channel: 94 raise ValueError("No connected signaling channel") 95 await self.signaling_channel.disconnect() 96 self.signaling_channel = None 97 98 async def initiate_transport_channel(self): 99 if self.transport_channel: 100 raise ValueError("RTP L2CAP channel already exists") 101 self.transport_channel = await self.connection.create_l2cap_channel( 102 l2cap.ClassicChannelSpec(psm=avdtp.AVDTP_PSM)) 103 104 async def disconnect_transport_channel(self): 105 if not self.transport_channel: 106 raise ValueError("No connected RTP channel") 107 await self.transport_channel.disconnect() 108 self.transport_channel = None 109 110 async def expect_signal(self, expected_sig: Union[SignalingPacket, type], timeout: float = 3) -> SignalingPacket: 111 packet = await asyncio.wait_for(self.signaling_queue.get(), timeout=timeout) 112 sig = SignalingPacket.parse_all(packet) 113 114 if isinstance(expected_sig, type) and not isinstance(sig, expected_sig): 115 logger.error("Received unexpected signal") 116 logger.error(f"Expected signal: {expected_sig.__class__.__name__}") 117 logger.error("Received signal:") 118 sig.show() 119 raise ValueError(f"Received unexpected signal") 120 121 if isinstance(expected_sig, SignalingPacket) and sig != expected_sig: 122 logger.error("Received unexpected signal") 123 logger.error("Expected signal:") 124 expected_sig.show() 125 logger.error("Received signal:") 126 sig.show() 127 raise ValueError(f"Received unexpected signal") 128 129 logger.debug(f"<<< {self.connection.self_address} {self.role} received signal: <<<") 130 sig.show() 131 return sig 132 133 async def expect_media(self, timeout: float = 3.0) -> bytes: 134 packet = await asyncio.wait_for(self.transport_queue.get(), timeout=timeout) 135 logger.debug(f"<<< {self.connection.self_address} {self.role} received media <<<") 136 logger.debug(f"RTP Packet: {packet.hex()}") 137 return packet 138 139 def send_signal(self, packet: SignalingPacket): 140 logger.debug(f">>> {self.connection.self_address} {self.role} sending signal: >>>") 141 packet.show() 142 self.signaling_channel.send_pdu(packet.serialize()) 143 144 def send_media(self, packet: bytes): 145 logger.debug(f">>> {self.connection.self_address} {self.role} sending media >>>") 146 self.transport_channel.send_pdu(packet) 147 148 async def _initiate_signaling_channel(self): 149 if self.signaling_channel: 150 raise ValueError("Signaling L2CAP channel already exists") 151 self.role = "initiator" 152 self.signaling_channel = await self.connection.create_l2cap_channel(spec=l2cap.ClassicChannelSpec( 153 psm=avdtp.AVDTP_PSM)) 154 # Register to receive PDUs from the channel 155 self.signaling_channel.sink = self._on_pdu 156 157 def _accept_signaling_channel(self): 158 if self.avdtp_server: 159 raise ValueError("L2CAP server already exists") 160 self.role = "acceptor" 161 avdtp_server = self.connection.device.l2cap_channel_manager.servers.get(avdtp.AVDTP_PSM) 162 if not avdtp_server: 163 self.avdtp_server = self.connection.device.create_l2cap_server(spec=l2cap.ClassicChannelSpec( 164 psm=avdtp.AVDTP_PSM)) 165 else: 166 self.avdtp_server = avdtp_server 167 self.avdtp_server.on('connection', self._on_l2cap_connection) 168 169 def _on_l2cap_connection(self, channel: l2cap.ClassicChannel): 170 logger.info(f"Incoming L2CAP channel: {channel}") 171 172 if not self.signaling_channel: 173 174 def _on_channel_open(): 175 logger.info(f"Signaling opened on channel {self.signaling_channel}") 176 # Register to receive PDUs from the channel 177 self.signaling_channel.sink = self._on_pdu 178 self.emit('connection') 179 180 def _on_channel_close(): 181 logger.info("Signaling channel closed") 182 self.signaling_channel = None 183 184 self.signaling_channel = channel 185 self.signaling_channel.on('open', _on_channel_open) 186 self.signaling_channel.on('close', _on_channel_close) 187 elif not self.transport_channel: 188 189 def _on_channel_open(): 190 logger.info(f"RTP opened on channel {self.transport_channel}") 191 # Register to receive PDUs from the channel 192 self.transport_channel.sink = self._on_avdtp_packet 193 194 def _on_channel_close(): 195 logger.info('RTP channel closed') 196 self.transport_channel = None 197 198 self.transport_channel = channel 199 self.transport_channel.on('open', _on_channel_open) 200 self.transport_channel.on('close', _on_channel_close) 201 202 def _on_pdu(self, pdu: bytes): 203 self.signaling_queue.put_nowait(pdu) 204 205 def _on_avdtp_packet(self, packet): 206 self.transport_queue.put_nowait(packet) 207 208 async def accept_discover(self, seid_information: List[av.SeidInformation]): 209 cmd = await self.expect_signal(av.DiscoverCommand(transaction_label=self.any)) 210 self.send_signal(av.DiscoverResponse(transaction_label=cmd.transaction_label, 211 seid_information=seid_information)) 212 213 async def accept_get_all_capabilities(self, service_capabilities: List[ServiceCapability]): 214 cmd = await self.expect_signal(av.GetAllCapabilitiesCommand(acp_seid=self.any, transaction_label=self.any)) 215 self.send_signal( 216 av.GetAllCapabilitiesResponse(transaction_label=cmd.transaction_label, 217 service_capabilities=service_capabilities)) 218 219 async def accept_set_configuration(self, expected_configuration: List[ServiceCapability]): 220 cmd = await self.expect_signal( 221 av.SetConfigurationCommand(transaction_label=self.any, 222 acp_seid=self.any, 223 int_seid=self.any, 224 service_capabilities=expected_configuration)) 225 self.acp_seid = cmd.acp_seid 226 self.int_seid = cmd.int_seid 227 self.send_signal(SetConfigurationResponse(transaction_label=cmd.transaction_label)) 228 229 async def accept_open(self, timeout: float = 3.0): 230 cmd = await self.expect_signal(av.OpenCommand(transaction_label=self.any, acp_seid=self.any), timeout=timeout) 231 self.send_signal(av.OpenResponse(transaction_label=cmd.transaction_label)) 232 233 async def accept_start(self, timeout: float = 3.0): 234 cmd = await self.expect_signal(av.StartCommand(transaction_label=self.any, acp_seid=self.any), timeout=timeout) 235 self.send_signal(av.StartResponse(transaction_label=cmd.transaction_label)) 236 237 async def accept_suspend(self, timeout: float = 3.0): 238 cmd = await self.expect_signal(av.SuspendCommand(transaction_label=self.any, acp_seid=self.any), 239 timeout=timeout) 240 self.send_signal(av.SuspendResponse(transaction_label=cmd.transaction_label)) 241 242 async def accept_close(self, timeout: float = 3.0): 243 cmd = await self.expect_signal(av.CloseCommand(transaction_label=self.any, acp_seid=self.any), timeout=timeout) 244 self.send_signal(av.CloseResponse(transaction_label=cmd.transaction_label)) 245 246 async def accept_open_stream(self, 247 seid_information: List[av.SeidInformation], 248 service_capabilities: List[ServiceCapability], 249 timeout: float = 10.0): 250 avdtp_future = asyncio.get_running_loop().create_future() 251 252 def on_avdtp_connection(): 253 logger.info(f"AVDTP Opened") 254 nonlocal avdtp_future 255 avdtp_future.set_result(None) 256 257 self.on('connection', on_avdtp_connection) 258 259 expected_configuration: List[ServiceCapability] = [] 260 for capability in service_capabilities: 261 if isinstance(capability, av.MediaTransportCapability) or isinstance(capability, 262 av.DelayReportingCapability): 263 expected_configuration.append(capability) 264 else: 265 expected_configuration.append(self.any) 266 267 await self.accept_discover(seid_information) 268 await self.accept_get_all_capabilities(service_capabilities) 269 await self.accept_set_configuration(expected_configuration) 270 await self.accept_open() 271 272 await asyncio.wait_for(avdtp_future, timeout=timeout) 273 274 async def initiate_delay_report(self, delay_ms: int = 100, timeout: float = 3.0): 275 delay_one_tenth = delay_ms * 10 276 delay_msb = (delay_one_tenth >> 8) & 0xff 277 delay_lsb = delay_one_tenth & 0xff 278 self.send_signal( 279 av.DelayReportCommand(transaction_label=0x01, 280 acp_seid=self.acp_seid, 281 delay_msb=delay_msb, 282 delay_lsb=delay_lsb)) 283 await self.expect_signal(av.DelayReportResponse(transaction_label=self.any), timeout=timeout) 284