• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20import os
21import random
22import pytest
23
24from bumble.controller import Controller
25from bumble.link import LocalLink
26from bumble.device import Device
27from bumble.host import Host
28from bumble.transport import AsyncPipeSink
29from bumble.core import ProtocolError
30from bumble.l2cap import L2CAP_Connection_Request
31
32
33# -----------------------------------------------------------------------------
34# Logging
35# -----------------------------------------------------------------------------
36logger = logging.getLogger(__name__)
37
38
39# -----------------------------------------------------------------------------
40class TwoDevices:
41    def __init__(self):
42        self.connections = [None, None]
43
44        self.link = LocalLink()
45        self.controllers = [
46            Controller('C1', link=self.link),
47            Controller('C2', link=self.link),
48        ]
49        self.devices = [
50            Device(
51                address='F0:F1:F2:F3:F4:F5',
52                host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
53            ),
54            Device(
55                address='F5:F4:F3:F2:F1:F0',
56                host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
57            ),
58        ]
59
60        self.paired = [None, None]
61
62    def on_connection(self, which, connection):
63        self.connections[which] = connection
64
65    def on_paired(self, which, keys):
66        self.paired[which] = keys
67
68
69# -----------------------------------------------------------------------------
70async def setup_connection():
71    # Create two devices, each with a controller, attached to the same link
72    two_devices = TwoDevices()
73
74    # Attach listeners
75    two_devices.devices[0].on(
76        'connection', lambda connection: two_devices.on_connection(0, connection)
77    )
78    two_devices.devices[1].on(
79        'connection', lambda connection: two_devices.on_connection(1, connection)
80    )
81
82    # Start
83    await two_devices.devices[0].power_on()
84    await two_devices.devices[1].power_on()
85
86    # Connect the two devices
87    await two_devices.devices[0].connect(two_devices.devices[1].random_address)
88
89    # Check the post conditions
90    assert two_devices.connections[0] is not None
91    assert two_devices.connections[1] is not None
92
93    return two_devices
94
95
96# -----------------------------------------------------------------------------
97def test_helpers():
98    psm = L2CAP_Connection_Request.serialize_psm(0x01)
99    assert psm == bytes([0x01, 0x00])
100
101    psm = L2CAP_Connection_Request.serialize_psm(0x1023)
102    assert psm == bytes([0x23, 0x10])
103
104    psm = L2CAP_Connection_Request.serialize_psm(0x242311)
105    assert psm == bytes([0x11, 0x23, 0x24])
106
107    (offset, psm) = L2CAP_Connection_Request.parse_psm(
108        bytes([0x00, 0x01, 0x00, 0x44]), 1
109    )
110    assert offset == 3
111    assert psm == 0x01
112
113    (offset, psm) = L2CAP_Connection_Request.parse_psm(
114        bytes([0x00, 0x23, 0x10, 0x44]), 1
115    )
116    assert offset == 3
117    assert psm == 0x1023
118
119    (offset, psm) = L2CAP_Connection_Request.parse_psm(
120        bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1
121    )
122    assert offset == 4
123    assert psm == 0x242311
124
125    rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44)
126    brq = bytes(rq)
127    srq = L2CAP_Connection_Request.from_bytes(brq)
128    assert srq.psm == rq.psm
129    assert srq.source_cid == rq.source_cid
130
131
132# -----------------------------------------------------------------------------
133@pytest.mark.asyncio
134async def test_basic_connection():
135    devices = await setup_connection()
136    psm = 1234
137
138    # Check that if there's no one listening, we can't connect
139    with pytest.raises(ProtocolError):
140        l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
141
142    # Now add a listener
143    incoming_channel = None
144    received = []
145
146    def on_coc(channel):
147        nonlocal incoming_channel
148        incoming_channel = channel
149
150        def on_data(data):
151            received.append(data)
152
153        channel.sink = on_data
154
155    devices.devices[1].register_l2cap_channel_server(psm, on_coc)
156    l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
157
158    messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000))
159    for message in messages:
160        l2cap_channel.write(message)
161        await asyncio.sleep(0)
162
163    await l2cap_channel.drain()
164
165    # Test closing
166    closed = [False, False]
167    closed_event = asyncio.Event()
168
169    def on_close(which, event):
170        closed[which] = True
171        if event:
172            event.set()
173
174    l2cap_channel.on('close', lambda: on_close(0, None))
175    incoming_channel.on('close', lambda: on_close(1, closed_event))
176    await l2cap_channel.disconnect()
177    assert closed == [True, True]
178    await closed_event.wait()
179
180    sent_bytes = b''.join(messages)
181    received_bytes = b''.join(received)
182    assert sent_bytes == received_bytes
183
184
185# -----------------------------------------------------------------------------
186async def transfer_payload(max_credits, mtu, mps):
187    devices = await setup_connection()
188
189    received = []
190
191    def on_coc(channel):
192        def on_data(data):
193            received.append(data)
194
195        channel.sink = on_data
196
197    psm = devices.devices[1].register_l2cap_channel_server(
198        psm=0, server=on_coc, max_credits=max_credits, mtu=mtu, mps=mps
199    )
200    l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
201
202    messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)]
203    for message in messages:
204        l2cap_channel.write(message)
205        await asyncio.sleep(0)
206        if random.randint(0, 5) == 1:
207            await l2cap_channel.drain()
208
209    await l2cap_channel.drain()
210    await l2cap_channel.disconnect()
211
212    sent_bytes = b''.join(messages)
213    received_bytes = b''.join(received)
214    assert sent_bytes == received_bytes
215
216
217@pytest.mark.asyncio
218async def test_transfer():
219    for max_credits in (1, 10, 100, 10000):
220        for mtu in (50, 255, 256, 1000):
221            for mps in (50, 255, 256, 1000):
222                # print(max_credits, mtu, mps)
223                await transfer_payload(max_credits, mtu, mps)
224
225
226# -----------------------------------------------------------------------------
227@pytest.mark.asyncio
228async def test_bidirectional_transfer():
229    devices = await setup_connection()
230
231    client_received = []
232    server_received = []
233    server_channel = None
234
235    def on_server_coc(channel):
236        nonlocal server_channel
237        server_channel = channel
238
239        def on_server_data(data):
240            server_received.append(data)
241
242        channel.sink = on_server_data
243
244    def on_client_data(data):
245        client_received.append(data)
246
247    psm = devices.devices[1].register_l2cap_channel_server(psm=0, server=on_server_coc)
248    client_channel = await devices.connections[0].open_l2cap_channel(psm)
249    client_channel.sink = on_client_data
250
251    messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)]
252    for message in messages:
253        client_channel.write(message)
254        await client_channel.drain()
255        await asyncio.sleep(0)
256        server_channel.write(message)
257        await server_channel.drain()
258
259    await client_channel.disconnect()
260
261    message_bytes = b''.join(messages)
262    client_received_bytes = b''.join(client_received)
263    server_received_bytes = b''.join(server_received)
264    assert client_received_bytes == message_bytes
265    assert server_received_bytes == message_bytes
266
267
268# -----------------------------------------------------------------------------
269async def run():
270    test_helpers()
271    await test_basic_connection()
272    await test_transfer()
273    await test_bidirectional_transfer()
274
275
276# -----------------------------------------------------------------------------
277if __name__ == '__main__':
278    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
279    asyncio.run(run())
280