• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::mem::MaybeUninit;
15 use std::net::SocketAddr;
16 use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
17 use std::time::Duration;
18 use std::{io, mem, net};
19 
20 use libc::{c_int, getsockopt};
21 use windows_sys::Win32::Networking::WinSock::{
22     self, closesocket, ioctlsocket, setsockopt, socket, ADDRESS_FAMILY, AF_INET, AF_INET6, FIONBIO,
23     INVALID_SOCKET, LINGER, SOCKET, SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_LINGER,
24 };
25 
26 use crate::sys::windows::net::init;
27 use crate::sys::windows::socket_addr::socket_addr_trans;
28 
29 pub(crate) struct TcpSocket {
30     socket: SOCKET,
31 }
32 
33 impl TcpSocket {
34     /// Gets new socket
new_socket(addr: SocketAddr) -> io::Result<TcpSocket>35     pub(crate) fn new_socket(addr: SocketAddr) -> io::Result<TcpSocket> {
36         if addr.is_ipv4() {
37             Self::create_socket(AF_INET, SOCK_STREAM)
38         } else {
39             Self::create_socket(AF_INET6, SOCK_STREAM)
40         }
41     }
42 
create_socket(domain: ADDRESS_FAMILY, socket_type: u16) -> io::Result<TcpSocket>43     fn create_socket(domain: ADDRESS_FAMILY, socket_type: u16) -> io::Result<TcpSocket> {
44         init();
45 
46         let socket = socket_syscall!(
47             socket(domain as i32, socket_type as i32, 0),
48             PartialEq::eq,
49             INVALID_SOCKET
50         )?;
51 
52         match socket_syscall!(ioctlsocket(socket, FIONBIO, &mut 1), PartialEq::ne, 0) {
53             Err(err) => {
54                 let _ = unsafe { closesocket(socket) };
55                 Err(err)
56             }
57             Ok(_) => Ok(TcpSocket {
58                 socket: socket as SOCKET,
59             }),
60         }
61     }
62 
63     /// System call to bind Socket.
bind(&self, addr: SocketAddr) -> io::Result<()>64     pub(crate) fn bind(&self, addr: SocketAddr) -> io::Result<()> {
65         use WinSock::bind;
66 
67         let (raw_addr, raw_addr_length) = socket_addr_trans(&addr);
68         socket_syscall!(
69             bind(self.socket as _, raw_addr.as_ptr(), raw_addr_length),
70             PartialEq::eq,
71             SOCKET_ERROR
72         )?;
73         Ok(())
74     }
75 
76     /// System call to listen.
listen(self, backlog: u32) -> io::Result<()>77     pub(crate) fn listen(self, backlog: u32) -> io::Result<()> {
78         use std::convert::TryInto;
79 
80         use WinSock::listen;
81 
82         let backlog = backlog.try_into().unwrap_or(i32::MAX);
83         socket_syscall!(
84             listen(self.socket as _, backlog),
85             PartialEq::eq,
86             SOCKET_ERROR
87         )?;
88         mem::forget(self);
89         Ok(())
90     }
91 
92     /// System call to connect.
connect(self, addr: SocketAddr) -> io::Result<()>93     pub(crate) fn connect(self, addr: SocketAddr) -> io::Result<()> {
94         use WinSock::connect;
95 
96         let (socket_addr, socket_addr_length) = socket_addr_trans(&addr);
97         let res = socket_syscall!(
98             connect(self.socket as _, socket_addr.as_ptr(), socket_addr_length),
99             PartialEq::eq,
100             SOCKET_ERROR
101         );
102 
103         match res {
104             Err(e) if e.kind() != io::ErrorKind::WouldBlock => Err(e),
105             _ => {
106                 mem::forget(self);
107                 Ok(())
108             }
109         }
110     }
111 
112     /// Closes Socket
close(&self)113     pub(crate) fn close(&self) {
114         let _ = unsafe { net::TcpStream::from_raw_socket(self.socket as RawSocket) };
115     }
116 }
117 
118 impl AsRawSocket for TcpSocket {
as_raw_socket(&self) -> RawSocket119     fn as_raw_socket(&self) -> RawSocket {
120         self.socket as RawSocket
121     }
122 }
123 
124 impl FromRawSocket for TcpSocket {
from_raw_socket(sock: RawSocket) -> Self125     unsafe fn from_raw_socket(sock: RawSocket) -> Self {
126         TcpSocket {
127             socket: sock as SOCKET,
128         }
129     }
130 }
131 
132 impl Drop for TcpSocket {
drop(&mut self)133     fn drop(&mut self) {
134         self.close();
135     }
136 }
137 
get_sock_linger(socket: RawSocket) -> io::Result<Option<Duration>>138 pub(crate) fn get_sock_linger(socket: RawSocket) -> io::Result<Option<Duration>> {
139     let mut optval: MaybeUninit<LINGER> = MaybeUninit::uninit();
140     let mut optlen = mem::size_of::<LINGER>() as c_int;
141 
142     socket_syscall!(
143         getsockopt(
144             socket as SOCKET,
145             SOL_SOCKET as c_int,
146             SO_LINGER as c_int,
147             optval.as_mut_ptr().cast(),
148             &mut optlen,
149         ),
150         PartialEq::eq,
151         SOCKET_ERROR
152     )
153     .map(|_| {
154         let linger = unsafe { optval.assume_init() };
155         from_linger(linger)
156     })
157 }
158 
set_sock_linger(socket: RawSocket, linger: Option<Duration>) -> io::Result<()>159 pub(crate) fn set_sock_linger(socket: RawSocket, linger: Option<Duration>) -> io::Result<()> {
160     let optval = into_linger(linger);
161     socket_syscall!(
162         setsockopt(
163             socket as SOCKET,
164             SOL_SOCKET as c_int,
165             SO_LINGER as c_int,
166             (&optval as *const LINGER).cast(),
167             mem::size_of::<LINGER>() as c_int,
168         ),
169         PartialEq::eq,
170         SOCKET_ERROR
171     )
172     .map(|_| ())
173 }
174 
from_linger(linger: LINGER) -> Option<Duration>175 fn from_linger(linger: LINGER) -> Option<Duration> {
176     if linger.l_onoff == 0 {
177         None
178     } else {
179         Some(Duration::from_secs(linger.l_linger as u64))
180     }
181 }
182 
into_linger(linger: Option<Duration>) -> LINGER183 fn into_linger(linger: Option<Duration>) -> LINGER {
184     match linger {
185         None => LINGER {
186             l_onoff: 0,
187             l_linger: 0,
188         },
189         Some(dur) => LINGER {
190             l_onoff: 1,
191             l_linger: dur.as_secs() as _,
192         },
193     }
194 }
195