1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 //! `Resolver` trait and `DefaultDnsResolver` implementation. 15 16 use std::collections::HashMap; 17 use std::future::Future; 18 use std::io; 19 use std::io::Error; 20 use std::net::{SocketAddr, ToSocketAddrs}; 21 use std::pin::Pin; 22 use std::sync::{Arc, Mutex}; 23 use std::task::{Context, Poll}; 24 use std::time::{Duration, Instant}; 25 use std::vec::IntoIter; 26 27 use crate::runtime::JoinHandle; 28 29 const DEFAULT_TTL: Duration = Duration::from_secs(60); 30 const MAX_ENTRIES_LEN: usize = 30000; 31 32 /// `SocketAddr` resolved by `Resolver`. 33 pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Sync + Send>; 34 /// Possible errors that this resolver may generate when attempting to 35 /// resolve. 36 pub type StdError = Box<dyn std::error::Error + Send + Sync>; 37 /// Futures generated by this resolve when attempting to resolve an address. 38 pub type SocketFuture<'a> = 39 Pin<Box<dyn Future<Output = Result<Addrs, StdError>> + Sync + Send + 'a>>; 40 41 /// `Resolver` trait used by `async_impl::connector::HttpConnector`. `Resolver` 42 /// provides asynchronous dns resolve interfaces. 43 pub trait Resolver: Send + Sync + 'static { 44 /// resolve authority to a `SocketAddr` `Future`. resolve(&self, authority: &str) -> SocketFuture45 fn resolve(&self, authority: &str) -> SocketFuture; 46 } 47 48 /// `SocketAddr` resolved by `DefaultDnsResolver`. 49 pub struct ResolvedAddrs { 50 iter: IntoIter<SocketAddr>, 51 } 52 53 impl ResolvedAddrs { new(iter: IntoIter<SocketAddr>) -> Self54 pub(super) fn new(iter: IntoIter<SocketAddr>) -> Self { 55 Self { iter } 56 } 57 58 // The first ip in the dns record is the preferred addrs type. split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>)59 pub(super) fn split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>) { 60 // get preferred address family type. 61 let is_ipv6 = self 62 .iter 63 .as_slice() 64 .first() 65 .map(SocketAddr::is_ipv6) 66 .unwrap_or(false); 67 self.iter 68 .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == is_ipv6) 69 } 70 } 71 72 impl Iterator for ResolvedAddrs { 73 type Item = SocketAddr; 74 next(&mut self) -> Option<Self::Item>75 fn next(&mut self) -> Option<Self::Item> { 76 self.iter.next() 77 } 78 } 79 80 /// Futures generated by `DefaultDnsResolver`. 81 pub struct DefaultDnsFuture { 82 inner: JoinHandle<Result<ResolvedAddrs, Error>>, 83 } 84 85 impl Future for DefaultDnsFuture { 86 type Output = Result<Addrs, StdError>; 87 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>88 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 89 Pin::new(&mut self.inner).poll(cx).map(|res| match res { 90 Ok(Ok(addrs)) => Ok(Box::new(addrs) as Addrs), 91 Ok(Err(err)) => Err(Box::new(err) as StdError), 92 Err(err) => Err(Box::new(Error::new(io::ErrorKind::Interrupted, err)) as StdError), 93 }) 94 } 95 } 96 97 /// Default dns resolver used by the `Client`. 98 /// DefaultDnsResolver provides DNS resolver with caching machanism. 99 pub struct DefaultDnsResolver { 100 manager: DnsManager, // Manages DNS cache 101 connector: DnsConnector, // Performing DNS resolution 102 ttl: Duration, // Time-to-live for the DNS cache 103 } 104 105 impl Default for DefaultDnsResolver { 106 // Default constructor for `DefaultDnsResolver`, with a default TTL of 60 107 // seconds. default() -> Self108 fn default() -> Self { 109 DefaultDnsResolver { 110 manager: DnsManager::default(), 111 connector: DnsConnector {}, 112 ttl: DEFAULT_TTL, // Default TTL set to 60 seconds 113 } 114 } 115 } 116 117 impl DefaultDnsResolver { 118 /// Create a new DefaultDnsResolver. And TTL is Time to live for cache. 119 /// 120 /// # Examples 121 /// 122 /// ``` 123 /// use std::time::Duration; 124 /// 125 /// use ylong_http_client::async_impl::DefaultDnsResolver; 126 /// 127 /// let res = DefaultDnsResolver::new(Duration::from_secs(1)); 128 /// ``` new(ttl: Duration) -> Self129 pub fn new(ttl: Duration) -> Self { 130 DefaultDnsResolver { 131 manager: DnsManager::new(), 132 connector: DnsConnector {}, 133 ttl, // Set TTL through the passed parameters 134 } 135 } 136 } 137 138 #[derive(Default)] 139 struct DnsManager { 140 // Cache storing authority and DNS results 141 map: Mutex<HashMap<String, DnsResult>>, 142 } 143 144 impl DnsManager { 145 // Creates a new `DnsManager` instance with an empty cache new() -> Self146 fn new() -> Self { 147 DnsManager { 148 map: Mutex::new(HashMap::new()), 149 } 150 } 151 152 // Cleans expired DNS cache entries by retaining only valid ones clean_expired_entries(&self)153 fn clean_expired_entries(&self) { 154 let mut map_lock = self.map.lock().unwrap(); 155 if map_lock.len() > MAX_ENTRIES_LEN { 156 map_lock.retain(|_, result| result.inner.lock().unwrap().is_valid()); 157 } 158 } 159 } 160 161 struct DnsResult { 162 inner: Arc<Mutex<DnsResultInner>>, 163 } 164 165 impl DnsResult { 166 // Creates a new DNS result with the given addresses and expiration time new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self167 fn new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self { 168 DnsResult { 169 inner: Arc::new(Mutex::new(DnsResultInner { 170 addr, 171 expiration_time, 172 })), 173 } 174 } 175 } 176 177 #[derive(Clone)] 178 struct DnsResultInner { 179 addr: Vec<SocketAddr>, // List of resolved addresses for the authority 180 expiration_time: Instant, // Expiration time for the cache entry 181 } 182 183 impl DnsResultInner { 184 // Checks if the DNS result is still valid is_valid(&self) -> bool185 fn is_valid(&self) -> bool { 186 self.expiration_time > Instant::now() 187 } 188 } 189 190 impl Default for DnsResultInner { 191 // Default constructor for `DnsResultInner`, with an empty address list and 60 192 // seconds expiration default() -> Self193 fn default() -> Self { 194 DnsResultInner { 195 addr: vec![], 196 expiration_time: Instant::now() + Duration::from_secs(60), 197 } 198 } 199 } 200 201 struct DnsConnector {} 202 203 impl DnsConnector { 204 // Resolves the authority to a list of socket addresses get_socket_addrs(&self, authority: &str) -> Result<Vec<SocketAddr>, io::Error>205 fn get_socket_addrs(&self, authority: &str) -> Result<Vec<SocketAddr>, io::Error> { 206 authority 207 .to_socket_addrs() 208 .map(|addrs| addrs.collect()) 209 .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) 210 } 211 } 212 213 impl Resolver for DefaultDnsResolver { resolve(&self, authority: &str) -> SocketFuture214 fn resolve(&self, authority: &str) -> SocketFuture { 215 let authority = authority.to_string(); 216 self.manager.clean_expired_entries(); 217 Box::pin(async move { 218 let mut map_lock = self.manager.map.lock().unwrap(); 219 if let Some(addrs) = map_lock.get(&authority) { 220 let lock_inner = addrs.inner.lock().unwrap(); 221 if lock_inner.is_valid() { 222 return Ok(Box::new(lock_inner.addr.clone().into_iter()) as Addrs); 223 } 224 } 225 match self.connector.get_socket_addrs(&authority) { 226 Ok(addrs) => { 227 let dns_result = DnsResult::new(addrs.clone(), Instant::now() + self.ttl); 228 map_lock.insert(authority, dns_result); 229 Ok(Box::new(addrs.into_iter()) as Addrs) 230 } 231 Err(err) => Err(Box::new(err) as StdError), 232 } 233 }) 234 } 235 } 236 237 #[cfg(feature = "tokio_base")] 238 #[cfg(test)] 239 mod ut_dns_cache { 240 use super::*; 241 242 /// UT test cases for `DefaultDnsResolver::resolve`. 243 /// 244 /// # Brief 245 /// 1. Verify the first DNS result is cached when connected to Internet or 246 /// return error when without Internet. 247 /// 2. Verify the second DNS result as same as the first one. 248 #[tokio::test] ut_default_dns_resolver()249 async fn ut_default_dns_resolver() { 250 let domain = "example.com:0"; 251 let resolver = DefaultDnsResolver::new(std::time::Duration::from_millis(100)); 252 let result1 = resolver.resolve(domain).await; 253 let result2 = resolver.resolve(domain).await; 254 let result1 = result1 255 .map(|a| a.collect::<Vec<_>>()) 256 .err() 257 .map(|e| e.to_string()); 258 let result2 = result2 259 .map(|a| a.collect::<Vec<_>>()) 260 .err() 261 .map(|e| e.to_string()); 262 assert_eq!(result1, result2); 263 } 264 } 265 266 #[cfg(feature = "ylong_base")] 267 #[cfg(test)] 268 mod ut_dns_cache { 269 use super::*; 270 271 /// UT test cases for `DefaultDnsResolver::resolve`. 272 /// 273 /// # Brief 274 /// 1. Verify the first DNS result is cached when connected to Internet or 275 /// return error when without Internet. 276 /// 2. Verify the second DNS result as same as the first one. 277 #[test] ut_default_dns_resolver()278 fn ut_default_dns_resolver() { 279 ylong_runtime::block_on(ut_default_dns_resolver_async()); 280 } 281 ut_default_dns_resolver_async()282 async fn ut_default_dns_resolver_async() { 283 let domain = "example.com:0"; 284 let resolver = DefaultDnsResolver::new(std::time::Duration::from_millis(100)); 285 let result1 = resolver.resolve(domain).await; 286 let result2 = resolver.resolve(domain).await; 287 let result1 = result1 288 .map(|a| a.collect::<Vec<_>>()) 289 .err() 290 .map(|e| e.to_string()); 291 let result2 = result2 292 .map(|a| a.collect::<Vec<_>>()) 293 .err() 294 .map(|e| e.to_string()); 295 assert_eq!(result1, result2); 296 } 297 } 298