• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
18 
19 #include "grpcpp/completion_queue.h"
20 #include "grpcpp/impl/service_type.h"
21 #include "grpcpp/server_builder.h"
22 #include "grpcpp/server_context.h"
23 #include "grpcpp/support/async_stream.h"
24 #include "grpcpp/support/async_unary_call.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/mutex.h"
28 
29 namespace tensorflow {
30 
31 // CALL STRUCTURES
32 // ===============
33 //
34 // Each pending (incoming) request corresponds to a call object that
35 // encapsulates the state of the call. Templates and
36 // pointers-to-member functions are used to avoid boilerplate and
37 // redundant closure creation. The class hierarchy is as follows:
38 //
39 // * `UntypedCall<Service>`: The base class represents a call that
40 //   could be associated with any of the methods on a service of type
41 //   `Service`. Also defines a `Tag` nested class that can be used as
42 //   the tag in a `grpc::CompletionQueue`.  Each class that
43 //   instantiates `Service` should have a completion queue polling
44 //   loop that knows about `UntypedCall<Service>::Tag` objects, and
45 //   invokes their `OnCompleted()` method to continue processing.
46 //
47 // * `Call<Service, GrpcService, Req, Resp>`: This class extends
48 //   `UntypedCall<Service>` and is additionally parameterized by the
49 //   gRPC-generated asynchronous service class, and the request and
50 //   response message types. It defines the state associated with a
51 //   call (whose type depends on the message types), and stores a
52 //   pointer to a `Service::HandleFoo()` handler method. Each
53 //   `Service::HandleFoo()` method knows about the corresponding
54 //   `Call` type, in order to access its state, and invoke its
55 //   `SendResponse()` method.
56 //
57 // The lifecycle of a call object is as follows.
58 //
59 // 1. A `Service` creates a `Call` for a particular method and
60 //    enqueues it in its completion queue (via an
61 //    `UntypedCall<Service>::Tag`).
62 //
63 // 2. When the tag is returned from `cq_->Next()`, the
64 //    `UntypedCall::RequestReceived()` method is invoked and takes
65 //    ownership of the call object. This indirectly invokes the
66 //    appropriate handler method on `Service`.
67 //
68 // 3. After the response has been written (perhaps in another thread),
69 //    the `Call::SendResponse()` method is invoked. It transfers
70 //    ownership of the call object back to the completion queue (via
71 //    an `UntypedCall::Tag`).
72 //
73 // 4. When the response has been sent, the tag is returned from
74 //    `cq_->Next()`, and the call object is deleted.
75 //
76 
77 template <class Service>
78 class GrpcCallTag {
79  public:
~GrpcCallTag()80   virtual ~GrpcCallTag() {}
81 
82   // Calls the callback associated with this tag.
83   virtual void OnCompleted(Service* service, bool ok) = 0;
84 };
85 
86 // Represents a pending request with unknown message types.
87 template <class Service>
88 class UntypedCall : public core::RefCounted {
89  public:
~UntypedCall()90   virtual ~UntypedCall() {}
91 
92   // The implementation of this method should use `service` to handle
93   // an incoming request, and (perhaps asynchronously) send the
94   // response.
95   //
96   // One reference on `this` is transferred to the callee, and the
97   // callee is responsible for releasing it (typically via
98   // `Call::SendResponse()`).
99   //
100   // `ok` is true if the request was received in a "regular event",
101   // otherwise false.
102   virtual void RequestReceived(Service* service, bool ok) = 0;
103 
104   // This method will be called either (i) when the server is notified
105   // that the request has been canceled, or (ii) when the request completes
106   // normally. The implementation should distinguish these cases by querying
107   // the `grpc::ServerContext` associated with the request.
108   virtual void RequestCancelled(Service* service, bool ok) = 0;
109 
110   // Associates a tag in a `::grpc::CompletionQueue` with a callback
111   // for an incoming RPC.  An active Tag owns a reference on the corresponding
112   // Call object.
113   class Tag : public GrpcCallTag<Service> {
114    public:
115     // One enum value per supported callback.
116     enum Callback { kRequestReceived, kResponseSent, kCancelled };
117 
Tag(UntypedCall * call,Callback cb)118     Tag(UntypedCall* call, Callback cb) : call_(call), callback_(cb) {}
119 
120     // Calls the callback associated with this tag.
121     //
122     // The callback takes ownership of `this->call_`.
OnCompleted(Service * service,bool ok)123     void OnCompleted(Service* service, bool ok) override {
124       switch (callback_) {
125         case kRequestReceived:
126           call_->RequestReceived(service, ok);
127           break;
128         case kResponseSent:
129           // No special handling needed apart from the Unref below.
130           break;
131         case kCancelled:
132           call_->RequestCancelled(service, ok);
133           break;
134       }
135       call_->Unref();  // Ref acquired when tag handed to grpc.
136     }
137 
138    private:
139     UntypedCall* const call_;  // `this` owns one reference.
140     Callback callback_;
141   };
142 };
143 
144 // Represents a pending call with known request and response message
145 // types, and a known request-handling method.
146 template <class Service, class GrpcService, class RequestMessage,
147           class ResponseMessage>
148 class Call : public UntypedCall<Service> {
149  public:
150   // Represents the generic signature of a generated
151   // `GrpcService::RequestFoo()` method, where `Foo` is the name of an
152   // RPC method.
153   using EnqueueFunction = void (GrpcService::*)(
154       ::grpc::ServerContext*, RequestMessage*,
155       ::grpc::ServerAsyncResponseWriter<ResponseMessage>*,
156       ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
157 
158   // Represents the generic signature of a `Service::HandleFoo()`
159   // method, where `Foo` is the name of an RPC method.
160   using HandleRequestFunction = void (Service::*)(
161       Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
162 
Call(HandleRequestFunction handle_request_function)163   Call(HandleRequestFunction handle_request_function)
164       : handle_request_function_(handle_request_function), responder_(&ctx_) {}
165 
~Call()166   virtual ~Call() {}
167 
RequestReceived(Service * service,bool ok)168   void RequestReceived(Service* service, bool ok) override {
169     if (ok) {
170       this->Ref();
171       (service->*handle_request_function_)(this);
172     }
173   }
174 
SendResponse(::grpc::Status status)175   void SendResponse(::grpc::Status status) {
176     this->Ref();  // Ref for grpc; released in Tag callback.
177     responder_.Finish(response, status, &response_sent_tag_);
178     this->Unref();
179   }
180 
RequestCancelled(Service * service,bool ok)181   void RequestCancelled(Service* service, bool ok) override {
182     if (ctx_.IsCancelled()) {
183       mutex_lock l(mu_);
184       if (cancel_callback_) {
185         cancel_callback_();
186       }
187     }
188   }
189 
190   // Registers `callback` as the function that should be called if and when this
191   // call is canceled by the client.
SetCancelCallback(std::function<void ()> callback)192   void SetCancelCallback(std::function<void()> callback) {
193     mutex_lock l(mu_);
194     cancel_callback_ = std::move(callback);
195   }
196 
197   // Clears any cancellation callback that has been registered for this call.
ClearCancelCallback()198   void ClearCancelCallback() {
199     mutex_lock l(mu_);
200     cancel_callback_ = nullptr;
201   }
202 
203   // Enqueues a new request for the given service on the given
204   // completion queue, using the given `enqueue_function`.
205   //
206   // The request will be handled with the given
207   // `handle_request_function`.
EnqueueRequest(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function,HandleRequestFunction handle_request_function,bool supports_cancel)208   static void EnqueueRequest(GrpcService* grpc_service,
209                              ::grpc::ServerCompletionQueue* cq,
210                              EnqueueFunction enqueue_function,
211                              HandleRequestFunction handle_request_function,
212                              bool supports_cancel) {
213     auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
214         handle_request_function);
215     if (supports_cancel) {
216       call->RegisterCancellationHandler();
217     }
218 
219     // Initial ref for call handed to grpc; released in Tag callback.
220     (grpc_service->*enqueue_function)(&call->ctx_, &call->request,
221                                       &call->responder_, cq, cq,
222                                       &call->request_received_tag_);
223   }
224 
225   // Enqueues a new request for the given service on the given
226   // completion queue, using the given `method_id`.
227   //
228   // The request will be handled with the given
229   // `handle_request_function`.
EnqueueRequestForMethod(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,int method_id,HandleRequestFunction handle_request_function,bool supports_cancel)230   static void EnqueueRequestForMethod(
231       GrpcService* grpc_service, ::grpc::ServerCompletionQueue* cq,
232       int method_id, HandleRequestFunction handle_request_function,
233       bool supports_cancel) {
234     auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
235         handle_request_function);
236     if (supports_cancel) {
237       call->RegisterCancellationHandler();
238     }
239 
240     // Initial ref for call handed to grpc; released in Tag callback.
241     grpc_service->RequestAsyncUnary(method_id, &call->ctx_, &call->request,
242                                     &call->responder_, cq, cq,
243                                     &call->request_received_tag_);
244   }
245 
246   RequestMessage request;
247   ResponseMessage response;
248 
client_metadata()249   const std::multimap<::grpc::string_ref, ::grpc::string_ref>& client_metadata()
250       const {
251     return ctx_.client_metadata();
252   }
253 
254  private:
255   // Creates a completion queue tag for handling cancellation by the client.
256   // NOTE: This method must be called before this call is enqueued on a
257   // completion queue.
RegisterCancellationHandler()258   void RegisterCancellationHandler() {
259     this->Ref();  // Ref for grpc; released in Tag callback.
260     ctx_.AsyncNotifyWhenDone(&cancelled_tag_);
261   }
262 
263   HandleRequestFunction handle_request_function_;
264   ::grpc::ServerContext ctx_;
265   ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
266 
267   // Used as void* completion markers from grpc to indicate different
268   // events of interest for a Call.
269   typedef typename UntypedCall<Service>::Tag Tag;
270   Tag request_received_tag_{this, Tag::kRequestReceived};
271   Tag response_sent_tag_{this, Tag::kResponseSent};
272   Tag cancelled_tag_{this, Tag::kCancelled};
273 
274   mutex mu_;
275   std::function<void()> cancel_callback_ TF_GUARDED_BY(mu_);
276 };
277 
278 // Lifetime of a server-side bidirectional streaming call:
279 // - The call is created in the static EnqueueRequest method. It transfers
280 //   ownership to the kCallOpen tag pushed onto the completion queue.
281 // - If kCallOpen completes successfully, a read is requested and the
282 //   kRequestReceived tag takes ownership of the call. If kCallOpen fails,
283 //   e.g. server is shutdown, no further requests are pushed and the call is
284 //   destroyed (at the end of Tag::OnCompleted).
285 // - When the first request is received, we Ref() the call and invoke the
286 //   handler method thereby transferring ownership to the handler method.
287 //   The handler is responsible for calling SendResponse() or Finish() on this
288 //   call.
289 //   - If the handler calls Finish(), e.g. the request was invalid, Finish()
290 //     transfers ownership from the handler to the kServerFinished tag that
291 //     it pushes on the completion queue. The ownership is transferred because
292 //     the ref count is not incremented before putting the tag on the queue.
293 //   - If the handler calls SendResponse(), SendResponse() transfers ownership
294 //     to the kResponseSent tag.
295 // - When kResponseSent completes, we request a new read, which owns the call
296 //   now.
297 // - When the next request is received, it is handled the same way as the first
298 //   request.
299 //
300 // Because we request a read only after the write is sent, we can safely reuse
301 // the same request and response messages for the whole call.
302 template <class Service>
303 class ServerUntypedBidirectionalStreamingCall : public core::RefCounted {
304  public:
305   virtual void RequestReceived(Service* service) = 0;
306 
307   // Enqueues a request on the completion queue to read the next request.
308   virtual void CallOpen() = 0;
309 
310   virtual void RequestRead() = 0;
311 
312   // Associates a tag in a `::grpc::CompletionQueue` with a callback.
313   // An active Tag owns a reference on the corresponding Call object.
314   class Tag : public GrpcCallTag<Service> {
315    public:
316     // One enum value per supported callback.
317     enum class TagType {
318       kCallOpen,
319       kRequestReceived,
320       kResponseSent,
321       kServerFinished,
322     };
323 
Tag(ServerUntypedBidirectionalStreamingCall * call,TagType cb)324     Tag(ServerUntypedBidirectionalStreamingCall* call, TagType cb)
325         : call_(call), callback_(cb) {}
326 
327     // Calls the callback associated with this tag and Unrefs this->call_.
OnCompleted(Service * service,bool ok)328     void OnCompleted(Service* service, bool ok) override {
329       switch (callback_) {
330         case TagType::kCallOpen:
331           // Non-ok value indicates that the server has been shutdown before we
332           // received a message for this call type. We do nothing to let this
333           // call object be destroyed and avoid enqueuing request for another
334           // call.
335           if (ok) {
336             call_->CallOpen();
337           }
338           break;
339         case TagType::kRequestReceived:
340           // Non-ok value from completion queue here means that we will not
341           // receive any more messages from the client, e.g. the client called
342           // WritesDone. There is nothing we need to do in this case. The call
343           // will be Unref'ed and deleted. If the client wants to open a new
344           // call, we have already enqueued a request for a new call in CallOpen
345           // above.
346           if (ok) {
347             call_->RequestReceived(service);
348           }
349           break;
350         case TagType::kResponseSent:
351           if (ok) {
352             // The obvious place to request a read would be at the end of
353             // RequestReceived(). Unfortunately, this can result in multiple
354             // outstanding write requests in the completion queue. This is
355             // currently not supported by gRPC, which requires at most one
356             // outstanding write request in the completion queue.
357             // Requesting a read here, in ResponseSent, works because at
358             // this point, the completion queue has no write requests
359             // (kResponseSent happens when a write completes).
360             // This might be synchronizing the processing more than strictly
361             // necessary, but is probably fine because, AFAICT from gRPC docs,
362             // the write request completes as soon as it can be written to
363             // outgoing buffer.
364             call_->RequestRead();
365           }
366           // ok == false means that the response is not going on the wire
367           // because the call is already dead (i.e., canceled, deadline
368           // expired, other side dropped the channel, etc). Since the call is
369           // dead, there is nothing for us to do, we just let the call be
370           // deleted.
371           break;
372         case TagType::kServerFinished:
373           // Whether our finish request is successful or not (whether it went
374           // on the wire towards the client), there is nothing for us to do.
375           // In the current implementation, there can be no read or write
376           // requests in the completion queue (see the comment in kResponseSent)
377           // above. Even if there were pending requests, they would complete
378           // with a non-ok status, we would not do anything, and let the call be
379           // deleted.
380           break;
381       }
382       call_->Unref();  // Ref acquired when tag was handed to grpc.
383     }
384 
385    private:
386     ServerUntypedBidirectionalStreamingCall* const
387         call_;  // `this` owns one reference.
388     TagType callback_;
389   };
390 };
391 
392 // Represents a pending call with known request and response message
393 // types, and a known request-handling method.
394 // Common usage pattern is to have a single thread waiting on events from
395 // completion queue and calling Tag::OnCompleted(), which invokes methods
396 // on this.
397 // This implementation assumes that the server will generate a single response
398 // message for each request message. More precisely, this class expects that
399 // each time it invokes handle_request_function_, the service implementation
400 // will either call SendResponse or Finish exactly once.
401 // Not thread-safe.
402 template <class Service, class GrpcService, class RequestMessage,
403           class ResponseMessage>
404 class ServerBidirectionalStreamingCall
405     : public ServerUntypedBidirectionalStreamingCall<Service> {
406  public:
407   // Represents the generic signature of a generated
408   // `GrpcService::RequestFoo()` method, where `Foo` is the name of an
409   // RPC method.
410   using EnqueueFunction = void (GrpcService::*)(
411       ::grpc::ServerContext*,
412       ::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage>*,
413       ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
414 
415   // Represents the generic signature of a `Service::HandleFoo()`
416   // method, where `Foo` is the name of an RPC method.
417   using HandleRequestFunction = void (Service::*)(
418       ServerBidirectionalStreamingCall<Service, GrpcService, RequestMessage,
419                                        ResponseMessage>*);
420 
ServerBidirectionalStreamingCall(HandleRequestFunction handle_request_function,GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function)421   ServerBidirectionalStreamingCall(
422       HandleRequestFunction handle_request_function, GrpcService* grpc_service,
423       ::grpc::ServerCompletionQueue* cq, EnqueueFunction enqueue_function)
424       : handle_request_function_(handle_request_function),
425         stream_(&ctx_),
426         grpc_service_(grpc_service),
427         cq_(cq),
428         enqueue_function_(enqueue_function) {
429     VLOG(3) << "Creating ServerBidirectionalStreamingCall " << this;
430   }
431 
~ServerBidirectionalStreamingCall()432   ~ServerBidirectionalStreamingCall() override {
433     VLOG(3) << "Destroying ServerBidirectionalStreamingCall " << this;
434   }
435 
CallOpen()436   void CallOpen() override {
437     // Let gRPC know that we can accept another call.
438     ServerBidirectionalStreamingCall<
439         Service, GrpcService, RequestMessage,
440         ResponseMessage>::EnqueueRequest(grpc_service_, cq_, enqueue_function_,
441                                          handle_request_function_);
442     RequestRead();
443   }
444 
RequestRead()445   void RequestRead() override {
446     this->Ref();
447     request_.Clear();
448     stream_.Read(&request_, &request_received_tag_);
449   }
450 
RequestReceived(Service * service)451   void RequestReceived(Service* service) override {
452     this->Ref();
453     // Request handling should result in a call to SendResponse or Finish.
454     (service->*handle_request_function_)(this);
455   }
456 
SendResponse()457   void SendResponse() {
458     // Transferring ownership of this to the response_sent_tag_.
459     stream_.Write(response_, &response_sent_tag_);
460     // stream_.Write does not save references to response_. We are free to muck
461     // around with it as soon as Write returns.
462     // We clear the response_ to prepare it for the next response.
463     response_.Clear();
464   }
465 
Finish(::grpc::Status status)466   void Finish(::grpc::Status status) {
467     // Transferring ownership of this to the server_finished_tag_.
468     stream_.Finish(status, &server_finished_tag_);
469   }
470 
471   // Enqueues a new request for the given service on the given
472   // completion queue, using the given `enqueue_function`.
473   //
474   // The request will be handled by the given `handle_request_function`.
EnqueueRequest(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function,HandleRequestFunction handle_request_function)475   static void EnqueueRequest(GrpcService* grpc_service,
476                              ::grpc::ServerCompletionQueue* cq,
477                              EnqueueFunction enqueue_function,
478                              HandleRequestFunction handle_request_function) {
479     auto call =
480         new ServerBidirectionalStreamingCall<Service, GrpcService,
481                                              RequestMessage, ResponseMessage>(
482             handle_request_function, grpc_service, cq, enqueue_function);
483 
484     // Initial ref for call handed to grpc; released in Tag callback.
485     (grpc_service->*enqueue_function)(&call->ctx_, &call->stream_, cq, cq,
486                                       &call->call_open_tag_);
487   }
488 
request()489   const RequestMessage& request() const { return request_; }
mutable_response()490   ResponseMessage* mutable_response() { return &response_; }
491 
492  private:
493   // Request and response messages are reused for each request/response exchange
494   // between the client and the server.
495   RequestMessage request_;
496   ResponseMessage response_;
497   ::grpc::ServerContext ctx_;
498 
499   HandleRequestFunction handle_request_function_;
500   ::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage> stream_;
501 
502   // Used as void* completion markers from grpc to indicate different
503   // events of interest for a ServerBidirectionalStreamingCall.
504   typedef typename ServerUntypedBidirectionalStreamingCall<Service>::Tag Tag;
505   // At most one tag of each kind may be given to gRPC at any one time.
506   // Beyond semantic sanity, this is needed to ensure proper ref counting
507   // of this call object.
508   Tag call_open_tag_{this, Tag::TagType::kCallOpen};
509   Tag request_received_tag_{this, Tag::TagType::kRequestReceived};
510   Tag response_sent_tag_{this, Tag::TagType::kResponseSent};
511   Tag server_finished_tag_{this, Tag::TagType::kServerFinished};
512 
513   // These fields are used only to spawn another instance of this to accept
514   // more streaming calls.
515   GrpcService* grpc_service_;
516   ::grpc::ServerCompletionQueue* cq_;
517   EnqueueFunction enqueue_function_;
518 };
519 
520 }  // namespace tensorflow
521 
522 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
523