1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! C API for the DoH backend for the Android DnsResolver module.
18
19 use crate::boot_time::{timeout, BootTime, Duration};
20 use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
21 use crate::network::{SocketTagger, ValidationReporter};
22 use futures::FutureExt;
23 use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
24 use log::{error, warn};
25 use std::ffi::CString;
26 use std::net::{IpAddr, SocketAddr};
27 use std::ops::DerefMut;
28 use std::os::unix::io::RawFd;
29 use std::str::FromStr;
30 use std::sync::{Arc, Mutex};
31 use std::{ptr, slice};
32 use tokio::runtime::Builder;
33 use tokio::sync::oneshot;
34 use tokio::task;
35 use url::Url;
36
37 pub type ValidationCallback =
38 extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
39 pub type TagSocketCallback = extern "C" fn(sock: RawFd);
40
41 #[repr(C)]
42 pub struct FeatureFlags {
43 probe_timeout_ms: uint64_t,
44 idle_timeout_ms: uint64_t,
45 use_session_resumption: bool,
46 }
47
wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter48 fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
49 Arc::new(move |info: &ServerInfo, success: bool| {
50 async move {
51 let (ip_addr, domain) = match (
52 CString::new(info.peer_addr.ip().to_string()),
53 CString::new(info.domain.clone().unwrap_or_default()),
54 ) {
55 (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
56 _ => {
57 error!("validation_callback bad input");
58 return;
59 }
60 };
61 let netd_id = info.net_id;
62 task::spawn_blocking(move || {
63 validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
64 })
65 .await
66 .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
67 }
68 .boxed()
69 })
70 }
71
wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger72 fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
73 use std::os::unix::io::AsRawFd;
74 Arc::new(move |udp_socket: &std::net::UdpSocket| {
75 let fd = udp_socket.as_raw_fd();
76 async move {
77 task::spawn_blocking(move || {
78 tag_socket_fn(fd);
79 })
80 .await
81 .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
82 }
83 .boxed()
84 })
85 }
86
87 pub struct DohDispatcher(Mutex<Dispatcher>);
88
89 impl DohDispatcher {
lock(&self) -> impl DerefMut<Target = Dispatcher> + '_90 fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
91 self.0.lock().unwrap()
92 }
93 }
94
95 const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
96
97 /// The return code of doh_query means that there is no answer.
98 pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
99 /// The return code of doh_query means that query can't be sent.
100 pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
101 /// The return code of doh_query to indicate that the query timed out.
102 pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
103
104 /// The error log level.
105 pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
106 /// The warning log level.
107 pub const DOH_LOG_LEVEL_WARN: u32 = 1;
108 /// The info log level.
109 pub const DOH_LOG_LEVEL_INFO: u32 = 2;
110 /// The debug log level.
111 pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
112 /// The trace log level.
113 pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
114
115 const DOH_PORT: u16 = 443;
116
level_from_u32(level: u32) -> Option<log::Level>117 fn level_from_u32(level: u32) -> Option<log::Level> {
118 use log::Level::*;
119 match level {
120 DOH_LOG_LEVEL_ERROR => Some(Error),
121 DOH_LOG_LEVEL_WARN => Some(Warn),
122 DOH_LOG_LEVEL_INFO => Some(Info),
123 DOH_LOG_LEVEL_DEBUG => Some(Debug),
124 DOH_LOG_LEVEL_TRACE => Some(Trace),
125 _ => None,
126 }
127 }
128
129 /// Performs static initialization for android logger.
130 /// If an invalid level is passed, defaults to logging errors only.
131 /// If called more than once, it will have no effect on subsequent calls.
132 #[no_mangle]
doh_init_logger(level: u32)133 pub extern "C" fn doh_init_logger(level: u32) {
134 let log_level = level_from_u32(level).unwrap_or(log::Level::Error);
135 android_logger::init_once(android_logger::Config::default().with_min_level(log_level));
136 }
137
138 /// Set the log level.
139 /// If an invalid level is passed, defaults to logging errors only.
140 #[no_mangle]
doh_set_log_level(level: u32)141 pub extern "C" fn doh_set_log_level(level: u32) {
142 let level_filter = level_from_u32(level)
143 .map(|level| level.to_level_filter())
144 .unwrap_or(log::LevelFilter::Error);
145 log::set_max_level(level_filter);
146 }
147
148 /// Performs the initialization for the DoH engine.
149 /// Creates and returns a DoH engine instance.
150 #[no_mangle]
doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher151 pub extern "C" fn doh_dispatcher_new(
152 validation_fn: ValidationCallback,
153 tag_socket_fn: TagSocketCallback,
154 ) -> *mut DohDispatcher {
155 match Dispatcher::new(
156 wrap_validation_callback(validation_fn),
157 wrap_tag_socket_callback(tag_socket_fn),
158 ) {
159 Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
160 Err(e) => {
161 error!("doh_dispatcher_new: failed: {:?}", e);
162 ptr::null_mut()
163 }
164 }
165 }
166
167 /// Deletes a DoH engine created by doh_dispatcher_new().
168 /// # Safety
169 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
170 /// and not yet deleted by `doh_dispatcher_delete()`.
171 #[no_mangle]
doh_dispatcher_delete(doh: *mut DohDispatcher)172 pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
173 Box::from_raw(doh).lock().exit_handler()
174 }
175
176 /// Probes and stores the DoH server with the given configurations.
177 /// Use the negative errno-style codes as the return value to represent the result.
178 /// # Safety
179 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
180 /// and not yet deleted by `doh_dispatcher_delete()`.
181 /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
182 #[no_mangle]
doh_net_new( doh: &DohDispatcher, net_id: uint32_t, url: *const c_char, domain: *const c_char, ip_addr: *const c_char, sk_mark: libc::uint32_t, cert_path: *const c_char, flags: &FeatureFlags, ) -> int32_t183 pub unsafe extern "C" fn doh_net_new(
184 doh: &DohDispatcher,
185 net_id: uint32_t,
186 url: *const c_char,
187 domain: *const c_char,
188 ip_addr: *const c_char,
189 sk_mark: libc::uint32_t,
190 cert_path: *const c_char,
191 flags: &FeatureFlags,
192 ) -> int32_t {
193 let (url, domain, ip_addr, cert_path) = match (
194 std::ffi::CStr::from_ptr(url).to_str(),
195 std::ffi::CStr::from_ptr(domain).to_str(),
196 std::ffi::CStr::from_ptr(ip_addr).to_str(),
197 std::ffi::CStr::from_ptr(cert_path).to_str(),
198 ) {
199 (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
200 if domain.is_empty() {
201 (url, None, ip_addr.to_string(), None)
202 } else if !cert_path.is_empty() {
203 (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
204 } else {
205 (
206 url,
207 Some(domain.to_string()),
208 ip_addr.to_string(),
209 Some(SYSTEM_CERT_PATH.to_string()),
210 )
211 }
212 }
213 _ => {
214 error!("bad input"); // Should not happen
215 return -libc::EINVAL;
216 }
217 };
218
219 let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
220 (Ok(url), Ok(ip_addr)) => (url, ip_addr),
221 _ => {
222 error!("bad ip or url"); // Should not happen
223 return -libc::EINVAL;
224 }
225 };
226 let cmd = Command::Probe {
227 info: ServerInfo {
228 net_id,
229 url,
230 peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
231 domain,
232 sk_mark,
233 cert_path,
234 idle_timeout_ms: flags.idle_timeout_ms,
235 use_session_resumption: flags.use_session_resumption,
236 },
237 timeout: Duration::from_millis(flags.probe_timeout_ms),
238 };
239 if let Err(e) = doh.lock().send_cmd(cmd) {
240 error!("Failed to send the probe: {:?}", e);
241 return -libc::EPIPE;
242 }
243 0
244 }
245
246 /// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
247 /// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
248 /// or the size of the answer.
249 /// # Safety
250 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
251 /// and not yet deleted by `doh_dispatcher_delete()`.
252 /// `dns_query` must point to a buffer at least `dns_query_len` in size.
253 /// `response` must point to a buffer at least `response_len` in size.
254 #[no_mangle]
doh_query( doh: &DohDispatcher, net_id: uint32_t, dns_query: *mut u8, dns_query_len: size_t, response: *mut u8, response_len: size_t, timeout_ms: uint64_t, ) -> ssize_t255 pub unsafe extern "C" fn doh_query(
256 doh: &DohDispatcher,
257 net_id: uint32_t,
258 dns_query: *mut u8,
259 dns_query_len: size_t,
260 response: *mut u8,
261 response_len: size_t,
262 timeout_ms: uint64_t,
263 ) -> ssize_t {
264 let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
265
266 let (resp_tx, resp_rx) = oneshot::channel();
267 let t = Duration::from_millis(timeout_ms);
268 if let Some(expired_time) = BootTime::now().checked_add(t) {
269 let cmd = Command::Query {
270 net_id,
271 base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
272 expired_time,
273 resp: resp_tx,
274 };
275
276 if let Err(e) = doh.lock().send_cmd(cmd) {
277 error!("Failed to send the query: {:?}", e);
278 return DOH_RESULT_CAN_NOT_SEND;
279 }
280 } else {
281 error!("Bad timeout parameter: {}", timeout_ms);
282 return DOH_RESULT_CAN_NOT_SEND;
283 }
284
285 if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
286 let local = task::LocalSet::new();
287 match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
288 Ok(v) => match v {
289 Ok(v) => match v {
290 Response::Success { answer } => {
291 if answer.len() > response_len || answer.len() > isize::MAX as usize {
292 return DOH_RESULT_INTERNAL_ERROR;
293 }
294 let response = slice::from_raw_parts_mut(response, answer.len());
295 response.copy_from_slice(&answer);
296 answer.len() as ssize_t
297 }
298 rsp => {
299 error!("Non-successful response: {:?}", rsp);
300 DOH_RESULT_CAN_NOT_SEND
301 }
302 },
303 Err(e) => {
304 error!("no result {}", e);
305 DOH_RESULT_CAN_NOT_SEND
306 }
307 },
308 Err(e) => {
309 error!("timeout: {}", e);
310 DOH_RESULT_TIMEOUT
311 }
312 }
313 } else {
314 DOH_RESULT_CAN_NOT_SEND
315 }
316 }
317
318 /// Clears the DoH servers associated with the given |netid|.
319 /// # Safety
320 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
321 /// and not yet deleted by `doh_dispatcher_delete()`.
322 #[no_mangle]
doh_net_delete(doh: &DohDispatcher, net_id: uint32_t)323 pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
324 if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
325 error!("Failed to send the query: {:?}", e);
326 }
327 }
328
329 #[cfg(test)]
330 mod tests {
331 use super::*;
332
333 const TEST_NET_ID: u32 = 50;
334 const LOOPBACK_ADDR: &str = "127.0.0.1:443";
335 const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
336
success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )337 extern "C" fn success_cb(
338 net_id: uint32_t,
339 success: bool,
340 ip_addr: *const c_char,
341 host: *const c_char,
342 ) {
343 assert!(success);
344 unsafe {
345 assert_validation_info(net_id, ip_addr, host);
346 }
347 }
348
fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )349 extern "C" fn fail_cb(
350 net_id: uint32_t,
351 success: bool,
352 ip_addr: *const c_char,
353 host: *const c_char,
354 ) {
355 assert!(!success);
356 unsafe {
357 assert_validation_info(net_id, ip_addr, host);
358 }
359 }
360
361 // # Safety
362 // `ip_addr`, `host` are null terminated strings
assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, )363 unsafe fn assert_validation_info(
364 net_id: uint32_t,
365 ip_addr: *const c_char,
366 host: *const c_char,
367 ) {
368 assert_eq!(net_id, TEST_NET_ID);
369 let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
370 let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
371 assert_eq!(ip_addr, expected_addr.ip().to_string());
372 let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
373 assert_eq!(host, "");
374 }
375
376 #[tokio::test]
wrap_validation_callback_converts_correctly()377 async fn wrap_validation_callback_converts_correctly() {
378 let info = ServerInfo {
379 net_id: TEST_NET_ID,
380 url: Url::parse(LOCALHOST_URL).unwrap(),
381 peer_addr: LOOPBACK_ADDR.parse().unwrap(),
382 domain: None,
383 sk_mark: 0,
384 cert_path: None,
385 idle_timeout_ms: 0,
386 use_session_resumption: true,
387 };
388
389 wrap_validation_callback(success_cb)(&info, true).await;
390 wrap_validation_callback(fail_cb)(&info, false).await;
391 }
392
tag_socket_cb(raw_fd: RawFd)393 extern "C" fn tag_socket_cb(raw_fd: RawFd) {
394 assert!(raw_fd > 0)
395 }
396
397 #[tokio::test]
wrap_tag_socket_callback_converts_correctly()398 async fn wrap_tag_socket_callback_converts_correctly() {
399 let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
400 wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
401 }
402 }
403