• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20import os
21from types import LambdaType
22import pytest
23
24from bumble.core import BT_BR_EDR_TRANSPORT
25from bumble.device import Connection, Device
26from bumble.host import Host
27from bumble.hci import (
28    HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
29    HCI_COMMAND_STATUS_PENDING,
30    HCI_CREATE_CONNECTION_COMMAND,
31    HCI_SUCCESS,
32    Address,
33    HCI_Command_Complete_Event,
34    HCI_Command_Status_Event,
35    HCI_Connection_Complete_Event,
36    HCI_Connection_Request_Event,
37    HCI_Packet,
38)
39from bumble.gatt import (
40    GATT_GENERIC_ACCESS_SERVICE,
41    GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
42    GATT_DEVICE_NAME_CHARACTERISTIC,
43    GATT_APPEARANCE_CHARACTERISTIC,
44)
45
46# -----------------------------------------------------------------------------
47# Logging
48# -----------------------------------------------------------------------------
49logger = logging.getLogger(__name__)
50
51
52# -----------------------------------------------------------------------------
53class Sink:
54    def __init__(self, flow):
55        self.flow = flow
56        next(self.flow)
57
58    def on_packet(self, packet):
59        self.flow.send(packet)
60
61
62# -----------------------------------------------------------------------------
63@pytest.mark.asyncio
64async def test_device_connect_parallel():
65    d0 = Device(host=Host(None, None))
66    d1 = Device(host=Host(None, None))
67    d2 = Device(host=Host(None, None))
68
69    # enable classic
70    d0.classic_enabled = True
71    d1.classic_enabled = True
72    d2.classic_enabled = True
73
74    # set public addresses
75    d0.public_address = Address(
76        'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
77    )
78    d1.public_address = Address(
79        'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS
80    )
81    d2.public_address = Address(
82        'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
83    )
84
85    def d0_flow():
86        packet = HCI_Packet.from_bytes((yield))
87        assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
88        assert packet.bd_addr == d1.public_address
89
90        d0.host.on_hci_packet(
91            HCI_Command_Status_Event(
92                status=HCI_COMMAND_STATUS_PENDING,
93                num_hci_command_packets=1,
94                command_opcode=HCI_CREATE_CONNECTION_COMMAND,
95            )
96        )
97
98        d1.host.on_hci_packet(
99            HCI_Connection_Request_Event(
100                bd_addr=d0.public_address,
101                class_of_device=0,
102                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
103            )
104        )
105
106        packet = HCI_Packet.from_bytes((yield))
107        assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
108        assert packet.bd_addr == d2.public_address
109
110        d0.host.on_hci_packet(
111            HCI_Command_Status_Event(
112                status=HCI_COMMAND_STATUS_PENDING,
113                num_hci_command_packets=1,
114                command_opcode=HCI_CREATE_CONNECTION_COMMAND,
115            )
116        )
117
118        d2.host.on_hci_packet(
119            HCI_Connection_Request_Event(
120                bd_addr=d0.public_address,
121                class_of_device=0,
122                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
123            )
124        )
125
126        assert (yield) == None
127
128    def d1_flow():
129        packet = HCI_Packet.from_bytes((yield))
130        assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
131
132        d1.host.on_hci_packet(
133            HCI_Command_Complete_Event(
134                num_hci_command_packets=1,
135                command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
136                return_parameters=b"\x00",
137            )
138        )
139
140        d1.host.on_hci_packet(
141            HCI_Connection_Complete_Event(
142                status=HCI_SUCCESS,
143                connection_handle=0x100,
144                bd_addr=d0.public_address,
145                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
146                encryption_enabled=True,
147            )
148        )
149
150        d0.host.on_hci_packet(
151            HCI_Connection_Complete_Event(
152                status=HCI_SUCCESS,
153                connection_handle=0x100,
154                bd_addr=d1.public_address,
155                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
156                encryption_enabled=True,
157            )
158        )
159
160        assert (yield) == None
161
162    def d2_flow():
163        packet = HCI_Packet.from_bytes((yield))
164        assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
165
166        d2.host.on_hci_packet(
167            HCI_Command_Complete_Event(
168                num_hci_command_packets=1,
169                command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
170                return_parameters=b"\x00",
171            )
172        )
173
174        d2.host.on_hci_packet(
175            HCI_Connection_Complete_Event(
176                status=HCI_SUCCESS,
177                connection_handle=0x101,
178                bd_addr=d0.public_address,
179                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
180                encryption_enabled=True,
181            )
182        )
183
184        d0.host.on_hci_packet(
185            HCI_Connection_Complete_Event(
186                status=HCI_SUCCESS,
187                connection_handle=0x101,
188                bd_addr=d2.public_address,
189                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
190                encryption_enabled=True,
191            )
192        )
193
194        assert (yield) == None
195
196    d0.host.set_packet_sink(Sink(d0_flow()))
197    d1.host.set_packet_sink(Sink(d1_flow()))
198    d2.host.set_packet_sink(Sink(d2_flow()))
199
200    [c01, c02, a10, a20] = await asyncio.gather(
201        *[
202            asyncio.create_task(
203                d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)
204            ),
205            asyncio.create_task(
206                d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)
207            ),
208            asyncio.create_task(d1.accept(peer_address=d0.public_address)),
209            asyncio.create_task(d2.accept()),
210        ]
211    )
212
213    assert type(c01) == Connection
214    assert type(c02) == Connection
215    assert type(a10) == Connection
216    assert type(a20) == Connection
217
218    assert c01.handle == a10.handle and c01.handle == 0x100
219    assert c02.handle == a20.handle and c02.handle == 0x101
220
221
222# -----------------------------------------------------------------------------
223@pytest.mark.asyncio
224async def test_flush():
225    d0 = Device(host=Host(None, None))
226    task = d0.abort_on('flush', asyncio.sleep(10000))
227    await d0.host.flush()
228    try:
229        await task
230        assert False
231    except asyncio.CancelledError:
232        pass
233
234
235# -----------------------------------------------------------------------------
236def test_gatt_services_with_gas():
237    device = Device(host=Host(None, None))
238
239    # there should be one service and two chars, therefore 5 attributes
240    assert len(device.gatt_server.attributes) == 5
241    assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE
242    assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
243    assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC
244    assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
245    assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC
246
247
248# -----------------------------------------------------------------------------
249def test_gatt_services_without_gas():
250    device = Device(host=Host(None, None), generic_access_service=False)
251
252    # there should be no services
253    assert len(device.gatt_server.attributes) == 0
254
255
256# -----------------------------------------------------------------------------
257async def run_test_device():
258    await test_device_connect_parallel()
259    await test_flush()
260    await test_gatt_services_with_gas()
261    await test_gatt_services_without_gas()
262
263
264# -----------------------------------------------------------------------------
265if __name__ == '__main__':
266    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
267    asyncio.run(run_test_device())
268