• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //! `Resolver` trait and `DefaultDnsResolver` implementation.
15 
16 use std::collections::HashMap;
17 use std::future::Future;
18 use std::io;
19 use std::net::SocketAddr;
20 use std::pin::Pin;
21 use std::sync::{Arc, Mutex, OnceLock};
22 use std::task::{Context, Poll};
23 use std::time::{Duration, Instant};
24 use std::vec::IntoIter;
25 
26 use crate::runtime::JoinHandle;
27 
28 const DEFAULT_MAX_LEN: usize = 30000;
29 const DEFAULT_TTL: Duration = Duration::from_secs(60);
30 
31 /// `SocketAddr` resolved by `Resolver`.
32 pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Sync + Send>;
33 /// Possible errors that this resolver may generate when attempting to
34 /// resolve.
35 pub type StdError = Box<dyn std::error::Error + Send + Sync>;
36 /// Futures generated by this resolve when attempting to resolve an address.
37 pub type SocketFuture<'a> =
38     Pin<Box<dyn Future<Output = Result<Addrs, StdError>> + Sync + Send + 'a>>;
39 
40 /// `Resolver` trait used by `async_impl::connector::HttpConnector`. `Resolver`
41 /// provides asynchronous dns resolve interfaces.
42 pub trait Resolver: Send + Sync + 'static {
43     /// resolve authority to a `SocketAddr` `Future`.
resolve(&self, authority: &str) -> SocketFuture44     fn resolve(&self, authority: &str) -> SocketFuture;
45 }
46 
47 /// `SocketAddr` resolved by `DefaultDnsResolver`.
48 pub struct ResolvedAddrs {
49     iter: IntoIter<SocketAddr>,
50 }
51 
52 impl ResolvedAddrs {
new(iter: IntoIter<SocketAddr>) -> Self53     pub(super) fn new(iter: IntoIter<SocketAddr>) -> Self {
54         Self { iter }
55     }
56 
57     // The first ip in the dns record is the preferred addrs type.
split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>)58     pub(super) fn split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>) {
59         // get preferred address family type.
60         let is_ipv6 = self
61             .iter
62             .as_slice()
63             .first()
64             .map(SocketAddr::is_ipv6)
65             .unwrap_or(false);
66         self.iter
67             .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == is_ipv6)
68     }
69 }
70 
71 impl Iterator for ResolvedAddrs {
72     type Item = SocketAddr;
73 
next(&mut self) -> Option<Self::Item>74     fn next(&mut self) -> Option<Self::Item> {
75         self.iter.next()
76     }
77 }
78 
79 /// Futures generated by `DefaultDnsResolver`.
80 pub struct DefaultDnsFuture {
81     inner: JoinHandle<Result<ResolvedAddrs, io::Error>>,
82 }
83 
84 impl DefaultDnsFuture {
new(handle: JoinHandle<Result<ResolvedAddrs, io::Error>>) -> Self85     pub(crate) fn new(handle: JoinHandle<Result<ResolvedAddrs, io::Error>>) -> Self {
86         DefaultDnsFuture { inner: handle }
87     }
88 }
89 
90 impl Future for DefaultDnsFuture {
91     type Output = Result<Addrs, StdError>;
92 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>93     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
94         Pin::new(&mut self.inner).poll(cx).map(|res| match res {
95             Ok(Ok(addrs)) => Ok(Box::new(addrs) as Addrs),
96             Ok(Err(err)) => Err(Box::new(err) as StdError),
97             Err(err) => Err(Box::new(io::Error::new(io::ErrorKind::Interrupted, err)) as StdError),
98         })
99     }
100 }
101 
102 pub(crate) struct DnsManager {
103     /// Cache storing authority and DNS results
104     pub(crate) map: Arc<Mutex<HashMap<String, DnsResult>>>,
105     max_entries_len: usize,
106     /// Time-to-live for the DNS cache
107     pub(crate) ttl: Duration,
108 }
109 
110 impl Default for DnsManager {
111     // Default constructor for `DnsManager`, with a default TTL of 60
112     // seconds.
default() -> Self113     fn default() -> Self {
114         DnsManager {
115             map: Default::default(),
116             max_entries_len: DEFAULT_MAX_LEN,
117             ttl: DEFAULT_TTL, // Default TTL set to 60 seconds
118         }
119     }
120 }
121 
122 impl DnsManager {
123     /// Global DNS Manager
global_dns_manager() -> Arc<Mutex<DnsManager>>124     pub(crate) fn global_dns_manager() -> Arc<Mutex<DnsManager>> {
125         static GLOBAL_DNS_MANAGER: OnceLock<Arc<Mutex<DnsManager>>> = OnceLock::new();
126         GLOBAL_DNS_MANAGER
127             .get_or_init(|| Arc::new(Mutex::new(DnsManager::default())))
128             .clone()
129     }
130 
131     /// Cleans expired DNS cache entries by retaining only valid ones
clean_expired_entries(&self)132     pub(crate) fn clean_expired_entries(&self) {
133         let mut map_lock = self.map.lock().unwrap();
134         if map_lock.len() > self.max_entries_len {
135             map_lock.retain(|_, result| result.is_valid());
136         }
137     }
138 }
139 
140 #[derive(Clone)]
141 pub(crate) struct DnsResult {
142     /// List of resolved addresses for the authority
143     pub(crate) addr: Vec<SocketAddr>,
144     /// Expiration time for the cache entry
145     expiration_time: Instant,
146 }
147 
148 impl DnsResult {
149     /// Creates a new DNS result with the given addresses and expiration time
new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self150     pub(crate) fn new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self {
151         DnsResult {
152             addr,
153             expiration_time,
154         }
155     }
156 
157     /// Checks if the DNS result is still valid
is_valid(&self) -> bool158     pub(crate) fn is_valid(&self) -> bool {
159         self.expiration_time > Instant::now()
160     }
161 }
162 
163 impl Default for DnsResult {
164     // Default constructor for `DnsResult`, with an empty address list and 60
165     // seconds expiration
default() -> Self166     fn default() -> Self {
167         DnsResult {
168             addr: vec![],
169             expiration_time: Instant::now() + DEFAULT_TTL,
170         }
171     }
172 }
173 
174 #[cfg(test)]
175 mod ut_resover_test {
176     use super::*;
177 
178     /// UT test case for `DnsManager::new`
179     ///
180     /// # Brief
181     /// 1. Creates a new `DnsManager` instance.
182     /// 2. Verifies the default `max_entries_len` is 30000.
183     /// 3. Sets and verifies a new `max_entries_len` of 1.
184     #[test]
ut_dns_manager_new()185     fn ut_dns_manager_new() {
186         let manager = DnsManager::default();
187         assert_eq!(manager.max_entries_len, 30000);
188         let mut map = manager.map.lock().unwrap();
189         map.insert(
190             "example.com".to_string(),
191             DnsResult::new(vec![SocketAddr::from(([0, 0, 0, 1], 1))], Instant::now()),
192         );
193         assert!(map.contains_key("example.com"));
194     }
195 
196     /// UT test case for `DnsManager::clean_expired_entries`
197     ///
198     /// # Brief
199     /// 1. Creates a `DnsManager` instance and sets `max_entries_len` to 1.
200     /// 2. Adds two DNS results to the cache: one valid and one expired.
201     /// 3. Calls `clean_expired_entries` to remove expired entries.
202     /// 4. Verifies the expired entry is removed from the cache.
203     #[test]
ut_dns_manager_clean_cache()204     fn ut_dns_manager_clean_cache() {
205         let manager = DnsManager {
206             max_entries_len: 1,
207             ..Default::default()
208         };
209         let mut map = manager.map.lock().unwrap();
210         map.insert(
211             "example1.com".to_string(),
212             DnsResult::new(
213                 vec![SocketAddr::from(([0, 0, 0, 1], 1))],
214                 Instant::now() + Duration::from_secs(60),
215             ),
216         );
217         map.insert(
218             "example2.com".to_string(),
219             DnsResult::new(
220                 vec![SocketAddr::from(([0, 0, 0, 2], 2))],
221                 Instant::now() - Duration::from_secs(60),
222             ),
223         );
224         drop(map);
225         manager.clean_expired_entries();
226         assert!(manager.map.lock().unwrap().contains_key("example1.com"));
227         assert!(!manager.map.lock().unwrap().contains_key("example2.com"));
228     }
229 }
230