1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/sm/security_request_phase.h"
16
17 #include <memory>
18
19 #include "pw_bluetooth_sapphire/internal/host/common/byte_buffer.h"
20 #include "pw_bluetooth_sapphire/internal/host/common/macros.h"
21 #include "pw_bluetooth_sapphire/internal/host/hci/connection.h"
22 #include "pw_bluetooth_sapphire/internal/host/l2cap/fake_channel_test.h"
23 #include "pw_bluetooth_sapphire/internal/host/sm/fake_phase_listener.h"
24 #include "pw_bluetooth_sapphire/internal/host/sm/packet.h"
25 #include "pw_bluetooth_sapphire/internal/host/sm/smp.h"
26 #include "pw_bluetooth_sapphire/internal/host/sm/types.h"
27 #include "pw_bluetooth_sapphire/internal/host/sm/util.h"
28 #include "pw_bluetooth_sapphire/internal/host/testing/test_helpers.h"
29 #include "pw_unit_test/framework.h"
30
31 namespace bt::sm {
32 namespace {
33 struct SecurityRequestOptions {
34 SecurityLevel requested_level = SecurityLevel::kEncrypted;
35 BondableMode bondable = BondableMode::Bondable;
36 };
37
38 class SecurityRequestPhaseTest : public l2cap::testing::FakeChannelTest {
39 public:
40 SecurityRequestPhaseTest() = default;
41 ~SecurityRequestPhaseTest() override = default;
42
43 protected:
SetUp()44 void SetUp() override { NewSecurityRequestPhase(); }
45
TearDown()46 void TearDown() override { security_request_phase_ = nullptr; }
47
heap_dispatcher()48 pw::async::HeapDispatcher& heap_dispatcher() { return heap_dispatcher_; }
49
NewSecurityRequestPhase(SecurityRequestOptions opts=SecurityRequestOptions (),bt::LinkType ll_type=bt::LinkType::kLE)50 void NewSecurityRequestPhase(
51 SecurityRequestOptions opts = SecurityRequestOptions(),
52 bt::LinkType ll_type = bt::LinkType::kLE) {
53 l2cap::ChannelId cid = ll_type == bt::LinkType::kLE ? l2cap::kLESMPChannelId
54 : l2cap::kSMPChannelId;
55 ChannelOptions options(cid);
56 options.link_type = ll_type;
57
58 fake_chan_ = CreateFakeChannel(options);
59 sm_chan_ = std::make_unique<PairingChannel>(fake_chan_->GetWeakPtr());
60 fake_listener_ = std::make_unique<FakeListener>();
61 security_request_phase_ = std::make_unique<SecurityRequestPhase>(
62 sm_chan_->GetWeakPtr(),
63 fake_listener_->as_weak_ptr(),
64 opts.requested_level,
65 opts.bondable,
66 [this](PairingRequestParams preq) { last_pairing_req_ = preq; });
67 }
68
fake_chan() const69 l2cap::testing::FakeChannel* fake_chan() const { return fake_chan_.get(); }
security_request_phase()70 SecurityRequestPhase* security_request_phase() {
71 return security_request_phase_.get();
72 }
73
last_pairing_req()74 std::optional<PairingRequestParams> last_pairing_req() {
75 return last_pairing_req_;
76 }
77
78 private:
79 std::unique_ptr<l2cap::testing::FakeChannel> fake_chan_;
80 std::unique_ptr<PairingChannel> sm_chan_;
81 std::unique_ptr<FakeListener> fake_listener_;
82 std::unique_ptr<SecurityRequestPhase> security_request_phase_;
83
84 std::optional<PairingRequestParams> last_pairing_req_;
85
86 pw::async::HeapDispatcher heap_dispatcher_{dispatcher()};
87
88 BT_DISALLOW_COPY_AND_ASSIGN_ALLOW_MOVE(SecurityRequestPhaseTest);
89 };
90
91 using SMP_SecurityRequestPhaseDeathTest = SecurityRequestPhaseTest;
92
TEST_F(SecurityRequestPhaseTest,MakeEncryptedBondableSecurityRequest)93 TEST_F(SecurityRequestPhaseTest, MakeEncryptedBondableSecurityRequest) {
94 NewSecurityRequestPhase(
95 SecurityRequestOptions{.requested_level = SecurityLevel::kEncrypted,
96 .bondable = BondableMode::Bondable});
97 StaticByteBuffer kExpectedReq(kSecurityRequest, AuthReq::kBondingFlag);
98 (void)heap_dispatcher().Post(
99 [this](pw::async::Context /*ctx*/, pw::Status status) {
100 if (status.ok()) {
101 security_request_phase()->Start();
102 }
103 });
104 ASSERT_TRUE(Expect(kExpectedReq));
105 EXPECT_EQ(SecurityLevel::kEncrypted,
106 security_request_phase()->pending_security_request());
107 }
108
TEST_F(SecurityRequestPhaseTest,MakeAuthenticatedNonBondableSecurityRequest)109 TEST_F(SecurityRequestPhaseTest, MakeAuthenticatedNonBondableSecurityRequest) {
110 NewSecurityRequestPhase(
111 SecurityRequestOptions{.requested_level = SecurityLevel::kAuthenticated,
112 .bondable = BondableMode::NonBondable});
113 // inclusive-language: ignore
114 StaticByteBuffer kExpectedReq(kSecurityRequest, AuthReq::kMITM);
115 (void)heap_dispatcher().Post(
116 [this](pw::async::Context /*ctx*/, pw::Status status) {
117 if (status.ok()) {
118 security_request_phase()->Start();
119 }
120 });
121 ASSERT_TRUE(Expect(kExpectedReq));
122 EXPECT_EQ(SecurityLevel::kAuthenticated,
123 security_request_phase()->pending_security_request());
124 }
125
TEST_F(SecurityRequestPhaseTest,MakeSecureAuthenticatedBondableSecurityRequest)126 TEST_F(SecurityRequestPhaseTest,
127 MakeSecureAuthenticatedBondableSecurityRequest) {
128 NewSecurityRequestPhase(SecurityRequestOptions{
129 .requested_level = SecurityLevel::kSecureAuthenticated});
130
131 // inclusive-language: disable
132 StaticByteBuffer kExpectedReq(
133 kSecurityRequest, AuthReq::kBondingFlag | AuthReq::kMITM | AuthReq::kSC);
134 // inclusive-language: enable
135
136 (void)heap_dispatcher().Post(
137 [this](pw::async::Context /*ctx*/, pw::Status status) {
138 if (status.ok()) {
139 security_request_phase()->Start();
140 }
141 });
142 ASSERT_TRUE(Expect(kExpectedReq));
143 EXPECT_EQ(SecurityLevel::kSecureAuthenticated,
144 security_request_phase()->pending_security_request());
145 }
146
TEST_F(SecurityRequestPhaseTest,HandlesChannelClosedGracefully)147 TEST_F(SecurityRequestPhaseTest, HandlesChannelClosedGracefully) {
148 fake_chan()->Close();
149 RunUntilIdle();
150 }
151
TEST_F(SecurityRequestPhaseTest,PairingRequestAsResponderPassedThrough)152 TEST_F(SecurityRequestPhaseTest, PairingRequestAsResponderPassedThrough) {
153 StaticByteBuffer<util::PacketSize<PairingRequestParams>()> preq_packet;
154 PacketWriter writer(kPairingRequest, &preq_packet);
155 PairingRequestParams generic_preq{.io_capability = IOCapability::kDisplayOnly,
156 .oob_data_flag = OOBDataFlag::kNotPresent,
157 .auth_req = AuthReq::kBondingFlag,
158 .max_encryption_key_size = 0,
159 .initiator_key_dist_gen = 0,
160 .responder_key_dist_gen = 0};
161 *writer.mutable_payload<PairingRequestParams>() = generic_preq;
162 ASSERT_FALSE(last_pairing_req().has_value());
163 fake_chan()->Receive(preq_packet);
164 RunUntilIdle();
165 ASSERT_TRUE(last_pairing_req().has_value());
166 PairingRequestParams last_preq = last_pairing_req().value();
167 ASSERT_EQ(0, memcmp(&last_preq, &generic_preq, sizeof(PairingRequestParams)));
168 }
169
TEST_F(SecurityRequestPhaseTest,InboundSecurityRequestFails)170 TEST_F(SecurityRequestPhaseTest, InboundSecurityRequestFails) {
171 StaticByteBuffer<util::PacketSize<PairingResponseParams>()> pres_packet;
172 PacketWriter writer(kPairingResponse, &pres_packet);
173 *writer.mutable_payload<PairingResponseParams>() = PairingResponseParams();
174
175 bool message_sent = false;
176 fake_chan()->SetSendCallback(
177 [&message_sent](ByteBufferPtr sdu) {
178 ValidPacketReader reader = ValidPacketReader::ParseSdu(sdu).value();
179 ASSERT_EQ(reader.code(), kPairingFailed);
180 message_sent = true;
181 },
182 dispatcher());
183
184 fake_chan()->Receive(pres_packet);
185 RunUntilIdle();
186 ASSERT_FALSE(last_pairing_req().has_value());
187 ASSERT_TRUE(message_sent);
188 }
189
TEST_F(SecurityRequestPhaseTest,DropsInvalidPacket)190 TEST_F(SecurityRequestPhaseTest, DropsInvalidPacket) {
191 StaticByteBuffer bad_packet(0xFF); // 0xFF is not a valid SMP header code
192
193 bool message_sent = false;
194 fake_chan()->SetSendCallback(
195 [&message_sent](ByteBufferPtr sdu) {
196 ValidPacketReader reader = ValidPacketReader::ParseSdu(sdu).value();
197 ASSERT_EQ(reader.code(), kPairingFailed);
198 message_sent = true;
199 },
200 dispatcher());
201
202 fake_chan()->Receive(bad_packet);
203 RunUntilIdle();
204 ASSERT_FALSE(last_pairing_req().has_value());
205 ASSERT_TRUE(message_sent);
206 }
207
208 } // namespace
209 } // namespace bt::sm
210