1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ 18 19 #include "grpcpp/alarm.h" 20 #include "grpcpp/completion_queue.h" 21 #include "grpcpp/server_builder.h" 22 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" 23 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 24 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 28 namespace tensorflow { 29 namespace eager { 30 31 // This class is a wrapper that handles communication for gRPC. 32 class GrpcEagerServiceImpl : public AsyncServiceInterface { 33 public: 34 template <class RequestMessage, class ResponseMessage> 35 using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, 36 RequestMessage, ResponseMessage>; 37 template <class RequestMessage, class ResponseMessage> 38 using StreamingCall = 39 ServerBidirectionalStreamingCall<GrpcEagerServiceImpl, 40 grpc::EagerService::AsyncService, 41 RequestMessage, ResponseMessage>; 42 43 GrpcEagerServiceImpl(const WorkerEnv* env, 44 ::grpc::ServerBuilder* server_builder); ~GrpcEagerServiceImpl()45 virtual ~GrpcEagerServiceImpl() {} 46 47 // Create a master context in eager service. 48 Status CreateMasterContext(const tensorflow::uint64 context_id, 49 EagerContext* context); 50 51 void HandleRPCsLoop() override; 52 void Shutdown() override; 53 54 private: 55 #define HANDLER(method) \ 56 void method##Handler(EagerCall<method##Request, method##Response>* call) { \ 57 env_->compute_pool->Schedule([this, call]() { \ 58 call->SendResponse( \ 59 ToGrpcStatus(local_impl_.method(&call->request, &call->response))); \ 60 }); \ 61 Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, \ 62 method##Request, method##Response>:: \ 63 EnqueueRequest(&service_, cq_.get(), \ 64 &grpc::EagerService::AsyncService::Request##method, \ 65 &GrpcEagerServiceImpl::method##Handler, false); \ 66 } 67 HANDLER(CreateContext); 68 HANDLER(UpdateContext); 69 HANDLER(Enqueue); 70 HANDLER(WaitQueueDone); 71 HANDLER(KeepAlive); 72 HANDLER(CloseContext); 73 #undef HANDLER 74 75 // Called when a new request has been received as part of a StreamingEnqueue 76 // call. 77 // StreamingEnqueueHandler gets the request from the `call` and fills the 78 // response (also found in `call`) by invoking the local EagerServiceImpl. 79 // The local EagerServiceImpl is invoked in a single-threaded thread pool. We 80 // do this to preserve request order. The local service can parallelize based 81 // on context_id in request if necessary. Remote contexts are created in async 82 // mode by default, so the local service impl just puts the request on eager 83 // executor queue. StreamingEnqueueHandler(StreamingCall<EnqueueRequest,EnqueueResponse> * call)84 void StreamingEnqueueHandler( 85 StreamingCall<EnqueueRequest, EnqueueResponse>* call) { 86 call->Ref(); 87 enqueue_streaming_thread_.Schedule([this, call]() { 88 if (call->RefCountIsOne()) { 89 // This StreamingCall has already been shutdown. Don't need to anything. 90 call->Unref(); 91 return; 92 } 93 // NOTE(fishx): Use the address of StreamingCall as the stream_id since we 94 // reuse the same StreamingCall for multiple requests in the same 95 // streaming connection. 96 Status status = local_impl_.Enqueue( 97 &call->request(), call->mutable_response(), 98 reinterpret_cast<uint64>(static_cast<void*>(call))); 99 100 if (status.ok()) { 101 VLOG(1) << "local_impl_.Enqueue completed successfully"; 102 call->SendResponse(); 103 } else { 104 VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString() 105 << " on request " << call->request().DebugString(); 106 call->Finish(ToGrpcStatus(status)); 107 } 108 call->Unref(); 109 110 // We do not tell gRPC to accept a new StreamingEnqueue request because 111 // this method can be called multiple times for a given streaming call. 112 // The StreamingCall does this per call instead, after a call has been 113 // opened. 114 }); 115 } 116 117 const WorkerEnv* const env_; // Not owned. 118 EagerServiceImpl local_impl_; 119 120 // A single-threaded thread pool to handle streaming enqueue rpc request. 121 thread::ThreadPool enqueue_streaming_thread_; 122 std::unique_ptr<::grpc::Alarm> shutdown_alarm_; 123 124 std::unique_ptr<::grpc::ServerCompletionQueue> cq_; 125 grpc::EagerService::AsyncService service_; 126 127 TF_DISALLOW_COPY_AND_ASSIGN(GrpcEagerServiceImpl); 128 }; 129 130 } // namespace eager 131 } // namespace tensorflow 132 133 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ 134