• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //! `Resolver` trait and `DefaultDnsResolver` implementation.
15 
16 use std::collections::HashMap;
17 use std::future::Future;
18 use std::io;
19 use std::io::Error;
20 use std::net::{SocketAddr, ToSocketAddrs};
21 use std::pin::Pin;
22 use std::sync::{Arc, Mutex};
23 use std::task::{Context, Poll};
24 use std::time::{Duration, Instant};
25 use std::vec::IntoIter;
26 
27 use crate::runtime::JoinHandle;
28 
29 const DEFAULT_TTL: Duration = Duration::from_secs(60);
30 const MAX_ENTRIES_LEN: usize = 30000;
31 
32 /// `SocketAddr` resolved by `Resolver`.
33 pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Sync + Send>;
34 /// Possible errors that this resolver may generate when attempting to
35 /// resolve.
36 pub type StdError = Box<dyn std::error::Error + Send + Sync>;
37 /// Futures generated by this resolve when attempting to resolve an address.
38 pub type SocketFuture<'a> =
39     Pin<Box<dyn Future<Output = Result<Addrs, StdError>> + Sync + Send + 'a>>;
40 
41 /// `Resolver` trait used by `async_impl::connector::HttpConnector`. `Resolver`
42 /// provides asynchronous dns resolve interfaces.
43 pub trait Resolver: Send + Sync + 'static {
44     /// resolve authority to a `SocketAddr` `Future`.
resolve(&self, authority: &str) -> SocketFuture45     fn resolve(&self, authority: &str) -> SocketFuture;
46 }
47 
48 /// `SocketAddr` resolved by `DefaultDnsResolver`.
49 pub struct ResolvedAddrs {
50     iter: IntoIter<SocketAddr>,
51 }
52 
53 impl ResolvedAddrs {
new(iter: IntoIter<SocketAddr>) -> Self54     pub(super) fn new(iter: IntoIter<SocketAddr>) -> Self {
55         Self { iter }
56     }
57 
58     // The first ip in the dns record is the preferred addrs type.
split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>)59     pub(super) fn split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>) {
60         // get preferred address family type.
61         let is_ipv6 = self
62             .iter
63             .as_slice()
64             .first()
65             .map(SocketAddr::is_ipv6)
66             .unwrap_or(false);
67         self.iter
68             .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == is_ipv6)
69     }
70 }
71 
72 impl Iterator for ResolvedAddrs {
73     type Item = SocketAddr;
74 
next(&mut self) -> Option<Self::Item>75     fn next(&mut self) -> Option<Self::Item> {
76         self.iter.next()
77     }
78 }
79 
80 /// Futures generated by `DefaultDnsResolver`.
81 pub struct DefaultDnsFuture {
82     inner: JoinHandle<Result<ResolvedAddrs, Error>>,
83 }
84 
85 impl Future for DefaultDnsFuture {
86     type Output = Result<Addrs, StdError>;
87 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>88     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
89         Pin::new(&mut self.inner).poll(cx).map(|res| match res {
90             Ok(Ok(addrs)) => Ok(Box::new(addrs) as Addrs),
91             Ok(Err(err)) => Err(Box::new(err) as StdError),
92             Err(err) => Err(Box::new(Error::new(io::ErrorKind::Interrupted, err)) as StdError),
93         })
94     }
95 }
96 
97 /// Default dns resolver used by the `Client`.
98 /// DefaultDnsResolver provides DNS resolver with caching machanism.
99 pub struct DefaultDnsResolver {
100     manager: DnsManager,     // Manages DNS cache
101     connector: DnsConnector, // Performing DNS resolution
102     ttl: Duration,           // Time-to-live for the DNS cache
103 }
104 
105 impl Default for DefaultDnsResolver {
106     // Default constructor for `DefaultDnsResolver`, with a default TTL of 60
107     // seconds.
default() -> Self108     fn default() -> Self {
109         DefaultDnsResolver {
110             manager: DnsManager::default(),
111             connector: DnsConnector {},
112             ttl: DEFAULT_TTL, // Default TTL set to 60 seconds
113         }
114     }
115 }
116 
117 impl DefaultDnsResolver {
118     /// Create a new DefaultDnsResolver. And TTL is Time to live for cache.
119     ///
120     /// # Examples
121     ///
122     /// ```
123     /// use std::time::Duration;
124     ///
125     /// use ylong_http_client::async_impl::DefaultDnsResolver;
126     ///
127     /// let res = DefaultDnsResolver::new(Duration::from_secs(1));
128     /// ```
new(ttl: Duration) -> Self129     pub fn new(ttl: Duration) -> Self {
130         DefaultDnsResolver {
131             manager: DnsManager::new(),
132             connector: DnsConnector {},
133             ttl, // Set TTL through the passed parameters
134         }
135     }
136 }
137 
138 #[derive(Default)]
139 struct DnsManager {
140     // Cache storing authority and DNS results
141     map: Mutex<HashMap<String, DnsResult>>,
142 }
143 
144 impl DnsManager {
145     // Creates a new `DnsManager` instance with an empty cache
new() -> Self146     fn new() -> Self {
147         DnsManager {
148             map: Mutex::new(HashMap::new()),
149         }
150     }
151 
152     // Cleans expired DNS cache entries by retaining only valid ones
clean_expired_entries(&self)153     fn clean_expired_entries(&self) {
154         let mut map_lock = self.map.lock().unwrap();
155         if map_lock.len() > MAX_ENTRIES_LEN {
156             map_lock.retain(|_, result| result.inner.lock().unwrap().is_valid());
157         }
158     }
159 }
160 
161 struct DnsResult {
162     inner: Arc<Mutex<DnsResultInner>>,
163 }
164 
165 impl DnsResult {
166     // Creates a new DNS result with the given addresses and expiration time
new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self167     fn new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self {
168         DnsResult {
169             inner: Arc::new(Mutex::new(DnsResultInner {
170                 addr,
171                 expiration_time,
172             })),
173         }
174     }
175 }
176 
177 #[derive(Clone)]
178 struct DnsResultInner {
179     addr: Vec<SocketAddr>,    // List of resolved addresses for the authority
180     expiration_time: Instant, // Expiration time for the cache entry
181 }
182 
183 impl DnsResultInner {
184     // Checks if the DNS result is still valid
is_valid(&self) -> bool185     fn is_valid(&self) -> bool {
186         self.expiration_time > Instant::now()
187     }
188 }
189 
190 impl Default for DnsResultInner {
191     // Default constructor for `DnsResultInner`, with an empty address list and 60
192     // seconds expiration
default() -> Self193     fn default() -> Self {
194         DnsResultInner {
195             addr: vec![],
196             expiration_time: Instant::now() + Duration::from_secs(60),
197         }
198     }
199 }
200 
201 struct DnsConnector {}
202 
203 impl DnsConnector {
204     // Resolves the authority to a list of socket addresses
get_socket_addrs(&self, authority: &str) -> Result<Vec<SocketAddr>, io::Error>205     fn get_socket_addrs(&self, authority: &str) -> Result<Vec<SocketAddr>, io::Error> {
206         authority
207             .to_socket_addrs()
208             .map(|addrs| addrs.collect())
209             .map_err(|err| io::Error::new(io::ErrorKind::Other, err))
210     }
211 }
212 
213 impl Resolver for DefaultDnsResolver {
resolve(&self, authority: &str) -> SocketFuture214     fn resolve(&self, authority: &str) -> SocketFuture {
215         let authority = authority.to_string();
216         self.manager.clean_expired_entries();
217         Box::pin(async move {
218             let mut map_lock = self.manager.map.lock().unwrap();
219             if let Some(addrs) = map_lock.get(&authority) {
220                 let lock_inner = addrs.inner.lock().unwrap();
221                 if lock_inner.is_valid() {
222                     return Ok(Box::new(lock_inner.addr.clone().into_iter()) as Addrs);
223                 }
224             }
225             match self.connector.get_socket_addrs(&authority) {
226                 Ok(addrs) => {
227                     let dns_result = DnsResult::new(addrs.clone(), Instant::now() + self.ttl);
228                     map_lock.insert(authority, dns_result);
229                     Ok(Box::new(addrs.into_iter()) as Addrs)
230                 }
231                 Err(err) => Err(Box::new(err) as StdError),
232             }
233         })
234     }
235 }
236 
237 #[cfg(feature = "tokio_base")]
238 #[cfg(test)]
239 mod ut_dns_cache {
240     use super::*;
241 
242     /// UT test cases for `DefaultDnsResolver::resolve`.
243     ///
244     /// # Brief
245     /// 1. Verify the first DNS result is cached when connected to Internet or
246     ///    return error when without Internet.
247     /// 2. Verify the second DNS result as same as the first one.
248     #[tokio::test]
ut_default_dns_resolver()249     async fn ut_default_dns_resolver() {
250         let domain = "example.com:0";
251         let resolver = DefaultDnsResolver::new(std::time::Duration::from_millis(100));
252         let result1 = resolver.resolve(domain).await;
253         let result2 = resolver.resolve(domain).await;
254         let result1 = result1
255             .map(|a| a.collect::<Vec<_>>())
256             .err()
257             .map(|e| e.to_string());
258         let result2 = result2
259             .map(|a| a.collect::<Vec<_>>())
260             .err()
261             .map(|e| e.to_string());
262         assert_eq!(result1, result2);
263     }
264 }
265 
266 #[cfg(feature = "ylong_base")]
267 #[cfg(test)]
268 mod ut_dns_cache {
269     use super::*;
270 
271     /// UT test cases for `DefaultDnsResolver::resolve`.
272     ///
273     /// # Brief
274     /// 1. Verify the first DNS result is cached when connected to Internet or
275     ///    return error when without Internet.
276     /// 2. Verify the second DNS result as same as the first one.
277     #[test]
ut_default_dns_resolver()278     fn ut_default_dns_resolver() {
279         ylong_runtime::block_on(ut_default_dns_resolver_async());
280     }
281 
ut_default_dns_resolver_async()282     async fn ut_default_dns_resolver_async() {
283         let domain = "example.com:0";
284         let resolver = DefaultDnsResolver::new(std::time::Duration::from_millis(100));
285         let result1 = resolver.resolve(domain).await;
286         let result2 = resolver.resolve(domain).await;
287         let result1 = result1
288             .map(|a| a.collect::<Vec<_>>())
289             .err()
290             .map(|e| e.to_string());
291         let result2 = result2
292             .map(|a| a.collect::<Vec<_>>())
293             .err()
294             .map(|e| e.to_string());
295         assert_eq!(result1, result2);
296     }
297 }
298