1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ 18 19 #include "grpcpp/completion_queue.h" 20 #include "grpcpp/impl/service_type.h" 21 #include "grpcpp/server_builder.h" 22 #include "grpcpp/server_context.h" 23 #include "grpcpp/support/async_stream.h" 24 #include "grpcpp/support/async_unary_call.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/mutex.h" 28 29 namespace tensorflow { 30 31 // CALL STRUCTURES 32 // =============== 33 // 34 // Each pending (incoming) request corresponds to a call object that 35 // encapsulates the state of the call. Templates and 36 // pointers-to-member functions are used to avoid boilerplate and 37 // redundant closure creation. The class hierarchy is as follows: 38 // 39 // * `UntypedCall<Service>`: The base class represents a call that 40 // could be associated with any of the methods on a service of type 41 // `Service`. Also defines a `Tag` nested class that can be used as 42 // the tag in a `grpc::CompletionQueue`. Each class that 43 // instantiates `Service` should have a completion queue polling 44 // loop that knows about `UntypedCall<Service>::Tag` objects, and 45 // invokes their `OnCompleted()` method to continue processing. 46 // 47 // * `Call<Service, GrpcService, Req, Resp>`: This class extends 48 // `UntypedCall<Service>` and is additionally parameterized by the 49 // gRPC-generated asynchronous service class, and the request and 50 // response message types. It defines the state associated with a 51 // call (whose type depends on the message types), and stores a 52 // pointer to a `Service::HandleFoo()` handler method. Each 53 // `Service::HandleFoo()` method knows about the corresponding 54 // `Call` type, in order to access its state, and invoke its 55 // `SendResponse()` method. 56 // 57 // The lifecycle of a call object is as follows. 58 // 59 // 1. A `Service` creates a `Call` for a particular method and 60 // enqueues it in its completion queue (via an 61 // `UntypedCall<Service>::Tag`). 62 // 63 // 2. When the tag is returned from `cq_->Next()`, the 64 // `UntypedCall::RequestReceived()` method is invoked and takes 65 // ownership of the call object. This indirectly invokes the 66 // appropriate handler method on `Service`. 67 // 68 // 3. After the response has been written (perhaps in another thread), 69 // the `Call::SendResponse()` method is invoked. It transfers 70 // ownership of the call object back to the completion queue (via 71 // an `UntypedCall::Tag`). 72 // 73 // 4. When the response has been sent, the tag is returned from 74 // `cq_->Next()`, and the call object is deleted. 75 // 76 77 template <class Service> 78 class GrpcCallTag { 79 public: ~GrpcCallTag()80 virtual ~GrpcCallTag() {} 81 82 // Calls the callback associated with this tag. 83 virtual void OnCompleted(Service* service, bool ok) = 0; 84 }; 85 86 // Represents a pending request with unknown message types. 87 template <class Service> 88 class UntypedCall : public core::RefCounted { 89 public: ~UntypedCall()90 virtual ~UntypedCall() {} 91 92 // The implementation of this method should use `service` to handle 93 // an incoming request, and (perhaps asynchronously) send the 94 // response. 95 // 96 // One reference on `this` is transferred to the callee, and the 97 // callee is responsible for releasing it (typically via 98 // `Call::SendResponse()`). 99 // 100 // `ok` is true if the request was received in a "regular event", 101 // otherwise false. 102 virtual void RequestReceived(Service* service, bool ok) = 0; 103 104 // This method will be called either (i) when the server is notified 105 // that the request has been canceled, or (ii) when the request completes 106 // normally. The implementation should distinguish these cases by querying 107 // the `grpc::ServerContext` associated with the request. 108 virtual void RequestCancelled(Service* service, bool ok) = 0; 109 110 // Associates a tag in a `::grpc::CompletionQueue` with a callback 111 // for an incoming RPC. An active Tag owns a reference on the corresponding 112 // Call object. 113 class Tag : public GrpcCallTag<Service> { 114 public: 115 // One enum value per supported callback. 116 enum Callback { kRequestReceived, kResponseSent, kCancelled }; 117 Tag(UntypedCall * call,Callback cb)118 Tag(UntypedCall* call, Callback cb) : call_(call), callback_(cb) {} 119 120 // Calls the callback associated with this tag. 121 // 122 // The callback takes ownership of `this->call_`. OnCompleted(Service * service,bool ok)123 void OnCompleted(Service* service, bool ok) override { 124 switch (callback_) { 125 case kRequestReceived: 126 call_->RequestReceived(service, ok); 127 break; 128 case kResponseSent: 129 // No special handling needed apart from the Unref below. 130 break; 131 case kCancelled: 132 call_->RequestCancelled(service, ok); 133 break; 134 } 135 call_->Unref(); // Ref acquired when tag handed to grpc. 136 } 137 138 private: 139 UntypedCall* const call_; // `this` owns one reference. 140 Callback callback_; 141 }; 142 }; 143 144 // Represents a pending call with known request and response message 145 // types, and a known request-handling method. 146 template <class Service, class GrpcService, class RequestMessage, 147 class ResponseMessage> 148 class Call : public UntypedCall<Service> { 149 public: 150 // Represents the generic signature of a generated 151 // `GrpcService::RequestFoo()` method, where `Foo` is the name of an 152 // RPC method. 153 using EnqueueFunction = void (GrpcService::*)( 154 ::grpc::ServerContext*, RequestMessage*, 155 ::grpc::ServerAsyncResponseWriter<ResponseMessage>*, 156 ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*); 157 158 // Represents the generic signature of a `Service::HandleFoo()` 159 // method, where `Foo` is the name of an RPC method. 160 using HandleRequestFunction = void (Service::*)( 161 Call<Service, GrpcService, RequestMessage, ResponseMessage>*); 162 Call(HandleRequestFunction handle_request_function)163 Call(HandleRequestFunction handle_request_function) 164 : handle_request_function_(handle_request_function), responder_(&ctx_) {} 165 ~Call()166 virtual ~Call() {} 167 RequestReceived(Service * service,bool ok)168 void RequestReceived(Service* service, bool ok) override { 169 if (ok) { 170 this->Ref(); 171 (service->*handle_request_function_)(this); 172 } 173 } 174 SendResponse(::grpc::Status status)175 void SendResponse(::grpc::Status status) { 176 this->Ref(); // Ref for grpc; released in Tag callback. 177 responder_.Finish(response, status, &response_sent_tag_); 178 this->Unref(); 179 } 180 RequestCancelled(Service * service,bool ok)181 void RequestCancelled(Service* service, bool ok) override { 182 if (ctx_.IsCancelled()) { 183 mutex_lock l(mu_); 184 if (cancel_callback_) { 185 cancel_callback_(); 186 } 187 } 188 } 189 190 // Registers `callback` as the function that should be called if and when this 191 // call is canceled by the client. SetCancelCallback(std::function<void ()> callback)192 void SetCancelCallback(std::function<void()> callback) { 193 mutex_lock l(mu_); 194 cancel_callback_ = std::move(callback); 195 } 196 197 // Clears any cancellation callback that has been registered for this call. ClearCancelCallback()198 void ClearCancelCallback() { 199 mutex_lock l(mu_); 200 cancel_callback_ = nullptr; 201 } 202 203 // Enqueues a new request for the given service on the given 204 // completion queue, using the given `enqueue_function`. 205 // 206 // The request will be handled with the given 207 // `handle_request_function`. EnqueueRequest(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function,HandleRequestFunction handle_request_function,bool supports_cancel)208 static void EnqueueRequest(GrpcService* grpc_service, 209 ::grpc::ServerCompletionQueue* cq, 210 EnqueueFunction enqueue_function, 211 HandleRequestFunction handle_request_function, 212 bool supports_cancel) { 213 auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>( 214 handle_request_function); 215 if (supports_cancel) { 216 call->RegisterCancellationHandler(); 217 } 218 219 // Initial ref for call handed to grpc; released in Tag callback. 220 (grpc_service->*enqueue_function)(&call->ctx_, &call->request, 221 &call->responder_, cq, cq, 222 &call->request_received_tag_); 223 } 224 225 // Enqueues a new request for the given service on the given 226 // completion queue, using the given `method_id`. 227 // 228 // The request will be handled with the given 229 // `handle_request_function`. EnqueueRequestForMethod(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,int method_id,HandleRequestFunction handle_request_function,bool supports_cancel)230 static void EnqueueRequestForMethod( 231 GrpcService* grpc_service, ::grpc::ServerCompletionQueue* cq, 232 int method_id, HandleRequestFunction handle_request_function, 233 bool supports_cancel) { 234 auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>( 235 handle_request_function); 236 if (supports_cancel) { 237 call->RegisterCancellationHandler(); 238 } 239 240 // Initial ref for call handed to grpc; released in Tag callback. 241 grpc_service->RequestAsyncUnary(method_id, &call->ctx_, &call->request, 242 &call->responder_, cq, cq, 243 &call->request_received_tag_); 244 } 245 246 RequestMessage request; 247 ResponseMessage response; 248 client_metadata()249 const std::multimap<::grpc::string_ref, ::grpc::string_ref>& client_metadata() 250 const { 251 return ctx_.client_metadata(); 252 } 253 254 private: 255 // Creates a completion queue tag for handling cancellation by the client. 256 // NOTE: This method must be called before this call is enqueued on a 257 // completion queue. RegisterCancellationHandler()258 void RegisterCancellationHandler() { 259 this->Ref(); // Ref for grpc; released in Tag callback. 260 ctx_.AsyncNotifyWhenDone(&cancelled_tag_); 261 } 262 263 HandleRequestFunction handle_request_function_; 264 ::grpc::ServerContext ctx_; 265 ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_; 266 267 // Used as void* completion markers from grpc to indicate different 268 // events of interest for a Call. 269 typedef typename UntypedCall<Service>::Tag Tag; 270 Tag request_received_tag_{this, Tag::kRequestReceived}; 271 Tag response_sent_tag_{this, Tag::kResponseSent}; 272 Tag cancelled_tag_{this, Tag::kCancelled}; 273 274 mutex mu_; 275 std::function<void()> cancel_callback_ GUARDED_BY(mu_); 276 }; 277 278 // Lifetime of a server-side bidirectional streaming call: 279 // - The call is created in the static EnqueueRequest method. It transfers 280 // ownership to the kCallOpen tag pushed onto the completion queue. 281 // - If kCallOpen completes successfully, a read is requested and the 282 // kRequestReceived tag takes ownership of the call. If kCallOpen fails, 283 // e.g. server is shutdown, no further requests are pushed and the call is 284 // destroyed (at the end of Tag::OnCompleted). 285 // - When the first request is received, we Ref() the call and invoke the 286 // handler method thereby transferring ownership to the handler method. 287 // The handler is responsible for calling SendResponse() or Finish() on this 288 // call. 289 // - If the handler calls Finish(), e.g. the request was invalid, Finish() 290 // transfers ownership from the handler to the kServerFinished tag that 291 // it pushes on the completion queue. The ownership is transferred because 292 // the ref count is not incremented before putting the tag on the queue. 293 // - If the handler calls SendResponse(), SendResponse() transfers ownership 294 // to the kResponseSent tag. 295 // - When kResponseSent completes, we request a new read, which owns the call 296 // now. 297 // - When the next request is received, it is handled the same way as the first 298 // request. 299 // 300 // Because we request a read only after the write is sent, we can safely reuse 301 // the same request and response messages for the whole call. 302 template <class Service> 303 class ServerUntypedBidirectionalStreamingCall : public core::RefCounted { 304 public: 305 virtual void RequestReceived(Service* service) = 0; 306 307 // Enqueues a request on the completion queue to read the next request. 308 virtual void CallOpen() = 0; 309 310 virtual void RequestRead() = 0; 311 312 // Associates a tag in a `::grpc::CompletionQueue` with a callback. 313 // An active Tag owns a reference on the corresponding Call object. 314 class Tag : public GrpcCallTag<Service> { 315 public: 316 // One enum value per supported callback. 317 enum class TagType { 318 kCallOpen, 319 kRequestReceived, 320 kResponseSent, 321 kServerFinished, 322 }; 323 Tag(ServerUntypedBidirectionalStreamingCall * call,TagType cb)324 Tag(ServerUntypedBidirectionalStreamingCall* call, TagType cb) 325 : call_(call), callback_(cb) {} 326 327 // Calls the callback associated with this tag and Unrefs this->call_. OnCompleted(Service * service,bool ok)328 void OnCompleted(Service* service, bool ok) override { 329 switch (callback_) { 330 case TagType::kCallOpen: 331 // Non-ok value indicates that the server has been shutdown before we 332 // received a message for this call type. We do nothing to let this 333 // call object be destroyed and avoid enqueuing request for another 334 // call. 335 if (ok) { 336 call_->CallOpen(); 337 } 338 break; 339 case TagType::kRequestReceived: 340 // Non-ok value from completion queue here means that we will not 341 // receive any more messages from the client, e.g. the client called 342 // WritesDone. There is nothing we need to do in this case. The call 343 // will be Unref'ed and deleted. If the client wants to open a new 344 // call, we have already enqueued a request for a new call in CallOpen 345 // above. 346 if (ok) { 347 call_->RequestReceived(service); 348 } 349 break; 350 case TagType::kResponseSent: 351 if (ok) { 352 // The obvious place to request a read would be at the end of 353 // RequestReceived(). Unfortunately, this can result in multiple 354 // outstanding write requests in the completion queue. This is 355 // currently not supported by gRPC, which requires at most one 356 // outstanding write request in the completion queue. 357 // Requesting a read here, in ResponseSent, works because at 358 // this point, the completion queue has no write requests 359 // (kResponseSent happens when a write completes). 360 // This might be synchronizing the processing more than strictly 361 // necessary, but is probably fine because, AFAICT from gRPC docs, 362 // the write request completes as soon as it can be written to 363 // outgoing buffer. 364 call_->RequestRead(); 365 } 366 // ok == false means that the response is not going on the wire 367 // because the call is already dead (i.e., canceled, deadline 368 // expired, other side dropped the channel, etc). Since the call is 369 // dead, there is nothing for us to do, we just let the call be 370 // deleted. 371 break; 372 case TagType::kServerFinished: 373 // Whether our finish request is successful or not (whether it went 374 // on the wire towards the client), there is nothing for us to do. 375 // In the current implementation, there can be no read or write 376 // requests in the completion queue (see the comment in kResponseSent) 377 // above. Even if there were pending requests, they would complete 378 // with a non-ok status, we would not do anything, and let the call be 379 // deleted. 380 break; 381 } 382 call_->Unref(); // Ref acquired when tag was handed to grpc. 383 } 384 385 private: 386 ServerUntypedBidirectionalStreamingCall* const 387 call_; // `this` owns one reference. 388 TagType callback_; 389 }; 390 }; 391 392 // Represents a pending call with known request and response message 393 // types, and a known request-handling method. 394 // Common usage pattern is to have a single thread waiting on events from 395 // completion queue and calling Tag::OnCompleted(), which invokes methods 396 // on this. 397 // This implementation assumes that the server will generate a single response 398 // message for each request message. More precisely, this class expects that 399 // each time it invokes handle_request_function_, the service implementation 400 // will either call SendResponse or Finish exactly once. 401 // Not thread-safe. 402 template <class Service, class GrpcService, class RequestMessage, 403 class ResponseMessage> 404 class ServerBidirectionalStreamingCall 405 : public ServerUntypedBidirectionalStreamingCall<Service> { 406 public: 407 // Represents the generic signature of a generated 408 // `GrpcService::RequestFoo()` method, where `Foo` is the name of an 409 // RPC method. 410 using EnqueueFunction = void (GrpcService::*)( 411 ::grpc::ServerContext*, 412 ::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage>*, 413 ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*); 414 415 // Represents the generic signature of a `Service::HandleFoo()` 416 // method, where `Foo` is the name of an RPC method. 417 using HandleRequestFunction = void (Service::*)( 418 ServerBidirectionalStreamingCall<Service, GrpcService, RequestMessage, 419 ResponseMessage>*); 420 ServerBidirectionalStreamingCall(HandleRequestFunction handle_request_function,GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function)421 ServerBidirectionalStreamingCall( 422 HandleRequestFunction handle_request_function, GrpcService* grpc_service, 423 ::grpc::ServerCompletionQueue* cq, EnqueueFunction enqueue_function) 424 : handle_request_function_(handle_request_function), 425 stream_(&ctx_), 426 grpc_service_(grpc_service), 427 cq_(cq), 428 enqueue_function_(enqueue_function) { 429 VLOG(3) << "Creating ServerBidirectionalStreamingCall " << this; 430 } 431 ~ServerBidirectionalStreamingCall()432 ~ServerBidirectionalStreamingCall() override { 433 VLOG(3) << "Destroying ServerBidirectionalStreamingCall " << this; 434 } 435 CallOpen()436 void CallOpen() override { 437 // Let gRPC know that we can accept another call. 438 ServerBidirectionalStreamingCall< 439 Service, GrpcService, RequestMessage, 440 ResponseMessage>::EnqueueRequest(grpc_service_, cq_, enqueue_function_, 441 handle_request_function_); 442 RequestRead(); 443 } 444 RequestRead()445 void RequestRead() override { 446 this->Ref(); 447 request_.Clear(); 448 stream_.Read(&request_, &request_received_tag_); 449 } 450 RequestReceived(Service * service)451 void RequestReceived(Service* service) override { 452 this->Ref(); 453 // Request handling should result in a call to SendResponse or Finish. 454 (service->*handle_request_function_)(this); 455 } 456 SendResponse()457 void SendResponse() { 458 // Transferring ownership of this to the response_sent_tag_. 459 stream_.Write(response_, &response_sent_tag_); 460 // stream_.Write does not save references to response_. We are free to muck 461 // around with it as soon as Write returns. 462 // We clear the response_ to prepare it for the next response. 463 response_.Clear(); 464 } 465 Finish(::grpc::Status status)466 void Finish(::grpc::Status status) { 467 // Transferring ownership of this to the server_finished_tag_. 468 stream_.Finish(status, &server_finished_tag_); 469 } 470 471 // Enqueues a new request for the given service on the given 472 // completion queue, using the given `enqueue_function`. 473 // 474 // The request will be handled by the given `handle_request_function`. EnqueueRequest(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function,HandleRequestFunction handle_request_function)475 static void EnqueueRequest(GrpcService* grpc_service, 476 ::grpc::ServerCompletionQueue* cq, 477 EnqueueFunction enqueue_function, 478 HandleRequestFunction handle_request_function) { 479 auto call = 480 new ServerBidirectionalStreamingCall<Service, GrpcService, 481 RequestMessage, ResponseMessage>( 482 handle_request_function, grpc_service, cq, enqueue_function); 483 484 // Initial ref for call handed to grpc; released in Tag callback. 485 (grpc_service->*enqueue_function)(&call->ctx_, &call->stream_, cq, cq, 486 &call->call_open_tag_); 487 } 488 request()489 const RequestMessage& request() const { return request_; } mutable_response()490 ResponseMessage* mutable_response() { return &response_; } 491 492 private: 493 // Request and response messages are reused for each request/response exchange 494 // between the client and the server. 495 RequestMessage request_; 496 ResponseMessage response_; 497 ::grpc::ServerContext ctx_; 498 499 HandleRequestFunction handle_request_function_; 500 ::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage> stream_; 501 502 // Used as void* completion markers from grpc to indicate different 503 // events of interest for a ServerBidirectionalStreamingCall. 504 typedef typename ServerUntypedBidirectionalStreamingCall<Service>::Tag Tag; 505 // At most one tag of each kind may be given to gRPC at any one time. 506 // Beyond semantic sanity, this is needed to ensure proper ref counting 507 // of this call object. 508 Tag call_open_tag_{this, Tag::TagType::kCallOpen}; 509 Tag request_received_tag_{this, Tag::TagType::kRequestReceived}; 510 Tag response_sent_tag_{this, Tag::TagType::kResponseSent}; 511 Tag server_finished_tag_{this, Tag::TagType::kServerFinished}; 512 513 // These fields are used only to spawn another instance of this to accept 514 // more streaming calls. 515 GrpcService* grpc_service_; 516 ::grpc::ServerCompletionQueue* cq_; 517 EnqueueFunction enqueue_function_; 518 }; 519 520 } // namespace tensorflow 521 522 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ 523