1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ 18 19 #include "absl/synchronization/mutex.h" 20 #include "absl/synchronization/notification.h" 21 #include "absl/time/time.h" 22 #include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" 23 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/core/platform/env.h" 27 28 namespace xla { 29 30 typedef int NodeId; 31 32 class DistributedRuntimeServiceImpl final 33 : public grpc::DistributedRuntimeService::Service { 34 public: 35 struct Options { 36 // Number of nodes in the job. Mandatory. Must be non-negative. 37 int num_nodes = -1; 38 39 tensorflow::Env* env = tensorflow::Env::Default(); 40 41 // Interval at which the service should check for missed heartbeat RPCs 42 // from the clients. 43 absl::Duration heartbeat_interval = absl::Seconds(10); 44 45 // Number of heartbeats that a client may miss in a row before the 46 // coordinator concludes that a client has vanished. 47 int max_missing_heartbeats = 10; 48 49 // How long should we wait for all clients to call EnumerateDevices() before 50 // giving up? 51 absl::Duration enumerate_devices_timeout = absl::Seconds(60); 52 53 // How long should we wait for all clients to call Shutdown() before giving 54 // up and returning a failure? 55 absl::Duration shutdown_timeout = absl::Seconds(60); 56 }; 57 explicit DistributedRuntimeServiceImpl(const Options& options); 58 ~DistributedRuntimeServiceImpl() override; 59 60 DistributedRuntimeServiceImpl(const DistributedRuntimeServiceImpl&) = delete; 61 DistributedRuntimeServiceImpl(DistributedRuntimeServiceImpl&&) = delete; 62 DistributedRuntimeServiceImpl& operator=( 63 const DistributedRuntimeServiceImpl&) = delete; 64 DistributedRuntimeServiceImpl&& operator=(DistributedRuntimeServiceImpl&&) = 65 delete; 66 67 ::grpc::Status Connect(::grpc::ServerContext* context, 68 const ConnectRequest* request, 69 ConnectResponse* response) override; 70 71 ::grpc::Status Shutdown(::grpc::ServerContext* context, 72 const ShutdownRequest* request, 73 ShutdownResponse* response) override; 74 75 ::grpc::Status Heartbeat(::grpc::ServerContext* context, 76 const HeartbeatRequest* request, 77 HeartbeatResponse* response) override; 78 79 ::grpc::Status EnumerateDevices(::grpc::ServerContext* context, 80 const EnumerateDevicesRequest* request, 81 EnumerateDevicesResponse* response) override; 82 83 ::grpc::Status KeyValueGet(::grpc::ServerContext* context, 84 const KeyValueGetRequest* request, 85 KeyValueGetResponse* response) override; 86 87 ::grpc::Status KeyValueSet(::grpc::ServerContext* context, 88 const KeyValueSetRequest* request, 89 KeyValueSetResponse* response) override; 90 91 private: 92 // Entry point for the heartbeat checking thread. 93 void HeartbeatLoop(); 94 95 // Validates a session id number matches the current session id. 96 xla::Status ValidateSessionId(uint64 session_id); 97 98 // Validates a node id number. 99 xla::Status ValidateNodeId(int node_id); 100 101 const Options options_; 102 const uint64 session_id_; 103 104 absl::Mutex mu_; 105 enum class State { kInitializing, kRunning, kClosed }; 106 State state_ ABSL_GUARDED_BY(mu_) = State::kInitializing; 107 Status service_status_ ABSL_GUARDED_BY(mu_); 108 109 // State for Connect() and heartbeats. 110 struct Node { 111 // Have we heard from a task with this ID? 112 bool present = false; 113 114 // A unique ID belonging to the client. Used to identify the client that 115 // most recently called Connect() with a particular task id. 116 uint64 client_id = 0; 117 118 // When did we last receive a heartbeat from this task? 119 absl::Time last_heartbeat = absl::InfinitePast(); 120 }; 121 int num_nodes_present_ ABSL_GUARDED_BY(mu_) = 0; 122 std::vector<Node> nodes_ ABSL_GUARDED_BY(mu_); 123 124 // State for EnumerateDevices. 125 int num_topologies_present_ ABSL_GUARDED_BY(mu_) = 0; 126 std::vector<LocalTopologyProto> local_topologies_ ABSL_GUARDED_BY(mu_); 127 absl::optional<GlobalTopologyProto> topology_ ABSL_GUARDED_BY(mu_); 128 129 // State for Shutdown(). Counter of how many nodes are blocked at the 130 // Shutdown() barrier. 131 int num_nodes_shutting_down_ ABSL_GUARDED_BY(mu_) = 0; 132 133 // Key-value store, used by distributed GPU code to share NCCL state. 134 KeyValueStore key_value_store_; 135 136 // Notification that tells the heartbeat thread to stop. 137 absl::Notification stop_heartbeat_thread_; 138 139 // Thread that checks for missing hearbeats from the clients periodically. 140 std::unique_ptr<tensorflow::Thread> heartbeat_thread_; 141 }; 142 143 class DistributedRuntimeService { 144 public: 145 static xla::StatusOr<std::unique_ptr<DistributedRuntimeService>> Get( 146 const std::string& address, 147 std::shared_ptr<::grpc::ServerCredentials> credentials, 148 const DistributedRuntimeServiceImpl::Options& options); 149 150 explicit DistributedRuntimeService( 151 const DistributedRuntimeServiceImpl::Options& options); 152 ~DistributedRuntimeService(); 153 154 DistributedRuntimeService(const DistributedRuntimeService&) = delete; 155 DistributedRuntimeService(DistributedRuntimeService&&) = delete; 156 DistributedRuntimeService& operator=(const DistributedRuntimeService&) = 157 delete; 158 DistributedRuntimeService& operator=(DistributedRuntimeService&&) = delete; 159 server()160 ::grpc::Server* server() const { return server_.get(); } 161 162 private: 163 DistributedRuntimeServiceImpl impl_; 164 std::unique_ptr<::grpc::Server> server_; 165 }; 166 167 // Everything below this point is exposed only for tests. 168 169 // Given a LocalTopologyProto object from each node, builds a 170 // GlobalTopologyProto that describes all nodes. 171 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies, 172 GlobalTopologyProto* global_topology); 173 174 } // namespace xla 175 176 #endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ 177