• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/dnssd/impl/dns_data_graph.h"
6 
7 #include <utility>
8 
9 #include "discovery/dnssd/impl/conversion_layer.h"
10 #include "discovery/dnssd/impl/instance_key.h"
11 
12 namespace openscreen {
13 namespace discovery {
14 namespace {
15 
CreateEndpoint(const DomainName & domain,const absl::optional<ARecordRdata> & a,const absl::optional<AAAARecordRdata> & aaaa,const SrvRecordRdata & srv,const TxtRecordRdata & txt,NetworkInterfaceIndex network_interface)16 ErrorOr<DnsSdInstanceEndpoint> CreateEndpoint(
17     const DomainName& domain,
18     const absl::optional<ARecordRdata>& a,
19     const absl::optional<AAAARecordRdata>& aaaa,
20     const SrvRecordRdata& srv,
21     const TxtRecordRdata& txt,
22     NetworkInterfaceIndex network_interface) {
23   // Create the user-visible TXT record representation.
24   ErrorOr<DnsSdTxtRecord> txt_or_error = CreateFromDnsTxt(txt);
25   if (txt_or_error.is_error()) {
26     return txt_or_error.error();
27   }
28 
29   InstanceKey instance_id(domain);
30   std::vector<IPEndpoint> endpoints;
31   if (a.has_value()) {
32     endpoints.push_back({a.value().ipv4_address(), srv.port()});
33   }
34   if (aaaa.has_value()) {
35     endpoints.push_back({aaaa.value().ipv6_address(), srv.port()});
36   }
37 
38   return DnsSdInstanceEndpoint(
39       instance_id.instance_id(), instance_id.service_id(),
40       instance_id.domain_id(), std::move(txt_or_error.value()),
41       network_interface, std::move(endpoints));
42 }
43 
44 class DnsDataGraphImpl : public DnsDataGraph {
45  public:
46   using DnsDataGraph::DomainChangeCallback;
47 
DnsDataGraphImpl(NetworkInterfaceIndex network_interface)48   explicit DnsDataGraphImpl(NetworkInterfaceIndex network_interface)
49       : network_interface_(network_interface) {}
50   DnsDataGraphImpl(const DnsDataGraphImpl& other) = delete;
51   DnsDataGraphImpl(DnsDataGraphImpl&& other) = delete;
52 
~DnsDataGraphImpl()53   ~DnsDataGraphImpl() override { is_dtor_running_ = true; }
54 
55   DnsDataGraphImpl& operator=(const DnsDataGraphImpl& rhs) = delete;
56   DnsDataGraphImpl& operator=(DnsDataGraphImpl&& rhs) = delete;
57 
58   // DnsDataGraph overrides.
59   void StartTracking(const DomainName& domain,
60                      DomainChangeCallback on_start_tracking) override;
61 
62   void StopTracking(const DomainName& domain,
63                     DomainChangeCallback on_stop_tracking) override;
64 
65   std::vector<ErrorOr<DnsSdInstanceEndpoint>> CreateEndpoints(
66       DomainGroup domain_group,
67       const DomainName& name) const override;
68 
69   Error ApplyDataRecordChange(MdnsRecord record,
70                               RecordChangedEvent event,
71                               DomainChangeCallback on_start_tracking,
72                               DomainChangeCallback on_stop_tracking) override;
73 
GetTrackedDomainCount() const74   size_t GetTrackedDomainCount() const override { return nodes_.size(); }
75 
IsTracked(const DomainName & name) const76   bool IsTracked(const DomainName& name) const override {
77     return nodes_.find(name) != nodes_.end();
78   }
79 
80  private:
81   class NodeLifetimeHandler;
82 
83   using ScopedCallbackHandler = std::unique_ptr<NodeLifetimeHandler>;
84 
85   // A single node of the graph represented by this type.
86   class Node {
87    public:
88     // NOE: This class is non-copyable, non-movable because either operation
89     // would invalidate the pointer references or bidirectional edge states
90     // maintained by instances of this class.
91     Node(DomainName name, DnsDataGraphImpl* graph);
92     Node(const Node& other) = delete;
93     Node(Node&& other) = delete;
94 
95     ~Node();
96 
97     Node& operator=(const Node& rhs) = delete;
98     Node& operator=(Node&& rhs) = delete;
99 
100     // Applies a record change for this node.
101     Error ApplyDataRecordChange(MdnsRecord record, RecordChangedEvent event);
102 
103     // Returns the first rdata of a record with type matching |type| in this
104     // node's |records_|, or absl::nullopt if no such record exists.
105     template <typename T>
GetRdata(DnsType type)106     absl::optional<T> GetRdata(DnsType type) {
107       auto it = FindRecord(type);
108       if (it == records_.end()) {
109         return absl::nullopt;
110       } else {
111         return std::cref(absl::get<T>(it->rdata()));
112       }
113     }
114 
name() const115     const DomainName& name() const { return name_; }
parents() const116     const std::vector<Node*>& parents() const { return parents_; }
children() const117     const std::vector<Node*>& children() const { return children_; }
records() const118     const std::vector<MdnsRecord>& records() const { return records_; }
119 
120    private:
121     // Adds or removes an edge in |graph_|.
122     // NOTE: The same edge may be added multiple times, and one call to remove
123     // is needed for every such call.
124     void AddChild(Node* child);
125     void RemoveChild(Node* child);
126 
127     // Applies the specified change to domain |child| for this node.
128     void ApplyChildChange(DomainName child_name, RecordChangedEvent event);
129 
130     // Finds an iterator to the record of the provided type, or to
131     // records_.end() if no such record exists.
132     std::vector<MdnsRecord>::iterator FindRecord(DnsType type);
133 
134     // The domain with which the data records stored in this node are
135     // associated.
136     const DomainName name_;
137 
138     // Currently extant mDNS Records at |name_|.
139     std::vector<MdnsRecord> records_;
140 
141     // Nodes which contain records pointing to this node's |name|.
142     std::vector<Node*> parents_;
143 
144     // Nodes containing records pointed to by the records in this node.
145     std::vector<Node*> children_;
146 
147     // Graph containing this node.
148     DnsDataGraphImpl* graph_;
149   };
150 
151   // Wrapper to handle the creation and deletion callbacks. When the object is
152   // created, it sets the callback to use, and erases the callback when it goes
153   // out of scope. This class allows all node creations to complete before
154   // calling the user-provided callback to ensure there are no race-conditions.
155   class NodeLifetimeHandler {
156    public:
157     NodeLifetimeHandler(DomainChangeCallback* callback_ptr,
158                         DomainChangeCallback callback);
159 
160     // NOTE: The copy and delete ctors and operators must be deleted because
161     // they would invalidate the pointer logic used here.
162     NodeLifetimeHandler(const NodeLifetimeHandler& other) = delete;
163     NodeLifetimeHandler(NodeLifetimeHandler&& other) = delete;
164 
165     ~NodeLifetimeHandler();
166 
167     NodeLifetimeHandler operator=(const NodeLifetimeHandler& other) = delete;
168     NodeLifetimeHandler operator=(NodeLifetimeHandler&& other) = delete;
169 
170    private:
171     std::vector<DomainName> domains_changed;
172 
173     DomainChangeCallback* callback_ptr_;
174     DomainChangeCallback callback_;
175   };
176 
177   // Helpers to create the ScopedCallbackHandlers for creation and deletion
178   // callbacks.
179   ScopedCallbackHandler GetScopedCreationHandler(
180       DomainChangeCallback creation_callback);
181   ScopedCallbackHandler GetScopedDeletionHandler(
182       DomainChangeCallback deletion_callback);
183 
184   // Determines whether the provided node has the necessary records to be a
185   // valid node at the specified domain level.
186   static bool IsValidAddressNode(Node* node);
187   static bool IsValidSrvAndTxtNode(Node* node);
188 
189   // Calculates the set of DnsSdInstanceEndpoints associated with the PTR
190   // records present at the given |node|.
191   std::vector<ErrorOr<DnsSdInstanceEndpoint>> CalculatePtrRecordEndpoints(
192       Node* node) const;
193 
194   // Denotes whether the dtor for this instance has been called. This is
195   // required for validation of Node instance functionality. See the
196   // implementation of DnsDataGraph::Node::~Node() for more details.
197   bool is_dtor_running_ = false;
198 
199   // Map from domain name to the node containing all records associated with the
200   // name.
201   std::map<DomainName, std::unique_ptr<Node>> nodes_;
202 
203   const NetworkInterfaceIndex network_interface_;
204 
205   // The methods to be called when a domain name either starts or stops being
206   // referenced. These will only be set when a record change is ongoing, and act
207   // as a single source of truth for the creation and deletion callbacks that
208   // should be used during that operation.
209   DomainChangeCallback on_node_creation_;
210   DomainChangeCallback on_node_deletion_;
211 };
212 
Node(DomainName name,DnsDataGraphImpl * graph)213 DnsDataGraphImpl::Node::Node(DomainName name, DnsDataGraphImpl* graph)
214     : name_(std::move(name)), graph_(graph) {
215   OSP_DCHECK(graph_);
216 
217   graph_->on_node_creation_(name_);
218 }
219 
~Node()220 DnsDataGraphImpl::Node::~Node() {
221   // A node should only be deleted when it has no parents. The only case where
222   // a deletion can occur when parents are still extant is during destruction of
223   // the holding graph. In that case, the state of the graph no longer matters
224   // and all nodes will be deleted, so no need to consider the child pointers.
225   if (!graph_->is_dtor_running_) {
226     auto it = std::find_if(parents_.begin(), parents_.end(),
227                            [this](Node* parent) { return parent != this; });
228     OSP_DCHECK(it == parents_.end());
229 
230     // Erase all childrens' parent pointers to this node.
231     for (Node* child : children_) {
232       RemoveChild(child);
233     }
234 
235     OSP_DCHECK(graph_->on_node_deletion_);
236     graph_->on_node_deletion_(name_);
237   }
238 }
239 
ApplyDataRecordChange(MdnsRecord record,RecordChangedEvent event)240 Error DnsDataGraphImpl::Node::ApplyDataRecordChange(MdnsRecord record,
241                                                     RecordChangedEvent event) {
242   OSP_DCHECK(record.name() == name_);
243 
244   // The child domain to which the changed record points, or none. This is only
245   // applicable for PTR and SRV records, and is empty in all other cases.
246   DomainName child_name;
247 
248   // The location of the current record. In the case of PTR records, multiple
249   // records are allowed for the same domain. In all other cases, this is not
250   // valid.
251   std::vector<MdnsRecord>::iterator it;
252 
253   if (record.dns_type() == DnsType::kPTR) {
254     child_name = absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
255     it = std::find_if(records_.begin(), records_.end(),
256                       [record](const MdnsRecord& rhs) {
257                         return record.IsReannouncementOf(rhs);
258                       });
259   } else {
260     if (record.dns_type() == DnsType::kSRV) {
261       child_name = absl::get<SrvRecordRdata>(record.rdata()).target();
262     }
263     it = FindRecord(record.dns_type());
264   }
265 
266   // Validate that the requested change is allowed and apply it.
267   switch (event) {
268     case RecordChangedEvent::kCreated:
269       if (it != records_.end()) {
270         return Error::Code::kItemAlreadyExists;
271       }
272       records_.push_back(std::move(record));
273       break;
274 
275     case RecordChangedEvent::kUpdated:
276       if (it == records_.end()) {
277         return Error::Code::kItemNotFound;
278       }
279       *it = std::move(record);
280       break;
281 
282     case RecordChangedEvent::kExpired:
283       if (it == records_.end()) {
284         return Error::Code::kItemNotFound;
285       }
286       records_.erase(it);
287       break;
288   }
289 
290   // Apply any required edge changes to the graph. This is only applicable if
291   // a |child| was found earlier. Note that the same child can be added multiple
292   // times to the |children_| vector, which simplifies the code dramatically.
293   if (!child_name.empty()) {
294     ApplyChildChange(std::move(child_name), event);
295   }
296 
297   return Error::None();
298 }
299 
ApplyChildChange(DomainName child_name,RecordChangedEvent event)300 void DnsDataGraphImpl::Node::ApplyChildChange(DomainName child_name,
301                                               RecordChangedEvent event) {
302   if (event == RecordChangedEvent::kCreated) {
303     const auto pair =
304         graph_->nodes_.emplace(child_name, std::unique_ptr<Node>());
305     if (pair.second) {
306       auto new_node = std::make_unique<Node>(std::move(child_name), graph_);
307       pair.first->second.swap(new_node);
308     }
309 
310     AddChild(pair.first->second.get());
311   } else if (event == RecordChangedEvent::kExpired) {
312     const auto it = graph_->nodes_.find(child_name);
313     OSP_DCHECK(it != graph_->nodes_.end());
314     RemoveChild(it->second.get());
315   }
316 }
317 
AddChild(Node * child)318 void DnsDataGraphImpl::Node::AddChild(Node* child) {
319   OSP_DCHECK(child);
320   children_.push_back(child);
321   child->parents_.push_back(this);
322 }
323 
RemoveChild(Node * child)324 void DnsDataGraphImpl::Node::RemoveChild(Node* child) {
325   OSP_DCHECK(child);
326 
327   auto it = std::find(children_.begin(), children_.end(), child);
328   OSP_DCHECK(it != children_.end());
329   children_.erase(it);
330 
331   it = std::find(child->parents_.begin(), child->parents_.end(), this);
332   OSP_DCHECK(it != child->parents_.end());
333   child->parents_.erase(it);
334 
335   // If the node has been orphaned, remove it.
336   it = std::find_if(child->parents_.begin(), child->parents_.end(),
337                     [child](Node* parent) { return parent != child; });
338   if (it == child->parents_.end()) {
339     DomainName child_name = child->name();
340     const size_t count = graph_->nodes_.erase(child_name);
341     OSP_DCHECK(child == this || count);
342   }
343 }
344 
FindRecord(DnsType type)345 std::vector<MdnsRecord>::iterator DnsDataGraphImpl::Node::FindRecord(
346     DnsType type) {
347   return std::find_if(
348       records_.begin(), records_.end(),
349       [type](const MdnsRecord& record) { return record.dns_type() == type; });
350 }
351 
NodeLifetimeHandler(DomainChangeCallback * callback_ptr,DomainChangeCallback callback)352 DnsDataGraphImpl::NodeLifetimeHandler::NodeLifetimeHandler(
353     DomainChangeCallback* callback_ptr,
354     DomainChangeCallback callback)
355     : callback_ptr_(callback_ptr), callback_(callback) {
356   OSP_DCHECK(callback_ptr_);
357   OSP_DCHECK(callback);
358   OSP_DCHECK(*callback_ptr_ == nullptr);
359   *callback_ptr = [this](DomainName domain) {
360     domains_changed.push_back(std::move(domain));
361   };
362 }
363 
~NodeLifetimeHandler()364 DnsDataGraphImpl::NodeLifetimeHandler::~NodeLifetimeHandler() {
365   *callback_ptr_ = nullptr;
366   for (DomainName& domain : domains_changed) {
367     callback_(domain);
368   }
369 }
370 
371 DnsDataGraphImpl::ScopedCallbackHandler
GetScopedCreationHandler(DomainChangeCallback creation_callback)372 DnsDataGraphImpl::GetScopedCreationHandler(
373     DomainChangeCallback creation_callback) {
374   return std::make_unique<NodeLifetimeHandler>(&on_node_creation_,
375                                                std::move(creation_callback));
376 }
377 
378 DnsDataGraphImpl::ScopedCallbackHandler
GetScopedDeletionHandler(DomainChangeCallback deletion_callback)379 DnsDataGraphImpl::GetScopedDeletionHandler(
380     DomainChangeCallback deletion_callback) {
381   return std::make_unique<NodeLifetimeHandler>(&on_node_deletion_,
382                                                std::move(deletion_callback));
383 }
384 
StartTracking(const DomainName & domain,DomainChangeCallback on_start_tracking)385 void DnsDataGraphImpl::StartTracking(const DomainName& domain,
386                                      DomainChangeCallback on_start_tracking) {
387   ScopedCallbackHandler creation_handler =
388       GetScopedCreationHandler(std::move(on_start_tracking));
389 
390   auto pair = nodes_.emplace(domain, std::make_unique<Node>(domain, this));
391 
392   OSP_DCHECK(pair.second);
393   OSP_DCHECK(nodes_.find(domain) != nodes_.end());
394 }
395 
StopTracking(const DomainName & domain,DomainChangeCallback on_stop_tracking)396 void DnsDataGraphImpl::StopTracking(const DomainName& domain,
397                                     DomainChangeCallback on_stop_tracking) {
398   ScopedCallbackHandler deletion_handler =
399       GetScopedDeletionHandler(std::move(on_stop_tracking));
400 
401   auto it = nodes_.find(domain);
402   OSP_CHECK(it != nodes_.end());
403   OSP_DCHECK(it->second->parents().empty());
404   it->second.reset();
405   const size_t erased_count = nodes_.erase(domain);
406   OSP_DCHECK(erased_count);
407 }
408 
ApplyDataRecordChange(MdnsRecord record,RecordChangedEvent event,DomainChangeCallback on_start_tracking,DomainChangeCallback on_stop_tracking)409 Error DnsDataGraphImpl::ApplyDataRecordChange(
410     MdnsRecord record,
411     RecordChangedEvent event,
412     DomainChangeCallback on_start_tracking,
413     DomainChangeCallback on_stop_tracking) {
414   ScopedCallbackHandler creation_handler =
415       GetScopedCreationHandler(std::move(on_start_tracking));
416   ScopedCallbackHandler deletion_handler =
417       GetScopedDeletionHandler(std::move(on_stop_tracking));
418 
419   auto it = nodes_.find(record.name());
420   if (it == nodes_.end()) {
421     return Error::Code::kOperationCancelled;
422   }
423 
424   const auto result =
425       it->second->ApplyDataRecordChange(std::move(record), event);
426 
427   return result;
428 }
429 
CreateEndpoints(DomainGroup domain_group,const DomainName & name) const430 std::vector<ErrorOr<DnsSdInstanceEndpoint>> DnsDataGraphImpl::CreateEndpoints(
431     DomainGroup domain_group,
432     const DomainName& name) const {
433   const auto it = nodes_.find(name);
434   if (it == nodes_.end()) {
435     return {};
436   }
437   Node* target_node = it->second.get();
438 
439   // NOTE: One of these will contain no more than one element, so iterating over
440   // them both will be fast.
441   std::vector<Node*> srv_and_txt_record_nodes;
442   std::vector<Node*> address_record_nodes;
443 
444   switch (domain_group) {
445     case DomainGroup::kAddress:
446       if (!IsValidAddressNode(target_node)) {
447         return {};
448       }
449 
450       address_record_nodes.push_back(target_node);
451       srv_and_txt_record_nodes = target_node->parents();
452       break;
453 
454     case DomainGroup::kSrvAndTxt:
455       if (!IsValidSrvAndTxtNode(target_node)) {
456         return {};
457       }
458 
459       srv_and_txt_record_nodes.push_back(target_node);
460       address_record_nodes = target_node->children();
461       break;
462 
463     case DomainGroup::kPtr:
464       return CalculatePtrRecordEndpoints(target_node);
465 
466     default:
467       return {};
468   }
469 
470   // Iterate across all node pairs and create all possible DnsSdInstanceEndpoint
471   // objects.
472   std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
473   for (Node* srv_and_txt : srv_and_txt_record_nodes) {
474     for (Node* address : address_record_nodes) {
475       // First, there has to be a SRV record present (to provide the port
476       // number), and the target of that SRV record has to be the node where the
477       // address records are sourced from.
478       const absl::optional<SrvRecordRdata> srv =
479           srv_and_txt->GetRdata<SrvRecordRdata>(DnsType::kSRV);
480       if (!srv.has_value() || srv.value().target() != address->name()) {
481         continue;
482       }
483 
484       // Next, a TXT record must be present to provide additional connection
485       // information about the service per RFC 6763.
486       const absl::optional<TxtRecordRdata> txt =
487           srv_and_txt->GetRdata<TxtRecordRdata>(DnsType::kTXT);
488       if (!txt.has_value()) {
489         continue;
490       }
491 
492       // Last, at least one address record must be present to provide an
493       // endpoint for this instance.
494       const absl::optional<ARecordRdata> a =
495           address->GetRdata<ARecordRdata>(DnsType::kA);
496       const absl::optional<AAAARecordRdata> aaaa =
497           address->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
498       if (!a.has_value() && !aaaa.has_value()) {
499         continue;
500       }
501 
502       // Then use the above info to create an endpoint object. If an error
503       // occurs, this is only related to the one endpoint and its possible that
504       // other endpoints may still be valid, so only the one endpoint is treated
505       // as failing. For instance, a bad TXT record for service A will not
506       // affect the endpoints for service B.
507       ErrorOr<DnsSdInstanceEndpoint> endpoint =
508           CreateEndpoint(srv_and_txt->name(), a, aaaa, srv.value(), txt.value(),
509                          network_interface_);
510       endpoints.push_back(std::move(endpoint));
511     }
512   }
513 
514   return endpoints;
515 }
516 
517 // static
IsValidAddressNode(Node * node)518 bool DnsDataGraphImpl::IsValidAddressNode(Node* node) {
519   const absl::optional<ARecordRdata> a =
520       node->GetRdata<ARecordRdata>(DnsType::kA);
521   const absl::optional<AAAARecordRdata> aaaa =
522       node->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
523   return a.has_value() || aaaa.has_value();
524 }
525 
526 // static
IsValidSrvAndTxtNode(Node * node)527 bool DnsDataGraphImpl::IsValidSrvAndTxtNode(Node* node) {
528   const absl::optional<SrvRecordRdata> srv =
529       node->GetRdata<SrvRecordRdata>(DnsType::kSRV);
530   const absl::optional<TxtRecordRdata> txt =
531       node->GetRdata<TxtRecordRdata>(DnsType::kTXT);
532 
533   return srv.has_value() && txt.has_value();
534 }
535 
536 std::vector<ErrorOr<DnsSdInstanceEndpoint>>
CalculatePtrRecordEndpoints(Node * node) const537 DnsDataGraphImpl::CalculatePtrRecordEndpoints(Node* node) const {
538   // PTR records aren't actually part of the generated endpoint objects, so
539   // call this method recursively on all children and
540   std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
541   for (const MdnsRecord& record : node->records()) {
542     if (record.dns_type() != DnsType::kPTR) {
543       continue;
544     }
545 
546     const DomainName domain =
547         absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
548     const Node* child = nodes_.find(domain)->second.get();
549     std::vector<ErrorOr<DnsSdInstanceEndpoint>> child_endpoints =
550         CreateEndpoints(DomainGroup::kSrvAndTxt, child->name());
551     for (auto& endpoint_or_error : child_endpoints) {
552       endpoints.push_back(std::move(endpoint_or_error));
553     }
554   }
555   return endpoints;
556 }
557 
558 }  // namespace
559 
560 DnsDataGraph::~DnsDataGraph() = default;
561 
562 // static
Create(NetworkInterfaceIndex network_interface)563 std::unique_ptr<DnsDataGraph> DnsDataGraph::Create(
564     NetworkInterfaceIndex network_interface) {
565   return std::make_unique<DnsDataGraphImpl>(network_interface);
566 }
567 
568 // static
GetDomainGroup(DnsType type)569 DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(DnsType type) {
570   switch (type) {
571     case DnsType::kA:
572     case DnsType::kAAAA:
573       return DnsDataGraphImpl::DomainGroup::kAddress;
574     case DnsType::kSRV:
575     case DnsType::kTXT:
576       return DnsDataGraphImpl::DomainGroup::kSrvAndTxt;
577     case DnsType::kPTR:
578       return DnsDataGraphImpl::DomainGroup::kPtr;
579     default:
580       OSP_NOTREACHED();
581   }
582 }
583 
584 // static
GetDomainGroup(const MdnsRecord record)585 DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(
586     const MdnsRecord record) {
587   return GetDomainGroup(record.dns_type());
588 }
589 
590 }  // namespace discovery
591 }  // namespace openscreen
592