• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
18 
19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service.
20 
21 #include <memory>
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
28 #include "tensorflow/core/distributed_runtime/master_env.h"
29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
32 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
33 #include "tensorflow/core/distributed_runtime/server_lib.h"
34 #include "tensorflow/core/distributed_runtime/session_mgr.h"
35 #include "tensorflow/core/distributed_runtime/worker_env.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/platform/env.h"
39 
40 namespace tensorflow {
41 
42 class GrpcWorker;
43 class Master;
44 
45 // function that creates a RendezvousMgr.
46 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
47     RendezvousMgrCreationFunction;
48 
49 // function that creates a CollectiveExecutorMgr.
50 typedef std::function<CollectiveExecutorMgrInterface*(
51     const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
52     CollectiveMgrCreationFunction;
53 
54 // function that registers a service to the server. The service needs to
55 // be registered before builder.BuildAndStart().
56 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
57     ServiceInitFunction;
58 
59 // function that creates a grpc based worker implementation.
60 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*,
61                                                   const ConfigProto& config)>
62     WorkerCreationFunction;
63 
64 struct GrpcServerOptions {
65   ServiceInitFunction service_func = nullptr;
66   RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr;
67   CollectiveMgrCreationFunction collective_mgr_func = nullptr;
68   WorkerCreationFunction worker_func = nullptr;
69   StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
70   GrpcWorkerServiceOptions worker_service_options;
71 };
72 
73 class GrpcServer : public ServerInterface {
74  protected:
75   GrpcServer(const ServerDef& server_def, Env* env);
76   // Allow children classes to override this and provide custom args to the
77   // server before it is constructed. Default behavior is to do nothing.
78   virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder);
79 
80  public:
81   static Status Create(const ServerDef& server_def, Env* env,
82                        std::unique_ptr<ServerInterface>* out_server);
83   static Status Create(const ServerDef& server_def, Env* env,
84                        std::unique_ptr<GrpcServer>* out_server);
85 
86   // Destruction is only supported in the factory method. Clean
87   // shutdown is not currently implemented for this server type.
88   virtual ~GrpcServer();
89 
90   // Implementations of ServerInterface methods.
91   Status Start() override;
92   Status Stop() override;
93   Status Join() override;
94   const string target() const override;
95 
worker_env()96   WorkerEnv* worker_env() { return &worker_env_; }
master_env()97   MasterEnv* master_env() { return &master_env_; }
98 
99   // Add master eager context to local eager service in order to handle enqueue
100   // requests from remote workers.
101   Status AddMasterEagerContextToEagerService(
102       const tensorflow::uint64 context_id, tensorflow::EagerContext* context);
103   // Update the set of workers that can be reached by the GRPC server
104   Status UpdateServerDef(const ServerDef& server_def);
105 
106  protected:
107   virtual Status GetPort(const ServerDef& server_def, int* port) const;
108   Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
109 
110   // A subclass can override this method to support secure credentials.
111   virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
112       const ServerDef& server_def) const;
113 
114   virtual ChannelCreationFunction GetChannelCreationFunction() const;
115 
116   virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
117 
118   // Creates a WorkerCacheInterface for a session.
119   virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
120                                     WorkerCacheInterface** worker_cache);
121 
122   // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
123   Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
124                           GrpcChannelSpec* channel_spec);
125 
126   // Returns the port to which this server is bound.
127   // This method may only be called after `this->Init()` returns successfully.
bound_port()128   int bound_port() const { return bound_port_; }
129 
server_def()130   const ServerDef& server_def() const { return server_def_; }
worker_impl()131   GrpcWorker* worker_impl() const { return worker_impl_.get(); }
132 
133  private:
134   Env* env_;
135 
136   // The port to which this server is bound.
137   int bound_port_ = 0;
138 
139   // Guards server configuration, server, and state.
140   mutex mu_;
141 
142   // Represents the current state of the server, which changes as follows:
143   //
144   //                 Join()            Join()
145   //                  ___               ___
146   //      Start()     \ /    Stop()     \ /
147   // NEW ---------> STARTED --------> STOPPED
148   //   \                          /
149   //    \________________________/
150   //            Stop(), Join()
151   enum State { NEW, STARTED, STOPPED };
152   State state_ GUARDED_BY(mu_);
153 
154   // Implementation of a TensorFlow master, and RPC polling thread.
155   MasterEnv master_env_;
156   std::unique_ptr<Master> master_impl_;
157   AsyncServiceInterface* master_service_ = nullptr;
158   std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_);
159 
160   // Implementation of a TensorFlow worker, and RPC polling thread.
161   WorkerEnv worker_env_;
162   std::unique_ptr<GrpcWorker> worker_impl_;
163   AsyncServiceInterface* worker_service_ = nullptr;
164   std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_);
165   std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
166 
167   // TensorFlow Eager implementation, and RPC polling thread.
168   AsyncServiceInterface* eager_service_ = nullptr;
169   std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_);
170   std::shared_ptr<WorkerSession> worker_session_;
171 
172   // The overall server configuration.
173   ServerDef server_def_ GUARDED_BY(mu_);
174 
175   std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
176 };
177 
178 }  // namespace tensorflow
179 
180 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
181