1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 18 19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service. 20 21 #include <memory> 22 23 #include "grpcpp/grpcpp.h" 24 #include "grpcpp/security/credentials.h" 25 #include "tensorflow/core/common_runtime/eager/context.h" 26 #include "tensorflow/core/common_runtime/process_util.h" 27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 28 #include "tensorflow/core/distributed_runtime/master_env.h" 29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" 32 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" 33 #include "tensorflow/core/distributed_runtime/server_lib.h" 34 #include "tensorflow/core/distributed_runtime/session_mgr.h" 35 #include "tensorflow/core/distributed_runtime/worker_env.h" 36 #include "tensorflow/core/framework/collective.h" 37 #include "tensorflow/core/framework/op.h" 38 #include "tensorflow/core/platform/env.h" 39 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h" 40 41 namespace tensorflow { 42 43 class GrpcWorker; 44 class Master; 45 46 // function that creates a RendezvousMgr. 47 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)> 48 RendezvousMgrCreationFunction; 49 50 // function that creates a CollectiveExecutorMgr. 51 typedef std::function<CollectiveExecutorMgrInterface*( 52 const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)> 53 CollectiveMgrCreationFunction; 54 55 // function that registers a service to the server. The service needs to 56 // be registered before builder.BuildAndStart(). 57 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> 58 ServiceInitFunction; 59 60 // function that creates a grpc based worker implementation. 61 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*, 62 const ConfigProto& config)> 63 WorkerCreationFunction; 64 65 struct GrpcServerOptions { 66 ServiceInitFunction service_func = nullptr; 67 RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr; 68 CollectiveMgrCreationFunction collective_mgr_func = nullptr; 69 WorkerCreationFunction worker_func = nullptr; 70 StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; 71 GrpcWorkerServiceOptions worker_service_options; 72 const DeviceMgr* local_device_mgr = nullptr; 73 }; 74 75 class GrpcServer : public ServerInterface { 76 protected: 77 GrpcServer(const ServerDef& server_def, Env* env); 78 GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr, 79 Env* env); 80 // Allow children classes to override this and provide custom args to the 81 // server before it is constructed. Default behavior is to do nothing. 82 virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder); 83 84 public: 85 static Status Create(const ServerDef& server_def, Env* env, 86 std::unique_ptr<ServerInterface>* out_server); 87 static Status Create(const ServerDef& server_def, Env* env, 88 std::unique_ptr<GrpcServer>* out_server); 89 // Reuse the local_device_mgr. 90 static Status Create(const ServerDef& server_def, Env* env, 91 const DeviceMgr* local_device_mgr, 92 std::unique_ptr<ServerInterface>* out_server); 93 94 // Destruction is only supported in the factory method. Clean 95 // shutdown is not currently implemented for this server type. 96 virtual ~GrpcServer(); 97 98 // Implementations of ServerInterface methods. 99 Status Start() override; 100 Status Stop() override; 101 Status Join() override; 102 const string target() const override; 103 worker_env()104 WorkerEnv* worker_env() { return &worker_env_; } master_env()105 MasterEnv* master_env() { return &master_env_; } 106 107 // Add master eager context to local eager service in order to handle enqueue 108 // requests from remote workers. 109 Status AddMasterEagerContextToEagerService( 110 const tensorflow::uint64 context_id, tensorflow::EagerContext* context); 111 // Update the set of workers that can be reached by the GRPC server 112 Status UpdateServerDef(const ServerDef& server_def); 113 114 protected: 115 virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name, 116 int* port) const; 117 Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); 118 119 // A subclass can override this method to support secure credentials. 120 virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( 121 const ServerDef& server_def) const; 122 123 virtual ChannelCreationFunction GetChannelCreationFunction() const; 124 125 virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); 126 127 // Creates a WorkerCacheInterface for a session. 128 virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, 129 WorkerCacheInterface** worker_cache); 130 131 // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. 132 Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, 133 GrpcChannelSpec* channel_spec); 134 135 // Returns the port to which this server is bound. 136 // This method may only be called after `this->Init()` returns successfully. bound_port()137 int bound_port() const { return bound_port_; } 138 139 // Returns hostname. host_name()140 const string& host_name() const { return host_name_; } 141 server_def()142 const ServerDef& server_def() const { return server_def_; } worker_impl()143 GrpcWorker* worker_impl() const { return worker_impl_.get(); } grpc_worker_env()144 GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } 145 146 private: 147 Env* env_; 148 149 // The port to which this server is bound. 150 int bound_port_ = 0; 151 152 // The host name of this server 153 string host_name_; 154 155 // Guards server configuration, server, and state. 156 mutex mu_; 157 158 // Represents the current state of the server, which changes as follows: 159 // 160 // Join() Join() 161 // ___ ___ 162 // Start() \ / Stop() \ / 163 // NEW ---------> STARTED --------> STOPPED 164 // \ / 165 // \________________________/ 166 // Stop(), Join() 167 enum State { NEW, STARTED, STOPPED }; 168 State state_ TF_GUARDED_BY(mu_); 169 170 // Implementation of a TensorFlow master, and RPC polling thread. 171 MasterEnv master_env_; 172 std::unique_ptr<Master> master_impl_; 173 AsyncServiceInterface* master_service_ = nullptr; 174 std::unique_ptr<Thread> master_thread_ TF_GUARDED_BY(mu_); 175 176 // Implementation of a TensorFlow worker, and RPC polling thread. 177 WorkerEnv worker_env_; 178 std::unique_ptr<const DeviceMgr> owned_device_manager_; 179 std::unique_ptr<GrpcWorker> worker_impl_; 180 AsyncServiceInterface* worker_service_ = nullptr; 181 std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_); 182 std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_; 183 184 // TensorFlow Eager implementation, and RPC polling thread. 185 AsyncServiceInterface* eager_service_ = nullptr; 186 std::unique_ptr<Thread> eager_thread_ TF_GUARDED_BY(mu_); 187 std::shared_ptr<WorkerSession> worker_session_; 188 189 // TensorFlow profiler service implementation. 190 std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr; 191 192 // The overall server configuration. 193 ServerDef server_def_ TF_GUARDED_BY(mu_); 194 195 std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_); 196 }; 197 198 } // namespace tensorflow 199 200 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 201