• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/utility/local_discovery/service_discovery_message_handler.h"
6 
7 #include <algorithm>
8 
9 #include "base/lazy_instance.h"
10 #include "chrome/common/local_discovery/local_discovery_messages.h"
11 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
12 #include "content/public/utility/utility_thread.h"
13 #include "net/socket/socket_descriptor.h"
14 #include "net/udp/datagram_server_socket.h"
15 
16 namespace local_discovery {
17 
18 namespace {
19 
20 void ClosePlatformSocket(net::SocketDescriptor socket);
21 
22 // Sets socket factory used by |net::CreatePlatformSocket|. Implemetation
23 // keeps single socket that will be returned to the first call to
24 // |net::CreatePlatformSocket| during object lifetime.
25 class ScopedSocketFactory : public net::PlatformSocketFactory {
26  public:
ScopedSocketFactory(net::SocketDescriptor socket)27   explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) {
28     net::PlatformSocketFactory::SetInstance(this);
29   }
30 
~ScopedSocketFactory()31   virtual ~ScopedSocketFactory() {
32     net::PlatformSocketFactory::SetInstance(NULL);
33     ClosePlatformSocket(socket_);
34     socket_ = net::kInvalidSocket;
35   }
36 
CreateSocket(int family,int type,int protocol)37   virtual net::SocketDescriptor CreateSocket(int family, int type,
38                                              int protocol) OVERRIDE {
39     DCHECK_EQ(type, SOCK_DGRAM);
40     DCHECK(family == AF_INET || family == AF_INET6);
41     net::SocketDescriptor result = net::kInvalidSocket;
42     std::swap(result, socket_);
43     return result;
44   }
45 
46  private:
47   net::SocketDescriptor socket_;
48   DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory);
49 };
50 
51 struct SocketInfo {
SocketInfolocal_discovery::__anon25e7e6800111::SocketInfo52   SocketInfo(net::SocketDescriptor socket,
53              net::AddressFamily address_family,
54              uint32 interface_index)
55       : socket(socket),
56         address_family(address_family),
57         interface_index(interface_index) {
58   }
59   net::SocketDescriptor socket;
60   net::AddressFamily address_family;
61   uint32 interface_index;
62 };
63 
64 // Returns list of sockets preallocated before.
65 class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory {
66  public:
PreCreatedMDnsSocketFactory()67   PreCreatedMDnsSocketFactory() {}
~PreCreatedMDnsSocketFactory()68   virtual ~PreCreatedMDnsSocketFactory() {
69     // Not empty if process exits too fast, before starting mDns code. If
70     // happened, destructors may crash accessing destroyed global objects.
71     sockets_.weak_clear();
72   }
73 
74   // net::MDnsSocketFactory implementation:
CreateSockets(ScopedVector<net::DatagramServerSocket> * sockets)75   virtual void CreateSockets(
76       ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
77     sockets->swap(sockets_);
78     Reset();
79   }
80 
AddSocket(const SocketInfo & socket_info)81   void AddSocket(const SocketInfo& socket_info) {
82     // Takes ownership of socket_info.socket;
83     ScopedSocketFactory platform_factory(socket_info.socket);
84     scoped_ptr<net::DatagramServerSocket> socket(
85         net::CreateAndBindMDnsSocket(socket_info.address_family,
86                                      socket_info.interface_index));
87     if (socket) {
88       socket->DetachFromThread();
89       sockets_.push_back(socket.release());
90     }
91   }
92 
Reset()93   void Reset() {
94     sockets_.clear();
95   }
96 
97  private:
98   ScopedVector<net::DatagramServerSocket> sockets_;
99 
100   DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory);
101 };
102 
103 base::LazyInstance<PreCreatedMDnsSocketFactory>
104     g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER;
105 
106 #if defined(OS_WIN)
107 
ClosePlatformSocket(net::SocketDescriptor socket)108 void ClosePlatformSocket(net::SocketDescriptor socket) {
109   ::closesocket(socket);
110 }
111 
StaticInitializeSocketFactory()112 void StaticInitializeSocketFactory() {
113   net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind());
114   for (size_t i = 0; i < interfaces.size(); ++i) {
115     DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 ||
116            interfaces[i].second == net::ADDRESS_FAMILY_IPV6);
117     net::SocketDescriptor descriptor =
118         net::CreatePlatformSocket(
119             net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM,
120                                       IPPROTO_UDP);
121     g_local_discovery_socket_factory.Get().AddSocket(
122         SocketInfo(descriptor, interfaces[i].second, interfaces[i].first));
123   }
124 }
125 
126 #else  // OS_WIN
127 
ClosePlatformSocket(net::SocketDescriptor socket)128 void ClosePlatformSocket(net::SocketDescriptor socket) {
129   ::close(socket);
130 }
131 
StaticInitializeSocketFactory()132 void StaticInitializeSocketFactory() {
133 }
134 
135 #endif  // OS_WIN
136 
SendHostMessageOnUtilityThread(IPC::Message * msg)137 void SendHostMessageOnUtilityThread(IPC::Message* msg) {
138   content::UtilityThread::Get()->Send(msg);
139 }
140 
WatcherUpdateToString(ServiceWatcher::UpdateType update)141 std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) {
142   switch (update) {
143     case ServiceWatcher::UPDATE_ADDED:
144       return "UPDATE_ADDED";
145     case ServiceWatcher::UPDATE_CHANGED:
146       return "UPDATE_CHANGED";
147     case ServiceWatcher::UPDATE_REMOVED:
148       return "UPDATE_REMOVED";
149     case ServiceWatcher::UPDATE_INVALIDATED:
150       return "UPDATE_INVALIDATED";
151   }
152   return "Unknown Update";
153 }
154 
ResolverStatusToString(ServiceResolver::RequestStatus status)155 std::string ResolverStatusToString(ServiceResolver::RequestStatus status) {
156   switch (status) {
157     case ServiceResolver::STATUS_SUCCESS:
158       return "STATUS_SUCESS";
159     case ServiceResolver::STATUS_REQUEST_TIMEOUT:
160       return "STATUS_REQUEST_TIMEOUT";
161     case ServiceResolver::STATUS_KNOWN_NONEXISTENT:
162       return "STATUS_KNOWN_NONEXISTENT";
163   }
164   return "Unknown Status";
165 }
166 
167 }  // namespace
168 
ServiceDiscoveryMessageHandler()169 ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
170 }
171 
~ServiceDiscoveryMessageHandler()172 ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
173   DCHECK(!discovery_thread_);
174 }
175 
PreSandboxStartup()176 void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
177   StaticInitializeSocketFactory();
178 }
179 
InitializeMdns()180 void ServiceDiscoveryMessageHandler::InitializeMdns() {
181   if (service_discovery_client_ || mdns_client_)
182     return;
183 
184   mdns_client_ = net::MDnsClient::CreateDefault();
185   bool result =
186       mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer());
187   // Close unused sockets.
188   g_local_discovery_socket_factory.Get().Reset();
189   if (!result) {
190     VLOG(1) << "Failed to start MDnsClient";
191     Send(new LocalDiscoveryHostMsg_Error());
192     return;
193   }
194 
195   service_discovery_client_.reset(
196       new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get()));
197 }
198 
InitializeThread()199 bool ServiceDiscoveryMessageHandler::InitializeThread() {
200   if (discovery_task_runner_.get())
201     return true;
202   if (discovery_thread_)
203     return false;
204   utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy();
205   discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread"));
206   base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0);
207   if (discovery_thread_->StartWithOptions(thread_options)) {
208     discovery_task_runner_ = discovery_thread_->message_loop_proxy();
209     discovery_task_runner_->PostTask(FROM_HERE,
210         base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns,
211                     base::Unretained(this)));
212   }
213   return discovery_task_runner_.get() != NULL;
214 }
215 
OnMessageReceived(const IPC::Message & message)216 bool ServiceDiscoveryMessageHandler::OnMessageReceived(
217     const IPC::Message& message) {
218   bool handled = true;
219   IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message)
220 #if defined(OS_POSIX)
221     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets)
222 #endif  // OS_POSIX
223     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher)
224     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices)
225     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices,
226                         OnSetActivelyRefreshServices)
227     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher)
228     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService)
229     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver)
230     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain,
231                         OnResolveLocalDomain)
232     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver,
233                         OnDestroyLocalDomainResolver)
234     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery,
235                         ShutdownLocalDiscovery)
236     IPC_MESSAGE_UNHANDLED(handled = false)
237   IPC_END_MESSAGE_MAP()
238   return handled;
239 }
240 
PostTask(const tracked_objects::Location & from_here,const base::Closure & task)241 void ServiceDiscoveryMessageHandler::PostTask(
242     const tracked_objects::Location& from_here,
243     const base::Closure& task) {
244   if (!InitializeThread())
245     return;
246   discovery_task_runner_->PostTask(from_here, task);
247 }
248 
249 #if defined(OS_POSIX)
OnSetSockets(const std::vector<LocalDiscoveryMsg_SocketInfo> & sockets)250 void ServiceDiscoveryMessageHandler::OnSetSockets(
251     const std::vector<LocalDiscoveryMsg_SocketInfo>& sockets) {
252   for (size_t i = 0; i < sockets.size(); ++i) {
253     g_local_discovery_socket_factory.Get().AddSocket(
254         SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family,
255                    sockets[i].interface_index));
256   }
257 }
258 #endif  // OS_POSIX
259 
OnStartWatcher(uint64 id,const std::string & service_type)260 void ServiceDiscoveryMessageHandler::OnStartWatcher(
261     uint64 id,
262     const std::string& service_type) {
263   PostTask(FROM_HERE,
264            base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher,
265                       base::Unretained(this), id, service_type));
266 }
267 
OnDiscoverServices(uint64 id,bool force_update)268 void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id,
269                                                         bool force_update) {
270   PostTask(FROM_HERE,
271            base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices,
272                       base::Unretained(this), id, force_update));
273 }
274 
OnSetActivelyRefreshServices(uint64 id,bool actively_refresh_services)275 void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices(
276     uint64 id, bool actively_refresh_services) {
277   PostTask(FROM_HERE,
278            base::Bind(
279                &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices,
280                base::Unretained(this), id, actively_refresh_services));
281 }
282 
OnDestroyWatcher(uint64 id)283 void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) {
284   PostTask(FROM_HERE,
285            base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher,
286                       base::Unretained(this), id));
287 }
288 
OnResolveService(uint64 id,const std::string & service_name)289 void ServiceDiscoveryMessageHandler::OnResolveService(
290     uint64 id,
291     const std::string& service_name) {
292   PostTask(FROM_HERE,
293            base::Bind(&ServiceDiscoveryMessageHandler::ResolveService,
294                       base::Unretained(this), id, service_name));
295 }
296 
OnDestroyResolver(uint64 id)297 void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) {
298   PostTask(FROM_HERE,
299            base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver,
300                       base::Unretained(this), id));
301 }
302 
OnResolveLocalDomain(uint64 id,const std::string & domain,net::AddressFamily address_family)303 void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
304     uint64 id, const std::string& domain,
305     net::AddressFamily address_family) {
306     PostTask(FROM_HERE,
307            base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain,
308                       base::Unretained(this), id, domain, address_family));
309 }
310 
OnDestroyLocalDomainResolver(uint64 id)311 void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) {
312   PostTask(FROM_HERE,
313            base::Bind(
314                &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver,
315                base::Unretained(this), id));
316 }
317 
StartWatcher(uint64 id,const std::string & service_type)318 void ServiceDiscoveryMessageHandler::StartWatcher(
319     uint64 id,
320     const std::string& service_type) {
321   VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type;
322   if (!service_discovery_client_)
323     return;
324   DCHECK(!ContainsKey(service_watchers_, id));
325   scoped_ptr<ServiceWatcher> watcher(
326       service_discovery_client_->CreateServiceWatcher(
327           service_type,
328           base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated,
329                      base::Unretained(this), id)));
330   watcher->Start();
331   service_watchers_[id].reset(watcher.release());
332 }
333 
DiscoverServices(uint64 id,bool force_update)334 void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id,
335                                                       bool force_update) {
336   VLOG(1) << "DiscoverServices, id=" << id;
337   if (!service_discovery_client_)
338     return;
339   DCHECK(ContainsKey(service_watchers_, id));
340   service_watchers_[id]->DiscoverNewServices(force_update);
341 }
342 
SetActivelyRefreshServices(uint64 id,bool actively_refresh_services)343 void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices(
344     uint64 id,
345     bool actively_refresh_services) {
346   VLOG(1) << "ActivelyRefreshServices, id=" << id;
347   if (!service_discovery_client_)
348     return;
349   DCHECK(ContainsKey(service_watchers_, id));
350   service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services);
351 }
352 
DestroyWatcher(uint64 id)353 void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) {
354   VLOG(1) << "DestoryWatcher, id=" << id;
355   if (!service_discovery_client_)
356     return;
357   service_watchers_.erase(id);
358 }
359 
ResolveService(uint64 id,const std::string & service_name)360 void ServiceDiscoveryMessageHandler::ResolveService(
361     uint64 id,
362     const std::string& service_name) {
363   VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name;
364   if (!service_discovery_client_)
365     return;
366   DCHECK(!ContainsKey(service_resolvers_, id));
367   scoped_ptr<ServiceResolver> resolver(
368       service_discovery_client_->CreateServiceResolver(
369           service_name,
370           base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved,
371                      base::Unretained(this), id)));
372   resolver->StartResolving();
373   service_resolvers_[id].reset(resolver.release());
374 }
375 
DestroyResolver(uint64 id)376 void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) {
377   VLOG(1) << "DestroyResolver, id=" << id;
378   if (!service_discovery_client_)
379     return;
380   service_resolvers_.erase(id);
381 }
382 
ResolveLocalDomain(uint64 id,const std::string & domain,net::AddressFamily address_family)383 void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
384     uint64 id,
385     const std::string& domain,
386     net::AddressFamily address_family) {
387   VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain;
388   if (!service_discovery_client_)
389     return;
390   DCHECK(!ContainsKey(local_domain_resolvers_, id));
391   scoped_ptr<LocalDomainResolver> resolver(
392       service_discovery_client_->CreateLocalDomainResolver(
393           domain, address_family,
394           base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved,
395                      base::Unretained(this), id)));
396   resolver->Start();
397   local_domain_resolvers_[id].reset(resolver.release());
398 }
399 
DestroyLocalDomainResolver(uint64 id)400 void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) {
401   VLOG(1) << "DestroyLocalDomainResolver, id=" << id;
402   if (!service_discovery_client_)
403     return;
404   local_domain_resolvers_.erase(id);
405 }
406 
ShutdownLocalDiscovery()407 void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
408   if (!discovery_task_runner_.get())
409     return;
410 
411   discovery_task_runner_->PostTask(
412       FROM_HERE,
413       base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread,
414                  base::Unretained(this)));
415 
416   // This will wait for message loop to drain, so ShutdownOnIOThread will
417   // definitely be called.
418   discovery_thread_.reset();
419 }
420 
ShutdownOnIOThread()421 void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
422   VLOG(1) << "ShutdownLocalDiscovery";
423   service_watchers_.clear();
424   service_resolvers_.clear();
425   local_domain_resolvers_.clear();
426   service_discovery_client_.reset();
427   mdns_client_.reset();
428 }
429 
OnServiceUpdated(uint64 id,ServiceWatcher::UpdateType update,const std::string & name)430 void ServiceDiscoveryMessageHandler::OnServiceUpdated(
431     uint64 id,
432     ServiceWatcher::UpdateType update,
433     const std::string& name) {
434   VLOG(1) << "OnServiceUpdated, id=" << id
435           << ", status=" << WatcherUpdateToString(update) << ", name=" << name;
436   DCHECK(service_discovery_client_);
437 
438   Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name));
439 }
440 
OnServiceResolved(uint64 id,ServiceResolver::RequestStatus status,const ServiceDescription & description)441 void ServiceDiscoveryMessageHandler::OnServiceResolved(
442     uint64 id,
443     ServiceResolver::RequestStatus status,
444     const ServiceDescription& description) {
445   VLOG(1) << "OnServiceResolved, id=" << id
446           << ", status=" << ResolverStatusToString(status)
447           << ", name=" << description.service_name;
448 
449   DCHECK(service_discovery_client_);
450   Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description));
451 }
452 
OnLocalDomainResolved(uint64 id,bool success,const net::IPAddressNumber & address_ipv4,const net::IPAddressNumber & address_ipv6)453 void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
454     uint64 id,
455     bool success,
456     const net::IPAddressNumber& address_ipv4,
457     const net::IPAddressNumber& address_ipv6) {
458   VLOG(1) << "OnLocalDomainResolved, id=" << id
459           << ", IPv4=" << (address_ipv4.empty() ? "" :
460                            net::IPAddressToString(address_ipv4))
461           << ", IPv6=" << (address_ipv6.empty() ? "" :
462                            net::IPAddressToString(address_ipv6));
463 
464   DCHECK(service_discovery_client_);
465   Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
466           id, success, address_ipv4, address_ipv6));
467 }
468 
Send(IPC::Message * msg)469 void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) {
470   utility_task_runner_->PostTask(FROM_HERE,
471                                  base::Bind(&SendHostMessageOnUtilityThread,
472                                             msg));
473 }
474 
475 }  // namespace local_discovery
476