1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 18 19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service. 20 21 #include <memory> 22 23 #include "grpcpp/grpcpp.h" 24 #include "grpcpp/security/credentials.h" 25 #include "tensorflow/core/common_runtime/eager/context.h" 26 #include "tensorflow/core/common_runtime/process_util.h" 27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 28 #include "tensorflow/core/distributed_runtime/master_env.h" 29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" 32 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" 33 #include "tensorflow/core/distributed_runtime/server_lib.h" 34 #include "tensorflow/core/distributed_runtime/session_mgr.h" 35 #include "tensorflow/core/distributed_runtime/worker_env.h" 36 #include "tensorflow/core/framework/collective.h" 37 #include "tensorflow/core/framework/op.h" 38 #include "tensorflow/core/platform/env.h" 39 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h" 40 41 namespace tensorflow { 42 43 class GrpcWorker; 44 class Master; 45 46 // function that creates a RendezvousMgr. 47 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)> 48 RendezvousMgrCreationFunction; 49 50 // function that creates a CollectiveExecutorMgr. 51 typedef std::function<CollectiveExecutorMgrInterface*( 52 const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)> 53 CollectiveMgrCreationFunction; 54 55 // function that registers a service to the server. The service needs to 56 // be registered before builder.BuildAndStart(). 57 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> 58 ServiceInitFunction; 59 60 // function that creates a grpc based worker implementation. 61 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*, 62 const ConfigProto& config)> 63 WorkerCreationFunction; 64 65 struct GrpcServerOptions { 66 ServiceInitFunction service_func = nullptr; 67 RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr; 68 CollectiveMgrCreationFunction collective_mgr_func = nullptr; 69 WorkerCreationFunction worker_func = nullptr; 70 StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; 71 GrpcWorkerServiceOptions worker_service_options; 72 DeviceMgr* local_device_mgr = nullptr; 73 }; 74 75 class GrpcServer : public ServerInterface { 76 protected: 77 GrpcServer(const ServerDef& server_def, Env* env); 78 GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr, 79 Env* env); 80 // Allow children classes to override this and provide custom args to the 81 // server before it is constructed. Default behavior is to do nothing. 82 // requested_port provides the port requested by caller as bound_port() is 83 // not available till BuildAndStart has been called. MaybeMutateBuilder(::grpc::ServerBuilder * builder,int requested_port)84 virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder, 85 int requested_port) {} 86 87 public: 88 static Status Create(const ServerDef& server_def, Env* env, 89 std::unique_ptr<ServerInterface>* out_server); 90 static Status Create(const ServerDef& server_def, Env* env, 91 std::unique_ptr<GrpcServer>* out_server); 92 // Reuse the local_device_mgr. 93 static Status Create(const ServerDef& server_def, Env* env, 94 DeviceMgr* local_device_mgr, 95 std::unique_ptr<ServerInterface>* out_server); 96 97 // Destruction is only supported in the factory method. Clean 98 // shutdown is not currently implemented for this server type. 99 virtual ~GrpcServer(); 100 101 // Implementations of ServerInterface methods. 102 Status Start() override; 103 Status Stop() override; 104 Status Join() override; 105 const string target() const override; 106 worker_env()107 WorkerEnv* worker_env() override { return &worker_env_; } master_env()108 MasterEnv* master_env() override { return &master_env_; } 109 110 // Add master eager context to local eager service in order to handle enqueue 111 // requests from remote workers. 112 Status AddMasterEagerContextToEagerService( 113 const tensorflow::uint64 context_id, 114 tensorflow::EagerContext* context) override; 115 // Update the set of workers that can be reached by the GRPC server 116 Status UpdateServerDef(const ServerDef& server_def) override; 117 // Pass coordination service agent instance to server's RPC handler 118 Status SetCoordinationServiceAgentInstance( 119 CoordinationServiceAgent* agent) override; 120 121 protected: 122 virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name, 123 int* port) const; 124 Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); 125 126 // A subclass can override this method to support secure credentials. 127 virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( 128 const ServerDef& server_def) const; 129 130 virtual ChannelCreationFunction GetChannelCreationFunction() const; 131 132 virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); 133 134 // Creates a WorkerCacheInterface for a session. 135 virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, 136 WorkerCacheInterface** worker_cache); 137 138 // Override to return extra services to be brought up and managed along with 139 // the standard {master, worker, eager} services. The map key is an aribtrary 140 // string and the value is a pointer to the service to be brought up. 141 // Ownership of the pointer is transferred to GrpcServer after this call 142 // returns, and the service will be destroyed during the destruction of 143 // GrpcServer. Each service will have its HandleRPCsLoop called in a separate 144 // thread. An example usage would be to add a RDMA based partial worker 145 // service to offload tensor and data buffer transfers. ExtraServices(::grpc::ServerBuilder *)146 virtual std::map<std::string, AsyncServiceInterface*> ExtraServices( 147 ::grpc::ServerBuilder*) { 148 return {}; 149 } 150 GetExtraServices()151 virtual std::map<std::string, AsyncServiceInterface*> GetExtraServices() { 152 return extra_services_; 153 } 154 155 // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. 156 Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, 157 GrpcChannelSpec* channel_spec); 158 159 // Returns the port to which this server is bound. 160 // This method may only be called after `this->Init()` returns successfully. bound_port()161 int bound_port() const { return bound_port_; } 162 163 // Returns hostname. host_name()164 const string& host_name() const { return host_name_; } 165 server_def()166 const ServerDef& server_def() const { return server_def_; } worker_impl()167 GrpcWorker* worker_impl() const { return worker_impl_.get(); } grpc_worker_env()168 GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } 169 170 private: 171 Env* env_; 172 173 // The port to which this server is bound. 174 int bound_port_ = 0; 175 176 // The host name of this server 177 string host_name_; 178 179 // Guards server configuration, server, and state. 180 mutex mu_; 181 182 // Represents the current state of the server, which changes as follows: 183 // 184 // Join() Join() 185 // ___ ___ 186 // Start() \ / Stop() \ / 187 // NEW ---------> STARTED --------> STOPPED 188 // \ / 189 // \________________________/ 190 // Stop(), Join() 191 enum State { NEW, STARTED, STOPPED }; 192 State state_ TF_GUARDED_BY(mu_); 193 194 // Implementation of a TensorFlow master, and RPC polling thread. 195 MasterEnv master_env_; 196 std::unique_ptr<Master> master_impl_; 197 AsyncServiceInterface* master_service_ = nullptr; 198 std::unique_ptr<Thread> master_thread_ TF_GUARDED_BY(mu_); 199 200 std::map<std::string, AsyncServiceInterface*> extra_services_; 201 std::vector<std::unique_ptr<Thread>> extra_service_threads_ 202 TF_GUARDED_BY(mu_); 203 204 // Implementation of a TensorFlow worker, and RPC polling thread. 205 WorkerEnv worker_env_; 206 std::unique_ptr<const DeviceMgr> owned_device_manager_; 207 std::unique_ptr<GrpcWorker> worker_impl_; 208 AsyncServiceInterface* worker_service_ = nullptr; 209 std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_); 210 std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_; 211 212 // TensorFlow Eager implementation, and RPC polling thread. 213 AsyncServiceInterface* eager_service_ = nullptr; 214 std::unique_ptr<Thread> eager_thread_ TF_GUARDED_BY(mu_); 215 std::shared_ptr<WorkerSession> worker_session_; 216 217 // TensorFlow profiler service implementation. 218 std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr; 219 220 // The overall server configuration. 221 ServerDef server_def_ TF_GUARDED_BY(mu_); 222 223 std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_); 224 }; 225 226 } // namespace tensorflow 227 228 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 229