• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2025 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::io;
15 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
16 use std::str::FromStr;
17 use std::time::{Duration, Instant};
18 
19 use crate::async_impl::dns::resolver::{DefaultDnsFuture, DnsManager, DnsResult, ResolvedAddrs};
20 use crate::async_impl::{Body, Client, Request, Resolver, SocketFuture};
21 use crate::HttpClientError;
22 
23 const DEFAULT_MAX_RETRY_COUNT: i32 = 1;
24 
25 /// Doh resolver used by the `Client`.
26 ///
27 /// # Examples
28 ///
29 /// ```
30 /// use ylong_http_client::async_impl::{Client, DohResolver};
31 ///
32 /// let doh_resolver = DohResolver::new("https://1.12.12.12/dns-query");
33 /// let _doh_client = Client::builder()
34 ///     .dns_resolver(doh_resolver)
35 ///     .build()
36 ///     .unwrap();
37 /// ```
38 pub struct DohResolver {
39     manager: Option<DnsManager>,
40     connector: DohConnector,
41 }
42 
43 impl DohResolver {
44     /// Creates a new DohResolver. And sets DOH server.
45     ///
46     /// # Examples
47     ///
48     /// ```
49     /// use ylong_http_client::async_impl::DohResolver;
50     ///
51     /// let res = DohResolver::new("https://1.12.12.12/dns-query");
52     /// ```
new(doh_server: &str) -> Self53     pub fn new(doh_server: &str) -> Self {
54         Self {
55             manager: Some(DnsManager::default()),
56             connector: DohConnector::new(doh_server),
57         }
58     }
59 
60     /// Adds the doh server.
61     ///
62     /// # Examples
63     ///
64     /// ```
65     /// use ylong_http_client::async_impl::DohResolver;
66     ///
67     /// let res = DohResolver::new("https://1.12.12.12/dns-query")
68     ///     .add_doh_server("https://1.12.12.12/dns-query");
69     /// ```
add_doh_server(mut self, doh_server: &str) -> Self70     pub fn add_doh_server(mut self, doh_server: &str) -> Self {
71         self.connector.add_doh_server(doh_server);
72         self
73     }
74 
75     /// Sets whether to use global DNS cache, default is false.
76     ///
77     /// # Examples
78     ///
79     /// ```
80     /// use ylong_http_client::async_impl::DohResolver;
81     ///
82     /// let res = DohResolver::new("https://1.12.12.12/dns-query").global_dns_cache(false);
83     /// ```
global_dns_cache(mut self, use_global: bool) -> Self84     pub fn global_dns_cache(mut self, use_global: bool) -> Self {
85         self.manager = (!use_global).then(DnsManager::default);
86         self
87     }
88 
89     /// Sets DNS ttl, default is 60 second.
90     ///
91     /// This will does nothing if `global_dns_cache` is set to true.
92     ///
93     /// # Examples
94     ///
95     /// ```
96     /// use std::time::Duration;
97     ///
98     /// use ylong_http_client::async_impl::DohResolver;
99     ///
100     /// let res = DohResolver::new("https://1.12.12.12/dns-query").set_ttl(Duration::from_secs(30));
101     /// ```
set_ttl(mut self, ttl: Duration) -> Self102     pub fn set_ttl(mut self, ttl: Duration) -> Self {
103         if let Some(manager) = self.manager.as_mut() {
104             manager.ttl = ttl
105         }
106         self
107     }
108 }
109 
110 #[derive(Clone)]
111 struct DohConnector {
112     doh_servers: Vec<String>,
113     max_retry_count: i32,
114 }
115 
116 impl DohConnector {
new(doh_server: &str) -> Self117     fn new(doh_server: &str) -> Self {
118         DohConnector {
119             doh_servers: vec![doh_server.to_string()],
120             max_retry_count: DEFAULT_MAX_RETRY_COUNT,
121         }
122     }
123 
add_doh_server(&mut self, doh_server: &str)124     fn add_doh_server(&mut self, doh_server: &str) {
125         self.doh_servers.push(doh_server.to_string());
126     }
127 
retry(&self, authority: &str) -> Result<(Vec<SocketAddr>, u64), HttpClientError>128     async fn retry(&self, authority: &str) -> Result<(Vec<SocketAddr>, u64), HttpClientError> {
129         for _ in 0..self.max_retry_count {
130             for server in self.doh_servers.iter() {
131                 if let Ok((socket_addr, ttl)) = self.doh_connect(authority, server.clone()).await {
132                     return Ok((socket_addr, ttl));
133                 }
134             }
135         }
136         Err(HttpClientError::from_str(
137             crate::ErrorKind::Connect,
138             "Can't find valid address",
139         ))
140     }
141 
142     /// Connects to the DOH server and retrieves DNS information.
doh_connect( &self, authority: &str, doh_server: String, ) -> Result<(Vec<SocketAddr>, u64), HttpClientError>143     async fn doh_connect(
144         &self,
145         authority: &str,
146         doh_server: String,
147     ) -> Result<(Vec<SocketAddr>, u64), HttpClientError> {
148         let part: Vec<&str> = authority.split(':').collect();
149         let host: &str = part[0];
150         let port: u16 = part[1].parse().unwrap();
151         let url_4 = format!("{}?name={}&type=A", doh_server, host);
152         let url_6 = format!("{}?name={}&type=AAAA", doh_server, host);
153         let client_4 = Client::builder().build()?;
154         let client_6 = Client::builder().build()?;
155         let request_4 = Request::builder().url(&url_4).body(Body::empty())?;
156         let request_6 = Request::builder().url(&url_6).body(Body::empty())?;
157         let response_4 = client_4.request(request_4).await?;
158         let response_6 = client_6.request(request_6).await?;
159         let text_4 = response_4.text().await?;
160         let text_6 = response_6.text().await?;
161         let text = format!("{},{}", text_4, text_6);
162         Ok(Self::get_info(&text, port))
163     }
164 
165     /// Parses and extracts information from the DNS response text.
get_info(text: &str, port: u16) -> (Vec<SocketAddr>, u64)166     fn get_info(text: &str, port: u16) -> (Vec<SocketAddr>, u64) {
167         let mut ips = Vec::new();
168         let mut start = 0;
169         let mut ttl = u64::MAX;
170         while let Some((answer_end, answer_str)) = Self::get_answer_str(text, start) {
171             if let Some(socket_addr) = Self::get_socket_addr(answer_str, port) {
172                 if let Some(answer_ttl) = Self::get_ttl(answer_str) {
173                     ips.push(socket_addr);
174                     ttl = std::cmp::min(ttl, answer_ttl);
175                 }
176             }
177             start = answer_end + 1;
178         }
179         ttl = if ttl == u64::MAX { 0 } else { ttl };
180         (ips, ttl)
181     }
182 
get_answer_str(answer_section: &str, start: usize) -> Option<(usize, &str)>183     fn get_answer_str(answer_section: &str, start: usize) -> Option<(usize, &str)> {
184         let answer_start = answer_section[start..].find('{').map(|pos| start + pos)?;
185         let answer_end = answer_section[answer_start..].find('}').unwrap() + answer_start;
186         Some((answer_end, &answer_section[answer_start..answer_end]))
187     }
188 
get_socket_addr(answer_str: &str, port: u16) -> Option<SocketAddr>189     fn get_socket_addr(answer_str: &str, port: u16) -> Option<SocketAddr> {
190         let data_str = r#""data":""#;
191         if let Some(ip_pos) = answer_str.find(data_str) {
192             let ip_start = ip_pos + data_str.len();
193             if let Some(ip_end) = answer_str[ip_start..].find('\"') {
194                 let ip = &answer_str[ip_start..ip_start + ip_end];
195                 if let Ok(ipv4_addr) = Ipv4Addr::from_str(ip) {
196                     return Some(SocketAddr::new(IpAddr::V4(ipv4_addr), port));
197                 }
198                 if let Ok(ipv6_addr) = Ipv6Addr::from_str(ip) {
199                     return Some(SocketAddr::new(IpAddr::V6(ipv6_addr), port));
200                 }
201             }
202         }
203         None
204     }
205 
get_ttl(answer_str: &str) -> Option<u64>206     fn get_ttl(answer_str: &str) -> Option<u64> {
207         let ttl_str = r#""TTL":"#;
208         if let Some(ttl_pos) = answer_str.find(ttl_str) {
209             let ttl_start = ttl_pos + ttl_str.len();
210             if let Some(ttl_end) = answer_str[ttl_start..].find(',') {
211                 let ttl: u64 = answer_str[ttl_start..ttl_start + ttl_end].parse().unwrap();
212                 return Some(ttl);
213             }
214         }
215         None
216     }
217 }
218 
219 impl Resolver for DohResolver {
resolve(&self, authority: &str) -> SocketFuture220     fn resolve(&self, authority: &str) -> SocketFuture {
221         let authority = authority.to_string();
222         let map = match &self.manager {
223             None => {
224                 let manager = DnsManager::global_dns_manager();
225                 let manager_guard = manager.lock().unwrap();
226                 manager_guard.clean_expired_entries();
227                 manager_guard.map.clone()
228             }
229             Some(manager) => {
230                 manager.clean_expired_entries();
231                 manager.map.clone()
232             }
233         };
234         let connector = self.connector.clone();
235         let handle = crate::runtime::spawn_blocking(move || {
236             let mut map_lock = map.lock().unwrap();
237             if let Some(addrs) = map_lock.get(&authority) {
238                 if addrs.is_valid() {
239                     return Ok(ResolvedAddrs::new(addrs.addr.clone().into_iter()));
240                 }
241             }
242             #[cfg(feature = "ylong_base")]
243             let result = ylong_runtime::block_on(connector.retry(&authority));
244             #[cfg(feature = "tokio_base")]
245             let result = tokio::runtime::Runtime::new()
246                 .unwrap()
247                 .block_on(connector.retry(&authority));
248             match result {
249                 Ok((addrs, ttl)) => {
250                     let dns_result =
251                         DnsResult::new(addrs.clone(), Instant::now() + Duration::from_secs(ttl));
252                     map_lock.insert(authority, dns_result);
253                     Ok(ResolvedAddrs::new(addrs.into_iter()))
254                 }
255                 Err(err) => Err(io::Error::new(io::ErrorKind::Other, err)),
256             }
257         });
258         Box::pin(DefaultDnsFuture::new(handle))
259     }
260 }
261 
262 #[cfg(test)]
263 mod ut_doh_test {
264     use super::*;
265 
266     /// UT test case for `DohResolver::global_dns_cache`
267     ///
268     /// # Brief
269     /// 1. Creates a new `DohResolver` instance.
270     /// 2. Verifies the default `manager` is None.
271     /// 3. Calls `global_dns_cache` and check manager.
272     #[test]
ut_dns_resolver_global()273     fn ut_dns_resolver_global() {
274         let mut resolver = DohResolver::new("https://1.12.12.12/dns-query");
275         assert!(resolver.manager.is_some());
276         resolver = resolver.global_dns_cache(true);
277         assert!(resolver.manager.is_none());
278         resolver = resolver.global_dns_cache(false);
279         assert!(resolver.manager.is_some());
280     }
281 
282     /// UT test case for `DohResolver::set_ttl()`
283     ///
284     /// # Brief
285     /// 1. Creates a new `DohResolver` instance.
286     /// 2. Verifies the default `ttl` is 60 second.
287     /// 3. Calls `set_ttl` and check ttl.
288     #[test]
ut_dns_resolver_ttl()289     fn ut_dns_resolver_ttl() {
290         let mut resolver = DohResolver::new("https://1.12.12.12/dns-query");
291         assert!(resolver.manager.is_some());
292         assert_eq!(
293             resolver.manager.as_ref().unwrap().ttl,
294             Duration::from_secs(60)
295         );
296         resolver = resolver.set_ttl(Duration::from_secs(30));
297         assert_eq!(
298             resolver.manager.as_ref().unwrap().ttl,
299             Duration::from_secs(30)
300         );
301     }
302 
303     /// UT test case for `get_info` function with IPv4 address
304     ///
305     /// # Brief
306     /// 1. Provides a DNS response text for an IPv4 address.
307     /// 2. Calls `get_info` to extract addresses and TTL.
308     /// 3. Verifies the extracted address and TTL.
309     #[test]
ut_get_info_ipv4()310     fn ut_get_info_ipv4() {
311         let ipv4_text = r#"{"Status":0,"TC":false,"RD":true,"RA":true,"AD":false,"CD":false,"Question":[{"name":"example.com.","type":1}],"Answer":[{"name":"example.com.","type":1,"TTL":3378,"data":"93.184.215.14"}]}"#;
312         let (addrs, ttl) = DohConnector::get_info(ipv4_text, 0);
313         assert_eq!(addrs, vec![SocketAddr::from(([93, 184, 215, 14], 0))]);
314         assert_eq!(ttl, 3378);
315     }
316 
317     /// UT test case for `get_info` function with IPv6 address
318     ///
319     /// # Brief
320     /// 1. Provides a DNS response text for an IPv6 address.
321     /// 2. Calls `get_info` to extract addresses and TTL.
322     /// 3. Verifies the extracted address and TTL.
323     #[test]
ut_get_info_ipv6()324     fn ut_get_info_ipv6() {
325         let ipv6_text = r#"{"Status":0,"TC":false,"RD":true,"RA":true,"AD":false,"CD":false,"Question":[{"name":"example.com.","type":28}]"Answer":[{"name":example.com.","type":28,"TTL":1466,"data":"2606:2800:21f:cb07:6820:80da:af6b:8b2c"}]}"#;
326         let (addrs, ttl) = DohConnector::get_info(ipv6_text, 0);
327         assert_eq!(
328             addrs,
329             vec![SocketAddr::from((
330                 [0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c],
331                 0
332             ))]
333         );
334         assert_eq!(ttl, 1466);
335     }
336 
337     /// UT test case for `get_info` function with both IPv4 and IPv6 addresses
338     ///
339     /// # Brief
340     /// 1. Provides a DNS response text with both IPv4 and IPv6 addresses.
341     /// 2. Calls `get_info` to extract the addresses and TTL.
342     /// 3. Verifies the extracted addresses and TTL.
343     #[test]
ut_get_info_both()344     fn ut_get_info_both() {
345         let text = r#"{"Status":0,"TC":false,"RD":true,"RA":true,"AD":false,"CD":false,"Question":[{"name":"example.com.","type":1}],"Answer":[{"name":"example.com.","type":1,"TTL":3378,"data":"93.184.215.14"}]},{"Status":0,"TC":false,"RD":true,"RA":true,"AD":false,"CD":false,"Question":[{"name":"example.com.","type":28}],"Answer":[{"name":"example.com.","type":28,"TTL":1466,"data":"2606:2800:21f:cb07:6820:80da:af6b:8b2c"}]}"#;
346         let (addrs, ttl) = DohConnector::get_info(text, 0);
347         assert_eq!(
348             addrs,
349             vec![
350                 SocketAddr::from(([93, 184, 215, 14], 0)),
351                 SocketAddr::from((
352                     [0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c],
353                     0
354                 )),
355             ]
356         );
357         assert_eq!(ttl, 1466);
358     }
359 
360     /// UT test case for `get_info` function with some error response.
361     ///
362     /// # Brief
363     /// 1. Provides a DNS response text with some error response.
364     /// 2. Calls `get_info` to extract the addresses and TTL.
365     /// 3. Verifies addresses is empty and TTL is the max of u64.
366     #[test]
ut_get_info_error()367     fn ut_get_info_error() {
368         let error_text = "This is some error response.";
369         let (addrs, ttl) = DohConnector::get_info(error_text, 0);
370         assert_eq!(addrs, vec![]);
371         assert_eq!(ttl, 0);
372     }
373 }
374