• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/gatt/fake_layer.h"
16 
17 #include <pw_assert/check.h>
18 
19 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
20 
21 namespace bt::gatt::testing {
22 
TestPeer(pw::async::Dispatcher & pw_dispatcher)23 FakeLayer::TestPeer::TestPeer(pw::async::Dispatcher& pw_dispatcher)
24     : fake_client(pw_dispatcher) {}
25 
26 std::pair<RemoteService::WeakPtr, FakeClient::WeakPtr>
AddPeerService(PeerId peer_id,const ServiceData & info,bool notify)27 FakeLayer::AddPeerService(PeerId peer_id,
28                           const ServiceData& info,
29                           bool notify) {
30   auto [iter, _] = peers_.try_emplace(peer_id, pw_dispatcher_);
31   auto& peer = iter->second;
32 
33   PW_CHECK(info.range_start <= info.range_end);
34   auto service =
35       std::make_unique<RemoteService>(info, peer.fake_client.GetWeakPtr());
36   RemoteService::WeakPtr service_weak = service->GetWeakPtr();
37 
38   std::vector<att::Handle> removed;
39   ServiceList added;
40   ServiceList modified;
41 
42   auto svc_iter = peer.services.find(info.range_start);
43   if (svc_iter != peer.services.end()) {
44     if (svc_iter->second->uuid() == info.type) {
45       modified.push_back(service_weak);
46     } else {
47       removed.push_back(svc_iter->second->handle());
48       added.push_back(service_weak);
49     }
50 
51     svc_iter->second->set_service_changed(true);
52     peer.services.erase(svc_iter);
53   } else {
54     added.push_back(service_weak);
55   }
56 
57   bt_log(DEBUG,
58          "gatt",
59          "services changed (removed: %zu, added: %zu, modified: %zu)",
60          removed.size(),
61          added.size(),
62          modified.size());
63 
64   peer.services.emplace(info.range_start, std::move(service));
65 
66   if (notify && remote_service_watchers_.count(peer_id)) {
67     remote_service_watchers_[peer_id](removed, added, modified);
68   }
69 
70   return {service_weak, peer.fake_client.AsFakeWeakPtr()};
71 }
72 
RemovePeerService(PeerId peer_id,att::Handle handle)73 void FakeLayer::RemovePeerService(PeerId peer_id, att::Handle handle) {
74   auto peer_iter = peers_.find(peer_id);
75   if (peer_iter == peers_.end()) {
76     return;
77   }
78   auto svc_iter = peer_iter->second.services.find(handle);
79   if (svc_iter == peer_iter->second.services.end()) {
80     return;
81   }
82   svc_iter->second->set_service_changed(true);
83   peer_iter->second.services.erase(svc_iter);
84 
85   if (remote_service_watchers_.count(peer_id)) {
86     remote_service_watchers_[peer_id](
87         /*removed=*/{handle}, /*added=*/{}, /*modified=*/{});
88   }
89 }
90 
AddConnection(PeerId peer_id,std::unique_ptr<Client>,Server::FactoryFunction)91 void FakeLayer::AddConnection(PeerId peer_id,
92                               std::unique_ptr<Client>,
93                               Server::FactoryFunction) {
94   peers_.try_emplace(peer_id, pw_dispatcher_);
95 }
96 
RemoveConnection(PeerId peer_id)97 void FakeLayer::RemoveConnection(PeerId peer_id) { peers_.erase(peer_id); }
98 
RegisterPeerMtuListener(PeerMtuListener)99 GATT::PeerMtuListenerId FakeLayer::RegisterPeerMtuListener(PeerMtuListener) {
100   PW_CRASH("implement fake behavior if needed");
101 }
102 
UnregisterPeerMtuListener(PeerMtuListenerId)103 bool FakeLayer::UnregisterPeerMtuListener(PeerMtuListenerId) {
104   PW_CRASH("implement fake behavior if needed");
105 }
106 
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)107 void FakeLayer::RegisterService(ServicePtr service,
108                                 ServiceIdCallback callback,
109                                 ReadHandler read_handler,
110                                 WriteHandler write_handler,
111                                 ClientConfigCallback ccc_callback) {
112   if (register_service_fails_) {
113     callback(kInvalidId);
114     return;
115   }
116 
117   IdType id = next_local_service_id_++;
118   local_services_.try_emplace(id,
119                               LocalService{std::move(service),
120                                            std::move(read_handler),
121                                            std::move(write_handler),
122                                            std::move(ccc_callback),
123                                            {}});
124   callback(id);
125 }
126 
UnregisterService(IdType service_id)127 void FakeLayer::UnregisterService(IdType service_id) {
128   local_services_.erase(service_id);
129 }
130 
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)131 void FakeLayer::SendUpdate(IdType service_id,
132                            IdType chrc_id,
133                            PeerId peer_id,
134                            ::std::vector<uint8_t> value,
135                            IndicationCallback indicate_cb) {
136   auto iter = local_services_.find(service_id);
137   if (iter == local_services_.end()) {
138     indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
139     return;
140   }
141   iter->second.updates.push_back(
142       Update{chrc_id, std::move(value), std::move(indicate_cb), peer_id});
143 }
144 
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)145 void FakeLayer::UpdateConnectedPeers(IdType service_id,
146                                      IdType chrc_id,
147                                      ::std::vector<uint8_t> value,
148                                      IndicationCallback indicate_cb) {
149   auto iter = local_services_.find(service_id);
150   if (iter == local_services_.end()) {
151     indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
152     return;
153   }
154   iter->second.updates.push_back(
155       Update{chrc_id, std::move(value), std::move(indicate_cb), std::nullopt});
156 }
157 
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)158 void FakeLayer::SetPersistServiceChangedCCCCallback(
159     PersistServiceChangedCCCCallback callback) {
160   if (set_persist_service_changed_ccc_cb_cb_) {
161     set_persist_service_changed_ccc_cb_cb_();
162   }
163   persist_service_changed_ccc_cb_ = std::move(callback);
164 }
165 
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)166 void FakeLayer::SetRetrieveServiceChangedCCCCallback(
167     RetrieveServiceChangedCCCCallback callback) {
168   if (set_retrieve_service_changed_ccc_cb_cb_) {
169     set_retrieve_service_changed_ccc_cb_cb_();
170   }
171   retrieve_service_changed_ccc_cb_ = std::move(callback);
172 }
173 
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)174 void FakeLayer::InitializeClient(PeerId peer_id,
175                                  std::vector<UUID> services_to_discover) {
176   std::vector<UUID> uuids = std::move(services_to_discover);
177   if (initialize_client_cb_) {
178     initialize_client_cb_(peer_id, uuids);
179   }
180 
181   auto iter = peers_.find(peer_id);
182   if (iter == peers_.end()) {
183     return;
184   }
185 
186   std::vector<RemoteService::WeakPtr> added;
187   if (uuids.empty()) {
188     for (auto& svc_pair : iter->second.services) {
189       added.push_back(svc_pair.second->GetWeakPtr());
190     }
191   } else {
192     for (auto& svc_pair : iter->second.services) {
193       auto uuid_iter =
194           std::find_if(uuids.begin(), uuids.end(), [&svc_pair](auto uuid) {
195             return svc_pair.second->uuid() == uuid;
196           });
197       if (uuid_iter != uuids.end()) {
198         added.push_back(svc_pair.second->GetWeakPtr());
199       }
200     }
201   }
202 
203   if (remote_service_watchers_.count(peer_id)) {
204     remote_service_watchers_[peer_id](
205         /*removed=*/{}, /*added=*/added, /*modified=*/{});
206   }
207 }
208 
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)209 GATT::RemoteServiceWatcherId FakeLayer::RegisterRemoteServiceWatcherForPeer(
210     PeerId peer_id, RemoteServiceWatcher watcher) {
211   PW_CHECK(remote_service_watchers_.count(peer_id) == 0);
212   remote_service_watchers_[peer_id] = std::move(watcher);
213   // Use the PeerId as the watcher ID because FakeLayer only needs to support 1
214   // watcher per peer.
215   return peer_id.value();
216 }
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)217 bool FakeLayer::UnregisterRemoteServiceWatcher(
218     RemoteServiceWatcherId watcher_id) {
219   bool result = remote_service_watchers_.count(PeerId(watcher_id));
220   remote_service_watchers_.erase(PeerId(watcher_id));
221   return result;
222 }
223 
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)224 void FakeLayer::ListServices(PeerId peer_id,
225                              std::vector<UUID> uuids,
226                              ServiceListCallback callback) {
227   if (pause_list_services_) {
228     return;
229   }
230 
231   ServiceList services;
232 
233   auto iter = peers_.find(peer_id);
234   if (iter != peers_.end()) {
235     for (auto& svc_pair : iter->second.services) {
236       auto pred = [&](const UUID& uuid) {
237         return svc_pair.second->uuid() == uuid;
238       };
239       if (uuids.empty() ||
240           std::find_if(uuids.begin(), uuids.end(), pred) != uuids.end()) {
241         services.push_back(svc_pair.second->GetWeakPtr());
242       }
243     }
244   }
245 
246   callback(list_services_status_, std::move(services));
247 }
248 
FindService(PeerId peer_id,IdType service_id)249 RemoteService::WeakPtr FakeLayer::FindService(PeerId peer_id,
250                                               IdType service_id) {
251   auto peer_iter = peers_.find(peer_id);
252   if (peer_iter == peers_.end()) {
253     return RemoteService::WeakPtr();
254   }
255   auto svc_iter = peer_iter->second.services.find(service_id);
256   if (svc_iter == peer_iter->second.services.end()) {
257     return RemoteService::WeakPtr();
258   }
259   return svc_iter->second->GetWeakPtr();
260 }
261 
SetInitializeClientCallback(InitializeClientCallback cb)262 void FakeLayer::SetInitializeClientCallback(InitializeClientCallback cb) {
263   initialize_client_cb_ = std::move(cb);
264 }
265 
set_list_services_status(att::Result<> status)266 void FakeLayer::set_list_services_status(att::Result<> status) {
267   list_services_status_ = status;
268 }
269 
SetSetPersistServiceChangedCCCCallbackCallback(SetPersistServiceChangedCCCCallbackCallback cb)270 void FakeLayer::SetSetPersistServiceChangedCCCCallbackCallback(
271     SetPersistServiceChangedCCCCallbackCallback cb) {
272   set_persist_service_changed_ccc_cb_cb_ = std::move(cb);
273 }
274 
SetSetRetrieveServiceChangedCCCCallbackCallback(SetRetrieveServiceChangedCCCCallbackCallback cb)275 void FakeLayer::SetSetRetrieveServiceChangedCCCCallbackCallback(
276     SetRetrieveServiceChangedCCCCallbackCallback cb) {
277   set_retrieve_service_changed_ccc_cb_cb_ = std::move(cb);
278 }
279 
CallPersistServiceChangedCCCCallback(PeerId peer_id,bool notify,bool indicate)280 void FakeLayer::CallPersistServiceChangedCCCCallback(PeerId peer_id,
281                                                      bool notify,
282                                                      bool indicate) {
283   persist_service_changed_ccc_cb_(peer_id,
284                                   {.notify = notify, .indicate = indicate});
285 }
286 
287 std::optional<ServiceChangedCCCPersistedData>
CallRetrieveServiceChangedCCCCallback(PeerId peer_id)288 FakeLayer::CallRetrieveServiceChangedCCCCallback(PeerId peer_id) {
289   return retrieve_service_changed_ccc_cb_(peer_id);
290 }
291 
292 }  // namespace bt::gatt::testing
293