• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
18 
19 #include "grpcpp/alarm.h"
20 #include "grpcpp/completion_queue.h"
21 #include "grpcpp/server_builder.h"
22 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
23 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
24 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 
28 namespace tensorflow {
29 namespace eager {
30 
31 // This class is a wrapper that handles communication for gRPC.
32 class GrpcEagerServiceImpl : public AsyncServiceInterface {
33  public:
34   template <class RequestMessage, class ResponseMessage>
35   using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
36                          RequestMessage, ResponseMessage>;
37   template <class RequestMessage, class ResponseMessage>
38   using StreamingCall =
39       ServerBidirectionalStreamingCall<GrpcEagerServiceImpl,
40                                        grpc::EagerService::AsyncService,
41                                        RequestMessage, ResponseMessage>;
42 
43   GrpcEagerServiceImpl(const WorkerEnv* env,
44                        ::grpc::ServerBuilder* server_builder);
~GrpcEagerServiceImpl()45   virtual ~GrpcEagerServiceImpl() {}
46 
47   // Create a master context in eager service.
48   Status CreateMasterContext(const tensorflow::uint64 context_id,
49                              EagerContext* context);
50 
51   void HandleRPCsLoop() override;
52   void Shutdown() override;
53 
54  private:
55 #define HANDLER(method)                                                       \
56   void method##Handler(EagerCall<method##Request, method##Response>* call) {  \
57     env_->compute_pool->Schedule([this, call]() {                             \
58       call->SendResponse(                                                     \
59           ToGrpcStatus(local_impl_.method(&call->request, &call->response))); \
60     });                                                                       \
61     Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,              \
62          method##Request, method##Response>::                                 \
63         EnqueueRequest(&service_, cq_.get(),                                  \
64                        &grpc::EagerService::AsyncService::Request##method,    \
65                        &GrpcEagerServiceImpl::method##Handler, false);        \
66   }
67   HANDLER(CreateContext);
68   HANDLER(UpdateContext);
69   HANDLER(Enqueue);
70   HANDLER(WaitQueueDone);
71   HANDLER(KeepAlive);
72   HANDLER(CloseContext);
73 #undef HANDLER
74 
75   // Called when a new request has been received as part of a StreamingEnqueue
76   // call.
77   // StreamingEnqueueHandler gets the request from the `call` and fills the
78   // response (also found in `call`) by invoking the local EagerServiceImpl.
79   // The local EagerServiceImpl is invoked in a single-threaded thread pool. We
80   // do this to preserve request order. The local service can parallelize based
81   // on context_id in request if necessary. Remote contexts are created in async
82   // mode by default, so the local service impl just puts the request on eager
83   // executor queue.
StreamingEnqueueHandler(StreamingCall<EnqueueRequest,EnqueueResponse> * call)84   void StreamingEnqueueHandler(
85       StreamingCall<EnqueueRequest, EnqueueResponse>* call) {
86     call->Ref();
87     enqueue_streaming_thread_.Schedule([this, call]() {
88       if (call->RefCountIsOne()) {
89         // This StreamingCall has already been shutdown. Don't need to anything.
90         call->Unref();
91         return;
92       }
93       // NOTE(fishx): Use the address of StreamingCall as the stream_id since we
94       // reuse the same StreamingCall for multiple requests in the same
95       // streaming connection.
96       Status status = local_impl_.Enqueue(
97           &call->request(), call->mutable_response(),
98           reinterpret_cast<uint64>(static_cast<void*>(call)));
99 
100       if (status.ok()) {
101         VLOG(1) << "local_impl_.Enqueue completed successfully";
102         call->SendResponse();
103       } else {
104         VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString()
105                 << " on request " << call->request().DebugString();
106         call->Finish(ToGrpcStatus(status));
107       }
108       call->Unref();
109 
110       // We do not tell gRPC to accept a new StreamingEnqueue request because
111       // this method can be called multiple times for a given streaming call.
112       // The StreamingCall does this per call instead, after a call has been
113       // opened.
114     });
115   }
116 
117   const WorkerEnv* const env_;  // Not owned.
118   EagerServiceImpl local_impl_;
119 
120   // A single-threaded thread pool to handle streaming enqueue rpc request.
121   thread::ThreadPool enqueue_streaming_thread_;
122   std::unique_ptr<::grpc::Alarm> shutdown_alarm_;
123 
124   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
125   grpc::EagerService::AsyncService service_;
126 
127   TF_DISALLOW_COPY_AND_ASSIGN(GrpcEagerServiceImpl);
128 };
129 
130 }  // namespace eager
131 }  // namespace tensorflow
132 
133 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
134