• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import enum
20import logging
21import os
22import struct
23import time
24
25import click
26
27from bumble import l2cap
28from bumble.core import (
29    BT_BR_EDR_TRANSPORT,
30    BT_LE_TRANSPORT,
31    BT_L2CAP_PROTOCOL_ID,
32    BT_RFCOMM_PROTOCOL_ID,
33    UUID,
34    CommandTimeoutError,
35)
36from bumble.colors import color
37from bumble.device import Connection, ConnectionParametersPreferences, Device, Peer
38from bumble.gatt import Characteristic, CharacteristicValue, Service
39from bumble.hci import (
40    HCI_LE_1M_PHY,
41    HCI_LE_2M_PHY,
42    HCI_LE_CODED_PHY,
43    HCI_Constant,
44    HCI_Error,
45    HCI_StatusError,
46)
47from bumble.sdp import (
48    SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
49    SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
50    SDP_PUBLIC_BROWSE_ROOT,
51    SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
52    SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
53    DataElement,
54    ServiceAttribute,
55)
56from bumble.transport import open_transport_or_link
57import bumble.rfcomm
58import bumble.core
59from bumble.utils import AsyncRunner
60
61
62# -----------------------------------------------------------------------------
63# Logging
64# -----------------------------------------------------------------------------
65logger = logging.getLogger(__name__)
66
67
68# -----------------------------------------------------------------------------
69# Constants
70# -----------------------------------------------------------------------------
71DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
72DEFAULT_CENTRAL_NAME = 'Speed Central'
73DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
74DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'
75
76SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
77SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
78SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
79
80DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
81DEFAULT_L2CAP_PSM = 128
82DEFAULT_L2CAP_MAX_CREDITS = 128
83DEFAULT_L2CAP_MTU = 1024
84DEFAULT_L2CAP_MPS = 1024
85
86DEFAULT_LINGER_TIME = 1.0
87DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
88
89DEFAULT_RFCOMM_CHANNEL = 8
90DEFAULT_RFCOMM_MTU = 2048
91
92
93# -----------------------------------------------------------------------------
94# Utils
95# -----------------------------------------------------------------------------
96def parse_packet(packet):
97    if len(packet) < 1:
98        logging.info(
99            color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
100        )
101        raise ValueError('packet too short')
102
103    try:
104        packet_type = PacketType(packet[0])
105    except ValueError:
106        logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
107        raise
108
109    return (packet_type, packet[1:])
110
111
112def parse_packet_sequence(packet_data):
113    if len(packet_data) < 5:
114        logging.info(
115            color(
116                f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
117                'red',
118            )
119        )
120        raise ValueError('packet too short')
121    return struct.unpack_from('>bI', packet_data, 0)
122
123
124def le_phy_name(phy_id):
125    return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
126        phy_id, HCI_Constant.le_phy_name(phy_id)
127    )
128
129
130def print_connection(connection):
131    if connection.transport == BT_LE_TRANSPORT:
132        phy_state = (
133            'PHY='
134            f'TX:{le_phy_name(connection.phy.tx_phy)}/'
135            f'RX:{le_phy_name(connection.phy.rx_phy)}'
136        )
137
138        data_length = (
139            'DL=('
140            f'TX:{connection.data_length[0]}/{connection.data_length[1]},'
141            f'RX:{connection.data_length[2]}/{connection.data_length[3]}'
142            ')'
143        )
144        connection_parameters = (
145            'Parameters='
146            f'{connection.parameters.connection_interval * 1.25:.2f}/'
147            f'{connection.parameters.peripheral_latency}/'
148            f'{connection.parameters.supervision_timeout * 10} '
149        )
150
151    else:
152        phy_state = ''
153        data_length = ''
154        connection_parameters = ''
155
156    mtu = connection.att_mtu
157
158    logging.info(
159        f'{color("@@@ Connection:", "yellow")} '
160        f'{connection_parameters} '
161        f'{data_length} '
162        f'{phy_state} '
163        f'MTU={mtu}'
164    )
165
166
167def make_sdp_records(channel):
168    return {
169        0x00010001: [
170            ServiceAttribute(
171                SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
172                DataElement.unsigned_integer_32(0x00010001),
173            ),
174            ServiceAttribute(
175                SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
176                DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
177            ),
178            ServiceAttribute(
179                SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
180                DataElement.sequence([DataElement.uuid(UUID(DEFAULT_RFCOMM_UUID))]),
181            ),
182            ServiceAttribute(
183                SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
184                DataElement.sequence(
185                    [
186                        DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
187                        DataElement.sequence(
188                            [
189                                DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
190                                DataElement.unsigned_integer_8(channel),
191                            ]
192                        ),
193                    ]
194                ),
195            ),
196        ]
197    }
198
199
200def log_stats(title, stats):
201    stats_min = min(stats)
202    stats_max = max(stats)
203    stats_avg = sum(stats) / len(stats)
204    logging.info(
205        color(
206            (
207                f'### {title} stats: '
208                f'min={stats_min:.2f}, '
209                f'max={stats_max:.2f}, '
210                f'average={stats_avg:.2f}'
211            ),
212            'cyan',
213        )
214    )
215
216
217class PacketType(enum.IntEnum):
218    RESET = 0
219    SEQUENCE = 1
220    ACK = 2
221
222
223PACKET_FLAG_LAST = 1
224
225
226# -----------------------------------------------------------------------------
227# Sender
228# -----------------------------------------------------------------------------
229class Sender:
230    def __init__(
231        self,
232        packet_io,
233        start_delay,
234        repeat,
235        repeat_delay,
236        pace,
237        packet_size,
238        packet_count,
239    ):
240        self.tx_start_delay = start_delay
241        self.tx_packet_size = packet_size
242        self.tx_packet_count = packet_count
243        self.packet_io = packet_io
244        self.packet_io.packet_listener = self
245        self.repeat = repeat
246        self.repeat_delay = repeat_delay
247        self.pace = pace
248        self.start_time = 0
249        self.bytes_sent = 0
250        self.stats = []
251        self.done = asyncio.Event()
252
253    def reset(self):
254        pass
255
256    async def run(self):
257        logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
258        await self.packet_io.ready.wait()
259        logging.info(color('--- Go!', 'blue'))
260
261        for run in range(self.repeat + 1):
262            self.done.clear()
263
264            if run > 0 and self.repeat and self.repeat_delay:
265                logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
266                await asyncio.sleep(self.repeat_delay)
267
268            if self.tx_start_delay:
269                logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
270                await asyncio.sleep(self.tx_start_delay)
271
272            logging.info(color('=== Sending RESET', 'magenta'))
273            await self.packet_io.send_packet(bytes([PacketType.RESET]))
274            self.start_time = time.time()
275            self.bytes_sent = 0
276            for tx_i in range(self.tx_packet_count):
277                packet_flags = (
278                    PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
279                )
280                packet = struct.pack(
281                    '>bbI',
282                    PacketType.SEQUENCE,
283                    packet_flags,
284                    tx_i,
285                ) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
286                logging.info(
287                    color(
288                        f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
289                    )
290                )
291                self.bytes_sent += len(packet)
292                await self.packet_io.send_packet(packet)
293
294                if self.pace is None:
295                    continue
296
297                if self.pace > 0:
298                    await asyncio.sleep(self.pace / 1000)
299                else:
300                    await self.packet_io.drain()
301
302            await self.done.wait()
303
304            run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
305            logging.info(color(f'=== {run_counter} Done!', 'magenta'))
306
307            if self.repeat:
308                log_stats('Run', self.stats)
309
310        if self.repeat:
311            logging.info(color('--- End of runs', 'blue'))
312
313    def on_packet_received(self, packet):
314        try:
315            packet_type, _ = parse_packet(packet)
316        except ValueError:
317            return
318
319        if packet_type == PacketType.ACK:
320            elapsed = time.time() - self.start_time
321            average_tx_speed = self.bytes_sent / elapsed
322            self.stats.append(average_tx_speed)
323            logging.info(
324                color(
325                    f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
326                    f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
327                    'green',
328                )
329            )
330            self.done.set()
331
332
333# -----------------------------------------------------------------------------
334# Receiver
335# -----------------------------------------------------------------------------
336class Receiver:
337    expected_packet_index: int
338    start_timestamp: float
339    last_timestamp: float
340
341    def __init__(self, packet_io, linger):
342        self.reset()
343        self.packet_io = packet_io
344        self.packet_io.packet_listener = self
345        self.linger = linger
346        self.done = asyncio.Event()
347
348    def reset(self):
349        self.expected_packet_index = 0
350        self.measurements = [(time.time(), 0)]
351        self.total_bytes_received = 0
352
353    def on_packet_received(self, packet):
354        try:
355            packet_type, packet_data = parse_packet(packet)
356        except ValueError:
357            return
358
359        if packet_type == PacketType.RESET:
360            logging.info(color('=== Received RESET', 'magenta'))
361            self.reset()
362            return
363
364        try:
365            packet_flags, packet_index = parse_packet_sequence(packet_data)
366        except ValueError:
367            return
368        logging.info(
369            f'<<< Received packet {packet_index}: '
370            f'flags=0x{packet_flags:02X}, '
371            f'{len(packet) + self.packet_io.overhead_size} bytes'
372        )
373
374        if packet_index != self.expected_packet_index:
375            logging.info(
376                color(
377                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
378                    f'but received {packet_index}'
379                )
380            )
381
382        now = time.time()
383        elapsed_since_start = now - self.measurements[0][0]
384        elapsed_since_last = now - self.measurements[-1][0]
385        self.measurements.append((now, len(packet)))
386        self.total_bytes_received += len(packet)
387        instant_rx_speed = len(packet) / elapsed_since_last
388        average_rx_speed = self.total_bytes_received / elapsed_since_start
389        window = self.measurements[-64:]
390        windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
391            window[-1][0] - window[0][0]
392        )
393        logging.info(
394            color(
395                'Speed: '
396                f'instant={instant_rx_speed:.4f}, '
397                f'windowed={windowed_rx_speed:.4f}, '
398                f'average={average_rx_speed:.4f}',
399                'yellow',
400            )
401        )
402
403        self.expected_packet_index = packet_index + 1
404
405        if packet_flags & PACKET_FLAG_LAST:
406            AsyncRunner.spawn(
407                self.packet_io.send_packet(
408                    struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
409                )
410            )
411            logging.info(color('@@@ Received last packet', 'green'))
412            if not self.linger:
413                self.done.set()
414
415    async def run(self):
416        await self.done.wait()
417        logging.info(color('=== Done!', 'magenta'))
418
419
420# -----------------------------------------------------------------------------
421# Ping
422# -----------------------------------------------------------------------------
423class Ping:
424    def __init__(
425        self,
426        packet_io,
427        start_delay,
428        repeat,
429        repeat_delay,
430        pace,
431        packet_size,
432        packet_count,
433    ):
434        self.tx_start_delay = start_delay
435        self.tx_packet_size = packet_size
436        self.tx_packet_count = packet_count
437        self.packet_io = packet_io
438        self.packet_io.packet_listener = self
439        self.repeat = repeat
440        self.repeat_delay = repeat_delay
441        self.pace = pace
442        self.done = asyncio.Event()
443        self.current_packet_index = 0
444        self.ping_sent_time = 0.0
445        self.latencies = []
446        self.min_stats = []
447        self.max_stats = []
448        self.avg_stats = []
449
450    def reset(self):
451        pass
452
453    async def run(self):
454        logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
455        await self.packet_io.ready.wait()
456        logging.info(color('--- Go!', 'blue'))
457
458        for run in range(self.repeat + 1):
459            self.done.clear()
460
461            if run > 0 and self.repeat and self.repeat_delay:
462                logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
463                await asyncio.sleep(self.repeat_delay)
464
465            if self.tx_start_delay:
466                logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
467                await asyncio.sleep(self.tx_start_delay)
468
469            logging.info(color('=== Sending RESET', 'magenta'))
470            await self.packet_io.send_packet(bytes([PacketType.RESET]))
471
472            self.current_packet_index = 0
473            self.latencies = []
474            await self.send_next_ping()
475
476            await self.done.wait()
477
478            min_latency = min(self.latencies)
479            max_latency = max(self.latencies)
480            avg_latency = sum(self.latencies) / len(self.latencies)
481            logging.info(
482                color(
483                    '@@@ Latencies: '
484                    f'min={min_latency:.2f}, '
485                    f'max={max_latency:.2f}, '
486                    f'average={avg_latency:.2f}'
487                )
488            )
489
490            self.min_stats.append(min_latency)
491            self.max_stats.append(max_latency)
492            self.avg_stats.append(avg_latency)
493
494            run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
495            logging.info(color(f'=== {run_counter} Done!', 'magenta'))
496
497            if self.repeat:
498                log_stats('Min Latency', self.min_stats)
499                log_stats('Max Latency', self.max_stats)
500                log_stats('Average Latency', self.avg_stats)
501
502        if self.repeat:
503            logging.info(color('--- End of runs', 'blue'))
504
505    async def send_next_ping(self):
506        if self.pace:
507            await asyncio.sleep(self.pace / 1000)
508
509        packet = struct.pack(
510            '>bbI',
511            PacketType.SEQUENCE,
512            (
513                PACKET_FLAG_LAST
514                if self.current_packet_index == self.tx_packet_count - 1
515                else 0
516            ),
517            self.current_packet_index,
518        ) + bytes(self.tx_packet_size - 6)
519        logging.info(color(f'Sending packet {self.current_packet_index}', 'yellow'))
520        self.ping_sent_time = time.time()
521        await self.packet_io.send_packet(packet)
522
523    def on_packet_received(self, packet):
524        elapsed = time.time() - self.ping_sent_time
525
526        try:
527            packet_type, packet_data = parse_packet(packet)
528        except ValueError:
529            return
530
531        try:
532            packet_flags, packet_index = parse_packet_sequence(packet_data)
533        except ValueError:
534            return
535
536        if packet_type == PacketType.ACK:
537            latency = elapsed * 1000
538            self.latencies.append(latency)
539            logging.info(
540                color(
541                    f'<<< Received ACK [{packet_index}], latency={latency:.2f}ms',
542                    'green',
543                )
544            )
545
546            if packet_index == self.current_packet_index:
547                self.current_packet_index += 1
548            else:
549                logging.info(
550                    color(
551                        f'!!! Unexpected packet, expected {self.current_packet_index} '
552                        f'but received {packet_index}'
553                    )
554                )
555
556        if packet_flags & PACKET_FLAG_LAST:
557            self.done.set()
558            return
559
560        AsyncRunner.spawn(self.send_next_ping())
561
562
563# -----------------------------------------------------------------------------
564# Pong
565# -----------------------------------------------------------------------------
566class Pong:
567    expected_packet_index: int
568
569    def __init__(self, packet_io, linger):
570        self.reset()
571        self.packet_io = packet_io
572        self.packet_io.packet_listener = self
573        self.linger = linger
574        self.done = asyncio.Event()
575
576    def reset(self):
577        self.expected_packet_index = 0
578
579    def on_packet_received(self, packet):
580        try:
581            packet_type, packet_data = parse_packet(packet)
582        except ValueError:
583            return
584
585        if packet_type == PacketType.RESET:
586            logging.info(color('=== Received RESET', 'magenta'))
587            self.reset()
588            return
589
590        try:
591            packet_flags, packet_index = parse_packet_sequence(packet_data)
592        except ValueError:
593            return
594        logging.info(
595            color(
596                f'<<< Received packet {packet_index}: '
597                f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
598                'green',
599            )
600        )
601
602        if packet_index != self.expected_packet_index:
603            logging.info(
604                color(
605                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
606                    f'but received {packet_index}'
607                )
608            )
609
610        self.expected_packet_index = packet_index + 1
611
612        AsyncRunner.spawn(
613            self.packet_io.send_packet(
614                struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
615            )
616        )
617
618        if packet_flags & PACKET_FLAG_LAST and not self.linger:
619            self.done.set()
620
621    async def run(self):
622        await self.done.wait()
623        logging.info(color('=== Done!', 'magenta'))
624
625
626# -----------------------------------------------------------------------------
627# GattClient
628# -----------------------------------------------------------------------------
629class GattClient:
630    def __init__(self, _device, att_mtu=None):
631        self.att_mtu = att_mtu
632        self.speed_rx = None
633        self.speed_tx = None
634        self.packet_listener = None
635        self.ready = asyncio.Event()
636        self.overhead_size = 0
637
638    async def on_connection(self, connection):
639        peer = Peer(connection)
640
641        if self.att_mtu:
642            logging.info(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
643            await peer.request_mtu(self.att_mtu)
644
645        logging.info(color('*** Discovering services...', 'blue'))
646        await peer.discover_services()
647
648        speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
649        if not speed_services:
650            logging.info(color('!!! Speed Service not found', 'red'))
651            return
652        speed_service = speed_services[0]
653        logging.info(color('*** Discovering characteristics...', 'blue'))
654        await speed_service.discover_characteristics()
655
656        speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
657        if not speed_txs:
658            logging.info(color('!!! Speed TX not found', 'red'))
659            return
660        self.speed_tx = speed_txs[0]
661
662        speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
663        if not speed_rxs:
664            logging.info(color('!!! Speed RX not found', 'red'))
665            return
666        self.speed_rx = speed_rxs[0]
667
668        logging.info(color('*** Subscribing to RX', 'blue'))
669        await self.speed_rx.subscribe(self.on_packet_received)
670
671        logging.info(color('*** Discovery complete', 'blue'))
672
673        connection.on('disconnection', self.on_disconnection)
674        self.ready.set()
675
676    def on_disconnection(self, _):
677        self.ready.clear()
678
679    def on_packet_received(self, packet):
680        if self.packet_listener:
681            self.packet_listener.on_packet_received(packet)
682
683    async def send_packet(self, packet):
684        await self.speed_tx.write_value(packet)
685
686    async def drain(self):
687        pass
688
689
690# -----------------------------------------------------------------------------
691# GattServer
692# -----------------------------------------------------------------------------
693class GattServer:
694    def __init__(self, device):
695        self.device = device
696        self.packet_listener = None
697        self.ready = asyncio.Event()
698        self.overhead_size = 0
699
700        # Setup the GATT service
701        self.speed_tx = Characteristic(
702            SPEED_TX_UUID,
703            Characteristic.Properties.WRITE,
704            Characteristic.WRITEABLE,
705            CharacteristicValue(write=self.on_tx_write),
706        )
707        self.speed_rx = Characteristic(
708            SPEED_RX_UUID, Characteristic.Properties.NOTIFY, 0
709        )
710
711        speed_service = Service(
712            SPEED_SERVICE_UUID,
713            [self.speed_tx, self.speed_rx],
714        )
715        device.add_services([speed_service])
716
717        self.speed_rx.on('subscription', self.on_rx_subscription)
718
719    async def on_connection(self, connection):
720        connection.on('disconnection', self.on_disconnection)
721
722    def on_disconnection(self, _):
723        self.ready.clear()
724
725    def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
726        if notify_enabled:
727            logging.info(color('*** RX subscription', 'blue'))
728            self.ready.set()
729        else:
730            logging.info(color('*** RX un-subscription', 'blue'))
731            self.ready.clear()
732
733    def on_tx_write(self, _, value):
734        if self.packet_listener:
735            self.packet_listener.on_packet_received(value)
736
737    async def send_packet(self, packet):
738        await self.device.notify_subscribers(self.speed_rx, packet)
739
740    async def drain(self):
741        pass
742
743
744# -----------------------------------------------------------------------------
745# StreamedPacketIO
746# -----------------------------------------------------------------------------
747class StreamedPacketIO:
748    def __init__(self):
749        self.packet_listener = None
750        self.io_sink = None
751        self.rx_packet = b''
752        self.rx_packet_header = b''
753        self.rx_packet_need = 0
754        self.overhead_size = 2
755
756    def on_packet(self, packet):
757        while packet:
758            if self.rx_packet_need:
759                chunk = packet[: self.rx_packet_need]
760                self.rx_packet += chunk
761                packet = packet[len(chunk) :]
762                self.rx_packet_need -= len(chunk)
763                if not self.rx_packet_need:
764                    # Packet completed
765                    if self.packet_listener:
766                        self.packet_listener.on_packet_received(self.rx_packet)
767
768                    self.rx_packet = b''
769                    self.rx_packet_header = b''
770            else:
771                # Expect the next packet
772                header_bytes_needed = 2 - len(self.rx_packet_header)
773                header_bytes = packet[:header_bytes_needed]
774                self.rx_packet_header += header_bytes
775                if len(self.rx_packet_header) != 2:
776                    return
777                packet = packet[len(header_bytes) :]
778                self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0]
779
780    async def send_packet(self, packet):
781        if not self.io_sink:
782            logging.info(color('!!! No sink, dropping packet', 'red'))
783            return
784
785        # pylint: disable-next=not-callable
786        self.io_sink(struct.pack('>H', len(packet)) + packet)
787
788
789# -----------------------------------------------------------------------------
790# L2capClient
791# -----------------------------------------------------------------------------
792class L2capClient(StreamedPacketIO):
793    def __init__(
794        self,
795        _device,
796        psm=DEFAULT_L2CAP_PSM,
797        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
798        mtu=DEFAULT_L2CAP_MTU,
799        mps=DEFAULT_L2CAP_MPS,
800    ):
801        super().__init__()
802        self.psm = psm
803        self.max_credits = max_credits
804        self.mtu = mtu
805        self.mps = mps
806        self.l2cap_channel = None
807        self.ready = asyncio.Event()
808
809    async def on_connection(self, connection: Connection) -> None:
810        connection.on('disconnection', self.on_disconnection)
811
812        # Connect a new L2CAP channel
813        logging.info(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
814        try:
815            l2cap_channel = await connection.create_l2cap_channel(
816                spec=l2cap.LeCreditBasedChannelSpec(
817                    psm=self.psm,
818                    max_credits=self.max_credits,
819                    mtu=self.mtu,
820                    mps=self.mps,
821                )
822            )
823            logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
824        except Exception as error:
825            logging.info(color(f'!!! Connection failed: {error}', 'red'))
826            return
827
828        self.io_sink = l2cap_channel.write
829        self.l2cap_channel = l2cap_channel
830        l2cap_channel.on('close', self.on_l2cap_close)
831        l2cap_channel.sink = self.on_packet
832
833        self.ready.set()
834
835    def on_disconnection(self, _):
836        pass
837
838    def on_l2cap_close(self):
839        logging.info(color('*** L2CAP channel closed', 'red'))
840
841    async def drain(self):
842        assert self.l2cap_channel
843        await self.l2cap_channel.drain()
844
845
846# -----------------------------------------------------------------------------
847# L2capServer
848# -----------------------------------------------------------------------------
849class L2capServer(StreamedPacketIO):
850    def __init__(
851        self,
852        device: Device,
853        psm=DEFAULT_L2CAP_PSM,
854        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
855        mtu=DEFAULT_L2CAP_MTU,
856        mps=DEFAULT_L2CAP_MPS,
857    ):
858        super().__init__()
859        self.l2cap_channel = None
860        self.ready = asyncio.Event()
861
862        # Listen for incoming L2CAP connections
863        device.create_l2cap_server(
864            spec=l2cap.LeCreditBasedChannelSpec(
865                psm=psm, mtu=mtu, mps=mps, max_credits=max_credits
866            ),
867            handler=self.on_l2cap_channel,
868        )
869        logging.info(
870            color(f'### Listening for L2CAP connection on PSM {psm}', 'yellow')
871        )
872
873    async def on_connection(self, connection):
874        connection.on('disconnection', self.on_disconnection)
875
876    def on_disconnection(self, _):
877        pass
878
879    def on_l2cap_channel(self, l2cap_channel):
880        logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
881
882        self.io_sink = l2cap_channel.write
883        self.l2cap_channel = l2cap_channel
884        l2cap_channel.on('close', self.on_l2cap_close)
885        l2cap_channel.sink = self.on_packet
886
887        self.ready.set()
888
889    def on_l2cap_close(self):
890        logging.info(color('*** L2CAP channel closed', 'red'))
891        self.l2cap_channel = None
892
893    async def drain(self):
894        assert self.l2cap_channel
895        await self.l2cap_channel.drain()
896
897
898# -----------------------------------------------------------------------------
899# RfcommClient
900# -----------------------------------------------------------------------------
901class RfcommClient(StreamedPacketIO):
902    def __init__(
903        self,
904        device,
905        channel,
906        uuid,
907        l2cap_mtu,
908        max_frame_size,
909        initial_credits,
910        max_credits,
911        credits_threshold,
912    ):
913        super().__init__()
914        self.device = device
915        self.channel = channel
916        self.uuid = uuid
917        self.l2cap_mtu = l2cap_mtu
918        self.max_frame_size = max_frame_size
919        self.initial_credits = initial_credits
920        self.max_credits = max_credits
921        self.credits_threshold = credits_threshold
922        self.rfcomm_session = None
923        self.ready = asyncio.Event()
924
925    async def on_connection(self, connection):
926        connection.on('disconnection', self.on_disconnection)
927
928        # Find the channel number if not specified
929        channel = self.channel
930        if channel == 0:
931            logging.info(
932                color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
933            )
934            channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
935                connection, self.uuid
936            )
937            logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
938            if channel == 0:
939                logging.info(color('!!! No RFComm service with this UUID found', 'red'))
940                await connection.disconnect()
941                return
942
943        # Create a client and start it
944        logging.info(color('*** Starting RFCOMM client...', 'blue'))
945        rfcomm_options = {}
946        if self.l2cap_mtu:
947            rfcomm_options['l2cap_mtu'] = self.l2cap_mtu
948        rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options)
949        rfcomm_mux = await rfcomm_client.start()
950        logging.info(color('*** Started', 'blue'))
951
952        logging.info(color(f'### Opening session for channel {channel}...', 'yellow'))
953        try:
954            dlc_options = {}
955            if self.max_frame_size is not None:
956                dlc_options['max_frame_size'] = self.max_frame_size
957            if self.initial_credits is not None:
958                dlc_options['initial_credits'] = self.initial_credits
959            rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options)
960            logging.info(color(f'### Session open: {rfcomm_session}', 'yellow'))
961            if self.max_credits is not None:
962                rfcomm_session.rx_max_credits = self.max_credits
963            if self.credits_threshold is not None:
964                rfcomm_session.rx_credits_threshold = self.credits_threshold
965
966        except bumble.core.ConnectionError as error:
967            logging.info(color(f'!!! Session open failed: {error}', 'red'))
968            await rfcomm_mux.disconnect()
969            return
970
971        rfcomm_session.sink = self.on_packet
972        self.io_sink = rfcomm_session.write
973        self.rfcomm_session = rfcomm_session
974
975        self.ready.set()
976
977    def on_disconnection(self, _):
978        pass
979
980    async def drain(self):
981        assert self.rfcomm_session
982        await self.rfcomm_session.drain()
983
984
985# -----------------------------------------------------------------------------
986# RfcommServer
987# -----------------------------------------------------------------------------
988class RfcommServer(StreamedPacketIO):
989    def __init__(
990        self,
991        device,
992        channel,
993        l2cap_mtu,
994        max_frame_size,
995        initial_credits,
996        max_credits,
997        credits_threshold,
998    ):
999        super().__init__()
1000        self.max_credits = max_credits
1001        self.credits_threshold = credits_threshold
1002        self.dlc = None
1003        self.ready = asyncio.Event()
1004
1005        # Create and register a server
1006        server_options = {}
1007        if l2cap_mtu:
1008            server_options['l2cap_mtu'] = l2cap_mtu
1009        rfcomm_server = bumble.rfcomm.Server(device, **server_options)
1010
1011        # Listen for incoming DLC connections
1012        dlc_options = {}
1013        if max_frame_size is not None:
1014            dlc_options['max_frame_size'] = max_frame_size
1015        if initial_credits is not None:
1016            dlc_options['initial_credits'] = initial_credits
1017        channel_number = rfcomm_server.listen(self.on_dlc, channel, **dlc_options)
1018
1019        # Setup the SDP to advertise this channel
1020        device.sdp_service_records = make_sdp_records(channel_number)
1021
1022        logging.info(
1023            color(
1024                f'### Listening for RFComm connection on channel {channel_number}',
1025                'yellow',
1026            )
1027        )
1028
1029    async def on_connection(self, connection):
1030        connection.on('disconnection', self.on_disconnection)
1031
1032    def on_disconnection(self, _):
1033        pass
1034
1035    def on_dlc(self, dlc):
1036        logging.info(color(f'*** DLC connected: {dlc}', 'blue'))
1037        dlc.sink = self.on_packet
1038        self.io_sink = dlc.write
1039        self.dlc = dlc
1040        if self.max_credits is not None:
1041            dlc.rx_max_credits = self.max_credits
1042        if self.credits_threshold is not None:
1043            dlc.rx_credits_threshold = self.credits_threshold
1044
1045    async def drain(self):
1046        assert self.dlc
1047        await self.dlc.drain()
1048
1049
1050# -----------------------------------------------------------------------------
1051# Central
1052# -----------------------------------------------------------------------------
1053class Central(Connection.Listener):
1054    def __init__(
1055        self,
1056        transport,
1057        peripheral_address,
1058        classic,
1059        role_factory,
1060        mode_factory,
1061        connection_interval,
1062        phy,
1063        authenticate,
1064        encrypt,
1065        extended_data_length,
1066    ):
1067        super().__init__()
1068        self.transport = transport
1069        self.peripheral_address = peripheral_address
1070        self.classic = classic
1071        self.role_factory = role_factory
1072        self.mode_factory = mode_factory
1073        self.authenticate = authenticate
1074        self.encrypt = encrypt or authenticate
1075        self.extended_data_length = extended_data_length
1076        self.device = None
1077        self.connection = None
1078
1079        if phy:
1080            self.phy = {
1081                '1m': HCI_LE_1M_PHY,
1082                '2m': HCI_LE_2M_PHY,
1083                'coded': HCI_LE_CODED_PHY,
1084            }[phy]
1085        else:
1086            self.phy = None
1087
1088        if connection_interval:
1089            connection_parameter_preferences = ConnectionParametersPreferences()
1090            connection_parameter_preferences.connection_interval_min = (
1091                connection_interval
1092            )
1093            connection_parameter_preferences.connection_interval_max = (
1094                connection_interval
1095            )
1096
1097            # Preferences for the 1M PHY are always set.
1098            self.connection_parameter_preferences = {
1099                HCI_LE_1M_PHY: connection_parameter_preferences,
1100            }
1101
1102            if self.phy not in (None, HCI_LE_1M_PHY):
1103                # Add an connections parameters entry for this PHY.
1104                self.connection_parameter_preferences[self.phy] = (
1105                    connection_parameter_preferences
1106                )
1107        else:
1108            self.connection_parameter_preferences = None
1109
1110    async def run(self):
1111        logging.info(color('>>> Connecting to HCI...', 'green'))
1112        async with await open_transport_or_link(self.transport) as (
1113            hci_source,
1114            hci_sink,
1115        ):
1116            logging.info(color('>>> Connected', 'green'))
1117
1118            central_address = DEFAULT_CENTRAL_ADDRESS
1119            self.device = Device.with_hci(
1120                DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
1121            )
1122            mode = self.mode_factory(self.device)
1123            role = self.role_factory(mode)
1124            self.device.classic_enabled = self.classic
1125
1126            await self.device.power_on()
1127
1128            if self.classic:
1129                await self.device.set_discoverable(False)
1130                await self.device.set_connectable(False)
1131
1132            logging.info(
1133                color(f'### Connecting to {self.peripheral_address}...', 'cyan')
1134            )
1135            try:
1136                self.connection = await self.device.connect(
1137                    self.peripheral_address,
1138                    connection_parameters_preferences=self.connection_parameter_preferences,
1139                    transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT,
1140                )
1141            except CommandTimeoutError:
1142                logging.info(color('!!! Connection timed out', 'red'))
1143                return
1144            except bumble.core.ConnectionError as error:
1145                logging.info(color(f'!!! Connection error: {error}', 'red'))
1146                return
1147            except HCI_StatusError as error:
1148                logging.info(color(f'!!! Connection failed: {error.error_name}'))
1149                return
1150            logging.info(color('### Connected', 'cyan'))
1151            self.connection.listener = self
1152            print_connection(self.connection)
1153
1154            # Wait a bit after the connection, some controllers aren't very good when
1155            # we start sending data right away while some connection parameters are
1156            # updated post connection
1157            await asyncio.sleep(DEFAULT_POST_CONNECTION_WAIT_TIME)
1158
1159            # Request a new data length if requested
1160            if self.extended_data_length:
1161                logging.info(color('+++ Requesting extended data length', 'cyan'))
1162                await self.connection.set_data_length(
1163                    self.extended_data_length[0], self.extended_data_length[1]
1164                )
1165
1166            # Authenticate if requested
1167            if self.authenticate:
1168                # Request authentication
1169                logging.info(color('*** Authenticating...', 'cyan'))
1170                await self.connection.authenticate()
1171                logging.info(color('*** Authenticated', 'cyan'))
1172
1173            # Encrypt if requested
1174            if self.encrypt:
1175                # Enable encryption
1176                logging.info(color('*** Enabling encryption...', 'cyan'))
1177                await self.connection.encrypt()
1178                logging.info(color('*** Encryption on', 'cyan'))
1179
1180            # Set the PHY if requested
1181            if self.phy is not None:
1182                try:
1183                    await self.connection.set_phy(
1184                        tx_phys=[self.phy], rx_phys=[self.phy]
1185                    )
1186                except HCI_Error as error:
1187                    logging.info(
1188                        color(
1189                            f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
1190                        )
1191                    )
1192
1193            await mode.on_connection(self.connection)
1194
1195            await role.run()
1196            await asyncio.sleep(DEFAULT_LINGER_TIME)
1197            await self.connection.disconnect()
1198
1199    def on_disconnection(self, reason):
1200        logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
1201        self.connection = None
1202
1203    def on_connection_parameters_update(self):
1204        print_connection(self.connection)
1205
1206    def on_connection_phy_update(self):
1207        print_connection(self.connection)
1208
1209    def on_connection_att_mtu_update(self):
1210        print_connection(self.connection)
1211
1212    def on_connection_data_length_change(self):
1213        print_connection(self.connection)
1214
1215
1216# -----------------------------------------------------------------------------
1217# Peripheral
1218# -----------------------------------------------------------------------------
1219class Peripheral(Device.Listener, Connection.Listener):
1220    def __init__(
1221        self, transport, classic, extended_data_length, role_factory, mode_factory
1222    ):
1223        self.transport = transport
1224        self.classic = classic
1225        self.extended_data_length = extended_data_length
1226        self.role_factory = role_factory
1227        self.role = None
1228        self.mode_factory = mode_factory
1229        self.mode = None
1230        self.device = None
1231        self.connection = None
1232        self.connected = asyncio.Event()
1233
1234    async def run(self):
1235        logging.info(color('>>> Connecting to HCI...', 'green'))
1236        async with await open_transport_or_link(self.transport) as (
1237            hci_source,
1238            hci_sink,
1239        ):
1240            logging.info(color('>>> Connected', 'green'))
1241
1242            peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
1243            self.device = Device.with_hci(
1244                DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink
1245            )
1246            self.device.listener = self
1247            self.mode = self.mode_factory(self.device)
1248            self.role = self.role_factory(self.mode)
1249            self.device.classic_enabled = self.classic
1250
1251            await self.device.power_on()
1252
1253            if self.classic:
1254                await self.device.set_discoverable(True)
1255                await self.device.set_connectable(True)
1256            else:
1257                await self.device.start_advertising(auto_restart=True)
1258
1259            if self.classic:
1260                logging.info(
1261                    color(
1262                        '### Waiting for connection on'
1263                        f' {self.device.public_address}...',
1264                        'cyan',
1265                    )
1266                )
1267            else:
1268                logging.info(
1269                    color(
1270                        f'### Waiting for connection on {peripheral_address}...',
1271                        'cyan',
1272                    )
1273                )
1274
1275            await self.connected.wait()
1276            logging.info(color('### Connected', 'cyan'))
1277
1278            await self.mode.on_connection(self.connection)
1279            await self.role.run()
1280            await asyncio.sleep(DEFAULT_LINGER_TIME)
1281
1282    def on_connection(self, connection):
1283        connection.listener = self
1284        self.connection = connection
1285        self.connected.set()
1286
1287        # Stop being discoverable and connectable
1288        if self.classic:
1289            AsyncRunner.spawn(self.device.set_discoverable(False))
1290            AsyncRunner.spawn(self.device.set_connectable(False))
1291
1292        # Request a new data length if needed
1293        if self.extended_data_length:
1294            logging.info("+++ Requesting extended data length")
1295            AsyncRunner.spawn(
1296                connection.set_data_length(
1297                    self.extended_data_length[0], self.extended_data_length[1]
1298                )
1299            )
1300
1301    def on_disconnection(self, reason):
1302        logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
1303        self.connection = None
1304        self.role.reset()
1305
1306        if self.classic:
1307            AsyncRunner.spawn(self.device.set_discoverable(True))
1308            AsyncRunner.spawn(self.device.set_connectable(True))
1309
1310    def on_connection_parameters_update(self):
1311        print_connection(self.connection)
1312
1313    def on_connection_phy_update(self):
1314        print_connection(self.connection)
1315
1316    def on_connection_att_mtu_update(self):
1317        print_connection(self.connection)
1318
1319    def on_connection_data_length_change(self):
1320        print_connection(self.connection)
1321
1322
1323# -----------------------------------------------------------------------------
1324def create_mode_factory(ctx, default_mode):
1325    mode = ctx.obj['mode']
1326    if mode is None:
1327        mode = default_mode
1328
1329    def create_mode(device):
1330        if mode == 'gatt-client':
1331            return GattClient(device, att_mtu=ctx.obj['att_mtu'])
1332
1333        if mode == 'gatt-server':
1334            return GattServer(device)
1335
1336        if mode == 'l2cap-client':
1337            return L2capClient(
1338                device,
1339                psm=ctx.obj['l2cap_psm'],
1340                mtu=ctx.obj['l2cap_mtu'],
1341                mps=ctx.obj['l2cap_mps'],
1342                max_credits=ctx.obj['l2cap_max_credits'],
1343            )
1344
1345        if mode == 'l2cap-server':
1346            return L2capServer(
1347                device,
1348                psm=ctx.obj['l2cap_psm'],
1349                mtu=ctx.obj['l2cap_mtu'],
1350                mps=ctx.obj['l2cap_mps'],
1351                max_credits=ctx.obj['l2cap_max_credits'],
1352            )
1353
1354        if mode == 'rfcomm-client':
1355            return RfcommClient(
1356                device,
1357                channel=ctx.obj['rfcomm_channel'],
1358                uuid=ctx.obj['rfcomm_uuid'],
1359                l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
1360                max_frame_size=ctx.obj['rfcomm_max_frame_size'],
1361                initial_credits=ctx.obj['rfcomm_initial_credits'],
1362                max_credits=ctx.obj['rfcomm_max_credits'],
1363                credits_threshold=ctx.obj['rfcomm_credits_threshold'],
1364            )
1365
1366        if mode == 'rfcomm-server':
1367            return RfcommServer(
1368                device,
1369                channel=ctx.obj['rfcomm_channel'],
1370                l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
1371                max_frame_size=ctx.obj['rfcomm_max_frame_size'],
1372                initial_credits=ctx.obj['rfcomm_initial_credits'],
1373                max_credits=ctx.obj['rfcomm_max_credits'],
1374                credits_threshold=ctx.obj['rfcomm_credits_threshold'],
1375            )
1376
1377        raise ValueError('invalid mode')
1378
1379    return create_mode
1380
1381
1382# -----------------------------------------------------------------------------
1383def create_role_factory(ctx, default_role):
1384    role = ctx.obj['role']
1385    if role is None:
1386        role = default_role
1387
1388    def create_role(packet_io):
1389        if role == 'sender':
1390            return Sender(
1391                packet_io,
1392                start_delay=ctx.obj['start_delay'],
1393                repeat=ctx.obj['repeat'],
1394                repeat_delay=ctx.obj['repeat_delay'],
1395                pace=ctx.obj['pace'],
1396                packet_size=ctx.obj['packet_size'],
1397                packet_count=ctx.obj['packet_count'],
1398            )
1399
1400        if role == 'receiver':
1401            return Receiver(packet_io, ctx.obj['linger'])
1402
1403        if role == 'ping':
1404            return Ping(
1405                packet_io,
1406                start_delay=ctx.obj['start_delay'],
1407                repeat=ctx.obj['repeat'],
1408                repeat_delay=ctx.obj['repeat_delay'],
1409                pace=ctx.obj['pace'],
1410                packet_size=ctx.obj['packet_size'],
1411                packet_count=ctx.obj['packet_count'],
1412            )
1413
1414        if role == 'pong':
1415            return Pong(packet_io, ctx.obj['linger'])
1416
1417        raise ValueError('invalid role')
1418
1419    return create_role
1420
1421
1422# -----------------------------------------------------------------------------
1423# Main
1424# -----------------------------------------------------------------------------
1425@click.group()
1426@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
1427@click.option('--role', type=click.Choice(['sender', 'receiver', 'ping', 'pong']))
1428@click.option(
1429    '--mode',
1430    type=click.Choice(
1431        [
1432            'gatt-client',
1433            'gatt-server',
1434            'l2cap-client',
1435            'l2cap-server',
1436            'rfcomm-client',
1437            'rfcomm-server',
1438        ]
1439    ),
1440)
1441@click.option(
1442    '--att-mtu',
1443    metavar='MTU',
1444    type=click.IntRange(23, 517),
1445    help='GATT MTU (gatt-client mode)',
1446)
1447@click.option(
1448    '--extended-data-length',
1449    help='Request a data length upon connection, specified as tx_octets/tx_time',
1450)
1451@click.option(
1452    '--rfcomm-channel',
1453    type=int,
1454    default=DEFAULT_RFCOMM_CHANNEL,
1455    help='RFComm channel to use',
1456)
1457@click.option(
1458    '--rfcomm-uuid',
1459    default=DEFAULT_RFCOMM_UUID,
1460    help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)',
1461)
1462@click.option(
1463    '--rfcomm-l2cap-mtu',
1464    type=int,
1465    help='RFComm L2CAP MTU',
1466)
1467@click.option(
1468    '--rfcomm-max-frame-size',
1469    type=int,
1470    help='RFComm maximum frame size',
1471)
1472@click.option(
1473    '--rfcomm-initial-credits',
1474    type=int,
1475    help='RFComm initial credits',
1476)
1477@click.option(
1478    '--rfcomm-max-credits',
1479    type=int,
1480    help='RFComm max credits',
1481)
1482@click.option(
1483    '--rfcomm-credits-threshold',
1484    type=int,
1485    help='RFComm credits threshold',
1486)
1487@click.option(
1488    '--l2cap-psm',
1489    type=int,
1490    default=DEFAULT_L2CAP_PSM,
1491    help='L2CAP PSM to use',
1492)
1493@click.option(
1494    '--l2cap-mtu',
1495    type=int,
1496    default=DEFAULT_L2CAP_MTU,
1497    help='L2CAP MTU to use',
1498)
1499@click.option(
1500    '--l2cap-mps',
1501    type=int,
1502    default=DEFAULT_L2CAP_MPS,
1503    help='L2CAP MPS to use',
1504)
1505@click.option(
1506    '--l2cap-max-credits',
1507    type=int,
1508    default=DEFAULT_L2CAP_MAX_CREDITS,
1509    help='L2CAP maximum number of credits allowed for the peer',
1510)
1511@click.option(
1512    '--packet-size',
1513    '-s',
1514    metavar='SIZE',
1515    type=click.IntRange(8, 4096),
1516    default=500,
1517    help='Packet size (client or ping role)',
1518)
1519@click.option(
1520    '--packet-count',
1521    '-c',
1522    metavar='COUNT',
1523    type=int,
1524    default=10,
1525    help='Packet count (client or ping role)',
1526)
1527@click.option(
1528    '--start-delay',
1529    '-sd',
1530    metavar='SECONDS',
1531    type=int,
1532    default=1,
1533    help='Start delay (client or ping role)',
1534)
1535@click.option(
1536    '--repeat',
1537    metavar='N',
1538    type=int,
1539    default=0,
1540    help=(
1541        'Repeat the run N times (client and ping roles)'
1542        '(0, which is the fault, to run just once) '
1543    ),
1544)
1545@click.option(
1546    '--repeat-delay',
1547    metavar='SECONDS',
1548    type=int,
1549    default=1,
1550    help=('Delay, in seconds, between repeats'),
1551)
1552@click.option(
1553    '--pace',
1554    metavar='MILLISECONDS',
1555    type=int,
1556    default=0,
1557    help=(
1558        'Wait N milliseconds between packets '
1559        '(0, which is the fault, to send as fast as possible) '
1560    ),
1561)
1562@click.option(
1563    '--linger',
1564    is_flag=True,
1565    help="Don't exit at the end of a run (server and pong roles)",
1566)
1567@click.pass_context
1568def bench(
1569    ctx,
1570    device_config,
1571    role,
1572    mode,
1573    att_mtu,
1574    extended_data_length,
1575    packet_size,
1576    packet_count,
1577    start_delay,
1578    repeat,
1579    repeat_delay,
1580    pace,
1581    linger,
1582    rfcomm_channel,
1583    rfcomm_uuid,
1584    rfcomm_l2cap_mtu,
1585    rfcomm_max_frame_size,
1586    rfcomm_initial_credits,
1587    rfcomm_max_credits,
1588    rfcomm_credits_threshold,
1589    l2cap_psm,
1590    l2cap_mtu,
1591    l2cap_mps,
1592    l2cap_max_credits,
1593):
1594    ctx.ensure_object(dict)
1595    ctx.obj['device_config'] = device_config
1596    ctx.obj['role'] = role
1597    ctx.obj['mode'] = mode
1598    ctx.obj['att_mtu'] = att_mtu
1599    ctx.obj['rfcomm_channel'] = rfcomm_channel
1600    ctx.obj['rfcomm_uuid'] = rfcomm_uuid
1601    ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu
1602    ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size
1603    ctx.obj['rfcomm_initial_credits'] = rfcomm_initial_credits
1604    ctx.obj['rfcomm_max_credits'] = rfcomm_max_credits
1605    ctx.obj['rfcomm_credits_threshold'] = rfcomm_credits_threshold
1606    ctx.obj['l2cap_psm'] = l2cap_psm
1607    ctx.obj['l2cap_mtu'] = l2cap_mtu
1608    ctx.obj['l2cap_mps'] = l2cap_mps
1609    ctx.obj['l2cap_max_credits'] = l2cap_max_credits
1610    ctx.obj['packet_size'] = packet_size
1611    ctx.obj['packet_count'] = packet_count
1612    ctx.obj['start_delay'] = start_delay
1613    ctx.obj['repeat'] = repeat
1614    ctx.obj['repeat_delay'] = repeat_delay
1615    ctx.obj['pace'] = pace
1616    ctx.obj['linger'] = linger
1617
1618    ctx.obj['extended_data_length'] = (
1619        [int(x) for x in extended_data_length.split('/')]
1620        if extended_data_length
1621        else None
1622    )
1623    ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')
1624
1625
1626@bench.command()
1627@click.argument('transport')
1628@click.option(
1629    '--peripheral',
1630    'peripheral_address',
1631    metavar='ADDRESS_OR_NAME',
1632    default=DEFAULT_PERIPHERAL_ADDRESS,
1633    help='Address or name to connect to',
1634)
1635@click.option(
1636    '--connection-interval',
1637    '--ci',
1638    metavar='CONNECTION_INTERVAL',
1639    type=int,
1640    help='Connection interval (in ms)',
1641)
1642@click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
1643@click.option('--authenticate', is_flag=True, help='Authenticate (RFComm only)')
1644@click.option('--encrypt', is_flag=True, help='Encrypt the connection (RFComm only)')
1645@click.pass_context
1646def central(
1647    ctx, transport, peripheral_address, connection_interval, phy, authenticate, encrypt
1648):
1649    """Run as a central (initiates the connection)"""
1650    role_factory = create_role_factory(ctx, 'sender')
1651    mode_factory = create_mode_factory(ctx, 'gatt-client')
1652    classic = ctx.obj['classic']
1653
1654    async def run_central():
1655        await Central(
1656            transport,
1657            peripheral_address,
1658            classic,
1659            role_factory,
1660            mode_factory,
1661            connection_interval,
1662            phy,
1663            authenticate,
1664            encrypt or authenticate,
1665            ctx.obj['extended_data_length'],
1666        ).run()
1667
1668    asyncio.run(run_central())
1669
1670
1671@bench.command()
1672@click.argument('transport')
1673@click.pass_context
1674def peripheral(ctx, transport):
1675    """Run as a peripheral (waits for a connection)"""
1676    role_factory = create_role_factory(ctx, 'receiver')
1677    mode_factory = create_mode_factory(ctx, 'gatt-server')
1678
1679    async def run_peripheral():
1680        await Peripheral(
1681            transport,
1682            ctx.obj['classic'],
1683            ctx.obj['extended_data_length'],
1684            role_factory,
1685            mode_factory,
1686        ).run()
1687
1688    asyncio.run(run_peripheral())
1689
1690
1691def main():
1692    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
1693    bench()
1694
1695
1696# -----------------------------------------------------------------------------
1697if __name__ == "__main__":
1698    main()  # pylint: disable=no-value-for-parameter
1699