1 /* Copyright 2020 Google LLC
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
17
18 #include <chrono> // NOLINT
19 #include <random>
20
21 #include "absl/time/time.h"
22 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
23 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/random.h"
27
28 namespace xla {
29
DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel,const Options & options)30 DistributedRuntimeClient::DistributedRuntimeClient(
31 std::shared_ptr<::grpc::Channel> channel, const Options& options)
32 : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))),
33 options_(options) {}
34
~DistributedRuntimeClient()35 DistributedRuntimeClient::~DistributedRuntimeClient() {
36 bool connected;
37 {
38 absl::MutexLock lock(&mu_);
39 connected = (state_ == State::kConnected);
40 }
41 if (connected) {
42 if (options_.shutdown_on_destruction) {
43 Status status = Shutdown();
44 if (!status.ok()) {
45 LOG(WARNING) << "PJRT shutdown failed: " << status;
46 }
47 } else {
48 if (!stop_heartbeats_.HasBeenNotified()) {
49 stop_heartbeats_.Notify();
50 }
51 }
52 }
53 }
54
StateToString(State state)55 /*static*/ absl::string_view DistributedRuntimeClient::StateToString(
56 State state) {
57 switch (state) {
58 case State::kNotConnected:
59 return "kNotConnected";
60 case State::kConnected:
61 return "kConnected";
62 case State::kShuttingDown:
63 return "kShuttingDown";
64 case State::kClosed:
65 return "kClosed";
66 }
67 }
68
Connect()69 xla::Status DistributedRuntimeClient::Connect() {
70 {
71 absl::MutexLock lock(&mu_);
72 if (state_ != State::kNotConnected) {
73 return xla::FailedPrecondition("Connect() called when client in state %s",
74 StateToString(state_));
75 }
76 }
77 ConnectRequest request;
78 request.set_protocol_version(kDistributedRuntimeProtocolVersion);
79 request.set_timeout_milliseconds(
80 absl::ToInt64Milliseconds(options_.rpc_timeout) / 2);
81 request.set_node_id(options_.node_id);
82 VLOG(10) << "Connect: " << request.DebugString();
83 ConnectResponse response;
84 ::grpc::Status status;
85 absl::Time deadline = absl::Now() + options_.init_timeout;
86 int attempt = 0;
87 std::default_random_engine generator;
88 std::uniform_real_distribution<double> distribution(0.0, 1.0);
89 do {
90 ::grpc::ClientContext ctx;
91 ctx.set_fail_fast(false);
92 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
93 request.set_client_id(tensorflow::random::New64());
94 response.Clear();
95 status = stub_->Connect(&ctx, request, &response);
96 if (!status.ok()) {
97 VLOG(1) << "Connect failed() with status: " << FromGrpcStatus(status);
98 if (attempt % 10 == 0) {
99 LOG(INFO) << "Connect failed() with status: " << FromGrpcStatus(status);
100 }
101 // Exponential backoff with jitter. Note we will retry for `init_timeout`
102 // time in total; the `14` here corresponds to an ~16s maximum interval
103 // between connection attempts.
104 int backoff = 1 << std::min(14, attempt);
105 absl::SleepFor(absl::Milliseconds(backoff * distribution(generator)));
106 }
107 ++attempt;
108 } while (!status.ok() && absl::Now() < deadline);
109 if (!status.ok()) {
110 LOG(ERROR) << "Connect() failed after " << attempt << " retries in "
111 << options_.init_timeout
112 << "; most recent failure status: " << FromGrpcStatus(status);
113 return tensorflow::errors::DeadlineExceeded(
114 absl::StrFormat("Connect() timed out after %s with %d attempts. Most "
115 "recent failure was: %s",
116 absl::FormatDuration(options_.init_timeout), attempt,
117 FromGrpcStatus(status).ToString()));
118 }
119 VLOG(10) << "Connect() response: " << response.DebugString();
120 {
121 absl::MutexLock lock(&mu_);
122 state_ = State::kConnected;
123 }
124 session_id_ = response.session_id();
125
126 heartbeat_thread_.reset(options_.env->StartThread(
127 tensorflow::ThreadOptions(), "pjrt_distributed_heartbeat",
128 [this]() { HeartbeatLoop(); }));
129 LOG(INFO) << "Connected to distributed JAX controller";
130 return xla::Status::OK();
131 }
132
EnumerateDevices(const LocalTopologyProto & local_topology,GlobalTopologyProto * global_topology)133 xla::Status DistributedRuntimeClient::EnumerateDevices(
134 const LocalTopologyProto& local_topology,
135 GlobalTopologyProto* global_topology) {
136 {
137 absl::MutexLock lock(&mu_);
138 if (state_ != State::kConnected) {
139 return xla::FailedPrecondition(
140 "EnumerateDevices() called when client not connected.");
141 }
142 }
143 ::grpc::ClientContext ctx;
144 ctx.set_fail_fast(false);
145 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
146 EnumerateDevicesRequest request;
147 request.set_session_id(session_id_);
148 *request.mutable_local_topology() = local_topology;
149 request.mutable_local_topology()->set_node_id(options_.node_id);
150
151 VLOG(10) << "EnumerateDevices: " << request.DebugString();
152 EnumerateDevicesResponse response;
153 ::grpc::Status status = stub_->EnumerateDevices(&ctx, request, &response);
154 if (!status.ok()) {
155 return FromGrpcStatus(status);
156 }
157 VLOG(10) << "EnumerateDevices() response: " << response.DebugString();
158 response.mutable_global_topology()->Swap(global_topology);
159 return xla::Status::OK();
160 }
161
Shutdown()162 xla::Status DistributedRuntimeClient::Shutdown() {
163 LOG(INFO) << "Waiting for all distributed JAX tasks to shut down.";
164 ::grpc::ClientContext ctx;
165 {
166 absl::MutexLock lock(&mu_);
167 if (state_ != State::kConnected) {
168 return xla::FailedPrecondition(
169 "Shutdown() called when client not connected.");
170 }
171 state_ = State::kShuttingDown;
172 }
173 ctx.set_fail_fast(false);
174 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.shutdown_timeout));
175 ShutdownRequest request;
176 request.set_session_id(session_id_);
177 VLOG(10) << "Shutdown: " << request.DebugString();
178 ShutdownResponse response;
179 ::grpc::Status status = stub_->Shutdown(&ctx, request, &response);
180 LOG(INFO) << "Distributed task shutdown result: " << FromGrpcStatus(status);
181 if (!status.ok()) {
182 return FromGrpcStatus(status);
183 }
184 if (!stop_heartbeats_.HasBeenNotified()) {
185 stop_heartbeats_.Notify();
186 }
187 VLOG(10) << "Shutdown() response: " << response.DebugString();
188 absl::MutexLock lock(&mu_);
189 state_ = State::kClosed;
190 return xla::Status::OK();
191 }
192
BlockingKeyValueGet(std::string key,absl::Duration timeout)193 xla::StatusOr<std::string> DistributedRuntimeClient::BlockingKeyValueGet(
194 std::string key, absl::Duration timeout) {
195 {
196 absl::MutexLock lock(&mu_);
197 if (state_ != State::kConnected) {
198 return xla::FailedPrecondition(
199 "BlockingKeyValueGet() called when client not connected.");
200 }
201 }
202 ::grpc::ClientContext ctx;
203 ctx.set_fail_fast(false);
204 ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout));
205 KeyValueGetRequest request;
206 request.set_session_id(session_id_);
207 request.set_key(std::move(key));
208 timeout = std::min(timeout, absl::Minutes(10)); // Avoid overflow
209 request.set_timeout_milliseconds(timeout / absl::Milliseconds(1));
210 VLOG(10) << "BlockingKeyValueGet: " << request.DebugString();
211 KeyValueGetResponse response;
212 ::grpc::Status status = stub_->KeyValueGet(&ctx, request, &response);
213 if (!status.ok()) {
214 return FromGrpcStatus(status);
215 }
216 return response.value();
217 }
218
KeyValueSet(std::string key,std::string value)219 xla::Status DistributedRuntimeClient::KeyValueSet(std::string key,
220 std::string value) {
221 {
222 absl::MutexLock lock(&mu_);
223 if (state_ != State::kConnected) {
224 return xla::FailedPrecondition(
225 "KeyValueSet() called when client not connected.");
226 }
227 }
228 ::grpc::ClientContext ctx;
229 ctx.set_fail_fast(false);
230 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
231 KeyValueSetRequest request;
232 request.set_session_id(session_id_);
233 request.set_key(std::move(key));
234 request.set_value(std::move(value));
235 VLOG(10) << "KeyValueSet: " << request.DebugString();
236 KeyValueSetResponse response;
237 ::grpc::Status status = stub_->KeyValueSet(&ctx, request, &response);
238 return FromGrpcStatus(status);
239 }
240
HeartbeatLoop()241 void DistributedRuntimeClient::HeartbeatLoop() {
242 int num_missing_heartbeats = 0;
243 while (true) {
244 stop_heartbeats_.WaitForNotificationWithTimeout(
245 options_.heartbeat_interval);
246 if (stop_heartbeats_.HasBeenNotified()) {
247 return;
248 }
249
250 ::grpc::ClientContext ctx;
251 ctx.set_fail_fast(false);
252 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
253 HeartbeatRequest request;
254 request.set_session_id(session_id_);
255 request.set_node_id(options_.node_id);
256 VLOG(10) << "Heartbeat: " << request.DebugString();
257 HeartbeatResponse response;
258 ::grpc::Status status = stub_->Heartbeat(&ctx, request, &response);
259 if (status.ok()) {
260 num_missing_heartbeats = 0;
261 } else {
262 ++num_missing_heartbeats;
263 bool is_transient_error =
264 (status.error_code() == ::grpc::StatusCode::DEADLINE_EXCEEDED ||
265 status.error_code() == ::grpc::StatusCode::UNAVAILABLE);
266 if (!stop_heartbeats_.HasBeenNotified() &&
267 (!is_transient_error ||
268 num_missing_heartbeats > options_.max_missing_heartbeats)) {
269 // If we are shutting down, missed heartbeats are benign: they may
270 // simply mean that the server has shut down already before it saw
271 // the heartbeat request.
272 absl::MutexLock lock(&mu_);
273 if (state_ != State::kShuttingDown) {
274 options_.missed_heartbeat_callback(FromGrpcStatus(status),
275 !is_transient_error);
276 }
277 return;
278 }
279 }
280 }
281 }
282
283 } // namespace xla
284