1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::error::Error; 15 use std::future::Future; 16 use std::mem::take; 17 use std::sync::{Arc, Mutex}; 18 19 use crate::async_impl::Connector; 20 use crate::error::HttpClientError; 21 use crate::util::dispatcher::{Conn, ConnDispatcher, Dispatcher}; 22 use crate::util::pool::{Pool, PoolKey}; 23 use crate::{AsyncRead, AsyncWrite, ErrorKind, HttpConfig, HttpVersion, Uri}; 24 25 pub(crate) struct ConnPool<C, S> { 26 pool: Pool<PoolKey, Conns<S>>, 27 connector: Arc<C>, 28 config: HttpConfig, 29 } 30 31 impl<C: Connector> ConnPool<C, C::Stream> { new(config: HttpConfig, connector: C) -> Self32 pub(crate) fn new(config: HttpConfig, connector: C) -> Self { 33 Self { 34 pool: Pool::new(), 35 connector: Arc::new(connector), 36 config, 37 } 38 } 39 connect_to(&self, uri: Uri) -> Result<Conn<C::Stream>, HttpClientError>40 pub(crate) async fn connect_to(&self, uri: Uri) -> Result<Conn<C::Stream>, HttpClientError> { 41 let key = PoolKey::new( 42 uri.scheme().unwrap().clone(), 43 uri.authority().unwrap().clone(), 44 ); 45 46 self.pool 47 .get(key, Conns::new) 48 .conn(self.config.clone(), self.connector.clone().connect(&uri)) 49 .await 50 } 51 } 52 53 pub(crate) struct Conns<S> { 54 list: Arc<Mutex<Vec<ConnDispatcher<S>>>>, 55 56 #[cfg(feature = "http2")] 57 h2_occupation: Arc<crate::AsyncMutex<()>>, 58 } 59 60 impl<S> Conns<S> { new() -> Self61 fn new() -> Self { 62 Self { 63 list: Arc::new(Mutex::new(Vec::new())), 64 65 #[cfg(feature = "http2")] 66 h2_occupation: Arc::new(crate::AsyncMutex::new(())), 67 } 68 } 69 } 70 71 impl<S> Clone for Conns<S> { clone(&self) -> Self72 fn clone(&self) -> Self { 73 Self { 74 list: self.list.clone(), 75 76 #[cfg(feature = "http2")] 77 h2_occupation: self.h2_occupation.clone(), 78 } 79 } 80 } 81 82 impl<S: AsyncRead + AsyncWrite + Unpin + Send + Sync> Conns<S> { conn<F, E>( &self, config: HttpConfig, connect_fut: F, ) -> Result<Conn<S>, HttpClientError> where F: Future<Output = Result<S, E>>, E: Into<Box<dyn Error + Send + Sync>>,83 async fn conn<F, E>( 84 &self, 85 config: HttpConfig, 86 connect_fut: F, 87 ) -> Result<Conn<S>, HttpClientError> 88 where 89 F: Future<Output = Result<S, E>>, 90 E: Into<Box<dyn Error + Send + Sync>>, 91 { 92 match config.version { 93 #[cfg(feature = "http2")] 94 HttpVersion::Http2PriorKnowledge => { 95 { 96 // The lock `h2_occupation` is used to prevent multiple coroutines from sending 97 // Requests at the same time under concurrent conditions, 98 // resulting in the creation of multiple tcp connections 99 let _lock = self.h2_occupation.lock().await; 100 if let Some(conn) = self.get_exist_conn() { 101 return Ok(conn); 102 } 103 // create tcp connection. 104 let dispatcher = ConnDispatcher::http2( 105 config.http2_config, 106 connect_fut.await.map_err(|e| { 107 HttpClientError::new_with_cause(ErrorKind::Connect, Some(e)) 108 })?, 109 ); 110 Ok(self.dispatch_conn(dispatcher)) 111 } 112 } 113 #[cfg(feature = "http1_1")] 114 HttpVersion::Http11 => { 115 if let Some(conn) = self.get_exist_conn() { 116 return Ok(conn); 117 } 118 let dispatcher = 119 ConnDispatcher::http1(connect_fut.await.map_err(|e| { 120 HttpClientError::new_with_cause(ErrorKind::Connect, Some(e)) 121 })?); 122 Ok(self.dispatch_conn(dispatcher)) 123 } 124 #[cfg(not(feature = "http1_1"))] 125 HttpVersion::Http11 => Err(HttpClientError::new_with_message( 126 ErrorKind::Connect, 127 "Invalid HTTP VERSION", 128 )), 129 } 130 } 131 dispatch_conn(&self, dispatcher: ConnDispatcher<S>) -> Conn<S>132 fn dispatch_conn(&self, dispatcher: ConnDispatcher<S>) -> Conn<S> { 133 // We must be able to get the `Conn` here. 134 let conn = dispatcher.dispatch().unwrap(); 135 let mut list = self.list.lock().unwrap(); 136 list.push(dispatcher); 137 conn 138 } 139 get_exist_conn(&self) -> Option<Conn<S>>140 fn get_exist_conn(&self) -> Option<Conn<S>> { 141 { 142 let mut list = self.list.lock().unwrap(); 143 let mut conn = None; 144 let curr = take(&mut *list); 145 for dispatcher in curr.into_iter() { 146 // Discard invalid dispatchers. 147 if dispatcher.is_shutdown() { 148 continue; 149 } 150 if conn.is_none() { 151 conn = dispatcher.dispatch(); 152 } 153 list.push(dispatcher); 154 } 155 conn 156 } 157 } 158 } 159