• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/sdp/server.h"
16 
17 #include <pw_assert/check.h>
18 
19 #include <cstdint>
20 #include <cstdio>
21 
22 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
23 #include "pw_bluetooth_sapphire/internal/host/common/random.h"
24 #include "pw_bluetooth_sapphire/internal/host/l2cap/l2cap_defs.h"
25 #include "pw_bluetooth_sapphire/internal/host/l2cap/types.h"
26 #include "pw_bluetooth_sapphire/internal/host/sdp/data_element.h"
27 #include "pw_bluetooth_sapphire/internal/host/sdp/pdu.h"
28 #include "pw_bluetooth_sapphire/internal/host/sdp/sdp.h"
29 
30 namespace bt::sdp {
31 
32 using RegistrationHandle = Server::RegistrationHandle;
33 
34 namespace {
35 
36 constexpr const char* kInspectRegisteredPsmName = "registered_psms";
37 constexpr const char* kInspectPsmName = "psm";
38 constexpr const char* kInspectRecordName = "record";
39 
IsQueuedPsm(const std::vector<std::pair<l2cap::Psm,ServiceHandle>> * queued_psms,l2cap::Psm psm)40 bool IsQueuedPsm(
41     const std::vector<std::pair<l2cap::Psm, ServiceHandle>>* queued_psms,
42     l2cap::Psm psm) {
43   auto is_queued = [target = psm](const auto& psm_pair) {
44     return target == psm_pair.first;
45   };
46   auto iter = std::find_if(queued_psms->begin(), queued_psms->end(), is_queued);
47   return iter != queued_psms->end();
48 }
49 
50 // Returns true if the |psm| is considered valid.
IsValidPsm(l2cap::Psm psm)51 bool IsValidPsm(l2cap::Psm psm) {
52   // The least significant bit of the most significant octet must be 0
53   // (Core 5.4, Vol 3, Part A, 4.2).
54   constexpr uint16_t MS_OCTET_MASK = 0x0100;
55   if (psm & MS_OCTET_MASK) {
56     return false;
57   }
58 
59   // The least significant bit of all other octets must be 1
60   // (Core 5.4, Vol 3, Part A, 4.2).
61   constexpr uint16_t LOWER_OCTET_MASK = 0x0001;
62   if ((psm & LOWER_OCTET_MASK) != LOWER_OCTET_MASK) {
63     return false;
64   }
65   return true;
66 }
67 
68 // Updates the L2CAP |protocol| with the provided dynamic |new_psm|.
69 // Returns true if the list was updated, false if |protocol| is invalid.
UpdateProtocolWithL2capPsm(DataElement * protocol,l2cap::Psm new_psm)70 bool UpdateProtocolWithL2capPsm(DataElement* protocol, l2cap::Psm new_psm) {
71   bt_log(TRACE,
72          "sdp",
73          "Updating protocol with dynamic PSM: %s",
74          protocol->ToString().c_str());
75 
76   // A valid protocol is a sequence containing a UUID and PSM value (2
77   // elements).
78   auto l2cap_protocol = protocol->Get<std::vector<DataElement>>();
79   if (!l2cap_protocol || (*l2cap_protocol).size() != 2) {
80     return false;
81   }
82 
83   // The protocol should specify the L2CAP UUID.
84   const auto prot_uuid = (*l2cap_protocol).data();
85   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid ||
86       *prot_uuid->Get<UUID>() != protocol::kL2CAP) {
87     return false;
88   }
89 
90   // The second element should be the dynamic PSM. If found, update it.
91   auto dynamic_psm_elem = &(*l2cap_protocol)[1];
92   if (!dynamic_psm_elem->Get<uint16_t>() ||
93       dynamic_psm_elem->Get<uint16_t>() != Server::kDynamicPsm) {
94     bt_log(WARN, "sdp", "Request to update non-dynamic L2CAP PSM. Ignoring");
95     return false;
96   }
97   (*l2cap_protocol)[1] = DataElement(uint16_t{new_psm});
98   protocol->Set(std::move(*l2cap_protocol));
99 
100   bt_log(TRACE,
101          "sdp",
102          "Updated protocol list with dynamic PSM %s",
103          protocol->ToString().c_str());
104   return true;
105 }
106 
107 // Updates the L2CAP |protocol_list| with the dynamic |new_psm|.
108 // |protocol_list| must be a list of protocols- one of which must be L2CAP.
109 // Returns true if the list was updated with the |new_psm|, false otherwise.
UpdateProtocolListWithL2capPsm(DataElement & protocol_list,l2cap::Psm new_psm)110 bool UpdateProtocolListWithL2capPsm(DataElement& protocol_list,
111                                     l2cap::Psm new_psm) {
112   bt_log(TRACE,
113          "sdp",
114          "Updating protocol list with dynamic psm: %s",
115          protocol_list.ToString().c_str());
116 
117   auto protocol_seq = protocol_list.Get<std::vector<DataElement>>();
118   if (!protocol_seq) {
119     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not a valid sequence");
120     return false;
121   }
122 
123   bool updated = false;
124   for (DataElement& protocol : (*protocol_seq)) {
125     if (UpdateProtocolWithL2capPsm(&protocol, new_psm)) {
126       updated = true;
127       break;
128     }
129   }
130 
131   protocol_list.Set(std::move(*protocol_seq));
132   return updated;
133 }
134 
135 // Finds the PSM that is specified in a ProtocolDescriptorList
136 // Returns l2cap::kInvalidPsm if none is found or the list is invalid
FindProtocolListPsm(const DataElement & protocol_list)137 l2cap::Psm FindProtocolListPsm(const DataElement& protocol_list) {
138   bt_log(TRACE,
139          "sdp",
140          "Trying to find PSM from %s",
141          protocol_list.ToString().c_str());
142   const auto* l2cap_protocol = protocol_list.At(0);
143   PW_DCHECK(l2cap_protocol);
144   const auto* prot_uuid = l2cap_protocol->At(0);
145   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid ||
146       *prot_uuid->Get<UUID>() != protocol::kL2CAP) {
147     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not valid or not L2CAP");
148     return l2cap::kInvalidPsm;
149   }
150 
151   const auto* psm_elem = l2cap_protocol->At(1);
152   if (psm_elem && psm_elem->Get<uint16_t>()) {
153     return *psm_elem->Get<uint16_t>();
154   }
155   if (psm_elem) {
156     bt_log(TRACE, "sdp", "ProtocolDescriptorList invalid L2CAP parameter type");
157     return l2cap::kInvalidPsm;
158   }
159 
160   // The PSM is missing, determined by the next protocol.
161   const auto* next_protocol = protocol_list.At(1);
162   if (!next_protocol) {
163     bt_log(TRACE, "sdp", "L2CAP has no PSM and no additional protocol");
164     return l2cap::kInvalidPsm;
165   }
166   const auto* next_protocol_uuid = next_protocol->At(0);
167   if (!next_protocol_uuid ||
168       next_protocol_uuid->type() != DataElement::Type::kUuid) {
169     bt_log(TRACE, "sdp", "L2CAP has no PSM and additional protocol invalid");
170     return l2cap::kInvalidPsm;
171   }
172   UUID protocol_uuid = *next_protocol_uuid->Get<UUID>();
173   // When it's RFCOMM, the L2CAP protocol descriptor omits the PSM parameter
174   // See example in the SPP Spec, v1.2
175   if (protocol_uuid == protocol::kRFCOMM) {
176     return l2cap::kRFCOMM;
177   }
178   bt_log(TRACE, "sdp", "Can't determine L2CAP PSM from protocol");
179   return l2cap::kInvalidPsm;
180 }
181 
PsmFromProtocolList(const DataElement * protocol_list)182 l2cap::Psm PsmFromProtocolList(const DataElement* protocol_list) {
183   const auto* primary_protocol = protocol_list->At(0);
184   if (!primary_protocol) {
185     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not a sequence");
186     return l2cap::kInvalidPsm;
187   }
188 
189   const auto* prot_uuid = primary_protocol->At(0);
190   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid) {
191     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not valid");
192     return l2cap::kInvalidPsm;
193   }
194 
195   // We do nothing for primary protocols that are not L2CAP
196   if (*prot_uuid->Get<UUID>() != protocol::kL2CAP) {
197     return l2cap::kInvalidPsm;
198   }
199 
200   l2cap::Psm psm = FindProtocolListPsm(*protocol_list);
201   if (psm == l2cap::kInvalidPsm) {
202     bt_log(TRACE, "sdp", "Couldn't find PSM from ProtocolDescriptorList");
203     return l2cap::kInvalidPsm;
204   }
205 
206   return psm;
207 }
208 
209 // Sets the browse group list of the record to be the top-level group.
SetBrowseGroupList(ServiceRecord * record)210 void SetBrowseGroupList(ServiceRecord* record) {
211   std::vector<DataElement> browse_list;
212   browse_list.emplace_back(kPublicBrowseRootUuid);
213   record->SetAttribute(kBrowseGroupList, DataElement(std::move(browse_list)));
214 }
215 
216 }  // namespace
217 
218 // The VersionNumberList value. (5.0, Vol 3, Part B, 5.2.3)
219 constexpr uint16_t kVersion = 0x0100;  // Version 1.0
220 
221 // The initial ServiceDatabaseState
222 constexpr uint32_t kInitialDbState = 0;
223 
224 // Populates the ServiceDiscoveryService record.
MakeServiceDiscoveryService()225 ServiceRecord Server::MakeServiceDiscoveryService() {
226   ServiceRecord sdp;
227   sdp.SetHandle(kSDPHandle);
228 
229   // ServiceClassIDList attribute should have the
230   // ServiceDiscoveryServerServiceClassID
231   // See v5.0, Vol 3, Part B, Sec 5.2.2
232   sdp.SetServiceClassUUIDs({profile::kServiceDiscoveryClass});
233 
234   // The VersionNumberList attribute. See v5.0, Vol 3, Part B, Sec 5.2.3
235   // Version 1.0
236   std::vector<DataElement> version_attribute;
237   version_attribute.emplace_back(kVersion);
238   sdp.SetAttribute(kSDP_VersionNumberList,
239                    DataElement(std::move(version_attribute)));
240 
241   // ServiceDatabaseState attribute. Changes when a service gets added or
242   // removed.
243   sdp.SetAttribute(kSDP_ServiceDatabaseState, DataElement(kInitialDbState));
244 
245   return sdp;
246 }
247 
Server(l2cap::ChannelManager * l2cap)248 Server::Server(l2cap::ChannelManager* l2cap)
249     : l2cap_(l2cap),
250       next_handle_(kFirstUnreservedHandle),
251       db_state_(0),
252       weak_ptr_factory_(this) {
253   PW_CHECK(l2cap_);
254 
255   records_.emplace(kSDPHandle, Server::MakeServiceDiscoveryService());
256 
257   // Register SDP
258   l2cap::ChannelParameters sdp_chan_params;
259   sdp_chan_params.mode = l2cap::RetransmissionAndFlowControlMode::kBasic;
260   l2cap_->RegisterService(
261       l2cap::kSDP,
262       sdp_chan_params,
263       [self = weak_ptr_factory_.GetWeakPtr()](auto channel) {
264         if (self.is_alive())
265           self->AddConnection(channel);
266       });
267 
268   // SDP is used by SDP server.
269   psm_to_service_.emplace(l2cap::kSDP,
270                           std::unordered_set<ServiceHandle>({kSDPHandle}));
271   service_to_psms_.emplace(kSDPHandle,
272                            std::unordered_set<l2cap::Psm>({l2cap::kSDP}));
273 
274   // Update the inspect properties after Server initialization.
275   UpdateInspectProperties();
276 }
277 
~Server()278 Server::~Server() { l2cap_->UnregisterService(l2cap::kSDP); }
279 
AttachInspect(inspect::Node & parent,std::string name)280 void Server::AttachInspect(inspect::Node& parent, std::string name) {
281   inspect_properties_.sdp_server_node = parent.CreateChild(name);
282   UpdateInspectProperties();
283 }
284 
AddConnection(l2cap::Channel::WeakPtr channel)285 bool Server::AddConnection(l2cap::Channel::WeakPtr channel) {
286   PW_CHECK(channel.is_alive());
287   hci_spec::ConnectionHandle handle = channel->link_handle();
288   bt_log(DEBUG, "sdp", "add connection handle %#.4x", handle);
289 
290   l2cap::Channel::UniqueId chan_id = channel->unique_id();
291   auto iter = channels_.find(chan_id);
292   if (iter != channels_.end()) {
293     bt_log(WARN, "sdp", "l2cap channel to %#.4x already connected", handle);
294     return false;
295   }
296 
297   auto self = weak_ptr_factory_.GetWeakPtr();
298   bool activated = channel->Activate(
299       [self, chan_id, max_tx_sdu_size = channel->max_tx_sdu_size()](
300           ByteBufferPtr sdu) {
301         if (self.is_alive()) {
302           auto packet = self->HandleRequest(std::move(sdu), max_tx_sdu_size);
303           if (packet) {
304             self->Send(chan_id, std::move(packet.value()));
305           }
306         }
307       },
308       [self, chan_id] {
309         if (self.is_alive()) {
310           self->OnChannelClosed(chan_id);
311         }
312       });
313   if (!activated) {
314     bt_log(WARN, "sdp", "failed to activate channel (handle %#.4x)", handle);
315     return false;
316   }
317   self->channels_.emplace(chan_id, std::move(channel));
318   return true;
319 }
320 
AddPsmToProtocol(ProtocolQueue * protocols_to_register,l2cap::Psm psm,ServiceHandle handle) const321 bool Server::AddPsmToProtocol(ProtocolQueue* protocols_to_register,
322                               l2cap::Psm psm,
323                               ServiceHandle handle) const {
324   if (psm == l2cap::kInvalidPsm) {
325     return false;
326   }
327 
328   if (IsAllocated(psm)) {
329     bt_log(TRACE, "sdp", "L2CAP PSM %#.4x is already allocated", psm);
330     return false;
331   }
332 
333   auto data = std::make_pair(psm, handle);
334   protocols_to_register->emplace_back(std::move(data));
335   return true;
336 }
337 
GetDynamicPsm(const ProtocolQueue * queued_psms) const338 l2cap::Psm Server::GetDynamicPsm(const ProtocolQueue* queued_psms) const {
339   // Generate a random PSM in the valid range of PSMs.
340   // RNG(Range(MIN, MAX)) = MIN + RNG(MAX-MIN) where MIN = kMinDynamicPSM =
341   // 0x1001. MAX = 0xffff.
342   uint16_t offset = 0;
343   constexpr uint16_t MAX_MINUS_MIN = 0xeffe;
344   random_generator()->GetInt(offset, MAX_MINUS_MIN);
345   uint16_t psm = l2cap::kMinDynamicPsm + offset;
346   // LSB of upper octet must be 0. LSB of lower octet must be 1.
347   constexpr uint16_t UPPER_OCTET_MASK = 0xFEFF;
348   constexpr uint16_t LOWER_OCTET_MASK = 0x0001;
349   psm &= UPPER_OCTET_MASK;
350   psm |= LOWER_OCTET_MASK;
351   bt_log(DEBUG, "sdp", "Trying random dynamic PSM %#.4x", psm);
352 
353   // Check if the PSM is valid (e.g. valid construction, not allocated, & not
354   // queued).
355   if ((IsValidPsm(psm)) && (!IsAllocated(psm)) &&
356       (!IsQueuedPsm(queued_psms, psm))) {
357     bt_log(TRACE, "sdp", "Generated random dynamic PSM %#.4x", psm);
358     return psm;
359   }
360 
361   // Otherwise, fall back to sequentially finding the next available PSM.
362   bool search_wrapped = false;
363   for (uint16_t next_psm = psm + 2; next_psm <= UINT16_MAX; next_psm += 2) {
364     if ((IsValidPsm(next_psm)) && (!IsAllocated(next_psm)) &&
365         (!IsQueuedPsm(queued_psms, next_psm))) {
366       bt_log(TRACE, "sdp", "Generated sequential dynamic PSM %#.4x", next_psm);
367       return next_psm;
368     }
369 
370     // If we reach the max valid PSM, wrap around to the minimum valid dynamic
371     // PSM. Only try this once.
372     if (next_psm == 0xFEFF) {
373       next_psm = l2cap::kMinDynamicPsm;
374       if (search_wrapped) {
375         break;
376       }
377       search_wrapped = true;
378     }
379   }
380   bt_log(WARN, "sdp", "Couldn't find an available dynamic PSM");
381   return l2cap::kInvalidPsm;
382 }
383 
QueueService(ServiceRecord * record,ProtocolQueue * protocols_to_register)384 bool Server::QueueService(ServiceRecord* record,
385                           ProtocolQueue* protocols_to_register) {
386   // ProtocolDescriptorList handling:
387   if (record->HasAttribute(kProtocolDescriptorList)) {
388     const auto& primary_protocol =
389         record->GetAttribute(kProtocolDescriptorList);
390     auto psm = PsmFromProtocolList(&primary_protocol);
391     if (psm == kDynamicPsm) {
392       bt_log(TRACE, "sdp", "Primary protocol contains dynamic PSM");
393       auto primary_protocol_copy = primary_protocol.Clone();
394       psm = GetDynamicPsm(protocols_to_register);
395       if (!UpdateProtocolListWithL2capPsm(primary_protocol_copy, psm)) {
396         return false;
397       }
398       record->SetAttribute(kProtocolDescriptorList,
399                            std::move(primary_protocol_copy));
400     }
401     if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
402       return false;
403     }
404   }
405 
406   // AdditionalProtocolDescriptorList handling:
407   if (record->HasAttribute(kAdditionalProtocolDescriptorList)) {
408     // |additional_list| is a list of ProtocolDescriptorLists.
409     const auto& additional_list =
410         record->GetAttribute(kAdditionalProtocolDescriptorList);
411     size_t attribute_id = 0;
412     const auto* additional = additional_list.At(attribute_id);
413 
414     // If `kAdditionalProtocolDescriptorList` exists, there should be at least
415     // one protocol provided.
416     if (!additional) {
417       bt_log(
418           TRACE, "sdp", "AdditionalProtocolDescriptorList provided but empty");
419       return false;
420     }
421 
422     // Add valid additional PSMs to the register queue. Because some additional
423     // protocols may need dynamic PSM assignment, modify the relevant protocols
424     // and rebuild the list.
425     std::vector<DataElement> additional_protocols;
426     while (additional) {
427       auto psm = PsmFromProtocolList(additional);
428       auto additional_protocol_copy = additional->Clone();
429       if (psm == kDynamicPsm) {
430         bt_log(TRACE, "sdp", "Additional protocol contains dynamic PSM");
431         psm = GetDynamicPsm(protocols_to_register);
432         if (!UpdateProtocolListWithL2capPsm(additional_protocol_copy, psm)) {
433           return l2cap::kInvalidPsm;
434         }
435       }
436       if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
437         return false;
438       }
439 
440       attribute_id++;
441       additional_protocols.emplace_back(std::move(additional_protocol_copy));
442       additional = additional_list.At(attribute_id);
443     }
444     record->SetAttribute(kAdditionalProtocolDescriptorList,
445                          DataElement(std::move(additional_protocols)));
446   }
447 
448   // For some services that depend on OBEX, the L2CAP PSM is specified in the
449   // GoepL2capPsm attribute.
450   bool has_obex = record->FindUUID(std::unordered_set<UUID>({protocol::kOBEX}));
451   if (has_obex && record->HasAttribute(kGoepL2capPsm)) {
452     const auto& attribute = record->GetAttribute(kGoepL2capPsm);
453     if (attribute.Get<uint16_t>()) {
454       auto psm = *attribute.Get<uint16_t>();
455       // If a dynamic PSM was requested, attempt to allocate the next available
456       // PSM.
457       if (psm == kDynamicPsm) {
458         bt_log(TRACE, "sdp", "GoepL2capAttribute contains dynamic PSM");
459         psm = GetDynamicPsm(protocols_to_register);
460         record->SetAttribute(kGoepL2capPsm, DataElement(uint16_t{psm}));
461       }
462       if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
463         return false;
464       }
465     }
466   }
467 
468   return true;
469 }
470 
RegisterService(std::vector<ServiceRecord> records,l2cap::ChannelParameters chan_params,ConnectCallback conn_cb)471 RegistrationHandle Server::RegisterService(std::vector<ServiceRecord> records,
472                                            l2cap::ChannelParameters chan_params,
473                                            ConnectCallback conn_cb) {
474   if (records.empty()) {
475     return 0;
476   }
477 
478   // The PSMs and their ServiceHandles to register.
479   ProtocolQueue protocols_to_register;
480 
481   // The ServiceHandles that are assigned to each ServiceRecord.
482   // There should be one ServiceHandle per ServiceRecord in |records|.
483   std::set<ServiceHandle> assigned_handles;
484 
485   for (auto& record : records) {
486     ServiceHandle next = GetNextHandle();
487     if (!next) {
488       return 0;
489     }
490     // Assign a new handle for the service record.
491     record.SetHandle(next);
492 
493     if (!record.IsProtocolOnly()) {
494       // Place record in a browse group.
495       SetBrowseGroupList(&record);
496 
497       // Validate the |ServiceRecord|.
498       if (!record.IsRegisterable()) {
499         return 0;
500       }
501     }
502 
503     // Attempt to queue the |record| for registration.
504     // Note: Since the validation & queueing operations for ALL the records
505     // occur before registration, multiple ServiceRecords can share the same
506     // PSM.
507     //
508     // If any |record| is not parsable, exit the registration process early.
509     if (!QueueService(&record, &protocols_to_register)) {
510       return 0;
511     }
512 
513     // For every ServiceRecord, there will be one ServiceHandle assigned.
514     assigned_handles.emplace(next);
515   }
516 
517   PW_CHECK(assigned_handles.size() == records.size());
518 
519   // The RegistrationHandle is the smallest ServiceHandle that was assigned.
520   RegistrationHandle reg_handle = *assigned_handles.begin();
521 
522   // Multiple ServiceRecords in |records| can request the same PSM. However,
523   // |l2cap_| expects a single target for each PSM to go to. Consequently,
524   // only the first occurrence of a PSM needs to be registered with the
525   // |l2cap_|.
526   std::unordered_set<l2cap::Psm> psms_to_register;
527 
528   // All PSMs have assigned handles and will be registered.
529   for (auto& [psm, handle] : protocols_to_register) {
530     psm_to_service_[psm].insert(handle);
531     service_to_psms_[handle].insert(psm);
532 
533     // Add unique PSMs to the data domain registration queue.
534     psms_to_register.insert(psm);
535   }
536 
537   for (const auto& psm : psms_to_register) {
538     bt_log(TRACE, "sdp", "Allocating PSM %#.4x for new service", psm);
539     l2cap_->RegisterService(
540         psm,
541         chan_params,
542         [l2cap_psm = psm, conn_cb_shared = conn_cb.share()](
543             l2cap::Channel::WeakPtr channel) mutable {
544           bt_log(TRACE, "sdp", "Channel connected to %#.4x", l2cap_psm);
545           // Build the L2CAP descriptor
546           std::vector<DataElement> protocol_l2cap;
547           protocol_l2cap.emplace_back(protocol::kL2CAP);
548           protocol_l2cap.emplace_back(l2cap_psm);
549           std::vector<DataElement> protocol;
550           protocol.emplace_back(std::move(protocol_l2cap));
551           conn_cb_shared(std::move(channel), DataElement(std::move(protocol)));
552         });
553   }
554 
555   // Store the complete records.
556   for (auto& record : records) {
557     auto [it, success] = records_.emplace(record.handle(), std::move(record));
558     PW_DCHECK(success);
559     const ServiceRecord& placed_record = it->second;
560     if (placed_record.IsProtocolOnly()) {
561       bt_log(TRACE,
562              "sdp",
563              "registered protocol-only service %#.8x, Protocol: %s",
564              placed_record.handle(),
565              bt_str(placed_record.GetAttribute(kProtocolDescriptorList)));
566     } else {
567       bt_log(TRACE,
568              "sdp",
569              "registered service %#.8x, classes: %s",
570              placed_record.handle(),
571              bt_str(placed_record.GetAttribute(kServiceClassIdList)));
572     }
573   }
574 
575   // Store the RegistrationHandle that represents the set of services that were
576   // registered.
577   reg_to_service_[reg_handle] = std::move(assigned_handles);
578 
579   // Update the inspect properties.
580   UpdateInspectProperties();
581 
582   return reg_handle;
583 }
584 
UnregisterService(RegistrationHandle handle)585 bool Server::UnregisterService(RegistrationHandle handle) {
586   if (handle == kNotRegistered) {
587     return false;
588   }
589 
590   auto handles_it = reg_to_service_.extract(handle);
591   if (!handles_it) {
592     return false;
593   }
594 
595   for (const auto& svc_h : handles_it.mapped()) {
596     PW_CHECK(svc_h != kSDPHandle);
597     PW_CHECK(records_.find(svc_h) != records_.end());
598     bt_log(DEBUG, "sdp", "unregistering service (handle: %#.8x)", svc_h);
599 
600     // Unregister any service callbacks from L2CAP
601     auto psms_it = service_to_psms_.extract(svc_h);
602     if (psms_it) {
603       for (const auto& psm : psms_it.mapped()) {
604         bt_log(DEBUG, "sdp", "removing registration for psm %#.4x", psm);
605         l2cap_->UnregisterService(psm);
606         psm_to_service_.erase(psm);
607       }
608     }
609 
610     records_.erase(svc_h);
611   }
612 
613   // Update the inspect properties as the registered PSMs may have changed.
614   UpdateInspectProperties();
615 
616   return true;
617 }
618 
GetRegisteredServices(RegistrationHandle handle) const619 std::vector<ServiceRecord> Server::GetRegisteredServices(
620     RegistrationHandle handle) const {
621   std::vector<ServiceRecord> out;
622   if (handle == kNotRegistered) {
623     return out;
624   }
625 
626   auto service_handles_it = reg_to_service_.find(handle);
627   if (service_handles_it == reg_to_service_.end()) {
628     return out;
629   }
630 
631   for (const auto& service_handle : service_handles_it->second) {
632     auto record_it = records_.find(service_handle);
633     if (record_it != records_.end()) {
634       ServiceRecord record_copy = record_it->second;
635       out.emplace_back(std::move(record_copy));
636     }
637   }
638 
639   return out;
640 }
641 
GetNextHandle()642 ServiceHandle Server::GetNextHandle() {
643   ServiceHandle initial_next_handle = next_handle_;
644   // We expect most of these to be free.
645   // Safeguard against possibly having to wrap-around and reuse handles.
646   while (records_.count(next_handle_)) {
647     if (next_handle_ == kLastHandle) {
648       bt_log(WARN, "sdp", "service handle wrapped to start");
649       next_handle_ = kFirstUnreservedHandle;
650     } else {
651       next_handle_++;
652     }
653     if (next_handle_ == initial_next_handle) {
654       return 0;
655     }
656   }
657   return next_handle_++;
658 }
659 
SearchServices(const std::unordered_set<UUID> & pattern) const660 ServiceSearchResponse Server::SearchServices(
661     const std::unordered_set<UUID>& pattern) const {
662   ServiceSearchResponse resp;
663   std::vector<ServiceHandle> matched;
664   for (const auto& it : records_) {
665     if (it.second.FindUUID(pattern) && !it.second.IsProtocolOnly()) {
666       matched.push_back(it.first);
667     }
668   }
669   bt_log(TRACE, "sdp", "ServiceSearch matched %zu records", matched.size());
670   resp.set_service_record_handle_list(matched);
671   return resp;
672 }
673 
GetServiceAttributes(ServiceHandle handle,const std::list<AttributeRange> & ranges) const674 ServiceAttributeResponse Server::GetServiceAttributes(
675     ServiceHandle handle, const std::list<AttributeRange>& ranges) const {
676   ServiceAttributeResponse resp;
677   const auto& record = records_.at(handle);
678   for (const auto& range : ranges) {
679     auto attrs = record.GetAttributesInRange(range.start, range.end);
680     for (const auto& attr : attrs) {
681       resp.set_attribute(attr, record.GetAttribute(attr).Clone());
682     }
683   }
684   bt_log(TRACE,
685          "sdp",
686          "ServiceAttribute %zu attributes",
687          resp.attributes().size());
688   return resp;
689 }
690 
SearchAllServiceAttributes(const std::unordered_set<UUID> & search_pattern,const std::list<AttributeRange> & attribute_ranges) const691 ServiceSearchAttributeResponse Server::SearchAllServiceAttributes(
692     const std::unordered_set<UUID>& search_pattern,
693     const std::list<AttributeRange>& attribute_ranges) const {
694   ServiceSearchAttributeResponse resp;
695   for (const auto& it : records_) {
696     const auto& rec = it.second;
697     if (rec.IsProtocolOnly()) {
698       continue;
699     }
700     if (rec.FindUUID(search_pattern)) {
701       for (const auto& range : attribute_ranges) {
702         auto attrs = rec.GetAttributesInRange(range.start, range.end);
703         for (const auto& attr : attrs) {
704           resp.SetAttribute(it.first, attr, rec.GetAttribute(attr).Clone());
705         }
706       }
707     }
708   }
709 
710   bt_log(TRACE,
711          "sdp",
712          "ServiceSearchAttribute %zu records",
713          resp.num_attribute_lists());
714   return resp;
715 }
716 
OnChannelClosed(l2cap::Channel::UniqueId channel_id)717 void Server::OnChannelClosed(l2cap::Channel::UniqueId channel_id) {
718   channels_.erase(channel_id);
719 }
720 
HandleRequest(ByteBufferPtr sdu,uint16_t max_tx_sdu_size)721 std::optional<ByteBufferPtr> Server::HandleRequest(ByteBufferPtr sdu,
722                                                    uint16_t max_tx_sdu_size) {
723   PW_DCHECK(sdu);
724   TRACE_DURATION("bluetooth", "sdp::Server::HandleRequest");
725   if (sdu->size() < sizeof(Header)) {
726     bt_log(DEBUG, "sdp", "PDU too short; dropping");
727     return std::nullopt;
728   }
729   PacketView<Header> packet(sdu.get());
730   TransactionId tid =
731       pw::bytes::ConvertOrderFrom(cpp20::endian::big, packet.header().tid);
732   uint16_t param_length = pw::bytes::ConvertOrderFrom(
733       cpp20::endian::big, packet.header().param_length);
734   auto error_response_builder =
735       [tid, max_tx_sdu_size](ErrorCode code) -> ByteBufferPtr {
736     return ErrorResponse(code).GetPDU(
737         0 /* ignored */, tid, max_tx_sdu_size, BufferView());
738   };
739   if (param_length != (sdu->size() - sizeof(Header))) {
740     bt_log(TRACE,
741            "sdp",
742            "request isn't the correct size (%hu != %zu)",
743            param_length,
744            sdu->size() - sizeof(Header));
745     return error_response_builder(ErrorCode::kInvalidSize);
746   }
747   packet.Resize(param_length);
748   switch (packet.header().pdu_id) {
749     case kServiceSearchRequest: {
750       ServiceSearchRequest request(packet.payload_data());
751       if (!request.valid()) {
752         bt_log(DEBUG, "sdp", "ServiceSearchRequest not valid");
753         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
754       }
755       auto resp = SearchServices(request.service_search_pattern());
756 
757       auto bytes = resp.GetPDU(request.max_service_record_count(),
758                                tid,
759                                max_tx_sdu_size,
760                                request.ContinuationState());
761       if (!bytes) {
762         return error_response_builder(ErrorCode::kInvalidContinuationState);
763       }
764       return std::move(bytes);
765     }
766     case kServiceAttributeRequest: {
767       ServiceAttributeRequest request(packet.payload_data());
768       if (!request.valid()) {
769         bt_log(TRACE, "sdp", "ServiceAttributeRequest not valid");
770         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
771       }
772       auto handle = request.service_record_handle();
773       auto record_it = records_.find(handle);
774       if (record_it == records_.end() || record_it->second.IsProtocolOnly()) {
775         bt_log(TRACE,
776                "sdp",
777                "ServiceAttributeRequest can't find handle %#.8x",
778                handle);
779         return error_response_builder(ErrorCode::kInvalidRecordHandle);
780       }
781       auto resp = GetServiceAttributes(handle, request.attribute_ranges());
782       auto bytes = resp.GetPDU(request.max_attribute_byte_count(),
783                                tid,
784                                max_tx_sdu_size,
785                                request.ContinuationState());
786       if (!bytes) {
787         return error_response_builder(ErrorCode::kInvalidContinuationState);
788       }
789       return std::move(bytes);
790     }
791     case kServiceSearchAttributeRequest: {
792       ServiceSearchAttributeRequest request(packet.payload_data());
793       if (!request.valid()) {
794         bt_log(TRACE, "sdp", "ServiceSearchAttributeRequest not valid");
795         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
796       }
797       auto resp = SearchAllServiceAttributes(request.service_search_pattern(),
798                                              request.attribute_ranges());
799       auto bytes = resp.GetPDU(request.max_attribute_byte_count(),
800                                tid,
801                                max_tx_sdu_size,
802                                request.ContinuationState());
803       if (!bytes) {
804         return error_response_builder(ErrorCode::kInvalidContinuationState);
805       }
806       return std::move(bytes);
807     }
808     case kErrorResponse: {
809       bt_log(TRACE, "sdp", "ErrorResponse isn't allowed as a request");
810       return error_response_builder(ErrorCode::kInvalidRequestSyntax);
811     }
812     default: {
813       bt_log(TRACE, "sdp", "unhandled request, returning InvalidRequest");
814       return error_response_builder(ErrorCode::kInvalidRequestSyntax);
815     }
816   }
817 }
818 
Send(l2cap::Channel::UniqueId channel_id,ByteBufferPtr bytes)819 void Server::Send(l2cap::Channel::UniqueId channel_id, ByteBufferPtr bytes) {
820   auto it = channels_.find(channel_id);
821   if (it == channels_.end()) {
822     bt_log(ERROR, "sdp", "can't find peer to respond to; dropping");
823     return;
824   }
825   l2cap::Channel::WeakPtr chan = it->second.get();
826   chan->Send(std::move(bytes));
827 }
828 
UpdateInspectProperties()829 void Server::UpdateInspectProperties() {
830   // Skip update if node has not been attached.
831   if (!inspect_properties_.sdp_server_node) {
832     return;
833   }
834 
835   // Clear the previous inspect data.
836   inspect_properties_.svc_record_properties.clear();
837 
838   for (const auto& svc_record : records_) {
839     auto record_string = svc_record.second.ToString();
840     auto psms_it = service_to_psms_.find(svc_record.first);
841     std::unordered_set<l2cap::Psm> psm_set;
842     if (psms_it != service_to_psms_.end()) {
843       psm_set = psms_it->second;
844     }
845 
846     InspectProperties::InspectServiceRecordProperties svc_rec_props(
847         std::move(record_string), std::move(psm_set));
848     auto& parent = inspect_properties_.sdp_server_node;
849     svc_rec_props.AttachInspect(parent, parent.UniqueName(kInspectRecordName));
850 
851     inspect_properties_.svc_record_properties.push_back(
852         std::move(svc_rec_props));
853   }
854 }
855 
AllocatedPsmsForTest() const856 std::set<l2cap::Psm> Server::AllocatedPsmsForTest() const {
857   std::set<l2cap::Psm> allocated;
858   for (auto it = psm_to_service_.begin(); it != psm_to_service_.end(); ++it) {
859     allocated.insert(it->first);
860   }
861   return allocated;
862 }
863 
864 Server::InspectProperties::InspectServiceRecordProperties::
InspectServiceRecordProperties(std::string record_in,std::unordered_set<l2cap::Psm> psms_in)865     InspectServiceRecordProperties(std::string record_in,
866                                    std::unordered_set<l2cap::Psm> psms_in)
867     : record(std::move(record_in)), psms(std::move(psms_in)) {}
868 
AttachInspect(inspect::Node & parent,std::string name)869 void Server::InspectProperties::InspectServiceRecordProperties::AttachInspect(
870     inspect::Node& parent, std::string name) {
871   node = parent.CreateChild(name);
872   record_property = node.CreateString(kInspectRecordName, record);
873   psms_node = node.CreateChild(kInspectRegisteredPsmName);
874   psm_nodes.clear();
875   for (const auto& psm : psms) {
876     auto psm_node =
877         psms_node.CreateChild(psms_node.UniqueName(kInspectPsmName));
878     auto psm_string =
879         psm_node.CreateString(kInspectPsmName, l2cap::PsmToString(psm));
880     psm_nodes.emplace_back(std::move(psm_node), std::move(psm_string));
881   }
882 }
883 
884 }  // namespace bt::sdp
885