1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::mem::MaybeUninit;
15 use std::net::SocketAddr;
16 use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
17 use std::time::Duration;
18 use std::{io, mem, net};
19
20 use libc::{c_int, getsockopt};
21 use windows_sys::Win32::Networking::WinSock::{
22 self, closesocket, ioctlsocket, setsockopt, socket, ADDRESS_FAMILY, AF_INET, AF_INET6, FIONBIO,
23 INVALID_SOCKET, LINGER, SOCKET, SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_LINGER,
24 };
25
26 use crate::sys::windows::net::init;
27 use crate::sys::windows::socket_addr::socket_addr_trans;
28
29 pub(crate) struct TcpSocket {
30 socket: SOCKET,
31 }
32
33 impl TcpSocket {
34 /// Gets new socket
new_socket(addr: SocketAddr) -> io::Result<TcpSocket>35 pub(crate) fn new_socket(addr: SocketAddr) -> io::Result<TcpSocket> {
36 if addr.is_ipv4() {
37 Self::create_socket(AF_INET, SOCK_STREAM)
38 } else {
39 Self::create_socket(AF_INET6, SOCK_STREAM)
40 }
41 }
42
create_socket(domain: ADDRESS_FAMILY, socket_type: u16) -> io::Result<TcpSocket>43 fn create_socket(domain: ADDRESS_FAMILY, socket_type: u16) -> io::Result<TcpSocket> {
44 init();
45
46 let socket = socket_syscall!(
47 socket(domain as i32, socket_type as i32, 0),
48 PartialEq::eq,
49 INVALID_SOCKET
50 )?;
51
52 match socket_syscall!(ioctlsocket(socket, FIONBIO, &mut 1), PartialEq::ne, 0) {
53 Err(err) => {
54 let _ = unsafe { closesocket(socket) };
55 Err(err)
56 }
57 Ok(_) => Ok(TcpSocket {
58 socket: socket as SOCKET,
59 }),
60 }
61 }
62
63 /// System call to bind Socket.
bind(&self, addr: SocketAddr) -> io::Result<()>64 pub(crate) fn bind(&self, addr: SocketAddr) -> io::Result<()> {
65 use WinSock::bind;
66
67 let (raw_addr, raw_addr_length) = socket_addr_trans(&addr);
68 socket_syscall!(
69 bind(self.socket as _, raw_addr.as_ptr(), raw_addr_length),
70 PartialEq::eq,
71 SOCKET_ERROR
72 )?;
73 Ok(())
74 }
75
76 /// System call to listen.
listen(self, backlog: u32) -> io::Result<()>77 pub(crate) fn listen(self, backlog: u32) -> io::Result<()> {
78 use std::convert::TryInto;
79
80 use WinSock::listen;
81
82 let backlog = backlog.try_into().unwrap_or(i32::MAX);
83 socket_syscall!(
84 listen(self.socket as _, backlog),
85 PartialEq::eq,
86 SOCKET_ERROR
87 )?;
88 mem::forget(self);
89 Ok(())
90 }
91
92 /// System call to connect.
connect(self, addr: SocketAddr) -> io::Result<()>93 pub(crate) fn connect(self, addr: SocketAddr) -> io::Result<()> {
94 use WinSock::connect;
95
96 let (socket_addr, socket_addr_length) = socket_addr_trans(&addr);
97 let res = socket_syscall!(
98 connect(self.socket as _, socket_addr.as_ptr(), socket_addr_length),
99 PartialEq::eq,
100 SOCKET_ERROR
101 );
102
103 match res {
104 Err(e) if e.kind() != io::ErrorKind::WouldBlock => Err(e),
105 _ => {
106 mem::forget(self);
107 Ok(())
108 }
109 }
110 }
111
112 /// Closes Socket
close(&self)113 pub(crate) fn close(&self) {
114 let _ = unsafe { net::TcpStream::from_raw_socket(self.socket as RawSocket) };
115 }
116 }
117
118 impl AsRawSocket for TcpSocket {
as_raw_socket(&self) -> RawSocket119 fn as_raw_socket(&self) -> RawSocket {
120 self.socket as RawSocket
121 }
122 }
123
124 impl FromRawSocket for TcpSocket {
from_raw_socket(sock: RawSocket) -> Self125 unsafe fn from_raw_socket(sock: RawSocket) -> Self {
126 TcpSocket {
127 socket: sock as SOCKET,
128 }
129 }
130 }
131
132 impl Drop for TcpSocket {
drop(&mut self)133 fn drop(&mut self) {
134 self.close();
135 }
136 }
137
get_sock_linger(socket: RawSocket) -> io::Result<Option<Duration>>138 pub(crate) fn get_sock_linger(socket: RawSocket) -> io::Result<Option<Duration>> {
139 let mut optval: MaybeUninit<LINGER> = MaybeUninit::uninit();
140 let mut optlen = mem::size_of::<LINGER>() as c_int;
141
142 socket_syscall!(
143 getsockopt(
144 socket as SOCKET,
145 SOL_SOCKET as c_int,
146 SO_LINGER as c_int,
147 optval.as_mut_ptr().cast(),
148 &mut optlen,
149 ),
150 PartialEq::eq,
151 SOCKET_ERROR
152 )
153 .map(|_| {
154 let linger = unsafe { optval.assume_init() };
155 from_linger(linger)
156 })
157 }
158
set_sock_linger(socket: RawSocket, linger: Option<Duration>) -> io::Result<()>159 pub(crate) fn set_sock_linger(socket: RawSocket, linger: Option<Duration>) -> io::Result<()> {
160 let optval = into_linger(linger);
161 socket_syscall!(
162 setsockopt(
163 socket as SOCKET,
164 SOL_SOCKET as c_int,
165 SO_LINGER as c_int,
166 (&optval as *const LINGER).cast(),
167 mem::size_of::<LINGER>() as c_int,
168 ),
169 PartialEq::eq,
170 SOCKET_ERROR
171 )
172 .map(|_| ())
173 }
174
from_linger(linger: LINGER) -> Option<Duration>175 fn from_linger(linger: LINGER) -> Option<Duration> {
176 if linger.l_onoff == 0 {
177 None
178 } else {
179 Some(Duration::from_secs(linger.l_linger as u64))
180 }
181 }
182
into_linger(linger: Option<Duration>) -> LINGER183 fn into_linger(linger: Option<Duration>) -> LINGER {
184 match linger {
185 None => LINGER {
186 l_onoff: 0,
187 l_linger: 0,
188 },
189 Some(dur) => LINGER {
190 l_onoff: 1,
191 l_linger: dur.as_secs() as _,
192 },
193 }
194 }
195