• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19
20from bumble.device import Connection
21try:
22    from packets import avdtp as av
23    from packets.avdtp import *
24except ImportError:
25    from .packets import avdtp as av
26    from .packets.avdtp import *
27from pyee import EventEmitter
28from typing import List, Literal, Union
29
30import asyncio
31import bumble.avdtp as avdtp
32import bumble.l2cap as l2cap
33import logging
34
35# -----------------------------------------------------------------------------
36# Logging
37# -----------------------------------------------------------------------------
38logger = logging.getLogger(__name__)
39
40av.print = lambda *args, **kwargs: logger.debug(" ".join(map(str, args)))
41
42
43class Any:
44    """Helper class that will match all other values.
45       Use an element of this class in expected packets to match any value
46      returned by the AVDTP signaling."""
47
48    def __eq__(self, other) -> bool:
49        return True
50
51    def __format__(self, format_spec: str) -> str:
52        return "_"
53
54    def __len__(self) -> int:
55        return 1
56
57    def show(self, prefix: str = "") -> str:
58        return prefix + "_"
59
60
61RoleType = Optional[Literal["acceptor", "initiator"]]
62
63
64class SignalingChannel(EventEmitter):
65    connection: Connection
66    signaling_channel: Optional[l2cap.ClassicChannel] = None
67    transport_channel: Optional[l2cap.ClassicChannel] = None
68    avdtp_server: Optional[l2cap.ClassicChannelServer] = None
69    role: RoleType = None
70    any: Any = Any()
71    acp_seid: int = 0
72    int_seid: int = 0
73
74    def __init__(self, connection: Connection):
75        super().__init__()
76        self.connection = connection
77        self.signaling_queue = asyncio.Queue()
78        self.transport_queue = asyncio.Queue()
79
80    @classmethod
81    async def initiate(cls, connection: Connection) -> SignalingChannel:
82        channel = cls(connection)
83        await channel._initiate_signaling_channel()
84        return channel
85
86    @classmethod
87    def accept(cls, connection: Connection) -> SignalingChannel:
88        channel = cls(connection)
89        channel._accept_signaling_channel()
90        return channel
91
92    async def disconnect(self):
93        if not self.signaling_channel:
94            raise ValueError("No connected signaling channel")
95        await self.signaling_channel.disconnect()
96        self.signaling_channel = None
97
98    async def initiate_transport_channel(self):
99        if self.transport_channel:
100            raise ValueError("RTP L2CAP channel already exists")
101        self.transport_channel = await self.connection.create_l2cap_channel(
102            l2cap.ClassicChannelSpec(psm=avdtp.AVDTP_PSM))
103
104    async def disconnect_transport_channel(self):
105        if not self.transport_channel:
106            raise ValueError("No connected RTP channel")
107        await self.transport_channel.disconnect()
108        self.transport_channel = None
109
110    async def expect_signal(self, expected_sig: Union[SignalingPacket, type], timeout: float = 3) -> SignalingPacket:
111        packet = await asyncio.wait_for(self.signaling_queue.get(), timeout=timeout)
112        sig = SignalingPacket.parse_all(packet)
113
114        if isinstance(expected_sig, type) and not isinstance(sig, expected_sig):
115            logger.error("Received unexpected signal")
116            logger.error(f"Expected signal: {expected_sig.__class__.__name__}")
117            logger.error("Received signal:")
118            sig.show()
119            raise ValueError(f"Received unexpected signal")
120
121        if isinstance(expected_sig, SignalingPacket) and sig != expected_sig:
122            logger.error("Received unexpected signal")
123            logger.error("Expected signal:")
124            expected_sig.show()
125            logger.error("Received signal:")
126            sig.show()
127            raise ValueError(f"Received unexpected signal")
128
129        logger.debug(f"<<< {self.connection.self_address} {self.role} received signal: <<<")
130        sig.show()
131        return sig
132
133    async def expect_media(self, timeout: float = 3.0) -> bytes:
134        packet = await asyncio.wait_for(self.transport_queue.get(), timeout=timeout)
135        logger.debug(f"<<< {self.connection.self_address} {self.role} received media <<<")
136        logger.debug(f"RTP Packet: {packet.hex()}")
137        return packet
138
139    def send_signal(self, packet: SignalingPacket):
140        logger.debug(f">>> {self.connection.self_address} {self.role} sending signal: >>>")
141        packet.show()
142        self.signaling_channel.send_pdu(packet.serialize())
143
144    def send_media(self, packet: bytes):
145        logger.debug(f">>> {self.connection.self_address} {self.role} sending media >>>")
146        self.transport_channel.send_pdu(packet)
147
148    async def _initiate_signaling_channel(self):
149        if self.signaling_channel:
150            raise ValueError("Signaling L2CAP channel already exists")
151        self.role = "initiator"
152        self.signaling_channel = await self.connection.create_l2cap_channel(spec=l2cap.ClassicChannelSpec(
153            psm=avdtp.AVDTP_PSM))
154        # Register to receive PDUs from the channel
155        self.signaling_channel.sink = self._on_pdu
156
157    def _accept_signaling_channel(self):
158        if self.avdtp_server:
159            raise ValueError("L2CAP server already exists")
160        self.role = "acceptor"
161        avdtp_server = self.connection.device.l2cap_channel_manager.servers.get(avdtp.AVDTP_PSM)
162        if not avdtp_server:
163            self.avdtp_server = self.connection.device.create_l2cap_server(spec=l2cap.ClassicChannelSpec(
164                psm=avdtp.AVDTP_PSM))
165        else:
166            self.avdtp_server = avdtp_server
167        self.avdtp_server.on('connection', self._on_l2cap_connection)
168
169    def _on_l2cap_connection(self, channel: l2cap.ClassicChannel):
170        logger.info(f"Incoming L2CAP channel: {channel}")
171
172        if not self.signaling_channel:
173
174            def _on_channel_open():
175                logger.info(f"Signaling opened on channel {self.signaling_channel}")
176                # Register to receive PDUs from the channel
177                self.signaling_channel.sink = self._on_pdu
178                self.emit('connection')
179
180            def _on_channel_close():
181                logger.info("Signaling channel closed")
182                self.signaling_channel = None
183
184            self.signaling_channel = channel
185            self.signaling_channel.on('open', _on_channel_open)
186            self.signaling_channel.on('close', _on_channel_close)
187        elif not self.transport_channel:
188
189            def _on_channel_open():
190                logger.info(f"RTP opened on channel {self.transport_channel}")
191                # Register to receive PDUs from the channel
192                self.transport_channel.sink = self._on_avdtp_packet
193
194            def _on_channel_close():
195                logger.info('RTP channel closed')
196                self.transport_channel = None
197
198            self.transport_channel = channel
199            self.transport_channel.on('open', _on_channel_open)
200            self.transport_channel.on('close', _on_channel_close)
201
202    def _on_pdu(self, pdu: bytes):
203        self.signaling_queue.put_nowait(pdu)
204
205    def _on_avdtp_packet(self, packet):
206        self.transport_queue.put_nowait(packet)
207
208    async def accept_discover(self, seid_information: List[av.SeidInformation]):
209        cmd = await self.expect_signal(av.DiscoverCommand(transaction_label=self.any))
210        self.send_signal(av.DiscoverResponse(transaction_label=cmd.transaction_label,
211                                             seid_information=seid_information))
212
213    async def accept_get_all_capabilities(self, service_capabilities: List[ServiceCapability]):
214        cmd = await self.expect_signal(av.GetAllCapabilitiesCommand(acp_seid=self.any, transaction_label=self.any))
215        self.send_signal(
216            av.GetAllCapabilitiesResponse(transaction_label=cmd.transaction_label,
217                                          service_capabilities=service_capabilities))
218
219    async def accept_set_configuration(self, expected_configuration: List[ServiceCapability]):
220        cmd = await self.expect_signal(
221            av.SetConfigurationCommand(transaction_label=self.any,
222                                       acp_seid=self.any,
223                                       int_seid=self.any,
224                                       service_capabilities=expected_configuration))
225        self.acp_seid = cmd.acp_seid
226        self.int_seid = cmd.int_seid
227        self.send_signal(SetConfigurationResponse(transaction_label=cmd.transaction_label))
228
229    async def accept_open(self, timeout: float = 3.0):
230        cmd = await self.expect_signal(av.OpenCommand(transaction_label=self.any, acp_seid=self.any), timeout=timeout)
231        self.send_signal(av.OpenResponse(transaction_label=cmd.transaction_label))
232
233    async def accept_start(self, timeout: float = 3.0):
234        cmd = await self.expect_signal(av.StartCommand(transaction_label=self.any, acp_seid=self.any), timeout=timeout)
235        self.send_signal(av.StartResponse(transaction_label=cmd.transaction_label))
236
237    async def accept_suspend(self, timeout: float = 3.0):
238        cmd = await self.expect_signal(av.SuspendCommand(transaction_label=self.any, acp_seid=self.any),
239                                       timeout=timeout)
240        self.send_signal(av.SuspendResponse(transaction_label=cmd.transaction_label))
241
242    async def accept_close(self, timeout: float = 3.0):
243        cmd = await self.expect_signal(av.CloseCommand(transaction_label=self.any, acp_seid=self.any), timeout=timeout)
244        self.send_signal(av.CloseResponse(transaction_label=cmd.transaction_label))
245
246    async def accept_open_stream(self,
247                                 seid_information: List[av.SeidInformation],
248                                 service_capabilities: List[ServiceCapability],
249                                 timeout: float = 10.0):
250        avdtp_future = asyncio.get_running_loop().create_future()
251
252        def on_avdtp_connection():
253            logger.info(f"AVDTP Opened")
254            nonlocal avdtp_future
255            avdtp_future.set_result(None)
256
257        self.on('connection', on_avdtp_connection)
258
259        expected_configuration: List[ServiceCapability] = []
260        for capability in service_capabilities:
261            if isinstance(capability, av.MediaTransportCapability) or isinstance(capability,
262                                                                                 av.DelayReportingCapability):
263                expected_configuration.append(capability)
264            else:
265                expected_configuration.append(self.any)
266
267        await self.accept_discover(seid_information)
268        await self.accept_get_all_capabilities(service_capabilities)
269        await self.accept_set_configuration(expected_configuration)
270        await self.accept_open()
271
272        await asyncio.wait_for(avdtp_future, timeout=timeout)
273
274    async def initiate_delay_report(self, delay_ms: int = 100, timeout: float = 3.0):
275        delay_one_tenth = delay_ms * 10
276        delay_msb = (delay_one_tenth >> 8) & 0xff
277        delay_lsb = delay_one_tenth & 0xff
278        self.send_signal(
279            av.DelayReportCommand(transaction_label=0x01,
280                                  acp_seid=self.acp_seid,
281                                  delay_msb=delay_msb,
282                                  delay_lsb=delay_lsb))
283        await self.expect_signal(av.DelayReportResponse(transaction_label=self.any), timeout=timeout)
284