1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/dnssd/impl/dns_data_graph.h"
6
7 #include <utility>
8
9 #include "discovery/dnssd/impl/conversion_layer.h"
10 #include "discovery/dnssd/impl/instance_key.h"
11
12 namespace openscreen {
13 namespace discovery {
14 namespace {
15
CreateEndpoint(const DomainName & domain,const absl::optional<ARecordRdata> & a,const absl::optional<AAAARecordRdata> & aaaa,const SrvRecordRdata & srv,const TxtRecordRdata & txt,NetworkInterfaceIndex network_interface)16 ErrorOr<DnsSdInstanceEndpoint> CreateEndpoint(
17 const DomainName& domain,
18 const absl::optional<ARecordRdata>& a,
19 const absl::optional<AAAARecordRdata>& aaaa,
20 const SrvRecordRdata& srv,
21 const TxtRecordRdata& txt,
22 NetworkInterfaceIndex network_interface) {
23 // Create the user-visible TXT record representation.
24 ErrorOr<DnsSdTxtRecord> txt_or_error = CreateFromDnsTxt(txt);
25 if (txt_or_error.is_error()) {
26 return txt_or_error.error();
27 }
28
29 InstanceKey instance_id(domain);
30 std::vector<IPEndpoint> endpoints;
31 if (a.has_value()) {
32 endpoints.push_back({a.value().ipv4_address(), srv.port()});
33 }
34 if (aaaa.has_value()) {
35 endpoints.push_back({aaaa.value().ipv6_address(), srv.port()});
36 }
37
38 return DnsSdInstanceEndpoint(
39 instance_id.instance_id(), instance_id.service_id(),
40 instance_id.domain_id(), std::move(txt_or_error.value()),
41 network_interface, std::move(endpoints));
42 }
43
44 class DnsDataGraphImpl : public DnsDataGraph {
45 public:
46 using DnsDataGraph::DomainChangeCallback;
47
DnsDataGraphImpl(NetworkInterfaceIndex network_interface)48 explicit DnsDataGraphImpl(NetworkInterfaceIndex network_interface)
49 : network_interface_(network_interface) {}
50 DnsDataGraphImpl(const DnsDataGraphImpl& other) = delete;
51 DnsDataGraphImpl(DnsDataGraphImpl&& other) = delete;
52
~DnsDataGraphImpl()53 ~DnsDataGraphImpl() override { is_dtor_running_ = true; }
54
55 DnsDataGraphImpl& operator=(const DnsDataGraphImpl& rhs) = delete;
56 DnsDataGraphImpl& operator=(DnsDataGraphImpl&& rhs) = delete;
57
58 // DnsDataGraph overrides.
59 void StartTracking(const DomainName& domain,
60 DomainChangeCallback on_start_tracking) override;
61
62 void StopTracking(const DomainName& domain,
63 DomainChangeCallback on_stop_tracking) override;
64
65 std::vector<ErrorOr<DnsSdInstanceEndpoint>> CreateEndpoints(
66 DomainGroup domain_group,
67 const DomainName& name) const override;
68
69 Error ApplyDataRecordChange(MdnsRecord record,
70 RecordChangedEvent event,
71 DomainChangeCallback on_start_tracking,
72 DomainChangeCallback on_stop_tracking) override;
73
GetTrackedDomainCount() const74 size_t GetTrackedDomainCount() const override { return nodes_.size(); }
75
IsTracked(const DomainName & name) const76 bool IsTracked(const DomainName& name) const override {
77 return nodes_.find(name) != nodes_.end();
78 }
79
80 private:
81 class NodeLifetimeHandler;
82
83 using ScopedCallbackHandler = std::unique_ptr<NodeLifetimeHandler>;
84
85 // A single node of the graph represented by this type.
86 class Node {
87 public:
88 // NOE: This class is non-copyable, non-movable because either operation
89 // would invalidate the pointer references or bidirectional edge states
90 // maintained by instances of this class.
91 Node(DomainName name, DnsDataGraphImpl* graph);
92 Node(const Node& other) = delete;
93 Node(Node&& other) = delete;
94
95 ~Node();
96
97 Node& operator=(const Node& rhs) = delete;
98 Node& operator=(Node&& rhs) = delete;
99
100 // Applies a record change for this node.
101 Error ApplyDataRecordChange(MdnsRecord record, RecordChangedEvent event);
102
103 // Returns the first rdata of a record with type matching |type| in this
104 // node's |records_|, or absl::nullopt if no such record exists.
105 template <typename T>
GetRdata(DnsType type)106 absl::optional<T> GetRdata(DnsType type) {
107 auto it = FindRecord(type);
108 if (it == records_.end()) {
109 return absl::nullopt;
110 } else {
111 return std::cref(absl::get<T>(it->rdata()));
112 }
113 }
114
name() const115 const DomainName& name() const { return name_; }
parents() const116 const std::vector<Node*>& parents() const { return parents_; }
children() const117 const std::vector<Node*>& children() const { return children_; }
records() const118 const std::vector<MdnsRecord>& records() const { return records_; }
119
120 private:
121 // Adds or removes an edge in |graph_|.
122 // NOTE: The same edge may be added multiple times, and one call to remove
123 // is needed for every such call.
124 void AddChild(Node* child);
125 void RemoveChild(Node* child);
126
127 // Applies the specified change to domain |child| for this node.
128 void ApplyChildChange(DomainName child_name, RecordChangedEvent event);
129
130 // Finds an iterator to the record of the provided type, or to
131 // records_.end() if no such record exists.
132 std::vector<MdnsRecord>::iterator FindRecord(DnsType type);
133
134 // The domain with which the data records stored in this node are
135 // associated.
136 const DomainName name_;
137
138 // Currently extant mDNS Records at |name_|.
139 std::vector<MdnsRecord> records_;
140
141 // Nodes which contain records pointing to this node's |name|.
142 std::vector<Node*> parents_;
143
144 // Nodes containing records pointed to by the records in this node.
145 std::vector<Node*> children_;
146
147 // Graph containing this node.
148 DnsDataGraphImpl* graph_;
149 };
150
151 // Wrapper to handle the creation and deletion callbacks. When the object is
152 // created, it sets the callback to use, and erases the callback when it goes
153 // out of scope. This class allows all node creations to complete before
154 // calling the user-provided callback to ensure there are no race-conditions.
155 class NodeLifetimeHandler {
156 public:
157 NodeLifetimeHandler(DomainChangeCallback* callback_ptr,
158 DomainChangeCallback callback);
159
160 // NOTE: The copy and delete ctors and operators must be deleted because
161 // they would invalidate the pointer logic used here.
162 NodeLifetimeHandler(const NodeLifetimeHandler& other) = delete;
163 NodeLifetimeHandler(NodeLifetimeHandler&& other) = delete;
164
165 ~NodeLifetimeHandler();
166
167 NodeLifetimeHandler operator=(const NodeLifetimeHandler& other) = delete;
168 NodeLifetimeHandler operator=(NodeLifetimeHandler&& other) = delete;
169
170 private:
171 std::vector<DomainName> domains_changed;
172
173 DomainChangeCallback* callback_ptr_;
174 DomainChangeCallback callback_;
175 };
176
177 // Helpers to create the ScopedCallbackHandlers for creation and deletion
178 // callbacks.
179 ScopedCallbackHandler GetScopedCreationHandler(
180 DomainChangeCallback creation_callback);
181 ScopedCallbackHandler GetScopedDeletionHandler(
182 DomainChangeCallback deletion_callback);
183
184 // Determines whether the provided node has the necessary records to be a
185 // valid node at the specified domain level.
186 static bool IsValidAddressNode(Node* node);
187 static bool IsValidSrvAndTxtNode(Node* node);
188
189 // Calculates the set of DnsSdInstanceEndpoints associated with the PTR
190 // records present at the given |node|.
191 std::vector<ErrorOr<DnsSdInstanceEndpoint>> CalculatePtrRecordEndpoints(
192 Node* node) const;
193
194 // Denotes whether the dtor for this instance has been called. This is
195 // required for validation of Node instance functionality. See the
196 // implementation of DnsDataGraph::Node::~Node() for more details.
197 bool is_dtor_running_ = false;
198
199 // Map from domain name to the node containing all records associated with the
200 // name.
201 std::map<DomainName, std::unique_ptr<Node>> nodes_;
202
203 const NetworkInterfaceIndex network_interface_;
204
205 // The methods to be called when a domain name either starts or stops being
206 // referenced. These will only be set when a record change is ongoing, and act
207 // as a single source of truth for the creation and deletion callbacks that
208 // should be used during that operation.
209 DomainChangeCallback on_node_creation_;
210 DomainChangeCallback on_node_deletion_;
211 };
212
Node(DomainName name,DnsDataGraphImpl * graph)213 DnsDataGraphImpl::Node::Node(DomainName name, DnsDataGraphImpl* graph)
214 : name_(std::move(name)), graph_(graph) {
215 OSP_DCHECK(graph_);
216
217 graph_->on_node_creation_(name_);
218 }
219
~Node()220 DnsDataGraphImpl::Node::~Node() {
221 // A node should only be deleted when it has no parents. The only case where
222 // a deletion can occur when parents are still extant is during destruction of
223 // the holding graph. In that case, the state of the graph no longer matters
224 // and all nodes will be deleted, so no need to consider the child pointers.
225 if (!graph_->is_dtor_running_) {
226 auto it = std::find_if(parents_.begin(), parents_.end(),
227 [this](Node* parent) { return parent != this; });
228 OSP_DCHECK(it == parents_.end());
229
230 // Erase all childrens' parent pointers to this node.
231 for (Node* child : children_) {
232 RemoveChild(child);
233 }
234
235 OSP_DCHECK(graph_->on_node_deletion_);
236 graph_->on_node_deletion_(name_);
237 }
238 }
239
ApplyDataRecordChange(MdnsRecord record,RecordChangedEvent event)240 Error DnsDataGraphImpl::Node::ApplyDataRecordChange(MdnsRecord record,
241 RecordChangedEvent event) {
242 OSP_DCHECK(record.name() == name_);
243
244 // The child domain to which the changed record points, or none. This is only
245 // applicable for PTR and SRV records, and is empty in all other cases.
246 DomainName child_name;
247
248 // The location of the current record. In the case of PTR records, multiple
249 // records are allowed for the same domain. In all other cases, this is not
250 // valid.
251 std::vector<MdnsRecord>::iterator it;
252
253 if (record.dns_type() == DnsType::kPTR) {
254 child_name = absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
255 it = std::find_if(records_.begin(), records_.end(),
256 [record](const MdnsRecord& rhs) {
257 return record.IsReannouncementOf(rhs);
258 });
259 } else {
260 if (record.dns_type() == DnsType::kSRV) {
261 child_name = absl::get<SrvRecordRdata>(record.rdata()).target();
262 }
263 it = FindRecord(record.dns_type());
264 }
265
266 // Validate that the requested change is allowed and apply it.
267 switch (event) {
268 case RecordChangedEvent::kCreated:
269 if (it != records_.end()) {
270 return Error::Code::kItemAlreadyExists;
271 }
272 records_.push_back(std::move(record));
273 break;
274
275 case RecordChangedEvent::kUpdated:
276 if (it == records_.end()) {
277 return Error::Code::kItemNotFound;
278 }
279 *it = std::move(record);
280 break;
281
282 case RecordChangedEvent::kExpired:
283 if (it == records_.end()) {
284 return Error::Code::kItemNotFound;
285 }
286 records_.erase(it);
287 break;
288 }
289
290 // Apply any required edge changes to the graph. This is only applicable if
291 // a |child| was found earlier. Note that the same child can be added multiple
292 // times to the |children_| vector, which simplifies the code dramatically.
293 if (!child_name.empty()) {
294 ApplyChildChange(std::move(child_name), event);
295 }
296
297 return Error::None();
298 }
299
ApplyChildChange(DomainName child_name,RecordChangedEvent event)300 void DnsDataGraphImpl::Node::ApplyChildChange(DomainName child_name,
301 RecordChangedEvent event) {
302 if (event == RecordChangedEvent::kCreated) {
303 const auto pair =
304 graph_->nodes_.emplace(child_name, std::unique_ptr<Node>());
305 if (pair.second) {
306 auto new_node = std::make_unique<Node>(std::move(child_name), graph_);
307 pair.first->second.swap(new_node);
308 }
309
310 AddChild(pair.first->second.get());
311 } else if (event == RecordChangedEvent::kExpired) {
312 const auto it = graph_->nodes_.find(child_name);
313 OSP_DCHECK(it != graph_->nodes_.end());
314 RemoveChild(it->second.get());
315 }
316 }
317
AddChild(Node * child)318 void DnsDataGraphImpl::Node::AddChild(Node* child) {
319 OSP_DCHECK(child);
320 children_.push_back(child);
321 child->parents_.push_back(this);
322 }
323
RemoveChild(Node * child)324 void DnsDataGraphImpl::Node::RemoveChild(Node* child) {
325 OSP_DCHECK(child);
326
327 auto it = std::find(children_.begin(), children_.end(), child);
328 OSP_DCHECK(it != children_.end());
329 children_.erase(it);
330
331 it = std::find(child->parents_.begin(), child->parents_.end(), this);
332 OSP_DCHECK(it != child->parents_.end());
333 child->parents_.erase(it);
334
335 // If the node has been orphaned, remove it.
336 it = std::find_if(child->parents_.begin(), child->parents_.end(),
337 [child](Node* parent) { return parent != child; });
338 if (it == child->parents_.end()) {
339 DomainName child_name = child->name();
340 const size_t count = graph_->nodes_.erase(child_name);
341 OSP_DCHECK(child == this || count);
342 }
343 }
344
FindRecord(DnsType type)345 std::vector<MdnsRecord>::iterator DnsDataGraphImpl::Node::FindRecord(
346 DnsType type) {
347 return std::find_if(
348 records_.begin(), records_.end(),
349 [type](const MdnsRecord& record) { return record.dns_type() == type; });
350 }
351
NodeLifetimeHandler(DomainChangeCallback * callback_ptr,DomainChangeCallback callback)352 DnsDataGraphImpl::NodeLifetimeHandler::NodeLifetimeHandler(
353 DomainChangeCallback* callback_ptr,
354 DomainChangeCallback callback)
355 : callback_ptr_(callback_ptr), callback_(callback) {
356 OSP_DCHECK(callback_ptr_);
357 OSP_DCHECK(callback);
358 OSP_DCHECK(*callback_ptr_ == nullptr);
359 *callback_ptr = [this](DomainName domain) {
360 domains_changed.push_back(std::move(domain));
361 };
362 }
363
~NodeLifetimeHandler()364 DnsDataGraphImpl::NodeLifetimeHandler::~NodeLifetimeHandler() {
365 *callback_ptr_ = nullptr;
366 for (DomainName& domain : domains_changed) {
367 callback_(domain);
368 }
369 }
370
371 DnsDataGraphImpl::ScopedCallbackHandler
GetScopedCreationHandler(DomainChangeCallback creation_callback)372 DnsDataGraphImpl::GetScopedCreationHandler(
373 DomainChangeCallback creation_callback) {
374 return std::make_unique<NodeLifetimeHandler>(&on_node_creation_,
375 std::move(creation_callback));
376 }
377
378 DnsDataGraphImpl::ScopedCallbackHandler
GetScopedDeletionHandler(DomainChangeCallback deletion_callback)379 DnsDataGraphImpl::GetScopedDeletionHandler(
380 DomainChangeCallback deletion_callback) {
381 return std::make_unique<NodeLifetimeHandler>(&on_node_deletion_,
382 std::move(deletion_callback));
383 }
384
StartTracking(const DomainName & domain,DomainChangeCallback on_start_tracking)385 void DnsDataGraphImpl::StartTracking(const DomainName& domain,
386 DomainChangeCallback on_start_tracking) {
387 ScopedCallbackHandler creation_handler =
388 GetScopedCreationHandler(std::move(on_start_tracking));
389
390 auto pair = nodes_.emplace(domain, std::make_unique<Node>(domain, this));
391
392 OSP_DCHECK(pair.second);
393 OSP_DCHECK(nodes_.find(domain) != nodes_.end());
394 }
395
StopTracking(const DomainName & domain,DomainChangeCallback on_stop_tracking)396 void DnsDataGraphImpl::StopTracking(const DomainName& domain,
397 DomainChangeCallback on_stop_tracking) {
398 ScopedCallbackHandler deletion_handler =
399 GetScopedDeletionHandler(std::move(on_stop_tracking));
400
401 auto it = nodes_.find(domain);
402 OSP_CHECK(it != nodes_.end());
403 OSP_DCHECK(it->second->parents().empty());
404 it->second.reset();
405 const size_t erased_count = nodes_.erase(domain);
406 OSP_DCHECK(erased_count);
407 }
408
ApplyDataRecordChange(MdnsRecord record,RecordChangedEvent event,DomainChangeCallback on_start_tracking,DomainChangeCallback on_stop_tracking)409 Error DnsDataGraphImpl::ApplyDataRecordChange(
410 MdnsRecord record,
411 RecordChangedEvent event,
412 DomainChangeCallback on_start_tracking,
413 DomainChangeCallback on_stop_tracking) {
414 ScopedCallbackHandler creation_handler =
415 GetScopedCreationHandler(std::move(on_start_tracking));
416 ScopedCallbackHandler deletion_handler =
417 GetScopedDeletionHandler(std::move(on_stop_tracking));
418
419 auto it = nodes_.find(record.name());
420 if (it == nodes_.end()) {
421 return Error::Code::kOperationCancelled;
422 }
423
424 const auto result =
425 it->second->ApplyDataRecordChange(std::move(record), event);
426
427 return result;
428 }
429
CreateEndpoints(DomainGroup domain_group,const DomainName & name) const430 std::vector<ErrorOr<DnsSdInstanceEndpoint>> DnsDataGraphImpl::CreateEndpoints(
431 DomainGroup domain_group,
432 const DomainName& name) const {
433 const auto it = nodes_.find(name);
434 if (it == nodes_.end()) {
435 return {};
436 }
437 Node* target_node = it->second.get();
438
439 // NOTE: One of these will contain no more than one element, so iterating over
440 // them both will be fast.
441 std::vector<Node*> srv_and_txt_record_nodes;
442 std::vector<Node*> address_record_nodes;
443
444 switch (domain_group) {
445 case DomainGroup::kAddress:
446 if (!IsValidAddressNode(target_node)) {
447 return {};
448 }
449
450 address_record_nodes.push_back(target_node);
451 srv_and_txt_record_nodes = target_node->parents();
452 break;
453
454 case DomainGroup::kSrvAndTxt:
455 if (!IsValidSrvAndTxtNode(target_node)) {
456 return {};
457 }
458
459 srv_and_txt_record_nodes.push_back(target_node);
460 address_record_nodes = target_node->children();
461 break;
462
463 case DomainGroup::kPtr:
464 return CalculatePtrRecordEndpoints(target_node);
465
466 default:
467 return {};
468 }
469
470 // Iterate across all node pairs and create all possible DnsSdInstanceEndpoint
471 // objects.
472 std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
473 for (Node* srv_and_txt : srv_and_txt_record_nodes) {
474 for (Node* address : address_record_nodes) {
475 // First, there has to be a SRV record present (to provide the port
476 // number), and the target of that SRV record has to be the node where the
477 // address records are sourced from.
478 const absl::optional<SrvRecordRdata> srv =
479 srv_and_txt->GetRdata<SrvRecordRdata>(DnsType::kSRV);
480 if (!srv.has_value() || srv.value().target() != address->name()) {
481 continue;
482 }
483
484 // Next, a TXT record must be present to provide additional connection
485 // information about the service per RFC 6763.
486 const absl::optional<TxtRecordRdata> txt =
487 srv_and_txt->GetRdata<TxtRecordRdata>(DnsType::kTXT);
488 if (!txt.has_value()) {
489 continue;
490 }
491
492 // Last, at least one address record must be present to provide an
493 // endpoint for this instance.
494 const absl::optional<ARecordRdata> a =
495 address->GetRdata<ARecordRdata>(DnsType::kA);
496 const absl::optional<AAAARecordRdata> aaaa =
497 address->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
498 if (!a.has_value() && !aaaa.has_value()) {
499 continue;
500 }
501
502 // Then use the above info to create an endpoint object. If an error
503 // occurs, this is only related to the one endpoint and its possible that
504 // other endpoints may still be valid, so only the one endpoint is treated
505 // as failing. For instance, a bad TXT record for service A will not
506 // affect the endpoints for service B.
507 ErrorOr<DnsSdInstanceEndpoint> endpoint =
508 CreateEndpoint(srv_and_txt->name(), a, aaaa, srv.value(), txt.value(),
509 network_interface_);
510 endpoints.push_back(std::move(endpoint));
511 }
512 }
513
514 return endpoints;
515 }
516
517 // static
IsValidAddressNode(Node * node)518 bool DnsDataGraphImpl::IsValidAddressNode(Node* node) {
519 const absl::optional<ARecordRdata> a =
520 node->GetRdata<ARecordRdata>(DnsType::kA);
521 const absl::optional<AAAARecordRdata> aaaa =
522 node->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
523 return a.has_value() || aaaa.has_value();
524 }
525
526 // static
IsValidSrvAndTxtNode(Node * node)527 bool DnsDataGraphImpl::IsValidSrvAndTxtNode(Node* node) {
528 const absl::optional<SrvRecordRdata> srv =
529 node->GetRdata<SrvRecordRdata>(DnsType::kSRV);
530 const absl::optional<TxtRecordRdata> txt =
531 node->GetRdata<TxtRecordRdata>(DnsType::kTXT);
532
533 return srv.has_value() && txt.has_value();
534 }
535
536 std::vector<ErrorOr<DnsSdInstanceEndpoint>>
CalculatePtrRecordEndpoints(Node * node) const537 DnsDataGraphImpl::CalculatePtrRecordEndpoints(Node* node) const {
538 // PTR records aren't actually part of the generated endpoint objects, so
539 // call this method recursively on all children and
540 std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
541 for (const MdnsRecord& record : node->records()) {
542 if (record.dns_type() != DnsType::kPTR) {
543 continue;
544 }
545
546 const DomainName domain =
547 absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
548 const Node* child = nodes_.find(domain)->second.get();
549 std::vector<ErrorOr<DnsSdInstanceEndpoint>> child_endpoints =
550 CreateEndpoints(DomainGroup::kSrvAndTxt, child->name());
551 for (auto& endpoint_or_error : child_endpoints) {
552 endpoints.push_back(std::move(endpoint_or_error));
553 }
554 }
555 return endpoints;
556 }
557
558 } // namespace
559
560 DnsDataGraph::~DnsDataGraph() = default;
561
562 // static
Create(NetworkInterfaceIndex network_interface)563 std::unique_ptr<DnsDataGraph> DnsDataGraph::Create(
564 NetworkInterfaceIndex network_interface) {
565 return std::make_unique<DnsDataGraphImpl>(network_interface);
566 }
567
568 // static
GetDomainGroup(DnsType type)569 DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(DnsType type) {
570 switch (type) {
571 case DnsType::kA:
572 case DnsType::kAAAA:
573 return DnsDataGraphImpl::DomainGroup::kAddress;
574 case DnsType::kSRV:
575 case DnsType::kTXT:
576 return DnsDataGraphImpl::DomainGroup::kSrvAndTxt;
577 case DnsType::kPTR:
578 return DnsDataGraphImpl::DomainGroup::kPtr;
579 default:
580 OSP_NOTREACHED();
581 }
582 }
583
584 // static
GetDomainGroup(const MdnsRecord record)585 DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(
586 const MdnsRecord record) {
587 return GetDomainGroup(record.dns_type());
588 }
589
590 } // namespace discovery
591 } // namespace openscreen
592