1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 18 19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service. 20 21 #include <memory> 22 23 #include "grpcpp/grpcpp.h" 24 #include "grpcpp/security/credentials.h" 25 #include "tensorflow/core/common_runtime/eager/context.h" 26 #include "tensorflow/core/common_runtime/process_util.h" 27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 28 #include "tensorflow/core/distributed_runtime/master_env.h" 29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" 32 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" 33 #include "tensorflow/core/distributed_runtime/server_lib.h" 34 #include "tensorflow/core/distributed_runtime/session_mgr.h" 35 #include "tensorflow/core/distributed_runtime/worker_env.h" 36 #include "tensorflow/core/framework/collective.h" 37 #include "tensorflow/core/framework/op.h" 38 #include "tensorflow/core/platform/env.h" 39 40 namespace tensorflow { 41 42 class GrpcWorker; 43 class Master; 44 45 // function that creates a RendezvousMgr. 46 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)> 47 RendezvousMgrCreationFunction; 48 49 // function that creates a CollectiveExecutorMgr. 50 typedef std::function<CollectiveExecutorMgrInterface*( 51 const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)> 52 CollectiveMgrCreationFunction; 53 54 // function that registers a service to the server. The service needs to 55 // be registered before builder.BuildAndStart(). 56 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> 57 ServiceInitFunction; 58 59 // function that creates a grpc based worker implementation. 60 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*, 61 const ConfigProto& config)> 62 WorkerCreationFunction; 63 64 struct GrpcServerOptions { 65 ServiceInitFunction service_func = nullptr; 66 RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr; 67 CollectiveMgrCreationFunction collective_mgr_func = nullptr; 68 WorkerCreationFunction worker_func = nullptr; 69 StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; 70 GrpcWorkerServiceOptions worker_service_options; 71 }; 72 73 class GrpcServer : public ServerInterface { 74 protected: 75 GrpcServer(const ServerDef& server_def, Env* env); 76 // Allow children classes to override this and provide custom args to the 77 // server before it is constructed. Default behavior is to do nothing. 78 virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder); 79 80 public: 81 static Status Create(const ServerDef& server_def, Env* env, 82 std::unique_ptr<ServerInterface>* out_server); 83 static Status Create(const ServerDef& server_def, Env* env, 84 std::unique_ptr<GrpcServer>* out_server); 85 86 // Destruction is only supported in the factory method. Clean 87 // shutdown is not currently implemented for this server type. 88 virtual ~GrpcServer(); 89 90 // Implementations of ServerInterface methods. 91 Status Start() override; 92 Status Stop() override; 93 Status Join() override; 94 const string target() const override; 95 worker_env()96 WorkerEnv* worker_env() { return &worker_env_; } master_env()97 MasterEnv* master_env() { return &master_env_; } 98 99 // Add master eager context to local eager service in order to handle enqueue 100 // requests from remote workers. 101 Status AddMasterEagerContextToEagerService( 102 const tensorflow::uint64 context_id, tensorflow::EagerContext* context); 103 // Update the set of workers that can be reached by the GRPC server 104 Status UpdateServerDef(const ServerDef& server_def); 105 106 protected: 107 virtual Status GetPort(const ServerDef& server_def, int* port) const; 108 Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); 109 110 // A subclass can override this method to support secure credentials. 111 virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( 112 const ServerDef& server_def) const; 113 114 virtual ChannelCreationFunction GetChannelCreationFunction() const; 115 116 virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); 117 118 // Creates a WorkerCacheInterface for a session. 119 virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, 120 WorkerCacheInterface** worker_cache); 121 122 // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. 123 Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, 124 GrpcChannelSpec* channel_spec); 125 126 // Returns the port to which this server is bound. 127 // This method may only be called after `this->Init()` returns successfully. bound_port()128 int bound_port() const { return bound_port_; } 129 server_def()130 const ServerDef& server_def() const { return server_def_; } worker_impl()131 GrpcWorker* worker_impl() const { return worker_impl_.get(); } 132 133 private: 134 Env* env_; 135 136 // The port to which this server is bound. 137 int bound_port_ = 0; 138 139 // Guards server configuration, server, and state. 140 mutex mu_; 141 142 // Represents the current state of the server, which changes as follows: 143 // 144 // Join() Join() 145 // ___ ___ 146 // Start() \ / Stop() \ / 147 // NEW ---------> STARTED --------> STOPPED 148 // \ / 149 // \________________________/ 150 // Stop(), Join() 151 enum State { NEW, STARTED, STOPPED }; 152 State state_ GUARDED_BY(mu_); 153 154 // Implementation of a TensorFlow master, and RPC polling thread. 155 MasterEnv master_env_; 156 std::unique_ptr<Master> master_impl_; 157 AsyncServiceInterface* master_service_ = nullptr; 158 std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_); 159 160 // Implementation of a TensorFlow worker, and RPC polling thread. 161 WorkerEnv worker_env_; 162 std::unique_ptr<GrpcWorker> worker_impl_; 163 AsyncServiceInterface* worker_service_ = nullptr; 164 std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_); 165 std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_; 166 167 // TensorFlow Eager implementation, and RPC polling thread. 168 AsyncServiceInterface* eager_service_ = nullptr; 169 std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_); 170 std::shared_ptr<WorkerSession> worker_session_; 171 172 // The overall server configuration. 173 ServerDef server_def_ GUARDED_BY(mu_); 174 175 std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); 176 }; 177 178 } // namespace tensorflow 179 180 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 181