1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/udp_socket_win.h"
6
7 #include <mstcpip.h>
8 #include <winsock2.h>
9
10 #include <memory>
11
12 #include "base/check_op.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback.h"
15 #include "base/lazy_instance.h"
16 #include "base/memory/raw_ptr.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/metrics/histogram_macros.h"
19 #include "base/notreached.h"
20 #include "base/rand_util.h"
21 #include "base/task/thread_pool.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/network_activity_monitor.h"
27 #include "net/base/network_change_notifier.h"
28 #include "net/base/sockaddr_storage.h"
29 #include "net/base/winsock_init.h"
30 #include "net/base/winsock_util.h"
31 #include "net/log/net_log.h"
32 #include "net/log/net_log_event_type.h"
33 #include "net/log/net_log_source.h"
34 #include "net/log/net_log_source_type.h"
35 #include "net/socket/socket_descriptor.h"
36 #include "net/socket/socket_options.h"
37 #include "net/socket/socket_tag.h"
38 #include "net/socket/udp_net_log_parameters.h"
39 #include "net/traffic_annotation/network_traffic_annotation.h"
40
41 namespace net {
42
43 // This class encapsulates all the state that has to be preserved as long as
44 // there is a network IO operation in progress. If the owner UDPSocketWin
45 // is destroyed while an operation is in progress, the Core is detached and it
46 // lives until the operation completes and the OS doesn't reference any resource
47 // declared on this class anymore.
48 class UDPSocketWin::Core : public base::RefCounted<Core> {
49 public:
50 explicit Core(UDPSocketWin* socket);
51
52 Core(const Core&) = delete;
53 Core& operator=(const Core&) = delete;
54
55 // Start watching for the end of a read or write operation.
56 void WatchForRead();
57 void WatchForWrite();
58
59 // The UDPSocketWin is going away.
Detach()60 void Detach() { socket_ = nullptr; }
61
62 // The separate OVERLAPPED variables for asynchronous operation.
63 OVERLAPPED read_overlapped_;
64 OVERLAPPED write_overlapped_;
65
66 // The buffers used in Read() and Write().
67 scoped_refptr<IOBuffer> read_iobuffer_;
68 scoped_refptr<IOBuffer> write_iobuffer_;
69
70 // The address storage passed to WSARecvFrom().
71 SockaddrStorage recv_addr_storage_;
72
73 private:
74 friend class base::RefCounted<Core>;
75
76 class ReadDelegate : public base::win::ObjectWatcher::Delegate {
77 public:
ReadDelegate(Core * core)78 explicit ReadDelegate(Core* core) : core_(core) {}
79 ~ReadDelegate() override = default;
80
81 // base::ObjectWatcher::Delegate methods:
82 void OnObjectSignaled(HANDLE object) override;
83
84 private:
85 const raw_ptr<Core> core_;
86 };
87
88 class WriteDelegate : public base::win::ObjectWatcher::Delegate {
89 public:
WriteDelegate(Core * core)90 explicit WriteDelegate(Core* core) : core_(core) {}
91 ~WriteDelegate() override = default;
92
93 // base::ObjectWatcher::Delegate methods:
94 void OnObjectSignaled(HANDLE object) override;
95
96 private:
97 const raw_ptr<Core> core_;
98 };
99
100 ~Core();
101
102 // The socket that created this object.
103 raw_ptr<UDPSocketWin> socket_;
104
105 // |reader_| handles the signals from |read_watcher_|.
106 ReadDelegate reader_;
107 // |writer_| handles the signals from |write_watcher_|.
108 WriteDelegate writer_;
109
110 // |read_watcher_| watches for events from Read().
111 base::win::ObjectWatcher read_watcher_;
112 // |write_watcher_| watches for events from Write();
113 base::win::ObjectWatcher write_watcher_;
114 };
115
Core(UDPSocketWin * socket)116 UDPSocketWin::Core::Core(UDPSocketWin* socket)
117 : socket_(socket),
118 reader_(this),
119 writer_(this) {
120 memset(&read_overlapped_, 0, sizeof(read_overlapped_));
121 memset(&write_overlapped_, 0, sizeof(write_overlapped_));
122
123 read_overlapped_.hEvent = WSACreateEvent();
124 write_overlapped_.hEvent = WSACreateEvent();
125 }
126
~Core()127 UDPSocketWin::Core::~Core() {
128 // Make sure the message loop is not watching this object anymore.
129 read_watcher_.StopWatching();
130 write_watcher_.StopWatching();
131
132 WSACloseEvent(read_overlapped_.hEvent);
133 memset(&read_overlapped_, 0xaf, sizeof(read_overlapped_));
134 WSACloseEvent(write_overlapped_.hEvent);
135 memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
136 }
137
WatchForRead()138 void UDPSocketWin::Core::WatchForRead() {
139 // We grab an extra reference because there is an IO operation in progress.
140 // Balanced in ReadDelegate::OnObjectSignaled().
141 AddRef();
142 read_watcher_.StartWatchingOnce(read_overlapped_.hEvent, &reader_);
143 }
144
WatchForWrite()145 void UDPSocketWin::Core::WatchForWrite() {
146 // We grab an extra reference because there is an IO operation in progress.
147 // Balanced in WriteDelegate::OnObjectSignaled().
148 AddRef();
149 write_watcher_.StartWatchingOnce(write_overlapped_.hEvent, &writer_);
150 }
151
OnObjectSignaled(HANDLE object)152 void UDPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
153 DCHECK_EQ(object, core_->read_overlapped_.hEvent);
154 if (core_->socket_)
155 core_->socket_->DidCompleteRead();
156
157 core_->Release();
158 }
159
OnObjectSignaled(HANDLE object)160 void UDPSocketWin::Core::WriteDelegate::OnObjectSignaled(HANDLE object) {
161 DCHECK_EQ(object, core_->write_overlapped_.hEvent);
162 if (core_->socket_)
163 core_->socket_->DidCompleteWrite();
164
165 core_->Release();
166 }
167 //-----------------------------------------------------------------------------
168
QwaveApi()169 QwaveApi::QwaveApi() {
170 HMODULE qwave = LoadLibrary(L"qwave.dll");
171 if (!qwave)
172 return;
173 create_handle_func_ =
174 (CreateHandleFn)GetProcAddress(qwave, "QOSCreateHandle");
175 close_handle_func_ =
176 (CloseHandleFn)GetProcAddress(qwave, "QOSCloseHandle");
177 add_socket_to_flow_func_ =
178 (AddSocketToFlowFn)GetProcAddress(qwave, "QOSAddSocketToFlow");
179 remove_socket_from_flow_func_ =
180 (RemoveSocketFromFlowFn)GetProcAddress(qwave, "QOSRemoveSocketFromFlow");
181 set_flow_func_ = (SetFlowFn)GetProcAddress(qwave, "QOSSetFlow");
182
183 if (create_handle_func_ && close_handle_func_ &&
184 add_socket_to_flow_func_ && remove_socket_from_flow_func_ &&
185 set_flow_func_) {
186 qwave_supported_ = true;
187 }
188 }
189
GetDefault()190 QwaveApi* QwaveApi::GetDefault() {
191 static base::LazyInstance<QwaveApi>::Leaky lazy_qwave =
192 LAZY_INSTANCE_INITIALIZER;
193 return lazy_qwave.Pointer();
194 }
195
qwave_supported() const196 bool QwaveApi::qwave_supported() const {
197 return qwave_supported_;
198 }
199
OnFatalError()200 void QwaveApi::OnFatalError() {
201 // Disable everything moving forward.
202 qwave_supported_ = false;
203 }
204
CreateHandle(PQOS_VERSION version,PHANDLE handle)205 BOOL QwaveApi::CreateHandle(PQOS_VERSION version, PHANDLE handle) {
206 return create_handle_func_(version, handle);
207 }
208
CloseHandle(HANDLE handle)209 BOOL QwaveApi::CloseHandle(HANDLE handle) {
210 return close_handle_func_(handle);
211 }
212
AddSocketToFlow(HANDLE handle,SOCKET socket,PSOCKADDR addr,QOS_TRAFFIC_TYPE traffic_type,DWORD flags,PQOS_FLOWID flow_id)213 BOOL QwaveApi::AddSocketToFlow(HANDLE handle,
214 SOCKET socket,
215 PSOCKADDR addr,
216 QOS_TRAFFIC_TYPE traffic_type,
217 DWORD flags,
218 PQOS_FLOWID flow_id) {
219 return add_socket_to_flow_func_(handle, socket, addr, traffic_type, flags,
220 flow_id);
221 }
222
RemoveSocketFromFlow(HANDLE handle,SOCKET socket,QOS_FLOWID flow_id,DWORD reserved)223 BOOL QwaveApi::RemoveSocketFromFlow(HANDLE handle,
224 SOCKET socket,
225 QOS_FLOWID flow_id,
226 DWORD reserved) {
227 return remove_socket_from_flow_func_(handle, socket, flow_id, reserved);
228 }
229
SetFlow(HANDLE handle,QOS_FLOWID flow_id,QOS_SET_FLOW op,ULONG size,PVOID data,DWORD reserved,LPOVERLAPPED overlapped)230 BOOL QwaveApi::SetFlow(HANDLE handle,
231 QOS_FLOWID flow_id,
232 QOS_SET_FLOW op,
233 ULONG size,
234 PVOID data,
235 DWORD reserved,
236 LPOVERLAPPED overlapped) {
237 return set_flow_func_(handle, flow_id, op, size, data, reserved, overlapped);
238 }
239
240 //-----------------------------------------------------------------------------
241
UDPSocketWin(DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)242 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
243 net::NetLog* net_log,
244 const net::NetLogSource& source)
245 : socket_(INVALID_SOCKET),
246 socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
247 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::UDP_SOCKET)) {
248 EnsureWinsockInit();
249 net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source);
250 }
251
UDPSocketWin(DatagramSocket::BindType bind_type,NetLogWithSource source_net_log)252 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
253 NetLogWithSource source_net_log)
254 : socket_(INVALID_SOCKET),
255 socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
256 net_log_(source_net_log) {
257 EnsureWinsockInit();
258 net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE,
259 net_log_.source());
260 }
261
~UDPSocketWin()262 UDPSocketWin::~UDPSocketWin() {
263 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
264 Close();
265 net_log_.EndEvent(NetLogEventType::SOCKET_ALIVE);
266 }
267
Open(AddressFamily address_family)268 int UDPSocketWin::Open(AddressFamily address_family) {
269 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
270 DCHECK_EQ(socket_, INVALID_SOCKET);
271
272 auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
273 if (owned_socket_count.empty())
274 return ERR_INSUFFICIENT_RESOURCES;
275
276 owned_socket_count_ = std::move(owned_socket_count);
277 addr_family_ = ConvertAddressFamily(address_family);
278 socket_ = CreatePlatformSocket(addr_family_, SOCK_DGRAM, IPPROTO_UDP);
279 if (socket_ == INVALID_SOCKET) {
280 owned_socket_count_.Reset();
281 return MapSystemError(WSAGetLastError());
282 }
283 ConfigureOpenedSocket();
284 return OK;
285 }
286
AdoptOpenedSocket(AddressFamily address_family,SOCKET socket)287 int UDPSocketWin::AdoptOpenedSocket(AddressFamily address_family,
288 SOCKET socket) {
289 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
290 auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
291 if (owned_socket_count.empty()) {
292 return ERR_INSUFFICIENT_RESOURCES;
293 }
294
295 owned_socket_count_ = std::move(owned_socket_count);
296 addr_family_ = ConvertAddressFamily(address_family);
297 socket_ = socket;
298 ConfigureOpenedSocket();
299 return OK;
300 }
301
ConfigureOpenedSocket()302 void UDPSocketWin::ConfigureOpenedSocket() {
303 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
304 if (!use_non_blocking_io_) {
305 core_ = base::MakeRefCounted<Core>(this);
306 } else {
307 read_write_event_.Set(WSACreateEvent());
308 WSAEventSelect(socket_, read_write_event_.Get(), FD_READ | FD_WRITE);
309 }
310 }
311
Close()312 void UDPSocketWin::Close() {
313 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
314
315 owned_socket_count_.Reset();
316
317 if (socket_ == INVALID_SOCKET)
318 return;
319
320 // Remove socket_ from the QoS subsystem before we invalidate it.
321 dscp_manager_ = nullptr;
322
323 // Zero out any pending read/write callback state.
324 read_callback_.Reset();
325 recv_from_address_ = nullptr;
326 write_callback_.Reset();
327
328 base::TimeTicks start_time = base::TimeTicks::Now();
329 closesocket(socket_);
330 UMA_HISTOGRAM_TIMES("Net.UDPSocketWinClose",
331 base::TimeTicks::Now() - start_time);
332 socket_ = INVALID_SOCKET;
333 addr_family_ = 0;
334 is_connected_ = false;
335
336 // Release buffers to free up memory.
337 read_iobuffer_ = nullptr;
338 read_iobuffer_len_ = 0;
339 write_iobuffer_ = nullptr;
340 write_iobuffer_len_ = 0;
341
342 read_write_watcher_.StopWatching();
343 read_write_event_.Close();
344
345 event_pending_.InvalidateWeakPtrs();
346
347 if (core_) {
348 core_->Detach();
349 core_ = nullptr;
350 }
351 }
352
GetPeerAddress(IPEndPoint * address) const353 int UDPSocketWin::GetPeerAddress(IPEndPoint* address) const {
354 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
355 DCHECK(address);
356 if (!is_connected())
357 return ERR_SOCKET_NOT_CONNECTED;
358
359 // TODO(szym): Simplify. http://crbug.com/126152
360 if (!remote_address_.get()) {
361 SockaddrStorage storage;
362 if (getpeername(socket_, storage.addr, &storage.addr_len))
363 return MapSystemError(WSAGetLastError());
364 auto remote_address = std::make_unique<IPEndPoint>();
365 if (!remote_address->FromSockAddr(storage.addr, storage.addr_len))
366 return ERR_ADDRESS_INVALID;
367 remote_address_ = std::move(remote_address);
368 }
369
370 *address = *remote_address_;
371 return OK;
372 }
373
GetLocalAddress(IPEndPoint * address) const374 int UDPSocketWin::GetLocalAddress(IPEndPoint* address) const {
375 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
376 DCHECK(address);
377 if (!is_connected())
378 return ERR_SOCKET_NOT_CONNECTED;
379
380 // TODO(szym): Simplify. http://crbug.com/126152
381 if (!local_address_.get()) {
382 SockaddrStorage storage;
383 if (getsockname(socket_, storage.addr, &storage.addr_len))
384 return MapSystemError(WSAGetLastError());
385 auto local_address = std::make_unique<IPEndPoint>();
386 if (!local_address->FromSockAddr(storage.addr, storage.addr_len))
387 return ERR_ADDRESS_INVALID;
388 local_address_ = std::move(local_address);
389 net_log_.AddEvent(NetLogEventType::UDP_LOCAL_ADDRESS, [&] {
390 return CreateNetLogUDPConnectParams(*local_address_,
391 handles::kInvalidNetworkHandle);
392 });
393 }
394
395 *address = *local_address_;
396 return OK;
397 }
398
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)399 int UDPSocketWin::Read(IOBuffer* buf,
400 int buf_len,
401 CompletionOnceCallback callback) {
402 return RecvFrom(buf, buf_len, nullptr, std::move(callback));
403 }
404
RecvFrom(IOBuffer * buf,int buf_len,IPEndPoint * address,CompletionOnceCallback callback)405 int UDPSocketWin::RecvFrom(IOBuffer* buf,
406 int buf_len,
407 IPEndPoint* address,
408 CompletionOnceCallback callback) {
409 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
410 DCHECK_NE(INVALID_SOCKET, socket_);
411 CHECK(read_callback_.is_null());
412 DCHECK(!recv_from_address_);
413 DCHECK(!callback.is_null()); // Synchronous operation not supported.
414 DCHECK_GT(buf_len, 0);
415
416 int nread = core_ ? InternalRecvFromOverlapped(buf, buf_len, address)
417 : InternalRecvFromNonBlocking(buf, buf_len, address);
418 if (nread != ERR_IO_PENDING)
419 return nread;
420
421 read_callback_ = std::move(callback);
422 recv_from_address_ = address;
423 return ERR_IO_PENDING;
424 }
425
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)426 int UDPSocketWin::Write(
427 IOBuffer* buf,
428 int buf_len,
429 CompletionOnceCallback callback,
430 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
431 return SendToOrWrite(buf, buf_len, remote_address_.get(),
432 std::move(callback));
433 }
434
SendTo(IOBuffer * buf,int buf_len,const IPEndPoint & address,CompletionOnceCallback callback)435 int UDPSocketWin::SendTo(IOBuffer* buf,
436 int buf_len,
437 const IPEndPoint& address,
438 CompletionOnceCallback callback) {
439 if (dscp_manager_) {
440 // Alert DscpManager in case this is a new remote address. Failure to
441 // apply Dscp code is never fatal.
442 int rv = dscp_manager_->PrepareForSend(address);
443 if (rv != OK)
444 net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, rv);
445 }
446 return SendToOrWrite(buf, buf_len, &address, std::move(callback));
447 }
448
SendToOrWrite(IOBuffer * buf,int buf_len,const IPEndPoint * address,CompletionOnceCallback callback)449 int UDPSocketWin::SendToOrWrite(IOBuffer* buf,
450 int buf_len,
451 const IPEndPoint* address,
452 CompletionOnceCallback callback) {
453 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
454 DCHECK_NE(INVALID_SOCKET, socket_);
455 CHECK(write_callback_.is_null());
456 DCHECK(!callback.is_null()); // Synchronous operation not supported.
457 DCHECK_GT(buf_len, 0);
458 DCHECK(!send_to_address_.get());
459
460 int nwrite = core_ ? InternalSendToOverlapped(buf, buf_len, address)
461 : InternalSendToNonBlocking(buf, buf_len, address);
462 if (nwrite != ERR_IO_PENDING)
463 return nwrite;
464
465 if (address)
466 send_to_address_ = std::make_unique<IPEndPoint>(*address);
467 write_callback_ = std::move(callback);
468 return ERR_IO_PENDING;
469 }
470
Connect(const IPEndPoint & address)471 int UDPSocketWin::Connect(const IPEndPoint& address) {
472 DCHECK_NE(socket_, INVALID_SOCKET);
473 net_log_.BeginEvent(NetLogEventType::UDP_CONNECT, [&] {
474 return CreateNetLogUDPConnectParams(address,
475 handles::kInvalidNetworkHandle);
476 });
477 int rv = SetMulticastOptions();
478 if (rv != OK)
479 return rv;
480 rv = InternalConnect(address);
481 net_log_.EndEventWithNetErrorCode(NetLogEventType::UDP_CONNECT, rv);
482 is_connected_ = (rv == OK);
483 return rv;
484 }
485
InternalConnect(const IPEndPoint & address)486 int UDPSocketWin::InternalConnect(const IPEndPoint& address) {
487 DCHECK(!is_connected());
488 DCHECK(!remote_address_.get());
489
490 // Always do a random bind.
491 // Ignore failures, which may happen if the socket was already bound.
492 DWORD randomize_port_value = 1;
493 setsockopt(socket_, SOL_SOCKET, SO_RANDOMIZE_PORT,
494 reinterpret_cast<const char*>(&randomize_port_value),
495 sizeof(randomize_port_value));
496
497 SockaddrStorage storage;
498 if (!address.ToSockAddr(storage.addr, &storage.addr_len))
499 return ERR_ADDRESS_INVALID;
500
501 int rv = connect(socket_, storage.addr, storage.addr_len);
502 if (rv < 0)
503 return MapSystemError(WSAGetLastError());
504
505 remote_address_ = std::make_unique<IPEndPoint>(address);
506
507 if (dscp_manager_)
508 dscp_manager_->PrepareForSend(*remote_address_.get());
509
510 return rv;
511 }
512
Bind(const IPEndPoint & address)513 int UDPSocketWin::Bind(const IPEndPoint& address) {
514 DCHECK_NE(socket_, INVALID_SOCKET);
515 DCHECK(!is_connected());
516
517 int rv = SetMulticastOptions();
518 if (rv < 0)
519 return rv;
520
521 rv = DoBind(address);
522 if (rv < 0)
523 return rv;
524
525 local_address_.reset();
526 is_connected_ = true;
527 return rv;
528 }
529
BindToNetwork(handles::NetworkHandle network)530 int UDPSocketWin::BindToNetwork(handles::NetworkHandle network) {
531 NOTIMPLEMENTED();
532 return ERR_NOT_IMPLEMENTED;
533 }
534
SetReceiveBufferSize(int32_t size)535 int UDPSocketWin::SetReceiveBufferSize(int32_t size) {
536 DCHECK_NE(socket_, INVALID_SOCKET);
537 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
538 int rv = SetSocketReceiveBufferSize(socket_, size);
539
540 if (rv != 0)
541 return MapSystemError(WSAGetLastError());
542
543 // According to documentation, setsockopt may succeed, but we need to check
544 // the results via getsockopt to be sure it works on Windows.
545 int32_t actual_size = 0;
546 int option_size = sizeof(actual_size);
547 rv = getsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
548 reinterpret_cast<char*>(&actual_size), &option_size);
549 if (rv != 0)
550 return MapSystemError(WSAGetLastError());
551 if (actual_size >= size)
552 return OK;
553 UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableReceiveBuffer",
554 actual_size, 1000, 1000000, 50);
555 return ERR_SOCKET_RECEIVE_BUFFER_SIZE_UNCHANGEABLE;
556 }
557
SetSendBufferSize(int32_t size)558 int UDPSocketWin::SetSendBufferSize(int32_t size) {
559 DCHECK_NE(socket_, INVALID_SOCKET);
560 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
561 int rv = SetSocketSendBufferSize(socket_, size);
562 if (rv != 0)
563 return MapSystemError(WSAGetLastError());
564 // According to documentation, setsockopt may succeed, but we need to check
565 // the results via getsockopt to be sure it works on Windows.
566 int32_t actual_size = 0;
567 int option_size = sizeof(actual_size);
568 rv = getsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
569 reinterpret_cast<char*>(&actual_size), &option_size);
570 if (rv != 0)
571 return MapSystemError(WSAGetLastError());
572 if (actual_size >= size)
573 return OK;
574 UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableSendBuffer",
575 actual_size, 1000, 1000000, 50);
576 return ERR_SOCKET_SEND_BUFFER_SIZE_UNCHANGEABLE;
577 }
578
SetDoNotFragment()579 int UDPSocketWin::SetDoNotFragment() {
580 DCHECK_NE(socket_, INVALID_SOCKET);
581 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
582
583 if (addr_family_ == AF_INET6)
584 return OK;
585
586 DWORD val = 1;
587 int rv = setsockopt(socket_, IPPROTO_IP, IP_DONTFRAGMENT,
588 reinterpret_cast<const char*>(&val), sizeof(val));
589 return rv == 0 ? OK : MapSystemError(WSAGetLastError());
590 }
591
SetRecvEcn()592 int UDPSocketWin::SetRecvEcn() {
593 DCHECK_NE(socket_, INVALID_SOCKET);
594 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
595
596 int rv;
597 unsigned int ecn = 1;
598 if (addr_family_ == AF_INET6) {
599 rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_RECVTCLASS,
600 reinterpret_cast<const char*>(&ecn), sizeof(ecn));
601 } else {
602 DCHECK_EQ(addr_family_, AF_INET);
603 rv = setsockopt(socket_, IPPROTO_IP, IP_RECVTOS,
604 reinterpret_cast<const char*>(&ecn), sizeof(ecn));
605 }
606 return rv == 0 ? OK : MapSystemError(WSAGetLastError());
607 }
608
SetMsgConfirm(bool confirm)609 void UDPSocketWin::SetMsgConfirm(bool confirm) {}
610
AllowAddressReuse()611 int UDPSocketWin::AllowAddressReuse() {
612 DCHECK_NE(socket_, INVALID_SOCKET);
613 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
614 DCHECK(!is_connected());
615
616 BOOL true_value = TRUE;
617 int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
618 reinterpret_cast<const char*>(&true_value),
619 sizeof(true_value));
620 return rv == 0 ? OK : MapSystemError(WSAGetLastError());
621 }
622
SetBroadcast(bool broadcast)623 int UDPSocketWin::SetBroadcast(bool broadcast) {
624 DCHECK_NE(socket_, INVALID_SOCKET);
625 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
626
627 BOOL value = broadcast ? TRUE : FALSE;
628 int rv = setsockopt(socket_, SOL_SOCKET, SO_BROADCAST,
629 reinterpret_cast<const char*>(&value), sizeof(value));
630 return rv == 0 ? OK : MapSystemError(WSAGetLastError());
631 }
632
AllowAddressSharingForMulticast()633 int UDPSocketWin::AllowAddressSharingForMulticast() {
634 // When proper multicast groups are used, Windows further defines the address
635 // resuse option (SO_REUSEADDR) to ensure all listening sockets can receive
636 // all incoming messages for the multicast group.
637 return AllowAddressReuse();
638 }
639
DoReadCallback(int rv)640 void UDPSocketWin::DoReadCallback(int rv) {
641 DCHECK_NE(rv, ERR_IO_PENDING);
642 DCHECK(!read_callback_.is_null());
643
644 // since Run may result in Read being called, clear read_callback_ up front.
645 std::move(read_callback_).Run(rv);
646 }
647
DoWriteCallback(int rv)648 void UDPSocketWin::DoWriteCallback(int rv) {
649 DCHECK_NE(rv, ERR_IO_PENDING);
650 DCHECK(!write_callback_.is_null());
651
652 // since Run may result in Write being called, clear write_callback_ up front.
653 std::move(write_callback_).Run(rv);
654 }
655
DidCompleteRead()656 void UDPSocketWin::DidCompleteRead() {
657 DWORD num_bytes, flags;
658 BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_,
659 &num_bytes, FALSE, &flags);
660 WSAResetEvent(core_->read_overlapped_.hEvent);
661 int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
662 // Convert address.
663 IPEndPoint address;
664 IPEndPoint* address_to_log = nullptr;
665 if (result >= 0) {
666 if (address.FromSockAddr(core_->recv_addr_storage_.addr,
667 core_->recv_addr_storage_.addr_len)) {
668 if (recv_from_address_)
669 *recv_from_address_ = address;
670 address_to_log = &address;
671 } else {
672 result = ERR_ADDRESS_INVALID;
673 }
674 }
675 LogRead(result, core_->read_iobuffer_->data(), address_to_log);
676 core_->read_iobuffer_ = nullptr;
677 recv_from_address_ = nullptr;
678 DoReadCallback(result);
679 }
680
DidCompleteWrite()681 void UDPSocketWin::DidCompleteWrite() {
682 DWORD num_bytes, flags;
683 BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
684 &num_bytes, FALSE, &flags);
685 WSAResetEvent(core_->write_overlapped_.hEvent);
686 int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
687 LogWrite(result, core_->write_iobuffer_->data(), send_to_address_.get());
688
689 send_to_address_.reset();
690 core_->write_iobuffer_ = nullptr;
691 DoWriteCallback(result);
692 }
693
OnObjectSignaled(HANDLE object)694 void UDPSocketWin::OnObjectSignaled(HANDLE object) {
695 DCHECK(object == read_write_event_.Get());
696 WSANETWORKEVENTS network_events;
697 int os_error = 0;
698 int rv =
699 WSAEnumNetworkEvents(socket_, read_write_event_.Get(), &network_events);
700 // Protects against trying to call the write callback if the read callback
701 // either closes or destroys |this|.
702 base::WeakPtr<UDPSocketWin> event_pending = event_pending_.GetWeakPtr();
703 if (rv == SOCKET_ERROR) {
704 os_error = WSAGetLastError();
705 rv = MapSystemError(os_error);
706
707 if (read_iobuffer_) {
708 read_iobuffer_ = nullptr;
709 read_iobuffer_len_ = 0;
710 recv_from_address_ = nullptr;
711 DoReadCallback(rv);
712 }
713
714 // Socket may have been closed or destroyed here.
715 if (event_pending && write_iobuffer_) {
716 write_iobuffer_ = nullptr;
717 write_iobuffer_len_ = 0;
718 send_to_address_.reset();
719 DoWriteCallback(rv);
720 }
721 return;
722 }
723
724 if ((network_events.lNetworkEvents & FD_READ) && read_iobuffer_)
725 OnReadSignaled();
726 if (!event_pending)
727 return;
728
729 if ((network_events.lNetworkEvents & FD_WRITE) && write_iobuffer_)
730 OnWriteSignaled();
731 if (!event_pending)
732 return;
733
734 // There's still pending read / write. Watch for further events.
735 if (read_iobuffer_ || write_iobuffer_)
736 WatchForReadWrite();
737 }
738
OnReadSignaled()739 void UDPSocketWin::OnReadSignaled() {
740 int rv = InternalRecvFromNonBlocking(read_iobuffer_.get(), read_iobuffer_len_,
741 recv_from_address_);
742 if (rv == ERR_IO_PENDING)
743 return;
744 read_iobuffer_ = nullptr;
745 read_iobuffer_len_ = 0;
746 recv_from_address_ = nullptr;
747 DoReadCallback(rv);
748 }
749
OnWriteSignaled()750 void UDPSocketWin::OnWriteSignaled() {
751 int rv = InternalSendToNonBlocking(write_iobuffer_.get(), write_iobuffer_len_,
752 send_to_address_.get());
753 if (rv == ERR_IO_PENDING)
754 return;
755 write_iobuffer_ = nullptr;
756 write_iobuffer_len_ = 0;
757 send_to_address_.reset();
758 DoWriteCallback(rv);
759 }
760
WatchForReadWrite()761 void UDPSocketWin::WatchForReadWrite() {
762 if (read_write_watcher_.IsWatching())
763 return;
764 bool watched =
765 read_write_watcher_.StartWatchingOnce(read_write_event_.Get(), this);
766 DCHECK(watched);
767 }
768
LogRead(int result,const char * bytes,const IPEndPoint * address) const769 void UDPSocketWin::LogRead(int result,
770 const char* bytes,
771 const IPEndPoint* address) const {
772 if (result < 0) {
773 net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_RECEIVE_ERROR,
774 result);
775 return;
776 }
777
778 if (net_log_.IsCapturing()) {
779 NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_RECEIVED, result,
780 bytes, address);
781 }
782
783 activity_monitor::IncrementBytesReceived(result);
784 }
785
LogWrite(int result,const char * bytes,const IPEndPoint * address) const786 void UDPSocketWin::LogWrite(int result,
787 const char* bytes,
788 const IPEndPoint* address) const {
789 if (result < 0) {
790 net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, result);
791 return;
792 }
793
794 if (net_log_.IsCapturing()) {
795 NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_SENT, result,
796 bytes, address);
797 }
798 }
799
InternalRecvFromOverlapped(IOBuffer * buf,int buf_len,IPEndPoint * address)800 int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf,
801 int buf_len,
802 IPEndPoint* address) {
803 DCHECK(!core_->read_iobuffer_.get());
804 SockaddrStorage& storage = core_->recv_addr_storage_;
805 storage.addr_len = sizeof(storage.addr_storage);
806
807 WSABUF read_buffer;
808 read_buffer.buf = buf->data();
809 read_buffer.len = buf_len;
810
811 DWORD flags = 0;
812 DWORD num;
813 CHECK_NE(INVALID_SOCKET, socket_);
814 int rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, storage.addr,
815 &storage.addr_len, &core_->read_overlapped_, nullptr);
816 if (rv == 0) {
817 if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) {
818 int result = num;
819 // Convert address.
820 IPEndPoint address_storage;
821 IPEndPoint* address_to_log = nullptr;
822 if (result >= 0) {
823 if (address_storage.FromSockAddr(core_->recv_addr_storage_.addr,
824 core_->recv_addr_storage_.addr_len)) {
825 if (address)
826 *address = address_storage;
827 address_to_log = &address_storage;
828 } else {
829 result = ERR_ADDRESS_INVALID;
830 }
831 }
832 LogRead(result, buf->data(), address_to_log);
833 return result;
834 }
835 } else {
836 int os_error = WSAGetLastError();
837 if (os_error != WSA_IO_PENDING) {
838 int result = MapSystemError(os_error);
839 LogRead(result, nullptr, nullptr);
840 return result;
841 }
842 }
843 core_->WatchForRead();
844 core_->read_iobuffer_ = buf;
845 return ERR_IO_PENDING;
846 }
847
InternalSendToOverlapped(IOBuffer * buf,int buf_len,const IPEndPoint * address)848 int UDPSocketWin::InternalSendToOverlapped(IOBuffer* buf,
849 int buf_len,
850 const IPEndPoint* address) {
851 DCHECK(!core_->write_iobuffer_.get());
852 SockaddrStorage storage;
853 struct sockaddr* addr = storage.addr;
854 // Convert address.
855 if (!address) {
856 addr = nullptr;
857 storage.addr_len = 0;
858 } else {
859 if (!address->ToSockAddr(addr, &storage.addr_len)) {
860 int result = ERR_ADDRESS_INVALID;
861 LogWrite(result, nullptr, nullptr);
862 return result;
863 }
864 }
865
866 WSABUF write_buffer;
867 write_buffer.buf = buf->data();
868 write_buffer.len = buf_len;
869
870 DWORD flags = 0;
871 DWORD num;
872 int rv = WSASendTo(socket_, &write_buffer, 1, &num, flags, addr,
873 storage.addr_len, &core_->write_overlapped_, nullptr);
874 if (rv == 0) {
875 if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) {
876 int result = num;
877 LogWrite(result, buf->data(), address);
878 return result;
879 }
880 } else {
881 int os_error = WSAGetLastError();
882 if (os_error != WSA_IO_PENDING) {
883 int result = MapSystemError(os_error);
884 LogWrite(result, nullptr, nullptr);
885 return result;
886 }
887 }
888
889 core_->WatchForWrite();
890 core_->write_iobuffer_ = buf;
891 return ERR_IO_PENDING;
892 }
893
InternalRecvFromNonBlocking(IOBuffer * buf,int buf_len,IPEndPoint * address)894 int UDPSocketWin::InternalRecvFromNonBlocking(IOBuffer* buf,
895 int buf_len,
896 IPEndPoint* address) {
897 DCHECK(!read_iobuffer_ || read_iobuffer_.get() == buf);
898 SockaddrStorage storage;
899 storage.addr_len = sizeof(storage.addr_storage);
900
901 CHECK_NE(INVALID_SOCKET, socket_);
902 int rv = recvfrom(socket_, buf->data(), buf_len, 0, storage.addr,
903 &storage.addr_len);
904 if (rv == SOCKET_ERROR) {
905 int os_error = WSAGetLastError();
906 if (os_error == WSAEWOULDBLOCK) {
907 read_iobuffer_ = buf;
908 read_iobuffer_len_ = buf_len;
909 WatchForReadWrite();
910 return ERR_IO_PENDING;
911 }
912 rv = MapSystemError(os_error);
913 LogRead(rv, nullptr, nullptr);
914 return rv;
915 }
916 IPEndPoint address_storage;
917 IPEndPoint* address_to_log = nullptr;
918 if (rv >= 0) {
919 if (address_storage.FromSockAddr(storage.addr, storage.addr_len)) {
920 if (address)
921 *address = address_storage;
922 address_to_log = &address_storage;
923 } else {
924 rv = ERR_ADDRESS_INVALID;
925 }
926 }
927 LogRead(rv, buf->data(), address_to_log);
928 return rv;
929 }
930
InternalSendToNonBlocking(IOBuffer * buf,int buf_len,const IPEndPoint * address)931 int UDPSocketWin::InternalSendToNonBlocking(IOBuffer* buf,
932 int buf_len,
933 const IPEndPoint* address) {
934 DCHECK(!write_iobuffer_ || write_iobuffer_.get() == buf);
935 SockaddrStorage storage;
936 struct sockaddr* addr = storage.addr;
937 // Convert address.
938 if (address) {
939 if (!address->ToSockAddr(addr, &storage.addr_len)) {
940 int result = ERR_ADDRESS_INVALID;
941 LogWrite(result, nullptr, nullptr);
942 return result;
943 }
944 } else {
945 addr = nullptr;
946 storage.addr_len = 0;
947 }
948
949 int rv = sendto(socket_, buf->data(), buf_len, 0, addr, storage.addr_len);
950 if (rv == SOCKET_ERROR) {
951 int os_error = WSAGetLastError();
952 if (os_error == WSAEWOULDBLOCK) {
953 write_iobuffer_ = buf;
954 write_iobuffer_len_ = buf_len;
955 WatchForReadWrite();
956 return ERR_IO_PENDING;
957 }
958 rv = MapSystemError(os_error);
959 LogWrite(rv, nullptr, nullptr);
960 return rv;
961 }
962 LogWrite(rv, buf->data(), address);
963 return rv;
964 }
965
SetMulticastOptions()966 int UDPSocketWin::SetMulticastOptions() {
967 if (!(socket_options_ & SOCKET_OPTION_MULTICAST_LOOP)) {
968 DWORD loop = 0;
969 int protocol_level =
970 addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
971 int option =
972 addr_family_ == AF_INET ? IP_MULTICAST_LOOP: IPV6_MULTICAST_LOOP;
973 int rv = setsockopt(socket_, protocol_level, option,
974 reinterpret_cast<const char*>(&loop), sizeof(loop));
975 if (rv < 0)
976 return MapSystemError(WSAGetLastError());
977 }
978 if (multicast_time_to_live_ != 1) {
979 DWORD hops = multicast_time_to_live_;
980 int protocol_level =
981 addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
982 int option =
983 addr_family_ == AF_INET ? IP_MULTICAST_TTL: IPV6_MULTICAST_HOPS;
984 int rv = setsockopt(socket_, protocol_level, option,
985 reinterpret_cast<const char*>(&hops), sizeof(hops));
986 if (rv < 0)
987 return MapSystemError(WSAGetLastError());
988 }
989 if (multicast_interface_ != 0) {
990 switch (addr_family_) {
991 case AF_INET: {
992 in_addr address;
993 address.s_addr = htonl(multicast_interface_);
994 int rv = setsockopt(socket_, IPPROTO_IP, IP_MULTICAST_IF,
995 reinterpret_cast<const char*>(&address),
996 sizeof(address));
997 if (rv)
998 return MapSystemError(WSAGetLastError());
999 break;
1000 }
1001 case AF_INET6: {
1002 uint32_t interface_index = multicast_interface_;
1003 int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_MULTICAST_IF,
1004 reinterpret_cast<const char*>(&interface_index),
1005 sizeof(interface_index));
1006 if (rv)
1007 return MapSystemError(WSAGetLastError());
1008 break;
1009 }
1010 default:
1011 NOTREACHED() << "Invalid address family";
1012 return ERR_ADDRESS_INVALID;
1013 }
1014 }
1015 return OK;
1016 }
1017
DoBind(const IPEndPoint & address)1018 int UDPSocketWin::DoBind(const IPEndPoint& address) {
1019 SockaddrStorage storage;
1020 if (!address.ToSockAddr(storage.addr, &storage.addr_len))
1021 return ERR_ADDRESS_INVALID;
1022 int rv = bind(socket_, storage.addr, storage.addr_len);
1023 if (rv == 0)
1024 return OK;
1025 int last_error = WSAGetLastError();
1026 // Map some codes that are special to bind() separately.
1027 // * WSAEACCES: If a port is already bound to a socket, WSAEACCES may be
1028 // returned instead of WSAEADDRINUSE, depending on whether the socket
1029 // option SO_REUSEADDR or SO_EXCLUSIVEADDRUSE is set and whether the
1030 // conflicting socket is owned by a different user account. See the MSDN
1031 // page "Using SO_REUSEADDR and SO_EXCLUSIVEADDRUSE" for the gory details.
1032 if (last_error == WSAEACCES || last_error == WSAEADDRNOTAVAIL)
1033 return ERR_ADDRESS_IN_USE;
1034 return MapSystemError(last_error);
1035 }
1036
GetQwaveApi() const1037 QwaveApi* UDPSocketWin::GetQwaveApi() const {
1038 return QwaveApi::GetDefault();
1039 }
1040
JoinGroup(const IPAddress & group_address) const1041 int UDPSocketWin::JoinGroup(const IPAddress& group_address) const {
1042 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1043 if (!is_connected())
1044 return ERR_SOCKET_NOT_CONNECTED;
1045
1046 switch (group_address.size()) {
1047 case IPAddress::kIPv4AddressSize: {
1048 if (addr_family_ != AF_INET)
1049 return ERR_ADDRESS_INVALID;
1050 ip_mreq mreq;
1051 mreq.imr_interface.s_addr = htonl(multicast_interface_);
1052 memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1053 IPAddress::kIPv4AddressSize);
1054 int rv = setsockopt(socket_, IPPROTO_IP, IP_ADD_MEMBERSHIP,
1055 reinterpret_cast<const char*>(&mreq),
1056 sizeof(mreq));
1057 if (rv)
1058 return MapSystemError(WSAGetLastError());
1059 return OK;
1060 }
1061 case IPAddress::kIPv6AddressSize: {
1062 if (addr_family_ != AF_INET6)
1063 return ERR_ADDRESS_INVALID;
1064 ipv6_mreq mreq;
1065 mreq.ipv6mr_interface = multicast_interface_;
1066 memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1067 IPAddress::kIPv6AddressSize);
1068 int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
1069 reinterpret_cast<const char*>(&mreq),
1070 sizeof(mreq));
1071 if (rv)
1072 return MapSystemError(WSAGetLastError());
1073 return OK;
1074 }
1075 default:
1076 NOTREACHED() << "Invalid address family";
1077 return ERR_ADDRESS_INVALID;
1078 }
1079 }
1080
LeaveGroup(const IPAddress & group_address) const1081 int UDPSocketWin::LeaveGroup(const IPAddress& group_address) const {
1082 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1083 if (!is_connected())
1084 return ERR_SOCKET_NOT_CONNECTED;
1085
1086 switch (group_address.size()) {
1087 case IPAddress::kIPv4AddressSize: {
1088 if (addr_family_ != AF_INET)
1089 return ERR_ADDRESS_INVALID;
1090 ip_mreq mreq;
1091 mreq.imr_interface.s_addr = htonl(multicast_interface_);
1092 memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1093 IPAddress::kIPv4AddressSize);
1094 int rv = setsockopt(socket_, IPPROTO_IP, IP_DROP_MEMBERSHIP,
1095 reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1096 if (rv)
1097 return MapSystemError(WSAGetLastError());
1098 return OK;
1099 }
1100 case IPAddress::kIPv6AddressSize: {
1101 if (addr_family_ != AF_INET6)
1102 return ERR_ADDRESS_INVALID;
1103 ipv6_mreq mreq;
1104 mreq.ipv6mr_interface = multicast_interface_;
1105 memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1106 IPAddress::kIPv6AddressSize);
1107 int rv = setsockopt(socket_, IPPROTO_IPV6, IP_DROP_MEMBERSHIP,
1108 reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1109 if (rv)
1110 return MapSystemError(WSAGetLastError());
1111 return OK;
1112 }
1113 default:
1114 NOTREACHED() << "Invalid address family";
1115 return ERR_ADDRESS_INVALID;
1116 }
1117 }
1118
SetMulticastInterface(uint32_t interface_index)1119 int UDPSocketWin::SetMulticastInterface(uint32_t interface_index) {
1120 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1121 if (is_connected())
1122 return ERR_SOCKET_IS_CONNECTED;
1123 multicast_interface_ = interface_index;
1124 return OK;
1125 }
1126
SetMulticastTimeToLive(int time_to_live)1127 int UDPSocketWin::SetMulticastTimeToLive(int time_to_live) {
1128 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1129 if (is_connected())
1130 return ERR_SOCKET_IS_CONNECTED;
1131
1132 if (time_to_live < 0 || time_to_live > 255)
1133 return ERR_INVALID_ARGUMENT;
1134 multicast_time_to_live_ = time_to_live;
1135 return OK;
1136 }
1137
SetMulticastLoopbackMode(bool loopback)1138 int UDPSocketWin::SetMulticastLoopbackMode(bool loopback) {
1139 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1140 if (is_connected())
1141 return ERR_SOCKET_IS_CONNECTED;
1142
1143 if (loopback)
1144 socket_options_ |= SOCKET_OPTION_MULTICAST_LOOP;
1145 else
1146 socket_options_ &= ~SOCKET_OPTION_MULTICAST_LOOP;
1147 return OK;
1148 }
1149
DscpToTrafficType(DiffServCodePoint dscp)1150 QOS_TRAFFIC_TYPE DscpToTrafficType(DiffServCodePoint dscp) {
1151 QOS_TRAFFIC_TYPE traffic_type = QOSTrafficTypeBestEffort;
1152 switch (dscp) {
1153 case DSCP_CS0:
1154 traffic_type = QOSTrafficTypeBestEffort;
1155 break;
1156 case DSCP_CS1:
1157 traffic_type = QOSTrafficTypeBackground;
1158 break;
1159 case DSCP_AF11:
1160 case DSCP_AF12:
1161 case DSCP_AF13:
1162 case DSCP_CS2:
1163 case DSCP_AF21:
1164 case DSCP_AF22:
1165 case DSCP_AF23:
1166 case DSCP_CS3:
1167 case DSCP_AF31:
1168 case DSCP_AF32:
1169 case DSCP_AF33:
1170 case DSCP_CS4:
1171 traffic_type = QOSTrafficTypeExcellentEffort;
1172 break;
1173 case DSCP_AF41:
1174 case DSCP_AF42:
1175 case DSCP_AF43:
1176 case DSCP_CS5:
1177 traffic_type = QOSTrafficTypeAudioVideo;
1178 break;
1179 case DSCP_EF:
1180 case DSCP_CS6:
1181 traffic_type = QOSTrafficTypeVoice;
1182 break;
1183 case DSCP_CS7:
1184 traffic_type = QOSTrafficTypeControl;
1185 break;
1186 case DSCP_NO_CHANGE:
1187 NOTREACHED();
1188 break;
1189 }
1190 return traffic_type;
1191 }
1192
SetDiffServCodePoint(DiffServCodePoint dscp)1193 int UDPSocketWin::SetDiffServCodePoint(DiffServCodePoint dscp) {
1194 if (dscp == DSCP_NO_CHANGE)
1195 return OK;
1196
1197 if (!is_connected())
1198 return ERR_SOCKET_NOT_CONNECTED;
1199
1200 QwaveApi* api = GetQwaveApi();
1201
1202 if (!api->qwave_supported())
1203 return ERR_NOT_IMPLEMENTED;
1204
1205 if (!dscp_manager_)
1206 dscp_manager_ = std::make_unique<DscpManager>(api, socket_);
1207
1208 dscp_manager_->Set(dscp);
1209 if (remote_address_)
1210 return dscp_manager_->PrepareForSend(*remote_address_.get());
1211
1212 return OK;
1213 }
1214
SetIPv6Only(bool ipv6_only)1215 int UDPSocketWin::SetIPv6Only(bool ipv6_only) {
1216 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1217 if (is_connected()) {
1218 return ERR_SOCKET_IS_CONNECTED;
1219 }
1220 return net::SetIPv6Only(socket_, ipv6_only);
1221 }
1222
DetachFromThread()1223 void UDPSocketWin::DetachFromThread() {
1224 DETACH_FROM_THREAD(thread_checker_);
1225 }
1226
UseNonBlockingIO()1227 void UDPSocketWin::UseNonBlockingIO() {
1228 DCHECK(!core_);
1229 use_non_blocking_io_ = true;
1230 }
1231
ApplySocketTag(const SocketTag & tag)1232 void UDPSocketWin::ApplySocketTag(const SocketTag& tag) {
1233 // Windows does not support any specific SocketTags so fail if any non-default
1234 // tag is applied.
1235 CHECK(tag == SocketTag());
1236 }
1237
DscpManager(QwaveApi * api,SOCKET socket)1238 DscpManager::DscpManager(QwaveApi* api, SOCKET socket)
1239 : api_(api), socket_(socket) {
1240 RequestHandle();
1241 }
1242
~DscpManager()1243 DscpManager::~DscpManager() {
1244 if (!qos_handle_)
1245 return;
1246
1247 if (flow_id_ != 0)
1248 api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1249
1250 api_->CloseHandle(qos_handle_);
1251 }
1252
Set(DiffServCodePoint dscp)1253 void DscpManager::Set(DiffServCodePoint dscp) {
1254 if (dscp == DSCP_NO_CHANGE || dscp == dscp_value_)
1255 return;
1256
1257 dscp_value_ = dscp;
1258
1259 // TODO(zstein): We could reuse the flow when the value changes
1260 // by calling QOSSetFlow with the new traffic type and dscp value.
1261 if (flow_id_ != 0 && qos_handle_) {
1262 api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1263 configured_.clear();
1264 flow_id_ = 0;
1265 }
1266 }
1267
PrepareForSend(const IPEndPoint & remote_address)1268 int DscpManager::PrepareForSend(const IPEndPoint& remote_address) {
1269 if (dscp_value_ == DSCP_NO_CHANGE) {
1270 // No DSCP value has been set.
1271 return OK;
1272 }
1273
1274 if (!api_->qwave_supported())
1275 return ERR_NOT_IMPLEMENTED;
1276
1277 if (!qos_handle_)
1278 return ERR_INVALID_HANDLE; // The closest net error to try again later.
1279
1280 if (configured_.find(remote_address) != configured_.end())
1281 return OK;
1282
1283 SockaddrStorage storage;
1284 if (!remote_address.ToSockAddr(storage.addr, &storage.addr_len))
1285 return ERR_ADDRESS_INVALID;
1286
1287 // We won't try this address again if we get an error.
1288 configured_.emplace(remote_address);
1289
1290 // We don't need to call SetFlow if we already have a qos flow.
1291 bool new_flow = flow_id_ == 0;
1292
1293 const QOS_TRAFFIC_TYPE traffic_type = DscpToTrafficType(dscp_value_);
1294
1295 if (!api_->AddSocketToFlow(qos_handle_, socket_, storage.addr, traffic_type,
1296 QOS_NON_ADAPTIVE_FLOW, &flow_id_)) {
1297 DWORD err = ::GetLastError();
1298 if (err == ERROR_DEVICE_REINITIALIZATION_NEEDED) {
1299 // Reset. PrepareForSend is called for every packet. Once RequestHandle
1300 // completes asynchronously the next PrepareForSend call will re-register
1301 // the address with the new QoS Handle. In the meantime, sends will
1302 // continue without DSCP.
1303 RequestHandle();
1304 configured_.clear();
1305 flow_id_ = 0;
1306 return ERR_INVALID_HANDLE;
1307 }
1308 return MapSystemError(err);
1309 }
1310
1311 if (new_flow) {
1312 DWORD buf = dscp_value_;
1313 // This requires admin rights, and may fail, if so we ignore it
1314 // as AddSocketToFlow should still do *approximately* the right thing.
1315 api_->SetFlow(qos_handle_, flow_id_, QOSSetOutgoingDSCPValue, sizeof(buf),
1316 &buf, 0, nullptr);
1317 }
1318
1319 return OK;
1320 }
1321
RequestHandle()1322 void DscpManager::RequestHandle() {
1323 if (handle_is_initializing_)
1324 return;
1325
1326 if (qos_handle_) {
1327 api_->CloseHandle(qos_handle_);
1328 qos_handle_ = nullptr;
1329 }
1330
1331 handle_is_initializing_ = true;
1332 base::ThreadPool::PostTaskAndReplyWithResult(
1333 FROM_HERE, {base::MayBlock()},
1334 base::BindOnce(&DscpManager::DoCreateHandle, api_),
1335 base::BindOnce(&DscpManager::OnHandleCreated, api_,
1336 weak_ptr_factory_.GetWeakPtr()));
1337 }
1338
DoCreateHandle(QwaveApi * api)1339 HANDLE DscpManager::DoCreateHandle(QwaveApi* api) {
1340 QOS_VERSION version;
1341 version.MajorVersion = 1;
1342 version.MinorVersion = 0;
1343
1344 HANDLE handle = nullptr;
1345
1346 // No access to net_log_ so swallow any errors here.
1347 api->CreateHandle(&version, &handle);
1348 return handle;
1349 }
1350
OnHandleCreated(QwaveApi * api,base::WeakPtr<DscpManager> dscp_manager,HANDLE handle)1351 void DscpManager::OnHandleCreated(QwaveApi* api,
1352 base::WeakPtr<DscpManager> dscp_manager,
1353 HANDLE handle) {
1354 if (!handle)
1355 api->OnFatalError();
1356
1357 if (!dscp_manager) {
1358 api->CloseHandle(handle);
1359 return;
1360 }
1361
1362 DCHECK(dscp_manager->handle_is_initializing_);
1363 DCHECK(!dscp_manager->qos_handle_);
1364
1365 dscp_manager->qos_handle_ = handle;
1366 dscp_manager->handle_is_initializing_ = false;
1367 }
1368
1369 } // namespace net
1370