• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import avatar
16import asyncio
17import logging
18import grpc
19
20from concurrent import futures
21from contextlib import suppress
22
23from mobly import test_runner, base_test
24
25from bumble.smp import PairingDelegate
26
27from avatar.utils import Address, AsyncQueue
28from avatar.controllers import pandora_device
29from pandora.host_pb2 import (
30    DiscoverabilityMode, DataTypes, OwnAddressType
31)
32from pandora.security_pb2 import (
33    PairingEventAnswer, SecurityLevel, LESecurityLevel
34)
35
36
37class ExampleTest(base_test.BaseTestClass):
38    def setup_class(self):
39        self.pandora_devices = self.register_controller(pandora_device)
40        self.dut: pandora_device.PandoraDevice = self.pandora_devices[0]
41        self.ref: pandora_device.BumblePandoraDevice = self.pandora_devices[1]
42
43    @avatar.asynchronous
44    async def setup_test(self):
45        async def reset(device: pandora_device.PandoraDevice):
46            await device.host.FactoryReset()
47            device.address = (await device.host.ReadLocalAddress(wait_for_ready=True)).address
48
49        await asyncio.gather(reset(self.dut), reset(self.ref))
50
51    def test_print_addresses(self):
52        dut_address = self.dut.address
53        self.dut.log.info(f'Address: {dut_address}')
54        ref_address = self.ref.address
55        self.ref.log.info(f'Address: {ref_address}')
56
57    def test_get_remote_name(self):
58        dut_name = self.ref.host.GetRemoteName(address=self.dut.address).name
59        self.ref.log.info(f'DUT remote name: {dut_name}')
60        ref_name = self.dut.host.GetRemoteName(address=self.ref.address).name
61        self.dut.log.info(f'REF remote name: {ref_name}')
62
63    def test_classic_connect(self):
64        dut_address = self.dut.address
65        self.dut.log.info(f'Address: {dut_address}')
66        connection = self.ref.host.Connect(address=dut_address).connection
67        dut_name = self.ref.host.GetRemoteName(connection=connection).name
68        self.ref.log.info(f'Connected with: "{dut_name}" {dut_address}')
69        self.ref.host.Disconnect(connection=connection)
70
71    # Using this decorator allow us to write one `test_le_connect`, and
72    # run it multiple time with different parameters.
73    # Here we check that no matter the address type we use for both sides
74    # the connection still complete.
75    @avatar.parameterized([
76        (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC),
77        (OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
78        (OwnAddressType.RANDOM, OwnAddressType.RANDOM),
79        (OwnAddressType.RANDOM, OwnAddressType.PUBLIC),
80    ])
81    def test_le_connect(self, dut_address_type: OwnAddressType, ref_address_type: OwnAddressType):
82        self.ref.host.StartAdvertising(legacy=True, connectable=True, own_address_type=ref_address_type)
83        peers = self.dut.host.Scan(own_address_type=dut_address_type)
84        if ref_address_type == OwnAddressType.PUBLIC:
85            scan_response = next((x for x in peers if x.public == self.ref.address))
86            connection = self.dut.host.ConnectLE(public=scan_response.public, own_address_type=dut_address_type).connection
87        else:
88            scan_response = next((x for x in peers if x.random == Address(self.ref.device.random_address)))
89            connection = self.dut.host.ConnectLE(random=scan_response.random, own_address_type=dut_address_type).connection
90        self.dut.host.Disconnect(connection=connection)
91
92    def test_not_discoverable(self):
93        self.dut.host.SetDiscoverabilityMode(mode=DiscoverabilityMode.NOT_DISCOVERABLE)
94        peers = self.ref.host.Inquiry(timeout=3.0)
95        try:
96            assert not next((x for x in peers if x.address == self.dut.address), None)
97        except grpc.RpcError as e:
98            assert e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
99
100    @avatar.parameterized([
101        (DiscoverabilityMode.DISCOVERABLE_LIMITED, ),
102        (DiscoverabilityMode.DISCOVERABLE_GENERAL, ),
103    ])
104    def test_discoverable(self, mode):
105        self.dut.host.SetDiscoverabilityMode(mode=mode)
106        peers = self.ref.host.Inquiry(timeout=15.0)
107        assert next((x for x in peers if x.address == self.dut.address), None)
108
109    @avatar.asynchronous
110    async def test_wait_connection(self):
111        dut_ref = self.dut.host.WaitConnection(address=self.ref.address)
112        ref_dut = await self.ref.host.Connect(address=self.dut.address)
113        dut_ref = await dut_ref
114        assert ref_dut.connection and dut_ref.connection
115
116    @avatar.asynchronous
117    async def test_wait_any_connection(self):
118        dut_ref = self.dut.host.WaitConnection()
119        ref_dut = await self.ref.host.Connect(address=self.dut.address)
120        dut_ref = await dut_ref
121        assert ref_dut.connection and dut_ref.connection
122
123    def test_scan_response_data(self):
124        self.dut.host.StartAdvertising(
125            legacy=True,
126            data=DataTypes(
127                include_shortened_local_name=True,
128                tx_power_level=42,
129                incomplete_service_class_uuids16=['FDF0']
130            ),
131            scan_response_data=DataTypes(include_complete_local_name=True, include_class_of_device=True)
132        )
133
134        peers = self.ref.host.Scan()
135        scan_response = next((x for x in peers if x.public == self.dut.address))
136        assert type(scan_response.data.complete_local_name) == str
137        assert type(scan_response.data.shortened_local_name) == str
138        assert type(scan_response.data.class_of_device) == int
139        assert type(scan_response.data.incomplete_service_class_uuids16[0]) == str
140        assert scan_response.data.tx_power_level == 42
141
142    @avatar.parameterized([
143        (PairingDelegate.NO_OUTPUT_NO_INPUT, ),
144        (PairingDelegate.KEYBOARD_INPUT_ONLY, ),
145        (PairingDelegate.DISPLAY_OUTPUT_ONLY, ),
146        (PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, ),
147        (PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, ),
148    ])
149    @avatar.asynchronous
150    async def test_classic_pairing(self, ref_io_capability):
151        # override reference device IO capability
152        self.ref.device.io_capability = ref_io_capability
153
154        await self.ref.security_storage.DeleteBond(public=self.dut.address)
155
156        async def handle_pairing_events():
157            on_ref_pairing = self.ref.security.OnPairing((ref_answer_queue := AsyncQueue()))
158            on_dut_pairing = self.dut.security.OnPairing((dut_answer_queue := AsyncQueue()))
159
160            try:
161                while True:
162                    dut_pairing_event = await anext(aiter(on_dut_pairing))
163                    ref_pairing_event = await anext(aiter(on_ref_pairing))
164
165                    if dut_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works'):
166                        assert ref_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works')
167                        dut_answer_queue.put_nowait(PairingEventAnswer(
168                            event=dut_pairing_event,
169                            confirm=True,
170                        ))
171                        ref_answer_queue.put_nowait(PairingEventAnswer(
172                            event=ref_pairing_event,
173                            confirm=True,
174                        ))
175                    elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_notification':
176                        assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_request'
177                        ref_answer_queue.put_nowait(PairingEventAnswer(
178                            event=ref_pairing_event,
179                            passkey=dut_pairing_event.passkey_entry_notification,
180                        ))
181                    elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_request':
182                        assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_notification'
183                        dut_answer_queue.put_nowait(PairingEventAnswer(
184                            event=dut_pairing_event,
185                            passkey=ref_pairing_event.passkey_entry_notification,
186                        ))
187                    else:
188                        assert False
189
190            finally:
191                on_ref_pairing.cancel()
192                on_dut_pairing.cancel()
193
194        pairing = asyncio.create_task(handle_pairing_events())
195        ref_dut = (await self.ref.host.Connect(address=self.dut.address)).connection
196        dut_ref = (await self.dut.host.WaitConnection(address=self.ref.address)).connection
197
198        await asyncio.gather(
199            self.ref.security.Secure(connection=ref_dut, classic=SecurityLevel.LEVEL2),
200            self.dut.security.WaitSecurity(connection=dut_ref, classic=SecurityLevel.LEVEL2)
201        )
202
203        pairing.cancel()
204        with suppress(asyncio.CancelledError, futures.CancelledError):
205            await pairing
206
207        await asyncio.gather(
208            self.dut.host.Disconnect(connection=dut_ref),
209            self.ref.host.WaitDisconnection(connection=ref_dut)
210        )
211
212    @avatar.parameterized([
213        (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.NO_OUTPUT_NO_INPUT),
214        (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.KEYBOARD_INPUT_ONLY),
215        (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_ONLY),
216        (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT),
217        (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
218        (OwnAddressType.PUBLIC, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
219        (OwnAddressType.RANDOM, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
220        (OwnAddressType.RANDOM, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
221    ])
222    @avatar.asynchronous
223    async def test_le_pairing(self,
224        dut_address_type: OwnAddressType,
225        ref_address_type: OwnAddressType,
226        ref_io_capability
227    ):
228        # override reference device IO capability
229        self.ref.device.io_capability = ref_io_capability
230
231        if ref_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC):
232            ref_address = {'public': self.ref.address}
233        else:
234            ref_address = {'random': Address(self.ref.device.random_address)}
235
236        await self.dut.security_storage.DeleteBond(**ref_address)
237        await self.dut.host.StartAdvertising(legacy=True, connectable=True, own_address_type=dut_address_type)
238
239        dut = await anext(aiter(self.ref.host.Scan(own_address_type=ref_address_type)))
240        if dut_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC):
241            dut_address = {'public': Address(dut.public)}
242        else:
243            dut_address = {'random': Address(dut.random)}
244
245        async def handle_pairing_events():
246            on_ref_pairing = self.ref.security.OnPairing((ref_answer_queue := AsyncQueue()))
247            on_dut_pairing = self.dut.security.OnPairing((dut_answer_queue := AsyncQueue()))
248
249            try:
250                while True:
251                    dut_pairing_event = await anext(aiter(on_dut_pairing))
252                    ref_pairing_event = await anext(aiter(on_ref_pairing))
253
254                    if dut_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works'):
255                        assert ref_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works')
256                        dut_answer_queue.put_nowait(PairingEventAnswer(
257                            event=dut_pairing_event,
258                            confirm=True,
259                        ))
260                        ref_answer_queue.put_nowait(PairingEventAnswer(
261                            event=ref_pairing_event,
262                            confirm=True,
263                        ))
264                    elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_notification':
265                        assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_request'
266                        ref_answer_queue.put_nowait(PairingEventAnswer(
267                            event=ref_pairing_event,
268                            passkey=dut_pairing_event.passkey_entry_notification,
269                        ))
270                    elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_request':
271                        assert ref_pairing_event.WhichOneof('method') == 'passkey_entry_notification'
272                        dut_answer_queue.put_nowait(PairingEventAnswer(
273                            event=dut_pairing_event,
274                            passkey=ref_pairing_event.passkey_entry_notification,
275                        ))
276                    else:
277                        assert False
278
279            finally:
280                on_ref_pairing.cancel()
281                on_dut_pairing.cancel()
282
283        pairing = asyncio.create_task(handle_pairing_events())
284        ref_dut = (await self.ref.host.ConnectLE(own_address_type=ref_address_type, **dut_address)).connection
285        dut_ref = (await self.dut.host.WaitLEConnection(**ref_address)).connection
286
287        await asyncio.gather(
288            self.ref.security.Secure(connection=ref_dut, le=LESecurityLevel.LE_LEVEL4),
289            self.dut.security.WaitSecurity(connection=dut_ref, le=LESecurityLevel.LE_LEVEL4)
290        )
291
292        pairing.cancel()
293        with suppress(asyncio.CancelledError, futures.CancelledError):
294            await pairing
295
296        await asyncio.gather(
297            self.dut.host.Disconnect(connection=dut_ref),
298            self.ref.host.WaitDisconnection(connection=ref_dut)
299        )
300
301
302if __name__ == '__main__':
303    logging.basicConfig(level=logging.DEBUG)
304    test_runner.main()
305