• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import logging
17
18from avatar import BumblePandoraDevice, PandoraDevice, PandoraDevices
19from avatar.aio import asynchronous
20from bumble import smp
21from bumble.hci import Address
22from concurrent import futures
23from contextlib import suppress
24from mobly import base_test, signals, test_runner
25from mobly.asserts import assert_equal  # type: ignore
26from mobly.asserts import assert_false  # type: ignore
27from mobly.asserts import assert_is_not_none  # type: ignore
28from mobly.asserts import assert_true  # type: ignore
29from pandora.host_pb2 import RANDOM, DataTypes, OwnAddressType, ScanningResponse
30from pandora.security_pb2 import LE_LEVEL3, PairingEventAnswer
31from typing import NoReturn, Optional
32
33
34class SmpTest(base_test.BaseTestClass):  # type: ignore[misc]
35    devices: Optional[PandoraDevices] = None
36
37    dut: PandoraDevice
38    ref: PandoraDevice
39
40    def setup_class(self) -> None:
41        self.devices = PandoraDevices(self)
42        self.dut, self.ref, *_ = self.devices
43
44        # Enable BR/EDR mode for Bumble devices.
45        for device in self.devices:
46            if isinstance(device, BumblePandoraDevice):
47                device.config.setdefault('classic_enabled', True)
48
49    def teardown_class(self) -> None:
50        if self.devices:
51            self.devices.stop_all()
52
53    @asynchronous
54    async def setup_test(self) -> None:
55        await asyncio.gather(self.dut.reset(), self.ref.reset())
56
57    async def handle_pairing_events(self) -> NoReturn:
58        dut_pairing_stream = self.dut.aio.security.OnPairing()
59        try:
60            while True:
61                dut_pairing_event = await (anext(dut_pairing_stream))
62                dut_pairing_stream.send_nowait(
63                    PairingEventAnswer(
64                        event=dut_pairing_event,
65                        confirm=True,
66                    )
67                )
68        finally:
69            dut_pairing_stream.cancel()
70
71    async def dut_pair(self, dut_address_type: OwnAddressType, ref_address_type: OwnAddressType) -> ScanningResponse:
72        advertisement = self.ref.aio.host.Advertise(
73            legacy=True,
74            connectable=True,
75            own_address_type=ref_address_type,
76            data=DataTypes(manufacturer_specific_data=b'pause cafe'),
77        )
78
79        scan = self.dut.aio.host.Scan(own_address_type=dut_address_type)
80        ref = await anext((x async for x in scan if b'pause cafe' in x.data.manufacturer_specific_data))
81        scan.cancel()
82
83        pairing = asyncio.create_task(self.handle_pairing_events())
84        (dut_ref_res, ref_dut_res) = await asyncio.gather(
85            self.dut.aio.host.ConnectLE(own_address_type=dut_address_type, **ref.address_asdict()),
86            anext(aiter(advertisement)),
87        )
88
89        advertisement.cancel()
90        ref_dut, dut_ref = ref_dut_res.connection, dut_ref_res.connection
91        assert_is_not_none(dut_ref)
92        assert dut_ref
93
94        (secure, wait_security) = await asyncio.gather(
95            self.dut.aio.security.Secure(connection=dut_ref, le=LE_LEVEL3),
96            self.ref.aio.security.WaitSecurity(connection=ref_dut, le=LE_LEVEL3),
97        )
98
99        pairing.cancel()
100        with suppress(asyncio.CancelledError, futures.CancelledError):
101            await pairing
102
103        assert_equal(secure.result_variant(), 'success')
104        assert_equal(wait_security.result_variant(), 'success')
105
106        await asyncio.gather(
107            self.ref.aio.host.Disconnect(connection=ref_dut),
108            self.dut.aio.host.WaitDisconnection(connection=dut_ref),
109        )
110        return ref
111
112    @asynchronous
113    async def test_le_pairing_delete_dup_bond_record(self) -> None:
114        if isinstance(self.dut, BumblePandoraDevice):
115            raise signals.TestSkip('TODO: Fix test for Bumble DUT')
116        if not isinstance(self.ref, BumblePandoraDevice):
117            raise signals.TestSkip('Test require Bumble as reference device(s)')
118
119        class Session(smp.Session):
120
121            # Hack to send same identity address from ref during both pairing
122            def send_command(self: smp.Session, command: smp.SMP_Command) -> None:
123                if isinstance(command, smp.SMP_Identity_Address_Information_Command):
124                    command = smp.SMP_Identity_Address_Information_Command(
125                        addr_type=Address.RANDOM_IDENTITY_ADDRESS,
126                        bd_addr=Address(
127                            'F6:F7:F8:F9:FA:FB',
128                            Address.RANDOM_IDENTITY_ADDRESS,
129                        ),
130                    )
131                self.manager.send_command(self.connection, command)
132
133        self.ref.device.smp_session_proxy = Session
134
135        # Pair with same device 2 times.
136        # Ref device advertises with different random address but uses same identity address
137        ref1 = await self.dut_pair(dut_address_type=RANDOM, ref_address_type=RANDOM)
138        is_bonded = await self.dut.aio.security_storage.IsBonded(random=ref1.random)
139        assert_true(is_bonded.value, "")
140
141        await self.ref.reset()
142        self.ref.device.smp_session_proxy = Session
143
144        ref2 = await self.dut_pair(dut_address_type=RANDOM, ref_address_type=RANDOM)
145        is_bonded = await self.dut.aio.security_storage.IsBonded(random=ref2.random)
146        assert_true(is_bonded.value, "")
147
148        is_bonded = await self.dut.aio.security_storage.IsBonded(random=ref1.random)
149        assert_false(is_bonded.value, "")
150
151
152if __name__ == '__main__':
153    logging.basicConfig(level=logging.DEBUG)
154    test_runner.main()  # type: ignore
155