1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
17
18 #include "absl/time/time.h"
19 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
20 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
21 #include "tensorflow/compiler/xla/status.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/random.h"
25
26 namespace xla {
27
DistributedRuntimeServiceImpl(const Options & options)28 DistributedRuntimeServiceImpl::DistributedRuntimeServiceImpl(
29 const Options& options)
30 : options_(options), session_id_(tensorflow::random::New64()) {
31 nodes_.resize(options.num_nodes);
32 local_topologies_.resize(options.num_nodes);
33 }
34
~DistributedRuntimeServiceImpl()35 DistributedRuntimeServiceImpl::~DistributedRuntimeServiceImpl() {
36 {
37 absl::MutexLock lock(&mu_);
38 state_ = State::kClosed;
39 service_status_ =
40 tensorflow::errors::FailedPrecondition("Service shutting down.");
41 if (!stop_heartbeat_thread_.HasBeenNotified()) {
42 stop_heartbeat_thread_.Notify();
43 }
44 }
45 }
46
47 // Steals the contents of `local_topologies`.
BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,GlobalTopologyProto * global_topology)48 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
49 GlobalTopologyProto* global_topology) {
50 int next_global_device_id = 0;
51 for (LocalTopologyProto& local : local_topologies) {
52 for (DeviceProto& device : *local.mutable_devices()) {
53 device.set_global_device_id(next_global_device_id++);
54 }
55 global_topology->add_nodes()->Swap(&local);
56 }
57 }
58
ValidateNodeId(int node_id)59 xla::Status DistributedRuntimeServiceImpl::ValidateNodeId(int node_id) {
60 if (node_id < 0) {
61 return xla::InvalidArgument("Invalid node ID %d, must be non-negative",
62 node_id);
63 }
64 if (node_id >= options_.num_nodes) {
65 return xla::FailedPrecondition(
66 "Invalid node ID %d, must be in the range [0, %d)", node_id,
67 options_.num_nodes);
68 }
69 return xla::Status::OK();
70 }
71
ValidateSessionId(uint64 session_id)72 xla::Status DistributedRuntimeServiceImpl::ValidateSessionId(
73 uint64 session_id) {
74 if (session_id != session_id_) {
75 return xla::FailedPrecondition(
76 "Session ID of request %llu does not match active session ID %llu",
77 session_id, session_id_);
78 }
79 return xla::Status::OK();
80 }
81
Connect(::grpc::ServerContext * context,const ConnectRequest * request,ConnectResponse * response)82 ::grpc::Status DistributedRuntimeServiceImpl::Connect(
83 ::grpc::ServerContext* context, const ConnectRequest* request,
84 ConnectResponse* response) {
85 VLOG(10) << "Connect " << request->DebugString();
86 if (request->protocol_version() != kDistributedRuntimeProtocolVersion) {
87 return ToGrpcStatus(xla::InvalidArgument("Invalid protocol version %d",
88 request->protocol_version()));
89 }
90 absl::MutexLock lock(&mu_);
91 if (state_ != State::kInitializing) {
92 // This most likely indicates that a client task was restarted but the
93 // old master is still up. Clients should retry on failure.
94 return ToGrpcStatus(tensorflow::errors::Aborted(
95 "Connect() called when system is not initializing."));
96 }
97 int node_id = request->node_id();
98 xla::Status status = ValidateNodeId(node_id);
99 if (!status.ok()) {
100 return ToGrpcStatus(status);
101 }
102 if (!nodes_[node_id].present) {
103 nodes_[node_id].present = true;
104 ++num_nodes_present_;
105 }
106 nodes_[node_id].client_id = request->client_id();
107
108 auto all_nodes_present_or_duplicate_request = [&]() {
109 mu_.AssertHeld();
110 return num_nodes_present_ == nodes_.size() ||
111 nodes_[node_id].client_id != request->client_id();
112 };
113 auto connect_timeout = absl::Milliseconds(request->timeout_milliseconds());
114 if (!mu_.AwaitWithTimeout(
115 absl::Condition(&all_nodes_present_or_duplicate_request),
116 connect_timeout)) {
117 nodes_[node_id].present = false;
118 --num_nodes_present_;
119 return ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
120 "Timed out after ", absl::FormatDuration(connect_timeout),
121 " waiting for all nodes to call Connect()"));
122 }
123
124 if (nodes_[node_id].client_id != request->client_id()) {
125 // This might happen either if two nodes are erroneously configured with the
126 // same ID number, or it might happen if a task fails and is restarted
127 // while we are waiting for nodes to connect. To elaborate on the second
128 // scenario, it would look like this:
129 // * a task calls Connect() with a particular node_id and client_id.
130 // * the task is killed and restarted, or alternatively the client's RPC
131 // times out and it decides to retry.
132 // * the task calls Connect() again with the same node_id and a different
133 // client_id.
134 // In this scenario we take whichever client showed up most recently and
135 // evict the client with an out-of-date client ID.
136 return ToGrpcStatus(
137 tensorflow::errors::Aborted("Duplicate node ID ", node_id));
138 }
139
140 if (node_id == 0) {
141 state_ = State::kRunning;
142 heartbeat_thread_.reset(options_.env->StartThread(
143 tensorflow::ThreadOptions(), "pjrt_service_heartbeat",
144 [this]() { HeartbeatLoop(); }));
145 } else {
146 auto running = [&]() {
147 mu_.AssertHeld();
148 return state_ == State::kRunning;
149 };
150 mu_.Await(absl::Condition(&running));
151 }
152 nodes_[node_id].last_heartbeat = absl::Now();
153 response->set_session_id(session_id_);
154 return ::grpc::Status::OK;
155 }
156
Shutdown(::grpc::ServerContext * context,const ShutdownRequest * request,ShutdownResponse * response)157 ::grpc::Status DistributedRuntimeServiceImpl::Shutdown(
158 ::grpc::ServerContext* context, const ShutdownRequest* request,
159 ShutdownResponse* response) {
160 VLOG(10) << "Shutdown " << request->DebugString();
161 xla::Status status = ValidateSessionId(request->session_id());
162 if (!status.ok()) {
163 return ToGrpcStatus(status);
164 }
165 absl::MutexLock lock(&mu_);
166 if (state_ != State::kRunning) {
167 if (!service_status_.ok()) {
168 return ToGrpcStatus(service_status_);
169 }
170 return ToGrpcStatus(xla::FailedPrecondition(
171 "Shutdown() called when system is not running."));
172 }
173 int node_id = request->node_id();
174 status = ValidateNodeId(node_id);
175 if (!status.ok()) {
176 return ToGrpcStatus(status);
177 }
178 ++num_nodes_shutting_down_;
179
180 auto all_nodes_shutting_down = [&]() {
181 mu_.AssertHeld();
182 return num_nodes_shutting_down_ == nodes_.size() || !service_status_.ok();
183 };
184 if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_shutting_down),
185 options_.shutdown_timeout)) {
186 state_ = State::kClosed;
187 return ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
188 "Timed out after ", absl::FormatDuration(options_.shutdown_timeout),
189 " waiting for all nodes to call Shutdown()"));
190 }
191 state_ = State::kClosed;
192 if (!stop_heartbeat_thread_.HasBeenNotified()) {
193 stop_heartbeat_thread_.Notify();
194 }
195 if (!service_status_.ok()) {
196 return ToGrpcStatus(service_status_);
197 }
198 return ::grpc::Status::OK;
199 }
200
EnumerateDevices(::grpc::ServerContext * context,const EnumerateDevicesRequest * request,EnumerateDevicesResponse * response)201 ::grpc::Status DistributedRuntimeServiceImpl::EnumerateDevices(
202 ::grpc::ServerContext* context, const EnumerateDevicesRequest* request,
203 EnumerateDevicesResponse* response) {
204 VLOG(10) << "EnumerateDevices " << request->DebugString();
205 xla::Status status = ValidateSessionId(request->session_id());
206 if (!status.ok()) {
207 return ToGrpcStatus(status);
208 }
209 absl::MutexLock lock(&mu_);
210 if (state_ != State::kRunning) {
211 if (!service_status_.ok()) {
212 return ToGrpcStatus(service_status_);
213 }
214 return ToGrpcStatus(xla::FailedPrecondition(
215 "EnumerateDevices() called when system is not running."));
216 }
217 int node_id = request->local_topology().node_id();
218 status = ValidateNodeId(node_id);
219 if (!status.ok()) {
220 return ToGrpcStatus(status);
221 }
222 local_topologies_[node_id] = request->local_topology();
223 ++num_topologies_present_;
224
225 auto all_topologies_present = [&]() {
226 mu_.AssertHeld();
227 return num_topologies_present_ == nodes_.size() || !service_status_.ok();
228 };
229 if (!mu_.AwaitWithTimeout(absl::Condition(&all_topologies_present),
230 options_.enumerate_devices_timeout)) {
231 return ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
232 "Timed out after ",
233 absl::FormatDuration(options_.enumerate_devices_timeout),
234 " waiting for all nodes to call EnumerateDevices()"));
235 }
236 if (!service_status_.ok()) {
237 return ToGrpcStatus(service_status_);
238 }
239
240 if (node_id == 0) {
241 topology_.emplace();
242 BuildGlobalTopology(absl::Span<LocalTopologyProto>(local_topologies_),
243 &*topology_);
244 local_topologies_.clear();
245 } else {
246 auto topology_ready = [&]() -> bool {
247 mu_.AssertHeld();
248 return topology_.has_value();
249 };
250 mu_.Await(absl::Condition(&topology_ready));
251 }
252 *response->mutable_global_topology() = *topology_;
253 return ::grpc::Status::OK;
254 }
255
Heartbeat(::grpc::ServerContext * context,const HeartbeatRequest * request,HeartbeatResponse * response)256 ::grpc::Status DistributedRuntimeServiceImpl::Heartbeat(
257 ::grpc::ServerContext* context, const HeartbeatRequest* request,
258 HeartbeatResponse* response) {
259 VLOG(10) << "Heartbeat " << request->DebugString();
260 xla::Status status = ValidateSessionId(request->session_id());
261 if (!status.ok()) {
262 return ToGrpcStatus(status);
263 }
264 absl::MutexLock lock(&mu_);
265 if (state_ != State::kRunning) {
266 if (!service_status_.ok()) {
267 return ToGrpcStatus(service_status_);
268 }
269 return ToGrpcStatus(xla::FailedPrecondition(
270 "Heartbeat() called when system is not running."));
271 }
272 int node_id = request->node_id();
273 status = ValidateNodeId(node_id);
274 if (!status.ok()) {
275 return ToGrpcStatus(status);
276 }
277 nodes_[node_id].last_heartbeat = absl::Now();
278 return ::grpc::Status::OK;
279 }
280
HeartbeatLoop()281 void DistributedRuntimeServiceImpl::HeartbeatLoop() {
282 while (true) {
283 stop_heartbeat_thread_.WaitForNotificationWithTimeout(
284 options_.heartbeat_interval);
285 VLOG(10) << "Checking heartbeats";
286 if (stop_heartbeat_thread_.HasBeenNotified()) {
287 VLOG(10) << "Heartbeat checking stopped.";
288 return;
289 }
290 absl::Time now = absl::Now();
291 absl::MutexLock lock(&mu_);
292 for (size_t i = 0; i < nodes_.size(); ++i) {
293 // If we haven't heard from the node for a number of heartbeat intervals,
294 // declare that we are unhealthy.
295 VLOG(10) << "Node " << i
296 << " last heartbeat: " << nodes_[i].last_heartbeat;
297 if (nodes_[i].last_heartbeat +
298 options_.max_missing_heartbeats * options_.heartbeat_interval <
299 now) {
300 LOG(INFO) << "Missed heartbeats from node " << i << ". Shutting down.";
301 state_ = State::kClosed;
302 service_status_ = tensorflow::errors::Aborted(
303 "Shutting down due to missed heartbeat from task ", i);
304 return;
305 }
306 }
307 }
308 }
309
KeyValueGet(::grpc::ServerContext * context,const KeyValueGetRequest * request,KeyValueGetResponse * response)310 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueGet(
311 ::grpc::ServerContext* context, const KeyValueGetRequest* request,
312 KeyValueGetResponse* response) {
313 VLOG(10) << "KeyValueGet " << request->DebugString();
314 xla::Status status = ValidateSessionId(request->session_id());
315 if (!status.ok()) {
316 return ToGrpcStatus(status);
317 }
318 {
319 absl::MutexLock lock(&mu_);
320 if (state_ != State::kRunning) {
321 if (!service_status_.ok()) {
322 return ToGrpcStatus(service_status_);
323 }
324 return ToGrpcStatus(xla::FailedPrecondition(
325 "KeyValueGet() called when system is not running."));
326 }
327 }
328 return key_value_store_.Get(
329 request->key(), absl::Milliseconds(request->timeout_milliseconds()),
330 response->mutable_value());
331 }
332
KeyValueSet(::grpc::ServerContext * context,const KeyValueSetRequest * request,KeyValueSetResponse * response)333 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueSet(
334 ::grpc::ServerContext* context, const KeyValueSetRequest* request,
335 KeyValueSetResponse* response) {
336 VLOG(10) << "KeyValueSet " << request->DebugString();
337 xla::Status status = ValidateSessionId(request->session_id());
338 if (!status.ok()) {
339 return ToGrpcStatus(status);
340 }
341 {
342 absl::MutexLock lock(&mu_);
343 if (state_ != State::kRunning) {
344 if (!service_status_.ok()) {
345 return ToGrpcStatus(service_status_);
346 }
347 return ToGrpcStatus(xla::FailedPrecondition(
348 "KeyValueSet() called when system is not running; clients must call "
349 "Connect() first"));
350 }
351 }
352 return key_value_store_.Set(request->key(), request->value());
353 }
354
355 xla::StatusOr<std::unique_ptr<DistributedRuntimeService>>
Get(const std::string & address,std::shared_ptr<::grpc::ServerCredentials> credentials,const DistributedRuntimeServiceImpl::Options & options)356 DistributedRuntimeService::Get(
357 const std::string& address,
358 std::shared_ptr<::grpc::ServerCredentials> credentials,
359 const DistributedRuntimeServiceImpl::Options& options) {
360 auto service = absl::make_unique<DistributedRuntimeService>(options);
361 ::grpc::ServerBuilder builder;
362 builder.AddListeningPort(address, credentials);
363 VLOG(1) << "Distributed runtime service address " << address;
364 builder.RegisterService(&service->impl_);
365 service->server_ = builder.BuildAndStart();
366 if (!service->server_) {
367 return xla::Unknown("Failed to start RPC server");
368 }
369 LOG(INFO) << "Jax service listening on " << address;
370 return service;
371 }
372
DistributedRuntimeService(const DistributedRuntimeServiceImpl::Options & options)373 DistributedRuntimeService::DistributedRuntimeService(
374 const DistributedRuntimeServiceImpl::Options& options)
375 : impl_(options) {}
376
~DistributedRuntimeService()377 DistributedRuntimeService::~DistributedRuntimeService() {
378 if (server_) {
379 LOG(INFO) << "Jax service shutting down";
380 server_->Shutdown();
381 server_->Wait();
382 }
383 }
384
385 } // namespace xla
386