1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! C API for the DoH backend for the Android DnsResolver module.
18
19 use crate::boot_time::{timeout, BootTime, Duration};
20 use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
21 use crate::network::{SocketTagger, ValidationReporter};
22 use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
23 use futures::FutureExt;
24 use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
25 use log::{error, warn};
26 use std::ffi::CString;
27 use std::net::{IpAddr, SocketAddr};
28 use std::ops::DerefMut;
29 use std::os::unix::io::RawFd;
30 use std::str::FromStr;
31 use std::sync::{Arc, Mutex};
32 use std::{ptr, slice};
33 use tokio::runtime::Builder;
34 use tokio::sync::oneshot;
35 use tokio::task;
36 use url::Url;
37
38 pub type ValidationCallback =
39 extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
40 pub type TagSocketCallback = extern "C" fn(sock: RawFd);
41
42 #[repr(C)]
43 pub struct FeatureFlags {
44 probe_timeout_ms: uint64_t,
45 idle_timeout_ms: uint64_t,
46 use_session_resumption: bool,
47 enable_early_data: bool,
48 }
49
wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter50 fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
51 Arc::new(move |info: &ServerInfo, success: bool| {
52 async move {
53 let (ip_addr, domain) = match (
54 CString::new(info.peer_addr.ip().to_string()),
55 CString::new(info.domain.clone().unwrap_or_default()),
56 ) {
57 (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
58 _ => {
59 error!("validation_callback bad input");
60 return;
61 }
62 };
63 let netd_id = info.net_id;
64 task::spawn_blocking(move || {
65 validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
66 })
67 .await
68 .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
69 }
70 .boxed()
71 })
72 }
73
wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger74 fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
75 use std::os::unix::io::AsRawFd;
76 Arc::new(move |udp_socket: &std::net::UdpSocket| {
77 let fd = udp_socket.as_raw_fd();
78 async move {
79 task::spawn_blocking(move || {
80 tag_socket_fn(fd);
81 })
82 .await
83 .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
84 }
85 .boxed()
86 })
87 }
88
89 pub struct DohDispatcher(Mutex<Dispatcher>);
90
91 impl DohDispatcher {
lock(&self) -> impl DerefMut<Target = Dispatcher> + '_92 fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
93 self.0.lock().unwrap()
94 }
95 }
96
97 const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
98
99 /// The return code of doh_query means that there is no answer.
100 pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
101 /// The return code of doh_query means that query can't be sent.
102 pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
103 /// The return code of doh_query to indicate that the query timed out.
104 pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
105
106 /// The error log level.
107 pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
108 /// The warning log level.
109 pub const DOH_LOG_LEVEL_WARN: u32 = 1;
110 /// The info log level.
111 pub const DOH_LOG_LEVEL_INFO: u32 = 2;
112 /// The debug log level.
113 pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
114 /// The trace log level.
115 pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
116
117 const DOH_PORT: u16 = 443;
118
level_from_u32(level: u32) -> Option<log::Level>119 fn level_from_u32(level: u32) -> Option<log::Level> {
120 use log::Level::*;
121 match level {
122 DOH_LOG_LEVEL_ERROR => Some(Error),
123 DOH_LOG_LEVEL_WARN => Some(Warn),
124 DOH_LOG_LEVEL_INFO => Some(Info),
125 DOH_LOG_LEVEL_DEBUG => Some(Debug),
126 DOH_LOG_LEVEL_TRACE => Some(Trace),
127 _ => None,
128 }
129 }
130
131 /// Performs static initialization for android logger.
132 /// If an invalid level is passed, defaults to logging errors only.
133 /// If called more than once, it will have no effect on subsequent calls.
134 #[no_mangle]
doh_init_logger(level: u32)135 pub extern "C" fn doh_init_logger(level: u32) {
136 let log_level = level_from_u32(level).unwrap_or(log::Level::Error);
137 android_logger::init_once(android_logger::Config::default().with_min_level(log_level));
138 }
139
140 /// Set the log level.
141 /// If an invalid level is passed, defaults to logging errors only.
142 #[no_mangle]
doh_set_log_level(level: u32)143 pub extern "C" fn doh_set_log_level(level: u32) {
144 let level_filter = level_from_u32(level)
145 .map(|level| level.to_level_filter())
146 .unwrap_or(log::LevelFilter::Error);
147 log::set_max_level(level_filter);
148 }
149
150 /// Performs the initialization for the DoH engine.
151 /// Creates and returns a DoH engine instance.
152 #[no_mangle]
doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher153 pub extern "C" fn doh_dispatcher_new(
154 validation_fn: ValidationCallback,
155 tag_socket_fn: TagSocketCallback,
156 ) -> *mut DohDispatcher {
157 match Dispatcher::new(
158 wrap_validation_callback(validation_fn),
159 wrap_tag_socket_callback(tag_socket_fn),
160 ) {
161 Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
162 Err(e) => {
163 error!("doh_dispatcher_new: failed: {:?}", e);
164 ptr::null_mut()
165 }
166 }
167 }
168
169 /// Deletes a DoH engine created by doh_dispatcher_new().
170 /// # Safety
171 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
172 /// and not yet deleted by `doh_dispatcher_delete()`.
173 #[no_mangle]
doh_dispatcher_delete(doh: *mut DohDispatcher)174 pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
175 Box::from_raw(doh).lock().exit_handler()
176 }
177
178 /// Probes and stores the DoH server with the given configurations.
179 /// Use the negative errno-style codes as the return value to represent the result.
180 /// # Safety
181 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
182 /// and not yet deleted by `doh_dispatcher_delete()`.
183 /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
184 #[no_mangle]
doh_net_new( doh: &DohDispatcher, net_id: uint32_t, url: *const c_char, domain: *const c_char, ip_addr: *const c_char, sk_mark: libc::uint32_t, cert_path: *const c_char, flags: &FeatureFlags, ) -> int32_t185 pub unsafe extern "C" fn doh_net_new(
186 doh: &DohDispatcher,
187 net_id: uint32_t,
188 url: *const c_char,
189 domain: *const c_char,
190 ip_addr: *const c_char,
191 sk_mark: libc::uint32_t,
192 cert_path: *const c_char,
193 flags: &FeatureFlags,
194 ) -> int32_t {
195 let (url, domain, ip_addr, cert_path) = match (
196 std::ffi::CStr::from_ptr(url).to_str(),
197 std::ffi::CStr::from_ptr(domain).to_str(),
198 std::ffi::CStr::from_ptr(ip_addr).to_str(),
199 std::ffi::CStr::from_ptr(cert_path).to_str(),
200 ) {
201 (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
202 if domain.is_empty() {
203 (url, None, ip_addr.to_string(), None)
204 } else if !cert_path.is_empty() {
205 (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
206 } else {
207 (
208 url,
209 Some(domain.to_string()),
210 ip_addr.to_string(),
211 Some(SYSTEM_CERT_PATH.to_string()),
212 )
213 }
214 }
215 _ => {
216 error!("bad input"); // Should not happen
217 return -libc::EINVAL;
218 }
219 };
220
221 let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
222 (Ok(url), Ok(ip_addr)) => (url, ip_addr),
223 _ => {
224 error!("bad ip or url"); // Should not happen
225 return -libc::EINVAL;
226 }
227 };
228 let cmd = Command::Probe {
229 info: ServerInfo {
230 net_id,
231 url,
232 peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
233 domain,
234 sk_mark,
235 cert_path,
236 idle_timeout_ms: flags.idle_timeout_ms,
237 use_session_resumption: flags.use_session_resumption,
238 enable_early_data: flags.enable_early_data,
239 },
240 timeout: Duration::from_millis(flags.probe_timeout_ms),
241 };
242 if let Err(e) = doh.lock().send_cmd(cmd) {
243 error!("Failed to send the probe: {:?}", e);
244 return -libc::EPIPE;
245 }
246 0
247 }
248
249 /// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
250 /// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
251 /// or the size of the answer.
252 /// # Safety
253 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
254 /// and not yet deleted by `doh_dispatcher_delete()`.
255 /// `dns_query` must point to a buffer at least `dns_query_len` in size.
256 /// `response` must point to a buffer at least `response_len` in size.
257 #[no_mangle]
doh_query( doh: &DohDispatcher, net_id: uint32_t, dns_query: *mut u8, dns_query_len: size_t, response: *mut u8, response_len: size_t, timeout_ms: uint64_t, ) -> ssize_t258 pub unsafe extern "C" fn doh_query(
259 doh: &DohDispatcher,
260 net_id: uint32_t,
261 dns_query: *mut u8,
262 dns_query_len: size_t,
263 response: *mut u8,
264 response_len: size_t,
265 timeout_ms: uint64_t,
266 ) -> ssize_t {
267 let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
268
269 let (resp_tx, resp_rx) = oneshot::channel();
270 let t = Duration::from_millis(timeout_ms);
271 if let Some(expired_time) = BootTime::now().checked_add(t) {
272 let cmd = Command::Query {
273 net_id,
274 base64_query: BASE64_URL_SAFE_NO_PAD.encode(q),
275 expired_time,
276 resp: resp_tx,
277 };
278
279 if let Err(e) = doh.lock().send_cmd(cmd) {
280 error!("Failed to send the query: {:?}", e);
281 return DOH_RESULT_CAN_NOT_SEND;
282 }
283 } else {
284 error!("Bad timeout parameter: {}", timeout_ms);
285 return DOH_RESULT_CAN_NOT_SEND;
286 }
287
288 if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
289 let local = task::LocalSet::new();
290 match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
291 Ok(v) => match v {
292 Ok(v) => match v {
293 Response::Success { answer } => {
294 if answer.len() > response_len || answer.len() > isize::MAX as usize {
295 return DOH_RESULT_INTERNAL_ERROR;
296 }
297 let response = slice::from_raw_parts_mut(response, answer.len());
298 response.copy_from_slice(&answer);
299 answer.len() as ssize_t
300 }
301 rsp => {
302 error!("Non-successful response: {:?}", rsp);
303 DOH_RESULT_CAN_NOT_SEND
304 }
305 },
306 Err(e) => {
307 error!("no result {}", e);
308 DOH_RESULT_CAN_NOT_SEND
309 }
310 },
311 Err(e) => {
312 error!("timeout: {}", e);
313 DOH_RESULT_TIMEOUT
314 }
315 }
316 } else {
317 DOH_RESULT_CAN_NOT_SEND
318 }
319 }
320
321 /// Clears the DoH servers associated with the given |netid|.
322 /// # Safety
323 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
324 /// and not yet deleted by `doh_dispatcher_delete()`.
325 #[no_mangle]
doh_net_delete(doh: &DohDispatcher, net_id: uint32_t)326 pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
327 if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
328 error!("Failed to send the query: {:?}", e);
329 }
330 }
331
332 #[cfg(test)]
333 mod tests {
334 use super::*;
335
336 const TEST_NET_ID: u32 = 50;
337 const LOOPBACK_ADDR: &str = "127.0.0.1:443";
338 const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
339
success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )340 extern "C" fn success_cb(
341 net_id: uint32_t,
342 success: bool,
343 ip_addr: *const c_char,
344 host: *const c_char,
345 ) {
346 assert!(success);
347 unsafe {
348 assert_validation_info(net_id, ip_addr, host);
349 }
350 }
351
fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )352 extern "C" fn fail_cb(
353 net_id: uint32_t,
354 success: bool,
355 ip_addr: *const c_char,
356 host: *const c_char,
357 ) {
358 assert!(!success);
359 unsafe {
360 assert_validation_info(net_id, ip_addr, host);
361 }
362 }
363
364 // # Safety
365 // `ip_addr`, `host` are null terminated strings
assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, )366 unsafe fn assert_validation_info(
367 net_id: uint32_t,
368 ip_addr: *const c_char,
369 host: *const c_char,
370 ) {
371 assert_eq!(net_id, TEST_NET_ID);
372 let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
373 let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
374 assert_eq!(ip_addr, expected_addr.ip().to_string());
375 let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
376 assert_eq!(host, "");
377 }
378
379 #[tokio::test]
wrap_validation_callback_converts_correctly()380 async fn wrap_validation_callback_converts_correctly() {
381 let info = ServerInfo {
382 net_id: TEST_NET_ID,
383 url: Url::parse(LOCALHOST_URL).unwrap(),
384 peer_addr: LOOPBACK_ADDR.parse().unwrap(),
385 domain: None,
386 sk_mark: 0,
387 cert_path: None,
388 idle_timeout_ms: 0,
389 use_session_resumption: true,
390 enable_early_data: true,
391 };
392
393 wrap_validation_callback(success_cb)(&info, true).await;
394 wrap_validation_callback(fail_cb)(&info, false).await;
395 }
396
tag_socket_cb(raw_fd: RawFd)397 extern "C" fn tag_socket_cb(raw_fd: RawFd) {
398 assert!(raw_fd > 0)
399 }
400
401 #[tokio::test]
wrap_tag_socket_callback_converts_correctly()402 async fn wrap_tag_socket_callback_converts_correctly() {
403 let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
404 wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
405 }
406 }
407