• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! C API for the DoH backend for the Android DnsResolver module.
18 
19 use crate::boot_time::{timeout, BootTime, Duration};
20 use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
21 use crate::network::{SocketTagger, ValidationReporter};
22 use futures::FutureExt;
23 use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
24 use log::{error, warn};
25 use std::ffi::CString;
26 use std::net::{IpAddr, SocketAddr};
27 use std::ops::DerefMut;
28 use std::os::unix::io::RawFd;
29 use std::str::FromStr;
30 use std::sync::{Arc, Mutex};
31 use std::{ptr, slice};
32 use tokio::runtime::Builder;
33 use tokio::sync::oneshot;
34 use tokio::task;
35 use url::Url;
36 
37 pub type ValidationCallback =
38     extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
39 pub type TagSocketCallback = extern "C" fn(sock: RawFd);
40 
41 #[repr(C)]
42 pub struct FeatureFlags {
43     probe_timeout_ms: uint64_t,
44     idle_timeout_ms: uint64_t,
45     use_session_resumption: bool,
46 }
47 
wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter48 fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
49     Arc::new(move |info: &ServerInfo, success: bool| {
50         async move {
51             let (ip_addr, domain) = match (
52                 CString::new(info.peer_addr.ip().to_string()),
53                 CString::new(info.domain.clone().unwrap_or_default()),
54             ) {
55                 (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
56                 _ => {
57                     error!("validation_callback bad input");
58                     return;
59                 }
60             };
61             let netd_id = info.net_id;
62             task::spawn_blocking(move || {
63                 validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
64             })
65             .await
66             .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
67         }
68         .boxed()
69     })
70 }
71 
wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger72 fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
73     use std::os::unix::io::AsRawFd;
74     Arc::new(move |udp_socket: &std::net::UdpSocket| {
75         let fd = udp_socket.as_raw_fd();
76         async move {
77             task::spawn_blocking(move || {
78                 tag_socket_fn(fd);
79             })
80             .await
81             .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
82         }
83         .boxed()
84     })
85 }
86 
87 pub struct DohDispatcher(Mutex<Dispatcher>);
88 
89 impl DohDispatcher {
lock(&self) -> impl DerefMut<Target = Dispatcher> + '_90     fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
91         self.0.lock().unwrap()
92     }
93 }
94 
95 const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
96 
97 /// The return code of doh_query means that there is no answer.
98 pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
99 /// The return code of doh_query means that query can't be sent.
100 pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
101 /// The return code of doh_query to indicate that the query timed out.
102 pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
103 
104 /// The error log level.
105 pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
106 /// The warning log level.
107 pub const DOH_LOG_LEVEL_WARN: u32 = 1;
108 /// The info log level.
109 pub const DOH_LOG_LEVEL_INFO: u32 = 2;
110 /// The debug log level.
111 pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
112 /// The trace log level.
113 pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
114 
115 const DOH_PORT: u16 = 443;
116 
level_from_u32(level: u32) -> Option<log::Level>117 fn level_from_u32(level: u32) -> Option<log::Level> {
118     use log::Level::*;
119     match level {
120         DOH_LOG_LEVEL_ERROR => Some(Error),
121         DOH_LOG_LEVEL_WARN => Some(Warn),
122         DOH_LOG_LEVEL_INFO => Some(Info),
123         DOH_LOG_LEVEL_DEBUG => Some(Debug),
124         DOH_LOG_LEVEL_TRACE => Some(Trace),
125         _ => None,
126     }
127 }
128 
129 /// Performs static initialization for android logger.
130 /// If an invalid level is passed, defaults to logging errors only.
131 /// If called more than once, it will have no effect on subsequent calls.
132 #[no_mangle]
doh_init_logger(level: u32)133 pub extern "C" fn doh_init_logger(level: u32) {
134     let log_level = level_from_u32(level).unwrap_or(log::Level::Error);
135     android_logger::init_once(android_logger::Config::default().with_min_level(log_level));
136 }
137 
138 /// Set the log level.
139 /// If an invalid level is passed, defaults to logging errors only.
140 #[no_mangle]
doh_set_log_level(level: u32)141 pub extern "C" fn doh_set_log_level(level: u32) {
142     let level_filter = level_from_u32(level)
143         .map(|level| level.to_level_filter())
144         .unwrap_or(log::LevelFilter::Error);
145     log::set_max_level(level_filter);
146 }
147 
148 /// Performs the initialization for the DoH engine.
149 /// Creates and returns a DoH engine instance.
150 #[no_mangle]
doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher151 pub extern "C" fn doh_dispatcher_new(
152     validation_fn: ValidationCallback,
153     tag_socket_fn: TagSocketCallback,
154 ) -> *mut DohDispatcher {
155     match Dispatcher::new(
156         wrap_validation_callback(validation_fn),
157         wrap_tag_socket_callback(tag_socket_fn),
158     ) {
159         Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
160         Err(e) => {
161             error!("doh_dispatcher_new: failed: {:?}", e);
162             ptr::null_mut()
163         }
164     }
165 }
166 
167 /// Deletes a DoH engine created by doh_dispatcher_new().
168 /// # Safety
169 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
170 /// and not yet deleted by `doh_dispatcher_delete()`.
171 #[no_mangle]
doh_dispatcher_delete(doh: *mut DohDispatcher)172 pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
173     Box::from_raw(doh).lock().exit_handler()
174 }
175 
176 /// Probes and stores the DoH server with the given configurations.
177 /// Use the negative errno-style codes as the return value to represent the result.
178 /// # Safety
179 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
180 /// and not yet deleted by `doh_dispatcher_delete()`.
181 /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
182 #[no_mangle]
doh_net_new( doh: &DohDispatcher, net_id: uint32_t, url: *const c_char, domain: *const c_char, ip_addr: *const c_char, sk_mark: libc::uint32_t, cert_path: *const c_char, flags: &FeatureFlags, ) -> int32_t183 pub unsafe extern "C" fn doh_net_new(
184     doh: &DohDispatcher,
185     net_id: uint32_t,
186     url: *const c_char,
187     domain: *const c_char,
188     ip_addr: *const c_char,
189     sk_mark: libc::uint32_t,
190     cert_path: *const c_char,
191     flags: &FeatureFlags,
192 ) -> int32_t {
193     let (url, domain, ip_addr, cert_path) = match (
194         std::ffi::CStr::from_ptr(url).to_str(),
195         std::ffi::CStr::from_ptr(domain).to_str(),
196         std::ffi::CStr::from_ptr(ip_addr).to_str(),
197         std::ffi::CStr::from_ptr(cert_path).to_str(),
198     ) {
199         (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
200             if domain.is_empty() {
201                 (url, None, ip_addr.to_string(), None)
202             } else if !cert_path.is_empty() {
203                 (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
204             } else {
205                 (
206                     url,
207                     Some(domain.to_string()),
208                     ip_addr.to_string(),
209                     Some(SYSTEM_CERT_PATH.to_string()),
210                 )
211             }
212         }
213         _ => {
214             error!("bad input"); // Should not happen
215             return -libc::EINVAL;
216         }
217     };
218 
219     let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
220         (Ok(url), Ok(ip_addr)) => (url, ip_addr),
221         _ => {
222             error!("bad ip or url"); // Should not happen
223             return -libc::EINVAL;
224         }
225     };
226     let cmd = Command::Probe {
227         info: ServerInfo {
228             net_id,
229             url,
230             peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
231             domain,
232             sk_mark,
233             cert_path,
234             idle_timeout_ms: flags.idle_timeout_ms,
235             use_session_resumption: flags.use_session_resumption,
236         },
237         timeout: Duration::from_millis(flags.probe_timeout_ms),
238     };
239     if let Err(e) = doh.lock().send_cmd(cmd) {
240         error!("Failed to send the probe: {:?}", e);
241         return -libc::EPIPE;
242     }
243     0
244 }
245 
246 /// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
247 /// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
248 /// or the size of the answer.
249 /// # Safety
250 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
251 /// and not yet deleted by `doh_dispatcher_delete()`.
252 /// `dns_query` must point to a buffer at least `dns_query_len` in size.
253 /// `response` must point to a buffer at least `response_len` in size.
254 #[no_mangle]
doh_query( doh: &DohDispatcher, net_id: uint32_t, dns_query: *mut u8, dns_query_len: size_t, response: *mut u8, response_len: size_t, timeout_ms: uint64_t, ) -> ssize_t255 pub unsafe extern "C" fn doh_query(
256     doh: &DohDispatcher,
257     net_id: uint32_t,
258     dns_query: *mut u8,
259     dns_query_len: size_t,
260     response: *mut u8,
261     response_len: size_t,
262     timeout_ms: uint64_t,
263 ) -> ssize_t {
264     let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
265 
266     let (resp_tx, resp_rx) = oneshot::channel();
267     let t = Duration::from_millis(timeout_ms);
268     if let Some(expired_time) = BootTime::now().checked_add(t) {
269         let cmd = Command::Query {
270             net_id,
271             base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
272             expired_time,
273             resp: resp_tx,
274         };
275 
276         if let Err(e) = doh.lock().send_cmd(cmd) {
277             error!("Failed to send the query: {:?}", e);
278             return DOH_RESULT_CAN_NOT_SEND;
279         }
280     } else {
281         error!("Bad timeout parameter: {}", timeout_ms);
282         return DOH_RESULT_CAN_NOT_SEND;
283     }
284 
285     if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
286         let local = task::LocalSet::new();
287         match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
288             Ok(v) => match v {
289                 Ok(v) => match v {
290                     Response::Success { answer } => {
291                         if answer.len() > response_len || answer.len() > isize::MAX as usize {
292                             return DOH_RESULT_INTERNAL_ERROR;
293                         }
294                         let response = slice::from_raw_parts_mut(response, answer.len());
295                         response.copy_from_slice(&answer);
296                         answer.len() as ssize_t
297                     }
298                     rsp => {
299                         error!("Non-successful response: {:?}", rsp);
300                         DOH_RESULT_CAN_NOT_SEND
301                     }
302                 },
303                 Err(e) => {
304                     error!("no result {}", e);
305                     DOH_RESULT_CAN_NOT_SEND
306                 }
307             },
308             Err(e) => {
309                 error!("timeout: {}", e);
310                 DOH_RESULT_TIMEOUT
311             }
312         }
313     } else {
314         DOH_RESULT_CAN_NOT_SEND
315     }
316 }
317 
318 /// Clears the DoH servers associated with the given |netid|.
319 /// # Safety
320 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
321 /// and not yet deleted by `doh_dispatcher_delete()`.
322 #[no_mangle]
doh_net_delete(doh: &DohDispatcher, net_id: uint32_t)323 pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
324     if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
325         error!("Failed to send the query: {:?}", e);
326     }
327 }
328 
329 #[cfg(test)]
330 mod tests {
331     use super::*;
332 
333     const TEST_NET_ID: u32 = 50;
334     const LOOPBACK_ADDR: &str = "127.0.0.1:443";
335     const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
336 
success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )337     extern "C" fn success_cb(
338         net_id: uint32_t,
339         success: bool,
340         ip_addr: *const c_char,
341         host: *const c_char,
342     ) {
343         assert!(success);
344         unsafe {
345             assert_validation_info(net_id, ip_addr, host);
346         }
347     }
348 
fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )349     extern "C" fn fail_cb(
350         net_id: uint32_t,
351         success: bool,
352         ip_addr: *const c_char,
353         host: *const c_char,
354     ) {
355         assert!(!success);
356         unsafe {
357             assert_validation_info(net_id, ip_addr, host);
358         }
359     }
360 
361     // # Safety
362     // `ip_addr`, `host` are null terminated strings
assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, )363     unsafe fn assert_validation_info(
364         net_id: uint32_t,
365         ip_addr: *const c_char,
366         host: *const c_char,
367     ) {
368         assert_eq!(net_id, TEST_NET_ID);
369         let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
370         let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
371         assert_eq!(ip_addr, expected_addr.ip().to_string());
372         let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
373         assert_eq!(host, "");
374     }
375 
376     #[tokio::test]
wrap_validation_callback_converts_correctly()377     async fn wrap_validation_callback_converts_correctly() {
378         let info = ServerInfo {
379             net_id: TEST_NET_ID,
380             url: Url::parse(LOCALHOST_URL).unwrap(),
381             peer_addr: LOOPBACK_ADDR.parse().unwrap(),
382             domain: None,
383             sk_mark: 0,
384             cert_path: None,
385             idle_timeout_ms: 0,
386             use_session_resumption: true,
387         };
388 
389         wrap_validation_callback(success_cb)(&info, true).await;
390         wrap_validation_callback(fail_cb)(&info, false).await;
391     }
392 
tag_socket_cb(raw_fd: RawFd)393     extern "C" fn tag_socket_cb(raw_fd: RawFd) {
394         assert!(raw_fd > 0)
395     }
396 
397     #[tokio::test]
wrap_tag_socket_callback_converts_correctly()398     async fn wrap_tag_socket_callback_converts_correctly() {
399         let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
400         wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
401     }
402 }
403