• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import enum
20import logging
21import os
22import struct
23import time
24
25import click
26
27from bumble.core import (
28    BT_BR_EDR_TRANSPORT,
29    BT_LE_TRANSPORT,
30    BT_L2CAP_PROTOCOL_ID,
31    BT_RFCOMM_PROTOCOL_ID,
32    UUID,
33    CommandTimeoutError,
34)
35from bumble.colors import color
36from bumble.device import Connection, ConnectionParametersPreferences, Device, Peer
37from bumble.gatt import Characteristic, CharacteristicValue, Service
38from bumble.hci import (
39    HCI_LE_1M_PHY,
40    HCI_LE_2M_PHY,
41    HCI_LE_CODED_PHY,
42    HCI_Constant,
43    HCI_Error,
44    HCI_StatusError,
45)
46from bumble.sdp import (
47    SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
48    SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
49    SDP_PUBLIC_BROWSE_ROOT,
50    SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
51    SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
52    DataElement,
53    ServiceAttribute,
54)
55from bumble.transport import open_transport_or_link
56import bumble.rfcomm
57import bumble.core
58from bumble.utils import AsyncRunner
59
60
61# -----------------------------------------------------------------------------
62# Logging
63# -----------------------------------------------------------------------------
64logger = logging.getLogger(__name__)
65
66
67# -----------------------------------------------------------------------------
68# Constants
69# -----------------------------------------------------------------------------
70DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
71DEFAULT_CENTRAL_NAME = 'Speed Central'
72DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
73DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'
74
75SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
76SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
77SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
78
79DEFAULT_L2CAP_PSM = 1234
80DEFAULT_L2CAP_MAX_CREDITS = 128
81DEFAULT_L2CAP_MTU = 1022
82DEFAULT_L2CAP_MPS = 1024
83
84DEFAULT_LINGER_TIME = 1.0
85
86DEFAULT_RFCOMM_CHANNEL = 8
87
88# -----------------------------------------------------------------------------
89# Utils
90# -----------------------------------------------------------------------------
91def parse_packet(packet):
92    if len(packet) < 1:
93        print(
94            color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
95        )
96        raise ValueError('packet too short')
97
98    try:
99        packet_type = PacketType(packet[0])
100    except ValueError:
101        print(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
102        raise
103
104    return (packet_type, packet[1:])
105
106
107def parse_packet_sequence(packet_data):
108    if len(packet_data) < 5:
109        print(
110            color(
111                f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
112                'red',
113            )
114        )
115        raise ValueError('packet too short')
116    return struct.unpack_from('>bI', packet_data, 0)
117
118
119def le_phy_name(phy_id):
120    return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
121        phy_id, HCI_Constant.le_phy_name(phy_id)
122    )
123
124
125def print_connection(connection):
126    if connection.transport == BT_LE_TRANSPORT:
127        phy_state = (
128            'PHY='
129            f'RX:{le_phy_name(connection.phy.rx_phy)}/'
130            f'TX:{le_phy_name(connection.phy.tx_phy)}'
131        )
132
133        data_length = f'DL={connection.data_length}'
134        connection_parameters = (
135            'Parameters='
136            f'{connection.parameters.connection_interval * 1.25:.2f}/'
137            f'{connection.parameters.peripheral_latency}/'
138            f'{connection.parameters.supervision_timeout * 10} '
139        )
140
141    else:
142        phy_state = ''
143        data_length = ''
144        connection_parameters = ''
145
146    mtu = connection.att_mtu
147
148    print(
149        f'{color("@@@ Connection:", "yellow")} '
150        f'{connection_parameters} '
151        f'{data_length} '
152        f'{phy_state} '
153        f'MTU={mtu}'
154    )
155
156
157def make_sdp_records(channel):
158    return {
159        0x00010001: [
160            ServiceAttribute(
161                SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
162                DataElement.unsigned_integer_32(0x00010001),
163            ),
164            ServiceAttribute(
165                SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
166                DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
167            ),
168            ServiceAttribute(
169                SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
170                DataElement.sequence(
171                    [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
172                ),
173            ),
174            ServiceAttribute(
175                SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
176                DataElement.sequence(
177                    [
178                        DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
179                        DataElement.sequence(
180                            [
181                                DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
182                                DataElement.unsigned_integer_8(channel),
183                            ]
184                        ),
185                    ]
186                ),
187            ),
188        ]
189    }
190
191
192class PacketType(enum.IntEnum):
193    RESET = 0
194    SEQUENCE = 1
195    ACK = 2
196
197
198PACKET_FLAG_LAST = 1
199
200# -----------------------------------------------------------------------------
201# Sender
202# -----------------------------------------------------------------------------
203class Sender:
204    def __init__(self, packet_io, start_delay, packet_size, packet_count):
205        self.tx_start_delay = start_delay
206        self.tx_packet_size = packet_size
207        self.tx_packet_count = packet_count
208        self.packet_io = packet_io
209        self.packet_io.packet_listener = self
210        self.start_time = 0
211        self.bytes_sent = 0
212        self.done = asyncio.Event()
213
214    def reset(self):
215        pass
216
217    async def run(self):
218        print(color('--- Waiting for I/O to be ready...', 'blue'))
219        await self.packet_io.ready.wait()
220        print(color('--- Go!', 'blue'))
221
222        if self.tx_start_delay:
223            print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
224            await asyncio.sleep(self.tx_start_delay)  # FIXME
225
226        print(color('=== Sending RESET', 'magenta'))
227        await self.packet_io.send_packet(bytes([PacketType.RESET]))
228        self.start_time = time.time()
229        for tx_i in range(self.tx_packet_count):
230            packet_flags = PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
231            packet = struct.pack(
232                '>bbI',
233                PacketType.SEQUENCE,
234                packet_flags,
235                tx_i,
236            ) + bytes(self.tx_packet_size - 6)
237            print(color(f'Sending packet {tx_i}: {len(packet)} bytes', 'yellow'))
238            self.bytes_sent += len(packet)
239            await self.packet_io.send_packet(packet)
240
241        await self.done.wait()
242        print(color('=== Done!', 'magenta'))
243
244    def on_packet_received(self, packet):
245        try:
246            packet_type, _ = parse_packet(packet)
247        except ValueError:
248            return
249
250        if packet_type == PacketType.ACK:
251            elapsed = time.time() - self.start_time
252            average_tx_speed = self.bytes_sent / elapsed
253            print(
254                color(
255                    f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
256                    f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
257                    'green',
258                )
259            )
260            self.done.set()
261
262
263# -----------------------------------------------------------------------------
264# Receiver
265# -----------------------------------------------------------------------------
266class Receiver:
267    def __init__(self, packet_io):
268        self.reset()
269        self.packet_io = packet_io
270        self.packet_io.packet_listener = self
271        self.done = asyncio.Event()
272
273    def reset(self):
274        self.expected_packet_index = 0
275        self.start_timestamp = 0.0
276        self.last_timestamp = 0.0
277        self.bytes_received = 0
278
279    def on_packet_received(self, packet):
280        try:
281            packet_type, packet_data = parse_packet(packet)
282        except ValueError:
283            return
284
285        now = time.time()
286
287        if packet_type == PacketType.RESET:
288            print(color('=== Received RESET', 'magenta'))
289            self.reset()
290            self.start_timestamp = now
291            return
292
293        try:
294            packet_flags, packet_index = parse_packet_sequence(packet_data)
295        except ValueError:
296            return
297        print(
298            f'<<< Received packet {packet_index}: '
299            f'flags=0x{packet_flags:02X}, {len(packet)} bytes'
300        )
301
302        if packet_index != self.expected_packet_index:
303            print(
304                color(
305                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
306                    f'but received {packet_index}'
307                )
308            )
309
310        elapsed_since_start = now - self.start_timestamp
311        elapsed_since_last = now - self.last_timestamp
312        self.bytes_received += len(packet)
313        instant_rx_speed = len(packet) / elapsed_since_last
314        average_rx_speed = self.bytes_received / elapsed_since_start
315        print(
316            color(
317                f'Speed: instant={instant_rx_speed:.4f}, average={average_rx_speed:.4f}',
318                'yellow',
319            )
320        )
321
322        self.last_timestamp = now
323        self.expected_packet_index = packet_index + 1
324
325        if packet_flags & PACKET_FLAG_LAST:
326            AsyncRunner.spawn(
327                self.packet_io.send_packet(
328                    struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
329                )
330            )
331            print(color('@@@ Received last packet', 'green'))
332            self.done.set()
333
334    async def run(self):
335        await self.done.wait()
336        print(color('=== Done!', 'magenta'))
337
338
339# -----------------------------------------------------------------------------
340# Ping
341# -----------------------------------------------------------------------------
342class Ping:
343    def __init__(self, packet_io, start_delay, packet_size, packet_count):
344        self.tx_start_delay = start_delay
345        self.tx_packet_size = packet_size
346        self.tx_packet_count = packet_count
347        self.packet_io = packet_io
348        self.packet_io.packet_listener = self
349        self.done = asyncio.Event()
350        self.current_packet_index = 0
351        self.ping_sent_time = 0.0
352        self.latencies = []
353
354    def reset(self):
355        pass
356
357    async def run(self):
358        print(color('--- Waiting for I/O to be ready...', 'blue'))
359        await self.packet_io.ready.wait()
360        print(color('--- Go!', 'blue'))
361
362        if self.tx_start_delay:
363            print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
364            await asyncio.sleep(self.tx_start_delay)  # FIXME
365
366        print(color('=== Sending RESET', 'magenta'))
367        await self.packet_io.send_packet(bytes([PacketType.RESET]))
368
369        await self.send_next_ping()
370
371        await self.done.wait()
372        average_latency = sum(self.latencies) / len(self.latencies)
373        print(color(f'@@@ Average latency: {average_latency:.2f}'))
374        print(color('=== Done!', 'magenta'))
375
376    async def send_next_ping(self):
377        packet = struct.pack(
378            '>bbI',
379            PacketType.SEQUENCE,
380            PACKET_FLAG_LAST
381            if self.current_packet_index == self.tx_packet_count - 1
382            else 0,
383            self.current_packet_index,
384        ) + bytes(self.tx_packet_size - 6)
385        print(color(f'Sending packet {self.current_packet_index}', 'yellow'))
386        self.ping_sent_time = time.time()
387        await self.packet_io.send_packet(packet)
388
389    def on_packet_received(self, packet):
390        elapsed = time.time() - self.ping_sent_time
391
392        try:
393            packet_type, packet_data = parse_packet(packet)
394        except ValueError:
395            return
396
397        try:
398            packet_flags, packet_index = parse_packet_sequence(packet_data)
399        except ValueError:
400            return
401
402        if packet_type == PacketType.ACK:
403            latency = elapsed * 1000
404            self.latencies.append(latency)
405            print(
406                color(
407                    f'<<< Received ACK [{packet_index}], latency={latency:.2f}ms',
408                    'green',
409                )
410            )
411
412            if packet_index == self.current_packet_index:
413                self.current_packet_index += 1
414            else:
415                print(
416                    color(
417                        f'!!! Unexpected packet, expected {self.current_packet_index} '
418                        f'but received {packet_index}'
419                    )
420                )
421
422        if packet_flags & PACKET_FLAG_LAST:
423            self.done.set()
424            return
425
426        AsyncRunner.spawn(self.send_next_ping())
427
428
429# -----------------------------------------------------------------------------
430# Pong
431# -----------------------------------------------------------------------------
432class Pong:
433    def __init__(self, packet_io):
434        self.reset()
435        self.packet_io = packet_io
436        self.packet_io.packet_listener = self
437        self.done = asyncio.Event()
438
439    def reset(self):
440        self.expected_packet_index = 0
441
442    def on_packet_received(self, packet):
443        try:
444            packet_type, packet_data = parse_packet(packet)
445        except ValueError:
446            return
447
448        if packet_type == PacketType.RESET:
449            print(color('=== Received RESET', 'magenta'))
450            self.reset()
451            return
452
453        try:
454            packet_flags, packet_index = parse_packet_sequence(packet_data)
455        except ValueError:
456            return
457        print(
458            color(
459                f'<<< Received packet {packet_index}: '
460                f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
461                'green',
462            )
463        )
464
465        if packet_index != self.expected_packet_index:
466            print(
467                color(
468                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
469                    f'but received {packet_index}'
470                )
471            )
472
473        self.expected_packet_index = packet_index + 1
474
475        AsyncRunner.spawn(
476            self.packet_io.send_packet(
477                struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
478            )
479        )
480
481        if packet_flags & PACKET_FLAG_LAST:
482            self.done.set()
483
484    async def run(self):
485        await self.done.wait()
486        print(color('=== Done!', 'magenta'))
487
488
489# -----------------------------------------------------------------------------
490# GattClient
491# -----------------------------------------------------------------------------
492class GattClient:
493    def __init__(self, _device, att_mtu=None):
494        self.att_mtu = att_mtu
495        self.speed_rx = None
496        self.speed_tx = None
497        self.packet_listener = None
498        self.ready = asyncio.Event()
499
500    async def on_connection(self, connection):
501        peer = Peer(connection)
502
503        if self.att_mtu:
504            print(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
505            await peer.request_mtu(self.att_mtu)
506
507        print(color('*** Discovering services...', 'blue'))
508        await peer.discover_services()
509
510        speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
511        if not speed_services:
512            print(color('!!! Speed Service not found', 'red'))
513            return
514        speed_service = speed_services[0]
515        print(color('*** Discovering characteristics...', 'blue'))
516        await speed_service.discover_characteristics()
517
518        speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
519        if not speed_txs:
520            print(color('!!! Speed TX not found', 'red'))
521            return
522        self.speed_tx = speed_txs[0]
523
524        speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
525        if not speed_rxs:
526            print(color('!!! Speed RX not found', 'red'))
527            return
528        self.speed_rx = speed_rxs[0]
529
530        print(color('*** Subscribing to RX', 'blue'))
531        await self.speed_rx.subscribe(self.on_packet_received)
532
533        print(color('*** Discovery complete', 'blue'))
534
535        connection.on('disconnection', self.on_disconnection)
536        self.ready.set()
537
538    def on_disconnection(self, _):
539        self.ready.clear()
540
541    def on_packet_received(self, packet):
542        if self.packet_listener:
543            self.packet_listener.on_packet_received(packet)
544
545    async def send_packet(self, packet):
546        await self.speed_tx.write_value(packet)
547
548
549# -----------------------------------------------------------------------------
550# GattServer
551# -----------------------------------------------------------------------------
552class GattServer:
553    def __init__(self, device):
554        self.device = device
555        self.packet_listener = None
556        self.ready = asyncio.Event()
557
558        # Setup the GATT service
559        self.speed_tx = Characteristic(
560            SPEED_TX_UUID,
561            Characteristic.WRITE,
562            Characteristic.WRITEABLE,
563            CharacteristicValue(write=self.on_tx_write),
564        )
565        self.speed_rx = Characteristic(SPEED_RX_UUID, Characteristic.NOTIFY, 0)
566
567        speed_service = Service(
568            SPEED_SERVICE_UUID,
569            [self.speed_tx, self.speed_rx],
570        )
571        device.add_services([speed_service])
572
573        self.speed_rx.on('subscription', self.on_rx_subscription)
574
575    async def on_connection(self, connection):
576        connection.on('disconnection', self.on_disconnection)
577
578    def on_disconnection(self, _):
579        self.ready.clear()
580
581    def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
582        if notify_enabled:
583            print(color('*** RX subscription', 'blue'))
584            self.ready.set()
585        else:
586            print(color('*** RX un-subscription', 'blue'))
587            self.ready.clear()
588
589    def on_tx_write(self, _, value):
590        if self.packet_listener:
591            self.packet_listener.on_packet_received(value)
592
593    async def send_packet(self, packet):
594        await self.device.notify_subscribers(self.speed_rx, packet)
595
596
597# -----------------------------------------------------------------------------
598# StreamedPacketIO
599# -----------------------------------------------------------------------------
600class StreamedPacketIO:
601    def __init__(self):
602        self.packet_listener = None
603        self.io_sink = None
604        self.rx_packet = b''
605        self.rx_packet_header = b''
606        self.rx_packet_need = 0
607
608    def on_packet(self, packet):
609        while packet:
610            if self.rx_packet_need:
611                chunk = packet[: self.rx_packet_need]
612                self.rx_packet += chunk
613                packet = packet[len(chunk) :]
614                self.rx_packet_need -= len(chunk)
615                if not self.rx_packet_need:
616                    # Packet completed
617                    if self.packet_listener:
618                        self.packet_listener.on_packet_received(self.rx_packet)
619
620                    self.rx_packet = b''
621                    self.rx_packet_header = b''
622            else:
623                # Expect the next packet
624                header_bytes_needed = 2 - len(self.rx_packet_header)
625                header_bytes = packet[:header_bytes_needed]
626                self.rx_packet_header += header_bytes
627                if len(self.rx_packet_header) != 2:
628                    return
629                packet = packet[len(header_bytes) :]
630                self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0]
631
632    async def send_packet(self, packet):
633        if not self.io_sink:
634            print(color('!!! No sink, dropping packet', 'red'))
635            return
636
637        # pylint: disable-next=not-callable
638        self.io_sink(struct.pack('>H', len(packet)) + packet)
639
640
641# -----------------------------------------------------------------------------
642# L2capClient
643# -----------------------------------------------------------------------------
644class L2capClient(StreamedPacketIO):
645    def __init__(
646        self,
647        _device,
648        psm=DEFAULT_L2CAP_PSM,
649        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
650        mtu=DEFAULT_L2CAP_MTU,
651        mps=DEFAULT_L2CAP_MPS,
652    ):
653        super().__init__()
654        self.psm = psm
655        self.max_credits = max_credits
656        self.mtu = mtu
657        self.mps = mps
658        self.ready = asyncio.Event()
659
660    async def on_connection(self, connection):
661        connection.on('disconnection', self.on_disconnection)
662
663        # Connect a new L2CAP channel
664        print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
665        try:
666            l2cap_channel = await connection.open_l2cap_channel(
667                psm=self.psm,
668                max_credits=self.max_credits,
669                mtu=self.mtu,
670                mps=self.mps,
671            )
672            print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
673        except Exception as error:
674            print(color(f'!!! Connection failed: {error}', 'red'))
675            return
676
677        l2cap_channel.sink = self.on_packet
678        l2cap_channel.on('close', self.on_l2cap_close)
679        self.io_sink = l2cap_channel.write
680
681        self.ready.set()
682
683    def on_disconnection(self, _):
684        pass
685
686    def on_l2cap_close(self):
687        print(color('*** L2CAP channel closed', 'red'))
688
689
690# -----------------------------------------------------------------------------
691# L2capServer
692# -----------------------------------------------------------------------------
693class L2capServer(StreamedPacketIO):
694    def __init__(
695        self,
696        device,
697        psm=DEFAULT_L2CAP_PSM,
698        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
699        mtu=DEFAULT_L2CAP_MTU,
700        mps=DEFAULT_L2CAP_MPS,
701    ):
702        super().__init__()
703        self.l2cap_channel = None
704        self.ready = asyncio.Event()
705
706        # Listen for incoming L2CAP CoC connections
707        device.register_l2cap_channel_server(
708            psm=psm,
709            server=self.on_l2cap_channel,
710            max_credits=max_credits,
711            mtu=mtu,
712            mps=mps,
713        )
714        print(color(f'### Listening for CoC connection on PSM {psm}', 'yellow'))
715
716    async def on_connection(self, connection):
717        connection.on('disconnection', self.on_disconnection)
718
719    def on_disconnection(self, _):
720        pass
721
722    def on_l2cap_channel(self, l2cap_channel):
723        print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
724
725        self.io_sink = l2cap_channel.write
726        l2cap_channel.on('close', self.on_l2cap_close)
727        l2cap_channel.sink = self.on_packet
728
729        self.ready.set()
730
731    def on_l2cap_close(self):
732        print(color('*** L2CAP channel closed', 'red'))
733        self.l2cap_channel = None
734
735
736# -----------------------------------------------------------------------------
737# RfcommClient
738# -----------------------------------------------------------------------------
739class RfcommClient(StreamedPacketIO):
740    def __init__(self, device):
741        super().__init__()
742        self.device = device
743        self.ready = asyncio.Event()
744
745    async def on_connection(self, connection):
746        connection.on('disconnection', self.on_disconnection)
747
748        # Create a client and start it
749        print(color('*** Starting RFCOMM client...', 'blue'))
750        rfcomm_client = bumble.rfcomm.Client(self.device, connection)
751        rfcomm_mux = await rfcomm_client.start()
752        print(color('*** Started', 'blue'))
753
754        channel = DEFAULT_RFCOMM_CHANNEL
755        print(color(f'### Opening session for channel {channel}...', 'yellow'))
756        try:
757            rfcomm_session = await rfcomm_mux.open_dlc(channel)
758            print(color('### Session open', 'yellow'), rfcomm_session)
759        except bumble.core.ConnectionError as error:
760            print(color(f'!!! Session open failed: {error}', 'red'))
761            await rfcomm_mux.disconnect()
762            return
763
764        rfcomm_session.sink = self.on_packet
765        self.io_sink = rfcomm_session.write
766
767        self.ready.set()
768
769    def on_disconnection(self, _):
770        pass
771
772
773# -----------------------------------------------------------------------------
774# RfcommServer
775# -----------------------------------------------------------------------------
776class RfcommServer(StreamedPacketIO):
777    def __init__(self, device):
778        super().__init__()
779        self.ready = asyncio.Event()
780
781        # Create and register a server
782        rfcomm_server = bumble.rfcomm.Server(device)
783
784        # Listen for incoming DLC connections
785        channel_number = rfcomm_server.listen(self.on_dlc, DEFAULT_RFCOMM_CHANNEL)
786
787        # Setup the SDP to advertise this channel
788        device.sdp_service_records = make_sdp_records(channel_number)
789
790        print(
791            color(
792                f'### Listening for RFComm connection on channel {channel_number}',
793                'yellow',
794            )
795        )
796
797    async def on_connection(self, connection):
798        connection.on('disconnection', self.on_disconnection)
799
800    def on_disconnection(self, _):
801        pass
802
803    def on_dlc(self, dlc):
804        print(color('*** DLC connected:', 'blue'), dlc)
805        dlc.sink = self.on_packet
806        self.io_sink = dlc.write
807
808
809# -----------------------------------------------------------------------------
810# Central
811# -----------------------------------------------------------------------------
812class Central(Connection.Listener):
813    def __init__(
814        self,
815        transport,
816        peripheral_address,
817        classic,
818        role_factory,
819        mode_factory,
820        connection_interval,
821        phy,
822    ):
823        super().__init__()
824        self.transport = transport
825        self.peripheral_address = peripheral_address
826        self.classic = classic
827        self.role_factory = role_factory
828        self.mode_factory = mode_factory
829        self.device = None
830        self.connection = None
831
832        if phy:
833            self.phy = {
834                '1m': HCI_LE_1M_PHY,
835                '2m': HCI_LE_2M_PHY,
836                'coded': HCI_LE_CODED_PHY,
837            }[phy]
838        else:
839            self.phy = None
840
841        if connection_interval:
842            connection_parameter_preferences = ConnectionParametersPreferences()
843            connection_parameter_preferences.connection_interval_min = (
844                connection_interval
845            )
846            connection_parameter_preferences.connection_interval_max = (
847                connection_interval
848            )
849
850            # Preferences for the 1M PHY are always set.
851            self.connection_parameter_preferences = {
852                HCI_LE_1M_PHY: connection_parameter_preferences,
853            }
854
855            if self.phy not in (None, HCI_LE_1M_PHY):
856                # Add an connections parameters entry for this PHY.
857                self.connection_parameter_preferences[
858                    self.phy
859                ] = connection_parameter_preferences
860        else:
861            self.connection_parameter_preferences = None
862
863    async def run(self):
864        print(color('>>> Connecting to HCI...', 'green'))
865        async with await open_transport_or_link(self.transport) as (
866            hci_source,
867            hci_sink,
868        ):
869            print(color('>>> Connected', 'green'))
870
871            central_address = DEFAULT_CENTRAL_ADDRESS
872            self.device = Device.with_hci(
873                DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
874            )
875            mode = self.mode_factory(self.device)
876            role = self.role_factory(mode)
877            self.device.classic_enabled = self.classic
878
879            await self.device.power_on()
880
881            print(color(f'### Connecting to {self.peripheral_address}...', 'cyan'))
882            try:
883                self.connection = await self.device.connect(
884                    self.peripheral_address,
885                    connection_parameters_preferences=self.connection_parameter_preferences,
886                    transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT,
887                )
888            except CommandTimeoutError:
889                print(color('!!! Connection timed out', 'red'))
890                return
891            except bumble.core.ConnectionError as error:
892                print(color(f'!!! Connection error: {error}', 'red'))
893                return
894            except HCI_StatusError as error:
895                print(color(f'!!! Connection failed: {error.error_name}'))
896                return
897            print(color('### Connected', 'cyan'))
898            self.connection.listener = self
899            print_connection(self.connection)
900
901            await mode.on_connection(self.connection)
902
903            # Set the PHY if requested
904            if self.phy is not None:
905                try:
906                    await self.connection.set_phy(
907                        tx_phys=[self.phy], rx_phys=[self.phy]
908                    )
909                except HCI_Error as error:
910                    print(
911                        color(
912                            f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
913                        )
914                    )
915
916            await role.run()
917            await asyncio.sleep(DEFAULT_LINGER_TIME)
918
919    def on_disconnection(self, reason):
920        print(color(f'!!! Disconnection: reason={reason}', 'red'))
921        self.connection = None
922
923    def on_connection_parameters_update(self):
924        print_connection(self.connection)
925
926    def on_connection_phy_update(self):
927        print_connection(self.connection)
928
929    def on_connection_att_mtu_update(self):
930        print_connection(self.connection)
931
932    def on_connection_data_length_change(self):
933        print_connection(self.connection)
934
935
936# -----------------------------------------------------------------------------
937# Peripheral
938# -----------------------------------------------------------------------------
939class Peripheral(Device.Listener, Connection.Listener):
940    def __init__(self, transport, classic, role_factory, mode_factory):
941        self.transport = transport
942        self.classic = classic
943        self.role_factory = role_factory
944        self.role = None
945        self.mode_factory = mode_factory
946        self.mode = None
947        self.device = None
948        self.connection = None
949        self.connected = asyncio.Event()
950
951    async def run(self):
952        print(color('>>> Connecting to HCI...', 'green'))
953        async with await open_transport_or_link(self.transport) as (
954            hci_source,
955            hci_sink,
956        ):
957            print(color('>>> Connected', 'green'))
958
959            peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
960            self.device = Device.with_hci(
961                DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink
962            )
963            self.device.listener = self
964            self.mode = self.mode_factory(self.device)
965            self.role = self.role_factory(self.mode)
966            self.device.classic_enabled = self.classic
967
968            await self.device.power_on()
969
970            if self.classic:
971                await self.device.set_discoverable(True)
972                await self.device.set_connectable(True)
973            else:
974                await self.device.start_advertising(auto_restart=True)
975
976            if self.classic:
977                print(
978                    color(
979                        '### Waiting for connection on'
980                        f' {self.device.public_address}...',
981                        'cyan',
982                    )
983                )
984            else:
985                print(
986                    color(
987                        f'### Waiting for connection on {peripheral_address}...',
988                        'cyan',
989                    )
990                )
991            await self.connected.wait()
992            print(color('### Connected', 'cyan'))
993
994            await self.mode.on_connection(self.connection)
995            await self.role.run()
996            await asyncio.sleep(DEFAULT_LINGER_TIME)
997
998    def on_connection(self, connection):
999        connection.listener = self
1000        self.connection = connection
1001        self.connected.set()
1002
1003    def on_disconnection(self, reason):
1004        print(color(f'!!! Disconnection: reason={reason}', 'red'))
1005        self.connection = None
1006        self.role.reset()
1007
1008    def on_connection_parameters_update(self):
1009        print_connection(self.connection)
1010
1011    def on_connection_phy_update(self):
1012        print_connection(self.connection)
1013
1014    def on_connection_att_mtu_update(self):
1015        print_connection(self.connection)
1016
1017    def on_connection_data_length_change(self):
1018        print_connection(self.connection)
1019
1020
1021# -----------------------------------------------------------------------------
1022def create_mode_factory(ctx, default_mode):
1023    mode = ctx.obj['mode']
1024    if mode is None:
1025        mode = default_mode
1026
1027    def create_mode(device):
1028        if mode == 'gatt-client':
1029            return GattClient(device, att_mtu=ctx.obj['att_mtu'])
1030
1031        if mode == 'gatt-server':
1032            return GattServer(device)
1033
1034        if mode == 'l2cap-client':
1035            return L2capClient(device)
1036
1037        if mode == 'l2cap-server':
1038            return L2capServer(device)
1039
1040        if mode == 'rfcomm-client':
1041            return RfcommClient(device)
1042
1043        if mode == 'rfcomm-server':
1044            return RfcommServer(device)
1045
1046        raise ValueError('invalid mode')
1047
1048    return create_mode
1049
1050
1051# -----------------------------------------------------------------------------
1052def create_role_factory(ctx, default_role):
1053    role = ctx.obj['role']
1054    if role is None:
1055        role = default_role
1056
1057    def create_role(packet_io):
1058        if role == 'sender':
1059            return Sender(
1060                packet_io,
1061                start_delay=ctx.obj['start_delay'],
1062                packet_size=ctx.obj['packet_size'],
1063                packet_count=ctx.obj['packet_count'],
1064            )
1065
1066        if role == 'receiver':
1067            return Receiver(packet_io)
1068
1069        if role == 'ping':
1070            return Ping(
1071                packet_io,
1072                start_delay=ctx.obj['start_delay'],
1073                packet_size=ctx.obj['packet_size'],
1074                packet_count=ctx.obj['packet_count'],
1075            )
1076
1077        if role == 'pong':
1078            return Pong(packet_io)
1079
1080        raise ValueError('invalid role')
1081
1082    return create_role
1083
1084
1085# -----------------------------------------------------------------------------
1086# Main
1087# -----------------------------------------------------------------------------
1088@click.group()
1089@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
1090@click.option('--role', type=click.Choice(['sender', 'receiver', 'ping', 'pong']))
1091@click.option(
1092    '--mode',
1093    type=click.Choice(
1094        [
1095            'gatt-client',
1096            'gatt-server',
1097            'l2cap-client',
1098            'l2cap-server',
1099            'rfcomm-client',
1100            'rfcomm-server',
1101        ]
1102    ),
1103)
1104@click.option(
1105    '--att-mtu',
1106    metavar='MTU',
1107    type=click.IntRange(23, 517),
1108    help='GATT MTU (gatt-client mode)',
1109)
1110@click.option(
1111    '--packet-size',
1112    '-s',
1113    metavar='SIZE',
1114    type=click.IntRange(8, 4096),
1115    default=500,
1116    help='Packet size (server role)',
1117)
1118@click.option(
1119    '--packet-count',
1120    '-c',
1121    metavar='COUNT',
1122    type=int,
1123    default=10,
1124    help='Packet count (server role)',
1125)
1126@click.option(
1127    '--start-delay',
1128    '-sd',
1129    metavar='SECONDS',
1130    type=int,
1131    default=1,
1132    help='Start delay (server role)',
1133)
1134@click.pass_context
1135def bench(
1136    ctx, device_config, role, mode, att_mtu, packet_size, packet_count, start_delay
1137):
1138    ctx.ensure_object(dict)
1139    ctx.obj['device_config'] = device_config
1140    ctx.obj['role'] = role
1141    ctx.obj['mode'] = mode
1142    ctx.obj['att_mtu'] = att_mtu
1143    ctx.obj['packet_size'] = packet_size
1144    ctx.obj['packet_count'] = packet_count
1145    ctx.obj['start_delay'] = start_delay
1146
1147    ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')
1148
1149
1150@bench.command()
1151@click.argument('transport')
1152@click.option(
1153    '--peripheral',
1154    'peripheral_address',
1155    metavar='ADDRESS_OR_NAME',
1156    default=DEFAULT_PERIPHERAL_ADDRESS,
1157    help='Address or name to connect to',
1158)
1159@click.option(
1160    '--connection-interval',
1161    '--ci',
1162    metavar='CONNECTION_INTERVAL',
1163    type=int,
1164    help='Connection interval (in ms)',
1165)
1166@click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
1167@click.pass_context
1168def central(ctx, transport, peripheral_address, connection_interval, phy):
1169    """Run as a central (initiates the connection)"""
1170    role_factory = create_role_factory(ctx, 'sender')
1171    mode_factory = create_mode_factory(ctx, 'gatt-client')
1172    classic = ctx.obj['classic']
1173
1174    asyncio.run(
1175        Central(
1176            transport,
1177            peripheral_address,
1178            classic,
1179            role_factory,
1180            mode_factory,
1181            connection_interval,
1182            phy,
1183        ).run()
1184    )
1185
1186
1187@bench.command()
1188@click.argument('transport')
1189@click.pass_context
1190def peripheral(ctx, transport):
1191    """Run as a peripheral (waits for a connection)"""
1192    role_factory = create_role_factory(ctx, 'receiver')
1193    mode_factory = create_mode_factory(ctx, 'gatt-server')
1194
1195    asyncio.run(
1196        Peripheral(transport, ctx.obj['classic'], role_factory, mode_factory).run()
1197    )
1198
1199
1200def main():
1201    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
1202    bench()
1203
1204
1205# -----------------------------------------------------------------------------
1206if __name__ == "__main__":
1207    main()  # pylint: disable=no-value-for-parameter
1208