• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
18 
19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service.
20 
21 #include <memory>
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
28 #include "tensorflow/core/distributed_runtime/master_env.h"
29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
32 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
33 #include "tensorflow/core/distributed_runtime/server_lib.h"
34 #include "tensorflow/core/distributed_runtime/session_mgr.h"
35 #include "tensorflow/core/distributed_runtime/worker_env.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
40 
41 namespace tensorflow {
42 
43 class GrpcWorker;
44 class Master;
45 
46 // function that creates a RendezvousMgr.
47 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
48     RendezvousMgrCreationFunction;
49 
50 // function that creates a CollectiveExecutorMgr.
51 typedef std::function<CollectiveExecutorMgrInterface*(
52     const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
53     CollectiveMgrCreationFunction;
54 
55 // function that registers a service to the server. The service needs to
56 // be registered before builder.BuildAndStart().
57 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
58     ServiceInitFunction;
59 
60 // function that creates a grpc based worker implementation.
61 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*,
62                                                   const ConfigProto& config)>
63     WorkerCreationFunction;
64 
65 struct GrpcServerOptions {
66   ServiceInitFunction service_func = nullptr;
67   RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr;
68   CollectiveMgrCreationFunction collective_mgr_func = nullptr;
69   WorkerCreationFunction worker_func = nullptr;
70   StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
71   GrpcWorkerServiceOptions worker_service_options;
72   DeviceMgr* local_device_mgr = nullptr;
73 };
74 
75 class GrpcServer : public ServerInterface {
76  protected:
77   GrpcServer(const ServerDef& server_def, Env* env);
78   GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr,
79              Env* env);
80   // Allow children classes to override this and provide custom args to the
81   // server before it is constructed. Default behavior is to do nothing.
82   // requested_port provides the port requested by caller as bound_port() is
83   // not available till BuildAndStart has been called.
MaybeMutateBuilder(::grpc::ServerBuilder * builder,int requested_port)84   virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder,
85                                   int requested_port) {}
86 
87  public:
88   static Status Create(const ServerDef& server_def, Env* env,
89                        std::unique_ptr<ServerInterface>* out_server);
90   static Status Create(const ServerDef& server_def, Env* env,
91                        std::unique_ptr<GrpcServer>* out_server);
92   // Reuse the local_device_mgr.
93   static Status Create(const ServerDef& server_def, Env* env,
94                        DeviceMgr* local_device_mgr,
95                        std::unique_ptr<ServerInterface>* out_server);
96 
97   // Destruction is only supported in the factory method. Clean
98   // shutdown is not currently implemented for this server type.
99   virtual ~GrpcServer();
100 
101   // Implementations of ServerInterface methods.
102   Status Start() override;
103   Status Stop() override;
104   Status Join() override;
105   const string target() const override;
106 
worker_env()107   WorkerEnv* worker_env() override { return &worker_env_; }
master_env()108   MasterEnv* master_env() override { return &master_env_; }
109 
110   // Add master eager context to local eager service in order to handle enqueue
111   // requests from remote workers.
112   Status AddMasterEagerContextToEagerService(
113       const tensorflow::uint64 context_id,
114       tensorflow::EagerContext* context) override;
115   // Update the set of workers that can be reached by the GRPC server
116   Status UpdateServerDef(const ServerDef& server_def) override;
117   // Pass coordination service agent instance to server's RPC handler
118   Status SetCoordinationServiceAgentInstance(
119       CoordinationServiceAgent* agent) override;
120   // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is
121   // supported.
122   Status StopCoordinationService() override;
123 
124  protected:
125   virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name,
126                                 int* port) const;
127   Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
128 
129   // A subclass can override this method to support secure credentials.
130   virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
131       const ServerDef& server_def) const;
132 
133   virtual ChannelCreationFunction GetChannelCreationFunction() const;
134 
135   virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
136 
137   // Creates a WorkerCacheInterface for a session.
138   virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
139                                     WorkerCacheInterface** worker_cache);
140 
141   // Override to return extra services to be brought up and managed along with
142   // the standard {master, worker, eager} services. The map key is an aribtrary
143   // string and the value is a pointer to the service to be brought up.
144   // Ownership of the pointer is transferred to GrpcServer after this call
145   // returns, and the service will be destroyed during the destruction of
146   // GrpcServer. Each service will have its HandleRPCsLoop called in a separate
147   // thread. An example usage would be to add a RDMA based partial worker
148   // service to offload tensor and data buffer transfers.
ExtraServices(::grpc::ServerBuilder *)149   virtual std::map<std::string, AsyncServiceInterface*> ExtraServices(
150       ::grpc::ServerBuilder*) {
151     return {};
152   }
153 
GetExtraServices()154   virtual std::map<std::string, AsyncServiceInterface*> GetExtraServices() {
155     return extra_services_;
156   }
157 
158   // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
159   Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
160                           GrpcChannelSpec* channel_spec);
161 
162   // Returns the port to which this server is bound.
163   // This method may only be called after `this->Init()` returns successfully.
bound_port()164   int bound_port() const { return bound_port_; }
165 
166   // Returns hostname.
host_name()167   const string& host_name() const { return host_name_; }
168 
server_def()169   const ServerDef& server_def() const { return server_def_; }
worker_impl()170   GrpcWorker* worker_impl() const { return worker_impl_.get(); }
grpc_worker_env()171   GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); }
172 
173  private:
174   Env* env_;
175 
176   // The port to which this server is bound.
177   int bound_port_ = 0;
178 
179   // The host name of this server
180   string host_name_;
181 
182   // Guards server configuration, server, and state.
183   mutex mu_;
184 
185   // Represents the current state of the server, which changes as follows:
186   //
187   //                 Join()            Join()
188   //                  ___               ___
189   //      Start()     \ /    Stop()     \ /
190   // NEW ---------> STARTED --------> STOPPED
191   //   \                          /
192   //    \________________________/
193   //            Stop(), Join()
194   enum State { NEW, STARTED, STOPPED };
195   State state_ TF_GUARDED_BY(mu_);
196 
197   // Implementation of a TensorFlow master, and RPC polling thread.
198   MasterEnv master_env_;
199   std::unique_ptr<Master> master_impl_;
200   AsyncServiceInterface* master_service_ = nullptr;
201   std::unique_ptr<Thread> master_thread_ TF_GUARDED_BY(mu_);
202 
203   std::map<std::string, AsyncServiceInterface*> extra_services_;
204   std::vector<std::unique_ptr<Thread>> extra_service_threads_
205       TF_GUARDED_BY(mu_);
206 
207   // Implementation of a TensorFlow worker, and RPC polling thread.
208   WorkerEnv worker_env_;
209   std::unique_ptr<const DeviceMgr> owned_device_manager_;
210   std::unique_ptr<GrpcWorker> worker_impl_;
211   AsyncServiceInterface* worker_service_ = nullptr;
212   std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_);
213   std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
214 
215   // TensorFlow Eager implementation, and RPC polling thread.
216   AsyncServiceInterface* eager_service_ = nullptr;
217   std::unique_ptr<Thread> eager_thread_ TF_GUARDED_BY(mu_);
218   std::shared_ptr<WorkerSession> worker_session_;
219 
220   // Experimental coordination service implementation, and RPC polling thread.
221   AsyncServiceInterface* coordination_service_ = nullptr;
222   std::unique_ptr<Thread> coordination_thread_ TF_GUARDED_BY(mu_);
223 
224   // TensorFlow profiler service implementation.
225   std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr;
226 
227   // The overall server configuration.
228   ServerDef server_def_ TF_GUARDED_BY(mu_);
229 
230   std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_);
231 };
232 
233 }  // namespace tensorflow
234 
235 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
236