• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/gatt/gatt.h"
16 
17 #include <lib/fit/defer.h>
18 
19 #include <unordered_map>
20 
21 #include "pw_bluetooth_sapphire/internal/host/att/bearer.h"
22 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
23 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
24 #include "pw_bluetooth_sapphire/internal/host/gatt/client.h"
25 #include "pw_bluetooth_sapphire/internal/host/gatt/connection.h"
26 #include "pw_bluetooth_sapphire/internal/host/gatt/generic_attribute_service.h"
27 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
28 #include "pw_bluetooth_sapphire/internal/host/gatt/server.h"
29 #include "pw_bluetooth_sapphire/internal/host/l2cap/channel.h"
30 
31 namespace bt::gatt {
32 
GATT()33 GATT::GATT() : WeakSelf(this) {}
34 
35 namespace {
36 
37 class Impl final : public GATT {
38  public:
Impl()39   explicit Impl() {
40     local_services_ = std::make_unique<LocalServiceManager>();
41 
42     // Forwards Service Changed payloads to clients.
43     auto send_indication_callback = [this](IdType service_id,
44                                            IdType chrc_id,
45                                            PeerId peer_id,
46                                            BufferView value) {
47       auto iter = connections_.find(peer_id);
48       if (iter == connections_.end()) {
49         bt_log(WARN, "gatt", "peer not registered: %s", bt_str(peer_id));
50         return;
51       }
52       auto indication_cb = [](att::Result<> result) {
53         bt_log(TRACE,
54                "gatt",
55                "service changed indication complete: %s",
56                bt_str(result));
57       };
58       iter->second.server()->SendUpdate(
59           service_id, chrc_id, value.view(), std::move(indication_cb));
60     };
61 
62     // Spin up Generic Attribute as the first service.
63     gatt_service_ = std::make_unique<GenericAttributeService>(
64         local_services_->GetWeakPtr(), std::move(send_indication_callback));
65 
66     bt_log(DEBUG, "gatt", "initialized");
67   }
68 
~Impl()69   ~Impl() override {
70     bt_log(DEBUG, "gatt", "shutting down");
71 
72     connections_.clear();
73     gatt_service_ = nullptr;
74     local_services_ = nullptr;
75   }
76 
77   // GATT overrides:
78 
AddConnection(PeerId peer_id,std::unique_ptr<Client> client,Server::FactoryFunction server_factory)79   void AddConnection(PeerId peer_id,
80                      std::unique_ptr<Client> client,
81                      Server::FactoryFunction server_factory) override {
82     bt_log(DEBUG, "gatt", "add connection %s", bt_str(peer_id));
83 
84     auto iter = connections_.find(peer_id);
85     if (iter != connections_.end()) {
86       bt_log(WARN, "gatt", "peer is already registered: %s", bt_str(peer_id));
87       return;
88     }
89 
90     RemoteServiceWatcher service_watcher =
91         [this, peer_id](std::vector<att::Handle> removed,
92                         std::vector<RemoteService::WeakPtr> added,
93                         std::vector<RemoteService::WeakPtr> modified) {
94           OnServicesChanged(peer_id, removed, added, modified);
95         };
96     std::unique_ptr<Server> server =
97         server_factory(peer_id, local_services_->GetWeakPtr());
98     connections_.try_emplace(peer_id,
99                              std::move(client),
100                              std::move(server),
101                              std::move(service_watcher));
102 
103     if (retrieve_service_changed_ccc_callback_) {
104       auto optional_service_changed_ccc_data =
105           retrieve_service_changed_ccc_callback_(peer_id);
106       if (optional_service_changed_ccc_data && gatt_service_) {
107         gatt_service_->SetServiceChangedIndicationSubscription(
108             peer_id, optional_service_changed_ccc_data->indicate);
109       }
110     } else {
111       bt_log(WARN,
112              "gatt",
113              "Unable to retrieve service changed CCC: callback not set.");
114     }
115   }
116 
RemoveConnection(PeerId peer_id)117   void RemoveConnection(PeerId peer_id) override {
118     bt_log(DEBUG, "gatt", "remove connection: %s", bt_str(peer_id));
119     local_services_->DisconnectClient(peer_id);
120     connections_.erase(peer_id);
121   }
122 
RegisterPeerMtuListener(PeerMtuListener listener)123   PeerMtuListenerId RegisterPeerMtuListener(PeerMtuListener listener) override {
124     peer_mtu_listeners_.insert({next_mtu_listener_id_, std::move(listener)});
125     return next_mtu_listener_id_++;
126   }
127 
UnregisterPeerMtuListener(PeerMtuListenerId listener_id)128   bool UnregisterPeerMtuListener(PeerMtuListenerId listener_id) override {
129     return peer_mtu_listeners_.erase(listener_id) == 1;
130   }
131 
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)132   void RegisterService(ServicePtr service,
133                        ServiceIdCallback callback,
134                        ReadHandler read_handler,
135                        WriteHandler write_handler,
136                        ClientConfigCallback ccc_callback) override {
137     IdType id = local_services_->RegisterService(std::move(service),
138                                                  std::move(read_handler),
139                                                  std::move(write_handler),
140                                                  std::move(ccc_callback));
141     callback(id);
142   }
143 
UnregisterService(IdType service_id)144   void UnregisterService(IdType service_id) override {
145     local_services_->UnregisterService(service_id);
146   }
147 
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)148   void SendUpdate(IdType service_id,
149                   IdType chrc_id,
150                   PeerId peer_id,
151                   ::std::vector<uint8_t> value,
152                   IndicationCallback indicate_cb) override {
153     // There is nothing to do if the requested peer is not connected.
154     auto iter = connections_.find(peer_id);
155     if (iter == connections_.end()) {
156       bt_log(TRACE,
157              "gatt",
158              "cannot notify disconnected peer: %s",
159              bt_str(peer_id));
160       if (indicate_cb) {
161         indicate_cb(ToResult(HostError::kNotFound));
162       }
163       return;
164     }
165     iter->second.server()->SendUpdate(service_id,
166                                       chrc_id,
167                                       BufferView(value.data(), value.size()),
168                                       std::move(indicate_cb));
169   }
170 
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)171   void UpdateConnectedPeers(IdType service_id,
172                             IdType chrc_id,
173                             ::std::vector<uint8_t> value,
174                             IndicationCallback indicate_cb) override {
175     att::ResultFunction<> shared_peer_results_cb(nullptr);
176     if (indicate_cb) {
177       // This notifies indicate_cb with success when destroyed (if indicate_cb
178       // has not been invoked)
179       auto deferred_success =
180           fit::defer([outer_cb = indicate_cb.share()]() mutable {
181             if (outer_cb) {
182               outer_cb(fit::ok());
183             }
184           });
185       // This captures, but doesn't use, deferred_success. Because this is later
186       // |share|d for each peer's SendUpdate callback, deferred_success is
187       // stored in this refcounted memory. If any of the SendUpdate callbacks
188       // fail, the outer callback is notified of failure. But if all of the
189       // callbacks succeed, shared_peer_results_cb's captures will be destroyed,
190       // and deferred_success will then notify indicate_cb of success.
191       shared_peer_results_cb =
192           [deferred = std::move(deferred_success),
193            outer_cb = std::move(indicate_cb)](att::Result<> res) mutable {
194             if (outer_cb && res.is_error()) {
195               outer_cb(res);
196             }
197           };
198     }
199     for (auto& iter : connections_) {
200       // The `shared_peer_results_cb.share()` *does* propagate indication vs.
201       // notification-ness correctly - `fit::function(nullptr).share` just
202       // creates another null fit::function.
203       iter.second.server()->SendUpdate(service_id,
204                                        chrc_id,
205                                        BufferView(value.data(), value.size()),
206                                        shared_peer_results_cb.share());
207     }
208   }
209 
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)210   void SetPersistServiceChangedCCCCallback(
211       PersistServiceChangedCCCCallback callback) override {
212     gatt_service_->SetPersistServiceChangedCCCCallback(std::move(callback));
213   }
214 
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)215   void SetRetrieveServiceChangedCCCCallback(
216       RetrieveServiceChangedCCCCallback callback) override {
217     retrieve_service_changed_ccc_callback_ = std::move(callback);
218   }
219 
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)220   void InitializeClient(PeerId peer_id,
221                         std::vector<UUID> services_to_discover) override {
222     bt_log(TRACE, "gatt", "initialize client: %s", bt_str(peer_id));
223 
224     auto iter = connections_.find(peer_id);
225     if (iter == connections_.end()) {
226       bt_log(WARN, "gatt", "unknown peer: %s", bt_str(peer_id));
227       return;
228     }
229     auto mtu_cb = [this, peer_id](uint16_t mtu) {
230       for (auto& [_id, listener] : peer_mtu_listeners_) {
231         listener(peer_id, mtu);
232       }
233     };
234     iter->second.Initialize(std::move(services_to_discover), std::move(mtu_cb));
235   }
236 
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)237   RemoteServiceWatcherId RegisterRemoteServiceWatcherForPeer(
238       PeerId peer_id, RemoteServiceWatcher watcher) override {
239     BT_ASSERT(watcher);
240 
241     RemoteServiceWatcherId id = next_watcher_id_++;
242     peer_remote_service_watchers_.emplace(
243         peer_id, std::make_pair(id, std::move(watcher)));
244     return id;
245   }
246 
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)247   bool UnregisterRemoteServiceWatcher(
248       RemoteServiceWatcherId watcher_id) override {
249     for (auto it = peer_remote_service_watchers_.begin();
250          it != peer_remote_service_watchers_.end();) {
251       if (watcher_id == it->second.first) {
252         it = peer_remote_service_watchers_.erase(it);
253         return true;
254       }
255       it++;
256     }
257     return false;
258   }
259 
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)260   void ListServices(PeerId peer_id,
261                     std::vector<UUID> uuids,
262                     ServiceListCallback callback) override {
263     BT_ASSERT(callback);
264     auto iter = connections_.find(peer_id);
265     if (iter == connections_.end()) {
266       // Connection not found.
267       callback(ToResult(HostError::kNotFound), ServiceList());
268       return;
269     }
270     iter->second.remote_service_manager()->ListServices(uuids,
271                                                         std::move(callback));
272   }
273 
FindService(PeerId peer_id,IdType service_id)274   RemoteService::WeakPtr FindService(PeerId peer_id,
275                                      IdType service_id) override {
276     auto iter = connections_.find(peer_id);
277     if (iter == connections_.end()) {
278       // Connection not found.
279       return RemoteService::WeakPtr();
280     }
281     return iter->second.remote_service_manager()->FindService(
282         static_cast<att::Handle>(service_id));
283   }
284 
285  private:
OnServicesChanged(PeerId peer_id,const std::vector<att::Handle> & removed,const std::vector<RemoteService::WeakPtr> & added,const std::vector<RemoteService::WeakPtr> & modified)286   void OnServicesChanged(PeerId peer_id,
287                          const std::vector<att::Handle>& removed,
288                          const std::vector<RemoteService::WeakPtr>& added,
289                          const std::vector<RemoteService::WeakPtr>& modified) {
290     auto peer_watcher_range =
291         peer_remote_service_watchers_.equal_range(peer_id);
292     for (auto it = peer_watcher_range.first; it != peer_watcher_range.second;
293          it++) {
294       TRACE_DURATION("bluetooth", "GATT::OnServiceChanged notify watcher");
295       it->second.second(removed, added, modified);
296     }
297   }
298 
299   // The registry containing all local GATT services. This represents a single
300   // ATT database.
301   std::unique_ptr<LocalServiceManager> local_services_;
302 
303   // Local GATT service (first in database) for clients to subscribe to service
304   // registration and removal.
305   std::unique_ptr<GenericAttributeService> gatt_service_;
306 
307   // Contains the state of all GATT profile connections and their services.
308   std::unordered_map<PeerId, internal::Connection> connections_;
309 
310   // Callback to fetch CCC for Service Changed indications from upper layers.
311   RetrieveServiceChangedCCCCallback retrieve_service_changed_ccc_callback_;
312 
313   RemoteServiceWatcherId next_watcher_id_ = 0u;
314   std::unordered_multimap<
315       PeerId,
316       std::pair<RemoteServiceWatcherId, RemoteServiceWatcher>>
317       peer_remote_service_watchers_;
318   PeerMtuListenerId next_mtu_listener_id_ = 0u;
319   std::unordered_map<PeerMtuListenerId, PeerMtuListener> peer_mtu_listeners_;
320 
321   BT_DISALLOW_COPY_AND_ASSIGN_ALLOW_MOVE(Impl);
322 };
323 }  // namespace
324 
325 // static
Create()326 std::unique_ptr<GATT> GATT::Create() { return std::make_unique<Impl>(); }
327 
328 }  // namespace bt::gatt
329