1 use std::io::{Read, Write}; 2 use std::net::{SocketAddr, TcpListener, TcpStream}; 3 use std::thread::{self, JoinHandle}; 4 5 use crate::ssl::{Ssl, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslRef, SslStream}; 6 7 pub struct Server { 8 handle: Option<JoinHandle<()>>, 9 addr: SocketAddr, 10 } 11 12 impl Drop for Server { drop(&mut self)13 fn drop(&mut self) { 14 if !thread::panicking() { 15 self.handle.take().unwrap().join().unwrap(); 16 } 17 } 18 } 19 20 impl Server { builder() -> Builder21 pub fn builder() -> Builder { 22 let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); 23 ctx.set_certificate_chain_file("test/cert.pem").unwrap(); 24 ctx.set_private_key_file("test/key.pem", SslFiletype::PEM) 25 .unwrap(); 26 27 Builder { 28 ctx, 29 ssl_cb: Box::new(|_| {}), 30 io_cb: Box::new(|_| {}), 31 should_error: false, 32 } 33 } 34 client(&self) -> ClientBuilder35 pub fn client(&self) -> ClientBuilder { 36 ClientBuilder { 37 ctx: SslContext::builder(SslMethod::tls()).unwrap(), 38 addr: self.addr, 39 } 40 } 41 connect_tcp(&self) -> TcpStream42 pub fn connect_tcp(&self) -> TcpStream { 43 TcpStream::connect(self.addr).unwrap() 44 } 45 } 46 47 pub struct Builder { 48 ctx: SslContextBuilder, 49 ssl_cb: Box<dyn FnMut(&mut SslRef) + Send>, 50 io_cb: Box<dyn FnMut(SslStream<TcpStream>) + Send>, 51 should_error: bool, 52 } 53 54 impl Builder { ctx(&mut self) -> &mut SslContextBuilder55 pub fn ctx(&mut self) -> &mut SslContextBuilder { 56 &mut self.ctx 57 } 58 ssl_cb<F>(&mut self, cb: F) where F: 'static + FnMut(&mut SslRef) + Send,59 pub fn ssl_cb<F>(&mut self, cb: F) 60 where 61 F: 'static + FnMut(&mut SslRef) + Send, 62 { 63 self.ssl_cb = Box::new(cb); 64 } 65 io_cb<F>(&mut self, cb: F) where F: 'static + FnMut(SslStream<TcpStream>) + Send,66 pub fn io_cb<F>(&mut self, cb: F) 67 where 68 F: 'static + FnMut(SslStream<TcpStream>) + Send, 69 { 70 self.io_cb = Box::new(cb); 71 } 72 should_error(&mut self)73 pub fn should_error(&mut self) { 74 self.should_error = true; 75 } 76 build(self) -> Server77 pub fn build(self) -> Server { 78 let ctx = self.ctx.build(); 79 let socket = TcpListener::bind("127.0.0.1:0").unwrap(); 80 let addr = socket.local_addr().unwrap(); 81 let mut ssl_cb = self.ssl_cb; 82 let mut io_cb = self.io_cb; 83 let should_error = self.should_error; 84 85 let handle = thread::spawn(move || { 86 let socket = socket.accept().unwrap().0; 87 let mut ssl = Ssl::new(&ctx).unwrap(); 88 ssl_cb(&mut ssl); 89 let r = ssl.accept(socket); 90 if should_error { 91 r.unwrap_err(); 92 } else { 93 let mut socket = r.unwrap(); 94 socket.write_all(&[0]).unwrap(); 95 io_cb(socket); 96 } 97 }); 98 99 Server { 100 handle: Some(handle), 101 addr, 102 } 103 } 104 } 105 106 pub struct ClientBuilder { 107 ctx: SslContextBuilder, 108 addr: SocketAddr, 109 } 110 111 impl ClientBuilder { ctx(&mut self) -> &mut SslContextBuilder112 pub fn ctx(&mut self) -> &mut SslContextBuilder { 113 &mut self.ctx 114 } 115 build(self) -> Client116 pub fn build(self) -> Client { 117 Client { 118 ctx: self.ctx.build(), 119 addr: self.addr, 120 } 121 } 122 connect(self) -> SslStream<TcpStream>123 pub fn connect(self) -> SslStream<TcpStream> { 124 self.build().builder().connect() 125 } 126 connect_err(self)127 pub fn connect_err(self) { 128 self.build().builder().connect_err(); 129 } 130 } 131 132 pub struct Client { 133 ctx: SslContext, 134 addr: SocketAddr, 135 } 136 137 impl Client { builder(&self) -> ClientSslBuilder138 pub fn builder(&self) -> ClientSslBuilder { 139 ClientSslBuilder { 140 ssl: Ssl::new(&self.ctx).unwrap(), 141 addr: self.addr, 142 } 143 } 144 } 145 146 pub struct ClientSslBuilder { 147 ssl: Ssl, 148 addr: SocketAddr, 149 } 150 151 impl ClientSslBuilder { ssl(&mut self) -> &mut SslRef152 pub fn ssl(&mut self) -> &mut SslRef { 153 &mut self.ssl 154 } 155 connect(self) -> SslStream<TcpStream>156 pub fn connect(self) -> SslStream<TcpStream> { 157 let socket = TcpStream::connect(self.addr).unwrap(); 158 let mut s = self.ssl.connect(socket).unwrap(); 159 s.read_exact(&mut [0]).unwrap(); 160 s 161 } 162 connect_err(self)163 pub fn connect_err(self) { 164 let socket = TcpStream::connect(self.addr).unwrap(); 165 self.ssl.connect(socket).unwrap_err(); 166 } 167 } 168