// Copyright 2020 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. /// Structs to supplement std::net. use std::io; use std::mem::{self, size_of}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs}; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use libc::{ c_int, in6_addr, in_addr, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, socklen_t, AF_INET, AF_INET6, SOCK_CLOEXEC, SOCK_STREAM, }; /// Assist in handling both IP version 4 and IP version 6. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum InetVersion { V4, V6, } impl InetVersion { pub fn from_sockaddr(s: &SocketAddr) -> Self { match s { SocketAddr::V4(_) => InetVersion::V4, SocketAddr::V6(_) => InetVersion::V6, } } } impl Into for InetVersion { fn into(self) -> sa_family_t { match self { InetVersion::V4 => AF_INET as sa_family_t, InetVersion::V6 => AF_INET6 as sa_family_t, } } } fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in { sockaddr_in { sin_family: AF_INET as sa_family_t, sin_port: s.port().to_be(), sin_addr: in_addr { s_addr: u32::from_ne_bytes(s.ip().octets()), }, sin_zero: [0; 8], } } fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 { sockaddr_in6 { sin6_family: AF_INET6 as sa_family_t, sin6_port: s.port().to_be(), sin6_flowinfo: 0, sin6_addr: in6_addr { s6_addr: s.ip().octets(), }, sin6_scope_id: 0, } } /// A TCP socket. /// /// Do not use this class unless you need to change socket options or query the /// state of the socket prior to calling listen or connect. Instead use either TcpStream or /// TcpListener. #[derive(Debug)] pub struct TcpSocket { inet_version: InetVersion, fd: RawFd, } impl TcpSocket { pub fn new(inet_version: InetVersion) -> io::Result { let fd = unsafe { libc::socket( Into::::into(inet_version) as c_int, SOCK_STREAM | SOCK_CLOEXEC, 0, ) }; if fd < 0 { Err(io::Error::last_os_error()) } else { Ok(TcpSocket { inet_version, fd }) } } pub fn bind(&mut self, addr: A) -> io::Result<()> { let sockaddr = addr .to_socket_addrs() .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? .next() .unwrap(); let ret = match sockaddr { SocketAddr::V4(a) => { let sin = sockaddrv4_to_lib_c(&a); // Safe because this doesn't modify any memory and we check the return value. unsafe { libc::bind( self.fd, &sin as *const sockaddr_in as *const sockaddr, size_of::() as socklen_t, ) } } SocketAddr::V6(a) => { let sin6 = sockaddrv6_to_lib_c(&a); // Safe because this doesn't modify any memory and we check the return value. unsafe { libc::bind( self.fd, &sin6 as *const sockaddr_in6 as *const sockaddr, size_of::() as socklen_t, ) } } }; if ret < 0 { let bind_err = io::Error::last_os_error(); Err(bind_err) } else { Ok(()) } } pub fn connect(self, addr: A) -> io::Result { let sockaddr = addr .to_socket_addrs() .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? .next() .unwrap(); let ret = match sockaddr { SocketAddr::V4(a) => { let sin = sockaddrv4_to_lib_c(&a); // Safe because this doesn't modify any memory and we check the return value. unsafe { libc::connect( self.fd, &sin as *const sockaddr_in as *const sockaddr, size_of::() as socklen_t, ) } } SocketAddr::V6(a) => { let sin6 = sockaddrv6_to_lib_c(&a); // Safe because this doesn't modify any memory and we check the return value. unsafe { libc::connect( self.fd, &sin6 as *const sockaddr_in6 as *const sockaddr, size_of::() as socklen_t, ) } } }; if ret < 0 { let connect_err = io::Error::last_os_error(); Err(connect_err) } else { // Safe because the ownership of the raw fd is released from self and taken over by the // new TcpStream. Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) }) } } pub fn listen(self) -> io::Result { // Safe because this doesn't modify any memory and we check the return value. let ret = unsafe { libc::listen(self.fd, 1) }; if ret < 0 { let listen_err = io::Error::last_os_error(); Err(listen_err) } else { // Safe because the ownership of the raw fd is released from self and taken over by the // new TcpListener. Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) }) } } /// Returns the port that this socket is bound to. This can only succeed after bind is called. pub fn local_port(&self) -> io::Result { match self.inet_version { InetVersion::V4 => { let mut sin = sockaddr_in { sin_family: 0, sin_port: 0, sin_addr: in_addr { s_addr: 0 }, sin_zero: [0; 8], }; // Safe because we give a valid pointer for addrlen and check the length. let mut addrlen = size_of::() as socklen_t; let ret = unsafe { // Get the socket address that was actually bound. libc::getsockname( self.fd, &mut sin as *mut sockaddr_in as *mut sockaddr, &mut addrlen as *mut socklen_t, ) }; if ret < 0 { let getsockname_err = io::Error::last_os_error(); Err(getsockname_err) } else { // If this doesn't match, it's not safe to get the port out of the sockaddr. assert_eq!(addrlen as usize, size_of::()); Ok(sin.sin_port) } } InetVersion::V6 => { let mut sin6 = sockaddr_in6 { sin6_family: 0, sin6_port: 0, sin6_flowinfo: 0, sin6_addr: in6_addr { s6_addr: [0; 16] }, sin6_scope_id: 0, }; // Safe because we give a valid pointer for addrlen and check the length. let mut addrlen = size_of::() as socklen_t; let ret = unsafe { // Get the socket address that was actually bound. libc::getsockname( self.fd, &mut sin6 as *mut sockaddr_in6 as *mut sockaddr, &mut addrlen as *mut socklen_t, ) }; if ret < 0 { let getsockname_err = io::Error::last_os_error(); Err(getsockname_err) } else { // If this doesn't match, it's not safe to get the port out of the sockaddr. assert_eq!(addrlen as usize, size_of::()); Ok(sin6.sin6_port) } } } } } impl IntoRawFd for TcpSocket { fn into_raw_fd(self) -> RawFd { let fd = self.fd; mem::forget(self); fd } } impl AsRawFd for TcpSocket { fn as_raw_fd(&self) -> RawFd { self.fd } } impl Drop for TcpSocket { fn drop(&mut self) { // Safe because this doesn't modify any memory and we are the only // owner of the file descriptor. unsafe { libc::close(self.fd) }; } }