• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_UDP_SOCKET_WIN_H_
6 #define NET_SOCKET_UDP_SOCKET_WIN_H_
7 
8 #include <qos2.h>
9 #include <stdint.h>
10 #include <winsock2.h>
11 
12 #include <atomic>
13 #include <memory>
14 #include <set>
15 
16 #include "base/gtest_prod_util.h"
17 #include "base/memory/raw_ptr.h"
18 #include "base/memory/scoped_refptr.h"
19 #include "base/memory/weak_ptr.h"
20 #include "base/threading/thread_checker.h"
21 #include "base/win/object_watcher.h"
22 #include "base/win/scoped_handle.h"
23 #include "net/base/address_family.h"
24 #include "net/base/completion_once_callback.h"
25 #include "net/base/io_buffer.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_export.h"
28 #include "net/base/network_handle.h"
29 #include "net/log/net_log_with_source.h"
30 #include "net/socket/datagram_socket.h"
31 #include "net/socket/diff_serv_code_point.h"
32 #include "net/socket/udp_socket_global_limits.h"
33 #include "net/traffic_annotation/network_traffic_annotation.h"
34 
35 namespace net {
36 
37 class IPAddress;
38 class NetLog;
39 struct NetLogSource;
40 class SocketTag;
41 
42 // QWAVE (Quality Windows Audio/Video Experience) is the latest windows
43 // library for setting packet priorities (and other things). Unfortunately,
44 // Microsoft has decided that setting the DSCP bits with setsockopt() no
45 // longer works, so we have to use this API instead.
46 // This class is meant to be used as a singleton. It exposes a few dynamically
47 // loaded functions and a bool called "qwave_supported".
48 class NET_EXPORT QwaveApi {
49   typedef BOOL(WINAPI* CreateHandleFn)(PQOS_VERSION, PHANDLE);
50   typedef BOOL(WINAPI* CloseHandleFn)(HANDLE);
51   typedef BOOL(WINAPI* AddSocketToFlowFn)(HANDLE,
52                                           SOCKET,
53                                           PSOCKADDR,
54                                           QOS_TRAFFIC_TYPE,
55                                           DWORD,
56                                           PQOS_FLOWID);
57   typedef BOOL(WINAPI* RemoveSocketFromFlowFn)(HANDLE,
58                                                SOCKET,
59                                                QOS_FLOWID,
60                                                DWORD);
61   typedef BOOL(WINAPI* SetFlowFn)(HANDLE,
62                                   QOS_FLOWID,
63                                   QOS_SET_FLOW,
64                                   ULONG,
65                                   PVOID,
66                                   DWORD,
67                                   LPOVERLAPPED);
68 
69  public:
70   QwaveApi();
71 
72   QwaveApi(const QwaveApi&) = delete;
73   QwaveApi& operator=(const QwaveApi&) = delete;
74 
75   static QwaveApi* GetDefault();
76 
77   virtual bool qwave_supported() const;
78   virtual void OnFatalError();
79 
80   virtual BOOL CreateHandle(PQOS_VERSION version, PHANDLE handle);
81   virtual BOOL CloseHandle(HANDLE handle);
82   virtual BOOL AddSocketToFlow(HANDLE handle,
83                                SOCKET socket,
84                                PSOCKADDR addr,
85                                QOS_TRAFFIC_TYPE traffic_type,
86                                DWORD flags,
87                                PQOS_FLOWID flow_id);
88   virtual BOOL RemoveSocketFromFlow(HANDLE handle,
89                                     SOCKET socket,
90                                     QOS_FLOWID flow_id,
91                                     DWORD reserved);
92   virtual BOOL SetFlow(HANDLE handle,
93                        QOS_FLOWID flow_id,
94                        QOS_SET_FLOW op,
95                        ULONG size,
96                        PVOID data,
97                        DWORD reserved,
98                        LPOVERLAPPED overlapped);
99 
100  private:
101   std::atomic<bool> qwave_supported_{false};
102 
103   CreateHandleFn create_handle_func_;
104   CloseHandleFn close_handle_func_;
105   AddSocketToFlowFn add_socket_to_flow_func_;
106   RemoveSocketFromFlowFn remove_socket_from_flow_func_;
107   SetFlowFn set_flow_func_;
108 };
109 
110 //-----------------------------------------------------------------------------
111 
112 // Helper for maintaining the state that (unlike a blanket socket option), DSCP
113 // values are set per-remote endpoint instead of just per-socket on Windows.
114 // The implementation creates a single QWAVE 'flow' for the socket, and adds
115 // all encountered remote addresses to that flow.  Flows are the minimum
116 // manageable unit within the QWAVE API.  See
117 // https://docs.microsoft.com/en-us/previous-versions/windows/desktop/api/qos2/
118 // for Microsoft's documentation.
119 class NET_EXPORT DscpManager {
120  public:
121   DscpManager(QwaveApi* api, SOCKET socket);
122 
123   DscpManager(const DscpManager&) = delete;
124   DscpManager& operator=(const DscpManager&) = delete;
125 
126   ~DscpManager();
127 
128   // Remembers the latest |dscp| so PrepareToSend can add remote addresses to
129   // the qos flow. Destroys the old flow if it exists and |dscp| changes.
130   void Set(DiffServCodePoint dscp);
131 
132   // Constructs a qos flow for the latest set DSCP value if we don't already
133   // have one. Adds |remote_address| to the qos flow if it hasn't been added
134   // already. Does nothing if no DSCP value has been Set.
135   int PrepareForSend(const IPEndPoint& remote_address);
136 
137  private:
138   void RequestHandle();
139   static HANDLE DoCreateHandle(QwaveApi* api);
140   static void OnHandleCreated(QwaveApi* api,
141                               base::WeakPtr<DscpManager> dscp_manager,
142                               HANDLE handle);
143 
144   const raw_ptr<QwaveApi> api_;
145   const SOCKET socket_;
146 
147   DiffServCodePoint dscp_value_ = DSCP_NO_CHANGE;
148   // The remote addresses currently in the flow.
149   std::set<IPEndPoint> configured_;
150 
151   HANDLE qos_handle_ = nullptr;
152   bool handle_is_initializing_ = false;
153   // 0 means no flow has been constructed.
154   QOS_FLOWID flow_id_ = 0;
155   base::WeakPtrFactory<DscpManager> weak_ptr_factory_{this};
156 };
157 
158 //-----------------------------------------------------------------------------
159 
160 class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate {
161  public:
162   // BindType is ignored. Windows has an option to do random binds, so
163   // UDPSocketWin sets that whenever connecting a socket.
164   UDPSocketWin(DatagramSocket::BindType bind_type,
165                net::NetLog* net_log,
166                const net::NetLogSource& source);
167 
168   UDPSocketWin(DatagramSocket::BindType bind_type,
169                NetLogWithSource source_net_log);
170 
171   UDPSocketWin(const UDPSocketWin&) = delete;
172   UDPSocketWin& operator=(const UDPSocketWin&) = delete;
173 
174   ~UDPSocketWin() override;
175 
176   // Opens the socket.
177   // Returns a net error code.
178   int Open(AddressFamily address_family);
179 
180   // Not implemented. Returns ERR_NOT_IMPLEMENTED.
181   int BindToNetwork(handles::NetworkHandle network);
182 
183   // Connects the socket to connect with a certain |address|.
184   // Should be called after Open().
185   // Returns a net error code.
186   int Connect(const IPEndPoint& address);
187 
188   // Binds the address/port for this socket to |address|.  This is generally
189   // only used on a server. Should be called after Open().
190   // Returns a net error code.
191   int Bind(const IPEndPoint& address);
192 
193   // Closes the socket.
194   void Close();
195 
196   // Copies the remote udp address into |address| and returns a net error code.
197   int GetPeerAddress(IPEndPoint* address) const;
198 
199   // Copies the local udp address into |address| and returns a net error code.
200   // (similar to getsockname)
201   int GetLocalAddress(IPEndPoint* address) const;
202 
203   // IO:
204   // Multiple outstanding read requests are not supported.
205   // Full duplex mode (reading and writing at the same time) is supported
206 
207   // Reads from the socket.
208   // Only usable from the client-side of a UDP socket, after the socket
209   // has been connected.
210   int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback);
211 
212   // Writes to the socket.
213   // Only usable from the client-side of a UDP socket, after the socket
214   // has been connected.
215   int Write(IOBuffer* buf,
216             int buf_len,
217             CompletionOnceCallback callback,
218             const NetworkTrafficAnnotationTag& traffic_annotation);
219 
220   // Reads from a socket and receive sender address information.
221   // |buf| is the buffer to read data into.
222   // |buf_len| is the maximum amount of data to read.
223   // |address| is a buffer provided by the caller for receiving the sender
224   //   address information about the received data.  This buffer must be kept
225   //   alive by the caller until the callback is placed.
226   // |callback| is the callback on completion of the RecvFrom.
227   // Returns a net error code, or ERR_IO_PENDING if the IO is in progress.
228   // If ERR_IO_PENDING is returned, this socket takes a ref to |buf| to keep
229   // it alive until the data is received. However, the caller must keep
230   // |address| alive until the callback is called.
231   int RecvFrom(IOBuffer* buf,
232                int buf_len,
233                IPEndPoint* address,
234                CompletionOnceCallback callback);
235 
236   // Sends to a socket with a particular destination.
237   // |buf| is the buffer to send.
238   // |buf_len| is the number of bytes to send.
239   // |address| is the recipient address.
240   // |callback| is the user callback function to call on complete.
241   // Returns a net error code, or ERR_IO_PENDING if the IO is in progress.
242   // If ERR_IO_PENDING is returned, this socket copies |address| for
243   // asynchronous sending, and takes a ref to |buf| to keep it alive until the
244   // data is sent.
245   int SendTo(IOBuffer* buf,
246              int buf_len,
247              const IPEndPoint& address,
248              CompletionOnceCallback callback);
249 
250   // Sets the receive buffer size (in bytes) for the socket.
251   // Returns a net error code.
252   int SetReceiveBufferSize(int32_t size);
253 
254   // Sets the send buffer size (in bytes) for the socket.
255   // Returns a net error code.
256   int SetSendBufferSize(int32_t size);
257 
258   // Requests that packets sent by this socket not be fragment, either locally
259   // by the host, or by routers (via the DF bit in the IPv4 packet header).
260   // May not be supported by all platforms. Returns a network error code if
261   // there was a problem, but the socket will still be usable. Can not
262   // return ERR_IO_PENDING.
263   int SetDoNotFragment();
264 
265   // Requests that packets received by this socket have the ECN bit set. Returns
266   // a network error code if there was a problem.
267   int SetRecvEcn();
268 
269   // This is a no-op on Windows.
270   void SetMsgConfirm(bool confirm);
271 
272   // Returns true if the socket is already connected or bound.
is_connected()273   bool is_connected() const { return is_connected_; }
274 
NetLog()275   const NetLogWithSource& NetLog() const { return net_log_; }
276 
277   // Sets socket options to allow the socket to share the local address to which
278   // the socket will be bound with other processes. If multiple processes are
279   // bound to the same local address at the same time, behavior is undefined;
280   // e.g., it is not guaranteed that incoming  messages will be sent to all
281   // listening sockets. Returns a net error code.
282   //
283   // Should be called between Open() and Bind().
284   int AllowAddressReuse();
285 
286   // Sets socket options to allow sending and receiving packets to and from
287   // broadcast addresses.
288   int SetBroadcast(bool broadcast);
289 
290   // Sets socket options to allow the socket to share the local address to which
291   // the socket will be bound with other processes and attempt to allow all such
292   // sockets to receive the same multicast messages. Returns a net error code.
293   //
294   // For Windows, multicast messages should always be shared between sockets
295   // configured thusly as long as the sockets join the same multicast group and
296   // interface.
297   //
298   // Should be called between Open() and Bind().
299   int AllowAddressSharingForMulticast();
300 
301   // Joins the multicast group.
302   // |group_address| is the group address to join, could be either
303   // an IPv4 or IPv6 address.
304   // Returns a net error code.
305   int JoinGroup(const IPAddress& group_address) const;
306 
307   // Leaves the multicast group.
308   // |group_address| is the group address to leave, could be either
309   // an IPv4 or IPv6 address. If the socket hasn't joined the group,
310   // it will be ignored.
311   // It's optional to leave the multicast group before destroying
312   // the socket. It will be done by the OS.
313   // Return a net error code.
314   int LeaveGroup(const IPAddress& group_address) const;
315 
316   // Sets interface to use for multicast. If |interface_index| set to 0,
317   // default interface is used.
318   // Should be called before Bind().
319   // Returns a net error code.
320   int SetMulticastInterface(uint32_t interface_index);
321 
322   // Sets the time-to-live option for UDP packets sent to the multicast
323   // group address. The default value of this option is 1.
324   // Cannot be negative or more than 255.
325   // Should be called before Bind().
326   int SetMulticastTimeToLive(int time_to_live);
327 
328   // Sets the loopback flag for UDP socket. If this flag is true, the host
329   // will receive packets sent to the joined group from itself.
330   // The default value of this option is true.
331   // Should be called before Bind().
332   //
333   // Note: the behavior of |SetMulticastLoopbackMode| is slightly
334   // different between Windows and Unix-like systems. The inconsistency only
335   // happens when there are more than one applications on the same host
336   // joined to the same multicast group while having different settings on
337   // multicast loopback mode. On Windows, the applications with loopback off
338   // will not RECEIVE the loopback packets; while on Unix-like systems, the
339   // applications with loopback off will not SEND the loopback packets to
340   // other applications on the same host. See MSDN: http://goo.gl/6vqbj
341   int SetMulticastLoopbackMode(bool loopback);
342 
343   // Sets the differentiated services flags on outgoing packets. May not do
344   // anything on some platforms. A return value of ERR_INVALID_HANDLE indicates
345   // the value was not set but could succeed on a future call, because
346   // initialization is in progress.
347   int SetDiffServCodePoint(DiffServCodePoint dscp);
348 
349   // Sets IPV6_V6ONLY on the socket. If this flag is true, the socket will be
350   // restricted to only IPv6; false allows both IPv4 and IPv6 traffic.
351   int SetIPv6Only(bool ipv6_only);
352 
353   // Resets the thread to be used for thread-safety checks.
354   void DetachFromThread();
355 
356   // This class by default uses overlapped IO. Call this method before Open() or
357   // AdoptOpenedSocket() to switch to non-blocking IO.
358   void UseNonBlockingIO();
359 
360   // Apply |tag| to this socket.
361   void ApplySocketTag(const SocketTag& tag);
362 
363   // Takes ownership of `socket`, which should be a socket descriptor opened
364   // with the specified address family. The socket should only be created but
365   // not bound or connected to an address. This method must be called after
366   // UseNonBlockingIO, otherwise the adopted socket will not have the
367   // non-blocking IO flag set.
368   int AdoptOpenedSocket(AddressFamily address_family, SOCKET socket);
369 
get_multicast_interface_for_testing()370   uint32_t get_multicast_interface_for_testing() {
371     return multicast_interface_;
372   }
get_use_non_blocking_io_for_testing()373   bool get_use_non_blocking_io_for_testing() { return use_non_blocking_io_; }
374 
375  private:
376   enum SocketOptions {
377     SOCKET_OPTION_MULTICAST_LOOP = 1 << 0
378   };
379 
380   class Core;
381 
382   void DoReadCallback(int rv);
383   void DoWriteCallback(int rv);
384 
385   void DidCompleteRead();
386   void DidCompleteWrite();
387 
388   // base::ObjectWatcher::Delegate implementation.
389   void OnObjectSignaled(HANDLE object) override;
390   void OnReadSignaled();
391   void OnWriteSignaled();
392 
393   void WatchForReadWrite();
394 
395   // Handles stats and logging. |result| is the number of bytes transferred, on
396   // success, or the net error code on failure.
397   void LogRead(int result, const char* bytes, const IPEndPoint* address) const;
398   void LogWrite(int result, const char* bytes, const IPEndPoint* address) const;
399 
400   // Same as SendTo(), except that address is passed by pointer
401   // instead of by reference. It is called from Write() with |address|
402   // set to NULL.
403   int SendToOrWrite(IOBuffer* buf,
404                     int buf_len,
405                     const IPEndPoint* address,
406                     CompletionOnceCallback callback);
407 
408   int InternalConnect(const IPEndPoint& address);
409 
410   // Version for using overlapped IO.
411   int InternalRecvFromOverlapped(IOBuffer* buf,
412                                  int buf_len,
413                                  IPEndPoint* address);
414   int InternalSendToOverlapped(IOBuffer* buf,
415                                int buf_len,
416                                const IPEndPoint* address);
417 
418   // Version for using non-blocking IO.
419   int InternalRecvFromNonBlocking(IOBuffer* buf,
420                                   int buf_len,
421                                   IPEndPoint* address);
422   int InternalSendToNonBlocking(IOBuffer* buf,
423                                 int buf_len,
424                                 const IPEndPoint* address);
425 
426   // Applies |socket_options_| to |socket_|. Should be called before
427   // Bind().
428   int SetMulticastOptions();
429   int DoBind(const IPEndPoint& address);
430 
431   // Configures opened `socket_` depending on whether it uses nonblocking IO.
432   void ConfigureOpenedSocket();
433 
434   // This is provided to allow QwaveApi mocking in tests. |UDPSocketWin| method
435   // implementations should call |GetQwaveApi()| instead of
436   // |QwaveApi::GetDefault()| directly.
437   virtual QwaveApi* GetQwaveApi() const;
438 
439   SOCKET socket_;
440   int addr_family_ = 0;
441   bool is_connected_ = false;
442 
443   // Bitwise-or'd combination of SocketOptions. Specifies the set of
444   // options that should be applied to |socket_| before Bind().
445   int socket_options_;
446 
447   // Multicast interface.
448   uint32_t multicast_interface_ = 0;
449 
450   // Multicast socket options cached for SetMulticastOption.
451   // Cannot be used after Bind().
452   int multicast_time_to_live_ = 1;
453 
454   // These are mutable since they're just cached copies to make
455   // GetPeerAddress/GetLocalAddress smarter.
456   mutable std::unique_ptr<IPEndPoint> local_address_;
457   mutable std::unique_ptr<IPEndPoint> remote_address_;
458 
459   // The core of the socket that can live longer than the socket itself. We pass
460   // resources to the Windows async IO functions and we have to make sure that
461   // they are not destroyed while the OS still references them.
462   scoped_refptr<Core> core_;
463 
464   // True if non-blocking IO is used.
465   bool use_non_blocking_io_ = false;
466 
467   // Watches |read_write_event_|.
468   base::win::ObjectWatcher read_write_watcher_;
469 
470   // Events for read and write.
471   base::win::ScopedHandle read_write_event_;
472 
473   // The buffers used in Read() and Write().
474   scoped_refptr<IOBuffer> read_iobuffer_;
475   scoped_refptr<IOBuffer> write_iobuffer_;
476 
477   int read_iobuffer_len_ = 0;
478   int write_iobuffer_len_ = 0;
479 
480   raw_ptr<IPEndPoint> recv_from_address_ = nullptr;
481 
482   // Cached copy of the current address we're sending to, if any.  Used for
483   // logging.
484   std::unique_ptr<IPEndPoint> send_to_address_;
485 
486   // External callback; called when read is complete.
487   CompletionOnceCallback read_callback_;
488 
489   // External callback; called when write is complete.
490   CompletionOnceCallback write_callback_;
491 
492   NetLogWithSource net_log_;
493 
494   // Maintains remote addresses for QWAVE qos management.
495   std::unique_ptr<DscpManager> dscp_manager_;
496 
497   // Manages decrementing the global open UDP socket counter when this
498   // UDPSocket is destroyed.
499   OwnedUDPSocketCount owned_socket_count_;
500 
501   THREAD_CHECKER(thread_checker_);
502 
503   // Used to prevent null dereferences in OnObjectSignaled, when passing an
504   // error to both read and write callbacks. Cleared in Close()
505   base::WeakPtrFactory<UDPSocketWin> event_pending_{this};
506 };
507 
508 //-----------------------------------------------------------------------------
509 
510 
511 
512 }  // namespace net
513 
514 #endif  // NET_SOCKET_UDP_SOCKET_WIN_H_
515