1 /*
2 * Copyright (C) 2021, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! DoH server frontend.
18
19 use super::client::{ClientMap, ConnectionID, CONN_ID_LEN, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use super::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use super::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use log::{debug, error, warn};
24 use std::fs::File;
25 use std::io::Write;
26 use std::os::unix::io::{AsRawFd, FromRawFd};
27 use std::sync::{Arc, LazyLock, Mutex};
28 use std::time::Duration;
29 use tokio::net::UdpSocket;
30 use tokio::runtime::{Builder, Runtime};
31 use tokio::sync::{mpsc, oneshot};
32 use tokio::task::JoinHandle;
33
34 static RUNTIME_STATIC: LazyLock<Arc<Runtime>> = LazyLock::new(|| {
35 Arc::new(
36 Builder::new_multi_thread()
37 .worker_threads(1)
38 .enable_all()
39 .thread_name("DohFrontend")
40 .build()
41 .expect("Failed to create tokio runtime"),
42 )
43 });
44
45 /// Command used by worker_thread itself.
46 #[derive(Debug)]
47 enum InternalCommand {
48 MaybeWrite { connection_id: ConnectionID },
49 }
50
51 /// Commands that DohFrontend to ask its worker_thread for.
52 #[derive(Debug)]
53 enum ControlCommand {
54 Stats { resp: oneshot::Sender<Stats> },
55 StatsClearQueries,
56 CloseConnection,
57 }
58
59 /// Frontend object.
60 #[derive(Debug)]
61 pub struct DohFrontend {
62 // Socket address the frontend listens to.
63 listen_socket_addr: std::net::SocketAddr,
64
65 // Socket address the backend listens to.
66 backend_socket_addr: std::net::SocketAddr,
67
68 /// The content of the certificate.
69 certificate: String,
70
71 /// The content of the private key.
72 private_key: String,
73
74 // The thread listening to frontend socket and backend socket
75 // and processing the messages.
76 worker_thread: Option<JoinHandle<Result<()>>>,
77
78 // Custom runtime configuration to control the behavior of the worker thread.
79 // It's shared with the worker thread.
80 // TODO: use channel to update worker_thread configuration.
81 config: Arc<Mutex<Config>>,
82
83 // Caches the latest stats so that the stats remains after worker_thread stops.
84 latest_stats: Stats,
85
86 // It is wrapped as Option because the channel is not created in DohFrontend construction.
87 command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
88 }
89
90 /// The parameters passed to the worker thread.
91 struct WorkerParams {
92 frontend_socket: std::net::UdpSocket,
93 backend_socket: std::net::UdpSocket,
94 clients: ClientMap,
95 config: Arc<Mutex<Config>>,
96 command_rx: mpsc::UnboundedReceiver<ControlCommand>,
97 }
98
99 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>100 pub fn new(
101 listen: std::net::SocketAddr,
102 backend: std::net::SocketAddr,
103 ) -> Result<Box<DohFrontend>> {
104 let doh = Box::new(DohFrontend {
105 listen_socket_addr: listen,
106 backend_socket_addr: backend,
107 certificate: String::new(),
108 private_key: String::new(),
109 worker_thread: None,
110 config: Arc::new(Mutex::new(Config::new())),
111 latest_stats: Stats::new(),
112 command_tx: None,
113 });
114 debug!("DohFrontend created: {:?}", doh);
115 Ok(doh)
116 }
117
start(&mut self) -> Result<()>118 pub fn start(&mut self) -> Result<()> {
119 ensure!(self.worker_thread.is_none(), "Worker thread has been running");
120 ensure!(!self.certificate.is_empty(), "certificate is empty");
121 ensure!(!self.private_key.is_empty(), "private_key is empty");
122
123 // Doing error handling here is much simpler.
124 let params = match self.init_worker_thread_params() {
125 Ok(v) => v,
126 Err(e) => return Err(e.context("init_worker_thread_params failed")),
127 };
128
129 self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
130 Ok(())
131 }
132
stop(&mut self) -> Result<()>133 pub fn stop(&mut self) -> Result<()> {
134 debug!("DohFrontend: stopping: {:?}", self);
135 if let Some(worker_thread) = self.worker_thread.take() {
136 // Update latest_stats before stopping worker_thread.
137 let _ = self.request_stats();
138
139 self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
140 if let Err(e) = self.wait_for_connections_closed() {
141 warn!("wait_for_connections_closed failed: {}", e);
142 }
143
144 worker_thread.abort();
145 RUNTIME_STATIC.block_on(async {
146 debug!("worker_thread result: {:?}", worker_thread.await);
147 })
148 }
149
150 debug!("DohFrontend: stopped: {:?}", self);
151 Ok(())
152 }
153
set_certificate(&mut self, certificate: &str) -> Result<()>154 pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
155 self.certificate = certificate.to_string();
156 Ok(())
157 }
158
set_private_key(&mut self, private_key: &str) -> Result<()>159 pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
160 self.private_key = private_key.to_string();
161 Ok(())
162 }
163
set_delay_queries(&self, value: i32) -> Result<()>164 pub fn set_delay_queries(&self, value: i32) -> Result<()> {
165 self.config.lock().unwrap().delay_queries = value;
166 Ok(())
167 }
168
set_max_idle_timeout(&self, value: u64) -> Result<()>169 pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
170 self.config.lock().unwrap().max_idle_timeout = value;
171 Ok(())
172 }
173
set_max_buffer_size(&self, value: u64) -> Result<()>174 pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
175 self.config.lock().unwrap().max_buffer_size = value;
176 Ok(())
177 }
178
set_max_streams_bidi(&self, value: u64) -> Result<()>179 pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
180 self.config.lock().unwrap().max_streams_bidi = value;
181 Ok(())
182 }
183
block_sending(&self, value: bool) -> Result<()>184 pub fn block_sending(&self, value: bool) -> Result<()> {
185 self.config.lock().unwrap().block_sending = value;
186 Ok(())
187 }
188
set_reset_stream_id(&self, value: u64) -> Result<()>189 pub fn set_reset_stream_id(&self, value: u64) -> Result<()> {
190 self.config.lock().unwrap().reset_stream_id = Some(value);
191 Ok(())
192 }
193
request_stats(&mut self) -> Result<Stats>194 pub fn request_stats(&mut self) -> Result<Stats> {
195 ensure!(
196 self.command_tx.is_some(),
197 "command_tx is None because worker thread not yet initialized"
198 );
199 let command_tx = self.command_tx.as_ref().unwrap();
200
201 if command_tx.is_closed() {
202 return Ok(self.latest_stats.clone());
203 }
204
205 let (resp_tx, resp_rx) = oneshot::channel();
206 command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
207
208 match RUNTIME_STATIC
209 .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
210 {
211 Ok(v) => match v {
212 Ok(stats) => {
213 self.latest_stats = stats.clone();
214 Ok(stats)
215 }
216 Err(e) => bail!(e),
217 },
218 Err(e) => bail!(e),
219 }
220 }
221
stats_clear_queries(&self) -> Result<()>222 pub fn stats_clear_queries(&self) -> Result<()> {
223 ensure!(
224 self.command_tx.is_some(),
225 "command_tx is None because worker thread not yet initialized"
226 );
227 self.command_tx
228 .as_ref()
229 .unwrap()
230 .send(ControlCommand::StatsClearQueries)
231 .or_else(|e| bail!(e))
232 }
233
init_worker_thread_params(&mut self) -> Result<WorkerParams>234 fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
235 let bind_addr =
236 if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
237 let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
238 backend_socket.connect(self.backend_socket_addr)?;
239 backend_socket.set_nonblocking(true)?;
240
241 let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
242 frontend_socket.set_nonblocking(true)?;
243
244 let clients = ClientMap::new(create_quiche_config(
245 self.certificate.to_string(),
246 self.private_key.to_string(),
247 self.config.clone(),
248 )?)?;
249
250 let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
251 self.command_tx = Some(command_tx);
252
253 Ok(WorkerParams {
254 frontend_socket,
255 backend_socket,
256 clients,
257 config: self.config.clone(),
258 command_rx,
259 })
260 }
261
wait_for_connections_closed(&mut self) -> Result<()>262 fn wait_for_connections_closed(&mut self) -> Result<()> {
263 for _ in 0..3 {
264 std::thread::sleep(Duration::from_millis(50));
265 match self.request_stats() {
266 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
267 Ok(_) => (),
268
269 // The worker thread is down. No connection is alive.
270 Err(_) => return Ok(()),
271 }
272 }
273 bail!("Some connections still alive")
274 }
275 }
276
worker_thread(params: WorkerParams) -> Result<()>277 async fn worker_thread(params: WorkerParams) -> Result<()> {
278 let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
279 let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
280 let config = params.config;
281 let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
282 let mut command_rx = params.command_rx;
283 let mut clients = params.clients;
284 let mut frontend_buf = [0; 65535];
285 let mut backend_buf = [0; 16384];
286 let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
287 let mut queries_received = 0;
288
289 debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
290
291 loop {
292 let timeout = clients
293 .iter_mut()
294 .filter_map(|(_, c)| c.timeout())
295 .min()
296 .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
297
298 tokio::select! {
299 _ = tokio::time::sleep(timeout) => {
300 debug!("timeout");
301 for (_, client) in clients.iter_mut() {
302 // If no timeout has occurred it does nothing.
303 client.on_timeout();
304
305 let connection_id = client.connection_id().clone();
306 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
307 }
308 }
309
310 Ok((len, peer)) = frontend_socket.recv_from(&mut frontend_buf) => {
311 debug!("Got {} bytes from {}", len, peer);
312
313 // Parse QUIC packet.
314 let pkt_buf = &mut frontend_buf[..len];
315 let hdr = match quiche::Header::from_slice(pkt_buf, CONN_ID_LEN) {
316 Ok(v) => v,
317 Err(e) => {
318 error!("Failed to parse QUIC header: {:?}", e);
319 continue;
320 }
321 };
322 debug!("Got QUIC packet: {:?}", hdr);
323
324 let local = frontend_socket.local_addr()?;
325 let client = match clients.get_or_create(&hdr, &peer, &local) {
326 Ok(v) => v,
327 Err(e) => {
328 error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
329 continue;
330 }
331 };
332 debug!("Got client: {:?}", client);
333
334 match client.handle_frontend_message(pkt_buf, &local) {
335 Ok(v) if !v.is_empty() => {
336 delay_queries_buffer.push(v);
337 queries_received += 1;
338 }
339 Err(e) => {
340 error!("Failed to process QUIC packet: {}", e);
341 continue;
342 }
343 _ => {}
344 }
345
346 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
347 for query in delay_queries_buffer.drain(..) {
348 debug!("sending {} bytes to backend", query.len());
349 backend_socket.send(&query).await?;
350 }
351 }
352
353 let connection_id = client.connection_id().clone();
354 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
355 }
356
357 Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
358 debug!("Got {} bytes from {}", len, src);
359 if len < DNS_HEADER_SIZE {
360 error!("Received insufficient bytes for DNS header");
361 continue;
362 }
363
364 let query_id = [backend_buf[0], backend_buf[1]];
365 for (_, client) in clients.iter_mut() {
366 if client.is_waiting_for_query(&query_id) {
367 let reset_stream_id = config.lock().unwrap().reset_stream_id;
368 if let Err(e) = client.handle_backend_message(&backend_buf[..len], reset_stream_id) {
369 error!("Failed to handle message from backend: {}", e);
370 }
371 let connection_id = client.connection_id().clone();
372 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
373
374 // It's a bug if more than one client is waiting for this query.
375 break;
376 }
377 }
378 }
379
380 Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
381 match command {
382 InternalCommand::MaybeWrite {connection_id} => {
383 if let Some(client) = clients.get_mut(&connection_id) {
384 while let Ok(v) = client.flush_egress() {
385 let addr = client.addr();
386 debug!("Sending {} bytes to client {}", v.len(), addr);
387 if let Err(e) = frontend_socket.send_to(&v, addr).await {
388 error!("Failed to send packet to {:?}: {:?}", client, e);
389 }
390 }
391 client.process_pending_answers()?;
392 }
393 }
394 }
395 }
396 Some(command) = command_rx.recv() => {
397 debug!("ControlCommand: {:?}", command);
398 match command {
399 ControlCommand::Stats {resp} => {
400 let stats = Stats {
401 queries_received,
402 connections_accepted: clients.len() as u32,
403 alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
404 resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
405 early_data_connections: clients.iter().filter(|(_, client)| client.handled_early_data()).count() as u32,
406 };
407 if let Err(e) = resp.send(stats) {
408 error!("Failed to send ControlCommand::Stats response: {:?}", e);
409 }
410 }
411 ControlCommand::StatsClearQueries => queries_received = 0,
412 ControlCommand::CloseConnection => {
413 for (_, client) in clients.iter_mut() {
414 client.close();
415 event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
416 }
417 }
418 }
419 }
420 }
421 }
422 }
423
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>424 fn create_quiche_config(
425 certificate: String,
426 private_key: String,
427 config: Arc<Mutex<Config>>,
428 ) -> Result<quiche::Config> {
429 let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
430
431 // Use pipe as a file path for Quiche to read the certificate and the private key.
432 let (rd, mut wr) = build_pipe()?;
433 let handle = std::thread::spawn(move || {
434 wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
435 });
436 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
437 quiche_config.load_cert_chain_from_pem_file(&filepath)?;
438 handle.join().unwrap();
439
440 let (rd, mut wr) = build_pipe()?;
441 let handle = std::thread::spawn(move || {
442 wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
443 });
444 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
445 quiche_config.load_priv_key_from_pem_file(&filepath)?;
446 handle.join().unwrap();
447
448 quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
449 quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
450 quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
451
452 let max_buffer_size = config.lock().unwrap().max_buffer_size;
453 quiche_config.set_initial_max_data(max_buffer_size);
454 quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
455 quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
456 quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
457
458 quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
459 quiche_config.set_initial_max_streams_uni(100);
460 quiche_config.set_disable_active_migration(true);
461 quiche_config.enable_early_data();
462
463 Ok(quiche_config)
464 }
465
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>466 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
467 match UdpSocket::from_std(socket) {
468 Ok(v) => Ok(v),
469 Err(e) => {
470 error!("into_tokio_udp_socket failed: {}", e);
471 bail!("into_tokio_udp_socket failed: {}", e)
472 }
473 }
474 }
475
build_pipe() -> Result<(File, File)>476 fn build_pipe() -> Result<(File, File)> {
477 let mut fds = [0, 0];
478 // SAFETY: The pointer we pass to `pipe` must be valid because it comes from a reference. The
479 // file descriptors it returns must be valid and open, so they are safe to pass to
480 // `File::from_raw_fd`.
481 unsafe {
482 if libc::pipe(fds.as_mut_ptr()) == 0 {
483 return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
484 }
485 }
486 Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
487 }
488
489 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>490 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
491 for _ in 0..3 {
492 match std::net::UdpSocket::bind(addr) {
493 Ok(socket) => return Ok(socket),
494 Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
495 warn!("Binding socket address {} that is in use. Try again", addr);
496 std::thread::sleep(Duration::from_millis(50));
497 }
498 Err(e) => return Err(anyhow::anyhow!(e)),
499 }
500 }
501 Err(anyhow::anyhow!(std::io::Error::last_os_error()))
502 }
503