• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/dnssd/impl/service_dispatcher.h"
6 
7 #include <utility>
8 
9 #include "discovery/common/config.h"
10 #include "discovery/dnssd/impl/service_instance.h"
11 #include "discovery/dnssd/public/dns_sd_instance.h"
12 #include "discovery/mdns/public/mdns_service.h"
13 #include "platform/api/serial_delete_ptr.h"
14 #include "platform/api/task_runner.h"
15 #include "util/trace_logging.h"
16 
17 namespace openscreen {
18 namespace discovery {
19 namespace {
20 
ForAllQueriers(std::vector<std::unique_ptr<ServiceInstance>> * service_instances,std::function<void (DnsSdQuerier *)> action)21 void ForAllQueriers(
22     std::vector<std::unique_ptr<ServiceInstance>>* service_instances,
23     std::function<void(DnsSdQuerier*)> action) {
24   for (auto& service_instance : *service_instances) {
25     auto* querier = service_instance->GetQuerier();
26     OSP_CHECK(querier);
27 
28     action(querier);
29   }
30 }
31 
ForAllPublishers(std::vector<std::unique_ptr<ServiceInstance>> * service_instances,std::function<Error (DnsSdPublisher *)> action,const char * operation)32 Error ForAllPublishers(
33     std::vector<std::unique_ptr<ServiceInstance>>* service_instances,
34     std::function<Error(DnsSdPublisher*)> action,
35     const char* operation) {
36   Error result = Error::None();
37   for (auto& service_instance : *service_instances) {
38     auto* publisher = service_instance->GetPublisher();
39     OSP_CHECK(publisher);
40 
41     TRACE_SCOPED(TraceCategory::kDiscovery, operation);
42     Error inner_result = action(publisher);
43     TRACE_SET_RESULT(inner_result);
44     if (!inner_result.ok()) {
45       result = std::move(inner_result);
46     }
47   }
48   return result;
49 }
50 
51 }  // namespace
52 
53 // static
CreateDnsSdService(TaskRunner * task_runner,ReportingClient * reporting_client,const Config & config)54 SerialDeletePtr<DnsSdService> CreateDnsSdService(
55     TaskRunner* task_runner,
56     ReportingClient* reporting_client,
57     const Config& config) {
58   return SerialDeletePtr<DnsSdService>(
59       task_runner,
60       new ServiceDispatcher(task_runner, reporting_client, config));
61 }
62 
ServiceDispatcher(TaskRunner * task_runner,ReportingClient * reporting_client,const Config & config)63 ServiceDispatcher::ServiceDispatcher(TaskRunner* task_runner,
64                                      ReportingClient* reporting_client,
65                                      const Config& config)
66     : task_runner_(task_runner),
67       publisher_(config.enable_publication ? this : nullptr),
68       querier_(config.enable_querying ? this : nullptr) {
69   OSP_DCHECK_GT(config.network_info.size(), 0);
70   OSP_DCHECK(task_runner);
71 
72   service_instances_.reserve(config.network_info.size());
73   for (const auto& network_info : config.network_info) {
74     service_instances_.push_back(std::make_unique<ServiceInstance>(
75         task_runner_, reporting_client, config, network_info));
76   }
77 }
78 
~ServiceDispatcher()79 ServiceDispatcher::~ServiceDispatcher() {
80   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
81 }
82 
83 // DnsSdQuerier overrides.
StartQuery(const std::string & service,Callback * cb)84 void ServiceDispatcher::StartQuery(const std::string& service, Callback* cb) {
85   TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
86   auto start_query = [&service, cb](DnsSdQuerier* querier) {
87     querier->StartQuery(service, cb);
88   };
89   ForAllQueriers(&service_instances_, std::move(start_query));
90 }
91 
StopQuery(const std::string & service,Callback * cb)92 void ServiceDispatcher::StopQuery(const std::string& service, Callback* cb) {
93   TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
94   auto stop_query = [&service, cb](DnsSdQuerier* querier) {
95     querier->StopQuery(service, cb);
96   };
97   ForAllQueriers(&service_instances_, std::move(stop_query));
98 }
99 
ReinitializeQueries(const std::string & service)100 void ServiceDispatcher::ReinitializeQueries(const std::string& service) {
101   TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
102   auto reinitialize_queries = [&service](DnsSdQuerier* querier) {
103     querier->ReinitializeQueries(service);
104   };
105   ForAllQueriers(&service_instances_, std::move(reinitialize_queries));
106 }
107 
108 // DnsSdPublisher overrides.
Register(const DnsSdInstance & instance,Client * client)109 Error ServiceDispatcher::Register(const DnsSdInstance& instance,
110                                   Client* client) {
111   TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
112   auto register_instance = [&instance, client](DnsSdPublisher* publisher) {
113     return publisher->Register(instance, client);
114   };
115   return ForAllPublishers(&service_instances_, std::move(register_instance),
116                           "DNS-SD.Register");
117 }
118 
UpdateRegistration(const DnsSdInstance & instance)119 Error ServiceDispatcher::UpdateRegistration(const DnsSdInstance& instance) {
120   TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
121   auto update_registration = [&instance](DnsSdPublisher* publisher) {
122     return publisher->UpdateRegistration(instance);
123   };
124   return ForAllPublishers(&service_instances_, std::move(update_registration),
125                           "DNS-SD.UpdateRegistration");
126 }
127 
DeregisterAll(const std::string & service)128 ErrorOr<int> ServiceDispatcher::DeregisterAll(const std::string& service) {
129   TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
130   int total = 0;
131   Error failure = Error::None();
132   for (auto& service_instance : service_instances_) {
133     auto* publisher = service_instance->GetPublisher();
134     OSP_CHECK(publisher);
135 
136     TRACE_SCOPED(TraceCategory::kDiscovery, "DNS-SD.DeregisterAll");
137     auto result = publisher->DeregisterAll(service);
138     if (result.is_error()) {
139       TRACE_SET_RESULT(result.error());
140       failure = std::move(result.error());
141     } else {
142       total += result.value();
143     }
144   }
145 
146   if (!failure.ok()) {
147     return failure;
148   } else {
149     return total;
150   }
151 }
152 
153 }  // namespace discovery
154 }  // namespace openscreen
155