• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import asyncio
2import collections
3import enum
4import hci_packets as hci
5import link_layer_packets as ll
6import py.bluetooth
7import sys
8import typing
9import unittest
10from typing import Optional, Tuple, Union
11from hci_packets import ErrorCode
12
13from ctypes import *
14
15rootcanal = cdll.LoadLibrary("lib_rootcanal_ffi.so")
16rootcanal.ffi_controller_new.restype = c_void_p
17
18SEND_HCI_FUNC = CFUNCTYPE(None, c_int, POINTER(c_ubyte), c_size_t)
19SEND_LL_FUNC = CFUNCTYPE(None, POINTER(c_ubyte), c_size_t, c_int, c_int)
20
21
22class Idc(enum.IntEnum):
23    Cmd = 1
24    Acl = 2
25    Sco = 3
26    Evt = 4
27    Iso = 5
28
29
30class Phy(enum.IntEnum):
31    LowEnergy = 0
32    BrEdr = 1
33
34
35class LeFeatures:
36
37    def __init__(self, le_features: int):
38        self.mask = le_features
39        self.ll_privacy = (le_features & hci.LLFeaturesBits.LL_PRIVACY) != 0
40        self.le_extended_advertising = (le_features & hci.LLFeaturesBits.LE_EXTENDED_ADVERTISING) != 0
41        self.le_periodic_advertising = (le_features & hci.LLFeaturesBits.LE_PERIODIC_ADVERTISING) != 0
42
43
44def generate_rpa(irk: bytes) -> hci.Address:
45    rpa = bytearray(6)
46    rpa_type = c_char * 6
47    rootcanal.ffi_generate_rpa(c_char_p(irk), rpa_type.from_buffer(rpa))
48    rpa.reverse()
49    return hci.Address(bytes(rpa))
50
51
52class Controller:
53    """Binder class over RootCanal's ffi interfaces.
54    The methods send_cmd, send_hci, send_ll are used to inject HCI or LL
55    packets into the controller, and receive_hci, receive_ll to
56    catch outgoing HCI packets of LL pdus."""
57
58    def __init__(self, address: hci.Address):
59        # Write the callbacks for handling HCI and LL send events.
60        @SEND_HCI_FUNC
61        def send_hci(idc: c_int, data: POINTER(c_ubyte), data_len: c_size_t):
62            packet = []
63            for n in range(data_len):
64                packet.append(data[n])
65            self.receive_hci_(int(idc), bytes(packet))
66
67        @SEND_LL_FUNC
68        def send_ll(data: POINTER(c_ubyte), data_len: c_size_t, phy: c_int, tx_power: c_int):
69            packet = []
70            for n in range(data_len):
71                packet.append(data[n])
72            self.receive_ll_(bytes(packet), int(phy), int(tx_power))
73
74        self.send_hci_callback = SEND_HCI_FUNC(send_hci)
75        self.send_ll_callback = SEND_LL_FUNC(send_ll)
76
77        # Create a c++ controller instance.
78        self.instance = rootcanal.ffi_controller_new(c_char_p(address.address), self.send_hci_callback,
79                                                     self.send_ll_callback)
80
81        self.address = address
82        self.evt_queue = collections.deque()
83        self.acl_queue = collections.deque()
84        self.ll_queue = collections.deque()
85        self.evt_queue_event = asyncio.Event()
86        self.acl_queue_event = asyncio.Event()
87        self.ll_queue_event = asyncio.Event()
88
89    def __del__(self):
90        rootcanal.ffi_controller_delete(c_void_p(self.instance))
91
92    def receive_hci_(self, idc: int, packet: bytes):
93        if idc == Idc.Evt:
94            print(f"<-- received HCI event data={len(packet)}[..]")
95            self.evt_queue.append(packet)
96            self.evt_queue_event.set()
97        elif idc == Idc.Acl:
98            print(f"<-- received HCI ACL packet data={len(packet)}[..]")
99            self.acl_queue.append(packet)
100            self.acl_queue_event.set()
101        else:
102            print(f"ignoring HCI packet typ={typ}")
103
104    def receive_ll_(self, packet: bytes, phy: int, tx_power: int):
105        print(f"<-- received LL pdu data={len(packet)}[..]")
106        self.ll_queue.append(packet)
107        self.ll_queue_event.set()
108
109    def send_cmd(self, cmd: hci.Command):
110        print(f"--> sending HCI command {cmd.__class__.__name__}")
111        data = cmd.serialize()
112        rootcanal.ffi_controller_receive_hci(c_void_p(self.instance), c_int(Idc.Cmd), c_char_p(data), c_int(len(data)))
113
114    def send_ll(self, pdu: ll.LinkLayerPacket, phy: Phy = Phy.LowEnergy, rssi: int = -90):
115        print(f"--> sending LL pdu {pdu.__class__.__name__}")
116        data = pdu.serialize()
117        rootcanal.ffi_controller_receive_ll(c_void_p(self.instance), c_char_p(data), c_int(len(data)), c_int(phy),
118                                            c_int(rssi))
119
120    async def start(self):
121
122        async def timer():
123            while True:
124                await asyncio.sleep(0.005)
125                rootcanal.ffi_controller_tick(c_void_p(self.instance))
126
127        # Spawn the controller timer task.
128        self.timer_task = asyncio.create_task(timer())
129
130    def stop(self):
131        # Cancel the controller timer task.
132        del self.timer_task
133
134        if self.evt_queue:
135            print("evt queue not empty at stop():")
136            for packet in self.evt_queue:
137                evt = hci.Event.parse_all(packet)
138                evt.show()
139            raise Exception("evt queue not empty at stop()")
140
141        if self.ll_queue:
142            for (packet, _) in self.ll_queue:
143                pdu = ll.LinkLayerPacket.parse_all(packet)
144                pdu.show()
145            raise Exception("ll queue not empty at stop()")
146
147    async def receive_evt(self):
148        while not self.evt_queue:
149            await self.evt_queue_event.wait()
150            self.evt_queue_event.clear()
151        return self.evt_queue.popleft()
152
153    async def expect_evt(self, expected_evt: hci.Event):
154        packet = await self.receive_evt()
155        evt = hci.Event.parse_all(packet)
156        if evt != expected_evt:
157            print("received unexpected event")
158            print("expected event:")
159            expected_evt.show()
160            print("received event:")
161            evt.show()
162            raise Exception(f"unexpected evt {evt.__class__.__name__}")
163
164    async def receive_ll(self):
165        while not self.ll_queue:
166            await self.ll_queue_event.wait()
167            self.ll_queue_event.clear()
168        return self.ll_queue.popleft()
169
170
171class Any:
172    """Helper class that will match all other values.
173       Use an element of this class in expected packets to match any value
174      returned by the Controller stack."""
175
176    def __eq__(self, other) -> bool:
177        return True
178
179    def __format__(self, format_spec: str) -> str:
180        return "_"
181
182
183class ControllerTest(unittest.IsolatedAsyncioTestCase):
184    """Helper class for writing controller tests using the python bindings.
185    The test setups the controller sending the Reset command and configuring
186    the event masks to allow all events. The local device address is
187    always configured as 11:11:11:11:11:11."""
188
189    Any = Any()
190
191    def setUp(self):
192        self.controller = Controller(hci.Address('11:11:11:11:11:11'))
193
194    async def asyncSetUp(self):
195        controller = self.controller
196
197        # Start the controller timer.
198        await controller.start()
199
200        # Reset the controller and enable all events and LE events.
201        controller.send_cmd(hci.Reset())
202        await controller.expect_evt(hci.ResetComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
203        controller.send_cmd(hci.SetEventMask(event_mask=0xffffffffffffffff))
204        await controller.expect_evt(hci.SetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
205        controller.send_cmd(hci.LeSetEventMask(le_event_mask=0xffffffffffffffff))
206        await controller.expect_evt(hci.LeSetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
207
208        # Load the local supported features to be able to disable tests
209        # that rely on unsupported features.
210        controller.send_cmd(hci.LeReadLocalSupportedFeatures())
211        evt = await self.expect_cmd_complete(hci.LeReadLocalSupportedFeaturesComplete)
212        controller.le_features = LeFeatures(evt.le_features)
213
214    async def expect_evt(self, expected_evt: typing.Union[hci.Event, type], timeout: int = 3) -> hci.Event:
215        packet = await asyncio.wait_for(self.controller.receive_evt(), timeout=timeout)
216        evt = hci.Event.parse_all(packet)
217
218        if isinstance(expected_evt, type) and not isinstance(evt, expected_evt):
219            print("received unexpected event")
220            print(f"expected event: {expected_evt.__class__.__name__}")
221            print("received event:")
222            evt.show()
223            self.assertTrue(False)
224
225        if isinstance(expected_evt, hci.Event) and evt != expected_evt:
226            print("received unexpected event")
227            print(f"expected event:")
228            expected_evt.show()
229            print("received event:")
230            evt.show()
231            self.assertTrue(False)
232
233        return evt
234
235    async def expect_cmd_complete(self, expected_evt: type, timeout: int = 3) -> hci.Event:
236        evt = await self.expect_evt(expected_evt, timeout=timeout)
237        assert evt.status == ErrorCode.SUCCESS
238        assert evt.num_hci_command_packets == 1
239        return evt
240
241    async def expect_ll(self,
242                        expected_pdus: typing.Union[list, typing.Union[ll.LinkLayerPacket, type]],
243                        timeout: int = 3) -> ll.LinkLayerPacket:
244        if not isinstance(expected_pdus, list):
245            expected_pdus = [expected_pdus]
246
247        packet = await asyncio.wait_for(self.controller.receive_ll(), timeout=timeout)
248        pdu = ll.LinkLayerPacket.parse_all(packet)
249
250        for expected_pdu in expected_pdus:
251            if isinstance(expected_pdu, type) and isinstance(pdu, expected_pdu):
252                return pdu
253            if isinstance(expected_pdu, ll.LinkLayerPacket) and pdu == expected_pdu:
254                return pdu
255
256        print("received unexpected pdu:")
257        pdu.show()
258        print("expected pdus:")
259        for expected_pdu in expected_pdus:
260            if isinstance(expected_pdu, type):
261                print(f"- {expected_pdu.__name__}")
262            if isinstance(expected_pdu, ll.LinkLayerPacket):
263                print(f"- {expected_pdu.__class__.__name__}")
264                expected_pdu.show()
265
266        self.assertTrue(False)
267
268    def tearDown(self):
269        self.controller.stop()
270