1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 18 19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service. 20 21 #include <memory> 22 23 #include "grpcpp/grpcpp.h" 24 #include "grpcpp/security/credentials.h" 25 26 #include "tensorflow/core/common_runtime/process_util.h" 27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 28 #include "tensorflow/core/distributed_runtime/master_env.h" 29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" 32 #include "tensorflow/core/distributed_runtime/server_lib.h" 33 #include "tensorflow/core/distributed_runtime/session_mgr.h" 34 #include "tensorflow/core/distributed_runtime/worker_env.h" 35 #include "tensorflow/core/framework/collective.h" 36 #include "tensorflow/core/framework/op.h" 37 #include "tensorflow/core/platform/env.h" 38 39 namespace tensorflow { 40 41 class GrpcWorker; 42 class Master; 43 44 // function that creates a RendezvousMgr. 45 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)> 46 RendezvousMgrCreationFunction; 47 48 // function that creates a CollectiveExecutorMgr. 49 typedef std::function<CollectiveExecutorMgrInterface*( 50 const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)> 51 CollectiveMgrCreationFunction; 52 53 // function that registers a service to the server. The service needs to 54 // be registered before builder.BuildAndStart(). 55 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> 56 ServiceInitFunction; 57 58 // function that creates a grpc based worker implementation. 59 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*, 60 const ConfigProto& config)> 61 WorkerCreationFunction; 62 63 struct GrpcServerOptions { 64 ServiceInitFunction service_func = nullptr; 65 RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr; 66 CollectiveMgrCreationFunction collective_mgr_func = nullptr; 67 WorkerCreationFunction worker_func = nullptr; 68 StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; 69 GrpcWorkerServiceOptions worker_service_options; 70 }; 71 72 class GrpcServer : public ServerInterface { 73 protected: 74 GrpcServer(const ServerDef& server_def, Env* env); 75 // Allow children classes to override this and provide custom args to the 76 // server before it is constructed. Default behavior is to do nothing. 77 virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder); 78 79 public: 80 static Status Create(const ServerDef& server_def, Env* env, 81 std::unique_ptr<ServerInterface>* out_server); 82 static Status Create(const ServerDef& server_def, Env* env, 83 std::unique_ptr<GrpcServer>* out_server); 84 85 // Destruction is only supported in the factory method. Clean 86 // shutdown is not currently implemented for this server type. 87 virtual ~GrpcServer(); 88 89 // Implementations of ServerInterface methods. 90 Status Start() override; 91 Status Stop() override; 92 Status Join() override; 93 const string target() const override; 94 worker_env()95 WorkerEnv* worker_env() { return &worker_env_; } master_env()96 MasterEnv* master_env() { return &master_env_; } 97 channel_cache()98 std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } 99 100 protected: 101 Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); 102 103 // A subclass can override this method to support secure credentials. 104 virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( 105 const ServerDef& server_def) const; 106 107 virtual ChannelCreationFunction GetChannelCreationFunction() const; 108 109 virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); 110 111 // Creates a WorkerCacheInterface for a session. 112 Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, 113 WorkerCacheInterface** worker_cache); 114 115 // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. 116 Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, 117 GrpcChannelSpec* channel_spec); 118 119 // Returns the port to which this server is bound. 120 // This method may only be called after `this->Init()` returns successfully. bound_port()121 int bound_port() const { return bound_port_; } 122 server_def()123 const ServerDef& server_def() const { return server_def_; } 124 125 private: 126 // The overall server configuration. 127 const ServerDef server_def_; 128 Env* env_; 129 130 // The port to which this server is bound. 131 int bound_port_ = 0; 132 133 // Guards state transitions. 134 mutex mu_; 135 136 // Represents the current state of the server, which changes as follows: 137 // 138 // Join() Join() 139 // ___ ___ 140 // Start() \ / Stop() \ / 141 // NEW ---------> STARTED --------> STOPPED 142 // \ / 143 // \________________________/ 144 // Stop(), Join() 145 enum State { NEW, STARTED, STOPPED }; 146 State state_ GUARDED_BY(mu_); 147 148 // Implementation of a TensorFlow master, and RPC polling thread. 149 MasterEnv master_env_; 150 std::unique_ptr<Master> master_impl_; 151 AsyncServiceInterface* master_service_ = nullptr; 152 std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_); 153 std::shared_ptr<GrpcChannelCache> channel_cache_; 154 155 // Implementation of a TensorFlow worker, and RPC polling thread. 156 WorkerEnv worker_env_; 157 std::unique_ptr<GrpcWorker> worker_impl_; 158 AsyncServiceInterface* worker_service_ = nullptr; 159 std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_); 160 161 // TensorFlow Eager implementation, and RPC polling thread. 162 AsyncServiceInterface* eager_service_ = nullptr; 163 std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_); 164 std::shared_ptr<WorkerSession> worker_session_; 165 166 std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); 167 }; 168 169 } // namespace tensorflow 170 171 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 172