1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/dnssd/impl/publisher_impl.h"
6
7 #include <map>
8 #include <string>
9 #include <utility>
10 #include <vector>
11
12 #include "absl/types/optional.h"
13 #include "discovery/common/reporting_client.h"
14 #include "discovery/dnssd/impl/conversion_layer.h"
15 #include "discovery/dnssd/impl/instance_key.h"
16 #include "discovery/dnssd/impl/network_interface_config.h"
17 #include "discovery/mdns/public/mdns_constants.h"
18 #include "platform/api/task_runner.h"
19 #include "platform/base/error.h"
20 #include "util/trace_logging.h"
21
22 namespace openscreen {
23 namespace discovery {
24 namespace {
25
CreateEndpoint(DnsSdInstance instance,InstanceKey key,const NetworkInterfaceConfig & network_config)26 DnsSdInstanceEndpoint CreateEndpoint(
27 DnsSdInstance instance,
28 InstanceKey key,
29 const NetworkInterfaceConfig& network_config) {
30 std::vector<IPEndpoint> endpoints;
31 if (network_config.HasAddressV4()) {
32 endpoints.push_back({network_config.address_v4(), instance.port()});
33 }
34 if (network_config.HasAddressV6()) {
35 endpoints.push_back({network_config.address_v6(), instance.port()});
36 }
37 return DnsSdInstanceEndpoint(
38 key.instance_id(), key.service_id(), key.domain_id(), instance.txt(),
39 network_config.network_interface(), std::move(endpoints));
40 }
41
UpdateDomain(const DomainName & name,DnsSdInstance instance,const NetworkInterfaceConfig & network_config)42 DnsSdInstanceEndpoint UpdateDomain(
43 const DomainName& name,
44 DnsSdInstance instance,
45 const NetworkInterfaceConfig& network_config) {
46 return CreateEndpoint(std::move(instance), InstanceKey(name), network_config);
47 }
48
CreateEndpoint(DnsSdInstance instance,const NetworkInterfaceConfig & network_config)49 DnsSdInstanceEndpoint CreateEndpoint(
50 DnsSdInstance instance,
51 const NetworkInterfaceConfig& network_config) {
52 InstanceKey key(instance);
53 return CreateEndpoint(std::move(instance), std::move(key), network_config);
54 }
55
56 template <typename T>
FindKey(std::map<DnsSdInstance,T> * instances,const InstanceKey & key)57 inline typename std::map<DnsSdInstance, T>::iterator FindKey(
58 std::map<DnsSdInstance, T>* instances,
59 const InstanceKey& key) {
60 return std::find_if(instances->begin(), instances->end(),
61 [&key](const std::pair<DnsSdInstance, T>& pair) {
62 return key == InstanceKey(pair.first);
63 });
64 }
65
66 template <typename T>
EraseInstancesWithServiceId(std::map<DnsSdInstance,T> * instances,const std::string & service_id)67 int EraseInstancesWithServiceId(std::map<DnsSdInstance, T>* instances,
68 const std::string& service_id) {
69 int removed_count = 0;
70 for (auto it = instances->begin(); it != instances->end();) {
71 if (it->first.service_id() == service_id) {
72 removed_count++;
73 it = instances->erase(it);
74 } else {
75 it++;
76 }
77 }
78
79 return removed_count;
80 }
81
82 } // namespace
83
PublisherImpl(MdnsService * publisher,ReportingClient * reporting_client,TaskRunner * task_runner,const NetworkInterfaceConfig * network_config)84 PublisherImpl::PublisherImpl(MdnsService* publisher,
85 ReportingClient* reporting_client,
86 TaskRunner* task_runner,
87 const NetworkInterfaceConfig* network_config)
88 : mdns_publisher_(publisher),
89 reporting_client_(reporting_client),
90 task_runner_(task_runner),
91 network_config_(network_config) {
92 OSP_DCHECK(mdns_publisher_);
93 OSP_DCHECK(reporting_client_);
94 OSP_DCHECK(task_runner_);
95 }
96
97 PublisherImpl::~PublisherImpl() = default;
98
Register(const DnsSdInstance & instance,Client * client)99 Error PublisherImpl::Register(const DnsSdInstance& instance, Client* client) {
100 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
101 OSP_DCHECK(client != nullptr);
102
103 if (published_instances_.find(instance) != published_instances_.end()) {
104 UpdateRegistration(instance);
105 } else if (pending_instances_.find(instance) != pending_instances_.end()) {
106 return Error::Code::kOperationInProgress;
107 }
108
109 InstanceKey key(instance);
110 const IPAddress& address = network_config_->GetAddress();
111 OSP_DCHECK(address);
112 pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
113 client);
114
115 OSP_DVLOG << "Registering instance '" << instance.instance_id() << "'";
116
117 return mdns_publisher_->StartProbe(this, GetDomainName(key), address);
118 }
119
UpdateRegistration(const DnsSdInstance & instance)120 Error PublisherImpl::UpdateRegistration(const DnsSdInstance& instance) {
121 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
122
123 // Check if the instance is still pending publication.
124 auto it = FindKey(&pending_instances_, InstanceKey(instance));
125
126 OSP_DVLOG << "Updating instance '" << instance.instance_id() << "'";
127
128 // If it is a pending instance, update it. Else, try to update a published
129 // instance.
130 if (it != pending_instances_.end()) {
131 // The instance, service, and domain ids have not changed, so only the
132 // remaining data needs to change. The ongoing probe does not need to be
133 // modified.
134 Client* const client = it->second;
135 pending_instances_.erase(it);
136 pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
137 client);
138 return Error::None();
139 } else {
140 return UpdatePublishedRegistration(instance);
141 }
142 }
143
UpdatePublishedRegistration(const DnsSdInstance & instance)144 Error PublisherImpl::UpdatePublishedRegistration(
145 const DnsSdInstance& instance) {
146 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
147
148 auto published_instance_it =
149 FindKey(&published_instances_, InstanceKey(instance));
150
151 // Check preconditions called out in header. Specifically, the updated
152 // instance must be making changes to an already published instance.
153 if (published_instance_it == published_instances_.end()) {
154 return Error::Code::kParameterInvalid;
155 }
156
157 const DnsSdInstanceEndpoint updated_endpoint =
158 UpdateDomain(GetDomainName(InstanceKey(published_instance_it->second)),
159 instance, *network_config_);
160 if (published_instance_it->second == updated_endpoint) {
161 return Error::Code::kParameterInvalid;
162 }
163
164 // Get all instances which have changed. By design, there an only be one
165 // instance of each DnsType, so use that here to simplify this step. First in
166 // each pair is the old instances, second is the new instance.
167 std::map<DnsType,
168 std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>>>
169 changed_records;
170 const std::vector<MdnsRecord> old_records =
171 GetDnsRecords(published_instance_it->second);
172 const std::vector<MdnsRecord> new_records = GetDnsRecords(updated_endpoint);
173
174 // Populate the first part of each pair in |changed_instances|.
175 for (size_t i = 0; i < old_records.size(); i++) {
176 const auto key = old_records[i].dns_type();
177 OSP_DCHECK(changed_records.find(key) == changed_records.end());
178 auto value = std::make_pair(std::move(old_records[i]), absl::nullopt);
179 changed_records.emplace(key, std::move(value));
180 }
181
182 // Populate the second part of each pair in |changed_records|.
183 for (size_t i = 0; i < new_records.size(); i++) {
184 const auto key = new_records[i].dns_type();
185 auto find_it = changed_records.find(key);
186 if (find_it == changed_records.end()) {
187 std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>> value(
188 absl::nullopt, std::move(new_records[i]));
189 changed_records.emplace(key, std::move(value));
190 } else {
191 find_it->second.second = std::move(new_records[i]);
192 }
193 }
194
195 // Apply changes called out in |changed_records|.
196 Error total_result = Error::None();
197 for (const auto& pair : changed_records) {
198 OSP_DCHECK(pair.second.first != absl::nullopt ||
199 pair.second.second != absl::nullopt);
200 if (pair.second.first == absl::nullopt) {
201 TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
202 auto error = mdns_publisher_->RegisterRecord(pair.second.second.value());
203 TRACE_SET_RESULT(error);
204 if (!error.ok()) {
205 total_result = error;
206 }
207 } else if (pair.second.second == absl::nullopt) {
208 TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
209 auto error = mdns_publisher_->UnregisterRecord(pair.second.first.value());
210 TRACE_SET_RESULT(error);
211 if (!error.ok()) {
212 total_result = error;
213 }
214 } else if (pair.second.first.value() != pair.second.second.value()) {
215 TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UpdateRegisteredRecord");
216 auto error = mdns_publisher_->UpdateRegisteredRecord(
217 pair.second.first.value(), pair.second.second.value());
218 TRACE_SET_RESULT(error);
219 if (!error.ok()) {
220 total_result = error;
221 }
222 }
223 }
224
225 // Replace the old instances with the new ones.
226 published_instances_.erase(published_instance_it);
227 published_instances_.emplace(instance, std::move(updated_endpoint));
228
229 return total_result;
230 }
231
DeregisterAll(const std::string & service)232 ErrorOr<int> PublisherImpl::DeregisterAll(const std::string& service) {
233 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
234
235 OSP_DVLOG << "Deregistering all instances";
236
237 int removed_count = 0;
238 Error error = Error::None();
239 for (auto it = published_instances_.begin();
240 it != published_instances_.end();) {
241 if (it->second.service_id() == service) {
242 for (const auto& mdns_record : GetDnsRecords(it->second)) {
243 TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
244 auto publisher_error = mdns_publisher_->UnregisterRecord(mdns_record);
245 TRACE_SET_RESULT(error);
246 if (!publisher_error.ok()) {
247 error = publisher_error;
248 }
249 }
250 removed_count++;
251 it = published_instances_.erase(it);
252 } else {
253 it++;
254 }
255 }
256
257 removed_count += EraseInstancesWithServiceId(&pending_instances_, service);
258
259 if (!error.ok()) {
260 return error;
261 } else {
262 return removed_count;
263 }
264 }
265
OnDomainFound(const DomainName & requested_name,const DomainName & confirmed_name)266 void PublisherImpl::OnDomainFound(const DomainName& requested_name,
267 const DomainName& confirmed_name) {
268 TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
269 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
270
271 OSP_DVLOG << "Domain successfully claimed: '" << confirmed_name.ToString()
272 << "' based on requested name: '" << requested_name.ToString()
273 << "'";
274
275 auto it = FindKey(&pending_instances_, InstanceKey(requested_name));
276
277 if (it == pending_instances_.end()) {
278 // This will be hit if the instance was deregister'd before the probe phase
279 // was completed.
280 return;
281 }
282
283 DnsSdInstance requested_instance = std::move(it->first);
284 DnsSdInstanceEndpoint endpoint =
285 CreateEndpoint(requested_instance, *network_config_);
286 Client* const client = it->second;
287 pending_instances_.erase(it);
288
289 InstanceKey requested_key(requested_instance);
290
291 if (requested_name != confirmed_name) {
292 OSP_DCHECK(HasValidDnsRecordAddress(confirmed_name));
293 endpoint =
294 UpdateDomain(confirmed_name, requested_instance, *network_config_);
295 }
296
297 for (const auto& mdns_record : GetDnsRecords(endpoint)) {
298 TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
299 Error result = mdns_publisher_->RegisterRecord(mdns_record);
300 if (!result.ok()) {
301 reporting_client_->OnRecoverableError(
302 Error(Error::Code::kRecordPublicationError, result.ToString()));
303 }
304 }
305
306 auto pair = published_instances_.emplace(std::move(requested_instance),
307 std::move(endpoint));
308 client->OnEndpointClaimed(pair.first->first, pair.first->second);
309 }
310
311 } // namespace discovery
312 } // namespace openscreen
313