1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import avatar 16import asyncio 17import logging 18import grpc 19 20from concurrent import futures 21from contextlib import suppress 22 23from mobly import test_runner, base_test 24 25from bumble.smp import PairingDelegate 26 27from avatar.utils import Address, AsyncQueue 28from avatar.controllers import pandora_device 29from pandora.host_pb2 import ( 30 DiscoverabilityMode, DataTypes, OwnAddressType 31) 32from pandora.security_pb2 import ( 33 PairingEventAnswer, SecurityLevel, LESecurityLevel 34) 35 36 37class ExampleTest(base_test.BaseTestClass): 38 def setup_class(self): 39 self.pandora_devices = self.register_controller(pandora_device) 40 self.dut: pandora_device.PandoraDevice = self.pandora_devices[0] 41 self.ref: pandora_device.BumblePandoraDevice = self.pandora_devices[1] 42 43 @avatar.asynchronous 44 async def setup_test(self): 45 async def reset(device: pandora_device.PandoraDevice): 46 await device.host.FactoryReset() 47 device.address = (await device.host.ReadLocalAddress(wait_for_ready=True)).address 48 49 await asyncio.gather(reset(self.dut), reset(self.ref)) 50 51 def test_print_addresses(self): 52 dut_address = self.dut.address 53 self.dut.log.info(f'Address: {dut_address}') 54 ref_address = self.ref.address 55 self.ref.log.info(f'Address: {ref_address}') 56 57 def test_get_remote_name(self): 58 dut_name = self.ref.host.GetRemoteName(address=self.dut.address).name 59 self.ref.log.info(f'DUT remote name: {dut_name}') 60 ref_name = self.dut.host.GetRemoteName(address=self.ref.address).name 61 self.dut.log.info(f'REF remote name: {ref_name}') 62 63 def test_classic_connect(self): 64 dut_address = self.dut.address 65 self.dut.log.info(f'Address: {dut_address}') 66 connection = self.ref.host.Connect(address=dut_address).connection 67 dut_name = self.ref.host.GetRemoteName(connection=connection).name 68 self.ref.log.info(f'Connected with: "{dut_name}" {dut_address}') 69 self.ref.host.Disconnect(connection=connection) 70 71 # Using this decorator allow us to write one `test_le_connect`, and 72 # run it multiple time with different parameters. 73 # Here we check that no matter the address type we use for both sides 74 # the connection still complete. 75 @avatar.parameterized([ 76 (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC), 77 (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), 78 (OwnAddressType.RANDOM, OwnAddressType.RANDOM), 79 (OwnAddressType.RANDOM, OwnAddressType.PUBLIC), 80 ]) 81 def test_le_connect(self, dut_address_type: OwnAddressType, ref_address_type: OwnAddressType): 82 self.ref.host.StartAdvertising(legacy=True, connectable=True, own_address_type=ref_address_type) 83 peers = self.dut.host.Scan(own_address_type=dut_address_type) 84 if ref_address_type == OwnAddressType.PUBLIC: 85 scan_response = next((x for x in peers if x.public == self.ref.address)) 86 connection = self.dut.host.ConnectLE(public=scan_response.public, own_address_type=dut_address_type).connection 87 else: 88 scan_response = next((x for x in peers if x.random == Address(self.ref.device.random_address))) 89 connection = self.dut.host.ConnectLE(random=scan_response.random, own_address_type=dut_address_type).connection 90 self.dut.host.Disconnect(connection=connection) 91 92 def test_not_discoverable(self): 93 self.dut.host.SetDiscoverabilityMode(mode=DiscoverabilityMode.NOT_DISCOVERABLE) 94 peers = self.ref.host.Inquiry(timeout=3.0) 95 try: 96 assert not next((x for x in peers if x.address == self.dut.address), None) 97 except grpc.RpcError as e: 98 assert e.code() == grpc.StatusCode.DEADLINE_EXCEEDED 99 100 @avatar.parameterized([ 101 (DiscoverabilityMode.DISCOVERABLE_LIMITED, ), 102 (DiscoverabilityMode.DISCOVERABLE_GENERAL, ), 103 ]) 104 def test_discoverable(self, mode): 105 self.dut.host.SetDiscoverabilityMode(mode=mode) 106 peers = self.ref.host.Inquiry(timeout=15.0) 107 assert next((x for x in peers if x.address == self.dut.address), None) 108 109 @avatar.asynchronous 110 async def test_wait_connection(self): 111 dut_ref = self.dut.host.WaitConnection(address=self.ref.address) 112 ref_dut = await self.ref.host.Connect(address=self.dut.address) 113 dut_ref = await dut_ref 114 assert ref_dut.connection and dut_ref.connection 115 116 @avatar.asynchronous 117 async def test_wait_any_connection(self): 118 dut_ref = self.dut.host.WaitConnection() 119 ref_dut = await self.ref.host.Connect(address=self.dut.address) 120 dut_ref = await dut_ref 121 assert ref_dut.connection and dut_ref.connection 122 123 def test_scan_response_data(self): 124 self.dut.host.StartAdvertising( 125 legacy=True, 126 data=DataTypes( 127 include_shortened_local_name=True, 128 tx_power_level=42, 129 incomplete_service_class_uuids16=['FDF0'] 130 ), 131 scan_response_data=DataTypes(include_complete_local_name=True, include_class_of_device=True) 132 ) 133 134 peers = self.ref.host.Scan() 135 scan_response = next((x for x in peers if x.public == self.dut.address)) 136 assert type(scan_response.data.complete_local_name) == str 137 assert type(scan_response.data.shortened_local_name) == str 138 assert type(scan_response.data.class_of_device) == int 139 assert type(scan_response.data.incomplete_service_class_uuids16[0]) == str 140 assert scan_response.data.tx_power_level == 42 141 142 @avatar.parameterized([ 143 (PairingDelegate.NO_OUTPUT_NO_INPUT, ), 144 (PairingDelegate.KEYBOARD_INPUT_ONLY, ), 145 (PairingDelegate.DISPLAY_OUTPUT_ONLY, ), 146 (PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, ), 147 (PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, ), 148 ]) 149 @avatar.asynchronous 150 async def test_classic_pairing(self, ref_io_capability): 151 # override reference device IO capability 152 self.ref.device.io_capability = ref_io_capability 153 154 await self.ref.security_storage.DeleteBond(public=self.dut.address) 155 156 async def handle_pairing_events(): 157 on_ref_pairing = self.ref.security.OnPairing((ref_answer_queue := AsyncQueue())) 158 on_dut_pairing = self.dut.security.OnPairing((dut_answer_queue := AsyncQueue())) 159 160 try: 161 while True: 162 dut_pairing_event = await anext(aiter(on_dut_pairing)) 163 ref_pairing_event = await anext(aiter(on_ref_pairing)) 164 165 if dut_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works'): 166 assert ref_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works') 167 dut_answer_queue.put_nowait(PairingEventAnswer( 168 event=dut_pairing_event, 169 confirm=True, 170 )) 171 ref_answer_queue.put_nowait(PairingEventAnswer( 172 event=ref_pairing_event, 173 confirm=True, 174 )) 175 elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_notification': 176 assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_request' 177 ref_answer_queue.put_nowait(PairingEventAnswer( 178 event=ref_pairing_event, 179 passkey=dut_pairing_event.passkey_entry_notification, 180 )) 181 elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_request': 182 assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_notification' 183 dut_answer_queue.put_nowait(PairingEventAnswer( 184 event=dut_pairing_event, 185 passkey=ref_pairing_event.passkey_entry_notification, 186 )) 187 else: 188 assert False 189 190 finally: 191 on_ref_pairing.cancel() 192 on_dut_pairing.cancel() 193 194 pairing = asyncio.create_task(handle_pairing_events()) 195 ref_dut = (await self.ref.host.Connect(address=self.dut.address)).connection 196 dut_ref = (await self.dut.host.WaitConnection(address=self.ref.address)).connection 197 198 await asyncio.gather( 199 self.ref.security.Secure(connection=ref_dut, classic=SecurityLevel.LEVEL2), 200 self.dut.security.WaitSecurity(connection=dut_ref, classic=SecurityLevel.LEVEL2) 201 ) 202 203 pairing.cancel() 204 with suppress(asyncio.CancelledError, futures.CancelledError): 205 await pairing 206 207 await asyncio.gather( 208 self.dut.host.Disconnect(connection=dut_ref), 209 self.ref.host.WaitDisconnection(connection=ref_dut) 210 ) 211 212 @avatar.parameterized([ 213 (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.NO_OUTPUT_NO_INPUT), 214 (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.KEYBOARD_INPUT_ONLY), 215 (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_ONLY), 216 (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT), 217 (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT), 218 (OwnAddressType.PUBLIC, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT), 219 (OwnAddressType.RANDOM, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT), 220 (OwnAddressType.RANDOM, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT), 221 ]) 222 @avatar.asynchronous 223 async def test_le_pairing(self, 224 dut_address_type: OwnAddressType, 225 ref_address_type: OwnAddressType, 226 ref_io_capability 227 ): 228 # override reference device IO capability 229 self.ref.device.io_capability = ref_io_capability 230 231 if ref_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC): 232 ref_address = {'public': self.ref.address} 233 else: 234 ref_address = {'random': Address(self.ref.device.random_address)} 235 236 await self.dut.security_storage.DeleteBond(**ref_address) 237 await self.dut.host.StartAdvertising(legacy=True, connectable=True, own_address_type=dut_address_type) 238 239 dut = await anext(aiter(self.ref.host.Scan(own_address_type=ref_address_type))) 240 if dut_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC): 241 dut_address = {'public': Address(dut.public)} 242 else: 243 dut_address = {'random': Address(dut.random)} 244 245 async def handle_pairing_events(): 246 on_ref_pairing = self.ref.security.OnPairing((ref_answer_queue := AsyncQueue())) 247 on_dut_pairing = self.dut.security.OnPairing((dut_answer_queue := AsyncQueue())) 248 249 try: 250 while True: 251 dut_pairing_event = await anext(aiter(on_dut_pairing)) 252 ref_pairing_event = await anext(aiter(on_ref_pairing)) 253 254 if dut_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works'): 255 assert ref_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works') 256 dut_answer_queue.put_nowait(PairingEventAnswer( 257 event=dut_pairing_event, 258 confirm=True, 259 )) 260 ref_answer_queue.put_nowait(PairingEventAnswer( 261 event=ref_pairing_event, 262 confirm=True, 263 )) 264 elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_notification': 265 assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_request' 266 ref_answer_queue.put_nowait(PairingEventAnswer( 267 event=ref_pairing_event, 268 passkey=dut_pairing_event.passkey_entry_notification, 269 )) 270 elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_request': 271 assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_notification' 272 dut_answer_queue.put_nowait(PairingEventAnswer( 273 event=dut_pairing_event, 274 passkey=ref_pairing_event.passkey_entry_notification, 275 )) 276 else: 277 assert False 278 279 finally: 280 on_ref_pairing.cancel() 281 on_dut_pairing.cancel() 282 283 pairing = asyncio.create_task(handle_pairing_events()) 284 ref_dut = (await self.ref.host.ConnectLE(own_address_type=ref_address_type, **dut_address)).connection 285 dut_ref = (await self.dut.host.WaitLEConnection(**ref_address)).connection 286 287 await asyncio.gather( 288 self.ref.security.Secure(connection=ref_dut, le=LESecurityLevel.LE_LEVEL4), 289 self.dut.security.WaitSecurity(connection=dut_ref, le=LESecurityLevel.LE_LEVEL4) 290 ) 291 292 pairing.cancel() 293 with suppress(asyncio.CancelledError, futures.CancelledError): 294 await pairing 295 296 await asyncio.gather( 297 self.dut.host.Disconnect(connection=dut_ref), 298 self.ref.host.WaitDisconnection(connection=ref_dut) 299 ) 300 301 302if __name__ == '__main__': 303 logging.basicConfig(level=logging.DEBUG) 304 test_runner.main() 305