1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ 18 19 #include "tensorflow/core/lib/core/refcount.h" 20 #include "tensorflow/core/platform/macros.h" 21 #include "tensorflow/core/platform/mutex.h" 22 23 #include "grpcpp/grpcpp.h" 24 #include "grpcpp/impl/codegen/service_type.h" 25 #include "grpcpp/server_builder.h" 26 27 namespace tensorflow { 28 29 // CALL STRUCTURES 30 // =============== 31 // 32 // Each pending (incoming) request corresponds to a call object that 33 // encapsulates the state of the call. Templates and 34 // pointers-to-member functions are used to avoid boilerplate and 35 // redundant closure creation. The class hierarchy is as follows: 36 // 37 // * `UntypedCall<Service>`: The base class represents a call that 38 // could be associated with any of the methods on a service of type 39 // `Service`. Also defines a `Tag` nested class that can be used as 40 // the tag in a `grpc::CompletionQueue`. Each class that 41 // instantiates `Service` should have a completion queue polling 42 // loop that knows about `UntypedCall<Service>::Tag` objects, and 43 // invokes their `OnCompleted()` method to continue processing. 44 // 45 // * `Call<Service, GrpcService, Req, Resp>`: This class extends 46 // `UntypedCall<Service>` and is additionally parameterized by the 47 // gRPC-generated asynchronous service class, and the request and 48 // response message types. It defines the state associated with a 49 // call (whose type depends on the message types), and stores a 50 // pointer to a `Service::HandleFoo()` handler method. Each 51 // `Service::HandleFoo()` method knows about the corresponding 52 // `Call` type, in order to access its state, and invoke its 53 // `SendResponse()` method. 54 // 55 // The lifecycle of a call object is as follows. 56 // 57 // 1. A `Service` creates a `Call` for a particular method and 58 // enqueues it in its completion queue (via an 59 // `UntypedCall<Service>::Tag`). 60 // 61 // 2. When the tag is returned from `cq_->Next()`, the 62 // `UntypedCall::RequestReceived()` method is invoked and takes 63 // ownership of the call object. This indirectly invokes the 64 // appropriate handler method on `Service`. 65 // 66 // 3. After the response has been written (perhaps in another thread), 67 // the `Call::SendResponse()` method is invoked. It transfers 68 // ownership of the call object back to the completion queue (via 69 // an `UntypedCall::Tag`). 70 // 71 // 4. When the response has been sent, the tag is returned from 72 // `cq_->Next()`, and the call object is deleted. 73 74 // Represents a pending request with unknown message types. 75 template <class Service> 76 class UntypedCall : public core::RefCounted { 77 public: ~UntypedCall()78 virtual ~UntypedCall() {} 79 80 // The implementation of this method should use `service` to handle 81 // an incoming request, and (perhaps asynchronously) send the 82 // response. 83 // 84 // One reference on `this` is transferred to the callee, and the 85 // callee is responsible for releasing it (typically via 86 // `Call::SendResponse()`). 87 // 88 // `ok` is true if the request was received in a "regular event", 89 // otherwise false. 90 virtual void RequestReceived(Service* service, bool ok) = 0; 91 92 // This method will be called either (i) when the server is notified 93 // that the request has been canceled, or (ii) when the request completes 94 // normally. The implementation should distinguish these cases by querying 95 // the `grpc::ServerContext` associated with the request. 96 virtual void RequestCancelled(Service* service, bool ok) = 0; 97 98 // Associates a tag in a `::grpc::CompletionQueue` with a callback 99 // for an incoming RPC. An active Tag owns a reference on the corresponding 100 // Call object. 101 class Tag { 102 public: 103 // One enum value per supported callback. 104 enum Callback { kRequestReceived, kResponseSent, kCancelled }; 105 Tag(UntypedCall * call,Callback cb)106 Tag(UntypedCall* call, Callback cb) : call_(call), callback_(cb) {} 107 108 // Calls the callback associated with this tag. 109 // 110 // The callback takes ownership of `this->call_`. OnCompleted(Service * service,bool ok)111 void OnCompleted(Service* service, bool ok) { 112 switch (callback_) { 113 case kRequestReceived: 114 call_->RequestReceived(service, ok); 115 break; 116 case kResponseSent: 117 // No special handling needed apart from the Unref below. 118 break; 119 case kCancelled: 120 call_->RequestCancelled(service, ok); 121 break; 122 } 123 call_->Unref(); // Ref acquired when tag handed to grpc. 124 } 125 126 private: 127 UntypedCall* const call_; // `this` owns one reference. 128 Callback callback_; 129 }; 130 }; 131 132 // Represents a pending call with known request and response message 133 // types, and a known request-handling method. 134 template <class Service, class GrpcService, class RequestMessage, 135 class ResponseMessage> 136 class Call : public UntypedCall<Service> { 137 public: 138 // Represents the generic signature of a generated 139 // `GrpcService::RequestFoo()` method, where `Foo` is the name of an 140 // RPC method. 141 using EnqueueFunction = void (GrpcService::*)( 142 ::grpc::ServerContext*, RequestMessage*, 143 ::grpc::ServerAsyncResponseWriter<ResponseMessage>*, 144 ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*); 145 146 // Represents the generic signature of a `Service::HandleFoo()` 147 // method, where `Foo` is the name of an RPC method. 148 using HandleRequestFunction = void (Service::*)( 149 Call<Service, GrpcService, RequestMessage, ResponseMessage>*); 150 Call(HandleRequestFunction handle_request_function)151 Call(HandleRequestFunction handle_request_function) 152 : handle_request_function_(handle_request_function), responder_(&ctx_) {} 153 ~Call()154 virtual ~Call() {} 155 RequestReceived(Service * service,bool ok)156 void RequestReceived(Service* service, bool ok) override { 157 if (ok) { 158 this->Ref(); 159 (service->*handle_request_function_)(this); 160 } 161 } 162 SendResponse(::grpc::Status status)163 void SendResponse(::grpc::Status status) { 164 this->Ref(); // Ref for grpc; released in Tag callback. 165 responder_.Finish(response, status, &response_sent_tag_); 166 this->Unref(); 167 } 168 RequestCancelled(Service * service,bool ok)169 void RequestCancelled(Service* service, bool ok) override { 170 if (ctx_.IsCancelled()) { 171 mutex_lock l(mu_); 172 if (cancel_callback_) { 173 cancel_callback_(); 174 } 175 } 176 } 177 178 // Registers `callback` as the function that should be called if and when this 179 // call is canceled by the client. SetCancelCallback(std::function<void ()> callback)180 void SetCancelCallback(std::function<void()> callback) { 181 mutex_lock l(mu_); 182 cancel_callback_ = std::move(callback); 183 } 184 185 // Clears any cancellation callback that has been registered for this call. ClearCancelCallback()186 void ClearCancelCallback() { 187 mutex_lock l(mu_); 188 cancel_callback_ = nullptr; 189 } 190 191 // Enqueues a new request for the given service on the given 192 // completion queue, using the given `enqueue_function`. 193 // 194 // The request will be handled with the given 195 // `handle_request_function`. EnqueueRequest(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function,HandleRequestFunction handle_request_function,bool supports_cancel)196 static void EnqueueRequest(GrpcService* grpc_service, 197 ::grpc::ServerCompletionQueue* cq, 198 EnqueueFunction enqueue_function, 199 HandleRequestFunction handle_request_function, 200 bool supports_cancel) { 201 auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>( 202 handle_request_function); 203 if (supports_cancel) { 204 call->RegisterCancellationHandler(); 205 } 206 207 // Initial ref for call handed to grpc; released in Tag callback. 208 (grpc_service->*enqueue_function)(&call->ctx_, &call->request, 209 &call->responder_, cq, cq, 210 &call->request_received_tag_); 211 } 212 213 // Enqueues a new request for the given service on the given 214 // completion queue, using the given `method_id`. 215 // 216 // The request will be handled with the given 217 // `handle_request_function`. EnqueueRequestForMethod(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,int method_id,HandleRequestFunction handle_request_function,bool supports_cancel)218 static void EnqueueRequestForMethod( 219 GrpcService* grpc_service, ::grpc::ServerCompletionQueue* cq, 220 int method_id, HandleRequestFunction handle_request_function, 221 bool supports_cancel) { 222 auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>( 223 handle_request_function); 224 if (supports_cancel) { 225 call->RegisterCancellationHandler(); 226 } 227 228 // Initial ref for call handed to grpc; released in Tag callback. 229 grpc_service->RequestAsyncUnary(method_id, &call->ctx_, &call->request, 230 &call->responder_, cq, cq, 231 &call->request_received_tag_); 232 } 233 234 RequestMessage request; 235 ResponseMessage response; 236 client_metadata()237 const std::multimap<::grpc::string_ref, ::grpc::string_ref>& client_metadata() 238 const { 239 return ctx_.client_metadata(); 240 } 241 242 private: 243 // Creates a completion queue tag for handling cancellation by the client. 244 // NOTE: This method must be called before this call is enqueued on a 245 // completion queue. RegisterCancellationHandler()246 void RegisterCancellationHandler() { 247 this->Ref(); // Ref for grpc; released in Tag callback. 248 ctx_.AsyncNotifyWhenDone(&cancelled_tag_); 249 } 250 251 HandleRequestFunction handle_request_function_; 252 ::grpc::ServerContext ctx_; 253 ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_; 254 255 // Used as void* completion markers from grpc to indicate different 256 // events of interest for a Call. 257 typedef typename UntypedCall<Service>::Tag Tag; 258 Tag request_received_tag_{this, Tag::kRequestReceived}; 259 Tag response_sent_tag_{this, Tag::kResponseSent}; 260 Tag cancelled_tag_{this, Tag::kCancelled}; 261 262 mutex mu_; 263 std::function<void()> cancel_callback_ GUARDED_BY(mu_); 264 }; 265 266 } // namespace tensorflow 267 268 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ 269