1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/gatt/gatt.h"
16
17 #include <lib/fit/defer.h>
18
19 #include <unordered_map>
20
21 #include "pw_bluetooth_sapphire/internal/host/att/bearer.h"
22 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
23 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
24 #include "pw_bluetooth_sapphire/internal/host/gatt/client.h"
25 #include "pw_bluetooth_sapphire/internal/host/gatt/connection.h"
26 #include "pw_bluetooth_sapphire/internal/host/gatt/generic_attribute_service.h"
27 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
28 #include "pw_bluetooth_sapphire/internal/host/gatt/server.h"
29 #include "pw_bluetooth_sapphire/internal/host/l2cap/channel.h"
30
31 namespace bt::gatt {
32
GATT()33 GATT::GATT() : WeakSelf(this) {}
34
35 namespace {
36
37 class Impl final : public GATT {
38 public:
Impl()39 explicit Impl() {
40 local_services_ = std::make_unique<LocalServiceManager>();
41
42 // Forwards Service Changed payloads to clients.
43 auto send_indication_callback = [this](IdType service_id,
44 IdType chrc_id,
45 PeerId peer_id,
46 BufferView value) {
47 auto iter = connections_.find(peer_id);
48 if (iter == connections_.end()) {
49 bt_log(WARN, "gatt", "peer not registered: %s", bt_str(peer_id));
50 return;
51 }
52 auto indication_cb = [](att::Result<> result) {
53 bt_log(TRACE,
54 "gatt",
55 "service changed indication complete: %s",
56 bt_str(result));
57 };
58 iter->second.server()->SendUpdate(
59 service_id, chrc_id, value.view(), std::move(indication_cb));
60 };
61
62 // Spin up Generic Attribute as the first service.
63 gatt_service_ = std::make_unique<GenericAttributeService>(
64 local_services_->GetWeakPtr(), std::move(send_indication_callback));
65
66 bt_log(DEBUG, "gatt", "initialized");
67 }
68
~Impl()69 ~Impl() override {
70 bt_log(DEBUG, "gatt", "shutting down");
71
72 connections_.clear();
73 gatt_service_ = nullptr;
74 local_services_ = nullptr;
75 }
76
77 // GATT overrides:
78
AddConnection(PeerId peer_id,std::unique_ptr<Client> client,Server::FactoryFunction server_factory)79 void AddConnection(PeerId peer_id,
80 std::unique_ptr<Client> client,
81 Server::FactoryFunction server_factory) override {
82 bt_log(DEBUG, "gatt", "add connection %s", bt_str(peer_id));
83
84 auto iter = connections_.find(peer_id);
85 if (iter != connections_.end()) {
86 bt_log(WARN, "gatt", "peer is already registered: %s", bt_str(peer_id));
87 return;
88 }
89
90 RemoteServiceWatcher service_watcher =
91 [this, peer_id](std::vector<att::Handle> removed,
92 std::vector<RemoteService::WeakPtr> added,
93 std::vector<RemoteService::WeakPtr> modified) {
94 OnServicesChanged(peer_id, removed, added, modified);
95 };
96 std::unique_ptr<Server> server =
97 server_factory(peer_id, local_services_->GetWeakPtr());
98 connections_.try_emplace(peer_id,
99 std::move(client),
100 std::move(server),
101 std::move(service_watcher));
102
103 if (retrieve_service_changed_ccc_callback_) {
104 auto optional_service_changed_ccc_data =
105 retrieve_service_changed_ccc_callback_(peer_id);
106 if (optional_service_changed_ccc_data && gatt_service_) {
107 gatt_service_->SetServiceChangedIndicationSubscription(
108 peer_id, optional_service_changed_ccc_data->indicate);
109 }
110 } else {
111 bt_log(WARN,
112 "gatt",
113 "Unable to retrieve service changed CCC: callback not set.");
114 }
115 }
116
RemoveConnection(PeerId peer_id)117 void RemoveConnection(PeerId peer_id) override {
118 bt_log(DEBUG, "gatt", "remove connection: %s", bt_str(peer_id));
119 local_services_->DisconnectClient(peer_id);
120 connections_.erase(peer_id);
121 }
122
RegisterPeerMtuListener(PeerMtuListener listener)123 PeerMtuListenerId RegisterPeerMtuListener(PeerMtuListener listener) override {
124 peer_mtu_listeners_.insert({next_mtu_listener_id_, std::move(listener)});
125 return next_mtu_listener_id_++;
126 }
127
UnregisterPeerMtuListener(PeerMtuListenerId listener_id)128 bool UnregisterPeerMtuListener(PeerMtuListenerId listener_id) override {
129 return peer_mtu_listeners_.erase(listener_id) == 1;
130 }
131
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)132 void RegisterService(ServicePtr service,
133 ServiceIdCallback callback,
134 ReadHandler read_handler,
135 WriteHandler write_handler,
136 ClientConfigCallback ccc_callback) override {
137 IdType id = local_services_->RegisterService(std::move(service),
138 std::move(read_handler),
139 std::move(write_handler),
140 std::move(ccc_callback));
141 callback(id);
142 }
143
UnregisterService(IdType service_id)144 void UnregisterService(IdType service_id) override {
145 local_services_->UnregisterService(service_id);
146 }
147
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)148 void SendUpdate(IdType service_id,
149 IdType chrc_id,
150 PeerId peer_id,
151 ::std::vector<uint8_t> value,
152 IndicationCallback indicate_cb) override {
153 // There is nothing to do if the requested peer is not connected.
154 auto iter = connections_.find(peer_id);
155 if (iter == connections_.end()) {
156 bt_log(TRACE,
157 "gatt",
158 "cannot notify disconnected peer: %s",
159 bt_str(peer_id));
160 if (indicate_cb) {
161 indicate_cb(ToResult(HostError::kNotFound));
162 }
163 return;
164 }
165 iter->second.server()->SendUpdate(service_id,
166 chrc_id,
167 BufferView(value.data(), value.size()),
168 std::move(indicate_cb));
169 }
170
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)171 void UpdateConnectedPeers(IdType service_id,
172 IdType chrc_id,
173 ::std::vector<uint8_t> value,
174 IndicationCallback indicate_cb) override {
175 att::ResultFunction<> shared_peer_results_cb(nullptr);
176 if (indicate_cb) {
177 // This notifies indicate_cb with success when destroyed (if indicate_cb
178 // has not been invoked)
179 auto deferred_success =
180 fit::defer([outer_cb = indicate_cb.share()]() mutable {
181 if (outer_cb) {
182 outer_cb(fit::ok());
183 }
184 });
185 // This captures, but doesn't use, deferred_success. Because this is later
186 // |share|d for each peer's SendUpdate callback, deferred_success is
187 // stored in this refcounted memory. If any of the SendUpdate callbacks
188 // fail, the outer callback is notified of failure. But if all of the
189 // callbacks succeed, shared_peer_results_cb's captures will be destroyed,
190 // and deferred_success will then notify indicate_cb of success.
191 shared_peer_results_cb =
192 [deferred = std::move(deferred_success),
193 outer_cb = std::move(indicate_cb)](att::Result<> res) mutable {
194 if (outer_cb && res.is_error()) {
195 outer_cb(res);
196 }
197 };
198 }
199 for (auto& iter : connections_) {
200 // The `shared_peer_results_cb.share()` *does* propagate indication vs.
201 // notification-ness correctly - `fit::function(nullptr).share` just
202 // creates another null fit::function.
203 iter.second.server()->SendUpdate(service_id,
204 chrc_id,
205 BufferView(value.data(), value.size()),
206 shared_peer_results_cb.share());
207 }
208 }
209
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)210 void SetPersistServiceChangedCCCCallback(
211 PersistServiceChangedCCCCallback callback) override {
212 gatt_service_->SetPersistServiceChangedCCCCallback(std::move(callback));
213 }
214
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)215 void SetRetrieveServiceChangedCCCCallback(
216 RetrieveServiceChangedCCCCallback callback) override {
217 retrieve_service_changed_ccc_callback_ = std::move(callback);
218 }
219
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)220 void InitializeClient(PeerId peer_id,
221 std::vector<UUID> services_to_discover) override {
222 bt_log(TRACE, "gatt", "initialize client: %s", bt_str(peer_id));
223
224 auto iter = connections_.find(peer_id);
225 if (iter == connections_.end()) {
226 bt_log(WARN, "gatt", "unknown peer: %s", bt_str(peer_id));
227 return;
228 }
229 auto mtu_cb = [this, peer_id](uint16_t mtu) {
230 for (auto& [_id, listener] : peer_mtu_listeners_) {
231 listener(peer_id, mtu);
232 }
233 };
234 iter->second.Initialize(std::move(services_to_discover), std::move(mtu_cb));
235 }
236
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)237 RemoteServiceWatcherId RegisterRemoteServiceWatcherForPeer(
238 PeerId peer_id, RemoteServiceWatcher watcher) override {
239 BT_ASSERT(watcher);
240
241 RemoteServiceWatcherId id = next_watcher_id_++;
242 peer_remote_service_watchers_.emplace(
243 peer_id, std::make_pair(id, std::move(watcher)));
244 return id;
245 }
246
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)247 bool UnregisterRemoteServiceWatcher(
248 RemoteServiceWatcherId watcher_id) override {
249 for (auto it = peer_remote_service_watchers_.begin();
250 it != peer_remote_service_watchers_.end();) {
251 if (watcher_id == it->second.first) {
252 it = peer_remote_service_watchers_.erase(it);
253 return true;
254 }
255 it++;
256 }
257 return false;
258 }
259
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)260 void ListServices(PeerId peer_id,
261 std::vector<UUID> uuids,
262 ServiceListCallback callback) override {
263 BT_ASSERT(callback);
264 auto iter = connections_.find(peer_id);
265 if (iter == connections_.end()) {
266 // Connection not found.
267 callback(ToResult(HostError::kNotFound), ServiceList());
268 return;
269 }
270 iter->second.remote_service_manager()->ListServices(uuids,
271 std::move(callback));
272 }
273
FindService(PeerId peer_id,IdType service_id)274 RemoteService::WeakPtr FindService(PeerId peer_id,
275 IdType service_id) override {
276 auto iter = connections_.find(peer_id);
277 if (iter == connections_.end()) {
278 // Connection not found.
279 return RemoteService::WeakPtr();
280 }
281 return iter->second.remote_service_manager()->FindService(
282 static_cast<att::Handle>(service_id));
283 }
284
285 private:
OnServicesChanged(PeerId peer_id,const std::vector<att::Handle> & removed,const std::vector<RemoteService::WeakPtr> & added,const std::vector<RemoteService::WeakPtr> & modified)286 void OnServicesChanged(PeerId peer_id,
287 const std::vector<att::Handle>& removed,
288 const std::vector<RemoteService::WeakPtr>& added,
289 const std::vector<RemoteService::WeakPtr>& modified) {
290 auto peer_watcher_range =
291 peer_remote_service_watchers_.equal_range(peer_id);
292 for (auto it = peer_watcher_range.first; it != peer_watcher_range.second;
293 it++) {
294 TRACE_DURATION("bluetooth", "GATT::OnServiceChanged notify watcher");
295 it->second.second(removed, added, modified);
296 }
297 }
298
299 // The registry containing all local GATT services. This represents a single
300 // ATT database.
301 std::unique_ptr<LocalServiceManager> local_services_;
302
303 // Local GATT service (first in database) for clients to subscribe to service
304 // registration and removal.
305 std::unique_ptr<GenericAttributeService> gatt_service_;
306
307 // Contains the state of all GATT profile connections and their services.
308 std::unordered_map<PeerId, internal::Connection> connections_;
309
310 // Callback to fetch CCC for Service Changed indications from upper layers.
311 RetrieveServiceChangedCCCCallback retrieve_service_changed_ccc_callback_;
312
313 RemoteServiceWatcherId next_watcher_id_ = 0u;
314 std::unordered_multimap<
315 PeerId,
316 std::pair<RemoteServiceWatcherId, RemoteServiceWatcher>>
317 peer_remote_service_watchers_;
318 PeerMtuListenerId next_mtu_listener_id_ = 0u;
319 std::unordered_map<PeerMtuListenerId, PeerMtuListener> peer_mtu_listeners_;
320
321 BT_DISALLOW_COPY_AND_ASSIGN_ALLOW_MOVE(Impl);
322 };
323 } // namespace
324
325 // static
Create()326 std::unique_ptr<GATT> GATT::Create() { return std::make_unique<Impl>(); }
327
328 } // namespace bt::gatt
329