• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/udp_socket_win.h"
6 
7 #include <mstcpip.h>
8 #include <winsock2.h>
9 
10 #include <memory>
11 
12 #include "base/check_op.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback.h"
15 #include "base/lazy_instance.h"
16 #include "base/memory/raw_ptr.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/metrics/histogram_macros.h"
19 #include "base/notreached.h"
20 #include "base/rand_util.h"
21 #include "base/task/thread_pool.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/network_activity_monitor.h"
27 #include "net/base/network_change_notifier.h"
28 #include "net/base/sockaddr_storage.h"
29 #include "net/base/winsock_init.h"
30 #include "net/base/winsock_util.h"
31 #include "net/log/net_log.h"
32 #include "net/log/net_log_event_type.h"
33 #include "net/log/net_log_source.h"
34 #include "net/log/net_log_source_type.h"
35 #include "net/socket/socket_descriptor.h"
36 #include "net/socket/socket_options.h"
37 #include "net/socket/socket_tag.h"
38 #include "net/socket/udp_net_log_parameters.h"
39 #include "net/traffic_annotation/network_traffic_annotation.h"
40 
41 namespace net {
42 
43 // This class encapsulates all the state that has to be preserved as long as
44 // there is a network IO operation in progress. If the owner UDPSocketWin
45 // is destroyed while an operation is in progress, the Core is detached and it
46 // lives until the operation completes and the OS doesn't reference any resource
47 // declared on this class anymore.
48 class UDPSocketWin::Core : public base::RefCounted<Core> {
49  public:
50   explicit Core(UDPSocketWin* socket);
51 
52   Core(const Core&) = delete;
53   Core& operator=(const Core&) = delete;
54 
55   // Start watching for the end of a read or write operation.
56   void WatchForRead();
57   void WatchForWrite();
58 
59   // The UDPSocketWin is going away.
Detach()60   void Detach() { socket_ = nullptr; }
61 
62   // The separate OVERLAPPED variables for asynchronous operation.
63   OVERLAPPED read_overlapped_;
64   OVERLAPPED write_overlapped_;
65 
66   // The buffers used in Read() and Write().
67   scoped_refptr<IOBuffer> read_iobuffer_;
68   scoped_refptr<IOBuffer> write_iobuffer_;
69 
70   // The address storage passed to WSARecvFrom().
71   SockaddrStorage recv_addr_storage_;
72 
73  private:
74   friend class base::RefCounted<Core>;
75 
76   class ReadDelegate : public base::win::ObjectWatcher::Delegate {
77    public:
ReadDelegate(Core * core)78     explicit ReadDelegate(Core* core) : core_(core) {}
79     ~ReadDelegate() override = default;
80 
81     // base::ObjectWatcher::Delegate methods:
82     void OnObjectSignaled(HANDLE object) override;
83 
84    private:
85     const raw_ptr<Core> core_;
86   };
87 
88   class WriteDelegate : public base::win::ObjectWatcher::Delegate {
89    public:
WriteDelegate(Core * core)90     explicit WriteDelegate(Core* core) : core_(core) {}
91     ~WriteDelegate() override = default;
92 
93     // base::ObjectWatcher::Delegate methods:
94     void OnObjectSignaled(HANDLE object) override;
95 
96    private:
97     const raw_ptr<Core> core_;
98   };
99 
100   ~Core();
101 
102   // The socket that created this object.
103   raw_ptr<UDPSocketWin> socket_;
104 
105   // |reader_| handles the signals from |read_watcher_|.
106   ReadDelegate reader_;
107   // |writer_| handles the signals from |write_watcher_|.
108   WriteDelegate writer_;
109 
110   // |read_watcher_| watches for events from Read().
111   base::win::ObjectWatcher read_watcher_;
112   // |write_watcher_| watches for events from Write();
113   base::win::ObjectWatcher write_watcher_;
114 };
115 
Core(UDPSocketWin * socket)116 UDPSocketWin::Core::Core(UDPSocketWin* socket)
117     : socket_(socket),
118       reader_(this),
119       writer_(this) {
120   memset(&read_overlapped_, 0, sizeof(read_overlapped_));
121   memset(&write_overlapped_, 0, sizeof(write_overlapped_));
122 
123   read_overlapped_.hEvent = WSACreateEvent();
124   write_overlapped_.hEvent = WSACreateEvent();
125 }
126 
~Core()127 UDPSocketWin::Core::~Core() {
128   // Make sure the message loop is not watching this object anymore.
129   read_watcher_.StopWatching();
130   write_watcher_.StopWatching();
131 
132   WSACloseEvent(read_overlapped_.hEvent);
133   memset(&read_overlapped_, 0xaf, sizeof(read_overlapped_));
134   WSACloseEvent(write_overlapped_.hEvent);
135   memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
136 }
137 
WatchForRead()138 void UDPSocketWin::Core::WatchForRead() {
139   // We grab an extra reference because there is an IO operation in progress.
140   // Balanced in ReadDelegate::OnObjectSignaled().
141   AddRef();
142   read_watcher_.StartWatchingOnce(read_overlapped_.hEvent, &reader_);
143 }
144 
WatchForWrite()145 void UDPSocketWin::Core::WatchForWrite() {
146   // We grab an extra reference because there is an IO operation in progress.
147   // Balanced in WriteDelegate::OnObjectSignaled().
148   AddRef();
149   write_watcher_.StartWatchingOnce(write_overlapped_.hEvent, &writer_);
150 }
151 
OnObjectSignaled(HANDLE object)152 void UDPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
153   DCHECK_EQ(object, core_->read_overlapped_.hEvent);
154   if (core_->socket_)
155     core_->socket_->DidCompleteRead();
156 
157   core_->Release();
158 }
159 
OnObjectSignaled(HANDLE object)160 void UDPSocketWin::Core::WriteDelegate::OnObjectSignaled(HANDLE object) {
161   DCHECK_EQ(object, core_->write_overlapped_.hEvent);
162   if (core_->socket_)
163     core_->socket_->DidCompleteWrite();
164 
165   core_->Release();
166 }
167 //-----------------------------------------------------------------------------
168 
QwaveApi()169 QwaveApi::QwaveApi() {
170   HMODULE qwave = LoadLibrary(L"qwave.dll");
171   if (!qwave)
172     return;
173   create_handle_func_ =
174       (CreateHandleFn)GetProcAddress(qwave, "QOSCreateHandle");
175   close_handle_func_ =
176       (CloseHandleFn)GetProcAddress(qwave, "QOSCloseHandle");
177   add_socket_to_flow_func_ =
178       (AddSocketToFlowFn)GetProcAddress(qwave, "QOSAddSocketToFlow");
179   remove_socket_from_flow_func_ =
180       (RemoveSocketFromFlowFn)GetProcAddress(qwave, "QOSRemoveSocketFromFlow");
181   set_flow_func_ = (SetFlowFn)GetProcAddress(qwave, "QOSSetFlow");
182 
183   if (create_handle_func_ && close_handle_func_ &&
184       add_socket_to_flow_func_ && remove_socket_from_flow_func_ &&
185       set_flow_func_) {
186     qwave_supported_ = true;
187   }
188 }
189 
GetDefault()190 QwaveApi* QwaveApi::GetDefault() {
191   static base::LazyInstance<QwaveApi>::Leaky lazy_qwave =
192       LAZY_INSTANCE_INITIALIZER;
193   return lazy_qwave.Pointer();
194 }
195 
qwave_supported() const196 bool QwaveApi::qwave_supported() const {
197   return qwave_supported_;
198 }
199 
OnFatalError()200 void QwaveApi::OnFatalError() {
201   // Disable everything moving forward.
202   qwave_supported_ = false;
203 }
204 
CreateHandle(PQOS_VERSION version,PHANDLE handle)205 BOOL QwaveApi::CreateHandle(PQOS_VERSION version, PHANDLE handle) {
206   return create_handle_func_(version, handle);
207 }
208 
CloseHandle(HANDLE handle)209 BOOL QwaveApi::CloseHandle(HANDLE handle) {
210   return close_handle_func_(handle);
211 }
212 
AddSocketToFlow(HANDLE handle,SOCKET socket,PSOCKADDR addr,QOS_TRAFFIC_TYPE traffic_type,DWORD flags,PQOS_FLOWID flow_id)213 BOOL QwaveApi::AddSocketToFlow(HANDLE handle,
214                                SOCKET socket,
215                                PSOCKADDR addr,
216                                QOS_TRAFFIC_TYPE traffic_type,
217                                DWORD flags,
218                                PQOS_FLOWID flow_id) {
219   return add_socket_to_flow_func_(handle, socket, addr, traffic_type, flags,
220                                   flow_id);
221 }
222 
RemoveSocketFromFlow(HANDLE handle,SOCKET socket,QOS_FLOWID flow_id,DWORD reserved)223 BOOL QwaveApi::RemoveSocketFromFlow(HANDLE handle,
224                                     SOCKET socket,
225                                     QOS_FLOWID flow_id,
226                                     DWORD reserved) {
227   return remove_socket_from_flow_func_(handle, socket, flow_id, reserved);
228 }
229 
SetFlow(HANDLE handle,QOS_FLOWID flow_id,QOS_SET_FLOW op,ULONG size,PVOID data,DWORD reserved,LPOVERLAPPED overlapped)230 BOOL QwaveApi::SetFlow(HANDLE handle,
231                        QOS_FLOWID flow_id,
232                        QOS_SET_FLOW op,
233                        ULONG size,
234                        PVOID data,
235                        DWORD reserved,
236                        LPOVERLAPPED overlapped) {
237   return set_flow_func_(handle, flow_id, op, size, data, reserved, overlapped);
238 }
239 
240 //-----------------------------------------------------------------------------
241 
UDPSocketWin(DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)242 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
243                            net::NetLog* net_log,
244                            const net::NetLogSource& source)
245     : socket_(INVALID_SOCKET),
246       socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
247       net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::UDP_SOCKET)) {
248   EnsureWinsockInit();
249   net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source);
250 }
251 
UDPSocketWin(DatagramSocket::BindType bind_type,NetLogWithSource source_net_log)252 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
253                            NetLogWithSource source_net_log)
254     : socket_(INVALID_SOCKET),
255       socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
256       net_log_(source_net_log) {
257   EnsureWinsockInit();
258   net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE,
259                                        net_log_.source());
260 }
261 
~UDPSocketWin()262 UDPSocketWin::~UDPSocketWin() {
263   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
264   Close();
265   net_log_.EndEvent(NetLogEventType::SOCKET_ALIVE);
266 }
267 
Open(AddressFamily address_family)268 int UDPSocketWin::Open(AddressFamily address_family) {
269   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
270   DCHECK_EQ(socket_, INVALID_SOCKET);
271 
272   auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
273   if (owned_socket_count.empty())
274     return ERR_INSUFFICIENT_RESOURCES;
275 
276   owned_socket_count_ = std::move(owned_socket_count);
277   addr_family_ = ConvertAddressFamily(address_family);
278   socket_ = CreatePlatformSocket(addr_family_, SOCK_DGRAM, IPPROTO_UDP);
279   if (socket_ == INVALID_SOCKET) {
280     owned_socket_count_.Reset();
281     return MapSystemError(WSAGetLastError());
282   }
283   ConfigureOpenedSocket();
284   return OK;
285 }
286 
AdoptOpenedSocket(AddressFamily address_family,SOCKET socket)287 int UDPSocketWin::AdoptOpenedSocket(AddressFamily address_family,
288                                     SOCKET socket) {
289   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
290   auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
291   if (owned_socket_count.empty()) {
292     return ERR_INSUFFICIENT_RESOURCES;
293   }
294 
295   owned_socket_count_ = std::move(owned_socket_count);
296   addr_family_ = ConvertAddressFamily(address_family);
297   socket_ = socket;
298   ConfigureOpenedSocket();
299   return OK;
300 }
301 
ConfigureOpenedSocket()302 void UDPSocketWin::ConfigureOpenedSocket() {
303   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
304   if (!use_non_blocking_io_) {
305     core_ = base::MakeRefCounted<Core>(this);
306   } else {
307     read_write_event_.Set(WSACreateEvent());
308     WSAEventSelect(socket_, read_write_event_.Get(), FD_READ | FD_WRITE);
309   }
310 }
311 
Close()312 void UDPSocketWin::Close() {
313   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
314 
315   owned_socket_count_.Reset();
316 
317   if (socket_ == INVALID_SOCKET)
318     return;
319 
320   // Remove socket_ from the QoS subsystem before we invalidate it.
321   dscp_manager_ = nullptr;
322 
323   // Zero out any pending read/write callback state.
324   read_callback_.Reset();
325   recv_from_address_ = nullptr;
326   write_callback_.Reset();
327 
328   base::TimeTicks start_time = base::TimeTicks::Now();
329   closesocket(socket_);
330   UMA_HISTOGRAM_TIMES("Net.UDPSocketWinClose",
331                       base::TimeTicks::Now() - start_time);
332   socket_ = INVALID_SOCKET;
333   addr_family_ = 0;
334   is_connected_ = false;
335 
336   // Release buffers to free up memory.
337   read_iobuffer_ = nullptr;
338   read_iobuffer_len_ = 0;
339   write_iobuffer_ = nullptr;
340   write_iobuffer_len_ = 0;
341 
342   read_write_watcher_.StopWatching();
343   read_write_event_.Close();
344 
345   event_pending_.InvalidateWeakPtrs();
346 
347   if (core_) {
348     core_->Detach();
349     core_ = nullptr;
350   }
351 }
352 
GetPeerAddress(IPEndPoint * address) const353 int UDPSocketWin::GetPeerAddress(IPEndPoint* address) const {
354   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
355   DCHECK(address);
356   if (!is_connected())
357     return ERR_SOCKET_NOT_CONNECTED;
358 
359   // TODO(szym): Simplify. http://crbug.com/126152
360   if (!remote_address_.get()) {
361     SockaddrStorage storage;
362     if (getpeername(socket_, storage.addr, &storage.addr_len))
363       return MapSystemError(WSAGetLastError());
364     auto remote_address = std::make_unique<IPEndPoint>();
365     if (!remote_address->FromSockAddr(storage.addr, storage.addr_len))
366       return ERR_ADDRESS_INVALID;
367     remote_address_ = std::move(remote_address);
368   }
369 
370   *address = *remote_address_;
371   return OK;
372 }
373 
GetLocalAddress(IPEndPoint * address) const374 int UDPSocketWin::GetLocalAddress(IPEndPoint* address) const {
375   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
376   DCHECK(address);
377   if (!is_connected())
378     return ERR_SOCKET_NOT_CONNECTED;
379 
380   // TODO(szym): Simplify. http://crbug.com/126152
381   if (!local_address_.get()) {
382     SockaddrStorage storage;
383     if (getsockname(socket_, storage.addr, &storage.addr_len))
384       return MapSystemError(WSAGetLastError());
385     auto local_address = std::make_unique<IPEndPoint>();
386     if (!local_address->FromSockAddr(storage.addr, storage.addr_len))
387       return ERR_ADDRESS_INVALID;
388     local_address_ = std::move(local_address);
389     net_log_.AddEvent(NetLogEventType::UDP_LOCAL_ADDRESS, [&] {
390       return CreateNetLogUDPConnectParams(*local_address_,
391                                           handles::kInvalidNetworkHandle);
392     });
393   }
394 
395   *address = *local_address_;
396   return OK;
397 }
398 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)399 int UDPSocketWin::Read(IOBuffer* buf,
400                        int buf_len,
401                        CompletionOnceCallback callback) {
402   return RecvFrom(buf, buf_len, nullptr, std::move(callback));
403 }
404 
RecvFrom(IOBuffer * buf,int buf_len,IPEndPoint * address,CompletionOnceCallback callback)405 int UDPSocketWin::RecvFrom(IOBuffer* buf,
406                            int buf_len,
407                            IPEndPoint* address,
408                            CompletionOnceCallback callback) {
409   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
410   DCHECK_NE(INVALID_SOCKET, socket_);
411   CHECK(read_callback_.is_null());
412   DCHECK(!recv_from_address_);
413   DCHECK(!callback.is_null());  // Synchronous operation not supported.
414   DCHECK_GT(buf_len, 0);
415 
416   int nread = core_ ? InternalRecvFromOverlapped(buf, buf_len, address)
417                     : InternalRecvFromNonBlocking(buf, buf_len, address);
418   if (nread != ERR_IO_PENDING)
419     return nread;
420 
421   read_callback_ = std::move(callback);
422   recv_from_address_ = address;
423   return ERR_IO_PENDING;
424 }
425 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)426 int UDPSocketWin::Write(
427     IOBuffer* buf,
428     int buf_len,
429     CompletionOnceCallback callback,
430     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
431   return SendToOrWrite(buf, buf_len, remote_address_.get(),
432                        std::move(callback));
433 }
434 
SendTo(IOBuffer * buf,int buf_len,const IPEndPoint & address,CompletionOnceCallback callback)435 int UDPSocketWin::SendTo(IOBuffer* buf,
436                          int buf_len,
437                          const IPEndPoint& address,
438                          CompletionOnceCallback callback) {
439   if (dscp_manager_) {
440     // Alert DscpManager in case this is a new remote address.  Failure to
441     // apply Dscp code is never fatal.
442     int rv = dscp_manager_->PrepareForSend(address);
443     if (rv != OK)
444       net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, rv);
445   }
446   return SendToOrWrite(buf, buf_len, &address, std::move(callback));
447 }
448 
SendToOrWrite(IOBuffer * buf,int buf_len,const IPEndPoint * address,CompletionOnceCallback callback)449 int UDPSocketWin::SendToOrWrite(IOBuffer* buf,
450                                 int buf_len,
451                                 const IPEndPoint* address,
452                                 CompletionOnceCallback callback) {
453   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
454   DCHECK_NE(INVALID_SOCKET, socket_);
455   CHECK(write_callback_.is_null());
456   DCHECK(!callback.is_null());  // Synchronous operation not supported.
457   DCHECK_GT(buf_len, 0);
458   DCHECK(!send_to_address_.get());
459 
460   int nwrite = core_ ? InternalSendToOverlapped(buf, buf_len, address)
461                      : InternalSendToNonBlocking(buf, buf_len, address);
462   if (nwrite != ERR_IO_PENDING)
463     return nwrite;
464 
465   if (address)
466     send_to_address_ = std::make_unique<IPEndPoint>(*address);
467   write_callback_ = std::move(callback);
468   return ERR_IO_PENDING;
469 }
470 
Connect(const IPEndPoint & address)471 int UDPSocketWin::Connect(const IPEndPoint& address) {
472   DCHECK_NE(socket_, INVALID_SOCKET);
473   net_log_.BeginEvent(NetLogEventType::UDP_CONNECT, [&] {
474     return CreateNetLogUDPConnectParams(address,
475                                         handles::kInvalidNetworkHandle);
476   });
477   int rv = SetMulticastOptions();
478   if (rv != OK)
479     return rv;
480   rv = InternalConnect(address);
481   net_log_.EndEventWithNetErrorCode(NetLogEventType::UDP_CONNECT, rv);
482   is_connected_ = (rv == OK);
483   return rv;
484 }
485 
InternalConnect(const IPEndPoint & address)486 int UDPSocketWin::InternalConnect(const IPEndPoint& address) {
487   DCHECK(!is_connected());
488   DCHECK(!remote_address_.get());
489 
490   // Always do a random bind.
491   // Ignore failures, which may happen if the socket was already bound.
492   DWORD randomize_port_value = 1;
493   setsockopt(socket_, SOL_SOCKET, SO_RANDOMIZE_PORT,
494              reinterpret_cast<const char*>(&randomize_port_value),
495              sizeof(randomize_port_value));
496 
497   SockaddrStorage storage;
498   if (!address.ToSockAddr(storage.addr, &storage.addr_len))
499     return ERR_ADDRESS_INVALID;
500 
501   int rv = connect(socket_, storage.addr, storage.addr_len);
502   if (rv < 0)
503     return MapSystemError(WSAGetLastError());
504 
505   remote_address_ = std::make_unique<IPEndPoint>(address);
506 
507   if (dscp_manager_)
508     dscp_manager_->PrepareForSend(*remote_address_.get());
509 
510   return rv;
511 }
512 
Bind(const IPEndPoint & address)513 int UDPSocketWin::Bind(const IPEndPoint& address) {
514   DCHECK_NE(socket_, INVALID_SOCKET);
515   DCHECK(!is_connected());
516 
517   int rv = SetMulticastOptions();
518   if (rv < 0)
519     return rv;
520 
521   rv = DoBind(address);
522   if (rv < 0)
523     return rv;
524 
525   local_address_.reset();
526   is_connected_ = true;
527   return rv;
528 }
529 
BindToNetwork(handles::NetworkHandle network)530 int UDPSocketWin::BindToNetwork(handles::NetworkHandle network) {
531   NOTIMPLEMENTED();
532   return ERR_NOT_IMPLEMENTED;
533 }
534 
SetReceiveBufferSize(int32_t size)535 int UDPSocketWin::SetReceiveBufferSize(int32_t size) {
536   DCHECK_NE(socket_, INVALID_SOCKET);
537   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
538   int rv = SetSocketReceiveBufferSize(socket_, size);
539 
540   if (rv != 0)
541     return MapSystemError(WSAGetLastError());
542 
543   // According to documentation, setsockopt may succeed, but we need to check
544   // the results via getsockopt to be sure it works on Windows.
545   int32_t actual_size = 0;
546   int option_size = sizeof(actual_size);
547   rv = getsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
548                   reinterpret_cast<char*>(&actual_size), &option_size);
549   if (rv != 0)
550     return MapSystemError(WSAGetLastError());
551   if (actual_size >= size)
552     return OK;
553   UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableReceiveBuffer",
554                               actual_size, 1000, 1000000, 50);
555   return ERR_SOCKET_RECEIVE_BUFFER_SIZE_UNCHANGEABLE;
556 }
557 
SetSendBufferSize(int32_t size)558 int UDPSocketWin::SetSendBufferSize(int32_t size) {
559   DCHECK_NE(socket_, INVALID_SOCKET);
560   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
561   int rv = SetSocketSendBufferSize(socket_, size);
562   if (rv != 0)
563     return MapSystemError(WSAGetLastError());
564   // According to documentation, setsockopt may succeed, but we need to check
565   // the results via getsockopt to be sure it works on Windows.
566   int32_t actual_size = 0;
567   int option_size = sizeof(actual_size);
568   rv = getsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
569                   reinterpret_cast<char*>(&actual_size), &option_size);
570   if (rv != 0)
571     return MapSystemError(WSAGetLastError());
572   if (actual_size >= size)
573     return OK;
574   UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableSendBuffer",
575                               actual_size, 1000, 1000000, 50);
576   return ERR_SOCKET_SEND_BUFFER_SIZE_UNCHANGEABLE;
577 }
578 
SetDoNotFragment()579 int UDPSocketWin::SetDoNotFragment() {
580   DCHECK_NE(socket_, INVALID_SOCKET);
581   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
582 
583   if (addr_family_ == AF_INET6)
584     return OK;
585 
586   DWORD val = 1;
587   int rv = setsockopt(socket_, IPPROTO_IP, IP_DONTFRAGMENT,
588                       reinterpret_cast<const char*>(&val), sizeof(val));
589   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
590 }
591 
SetRecvEcn()592 int UDPSocketWin::SetRecvEcn() {
593   DCHECK_NE(socket_, INVALID_SOCKET);
594   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
595 
596   int rv;
597   unsigned int ecn = 1;
598   if (addr_family_ == AF_INET6) {
599     rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_RECVTCLASS,
600                     reinterpret_cast<const char*>(&ecn), sizeof(ecn));
601   } else {
602     DCHECK_EQ(addr_family_, AF_INET);
603     rv = setsockopt(socket_, IPPROTO_IP, IP_RECVTOS,
604                     reinterpret_cast<const char*>(&ecn), sizeof(ecn));
605   }
606   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
607 }
608 
SetMsgConfirm(bool confirm)609 void UDPSocketWin::SetMsgConfirm(bool confirm) {}
610 
AllowAddressReuse()611 int UDPSocketWin::AllowAddressReuse() {
612   DCHECK_NE(socket_, INVALID_SOCKET);
613   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
614   DCHECK(!is_connected());
615 
616   BOOL true_value = TRUE;
617   int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
618                       reinterpret_cast<const char*>(&true_value),
619                       sizeof(true_value));
620   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
621 }
622 
SetBroadcast(bool broadcast)623 int UDPSocketWin::SetBroadcast(bool broadcast) {
624   DCHECK_NE(socket_, INVALID_SOCKET);
625   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
626 
627   BOOL value = broadcast ? TRUE : FALSE;
628   int rv = setsockopt(socket_, SOL_SOCKET, SO_BROADCAST,
629                       reinterpret_cast<const char*>(&value), sizeof(value));
630   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
631 }
632 
AllowAddressSharingForMulticast()633 int UDPSocketWin::AllowAddressSharingForMulticast() {
634   // When proper multicast groups are used, Windows further defines the address
635   // resuse option (SO_REUSEADDR) to ensure all listening sockets can receive
636   // all incoming messages for the multicast group.
637   return AllowAddressReuse();
638 }
639 
DoReadCallback(int rv)640 void UDPSocketWin::DoReadCallback(int rv) {
641   DCHECK_NE(rv, ERR_IO_PENDING);
642   DCHECK(!read_callback_.is_null());
643 
644   // since Run may result in Read being called, clear read_callback_ up front.
645   std::move(read_callback_).Run(rv);
646 }
647 
DoWriteCallback(int rv)648 void UDPSocketWin::DoWriteCallback(int rv) {
649   DCHECK_NE(rv, ERR_IO_PENDING);
650   DCHECK(!write_callback_.is_null());
651 
652   // since Run may result in Write being called, clear write_callback_ up front.
653   std::move(write_callback_).Run(rv);
654 }
655 
DidCompleteRead()656 void UDPSocketWin::DidCompleteRead() {
657   DWORD num_bytes, flags;
658   BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_,
659                                    &num_bytes, FALSE, &flags);
660   WSAResetEvent(core_->read_overlapped_.hEvent);
661   int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
662   // Convert address.
663   IPEndPoint address;
664   IPEndPoint* address_to_log = nullptr;
665   if (result >= 0) {
666     if (address.FromSockAddr(core_->recv_addr_storage_.addr,
667                              core_->recv_addr_storage_.addr_len)) {
668       if (recv_from_address_)
669         *recv_from_address_ = address;
670       address_to_log = &address;
671     } else {
672       result = ERR_ADDRESS_INVALID;
673     }
674   }
675   LogRead(result, core_->read_iobuffer_->data(), address_to_log);
676   core_->read_iobuffer_ = nullptr;
677   recv_from_address_ = nullptr;
678   DoReadCallback(result);
679 }
680 
DidCompleteWrite()681 void UDPSocketWin::DidCompleteWrite() {
682   DWORD num_bytes, flags;
683   BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
684                                    &num_bytes, FALSE, &flags);
685   WSAResetEvent(core_->write_overlapped_.hEvent);
686   int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
687   LogWrite(result, core_->write_iobuffer_->data(), send_to_address_.get());
688 
689   send_to_address_.reset();
690   core_->write_iobuffer_ = nullptr;
691   DoWriteCallback(result);
692 }
693 
OnObjectSignaled(HANDLE object)694 void UDPSocketWin::OnObjectSignaled(HANDLE object) {
695   DCHECK(object == read_write_event_.Get());
696   WSANETWORKEVENTS network_events;
697   int os_error = 0;
698   int rv =
699       WSAEnumNetworkEvents(socket_, read_write_event_.Get(), &network_events);
700   // Protects against trying to call the write callback if the read callback
701   // either closes or destroys |this|.
702   base::WeakPtr<UDPSocketWin> event_pending = event_pending_.GetWeakPtr();
703   if (rv == SOCKET_ERROR) {
704     os_error = WSAGetLastError();
705     rv = MapSystemError(os_error);
706 
707     if (read_iobuffer_) {
708       read_iobuffer_ = nullptr;
709       read_iobuffer_len_ = 0;
710       recv_from_address_ = nullptr;
711       DoReadCallback(rv);
712     }
713 
714     // Socket may have been closed or destroyed here.
715     if (event_pending && write_iobuffer_) {
716       write_iobuffer_ = nullptr;
717       write_iobuffer_len_ = 0;
718       send_to_address_.reset();
719       DoWriteCallback(rv);
720     }
721     return;
722   }
723 
724   if ((network_events.lNetworkEvents & FD_READ) && read_iobuffer_)
725     OnReadSignaled();
726   if (!event_pending)
727     return;
728 
729   if ((network_events.lNetworkEvents & FD_WRITE) && write_iobuffer_)
730     OnWriteSignaled();
731   if (!event_pending)
732     return;
733 
734   // There's still pending read / write. Watch for further events.
735   if (read_iobuffer_ || write_iobuffer_)
736     WatchForReadWrite();
737 }
738 
OnReadSignaled()739 void UDPSocketWin::OnReadSignaled() {
740   int rv = InternalRecvFromNonBlocking(read_iobuffer_.get(), read_iobuffer_len_,
741                                        recv_from_address_);
742   if (rv == ERR_IO_PENDING)
743     return;
744   read_iobuffer_ = nullptr;
745   read_iobuffer_len_ = 0;
746   recv_from_address_ = nullptr;
747   DoReadCallback(rv);
748 }
749 
OnWriteSignaled()750 void UDPSocketWin::OnWriteSignaled() {
751   int rv = InternalSendToNonBlocking(write_iobuffer_.get(), write_iobuffer_len_,
752                                      send_to_address_.get());
753   if (rv == ERR_IO_PENDING)
754     return;
755   write_iobuffer_ = nullptr;
756   write_iobuffer_len_ = 0;
757   send_to_address_.reset();
758   DoWriteCallback(rv);
759 }
760 
WatchForReadWrite()761 void UDPSocketWin::WatchForReadWrite() {
762   if (read_write_watcher_.IsWatching())
763     return;
764   bool watched =
765       read_write_watcher_.StartWatchingOnce(read_write_event_.Get(), this);
766   DCHECK(watched);
767 }
768 
LogRead(int result,const char * bytes,const IPEndPoint * address) const769 void UDPSocketWin::LogRead(int result,
770                            const char* bytes,
771                            const IPEndPoint* address) const {
772   if (result < 0) {
773     net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_RECEIVE_ERROR,
774                                       result);
775     return;
776   }
777 
778   if (net_log_.IsCapturing()) {
779     NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_RECEIVED, result,
780                           bytes, address);
781   }
782 
783   activity_monitor::IncrementBytesReceived(result);
784 }
785 
LogWrite(int result,const char * bytes,const IPEndPoint * address) const786 void UDPSocketWin::LogWrite(int result,
787                             const char* bytes,
788                             const IPEndPoint* address) const {
789   if (result < 0) {
790     net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, result);
791     return;
792   }
793 
794   if (net_log_.IsCapturing()) {
795     NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_SENT, result,
796                           bytes, address);
797   }
798 }
799 
InternalRecvFromOverlapped(IOBuffer * buf,int buf_len,IPEndPoint * address)800 int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf,
801                                              int buf_len,
802                                              IPEndPoint* address) {
803   DCHECK(!core_->read_iobuffer_.get());
804   SockaddrStorage& storage = core_->recv_addr_storage_;
805   storage.addr_len = sizeof(storage.addr_storage);
806 
807   WSABUF read_buffer;
808   read_buffer.buf = buf->data();
809   read_buffer.len = buf_len;
810 
811   DWORD flags = 0;
812   DWORD num;
813   CHECK_NE(INVALID_SOCKET, socket_);
814   int rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, storage.addr,
815                        &storage.addr_len, &core_->read_overlapped_, nullptr);
816   if (rv == 0) {
817     if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) {
818       int result = num;
819       // Convert address.
820       IPEndPoint address_storage;
821       IPEndPoint* address_to_log = nullptr;
822       if (result >= 0) {
823         if (address_storage.FromSockAddr(core_->recv_addr_storage_.addr,
824                                          core_->recv_addr_storage_.addr_len)) {
825           if (address)
826             *address = address_storage;
827           address_to_log = &address_storage;
828         } else {
829           result = ERR_ADDRESS_INVALID;
830         }
831       }
832       LogRead(result, buf->data(), address_to_log);
833       return result;
834     }
835   } else {
836     int os_error = WSAGetLastError();
837     if (os_error != WSA_IO_PENDING) {
838       int result = MapSystemError(os_error);
839       LogRead(result, nullptr, nullptr);
840       return result;
841     }
842   }
843   core_->WatchForRead();
844   core_->read_iobuffer_ = buf;
845   return ERR_IO_PENDING;
846 }
847 
InternalSendToOverlapped(IOBuffer * buf,int buf_len,const IPEndPoint * address)848 int UDPSocketWin::InternalSendToOverlapped(IOBuffer* buf,
849                                            int buf_len,
850                                            const IPEndPoint* address) {
851   DCHECK(!core_->write_iobuffer_.get());
852   SockaddrStorage storage;
853   struct sockaddr* addr = storage.addr;
854   // Convert address.
855   if (!address) {
856     addr = nullptr;
857     storage.addr_len = 0;
858   } else {
859     if (!address->ToSockAddr(addr, &storage.addr_len)) {
860       int result = ERR_ADDRESS_INVALID;
861       LogWrite(result, nullptr, nullptr);
862       return result;
863     }
864   }
865 
866   WSABUF write_buffer;
867   write_buffer.buf = buf->data();
868   write_buffer.len = buf_len;
869 
870   DWORD flags = 0;
871   DWORD num;
872   int rv = WSASendTo(socket_, &write_buffer, 1, &num, flags, addr,
873                      storage.addr_len, &core_->write_overlapped_, nullptr);
874   if (rv == 0) {
875     if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) {
876       int result = num;
877       LogWrite(result, buf->data(), address);
878       return result;
879     }
880   } else {
881     int os_error = WSAGetLastError();
882     if (os_error != WSA_IO_PENDING) {
883       int result = MapSystemError(os_error);
884       LogWrite(result, nullptr, nullptr);
885       return result;
886     }
887   }
888 
889   core_->WatchForWrite();
890   core_->write_iobuffer_ = buf;
891   return ERR_IO_PENDING;
892 }
893 
InternalRecvFromNonBlocking(IOBuffer * buf,int buf_len,IPEndPoint * address)894 int UDPSocketWin::InternalRecvFromNonBlocking(IOBuffer* buf,
895                                               int buf_len,
896                                               IPEndPoint* address) {
897   DCHECK(!read_iobuffer_ || read_iobuffer_.get() == buf);
898   SockaddrStorage storage;
899   storage.addr_len = sizeof(storage.addr_storage);
900 
901   CHECK_NE(INVALID_SOCKET, socket_);
902   int rv = recvfrom(socket_, buf->data(), buf_len, 0, storage.addr,
903                     &storage.addr_len);
904   if (rv == SOCKET_ERROR) {
905     int os_error = WSAGetLastError();
906     if (os_error == WSAEWOULDBLOCK) {
907       read_iobuffer_ = buf;
908       read_iobuffer_len_ = buf_len;
909       WatchForReadWrite();
910       return ERR_IO_PENDING;
911     }
912     rv = MapSystemError(os_error);
913     LogRead(rv, nullptr, nullptr);
914     return rv;
915   }
916   IPEndPoint address_storage;
917   IPEndPoint* address_to_log = nullptr;
918   if (rv >= 0) {
919     if (address_storage.FromSockAddr(storage.addr, storage.addr_len)) {
920       if (address)
921         *address = address_storage;
922       address_to_log = &address_storage;
923     } else {
924       rv = ERR_ADDRESS_INVALID;
925     }
926   }
927   LogRead(rv, buf->data(), address_to_log);
928   return rv;
929 }
930 
InternalSendToNonBlocking(IOBuffer * buf,int buf_len,const IPEndPoint * address)931 int UDPSocketWin::InternalSendToNonBlocking(IOBuffer* buf,
932                                             int buf_len,
933                                             const IPEndPoint* address) {
934   DCHECK(!write_iobuffer_ || write_iobuffer_.get() == buf);
935   SockaddrStorage storage;
936   struct sockaddr* addr = storage.addr;
937   // Convert address.
938   if (address) {
939     if (!address->ToSockAddr(addr, &storage.addr_len)) {
940       int result = ERR_ADDRESS_INVALID;
941       LogWrite(result, nullptr, nullptr);
942       return result;
943     }
944   } else {
945     addr = nullptr;
946     storage.addr_len = 0;
947   }
948 
949   int rv = sendto(socket_, buf->data(), buf_len, 0, addr, storage.addr_len);
950   if (rv == SOCKET_ERROR) {
951     int os_error = WSAGetLastError();
952     if (os_error == WSAEWOULDBLOCK) {
953       write_iobuffer_ = buf;
954       write_iobuffer_len_ = buf_len;
955       WatchForReadWrite();
956       return ERR_IO_PENDING;
957     }
958     rv = MapSystemError(os_error);
959     LogWrite(rv, nullptr, nullptr);
960     return rv;
961   }
962   LogWrite(rv, buf->data(), address);
963   return rv;
964 }
965 
SetMulticastOptions()966 int UDPSocketWin::SetMulticastOptions() {
967   if (!(socket_options_ & SOCKET_OPTION_MULTICAST_LOOP)) {
968     DWORD loop = 0;
969     int protocol_level =
970         addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
971     int option =
972         addr_family_ == AF_INET ? IP_MULTICAST_LOOP: IPV6_MULTICAST_LOOP;
973     int rv = setsockopt(socket_, protocol_level, option,
974                         reinterpret_cast<const char*>(&loop), sizeof(loop));
975     if (rv < 0)
976       return MapSystemError(WSAGetLastError());
977   }
978   if (multicast_time_to_live_ != 1) {
979     DWORD hops = multicast_time_to_live_;
980     int protocol_level =
981         addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
982     int option =
983         addr_family_ == AF_INET ? IP_MULTICAST_TTL: IPV6_MULTICAST_HOPS;
984     int rv = setsockopt(socket_, protocol_level, option,
985                         reinterpret_cast<const char*>(&hops), sizeof(hops));
986     if (rv < 0)
987       return MapSystemError(WSAGetLastError());
988   }
989   if (multicast_interface_ != 0) {
990     switch (addr_family_) {
991       case AF_INET: {
992         in_addr address;
993         address.s_addr = htonl(multicast_interface_);
994         int rv = setsockopt(socket_, IPPROTO_IP, IP_MULTICAST_IF,
995                             reinterpret_cast<const char*>(&address),
996                             sizeof(address));
997         if (rv)
998           return MapSystemError(WSAGetLastError());
999         break;
1000       }
1001       case AF_INET6: {
1002         uint32_t interface_index = multicast_interface_;
1003         int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_MULTICAST_IF,
1004                             reinterpret_cast<const char*>(&interface_index),
1005                             sizeof(interface_index));
1006         if (rv)
1007           return MapSystemError(WSAGetLastError());
1008         break;
1009       }
1010       default:
1011         NOTREACHED() << "Invalid address family";
1012         return ERR_ADDRESS_INVALID;
1013     }
1014   }
1015   return OK;
1016 }
1017 
DoBind(const IPEndPoint & address)1018 int UDPSocketWin::DoBind(const IPEndPoint& address) {
1019   SockaddrStorage storage;
1020   if (!address.ToSockAddr(storage.addr, &storage.addr_len))
1021     return ERR_ADDRESS_INVALID;
1022   int rv = bind(socket_, storage.addr, storage.addr_len);
1023   if (rv == 0)
1024     return OK;
1025   int last_error = WSAGetLastError();
1026   // Map some codes that are special to bind() separately.
1027   // * WSAEACCES: If a port is already bound to a socket, WSAEACCES may be
1028   //   returned instead of WSAEADDRINUSE, depending on whether the socket
1029   //   option SO_REUSEADDR or SO_EXCLUSIVEADDRUSE is set and whether the
1030   //   conflicting socket is owned by a different user account. See the MSDN
1031   //   page "Using SO_REUSEADDR and SO_EXCLUSIVEADDRUSE" for the gory details.
1032   if (last_error == WSAEACCES || last_error == WSAEADDRNOTAVAIL)
1033     return ERR_ADDRESS_IN_USE;
1034   return MapSystemError(last_error);
1035 }
1036 
GetQwaveApi() const1037 QwaveApi* UDPSocketWin::GetQwaveApi() const {
1038   return QwaveApi::GetDefault();
1039 }
1040 
JoinGroup(const IPAddress & group_address) const1041 int UDPSocketWin::JoinGroup(const IPAddress& group_address) const {
1042   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1043   if (!is_connected())
1044     return ERR_SOCKET_NOT_CONNECTED;
1045 
1046   switch (group_address.size()) {
1047     case IPAddress::kIPv4AddressSize: {
1048       if (addr_family_ != AF_INET)
1049         return ERR_ADDRESS_INVALID;
1050       ip_mreq mreq;
1051       mreq.imr_interface.s_addr = htonl(multicast_interface_);
1052       memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1053              IPAddress::kIPv4AddressSize);
1054       int rv = setsockopt(socket_, IPPROTO_IP, IP_ADD_MEMBERSHIP,
1055                           reinterpret_cast<const char*>(&mreq),
1056                           sizeof(mreq));
1057       if (rv)
1058         return MapSystemError(WSAGetLastError());
1059       return OK;
1060     }
1061     case IPAddress::kIPv6AddressSize: {
1062       if (addr_family_ != AF_INET6)
1063         return ERR_ADDRESS_INVALID;
1064       ipv6_mreq mreq;
1065       mreq.ipv6mr_interface = multicast_interface_;
1066       memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1067              IPAddress::kIPv6AddressSize);
1068       int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
1069                           reinterpret_cast<const char*>(&mreq),
1070                           sizeof(mreq));
1071       if (rv)
1072         return MapSystemError(WSAGetLastError());
1073       return OK;
1074     }
1075     default:
1076       NOTREACHED() << "Invalid address family";
1077       return ERR_ADDRESS_INVALID;
1078   }
1079 }
1080 
LeaveGroup(const IPAddress & group_address) const1081 int UDPSocketWin::LeaveGroup(const IPAddress& group_address) const {
1082   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1083   if (!is_connected())
1084     return ERR_SOCKET_NOT_CONNECTED;
1085 
1086   switch (group_address.size()) {
1087     case IPAddress::kIPv4AddressSize: {
1088       if (addr_family_ != AF_INET)
1089         return ERR_ADDRESS_INVALID;
1090       ip_mreq mreq;
1091       mreq.imr_interface.s_addr = htonl(multicast_interface_);
1092       memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1093              IPAddress::kIPv4AddressSize);
1094       int rv = setsockopt(socket_, IPPROTO_IP, IP_DROP_MEMBERSHIP,
1095                           reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1096       if (rv)
1097         return MapSystemError(WSAGetLastError());
1098       return OK;
1099     }
1100     case IPAddress::kIPv6AddressSize: {
1101       if (addr_family_ != AF_INET6)
1102         return ERR_ADDRESS_INVALID;
1103       ipv6_mreq mreq;
1104       mreq.ipv6mr_interface = multicast_interface_;
1105       memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1106              IPAddress::kIPv6AddressSize);
1107       int rv = setsockopt(socket_, IPPROTO_IPV6, IP_DROP_MEMBERSHIP,
1108                           reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1109       if (rv)
1110         return MapSystemError(WSAGetLastError());
1111       return OK;
1112     }
1113     default:
1114       NOTREACHED() << "Invalid address family";
1115       return ERR_ADDRESS_INVALID;
1116   }
1117 }
1118 
SetMulticastInterface(uint32_t interface_index)1119 int UDPSocketWin::SetMulticastInterface(uint32_t interface_index) {
1120   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1121   if (is_connected())
1122     return ERR_SOCKET_IS_CONNECTED;
1123   multicast_interface_ = interface_index;
1124   return OK;
1125 }
1126 
SetMulticastTimeToLive(int time_to_live)1127 int UDPSocketWin::SetMulticastTimeToLive(int time_to_live) {
1128   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1129   if (is_connected())
1130     return ERR_SOCKET_IS_CONNECTED;
1131 
1132   if (time_to_live < 0 || time_to_live > 255)
1133     return ERR_INVALID_ARGUMENT;
1134   multicast_time_to_live_ = time_to_live;
1135   return OK;
1136 }
1137 
SetMulticastLoopbackMode(bool loopback)1138 int UDPSocketWin::SetMulticastLoopbackMode(bool loopback) {
1139   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1140   if (is_connected())
1141     return ERR_SOCKET_IS_CONNECTED;
1142 
1143   if (loopback)
1144     socket_options_ |= SOCKET_OPTION_MULTICAST_LOOP;
1145   else
1146     socket_options_ &= ~SOCKET_OPTION_MULTICAST_LOOP;
1147   return OK;
1148 }
1149 
DscpToTrafficType(DiffServCodePoint dscp)1150 QOS_TRAFFIC_TYPE DscpToTrafficType(DiffServCodePoint dscp) {
1151   QOS_TRAFFIC_TYPE traffic_type = QOSTrafficTypeBestEffort;
1152   switch (dscp) {
1153     case DSCP_CS0:
1154       traffic_type = QOSTrafficTypeBestEffort;
1155       break;
1156     case DSCP_CS1:
1157       traffic_type = QOSTrafficTypeBackground;
1158       break;
1159     case DSCP_AF11:
1160     case DSCP_AF12:
1161     case DSCP_AF13:
1162     case DSCP_CS2:
1163     case DSCP_AF21:
1164     case DSCP_AF22:
1165     case DSCP_AF23:
1166     case DSCP_CS3:
1167     case DSCP_AF31:
1168     case DSCP_AF32:
1169     case DSCP_AF33:
1170     case DSCP_CS4:
1171       traffic_type = QOSTrafficTypeExcellentEffort;
1172       break;
1173     case DSCP_AF41:
1174     case DSCP_AF42:
1175     case DSCP_AF43:
1176     case DSCP_CS5:
1177       traffic_type = QOSTrafficTypeAudioVideo;
1178       break;
1179     case DSCP_EF:
1180     case DSCP_CS6:
1181       traffic_type = QOSTrafficTypeVoice;
1182       break;
1183     case DSCP_CS7:
1184       traffic_type = QOSTrafficTypeControl;
1185       break;
1186     case DSCP_NO_CHANGE:
1187       NOTREACHED();
1188       break;
1189   }
1190   return traffic_type;
1191 }
1192 
SetDiffServCodePoint(DiffServCodePoint dscp)1193 int UDPSocketWin::SetDiffServCodePoint(DiffServCodePoint dscp) {
1194   if (dscp == DSCP_NO_CHANGE)
1195     return OK;
1196 
1197   if (!is_connected())
1198     return ERR_SOCKET_NOT_CONNECTED;
1199 
1200   QwaveApi* api = GetQwaveApi();
1201 
1202   if (!api->qwave_supported())
1203     return ERR_NOT_IMPLEMENTED;
1204 
1205   if (!dscp_manager_)
1206     dscp_manager_ = std::make_unique<DscpManager>(api, socket_);
1207 
1208   dscp_manager_->Set(dscp);
1209   if (remote_address_)
1210     return dscp_manager_->PrepareForSend(*remote_address_.get());
1211 
1212   return OK;
1213 }
1214 
SetIPv6Only(bool ipv6_only)1215 int UDPSocketWin::SetIPv6Only(bool ipv6_only) {
1216   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1217   if (is_connected()) {
1218     return ERR_SOCKET_IS_CONNECTED;
1219   }
1220   return net::SetIPv6Only(socket_, ipv6_only);
1221 }
1222 
DetachFromThread()1223 void UDPSocketWin::DetachFromThread() {
1224   DETACH_FROM_THREAD(thread_checker_);
1225 }
1226 
UseNonBlockingIO()1227 void UDPSocketWin::UseNonBlockingIO() {
1228   DCHECK(!core_);
1229   use_non_blocking_io_ = true;
1230 }
1231 
ApplySocketTag(const SocketTag & tag)1232 void UDPSocketWin::ApplySocketTag(const SocketTag& tag) {
1233   // Windows does not support any specific SocketTags so fail if any non-default
1234   // tag is applied.
1235   CHECK(tag == SocketTag());
1236 }
1237 
DscpManager(QwaveApi * api,SOCKET socket)1238 DscpManager::DscpManager(QwaveApi* api, SOCKET socket)
1239     : api_(api), socket_(socket) {
1240   RequestHandle();
1241 }
1242 
~DscpManager()1243 DscpManager::~DscpManager() {
1244   if (!qos_handle_)
1245     return;
1246 
1247   if (flow_id_ != 0)
1248     api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1249 
1250   api_->CloseHandle(qos_handle_);
1251 }
1252 
Set(DiffServCodePoint dscp)1253 void DscpManager::Set(DiffServCodePoint dscp) {
1254   if (dscp == DSCP_NO_CHANGE || dscp == dscp_value_)
1255     return;
1256 
1257   dscp_value_ = dscp;
1258 
1259   // TODO(zstein): We could reuse the flow when the value changes
1260   // by calling QOSSetFlow with the new traffic type and dscp value.
1261   if (flow_id_ != 0 && qos_handle_) {
1262     api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1263     configured_.clear();
1264     flow_id_ = 0;
1265   }
1266 }
1267 
PrepareForSend(const IPEndPoint & remote_address)1268 int DscpManager::PrepareForSend(const IPEndPoint& remote_address) {
1269   if (dscp_value_ == DSCP_NO_CHANGE) {
1270     // No DSCP value has been set.
1271     return OK;
1272   }
1273 
1274   if (!api_->qwave_supported())
1275     return ERR_NOT_IMPLEMENTED;
1276 
1277   if (!qos_handle_)
1278     return ERR_INVALID_HANDLE;  // The closest net error to try again later.
1279 
1280   if (configured_.find(remote_address) != configured_.end())
1281     return OK;
1282 
1283   SockaddrStorage storage;
1284   if (!remote_address.ToSockAddr(storage.addr, &storage.addr_len))
1285     return ERR_ADDRESS_INVALID;
1286 
1287   // We won't try this address again if we get an error.
1288   configured_.emplace(remote_address);
1289 
1290   // We don't need to call SetFlow if we already have a qos flow.
1291   bool new_flow = flow_id_ == 0;
1292 
1293   const QOS_TRAFFIC_TYPE traffic_type = DscpToTrafficType(dscp_value_);
1294 
1295   if (!api_->AddSocketToFlow(qos_handle_, socket_, storage.addr, traffic_type,
1296                              QOS_NON_ADAPTIVE_FLOW, &flow_id_)) {
1297     DWORD err = ::GetLastError();
1298     if (err == ERROR_DEVICE_REINITIALIZATION_NEEDED) {
1299       // Reset. PrepareForSend is called for every packet.  Once RequestHandle
1300       // completes asynchronously the next PrepareForSend call will re-register
1301       // the address with the new QoS Handle.  In the meantime, sends will
1302       // continue without DSCP.
1303       RequestHandle();
1304       configured_.clear();
1305       flow_id_ = 0;
1306       return ERR_INVALID_HANDLE;
1307     }
1308     return MapSystemError(err);
1309   }
1310 
1311   if (new_flow) {
1312     DWORD buf = dscp_value_;
1313     // This requires admin rights, and may fail, if so we ignore it
1314     // as AddSocketToFlow should still do *approximately* the right thing.
1315     api_->SetFlow(qos_handle_, flow_id_, QOSSetOutgoingDSCPValue, sizeof(buf),
1316                   &buf, 0, nullptr);
1317   }
1318 
1319   return OK;
1320 }
1321 
RequestHandle()1322 void DscpManager::RequestHandle() {
1323   if (handle_is_initializing_)
1324     return;
1325 
1326   if (qos_handle_) {
1327     api_->CloseHandle(qos_handle_);
1328     qos_handle_ = nullptr;
1329   }
1330 
1331   handle_is_initializing_ = true;
1332   base::ThreadPool::PostTaskAndReplyWithResult(
1333       FROM_HERE, {base::MayBlock()},
1334       base::BindOnce(&DscpManager::DoCreateHandle, api_),
1335       base::BindOnce(&DscpManager::OnHandleCreated, api_,
1336                      weak_ptr_factory_.GetWeakPtr()));
1337 }
1338 
DoCreateHandle(QwaveApi * api)1339 HANDLE DscpManager::DoCreateHandle(QwaveApi* api) {
1340   QOS_VERSION version;
1341   version.MajorVersion = 1;
1342   version.MinorVersion = 0;
1343 
1344   HANDLE handle = nullptr;
1345 
1346   // No access to net_log_ so swallow any errors here.
1347   api->CreateHandle(&version, &handle);
1348   return handle;
1349 }
1350 
OnHandleCreated(QwaveApi * api,base::WeakPtr<DscpManager> dscp_manager,HANDLE handle)1351 void DscpManager::OnHandleCreated(QwaveApi* api,
1352                                   base::WeakPtr<DscpManager> dscp_manager,
1353                                   HANDLE handle) {
1354   if (!handle)
1355     api->OnFatalError();
1356 
1357   if (!dscp_manager) {
1358     api->CloseHandle(handle);
1359     return;
1360   }
1361 
1362   DCHECK(dscp_manager->handle_is_initializing_);
1363   DCHECK(!dscp_manager->qos_handle_);
1364 
1365   dscp_manager->qos_handle_ = handle;
1366   dscp_manager->handle_is_initializing_ = false;
1367 }
1368 
1369 }  // namespace net
1370