1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ 18 19 #include "grpcpp/alarm.h" 20 #include "grpcpp/completion_queue.h" 21 #include "grpcpp/server_builder.h" 22 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" 23 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 24 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 28 namespace tensorflow { 29 namespace eager { 30 31 // This class is a wrapper that handles communication for gRPC. 32 class GrpcEagerServiceImpl : public AsyncServiceInterface { 33 public: 34 template <class RequestMessage, class ResponseMessage> 35 using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, 36 RequestMessage, ResponseMessage>; 37 template <class RequestMessage, class ResponseMessage> 38 using StreamingCall = 39 ServerBidirectionalStreamingCall<GrpcEagerServiceImpl, 40 grpc::EagerService::AsyncService, 41 RequestMessage, ResponseMessage>; 42 43 GrpcEagerServiceImpl(const WorkerEnv* env, 44 ::grpc::ServerBuilder* server_builder); ~GrpcEagerServiceImpl()45 virtual ~GrpcEagerServiceImpl() {} 46 47 // Create a master context in eager service. 48 Status CreateMasterContext(const tensorflow::uint64 context_id, 49 EagerContext* context); 50 51 void HandleRPCsLoop() override; 52 void Shutdown() override; 53 54 private: 55 #define HANDLER(method) \ 56 void method##Handler(EagerCall<method##Request, method##Response>* call) { \ 57 env_->compute_pool->Schedule([this, call]() { \ 58 call->SendResponse( \ 59 ToGrpcStatus(local_impl_.method(&call->request, &call->response))); \ 60 }); \ 61 Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, \ 62 method##Request, method##Response>:: \ 63 EnqueueRequest(&service_, cq_.get(), \ 64 &grpc::EagerService::AsyncService::Request##method, \ 65 &GrpcEagerServiceImpl::method##Handler, false); \ 66 } 67 HANDLER(CreateContext); 68 HANDLER(UpdateContext); 69 HANDLER(WaitQueueDone); 70 HANDLER(KeepAlive); 71 HANDLER(CloseContext); 72 #undef HANDLER 73 EnqueueHandler(EagerCall<EnqueueRequest,EnqueueResponse> * call)74 void EnqueueHandler(EagerCall<EnqueueRequest, EnqueueResponse>* call) { 75 env_->compute_pool->Schedule([this, call]() { 76 auto call_opts = std::make_shared<CallOptions>(); 77 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); 78 call->SendResponse(ToGrpcStatus(local_impl_.Enqueue( 79 call_opts.get(), &call->request, &call->response))); 80 }); 81 Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, EnqueueRequest, 82 EnqueueResponse>:: 83 EnqueueRequest(&service_, cq_.get(), 84 &grpc::EagerService::AsyncService::RequestEnqueue, 85 &GrpcEagerServiceImpl::EnqueueHandler, 86 /*supports_cancel=*/true); 87 } 88 RunComponentFunctionHandler(EagerCall<RunComponentFunctionRequest,RunComponentFunctionResponse> * call)89 void RunComponentFunctionHandler( 90 EagerCall<RunComponentFunctionRequest, RunComponentFunctionResponse>* 91 call) { 92 env_->compute_pool->Schedule([this, call]() { 93 auto call_opts = std::make_shared<CallOptions>(); 94 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); 95 local_impl_.RunComponentFunction(call_opts.get(), &call->request, 96 &call->response, 97 [call, call_opts](const Status& s) { 98 call->ClearCancelCallback(); 99 call->SendResponse(ToGrpcStatus(s)); 100 }); 101 }); 102 Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, 103 RunComponentFunctionRequest, RunComponentFunctionResponse>:: 104 EnqueueRequest( 105 &service_, cq_.get(), 106 &grpc::EagerService::AsyncService::RequestRunComponentFunction, 107 &GrpcEagerServiceImpl::RunComponentFunctionHandler, 108 /*supports_cancel=*/true); 109 } 110 111 // Called when a new request has been received as part of a StreamingEnqueue 112 // call. 113 // StreamingEnqueueHandler gets the request from the `call` and fills the 114 // response (also found in `call`) by invoking the local EagerServiceImpl. 115 // The local EagerServiceImpl is invoked in a single-threaded thread pool. We 116 // do this to preserve request order. The local service can parallelize based 117 // on context_id in request if necessary. Remote contexts are created in async 118 // mode by default, so the local service impl just puts the request on eager 119 // executor queue. StreamingEnqueueHandler(StreamingCall<EnqueueRequest,EnqueueResponse> * call)120 void StreamingEnqueueHandler( 121 StreamingCall<EnqueueRequest, EnqueueResponse>* call) { 122 call->Ref(); 123 enqueue_streaming_thread_.Schedule([this, call]() { 124 if (call->RefCountIsOne()) { 125 // This StreamingCall has already been shutdown. Don't need to anything. 126 call->Unref(); 127 return; 128 } 129 // NOTE(fishx): Use the address of StreamingCall as the stream_id since we 130 // reuse the same StreamingCall for multiple requests in the same 131 // streaming connection. 132 Status status = local_impl_.Enqueue( 133 /*call_opts=*/nullptr, &call->request(), call->mutable_response(), 134 reinterpret_cast<uint64>(static_cast<void*>(call))); 135 136 if (status.ok()) { 137 VLOG(1) << "local_impl_.Enqueue completed successfully"; 138 call->SendResponse(); 139 } else { 140 VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString() 141 << " on request " << call->request().DebugString(); 142 call->Finish(ToGrpcStatus(status)); 143 } 144 call->Unref(); 145 146 // We do not tell gRPC to accept a new StreamingEnqueue request because 147 // this method can be called multiple times for a given streaming call. 148 // The StreamingCall does this per call instead, after a call has been 149 // opened. 150 }); 151 } 152 153 const WorkerEnv* const env_; // Not owned. 154 EagerServiceImpl local_impl_; 155 156 // A single-threaded thread pool to handle streaming enqueue rpc request. 157 thread::ThreadPool enqueue_streaming_thread_; 158 std::unique_ptr<::grpc::Alarm> shutdown_alarm_; 159 160 std::unique_ptr<::grpc::ServerCompletionQueue> cq_; 161 grpc::EagerService::AsyncService service_; 162 163 TF_DISALLOW_COPY_AND_ASSIGN(GrpcEagerServiceImpl); 164 }; 165 166 } // namespace eager 167 } // namespace tensorflow 168 169 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ 170