• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_UDP_SOCKET_WIN_H_
6 #define NET_SOCKET_UDP_SOCKET_WIN_H_
7 
8 #include <qos2.h>
9 #include <stdint.h>
10 #include <winsock2.h>
11 
12 #include <atomic>
13 #include <memory>
14 #include <set>
15 
16 #include "base/gtest_prod_util.h"
17 #include "base/memory/raw_ptr.h"
18 #include "base/memory/scoped_refptr.h"
19 #include "base/memory/weak_ptr.h"
20 #include "base/threading/thread_checker.h"
21 #include "base/win/object_watcher.h"
22 #include "base/win/scoped_handle.h"
23 #include "net/base/address_family.h"
24 #include "net/base/completion_once_callback.h"
25 #include "net/base/io_buffer.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_export.h"
28 #include "net/base/network_handle.h"
29 #include "net/log/net_log_with_source.h"
30 #include "net/socket/datagram_socket.h"
31 #include "net/socket/diff_serv_code_point.h"
32 #include "net/socket/udp_socket_global_limits.h"
33 #include "net/traffic_annotation/network_traffic_annotation.h"
34 
35 namespace net {
36 
37 class IPAddress;
38 class NetLog;
39 struct NetLogSource;
40 class SocketTag;
41 
42 // QWAVE (Quality Windows Audio/Video Experience) is the latest windows
43 // library for setting packet priorities (and other things). Unfortunately,
44 // Microsoft has decided that setting the DSCP bits with setsockopt() no
45 // longer works, so we have to use this API instead.
46 // This class is meant to be used as a singleton. It exposes a few dynamically
47 // loaded functions and a bool called "qwave_supported".
48 class NET_EXPORT QwaveApi {
49   typedef BOOL(WINAPI* CreateHandleFn)(PQOS_VERSION, PHANDLE);
50   typedef BOOL(WINAPI* CloseHandleFn)(HANDLE);
51   typedef BOOL(WINAPI* AddSocketToFlowFn)(HANDLE,
52                                           SOCKET,
53                                           PSOCKADDR,
54                                           QOS_TRAFFIC_TYPE,
55                                           DWORD,
56                                           PQOS_FLOWID);
57   typedef BOOL(WINAPI* RemoveSocketFromFlowFn)(HANDLE,
58                                                SOCKET,
59                                                QOS_FLOWID,
60                                                DWORD);
61   typedef BOOL(WINAPI* SetFlowFn)(HANDLE,
62                                   QOS_FLOWID,
63                                   QOS_SET_FLOW,
64                                   ULONG,
65                                   PVOID,
66                                   DWORD,
67                                   LPOVERLAPPED);
68 
69  public:
70   QwaveApi();
71 
72   QwaveApi(const QwaveApi&) = delete;
73   QwaveApi& operator=(const QwaveApi&) = delete;
74 
75   static QwaveApi* GetDefault();
76 
77   virtual bool qwave_supported() const;
78   virtual void OnFatalError();
79 
80   virtual BOOL CreateHandle(PQOS_VERSION version, PHANDLE handle);
81   virtual BOOL CloseHandle(HANDLE handle);
82   virtual BOOL AddSocketToFlow(HANDLE handle,
83                                SOCKET socket,
84                                PSOCKADDR addr,
85                                QOS_TRAFFIC_TYPE traffic_type,
86                                DWORD flags,
87                                PQOS_FLOWID flow_id);
88   virtual BOOL RemoveSocketFromFlow(HANDLE handle,
89                                     SOCKET socket,
90                                     QOS_FLOWID flow_id,
91                                     DWORD reserved);
92   virtual BOOL SetFlow(HANDLE handle,
93                        QOS_FLOWID flow_id,
94                        QOS_SET_FLOW op,
95                        ULONG size,
96                        PVOID data,
97                        DWORD reserved,
98                        LPOVERLAPPED overlapped);
99 
100  private:
101   std::atomic<bool> qwave_supported_{false};
102 
103   CreateHandleFn create_handle_func_;
104   CloseHandleFn close_handle_func_;
105   AddSocketToFlowFn add_socket_to_flow_func_;
106   RemoveSocketFromFlowFn remove_socket_from_flow_func_;
107   SetFlowFn set_flow_func_;
108 };
109 
110 //-----------------------------------------------------------------------------
111 
112 // Helper for maintaining the state that (unlike a blanket socket option), DSCP
113 // values are set per-remote endpoint instead of just per-socket on Windows.
114 // The implementation creates a single QWAVE 'flow' for the socket, and adds
115 // all encountered remote addresses to that flow.  Flows are the minimum
116 // manageable unit within the QWAVE API.  See
117 // https://docs.microsoft.com/en-us/previous-versions/windows/desktop/api/qos2/
118 // for Microsoft's documentation.
119 class NET_EXPORT DscpManager {
120  public:
121   DscpManager(QwaveApi* api, SOCKET socket);
122 
123   DscpManager(const DscpManager&) = delete;
124   DscpManager& operator=(const DscpManager&) = delete;
125 
126   ~DscpManager();
127 
128   // Remembers the latest |dscp| so PrepareToSend can add remote addresses to
129   // the qos flow. Destroys the old flow if it exists and |dscp| changes.
130   void Set(DiffServCodePoint dscp);
131 
132   // Constructs a qos flow for the latest set DSCP value if we don't already
133   // have one. Adds |remote_address| to the qos flow if it hasn't been added
134   // already. Does nothing if no DSCP value has been Set.
135   int PrepareForSend(const IPEndPoint& remote_address);
136 
137  private:
138   void RequestHandle();
139   static HANDLE DoCreateHandle(QwaveApi* api);
140   static void OnHandleCreated(QwaveApi* api,
141                               base::WeakPtr<DscpManager> dscp_manager,
142                               HANDLE handle);
143 
144   const raw_ptr<QwaveApi> api_;
145   const SOCKET socket_;
146 
147   DiffServCodePoint dscp_value_ = DSCP_NO_CHANGE;
148   // The remote addresses currently in the flow.
149   std::set<IPEndPoint> configured_;
150 
151   HANDLE qos_handle_ = nullptr;
152   bool handle_is_initializing_ = false;
153   // 0 means no flow has been constructed.
154   QOS_FLOWID flow_id_ = 0;
155   base::WeakPtrFactory<DscpManager> weak_ptr_factory_{this};
156 };
157 
158 //-----------------------------------------------------------------------------
159 
160 class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate {
161  public:
162   // BindType is ignored. Windows has an option to do random binds, so
163   // UDPSocketWin sets that whenever connecting a socket.
164   UDPSocketWin(DatagramSocket::BindType bind_type,
165                net::NetLog* net_log,
166                const net::NetLogSource& source);
167 
168   UDPSocketWin(const UDPSocketWin&) = delete;
169   UDPSocketWin& operator=(const UDPSocketWin&) = delete;
170 
171   ~UDPSocketWin() override;
172 
173   // Opens the socket.
174   // Returns a net error code.
175   int Open(AddressFamily address_family);
176 
177   // Not implemented. Returns ERR_NOT_IMPLEMENTED.
178   int BindToNetwork(handles::NetworkHandle network);
179 
180   // Connects the socket to connect with a certain |address|.
181   // Should be called after Open().
182   // Returns a net error code.
183   int Connect(const IPEndPoint& address);
184 
185   // Binds the address/port for this socket to |address|.  This is generally
186   // only used on a server. Should be called after Open().
187   // Returns a net error code.
188   int Bind(const IPEndPoint& address);
189 
190   // Closes the socket.
191   void Close();
192 
193   // Copies the remote udp address into |address| and returns a net error code.
194   int GetPeerAddress(IPEndPoint* address) const;
195 
196   // Copies the local udp address into |address| and returns a net error code.
197   // (similar to getsockname)
198   int GetLocalAddress(IPEndPoint* address) const;
199 
200   // IO:
201   // Multiple outstanding read requests are not supported.
202   // Full duplex mode (reading and writing at the same time) is supported
203 
204   // Reads from the socket.
205   // Only usable from the client-side of a UDP socket, after the socket
206   // has been connected.
207   int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback);
208 
209   // Writes to the socket.
210   // Only usable from the client-side of a UDP socket, after the socket
211   // has been connected.
212   int Write(IOBuffer* buf,
213             int buf_len,
214             CompletionOnceCallback callback,
215             const NetworkTrafficAnnotationTag& traffic_annotation);
216 
217   // Reads from a socket and receive sender address information.
218   // |buf| is the buffer to read data into.
219   // |buf_len| is the maximum amount of data to read.
220   // |address| is a buffer provided by the caller for receiving the sender
221   //   address information about the received data.  This buffer must be kept
222   //   alive by the caller until the callback is placed.
223   // |callback| is the callback on completion of the RecvFrom.
224   // Returns a net error code, or ERR_IO_PENDING if the IO is in progress.
225   // If ERR_IO_PENDING is returned, this socket takes a ref to |buf| to keep
226   // it alive until the data is received. However, the caller must keep
227   // |address| alive until the callback is called.
228   int RecvFrom(IOBuffer* buf,
229                int buf_len,
230                IPEndPoint* address,
231                CompletionOnceCallback callback);
232 
233   // Sends to a socket with a particular destination.
234   // |buf| is the buffer to send.
235   // |buf_len| is the number of bytes to send.
236   // |address| is the recipient address.
237   // |callback| is the user callback function to call on complete.
238   // Returns a net error code, or ERR_IO_PENDING if the IO is in progress.
239   // If ERR_IO_PENDING is returned, this socket copies |address| for
240   // asynchronous sending, and takes a ref to |buf| to keep it alive until the
241   // data is sent.
242   int SendTo(IOBuffer* buf,
243              int buf_len,
244              const IPEndPoint& address,
245              CompletionOnceCallback callback);
246 
247   // Sets the receive buffer size (in bytes) for the socket.
248   // Returns a net error code.
249   int SetReceiveBufferSize(int32_t size);
250 
251   // Sets the send buffer size (in bytes) for the socket.
252   // Returns a net error code.
253   int SetSendBufferSize(int32_t size);
254 
255   // Requests that packets sent by this socket not be fragment, either locally
256   // by the host, or by routers (via the DF bit in the IPv4 packet header).
257   // May not be supported by all platforms. Returns a network error code if
258   // there was a problem, but the socket will still be usable. Can not
259   // return ERR_IO_PENDING.
260   int SetDoNotFragment();
261 
262   // This is a no-op on Windows.
263   void SetMsgConfirm(bool confirm);
264 
265   // Returns true if the socket is already connected or bound.
is_connected()266   bool is_connected() const { return is_connected_; }
267 
NetLog()268   const NetLogWithSource& NetLog() const { return net_log_; }
269 
270   // Sets socket options to allow the socket to share the local address to which
271   // the socket will be bound with other processes. If multiple processes are
272   // bound to the same local address at the same time, behavior is undefined;
273   // e.g., it is not guaranteed that incoming  messages will be sent to all
274   // listening sockets. Returns a net error code.
275   //
276   // Should be called between Open() and Bind().
277   int AllowAddressReuse();
278 
279   // Sets socket options to allow sending and receiving packets to and from
280   // broadcast addresses.
281   int SetBroadcast(bool broadcast);
282 
283   // Sets socket options to allow the socket to share the local address to which
284   // the socket will be bound with other processes and attempt to allow all such
285   // sockets to receive the same multicast messages. Returns a net error code.
286   //
287   // For Windows, multicast messages should always be shared between sockets
288   // configured thusly as long as the sockets join the same multicast group and
289   // interface.
290   //
291   // Should be called between Open() and Bind().
292   int AllowAddressSharingForMulticast();
293 
294   // Joins the multicast group.
295   // |group_address| is the group address to join, could be either
296   // an IPv4 or IPv6 address.
297   // Returns a net error code.
298   int JoinGroup(const IPAddress& group_address) const;
299 
300   // Leaves the multicast group.
301   // |group_address| is the group address to leave, could be either
302   // an IPv4 or IPv6 address. If the socket hasn't joined the group,
303   // it will be ignored.
304   // It's optional to leave the multicast group before destroying
305   // the socket. It will be done by the OS.
306   // Return a net error code.
307   int LeaveGroup(const IPAddress& group_address) const;
308 
309   // Sets interface to use for multicast. If |interface_index| set to 0,
310   // default interface is used.
311   // Should be called before Bind().
312   // Returns a net error code.
313   int SetMulticastInterface(uint32_t interface_index);
314 
315   // Sets the time-to-live option for UDP packets sent to the multicast
316   // group address. The default value of this option is 1.
317   // Cannot be negative or more than 255.
318   // Should be called before Bind().
319   int SetMulticastTimeToLive(int time_to_live);
320 
321   // Sets the loopback flag for UDP socket. If this flag is true, the host
322   // will receive packets sent to the joined group from itself.
323   // The default value of this option is true.
324   // Should be called before Bind().
325   //
326   // Note: the behavior of |SetMulticastLoopbackMode| is slightly
327   // different between Windows and Unix-like systems. The inconsistency only
328   // happens when there are more than one applications on the same host
329   // joined to the same multicast group while having different settings on
330   // multicast loopback mode. On Windows, the applications with loopback off
331   // will not RECEIVE the loopback packets; while on Unix-like systems, the
332   // applications with loopback off will not SEND the loopback packets to
333   // other applications on the same host. See MSDN: http://goo.gl/6vqbj
334   int SetMulticastLoopbackMode(bool loopback);
335 
336   // Sets the differentiated services flags on outgoing packets. May not do
337   // anything on some platforms. A return value of ERR_INVALID_HANDLE indicates
338   // the value was not set but could succeed on a future call, because
339   // initialization is in progress.
340   int SetDiffServCodePoint(DiffServCodePoint dscp);
341 
342   // Sets IPV6_V6ONLY on the socket. If this flag is true, the socket will be
343   // restricted to only IPv6; false allows both IPv4 and IPv6 traffic.
344   int SetIPv6Only(bool ipv6_only);
345 
346   // Resets the thread to be used for thread-safety checks.
347   void DetachFromThread();
348 
349   // This class by default uses overlapped IO. Call this method before Open()
350   // to switch to non-blocking IO.
351   void UseNonBlockingIO();
352 
353   // Apply |tag| to this socket.
354   void ApplySocketTag(const SocketTag& tag);
355 
356   // Takes ownership of `socket`, which should be a socket descriptor opened
357   // with the specified address family. The socket should only be created but
358   // not bound or connected to an address.
359   int AdoptOpenedSocket(AddressFamily address_family, SOCKET socket);
360 
361  private:
362   enum SocketOptions {
363     SOCKET_OPTION_MULTICAST_LOOP = 1 << 0
364   };
365 
366   class Core;
367 
368   void DoReadCallback(int rv);
369   void DoWriteCallback(int rv);
370 
371   void DidCompleteRead();
372   void DidCompleteWrite();
373 
374   // base::ObjectWatcher::Delegate implementation.
375   void OnObjectSignaled(HANDLE object) override;
376   void OnReadSignaled();
377   void OnWriteSignaled();
378 
379   void WatchForReadWrite();
380 
381   // Handles stats and logging. |result| is the number of bytes transferred, on
382   // success, or the net error code on failure.
383   void LogRead(int result, const char* bytes, const IPEndPoint* address) const;
384   void LogWrite(int result, const char* bytes, const IPEndPoint* address) const;
385 
386   // Same as SendTo(), except that address is passed by pointer
387   // instead of by reference. It is called from Write() with |address|
388   // set to NULL.
389   int SendToOrWrite(IOBuffer* buf,
390                     int buf_len,
391                     const IPEndPoint* address,
392                     CompletionOnceCallback callback);
393 
394   int InternalConnect(const IPEndPoint& address);
395 
396   // Version for using overlapped IO.
397   int InternalRecvFromOverlapped(IOBuffer* buf,
398                                  int buf_len,
399                                  IPEndPoint* address);
400   int InternalSendToOverlapped(IOBuffer* buf,
401                                int buf_len,
402                                const IPEndPoint* address);
403 
404   // Version for using non-blocking IO.
405   int InternalRecvFromNonBlocking(IOBuffer* buf,
406                                   int buf_len,
407                                   IPEndPoint* address);
408   int InternalSendToNonBlocking(IOBuffer* buf,
409                                 int buf_len,
410                                 const IPEndPoint* address);
411 
412   // Applies |socket_options_| to |socket_|. Should be called before
413   // Bind().
414   int SetMulticastOptions();
415   int DoBind(const IPEndPoint& address);
416 
417   // Configures opened `socket_` depending on whether it uses nonblocking IO.
418   void ConfigureOpenedSocket();
419 
420   // This is provided to allow QwaveApi mocking in tests. |UDPSocketWin| method
421   // implementations should call |GetQwaveApi()| instead of
422   // |QwaveApi::GetDefault()| directly.
423   virtual QwaveApi* GetQwaveApi() const;
424 
425   SOCKET socket_;
426   int addr_family_ = 0;
427   bool is_connected_ = false;
428 
429   // Bitwise-or'd combination of SocketOptions. Specifies the set of
430   // options that should be applied to |socket_| before Bind().
431   int socket_options_;
432 
433   // Multicast interface.
434   uint32_t multicast_interface_ = 0;
435 
436   // Multicast socket options cached for SetMulticastOption.
437   // Cannot be used after Bind().
438   int multicast_time_to_live_ = 1;
439 
440   // These are mutable since they're just cached copies to make
441   // GetPeerAddress/GetLocalAddress smarter.
442   mutable std::unique_ptr<IPEndPoint> local_address_;
443   mutable std::unique_ptr<IPEndPoint> remote_address_;
444 
445   // The core of the socket that can live longer than the socket itself. We pass
446   // resources to the Windows async IO functions and we have to make sure that
447   // they are not destroyed while the OS still references them.
448   scoped_refptr<Core> core_;
449 
450   // True if non-blocking IO is used.
451   bool use_non_blocking_io_ = false;
452 
453   // Watches |read_write_event_|.
454   base::win::ObjectWatcher read_write_watcher_;
455 
456   // Events for read and write.
457   base::win::ScopedHandle read_write_event_;
458 
459   // The buffers used in Read() and Write().
460   scoped_refptr<IOBuffer> read_iobuffer_;
461   scoped_refptr<IOBuffer> write_iobuffer_;
462 
463   int read_iobuffer_len_ = 0;
464   int write_iobuffer_len_ = 0;
465 
466   raw_ptr<IPEndPoint> recv_from_address_ = nullptr;
467 
468   // Cached copy of the current address we're sending to, if any.  Used for
469   // logging.
470   std::unique_ptr<IPEndPoint> send_to_address_;
471 
472   // External callback; called when read is complete.
473   CompletionOnceCallback read_callback_;
474 
475   // External callback; called when write is complete.
476   CompletionOnceCallback write_callback_;
477 
478   NetLogWithSource net_log_;
479 
480   // Maintains remote addresses for QWAVE qos management.
481   std::unique_ptr<DscpManager> dscp_manager_;
482 
483   // Manages decrementing the global open UDP socket counter when this
484   // UDPSocket is destroyed.
485   OwnedUDPSocketCount owned_socket_count_;
486 
487   THREAD_CHECKER(thread_checker_);
488 
489   // Used to prevent null dereferences in OnObjectSignaled, when passing an
490   // error to both read and write callbacks. Cleared in Close()
491   base::WeakPtrFactory<UDPSocketWin> event_pending_{this};
492 };
493 
494 //-----------------------------------------------------------------------------
495 
496 
497 
498 }  // namespace net
499 
500 #endif  // NET_SOCKET_UDP_SOCKET_WIN_H_
501