1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/dnssd/impl/service_dispatcher.h"
6
7 #include <utility>
8
9 #include "discovery/common/config.h"
10 #include "discovery/dnssd/impl/service_instance.h"
11 #include "discovery/dnssd/public/dns_sd_instance.h"
12 #include "discovery/mdns/public/mdns_service.h"
13 #include "platform/api/serial_delete_ptr.h"
14 #include "platform/api/task_runner.h"
15 #include "util/trace_logging.h"
16
17 namespace openscreen {
18 namespace discovery {
19 namespace {
20
ForAllQueriers(std::vector<std::unique_ptr<ServiceInstance>> * service_instances,std::function<void (DnsSdQuerier *)> action)21 void ForAllQueriers(
22 std::vector<std::unique_ptr<ServiceInstance>>* service_instances,
23 std::function<void(DnsSdQuerier*)> action) {
24 for (auto& service_instance : *service_instances) {
25 auto* querier = service_instance->GetQuerier();
26 OSP_CHECK(querier);
27
28 action(querier);
29 }
30 }
31
ForAllPublishers(std::vector<std::unique_ptr<ServiceInstance>> * service_instances,std::function<Error (DnsSdPublisher *)> action,const char * operation)32 Error ForAllPublishers(
33 std::vector<std::unique_ptr<ServiceInstance>>* service_instances,
34 std::function<Error(DnsSdPublisher*)> action,
35 const char* operation) {
36 Error result = Error::None();
37 for (auto& service_instance : *service_instances) {
38 auto* publisher = service_instance->GetPublisher();
39 OSP_CHECK(publisher);
40
41 TRACE_SCOPED(TraceCategory::kDiscovery, operation);
42 Error inner_result = action(publisher);
43 TRACE_SET_RESULT(inner_result);
44 if (!inner_result.ok()) {
45 result = std::move(inner_result);
46 }
47 }
48 return result;
49 }
50
51 } // namespace
52
53 // static
CreateDnsSdService(TaskRunner * task_runner,ReportingClient * reporting_client,const Config & config)54 SerialDeletePtr<DnsSdService> CreateDnsSdService(
55 TaskRunner* task_runner,
56 ReportingClient* reporting_client,
57 const Config& config) {
58 return SerialDeletePtr<DnsSdService>(
59 task_runner,
60 new ServiceDispatcher(task_runner, reporting_client, config));
61 }
62
ServiceDispatcher(TaskRunner * task_runner,ReportingClient * reporting_client,const Config & config)63 ServiceDispatcher::ServiceDispatcher(TaskRunner* task_runner,
64 ReportingClient* reporting_client,
65 const Config& config)
66 : task_runner_(task_runner),
67 publisher_(config.enable_publication ? this : nullptr),
68 querier_(config.enable_querying ? this : nullptr) {
69 OSP_DCHECK_GT(config.network_info.size(), 0);
70 OSP_DCHECK(task_runner);
71
72 service_instances_.reserve(config.network_info.size());
73 for (const auto& network_info : config.network_info) {
74 service_instances_.push_back(std::make_unique<ServiceInstance>(
75 task_runner_, reporting_client, config, network_info));
76 }
77 }
78
~ServiceDispatcher()79 ServiceDispatcher::~ServiceDispatcher() {
80 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
81 }
82
83 // DnsSdQuerier overrides.
StartQuery(const std::string & service,Callback * cb)84 void ServiceDispatcher::StartQuery(const std::string& service, Callback* cb) {
85 TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
86 auto start_query = [&service, cb](DnsSdQuerier* querier) {
87 querier->StartQuery(service, cb);
88 };
89 ForAllQueriers(&service_instances_, std::move(start_query));
90 }
91
StopQuery(const std::string & service,Callback * cb)92 void ServiceDispatcher::StopQuery(const std::string& service, Callback* cb) {
93 TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
94 auto stop_query = [&service, cb](DnsSdQuerier* querier) {
95 querier->StopQuery(service, cb);
96 };
97 ForAllQueriers(&service_instances_, std::move(stop_query));
98 }
99
ReinitializeQueries(const std::string & service)100 void ServiceDispatcher::ReinitializeQueries(const std::string& service) {
101 TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
102 auto reinitialize_queries = [&service](DnsSdQuerier* querier) {
103 querier->ReinitializeQueries(service);
104 };
105 ForAllQueriers(&service_instances_, std::move(reinitialize_queries));
106 }
107
108 // DnsSdPublisher overrides.
Register(const DnsSdInstance & instance,Client * client)109 Error ServiceDispatcher::Register(const DnsSdInstance& instance,
110 Client* client) {
111 TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
112 auto register_instance = [&instance, client](DnsSdPublisher* publisher) {
113 return publisher->Register(instance, client);
114 };
115 return ForAllPublishers(&service_instances_, std::move(register_instance),
116 "DNS-SD.Register");
117 }
118
UpdateRegistration(const DnsSdInstance & instance)119 Error ServiceDispatcher::UpdateRegistration(const DnsSdInstance& instance) {
120 TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
121 auto update_registration = [&instance](DnsSdPublisher* publisher) {
122 return publisher->UpdateRegistration(instance);
123 };
124 return ForAllPublishers(&service_instances_, std::move(update_registration),
125 "DNS-SD.UpdateRegistration");
126 }
127
DeregisterAll(const std::string & service)128 ErrorOr<int> ServiceDispatcher::DeregisterAll(const std::string& service) {
129 TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
130 int total = 0;
131 Error failure = Error::None();
132 for (auto& service_instance : service_instances_) {
133 auto* publisher = service_instance->GetPublisher();
134 OSP_CHECK(publisher);
135
136 TRACE_SCOPED(TraceCategory::kDiscovery, "DNS-SD.DeregisterAll");
137 auto result = publisher->DeregisterAll(service);
138 if (result.is_error()) {
139 TRACE_SET_RESULT(result.error());
140 failure = std::move(result.error());
141 } else {
142 total += result.value();
143 }
144 }
145
146 if (!failure.ok()) {
147 return failure;
148 } else {
149 return total;
150 }
151 }
152
153 } // namespace discovery
154 } // namespace openscreen
155