• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! DoH server frontend.
18 
19 use crate::client::{ClientMap, ConnectionID, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use crate::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use crate::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use lazy_static::lazy_static;
24 use log::{debug, error, warn};
25 use std::fs::File;
26 use std::io::Write;
27 use std::os::unix::io::{AsRawFd, FromRawFd};
28 use std::sync::{Arc, Mutex};
29 use std::time::Duration;
30 use tokio::net::UdpSocket;
31 use tokio::runtime::{Builder, Runtime};
32 use tokio::sync::{mpsc, oneshot};
33 use tokio::task::JoinHandle;
34 
35 lazy_static! {
36     static ref RUNTIME_STATIC: Arc<Runtime> = Arc::new(
37         Builder::new_multi_thread()
38             .worker_threads(2)
39             .max_blocking_threads(1)
40             .enable_all()
41             .thread_name("DohFrontend")
42             .build()
43             .expect("Failed to create tokio runtime")
44     );
45 }
46 
47 /// Command used by worker_thread itself.
48 #[derive(Debug)]
49 enum InternalCommand {
50     MaybeWrite { connection_id: ConnectionID },
51 }
52 
53 /// Commands that DohFrontend to ask its worker_thread for.
54 #[derive(Debug)]
55 enum ControlCommand {
56     Stats { resp: oneshot::Sender<Stats> },
57     StatsClearQueries,
58     CloseConnection,
59 }
60 
61 /// Frontend object.
62 #[derive(Debug)]
63 pub struct DohFrontend {
64     // Socket address the frontend listens to.
65     listen_socket_addr: std::net::SocketAddr,
66 
67     // Socket address the backend listens to.
68     backend_socket_addr: std::net::SocketAddr,
69 
70     /// The content of the certificate.
71     certificate: String,
72 
73     /// The content of the private key.
74     private_key: String,
75 
76     // The thread listening to frontend socket and backend socket
77     // and processing the messages.
78     worker_thread: Option<JoinHandle<Result<()>>>,
79 
80     // Custom runtime configuration to control the behavior of the worker thread.
81     // It's shared with the worker thread.
82     // TODO: use channel to update worker_thread configuration.
83     config: Arc<Mutex<Config>>,
84 
85     // Caches the latest stats so that the stats remains after worker_thread stops.
86     latest_stats: Stats,
87 
88     // It is wrapped as Option because the channel is not created in DohFrontend construction.
89     command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
90 }
91 
92 /// The parameters passed to the worker thread.
93 struct WorkerParams {
94     frontend_socket: std::net::UdpSocket,
95     backend_socket: std::net::UdpSocket,
96     clients: ClientMap,
97     config: Arc<Mutex<Config>>,
98     command_rx: mpsc::UnboundedReceiver<ControlCommand>,
99 }
100 
101 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>102     pub fn new(
103         listen: std::net::SocketAddr,
104         backend: std::net::SocketAddr,
105     ) -> Result<Box<DohFrontend>> {
106         let doh = Box::new(DohFrontend {
107             listen_socket_addr: listen,
108             backend_socket_addr: backend,
109             certificate: String::new(),
110             private_key: String::new(),
111             worker_thread: None,
112             config: Arc::new(Mutex::new(Config::new())),
113             latest_stats: Stats::new(),
114             command_tx: None,
115         });
116         debug!("DohFrontend created: {:?}", doh);
117         Ok(doh)
118     }
119 
start(&mut self) -> Result<()>120     pub fn start(&mut self) -> Result<()> {
121         ensure!(self.worker_thread.is_none(), "Worker thread has been running");
122         ensure!(!self.certificate.is_empty(), "certificate is empty");
123         ensure!(!self.private_key.is_empty(), "private_key is empty");
124 
125         // Doing error handling here is much simpler.
126         let params = match self.init_worker_thread_params() {
127             Ok(v) => v,
128             Err(e) => return Err(e.context("init_worker_thread_params failed")),
129         };
130 
131         self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
132         Ok(())
133     }
134 
stop(&mut self) -> Result<()>135     pub fn stop(&mut self) -> Result<()> {
136         debug!("DohFrontend: stopping: {:?}", self);
137         if let Some(worker_thread) = self.worker_thread.take() {
138             // Update latest_stats before stopping worker_thread.
139             let _ = self.request_stats();
140 
141             self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
142             if let Err(e) = self.wait_for_connections_closed() {
143                 warn!("wait_for_connections_closed failed: {}", e);
144             }
145 
146             worker_thread.abort();
147         }
148 
149         debug!("DohFrontend: stopped: {:?}", self);
150         Ok(())
151     }
152 
set_certificate(&mut self, certificate: &str) -> Result<()>153     pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
154         self.certificate = certificate.to_string();
155         Ok(())
156     }
157 
set_private_key(&mut self, private_key: &str) -> Result<()>158     pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
159         self.private_key = private_key.to_string();
160         Ok(())
161     }
162 
set_delay_queries(&self, value: i32) -> Result<()>163     pub fn set_delay_queries(&self, value: i32) -> Result<()> {
164         self.config.lock().unwrap().delay_queries = value;
165         Ok(())
166     }
167 
set_max_idle_timeout(&self, value: u64) -> Result<()>168     pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
169         self.config.lock().unwrap().max_idle_timeout = value;
170         Ok(())
171     }
172 
set_max_buffer_size(&self, value: u64) -> Result<()>173     pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
174         self.config.lock().unwrap().max_buffer_size = value;
175         Ok(())
176     }
177 
set_max_streams_bidi(&self, value: u64) -> Result<()>178     pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
179         self.config.lock().unwrap().max_streams_bidi = value;
180         Ok(())
181     }
182 
block_sending(&self, value: bool) -> Result<()>183     pub fn block_sending(&self, value: bool) -> Result<()> {
184         self.config.lock().unwrap().block_sending = value;
185         Ok(())
186     }
187 
request_stats(&mut self) -> Result<Stats>188     pub fn request_stats(&mut self) -> Result<Stats> {
189         ensure!(
190             self.command_tx.is_some(),
191             "command_tx is None because worker thread not yet initialized"
192         );
193         let command_tx = self.command_tx.as_ref().unwrap();
194 
195         if command_tx.is_closed() {
196             return Ok(self.latest_stats.clone());
197         }
198 
199         let (resp_tx, resp_rx) = oneshot::channel();
200         command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
201 
202         match RUNTIME_STATIC
203             .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
204         {
205             Ok(v) => match v {
206                 Ok(stats) => {
207                     self.latest_stats = stats.clone();
208                     Ok(stats)
209                 }
210                 Err(e) => bail!(e),
211             },
212             Err(e) => bail!(e),
213         }
214     }
215 
stats_clear_queries(&self) -> Result<()>216     pub fn stats_clear_queries(&self) -> Result<()> {
217         ensure!(
218             self.command_tx.is_some(),
219             "command_tx is None because worker thread not yet initialized"
220         );
221         return self
222             .command_tx
223             .as_ref()
224             .unwrap()
225             .send(ControlCommand::StatsClearQueries)
226             .or_else(|e| bail!(e));
227     }
228 
init_worker_thread_params(&mut self) -> Result<WorkerParams>229     fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
230         let bind_addr =
231             if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
232         let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
233         backend_socket.connect(self.backend_socket_addr)?;
234         backend_socket.set_nonblocking(true)?;
235 
236         let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
237         frontend_socket.set_nonblocking(true)?;
238 
239         let clients = ClientMap::new(create_quiche_config(
240             self.certificate.to_string(),
241             self.private_key.to_string(),
242             self.config.clone(),
243         )?)?;
244 
245         let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
246         self.command_tx = Some(command_tx);
247 
248         Ok(WorkerParams {
249             frontend_socket,
250             backend_socket,
251             clients,
252             config: self.config.clone(),
253             command_rx,
254         })
255     }
256 
wait_for_connections_closed(&mut self) -> Result<()>257     fn wait_for_connections_closed(&mut self) -> Result<()> {
258         for _ in 0..3 {
259             std::thread::sleep(Duration::from_millis(50));
260             match self.request_stats() {
261                 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
262                 Ok(_) => (),
263 
264                 // The worker thread is down. No connection is alive.
265                 Err(_) => return Ok(()),
266             }
267         }
268         bail!("Some connections still alive")
269     }
270 }
271 
worker_thread(params: WorkerParams) -> Result<()>272 async fn worker_thread(params: WorkerParams) -> Result<()> {
273     let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
274     let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
275     let config = params.config;
276     let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
277     let mut command_rx = params.command_rx;
278     let mut clients = params.clients;
279     let mut frontend_buf = [0; 65535];
280     let mut backend_buf = [0; 16384];
281     let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
282     let mut queries_received = 0;
283 
284     debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
285 
286     loop {
287         let timeout = clients
288             .iter_mut()
289             .filter_map(|(_, c)| c.timeout())
290             .min()
291             .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
292 
293         tokio::select! {
294             _ = tokio::time::sleep(timeout) => {
295                 debug!("timeout");
296                 for (_, client) in clients.iter_mut() {
297                     // If no timeout has occurred it does nothing.
298                     client.on_timeout();
299 
300                     let connection_id = client.connection_id().clone();
301                     event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
302                 }
303             }
304 
305             Ok((len, src)) = frontend_socket.recv_from(&mut frontend_buf) => {
306                 debug!("Got {} bytes from {}", len, src);
307 
308                 // Parse QUIC packet.
309                 let pkt_buf = &mut frontend_buf[..len];
310                 let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) {
311                     Ok(v) => v,
312                     Err(e) => {
313                         error!("Failed to parse QUIC header: {:?}", e);
314                         continue;
315                     }
316                 };
317                 debug!("Got QUIC packet: {:?}", hdr);
318 
319                 let client = match clients.get_or_create(&hdr, &src) {
320                     Ok(v) => v,
321                     Err(e) => {
322                         error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
323                         continue;
324                     }
325                 };
326                 debug!("Got client: {:?}", client);
327 
328                 match client.handle_frontend_message(pkt_buf) {
329                     Ok(v) if !v.is_empty() => {
330                         delay_queries_buffer.push(v);
331                         queries_received += 1;
332                     }
333                     Err(e) => {
334                         error!("Failed to process QUIC packet: {}", e);
335                         continue;
336                     }
337                     _ => {}
338                 }
339 
340                 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
341                     for query in delay_queries_buffer.drain(..) {
342                         debug!("sending {} bytes to backend", query.len());
343                         backend_socket.send(&query).await?;
344                     }
345                 }
346 
347                 let connection_id = client.connection_id().clone();
348                 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
349             }
350 
351             Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
352                 debug!("Got {} bytes from {}", len, src);
353                 if len < DNS_HEADER_SIZE {
354                     error!("Received insufficient bytes for DNS header");
355                     continue;
356                 }
357 
358                 let query_id = [backend_buf[0], backend_buf[1]];
359                 for (_, client) in clients.iter_mut() {
360                     if client.is_waiting_for_query(&query_id) {
361                         if let Err(e) = client.handle_backend_message(&backend_buf[..len]) {
362                             error!("Failed to handle message from backend: {}", e);
363                         }
364                         let connection_id = client.connection_id().clone();
365                         event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
366 
367                         // It's a bug if more than one client is waiting for this query.
368                         break;
369                     }
370                 }
371             }
372 
373             Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
374                 match command {
375                     InternalCommand::MaybeWrite {connection_id} => {
376                         if let Some(client) = clients.get_mut(&connection_id) {
377                             while let Ok(v) = client.flush_egress() {
378                                 let addr = client.addr();
379                                 debug!("Sending {} bytes to client {}", v.len(), addr);
380                                 if let Err(e) = frontend_socket.send_to(&v, addr).await {
381                                     error!("Failed to send packet to {:?}: {:?}", client, e);
382                                 }
383                             }
384                             client.process_pending_answers()?;
385                         }
386                     }
387                 }
388             }
389             Some(command) = command_rx.recv() => {
390                 debug!("ControlCommand: {:?}", command);
391                 match command {
392                     ControlCommand::Stats {resp} => {
393                         let stats = Stats {
394                             queries_received,
395                             connections_accepted: clients.len() as u32,
396                             alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
397                             resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
398                         };
399                         if let Err(e) = resp.send(stats) {
400                             error!("Failed to send ControlCommand::Stats response: {:?}", e);
401                         }
402                     }
403                     ControlCommand::StatsClearQueries => queries_received = 0,
404                     ControlCommand::CloseConnection => {
405                         for (_, client) in clients.iter_mut() {
406                             client.close();
407                             event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
408                         }
409                     }
410                 }
411             }
412         }
413     }
414 }
415 
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>416 fn create_quiche_config(
417     certificate: String,
418     private_key: String,
419     config: Arc<Mutex<Config>>,
420 ) -> Result<quiche::Config> {
421     let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
422 
423     // Use pipe as a file path for Quiche to read the certificate and the private key.
424     let (rd, mut wr) = build_pipe()?;
425     let handle = std::thread::spawn(move || {
426         wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
427     });
428     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
429     quiche_config.load_cert_chain_from_pem_file(&filepath)?;
430     handle.join().unwrap();
431 
432     let (rd, mut wr) = build_pipe()?;
433     let handle = std::thread::spawn(move || {
434         wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
435     });
436     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
437     quiche_config.load_priv_key_from_pem_file(&filepath)?;
438     handle.join().unwrap();
439 
440     quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
441     quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
442     quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
443 
444     let max_buffer_size = config.lock().unwrap().max_buffer_size;
445     quiche_config.set_initial_max_data(max_buffer_size);
446     quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
447     quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
448     quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
449 
450     quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
451     quiche_config.set_initial_max_streams_uni(100);
452     quiche_config.set_disable_active_migration(true);
453 
454     Ok(quiche_config)
455 }
456 
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>457 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
458     match UdpSocket::from_std(socket) {
459         Ok(v) => Ok(v),
460         Err(e) => {
461             error!("into_tokio_udp_socket failed: {}", e);
462             bail!("into_tokio_udp_socket failed: {}", e)
463         }
464     }
465 }
466 
build_pipe() -> Result<(File, File)>467 fn build_pipe() -> Result<(File, File)> {
468     let mut fds = [0, 0];
469     unsafe {
470         if libc::pipe(fds.as_mut_ptr()) == 0 {
471             return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
472         }
473     }
474     Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
475 }
476 
477 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>478 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
479     for _ in 0..3 {
480         match std::net::UdpSocket::bind(addr) {
481             Ok(socket) => return Ok(socket),
482             Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
483                 warn!("Binding socket address {} that is in use. Try again", addr);
484                 std::thread::sleep(Duration::from_millis(50));
485             }
486             Err(e) => return Err(anyhow::anyhow!(e)),
487         }
488     }
489     Err(anyhow::anyhow!(std::io::Error::last_os_error()))
490 }
491