• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/udp_socket_win.h"
6 
7 #include <mstcpip.h>
8 #include <winsock2.h>
9 
10 #include <memory>
11 
12 #include "base/check_op.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback.h"
15 #include "base/lazy_instance.h"
16 #include "base/memory/raw_ptr.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/metrics/histogram_macros.h"
19 #include "base/notreached.h"
20 #include "base/rand_util.h"
21 #include "base/task/thread_pool.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/network_activity_monitor.h"
27 #include "net/base/network_change_notifier.h"
28 #include "net/base/sockaddr_storage.h"
29 #include "net/base/winsock_init.h"
30 #include "net/base/winsock_util.h"
31 #include "net/log/net_log.h"
32 #include "net/log/net_log_event_type.h"
33 #include "net/log/net_log_source.h"
34 #include "net/log/net_log_source_type.h"
35 #include "net/socket/socket_descriptor.h"
36 #include "net/socket/socket_options.h"
37 #include "net/socket/socket_tag.h"
38 #include "net/socket/udp_net_log_parameters.h"
39 #include "net/traffic_annotation/network_traffic_annotation.h"
40 
41 namespace net {
42 
43 // This class encapsulates all the state that has to be preserved as long as
44 // there is a network IO operation in progress. If the owner UDPSocketWin
45 // is destroyed while an operation is in progress, the Core is detached and it
46 // lives until the operation completes and the OS doesn't reference any resource
47 // declared on this class anymore.
48 class UDPSocketWin::Core : public base::RefCounted<Core> {
49  public:
50   explicit Core(UDPSocketWin* socket);
51 
52   Core(const Core&) = delete;
53   Core& operator=(const Core&) = delete;
54 
55   // Start watching for the end of a read or write operation.
56   void WatchForRead();
57   void WatchForWrite();
58 
59   // The UDPSocketWin is going away.
Detach()60   void Detach() { socket_ = nullptr; }
61 
62   // The separate OVERLAPPED variables for asynchronous operation.
63   OVERLAPPED read_overlapped_;
64   OVERLAPPED write_overlapped_;
65 
66   // The buffers used in Read() and Write().
67   scoped_refptr<IOBuffer> read_iobuffer_;
68   scoped_refptr<IOBuffer> write_iobuffer_;
69 
70   // The address storage passed to WSARecvFrom().
71   SockaddrStorage recv_addr_storage_;
72 
73  private:
74   friend class base::RefCounted<Core>;
75 
76   class ReadDelegate : public base::win::ObjectWatcher::Delegate {
77    public:
ReadDelegate(Core * core)78     explicit ReadDelegate(Core* core) : core_(core) {}
79     ~ReadDelegate() override = default;
80 
81     // base::ObjectWatcher::Delegate methods:
82     void OnObjectSignaled(HANDLE object) override;
83 
84    private:
85     const raw_ptr<Core> core_;
86   };
87 
88   class WriteDelegate : public base::win::ObjectWatcher::Delegate {
89    public:
WriteDelegate(Core * core)90     explicit WriteDelegate(Core* core) : core_(core) {}
91     ~WriteDelegate() override = default;
92 
93     // base::ObjectWatcher::Delegate methods:
94     void OnObjectSignaled(HANDLE object) override;
95 
96    private:
97     const raw_ptr<Core> core_;
98   };
99 
100   ~Core();
101 
102   // The socket that created this object.
103   raw_ptr<UDPSocketWin> socket_;
104 
105   // |reader_| handles the signals from |read_watcher_|.
106   ReadDelegate reader_;
107   // |writer_| handles the signals from |write_watcher_|.
108   WriteDelegate writer_;
109 
110   // |read_watcher_| watches for events from Read().
111   base::win::ObjectWatcher read_watcher_;
112   // |write_watcher_| watches for events from Write();
113   base::win::ObjectWatcher write_watcher_;
114 };
115 
Core(UDPSocketWin * socket)116 UDPSocketWin::Core::Core(UDPSocketWin* socket)
117     : socket_(socket),
118       reader_(this),
119       writer_(this) {
120   memset(&read_overlapped_, 0, sizeof(read_overlapped_));
121   memset(&write_overlapped_, 0, sizeof(write_overlapped_));
122 
123   read_overlapped_.hEvent = WSACreateEvent();
124   write_overlapped_.hEvent = WSACreateEvent();
125 }
126 
~Core()127 UDPSocketWin::Core::~Core() {
128   // Make sure the message loop is not watching this object anymore.
129   read_watcher_.StopWatching();
130   write_watcher_.StopWatching();
131 
132   WSACloseEvent(read_overlapped_.hEvent);
133   memset(&read_overlapped_, 0xaf, sizeof(read_overlapped_));
134   WSACloseEvent(write_overlapped_.hEvent);
135   memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
136 }
137 
WatchForRead()138 void UDPSocketWin::Core::WatchForRead() {
139   // We grab an extra reference because there is an IO operation in progress.
140   // Balanced in ReadDelegate::OnObjectSignaled().
141   AddRef();
142   read_watcher_.StartWatchingOnce(read_overlapped_.hEvent, &reader_);
143 }
144 
WatchForWrite()145 void UDPSocketWin::Core::WatchForWrite() {
146   // We grab an extra reference because there is an IO operation in progress.
147   // Balanced in WriteDelegate::OnObjectSignaled().
148   AddRef();
149   write_watcher_.StartWatchingOnce(write_overlapped_.hEvent, &writer_);
150 }
151 
OnObjectSignaled(HANDLE object)152 void UDPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
153   DCHECK_EQ(object, core_->read_overlapped_.hEvent);
154   if (core_->socket_)
155     core_->socket_->DidCompleteRead();
156 
157   core_->Release();
158 }
159 
OnObjectSignaled(HANDLE object)160 void UDPSocketWin::Core::WriteDelegate::OnObjectSignaled(HANDLE object) {
161   DCHECK_EQ(object, core_->write_overlapped_.hEvent);
162   if (core_->socket_)
163     core_->socket_->DidCompleteWrite();
164 
165   core_->Release();
166 }
167 //-----------------------------------------------------------------------------
168 
QwaveApi()169 QwaveApi::QwaveApi() {
170   HMODULE qwave = LoadLibrary(L"qwave.dll");
171   if (!qwave)
172     return;
173   create_handle_func_ =
174       (CreateHandleFn)GetProcAddress(qwave, "QOSCreateHandle");
175   close_handle_func_ =
176       (CloseHandleFn)GetProcAddress(qwave, "QOSCloseHandle");
177   add_socket_to_flow_func_ =
178       (AddSocketToFlowFn)GetProcAddress(qwave, "QOSAddSocketToFlow");
179   remove_socket_from_flow_func_ =
180       (RemoveSocketFromFlowFn)GetProcAddress(qwave, "QOSRemoveSocketFromFlow");
181   set_flow_func_ = (SetFlowFn)GetProcAddress(qwave, "QOSSetFlow");
182 
183   if (create_handle_func_ && close_handle_func_ &&
184       add_socket_to_flow_func_ && remove_socket_from_flow_func_ &&
185       set_flow_func_) {
186     qwave_supported_ = true;
187   }
188 }
189 
GetDefault()190 QwaveApi* QwaveApi::GetDefault() {
191   static base::LazyInstance<QwaveApi>::Leaky lazy_qwave =
192       LAZY_INSTANCE_INITIALIZER;
193   return lazy_qwave.Pointer();
194 }
195 
qwave_supported() const196 bool QwaveApi::qwave_supported() const {
197   return qwave_supported_;
198 }
199 
OnFatalError()200 void QwaveApi::OnFatalError() {
201   // Disable everything moving forward.
202   qwave_supported_ = false;
203 }
204 
CreateHandle(PQOS_VERSION version,PHANDLE handle)205 BOOL QwaveApi::CreateHandle(PQOS_VERSION version, PHANDLE handle) {
206   return create_handle_func_(version, handle);
207 }
208 
CloseHandle(HANDLE handle)209 BOOL QwaveApi::CloseHandle(HANDLE handle) {
210   return close_handle_func_(handle);
211 }
212 
AddSocketToFlow(HANDLE handle,SOCKET socket,PSOCKADDR addr,QOS_TRAFFIC_TYPE traffic_type,DWORD flags,PQOS_FLOWID flow_id)213 BOOL QwaveApi::AddSocketToFlow(HANDLE handle,
214                                SOCKET socket,
215                                PSOCKADDR addr,
216                                QOS_TRAFFIC_TYPE traffic_type,
217                                DWORD flags,
218                                PQOS_FLOWID flow_id) {
219   return add_socket_to_flow_func_(handle, socket, addr, traffic_type, flags,
220                                   flow_id);
221 }
222 
RemoveSocketFromFlow(HANDLE handle,SOCKET socket,QOS_FLOWID flow_id,DWORD reserved)223 BOOL QwaveApi::RemoveSocketFromFlow(HANDLE handle,
224                                     SOCKET socket,
225                                     QOS_FLOWID flow_id,
226                                     DWORD reserved) {
227   return remove_socket_from_flow_func_(handle, socket, flow_id, reserved);
228 }
229 
SetFlow(HANDLE handle,QOS_FLOWID flow_id,QOS_SET_FLOW op,ULONG size,PVOID data,DWORD reserved,LPOVERLAPPED overlapped)230 BOOL QwaveApi::SetFlow(HANDLE handle,
231                        QOS_FLOWID flow_id,
232                        QOS_SET_FLOW op,
233                        ULONG size,
234                        PVOID data,
235                        DWORD reserved,
236                        LPOVERLAPPED overlapped) {
237   return set_flow_func_(handle, flow_id, op, size, data, reserved, overlapped);
238 }
239 
240 //-----------------------------------------------------------------------------
241 
UDPSocketWin(DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)242 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
243                            net::NetLog* net_log,
244                            const net::NetLogSource& source)
245     : socket_(INVALID_SOCKET),
246       socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
247       net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::UDP_SOCKET)) {
248   EnsureWinsockInit();
249   net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source);
250 }
251 
~UDPSocketWin()252 UDPSocketWin::~UDPSocketWin() {
253   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
254   Close();
255   net_log_.EndEvent(NetLogEventType::SOCKET_ALIVE);
256 }
257 
Open(AddressFamily address_family)258 int UDPSocketWin::Open(AddressFamily address_family) {
259   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
260   DCHECK_EQ(socket_, INVALID_SOCKET);
261 
262   auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
263   if (owned_socket_count.empty())
264     return ERR_INSUFFICIENT_RESOURCES;
265 
266   owned_socket_count_ = std::move(owned_socket_count);
267   addr_family_ = ConvertAddressFamily(address_family);
268   socket_ = CreatePlatformSocket(addr_family_, SOCK_DGRAM, IPPROTO_UDP);
269   if (socket_ == INVALID_SOCKET) {
270     owned_socket_count_.Reset();
271     return MapSystemError(WSAGetLastError());
272   }
273   ConfigureOpenedSocket();
274   return OK;
275 }
276 
AdoptOpenedSocket(AddressFamily address_family,SOCKET socket)277 int UDPSocketWin::AdoptOpenedSocket(AddressFamily address_family,
278                                     SOCKET socket) {
279   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
280   auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
281   if (owned_socket_count.empty()) {
282     return ERR_INSUFFICIENT_RESOURCES;
283   }
284 
285   owned_socket_count_ = std::move(owned_socket_count);
286   addr_family_ = ConvertAddressFamily(address_family);
287   socket_ = socket;
288   ConfigureOpenedSocket();
289   return OK;
290 }
291 
ConfigureOpenedSocket()292 void UDPSocketWin::ConfigureOpenedSocket() {
293   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
294   if (!use_non_blocking_io_) {
295     core_ = base::MakeRefCounted<Core>(this);
296   } else {
297     read_write_event_.Set(WSACreateEvent());
298     WSAEventSelect(socket_, read_write_event_.Get(), FD_READ | FD_WRITE);
299   }
300 }
301 
Close()302 void UDPSocketWin::Close() {
303   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
304 
305   owned_socket_count_.Reset();
306 
307   if (socket_ == INVALID_SOCKET)
308     return;
309 
310   // Remove socket_ from the QoS subsystem before we invalidate it.
311   dscp_manager_ = nullptr;
312 
313   // Zero out any pending read/write callback state.
314   read_callback_.Reset();
315   recv_from_address_ = nullptr;
316   write_callback_.Reset();
317 
318   base::TimeTicks start_time = base::TimeTicks::Now();
319   closesocket(socket_);
320   UMA_HISTOGRAM_TIMES("Net.UDPSocketWinClose",
321                       base::TimeTicks::Now() - start_time);
322   socket_ = INVALID_SOCKET;
323   addr_family_ = 0;
324   is_connected_ = false;
325 
326   // Release buffers to free up memory.
327   read_iobuffer_ = nullptr;
328   read_iobuffer_len_ = 0;
329   write_iobuffer_ = nullptr;
330   write_iobuffer_len_ = 0;
331 
332   read_write_watcher_.StopWatching();
333   read_write_event_.Close();
334 
335   event_pending_.InvalidateWeakPtrs();
336 
337   if (core_) {
338     core_->Detach();
339     core_ = nullptr;
340   }
341 }
342 
GetPeerAddress(IPEndPoint * address) const343 int UDPSocketWin::GetPeerAddress(IPEndPoint* address) const {
344   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
345   DCHECK(address);
346   if (!is_connected())
347     return ERR_SOCKET_NOT_CONNECTED;
348 
349   // TODO(szym): Simplify. http://crbug.com/126152
350   if (!remote_address_.get()) {
351     SockaddrStorage storage;
352     if (getpeername(socket_, storage.addr, &storage.addr_len))
353       return MapSystemError(WSAGetLastError());
354     auto remote_address = std::make_unique<IPEndPoint>();
355     if (!remote_address->FromSockAddr(storage.addr, storage.addr_len))
356       return ERR_ADDRESS_INVALID;
357     remote_address_ = std::move(remote_address);
358   }
359 
360   *address = *remote_address_;
361   return OK;
362 }
363 
GetLocalAddress(IPEndPoint * address) const364 int UDPSocketWin::GetLocalAddress(IPEndPoint* address) const {
365   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
366   DCHECK(address);
367   if (!is_connected())
368     return ERR_SOCKET_NOT_CONNECTED;
369 
370   // TODO(szym): Simplify. http://crbug.com/126152
371   if (!local_address_.get()) {
372     SockaddrStorage storage;
373     if (getsockname(socket_, storage.addr, &storage.addr_len))
374       return MapSystemError(WSAGetLastError());
375     auto local_address = std::make_unique<IPEndPoint>();
376     if (!local_address->FromSockAddr(storage.addr, storage.addr_len))
377       return ERR_ADDRESS_INVALID;
378     local_address_ = std::move(local_address);
379     net_log_.AddEvent(NetLogEventType::UDP_LOCAL_ADDRESS, [&] {
380       return CreateNetLogUDPConnectParams(*local_address_,
381                                           handles::kInvalidNetworkHandle);
382     });
383   }
384 
385   *address = *local_address_;
386   return OK;
387 }
388 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)389 int UDPSocketWin::Read(IOBuffer* buf,
390                        int buf_len,
391                        CompletionOnceCallback callback) {
392   return RecvFrom(buf, buf_len, nullptr, std::move(callback));
393 }
394 
RecvFrom(IOBuffer * buf,int buf_len,IPEndPoint * address,CompletionOnceCallback callback)395 int UDPSocketWin::RecvFrom(IOBuffer* buf,
396                            int buf_len,
397                            IPEndPoint* address,
398                            CompletionOnceCallback callback) {
399   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
400   DCHECK_NE(INVALID_SOCKET, socket_);
401   CHECK(read_callback_.is_null());
402   DCHECK(!recv_from_address_);
403   DCHECK(!callback.is_null());  // Synchronous operation not supported.
404   DCHECK_GT(buf_len, 0);
405 
406   int nread = core_ ? InternalRecvFromOverlapped(buf, buf_len, address)
407                     : InternalRecvFromNonBlocking(buf, buf_len, address);
408   if (nread != ERR_IO_PENDING)
409     return nread;
410 
411   read_callback_ = std::move(callback);
412   recv_from_address_ = address;
413   return ERR_IO_PENDING;
414 }
415 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)416 int UDPSocketWin::Write(
417     IOBuffer* buf,
418     int buf_len,
419     CompletionOnceCallback callback,
420     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
421   return SendToOrWrite(buf, buf_len, remote_address_.get(),
422                        std::move(callback));
423 }
424 
SendTo(IOBuffer * buf,int buf_len,const IPEndPoint & address,CompletionOnceCallback callback)425 int UDPSocketWin::SendTo(IOBuffer* buf,
426                          int buf_len,
427                          const IPEndPoint& address,
428                          CompletionOnceCallback callback) {
429   if (dscp_manager_) {
430     // Alert DscpManager in case this is a new remote address.  Failure to
431     // apply Dscp code is never fatal.
432     int rv = dscp_manager_->PrepareForSend(address);
433     if (rv != OK)
434       net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, rv);
435   }
436   return SendToOrWrite(buf, buf_len, &address, std::move(callback));
437 }
438 
SendToOrWrite(IOBuffer * buf,int buf_len,const IPEndPoint * address,CompletionOnceCallback callback)439 int UDPSocketWin::SendToOrWrite(IOBuffer* buf,
440                                 int buf_len,
441                                 const IPEndPoint* address,
442                                 CompletionOnceCallback callback) {
443   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
444   DCHECK_NE(INVALID_SOCKET, socket_);
445   CHECK(write_callback_.is_null());
446   DCHECK(!callback.is_null());  // Synchronous operation not supported.
447   DCHECK_GT(buf_len, 0);
448   DCHECK(!send_to_address_.get());
449 
450   int nwrite = core_ ? InternalSendToOverlapped(buf, buf_len, address)
451                      : InternalSendToNonBlocking(buf, buf_len, address);
452   if (nwrite != ERR_IO_PENDING)
453     return nwrite;
454 
455   if (address)
456     send_to_address_ = std::make_unique<IPEndPoint>(*address);
457   write_callback_ = std::move(callback);
458   return ERR_IO_PENDING;
459 }
460 
Connect(const IPEndPoint & address)461 int UDPSocketWin::Connect(const IPEndPoint& address) {
462   DCHECK_NE(socket_, INVALID_SOCKET);
463   net_log_.BeginEvent(NetLogEventType::UDP_CONNECT, [&] {
464     return CreateNetLogUDPConnectParams(address,
465                                         handles::kInvalidNetworkHandle);
466   });
467   int rv = SetMulticastOptions();
468   if (rv != OK)
469     return rv;
470   rv = InternalConnect(address);
471   net_log_.EndEventWithNetErrorCode(NetLogEventType::UDP_CONNECT, rv);
472   is_connected_ = (rv == OK);
473   return rv;
474 }
475 
InternalConnect(const IPEndPoint & address)476 int UDPSocketWin::InternalConnect(const IPEndPoint& address) {
477   DCHECK(!is_connected());
478   DCHECK(!remote_address_.get());
479 
480   // Always do a random bind.
481   // Ignore failures, which may happen if the socket was already bound.
482   DWORD randomize_port_value = 1;
483   setsockopt(socket_, SOL_SOCKET, SO_RANDOMIZE_PORT,
484              reinterpret_cast<const char*>(&randomize_port_value),
485              sizeof(randomize_port_value));
486 
487   SockaddrStorage storage;
488   if (!address.ToSockAddr(storage.addr, &storage.addr_len))
489     return ERR_ADDRESS_INVALID;
490 
491   int rv = connect(socket_, storage.addr, storage.addr_len);
492   if (rv < 0)
493     return MapSystemError(WSAGetLastError());
494 
495   remote_address_ = std::make_unique<IPEndPoint>(address);
496 
497   if (dscp_manager_)
498     dscp_manager_->PrepareForSend(*remote_address_.get());
499 
500   return rv;
501 }
502 
Bind(const IPEndPoint & address)503 int UDPSocketWin::Bind(const IPEndPoint& address) {
504   DCHECK_NE(socket_, INVALID_SOCKET);
505   DCHECK(!is_connected());
506 
507   int rv = SetMulticastOptions();
508   if (rv < 0)
509     return rv;
510 
511   rv = DoBind(address);
512   if (rv < 0)
513     return rv;
514 
515   local_address_.reset();
516   is_connected_ = true;
517   return rv;
518 }
519 
BindToNetwork(handles::NetworkHandle network)520 int UDPSocketWin::BindToNetwork(handles::NetworkHandle network) {
521   NOTIMPLEMENTED();
522   return ERR_NOT_IMPLEMENTED;
523 }
524 
SetReceiveBufferSize(int32_t size)525 int UDPSocketWin::SetReceiveBufferSize(int32_t size) {
526   DCHECK_NE(socket_, INVALID_SOCKET);
527   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
528   int rv = SetSocketReceiveBufferSize(socket_, size);
529 
530   if (rv != 0)
531     return MapSystemError(WSAGetLastError());
532 
533   // According to documentation, setsockopt may succeed, but we need to check
534   // the results via getsockopt to be sure it works on Windows.
535   int32_t actual_size = 0;
536   int option_size = sizeof(actual_size);
537   rv = getsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
538                   reinterpret_cast<char*>(&actual_size), &option_size);
539   if (rv != 0)
540     return MapSystemError(WSAGetLastError());
541   if (actual_size >= size)
542     return OK;
543   UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableReceiveBuffer",
544                               actual_size, 1000, 1000000, 50);
545   return ERR_SOCKET_RECEIVE_BUFFER_SIZE_UNCHANGEABLE;
546 }
547 
SetSendBufferSize(int32_t size)548 int UDPSocketWin::SetSendBufferSize(int32_t size) {
549   DCHECK_NE(socket_, INVALID_SOCKET);
550   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
551   int rv = SetSocketSendBufferSize(socket_, size);
552   if (rv != 0)
553     return MapSystemError(WSAGetLastError());
554   // According to documentation, setsockopt may succeed, but we need to check
555   // the results via getsockopt to be sure it works on Windows.
556   int32_t actual_size = 0;
557   int option_size = sizeof(actual_size);
558   rv = getsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
559                   reinterpret_cast<char*>(&actual_size), &option_size);
560   if (rv != 0)
561     return MapSystemError(WSAGetLastError());
562   if (actual_size >= size)
563     return OK;
564   UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableSendBuffer",
565                               actual_size, 1000, 1000000, 50);
566   return ERR_SOCKET_SEND_BUFFER_SIZE_UNCHANGEABLE;
567 }
568 
SetDoNotFragment()569 int UDPSocketWin::SetDoNotFragment() {
570   DCHECK_NE(socket_, INVALID_SOCKET);
571   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
572 
573   if (addr_family_ == AF_INET6)
574     return OK;
575 
576   DWORD val = 1;
577   int rv = setsockopt(socket_, IPPROTO_IP, IP_DONTFRAGMENT,
578                       reinterpret_cast<const char*>(&val), sizeof(val));
579   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
580 }
581 
SetMsgConfirm(bool confirm)582 void UDPSocketWin::SetMsgConfirm(bool confirm) {}
583 
AllowAddressReuse()584 int UDPSocketWin::AllowAddressReuse() {
585   DCHECK_NE(socket_, INVALID_SOCKET);
586   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
587   DCHECK(!is_connected());
588 
589   BOOL true_value = TRUE;
590   int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
591                       reinterpret_cast<const char*>(&true_value),
592                       sizeof(true_value));
593   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
594 }
595 
SetBroadcast(bool broadcast)596 int UDPSocketWin::SetBroadcast(bool broadcast) {
597   DCHECK_NE(socket_, INVALID_SOCKET);
598   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
599 
600   BOOL value = broadcast ? TRUE : FALSE;
601   int rv = setsockopt(socket_, SOL_SOCKET, SO_BROADCAST,
602                       reinterpret_cast<const char*>(&value), sizeof(value));
603   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
604 }
605 
AllowAddressSharingForMulticast()606 int UDPSocketWin::AllowAddressSharingForMulticast() {
607   // When proper multicast groups are used, Windows further defines the address
608   // resuse option (SO_REUSEADDR) to ensure all listening sockets can receive
609   // all incoming messages for the multicast group.
610   return AllowAddressReuse();
611 }
612 
DoReadCallback(int rv)613 void UDPSocketWin::DoReadCallback(int rv) {
614   DCHECK_NE(rv, ERR_IO_PENDING);
615   DCHECK(!read_callback_.is_null());
616 
617   // since Run may result in Read being called, clear read_callback_ up front.
618   std::move(read_callback_).Run(rv);
619 }
620 
DoWriteCallback(int rv)621 void UDPSocketWin::DoWriteCallback(int rv) {
622   DCHECK_NE(rv, ERR_IO_PENDING);
623   DCHECK(!write_callback_.is_null());
624 
625   // since Run may result in Write being called, clear write_callback_ up front.
626   std::move(write_callback_).Run(rv);
627 }
628 
DidCompleteRead()629 void UDPSocketWin::DidCompleteRead() {
630   DWORD num_bytes, flags;
631   BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_,
632                                    &num_bytes, FALSE, &flags);
633   WSAResetEvent(core_->read_overlapped_.hEvent);
634   int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
635   // Convert address.
636   IPEndPoint address;
637   IPEndPoint* address_to_log = nullptr;
638   if (result >= 0) {
639     if (address.FromSockAddr(core_->recv_addr_storage_.addr,
640                              core_->recv_addr_storage_.addr_len)) {
641       if (recv_from_address_)
642         *recv_from_address_ = address;
643       address_to_log = &address;
644     } else {
645       result = ERR_ADDRESS_INVALID;
646     }
647   }
648   LogRead(result, core_->read_iobuffer_->data(), address_to_log);
649   core_->read_iobuffer_ = nullptr;
650   recv_from_address_ = nullptr;
651   DoReadCallback(result);
652 }
653 
DidCompleteWrite()654 void UDPSocketWin::DidCompleteWrite() {
655   DWORD num_bytes, flags;
656   BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
657                                    &num_bytes, FALSE, &flags);
658   WSAResetEvent(core_->write_overlapped_.hEvent);
659   int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
660   LogWrite(result, core_->write_iobuffer_->data(), send_to_address_.get());
661 
662   send_to_address_.reset();
663   core_->write_iobuffer_ = nullptr;
664   DoWriteCallback(result);
665 }
666 
OnObjectSignaled(HANDLE object)667 void UDPSocketWin::OnObjectSignaled(HANDLE object) {
668   DCHECK(object == read_write_event_.Get());
669   WSANETWORKEVENTS network_events;
670   int os_error = 0;
671   int rv =
672       WSAEnumNetworkEvents(socket_, read_write_event_.Get(), &network_events);
673   // Protects against trying to call the write callback if the read callback
674   // either closes or destroys |this|.
675   base::WeakPtr<UDPSocketWin> event_pending = event_pending_.GetWeakPtr();
676   if (rv == SOCKET_ERROR) {
677     os_error = WSAGetLastError();
678     rv = MapSystemError(os_error);
679 
680     if (read_iobuffer_) {
681       read_iobuffer_ = nullptr;
682       read_iobuffer_len_ = 0;
683       recv_from_address_ = nullptr;
684       DoReadCallback(rv);
685     }
686 
687     // Socket may have been closed or destroyed here.
688     if (event_pending && write_iobuffer_) {
689       write_iobuffer_ = nullptr;
690       write_iobuffer_len_ = 0;
691       send_to_address_.reset();
692       DoWriteCallback(rv);
693     }
694     return;
695   }
696 
697   if ((network_events.lNetworkEvents & FD_READ) && read_iobuffer_)
698     OnReadSignaled();
699   if (!event_pending)
700     return;
701 
702   if ((network_events.lNetworkEvents & FD_WRITE) && write_iobuffer_)
703     OnWriteSignaled();
704   if (!event_pending)
705     return;
706 
707   // There's still pending read / write. Watch for further events.
708   if (read_iobuffer_ || write_iobuffer_)
709     WatchForReadWrite();
710 }
711 
OnReadSignaled()712 void UDPSocketWin::OnReadSignaled() {
713   int rv = InternalRecvFromNonBlocking(read_iobuffer_.get(), read_iobuffer_len_,
714                                        recv_from_address_);
715   if (rv == ERR_IO_PENDING)
716     return;
717   read_iobuffer_ = nullptr;
718   read_iobuffer_len_ = 0;
719   recv_from_address_ = nullptr;
720   DoReadCallback(rv);
721 }
722 
OnWriteSignaled()723 void UDPSocketWin::OnWriteSignaled() {
724   int rv = InternalSendToNonBlocking(write_iobuffer_.get(), write_iobuffer_len_,
725                                      send_to_address_.get());
726   if (rv == ERR_IO_PENDING)
727     return;
728   write_iobuffer_ = nullptr;
729   write_iobuffer_len_ = 0;
730   send_to_address_.reset();
731   DoWriteCallback(rv);
732 }
733 
WatchForReadWrite()734 void UDPSocketWin::WatchForReadWrite() {
735   if (read_write_watcher_.IsWatching())
736     return;
737   bool watched =
738       read_write_watcher_.StartWatchingOnce(read_write_event_.Get(), this);
739   DCHECK(watched);
740 }
741 
LogRead(int result,const char * bytes,const IPEndPoint * address) const742 void UDPSocketWin::LogRead(int result,
743                            const char* bytes,
744                            const IPEndPoint* address) const {
745   if (result < 0) {
746     net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_RECEIVE_ERROR,
747                                       result);
748     return;
749   }
750 
751   if (net_log_.IsCapturing()) {
752     NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_RECEIVED, result,
753                           bytes, address);
754   }
755 
756   activity_monitor::IncrementBytesReceived(result);
757 }
758 
LogWrite(int result,const char * bytes,const IPEndPoint * address) const759 void UDPSocketWin::LogWrite(int result,
760                             const char* bytes,
761                             const IPEndPoint* address) const {
762   if (result < 0) {
763     net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, result);
764     return;
765   }
766 
767   if (net_log_.IsCapturing()) {
768     NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_SENT, result,
769                           bytes, address);
770   }
771 }
772 
InternalRecvFromOverlapped(IOBuffer * buf,int buf_len,IPEndPoint * address)773 int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf,
774                                              int buf_len,
775                                              IPEndPoint* address) {
776   DCHECK(!core_->read_iobuffer_.get());
777   SockaddrStorage& storage = core_->recv_addr_storage_;
778   storage.addr_len = sizeof(storage.addr_storage);
779 
780   WSABUF read_buffer;
781   read_buffer.buf = buf->data();
782   read_buffer.len = buf_len;
783 
784   DWORD flags = 0;
785   DWORD num;
786   CHECK_NE(INVALID_SOCKET, socket_);
787   int rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, storage.addr,
788                        &storage.addr_len, &core_->read_overlapped_, nullptr);
789   if (rv == 0) {
790     if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) {
791       int result = num;
792       // Convert address.
793       IPEndPoint address_storage;
794       IPEndPoint* address_to_log = nullptr;
795       if (result >= 0) {
796         if (address_storage.FromSockAddr(core_->recv_addr_storage_.addr,
797                                          core_->recv_addr_storage_.addr_len)) {
798           if (address)
799             *address = address_storage;
800           address_to_log = &address_storage;
801         } else {
802           result = ERR_ADDRESS_INVALID;
803         }
804       }
805       LogRead(result, buf->data(), address_to_log);
806       return result;
807     }
808   } else {
809     int os_error = WSAGetLastError();
810     if (os_error != WSA_IO_PENDING) {
811       int result = MapSystemError(os_error);
812       LogRead(result, nullptr, nullptr);
813       return result;
814     }
815   }
816   core_->WatchForRead();
817   core_->read_iobuffer_ = buf;
818   return ERR_IO_PENDING;
819 }
820 
InternalSendToOverlapped(IOBuffer * buf,int buf_len,const IPEndPoint * address)821 int UDPSocketWin::InternalSendToOverlapped(IOBuffer* buf,
822                                            int buf_len,
823                                            const IPEndPoint* address) {
824   DCHECK(!core_->write_iobuffer_.get());
825   SockaddrStorage storage;
826   struct sockaddr* addr = storage.addr;
827   // Convert address.
828   if (!address) {
829     addr = nullptr;
830     storage.addr_len = 0;
831   } else {
832     if (!address->ToSockAddr(addr, &storage.addr_len)) {
833       int result = ERR_ADDRESS_INVALID;
834       LogWrite(result, nullptr, nullptr);
835       return result;
836     }
837   }
838 
839   WSABUF write_buffer;
840   write_buffer.buf = buf->data();
841   write_buffer.len = buf_len;
842 
843   DWORD flags = 0;
844   DWORD num;
845   int rv = WSASendTo(socket_, &write_buffer, 1, &num, flags, addr,
846                      storage.addr_len, &core_->write_overlapped_, nullptr);
847   if (rv == 0) {
848     if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) {
849       int result = num;
850       LogWrite(result, buf->data(), address);
851       return result;
852     }
853   } else {
854     int os_error = WSAGetLastError();
855     if (os_error != WSA_IO_PENDING) {
856       int result = MapSystemError(os_error);
857       LogWrite(result, nullptr, nullptr);
858       return result;
859     }
860   }
861 
862   core_->WatchForWrite();
863   core_->write_iobuffer_ = buf;
864   return ERR_IO_PENDING;
865 }
866 
InternalRecvFromNonBlocking(IOBuffer * buf,int buf_len,IPEndPoint * address)867 int UDPSocketWin::InternalRecvFromNonBlocking(IOBuffer* buf,
868                                               int buf_len,
869                                               IPEndPoint* address) {
870   DCHECK(!read_iobuffer_ || read_iobuffer_.get() == buf);
871   SockaddrStorage storage;
872   storage.addr_len = sizeof(storage.addr_storage);
873 
874   CHECK_NE(INVALID_SOCKET, socket_);
875   int rv = recvfrom(socket_, buf->data(), buf_len, 0, storage.addr,
876                     &storage.addr_len);
877   if (rv == SOCKET_ERROR) {
878     int os_error = WSAGetLastError();
879     if (os_error == WSAEWOULDBLOCK) {
880       read_iobuffer_ = buf;
881       read_iobuffer_len_ = buf_len;
882       WatchForReadWrite();
883       return ERR_IO_PENDING;
884     }
885     rv = MapSystemError(os_error);
886     LogRead(rv, nullptr, nullptr);
887     return rv;
888   }
889   IPEndPoint address_storage;
890   IPEndPoint* address_to_log = nullptr;
891   if (rv >= 0) {
892     if (address_storage.FromSockAddr(storage.addr, storage.addr_len)) {
893       if (address)
894         *address = address_storage;
895       address_to_log = &address_storage;
896     } else {
897       rv = ERR_ADDRESS_INVALID;
898     }
899   }
900   LogRead(rv, buf->data(), address_to_log);
901   return rv;
902 }
903 
InternalSendToNonBlocking(IOBuffer * buf,int buf_len,const IPEndPoint * address)904 int UDPSocketWin::InternalSendToNonBlocking(IOBuffer* buf,
905                                             int buf_len,
906                                             const IPEndPoint* address) {
907   DCHECK(!write_iobuffer_ || write_iobuffer_.get() == buf);
908   SockaddrStorage storage;
909   struct sockaddr* addr = storage.addr;
910   // Convert address.
911   if (address) {
912     if (!address->ToSockAddr(addr, &storage.addr_len)) {
913       int result = ERR_ADDRESS_INVALID;
914       LogWrite(result, nullptr, nullptr);
915       return result;
916     }
917   } else {
918     addr = nullptr;
919     storage.addr_len = 0;
920   }
921 
922   int rv = sendto(socket_, buf->data(), buf_len, 0, addr, storage.addr_len);
923   if (rv == SOCKET_ERROR) {
924     int os_error = WSAGetLastError();
925     if (os_error == WSAEWOULDBLOCK) {
926       write_iobuffer_ = buf;
927       write_iobuffer_len_ = buf_len;
928       WatchForReadWrite();
929       return ERR_IO_PENDING;
930     }
931     rv = MapSystemError(os_error);
932     LogWrite(rv, nullptr, nullptr);
933     return rv;
934   }
935   LogWrite(rv, buf->data(), address);
936   return rv;
937 }
938 
SetMulticastOptions()939 int UDPSocketWin::SetMulticastOptions() {
940   if (!(socket_options_ & SOCKET_OPTION_MULTICAST_LOOP)) {
941     DWORD loop = 0;
942     int protocol_level =
943         addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
944     int option =
945         addr_family_ == AF_INET ? IP_MULTICAST_LOOP: IPV6_MULTICAST_LOOP;
946     int rv = setsockopt(socket_, protocol_level, option,
947                         reinterpret_cast<const char*>(&loop), sizeof(loop));
948     if (rv < 0)
949       return MapSystemError(WSAGetLastError());
950   }
951   if (multicast_time_to_live_ != 1) {
952     DWORD hops = multicast_time_to_live_;
953     int protocol_level =
954         addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
955     int option =
956         addr_family_ == AF_INET ? IP_MULTICAST_TTL: IPV6_MULTICAST_HOPS;
957     int rv = setsockopt(socket_, protocol_level, option,
958                         reinterpret_cast<const char*>(&hops), sizeof(hops));
959     if (rv < 0)
960       return MapSystemError(WSAGetLastError());
961   }
962   if (multicast_interface_ != 0) {
963     switch (addr_family_) {
964       case AF_INET: {
965         in_addr address;
966         address.s_addr = htonl(multicast_interface_);
967         int rv = setsockopt(socket_, IPPROTO_IP, IP_MULTICAST_IF,
968                             reinterpret_cast<const char*>(&address),
969                             sizeof(address));
970         if (rv)
971           return MapSystemError(WSAGetLastError());
972         break;
973       }
974       case AF_INET6: {
975         uint32_t interface_index = multicast_interface_;
976         int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_MULTICAST_IF,
977                             reinterpret_cast<const char*>(&interface_index),
978                             sizeof(interface_index));
979         if (rv)
980           return MapSystemError(WSAGetLastError());
981         break;
982       }
983       default:
984         NOTREACHED() << "Invalid address family";
985         return ERR_ADDRESS_INVALID;
986     }
987   }
988   return OK;
989 }
990 
DoBind(const IPEndPoint & address)991 int UDPSocketWin::DoBind(const IPEndPoint& address) {
992   SockaddrStorage storage;
993   if (!address.ToSockAddr(storage.addr, &storage.addr_len))
994     return ERR_ADDRESS_INVALID;
995   int rv = bind(socket_, storage.addr, storage.addr_len);
996   if (rv == 0)
997     return OK;
998   int last_error = WSAGetLastError();
999   // Map some codes that are special to bind() separately.
1000   // * WSAEACCES: If a port is already bound to a socket, WSAEACCES may be
1001   //   returned instead of WSAEADDRINUSE, depending on whether the socket
1002   //   option SO_REUSEADDR or SO_EXCLUSIVEADDRUSE is set and whether the
1003   //   conflicting socket is owned by a different user account. See the MSDN
1004   //   page "Using SO_REUSEADDR and SO_EXCLUSIVEADDRUSE" for the gory details.
1005   if (last_error == WSAEACCES || last_error == WSAEADDRNOTAVAIL)
1006     return ERR_ADDRESS_IN_USE;
1007   return MapSystemError(last_error);
1008 }
1009 
GetQwaveApi() const1010 QwaveApi* UDPSocketWin::GetQwaveApi() const {
1011   return QwaveApi::GetDefault();
1012 }
1013 
JoinGroup(const IPAddress & group_address) const1014 int UDPSocketWin::JoinGroup(const IPAddress& group_address) const {
1015   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1016   if (!is_connected())
1017     return ERR_SOCKET_NOT_CONNECTED;
1018 
1019   switch (group_address.size()) {
1020     case IPAddress::kIPv4AddressSize: {
1021       if (addr_family_ != AF_INET)
1022         return ERR_ADDRESS_INVALID;
1023       ip_mreq mreq;
1024       mreq.imr_interface.s_addr = htonl(multicast_interface_);
1025       memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1026              IPAddress::kIPv4AddressSize);
1027       int rv = setsockopt(socket_, IPPROTO_IP, IP_ADD_MEMBERSHIP,
1028                           reinterpret_cast<const char*>(&mreq),
1029                           sizeof(mreq));
1030       if (rv)
1031         return MapSystemError(WSAGetLastError());
1032       return OK;
1033     }
1034     case IPAddress::kIPv6AddressSize: {
1035       if (addr_family_ != AF_INET6)
1036         return ERR_ADDRESS_INVALID;
1037       ipv6_mreq mreq;
1038       mreq.ipv6mr_interface = multicast_interface_;
1039       memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1040              IPAddress::kIPv6AddressSize);
1041       int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
1042                           reinterpret_cast<const char*>(&mreq),
1043                           sizeof(mreq));
1044       if (rv)
1045         return MapSystemError(WSAGetLastError());
1046       return OK;
1047     }
1048     default:
1049       NOTREACHED() << "Invalid address family";
1050       return ERR_ADDRESS_INVALID;
1051   }
1052 }
1053 
LeaveGroup(const IPAddress & group_address) const1054 int UDPSocketWin::LeaveGroup(const IPAddress& group_address) const {
1055   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1056   if (!is_connected())
1057     return ERR_SOCKET_NOT_CONNECTED;
1058 
1059   switch (group_address.size()) {
1060     case IPAddress::kIPv4AddressSize: {
1061       if (addr_family_ != AF_INET)
1062         return ERR_ADDRESS_INVALID;
1063       ip_mreq mreq;
1064       mreq.imr_interface.s_addr = htonl(multicast_interface_);
1065       memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1066              IPAddress::kIPv4AddressSize);
1067       int rv = setsockopt(socket_, IPPROTO_IP, IP_DROP_MEMBERSHIP,
1068                           reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1069       if (rv)
1070         return MapSystemError(WSAGetLastError());
1071       return OK;
1072     }
1073     case IPAddress::kIPv6AddressSize: {
1074       if (addr_family_ != AF_INET6)
1075         return ERR_ADDRESS_INVALID;
1076       ipv6_mreq mreq;
1077       mreq.ipv6mr_interface = multicast_interface_;
1078       memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1079              IPAddress::kIPv6AddressSize);
1080       int rv = setsockopt(socket_, IPPROTO_IPV6, IP_DROP_MEMBERSHIP,
1081                           reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1082       if (rv)
1083         return MapSystemError(WSAGetLastError());
1084       return OK;
1085     }
1086     default:
1087       NOTREACHED() << "Invalid address family";
1088       return ERR_ADDRESS_INVALID;
1089   }
1090 }
1091 
SetMulticastInterface(uint32_t interface_index)1092 int UDPSocketWin::SetMulticastInterface(uint32_t interface_index) {
1093   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1094   if (is_connected())
1095     return ERR_SOCKET_IS_CONNECTED;
1096   multicast_interface_ = interface_index;
1097   return OK;
1098 }
1099 
SetMulticastTimeToLive(int time_to_live)1100 int UDPSocketWin::SetMulticastTimeToLive(int time_to_live) {
1101   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1102   if (is_connected())
1103     return ERR_SOCKET_IS_CONNECTED;
1104 
1105   if (time_to_live < 0 || time_to_live > 255)
1106     return ERR_INVALID_ARGUMENT;
1107   multicast_time_to_live_ = time_to_live;
1108   return OK;
1109 }
1110 
SetMulticastLoopbackMode(bool loopback)1111 int UDPSocketWin::SetMulticastLoopbackMode(bool loopback) {
1112   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1113   if (is_connected())
1114     return ERR_SOCKET_IS_CONNECTED;
1115 
1116   if (loopback)
1117     socket_options_ |= SOCKET_OPTION_MULTICAST_LOOP;
1118   else
1119     socket_options_ &= ~SOCKET_OPTION_MULTICAST_LOOP;
1120   return OK;
1121 }
1122 
DscpToTrafficType(DiffServCodePoint dscp)1123 QOS_TRAFFIC_TYPE DscpToTrafficType(DiffServCodePoint dscp) {
1124   QOS_TRAFFIC_TYPE traffic_type = QOSTrafficTypeBestEffort;
1125   switch (dscp) {
1126     case DSCP_CS0:
1127       traffic_type = QOSTrafficTypeBestEffort;
1128       break;
1129     case DSCP_CS1:
1130       traffic_type = QOSTrafficTypeBackground;
1131       break;
1132     case DSCP_AF11:
1133     case DSCP_AF12:
1134     case DSCP_AF13:
1135     case DSCP_CS2:
1136     case DSCP_AF21:
1137     case DSCP_AF22:
1138     case DSCP_AF23:
1139     case DSCP_CS3:
1140     case DSCP_AF31:
1141     case DSCP_AF32:
1142     case DSCP_AF33:
1143     case DSCP_CS4:
1144       traffic_type = QOSTrafficTypeExcellentEffort;
1145       break;
1146     case DSCP_AF41:
1147     case DSCP_AF42:
1148     case DSCP_AF43:
1149     case DSCP_CS5:
1150       traffic_type = QOSTrafficTypeAudioVideo;
1151       break;
1152     case DSCP_EF:
1153     case DSCP_CS6:
1154       traffic_type = QOSTrafficTypeVoice;
1155       break;
1156     case DSCP_CS7:
1157       traffic_type = QOSTrafficTypeControl;
1158       break;
1159     case DSCP_NO_CHANGE:
1160       NOTREACHED();
1161       break;
1162   }
1163   return traffic_type;
1164 }
1165 
SetDiffServCodePoint(DiffServCodePoint dscp)1166 int UDPSocketWin::SetDiffServCodePoint(DiffServCodePoint dscp) {
1167   if (dscp == DSCP_NO_CHANGE)
1168     return OK;
1169 
1170   if (!is_connected())
1171     return ERR_SOCKET_NOT_CONNECTED;
1172 
1173   QwaveApi* api = GetQwaveApi();
1174 
1175   if (!api->qwave_supported())
1176     return ERR_NOT_IMPLEMENTED;
1177 
1178   if (!dscp_manager_)
1179     dscp_manager_ = std::make_unique<DscpManager>(api, socket_);
1180 
1181   dscp_manager_->Set(dscp);
1182   if (remote_address_)
1183     return dscp_manager_->PrepareForSend(*remote_address_.get());
1184 
1185   return OK;
1186 }
1187 
SetIPv6Only(bool ipv6_only)1188 int UDPSocketWin::SetIPv6Only(bool ipv6_only) {
1189   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1190   if (is_connected()) {
1191     return ERR_SOCKET_IS_CONNECTED;
1192   }
1193   return net::SetIPv6Only(socket_, ipv6_only);
1194 }
1195 
DetachFromThread()1196 void UDPSocketWin::DetachFromThread() {
1197   DETACH_FROM_THREAD(thread_checker_);
1198 }
1199 
UseNonBlockingIO()1200 void UDPSocketWin::UseNonBlockingIO() {
1201   DCHECK(!core_);
1202   use_non_blocking_io_ = true;
1203 }
1204 
ApplySocketTag(const SocketTag & tag)1205 void UDPSocketWin::ApplySocketTag(const SocketTag& tag) {
1206   // Windows does not support any specific SocketTags so fail if any non-default
1207   // tag is applied.
1208   CHECK(tag == SocketTag());
1209 }
1210 
DscpManager(QwaveApi * api,SOCKET socket)1211 DscpManager::DscpManager(QwaveApi* api, SOCKET socket)
1212     : api_(api), socket_(socket) {
1213   RequestHandle();
1214 }
1215 
~DscpManager()1216 DscpManager::~DscpManager() {
1217   if (!qos_handle_)
1218     return;
1219 
1220   if (flow_id_ != 0)
1221     api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1222 
1223   api_->CloseHandle(qos_handle_);
1224 }
1225 
Set(DiffServCodePoint dscp)1226 void DscpManager::Set(DiffServCodePoint dscp) {
1227   if (dscp == DSCP_NO_CHANGE || dscp == dscp_value_)
1228     return;
1229 
1230   dscp_value_ = dscp;
1231 
1232   // TODO(zstein): We could reuse the flow when the value changes
1233   // by calling QOSSetFlow with the new traffic type and dscp value.
1234   if (flow_id_ != 0 && qos_handle_) {
1235     api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1236     configured_.clear();
1237     flow_id_ = 0;
1238   }
1239 }
1240 
PrepareForSend(const IPEndPoint & remote_address)1241 int DscpManager::PrepareForSend(const IPEndPoint& remote_address) {
1242   if (dscp_value_ == DSCP_NO_CHANGE) {
1243     // No DSCP value has been set.
1244     return OK;
1245   }
1246 
1247   if (!api_->qwave_supported())
1248     return ERR_NOT_IMPLEMENTED;
1249 
1250   if (!qos_handle_)
1251     return ERR_INVALID_HANDLE;  // The closest net error to try again later.
1252 
1253   if (configured_.find(remote_address) != configured_.end())
1254     return OK;
1255 
1256   SockaddrStorage storage;
1257   if (!remote_address.ToSockAddr(storage.addr, &storage.addr_len))
1258     return ERR_ADDRESS_INVALID;
1259 
1260   // We won't try this address again if we get an error.
1261   configured_.emplace(remote_address);
1262 
1263   // We don't need to call SetFlow if we already have a qos flow.
1264   bool new_flow = flow_id_ == 0;
1265 
1266   const QOS_TRAFFIC_TYPE traffic_type = DscpToTrafficType(dscp_value_);
1267 
1268   if (!api_->AddSocketToFlow(qos_handle_, socket_, storage.addr, traffic_type,
1269                              QOS_NON_ADAPTIVE_FLOW, &flow_id_)) {
1270     DWORD err = ::GetLastError();
1271     if (err == ERROR_DEVICE_REINITIALIZATION_NEEDED) {
1272       // Reset. PrepareForSend is called for every packet.  Once RequestHandle
1273       // completes asynchronously the next PrepareForSend call will re-register
1274       // the address with the new QoS Handle.  In the meantime, sends will
1275       // continue without DSCP.
1276       RequestHandle();
1277       configured_.clear();
1278       flow_id_ = 0;
1279       return ERR_INVALID_HANDLE;
1280     }
1281     return MapSystemError(err);
1282   }
1283 
1284   if (new_flow) {
1285     DWORD buf = dscp_value_;
1286     // This requires admin rights, and may fail, if so we ignore it
1287     // as AddSocketToFlow should still do *approximately* the right thing.
1288     api_->SetFlow(qos_handle_, flow_id_, QOSSetOutgoingDSCPValue, sizeof(buf),
1289                   &buf, 0, nullptr);
1290   }
1291 
1292   return OK;
1293 }
1294 
RequestHandle()1295 void DscpManager::RequestHandle() {
1296   if (handle_is_initializing_)
1297     return;
1298 
1299   if (qos_handle_) {
1300     api_->CloseHandle(qos_handle_);
1301     qos_handle_ = nullptr;
1302   }
1303 
1304   handle_is_initializing_ = true;
1305   base::ThreadPool::PostTaskAndReplyWithResult(
1306       FROM_HERE, {base::MayBlock()},
1307       base::BindOnce(&DscpManager::DoCreateHandle, api_),
1308       base::BindOnce(&DscpManager::OnHandleCreated, api_,
1309                      weak_ptr_factory_.GetWeakPtr()));
1310 }
1311 
DoCreateHandle(QwaveApi * api)1312 HANDLE DscpManager::DoCreateHandle(QwaveApi* api) {
1313   QOS_VERSION version;
1314   version.MajorVersion = 1;
1315   version.MinorVersion = 0;
1316 
1317   HANDLE handle = nullptr;
1318 
1319   // No access to net_log_ so swallow any errors here.
1320   api->CreateHandle(&version, &handle);
1321   return handle;
1322 }
1323 
OnHandleCreated(QwaveApi * api,base::WeakPtr<DscpManager> dscp_manager,HANDLE handle)1324 void DscpManager::OnHandleCreated(QwaveApi* api,
1325                                   base::WeakPtr<DscpManager> dscp_manager,
1326                                   HANDLE handle) {
1327   if (!handle)
1328     api->OnFatalError();
1329 
1330   if (!dscp_manager) {
1331     api->CloseHandle(handle);
1332     return;
1333   }
1334 
1335   DCHECK(dscp_manager->handle_is_initializing_);
1336   DCHECK(!dscp_manager->qos_handle_);
1337 
1338   dscp_manager->qos_handle_ = handle;
1339   dscp_manager->handle_is_initializing_ = false;
1340 }
1341 
1342 }  // namespace net
1343