1 /*
2 * Copyright (C) 2021, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! DoH server frontend.
18
19 use crate::client::{ClientMap, ConnectionID, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use crate::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use crate::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use lazy_static::lazy_static;
24 use log::{debug, error, warn};
25 use std::fs::File;
26 use std::io::Write;
27 use std::os::unix::io::{AsRawFd, FromRawFd};
28 use std::sync::{Arc, Mutex};
29 use std::time::Duration;
30 use tokio::net::UdpSocket;
31 use tokio::runtime::{Builder, Runtime};
32 use tokio::sync::{mpsc, oneshot};
33 use tokio::task::JoinHandle;
34
35 lazy_static! {
36 static ref RUNTIME_STATIC: Arc<Runtime> = Arc::new(
37 Builder::new_multi_thread()
38 .worker_threads(2)
39 .max_blocking_threads(1)
40 .enable_all()
41 .thread_name("DohFrontend")
42 .build()
43 .expect("Failed to create tokio runtime")
44 );
45 }
46
47 /// Command used by worker_thread itself.
48 #[derive(Debug)]
49 enum InternalCommand {
50 MaybeWrite { connection_id: ConnectionID },
51 }
52
53 /// Commands that DohFrontend to ask its worker_thread for.
54 #[derive(Debug)]
55 enum ControlCommand {
56 Stats { resp: oneshot::Sender<Stats> },
57 StatsClearQueries,
58 CloseConnection,
59 }
60
61 /// Frontend object.
62 #[derive(Debug)]
63 pub struct DohFrontend {
64 // Socket address the frontend listens to.
65 listen_socket_addr: std::net::SocketAddr,
66
67 // Socket address the backend listens to.
68 backend_socket_addr: std::net::SocketAddr,
69
70 /// The content of the certificate.
71 certificate: String,
72
73 /// The content of the private key.
74 private_key: String,
75
76 // The thread listening to frontend socket and backend socket
77 // and processing the messages.
78 worker_thread: Option<JoinHandle<Result<()>>>,
79
80 // Custom runtime configuration to control the behavior of the worker thread.
81 // It's shared with the worker thread.
82 // TODO: use channel to update worker_thread configuration.
83 config: Arc<Mutex<Config>>,
84
85 // Caches the latest stats so that the stats remains after worker_thread stops.
86 latest_stats: Stats,
87
88 // It is wrapped as Option because the channel is not created in DohFrontend construction.
89 command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
90 }
91
92 /// The parameters passed to the worker thread.
93 struct WorkerParams {
94 frontend_socket: std::net::UdpSocket,
95 backend_socket: std::net::UdpSocket,
96 clients: ClientMap,
97 config: Arc<Mutex<Config>>,
98 command_rx: mpsc::UnboundedReceiver<ControlCommand>,
99 }
100
101 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>102 pub fn new(
103 listen: std::net::SocketAddr,
104 backend: std::net::SocketAddr,
105 ) -> Result<Box<DohFrontend>> {
106 let doh = Box::new(DohFrontend {
107 listen_socket_addr: listen,
108 backend_socket_addr: backend,
109 certificate: String::new(),
110 private_key: String::new(),
111 worker_thread: None,
112 config: Arc::new(Mutex::new(Config::new())),
113 latest_stats: Stats::new(),
114 command_tx: None,
115 });
116 debug!("DohFrontend created: {:?}", doh);
117 Ok(doh)
118 }
119
start(&mut self) -> Result<()>120 pub fn start(&mut self) -> Result<()> {
121 ensure!(self.worker_thread.is_none(), "Worker thread has been running");
122 ensure!(!self.certificate.is_empty(), "certificate is empty");
123 ensure!(!self.private_key.is_empty(), "private_key is empty");
124
125 // Doing error handling here is much simpler.
126 let params = match self.init_worker_thread_params() {
127 Ok(v) => v,
128 Err(e) => return Err(e.context("init_worker_thread_params failed")),
129 };
130
131 self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
132 Ok(())
133 }
134
stop(&mut self) -> Result<()>135 pub fn stop(&mut self) -> Result<()> {
136 debug!("DohFrontend: stopping: {:?}", self);
137 if let Some(worker_thread) = self.worker_thread.take() {
138 // Update latest_stats before stopping worker_thread.
139 let _ = self.request_stats();
140
141 self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
142 if let Err(e) = self.wait_for_connections_closed() {
143 warn!("wait_for_connections_closed failed: {}", e);
144 }
145
146 worker_thread.abort();
147 }
148
149 debug!("DohFrontend: stopped: {:?}", self);
150 Ok(())
151 }
152
set_certificate(&mut self, certificate: &str) -> Result<()>153 pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
154 self.certificate = certificate.to_string();
155 Ok(())
156 }
157
set_private_key(&mut self, private_key: &str) -> Result<()>158 pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
159 self.private_key = private_key.to_string();
160 Ok(())
161 }
162
set_delay_queries(&self, value: i32) -> Result<()>163 pub fn set_delay_queries(&self, value: i32) -> Result<()> {
164 self.config.lock().unwrap().delay_queries = value;
165 Ok(())
166 }
167
set_max_idle_timeout(&self, value: u64) -> Result<()>168 pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
169 self.config.lock().unwrap().max_idle_timeout = value;
170 Ok(())
171 }
172
set_max_buffer_size(&self, value: u64) -> Result<()>173 pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
174 self.config.lock().unwrap().max_buffer_size = value;
175 Ok(())
176 }
177
set_max_streams_bidi(&self, value: u64) -> Result<()>178 pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
179 self.config.lock().unwrap().max_streams_bidi = value;
180 Ok(())
181 }
182
block_sending(&self, value: bool) -> Result<()>183 pub fn block_sending(&self, value: bool) -> Result<()> {
184 self.config.lock().unwrap().block_sending = value;
185 Ok(())
186 }
187
request_stats(&mut self) -> Result<Stats>188 pub fn request_stats(&mut self) -> Result<Stats> {
189 ensure!(
190 self.command_tx.is_some(),
191 "command_tx is None because worker thread not yet initialized"
192 );
193 let command_tx = self.command_tx.as_ref().unwrap();
194
195 if command_tx.is_closed() {
196 return Ok(self.latest_stats.clone());
197 }
198
199 let (resp_tx, resp_rx) = oneshot::channel();
200 command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
201
202 match RUNTIME_STATIC
203 .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
204 {
205 Ok(v) => match v {
206 Ok(stats) => {
207 self.latest_stats = stats.clone();
208 Ok(stats)
209 }
210 Err(e) => bail!(e),
211 },
212 Err(e) => bail!(e),
213 }
214 }
215
stats_clear_queries(&self) -> Result<()>216 pub fn stats_clear_queries(&self) -> Result<()> {
217 ensure!(
218 self.command_tx.is_some(),
219 "command_tx is None because worker thread not yet initialized"
220 );
221 return self
222 .command_tx
223 .as_ref()
224 .unwrap()
225 .send(ControlCommand::StatsClearQueries)
226 .or_else(|e| bail!(e));
227 }
228
init_worker_thread_params(&mut self) -> Result<WorkerParams>229 fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
230 let bind_addr =
231 if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
232 let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
233 backend_socket.connect(self.backend_socket_addr)?;
234 backend_socket.set_nonblocking(true)?;
235
236 let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
237 frontend_socket.set_nonblocking(true)?;
238
239 let clients = ClientMap::new(create_quiche_config(
240 self.certificate.to_string(),
241 self.private_key.to_string(),
242 self.config.clone(),
243 )?)?;
244
245 let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
246 self.command_tx = Some(command_tx);
247
248 Ok(WorkerParams {
249 frontend_socket,
250 backend_socket,
251 clients,
252 config: self.config.clone(),
253 command_rx,
254 })
255 }
256
wait_for_connections_closed(&mut self) -> Result<()>257 fn wait_for_connections_closed(&mut self) -> Result<()> {
258 for _ in 0..3 {
259 std::thread::sleep(Duration::from_millis(50));
260 match self.request_stats() {
261 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
262 Ok(_) => (),
263
264 // The worker thread is down. No connection is alive.
265 Err(_) => return Ok(()),
266 }
267 }
268 bail!("Some connections still alive")
269 }
270 }
271
worker_thread(params: WorkerParams) -> Result<()>272 async fn worker_thread(params: WorkerParams) -> Result<()> {
273 let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
274 let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
275 let config = params.config;
276 let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
277 let mut command_rx = params.command_rx;
278 let mut clients = params.clients;
279 let mut frontend_buf = [0; 65535];
280 let mut backend_buf = [0; 16384];
281 let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
282 let mut queries_received = 0;
283
284 debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
285
286 loop {
287 let timeout = clients
288 .iter_mut()
289 .filter_map(|(_, c)| c.timeout())
290 .min()
291 .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
292
293 tokio::select! {
294 _ = tokio::time::sleep(timeout) => {
295 debug!("timeout");
296 for (_, client) in clients.iter_mut() {
297 // If no timeout has occurred it does nothing.
298 client.on_timeout();
299
300 let connection_id = client.connection_id().clone();
301 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
302 }
303 }
304
305 Ok((len, src)) = frontend_socket.recv_from(&mut frontend_buf) => {
306 debug!("Got {} bytes from {}", len, src);
307
308 // Parse QUIC packet.
309 let pkt_buf = &mut frontend_buf[..len];
310 let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) {
311 Ok(v) => v,
312 Err(e) => {
313 error!("Failed to parse QUIC header: {:?}", e);
314 continue;
315 }
316 };
317 debug!("Got QUIC packet: {:?}", hdr);
318
319 let client = match clients.get_or_create(&hdr, &src) {
320 Ok(v) => v,
321 Err(e) => {
322 error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
323 continue;
324 }
325 };
326 debug!("Got client: {:?}", client);
327
328 match client.handle_frontend_message(pkt_buf) {
329 Ok(v) if !v.is_empty() => {
330 delay_queries_buffer.push(v);
331 queries_received += 1;
332 }
333 Err(e) => {
334 error!("Failed to process QUIC packet: {}", e);
335 continue;
336 }
337 _ => {}
338 }
339
340 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
341 for query in delay_queries_buffer.drain(..) {
342 debug!("sending {} bytes to backend", query.len());
343 backend_socket.send(&query).await?;
344 }
345 }
346
347 let connection_id = client.connection_id().clone();
348 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
349 }
350
351 Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
352 debug!("Got {} bytes from {}", len, src);
353 if len < DNS_HEADER_SIZE {
354 error!("Received insufficient bytes for DNS header");
355 continue;
356 }
357
358 let query_id = [backend_buf[0], backend_buf[1]];
359 for (_, client) in clients.iter_mut() {
360 if client.is_waiting_for_query(&query_id) {
361 if let Err(e) = client.handle_backend_message(&backend_buf[..len]) {
362 error!("Failed to handle message from backend: {}", e);
363 }
364 let connection_id = client.connection_id().clone();
365 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
366
367 // It's a bug if more than one client is waiting for this query.
368 break;
369 }
370 }
371 }
372
373 Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
374 match command {
375 InternalCommand::MaybeWrite {connection_id} => {
376 if let Some(client) = clients.get_mut(&connection_id) {
377 while let Ok(v) = client.flush_egress() {
378 let addr = client.addr();
379 debug!("Sending {} bytes to client {}", v.len(), addr);
380 if let Err(e) = frontend_socket.send_to(&v, addr).await {
381 error!("Failed to send packet to {:?}: {:?}", client, e);
382 }
383 }
384 client.process_pending_answers()?;
385 }
386 }
387 }
388 }
389 Some(command) = command_rx.recv() => {
390 debug!("ControlCommand: {:?}", command);
391 match command {
392 ControlCommand::Stats {resp} => {
393 let stats = Stats {
394 queries_received,
395 connections_accepted: clients.len() as u32,
396 alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
397 resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
398 };
399 if let Err(e) = resp.send(stats) {
400 error!("Failed to send ControlCommand::Stats response: {:?}", e);
401 }
402 }
403 ControlCommand::StatsClearQueries => queries_received = 0,
404 ControlCommand::CloseConnection => {
405 for (_, client) in clients.iter_mut() {
406 client.close();
407 event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
408 }
409 }
410 }
411 }
412 }
413 }
414 }
415
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>416 fn create_quiche_config(
417 certificate: String,
418 private_key: String,
419 config: Arc<Mutex<Config>>,
420 ) -> Result<quiche::Config> {
421 let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
422
423 // Use pipe as a file path for Quiche to read the certificate and the private key.
424 let (rd, mut wr) = build_pipe()?;
425 let handle = std::thread::spawn(move || {
426 wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
427 });
428 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
429 quiche_config.load_cert_chain_from_pem_file(&filepath)?;
430 handle.join().unwrap();
431
432 let (rd, mut wr) = build_pipe()?;
433 let handle = std::thread::spawn(move || {
434 wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
435 });
436 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
437 quiche_config.load_priv_key_from_pem_file(&filepath)?;
438 handle.join().unwrap();
439
440 quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
441 quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
442 quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
443
444 let max_buffer_size = config.lock().unwrap().max_buffer_size;
445 quiche_config.set_initial_max_data(max_buffer_size);
446 quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
447 quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
448 quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
449
450 quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
451 quiche_config.set_initial_max_streams_uni(100);
452 quiche_config.set_disable_active_migration(true);
453
454 Ok(quiche_config)
455 }
456
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>457 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
458 match UdpSocket::from_std(socket) {
459 Ok(v) => Ok(v),
460 Err(e) => {
461 error!("into_tokio_udp_socket failed: {}", e);
462 bail!("into_tokio_udp_socket failed: {}", e)
463 }
464 }
465 }
466
build_pipe() -> Result<(File, File)>467 fn build_pipe() -> Result<(File, File)> {
468 let mut fds = [0, 0];
469 unsafe {
470 if libc::pipe(fds.as_mut_ptr()) == 0 {
471 return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
472 }
473 }
474 Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
475 }
476
477 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>478 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
479 for _ in 0..3 {
480 match std::net::UdpSocket::bind(addr) {
481 Ok(socket) => return Ok(socket),
482 Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
483 warn!("Binding socket address {} that is in use. Try again", addr);
484 std::thread::sleep(Duration::from_millis(50));
485 }
486 Err(e) => return Err(anyhow::anyhow!(e)),
487 }
488 }
489 Err(anyhow::anyhow!(std::io::Error::last_os_error()))
490 }
491