1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/udp_socket_win.h"
6
7 #include <mstcpip.h>
8 #include <winsock2.h>
9
10 #include <memory>
11
12 #include "base/check_op.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback.h"
15 #include "base/lazy_instance.h"
16 #include "base/memory/raw_ptr.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/metrics/histogram_macros.h"
19 #include "base/notreached.h"
20 #include "base/rand_util.h"
21 #include "base/task/thread_pool.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/network_activity_monitor.h"
27 #include "net/base/network_change_notifier.h"
28 #include "net/base/sockaddr_storage.h"
29 #include "net/base/winsock_init.h"
30 #include "net/base/winsock_util.h"
31 #include "net/log/net_log.h"
32 #include "net/log/net_log_event_type.h"
33 #include "net/log/net_log_source.h"
34 #include "net/log/net_log_source_type.h"
35 #include "net/socket/socket_descriptor.h"
36 #include "net/socket/socket_options.h"
37 #include "net/socket/socket_tag.h"
38 #include "net/socket/udp_net_log_parameters.h"
39 #include "net/traffic_annotation/network_traffic_annotation.h"
40
41 namespace net {
42
43 // This class encapsulates all the state that has to be preserved as long as
44 // there is a network IO operation in progress. If the owner UDPSocketWin
45 // is destroyed while an operation is in progress, the Core is detached and it
46 // lives until the operation completes and the OS doesn't reference any resource
47 // declared on this class anymore.
48 class UDPSocketWin::Core : public base::RefCounted<Core> {
49 public:
50 explicit Core(UDPSocketWin* socket);
51
52 Core(const Core&) = delete;
53 Core& operator=(const Core&) = delete;
54
55 // Start watching for the end of a read or write operation.
56 void WatchForRead();
57 void WatchForWrite();
58
59 // The UDPSocketWin is going away.
Detach()60 void Detach() { socket_ = nullptr; }
61
62 // The separate OVERLAPPED variables for asynchronous operation.
63 OVERLAPPED read_overlapped_;
64 OVERLAPPED write_overlapped_;
65
66 // The buffers used in Read() and Write().
67 scoped_refptr<IOBuffer> read_iobuffer_;
68 scoped_refptr<IOBuffer> write_iobuffer_;
69
70 // The address storage passed to WSARecvFrom().
71 SockaddrStorage recv_addr_storage_;
72
73 private:
74 friend class base::RefCounted<Core>;
75
76 class ReadDelegate : public base::win::ObjectWatcher::Delegate {
77 public:
ReadDelegate(Core * core)78 explicit ReadDelegate(Core* core) : core_(core) {}
79 ~ReadDelegate() override = default;
80
81 // base::ObjectWatcher::Delegate methods:
82 void OnObjectSignaled(HANDLE object) override;
83
84 private:
85 const raw_ptr<Core> core_;
86 };
87
88 class WriteDelegate : public base::win::ObjectWatcher::Delegate {
89 public:
WriteDelegate(Core * core)90 explicit WriteDelegate(Core* core) : core_(core) {}
91 ~WriteDelegate() override = default;
92
93 // base::ObjectWatcher::Delegate methods:
94 void OnObjectSignaled(HANDLE object) override;
95
96 private:
97 const raw_ptr<Core> core_;
98 };
99
100 ~Core();
101
102 // The socket that created this object.
103 raw_ptr<UDPSocketWin> socket_;
104
105 // |reader_| handles the signals from |read_watcher_|.
106 ReadDelegate reader_;
107 // |writer_| handles the signals from |write_watcher_|.
108 WriteDelegate writer_;
109
110 // |read_watcher_| watches for events from Read().
111 base::win::ObjectWatcher read_watcher_;
112 // |write_watcher_| watches for events from Write();
113 base::win::ObjectWatcher write_watcher_;
114 };
115
Core(UDPSocketWin * socket)116 UDPSocketWin::Core::Core(UDPSocketWin* socket)
117 : socket_(socket),
118 reader_(this),
119 writer_(this) {
120 memset(&read_overlapped_, 0, sizeof(read_overlapped_));
121 memset(&write_overlapped_, 0, sizeof(write_overlapped_));
122
123 read_overlapped_.hEvent = WSACreateEvent();
124 write_overlapped_.hEvent = WSACreateEvent();
125 }
126
~Core()127 UDPSocketWin::Core::~Core() {
128 // Make sure the message loop is not watching this object anymore.
129 read_watcher_.StopWatching();
130 write_watcher_.StopWatching();
131
132 WSACloseEvent(read_overlapped_.hEvent);
133 memset(&read_overlapped_, 0xaf, sizeof(read_overlapped_));
134 WSACloseEvent(write_overlapped_.hEvent);
135 memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
136 }
137
WatchForRead()138 void UDPSocketWin::Core::WatchForRead() {
139 // We grab an extra reference because there is an IO operation in progress.
140 // Balanced in ReadDelegate::OnObjectSignaled().
141 AddRef();
142 read_watcher_.StartWatchingOnce(read_overlapped_.hEvent, &reader_);
143 }
144
WatchForWrite()145 void UDPSocketWin::Core::WatchForWrite() {
146 // We grab an extra reference because there is an IO operation in progress.
147 // Balanced in WriteDelegate::OnObjectSignaled().
148 AddRef();
149 write_watcher_.StartWatchingOnce(write_overlapped_.hEvent, &writer_);
150 }
151
OnObjectSignaled(HANDLE object)152 void UDPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
153 DCHECK_EQ(object, core_->read_overlapped_.hEvent);
154 if (core_->socket_)
155 core_->socket_->DidCompleteRead();
156
157 core_->Release();
158 }
159
OnObjectSignaled(HANDLE object)160 void UDPSocketWin::Core::WriteDelegate::OnObjectSignaled(HANDLE object) {
161 DCHECK_EQ(object, core_->write_overlapped_.hEvent);
162 if (core_->socket_)
163 core_->socket_->DidCompleteWrite();
164
165 core_->Release();
166 }
167 //-----------------------------------------------------------------------------
168
QwaveApi()169 QwaveApi::QwaveApi() {
170 HMODULE qwave = LoadLibrary(L"qwave.dll");
171 if (!qwave)
172 return;
173 create_handle_func_ =
174 (CreateHandleFn)GetProcAddress(qwave, "QOSCreateHandle");
175 close_handle_func_ =
176 (CloseHandleFn)GetProcAddress(qwave, "QOSCloseHandle");
177 add_socket_to_flow_func_ =
178 (AddSocketToFlowFn)GetProcAddress(qwave, "QOSAddSocketToFlow");
179 remove_socket_from_flow_func_ =
180 (RemoveSocketFromFlowFn)GetProcAddress(qwave, "QOSRemoveSocketFromFlow");
181 set_flow_func_ = (SetFlowFn)GetProcAddress(qwave, "QOSSetFlow");
182
183 if (create_handle_func_ && close_handle_func_ &&
184 add_socket_to_flow_func_ && remove_socket_from_flow_func_ &&
185 set_flow_func_) {
186 qwave_supported_ = true;
187 }
188 }
189
GetDefault()190 QwaveApi* QwaveApi::GetDefault() {
191 static base::LazyInstance<QwaveApi>::Leaky lazy_qwave =
192 LAZY_INSTANCE_INITIALIZER;
193 return lazy_qwave.Pointer();
194 }
195
qwave_supported() const196 bool QwaveApi::qwave_supported() const {
197 return qwave_supported_;
198 }
199
OnFatalError()200 void QwaveApi::OnFatalError() {
201 // Disable everything moving forward.
202 qwave_supported_ = false;
203 }
204
CreateHandle(PQOS_VERSION version,PHANDLE handle)205 BOOL QwaveApi::CreateHandle(PQOS_VERSION version, PHANDLE handle) {
206 return create_handle_func_(version, handle);
207 }
208
CloseHandle(HANDLE handle)209 BOOL QwaveApi::CloseHandle(HANDLE handle) {
210 return close_handle_func_(handle);
211 }
212
AddSocketToFlow(HANDLE handle,SOCKET socket,PSOCKADDR addr,QOS_TRAFFIC_TYPE traffic_type,DWORD flags,PQOS_FLOWID flow_id)213 BOOL QwaveApi::AddSocketToFlow(HANDLE handle,
214 SOCKET socket,
215 PSOCKADDR addr,
216 QOS_TRAFFIC_TYPE traffic_type,
217 DWORD flags,
218 PQOS_FLOWID flow_id) {
219 return add_socket_to_flow_func_(handle, socket, addr, traffic_type, flags,
220 flow_id);
221 }
222
RemoveSocketFromFlow(HANDLE handle,SOCKET socket,QOS_FLOWID flow_id,DWORD reserved)223 BOOL QwaveApi::RemoveSocketFromFlow(HANDLE handle,
224 SOCKET socket,
225 QOS_FLOWID flow_id,
226 DWORD reserved) {
227 return remove_socket_from_flow_func_(handle, socket, flow_id, reserved);
228 }
229
SetFlow(HANDLE handle,QOS_FLOWID flow_id,QOS_SET_FLOW op,ULONG size,PVOID data,DWORD reserved,LPOVERLAPPED overlapped)230 BOOL QwaveApi::SetFlow(HANDLE handle,
231 QOS_FLOWID flow_id,
232 QOS_SET_FLOW op,
233 ULONG size,
234 PVOID data,
235 DWORD reserved,
236 LPOVERLAPPED overlapped) {
237 return set_flow_func_(handle, flow_id, op, size, data, reserved, overlapped);
238 }
239
240 //-----------------------------------------------------------------------------
241
UDPSocketWin(DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)242 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
243 net::NetLog* net_log,
244 const net::NetLogSource& source)
245 : socket_(INVALID_SOCKET),
246 socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
247 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::UDP_SOCKET)) {
248 EnsureWinsockInit();
249 net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source);
250 }
251
~UDPSocketWin()252 UDPSocketWin::~UDPSocketWin() {
253 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
254 Close();
255 net_log_.EndEvent(NetLogEventType::SOCKET_ALIVE);
256 }
257
Open(AddressFamily address_family)258 int UDPSocketWin::Open(AddressFamily address_family) {
259 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
260 DCHECK_EQ(socket_, INVALID_SOCKET);
261
262 auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
263 if (owned_socket_count.empty())
264 return ERR_INSUFFICIENT_RESOURCES;
265
266 owned_socket_count_ = std::move(owned_socket_count);
267 addr_family_ = ConvertAddressFamily(address_family);
268 socket_ = CreatePlatformSocket(addr_family_, SOCK_DGRAM, IPPROTO_UDP);
269 if (socket_ == INVALID_SOCKET) {
270 owned_socket_count_.Reset();
271 return MapSystemError(WSAGetLastError());
272 }
273 ConfigureOpenedSocket();
274 return OK;
275 }
276
AdoptOpenedSocket(AddressFamily address_family,SOCKET socket)277 int UDPSocketWin::AdoptOpenedSocket(AddressFamily address_family,
278 SOCKET socket) {
279 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
280 auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
281 if (owned_socket_count.empty()) {
282 return ERR_INSUFFICIENT_RESOURCES;
283 }
284
285 owned_socket_count_ = std::move(owned_socket_count);
286 addr_family_ = ConvertAddressFamily(address_family);
287 socket_ = socket;
288 ConfigureOpenedSocket();
289 return OK;
290 }
291
ConfigureOpenedSocket()292 void UDPSocketWin::ConfigureOpenedSocket() {
293 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
294 if (!use_non_blocking_io_) {
295 core_ = base::MakeRefCounted<Core>(this);
296 } else {
297 read_write_event_.Set(WSACreateEvent());
298 WSAEventSelect(socket_, read_write_event_.Get(), FD_READ | FD_WRITE);
299 }
300 }
301
Close()302 void UDPSocketWin::Close() {
303 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
304
305 owned_socket_count_.Reset();
306
307 if (socket_ == INVALID_SOCKET)
308 return;
309
310 // Remove socket_ from the QoS subsystem before we invalidate it.
311 dscp_manager_ = nullptr;
312
313 // Zero out any pending read/write callback state.
314 read_callback_.Reset();
315 recv_from_address_ = nullptr;
316 write_callback_.Reset();
317
318 base::TimeTicks start_time = base::TimeTicks::Now();
319 closesocket(socket_);
320 UMA_HISTOGRAM_TIMES("Net.UDPSocketWinClose",
321 base::TimeTicks::Now() - start_time);
322 socket_ = INVALID_SOCKET;
323 addr_family_ = 0;
324 is_connected_ = false;
325
326 // Release buffers to free up memory.
327 read_iobuffer_ = nullptr;
328 read_iobuffer_len_ = 0;
329 write_iobuffer_ = nullptr;
330 write_iobuffer_len_ = 0;
331
332 read_write_watcher_.StopWatching();
333 read_write_event_.Close();
334
335 event_pending_.InvalidateWeakPtrs();
336
337 if (core_) {
338 core_->Detach();
339 core_ = nullptr;
340 }
341 }
342
GetPeerAddress(IPEndPoint * address) const343 int UDPSocketWin::GetPeerAddress(IPEndPoint* address) const {
344 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
345 DCHECK(address);
346 if (!is_connected())
347 return ERR_SOCKET_NOT_CONNECTED;
348
349 // TODO(szym): Simplify. http://crbug.com/126152
350 if (!remote_address_.get()) {
351 SockaddrStorage storage;
352 if (getpeername(socket_, storage.addr, &storage.addr_len))
353 return MapSystemError(WSAGetLastError());
354 auto remote_address = std::make_unique<IPEndPoint>();
355 if (!remote_address->FromSockAddr(storage.addr, storage.addr_len))
356 return ERR_ADDRESS_INVALID;
357 remote_address_ = std::move(remote_address);
358 }
359
360 *address = *remote_address_;
361 return OK;
362 }
363
GetLocalAddress(IPEndPoint * address) const364 int UDPSocketWin::GetLocalAddress(IPEndPoint* address) const {
365 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
366 DCHECK(address);
367 if (!is_connected())
368 return ERR_SOCKET_NOT_CONNECTED;
369
370 // TODO(szym): Simplify. http://crbug.com/126152
371 if (!local_address_.get()) {
372 SockaddrStorage storage;
373 if (getsockname(socket_, storage.addr, &storage.addr_len))
374 return MapSystemError(WSAGetLastError());
375 auto local_address = std::make_unique<IPEndPoint>();
376 if (!local_address->FromSockAddr(storage.addr, storage.addr_len))
377 return ERR_ADDRESS_INVALID;
378 local_address_ = std::move(local_address);
379 net_log_.AddEvent(NetLogEventType::UDP_LOCAL_ADDRESS, [&] {
380 return CreateNetLogUDPConnectParams(*local_address_,
381 handles::kInvalidNetworkHandle);
382 });
383 }
384
385 *address = *local_address_;
386 return OK;
387 }
388
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)389 int UDPSocketWin::Read(IOBuffer* buf,
390 int buf_len,
391 CompletionOnceCallback callback) {
392 return RecvFrom(buf, buf_len, nullptr, std::move(callback));
393 }
394
RecvFrom(IOBuffer * buf,int buf_len,IPEndPoint * address,CompletionOnceCallback callback)395 int UDPSocketWin::RecvFrom(IOBuffer* buf,
396 int buf_len,
397 IPEndPoint* address,
398 CompletionOnceCallback callback) {
399 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
400 DCHECK_NE(INVALID_SOCKET, socket_);
401 CHECK(read_callback_.is_null());
402 DCHECK(!recv_from_address_);
403 DCHECK(!callback.is_null()); // Synchronous operation not supported.
404 DCHECK_GT(buf_len, 0);
405
406 int nread = core_ ? InternalRecvFromOverlapped(buf, buf_len, address)
407 : InternalRecvFromNonBlocking(buf, buf_len, address);
408 if (nread != ERR_IO_PENDING)
409 return nread;
410
411 read_callback_ = std::move(callback);
412 recv_from_address_ = address;
413 return ERR_IO_PENDING;
414 }
415
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)416 int UDPSocketWin::Write(
417 IOBuffer* buf,
418 int buf_len,
419 CompletionOnceCallback callback,
420 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
421 return SendToOrWrite(buf, buf_len, remote_address_.get(),
422 std::move(callback));
423 }
424
SendTo(IOBuffer * buf,int buf_len,const IPEndPoint & address,CompletionOnceCallback callback)425 int UDPSocketWin::SendTo(IOBuffer* buf,
426 int buf_len,
427 const IPEndPoint& address,
428 CompletionOnceCallback callback) {
429 if (dscp_manager_) {
430 // Alert DscpManager in case this is a new remote address. Failure to
431 // apply Dscp code is never fatal.
432 int rv = dscp_manager_->PrepareForSend(address);
433 if (rv != OK)
434 net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, rv);
435 }
436 return SendToOrWrite(buf, buf_len, &address, std::move(callback));
437 }
438
SendToOrWrite(IOBuffer * buf,int buf_len,const IPEndPoint * address,CompletionOnceCallback callback)439 int UDPSocketWin::SendToOrWrite(IOBuffer* buf,
440 int buf_len,
441 const IPEndPoint* address,
442 CompletionOnceCallback callback) {
443 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
444 DCHECK_NE(INVALID_SOCKET, socket_);
445 CHECK(write_callback_.is_null());
446 DCHECK(!callback.is_null()); // Synchronous operation not supported.
447 DCHECK_GT(buf_len, 0);
448 DCHECK(!send_to_address_.get());
449
450 int nwrite = core_ ? InternalSendToOverlapped(buf, buf_len, address)
451 : InternalSendToNonBlocking(buf, buf_len, address);
452 if (nwrite != ERR_IO_PENDING)
453 return nwrite;
454
455 if (address)
456 send_to_address_ = std::make_unique<IPEndPoint>(*address);
457 write_callback_ = std::move(callback);
458 return ERR_IO_PENDING;
459 }
460
Connect(const IPEndPoint & address)461 int UDPSocketWin::Connect(const IPEndPoint& address) {
462 DCHECK_NE(socket_, INVALID_SOCKET);
463 net_log_.BeginEvent(NetLogEventType::UDP_CONNECT, [&] {
464 return CreateNetLogUDPConnectParams(address,
465 handles::kInvalidNetworkHandle);
466 });
467 int rv = SetMulticastOptions();
468 if (rv != OK)
469 return rv;
470 rv = InternalConnect(address);
471 net_log_.EndEventWithNetErrorCode(NetLogEventType::UDP_CONNECT, rv);
472 is_connected_ = (rv == OK);
473 return rv;
474 }
475
InternalConnect(const IPEndPoint & address)476 int UDPSocketWin::InternalConnect(const IPEndPoint& address) {
477 DCHECK(!is_connected());
478 DCHECK(!remote_address_.get());
479
480 // Always do a random bind.
481 // Ignore failures, which may happen if the socket was already bound.
482 DWORD randomize_port_value = 1;
483 setsockopt(socket_, SOL_SOCKET, SO_RANDOMIZE_PORT,
484 reinterpret_cast<const char*>(&randomize_port_value),
485 sizeof(randomize_port_value));
486
487 SockaddrStorage storage;
488 if (!address.ToSockAddr(storage.addr, &storage.addr_len))
489 return ERR_ADDRESS_INVALID;
490
491 int rv = connect(socket_, storage.addr, storage.addr_len);
492 if (rv < 0)
493 return MapSystemError(WSAGetLastError());
494
495 remote_address_ = std::make_unique<IPEndPoint>(address);
496
497 if (dscp_manager_)
498 dscp_manager_->PrepareForSend(*remote_address_.get());
499
500 return rv;
501 }
502
Bind(const IPEndPoint & address)503 int UDPSocketWin::Bind(const IPEndPoint& address) {
504 DCHECK_NE(socket_, INVALID_SOCKET);
505 DCHECK(!is_connected());
506
507 int rv = SetMulticastOptions();
508 if (rv < 0)
509 return rv;
510
511 rv = DoBind(address);
512 if (rv < 0)
513 return rv;
514
515 local_address_.reset();
516 is_connected_ = true;
517 return rv;
518 }
519
BindToNetwork(handles::NetworkHandle network)520 int UDPSocketWin::BindToNetwork(handles::NetworkHandle network) {
521 NOTIMPLEMENTED();
522 return ERR_NOT_IMPLEMENTED;
523 }
524
SetReceiveBufferSize(int32_t size)525 int UDPSocketWin::SetReceiveBufferSize(int32_t size) {
526 DCHECK_NE(socket_, INVALID_SOCKET);
527 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
528 int rv = SetSocketReceiveBufferSize(socket_, size);
529
530 if (rv != 0)
531 return MapSystemError(WSAGetLastError());
532
533 // According to documentation, setsockopt may succeed, but we need to check
534 // the results via getsockopt to be sure it works on Windows.
535 int32_t actual_size = 0;
536 int option_size = sizeof(actual_size);
537 rv = getsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
538 reinterpret_cast<char*>(&actual_size), &option_size);
539 if (rv != 0)
540 return MapSystemError(WSAGetLastError());
541 if (actual_size >= size)
542 return OK;
543 UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableReceiveBuffer",
544 actual_size, 1000, 1000000, 50);
545 return ERR_SOCKET_RECEIVE_BUFFER_SIZE_UNCHANGEABLE;
546 }
547
SetSendBufferSize(int32_t size)548 int UDPSocketWin::SetSendBufferSize(int32_t size) {
549 DCHECK_NE(socket_, INVALID_SOCKET);
550 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
551 int rv = SetSocketSendBufferSize(socket_, size);
552 if (rv != 0)
553 return MapSystemError(WSAGetLastError());
554 // According to documentation, setsockopt may succeed, but we need to check
555 // the results via getsockopt to be sure it works on Windows.
556 int32_t actual_size = 0;
557 int option_size = sizeof(actual_size);
558 rv = getsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
559 reinterpret_cast<char*>(&actual_size), &option_size);
560 if (rv != 0)
561 return MapSystemError(WSAGetLastError());
562 if (actual_size >= size)
563 return OK;
564 UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableSendBuffer",
565 actual_size, 1000, 1000000, 50);
566 return ERR_SOCKET_SEND_BUFFER_SIZE_UNCHANGEABLE;
567 }
568
SetDoNotFragment()569 int UDPSocketWin::SetDoNotFragment() {
570 DCHECK_NE(socket_, INVALID_SOCKET);
571 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
572
573 if (addr_family_ == AF_INET6)
574 return OK;
575
576 DWORD val = 1;
577 int rv = setsockopt(socket_, IPPROTO_IP, IP_DONTFRAGMENT,
578 reinterpret_cast<const char*>(&val), sizeof(val));
579 return rv == 0 ? OK : MapSystemError(WSAGetLastError());
580 }
581
SetMsgConfirm(bool confirm)582 void UDPSocketWin::SetMsgConfirm(bool confirm) {}
583
AllowAddressReuse()584 int UDPSocketWin::AllowAddressReuse() {
585 DCHECK_NE(socket_, INVALID_SOCKET);
586 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
587 DCHECK(!is_connected());
588
589 BOOL true_value = TRUE;
590 int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
591 reinterpret_cast<const char*>(&true_value),
592 sizeof(true_value));
593 return rv == 0 ? OK : MapSystemError(WSAGetLastError());
594 }
595
SetBroadcast(bool broadcast)596 int UDPSocketWin::SetBroadcast(bool broadcast) {
597 DCHECK_NE(socket_, INVALID_SOCKET);
598 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
599
600 BOOL value = broadcast ? TRUE : FALSE;
601 int rv = setsockopt(socket_, SOL_SOCKET, SO_BROADCAST,
602 reinterpret_cast<const char*>(&value), sizeof(value));
603 return rv == 0 ? OK : MapSystemError(WSAGetLastError());
604 }
605
AllowAddressSharingForMulticast()606 int UDPSocketWin::AllowAddressSharingForMulticast() {
607 // When proper multicast groups are used, Windows further defines the address
608 // resuse option (SO_REUSEADDR) to ensure all listening sockets can receive
609 // all incoming messages for the multicast group.
610 return AllowAddressReuse();
611 }
612
DoReadCallback(int rv)613 void UDPSocketWin::DoReadCallback(int rv) {
614 DCHECK_NE(rv, ERR_IO_PENDING);
615 DCHECK(!read_callback_.is_null());
616
617 // since Run may result in Read being called, clear read_callback_ up front.
618 std::move(read_callback_).Run(rv);
619 }
620
DoWriteCallback(int rv)621 void UDPSocketWin::DoWriteCallback(int rv) {
622 DCHECK_NE(rv, ERR_IO_PENDING);
623 DCHECK(!write_callback_.is_null());
624
625 // since Run may result in Write being called, clear write_callback_ up front.
626 std::move(write_callback_).Run(rv);
627 }
628
DidCompleteRead()629 void UDPSocketWin::DidCompleteRead() {
630 DWORD num_bytes, flags;
631 BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_,
632 &num_bytes, FALSE, &flags);
633 WSAResetEvent(core_->read_overlapped_.hEvent);
634 int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
635 // Convert address.
636 IPEndPoint address;
637 IPEndPoint* address_to_log = nullptr;
638 if (result >= 0) {
639 if (address.FromSockAddr(core_->recv_addr_storage_.addr,
640 core_->recv_addr_storage_.addr_len)) {
641 if (recv_from_address_)
642 *recv_from_address_ = address;
643 address_to_log = &address;
644 } else {
645 result = ERR_ADDRESS_INVALID;
646 }
647 }
648 LogRead(result, core_->read_iobuffer_->data(), address_to_log);
649 core_->read_iobuffer_ = nullptr;
650 recv_from_address_ = nullptr;
651 DoReadCallback(result);
652 }
653
DidCompleteWrite()654 void UDPSocketWin::DidCompleteWrite() {
655 DWORD num_bytes, flags;
656 BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
657 &num_bytes, FALSE, &flags);
658 WSAResetEvent(core_->write_overlapped_.hEvent);
659 int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
660 LogWrite(result, core_->write_iobuffer_->data(), send_to_address_.get());
661
662 send_to_address_.reset();
663 core_->write_iobuffer_ = nullptr;
664 DoWriteCallback(result);
665 }
666
OnObjectSignaled(HANDLE object)667 void UDPSocketWin::OnObjectSignaled(HANDLE object) {
668 DCHECK(object == read_write_event_.Get());
669 WSANETWORKEVENTS network_events;
670 int os_error = 0;
671 int rv =
672 WSAEnumNetworkEvents(socket_, read_write_event_.Get(), &network_events);
673 // Protects against trying to call the write callback if the read callback
674 // either closes or destroys |this|.
675 base::WeakPtr<UDPSocketWin> event_pending = event_pending_.GetWeakPtr();
676 if (rv == SOCKET_ERROR) {
677 os_error = WSAGetLastError();
678 rv = MapSystemError(os_error);
679
680 if (read_iobuffer_) {
681 read_iobuffer_ = nullptr;
682 read_iobuffer_len_ = 0;
683 recv_from_address_ = nullptr;
684 DoReadCallback(rv);
685 }
686
687 // Socket may have been closed or destroyed here.
688 if (event_pending && write_iobuffer_) {
689 write_iobuffer_ = nullptr;
690 write_iobuffer_len_ = 0;
691 send_to_address_.reset();
692 DoWriteCallback(rv);
693 }
694 return;
695 }
696
697 if ((network_events.lNetworkEvents & FD_READ) && read_iobuffer_)
698 OnReadSignaled();
699 if (!event_pending)
700 return;
701
702 if ((network_events.lNetworkEvents & FD_WRITE) && write_iobuffer_)
703 OnWriteSignaled();
704 if (!event_pending)
705 return;
706
707 // There's still pending read / write. Watch for further events.
708 if (read_iobuffer_ || write_iobuffer_)
709 WatchForReadWrite();
710 }
711
OnReadSignaled()712 void UDPSocketWin::OnReadSignaled() {
713 int rv = InternalRecvFromNonBlocking(read_iobuffer_.get(), read_iobuffer_len_,
714 recv_from_address_);
715 if (rv == ERR_IO_PENDING)
716 return;
717 read_iobuffer_ = nullptr;
718 read_iobuffer_len_ = 0;
719 recv_from_address_ = nullptr;
720 DoReadCallback(rv);
721 }
722
OnWriteSignaled()723 void UDPSocketWin::OnWriteSignaled() {
724 int rv = InternalSendToNonBlocking(write_iobuffer_.get(), write_iobuffer_len_,
725 send_to_address_.get());
726 if (rv == ERR_IO_PENDING)
727 return;
728 write_iobuffer_ = nullptr;
729 write_iobuffer_len_ = 0;
730 send_to_address_.reset();
731 DoWriteCallback(rv);
732 }
733
WatchForReadWrite()734 void UDPSocketWin::WatchForReadWrite() {
735 if (read_write_watcher_.IsWatching())
736 return;
737 bool watched =
738 read_write_watcher_.StartWatchingOnce(read_write_event_.Get(), this);
739 DCHECK(watched);
740 }
741
LogRead(int result,const char * bytes,const IPEndPoint * address) const742 void UDPSocketWin::LogRead(int result,
743 const char* bytes,
744 const IPEndPoint* address) const {
745 if (result < 0) {
746 net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_RECEIVE_ERROR,
747 result);
748 return;
749 }
750
751 if (net_log_.IsCapturing()) {
752 NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_RECEIVED, result,
753 bytes, address);
754 }
755
756 activity_monitor::IncrementBytesReceived(result);
757 }
758
LogWrite(int result,const char * bytes,const IPEndPoint * address) const759 void UDPSocketWin::LogWrite(int result,
760 const char* bytes,
761 const IPEndPoint* address) const {
762 if (result < 0) {
763 net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, result);
764 return;
765 }
766
767 if (net_log_.IsCapturing()) {
768 NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_SENT, result,
769 bytes, address);
770 }
771 }
772
InternalRecvFromOverlapped(IOBuffer * buf,int buf_len,IPEndPoint * address)773 int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf,
774 int buf_len,
775 IPEndPoint* address) {
776 DCHECK(!core_->read_iobuffer_.get());
777 SockaddrStorage& storage = core_->recv_addr_storage_;
778 storage.addr_len = sizeof(storage.addr_storage);
779
780 WSABUF read_buffer;
781 read_buffer.buf = buf->data();
782 read_buffer.len = buf_len;
783
784 DWORD flags = 0;
785 DWORD num;
786 CHECK_NE(INVALID_SOCKET, socket_);
787 int rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, storage.addr,
788 &storage.addr_len, &core_->read_overlapped_, nullptr);
789 if (rv == 0) {
790 if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) {
791 int result = num;
792 // Convert address.
793 IPEndPoint address_storage;
794 IPEndPoint* address_to_log = nullptr;
795 if (result >= 0) {
796 if (address_storage.FromSockAddr(core_->recv_addr_storage_.addr,
797 core_->recv_addr_storage_.addr_len)) {
798 if (address)
799 *address = address_storage;
800 address_to_log = &address_storage;
801 } else {
802 result = ERR_ADDRESS_INVALID;
803 }
804 }
805 LogRead(result, buf->data(), address_to_log);
806 return result;
807 }
808 } else {
809 int os_error = WSAGetLastError();
810 if (os_error != WSA_IO_PENDING) {
811 int result = MapSystemError(os_error);
812 LogRead(result, nullptr, nullptr);
813 return result;
814 }
815 }
816 core_->WatchForRead();
817 core_->read_iobuffer_ = buf;
818 return ERR_IO_PENDING;
819 }
820
InternalSendToOverlapped(IOBuffer * buf,int buf_len,const IPEndPoint * address)821 int UDPSocketWin::InternalSendToOverlapped(IOBuffer* buf,
822 int buf_len,
823 const IPEndPoint* address) {
824 DCHECK(!core_->write_iobuffer_.get());
825 SockaddrStorage storage;
826 struct sockaddr* addr = storage.addr;
827 // Convert address.
828 if (!address) {
829 addr = nullptr;
830 storage.addr_len = 0;
831 } else {
832 if (!address->ToSockAddr(addr, &storage.addr_len)) {
833 int result = ERR_ADDRESS_INVALID;
834 LogWrite(result, nullptr, nullptr);
835 return result;
836 }
837 }
838
839 WSABUF write_buffer;
840 write_buffer.buf = buf->data();
841 write_buffer.len = buf_len;
842
843 DWORD flags = 0;
844 DWORD num;
845 int rv = WSASendTo(socket_, &write_buffer, 1, &num, flags, addr,
846 storage.addr_len, &core_->write_overlapped_, nullptr);
847 if (rv == 0) {
848 if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) {
849 int result = num;
850 LogWrite(result, buf->data(), address);
851 return result;
852 }
853 } else {
854 int os_error = WSAGetLastError();
855 if (os_error != WSA_IO_PENDING) {
856 int result = MapSystemError(os_error);
857 LogWrite(result, nullptr, nullptr);
858 return result;
859 }
860 }
861
862 core_->WatchForWrite();
863 core_->write_iobuffer_ = buf;
864 return ERR_IO_PENDING;
865 }
866
InternalRecvFromNonBlocking(IOBuffer * buf,int buf_len,IPEndPoint * address)867 int UDPSocketWin::InternalRecvFromNonBlocking(IOBuffer* buf,
868 int buf_len,
869 IPEndPoint* address) {
870 DCHECK(!read_iobuffer_ || read_iobuffer_.get() == buf);
871 SockaddrStorage storage;
872 storage.addr_len = sizeof(storage.addr_storage);
873
874 CHECK_NE(INVALID_SOCKET, socket_);
875 int rv = recvfrom(socket_, buf->data(), buf_len, 0, storage.addr,
876 &storage.addr_len);
877 if (rv == SOCKET_ERROR) {
878 int os_error = WSAGetLastError();
879 if (os_error == WSAEWOULDBLOCK) {
880 read_iobuffer_ = buf;
881 read_iobuffer_len_ = buf_len;
882 WatchForReadWrite();
883 return ERR_IO_PENDING;
884 }
885 rv = MapSystemError(os_error);
886 LogRead(rv, nullptr, nullptr);
887 return rv;
888 }
889 IPEndPoint address_storage;
890 IPEndPoint* address_to_log = nullptr;
891 if (rv >= 0) {
892 if (address_storage.FromSockAddr(storage.addr, storage.addr_len)) {
893 if (address)
894 *address = address_storage;
895 address_to_log = &address_storage;
896 } else {
897 rv = ERR_ADDRESS_INVALID;
898 }
899 }
900 LogRead(rv, buf->data(), address_to_log);
901 return rv;
902 }
903
InternalSendToNonBlocking(IOBuffer * buf,int buf_len,const IPEndPoint * address)904 int UDPSocketWin::InternalSendToNonBlocking(IOBuffer* buf,
905 int buf_len,
906 const IPEndPoint* address) {
907 DCHECK(!write_iobuffer_ || write_iobuffer_.get() == buf);
908 SockaddrStorage storage;
909 struct sockaddr* addr = storage.addr;
910 // Convert address.
911 if (address) {
912 if (!address->ToSockAddr(addr, &storage.addr_len)) {
913 int result = ERR_ADDRESS_INVALID;
914 LogWrite(result, nullptr, nullptr);
915 return result;
916 }
917 } else {
918 addr = nullptr;
919 storage.addr_len = 0;
920 }
921
922 int rv = sendto(socket_, buf->data(), buf_len, 0, addr, storage.addr_len);
923 if (rv == SOCKET_ERROR) {
924 int os_error = WSAGetLastError();
925 if (os_error == WSAEWOULDBLOCK) {
926 write_iobuffer_ = buf;
927 write_iobuffer_len_ = buf_len;
928 WatchForReadWrite();
929 return ERR_IO_PENDING;
930 }
931 rv = MapSystemError(os_error);
932 LogWrite(rv, nullptr, nullptr);
933 return rv;
934 }
935 LogWrite(rv, buf->data(), address);
936 return rv;
937 }
938
SetMulticastOptions()939 int UDPSocketWin::SetMulticastOptions() {
940 if (!(socket_options_ & SOCKET_OPTION_MULTICAST_LOOP)) {
941 DWORD loop = 0;
942 int protocol_level =
943 addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
944 int option =
945 addr_family_ == AF_INET ? IP_MULTICAST_LOOP: IPV6_MULTICAST_LOOP;
946 int rv = setsockopt(socket_, protocol_level, option,
947 reinterpret_cast<const char*>(&loop), sizeof(loop));
948 if (rv < 0)
949 return MapSystemError(WSAGetLastError());
950 }
951 if (multicast_time_to_live_ != 1) {
952 DWORD hops = multicast_time_to_live_;
953 int protocol_level =
954 addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
955 int option =
956 addr_family_ == AF_INET ? IP_MULTICAST_TTL: IPV6_MULTICAST_HOPS;
957 int rv = setsockopt(socket_, protocol_level, option,
958 reinterpret_cast<const char*>(&hops), sizeof(hops));
959 if (rv < 0)
960 return MapSystemError(WSAGetLastError());
961 }
962 if (multicast_interface_ != 0) {
963 switch (addr_family_) {
964 case AF_INET: {
965 in_addr address;
966 address.s_addr = htonl(multicast_interface_);
967 int rv = setsockopt(socket_, IPPROTO_IP, IP_MULTICAST_IF,
968 reinterpret_cast<const char*>(&address),
969 sizeof(address));
970 if (rv)
971 return MapSystemError(WSAGetLastError());
972 break;
973 }
974 case AF_INET6: {
975 uint32_t interface_index = multicast_interface_;
976 int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_MULTICAST_IF,
977 reinterpret_cast<const char*>(&interface_index),
978 sizeof(interface_index));
979 if (rv)
980 return MapSystemError(WSAGetLastError());
981 break;
982 }
983 default:
984 NOTREACHED() << "Invalid address family";
985 return ERR_ADDRESS_INVALID;
986 }
987 }
988 return OK;
989 }
990
DoBind(const IPEndPoint & address)991 int UDPSocketWin::DoBind(const IPEndPoint& address) {
992 SockaddrStorage storage;
993 if (!address.ToSockAddr(storage.addr, &storage.addr_len))
994 return ERR_ADDRESS_INVALID;
995 int rv = bind(socket_, storage.addr, storage.addr_len);
996 if (rv == 0)
997 return OK;
998 int last_error = WSAGetLastError();
999 // Map some codes that are special to bind() separately.
1000 // * WSAEACCES: If a port is already bound to a socket, WSAEACCES may be
1001 // returned instead of WSAEADDRINUSE, depending on whether the socket
1002 // option SO_REUSEADDR or SO_EXCLUSIVEADDRUSE is set and whether the
1003 // conflicting socket is owned by a different user account. See the MSDN
1004 // page "Using SO_REUSEADDR and SO_EXCLUSIVEADDRUSE" for the gory details.
1005 if (last_error == WSAEACCES || last_error == WSAEADDRNOTAVAIL)
1006 return ERR_ADDRESS_IN_USE;
1007 return MapSystemError(last_error);
1008 }
1009
GetQwaveApi() const1010 QwaveApi* UDPSocketWin::GetQwaveApi() const {
1011 return QwaveApi::GetDefault();
1012 }
1013
JoinGroup(const IPAddress & group_address) const1014 int UDPSocketWin::JoinGroup(const IPAddress& group_address) const {
1015 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1016 if (!is_connected())
1017 return ERR_SOCKET_NOT_CONNECTED;
1018
1019 switch (group_address.size()) {
1020 case IPAddress::kIPv4AddressSize: {
1021 if (addr_family_ != AF_INET)
1022 return ERR_ADDRESS_INVALID;
1023 ip_mreq mreq;
1024 mreq.imr_interface.s_addr = htonl(multicast_interface_);
1025 memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1026 IPAddress::kIPv4AddressSize);
1027 int rv = setsockopt(socket_, IPPROTO_IP, IP_ADD_MEMBERSHIP,
1028 reinterpret_cast<const char*>(&mreq),
1029 sizeof(mreq));
1030 if (rv)
1031 return MapSystemError(WSAGetLastError());
1032 return OK;
1033 }
1034 case IPAddress::kIPv6AddressSize: {
1035 if (addr_family_ != AF_INET6)
1036 return ERR_ADDRESS_INVALID;
1037 ipv6_mreq mreq;
1038 mreq.ipv6mr_interface = multicast_interface_;
1039 memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1040 IPAddress::kIPv6AddressSize);
1041 int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
1042 reinterpret_cast<const char*>(&mreq),
1043 sizeof(mreq));
1044 if (rv)
1045 return MapSystemError(WSAGetLastError());
1046 return OK;
1047 }
1048 default:
1049 NOTREACHED() << "Invalid address family";
1050 return ERR_ADDRESS_INVALID;
1051 }
1052 }
1053
LeaveGroup(const IPAddress & group_address) const1054 int UDPSocketWin::LeaveGroup(const IPAddress& group_address) const {
1055 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1056 if (!is_connected())
1057 return ERR_SOCKET_NOT_CONNECTED;
1058
1059 switch (group_address.size()) {
1060 case IPAddress::kIPv4AddressSize: {
1061 if (addr_family_ != AF_INET)
1062 return ERR_ADDRESS_INVALID;
1063 ip_mreq mreq;
1064 mreq.imr_interface.s_addr = htonl(multicast_interface_);
1065 memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1066 IPAddress::kIPv4AddressSize);
1067 int rv = setsockopt(socket_, IPPROTO_IP, IP_DROP_MEMBERSHIP,
1068 reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1069 if (rv)
1070 return MapSystemError(WSAGetLastError());
1071 return OK;
1072 }
1073 case IPAddress::kIPv6AddressSize: {
1074 if (addr_family_ != AF_INET6)
1075 return ERR_ADDRESS_INVALID;
1076 ipv6_mreq mreq;
1077 mreq.ipv6mr_interface = multicast_interface_;
1078 memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1079 IPAddress::kIPv6AddressSize);
1080 int rv = setsockopt(socket_, IPPROTO_IPV6, IP_DROP_MEMBERSHIP,
1081 reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1082 if (rv)
1083 return MapSystemError(WSAGetLastError());
1084 return OK;
1085 }
1086 default:
1087 NOTREACHED() << "Invalid address family";
1088 return ERR_ADDRESS_INVALID;
1089 }
1090 }
1091
SetMulticastInterface(uint32_t interface_index)1092 int UDPSocketWin::SetMulticastInterface(uint32_t interface_index) {
1093 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1094 if (is_connected())
1095 return ERR_SOCKET_IS_CONNECTED;
1096 multicast_interface_ = interface_index;
1097 return OK;
1098 }
1099
SetMulticastTimeToLive(int time_to_live)1100 int UDPSocketWin::SetMulticastTimeToLive(int time_to_live) {
1101 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1102 if (is_connected())
1103 return ERR_SOCKET_IS_CONNECTED;
1104
1105 if (time_to_live < 0 || time_to_live > 255)
1106 return ERR_INVALID_ARGUMENT;
1107 multicast_time_to_live_ = time_to_live;
1108 return OK;
1109 }
1110
SetMulticastLoopbackMode(bool loopback)1111 int UDPSocketWin::SetMulticastLoopbackMode(bool loopback) {
1112 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1113 if (is_connected())
1114 return ERR_SOCKET_IS_CONNECTED;
1115
1116 if (loopback)
1117 socket_options_ |= SOCKET_OPTION_MULTICAST_LOOP;
1118 else
1119 socket_options_ &= ~SOCKET_OPTION_MULTICAST_LOOP;
1120 return OK;
1121 }
1122
DscpToTrafficType(DiffServCodePoint dscp)1123 QOS_TRAFFIC_TYPE DscpToTrafficType(DiffServCodePoint dscp) {
1124 QOS_TRAFFIC_TYPE traffic_type = QOSTrafficTypeBestEffort;
1125 switch (dscp) {
1126 case DSCP_CS0:
1127 traffic_type = QOSTrafficTypeBestEffort;
1128 break;
1129 case DSCP_CS1:
1130 traffic_type = QOSTrafficTypeBackground;
1131 break;
1132 case DSCP_AF11:
1133 case DSCP_AF12:
1134 case DSCP_AF13:
1135 case DSCP_CS2:
1136 case DSCP_AF21:
1137 case DSCP_AF22:
1138 case DSCP_AF23:
1139 case DSCP_CS3:
1140 case DSCP_AF31:
1141 case DSCP_AF32:
1142 case DSCP_AF33:
1143 case DSCP_CS4:
1144 traffic_type = QOSTrafficTypeExcellentEffort;
1145 break;
1146 case DSCP_AF41:
1147 case DSCP_AF42:
1148 case DSCP_AF43:
1149 case DSCP_CS5:
1150 traffic_type = QOSTrafficTypeAudioVideo;
1151 break;
1152 case DSCP_EF:
1153 case DSCP_CS6:
1154 traffic_type = QOSTrafficTypeVoice;
1155 break;
1156 case DSCP_CS7:
1157 traffic_type = QOSTrafficTypeControl;
1158 break;
1159 case DSCP_NO_CHANGE:
1160 NOTREACHED();
1161 break;
1162 }
1163 return traffic_type;
1164 }
1165
SetDiffServCodePoint(DiffServCodePoint dscp)1166 int UDPSocketWin::SetDiffServCodePoint(DiffServCodePoint dscp) {
1167 if (dscp == DSCP_NO_CHANGE)
1168 return OK;
1169
1170 if (!is_connected())
1171 return ERR_SOCKET_NOT_CONNECTED;
1172
1173 QwaveApi* api = GetQwaveApi();
1174
1175 if (!api->qwave_supported())
1176 return ERR_NOT_IMPLEMENTED;
1177
1178 if (!dscp_manager_)
1179 dscp_manager_ = std::make_unique<DscpManager>(api, socket_);
1180
1181 dscp_manager_->Set(dscp);
1182 if (remote_address_)
1183 return dscp_manager_->PrepareForSend(*remote_address_.get());
1184
1185 return OK;
1186 }
1187
SetIPv6Only(bool ipv6_only)1188 int UDPSocketWin::SetIPv6Only(bool ipv6_only) {
1189 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1190 if (is_connected()) {
1191 return ERR_SOCKET_IS_CONNECTED;
1192 }
1193 return net::SetIPv6Only(socket_, ipv6_only);
1194 }
1195
DetachFromThread()1196 void UDPSocketWin::DetachFromThread() {
1197 DETACH_FROM_THREAD(thread_checker_);
1198 }
1199
UseNonBlockingIO()1200 void UDPSocketWin::UseNonBlockingIO() {
1201 DCHECK(!core_);
1202 use_non_blocking_io_ = true;
1203 }
1204
ApplySocketTag(const SocketTag & tag)1205 void UDPSocketWin::ApplySocketTag(const SocketTag& tag) {
1206 // Windows does not support any specific SocketTags so fail if any non-default
1207 // tag is applied.
1208 CHECK(tag == SocketTag());
1209 }
1210
DscpManager(QwaveApi * api,SOCKET socket)1211 DscpManager::DscpManager(QwaveApi* api, SOCKET socket)
1212 : api_(api), socket_(socket) {
1213 RequestHandle();
1214 }
1215
~DscpManager()1216 DscpManager::~DscpManager() {
1217 if (!qos_handle_)
1218 return;
1219
1220 if (flow_id_ != 0)
1221 api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1222
1223 api_->CloseHandle(qos_handle_);
1224 }
1225
Set(DiffServCodePoint dscp)1226 void DscpManager::Set(DiffServCodePoint dscp) {
1227 if (dscp == DSCP_NO_CHANGE || dscp == dscp_value_)
1228 return;
1229
1230 dscp_value_ = dscp;
1231
1232 // TODO(zstein): We could reuse the flow when the value changes
1233 // by calling QOSSetFlow with the new traffic type and dscp value.
1234 if (flow_id_ != 0 && qos_handle_) {
1235 api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1236 configured_.clear();
1237 flow_id_ = 0;
1238 }
1239 }
1240
PrepareForSend(const IPEndPoint & remote_address)1241 int DscpManager::PrepareForSend(const IPEndPoint& remote_address) {
1242 if (dscp_value_ == DSCP_NO_CHANGE) {
1243 // No DSCP value has been set.
1244 return OK;
1245 }
1246
1247 if (!api_->qwave_supported())
1248 return ERR_NOT_IMPLEMENTED;
1249
1250 if (!qos_handle_)
1251 return ERR_INVALID_HANDLE; // The closest net error to try again later.
1252
1253 if (configured_.find(remote_address) != configured_.end())
1254 return OK;
1255
1256 SockaddrStorage storage;
1257 if (!remote_address.ToSockAddr(storage.addr, &storage.addr_len))
1258 return ERR_ADDRESS_INVALID;
1259
1260 // We won't try this address again if we get an error.
1261 configured_.emplace(remote_address);
1262
1263 // We don't need to call SetFlow if we already have a qos flow.
1264 bool new_flow = flow_id_ == 0;
1265
1266 const QOS_TRAFFIC_TYPE traffic_type = DscpToTrafficType(dscp_value_);
1267
1268 if (!api_->AddSocketToFlow(qos_handle_, socket_, storage.addr, traffic_type,
1269 QOS_NON_ADAPTIVE_FLOW, &flow_id_)) {
1270 DWORD err = ::GetLastError();
1271 if (err == ERROR_DEVICE_REINITIALIZATION_NEEDED) {
1272 // Reset. PrepareForSend is called for every packet. Once RequestHandle
1273 // completes asynchronously the next PrepareForSend call will re-register
1274 // the address with the new QoS Handle. In the meantime, sends will
1275 // continue without DSCP.
1276 RequestHandle();
1277 configured_.clear();
1278 flow_id_ = 0;
1279 return ERR_INVALID_HANDLE;
1280 }
1281 return MapSystemError(err);
1282 }
1283
1284 if (new_flow) {
1285 DWORD buf = dscp_value_;
1286 // This requires admin rights, and may fail, if so we ignore it
1287 // as AddSocketToFlow should still do *approximately* the right thing.
1288 api_->SetFlow(qos_handle_, flow_id_, QOSSetOutgoingDSCPValue, sizeof(buf),
1289 &buf, 0, nullptr);
1290 }
1291
1292 return OK;
1293 }
1294
RequestHandle()1295 void DscpManager::RequestHandle() {
1296 if (handle_is_initializing_)
1297 return;
1298
1299 if (qos_handle_) {
1300 api_->CloseHandle(qos_handle_);
1301 qos_handle_ = nullptr;
1302 }
1303
1304 handle_is_initializing_ = true;
1305 base::ThreadPool::PostTaskAndReplyWithResult(
1306 FROM_HERE, {base::MayBlock()},
1307 base::BindOnce(&DscpManager::DoCreateHandle, api_),
1308 base::BindOnce(&DscpManager::OnHandleCreated, api_,
1309 weak_ptr_factory_.GetWeakPtr()));
1310 }
1311
DoCreateHandle(QwaveApi * api)1312 HANDLE DscpManager::DoCreateHandle(QwaveApi* api) {
1313 QOS_VERSION version;
1314 version.MajorVersion = 1;
1315 version.MinorVersion = 0;
1316
1317 HANDLE handle = nullptr;
1318
1319 // No access to net_log_ so swallow any errors here.
1320 api->CreateHandle(&version, &handle);
1321 return handle;
1322 }
1323
OnHandleCreated(QwaveApi * api,base::WeakPtr<DscpManager> dscp_manager,HANDLE handle)1324 void DscpManager::OnHandleCreated(QwaveApi* api,
1325 base::WeakPtr<DscpManager> dscp_manager,
1326 HANDLE handle) {
1327 if (!handle)
1328 api->OnFatalError();
1329
1330 if (!dscp_manager) {
1331 api->CloseHandle(handle);
1332 return;
1333 }
1334
1335 DCHECK(dscp_manager->handle_is_initializing_);
1336 DCHECK(!dscp_manager->qos_handle_);
1337
1338 dscp_manager->qos_handle_ = handle;
1339 dscp_manager->handle_is_initializing_ = false;
1340 }
1341
1342 } // namespace net
1343