1 use http::{header::USER_AGENT, HeaderValue, Request}; 2 use std::task::{Context, Poll}; 3 use tower_service::Service; 4 5 const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION")); 6 7 #[derive(Debug)] 8 pub(crate) struct UserAgent<T> { 9 inner: T, 10 user_agent: HeaderValue, 11 } 12 13 impl<T> UserAgent<T> { new(inner: T, user_agent: Option<HeaderValue>) -> Self14 pub(crate) fn new(inner: T, user_agent: Option<HeaderValue>) -> Self { 15 let user_agent = user_agent 16 .map(|value| { 17 let mut buf = Vec::new(); 18 buf.extend(value.as_bytes()); 19 buf.push(b' '); 20 buf.extend(TONIC_USER_AGENT.as_bytes()); 21 HeaderValue::from_bytes(&buf).expect("user-agent should be valid") 22 }) 23 .unwrap_or_else(|| HeaderValue::from_static(TONIC_USER_AGENT)); 24 25 Self { inner, user_agent } 26 } 27 } 28 29 impl<T, ReqBody> Service<Request<ReqBody>> for UserAgent<T> 30 where 31 T: Service<Request<ReqBody>>, 32 { 33 type Response = T::Response; 34 type Error = T::Error; 35 type Future = T::Future; 36 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>37 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 38 self.inner.poll_ready(cx) 39 } 40 call(&mut self, mut req: Request<ReqBody>) -> Self::Future41 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { 42 req.headers_mut() 43 .insert(USER_AGENT, self.user_agent.clone()); 44 45 self.inner.call(req) 46 } 47 } 48 49 #[cfg(test)] 50 mod tests { 51 use super::*; 52 53 struct Svc; 54 55 #[test] sets_default_if_no_custom_user_agent()56 fn sets_default_if_no_custom_user_agent() { 57 assert_eq!( 58 UserAgent::new(Svc, None).user_agent, 59 HeaderValue::from_static(TONIC_USER_AGENT) 60 ) 61 } 62 63 #[test] prepends_custom_user_agent_to_default()64 fn prepends_custom_user_agent_to_default() { 65 assert_eq!( 66 UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent, 67 HeaderValue::from_str(&format!("Greeter 1.1 {}", TONIC_USER_AGENT)).unwrap() 68 ) 69 } 70 } 71