1import asyncio 2import collections 3import enum 4import hci_packets as hci 5import link_layer_packets as ll 6import py.bluetooth 7import sys 8import typing 9import unittest 10from typing import Optional, Tuple, Union 11from hci_packets import ErrorCode 12 13from ctypes import * 14 15rootcanal = cdll.LoadLibrary("lib_rootcanal_ffi.so") 16rootcanal.ffi_controller_new.restype = c_void_p 17 18SEND_HCI_FUNC = CFUNCTYPE(None, c_int, POINTER(c_ubyte), c_size_t) 19SEND_LL_FUNC = CFUNCTYPE(None, POINTER(c_ubyte), c_size_t, c_int, c_int) 20 21 22class Idc(enum.IntEnum): 23 Cmd = 1 24 Acl = 2 25 Sco = 3 26 Evt = 4 27 Iso = 5 28 29 30class Phy(enum.IntEnum): 31 LowEnergy = 0 32 BrEdr = 1 33 34 35class LeFeatures: 36 37 def __init__(self, le_features: int): 38 self.mask = le_features 39 self.ll_privacy = (le_features & hci.LLFeaturesBits.LL_PRIVACY) != 0 40 self.le_extended_advertising = (le_features & hci.LLFeaturesBits.LE_EXTENDED_ADVERTISING) != 0 41 self.le_periodic_advertising = (le_features & hci.LLFeaturesBits.LE_PERIODIC_ADVERTISING) != 0 42 43 44def generate_rpa(irk: bytes) -> hci.Address: 45 rpa = bytearray(6) 46 rpa_type = c_char * 6 47 rootcanal.ffi_generate_rpa(c_char_p(irk), rpa_type.from_buffer(rpa)) 48 rpa.reverse() 49 return hci.Address(bytes(rpa)) 50 51 52class Controller: 53 """Binder class over RootCanal's ffi interfaces. 54 The methods send_cmd, send_hci, send_ll are used to inject HCI or LL 55 packets into the controller, and receive_hci, receive_ll to 56 catch outgoing HCI packets of LL pdus.""" 57 58 def __init__(self, address: hci.Address): 59 # Write the callbacks for handling HCI and LL send events. 60 @SEND_HCI_FUNC 61 def send_hci(idc: c_int, data: POINTER(c_ubyte), data_len: c_size_t): 62 packet = [] 63 for n in range(data_len): 64 packet.append(data[n]) 65 self.receive_hci_(int(idc), bytes(packet)) 66 67 @SEND_LL_FUNC 68 def send_ll(data: POINTER(c_ubyte), data_len: c_size_t, phy: c_int, tx_power: c_int): 69 packet = [] 70 for n in range(data_len): 71 packet.append(data[n]) 72 self.receive_ll_(bytes(packet), int(phy), int(tx_power)) 73 74 self.send_hci_callback = SEND_HCI_FUNC(send_hci) 75 self.send_ll_callback = SEND_LL_FUNC(send_ll) 76 77 # Create a c++ controller instance. 78 self.instance = rootcanal.ffi_controller_new(c_char_p(address.address), self.send_hci_callback, 79 self.send_ll_callback) 80 81 self.address = address 82 self.evt_queue = collections.deque() 83 self.acl_queue = collections.deque() 84 self.ll_queue = collections.deque() 85 self.evt_queue_event = asyncio.Event() 86 self.acl_queue_event = asyncio.Event() 87 self.ll_queue_event = asyncio.Event() 88 89 def __del__(self): 90 rootcanal.ffi_controller_delete(c_void_p(self.instance)) 91 92 def receive_hci_(self, idc: int, packet: bytes): 93 if idc == Idc.Evt: 94 print(f"<-- received HCI event data={len(packet)}[..]") 95 self.evt_queue.append(packet) 96 self.evt_queue_event.set() 97 elif idc == Idc.Acl: 98 print(f"<-- received HCI ACL packet data={len(packet)}[..]") 99 self.acl_queue.append(packet) 100 self.acl_queue_event.set() 101 else: 102 print(f"ignoring HCI packet typ={typ}") 103 104 def receive_ll_(self, packet: bytes, phy: int, tx_power: int): 105 print(f"<-- received LL pdu data={len(packet)}[..]") 106 self.ll_queue.append(packet) 107 self.ll_queue_event.set() 108 109 def send_cmd(self, cmd: hci.Command): 110 print(f"--> sending HCI command {cmd.__class__.__name__}") 111 data = cmd.serialize() 112 rootcanal.ffi_controller_receive_hci(c_void_p(self.instance), c_int(Idc.Cmd), c_char_p(data), c_int(len(data))) 113 114 def send_ll(self, pdu: ll.LinkLayerPacket, phy: Phy = Phy.LowEnergy, rssi: int = -90): 115 print(f"--> sending LL pdu {pdu.__class__.__name__}") 116 data = pdu.serialize() 117 rootcanal.ffi_controller_receive_ll(c_void_p(self.instance), c_char_p(data), c_int(len(data)), c_int(phy), 118 c_int(rssi)) 119 120 async def start(self): 121 122 async def timer(): 123 while True: 124 await asyncio.sleep(0.005) 125 rootcanal.ffi_controller_tick(c_void_p(self.instance)) 126 127 # Spawn the controller timer task. 128 self.timer_task = asyncio.create_task(timer()) 129 130 def stop(self): 131 # Cancel the controller timer task. 132 del self.timer_task 133 134 if self.evt_queue: 135 print("evt queue not empty at stop():") 136 for packet in self.evt_queue: 137 evt = hci.Event.parse_all(packet) 138 evt.show() 139 raise Exception("evt queue not empty at stop()") 140 141 if self.ll_queue: 142 for (packet, _) in self.ll_queue: 143 pdu = ll.LinkLayerPacket.parse_all(packet) 144 pdu.show() 145 raise Exception("ll queue not empty at stop()") 146 147 async def receive_evt(self): 148 while not self.evt_queue: 149 await self.evt_queue_event.wait() 150 self.evt_queue_event.clear() 151 return self.evt_queue.popleft() 152 153 async def expect_evt(self, expected_evt: hci.Event): 154 packet = await self.receive_evt() 155 evt = hci.Event.parse_all(packet) 156 if evt != expected_evt: 157 print("received unexpected event") 158 print("expected event:") 159 expected_evt.show() 160 print("received event:") 161 evt.show() 162 raise Exception(f"unexpected evt {evt.__class__.__name__}") 163 164 async def receive_ll(self): 165 while not self.ll_queue: 166 await self.ll_queue_event.wait() 167 self.ll_queue_event.clear() 168 return self.ll_queue.popleft() 169 170 171class Any: 172 """Helper class that will match all other values. 173 Use an element of this class in expected packets to match any value 174 returned by the Controller stack.""" 175 176 def __eq__(self, other) -> bool: 177 return True 178 179 def __format__(self, format_spec: str) -> str: 180 return "_" 181 182 183class ControllerTest(unittest.IsolatedAsyncioTestCase): 184 """Helper class for writing controller tests using the python bindings. 185 The test setups the controller sending the Reset command and configuring 186 the event masks to allow all events. The local device address is 187 always configured as 11:11:11:11:11:11.""" 188 189 Any = Any() 190 191 def setUp(self): 192 self.controller = Controller(hci.Address('11:11:11:11:11:11')) 193 194 async def asyncSetUp(self): 195 controller = self.controller 196 197 # Start the controller timer. 198 await controller.start() 199 200 # Reset the controller and enable all events and LE events. 201 controller.send_cmd(hci.Reset()) 202 await controller.expect_evt(hci.ResetComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1)) 203 controller.send_cmd(hci.SetEventMask(event_mask=0xffffffffffffffff)) 204 await controller.expect_evt(hci.SetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1)) 205 controller.send_cmd(hci.LeSetEventMask(le_event_mask=0xffffffffffffffff)) 206 await controller.expect_evt(hci.LeSetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1)) 207 208 # Load the local supported features to be able to disable tests 209 # that rely on unsupported features. 210 controller.send_cmd(hci.LeReadLocalSupportedFeatures()) 211 evt = await self.expect_cmd_complete(hci.LeReadLocalSupportedFeaturesComplete) 212 controller.le_features = LeFeatures(evt.le_features) 213 214 async def expect_evt(self, expected_evt: typing.Union[hci.Event, type], timeout: int = 3) -> hci.Event: 215 packet = await asyncio.wait_for(self.controller.receive_evt(), timeout=timeout) 216 evt = hci.Event.parse_all(packet) 217 218 if isinstance(expected_evt, type) and not isinstance(evt, expected_evt): 219 print("received unexpected event") 220 print(f"expected event: {expected_evt.__class__.__name__}") 221 print("received event:") 222 evt.show() 223 self.assertTrue(False) 224 225 if isinstance(expected_evt, hci.Event) and evt != expected_evt: 226 print("received unexpected event") 227 print(f"expected event:") 228 expected_evt.show() 229 print("received event:") 230 evt.show() 231 self.assertTrue(False) 232 233 return evt 234 235 async def expect_cmd_complete(self, expected_evt: type, timeout: int = 3) -> hci.Event: 236 evt = await self.expect_evt(expected_evt, timeout=timeout) 237 assert evt.status == ErrorCode.SUCCESS 238 assert evt.num_hci_command_packets == 1 239 return evt 240 241 async def expect_ll(self, 242 expected_pdus: typing.Union[list, typing.Union[ll.LinkLayerPacket, type]], 243 timeout: int = 3) -> ll.LinkLayerPacket: 244 if not isinstance(expected_pdus, list): 245 expected_pdus = [expected_pdus] 246 247 packet = await asyncio.wait_for(self.controller.receive_ll(), timeout=timeout) 248 pdu = ll.LinkLayerPacket.parse_all(packet) 249 250 for expected_pdu in expected_pdus: 251 if isinstance(expected_pdu, type) and isinstance(pdu, expected_pdu): 252 return pdu 253 if isinstance(expected_pdu, ll.LinkLayerPacket) and pdu == expected_pdu: 254 return pdu 255 256 print("received unexpected pdu:") 257 pdu.show() 258 print("expected pdus:") 259 for expected_pdu in expected_pdus: 260 if isinstance(expected_pdu, type): 261 print(f"- {expected_pdu.__name__}") 262 if isinstance(expected_pdu, ll.LinkLayerPacket): 263 print(f"- {expected_pdu.__class__.__name__}") 264 expected_pdu.show() 265 266 self.assertTrue(False) 267 268 def tearDown(self): 269 self.controller.stop() 270