1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 //! `Resolver` trait and `DefaultDnsResolver` implementation. 15 16 use std::collections::HashMap; 17 use std::future::Future; 18 use std::io; 19 use std::net::SocketAddr; 20 use std::pin::Pin; 21 use std::sync::{Arc, Mutex, OnceLock}; 22 use std::task::{Context, Poll}; 23 use std::time::{Duration, Instant}; 24 use std::vec::IntoIter; 25 26 use crate::runtime::JoinHandle; 27 28 const DEFAULT_MAX_LEN: usize = 30000; 29 const DEFAULT_TTL: Duration = Duration::from_secs(60); 30 31 /// `SocketAddr` resolved by `Resolver`. 32 pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Sync + Send>; 33 /// Possible errors that this resolver may generate when attempting to 34 /// resolve. 35 pub type StdError = Box<dyn std::error::Error + Send + Sync>; 36 /// Futures generated by this resolve when attempting to resolve an address. 37 pub type SocketFuture<'a> = 38 Pin<Box<dyn Future<Output = Result<Addrs, StdError>> + Sync + Send + 'a>>; 39 40 /// `Resolver` trait used by `async_impl::connector::HttpConnector`. `Resolver` 41 /// provides asynchronous dns resolve interfaces. 42 pub trait Resolver: Send + Sync + 'static { 43 /// resolve authority to a `SocketAddr` `Future`. resolve(&self, authority: &str) -> SocketFuture44 fn resolve(&self, authority: &str) -> SocketFuture; 45 } 46 47 /// `SocketAddr` resolved by `DefaultDnsResolver`. 48 pub struct ResolvedAddrs { 49 iter: IntoIter<SocketAddr>, 50 } 51 52 impl ResolvedAddrs { new(iter: IntoIter<SocketAddr>) -> Self53 pub(super) fn new(iter: IntoIter<SocketAddr>) -> Self { 54 Self { iter } 55 } 56 57 // The first ip in the dns record is the preferred addrs type. split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>)58 pub(super) fn split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>) { 59 // get preferred address family type. 60 let is_ipv6 = self 61 .iter 62 .as_slice() 63 .first() 64 .map(SocketAddr::is_ipv6) 65 .unwrap_or(false); 66 self.iter 67 .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == is_ipv6) 68 } 69 } 70 71 impl Iterator for ResolvedAddrs { 72 type Item = SocketAddr; 73 next(&mut self) -> Option<Self::Item>74 fn next(&mut self) -> Option<Self::Item> { 75 self.iter.next() 76 } 77 } 78 79 /// Futures generated by `DefaultDnsResolver`. 80 pub struct DefaultDnsFuture { 81 inner: JoinHandle<Result<ResolvedAddrs, io::Error>>, 82 } 83 84 impl DefaultDnsFuture { new(handle: JoinHandle<Result<ResolvedAddrs, io::Error>>) -> Self85 pub(crate) fn new(handle: JoinHandle<Result<ResolvedAddrs, io::Error>>) -> Self { 86 DefaultDnsFuture { inner: handle } 87 } 88 } 89 90 impl Future for DefaultDnsFuture { 91 type Output = Result<Addrs, StdError>; 92 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>93 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 94 Pin::new(&mut self.inner).poll(cx).map(|res| match res { 95 Ok(Ok(addrs)) => Ok(Box::new(addrs) as Addrs), 96 Ok(Err(err)) => Err(Box::new(err) as StdError), 97 Err(err) => Err(Box::new(io::Error::new(io::ErrorKind::Interrupted, err)) as StdError), 98 }) 99 } 100 } 101 102 pub(crate) struct DnsManager { 103 /// Cache storing authority and DNS results 104 pub(crate) map: Arc<Mutex<HashMap<String, DnsResult>>>, 105 max_entries_len: usize, 106 /// Time-to-live for the DNS cache 107 pub(crate) ttl: Duration, 108 } 109 110 impl Default for DnsManager { 111 // Default constructor for `DnsManager`, with a default TTL of 60 112 // seconds. default() -> Self113 fn default() -> Self { 114 DnsManager { 115 map: Default::default(), 116 max_entries_len: DEFAULT_MAX_LEN, 117 ttl: DEFAULT_TTL, // Default TTL set to 60 seconds 118 } 119 } 120 } 121 122 impl DnsManager { 123 /// Global DNS Manager global_dns_manager() -> Arc<Mutex<DnsManager>>124 pub(crate) fn global_dns_manager() -> Arc<Mutex<DnsManager>> { 125 static GLOBAL_DNS_MANAGER: OnceLock<Arc<Mutex<DnsManager>>> = OnceLock::new(); 126 GLOBAL_DNS_MANAGER 127 .get_or_init(|| Arc::new(Mutex::new(DnsManager::default()))) 128 .clone() 129 } 130 131 /// Cleans expired DNS cache entries by retaining only valid ones clean_expired_entries(&self)132 pub(crate) fn clean_expired_entries(&self) { 133 let mut map_lock = self.map.lock().unwrap(); 134 if map_lock.len() > self.max_entries_len { 135 map_lock.retain(|_, result| result.is_valid()); 136 } 137 } 138 } 139 140 #[derive(Clone)] 141 pub(crate) struct DnsResult { 142 /// List of resolved addresses for the authority 143 pub(crate) addr: Vec<SocketAddr>, 144 /// Expiration time for the cache entry 145 expiration_time: Instant, 146 } 147 148 impl DnsResult { 149 /// Creates a new DNS result with the given addresses and expiration time new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self150 pub(crate) fn new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self { 151 DnsResult { 152 addr, 153 expiration_time, 154 } 155 } 156 157 /// Checks if the DNS result is still valid is_valid(&self) -> bool158 pub(crate) fn is_valid(&self) -> bool { 159 self.expiration_time > Instant::now() 160 } 161 } 162 163 impl Default for DnsResult { 164 // Default constructor for `DnsResult`, with an empty address list and 60 165 // seconds expiration default() -> Self166 fn default() -> Self { 167 DnsResult { 168 addr: vec![], 169 expiration_time: Instant::now() + DEFAULT_TTL, 170 } 171 } 172 } 173 174 #[cfg(test)] 175 mod ut_resover_test { 176 use super::*; 177 178 /// UT test case for `DnsManager::new` 179 /// 180 /// # Brief 181 /// 1. Creates a new `DnsManager` instance. 182 /// 2. Verifies the default `max_entries_len` is 30000. 183 /// 3. Sets and verifies a new `max_entries_len` of 1. 184 #[test] ut_dns_manager_new()185 fn ut_dns_manager_new() { 186 let manager = DnsManager::default(); 187 assert_eq!(manager.max_entries_len, 30000); 188 let mut map = manager.map.lock().unwrap(); 189 map.insert( 190 "example.com".to_string(), 191 DnsResult::new(vec![SocketAddr::from(([0, 0, 0, 1], 1))], Instant::now()), 192 ); 193 assert!(map.contains_key("example.com")); 194 } 195 196 /// UT test case for `DnsManager::clean_expired_entries` 197 /// 198 /// # Brief 199 /// 1. Creates a `DnsManager` instance and sets `max_entries_len` to 1. 200 /// 2. Adds two DNS results to the cache: one valid and one expired. 201 /// 3. Calls `clean_expired_entries` to remove expired entries. 202 /// 4. Verifies the expired entry is removed from the cache. 203 #[test] ut_dns_manager_clean_cache()204 fn ut_dns_manager_clean_cache() { 205 let manager = DnsManager { 206 max_entries_len: 1, 207 ..Default::default() 208 }; 209 let mut map = manager.map.lock().unwrap(); 210 map.insert( 211 "example1.com".to_string(), 212 DnsResult::new( 213 vec![SocketAddr::from(([0, 0, 0, 1], 1))], 214 Instant::now() + Duration::from_secs(60), 215 ), 216 ); 217 map.insert( 218 "example2.com".to_string(), 219 DnsResult::new( 220 vec![SocketAddr::from(([0, 0, 0, 2], 2))], 221 Instant::now() - Duration::from_secs(60), 222 ), 223 ); 224 drop(map); 225 manager.clean_expired_entries(); 226 assert!(manager.map.lock().unwrap().contains_key("example1.com")); 227 assert!(!manager.map.lock().unwrap().contains_key("example2.com")); 228 } 229 } 230