1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 use crate::boot_time::{BootTime, Duration}; 18 use anyhow::Result; 19 use log::error; 20 use tokio::runtime::{Builder, Runtime}; 21 use tokio::sync::{mpsc, oneshot}; 22 use tokio::task; 23 24 pub use crate::network::{ServerInfo, SocketTagger, ValidationReporter}; 25 26 const MAX_BUFFERED_CMD_COUNT: usize = 400; 27 28 mod driver; 29 use driver::Driver; 30 31 #[derive(Eq, PartialEq, Debug)] 32 /// Error response to a query 33 pub enum QueryError { 34 /// Network failed probing 35 BrokenServer, 36 /// HTTP/3 connection died 37 ConnectionError, 38 /// Network not probed yet 39 ServerNotReady, 40 /// Server reset HTTP/3 stream 41 Reset(u64), 42 /// Tried to query non-existent network 43 Unexpected, 44 } 45 46 #[derive(Eq, PartialEq, Debug)] 47 pub enum Response { 48 Error { error: QueryError }, 49 Success { answer: Vec<u8> }, 50 } 51 52 #[derive(Debug)] 53 pub enum Command { 54 Probe { 55 info: ServerInfo, 56 timeout: Duration, 57 }, 58 Query { 59 net_id: u32, 60 base64_query: String, 61 expired_time: BootTime, 62 resp: oneshot::Sender<Response>, 63 }, 64 Clear { 65 net_id: u32, 66 }, 67 Exit, 68 } 69 70 /// Context for a running DoH engine. 71 pub struct Dispatcher { 72 /// Used to submit cmds to the I/O task. 73 cmd_sender: mpsc::Sender<Command>, 74 join_handle: task::JoinHandle<Result<()>>, 75 runtime: Runtime, 76 } 77 78 impl Dispatcher { 79 const DOH_THREADS: usize = 1; 80 new(validation: ValidationReporter, tagger: SocketTagger) -> Result<Dispatcher>81 pub fn new(validation: ValidationReporter, tagger: SocketTagger) -> Result<Dispatcher> { 82 let (cmd_sender, cmd_receiver) = mpsc::channel::<Command>(MAX_BUFFERED_CMD_COUNT); 83 let runtime = Builder::new_multi_thread() 84 .worker_threads(Self::DOH_THREADS) 85 .enable_all() 86 .thread_name("doh-handler") 87 .build()?; 88 let join_handle = runtime.spawn(async { 89 let result = Driver::new(cmd_receiver, validation, tagger).drive().await; 90 if let Err(ref e) = result { error!("Dispatcher driver exited due to {:?}", e) } 91 result 92 }); 93 Ok(Dispatcher { cmd_sender, join_handle, runtime }) 94 } 95 send_cmd(&self, cmd: Command) -> Result<()>96 pub fn send_cmd(&self, cmd: Command) -> Result<()> { 97 self.cmd_sender.blocking_send(cmd)?; 98 Ok(()) 99 } 100 exit_handler(&mut self)101 pub fn exit_handler(&mut self) { 102 if self.cmd_sender.blocking_send(Command::Exit).is_err() { 103 return; 104 } 105 let _ = self.runtime.block_on(&mut self.join_handle); 106 } 107 } 108