1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/gatt/fake_layer.h"
16
17 #include <pw_assert/check.h>
18
19 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
20
21 namespace bt::gatt::testing {
22
TestPeer(pw::async::Dispatcher & pw_dispatcher)23 FakeLayer::TestPeer::TestPeer(pw::async::Dispatcher& pw_dispatcher)
24 : fake_client(pw_dispatcher) {}
25
26 std::pair<RemoteService::WeakPtr, FakeClient::WeakPtr>
AddPeerService(PeerId peer_id,const ServiceData & info,bool notify)27 FakeLayer::AddPeerService(PeerId peer_id,
28 const ServiceData& info,
29 bool notify) {
30 auto [iter, _] = peers_.try_emplace(peer_id, pw_dispatcher_);
31 auto& peer = iter->second;
32
33 PW_CHECK(info.range_start <= info.range_end);
34 auto service =
35 std::make_unique<RemoteService>(info, peer.fake_client.GetWeakPtr());
36 RemoteService::WeakPtr service_weak = service->GetWeakPtr();
37
38 std::vector<att::Handle> removed;
39 ServiceList added;
40 ServiceList modified;
41
42 auto svc_iter = peer.services.find(info.range_start);
43 if (svc_iter != peer.services.end()) {
44 if (svc_iter->second->uuid() == info.type) {
45 modified.push_back(service_weak);
46 } else {
47 removed.push_back(svc_iter->second->handle());
48 added.push_back(service_weak);
49 }
50
51 svc_iter->second->set_service_changed(true);
52 peer.services.erase(svc_iter);
53 } else {
54 added.push_back(service_weak);
55 }
56
57 bt_log(DEBUG,
58 "gatt",
59 "services changed (removed: %zu, added: %zu, modified: %zu)",
60 removed.size(),
61 added.size(),
62 modified.size());
63
64 peer.services.emplace(info.range_start, std::move(service));
65
66 if (notify && remote_service_watchers_.count(peer_id)) {
67 remote_service_watchers_[peer_id](removed, added, modified);
68 }
69
70 return {service_weak, peer.fake_client.AsFakeWeakPtr()};
71 }
72
RemovePeerService(PeerId peer_id,att::Handle handle)73 void FakeLayer::RemovePeerService(PeerId peer_id, att::Handle handle) {
74 auto peer_iter = peers_.find(peer_id);
75 if (peer_iter == peers_.end()) {
76 return;
77 }
78 auto svc_iter = peer_iter->second.services.find(handle);
79 if (svc_iter == peer_iter->second.services.end()) {
80 return;
81 }
82 svc_iter->second->set_service_changed(true);
83 peer_iter->second.services.erase(svc_iter);
84
85 if (remote_service_watchers_.count(peer_id)) {
86 remote_service_watchers_[peer_id](
87 /*removed=*/{handle}, /*added=*/{}, /*modified=*/{});
88 }
89 }
90
AddConnection(PeerId peer_id,std::unique_ptr<Client>,Server::FactoryFunction)91 void FakeLayer::AddConnection(PeerId peer_id,
92 std::unique_ptr<Client>,
93 Server::FactoryFunction) {
94 peers_.try_emplace(peer_id, pw_dispatcher_);
95 }
96
RemoveConnection(PeerId peer_id)97 void FakeLayer::RemoveConnection(PeerId peer_id) { peers_.erase(peer_id); }
98
RegisterPeerMtuListener(PeerMtuListener)99 GATT::PeerMtuListenerId FakeLayer::RegisterPeerMtuListener(PeerMtuListener) {
100 PW_CRASH("implement fake behavior if needed");
101 }
102
UnregisterPeerMtuListener(PeerMtuListenerId)103 bool FakeLayer::UnregisterPeerMtuListener(PeerMtuListenerId) {
104 PW_CRASH("implement fake behavior if needed");
105 }
106
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)107 void FakeLayer::RegisterService(ServicePtr service,
108 ServiceIdCallback callback,
109 ReadHandler read_handler,
110 WriteHandler write_handler,
111 ClientConfigCallback ccc_callback) {
112 if (register_service_fails_) {
113 callback(kInvalidId);
114 return;
115 }
116
117 IdType id = next_local_service_id_++;
118 local_services_.try_emplace(id,
119 LocalService{std::move(service),
120 std::move(read_handler),
121 std::move(write_handler),
122 std::move(ccc_callback),
123 {}});
124 callback(id);
125 }
126
UnregisterService(IdType service_id)127 void FakeLayer::UnregisterService(IdType service_id) {
128 local_services_.erase(service_id);
129 }
130
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)131 void FakeLayer::SendUpdate(IdType service_id,
132 IdType chrc_id,
133 PeerId peer_id,
134 ::std::vector<uint8_t> value,
135 IndicationCallback indicate_cb) {
136 auto iter = local_services_.find(service_id);
137 if (iter == local_services_.end()) {
138 indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
139 return;
140 }
141 iter->second.updates.push_back(
142 Update{chrc_id, std::move(value), std::move(indicate_cb), peer_id});
143 }
144
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)145 void FakeLayer::UpdateConnectedPeers(IdType service_id,
146 IdType chrc_id,
147 ::std::vector<uint8_t> value,
148 IndicationCallback indicate_cb) {
149 auto iter = local_services_.find(service_id);
150 if (iter == local_services_.end()) {
151 indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
152 return;
153 }
154 iter->second.updates.push_back(
155 Update{chrc_id, std::move(value), std::move(indicate_cb), std::nullopt});
156 }
157
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)158 void FakeLayer::SetPersistServiceChangedCCCCallback(
159 PersistServiceChangedCCCCallback callback) {
160 if (set_persist_service_changed_ccc_cb_cb_) {
161 set_persist_service_changed_ccc_cb_cb_();
162 }
163 persist_service_changed_ccc_cb_ = std::move(callback);
164 }
165
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)166 void FakeLayer::SetRetrieveServiceChangedCCCCallback(
167 RetrieveServiceChangedCCCCallback callback) {
168 if (set_retrieve_service_changed_ccc_cb_cb_) {
169 set_retrieve_service_changed_ccc_cb_cb_();
170 }
171 retrieve_service_changed_ccc_cb_ = std::move(callback);
172 }
173
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)174 void FakeLayer::InitializeClient(PeerId peer_id,
175 std::vector<UUID> services_to_discover) {
176 std::vector<UUID> uuids = std::move(services_to_discover);
177 if (initialize_client_cb_) {
178 initialize_client_cb_(peer_id, uuids);
179 }
180
181 auto iter = peers_.find(peer_id);
182 if (iter == peers_.end()) {
183 return;
184 }
185
186 std::vector<RemoteService::WeakPtr> added;
187 if (uuids.empty()) {
188 for (auto& svc_pair : iter->second.services) {
189 added.push_back(svc_pair.second->GetWeakPtr());
190 }
191 } else {
192 for (auto& svc_pair : iter->second.services) {
193 auto uuid_iter =
194 std::find_if(uuids.begin(), uuids.end(), [&svc_pair](auto uuid) {
195 return svc_pair.second->uuid() == uuid;
196 });
197 if (uuid_iter != uuids.end()) {
198 added.push_back(svc_pair.second->GetWeakPtr());
199 }
200 }
201 }
202
203 if (remote_service_watchers_.count(peer_id)) {
204 remote_service_watchers_[peer_id](
205 /*removed=*/{}, /*added=*/added, /*modified=*/{});
206 }
207 }
208
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)209 GATT::RemoteServiceWatcherId FakeLayer::RegisterRemoteServiceWatcherForPeer(
210 PeerId peer_id, RemoteServiceWatcher watcher) {
211 PW_CHECK(remote_service_watchers_.count(peer_id) == 0);
212 remote_service_watchers_[peer_id] = std::move(watcher);
213 // Use the PeerId as the watcher ID because FakeLayer only needs to support 1
214 // watcher per peer.
215 return peer_id.value();
216 }
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)217 bool FakeLayer::UnregisterRemoteServiceWatcher(
218 RemoteServiceWatcherId watcher_id) {
219 bool result = remote_service_watchers_.count(PeerId(watcher_id));
220 remote_service_watchers_.erase(PeerId(watcher_id));
221 return result;
222 }
223
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)224 void FakeLayer::ListServices(PeerId peer_id,
225 std::vector<UUID> uuids,
226 ServiceListCallback callback) {
227 if (pause_list_services_) {
228 return;
229 }
230
231 ServiceList services;
232
233 auto iter = peers_.find(peer_id);
234 if (iter != peers_.end()) {
235 for (auto& svc_pair : iter->second.services) {
236 auto pred = [&](const UUID& uuid) {
237 return svc_pair.second->uuid() == uuid;
238 };
239 if (uuids.empty() ||
240 std::find_if(uuids.begin(), uuids.end(), pred) != uuids.end()) {
241 services.push_back(svc_pair.second->GetWeakPtr());
242 }
243 }
244 }
245
246 callback(list_services_status_, std::move(services));
247 }
248
FindService(PeerId peer_id,IdType service_id)249 RemoteService::WeakPtr FakeLayer::FindService(PeerId peer_id,
250 IdType service_id) {
251 auto peer_iter = peers_.find(peer_id);
252 if (peer_iter == peers_.end()) {
253 return RemoteService::WeakPtr();
254 }
255 auto svc_iter = peer_iter->second.services.find(service_id);
256 if (svc_iter == peer_iter->second.services.end()) {
257 return RemoteService::WeakPtr();
258 }
259 return svc_iter->second->GetWeakPtr();
260 }
261
SetInitializeClientCallback(InitializeClientCallback cb)262 void FakeLayer::SetInitializeClientCallback(InitializeClientCallback cb) {
263 initialize_client_cb_ = std::move(cb);
264 }
265
set_list_services_status(att::Result<> status)266 void FakeLayer::set_list_services_status(att::Result<> status) {
267 list_services_status_ = status;
268 }
269
SetSetPersistServiceChangedCCCCallbackCallback(SetPersistServiceChangedCCCCallbackCallback cb)270 void FakeLayer::SetSetPersistServiceChangedCCCCallbackCallback(
271 SetPersistServiceChangedCCCCallbackCallback cb) {
272 set_persist_service_changed_ccc_cb_cb_ = std::move(cb);
273 }
274
SetSetRetrieveServiceChangedCCCCallbackCallback(SetRetrieveServiceChangedCCCCallbackCallback cb)275 void FakeLayer::SetSetRetrieveServiceChangedCCCCallbackCallback(
276 SetRetrieveServiceChangedCCCCallbackCallback cb) {
277 set_retrieve_service_changed_ccc_cb_cb_ = std::move(cb);
278 }
279
CallPersistServiceChangedCCCCallback(PeerId peer_id,bool notify,bool indicate)280 void FakeLayer::CallPersistServiceChangedCCCCallback(PeerId peer_id,
281 bool notify,
282 bool indicate) {
283 persist_service_changed_ccc_cb_(peer_id,
284 {.notify = notify, .indicate = indicate});
285 }
286
287 std::optional<ServiceChangedCCCPersistedData>
CallRetrieveServiceChangedCCCCallback(PeerId peer_id)288 FakeLayer::CallRetrieveServiceChangedCCCCallback(PeerId peer_id) {
289 return retrieve_service_changed_ccc_cb_(peer_id);
290 }
291
292 } // namespace bt::gatt::testing
293