1 /* Copyright 2020 Google LLC 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ 18 19 #include <memory> 20 21 #include "grpcpp/channel.h" 22 #include "absl/synchronization/mutex.h" 23 #include "absl/synchronization/notification.h" 24 #include "absl/time/time.h" 25 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/core/platform/env.h" 29 30 namespace xla { 31 32 class DistributedRuntimeClient { 33 public: 34 struct Options { 35 // This node's global ID. Required. 36 int32 node_id = -1; 37 38 // Environment used for starting threads. 39 tensorflow::Env* env = tensorflow::Env::Default(); 40 41 // RPC timeout used for RPC that don't have their own timeouts. 42 absl::Duration rpc_timeout = absl::Seconds(120); 43 44 // Time period for which Connect() should be retried. The client will keep 45 // trying to open the initial connection for this period, even if any 46 // individual Connect() RPC fails. May be zero, in which case Connect() will 47 // only be attempted once. 48 absl::Duration init_timeout = absl::ZeroDuration(); 49 50 // How long to wait for all nodes to call Shutdown(). If the timeout 51 // expires, then shutdown() reports an error and returns control. 52 absl::Duration shutdown_timeout = absl::Seconds(60); 53 54 // Interval at which the client should send heartbeat RPCs to the 55 // coordinator. 56 absl::Duration heartbeat_interval = absl::Seconds(10); 57 58 // How many failed heartbeat RPCs may fail due to a possibly-ephemeral 59 // reason before we decide the coordinator has vanished and that we should 60 // shut down. 61 int max_missing_heartbeats = 10; 62 63 // Callback invoked by the client when notification of a missing heartbeat 64 // is reported by the coordinator, or we have not heard from the coordinator 65 // recently. `coordinator_reported_failure` is true in the former case. 66 // Exposed so tests can override this behavior to something non-fatal. 67 std::function<void(xla::Status, bool coordinator_reported_failure)> 68 missed_heartbeat_callback = 69 [](xla::Status status, bool coordinator_reported_failure) { 70 if (coordinator_reported_failure) { 71 LOG(QFATAL) 72 << "Terminating process because the coordinator detected " 73 "missing heartbeats. This most likely indicates that " 74 "another task died; see the other task logs for more " 75 "details. Status: " 76 << status; 77 } else { 78 LOG(QFATAL) 79 << "Terminating process because of missing heartbeat " 80 "response from the coordinator. This most likely " 81 "indicates that the coordinator task died; see the " 82 "coordinator's task logs for more details. Status: " 83 << status; 84 } 85 }; 86 87 // For testing. Should the client explicitly Shutdown() on destruction? 88 bool shutdown_on_destruction = true; 89 }; 90 DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel, 91 const Options& options); DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel)92 explicit DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel) 93 : DistributedRuntimeClient(channel, Options()) {} 94 ~DistributedRuntimeClient(); 95 96 // Connects to the master, and blocks until all clients have successfully 97 // connected. 98 // Not thread-safe, i.e., calls to Connect()/Shutdown()/EnumerateDevices() 99 // must be serialized by some other means. 100 xla::Status Connect(); 101 102 // Reports to the master that the client is ready to shutdown, and blocks 103 // until all clients are ready to shutdown or the shutdown timeout expires. 104 // Not thread-safe. 105 xla::Status Shutdown(); 106 107 // Blocking enumeration of global devices. Used by the GPU platform. 108 // Not thread-safe. 109 xla::Status EnumerateDevices(const LocalTopologyProto& local_topology, 110 GlobalTopologyProto* global_topology); 111 112 // The following APIs are thread-safe. 113 xla::StatusOr<std::string> BlockingKeyValueGet(std::string key, 114 absl::Duration timeout); 115 116 xla::Status KeyValueSet(std::string key, std::string value); 117 118 private: 119 // Entry point for the heartbeat thread. 120 void HeartbeatLoop(); 121 122 const std::unique_ptr<grpc::DistributedRuntimeService::Stub> stub_; 123 const Options options_; 124 125 // Possible states of the client. 126 // The only legal transitions are downwards in the order below. i.e., there is 127 // no way to reopen a closed client. 128 enum class State { 129 // The client has not yet connected to the server, i.e., had a Connect() 130 // RPC succeed. 131 kNotConnected, 132 133 // The client is connected to the server and as far as we are aware the 134 // connection is healthy. 135 kConnected, 136 137 // The client is in the process of shutting down, i.e., Shutdown() has been 138 // called. 139 kShuttingDown, 140 141 // The client has shut down its server connection, either due to an error 142 // or due to an explicit shutdown. 143 kClosed, 144 }; 145 146 static absl::string_view StateToString(State state); 147 148 // state_ is protected by a mutex because the heartbeat thread needs to look 149 // at it. 150 absl::Mutex mu_; 151 State state_ ABSL_GUARDED_BY(mu_) = State::kNotConnected; 152 153 // A unique session ID, assigned by the server during Connect(). 154 uint64 session_id_; 155 156 // Notification that tells the heartbeat thread to stop running. 157 absl::Notification stop_heartbeats_; 158 159 // Thread responsible for performing heartbeats. 160 std::unique_ptr<tensorflow::Thread> heartbeat_thread_; 161 }; 162 163 } // namespace xla 164 165 #endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ 166