• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
18 
19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service.
20 
21 #include <memory>
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
28 #include "tensorflow/core/distributed_runtime/master_env.h"
29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
32 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
33 #include "tensorflow/core/distributed_runtime/server_lib.h"
34 #include "tensorflow/core/distributed_runtime/session_mgr.h"
35 #include "tensorflow/core/distributed_runtime/worker_env.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
40 
41 namespace tensorflow {
42 
43 class GrpcWorker;
44 class Master;
45 
46 // function that creates a RendezvousMgr.
47 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
48     RendezvousMgrCreationFunction;
49 
50 // function that creates a CollectiveExecutorMgr.
51 typedef std::function<CollectiveExecutorMgrInterface*(
52     const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
53     CollectiveMgrCreationFunction;
54 
55 // function that registers a service to the server. The service needs to
56 // be registered before builder.BuildAndStart().
57 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
58     ServiceInitFunction;
59 
60 // function that creates a grpc based worker implementation.
61 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*,
62                                                   const ConfigProto& config)>
63     WorkerCreationFunction;
64 
65 struct GrpcServerOptions {
66   ServiceInitFunction service_func = nullptr;
67   RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr;
68   CollectiveMgrCreationFunction collective_mgr_func = nullptr;
69   WorkerCreationFunction worker_func = nullptr;
70   StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
71   GrpcWorkerServiceOptions worker_service_options;
72   const DeviceMgr* local_device_mgr = nullptr;
73 };
74 
75 class GrpcServer : public ServerInterface {
76  protected:
77   GrpcServer(const ServerDef& server_def, Env* env);
78   GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr,
79              Env* env);
80   // Allow children classes to override this and provide custom args to the
81   // server before it is constructed. Default behavior is to do nothing.
82   virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder);
83 
84  public:
85   static Status Create(const ServerDef& server_def, Env* env,
86                        std::unique_ptr<ServerInterface>* out_server);
87   static Status Create(const ServerDef& server_def, Env* env,
88                        std::unique_ptr<GrpcServer>* out_server);
89   // Reuse the local_device_mgr.
90   static Status Create(const ServerDef& server_def, Env* env,
91                        const DeviceMgr* local_device_mgr,
92                        std::unique_ptr<ServerInterface>* out_server);
93 
94   // Destruction is only supported in the factory method. Clean
95   // shutdown is not currently implemented for this server type.
96   virtual ~GrpcServer();
97 
98   // Implementations of ServerInterface methods.
99   Status Start() override;
100   Status Stop() override;
101   Status Join() override;
102   const string target() const override;
103 
worker_env()104   WorkerEnv* worker_env() { return &worker_env_; }
master_env()105   MasterEnv* master_env() { return &master_env_; }
106 
107   // Add master eager context to local eager service in order to handle enqueue
108   // requests from remote workers.
109   Status AddMasterEagerContextToEagerService(
110       const tensorflow::uint64 context_id, tensorflow::EagerContext* context);
111   // Update the set of workers that can be reached by the GRPC server
112   Status UpdateServerDef(const ServerDef& server_def);
113 
114  protected:
115   virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name,
116                                 int* port) const;
117   Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
118 
119   // A subclass can override this method to support secure credentials.
120   virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
121       const ServerDef& server_def) const;
122 
123   virtual ChannelCreationFunction GetChannelCreationFunction() const;
124 
125   virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
126 
127   // Creates a WorkerCacheInterface for a session.
128   virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
129                                     WorkerCacheInterface** worker_cache);
130 
131   // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
132   Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
133                           GrpcChannelSpec* channel_spec);
134 
135   // Returns the port to which this server is bound.
136   // This method may only be called after `this->Init()` returns successfully.
bound_port()137   int bound_port() const { return bound_port_; }
138 
139   // Returns hostname.
host_name()140   const string& host_name() const { return host_name_; }
141 
server_def()142   const ServerDef& server_def() const { return server_def_; }
worker_impl()143   GrpcWorker* worker_impl() const { return worker_impl_.get(); }
grpc_worker_env()144   GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); }
145 
146  private:
147   Env* env_;
148 
149   // The port to which this server is bound.
150   int bound_port_ = 0;
151 
152   // The host name of this server
153   string host_name_;
154 
155   // Guards server configuration, server, and state.
156   mutex mu_;
157 
158   // Represents the current state of the server, which changes as follows:
159   //
160   //                 Join()            Join()
161   //                  ___               ___
162   //      Start()     \ /    Stop()     \ /
163   // NEW ---------> STARTED --------> STOPPED
164   //   \                          /
165   //    \________________________/
166   //            Stop(), Join()
167   enum State { NEW, STARTED, STOPPED };
168   State state_ TF_GUARDED_BY(mu_);
169 
170   // Implementation of a TensorFlow master, and RPC polling thread.
171   MasterEnv master_env_;
172   std::unique_ptr<Master> master_impl_;
173   AsyncServiceInterface* master_service_ = nullptr;
174   std::unique_ptr<Thread> master_thread_ TF_GUARDED_BY(mu_);
175 
176   // Implementation of a TensorFlow worker, and RPC polling thread.
177   WorkerEnv worker_env_;
178   std::unique_ptr<const DeviceMgr> owned_device_manager_;
179   std::unique_ptr<GrpcWorker> worker_impl_;
180   AsyncServiceInterface* worker_service_ = nullptr;
181   std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_);
182   std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
183 
184   // TensorFlow Eager implementation, and RPC polling thread.
185   AsyncServiceInterface* eager_service_ = nullptr;
186   std::unique_ptr<Thread> eager_thread_ TF_GUARDED_BY(mu_);
187   std::shared_ptr<WorkerSession> worker_session_;
188 
189   // TensorFlow profiler service implementation.
190   std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr;
191 
192   // The overall server configuration.
193   ServerDef server_def_ TF_GUARDED_BY(mu_);
194 
195   std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_);
196 };
197 
198 }  // namespace tensorflow
199 
200 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
201