• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! DoH server frontend.
18 
19 use super::client::{ClientMap, ConnectionID, CONN_ID_LEN, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use super::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use super::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use log::{debug, error, warn};
24 use std::fs::File;
25 use std::io::Write;
26 use std::os::unix::io::{AsRawFd, FromRawFd};
27 use std::sync::{Arc, LazyLock, Mutex};
28 use std::time::Duration;
29 use tokio::net::UdpSocket;
30 use tokio::runtime::{Builder, Runtime};
31 use tokio::sync::{mpsc, oneshot};
32 use tokio::task::JoinHandle;
33 
34 static RUNTIME_STATIC: LazyLock<Arc<Runtime>> = LazyLock::new(|| {
35     Arc::new(
36         Builder::new_multi_thread()
37             .worker_threads(1)
38             .enable_all()
39             .thread_name("DohFrontend")
40             .build()
41             .expect("Failed to create tokio runtime"),
42     )
43 });
44 
45 /// Command used by worker_thread itself.
46 #[derive(Debug)]
47 enum InternalCommand {
48     MaybeWrite { connection_id: ConnectionID },
49 }
50 
51 /// Commands that DohFrontend to ask its worker_thread for.
52 #[derive(Debug)]
53 enum ControlCommand {
54     Stats { resp: oneshot::Sender<Stats> },
55     StatsClearQueries,
56     CloseConnection,
57 }
58 
59 /// Frontend object.
60 #[derive(Debug)]
61 pub struct DohFrontend {
62     // Socket address the frontend listens to.
63     listen_socket_addr: std::net::SocketAddr,
64 
65     // Socket address the backend listens to.
66     backend_socket_addr: std::net::SocketAddr,
67 
68     /// The content of the certificate.
69     certificate: String,
70 
71     /// The content of the private key.
72     private_key: String,
73 
74     // The thread listening to frontend socket and backend socket
75     // and processing the messages.
76     worker_thread: Option<JoinHandle<Result<()>>>,
77 
78     // Custom runtime configuration to control the behavior of the worker thread.
79     // It's shared with the worker thread.
80     // TODO: use channel to update worker_thread configuration.
81     config: Arc<Mutex<Config>>,
82 
83     // Caches the latest stats so that the stats remains after worker_thread stops.
84     latest_stats: Stats,
85 
86     // It is wrapped as Option because the channel is not created in DohFrontend construction.
87     command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
88 }
89 
90 /// The parameters passed to the worker thread.
91 struct WorkerParams {
92     frontend_socket: std::net::UdpSocket,
93     backend_socket: std::net::UdpSocket,
94     clients: ClientMap,
95     config: Arc<Mutex<Config>>,
96     command_rx: mpsc::UnboundedReceiver<ControlCommand>,
97 }
98 
99 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>100     pub fn new(
101         listen: std::net::SocketAddr,
102         backend: std::net::SocketAddr,
103     ) -> Result<Box<DohFrontend>> {
104         let doh = Box::new(DohFrontend {
105             listen_socket_addr: listen,
106             backend_socket_addr: backend,
107             certificate: String::new(),
108             private_key: String::new(),
109             worker_thread: None,
110             config: Arc::new(Mutex::new(Config::new())),
111             latest_stats: Stats::new(),
112             command_tx: None,
113         });
114         debug!("DohFrontend created: {:?}", doh);
115         Ok(doh)
116     }
117 
start(&mut self) -> Result<()>118     pub fn start(&mut self) -> Result<()> {
119         ensure!(self.worker_thread.is_none(), "Worker thread has been running");
120         ensure!(!self.certificate.is_empty(), "certificate is empty");
121         ensure!(!self.private_key.is_empty(), "private_key is empty");
122 
123         // Doing error handling here is much simpler.
124         let params = match self.init_worker_thread_params() {
125             Ok(v) => v,
126             Err(e) => return Err(e.context("init_worker_thread_params failed")),
127         };
128 
129         self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
130         Ok(())
131     }
132 
stop(&mut self) -> Result<()>133     pub fn stop(&mut self) -> Result<()> {
134         debug!("DohFrontend: stopping: {:?}", self);
135         if let Some(worker_thread) = self.worker_thread.take() {
136             // Update latest_stats before stopping worker_thread.
137             let _ = self.request_stats();
138 
139             self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
140             if let Err(e) = self.wait_for_connections_closed() {
141                 warn!("wait_for_connections_closed failed: {}", e);
142             }
143 
144             worker_thread.abort();
145             RUNTIME_STATIC.block_on(async {
146                 debug!("worker_thread result: {:?}", worker_thread.await);
147             })
148         }
149 
150         debug!("DohFrontend: stopped: {:?}", self);
151         Ok(())
152     }
153 
set_certificate(&mut self, certificate: &str) -> Result<()>154     pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
155         self.certificate = certificate.to_string();
156         Ok(())
157     }
158 
set_private_key(&mut self, private_key: &str) -> Result<()>159     pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
160         self.private_key = private_key.to_string();
161         Ok(())
162     }
163 
set_delay_queries(&self, value: i32) -> Result<()>164     pub fn set_delay_queries(&self, value: i32) -> Result<()> {
165         self.config.lock().unwrap().delay_queries = value;
166         Ok(())
167     }
168 
set_max_idle_timeout(&self, value: u64) -> Result<()>169     pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
170         self.config.lock().unwrap().max_idle_timeout = value;
171         Ok(())
172     }
173 
set_max_buffer_size(&self, value: u64) -> Result<()>174     pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
175         self.config.lock().unwrap().max_buffer_size = value;
176         Ok(())
177     }
178 
set_max_streams_bidi(&self, value: u64) -> Result<()>179     pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
180         self.config.lock().unwrap().max_streams_bidi = value;
181         Ok(())
182     }
183 
block_sending(&self, value: bool) -> Result<()>184     pub fn block_sending(&self, value: bool) -> Result<()> {
185         self.config.lock().unwrap().block_sending = value;
186         Ok(())
187     }
188 
set_reset_stream_id(&self, value: u64) -> Result<()>189     pub fn set_reset_stream_id(&self, value: u64) -> Result<()> {
190         self.config.lock().unwrap().reset_stream_id = Some(value);
191         Ok(())
192     }
193 
request_stats(&mut self) -> Result<Stats>194     pub fn request_stats(&mut self) -> Result<Stats> {
195         ensure!(
196             self.command_tx.is_some(),
197             "command_tx is None because worker thread not yet initialized"
198         );
199         let command_tx = self.command_tx.as_ref().unwrap();
200 
201         if command_tx.is_closed() {
202             return Ok(self.latest_stats.clone());
203         }
204 
205         let (resp_tx, resp_rx) = oneshot::channel();
206         command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
207 
208         match RUNTIME_STATIC
209             .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
210         {
211             Ok(v) => match v {
212                 Ok(stats) => {
213                     self.latest_stats = stats.clone();
214                     Ok(stats)
215                 }
216                 Err(e) => bail!(e),
217             },
218             Err(e) => bail!(e),
219         }
220     }
221 
stats_clear_queries(&self) -> Result<()>222     pub fn stats_clear_queries(&self) -> Result<()> {
223         ensure!(
224             self.command_tx.is_some(),
225             "command_tx is None because worker thread not yet initialized"
226         );
227         self.command_tx
228             .as_ref()
229             .unwrap()
230             .send(ControlCommand::StatsClearQueries)
231             .or_else(|e| bail!(e))
232     }
233 
init_worker_thread_params(&mut self) -> Result<WorkerParams>234     fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
235         let bind_addr =
236             if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
237         let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
238         backend_socket.connect(self.backend_socket_addr)?;
239         backend_socket.set_nonblocking(true)?;
240 
241         let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
242         frontend_socket.set_nonblocking(true)?;
243 
244         let clients = ClientMap::new(create_quiche_config(
245             self.certificate.to_string(),
246             self.private_key.to_string(),
247             self.config.clone(),
248         )?)?;
249 
250         let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
251         self.command_tx = Some(command_tx);
252 
253         Ok(WorkerParams {
254             frontend_socket,
255             backend_socket,
256             clients,
257             config: self.config.clone(),
258             command_rx,
259         })
260     }
261 
wait_for_connections_closed(&mut self) -> Result<()>262     fn wait_for_connections_closed(&mut self) -> Result<()> {
263         for _ in 0..3 {
264             std::thread::sleep(Duration::from_millis(50));
265             match self.request_stats() {
266                 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
267                 Ok(_) => (),
268 
269                 // The worker thread is down. No connection is alive.
270                 Err(_) => return Ok(()),
271             }
272         }
273         bail!("Some connections still alive")
274     }
275 }
276 
worker_thread(params: WorkerParams) -> Result<()>277 async fn worker_thread(params: WorkerParams) -> Result<()> {
278     let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
279     let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
280     let config = params.config;
281     let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
282     let mut command_rx = params.command_rx;
283     let mut clients = params.clients;
284     let mut frontend_buf = [0; 65535];
285     let mut backend_buf = [0; 16384];
286     let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
287     let mut queries_received = 0;
288 
289     debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
290 
291     loop {
292         let timeout = clients
293             .iter_mut()
294             .filter_map(|(_, c)| c.timeout())
295             .min()
296             .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
297 
298         tokio::select! {
299             _ = tokio::time::sleep(timeout) => {
300                 debug!("timeout");
301                 for (_, client) in clients.iter_mut() {
302                     // If no timeout has occurred it does nothing.
303                     client.on_timeout();
304 
305                     let connection_id = client.connection_id().clone();
306                     event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
307                 }
308             }
309 
310             Ok((len, peer)) = frontend_socket.recv_from(&mut frontend_buf) => {
311                 debug!("Got {} bytes from {}", len, peer);
312 
313                 // Parse QUIC packet.
314                 let pkt_buf = &mut frontend_buf[..len];
315                 let hdr = match quiche::Header::from_slice(pkt_buf, CONN_ID_LEN) {
316                     Ok(v) => v,
317                     Err(e) => {
318                         error!("Failed to parse QUIC header: {:?}", e);
319                         continue;
320                     }
321                 };
322                 debug!("Got QUIC packet: {:?}", hdr);
323 
324                 let local = frontend_socket.local_addr()?;
325                 let client = match clients.get_or_create(&hdr, &peer, &local) {
326                     Ok(v) => v,
327                     Err(e) => {
328                         error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
329                         continue;
330                     }
331                 };
332                 debug!("Got client: {:?}", client);
333 
334                 match client.handle_frontend_message(pkt_buf, &local) {
335                     Ok(v) if !v.is_empty() => {
336                         delay_queries_buffer.push(v);
337                         queries_received += 1;
338                     }
339                     Err(e) => {
340                         error!("Failed to process QUIC packet: {}", e);
341                         continue;
342                     }
343                     _ => {}
344                 }
345 
346                 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
347                     for query in delay_queries_buffer.drain(..) {
348                         debug!("sending {} bytes to backend", query.len());
349                         backend_socket.send(&query).await?;
350                     }
351                 }
352 
353                 let connection_id = client.connection_id().clone();
354                 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
355             }
356 
357             Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
358                 debug!("Got {} bytes from {}", len, src);
359                 if len < DNS_HEADER_SIZE {
360                     error!("Received insufficient bytes for DNS header");
361                     continue;
362                 }
363 
364                 let query_id = [backend_buf[0], backend_buf[1]];
365                 for (_, client) in clients.iter_mut() {
366                     if client.is_waiting_for_query(&query_id) {
367                         let reset_stream_id = config.lock().unwrap().reset_stream_id;
368                         if let Err(e) = client.handle_backend_message(&backend_buf[..len], reset_stream_id) {
369                             error!("Failed to handle message from backend: {}", e);
370                         }
371                         let connection_id = client.connection_id().clone();
372                         event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
373 
374                         // It's a bug if more than one client is waiting for this query.
375                         break;
376                     }
377                 }
378             }
379 
380             Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
381                 match command {
382                     InternalCommand::MaybeWrite {connection_id} => {
383                         if let Some(client) = clients.get_mut(&connection_id) {
384                             while let Ok(v) = client.flush_egress() {
385                                 let addr = client.addr();
386                                 debug!("Sending {} bytes to client {}", v.len(), addr);
387                                 if let Err(e) = frontend_socket.send_to(&v, addr).await {
388                                     error!("Failed to send packet to {:?}: {:?}", client, e);
389                                 }
390                             }
391                             client.process_pending_answers()?;
392                         }
393                     }
394                 }
395             }
396             Some(command) = command_rx.recv() => {
397                 debug!("ControlCommand: {:?}", command);
398                 match command {
399                     ControlCommand::Stats {resp} => {
400                         let stats = Stats {
401                             queries_received,
402                             connections_accepted: clients.len() as u32,
403                             alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
404                             resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
405                             early_data_connections: clients.iter().filter(|(_, client)| client.handled_early_data()).count() as u32,
406                         };
407                         if let Err(e) = resp.send(stats) {
408                             error!("Failed to send ControlCommand::Stats response: {:?}", e);
409                         }
410                     }
411                     ControlCommand::StatsClearQueries => queries_received = 0,
412                     ControlCommand::CloseConnection => {
413                         for (_, client) in clients.iter_mut() {
414                             client.close();
415                             event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
416                         }
417                     }
418                 }
419             }
420         }
421     }
422 }
423 
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>424 fn create_quiche_config(
425     certificate: String,
426     private_key: String,
427     config: Arc<Mutex<Config>>,
428 ) -> Result<quiche::Config> {
429     let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
430 
431     // Use pipe as a file path for Quiche to read the certificate and the private key.
432     let (rd, mut wr) = build_pipe()?;
433     let handle = std::thread::spawn(move || {
434         wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
435     });
436     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
437     quiche_config.load_cert_chain_from_pem_file(&filepath)?;
438     handle.join().unwrap();
439 
440     let (rd, mut wr) = build_pipe()?;
441     let handle = std::thread::spawn(move || {
442         wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
443     });
444     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
445     quiche_config.load_priv_key_from_pem_file(&filepath)?;
446     handle.join().unwrap();
447 
448     quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
449     quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
450     quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
451 
452     let max_buffer_size = config.lock().unwrap().max_buffer_size;
453     quiche_config.set_initial_max_data(max_buffer_size);
454     quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
455     quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
456     quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
457 
458     quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
459     quiche_config.set_initial_max_streams_uni(100);
460     quiche_config.set_disable_active_migration(true);
461     quiche_config.enable_early_data();
462 
463     Ok(quiche_config)
464 }
465 
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>466 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
467     match UdpSocket::from_std(socket) {
468         Ok(v) => Ok(v),
469         Err(e) => {
470             error!("into_tokio_udp_socket failed: {}", e);
471             bail!("into_tokio_udp_socket failed: {}", e)
472         }
473     }
474 }
475 
build_pipe() -> Result<(File, File)>476 fn build_pipe() -> Result<(File, File)> {
477     let mut fds = [0, 0];
478     // SAFETY: The pointer we pass to `pipe` must be valid because it comes from a reference. The
479     // file descriptors it returns must be valid and open, so they are safe to pass to
480     // `File::from_raw_fd`.
481     unsafe {
482         if libc::pipe(fds.as_mut_ptr()) == 0 {
483             return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
484         }
485     }
486     Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
487 }
488 
489 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>490 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
491     for _ in 0..3 {
492         match std::net::UdpSocket::bind(addr) {
493             Ok(socket) => return Ok(socket),
494             Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
495                 warn!("Binding socket address {} that is in use. Try again", addr);
496                 std::thread::sleep(Duration::from_millis(50));
497             }
498             Err(e) => return Err(anyhow::anyhow!(e)),
499         }
500     }
501     Err(anyhow::anyhow!(std::io::Error::last_os_error()))
502 }
503