1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // ==============================================================================
15
16 #include <complex>
17 #include <cstddef>
18 #include <functional>
19 #include <memory>
20
21 #include "grpcpp/grpcpp.h"
22 #include "absl/base/thread_annotations.h"
23 #include "absl/strings/strip.h"
24 #include "absl/synchronization/mutex.h"
25 #include "absl/time/clock.h"
26 #include "absl/time/time.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/python/tpu_driver/event_id.h"
29 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h"
30 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
31 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
32 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_service.grpc.pb.h"
33 #include "tensorflow/compiler/xla/util.h"
34
35 namespace tpu_driver {
36 namespace {
37
38 using xla::Status;
39
40 const int64_t kMaxStreamWriteSize = 10 * 1000 * 1000;
41 const absl::Duration kWriteEpochDuration = absl::Microseconds(10);
42
43 constexpr char kGrpcProtocol[] = "grpc://";
44
45 class GrpcTpuStream;
46 class GrpcTpuDriver;
47
48 class GrpcEvent : public Event {
49 public:
GrpcEvent(EventId id,GrpcTpuStream * stream)50 explicit GrpcEvent(EventId id, GrpcTpuStream* stream)
51 : id_(id), stream_(stream) {}
52 ~GrpcEvent() override;
53
54 xla::Status Await() override;
55 absl::optional<xla::Status> AwaitWithTimeout(
56 absl::Duration duration) override;
57 void AddCallback(std::function<void(Status)> callback) override;
58
id() const59 EventId id() const { return id_; }
stream() const60 GrpcTpuStream* stream() const { return stream_; }
61
62 private:
63 const EventId id_;
64 GrpcTpuStream* stream_;
65 };
66
67 class ErrorEvent : public GrpcEvent {
68 public:
ErrorEvent(Status status)69 explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) {
70 status_ = status;
71 }
72
Await()73 xla::Status Await() override { return status_; }
AwaitWithTimeout(absl::Duration duration)74 absl::optional<xla::Status> AwaitWithTimeout(
75 absl::Duration duration) override {
76 return status_;
77 }
AddCallback(std::function<void (Status)> callback)78 void AddCallback(std::function<void(Status)> callback) override {
79 callback(status_);
80 }
81
82 private:
83 Status status_;
84 };
85
86 class GrpcBufferHandle : public BufferHandle {
87 public:
GrpcBufferHandle(EventId id,std::shared_ptr<GrpcEvent> event,int64_t bytes,absl::optional<xla::ShapeProto> shape=absl::nullopt)88 explicit GrpcBufferHandle(
89 EventId id, std::shared_ptr<GrpcEvent> event, int64_t bytes,
90 absl::optional<xla::ShapeProto> shape = absl::nullopt)
91 : id_(id),
92 stream_(event->stream()),
93 event_(std::move(event)),
94 bytes_(bytes),
95 shape_(shape) {}
96
OnReady()97 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()98 int64_t size_in_bytes() override { return bytes_; }
99
id() const100 EventId id() const { return id_; }
stream() const101 GrpcTpuStream* stream() const { return stream_; }
102
shape()103 absl::optional<xla::ShapeProto> shape() override { return shape_; }
104
105 private:
106 const EventId id_;
107 GrpcTpuStream* stream_;
108 std::shared_ptr<GrpcEvent> event_;
109 int64_t bytes_;
110 absl::optional<xla::ShapeProto> shape_;
111 };
112
113 class GrpcCompiledProgramHandle : public CompiledProgramHandle {
114 public:
GrpcCompiledProgramHandle(EventId id,std::shared_ptr<GrpcEvent> event)115 explicit GrpcCompiledProgramHandle(EventId id,
116 std::shared_ptr<GrpcEvent> event)
117 : id_(id),
118 stream_(event->stream()),
119 event_(std::move(event)),
120 metadata_(std::make_shared<CompiledProgramMetadata>()) {}
121
OnReady()122 std::shared_ptr<Event> OnReady() override { return event_; }
123
id() const124 EventId id() const { return id_; }
stream() const125 GrpcTpuStream* stream() const { return stream_; }
126
program_shape(xla::ProgramShapeProto * program_shape)127 Status program_shape(xla::ProgramShapeProto* program_shape) override {
128 auto opt_status = OnReady()->AwaitWithTimeout(absl::Hours(1));
129 if (!opt_status.has_value()) {
130 return xla::InternalError("Compile failed to finish within 1 hour.");
131 }
132
133 Status status = opt_status.value();
134 if (!status.ok()) {
135 return status;
136 }
137 *program_shape = metadata_->program_shape();
138 return Status::OK();
139 }
140
metadata()141 std::shared_ptr<CompiledProgramMetadata> metadata() { return metadata_; }
142
143 private:
144 const EventId id_;
145 GrpcTpuStream* stream_;
146 std::shared_ptr<GrpcEvent> event_;
147
148 // Using a shared pointer here because the program handle can go out of scope
149 // before we get a response back, but we want a valid location to write things
150 // into regardless.
151 std::shared_ptr<CompiledProgramMetadata> metadata_;
152 };
153
154 class GrpcLoadedProgramHandle : public LoadedProgramHandle {
155 public:
GrpcLoadedProgramHandle(EventId id,std::shared_ptr<GrpcEvent> event)156 explicit GrpcLoadedProgramHandle(EventId id, std::shared_ptr<GrpcEvent> event)
157 : id_(id), stream_(event->stream()), event_(std::move(event)) {}
158
OnReady()159 std::shared_ptr<Event> OnReady() override { return event_; }
160
id() const161 EventId id() const { return id_; }
stream() const162 GrpcTpuStream* stream() const { return stream_; }
163
164 private:
165 const EventId id_;
166 GrpcTpuStream* stream_;
167 std::shared_ptr<GrpcEvent> event_;
168 };
169
170 class GrpcTpuStream {
171 public:
172 explicit GrpcTpuStream(int32_t id, GrpcTpuDriver* driver,
173 std::unique_ptr<grpc::CloudTpuDriver::Stub> stub);
174 virtual ~GrpcTpuStream();
175
176 std::unique_ptr<BufferHandle> Allocate(int32_t core_id, MemoryRegion region,
177 int64_t num_bytes,
178 absl::Span<Event* const> wait_for);
179 std::unique_ptr<BufferHandle> Allocate(int32_t core_id, MemoryRegion region,
180 const xla::ShapeProto& shape,
181 absl::Span<Event* const> wait_for);
182 std::unique_ptr<BufferHandle> AllocateTuple(
183 int32_t core_id, MemoryRegion region,
184 absl::Span<BufferHandle* const> children,
185 absl::Span<Event* const> wait_for);
186 std::shared_ptr<Event> Deallocate(std::unique_ptr<BufferHandle> handle,
187 absl::Span<Event* const> wait_for);
188
189 std::shared_ptr<Event> TransferToDevice(const void* src, BufferHandle* dst,
190 absl::Span<Event* const> wait_for);
191 std::shared_ptr<Event> TransferFromDevice(const BufferHandle* src, void* dst,
192 absl::Span<Event* const> wait_for);
193
194 std::shared_ptr<Event> TransferFromDeviceToDevice(
195 const BufferHandle* src, BufferHandle* dst,
196 absl::Span<Event* const> wait_for);
197
198 std::unique_ptr<CompiledProgramHandle> CompileProgram(
199 const xla::HloProto& source, int32_t num_replicas,
200 absl::Span<Event* const> wait_for);
201 std::unique_ptr<LoadedProgramHandle> LoadProgram(
202 int32_t core_id, const CompiledProgramHandle* handle,
203 absl::Span<Event* const> wait_for);
204 std::shared_ptr<Event> UnloadProgram(
205 std::unique_ptr<LoadedProgramHandle> handle,
206 absl::Span<Event* const> wait_for);
207 std::shared_ptr<Event> ExecuteProgram(
208 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
209 absl::Span<BufferHandle* const> outputs,
210 const xla::DeviceAssignmentProto& device_assignment,
211 absl::Span<Event* const> wait_for);
212
213 private:
214 friend class GrpcEvent;
215 friend class GrpcTpuDriver;
216
217 struct EventInfo {
218 bool all_deps_done = false;
219 bool done = false; // response received
220 bool deleted = false; // deleted by the user
221 Status status;
222 absl::InlinedVector<std::function<void(Status)>, 1> callbacks;
223 // Most events should have <= 2 requirement events.
224 absl::InlinedVector<EventId, 2> deps;
225 };
226
227 struct TransferInfo {
TransferInfotpu_driver::__anon4073f9c10111::GrpcTpuStream::TransferInfo228 explicit TransferInfo(void* dst, int64_t num_bytes)
229 : dst(dst), num_bytes(num_bytes) {}
230
231 void* const dst;
232 const uint64_t num_bytes;
233 };
234
235 struct CompileMetadataInfo {
CompileMetadataInfotpu_driver::__anon4073f9c10111::GrpcTpuStream::CompileMetadataInfo236 explicit CompileMetadataInfo(
237 std::shared_ptr<CompiledProgramMetadata> metadata) {
238 compiled_metadata = metadata;
239 }
240 std::shared_ptr<CompiledProgramMetadata> compiled_metadata;
241 };
242
243 // Every public method above should call this first.
244 void InitializeRequest(StreamRequest::Entry* req,
245 absl::Span<Event* const> wait_for)
246 ABSL_LOCKS_EXCLUDED(events_mutex_);
247
248 // The first update to an event marks it done and calls registered callbacks.
249 // All subsequent updates must have the same OK-ness as the first update.
250 // Among non-OK updates, only the first error status is remembered.
251 void UpdateEventStatus(EventId id, Status status)
252 ABSL_EXCLUSIVE_LOCKS_REQUIRED(events_mutex_);
253
254 // To ensure callbacks are still triggered, after this is called, we do not
255 // remove the event from the event mapping until a response is received from
256 // the server.
257 void DeleteEvent(EventId id) ABSL_LOCKS_EXCLUDED(events_mutex_);
258
259 // Wait at most `duration` for event `id` to complete. Returns the event
260 // status or an empty optional if the event does not complete in time.
261 absl::optional<Status> WaitForEvent(EventId id, absl::Duration duration)
262 ABSL_LOCKS_EXCLUDED(events_mutex_);
263
264 void AddEventCallback(EventId id, std::function<void(Status)> callback)
265 ABSL_LOCKS_EXCLUDED(events_mutex_);
266
AddWriteRequest(std::unique_ptr<StreamRequest::Entry> req)267 void AddWriteRequest(std::unique_ptr<StreamRequest::Entry> req) {
268 absl::MutexLock m(&request_lock_);
269 VLOG(2) << "Adding request: " << req->DebugString();
270 requests_.push_back(std::move(req));
271 }
272
273 // Unique identifier for this stream.
274 int32_t id_;
275 // The parent driver that created this stream.
276 GrpcTpuDriver* driver_;
277
278 std::unique_ptr<grpc::CloudTpuDriver::Stub> stub_;
279 ::grpc::ClientContext ctx_;
280 std::unique_ptr<
281 ::grpc::ClientReaderWriterInterface<StreamRequest, StreamResponse>>
282 stream_;
283
284 absl::Mutex request_lock_;
285 std::deque<std::unique_ptr<StreamRequest::Entry>> requests_
286 ABSL_GUARDED_BY(request_lock_);
287 int64_t num_pending_requests_ ABSL_GUARDED_BY(request_lock_) = 0;
288
289 bool shutting_down_ ABSL_GUARDED_BY(request_lock_) = false;
290
291 void StreamWriterFn();
292 Thread writer_thread_;
293
294 void StreamReaderFn();
295 Thread reader_thread_;
296
297 // Map from operation ID to event information.
298 absl::Mutex events_mutex_;
299 absl::flat_hash_map<EventId, EventInfo> events_
300 ABSL_GUARDED_BY(events_mutex_);
301
302 // Map from operation ID to transfer information.
303 // When a D2H transfer completes, received data is copied into the `dst`
304 // pointer in `TransferInfo`.
305 absl::Mutex transfers_mutex_;
306 absl::flat_hash_map<EventId, TransferInfo> transfers_
307 ABSL_GUARDED_BY(transfers_mutex_);
308
309 absl::Mutex compiles_mutex_;
310 absl::flat_hash_map<EventId, CompileMetadataInfo> compiles_
311 ABSL_GUARDED_BY(compiles_mutex_);
312 };
313
314 class GrpcTpuDriver : public TpuDriver {
315 public:
GrpcTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds,int32_t client_id)316 explicit GrpcTpuDriver(const TpuDriverConfig& config,
317 std::shared_ptr<::grpc::ChannelCredentials> creds,
318 int32_t client_id)
319 : config_(config), creds_(creds), client_id_(client_id) {
320 SystemInfo system_info;
321 QuerySystemInfo(&system_info);
322 for (auto& chip_info : system_info.tpu_chip()) {
323 for (auto& core_info : chip_info.core()) {
324 int32_t core_id = core_info.id();
325 // We have one stream per core, so use core ID as stream ID.
326 streams_[core_id] = AllocateStream(core_id);
327 }
328 }
329 CHECK_GT(streams_.size(), 0) << "Can't find any TPU chip in the system.";
330
331 host_stream_ = AllocateStream(-1);
332 }
333
~GrpcTpuDriver()334 ~GrpcTpuDriver() override {
335 if (closed_) {
336 return;
337 }
338 auto status = Close();
339 if (!status.ok()) {
340 LOG(ERROR) << status;
341 }
342 }
343
344 void QuerySystemInfo(SystemInfo* system_info) override;
345 Status Reset() override;
346
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)347 std::unique_ptr<BufferHandle> Allocate(
348 int32_t core_id, MemoryRegion region, int64_t num_bytes,
349 absl::Span<Event* const> wait_for) override {
350 return streams_[core_id]->Allocate(core_id, region, num_bytes, wait_for);
351 }
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)352 std::unique_ptr<BufferHandle> Allocate(
353 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
354 absl::Span<Event* const> wait_for) override {
355 return streams_[core_id]->Allocate(core_id, region, shape, wait_for);
356 }
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)357 std::unique_ptr<BufferHandle> AllocateTuple(
358 int32_t core_id, MemoryRegion region,
359 absl::Span<BufferHandle* const> children,
360 absl::Span<Event* const> wait_for) override {
361 return streams_[core_id]->AllocateTuple(core_id, region, children,
362 wait_for);
363 }
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)364 std::shared_ptr<Event> Deallocate(
365 std::unique_ptr<BufferHandle> handle,
366 absl::Span<Event* const> wait_for) override {
367 auto* stream = static_cast<GrpcBufferHandle*>(handle.get())->stream();
368 return stream->Deallocate(std::move(handle), wait_for);
369 }
370
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)371 std::shared_ptr<Event> TransferToDevice(
372 const void* src, BufferHandle* dst,
373 absl::Span<Event* const> wait_for) override {
374 auto* stream = static_cast<GrpcBufferHandle*>(dst)->stream();
375 return stream->TransferToDevice(src, dst, wait_for);
376 }
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)377 std::shared_ptr<Event> TransferFromDevice(
378 const BufferHandle* src, void* dst,
379 absl::Span<Event* const> wait_for) override {
380 auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
381 return stream->TransferFromDevice(src, dst, wait_for);
382 }
383
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)384 std::shared_ptr<Event> TransferFromDeviceToDevice(
385 const BufferHandle* src, BufferHandle* dst,
386 absl::Span<Event* const> wait_for) override {
387 auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
388 return stream->TransferFromDeviceToDevice(src, dst, wait_for);
389 }
390
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)391 std::unique_ptr<CompiledProgramHandle> CompileProgram(
392 const xla::HloProto& source, int32_t num_replicas,
393 absl::Span<Event* const> wait_for) override {
394 // Always compile using the first/default core's stream.
395 return streams_[0]->CompileProgram(source, num_replicas, wait_for);
396 }
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)397 std::unique_ptr<LoadedProgramHandle> LoadProgram(
398 int32_t core_id, const CompiledProgramHandle* handle,
399 absl::Span<Event* const> wait_for) override {
400 return streams_[core_id]->LoadProgram(core_id, handle, wait_for);
401 }
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)402 std::shared_ptr<Event> UnloadProgram(
403 std::unique_ptr<LoadedProgramHandle> handle,
404 absl::Span<Event* const> wait_for) override {
405 auto* stream =
406 static_cast<const GrpcLoadedProgramHandle*>(handle.get())->stream();
407 return stream->UnloadProgram(std::move(handle), wait_for);
408 }
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)409 std::shared_ptr<Event> ExecuteProgram(
410 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
411 absl::Span<BufferHandle* const> outputs,
412 const xla::DeviceAssignmentProto& device_assignment,
413 absl::Span<Event* const> wait_for) override {
414 auto* stream =
415 static_cast<const GrpcLoadedProgramHandle*>(program)->stream();
416 return stream->ExecuteProgram(program, inputs, outputs, device_assignment,
417 wait_for);
418 }
419
NewOperationId()420 EventId NewOperationId() { return EventId{client_id_, ++operation_id_}; }
421
422 static std::unique_ptr<grpc::CloudTpuDriver::Stub> CreateTpuDriverStub(
423 const TpuDriverConfig& config,
424 std::shared_ptr<::grpc::ChannelCredentials> creds);
425
client_id() const426 uint32_t client_id() const { return client_id_; }
427
428 private:
429 Status Close();
430 std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
431
432 const TpuDriverConfig config_;
433 std::shared_ptr<::grpc::ChannelCredentials> creds_;
434 const uint32_t client_id_;
435 // Map from stream IDs to streams.
436 absl::flat_hash_map<int32_t, std::unique_ptr<GrpcTpuStream>> streams_;
437 std::unique_ptr<GrpcTpuStream> host_stream_;
438 // Shared by all streams.
439 std::atomic<uint64_t> operation_id_{0};
440 std::atomic<bool> closed_{false};
441 }; // namespace
442
~GrpcEvent()443 GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
444
Await()445 Status GrpcEvent::Await() {
446 auto opt_status = stream_->WaitForEvent(id_, absl::InfiniteDuration());
447 return opt_status.value();
448 }
449
AwaitWithTimeout(absl::Duration duration)450 absl::optional<Status> GrpcEvent::AwaitWithTimeout(absl::Duration duration) {
451 return stream_->WaitForEvent(id_, duration);
452 }
453
AddCallback(std::function<void (Status)> callback)454 void GrpcEvent::AddCallback(std::function<void(Status)> callback) {
455 stream_->AddEventCallback(id_, std::move(callback));
456 }
457
GrpcTpuStream(int32_t id,GrpcTpuDriver * driver,std::unique_ptr<grpc::CloudTpuDriver::Stub> stub)458 GrpcTpuStream::GrpcTpuStream(int32_t id, GrpcTpuDriver* driver,
459 std::unique_ptr<grpc::CloudTpuDriver::Stub> stub)
460 : id_(id),
461 driver_(driver),
462 stub_(std::move(stub)),
463 stream_(stub_->StreamExecute(&ctx_)),
464 writer_thread_(&GrpcTpuStream::StreamWriterFn, this),
465 reader_thread_(&GrpcTpuStream::StreamReaderFn, this) {}
466
~GrpcTpuStream()467 GrpcTpuStream::~GrpcTpuStream() {
468 {
469 absl::MutexLock lock(&request_lock_);
470 shutting_down_ = true;
471 }
472
473 VLOG(1) << "Shutting down stream.";
474 {
475 // Mark all remaining events invalid.
476 absl::MutexLock lock(&events_mutex_);
477 for (const auto& e : events_) {
478 if (!e.second.done) {
479 LOG(ERROR) << "Resetting: " << e.first;
480 UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
481 "Driver was closed."));
482 }
483 }
484 }
485 VLOG(1) << "Closing stream.";
486 stream_->WritesDone();
487 stream_->Finish().IgnoreError();
488 VLOG(1) << "Waiting for writer.";
489 writer_thread_.join();
490 VLOG(1) << "Waiting for reader.";
491 reader_thread_.join();
492 }
493
InitializeRequest(StreamRequest::Entry * req,absl::Span<Event * const> wait_for)494 void GrpcTpuStream::InitializeRequest(StreamRequest::Entry* req,
495 absl::Span<Event* const> wait_for) {
496 auto operation_id = driver_->NewOperationId();
497 EventInfo event_info;
498
499 req->set_operation_id(operation_id.AsInt());
500 if (wait_for.empty()) {
501 event_info.all_deps_done = true;
502 } else {
503 event_info.deps.reserve(wait_for.size());
504 for (auto* event : wait_for) {
505 auto grpc_event = static_cast<const GrpcEvent*>(event);
506 req->add_wait_for_id(grpc_event->id().AsInt());
507 event_info.deps.push_back(grpc_event->id());
508 }
509 }
510
511 absl::MutexLock lock(&events_mutex_);
512 events_[operation_id] = event_info;
513 }
514
UpdateEventStatus(EventId id,Status status)515 void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) {
516 auto it = events_.find(id);
517
518 // These should only happen when the server shuts down, and our local event
519 // cancellation interleaves with server responses. It should be safe to ignore
520 // the second updates in these situations.
521 if (it == events_.end()) {
522 VLOG(1) << "Received a status update: " << status
523 << ", but cannot find GrpcEvent " << id;
524 return;
525 }
526 if (it->second.done) {
527 // Done and deleted events must have already been removed.
528 CHECK(!it->second.deleted);
529 VLOG(1) << "Received a second status update: " << status.error_message()
530 << ", for GrpcEvent " << id << " already done with status: "
531 << it->second.status.error_message();
532 return;
533 }
534
535 // This is the first time this event finishes. Remember the results and call
536 // the callbacks.
537 VLOG(1) << "Response received for GrpcEvent " << id << ". "
538 << status.ToString() << ". Firing " << it->second.callbacks.size()
539 << " callbacks.";
540 it->second.done = true;
541 it->second.status = status;
542 for (const auto& callback : it->second.callbacks) {
543 callback(status);
544 }
545
546 // Truly remove the event if it's both done and deleted.
547 if (it->second.deleted) {
548 events_.erase(it);
549 }
550 }
551
DeleteEvent(EventId id)552 void GrpcTpuStream::DeleteEvent(EventId id) {
553 absl::MutexLock lock(&events_mutex_);
554 auto it = events_.find(id);
555 CHECK(it != events_.end());
556 CHECK(!it->second.deleted);
557 it->second.deleted = true;
558 // Truly remove the event if it's both done and deleted.
559 if (it->second.done) {
560 events_.erase(it);
561 }
562 }
563
WaitForEvent(EventId id,absl::Duration duration)564 absl::optional<Status> GrpcTpuStream::WaitForEvent(EventId id,
565 absl::Duration duration) {
566 events_mutex_.Lock();
567 auto it = events_.find(id);
568
569 if (it == events_.end()) {
570 // This event has already been marked as done and deleted. Assume success.
571 events_mutex_.Unlock();
572 return Status::OK();
573 }
574
575 if (!it->second.all_deps_done) {
576 absl::InlinedVector<EventId, 2> deps = it->second.deps;
577 events_mutex_.Unlock();
578 for (auto dep : deps) {
579 // If a requirement event timed out, no point in any further waiting.
580 if (!WaitForEvent(dep, duration)) {
581 return absl::nullopt;
582 }
583 }
584 events_mutex_.Lock();
585 }
586
587 // Set the flag here, as we're guaranteed they have all completed at this
588 // point. This helps terminate recursion on a chain of completed events as
589 // soon as possible, at this event.
590 it = events_.find(id);
591 if (it != events_.end()) {
592 it->second.all_deps_done = true;
593 }
594
595 auto done = [this, id]() {
596 events_mutex_.AssertHeld();
597 return !events_.contains(id) || events_[id].done;
598 };
599 if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) {
600 auto status = events_.contains(id) ? events_[id].status : Status::OK();
601 events_mutex_.Unlock();
602 return status;
603 }
604 events_mutex_.Unlock();
605 return absl::nullopt;
606 }
607
AddEventCallback(EventId id,std::function<void (Status)> callback)608 void GrpcTpuStream::AddEventCallback(EventId id,
609 std::function<void(Status)> callback) {
610 absl::MutexLock lock(&events_mutex_);
611 auto it = events_.find(id);
612 if (it == events_.end()) {
613 callback(Status());
614 return;
615 }
616 if (it->second.done) {
617 callback(it->second.status);
618 return;
619 }
620 it->second.callbacks.push_back(std::move(callback));
621 }
622
ShouldBeginWriting(int64_t * pending_requests)623 static bool ShouldBeginWriting(int64_t* pending_requests) {
624 return *pending_requests > 32;
625 }
626
StreamWriterFn()627 void GrpcTpuStream::StreamWriterFn() {
628 while (true) {
629 request_lock_.LockWhenWithTimeout(
630 absl::Condition(&ShouldBeginWriting, &num_pending_requests_),
631 kWriteEpochDuration);
632 if (shutting_down_) {
633 request_lock_.Unlock();
634 return;
635 }
636
637 if (requests_.empty()) {
638 request_lock_.Unlock();
639 continue;
640 }
641
642 std::vector<StreamRequest> reqs;
643 int64_t request_bytes = 0;
644 while (!requests_.empty()) {
645 StreamRequest::Entry* e = requests_.front().release();
646 requests_.pop_front();
647 const int64_t entry_bytes = e->ByteSizeLong();
648 if (reqs.empty() || request_bytes + entry_bytes > kMaxStreamWriteSize) {
649 reqs.push_back(StreamRequest());
650 request_bytes = 0;
651 }
652 VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id());
653 VLOG(2) << "Sending request: " << e->DebugString();
654 reqs.back().mutable_entry()->AddAllocated(e);
655 }
656 num_pending_requests_ = 0;
657 request_lock_.Unlock();
658
659 for (const auto& r : reqs) {
660 TraceMe activity("GrpcTpuStream::Send ");
661 ::grpc::WriteOptions opts;
662 opts.set_no_compression().clear_buffer_hint();
663 stream_->Write(r, opts);
664 }
665 }
666 }
667
StreamReaderFn()668 void GrpcTpuStream::StreamReaderFn() {
669 StreamResponse resp;
670 while (stream_->Read(&resp)) {
671 VLOG(2) << "Received response: " << resp.DebugString();
672 for (const StreamResponse::Entry& entry : resp.entry()) {
673 EventId event_id = EventId::FromInt(entry.operation_id());
674 VLOG(1) << "Received response for: " << event_id;
675
676 TraceMe activity("GrpcTpuStream::RequestComplete");
677 if (entry.has_transfer_from()) {
678 TraceMe activity("GrpcTpuStream::TransferFromComplete");
679 absl::MutexLock lock(&transfers_mutex_);
680 auto it = transfers_.find(event_id);
681 CHECK(it != transfers_.end());
682 VLOG(1) << "Copying: " << it->second.num_bytes << " to position "
683 << it->second.dst;
684 if (entry.transfer_from().data().size() != it->second.num_bytes) {
685 absl::MutexLock lock(&events_mutex_);
686 UpdateEventStatus(
687 event_id,
688 Status(
689 tensorflow::error::Code::DATA_LOSS,
690 absl::StrCat("Expected ", it->second.num_bytes, " received ",
691 entry.transfer_from().data().size())));
692 continue;
693 }
694 memcpy(it->second.dst, entry.transfer_from().data().data(),
695 it->second.num_bytes);
696 }
697
698 if (entry.has_compile()) {
699 TraceMe activity("GrpcTpuStream::CompileComplete");
700 absl::MutexLock lock(&compiles_mutex_);
701 auto it = compiles_.find(event_id);
702 CHECK(it != compiles_.end());
703 *it->second.compiled_metadata = entry.compile().metadata();
704 }
705
706 absl::MutexLock lock(&events_mutex_);
707 if (entry.status().code() != tensorflow::error::Code::OK) {
708 UpdateEventStatus(
709 event_id,
710 Status(static_cast<tensorflow::error::Code>(entry.status().code()),
711 entry.status().message()));
712 } else {
713 UpdateEventStatus(event_id, Status::OK());
714 }
715 }
716 }
717 }
718
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)719 std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
720 int32_t core_id, MemoryRegion region, int64_t num_bytes,
721 absl::Span<Event* const> wait_for) {
722 auto req = absl::make_unique<StreamRequest::Entry>();
723 InitializeRequest(req.get(), wait_for);
724 TraceMe activity("GrpcTpuStream::Allocate(num_bytes)");
725 req->mutable_alloc()->set_core_id(core_id);
726 req->mutable_alloc()->set_region(region);
727 req->mutable_alloc()->set_num_bytes(num_bytes);
728 auto event =
729 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
730 AddWriteRequest(std::move(req));
731 return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event),
732 num_bytes);
733 }
734
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)735 std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
736 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
737 absl::Span<Event* const> wait_for) {
738 auto req = absl::make_unique<StreamRequest::Entry>();
739 InitializeRequest(req.get(), wait_for);
740 TraceMe activity("GrpcTpuStream::Allocate(shape)");
741 req->mutable_alloc()->set_core_id(core_id);
742 req->mutable_alloc()->set_region(region);
743 *req->mutable_alloc()->mutable_shape() = shape;
744 auto event =
745 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
746 AddWriteRequest(std::move(req));
747 return absl::make_unique<GrpcBufferHandle>(
748 event->id(), std::move(event), ComputeBytesFromShape(shape), shape);
749 }
750
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)751 std::unique_ptr<BufferHandle> GrpcTpuStream::AllocateTuple(
752 int32_t core_id, MemoryRegion region,
753 absl::Span<BufferHandle* const> children,
754 absl::Span<Event* const> wait_for) {
755 auto req = absl::make_unique<StreamRequest::Entry>();
756 InitializeRequest(req.get(), wait_for);
757 TraceMe activity("GrpcTpuStream::AllocateTuple");
758 req->mutable_alloc_tuple()->set_core_id(core_id);
759 req->mutable_alloc_tuple()->set_region(region);
760 for (auto child : children) {
761 auto grpc_child = static_cast<GrpcBufferHandle*>(child);
762 req->mutable_alloc_tuple()->add_children(grpc_child->id().AsInt());
763 }
764 auto event =
765 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
766 AddWriteRequest(std::move(req));
767 return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event), 0);
768 }
769
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)770 std::shared_ptr<Event> GrpcTpuStream::Deallocate(
771 std::unique_ptr<BufferHandle> handle, absl::Span<Event* const> wait_for) {
772 auto req = absl::make_unique<StreamRequest::Entry>();
773 InitializeRequest(req.get(), wait_for);
774 TraceMe activity("GrpcTpuStream::Deallocate");
775 auto grpc_handle = static_cast<GrpcBufferHandle*>(handle.get());
776 req->mutable_dealloc()->set_handle(grpc_handle->id().AsInt());
777 auto event =
778 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
779 AddWriteRequest(std::move(req));
780 return event;
781 }
782
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)783 std::shared_ptr<Event> GrpcTpuStream::TransferToDevice(
784 const void* src, BufferHandle* dst, absl::Span<Event* const> wait_for) {
785 auto req = absl::make_unique<StreamRequest::Entry>();
786 InitializeRequest(req.get(), wait_for);
787 TraceMe activity("GrpcTpuStream::TransferToDevice");
788 req->mutable_transfer_to()->mutable_data()->assign(
789 static_cast<const char*>(src), dst->size_in_bytes());
790 req->mutable_transfer_to()->set_target_handle(
791 static_cast<GrpcBufferHandle*>(dst)->id().AsInt());
792 auto event =
793 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
794 AddWriteRequest(std::move(req));
795 return event;
796 }
797
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)798 std::shared_ptr<Event> GrpcTpuStream::TransferFromDevice(
799 const BufferHandle* src, void* dst, absl::Span<Event* const> wait_for) {
800 auto req = absl::make_unique<StreamRequest::Entry>();
801 InitializeRequest(req.get(), wait_for);
802 TraceMe activity("GrpcTpuStream::TransferFromDevice");
803 req->mutable_transfer_from()->set_source_handle(
804 static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
805 EventId event_id = EventId::FromInt(req->operation_id());
806 {
807 absl::MutexLock lock(&transfers_mutex_);
808 TransferInfo info(dst, const_cast<BufferHandle*>(src)->size_in_bytes());
809 transfers_.insert(std::make_pair(event_id, info));
810 }
811 auto event = std::make_shared<GrpcEvent>(event_id, this);
812 AddWriteRequest(std::move(req));
813 return event;
814 }
815
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)816 std::shared_ptr<Event> GrpcTpuStream::TransferFromDeviceToDevice(
817 const BufferHandle* src, BufferHandle* dst,
818 absl::Span<Event* const> wait_for) {
819 auto req = absl::make_unique<StreamRequest::Entry>();
820 InitializeRequest(req.get(), wait_for);
821 TraceMe activity([&req] {
822 return absl::StrCat("GrpcTpuStream::TransferFromDeviceToDevice",
823 req->operation_id());
824 });
825
826 req->mutable_transfer_from_to()->set_source_handle(
827 static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
828 req->mutable_transfer_from_to()->set_target_handle(
829 static_cast<const GrpcBufferHandle*>(dst)->id().AsInt());
830 EventId event_id = EventId::FromInt(req->operation_id());
831 auto event = std::make_shared<GrpcEvent>(event_id, this);
832 AddWriteRequest(std::move(req));
833 return event;
834 }
835
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)836 std::unique_ptr<CompiledProgramHandle> GrpcTpuStream::CompileProgram(
837 const xla::HloProto& source, int32_t num_replicas,
838 absl::Span<Event* const> wait_for) {
839 auto req = absl::make_unique<StreamRequest::Entry>();
840 InitializeRequest(req.get(), wait_for);
841 TraceMe activity("GrpcTpuStream::CompileProgram");
842 *req->mutable_compile()->mutable_hlo_program() = source;
843 req->mutable_compile()->set_num_replicas(num_replicas);
844 EventId event_id = EventId::FromInt(req->operation_id());
845
846 auto event =
847 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
848
849 auto handle = absl::make_unique<GrpcCompiledProgramHandle>(event->id(),
850 std::move(event));
851 {
852 absl::MutexLock lock(&compiles_mutex_);
853 CompileMetadataInfo info(handle->metadata());
854 compiles_.insert(std::make_pair(event_id, info));
855 }
856
857 AddWriteRequest(std::move(req));
858 return std::move(handle);
859 }
860
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)861 std::unique_ptr<LoadedProgramHandle> GrpcTpuStream::LoadProgram(
862 int32_t core_id, const CompiledProgramHandle* handle,
863 absl::Span<Event* const> wait_for) {
864 auto req = absl::make_unique<StreamRequest::Entry>();
865 InitializeRequest(req.get(), wait_for);
866 TraceMe activity("GrpcTpuStream::LoadProgram");
867 req->mutable_load()->set_core_id(core_id);
868 auto grpc_handle = static_cast<const GrpcCompiledProgramHandle*>(handle);
869 if (grpc_handle->id().client_id != driver_->client_id()) {
870 auto event = std::make_shared<ErrorEvent>(
871 xla::InvalidArgument("Invalid program handle (wrong client id). Did "
872 "you restart the server or use a stale handle?"));
873 return absl::make_unique<GrpcLoadedProgramHandle>(event->id(),
874 std::move(event));
875 }
876 req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt());
877 auto event =
878 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
879 AddWriteRequest(std::move(req));
880 return absl::make_unique<GrpcLoadedProgramHandle>(event->id(),
881 std::move(event));
882 }
883
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)884 std::shared_ptr<Event> GrpcTpuStream::UnloadProgram(
885 std::unique_ptr<LoadedProgramHandle> handle,
886 absl::Span<Event* const> wait_for) {
887 auto req = absl::make_unique<StreamRequest::Entry>();
888 InitializeRequest(req.get(), wait_for);
889 TraceMe activity("GrpcTpuStream::UnloadProgram");
890 req->mutable_unload()->set_loaded_program_handle(
891 static_cast<GrpcLoadedProgramHandle*>(handle.get())->id().AsInt());
892 auto event =
893 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
894 AddWriteRequest(std::move(req));
895 return event;
896 }
897
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)898 std::shared_ptr<Event> GrpcTpuStream::ExecuteProgram(
899 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
900 absl::Span<BufferHandle* const> outputs,
901 const xla::DeviceAssignmentProto& device_assignment,
902 absl::Span<Event* const> wait_for) {
903 auto req = absl::make_unique<StreamRequest::Entry>();
904 InitializeRequest(req.get(), wait_for);
905 auto program_handle = static_cast<GrpcLoadedProgramHandle*>(program);
906 if (program_handle->id().client_id != driver_->client_id()) {
907 return std::make_shared<ErrorEvent>(
908 xla::InvalidArgument("Invalid program handle (wrong client id). Did "
909 "you restart the server or use a stale handle?"));
910 }
911
912 req->mutable_execute()->set_loaded_program_handle(
913 program_handle->id().AsInt());
914
915 for (BufferHandle* input : inputs) {
916 auto* grpc_handle = static_cast<GrpcBufferHandle*>(input);
917 if (grpc_handle->id().client_id != driver_->client_id()) {
918 return std::make_shared<ErrorEvent>(xla::InvalidArgument(
919 "Invalid input buffer (wrong client id). Did you restart the server "
920 "or use a stale handle?"));
921 }
922 req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt());
923 }
924
925 for (BufferHandle* output : outputs) {
926 auto* grpc_handle = static_cast<GrpcBufferHandle*>(output);
927 if (grpc_handle->id().client_id != driver_->client_id()) {
928 return std::make_shared<ErrorEvent>(xla::InvalidArgument(
929 "Invalid output buffer (wrong client id). Did you restart the server "
930 "or use a stale handle?"));
931 }
932 req->mutable_execute()->add_output_handle(
933 static_cast<GrpcBufferHandle*>(output)->id().AsInt());
934 }
935 // Only pass along device_assignment if it's not default constructed.
936 if (!(device_assignment.replica_count() == 0 &&
937 device_assignment.computation_count() == 0)) {
938 *req->mutable_execute()->mutable_device_assignment() = device_assignment;
939 }
940 auto event =
941 std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
942 AddWriteRequest(std::move(req));
943 return event;
944 }
945
946 /*static*/ std::unique_ptr<grpc::CloudTpuDriver::Stub>
CreateTpuDriverStub(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)947 GrpcTpuDriver::CreateTpuDriverStub(
948 const TpuDriverConfig& config,
949 std::shared_ptr<::grpc::ChannelCredentials> creds) {
950 ::grpc::ChannelArguments args;
951 args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
952 args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
953
954 // Send at least 20 keep-alives before giving up.
955 int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000;
956 int keepalive_interval_ms = keepalive_timeout_ms / 20;
957
958 grpc_arg client_arg_vals[] = {
959 {.type = GRPC_ARG_INTEGER,
960 .key = const_cast<char*>(
961 GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS),
962 .value = {.integer = keepalive_interval_ms}},
963 {.type = GRPC_ARG_INTEGER,
964 .key = const_cast<char*>(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA),
965 .value = {.integer = 0}}, // unlimited
966 {.type = GRPC_ARG_INTEGER,
967 .key = const_cast<char*>(GRPC_ARG_KEEPALIVE_TIME_MS),
968 .value = {.integer = keepalive_interval_ms}},
969 {.type = GRPC_ARG_INTEGER,
970 .key = const_cast<char*>(GRPC_ARG_KEEPALIVE_TIMEOUT_MS),
971 .value = {.integer = keepalive_timeout_ms}},
972 {.type = GRPC_ARG_INTEGER,
973 .key = const_cast<char*>(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS),
974 .value = {.integer = 1}},
975 {.type = GRPC_ARG_INTEGER,
976 .key = const_cast<char*>(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE),
977 .value = {.integer = 64 * 1000 * 1000}}};
978
979 grpc_channel_args client_args = {.num_args = 6, .args = client_arg_vals};
980 args.SetChannelArgs(&client_args);
981
982 // strips out 'grpc://'
983 auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol);
984 std::shared_ptr<::grpc::Channel> channel =
985 ::grpc::CreateCustomChannel(std::string(worker_addr), creds, args);
986 return grpc::CloudTpuDriver::NewStub(channel);
987 }
988
AllocateStream(int32_t id)989 std::unique_ptr<GrpcTpuStream> GrpcTpuDriver::AllocateStream(int32_t id) {
990 auto stub = CreateTpuDriverStub(config_, creds_);
991 ::grpc::ClientContext ctx;
992 ctx.set_fail_fast(false);
993 ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
994 return absl::make_unique<GrpcTpuStream>(id, this, std::move(stub));
995 }
996
QuerySystemInfo(SystemInfo * system_info)997 void GrpcTpuDriver::QuerySystemInfo(SystemInfo* system_info) {
998 auto stub = CreateTpuDriverStub(config_, creds_);
999 ::grpc::ClientContext ctx;
1000 ctx.set_fail_fast(false);
1001 ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
1002
1003 QuerySystemInfoRequest req;
1004 QuerySystemInfoResponse resp;
1005 ::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp);
1006 if (!status.ok()) {
1007 LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code()
1008 << ": " << status.error_message() << ": "
1009 << status.error_details();
1010 return;
1011 }
1012 *system_info = resp.system_info();
1013 }
1014
Reset()1015 Status GrpcTpuDriver::Reset() {
1016 auto stub = CreateTpuDriverStub(config_, creds_);
1017 ::grpc::ClientContext ctx;
1018 ctx.set_fail_fast(false);
1019 ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
1020 ResetRequest req;
1021 ResetResponse resp;
1022 ::grpc::Status status = stub->Reset(&ctx, req, &resp);
1023 if (!status.ok()) {
1024 LOG(ERROR) << "Failed to reset the gRPC driver: " << status.error_code()
1025 << ": " << status.error_message() << ": "
1026 << status.error_details();
1027 return xla::Status(tensorflow::error::Code(status.error_code()),
1028 absl::StrCat("Failed to reset TPU driver. Error was: ",
1029 status.error_message(),
1030 ". Details: ", status.error_details()));
1031 }
1032 streams_.clear();
1033 host_stream_.reset();
1034 return Close();
1035 }
1036
Close()1037 Status GrpcTpuDriver::Close() {
1038 auto stub = CreateTpuDriverStub(config_, creds_);
1039 ::grpc::ClientContext ctx;
1040 ctx.set_fail_fast(false);
1041 ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
1042 CloseRequest req;
1043 req.set_client_id(client_id_);
1044 CloseResponse resp;
1045 ::grpc::Status status = stub->Close(&ctx, req, &resp);
1046 if (!status.ok()) {
1047 return xla::Status(tensorflow::error::Code(status.error_code()),
1048 absl::StrCat("Failed to close TPU driver. Error was: ",
1049 status.error_message(),
1050 ". Details: ", status.error_details()));
1051 }
1052 closed_ = true;
1053 return Status::OK();
1054 }
1055 } // namespace
1056
CreateGrpcTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)1057 xla::StatusOr<std::unique_ptr<TpuDriver>> CreateGrpcTpuDriver(
1058 const TpuDriverConfig& config,
1059 std::shared_ptr<::grpc::ChannelCredentials> creds) {
1060 auto stub = GrpcTpuDriver::CreateTpuDriverStub(config, creds);
1061 ::grpc::ClientContext ctx;
1062 ctx.set_fail_fast(false);
1063 ctx.set_deadline(
1064 std::chrono::system_clock::now() +
1065 std::chrono::seconds(config.grpc().connection_timeout_secs()));
1066 OpenRequest req;
1067 OpenResponse resp;
1068 ::grpc::Status status = stub->Open(&ctx, req, &resp);
1069 if (!status.ok()) {
1070 LOG(ERROR) << "Failed to open the gRPC driver: " << status.error_code()
1071 << ": " << status.error_message() << ": "
1072 << status.error_details();
1073 return xla::Status(
1074 tensorflow::error::Code(status.error_code()),
1075 absl::StrCat(
1076 "Failed to connect to remote server at address: ", config.worker(),
1077 ". Error from gRPC: ", status.error_message(),
1078 ". Details: ", status.error_details()));
1079 }
1080 return std::unique_ptr<TpuDriver>(
1081 new GrpcTpuDriver(config, creds, resp.client_id()));
1082 }
1083
1084 REGISTER_TPU_DRIVER(
1085 "grpc://",
1086 [](const TpuDriverConfig& config)
__anon4073f9c10402(const TpuDriverConfig& config) 1087 -> xla::StatusOr<std::unique_ptr<TpuDriver>> {
1088 if (absl::StartsWith(config.worker(), "grpc://localhost")) {
1089 LOG(INFO) << "Using local credentials for localhost: connection.";
1090 return CreateGrpcTpuDriver(
1091 config, ::grpc::experimental::LocalCredentials(LOCAL_TCP));
1092 } else {
1093 return CreateGrpcTpuDriver(config,
1094 ::grpc::InsecureChannelCredentials());
1095 }
1096 });
1097
1098 } // namespace tpu_driver
1099