• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! C API for the DoH backend for the Android DnsResolver module.
18 
19 use crate::boot_time::{timeout, BootTime, Duration};
20 use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
21 use crate::network::{SocketTagger, ValidationReporter};
22 use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
23 use futures::FutureExt;
24 use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
25 use log::{error, warn};
26 use std::ffi::CString;
27 use std::net::{IpAddr, SocketAddr};
28 use std::ops::DerefMut;
29 use std::os::unix::io::RawFd;
30 use std::str::FromStr;
31 use std::sync::{Arc, Mutex};
32 use std::{ptr, slice};
33 use tokio::runtime::Builder;
34 use tokio::sync::oneshot;
35 use tokio::task;
36 use url::Url;
37 
38 pub type ValidationCallback =
39     extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
40 pub type TagSocketCallback = extern "C" fn(sock: RawFd);
41 
42 #[repr(C)]
43 pub struct FeatureFlags {
44     probe_timeout_ms: uint64_t,
45     idle_timeout_ms: uint64_t,
46     use_session_resumption: bool,
47     enable_early_data: bool,
48 }
49 
wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter50 fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
51     Arc::new(move |info: &ServerInfo, success: bool| {
52         async move {
53             let (ip_addr, domain) = match (
54                 CString::new(info.peer_addr.ip().to_string()),
55                 CString::new(info.domain.clone().unwrap_or_default()),
56             ) {
57                 (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
58                 _ => {
59                     error!("validation_callback bad input");
60                     return;
61                 }
62             };
63             let netd_id = info.net_id;
64             task::spawn_blocking(move || {
65                 validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
66             })
67             .await
68             .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
69         }
70         .boxed()
71     })
72 }
73 
wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger74 fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
75     use std::os::unix::io::AsRawFd;
76     Arc::new(move |udp_socket: &std::net::UdpSocket| {
77         let fd = udp_socket.as_raw_fd();
78         async move {
79             task::spawn_blocking(move || {
80                 tag_socket_fn(fd);
81             })
82             .await
83             .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
84         }
85         .boxed()
86     })
87 }
88 
89 pub struct DohDispatcher(Mutex<Dispatcher>);
90 
91 impl DohDispatcher {
lock(&self) -> impl DerefMut<Target = Dispatcher> + '_92     fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
93         self.0.lock().unwrap()
94     }
95 }
96 
97 const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
98 
99 /// The return code of doh_query means that there is no answer.
100 pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
101 /// The return code of doh_query means that query can't be sent.
102 pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
103 /// The return code of doh_query to indicate that the query timed out.
104 pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
105 
106 /// The error log level.
107 pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
108 /// The warning log level.
109 pub const DOH_LOG_LEVEL_WARN: u32 = 1;
110 /// The info log level.
111 pub const DOH_LOG_LEVEL_INFO: u32 = 2;
112 /// The debug log level.
113 pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
114 /// The trace log level.
115 pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
116 
117 const DOH_PORT: u16 = 443;
118 
level_from_u32(level: u32) -> Option<log::Level>119 fn level_from_u32(level: u32) -> Option<log::Level> {
120     use log::Level::*;
121     match level {
122         DOH_LOG_LEVEL_ERROR => Some(Error),
123         DOH_LOG_LEVEL_WARN => Some(Warn),
124         DOH_LOG_LEVEL_INFO => Some(Info),
125         DOH_LOG_LEVEL_DEBUG => Some(Debug),
126         DOH_LOG_LEVEL_TRACE => Some(Trace),
127         _ => None,
128     }
129 }
130 
131 /// Performs static initialization for android logger.
132 /// If an invalid level is passed, defaults to logging errors only.
133 /// If called more than once, it will have no effect on subsequent calls.
134 #[no_mangle]
doh_init_logger(level: u32)135 pub extern "C" fn doh_init_logger(level: u32) {
136     let log_level = level_from_u32(level).unwrap_or(log::Level::Error);
137     android_logger::init_once(android_logger::Config::default().with_min_level(log_level));
138 }
139 
140 /// Set the log level.
141 /// If an invalid level is passed, defaults to logging errors only.
142 #[no_mangle]
doh_set_log_level(level: u32)143 pub extern "C" fn doh_set_log_level(level: u32) {
144     let level_filter = level_from_u32(level)
145         .map(|level| level.to_level_filter())
146         .unwrap_or(log::LevelFilter::Error);
147     log::set_max_level(level_filter);
148 }
149 
150 /// Performs the initialization for the DoH engine.
151 /// Creates and returns a DoH engine instance.
152 #[no_mangle]
doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher153 pub extern "C" fn doh_dispatcher_new(
154     validation_fn: ValidationCallback,
155     tag_socket_fn: TagSocketCallback,
156 ) -> *mut DohDispatcher {
157     match Dispatcher::new(
158         wrap_validation_callback(validation_fn),
159         wrap_tag_socket_callback(tag_socket_fn),
160     ) {
161         Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
162         Err(e) => {
163             error!("doh_dispatcher_new: failed: {:?}", e);
164             ptr::null_mut()
165         }
166     }
167 }
168 
169 /// Deletes a DoH engine created by doh_dispatcher_new().
170 /// # Safety
171 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
172 /// and not yet deleted by `doh_dispatcher_delete()`.
173 #[no_mangle]
doh_dispatcher_delete(doh: *mut DohDispatcher)174 pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
175     Box::from_raw(doh).lock().exit_handler()
176 }
177 
178 /// Probes and stores the DoH server with the given configurations.
179 /// Use the negative errno-style codes as the return value to represent the result.
180 /// # Safety
181 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
182 /// and not yet deleted by `doh_dispatcher_delete()`.
183 /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
184 #[no_mangle]
doh_net_new( doh: &DohDispatcher, net_id: uint32_t, url: *const c_char, domain: *const c_char, ip_addr: *const c_char, sk_mark: libc::uint32_t, cert_path: *const c_char, flags: &FeatureFlags, ) -> int32_t185 pub unsafe extern "C" fn doh_net_new(
186     doh: &DohDispatcher,
187     net_id: uint32_t,
188     url: *const c_char,
189     domain: *const c_char,
190     ip_addr: *const c_char,
191     sk_mark: libc::uint32_t,
192     cert_path: *const c_char,
193     flags: &FeatureFlags,
194 ) -> int32_t {
195     let (url, domain, ip_addr, cert_path) = match (
196         std::ffi::CStr::from_ptr(url).to_str(),
197         std::ffi::CStr::from_ptr(domain).to_str(),
198         std::ffi::CStr::from_ptr(ip_addr).to_str(),
199         std::ffi::CStr::from_ptr(cert_path).to_str(),
200     ) {
201         (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
202             if domain.is_empty() {
203                 (url, None, ip_addr.to_string(), None)
204             } else if !cert_path.is_empty() {
205                 (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
206             } else {
207                 (
208                     url,
209                     Some(domain.to_string()),
210                     ip_addr.to_string(),
211                     Some(SYSTEM_CERT_PATH.to_string()),
212                 )
213             }
214         }
215         _ => {
216             error!("bad input"); // Should not happen
217             return -libc::EINVAL;
218         }
219     };
220 
221     let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
222         (Ok(url), Ok(ip_addr)) => (url, ip_addr),
223         _ => {
224             error!("bad ip or url"); // Should not happen
225             return -libc::EINVAL;
226         }
227     };
228     let cmd = Command::Probe {
229         info: ServerInfo {
230             net_id,
231             url,
232             peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
233             domain,
234             sk_mark,
235             cert_path,
236             idle_timeout_ms: flags.idle_timeout_ms,
237             use_session_resumption: flags.use_session_resumption,
238             enable_early_data: flags.enable_early_data,
239         },
240         timeout: Duration::from_millis(flags.probe_timeout_ms),
241     };
242     if let Err(e) = doh.lock().send_cmd(cmd) {
243         error!("Failed to send the probe: {:?}", e);
244         return -libc::EPIPE;
245     }
246     0
247 }
248 
249 /// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
250 /// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
251 /// or the size of the answer.
252 /// # Safety
253 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
254 /// and not yet deleted by `doh_dispatcher_delete()`.
255 /// `dns_query` must point to a buffer at least `dns_query_len` in size.
256 /// `response` must point to a buffer at least `response_len` in size.
257 #[no_mangle]
doh_query( doh: &DohDispatcher, net_id: uint32_t, dns_query: *mut u8, dns_query_len: size_t, response: *mut u8, response_len: size_t, timeout_ms: uint64_t, ) -> ssize_t258 pub unsafe extern "C" fn doh_query(
259     doh: &DohDispatcher,
260     net_id: uint32_t,
261     dns_query: *mut u8,
262     dns_query_len: size_t,
263     response: *mut u8,
264     response_len: size_t,
265     timeout_ms: uint64_t,
266 ) -> ssize_t {
267     let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
268 
269     let (resp_tx, resp_rx) = oneshot::channel();
270     let t = Duration::from_millis(timeout_ms);
271     if let Some(expired_time) = BootTime::now().checked_add(t) {
272         let cmd = Command::Query {
273             net_id,
274             base64_query: BASE64_URL_SAFE_NO_PAD.encode(q),
275             expired_time,
276             resp: resp_tx,
277         };
278 
279         if let Err(e) = doh.lock().send_cmd(cmd) {
280             error!("Failed to send the query: {:?}", e);
281             return DOH_RESULT_CAN_NOT_SEND;
282         }
283     } else {
284         error!("Bad timeout parameter: {}", timeout_ms);
285         return DOH_RESULT_CAN_NOT_SEND;
286     }
287 
288     if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
289         let local = task::LocalSet::new();
290         match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
291             Ok(v) => match v {
292                 Ok(v) => match v {
293                     Response::Success { answer } => {
294                         if answer.len() > response_len || answer.len() > isize::MAX as usize {
295                             return DOH_RESULT_INTERNAL_ERROR;
296                         }
297                         let response = slice::from_raw_parts_mut(response, answer.len());
298                         response.copy_from_slice(&answer);
299                         answer.len() as ssize_t
300                     }
301                     rsp => {
302                         error!("Non-successful response: {:?}", rsp);
303                         DOH_RESULT_CAN_NOT_SEND
304                     }
305                 },
306                 Err(e) => {
307                     error!("no result {}", e);
308                     DOH_RESULT_CAN_NOT_SEND
309                 }
310             },
311             Err(e) => {
312                 error!("timeout: {}", e);
313                 DOH_RESULT_TIMEOUT
314             }
315         }
316     } else {
317         DOH_RESULT_CAN_NOT_SEND
318     }
319 }
320 
321 /// Clears the DoH servers associated with the given |netid|.
322 /// # Safety
323 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
324 /// and not yet deleted by `doh_dispatcher_delete()`.
325 #[no_mangle]
doh_net_delete(doh: &DohDispatcher, net_id: uint32_t)326 pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
327     if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
328         error!("Failed to send the query: {:?}", e);
329     }
330 }
331 
332 #[cfg(test)]
333 mod tests {
334     use super::*;
335 
336     const TEST_NET_ID: u32 = 50;
337     const LOOPBACK_ADDR: &str = "127.0.0.1:443";
338     const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
339 
success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )340     extern "C" fn success_cb(
341         net_id: uint32_t,
342         success: bool,
343         ip_addr: *const c_char,
344         host: *const c_char,
345     ) {
346         assert!(success);
347         unsafe {
348             assert_validation_info(net_id, ip_addr, host);
349         }
350     }
351 
fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )352     extern "C" fn fail_cb(
353         net_id: uint32_t,
354         success: bool,
355         ip_addr: *const c_char,
356         host: *const c_char,
357     ) {
358         assert!(!success);
359         unsafe {
360             assert_validation_info(net_id, ip_addr, host);
361         }
362     }
363 
364     // # Safety
365     // `ip_addr`, `host` are null terminated strings
assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, )366     unsafe fn assert_validation_info(
367         net_id: uint32_t,
368         ip_addr: *const c_char,
369         host: *const c_char,
370     ) {
371         assert_eq!(net_id, TEST_NET_ID);
372         let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
373         let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
374         assert_eq!(ip_addr, expected_addr.ip().to_string());
375         let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
376         assert_eq!(host, "");
377     }
378 
379     #[tokio::test]
wrap_validation_callback_converts_correctly()380     async fn wrap_validation_callback_converts_correctly() {
381         let info = ServerInfo {
382             net_id: TEST_NET_ID,
383             url: Url::parse(LOCALHOST_URL).unwrap(),
384             peer_addr: LOOPBACK_ADDR.parse().unwrap(),
385             domain: None,
386             sk_mark: 0,
387             cert_path: None,
388             idle_timeout_ms: 0,
389             use_session_resumption: true,
390             enable_early_data: true,
391         };
392 
393         wrap_validation_callback(success_cb)(&info, true).await;
394         wrap_validation_callback(fail_cb)(&info, false).await;
395     }
396 
tag_socket_cb(raw_fd: RawFd)397     extern "C" fn tag_socket_cb(raw_fd: RawFd) {
398         assert!(raw_fd > 0)
399     }
400 
401     #[tokio::test]
wrap_tag_socket_callback_converts_correctly()402     async fn wrap_tag_socket_callback_converts_correctly() {
403         let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
404         wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
405     }
406 }
407