1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/gatt/fake_client.h"
16
17 #include <unordered_set>
18
19 #include "pw_bluetooth_sapphire/internal/host/gatt/client.h"
20
21 namespace bt::gatt::testing {
22
FakeClient(pw::async::Dispatcher & pw_dispatcher)23 FakeClient::FakeClient(pw::async::Dispatcher& pw_dispatcher)
24 : heap_dispatcher_(pw_dispatcher), weak_self_(this), weak_fake_(this) {}
25
mtu() const26 uint16_t FakeClient::mtu() const {
27 // TODO(armansito): Return a configurable value.
28 return att::kLEMinMTU;
29 }
30
ExchangeMTU(MTUCallback callback)31 void FakeClient::ExchangeMTU(MTUCallback callback) {
32 (void)heap_dispatcher_.Post(
33 [mtu_status = exchange_mtu_status_,
34 mtu = server_mtu_,
35 callback = std::move(callback)](pw::async::Context /*ctx*/,
36 pw::Status status) mutable {
37 if (!status.ok()) {
38 return;
39 }
40
41 if (mtu_status.is_error()) {
42 callback(fit::error(mtu_status.error_value()));
43 } else {
44 callback(fit::ok(mtu));
45 }
46 });
47 }
48
DiscoverServices(ServiceKind kind,ServiceCallback svc_callback,att::ResultFunction<> status_callback)49 void FakeClient::DiscoverServices(ServiceKind kind,
50 ServiceCallback svc_callback,
51 att::ResultFunction<> status_callback) {
52 DiscoverServicesInRange(kind,
53 /*start=*/att::kHandleMin,
54 /*end=*/att::kHandleMax,
55 std::move(svc_callback),
56 std::move(status_callback));
57 }
58
DiscoverServicesInRange(ServiceKind kind,att::Handle start,att::Handle end,ServiceCallback svc_callback,att::ResultFunction<> status_callback)59 void FakeClient::DiscoverServicesInRange(
60 ServiceKind kind,
61 att::Handle start,
62 att::Handle end,
63 ServiceCallback svc_callback,
64 att::ResultFunction<> status_callback) {
65 DiscoverServicesWithUuidsInRange(kind,
66 start,
67 end,
68 std::move(svc_callback),
69 std::move(status_callback),
70 /*uuids=*/{});
71 }
72
DiscoverServicesWithUuids(ServiceKind kind,ServiceCallback svc_callback,att::ResultFunction<> status_callback,std::vector<UUID> uuids)73 void FakeClient::DiscoverServicesWithUuids(
74 ServiceKind kind,
75 ServiceCallback svc_callback,
76 att::ResultFunction<> status_callback,
77 std::vector<UUID> uuids) {
78 DiscoverServicesWithUuidsInRange(kind,
79 /*start=*/att::kHandleMin,
80 /*end=*/att::kHandleMax,
81 std::move(svc_callback),
82 std::move(status_callback),
83 std::move(uuids));
84 }
85
DiscoverServicesWithUuidsInRange(ServiceKind kind,att::Handle start,att::Handle end,ServiceCallback svc_callback,att::ResultFunction<> status_callback,std::vector<UUID> uuids)86 void FakeClient::DiscoverServicesWithUuidsInRange(
87 ServiceKind kind,
88 att::Handle start,
89 att::Handle end,
90 ServiceCallback svc_callback,
91 att::ResultFunction<> status_callback,
92 std::vector<UUID> uuids) {
93 att::Result<> callback_status = fit::ok();
94 if (discover_services_callback_) {
95 callback_status = discover_services_callback_(kind);
96 }
97
98 std::unordered_set<UUID> uuids_set(uuids.cbegin(), uuids.cend());
99
100 if (callback_status.is_ok()) {
101 for (const ServiceData& svc : services_) {
102 bool uuid_matches =
103 uuids.empty() || uuids_set.find(svc.type) != uuids_set.end();
104 if (svc.kind == kind && uuid_matches && svc.range_start >= start &&
105 svc.range_start <= end) {
106 (void)heap_dispatcher_.Post(
107 [svc, cb = svc_callback.share()](pw::async::Context /*ctx*/,
108 pw::Status status) {
109 if (status.ok()) {
110 cb(svc);
111 }
112 });
113 }
114 }
115 }
116
117 (void)heap_dispatcher_.Post(
118 [callback_status, cb = std::move(status_callback)](
119 pw::async::Context /*ctx*/, pw::Status status) {
120 if (status.ok()) {
121 cb(callback_status);
122 }
123 });
124 }
125
DiscoverCharacteristics(att::Handle range_start,att::Handle range_end,CharacteristicCallback chrc_callback,att::ResultFunction<> status_callback)126 void FakeClient::DiscoverCharacteristics(
127 att::Handle range_start,
128 att::Handle range_end,
129 CharacteristicCallback chrc_callback,
130 att::ResultFunction<> status_callback) {
131 last_chrc_discovery_start_handle_ = range_start;
132 last_chrc_discovery_end_handle_ = range_end;
133 chrc_discovery_count_++;
134
135 (void)heap_dispatcher_.Post(
136 [this,
137 range_start,
138 range_end,
139 chrc_callback = std::move(chrc_callback),
140 status_callback = std::move(status_callback)](pw::async::Context /*ctx*/,
141 pw::Status status) {
142 if (!status.ok()) {
143 return;
144 }
145 for (const auto& chrc : chrcs_) {
146 if (chrc.handle >= range_start && chrc.handle <= range_end) {
147 chrc_callback(chrc);
148 }
149 }
150 status_callback(chrc_discovery_status_);
151 });
152 }
153
DiscoverDescriptors(att::Handle range_start,att::Handle range_end,DescriptorCallback desc_callback,att::ResultFunction<> status_callback)154 void FakeClient::DiscoverDescriptors(att::Handle range_start,
155 att::Handle range_end,
156 DescriptorCallback desc_callback,
157 att::ResultFunction<> status_callback) {
158 last_desc_discovery_start_handle_ = range_start;
159 last_desc_discovery_end_handle_ = range_end;
160 desc_discovery_count_++;
161
162 att::Result<> discovery_status = fit::ok();
163 if (!desc_discovery_status_target_ ||
164 desc_discovery_count_ == desc_discovery_status_target_) {
165 discovery_status = desc_discovery_status_;
166 }
167
168 (void)heap_dispatcher_.Post(
169 [this,
170 discovery_status,
171 range_start,
172 range_end,
173 desc_callback = std::move(desc_callback),
174 status_callback = std::move(status_callback)](pw::async::Context /*ctx*/,
175 pw::Status status) {
176 if (!status.ok()) {
177 return;
178 }
179 for (const auto& desc : descs_) {
180 if (desc.handle >= range_start && desc.handle <= range_end) {
181 desc_callback(desc);
182 }
183 }
184 status_callback(discovery_status);
185 });
186 }
187
ReadRequest(att::Handle handle,ReadCallback callback)188 void FakeClient::ReadRequest(att::Handle handle, ReadCallback callback) {
189 if (read_request_callback_) {
190 read_request_callback_(handle, std::move(callback));
191 }
192 }
193
ReadByTypeRequest(const UUID & type,att::Handle start_handle,att::Handle end_handle,ReadByTypeCallback callback)194 void FakeClient::ReadByTypeRequest(const UUID& type,
195 att::Handle start_handle,
196 att::Handle end_handle,
197 ReadByTypeCallback callback) {
198 if (read_by_type_request_callback_) {
199 read_by_type_request_callback_(
200 type, start_handle, end_handle, std::move(callback));
201 }
202 }
203
ReadBlobRequest(att::Handle handle,uint16_t offset,ReadCallback callback)204 void FakeClient::ReadBlobRequest(att::Handle handle,
205 uint16_t offset,
206 ReadCallback callback) {
207 if (read_blob_request_callback_) {
208 read_blob_request_callback_(handle, offset, std::move(callback));
209 }
210 }
211
WriteRequest(att::Handle handle,const ByteBuffer & value,att::ResultFunction<> callback)212 void FakeClient::WriteRequest(att::Handle handle,
213 const ByteBuffer& value,
214 att::ResultFunction<> callback) {
215 if (write_request_callback_) {
216 write_request_callback_(handle, value, std::move(callback));
217 }
218 }
219
ExecutePrepareWrites(att::PrepareWriteQueue write_queue,ReliableMode reliable_mode,att::ResultFunction<> callback)220 void FakeClient::ExecutePrepareWrites(att::PrepareWriteQueue write_queue,
221 ReliableMode reliable_mode,
222 att::ResultFunction<> callback) {
223 if (execute_prepare_writes_callback_) {
224 execute_prepare_writes_callback_(
225 std::move(write_queue), reliable_mode, std::move(callback));
226 }
227 }
228
PrepareWriteRequest(att::Handle handle,uint16_t offset,const ByteBuffer & part_value,PrepareCallback callback)229 void FakeClient::PrepareWriteRequest(att::Handle handle,
230 uint16_t offset,
231 const ByteBuffer& part_value,
232 PrepareCallback callback) {
233 if (prepare_write_request_callback_) {
234 prepare_write_request_callback_(
235 handle, offset, part_value, std::move(callback));
236 }
237 }
ExecuteWriteRequest(att::ExecuteWriteFlag flag,att::ResultFunction<> callback)238 void FakeClient::ExecuteWriteRequest(att::ExecuteWriteFlag flag,
239 att::ResultFunction<> callback) {
240 if (execute_write_request_callback_) {
241 execute_write_request_callback_(flag, std::move(callback));
242 }
243 }
244
WriteWithoutResponse(att::Handle handle,const ByteBuffer & value,att::ResultFunction<> callback)245 void FakeClient::WriteWithoutResponse(att::Handle handle,
246 const ByteBuffer& value,
247 att::ResultFunction<> callback) {
248 if (write_without_rsp_callback_) {
249 write_without_rsp_callback_(handle, value, std::move(callback));
250 }
251 }
252
SendNotification(bool indicate,att::Handle handle,const ByteBuffer & value,bool maybe_truncated)253 void FakeClient::SendNotification(bool indicate,
254 att::Handle handle,
255 const ByteBuffer& value,
256 bool maybe_truncated) {
257 if (notification_callback_) {
258 notification_callback_(indicate, handle, value, maybe_truncated);
259 }
260 }
261
SetNotificationHandler(NotificationCallback callback)262 void FakeClient::SetNotificationHandler(NotificationCallback callback) {
263 notification_callback_ = std::move(callback);
264 }
265
266 } // namespace bt::gatt::testing
267