1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/dnssd/impl/querier_impl.h"
6
7 #include <algorithm>
8 #include <string>
9 #include <utility>
10 #include <vector>
11
12 #include "discovery/common/reporting_client.h"
13 #include "discovery/dnssd/impl/conversion_layer.h"
14 #include "discovery/dnssd/impl/network_interface_config.h"
15 #include "platform/api/task_runner.h"
16 #include "util/osp_logging.h"
17
18 namespace openscreen {
19 namespace discovery {
20 namespace {
21
22 static constexpr char kLocalDomain[] = "local";
23
24 // Removes all error instances from the below records, and calls the log
25 // function on all errors present in |new_endpoints|. Input vectors are expected
26 // to be sorted in ascending order.
ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>> * old_endpoints,std::vector<ErrorOr<DnsSdInstanceEndpoint>> * new_endpoints,std::function<void (Error)> log)27 void ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>>* old_endpoints,
28 std::vector<ErrorOr<DnsSdInstanceEndpoint>>* new_endpoints,
29 std::function<void(Error)> log) {
30 OSP_DCHECK(old_endpoints);
31 OSP_DCHECK(new_endpoints);
32
33 auto old_it = old_endpoints->begin();
34 auto new_it = new_endpoints->begin();
35
36 // Iterate across both vectors and log new errors in the process.
37 // NOTE: In sorted order, all errors will appear before all non-errors.
38 while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) {
39 ErrorOr<DnsSdInstanceEndpoint>& old_ep = *old_it;
40 ErrorOr<DnsSdInstanceEndpoint>& new_ep = *new_it;
41
42 if (new_ep.is_value()) {
43 break;
44 }
45
46 // If they are equal, the element is in both |old_endpoints| and
47 // |new_endpoints|, so skip it in both vectors.
48 if (old_ep == new_ep) {
49 old_it++;
50 new_it++;
51 continue;
52 }
53
54 // There's an error in |old_endpoints| not in |new_endpoints|, so skip it.
55 if (old_ep < new_ep) {
56 old_it++;
57 continue;
58 }
59
60 // There's an error in |new_endpoints| not in |old_endpoints|, so it's a new
61 // error from the applied changes. Log it.
62 log(std::move(new_ep.error()));
63 new_it++;
64 }
65
66 // Skip all remaining errors in the old vector.
67 for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) {
68 }
69
70 // Log all errors remaining in the new vector.
71 for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) {
72 log(std::move(new_it->error()));
73 }
74
75 // Erase errors.
76 old_endpoints->erase(old_endpoints->begin(), old_it);
77 new_endpoints->erase(new_endpoints->begin(), new_it);
78 }
79
80 // Returns a vector containing the value of each ErrorOr<> instance provided.
81 // All ErrorOr<> values are expected to be non-errors.
GetValues(std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints)82 std::vector<DnsSdInstanceEndpoint> GetValues(
83 std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints) {
84 std::vector<DnsSdInstanceEndpoint> results;
85 results.reserve(endpoints.size());
86 for (ErrorOr<DnsSdInstanceEndpoint>& endpoint : endpoints) {
87 results.push_back(std::move(endpoint.value()));
88 }
89 return results;
90 }
91
IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint> & first,const absl::optional<DnsSdInstanceEndpoint> & second)92 bool IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
93 const absl::optional<DnsSdInstanceEndpoint>& second) {
94 if (!first.has_value() || !second.has_value()) {
95 return !first.has_value() && !second.has_value();
96 }
97
98 // In the remaining case, both |first| and |second| must be values.
99 const DnsSdInstanceEndpoint& a = first.value();
100 const DnsSdInstanceEndpoint& b = second.value();
101
102 // All endpoints from this querier should have the same network interface
103 // because the querier is only associated with a single network interface.
104 OSP_DCHECK_EQ(a.network_interface(), b.network_interface());
105
106 // Function returns true if first < second.
107 return a.instance_id() == b.instance_id() &&
108 a.service_id() == b.service_id() && a.domain_id() == b.domain_id();
109 }
110
IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint> & first,const absl::optional<DnsSdInstanceEndpoint> & second)111 bool IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
112 const absl::optional<DnsSdInstanceEndpoint>& second) {
113 return !IsEqualOrUpdate(first, second);
114 }
115
116 // Calculates the created, updated, and deleted elements using the provided
117 // sets, appending these values to the provided vectors. Each of the input
118 // vectors is expected to contain only elements such that
119 // |element|.is_error() == false. Additionally, input vectors are expected to
120 // be sorted in ascending order.
121 //
122 // NOTE: A lot of operations are used to do this, but each is only O(n) so the
123 // resulting algorithm is still fast.
CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,std::vector<DnsSdInstanceEndpoint> new_endpoints,std::vector<DnsSdInstanceEndpoint> * created_out,std::vector<DnsSdInstanceEndpoint> * updated_out,std::vector<DnsSdInstanceEndpoint> * deleted_out)124 void CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,
125 std::vector<DnsSdInstanceEndpoint> new_endpoints,
126 std::vector<DnsSdInstanceEndpoint>* created_out,
127 std::vector<DnsSdInstanceEndpoint>* updated_out,
128 std::vector<DnsSdInstanceEndpoint>* deleted_out) {
129 OSP_DCHECK(created_out);
130 OSP_DCHECK(updated_out);
131 OSP_DCHECK(deleted_out);
132
133 // Use set difference with default operators to find the elements present in
134 // one list but not the others.
135 //
136 // NOTE: Because absl::optional<...> types are used here and below, calls to
137 // the ctor and dtor for empty elements are no-ops.
138 const int total_count = old_endpoints.size() + new_endpoints.size();
139
140 // This is the set of elements that aren't in the old endpoints, meaning the
141 // old endpoint either didn't exist or had different TXT / Address / etc..
142 std::vector<absl::optional<DnsSdInstanceEndpoint>> created_or_updated(
143 total_count);
144 auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(),
145 old_endpoints.begin(), old_endpoints.end(),
146 created_or_updated.begin());
147 created_or_updated.erase(new_end, created_or_updated.end());
148
149 // This is the set of elements that are only in the old endpoints, similar to
150 // the above.
151 std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted_or_updated(
152 total_count);
153 new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(),
154 new_endpoints.begin(), new_endpoints.end(),
155 deleted_or_updated.begin());
156 deleted_or_updated.erase(new_end, deleted_or_updated.end());
157
158 // Next, find the elements which were updated.
159 const size_t max_count =
160 std::max(created_or_updated.size(), deleted_or_updated.size());
161 std::vector<absl::optional<DnsSdInstanceEndpoint>> updated(max_count);
162 new_end = std::set_intersection(
163 created_or_updated.begin(), created_or_updated.end(),
164 deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
165 IsNotEqualOrUpdate);
166 updated.erase(new_end, updated.end());
167
168 // Use the updated elements to find all created and deleted elements.
169 std::vector<absl::optional<DnsSdInstanceEndpoint>> created(
170 created_or_updated.size());
171 new_end = std::set_difference(
172 created_or_updated.begin(), created_or_updated.end(), updated.begin(),
173 updated.end(), created.begin(), IsNotEqualOrUpdate);
174 created.erase(new_end, created.end());
175
176 std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted(
177 deleted_or_updated.size());
178 new_end = std::set_difference(
179 deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
180 updated.end(), deleted.begin(), IsNotEqualOrUpdate);
181 deleted.erase(new_end, deleted.end());
182
183 // Return the calculated elements back to the caller in the output variables.
184 created_out->reserve(created.size());
185 for (absl::optional<DnsSdInstanceEndpoint>& endpoint : created) {
186 OSP_DCHECK(endpoint.has_value());
187 created_out->push_back(std::move(endpoint.value()));
188 }
189
190 updated_out->reserve(updated.size());
191 for (absl::optional<DnsSdInstanceEndpoint>& endpoint : updated) {
192 OSP_DCHECK(endpoint.has_value());
193 updated_out->push_back(std::move(endpoint.value()));
194 }
195
196 deleted_out->reserve(deleted.size());
197 for (absl::optional<DnsSdInstanceEndpoint>& endpoint : deleted) {
198 OSP_DCHECK(endpoint.has_value());
199 deleted_out->push_back(std::move(endpoint.value()));
200 }
201 }
202
203 } // namespace
204
QuerierImpl(MdnsService * mdns_querier,TaskRunner * task_runner,ReportingClient * reporting_client,const NetworkInterfaceConfig * network_config)205 QuerierImpl::QuerierImpl(MdnsService* mdns_querier,
206 TaskRunner* task_runner,
207 ReportingClient* reporting_client,
208 const NetworkInterfaceConfig* network_config)
209 : mdns_querier_(mdns_querier),
210 task_runner_(task_runner),
211 reporting_client_(reporting_client) {
212 OSP_DCHECK(mdns_querier_);
213 OSP_DCHECK(task_runner_);
214
215 OSP_DCHECK(network_config);
216 graph_ = DnsDataGraph::Create(network_config->network_interface());
217 }
218
219 QuerierImpl::~QuerierImpl() = default;
220
StartQuery(const std::string & service,Callback * callback)221 void QuerierImpl::StartQuery(const std::string& service, Callback* callback) {
222 OSP_DCHECK(callback);
223 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
224
225 OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'";
226
227 // Start tracking the new callback
228 const ServiceKey key(service, kLocalDomain);
229 auto it = callback_map_.emplace(key, std::vector<Callback*>{}).first;
230 it->second.push_back(callback);
231
232 const DomainName domain = key.GetName();
233
234 // If the associated service isn't tracked yet, start tracking it and start
235 // queries for the relevant PTR records.
236 if (!graph_->IsTracked(domain)) {
237 std::function<void(const DomainName&)> mdns_query(
238 [this, &domain](const DomainName& changed_domain) {
239 OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'";
240 mdns_querier_->StartQuery(changed_domain, DnsType::kANY,
241 DnsClass::kANY, this);
242 });
243 graph_->StartTracking(domain, std::move(mdns_query));
244 return;
245 }
246
247 // Else, it's already being tracked so fire creation callbacks for any already
248 // found service instances.
249 const std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints =
250 graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain);
251 for (const auto& endpoint : endpoints) {
252 if (endpoint.is_value()) {
253 callback->OnEndpointCreated(endpoint.value());
254 }
255 }
256 }
257
StopQuery(const std::string & service,Callback * callback)258 void QuerierImpl::StopQuery(const std::string& service, Callback* callback) {
259 OSP_DCHECK(callback);
260 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
261
262 OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'";
263
264 ServiceKey key(service, kLocalDomain);
265 const auto callbacks_it = callback_map_.find(key);
266 if (callbacks_it == callback_map_.end()) {
267 return;
268 }
269
270 std::vector<Callback*>& callbacks = callbacks_it->second;
271 const auto it = std::find(callbacks.begin(), callbacks.end(), callback);
272 if (it == callbacks.end()) {
273 return;
274 }
275
276 callbacks.erase(it);
277 if (callbacks.empty()) {
278 callback_map_.erase(callbacks_it);
279
280 ServiceKey key(service, kLocalDomain);
281 DomainName domain = key.GetName();
282
283 std::function<void(const DomainName&)> stop_mdns_query(
284 [this](const DomainName& changed_domain) {
285 OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString()
286 << "'";
287 mdns_querier_->StopQuery(changed_domain, DnsType::kANY,
288 DnsClass::kANY, this);
289 });
290 graph_->StopTracking(domain, std::move(stop_mdns_query));
291 }
292 }
293
IsQueryRunning(const std::string & service) const294 bool QuerierImpl::IsQueryRunning(const std::string& service) const {
295 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
296 const ServiceKey key(service, kLocalDomain);
297 return graph_->IsTracked(key.GetName());
298 }
299
ReinitializeQueries(const std::string & service)300 void QuerierImpl::ReinitializeQueries(const std::string& service) {
301 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
302
303 OSP_DVLOG << "Re-initializing query for service '" << service << "'";
304
305 const ServiceKey key(service, kLocalDomain);
306 const DomainName domain = key.GetName();
307
308 std::function<void(const DomainName&)> start_callback(
309 [this](const DomainName& domain) {
310 mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this);
311 });
312 std::function<void(const DomainName&)> stop_callback(
313 [this](const DomainName& domain) {
314 mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this);
315 });
316 graph_->StopTracking(domain, std::move(stop_callback));
317
318 // Restart top-level queries.
319 mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name);
320
321 graph_->StartTracking(domain, std::move(start_callback));
322 }
323
OnRecordChanged(const MdnsRecord & record,RecordChangedEvent event)324 std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged(
325 const MdnsRecord& record,
326 RecordChangedEvent event) {
327 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
328
329 OSP_DVLOG << "Record " << record.ToString()
330 << " has received change of type '" << event << "'";
331
332 std::function<void(Error)> log = [this](Error error) mutable {
333 reporting_client_->OnRecoverableError(
334 Error(Error::Code::kProcessReceivedRecordFailure));
335 };
336
337 // Get the details to use for calling CreateEndpoints(). Special case PTR
338 // records to optimize performance.
339 const DomainName& create_endpoints_domain =
340 record.dns_type() != DnsType::kPTR
341 ? record.name()
342 : absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
343 const DnsDataGraph::DomainGroup create_endpoints_group =
344 record.dns_type() != DnsType::kPTR
345 ? DnsDataGraph::GetDomainGroup(record)
346 : DnsDataGraph::DomainGroup::kSrvAndTxt;
347
348 // Get the current set of DnsSdInstanceEndpoints prior to this change. Special
349 // case PTR records to avoid iterating over unrelated child domains.
350 std::vector<ErrorOr<DnsSdInstanceEndpoint>> old_endpoints_or_errors =
351 graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
352
353 // Apply the changes, creating a list of all pending changes that should be
354 // applied afterwards.
355 ErrorOr<std::vector<PendingQueryChange>> pending_changes_or_error =
356 ApplyRecordChanges(record, event);
357 if (pending_changes_or_error.is_error()) {
358 OSP_DVLOG << "Failed to apply changes for " << record.dns_type()
359 << " record change of type " << event << " with error "
360 << pending_changes_or_error.error();
361 log(std::move(pending_changes_or_error.error()));
362 return {};
363 }
364 std::vector<PendingQueryChange>& pending_changes =
365 pending_changes_or_error.value();
366
367 // Get the new set of DnsSdInstanceEndpoints following this change.
368 std::vector<ErrorOr<DnsSdInstanceEndpoint>> new_endpoints_or_errors =
369 graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
370
371 // Return early if the resulting sets are equal. This will frequently be the
372 // case, especially when both sets are empty.
373 std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end());
374 std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end());
375 if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() &&
376 std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(),
377 new_endpoints_or_errors.begin())) {
378 return pending_changes;
379 }
380
381 // Log all errors and erase them.
382 ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors,
383 std::move(log));
384 const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size();
385 const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size();
386 std::vector<DnsSdInstanceEndpoint> old_endpoints =
387 GetValues(std::move(old_endpoints_or_errors));
388 std::vector<DnsSdInstanceEndpoint> new_endpoints =
389 GetValues(std::move(new_endpoints_or_errors));
390 OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count);
391 OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count);
392
393 // Calculate the changes and call callbacks.
394 //
395 // NOTE: As the input sets are expected to be small, the generated sets will
396 // also be small.
397 std::vector<DnsSdInstanceEndpoint> created;
398 std::vector<DnsSdInstanceEndpoint> updated;
399 std::vector<DnsSdInstanceEndpoint> deleted;
400 CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints),
401 &created, &updated, &deleted);
402
403 InvokeChangeCallbacks(std::move(created), std::move(updated),
404 std::move(deleted));
405 return pending_changes;
406 }
407
InvokeChangeCallbacks(std::vector<DnsSdInstanceEndpoint> created,std::vector<DnsSdInstanceEndpoint> updated,std::vector<DnsSdInstanceEndpoint> deleted)408 void QuerierImpl::InvokeChangeCallbacks(
409 std::vector<DnsSdInstanceEndpoint> created,
410 std::vector<DnsSdInstanceEndpoint> updated,
411 std::vector<DnsSdInstanceEndpoint> deleted) {
412 // Find an endpoint and use it to create the key, or return if there is none.
413 DnsSdInstanceEndpoint* some_endpoint;
414 if (!created.empty()) {
415 some_endpoint = &created.front();
416 } else if (!updated.empty()) {
417 some_endpoint = &updated.front();
418 } else if (!deleted.empty()) {
419 some_endpoint = &deleted.front();
420 } else {
421 return;
422 }
423 ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id());
424
425 // Find all callbacks.
426 auto it = callback_map_.find(key);
427 if (it == callback_map_.end()) {
428 return;
429 }
430
431 // Call relevant callbacks.
432 std::vector<Callback*>& callbacks = it->second;
433 for (Callback* callback : callbacks) {
434 for (const DnsSdInstanceEndpoint& endpoint : created) {
435 callback->OnEndpointCreated(endpoint);
436 }
437 for (const DnsSdInstanceEndpoint& endpoint : updated) {
438 callback->OnEndpointUpdated(endpoint);
439 }
440 for (const DnsSdInstanceEndpoint& endpoint : deleted) {
441 callback->OnEndpointDeleted(endpoint);
442 }
443 }
444 }
445
ApplyRecordChanges(const MdnsRecord & record,RecordChangedEvent event)446 ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::ApplyRecordChanges(
447 const MdnsRecord& record,
448 RecordChangedEvent event) {
449 std::vector<PendingQueryChange> pending_changes;
450 std::function<void(DomainName)> creation_callback(
451 [this, &pending_changes](DomainName domain) mutable {
452 pending_changes.push_back({std::move(domain), DnsType::kANY,
453 DnsClass::kANY, this,
454 PendingQueryChange::kStartQuery});
455 });
456 std::function<void(DomainName)> deletion_callback(
457 [this, &pending_changes](DomainName domain) mutable {
458 pending_changes.push_back({std::move(domain), DnsType::kANY,
459 DnsClass::kANY, this,
460 PendingQueryChange::kStopQuery});
461 });
462 Error result =
463 graph_->ApplyDataRecordChange(record, event, std::move(creation_callback),
464 std::move(deletion_callback));
465 if (!result.ok()) {
466 return result;
467 }
468
469 return pending_changes;
470 }
471
472 } // namespace discovery
473 } // namespace openscreen
474