• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::{error::Error, io::ErrorKind, net::ToSocketAddrs, pin::Pin, time::Duration};
2 use tokio::sync::mpsc;
3 use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
4 use tonic::{transport::Server, Request, Response, Status, Streaming};
5 
6 use echo_proto::echo::{echo_server, EchoRequest, EchoResponse};
7 
8 type EchoResult<T> = Result<Response<T>, Status>;
9 type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send>>;
10 
match_for_io_error(err_status: &Status) -> Option<&std::io::Error>11 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
12     let mut err: &(dyn Error + 'static) = err_status;
13 
14     loop {
15         if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
16             return Some(io_err);
17         }
18 
19         // h2::Error do not expose std::io::Error with `source()`
20         // https://github.com/hyperium/h2/pull/462
21         if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
22             if let Some(io_err) = h2_err.get_io() {
23                 return Some(io_err);
24             }
25         }
26 
27         err = match err.source() {
28             Some(err) => err,
29             None => return None,
30         };
31     }
32 }
33 
34 #[derive(Debug)]
35 pub struct EchoServer {}
36 
37 #[tonic::async_trait]
38 impl echo_server::Echo for EchoServer {
unary_echo(&self, _: Request<EchoRequest>) -> EchoResult<EchoResponse>39     async fn unary_echo(&self, _: Request<EchoRequest>) -> EchoResult<EchoResponse> {
40         Err(Status::unimplemented("not implemented"))
41     }
42 
43     type ServerStreamingEchoStream = ResponseStream;
44 
server_streaming_echo( &self, req: Request<EchoRequest>, ) -> EchoResult<Self::ServerStreamingEchoStream>45     async fn server_streaming_echo(
46         &self,
47         req: Request<EchoRequest>,
48     ) -> EchoResult<Self::ServerStreamingEchoStream> {
49         println!("EchoServer::server_streaming_echo");
50         println!("\tclient connected from: {:?}", req.remote_addr());
51 
52         // creating infinite stream with requested message
53         let repeat = std::iter::repeat(EchoResponse {
54             message: req.into_inner().message,
55         });
56         let mut stream = Box::pin(tokio_stream::iter(repeat).throttle(Duration::from_millis(200)));
57 
58         // spawn and channel are required if you want handle "disconnect" functionality
59         // the `out_stream` will not be polled after client disconnect
60         let (tx, rx) = mpsc::channel(128);
61         tokio::spawn(async move {
62             while let Some(item) = stream.next().await {
63                 match tx.send(Result::<_, Status>::Ok(item)).await {
64                     Ok(_) => {
65                         // item (server response) was queued to be send to client
66                     }
67                     Err(_item) => {
68                         // output_stream was build from rx and both are dropped
69                         break;
70                     }
71                 }
72             }
73             println!("\tclient disconnected");
74         });
75 
76         let output_stream = ReceiverStream::new(rx);
77         Ok(Response::new(
78             Box::pin(output_stream) as Self::ServerStreamingEchoStream
79         ))
80     }
81 
client_streaming_echo( &self, _: Request<Streaming<EchoRequest>>, ) -> EchoResult<EchoResponse>82     async fn client_streaming_echo(
83         &self,
84         _: Request<Streaming<EchoRequest>>,
85     ) -> EchoResult<EchoResponse> {
86         Err(Status::unimplemented("not implemented"))
87     }
88 
89     type BidirectionalStreamingEchoStream = ResponseStream;
90 
bidirectional_streaming_echo( &self, req: Request<Streaming<EchoRequest>>, ) -> EchoResult<Self::BidirectionalStreamingEchoStream>91     async fn bidirectional_streaming_echo(
92         &self,
93         req: Request<Streaming<EchoRequest>>,
94     ) -> EchoResult<Self::BidirectionalStreamingEchoStream> {
95         println!("EchoServer::bidirectional_streaming_echo");
96 
97         let mut in_stream = req.into_inner();
98         let (tx, rx) = mpsc::channel(128);
99 
100         // this spawn here is required if you want to handle connection error.
101         // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
102         // will be drooped when connection error occurs and error will never be propagated
103         // to mapped version of `in_stream`.
104         tokio::spawn(async move {
105             while let Some(result) = in_stream.next().await {
106                 match result {
107                     Ok(v) => tx
108                         .send(Ok(EchoResponse { message: v.message }))
109                         .await
110                         .expect("working rx"),
111                     Err(err) => {
112                         if let Some(io_err) = match_for_io_error(&err) {
113                             if io_err.kind() == ErrorKind::BrokenPipe {
114                                 // here you can handle special case when client
115                                 // disconnected in unexpected way
116                                 eprintln!("\tclient disconnected: broken pipe");
117                                 break;
118                             }
119                         }
120 
121                         match tx.send(Err(err)).await {
122                             Ok(_) => (),
123                             Err(_err) => break, // response was droped
124                         }
125                     }
126                 }
127             }
128             println!("\tstream ended");
129         });
130 
131         // echo just write the same data that was received
132         let out_stream = ReceiverStream::new(rx);
133 
134         Ok(Response::new(
135             Box::pin(out_stream) as Self::BidirectionalStreamingEchoStream
136         ))
137     }
138 }
139 
140 #[tokio::main]
main() -> Result<(), Box<dyn std::error::Error>>141 async fn main() -> Result<(), Box<dyn std::error::Error>> {
142     let server = EchoServer {};
143     Server::builder()
144         .add_service(echo_server::EchoServer::new(server))
145         .serve("[::1]:50051".to_socket_addrs().unwrap().next().unwrap())
146         .await
147         .unwrap();
148 
149     Ok(())
150 }
151