• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
18 
19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service.
20 
21 #include <memory>
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
28 #include "tensorflow/core/distributed_runtime/master_env.h"
29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
32 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
33 #include "tensorflow/core/distributed_runtime/server_lib.h"
34 #include "tensorflow/core/distributed_runtime/session_mgr.h"
35 #include "tensorflow/core/distributed_runtime/worker_env.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
40 
41 namespace tensorflow {
42 
43 class GrpcWorker;
44 class Master;
45 
46 // function that creates a RendezvousMgr.
47 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
48     RendezvousMgrCreationFunction;
49 
50 // function that creates a CollectiveExecutorMgr.
51 typedef std::function<CollectiveExecutorMgrInterface*(
52     const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
53     CollectiveMgrCreationFunction;
54 
55 // function that registers a service to the server. The service needs to
56 // be registered before builder.BuildAndStart().
57 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
58     ServiceInitFunction;
59 
60 // function that creates a grpc based worker implementation.
61 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*,
62                                                   const ConfigProto& config)>
63     WorkerCreationFunction;
64 
65 struct GrpcServerOptions {
66   ServiceInitFunction service_func = nullptr;
67   RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr;
68   CollectiveMgrCreationFunction collective_mgr_func = nullptr;
69   WorkerCreationFunction worker_func = nullptr;
70   StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
71   GrpcWorkerServiceOptions worker_service_options;
72   DeviceMgr* local_device_mgr = nullptr;
73 };
74 
75 class GrpcServer : public ServerInterface {
76  protected:
77   GrpcServer(const ServerDef& server_def, Env* env);
78   GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr,
79              Env* env);
80   // Allow children classes to override this and provide custom args to the
81   // server before it is constructed. Default behavior is to do nothing.
82   // requested_port provides the port requested by caller as bound_port() is
83   // not available till BuildAndStart has been called.
MaybeMutateBuilder(::grpc::ServerBuilder * builder,int requested_port)84   virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder,
85                                   int requested_port) {}
86 
87  public:
88   static Status Create(const ServerDef& server_def, Env* env,
89                        std::unique_ptr<ServerInterface>* out_server);
90   static Status Create(const ServerDef& server_def, Env* env,
91                        std::unique_ptr<GrpcServer>* out_server);
92   // Reuse the local_device_mgr.
93   static Status Create(const ServerDef& server_def, Env* env,
94                        DeviceMgr* local_device_mgr,
95                        std::unique_ptr<ServerInterface>* out_server);
96 
97   // Destruction is only supported in the factory method. Clean
98   // shutdown is not currently implemented for this server type.
99   virtual ~GrpcServer();
100 
101   // Implementations of ServerInterface methods.
102   Status Start() override;
103   Status Stop() override;
104   Status Join() override;
105   const string target() const override;
106 
worker_env()107   WorkerEnv* worker_env() override { return &worker_env_; }
master_env()108   MasterEnv* master_env() override { return &master_env_; }
109 
110   // Add master eager context to local eager service in order to handle enqueue
111   // requests from remote workers.
112   Status AddMasterEagerContextToEagerService(
113       const tensorflow::uint64 context_id,
114       tensorflow::EagerContext* context) override;
115   // Update the set of workers that can be reached by the GRPC server
116   Status UpdateServerDef(const ServerDef& server_def) override;
117   // Pass coordination service agent instance to server's RPC handler
118   Status SetCoordinationServiceAgentInstance(
119       CoordinationServiceAgent* agent) override;
120 
121  protected:
122   virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name,
123                                 int* port) const;
124   Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
125 
126   // A subclass can override this method to support secure credentials.
127   virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
128       const ServerDef& server_def) const;
129 
130   virtual ChannelCreationFunction GetChannelCreationFunction() const;
131 
132   virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
133 
134   // Creates a WorkerCacheInterface for a session.
135   virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
136                                     WorkerCacheInterface** worker_cache);
137 
138   // Override to return extra services to be brought up and managed along with
139   // the standard {master, worker, eager} services. The map key is an aribtrary
140   // string and the value is a pointer to the service to be brought up.
141   // Ownership of the pointer is transferred to GrpcServer after this call
142   // returns, and the service will be destroyed during the destruction of
143   // GrpcServer. Each service will have its HandleRPCsLoop called in a separate
144   // thread. An example usage would be to add a RDMA based partial worker
145   // service to offload tensor and data buffer transfers.
ExtraServices(::grpc::ServerBuilder *)146   virtual std::map<std::string, AsyncServiceInterface*> ExtraServices(
147       ::grpc::ServerBuilder*) {
148     return {};
149   }
150 
GetExtraServices()151   virtual std::map<std::string, AsyncServiceInterface*> GetExtraServices() {
152     return extra_services_;
153   }
154 
155   // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
156   Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
157                           GrpcChannelSpec* channel_spec);
158 
159   // Returns the port to which this server is bound.
160   // This method may only be called after `this->Init()` returns successfully.
bound_port()161   int bound_port() const { return bound_port_; }
162 
163   // Returns hostname.
host_name()164   const string& host_name() const { return host_name_; }
165 
server_def()166   const ServerDef& server_def() const { return server_def_; }
worker_impl()167   GrpcWorker* worker_impl() const { return worker_impl_.get(); }
grpc_worker_env()168   GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); }
169 
170  private:
171   Env* env_;
172 
173   // The port to which this server is bound.
174   int bound_port_ = 0;
175 
176   // The host name of this server
177   string host_name_;
178 
179   // Guards server configuration, server, and state.
180   mutex mu_;
181 
182   // Represents the current state of the server, which changes as follows:
183   //
184   //                 Join()            Join()
185   //                  ___               ___
186   //      Start()     \ /    Stop()     \ /
187   // NEW ---------> STARTED --------> STOPPED
188   //   \                          /
189   //    \________________________/
190   //            Stop(), Join()
191   enum State { NEW, STARTED, STOPPED };
192   State state_ TF_GUARDED_BY(mu_);
193 
194   // Implementation of a TensorFlow master, and RPC polling thread.
195   MasterEnv master_env_;
196   std::unique_ptr<Master> master_impl_;
197   AsyncServiceInterface* master_service_ = nullptr;
198   std::unique_ptr<Thread> master_thread_ TF_GUARDED_BY(mu_);
199 
200   std::map<std::string, AsyncServiceInterface*> extra_services_;
201   std::vector<std::unique_ptr<Thread>> extra_service_threads_
202       TF_GUARDED_BY(mu_);
203 
204   // Implementation of a TensorFlow worker, and RPC polling thread.
205   WorkerEnv worker_env_;
206   std::unique_ptr<const DeviceMgr> owned_device_manager_;
207   std::unique_ptr<GrpcWorker> worker_impl_;
208   AsyncServiceInterface* worker_service_ = nullptr;
209   std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_);
210   std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
211 
212   // TensorFlow Eager implementation, and RPC polling thread.
213   AsyncServiceInterface* eager_service_ = nullptr;
214   std::unique_ptr<Thread> eager_thread_ TF_GUARDED_BY(mu_);
215   std::shared_ptr<WorkerSession> worker_session_;
216 
217   // TensorFlow profiler service implementation.
218   std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr;
219 
220   // The overall server configuration.
221   ServerDef server_def_ TF_GUARDED_BY(mu_);
222 
223   std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_);
224 };
225 
226 }  // namespace tensorflow
227 
228 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
229