1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2
3 use std::pin::Pin;
4 use std::ptr;
5 use std::sync::Arc;
6 use std::time::Duration;
7
8 use crate::grpc_sys;
9 use futures::ready;
10 use futures::sink::Sink;
11 use futures::stream::Stream;
12 use futures::task::{Context, Poll};
13 use parking_lot::Mutex;
14 use std::future::Future;
15
16 use super::{ShareCall, ShareCallHolder, SinkBase, WriteFlags};
17 use crate::buf::GrpcSlice;
18 use crate::call::{check_run, Call, MessageReader, Method};
19 use crate::channel::Channel;
20 use crate::codec::{DeserializeFn, SerializeFn};
21 use crate::error::{Error, Result};
22 use crate::metadata::Metadata;
23 use crate::task::{BatchFuture, BatchType};
24
25 /// Update the flag bit in res.
26 #[inline]
change_flag(res: &mut u32, flag: u32, set: bool)27 pub fn change_flag(res: &mut u32, flag: u32, set: bool) {
28 if set {
29 *res |= flag;
30 } else {
31 *res &= !flag;
32 }
33 }
34
35 /// Options for calls made by client.
36 #[derive(Clone, Default)]
37 pub struct CallOption {
38 timeout: Option<Duration>,
39 write_flags: WriteFlags,
40 call_flags: u32,
41 headers: Option<Metadata>,
42 }
43
44 impl CallOption {
45 /// Signal that the call is idempotent.
idempotent(mut self, is_idempotent: bool) -> CallOption46 pub fn idempotent(mut self, is_idempotent: bool) -> CallOption {
47 change_flag(
48 &mut self.call_flags,
49 grpc_sys::GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST,
50 is_idempotent,
51 );
52 self
53 }
54
55 /// Signal that the call should not return UNAVAILABLE before it has started.
wait_for_ready(mut self, wait_for_ready: bool) -> CallOption56 pub fn wait_for_ready(mut self, wait_for_ready: bool) -> CallOption {
57 change_flag(
58 &mut self.call_flags,
59 grpc_sys::GRPC_INITIAL_METADATA_WAIT_FOR_READY,
60 wait_for_ready,
61 );
62 self
63 }
64
65 /// Signal that the call is cacheable. gRPC is free to use GET verb.
cacheable(mut self, cacheable: bool) -> CallOption66 pub fn cacheable(mut self, cacheable: bool) -> CallOption {
67 change_flag(
68 &mut self.call_flags,
69 grpc_sys::GRPC_INITIAL_METADATA_CACHEABLE_REQUEST,
70 cacheable,
71 );
72 self
73 }
74
75 /// Set write flags.
write_flags(mut self, write_flags: WriteFlags) -> CallOption76 pub fn write_flags(mut self, write_flags: WriteFlags) -> CallOption {
77 self.write_flags = write_flags;
78 self
79 }
80
81 /// Set a timeout.
timeout(mut self, timeout: Duration) -> CallOption82 pub fn timeout(mut self, timeout: Duration) -> CallOption {
83 self.timeout = Some(timeout);
84 self
85 }
86
87 /// Get the timeout.
get_timeout(&self) -> Option<Duration>88 pub fn get_timeout(&self) -> Option<Duration> {
89 self.timeout
90 }
91
92 /// Set the headers to be sent with the call.
headers(mut self, meta: Metadata) -> CallOption93 pub fn headers(mut self, meta: Metadata) -> CallOption {
94 self.headers = Some(meta);
95 self
96 }
97
98 /// Get headers to be sent with the call.
get_headers(&self) -> Option<&Metadata>99 pub fn get_headers(&self) -> Option<&Metadata> {
100 self.headers.as_ref()
101 }
102 }
103
104 impl Call {
unary_async<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, req: &Req, mut opt: CallOption, ) -> Result<ClientUnaryReceiver<Resp>>105 pub fn unary_async<Req, Resp>(
106 channel: &Channel,
107 method: &Method<Req, Resp>,
108 req: &Req,
109 mut opt: CallOption,
110 ) -> Result<ClientUnaryReceiver<Resp>> {
111 let call = channel.create_call(method, &opt)?;
112 let mut payload = GrpcSlice::default();
113 (method.req_ser())(req, &mut payload);
114 let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
115 grpc_sys::grpcwrap_call_start_unary(
116 call.call,
117 ctx,
118 payload.as_mut_ptr(),
119 opt.write_flags.flags,
120 opt.headers
121 .as_mut()
122 .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
123 opt.call_flags,
124 tag,
125 )
126 });
127 Ok(ClientUnaryReceiver::new(call, cq_f, method.resp_de()))
128 }
129
client_streaming<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, mut opt: CallOption, ) -> Result<(ClientCStreamSender<Req>, ClientCStreamReceiver<Resp>)>130 pub fn client_streaming<Req, Resp>(
131 channel: &Channel,
132 method: &Method<Req, Resp>,
133 mut opt: CallOption,
134 ) -> Result<(ClientCStreamSender<Req>, ClientCStreamReceiver<Resp>)> {
135 let call = channel.create_call(method, &opt)?;
136 let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
137 grpc_sys::grpcwrap_call_start_client_streaming(
138 call.call,
139 ctx,
140 opt.headers
141 .as_mut()
142 .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
143 opt.call_flags,
144 tag,
145 )
146 });
147
148 let share_call = Arc::new(Mutex::new(ShareCall::new(call, cq_f)));
149 let sink = ClientCStreamSender::new(share_call.clone(), method.req_ser());
150 let recv = ClientCStreamReceiver {
151 call: share_call,
152 resp_de: method.resp_de(),
153 finished: false,
154 };
155 Ok((sink, recv))
156 }
157
server_streaming<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, req: &Req, mut opt: CallOption, ) -> Result<ClientSStreamReceiver<Resp>>158 pub fn server_streaming<Req, Resp>(
159 channel: &Channel,
160 method: &Method<Req, Resp>,
161 req: &Req,
162 mut opt: CallOption,
163 ) -> Result<ClientSStreamReceiver<Resp>> {
164 let call = channel.create_call(method, &opt)?;
165 let mut payload = GrpcSlice::default();
166 (method.req_ser())(req, &mut payload);
167 let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
168 grpc_sys::grpcwrap_call_start_server_streaming(
169 call.call,
170 ctx,
171 payload.as_mut_ptr(),
172 opt.write_flags.flags,
173 opt.headers
174 .as_mut()
175 .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
176 opt.call_flags,
177 tag,
178 )
179 });
180
181 // TODO: handle header
182 check_run(BatchType::Finish, |ctx, tag| unsafe {
183 grpc_sys::grpcwrap_call_recv_initial_metadata(call.call, ctx, tag)
184 });
185
186 Ok(ClientSStreamReceiver::new(call, cq_f, method.resp_de()))
187 }
188
duplex_streaming<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, mut opt: CallOption, ) -> Result<(ClientDuplexSender<Req>, ClientDuplexReceiver<Resp>)>189 pub fn duplex_streaming<Req, Resp>(
190 channel: &Channel,
191 method: &Method<Req, Resp>,
192 mut opt: CallOption,
193 ) -> Result<(ClientDuplexSender<Req>, ClientDuplexReceiver<Resp>)> {
194 let call = channel.create_call(method, &opt)?;
195 let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
196 grpc_sys::grpcwrap_call_start_duplex_streaming(
197 call.call,
198 ctx,
199 opt.headers
200 .as_mut()
201 .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
202 opt.call_flags,
203 tag,
204 )
205 });
206
207 // TODO: handle header.
208 check_run(BatchType::Finish, |ctx, tag| unsafe {
209 grpc_sys::grpcwrap_call_recv_initial_metadata(call.call, ctx, tag)
210 });
211
212 let share_call = Arc::new(Mutex::new(ShareCall::new(call, cq_f)));
213 let sink = ClientDuplexSender::new(share_call.clone(), method.req_ser());
214 let recv = ClientDuplexReceiver::new(share_call, method.resp_de());
215 Ok((sink, recv))
216 }
217 }
218
219 /// A receiver for unary request.
220 ///
221 /// The future is resolved once response is received.
222 #[must_use = "if unused the ClientUnaryReceiver may immediately cancel the RPC"]
223 pub struct ClientUnaryReceiver<T> {
224 call: Call,
225 resp_f: BatchFuture,
226 resp_de: DeserializeFn<T>,
227 }
228
229 impl<T> ClientUnaryReceiver<T> {
new(call: Call, resp_f: BatchFuture, resp_de: DeserializeFn<T>) -> ClientUnaryReceiver<T>230 fn new(call: Call, resp_f: BatchFuture, resp_de: DeserializeFn<T>) -> ClientUnaryReceiver<T> {
231 ClientUnaryReceiver {
232 call,
233 resp_f,
234 resp_de,
235 }
236 }
237
238 /// Cancel the call.
239 #[inline]
cancel(&mut self)240 pub fn cancel(&mut self) {
241 self.call.cancel()
242 }
243
244 #[inline]
resp_de(&self, reader: MessageReader) -> Result<T>245 pub fn resp_de(&self, reader: MessageReader) -> Result<T> {
246 (self.resp_de)(reader)
247 }
248 }
249
250 impl<T> Future for ClientUnaryReceiver<T> {
251 type Output = Result<T>;
252
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>>253 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>> {
254 let data = ready!(Pin::new(&mut self.resp_f).poll(cx)?);
255 let t = self.resp_de(data.unwrap())?;
256 Poll::Ready(Ok(t))
257 }
258 }
259
260 /// A receiver for client streaming call.
261 ///
262 /// If the corresponding sink has dropped or cancelled, this will poll a
263 /// [`RpcFailure`] error with the [`Cancelled`] status.
264 ///
265 /// [`RpcFailure`]: ./enum.Error.html#variant.RpcFailure
266 /// [`Cancelled`]: ./enum.RpcStatusCode.html#variant.Cancelled
267 #[must_use = "if unused the ClientCStreamReceiver may immediately cancel the RPC"]
268 pub struct ClientCStreamReceiver<T> {
269 call: Arc<Mutex<ShareCall>>,
270 resp_de: DeserializeFn<T>,
271 finished: bool,
272 }
273
274 impl<T> ClientCStreamReceiver<T> {
275 /// Cancel the call.
cancel(&mut self)276 pub fn cancel(&mut self) {
277 let lock = self.call.lock();
278 lock.call.cancel()
279 }
280
281 #[inline]
resp_de(&self, reader: MessageReader) -> Result<T>282 pub fn resp_de(&self, reader: MessageReader) -> Result<T> {
283 (self.resp_de)(reader)
284 }
285 }
286
287 impl<T> Drop for ClientCStreamReceiver<T> {
288 /// The corresponding RPC will be canceled if the receiver did not
289 /// finish before dropping.
drop(&mut self)290 fn drop(&mut self) {
291 if !self.finished {
292 self.cancel();
293 }
294 }
295 }
296
297 impl<T> Future for ClientCStreamReceiver<T> {
298 type Output = Result<T>;
299
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>>300 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>> {
301 let data = {
302 let mut call = self.call.lock();
303 ready!(call.poll_finish(cx)?)
304 };
305 let t = (self.resp_de)(data.unwrap())?;
306 self.finished = true;
307 Poll::Ready(Ok(t))
308 }
309 }
310
311 /// A sink for client streaming call and duplex streaming call.
312 /// To close the sink properly, you should call [`close`] before dropping.
313 ///
314 /// [`close`]: #method.close
315 #[must_use = "if unused the StreamingCallSink may immediately cancel the RPC"]
316 pub struct StreamingCallSink<Req> {
317 call: Arc<Mutex<ShareCall>>,
318 sink_base: SinkBase,
319 close_f: Option<BatchFuture>,
320 req_ser: SerializeFn<Req>,
321 }
322
323 impl<Req> StreamingCallSink<Req> {
new(call: Arc<Mutex<ShareCall>>, req_ser: SerializeFn<Req>) -> StreamingCallSink<Req>324 fn new(call: Arc<Mutex<ShareCall>>, req_ser: SerializeFn<Req>) -> StreamingCallSink<Req> {
325 StreamingCallSink {
326 call,
327 sink_base: SinkBase::new(false),
328 close_f: None,
329 req_ser,
330 }
331 }
332
333 /// By default it always sends messages with their configured buffer hint. But when the
334 /// `enhance_batch` is enabled, messages will be batched together as many as possible.
335 /// The rules are listed as below:
336 /// - All messages except the last one will be sent with `buffer_hint` set to true.
337 /// - The last message will also be sent with `buffer_hint` set to true unless any message is
338 /// offered with buffer hint set to false.
339 ///
340 /// No matter `enhance_batch` is true or false, it's recommended to follow the contract of
341 /// Sink and call `poll_flush` to ensure messages are handled by gRPC C Core.
enhance_batch(&mut self, flag: bool)342 pub fn enhance_batch(&mut self, flag: bool) {
343 self.sink_base.enhance_buffer_strategy = flag;
344 }
345
cancel(&mut self)346 pub fn cancel(&mut self) {
347 let call = self.call.lock();
348 call.call.cancel()
349 }
350 }
351
352 impl<P> Drop for StreamingCallSink<P> {
353 /// The corresponding RPC will be canceled if the sink did not call
354 /// [`close`] before dropping.
355 ///
356 /// [`close`]: #method.close
drop(&mut self)357 fn drop(&mut self) {
358 if self.close_f.is_none() {
359 self.cancel();
360 }
361 }
362 }
363
364 impl<Req> Sink<(Req, WriteFlags)> for StreamingCallSink<Req> {
365 type Error = Error;
366
367 #[inline]
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>368 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
369 Pin::new(&mut self.sink_base).poll_ready(cx)
370 }
371
372 #[inline]
start_send(mut self: Pin<&mut Self>, (msg, flags): (Req, WriteFlags)) -> Result<()>373 fn start_send(mut self: Pin<&mut Self>, (msg, flags): (Req, WriteFlags)) -> Result<()> {
374 {
375 let mut call = self.call.lock();
376 call.check_alive()?;
377 }
378 let t = &mut *self;
379 Pin::new(&mut t.sink_base).start_send(&mut t.call, &msg, flags, t.req_ser)
380 }
381
382 #[inline]
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>383 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
384 {
385 let mut call = self.call.lock();
386 call.check_alive()?;
387 }
388 let t = &mut *self;
389 Pin::new(&mut t.sink_base).poll_flush(cx, &mut t.call)
390 }
391
poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>392 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
393 let t = &mut *self;
394 let mut call = t.call.lock();
395 if t.close_f.is_none() {
396 ready!(Pin::new(&mut t.sink_base).poll_ready(cx)?);
397
398 let close_f = call.call.start_send_close_client()?;
399 t.close_f = Some(close_f);
400 }
401
402 if Pin::new(t.close_f.as_mut().unwrap()).poll(cx)?.is_pending() {
403 // if call is finished, can return early here.
404 call.check_alive()?;
405 return Poll::Pending;
406 }
407 Poll::Ready(Ok(()))
408 }
409 }
410
411 /// A sink for client streaming call.
412 ///
413 /// To close the sink properly, you should call [`close`] before dropping.
414 ///
415 /// [`close`]: #method.close
416 pub type ClientCStreamSender<T> = StreamingCallSink<T>;
417 /// A sink for duplex streaming call.
418 ///
419 /// To close the sink properly, you should call [`close`] before dropping.
420 ///
421 /// [`close`]: #method.close
422 pub type ClientDuplexSender<T> = StreamingCallSink<T>;
423
424 struct ResponseStreamImpl<H, T> {
425 call: H,
426 msg_f: Option<BatchFuture>,
427 read_done: bool,
428 finished: bool,
429 resp_de: DeserializeFn<T>,
430 }
431
432 impl<H: ShareCallHolder + Unpin, T> ResponseStreamImpl<H, T> {
new(call: H, resp_de: DeserializeFn<T>) -> ResponseStreamImpl<H, T>433 fn new(call: H, resp_de: DeserializeFn<T>) -> ResponseStreamImpl<H, T> {
434 ResponseStreamImpl {
435 call,
436 msg_f: None,
437 read_done: false,
438 finished: false,
439 resp_de,
440 }
441 }
442
cancel(&mut self)443 fn cancel(&mut self) {
444 self.call.call(|c| c.call.cancel())
445 }
446
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Result<T>>>447 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Result<T>>> {
448 if !self.finished {
449 let t = &mut *self;
450 let finished = &mut t.finished;
451 let _ = t.call.call(|c| {
452 let res = c.poll_finish(cx);
453 *finished = c.finished;
454 res
455 })?;
456 }
457
458 let mut bytes = None;
459 loop {
460 if !self.read_done {
461 if let Some(msg_f) = &mut self.msg_f {
462 bytes = ready!(Pin::new(msg_f).poll(cx)?);
463 if bytes.is_none() {
464 self.read_done = true;
465 }
466 }
467 }
468
469 if self.read_done {
470 if self.finished {
471 return Poll::Ready(None);
472 }
473 return Poll::Pending;
474 }
475
476 // so msg_f must be either stale or not initialised yet.
477 self.msg_f.take();
478 let msg_f = self.call.call(|c| c.call.start_recv_message())?;
479 self.msg_f = Some(msg_f);
480 if let Some(data) = bytes {
481 let msg = (self.resp_de)(data)?;
482 return Poll::Ready(Some(Ok(msg)));
483 }
484 }
485 }
486
487 // Cancel the call if we still have some messages or did not
488 // receive status code.
on_drop(&mut self)489 fn on_drop(&mut self) {
490 if !self.read_done || !self.finished {
491 self.cancel();
492 }
493 }
494 }
495
496 /// A receiver for server streaming call.
497 #[must_use = "if unused the ClientSStreamReceiver may immediately cancel the RPC"]
498 pub struct ClientSStreamReceiver<Resp> {
499 imp: ResponseStreamImpl<ShareCall, Resp>,
500 }
501
502 impl<Resp> ClientSStreamReceiver<Resp> {
new( call: Call, finish_f: BatchFuture, de: DeserializeFn<Resp>, ) -> ClientSStreamReceiver<Resp>503 fn new(
504 call: Call,
505 finish_f: BatchFuture,
506 de: DeserializeFn<Resp>,
507 ) -> ClientSStreamReceiver<Resp> {
508 let share_call = ShareCall::new(call, finish_f);
509 ClientSStreamReceiver {
510 imp: ResponseStreamImpl::new(share_call, de),
511 }
512 }
513
cancel(&mut self)514 pub fn cancel(&mut self) {
515 self.imp.cancel()
516 }
517 }
518
519 impl<Resp> Stream for ClientSStreamReceiver<Resp> {
520 type Item = Result<Resp>;
521
522 #[inline]
poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>523 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
524 Pin::new(&mut self.imp).poll(cx)
525 }
526 }
527
528 /// A response receiver for duplex call.
529 ///
530 /// If the corresponding sink has dropped or cancelled, this will poll a
531 /// [`RpcFailure`] error with the [`Cancelled`] status.
532 ///
533 /// [`RpcFailure`]: ./enum.Error.html#variant.RpcFailure
534 /// [`Cancelled`]: ./enum.RpcStatusCode.html#variant.Cancelled
535 #[must_use = "if unused the ClientDuplexReceiver may immediately cancel the RPC"]
536 pub struct ClientDuplexReceiver<Resp> {
537 imp: ResponseStreamImpl<Arc<Mutex<ShareCall>>, Resp>,
538 }
539
540 impl<Resp> ClientDuplexReceiver<Resp> {
new(call: Arc<Mutex<ShareCall>>, de: DeserializeFn<Resp>) -> ClientDuplexReceiver<Resp>541 fn new(call: Arc<Mutex<ShareCall>>, de: DeserializeFn<Resp>) -> ClientDuplexReceiver<Resp> {
542 ClientDuplexReceiver {
543 imp: ResponseStreamImpl::new(call, de),
544 }
545 }
546
cancel(&mut self)547 pub fn cancel(&mut self) {
548 self.imp.cancel()
549 }
550 }
551
552 impl<Resp> Drop for ClientDuplexReceiver<Resp> {
553 /// The corresponding RPC will be canceled if the receiver did not
554 /// finish before dropping.
drop(&mut self)555 fn drop(&mut self) {
556 self.imp.on_drop()
557 }
558 }
559
560 impl<Resp> Stream for ClientDuplexReceiver<Resp> {
561 type Item = Result<Resp>;
562
poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>563 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
564 Pin::new(&mut self.imp).poll(cx)
565 }
566 }
567
568 #[cfg(test)]
569 mod tests {
570 #[test]
test_change_flag()571 fn test_change_flag() {
572 let mut flag = 2 | 4;
573 super::change_flag(&mut flag, 8, true);
574 assert_eq!(flag, 2 | 4 | 8);
575 super::change_flag(&mut flag, 4, false);
576 assert_eq!(flag, 2 | 8);
577 }
578 }
579