• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/dns_socket_pool.h"
6 
7 #include "base/logging.h"
8 #include "base/rand_util.h"
9 #include "base/stl_util.h"
10 #include "net/base/address_list.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/base/net_errors.h"
13 #include "net/base/rand_callback.h"
14 #include "net/socket/client_socket_factory.h"
15 #include "net/socket/stream_socket.h"
16 #include "net/udp/datagram_client_socket.h"
17 
18 namespace net {
19 
20 namespace {
21 
22 // When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
23 // When we allocate a socket, we ensure we have at least kAllocateMinSize
24 // sockets to choose from.  Freed sockets are not retained.
25 
26 // On Windows, we can't request specific (random) ports, since that will
27 // trigger firewall prompts, so request default ones, but keep a pile of
28 // them.  Everywhere else, request fresh, random ports each time.
29 #if defined(OS_WIN)
30 const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND;
31 const unsigned kInitialPoolSize = 256;
32 const unsigned kAllocateMinSize = 256;
33 #else
34 const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND;
35 const unsigned kInitialPoolSize = 0;
36 const unsigned kAllocateMinSize = 1;
37 #endif
38 
39 } // namespace
40 
DnsSocketPool(ClientSocketFactory * socket_factory)41 DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory)
42     : socket_factory_(socket_factory),
43       net_log_(NULL),
44       nameservers_(NULL),
45       initialized_(false) {
46 }
47 
InitializeInternal(const std::vector<IPEndPoint> * nameservers,NetLog * net_log)48 void DnsSocketPool::InitializeInternal(
49     const std::vector<IPEndPoint>* nameservers,
50     NetLog* net_log) {
51   DCHECK(nameservers);
52   DCHECK(!initialized_);
53 
54   net_log_ = net_log;
55   nameservers_ = nameservers;
56   initialized_ = true;
57 }
58 
CreateTCPSocket(unsigned server_index,const NetLog::Source & source)59 scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket(
60     unsigned server_index,
61     const NetLog::Source& source) {
62   DCHECK_LT(server_index, nameservers_->size());
63 
64   return scoped_ptr<StreamSocket>(
65       socket_factory_->CreateTransportClientSocket(
66           AddressList((*nameservers_)[server_index]), net_log_, source));
67 }
68 
CreateConnectedSocket(unsigned server_index)69 scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
70     unsigned server_index) {
71   DCHECK_LT(server_index, nameservers_->size());
72 
73   scoped_ptr<DatagramClientSocket> socket;
74 
75   NetLog::Source no_source;
76   socket = socket_factory_->CreateDatagramClientSocket(
77       kBindType, base::Bind(&base::RandInt), net_log_, no_source);
78 
79   if (socket.get()) {
80     int rv = socket->Connect((*nameservers_)[server_index]);
81     if (rv != OK) {
82       VLOG(1) << "Failed to connect socket: " << rv;
83       socket.reset();
84     }
85   } else {
86     LOG(WARNING) << "Failed to create socket.";
87   }
88 
89   return socket.Pass();
90 }
91 
92 class NullDnsSocketPool : public DnsSocketPool {
93  public:
NullDnsSocketPool(ClientSocketFactory * factory)94   NullDnsSocketPool(ClientSocketFactory* factory)
95      : DnsSocketPool(factory) {
96   }
97 
Initialize(const std::vector<IPEndPoint> * nameservers,NetLog * net_log)98   virtual void Initialize(
99       const std::vector<IPEndPoint>* nameservers,
100       NetLog* net_log) OVERRIDE {
101     InitializeInternal(nameservers, net_log);
102   }
103 
AllocateSocket(unsigned server_index)104   virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
105       unsigned server_index) OVERRIDE {
106     return CreateConnectedSocket(server_index);
107   }
108 
FreeSocket(unsigned server_index,scoped_ptr<DatagramClientSocket> socket)109   virtual void FreeSocket(
110       unsigned server_index,
111       scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
112   }
113 
114  private:
115   DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool);
116 };
117 
118 // static
CreateNull(ClientSocketFactory * factory)119 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull(
120     ClientSocketFactory* factory) {
121   return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory));
122 }
123 
124 class DefaultDnsSocketPool : public DnsSocketPool {
125  public:
DefaultDnsSocketPool(ClientSocketFactory * factory)126   DefaultDnsSocketPool(ClientSocketFactory* factory)
127      : DnsSocketPool(factory) {
128   };
129 
130   virtual ~DefaultDnsSocketPool();
131 
132   virtual void Initialize(
133       const std::vector<IPEndPoint>* nameservers,
134       NetLog* net_log) OVERRIDE;
135 
136   virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
137       unsigned server_index) OVERRIDE;
138 
139   virtual void FreeSocket(
140       unsigned server_index,
141       scoped_ptr<DatagramClientSocket> socket) OVERRIDE;
142 
143  private:
144   void FillPool(unsigned server_index, unsigned size);
145 
146   typedef std::vector<DatagramClientSocket*> SocketVector;
147 
148   std::vector<SocketVector> pools_;
149 
150   DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool);
151 };
152 
153 // static
CreateDefault(ClientSocketFactory * factory)154 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault(
155     ClientSocketFactory* factory) {
156   return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory));
157 }
158 
Initialize(const std::vector<IPEndPoint> * nameservers,NetLog * net_log)159 void DefaultDnsSocketPool::Initialize(
160     const std::vector<IPEndPoint>* nameservers,
161     NetLog* net_log) {
162   InitializeInternal(nameservers, net_log);
163 
164   DCHECK(pools_.empty());
165   const unsigned num_servers = nameservers->size();
166   pools_.resize(num_servers);
167   for (unsigned server_index = 0; server_index < num_servers; ++server_index)
168     FillPool(server_index, kInitialPoolSize);
169 }
170 
~DefaultDnsSocketPool()171 DefaultDnsSocketPool::~DefaultDnsSocketPool() {
172   unsigned num_servers = pools_.size();
173   for (unsigned server_index = 0; server_index < num_servers; ++server_index) {
174     SocketVector& pool = pools_[server_index];
175     STLDeleteElements(&pool);
176   }
177 }
178 
AllocateSocket(unsigned server_index)179 scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket(
180     unsigned server_index) {
181   DCHECK_LT(server_index, pools_.size());
182   SocketVector& pool = pools_[server_index];
183 
184   FillPool(server_index, kAllocateMinSize);
185   if (pool.size() == 0) {
186     LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!";
187     return scoped_ptr<DatagramClientSocket>();
188   }
189 
190   if (pool.size() < kAllocateMinSize) {
191     LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize
192                  << " sockets to choose from, but only have " << pool.size()
193                  << " in pool " << server_index << ".";
194   }
195 
196   unsigned socket_index = base::RandInt(0, pool.size() - 1);
197   DatagramClientSocket* socket = pool[socket_index];
198   pool[socket_index] = pool.back();
199   pool.pop_back();
200 
201   return scoped_ptr<DatagramClientSocket>(socket);
202 }
203 
FreeSocket(unsigned server_index,scoped_ptr<DatagramClientSocket> socket)204 void DefaultDnsSocketPool::FreeSocket(
205     unsigned server_index,
206     scoped_ptr<DatagramClientSocket> socket) {
207   DCHECK_LT(server_index, pools_.size());
208 }
209 
FillPool(unsigned server_index,unsigned size)210 void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) {
211   SocketVector& pool = pools_[server_index];
212 
213   for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) {
214     DatagramClientSocket* socket =
215         CreateConnectedSocket(server_index).release();
216     if (!socket)
217       break;
218     pool.push_back(socket);
219   }
220 }
221 
222 } // namespace net
223