• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
18 
19 #include "grpcpp/alarm.h"
20 #include "grpcpp/completion_queue.h"
21 #include "grpcpp/server_builder.h"
22 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
23 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
24 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 
28 namespace tensorflow {
29 namespace eager {
30 
31 // This class is a wrapper that handles communication for gRPC.
32 class GrpcEagerServiceImpl : public AsyncServiceInterface {
33  public:
34   template <class RequestMessage, class ResponseMessage>
35   using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
36                          RequestMessage, ResponseMessage>;
37   template <class RequestMessage, class ResponseMessage>
38   using StreamingCall =
39       ServerBidirectionalStreamingCall<GrpcEagerServiceImpl,
40                                        grpc::EagerService::AsyncService,
41                                        RequestMessage, ResponseMessage>;
42 
43   GrpcEagerServiceImpl(const WorkerEnv* env,
44                        ::grpc::ServerBuilder* server_builder);
~GrpcEagerServiceImpl()45   virtual ~GrpcEagerServiceImpl() {}
46 
47   // Create a master context in eager service.
48   Status CreateMasterContext(const tensorflow::uint64 context_id,
49                              EagerContext* context);
50 
51   void HandleRPCsLoop() override;
52   void Shutdown() override;
53 
54  private:
55 #define HANDLER(method)                                                       \
56   void method##Handler(EagerCall<method##Request, method##Response>* call) {  \
57     env_->compute_pool->Schedule([this, call]() {                             \
58       call->SendResponse(                                                     \
59           ToGrpcStatus(local_impl_.method(&call->request, &call->response))); \
60     });                                                                       \
61     Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,              \
62          method##Request, method##Response>::                                 \
63         EnqueueRequest(&service_, cq_.get(),                                  \
64                        &grpc::EagerService::AsyncService::Request##method,    \
65                        &GrpcEagerServiceImpl::method##Handler, false);        \
66   }
67   HANDLER(CreateContext);
68   HANDLER(UpdateContext);
69   HANDLER(WaitQueueDone);
70   HANDLER(KeepAlive);
71   HANDLER(CloseContext);
72 #undef HANDLER
73 
EnqueueHandler(EagerCall<EnqueueRequest,EnqueueResponse> * call)74   void EnqueueHandler(EagerCall<EnqueueRequest, EnqueueResponse>* call) {
75     env_->compute_pool->Schedule([this, call]() {
76       auto call_opts = std::make_shared<CallOptions>();
77       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
78       call->SendResponse(ToGrpcStatus(local_impl_.Enqueue(
79           call_opts.get(), &call->request, &call->response)));
80     });
81     Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, EnqueueRequest,
82          EnqueueResponse>::
83         EnqueueRequest(&service_, cq_.get(),
84                        &grpc::EagerService::AsyncService::RequestEnqueue,
85                        &GrpcEagerServiceImpl::EnqueueHandler,
86                        /*supports_cancel=*/true);
87   }
88 
RunComponentFunctionHandler(EagerCall<RunComponentFunctionRequest,RunComponentFunctionResponse> * call)89   void RunComponentFunctionHandler(
90       EagerCall<RunComponentFunctionRequest, RunComponentFunctionResponse>*
91           call) {
92     env_->compute_pool->Schedule([this, call]() {
93       auto call_opts = std::make_shared<CallOptions>();
94       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
95       local_impl_.RunComponentFunction(call_opts.get(), &call->request,
96                                        &call->response,
97                                        [call, call_opts](const Status& s) {
98                                          call->ClearCancelCallback();
99                                          call->SendResponse(ToGrpcStatus(s));
100                                        });
101     });
102     Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
103          RunComponentFunctionRequest, RunComponentFunctionResponse>::
104         EnqueueRequest(
105             &service_, cq_.get(),
106             &grpc::EagerService::AsyncService::RequestRunComponentFunction,
107             &GrpcEagerServiceImpl::RunComponentFunctionHandler,
108             /*supports_cancel=*/true);
109   }
110 
111   // Called when a new request has been received as part of a StreamingEnqueue
112   // call.
113   // StreamingEnqueueHandler gets the request from the `call` and fills the
114   // response (also found in `call`) by invoking the local EagerServiceImpl.
115   // The local EagerServiceImpl is invoked in a single-threaded thread pool. We
116   // do this to preserve request order. The local service can parallelize based
117   // on context_id in request if necessary. Remote contexts are created in async
118   // mode by default, so the local service impl just puts the request on eager
119   // executor queue.
StreamingEnqueueHandler(StreamingCall<EnqueueRequest,EnqueueResponse> * call)120   void StreamingEnqueueHandler(
121       StreamingCall<EnqueueRequest, EnqueueResponse>* call) {
122     call->Ref();
123     enqueue_streaming_thread_.Schedule([this, call]() {
124       if (call->RefCountIsOne()) {
125         // This StreamingCall has already been shutdown. Don't need to anything.
126         call->Unref();
127         return;
128       }
129       // NOTE(fishx): Use the address of StreamingCall as the stream_id since we
130       // reuse the same StreamingCall for multiple requests in the same
131       // streaming connection.
132       Status status = local_impl_.Enqueue(
133           /*call_opts=*/nullptr, &call->request(), call->mutable_response(),
134           reinterpret_cast<uint64>(static_cast<void*>(call)));
135 
136       if (status.ok()) {
137         VLOG(1) << "local_impl_.Enqueue completed successfully";
138         call->SendResponse();
139       } else {
140         VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString()
141                 << " on request " << call->request().DebugString();
142         call->Finish(ToGrpcStatus(status));
143       }
144       call->Unref();
145 
146       // We do not tell gRPC to accept a new StreamingEnqueue request because
147       // this method can be called multiple times for a given streaming call.
148       // The StreamingCall does this per call instead, after a call has been
149       // opened.
150     });
151   }
152 
153   const WorkerEnv* const env_;  // Not owned.
154   EagerServiceImpl local_impl_;
155 
156   // A single-threaded thread pool to handle streaming enqueue rpc request.
157   thread::ThreadPool enqueue_streaming_thread_;
158   std::unique_ptr<::grpc::Alarm> shutdown_alarm_;
159 
160   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
161   grpc::EagerService::AsyncService service_;
162 
163   TF_DISALLOW_COPY_AND_ASSIGN(GrpcEagerServiceImpl);
164 };
165 
166 }  // namespace eager
167 }  // namespace tensorflow
168 
169 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
170