• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import functools
20import logging
21import os
22from types import LambdaType
23import pytest
24from unittest import mock
25
26from bumble.core import (
27    BT_BR_EDR_TRANSPORT,
28    BT_LE_TRANSPORT,
29    BT_PERIPHERAL_ROLE,
30    ConnectionParameters,
31)
32from bumble.device import AdvertisingParameters, Connection, Device
33from bumble.host import AclPacketQueue, Host
34from bumble.hci import (
35    HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
36    HCI_COMMAND_STATUS_PENDING,
37    HCI_CREATE_CONNECTION_COMMAND,
38    HCI_SUCCESS,
39    HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
40    Address,
41    OwnAddressType,
42    HCI_Command_Complete_Event,
43    HCI_Command_Status_Event,
44    HCI_Connection_Complete_Event,
45    HCI_Connection_Request_Event,
46    HCI_Error,
47    HCI_Packet,
48)
49from bumble.gatt import (
50    GATT_GENERIC_ACCESS_SERVICE,
51    GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
52    GATT_DEVICE_NAME_CHARACTERISTIC,
53    GATT_APPEARANCE_CHARACTERISTIC,
54)
55
56from .test_utils import TwoDevices, async_barrier
57
58# -----------------------------------------------------------------------------
59# Constants
60# -----------------------------------------------------------------------------
61_TIMEOUT = 0.1
62
63# -----------------------------------------------------------------------------
64# Logging
65# -----------------------------------------------------------------------------
66logger = logging.getLogger(__name__)
67
68
69# -----------------------------------------------------------------------------
70class Sink:
71    def __init__(self, flow):
72        self.flow = flow
73        next(self.flow)
74
75    def on_packet(self, packet):
76        self.flow.send(packet)
77
78
79# -----------------------------------------------------------------------------
80@pytest.mark.asyncio
81async def test_device_connect_parallel():
82    d0 = Device(host=Host(None, None))
83    d1 = Device(host=Host(None, None))
84    d2 = Device(host=Host(None, None))
85
86    def _send(packet):
87        pass
88
89    d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
90    d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
91    d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
92
93    # enable classic
94    d0.classic_enabled = True
95    d1.classic_enabled = True
96    d2.classic_enabled = True
97
98    # set public addresses
99    d0.public_address = Address(
100        'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
101    )
102    d1.public_address = Address(
103        'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS
104    )
105    d2.public_address = Address(
106        'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
107    )
108
109    def d0_flow():
110        packet = HCI_Packet.from_bytes((yield))
111        assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
112        assert packet.bd_addr == d1.public_address
113
114        d0.host.on_hci_packet(
115            HCI_Command_Status_Event(
116                status=HCI_COMMAND_STATUS_PENDING,
117                num_hci_command_packets=1,
118                command_opcode=HCI_CREATE_CONNECTION_COMMAND,
119            )
120        )
121
122        d1.host.on_hci_packet(
123            HCI_Connection_Request_Event(
124                bd_addr=d0.public_address,
125                class_of_device=0,
126                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
127            )
128        )
129
130        packet = HCI_Packet.from_bytes((yield))
131        assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
132        assert packet.bd_addr == d2.public_address
133
134        d0.host.on_hci_packet(
135            HCI_Command_Status_Event(
136                status=HCI_COMMAND_STATUS_PENDING,
137                num_hci_command_packets=1,
138                command_opcode=HCI_CREATE_CONNECTION_COMMAND,
139            )
140        )
141
142        d2.host.on_hci_packet(
143            HCI_Connection_Request_Event(
144                bd_addr=d0.public_address,
145                class_of_device=0,
146                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
147            )
148        )
149
150        assert (yield) == None
151
152    def d1_flow():
153        packet = HCI_Packet.from_bytes((yield))
154        assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
155
156        d1.host.on_hci_packet(
157            HCI_Command_Complete_Event(
158                num_hci_command_packets=1,
159                command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
160                return_parameters=b"\x00",
161            )
162        )
163
164        d1.host.on_hci_packet(
165            HCI_Connection_Complete_Event(
166                status=HCI_SUCCESS,
167                connection_handle=0x100,
168                bd_addr=d0.public_address,
169                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
170                encryption_enabled=True,
171            )
172        )
173
174        d0.host.on_hci_packet(
175            HCI_Connection_Complete_Event(
176                status=HCI_SUCCESS,
177                connection_handle=0x100,
178                bd_addr=d1.public_address,
179                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
180                encryption_enabled=True,
181            )
182        )
183
184        assert (yield) == None
185
186    def d2_flow():
187        packet = HCI_Packet.from_bytes((yield))
188        assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
189
190        d2.host.on_hci_packet(
191            HCI_Command_Complete_Event(
192                num_hci_command_packets=1,
193                command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
194                return_parameters=b"\x00",
195            )
196        )
197
198        d2.host.on_hci_packet(
199            HCI_Connection_Complete_Event(
200                status=HCI_SUCCESS,
201                connection_handle=0x101,
202                bd_addr=d0.public_address,
203                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
204                encryption_enabled=True,
205            )
206        )
207
208        d0.host.on_hci_packet(
209            HCI_Connection_Complete_Event(
210                status=HCI_SUCCESS,
211                connection_handle=0x101,
212                bd_addr=d2.public_address,
213                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
214                encryption_enabled=True,
215            )
216        )
217
218        assert (yield) == None
219
220    d0.host.set_packet_sink(Sink(d0_flow()))
221    d1.host.set_packet_sink(Sink(d1_flow()))
222    d2.host.set_packet_sink(Sink(d2_flow()))
223
224    d1_accept_task = asyncio.create_task(d1.accept(peer_address=d0.public_address))
225    d2_accept_task = asyncio.create_task(d2.accept())
226
227    # Ensure that the accept tasks have started.
228    await async_barrier()
229
230    [c01, c02, a10, a20] = await asyncio.gather(
231        *[
232            asyncio.create_task(
233                d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)
234            ),
235            asyncio.create_task(
236                d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)
237            ),
238            d1_accept_task,
239            d2_accept_task,
240        ]
241    )
242
243    assert type(c01) == Connection
244    assert type(c02) == Connection
245    assert type(a10) == Connection
246    assert type(a20) == Connection
247
248    assert c01.handle == a10.handle and c01.handle == 0x100
249    assert c02.handle == a20.handle and c02.handle == 0x101
250
251
252# -----------------------------------------------------------------------------
253@pytest.mark.asyncio
254async def test_flush():
255    d0 = Device(host=Host(None, None))
256    task = d0.abort_on('flush', asyncio.sleep(10000))
257    await d0.host.flush()
258    try:
259        await task
260        assert False
261    except asyncio.CancelledError:
262        pass
263
264
265# -----------------------------------------------------------------------------
266@pytest.mark.asyncio
267async def test_legacy_advertising():
268    device = Device(host=mock.AsyncMock(Host))
269
270    # Start advertising
271    await device.start_advertising()
272    assert device.is_advertising
273
274    # Stop advertising
275    await device.stop_advertising()
276    assert not device.is_advertising
277
278
279# -----------------------------------------------------------------------------
280@pytest.mark.parametrize(
281    'own_address_type,',
282    (OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
283)
284@pytest.mark.asyncio
285async def test_legacy_advertising_connection(own_address_type):
286    device = Device(host=mock.AsyncMock(Host))
287    peer_address = Address('F0:F1:F2:F3:F4:F5')
288
289    # Start advertising
290    await device.start_advertising()
291    device.on_connection(
292        0x0001,
293        BT_LE_TRANSPORT,
294        peer_address,
295        BT_PERIPHERAL_ROLE,
296        ConnectionParameters(0, 0, 0),
297    )
298
299    if own_address_type == OwnAddressType.PUBLIC:
300        assert device.lookup_connection(0x0001).self_address == device.public_address
301    else:
302        assert device.lookup_connection(0x0001).self_address == device.random_address
303
304    # For unknown reason, read_phy() in on_connection() would be killed at the end of
305    # test, so we force scheduling here to avoid an warning.
306    await asyncio.sleep(0.0001)
307
308
309# -----------------------------------------------------------------------------
310@pytest.mark.parametrize(
311    'auto_restart,',
312    (True, False),
313)
314@pytest.mark.asyncio
315async def test_legacy_advertising_disconnection(auto_restart):
316    device = Device(host=mock.AsyncMock(spec=Host))
317    peer_address = Address('F0:F1:F2:F3:F4:F5')
318    await device.start_advertising(auto_restart=auto_restart)
319    device.on_connection(
320        0x0001,
321        BT_LE_TRANSPORT,
322        peer_address,
323        BT_PERIPHERAL_ROLE,
324        ConnectionParameters(0, 0, 0),
325    )
326
327    device.on_advertising_set_termination(
328        HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0
329    )
330
331    device.on_disconnection(0x0001, 0)
332    await async_barrier()
333    await async_barrier()
334
335    if auto_restart:
336        assert device.is_advertising
337    else:
338        assert not device.is_advertising
339
340
341# -----------------------------------------------------------------------------
342@pytest.mark.asyncio
343async def test_extended_advertising():
344    device = Device(host=mock.AsyncMock(Host))
345
346    # Start advertising
347    advertising_set = await device.create_advertising_set()
348    assert device.extended_advertising_sets
349    assert advertising_set.enabled
350
351    # Stop advertising
352    await advertising_set.stop()
353    assert not advertising_set.enabled
354
355
356# -----------------------------------------------------------------------------
357@pytest.mark.parametrize(
358    'own_address_type,',
359    (OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
360)
361@pytest.mark.asyncio
362async def test_extended_advertising_connection(own_address_type):
363    device = Device(host=mock.AsyncMock(spec=Host))
364    peer_address = Address('F0:F1:F2:F3:F4:F5')
365    advertising_set = await device.create_advertising_set(
366        advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)
367    )
368    device.on_connection(
369        0x0001,
370        BT_LE_TRANSPORT,
371        peer_address,
372        BT_PERIPHERAL_ROLE,
373        ConnectionParameters(0, 0, 0),
374    )
375    device.on_advertising_set_termination(
376        HCI_SUCCESS,
377        advertising_set.advertising_handle,
378        0x0001,
379        0,
380    )
381
382    if own_address_type == OwnAddressType.PUBLIC:
383        assert device.lookup_connection(0x0001).self_address == device.public_address
384    else:
385        assert device.lookup_connection(0x0001).self_address == device.random_address
386
387    # For unknown reason, read_phy() in on_connection() would be killed at the end of
388    # test, so we force scheduling here to avoid an warning.
389    await asyncio.sleep(0.0001)
390
391
392# -----------------------------------------------------------------------------
393@pytest.mark.asyncio
394async def test_get_remote_le_features():
395    devices = TwoDevices()
396    await devices.setup_connection()
397
398    assert (await devices.connections[0].get_remote_le_features()) is not None
399
400
401# -----------------------------------------------------------------------------
402@pytest.mark.asyncio
403async def test_get_remote_le_features_failed():
404    devices = TwoDevices()
405    await devices.setup_connection()
406
407    def on_hci_le_read_remote_features_complete_event(event):
408        devices[0].host.emit(
409            'le_remote_features_failure',
410            event.connection_handle,
411            HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
412        )
413
414    devices[0].host.on_hci_le_read_remote_features_complete_event = (
415        on_hci_le_read_remote_features_complete_event
416    )
417
418    with pytest.raises(HCI_Error):
419        await asyncio.wait_for(
420            devices.connections[0].get_remote_le_features(), _TIMEOUT
421        )
422
423
424# -----------------------------------------------------------------------------
425@pytest.mark.asyncio
426async def test_cis():
427    devices = TwoDevices()
428    await devices.setup_connection()
429
430    peripheral_cis_futures = {}
431
432    def on_cis_request(
433        acl_connection: Connection,
434        cis_handle: int,
435        _cig_id: int,
436        _cis_id: int,
437    ):
438        acl_connection.abort_on(
439            'disconnection', devices[1].accept_cis_request(cis_handle)
440        )
441        peripheral_cis_futures[cis_handle] = asyncio.get_running_loop().create_future()
442
443    devices[1].on('cis_request', on_cis_request)
444    devices[1].on(
445        'cis_establishment',
446        lambda cis_link: peripheral_cis_futures[cis_link.handle].set_result(None),
447    )
448
449    cis_handles = await devices[0].setup_cig(
450        cig_id=1,
451        cis_id=[2, 3],
452        sdu_interval=(0, 0),
453        framing=0,
454        max_sdu=(0, 0),
455        retransmission_number=0,
456        max_transport_latency=(0, 0),
457    )
458    assert len(cis_handles) == 2
459    cis_links = await devices[0].create_cis(
460        [
461            (cis_handles[0], devices.connections[0].handle),
462            (cis_handles[1], devices.connections[0].handle),
463        ]
464    )
465    await asyncio.gather(*peripheral_cis_futures.values())
466    assert len(cis_links) == 2
467
468    await cis_links[0].disconnect()
469    await cis_links[1].disconnect()
470
471
472# -----------------------------------------------------------------------------
473@pytest.mark.asyncio
474async def test_cis_setup_failure():
475    devices = TwoDevices()
476    await devices.setup_connection()
477
478    cis_requests = asyncio.Queue()
479
480    def on_cis_request(
481        acl_connection: Connection,
482        cis_handle: int,
483        cig_id: int,
484        cis_id: int,
485    ):
486        del acl_connection, cig_id, cis_id
487        cis_requests.put_nowait(cis_handle)
488
489    devices[1].on('cis_request', on_cis_request)
490
491    cis_handles = await devices[0].setup_cig(
492        cig_id=1,
493        cis_id=[2],
494        sdu_interval=(0, 0),
495        framing=0,
496        max_sdu=(0, 0),
497        retransmission_number=0,
498        max_transport_latency=(0, 0),
499    )
500    assert len(cis_handles) == 1
501
502    cis_create_task = asyncio.create_task(
503        devices[0].create_cis(
504            [
505                (cis_handles[0], devices.connections[0].handle),
506            ]
507        )
508    )
509
510    def on_hci_le_cis_established_event(host, event):
511        host.emit(
512            'cis_establishment_failure',
513            event.connection_handle,
514            HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
515        )
516
517    for device in devices:
518        device.host.on_hci_le_cis_established_event = functools.partial(
519            on_hci_le_cis_established_event, device.host
520        )
521
522    cis_request = await asyncio.wait_for(cis_requests.get(), _TIMEOUT)
523
524    with pytest.raises(HCI_Error):
525        await asyncio.wait_for(devices[1].accept_cis_request(cis_request), _TIMEOUT)
526
527    with pytest.raises(HCI_Error):
528        await asyncio.wait_for(cis_create_task, _TIMEOUT)
529
530
531# -----------------------------------------------------------------------------
532def test_gatt_services_with_gas():
533    device = Device(host=Host(None, None))
534
535    # there should be one service and two chars, therefore 5 attributes
536    assert len(device.gatt_server.attributes) == 5
537    assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE
538    assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
539    assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC
540    assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
541    assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC
542
543
544# -----------------------------------------------------------------------------
545def test_gatt_services_without_gas():
546    device = Device(host=Host(None, None), generic_access_service=False)
547
548    # there should be no services
549    assert len(device.gatt_server.attributes) == 0
550
551
552# -----------------------------------------------------------------------------
553async def run_test_device():
554    await test_device_connect_parallel()
555    await test_flush()
556    await test_gatt_services_with_gas()
557    await test_gatt_services_without_gas()
558
559
560# -----------------------------------------------------------------------------
561if __name__ == '__main__':
562    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
563    asyncio.run(run_test_device())
564