• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/sdp/server.h"
16 
17 #include <cstdint>
18 #include <cstdio>
19 
20 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
21 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
22 #include "pw_bluetooth_sapphire/internal/host/common/random.h"
23 #include "pw_bluetooth_sapphire/internal/host/l2cap/l2cap_defs.h"
24 #include "pw_bluetooth_sapphire/internal/host/l2cap/types.h"
25 #include "pw_bluetooth_sapphire/internal/host/sdp/data_element.h"
26 #include "pw_bluetooth_sapphire/internal/host/sdp/pdu.h"
27 #include "pw_bluetooth_sapphire/internal/host/sdp/sdp.h"
28 
29 namespace bt::sdp {
30 
31 using RegistrationHandle = Server::RegistrationHandle;
32 
33 namespace {
34 
35 constexpr const char* kInspectRegisteredPsmName = "registered_psms";
36 constexpr const char* kInspectPsmName = "psm";
37 constexpr const char* kInspectRecordName = "record";
38 
IsQueuedPsm(const std::vector<std::pair<l2cap::Psm,ServiceHandle>> * queued_psms,l2cap::Psm psm)39 bool IsQueuedPsm(
40     const std::vector<std::pair<l2cap::Psm, ServiceHandle>>* queued_psms,
41     l2cap::Psm psm) {
42   auto is_queued = [target = psm](const auto& psm_pair) {
43     return target == psm_pair.first;
44   };
45   auto iter = std::find_if(queued_psms->begin(), queued_psms->end(), is_queued);
46   return iter != queued_psms->end();
47 }
48 
49 // Returns true if the |psm| is considered valid.
IsValidPsm(l2cap::Psm psm)50 bool IsValidPsm(l2cap::Psm psm) {
51   // The least significant bit of the most significant octet must be 0
52   // (Core 5.4, Vol 3, Part A, 4.2).
53   constexpr uint16_t MS_OCTET_MASK = 0x0100;
54   if (psm & MS_OCTET_MASK) {
55     return false;
56   }
57 
58   // The least significant bit of all other octets must be 1
59   // (Core 5.4, Vol 3, Part A, 4.2).
60   constexpr uint16_t LOWER_OCTET_MASK = 0x0001;
61   if ((psm & LOWER_OCTET_MASK) != LOWER_OCTET_MASK) {
62     return false;
63   }
64   return true;
65 }
66 
67 // Updates the |protocol_list| with the provided dynamic |psm|.
68 // Returns true if the list was updated, false if the list couldn't be updated.
UpdateProtocolWithPsm(DataElement & protocol_list,l2cap::Psm psm)69 bool UpdateProtocolWithPsm(DataElement& protocol_list, l2cap::Psm psm) {
70   const auto* l2cap_protocol = protocol_list.At(0);
71   BT_DEBUG_ASSERT(l2cap_protocol);
72   const auto* prot_uuid = l2cap_protocol->At(0);
73   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid ||
74       *prot_uuid->Get<UUID>() != protocol::kL2CAP) {
75     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not valid or not L2CAP");
76     return false;
77   }
78   std::vector<DataElement> result;
79   // Rebuild the L2CAP protocol by adding UUID & new PSM.
80   result.emplace_back(prot_uuid->Clone());
81   result.emplace_back(DataElement(uint16_t{psm}));
82 
83   // Copy over the remaining protocol descriptors.
84   const DataElement* it;
85   for (size_t idx = 1; nullptr != (it = protocol_list.At(idx)); idx++) {
86     result.emplace_back(it->Clone());
87   }
88 
89   protocol_list = DataElement(std::move(result));
90   bt_log(TRACE,
91          "sdp",
92          "Updated protocol list with dynamic PSM %s",
93          protocol_list.ToString().c_str());
94   return true;
95 }
96 
97 // Finds the PSM that is specified in a ProtocolDescriptorList
98 // Returns l2cap::kInvalidPsm if none is found or the list is invalid
FindProtocolListPsm(const DataElement & protocol_list)99 l2cap::Psm FindProtocolListPsm(const DataElement& protocol_list) {
100   bt_log(TRACE,
101          "sdp",
102          "Trying to find PSM from %s",
103          protocol_list.ToString().c_str());
104   const auto* l2cap_protocol = protocol_list.At(0);
105   BT_DEBUG_ASSERT(l2cap_protocol);
106   const auto* prot_uuid = l2cap_protocol->At(0);
107   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid ||
108       *prot_uuid->Get<UUID>() != protocol::kL2CAP) {
109     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not valid or not L2CAP");
110     return l2cap::kInvalidPsm;
111   }
112 
113   const auto* psm_elem = l2cap_protocol->At(1);
114   if (psm_elem && psm_elem->Get<uint16_t>()) {
115     return *psm_elem->Get<uint16_t>();
116   }
117   if (psm_elem) {
118     bt_log(TRACE, "sdp", "ProtocolDescriptorList invalid L2CAP parameter type");
119     return l2cap::kInvalidPsm;
120   }
121 
122   // The PSM is missing, determined by the next protocol.
123   const auto* next_protocol = protocol_list.At(1);
124   if (!next_protocol) {
125     bt_log(TRACE, "sdp", "L2CAP has no PSM and no additional protocol");
126     return l2cap::kInvalidPsm;
127   }
128   const auto* next_protocol_uuid = next_protocol->At(0);
129   if (!next_protocol_uuid ||
130       next_protocol_uuid->type() != DataElement::Type::kUuid) {
131     bt_log(TRACE, "sdp", "L2CAP has no PSM and additional protocol invalid");
132     return l2cap::kInvalidPsm;
133   }
134   UUID protocol_uuid = *next_protocol_uuid->Get<UUID>();
135   // When it's RFCOMM, the L2CAP protocol descriptor omits the PSM parameter
136   // See example in the SPP Spec, v1.2
137   if (protocol_uuid == protocol::kRFCOMM) {
138     return l2cap::kRFCOMM;
139   }
140   bt_log(TRACE, "sdp", "Can't determine L2CAP PSM from protocol");
141   return l2cap::kInvalidPsm;
142 }
143 
PsmFromProtocolList(const DataElement * protocol_list)144 l2cap::Psm PsmFromProtocolList(const DataElement* protocol_list) {
145   const auto* primary_protocol = protocol_list->At(0);
146   if (!primary_protocol) {
147     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not a sequence");
148     return l2cap::kInvalidPsm;
149   }
150 
151   const auto* prot_uuid = primary_protocol->At(0);
152   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid) {
153     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not valid");
154     return l2cap::kInvalidPsm;
155   }
156 
157   // We do nothing for primary protocols that are not L2CAP
158   if (*prot_uuid->Get<UUID>() != protocol::kL2CAP) {
159     return l2cap::kInvalidPsm;
160   }
161 
162   l2cap::Psm psm = FindProtocolListPsm(*protocol_list);
163   if (psm == l2cap::kInvalidPsm) {
164     bt_log(TRACE, "sdp", "Couldn't find PSM from ProtocolDescriptorList");
165     return l2cap::kInvalidPsm;
166   }
167 
168   return psm;
169 }
170 
171 // Sets the browse group list of the record to be the top-level group.
SetBrowseGroupList(ServiceRecord * record)172 void SetBrowseGroupList(ServiceRecord* record) {
173   std::vector<DataElement> browse_list;
174   browse_list.emplace_back(kPublicBrowseRootUuid);
175   record->SetAttribute(kBrowseGroupList, DataElement(std::move(browse_list)));
176 }
177 
178 }  // namespace
179 
180 // The VersionNumberList value. (5.0, Vol 3, Part B, 5.2.3)
181 constexpr uint16_t kVersion = 0x0100;  // Version 1.0
182 
183 // The initial ServiceDatabaseState
184 constexpr uint32_t kInitialDbState = 0;
185 
186 // Populates the ServiceDiscoveryService record.
MakeServiceDiscoveryService()187 ServiceRecord Server::MakeServiceDiscoveryService() {
188   ServiceRecord sdp;
189   sdp.SetHandle(kSDPHandle);
190 
191   // ServiceClassIDList attribute should have the
192   // ServiceDiscoveryServerServiceClassID
193   // See v5.0, Vol 3, Part B, Sec 5.2.2
194   sdp.SetServiceClassUUIDs({profile::kServiceDiscoveryClass});
195 
196   // The VersionNumberList attribute. See v5.0, Vol 3, Part B, Sec 5.2.3
197   // Version 1.0
198   std::vector<DataElement> version_attribute;
199   version_attribute.emplace_back(kVersion);
200   sdp.SetAttribute(kSDP_VersionNumberList,
201                    DataElement(std::move(version_attribute)));
202 
203   // ServiceDatabaseState attribute. Changes when a service gets added or
204   // removed.
205   sdp.SetAttribute(kSDP_ServiceDatabaseState, DataElement(kInitialDbState));
206 
207   return sdp;
208 }
209 
Server(l2cap::ChannelManager * l2cap)210 Server::Server(l2cap::ChannelManager* l2cap)
211     : l2cap_(l2cap),
212       next_handle_(kFirstUnreservedHandle),
213       db_state_(0),
214       weak_ptr_factory_(this) {
215   BT_ASSERT(l2cap_);
216 
217   records_.emplace(kSDPHandle, Server::MakeServiceDiscoveryService());
218 
219   // Register SDP
220   l2cap::ChannelParameters sdp_chan_params;
221   sdp_chan_params.mode = l2cap::RetransmissionAndFlowControlMode::kBasic;
222   l2cap_->RegisterService(
223       l2cap::kSDP,
224       sdp_chan_params,
225       [self = weak_ptr_factory_.GetWeakPtr()](auto channel) {
226         if (self.is_alive())
227           self->AddConnection(channel);
228       });
229 
230   // SDP is used by SDP server.
231   psm_to_service_.emplace(l2cap::kSDP,
232                           std::unordered_set<ServiceHandle>({kSDPHandle}));
233   service_to_psms_.emplace(kSDPHandle,
234                            std::unordered_set<l2cap::Psm>({l2cap::kSDP}));
235 
236   // Update the inspect properties after Server initialization.
237   UpdateInspectProperties();
238 }
239 
~Server()240 Server::~Server() { l2cap_->UnregisterService(l2cap::kSDP); }
241 
AttachInspect(inspect::Node & parent,std::string name)242 void Server::AttachInspect(inspect::Node& parent, std::string name) {
243   inspect_properties_.sdp_server_node = parent.CreateChild(name);
244   UpdateInspectProperties();
245 }
246 
AddConnection(l2cap::Channel::WeakPtr channel)247 bool Server::AddConnection(l2cap::Channel::WeakPtr channel) {
248   BT_ASSERT(channel.is_alive());
249   hci_spec::ConnectionHandle handle = channel->link_handle();
250   bt_log(DEBUG, "sdp", "add connection handle %#.4x", handle);
251 
252   l2cap::Channel::UniqueId chan_id = channel->unique_id();
253   auto iter = channels_.find(chan_id);
254   if (iter != channels_.end()) {
255     bt_log(WARN, "sdp", "l2cap channel to %#.4x already connected", handle);
256     return false;
257   }
258 
259   auto self = weak_ptr_factory_.GetWeakPtr();
260   bool activated = channel->Activate(
261       [self, chan_id, max_tx_sdu_size = channel->max_tx_sdu_size()](
262           ByteBufferPtr sdu) {
263         if (self.is_alive()) {
264           auto packet = self->HandleRequest(std::move(sdu), max_tx_sdu_size);
265           if (packet) {
266             self->Send(chan_id, std::move(packet.value()));
267           }
268         }
269       },
270       [self, chan_id] {
271         if (self.is_alive()) {
272           self->OnChannelClosed(chan_id);
273         }
274       });
275   if (!activated) {
276     bt_log(WARN, "sdp", "failed to activate channel (handle %#.4x)", handle);
277     return false;
278   }
279   self->channels_.emplace(chan_id, std::move(channel));
280   return true;
281 }
282 
AddPsmToProtocol(ProtocolQueue * protocols_to_register,l2cap::Psm psm,ServiceHandle handle) const283 bool Server::AddPsmToProtocol(ProtocolQueue* protocols_to_register,
284                               l2cap::Psm psm,
285                               ServiceHandle handle) const {
286   if (psm == l2cap::kInvalidPsm) {
287     return false;
288   }
289 
290   if (IsAllocated(psm)) {
291     bt_log(TRACE, "sdp", "L2CAP PSM %#.4x is already allocated", psm);
292     return false;
293   }
294 
295   auto data = std::make_pair(psm, handle);
296   protocols_to_register->emplace_back(std::move(data));
297   return true;
298 }
299 
GetDynamicPsm(const ProtocolQueue * queued_psms) const300 l2cap::Psm Server::GetDynamicPsm(const ProtocolQueue* queued_psms) const {
301   // Generate a random PSM in the valid range of PSMs.
302   // RNG(Range(MIN, MAX)) = MIN + RNG(MAX-MIN) where MIN = kMinDynamicPSM =
303   // 0x1001. MAX = 0xffff.
304   uint16_t offset = 0;
305   constexpr uint16_t MAX_MINUS_MIN = 0xeffe;
306   random_generator()->GetInt(offset, MAX_MINUS_MIN);
307   uint16_t psm = l2cap::kMinDynamicPsm + offset;
308   // LSB of upper octet must be 0. LSB of lower octet must be 1.
309   constexpr uint16_t UPPER_OCTET_MASK = 0xFEFF;
310   constexpr uint16_t LOWER_OCTET_MASK = 0x0001;
311   psm &= UPPER_OCTET_MASK;
312   psm |= LOWER_OCTET_MASK;
313   bt_log(DEBUG, "sdp", "Trying random dynamic PSM %#.4x", psm);
314 
315   // Check if the PSM is valid (e.g. valid construction, not allocated, & not
316   // queued).
317   if ((IsValidPsm(psm)) && (!IsAllocated(psm)) &&
318       (!IsQueuedPsm(queued_psms, psm))) {
319     bt_log(TRACE, "sdp", "Generated random dynamic PSM %#.4x", psm);
320     return psm;
321   }
322 
323   // Otherwise, fall back to sequentially finding the next available PSM.
324   bool search_wrapped = false;
325   for (uint16_t next_psm = psm + 2; next_psm <= UINT16_MAX; next_psm += 2) {
326     if ((IsValidPsm(next_psm)) && (!IsAllocated(next_psm)) &&
327         (!IsQueuedPsm(queued_psms, next_psm))) {
328       bt_log(TRACE, "sdp", "Generated sequential dynamic PSM %#.4x", next_psm);
329       return next_psm;
330     }
331 
332     // If we reach the max valid PSM, wrap around to the minimum valid dynamic
333     // PSM. Only try this once.
334     if (next_psm == 0xFEFF) {
335       next_psm = l2cap::kMinDynamicPsm;
336       if (search_wrapped) {
337         break;
338       }
339       search_wrapped = true;
340     }
341   }
342   bt_log(WARN, "sdp", "Couldn't find an available dynamic PSM");
343   return l2cap::kInvalidPsm;
344 }
345 
QueueService(ServiceRecord * record,ProtocolQueue * protocols_to_register)346 bool Server::QueueService(ServiceRecord* record,
347                           ProtocolQueue* protocols_to_register) {
348   // ProtocolDescriptorList handling:
349   if (record->HasAttribute(kProtocolDescriptorList)) {
350     const auto& primary_protocol =
351         record->GetAttribute(kProtocolDescriptorList);
352     auto psm = PsmFromProtocolList(&primary_protocol);
353     if (psm == kDynamicPsm) {
354       bt_log(TRACE, "sdp", "Primary protocol contains dynamic PSM");
355       auto primary_protocol_copy = primary_protocol.Clone();
356       psm = GetDynamicPsm(protocols_to_register);
357       if (!UpdateProtocolWithPsm(primary_protocol_copy, psm)) {
358         return false;
359       }
360       record->SetAttribute(kProtocolDescriptorList,
361                            std::move(primary_protocol_copy));
362     }
363     if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
364       return false;
365     }
366   }
367 
368   // AdditionalProtocolDescriptorList handling:
369   if (record->HasAttribute(kAdditionalProtocolDescriptorList)) {
370     // |additional_list| is a list of ProtocolDescriptorLists.
371     const auto& additional_list =
372         record->GetAttribute(kAdditionalProtocolDescriptorList);
373     size_t attribute_id = 0;
374     const auto* additional = additional_list.At(attribute_id);
375 
376     // If `kAdditionalProtocolDescriptorList` exists, there should be at least
377     // one protocol provided.
378     if (!additional) {
379       bt_log(
380           TRACE, "sdp", "AdditionalProtocolDescriptorList provided but empty");
381       return false;
382     }
383 
384     // Add valid additional PSMs to the register queue. Because some additional
385     // protocols may need dynamic PSM assignment, modify the relevant protocols
386     // and rebuild the list.
387     std::vector<DataElement> additional_protocols;
388     while (additional) {
389       auto psm = PsmFromProtocolList(additional);
390       auto additional_protocol_copy = additional->Clone();
391       if (psm == kDynamicPsm) {
392         bt_log(TRACE, "sdp", "Additional protocol contains dynamic PSM");
393         psm = GetDynamicPsm(protocols_to_register);
394         if (!UpdateProtocolWithPsm(additional_protocol_copy, psm)) {
395           return l2cap::kInvalidPsm;
396         }
397       }
398       if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
399         return false;
400       }
401 
402       attribute_id++;
403       additional_protocols.emplace_back(std::move(additional_protocol_copy));
404       additional = additional_list.At(attribute_id);
405     }
406     record->SetAttribute(kAdditionalProtocolDescriptorList,
407                          DataElement(std::move(additional_protocols)));
408   }
409 
410   // For some services that depend on OBEX, the L2CAP PSM is specified in the
411   // GoepL2capPsm attribute.
412   bool has_obex = record->FindUUID(std::unordered_set<UUID>({protocol::kOBEX}));
413   if (has_obex && record->HasAttribute(kGoepL2capPsm)) {
414     const auto& attribute = record->GetAttribute(kGoepL2capPsm);
415     if (attribute.Get<uint16_t>()) {
416       auto psm = *attribute.Get<uint16_t>();
417       // If a dynamic PSM was requested, attempt to allocate the next available
418       // PSM.
419       if (psm == kDynamicPsm) {
420         bt_log(TRACE, "sdp", "GoepL2capAttribute contains dynamic PSM");
421         psm = GetDynamicPsm(protocols_to_register);
422         record->SetAttribute(kGoepL2capPsm, DataElement(uint16_t{psm}));
423       }
424       if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
425         return false;
426       }
427     }
428   }
429 
430   return true;
431 }
432 
RegisterService(std::vector<ServiceRecord> records,l2cap::ChannelParameters chan_params,ConnectCallback conn_cb)433 RegistrationHandle Server::RegisterService(std::vector<ServiceRecord> records,
434                                            l2cap::ChannelParameters chan_params,
435                                            ConnectCallback conn_cb) {
436   if (records.empty()) {
437     return 0;
438   }
439 
440   // The PSMs and their ServiceHandles to register.
441   ProtocolQueue protocols_to_register;
442 
443   // The ServiceHandles that are assigned to each ServiceRecord.
444   // There should be one ServiceHandle per ServiceRecord in |records|.
445   std::set<ServiceHandle> assigned_handles;
446 
447   for (auto& record : records) {
448     ServiceHandle next = GetNextHandle();
449     if (!next) {
450       return 0;
451     }
452     // Assign a new handle for the service record.
453     record.SetHandle(next);
454 
455     if (!record.IsProtocolOnly()) {
456       // Place record in a browse group.
457       SetBrowseGroupList(&record);
458 
459       // Validate the |ServiceRecord|.
460       if (!record.IsRegisterable()) {
461         return 0;
462       }
463     }
464 
465     // Attempt to queue the |record| for registration.
466     // Note: Since the validation & queueing operations for ALL the records
467     // occur before registration, multiple ServiceRecords can share the same
468     // PSM.
469     //
470     // If any |record| is not parsable, exit the registration process early.
471     if (!QueueService(&record, &protocols_to_register)) {
472       return 0;
473     }
474 
475     // For every ServiceRecord, there will be one ServiceHandle assigned.
476     assigned_handles.emplace(next);
477   }
478 
479   BT_ASSERT(assigned_handles.size() == records.size());
480 
481   // The RegistrationHandle is the smallest ServiceHandle that was assigned.
482   RegistrationHandle reg_handle = *assigned_handles.begin();
483 
484   // Multiple ServiceRecords in |records| can request the same PSM. However,
485   // |l2cap_| expects a single target for each PSM to go to. Consequently,
486   // only the first occurrence of a PSM needs to be registered with the
487   // |l2cap_|.
488   std::unordered_set<l2cap::Psm> psms_to_register;
489 
490   // All PSMs have assigned handles and will be registered.
491   for (auto& [psm, handle] : protocols_to_register) {
492     psm_to_service_[psm].insert(handle);
493     service_to_psms_[handle].insert(psm);
494 
495     // Add unique PSMs to the data domain registration queue.
496     psms_to_register.insert(psm);
497   }
498 
499   for (const auto& psm : psms_to_register) {
500     bt_log(TRACE, "sdp", "Allocating PSM %#.4x for new service", psm);
501     l2cap_->RegisterService(
502         psm,
503         chan_params,
504         [psm = psm,
505          conn_cb = conn_cb.share()](l2cap::Channel::WeakPtr channel) mutable {
506           bt_log(TRACE, "sdp", "Channel connected to %#.4x", psm);
507           // Build the L2CAP descriptor
508           std::vector<DataElement> protocol_l2cap;
509           protocol_l2cap.emplace_back(protocol::kL2CAP);
510           protocol_l2cap.emplace_back(psm);
511           std::vector<DataElement> protocol;
512           protocol.emplace_back(std::move(protocol_l2cap));
513           conn_cb(std::move(channel), DataElement(std::move(protocol)));
514         });
515   }
516 
517   // Store the complete records.
518   for (auto& record : records) {
519     auto [it, success] = records_.emplace(record.handle(), std::move(record));
520     BT_DEBUG_ASSERT(success);
521     const ServiceRecord& placed_record = it->second;
522     if (placed_record.IsProtocolOnly()) {
523       bt_log(TRACE,
524              "sdp",
525              "registered protocol-only service %#.8x, Protocol: %s",
526              placed_record.handle(),
527              bt_str(placed_record.GetAttribute(kProtocolDescriptorList)));
528     } else {
529       bt_log(TRACE,
530              "sdp",
531              "registered service %#.8x, classes: %s",
532              placed_record.handle(),
533              bt_str(placed_record.GetAttribute(kServiceClassIdList)));
534     }
535   }
536 
537   // Store the RegistrationHandle that represents the set of services that were
538   // registered.
539   reg_to_service_[reg_handle] = std::move(assigned_handles);
540 
541   // Update the inspect properties.
542   UpdateInspectProperties();
543 
544   return reg_handle;
545 }
546 
UnregisterService(RegistrationHandle handle)547 bool Server::UnregisterService(RegistrationHandle handle) {
548   if (handle == kNotRegistered) {
549     return false;
550   }
551 
552   auto handles_it = reg_to_service_.extract(handle);
553   if (!handles_it) {
554     return false;
555   }
556 
557   for (const auto& svc_h : handles_it.mapped()) {
558     BT_ASSERT(svc_h != kSDPHandle);
559     BT_ASSERT(records_.find(svc_h) != records_.end());
560     bt_log(DEBUG, "sdp", "unregistering service (handle: %#.8x)", svc_h);
561 
562     // Unregister any service callbacks from L2CAP
563     auto psms_it = service_to_psms_.extract(svc_h);
564     if (psms_it) {
565       for (const auto& psm : psms_it.mapped()) {
566         bt_log(DEBUG, "sdp", "removing registration for psm %#.4x", psm);
567         l2cap_->UnregisterService(psm);
568         psm_to_service_.erase(psm);
569       }
570     }
571 
572     records_.erase(svc_h);
573   }
574 
575   // Update the inspect properties as the registered PSMs may have changed.
576   UpdateInspectProperties();
577 
578   return true;
579 }
580 
GetNextHandle()581 ServiceHandle Server::GetNextHandle() {
582   ServiceHandle initial_next_handle = next_handle_;
583   // We expect most of these to be free.
584   // Safeguard against possibly having to wrap-around and reuse handles.
585   while (records_.count(next_handle_)) {
586     if (next_handle_ == kLastHandle) {
587       bt_log(WARN, "sdp", "service handle wrapped to start");
588       next_handle_ = kFirstUnreservedHandle;
589     } else {
590       next_handle_++;
591     }
592     if (next_handle_ == initial_next_handle) {
593       return 0;
594     }
595   }
596   return next_handle_++;
597 }
598 
SearchServices(const std::unordered_set<UUID> & pattern) const599 ServiceSearchResponse Server::SearchServices(
600     const std::unordered_set<UUID>& pattern) const {
601   ServiceSearchResponse resp;
602   std::vector<ServiceHandle> matched;
603   for (const auto& it : records_) {
604     if (it.second.FindUUID(pattern) && !it.second.IsProtocolOnly()) {
605       matched.push_back(it.first);
606     }
607   }
608   bt_log(TRACE, "sdp", "ServiceSearch matched %zu records", matched.size());
609   resp.set_service_record_handle_list(matched);
610   return resp;
611 }
612 
GetServiceAttributes(ServiceHandle handle,const std::list<AttributeRange> & ranges) const613 ServiceAttributeResponse Server::GetServiceAttributes(
614     ServiceHandle handle, const std::list<AttributeRange>& ranges) const {
615   ServiceAttributeResponse resp;
616   const auto& record = records_.at(handle);
617   for (const auto& range : ranges) {
618     auto attrs = record.GetAttributesInRange(range.start, range.end);
619     for (const auto& attr : attrs) {
620       resp.set_attribute(attr, record.GetAttribute(attr).Clone());
621     }
622   }
623   bt_log(TRACE,
624          "sdp",
625          "ServiceAttribute %zu attributes",
626          resp.attributes().size());
627   return resp;
628 }
629 
SearchAllServiceAttributes(const std::unordered_set<UUID> & search_pattern,const std::list<AttributeRange> & attribute_ranges) const630 ServiceSearchAttributeResponse Server::SearchAllServiceAttributes(
631     const std::unordered_set<UUID>& search_pattern,
632     const std::list<AttributeRange>& attribute_ranges) const {
633   ServiceSearchAttributeResponse resp;
634   for (const auto& it : records_) {
635     const auto& rec = it.second;
636     if (rec.IsProtocolOnly()) {
637       continue;
638     }
639     if (rec.FindUUID(search_pattern)) {
640       for (const auto& range : attribute_ranges) {
641         auto attrs = rec.GetAttributesInRange(range.start, range.end);
642         for (const auto& attr : attrs) {
643           resp.SetAttribute(it.first, attr, rec.GetAttribute(attr).Clone());
644         }
645       }
646     }
647   }
648 
649   bt_log(TRACE,
650          "sdp",
651          "ServiceSearchAttribute %zu records",
652          resp.num_attribute_lists());
653   return resp;
654 }
655 
OnChannelClosed(l2cap::Channel::UniqueId channel_id)656 void Server::OnChannelClosed(l2cap::Channel::UniqueId channel_id) {
657   channels_.erase(channel_id);
658 }
659 
HandleRequest(ByteBufferPtr sdu,uint16_t max_tx_sdu_size)660 std::optional<ByteBufferPtr> Server::HandleRequest(ByteBufferPtr sdu,
661                                                    uint16_t max_tx_sdu_size) {
662   BT_DEBUG_ASSERT(sdu);
663   TRACE_DURATION("bluetooth", "sdp::Server::HandleRequest");
664   if (sdu->size() < sizeof(Header)) {
665     bt_log(DEBUG, "sdp", "PDU too short; dropping");
666     return std::nullopt;
667   }
668   PacketView<Header> packet(sdu.get());
669   TransactionId tid = be16toh(static_cast<uint16_t>(packet.header().tid));
670   uint16_t param_length = be16toh(packet.header().param_length);
671   auto error_response_builder =
672       [tid, max_tx_sdu_size](ErrorCode code) -> ByteBufferPtr {
673     return ErrorResponse(code).GetPDU(
674         0 /* ignored */, tid, max_tx_sdu_size, BufferView());
675   };
676   if (param_length != (sdu->size() - sizeof(Header))) {
677     bt_log(TRACE,
678            "sdp",
679            "request isn't the correct size (%hu != %zu)",
680            param_length,
681            sdu->size() - sizeof(Header));
682     return error_response_builder(ErrorCode::kInvalidSize);
683   }
684   packet.Resize(param_length);
685   switch (packet.header().pdu_id) {
686     case kServiceSearchRequest: {
687       ServiceSearchRequest request(packet.payload_data());
688       if (!request.valid()) {
689         bt_log(DEBUG, "sdp", "ServiceSearchRequest not valid");
690         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
691       }
692       auto resp = SearchServices(request.service_search_pattern());
693 
694       auto bytes = resp.GetPDU(request.max_service_record_count(),
695                                tid,
696                                max_tx_sdu_size,
697                                request.ContinuationState());
698       if (!bytes) {
699         return error_response_builder(ErrorCode::kInvalidContinuationState);
700       }
701       return std::move(bytes);
702     }
703     case kServiceAttributeRequest: {
704       ServiceAttributeRequest request(packet.payload_data());
705       if (!request.valid()) {
706         bt_log(TRACE, "sdp", "ServiceAttributeRequest not valid");
707         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
708       }
709       auto handle = request.service_record_handle();
710       auto record_it = records_.find(handle);
711       if (record_it == records_.end() || record_it->second.IsProtocolOnly()) {
712         bt_log(TRACE,
713                "sdp",
714                "ServiceAttributeRequest can't find handle %#.8x",
715                handle);
716         return error_response_builder(ErrorCode::kInvalidRecordHandle);
717       }
718       auto resp = GetServiceAttributes(handle, request.attribute_ranges());
719       auto bytes = resp.GetPDU(request.max_attribute_byte_count(),
720                                tid,
721                                max_tx_sdu_size,
722                                request.ContinuationState());
723       if (!bytes) {
724         return error_response_builder(ErrorCode::kInvalidContinuationState);
725       }
726       return std::move(bytes);
727     }
728     case kServiceSearchAttributeRequest: {
729       ServiceSearchAttributeRequest request(packet.payload_data());
730       if (!request.valid()) {
731         bt_log(TRACE, "sdp", "ServiceSearchAttributeRequest not valid");
732         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
733       }
734       auto resp = SearchAllServiceAttributes(request.service_search_pattern(),
735                                              request.attribute_ranges());
736       auto bytes = resp.GetPDU(request.max_attribute_byte_count(),
737                                tid,
738                                max_tx_sdu_size,
739                                request.ContinuationState());
740       if (!bytes) {
741         return error_response_builder(ErrorCode::kInvalidContinuationState);
742       }
743       return std::move(bytes);
744     }
745     case kErrorResponse: {
746       bt_log(TRACE, "sdp", "ErrorResponse isn't allowed as a request");
747       return error_response_builder(ErrorCode::kInvalidRequestSyntax);
748     }
749     default: {
750       bt_log(TRACE, "sdp", "unhandled request, returning InvalidRequest");
751       return error_response_builder(ErrorCode::kInvalidRequestSyntax);
752     }
753   }
754 }
755 
Send(l2cap::Channel::UniqueId channel_id,ByteBufferPtr bytes)756 void Server::Send(l2cap::Channel::UniqueId channel_id, ByteBufferPtr bytes) {
757   auto it = channels_.find(channel_id);
758   if (it == channels_.end()) {
759     bt_log(ERROR, "sdp", "can't find peer to respond to; dropping");
760     return;
761   }
762   l2cap::Channel::WeakPtr chan = it->second.get();
763   chan->Send(std::move(bytes));
764 }
765 
UpdateInspectProperties()766 void Server::UpdateInspectProperties() {
767   // Skip update if node has not been attached.
768   if (!inspect_properties_.sdp_server_node) {
769     return;
770   }
771 
772   // Clear the previous inspect data.
773   inspect_properties_.svc_record_properties.clear();
774 
775   for (const auto& svc_record : records_) {
776     auto record_string = svc_record.second.ToString();
777     auto psms_it = service_to_psms_.find(svc_record.first);
778     std::unordered_set<l2cap::Psm> psm_set;
779     if (psms_it != service_to_psms_.end()) {
780       psm_set = psms_it->second;
781     }
782 
783     InspectProperties::InspectServiceRecordProperties svc_rec_props(
784         std::move(record_string), std::move(psm_set));
785     auto& parent = inspect_properties_.sdp_server_node;
786     svc_rec_props.AttachInspect(parent, parent.UniqueName(kInspectRecordName));
787 
788     inspect_properties_.svc_record_properties.push_back(
789         std::move(svc_rec_props));
790   }
791 }
792 
AllocatedPsmsForTest() const793 std::set<l2cap::Psm> Server::AllocatedPsmsForTest() const {
794   std::set<l2cap::Psm> allocated;
795   for (auto it = psm_to_service_.begin(); it != psm_to_service_.end(); ++it) {
796     allocated.insert(it->first);
797   }
798   return allocated;
799 }
800 
801 Server::InspectProperties::InspectServiceRecordProperties::
InspectServiceRecordProperties(std::string record,std::unordered_set<l2cap::Psm> psms)802     InspectServiceRecordProperties(std::string record,
803                                    std::unordered_set<l2cap::Psm> psms)
804     : record(std::move(record)), psms(std::move(psms)) {}
805 
AttachInspect(inspect::Node & parent,std::string name)806 void Server::InspectProperties::InspectServiceRecordProperties::AttachInspect(
807     inspect::Node& parent, std::string name) {
808   node = parent.CreateChild(name);
809   record_property = node.CreateString(kInspectRecordName, record);
810   psms_node = node.CreateChild(kInspectRegisteredPsmName);
811   psm_nodes.clear();
812   for (const auto& psm : psms) {
813     auto psm_node =
814         psms_node.CreateChild(psms_node.UniqueName(kInspectPsmName));
815     auto psm_string =
816         psm_node.CreateString(kInspectPsmName, l2cap::PsmToString(psm));
817     psm_nodes.emplace_back(std::move(psm_node), std::move(psm_string));
818   }
819 }
820 
821 }  // namespace bt::sdp
822