1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15
16 #include "absl/container/btree_map.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_split.h"
20 #include "absl/synchronization/mutex.h"
21 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
22 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29
30 namespace tpu_driver {
31 namespace {
32
33 #define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id) \
34 { \
35 auto p = CheckHandleExists(container, target_op_id, operation_id); \
36 if (p != nullptr) return p; \
37 }
38
39 using xla::Status;
40 using xla::WorkerThread;
41
42 const char kPodTpuDriverPrefix[] = "grpc+pod://";
43
44 class PodTpuDriver;
45
46 class PodEvent : public Event {
47 public:
PodEvent(PodTpuDriver * driver,int64_t operation_id)48 explicit PodEvent(PodTpuDriver* driver, int64_t operation_id)
49 : driver_(driver), operation_id_(operation_id) {}
operation_id() const50 int64_t operation_id() const { return operation_id_; }
51
52 xla::Status Await() override;
53
54 absl::optional<xla::Status> AwaitWithTimeout(
55 absl::Duration duration) override;
56
57 void AddCallback(std::function<void(Status)> callback) override;
58
59 private:
60 PodTpuDriver* driver_;
61 const int64_t operation_id_;
62 };
63
64 class ErrorEvent : public PodEvent {
65 public:
ErrorEvent(PodTpuDriver * driver,int64_t operation_id,Status status)66 explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status)
67 : PodEvent(driver, operation_id) {
68 status_ = status;
69 }
70
Await()71 xla::Status Await() override { return status_; }
AwaitWithTimeout(absl::Duration duration)72 absl::optional<xla::Status> AwaitWithTimeout(
73 absl::Duration duration) override {
74 return status_;
75 }
AddCallback(std::function<void (Status)> callback)76 void AddCallback(std::function<void(Status)> callback) override {
77 callback(status_);
78 }
79
80 private:
81 Status status_;
82 };
83
84 class CombinedEvent : public PodEvent {
85 public:
CombinedEvent(PodTpuDriver * driver,int64_t operation_id,std::vector<std::shared_ptr<Event>> events)86 explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
87 std::vector<std::shared_ptr<Event>> events)
88 : PodEvent(driver, operation_id), events_(events) {
89 for (auto& event : events_) {
90 event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); });
91 }
92 }
93
Await()94 xla::Status Await() override {
95 for (auto& event : events_) {
96 TF_RETURN_IF_ERROR(event->Await());
97 }
98 return Status::OK();
99 }
100
AwaitWithTimeout(absl::Duration duration)101 absl::optional<xla::Status> AwaitWithTimeout(
102 absl::Duration duration) override {
103 for (auto& event : events_) {
104 auto start_time = absl::Now();
105 auto status = event->AwaitWithTimeout(duration);
106 duration -= absl::Now() - start_time;
107 if (status == absl::nullopt) {
108 return absl::nullopt;
109 } else {
110 TF_RETURN_IF_ERROR(status.value());
111 }
112 }
113 return Status::OK();
114 }
115
AddCallback(std::function<void (Status)> callback)116 void AddCallback(std::function<void(Status)> callback)
117 TF_LOCKS_EXCLUDED(mu_) override {
118 bool all_events_completed = false;
119 {
120 absl::MutexLock l(&mu_);
121 all_events_completed = events_completed_ == events_.size();
122 }
123 if (all_events_completed) {
124 callback(event_status_);
125 } else {
126 absl::MutexLock l(&mu_);
127 callbacks_.push_back(std::move(callback));
128 }
129 }
130
131 private:
IncrementAndCheckComplete(Status s)132 void IncrementAndCheckComplete(Status s) TF_LOCKS_EXCLUDED(mu_) {
133 std::vector<std::function<void(Status)>> callbacks;
134 {
135 absl::MutexLock l(&mu_);
136
137 event_status_ = s;
138 events_completed_++;
139 if (events_completed_ == events_.size()) {
140 // Copy callbacks to a temporary to be invoked outside the mutex.
141 callbacks.assign(callbacks_.begin(), callbacks_.end());
142 callbacks_.clear();
143 } else {
144 return;
145 }
146 }
147
148 for (const auto& callback : callbacks) {
149 callback(event_status_);
150 }
151 }
152
153 absl::Mutex mu_;
154 std::vector<std::shared_ptr<Event>> events_;
155 std::vector<std::function<void(Status)>> callbacks_ ABSL_GUARDED_BY(mu_);
156 int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0;
157 Status event_status_;
158 };
159
160 class PodBufferHandle : public BufferHandle {
161 public:
PodBufferHandle(PodTpuDriver * driver,int64_t operation_id,int64_t size_in_bytes,absl::optional<xla::ShapeProto> shape,int64_t core_id)162 explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id,
163 int64_t size_in_bytes,
164 absl::optional<xla::ShapeProto> shape,
165 int64_t core_id)
166 : driver_(driver),
167 operation_id_(operation_id),
168 size_in_bytes_(size_in_bytes),
169 shape_(shape),
170 event_(std::make_shared<PodEvent>(driver_, operation_id_)),
171 core_id_(core_id) {}
172
OnReady()173 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()174 int64_t size_in_bytes() override { return size_in_bytes_; }
shape()175 absl::optional<xla::ShapeProto> shape() override { return shape_; }
176
operation_id() const177 int64_t operation_id() const { return operation_id_; }
core_id() const178 int64_t core_id() const { return core_id_; }
179
180 private:
181 PodTpuDriver* driver_;
182 const int64_t operation_id_;
183 const int64_t size_in_bytes_;
184 const absl::optional<xla::ShapeProto> shape_;
185 std::shared_ptr<PodEvent> event_;
186 const int64_t core_id_;
187 };
188
189 class PodCompiledProgramHandle : public CompiledProgramHandle {
190 public:
PodCompiledProgramHandle(PodTpuDriver * driver,int64_t operation_id)191 explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id)
192 : driver_(driver),
193 operation_id_(operation_id),
194 event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
195
OnReady()196 std::shared_ptr<Event> OnReady() override { return event_; }
197
198 xla::Status program_shape(xla::ProgramShapeProto* program_shape) override;
199
operation_id() const200 int64_t operation_id() const { return operation_id_; }
201
202 private:
203 PodTpuDriver* driver_;
204 const int64_t operation_id_;
205 std::shared_ptr<PodEvent> event_;
206 };
207
208 class PodLoadedProgramHandle : public LoadedProgramHandle {
209 public:
PodLoadedProgramHandle(PodTpuDriver * driver,int64_t operation_id,int64_t core_id)210 explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id,
211 int64_t core_id)
212 : driver_(driver),
213 operation_id_(operation_id),
214 core_id_(core_id),
215 event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
216
OnReady()217 std::shared_ptr<Event> OnReady() override { return event_; }
218
operation_id() const219 int64_t operation_id() const { return operation_id_; }
core_id() const220 int64_t core_id() const { return core_id_; }
221
222 private:
223 PodTpuDriver* driver_;
224 const int64_t operation_id_;
225 const int64_t core_id_;
226 std::shared_ptr<PodEvent> event_;
227 };
228
229 struct EventInFlight {
EventInFlighttpu_driver::__anon6b014a980111::EventInFlight230 EventInFlight()
231 : underlying_event(nullptr),
232 create_fn(nullptr),
233 incomplete_deps(),
234 callbacks() {}
235
236 std::shared_ptr<Event> underlying_event;
237 std::function<std::shared_ptr<Event>(void)> create_fn;
238
239 absl::flat_hash_set<int64_t> incomplete_deps;
240 std::vector<std::function<void(Status)>> callbacks;
241 };
242
243 class PodTpuDriver : public TpuDriver {
244 public:
PodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)245 explicit PodTpuDriver(const TpuDriverConfig& config,
246 std::shared_ptr<::grpc::ChannelCredentials> creds)
247 : config_(config),
248 creds_(creds),
249 event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") {
250 std::vector<std::string> workers = absl::StrSplit(
251 absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ',');
252
253 int worker_count = 0;
254
255 // Flag for environments where local core # == all cores in TPU system #,
256 // which means that we are connecting to separate TPU systems or we are in
257 // a test environment.
258 bool in_local_core_environment = false;
259
260 for (const auto& worker : workers) {
261 TpuDriverConfig worker_config(config_);
262 *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker);
263 auto tpu_driver =
264 CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie();
265
266 SystemInfo driver_info;
267 tpu_driver->QuerySystemInfo(&driver_info);
268
269 if (driver_info.core_count() == driver_info.local_core_size()) {
270 drivers_.insert({worker_count, std::move(tpu_driver)});
271 in_local_core_environment = true;
272 } else {
273 drivers_.insert({driver_info.host_id(), std::move(tpu_driver)});
274 }
275
276 worker_count++;
277 }
278
279 absl::flat_hash_set<std::tuple<int, int, int>> processed_chips;
280
281 for (int driver_num = 0; driver_num < workers.size(); ++driver_num) {
282 SystemInfo driver_info;
283 drivers_[driver_num]->QuerySystemInfo(&driver_info);
284
285 for (const auto& tpu_chip : driver_info.tpu_chip()) {
286 std::tuple<int, int, int> coord{tpu_chip.chip_coord().x(),
287 tpu_chip.chip_coord().y(),
288 tpu_chip.chip_coord().z()};
289 // We only want to add chips that we have not seen before if we are in a
290 // TPU pod slice, or we are only seeing local cores (e.g. we are
291 // connected to individual TPUs or we are in a test environment).
292 if (!processed_chips.contains(coord) ||
293 driver_info.core_count() == driver_info.local_core_size()) {
294 *(pod_info_.add_tpu_chip()) = tpu_chip;
295 processed_chips.insert(coord);
296 }
297 }
298
299 *(pod_info_.mutable_cpu()) = driver_info.cpu();
300 }
301
302 // Process all the unique chips that we have seen.
303 int core_count = 0;
304 for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) {
305 for (auto& tpu_core : *tpu_chip.mutable_core()) {
306 int current_core = tpu_core.id();
307 if (in_local_core_environment) {
308 current_core = core_count;
309 }
310
311 core_to_driver_.insert(
312 {current_core, drivers_[tpu_chip.host_id()].get()});
313 core_to_driver_id_.insert({current_core, tpu_chip.host_id()});
314 core_to_driver_core_.insert({current_core, tpu_core.id()});
315
316 tpu_core.set_id(current_core);
317 tpu_core.set_core_on_host_index(current_core);
318 *(pod_info_.add_local_core()) = tpu_core;
319
320 core_count++;
321 }
322
323 // We are setting host_id to zero because we want this to look like one
324 // host with many cores from the perspective of tpu_client.cc.
325 tpu_chip.set_host_id(0);
326 }
327
328 pod_info_.set_chip_count(pod_info_.tpu_chip_size());
329 pod_info_.set_core_count(pod_info_.local_core_size());
330
331 // We want this to look like one host with many TPU chips/cores connected.
332 pod_info_.set_host_count(1);
333 pod_info_.set_host_id(0);
334 }
335
~PodTpuDriver()336 ~PodTpuDriver() override {
337 // TODO(frankchn): Unload all handles, and wait for all events to finish.
338 }
339
QuerySystemInfo(SystemInfo * system_info)340 void QuerySystemInfo(SystemInfo* system_info) override {
341 *system_info = pod_info_;
342 }
343
Reset()344 xla::Status Reset() override {
345 for (auto& driver : drivers_) {
346 TF_RETURN_IF_ERROR(driver.second->Reset());
347 }
348 return xla::Status::OK();
349 }
350
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)351 std::unique_ptr<BufferHandle> Allocate(
352 int32_t core_id, MemoryRegion region, int64_t num_bytes,
353 absl::Span<Event* const> wait_for) override {
354 int64_t operation_id = GetOperationId();
355 auto deps = GetDependencyOperationIds(wait_for);
356
357 ScheduleRequest(
358 operation_id,
359 [this, core_id, region, num_bytes,
360 operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
361 underlying_buffers_.insert(
362 {operation_id,
363 core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id],
364 region, num_bytes, {})});
365 return underlying_buffers_[operation_id]->OnReady();
366 },
367 deps);
368
369 return absl::make_unique<PodBufferHandle>(this, operation_id, num_bytes,
370 absl::nullopt, core_id);
371 }
372
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)373 std::unique_ptr<BufferHandle> Allocate(
374 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
375 absl::Span<Event* const> wait_for) override {
376 int64_t operation_id = GetOperationId();
377 auto deps = GetDependencyOperationIds(wait_for);
378
379 ScheduleRequest(
380 operation_id,
381 [this, core_id, region, shape,
382 operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
383 underlying_buffers_.insert(
384 {operation_id,
385 core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id],
386 region, shape, {})});
387 return underlying_buffers_[operation_id]->OnReady();
388 },
389 deps);
390
391 return absl::make_unique<PodBufferHandle>(
392 this, operation_id, ComputeBytesFromShape(shape), shape, core_id);
393 }
394
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)395 std::unique_ptr<BufferHandle> AllocateTuple(
396 int32_t core_id, MemoryRegion region,
397 absl::Span<BufferHandle* const> children,
398 absl::Span<Event* const> wait_for) override {
399 int64_t operation_id = GetOperationId();
400 auto deps = GetDependencyOperationIds(wait_for);
401
402 std::vector<int64_t> children_ids;
403 for (int i = 0; i < children.size(); ++i) {
404 auto child_op_id =
405 static_cast<PodBufferHandle* const>(children[i])->operation_id();
406 deps.insert(child_op_id);
407 children_ids.push_back(child_op_id);
408 }
409
410 ScheduleRequest(
411 operation_id,
412 [this, core_id, region, children_ids,
413 operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_)
414 -> std::shared_ptr<Event> {
415 std::vector<BufferHandle*> child_buffers;
416 child_buffers.reserve(children_ids.size());
417 for (int i = 0; i < children_ids.size(); ++i) {
418 CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i],
419 operation_id);
420 child_buffers.push_back(underlying_buffers_[children_ids[i]].get());
421 }
422
423 underlying_buffers_.insert(
424 {operation_id,
425 core_to_driver_[core_id]->AllocateTuple(
426 core_to_driver_core_[core_id], region, child_buffers, {})});
427 return underlying_buffers_[operation_id]->OnReady();
428 },
429 deps);
430
431 return absl::make_unique<PodBufferHandle>(this, operation_id, 0,
432 absl::nullopt, core_id);
433 }
434
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)435 std::shared_ptr<Event> Deallocate(
436 std::unique_ptr<BufferHandle> handle,
437 absl::Span<Event* const> wait_for) override {
438 int64_t operation_id = GetOperationId();
439 auto deps = GetDependencyOperationIds(wait_for);
440 deps.insert(static_cast<PodBufferHandle*>(handle.get())->operation_id());
441
442 auto op_id = static_cast<PodBufferHandle*>(handle.get())->operation_id();
443 auto core_id = static_cast<PodBufferHandle*>(handle.get())->core_id();
444
445 ScheduleRequest(
446 operation_id,
447 [this, operation_id, op_id,
448 core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
449 CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
450
451 auto buf_iter = underlying_buffers_.find(op_id);
452 auto underlying_hn = std::move(buf_iter->second);
453 underlying_buffers_.erase(buf_iter);
454
455 return core_to_driver_[core_id]->Deallocate(std::move(underlying_hn),
456 {});
457 },
458 deps);
459
460 return std::make_shared<PodEvent>(this, operation_id);
461 }
462
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)463 std::shared_ptr<Event> TransferToDevice(
464 const void* src, BufferHandle* dst,
465 absl::Span<Event* const> wait_for) override {
466 int64_t operation_id = GetOperationId();
467 auto deps = GetDependencyOperationIds(wait_for);
468 deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
469
470 auto op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
471 auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
472
473 ScheduleRequest(
474 operation_id,
475 [this, src, operation_id, op_id,
476 core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
477 CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
478
479 auto buf_iter = underlying_buffers_.find(op_id);
480 return core_to_driver_[core_id]->TransferToDevice(
481 src, buf_iter->second.get(), {});
482 },
483 deps);
484
485 return std::make_shared<PodEvent>(this, operation_id);
486 }
487
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)488 std::shared_ptr<Event> TransferFromDevice(
489 const BufferHandle* src, void* dst,
490 absl::Span<Event* const> wait_for) override {
491 int64_t operation_id = GetOperationId();
492 auto deps = GetDependencyOperationIds(wait_for);
493 deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
494
495 auto op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
496 auto core_id = static_cast<const PodBufferHandle*>(src)->core_id();
497
498 ScheduleRequest(
499 operation_id,
500 [this, dst, operation_id, op_id,
501 core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
502 CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
503 auto buf_iter = underlying_buffers_.find(op_id);
504 return core_to_driver_[core_id]->TransferFromDevice(
505 buf_iter->second.get(), dst, {});
506 },
507 deps);
508
509 return std::make_shared<PodEvent>(this, operation_id);
510 }
511
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)512 std::shared_ptr<Event> TransferFromDeviceToDevice(
513 const BufferHandle* src, BufferHandle* dst,
514 absl::Span<Event* const> wait_for) override {
515 auto src_core_id = static_cast<const PodBufferHandle*>(src)->core_id();
516 auto dst_core_id = static_cast<PodBufferHandle*>(dst)->core_id();
517
518 auto src_driver_id = core_to_driver_id_[src_core_id];
519 auto dst_driver_id = core_to_driver_id_[dst_core_id];
520
521 if (src_driver_id == dst_driver_id) {
522 // They are in the same host, we can schedule it normally
523 int64_t operation_id = GetOperationId();
524 auto deps = GetDependencyOperationIds(wait_for);
525 deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
526 deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
527
528 auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
529 auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
530
531 ScheduleRequest(
532 operation_id,
533 [this, operation_id, src_op_id, dst_op_id, dst_core_id]()
534 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
535 CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id,
536 operation_id);
537 CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id,
538 operation_id);
539
540 auto src_iter = underlying_buffers_.find(src_op_id);
541 auto dst_iter = underlying_buffers_.find(dst_op_id);
542 return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice(
543 src_iter->second.get(), dst_iter->second.get(), {});
544 },
545 deps);
546 return std::make_shared<PodEvent>(this, operation_id);
547 } else {
548 // src and dst are on different hosts, we have to bounce through us.
549 auto dst_size = dst->size_in_bytes();
550 char* host_buf = new char[dst_size];
551
552 auto src_event = TransferFromDevice(src, host_buf, wait_for);
553 auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()});
554 dst_event->AddCallback(
555 [src_event, host_buf](xla::Status status) { delete[] host_buf; });
556 return dst_event;
557 }
558 }
559
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)560 std::unique_ptr<CompiledProgramHandle> CompileProgram(
561 const xla::HloProto& source, int32_t num_replicas,
562 absl::Span<Event* const> wait_for) override {
563 int64_t operation_id = GetOperationId();
564 auto deps = GetDependencyOperationIds(wait_for);
565
566 ScheduleRequest(
567 operation_id,
568 [this, operation_id, source,
569 num_replicas]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
570 auto cph_iterator =
571 underlying_cph_
572 .insert(
573 {operation_id,
574 std::vector<std::unique_ptr<CompiledProgramHandle>>()})
575 .first;
576
577 std::vector<std::shared_ptr<Event>> collected_events;
578 for (int i = 0; i < drivers_.size(); ++i) {
579 auto current_cph =
580 drivers_[i]->CompileProgram(source, num_replicas, {});
581 cph_iterator->second.push_back(std::move(current_cph));
582 collected_events.push_back(cph_iterator->second[i]->OnReady());
583 }
584 return std::make_shared<CombinedEvent>(this, operation_id,
585 collected_events);
586 },
587 deps);
588
589 return absl::make_unique<PodCompiledProgramHandle>(this, operation_id);
590 }
591
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)592 std::unique_ptr<LoadedProgramHandle> LoadProgram(
593 int32_t core_id, const CompiledProgramHandle* handle,
594 absl::Span<Event* const> wait_for) override {
595 int64_t operation_id = GetOperationId();
596 auto deps = GetDependencyOperationIds(wait_for);
597 deps.insert(
598 static_cast<const PodCompiledProgramHandle*>(handle)->operation_id());
599 auto cph_op_id =
600 static_cast<const PodCompiledProgramHandle*>(handle)->operation_id();
601
602 ScheduleRequest(
603 operation_id,
604 [this, operation_id, cph_op_id,
605 core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
606 CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id);
607 auto cph_iter = underlying_cph_.find(cph_op_id);
608
609 underlying_lph_.insert(
610 {operation_id,
611 core_to_driver_[core_id]->LoadProgram(
612 core_to_driver_core_[core_id],
613 cph_iter->second[core_to_driver_id_[core_id]].get(), {})});
614
615 return underlying_lph_[operation_id]->OnReady();
616 },
617 deps);
618
619 return absl::make_unique<PodLoadedProgramHandle>(this, operation_id,
620 core_id);
621 }
622
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)623 std::shared_ptr<Event> UnloadProgram(
624 std::unique_ptr<LoadedProgramHandle> handle,
625 absl::Span<Event* const> wait_for) override {
626 int64_t operation_id = GetOperationId();
627 auto deps = GetDependencyOperationIds(wait_for);
628 deps.insert(
629 static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id());
630 auto op_id =
631 static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id();
632 auto core_id =
633 static_cast<PodLoadedProgramHandle*>(handle.get())->core_id();
634
635 ScheduleRequest(
636 operation_id,
637 [this, operation_id, op_id,
638 core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
639 CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
640 auto lph_iter = underlying_lph_.find(op_id);
641 auto event = core_to_driver_[core_id]->UnloadProgram(
642 std::move(lph_iter->second), {});
643 underlying_lph_.erase(lph_iter);
644
645 return event;
646 },
647 deps);
648
649 return std::make_shared<PodEvent>(this, operation_id);
650 }
651
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)652 std::shared_ptr<Event> ExecuteProgram(
653 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
654 absl::Span<BufferHandle* const> outputs,
655 const xla::DeviceAssignmentProto& device_assignment,
656 absl::Span<Event* const> wait_for) override {
657 int64_t operation_id = GetOperationId();
658
659 auto deps = GetDependencyOperationIds(wait_for);
660 deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
661
662 auto op_id = static_cast<PodLoadedProgramHandle*>(program)->operation_id();
663 auto core_id = static_cast<PodLoadedProgramHandle*>(program)->core_id();
664
665 std::vector<int64_t> input_op_ids;
666 std::vector<int64_t> output_op_ids;
667
668 for (auto* input : inputs) {
669 auto input_dep =
670 static_cast<PodBufferHandle* const>(input)->operation_id();
671 input_op_ids.push_back(input_dep);
672 deps.insert(input_dep);
673 }
674 for (auto* output : outputs) {
675 auto output_dep =
676 static_cast<PodBufferHandle* const>(output)->operation_id();
677 output_op_ids.push_back(output_dep);
678 deps.insert(output_dep);
679 }
680
681 ScheduleRequest(
682 operation_id,
683 [this, operation_id, core_id, op_id, input_op_ids, output_op_ids,
684 device_assignment]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_)
685 -> std::shared_ptr<Event> {
686 std::vector<BufferHandle*> underlying_inputs;
687 std::vector<BufferHandle*> underlying_outputs;
688
689 underlying_inputs.reserve(input_op_ids.size());
690 for (auto input_op_id : input_op_ids) {
691 CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id,
692 operation_id);
693 underlying_inputs.push_back(underlying_buffers_[input_op_id].get());
694 }
695 underlying_outputs.reserve(output_op_ids.size());
696 for (auto output_op_id : output_op_ids) {
697 CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id,
698 operation_id);
699 underlying_outputs.push_back(
700 underlying_buffers_[output_op_id].get());
701 }
702
703 CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
704 LoadedProgramHandle* handle = underlying_lph_[op_id].get();
705 return core_to_driver_[core_id]->ExecuteProgram(
706 handle, underlying_inputs, underlying_outputs, device_assignment,
707 {});
708 },
709 deps);
710
711 return std::make_shared<PodEvent>(this, operation_id);
712 }
713
GetLinearizer()714 std::unique_ptr<TpuLinearizer> GetLinearizer() override {
715 return drivers_[0]->GetLinearizer();
716 }
717
718 // Helper methods for Event scheduling
719
WaitForEvent(int64_t event_id,absl::Duration duration)720 absl::optional<Status> WaitForEvent(int64_t event_id, absl::Duration duration)
721 TF_LOCKS_EXCLUDED(mu_) {
722 std::shared_ptr<Event> underlying_event;
723
724 {
725 absl::MutexLock l(&mu_);
726 auto event = events_.find(event_id);
727
728 if (event == events_.end()) {
729 auto event_status = abnormal_event_status_.find(event_id);
730 if (event_status == abnormal_event_status_.end()) {
731 return Status::OK();
732 } else {
733 return event_status->second;
734 }
735 }
736
737 auto done = [this, event_id]() {
738 mu_.AssertHeld();
739 // The event was either completed and erased from the map or we have
740 // an underlying event available to us.
741 return events_.count(event_id) == 0 ||
742 (events_[event_id]->underlying_event != nullptr &&
743 events_[event_id]->underlying_event.use_count() != 0);
744 };
745
746 auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration);
747 if (!status) {
748 return absl::nullopt;
749 }
750
751 if (events_.count(event_id) > 0) {
752 underlying_event = events_[event_id]->underlying_event;
753 } else {
754 underlying_event = nullptr;
755 }
756 }
757
758 // Wait for the underlying event without holding on to the event_lock_, or
759 // else incoming events will not be processed.
760 if (underlying_event != nullptr) {
761 return underlying_event->AwaitWithTimeout(duration);
762 } else {
763 absl::MutexLock l(&mu_);
764 auto event_status = abnormal_event_status_.find(event_id);
765 if (event_status == abnormal_event_status_.end()) {
766 return Status::OK();
767 } else {
768 return event_status->second;
769 }
770 }
771 }
772
AddCallbackForEvent(int64_t event_id,std::function<void (Status)> fn)773 void AddCallbackForEvent(int64_t event_id, std::function<void(Status)> fn)
774 TF_LOCKS_EXCLUDED(mu_) {
775 absl::MutexLock l(&mu_);
776 auto event = events_.find(event_id);
777
778 if (event == events_.end()) {
779 auto event_status = abnormal_event_status_.find(event_id);
780 if (event_status == abnormal_event_status_.end()) {
781 fn(Status::OK());
782 } else {
783 fn(event_status->second);
784 }
785 } else {
786 if (event->second->underlying_event != nullptr &&
787 event->second->underlying_event.use_count() != 0) {
788 event->second->underlying_event->AddCallback(fn);
789 } else {
790 event->second->callbacks.push_back(std::move(fn));
791 }
792 }
793 }
794
GetCompiledProgramShape(int64_t op_id,xla::ProgramShapeProto * program_shape)795 xla::Status GetCompiledProgramShape(int64_t op_id,
796 xla::ProgramShapeProto* program_shape)
797 TF_LOCKS_EXCLUDED(mu_) {
798 absl::MutexLock l(&mu_);
799
800 auto done = [this, op_id]() {
801 mu_.AssertHeld();
802 return underlying_cph_.contains(op_id);
803 };
804 mu_.Await(absl::Condition(&done));
805
806 return underlying_cph_[op_id][0]->program_shape(program_shape);
807 }
808
809 private:
810 const TpuDriverConfig& config_;
811 std::shared_ptr<::grpc::ChannelCredentials> creds_;
812
813 absl::flat_hash_map<int32_t, std::unique_ptr<TpuDriver>> drivers_;
814 absl::flat_hash_map<int32_t, int32_t> core_to_driver_id_;
815 absl::flat_hash_map<int32_t, TpuDriver*> core_to_driver_;
816 absl::flat_hash_map<int32_t, int32_t> core_to_driver_core_;
817 SystemInfo pod_info_;
818
819 absl::Mutex mu_;
820
821 absl::flat_hash_map<int64_t, std::unique_ptr<BufferHandle>>
822 underlying_buffers_ ABSL_GUARDED_BY(mu_);
823 absl::flat_hash_map<int64_t,
824 std::vector<std::unique_ptr<CompiledProgramHandle>>>
825 underlying_cph_ ABSL_GUARDED_BY(mu_);
826 absl::flat_hash_map<int64_t, std::unique_ptr<LoadedProgramHandle>>
827 underlying_lph_ ABSL_GUARDED_BY(mu_);
828
829 absl::btree_map<int64_t, std::unique_ptr<EventInFlight>> events_
830 ABSL_GUARDED_BY(mu_);
831 absl::flat_hash_map<int64_t, Status> abnormal_event_status_
832 ABSL_GUARDED_BY(mu_);
833
834 std::atomic<int64_t> operation_id_counter_{0};
835
836 WorkerThread event_thread_;
837
GetOperationId()838 int64_t GetOperationId() { return operation_id_counter_++; }
839
GetDependencyOperationIds(absl::Span<Event * const> wait_for)840 absl::flat_hash_set<int64_t> GetDependencyOperationIds(
841 absl::Span<Event* const> wait_for) {
842 absl::flat_hash_set<int64_t> deps;
843 for (auto* event : wait_for) {
844 deps.insert(static_cast<PodEvent* const>(event)->operation_id());
845 }
846 return deps;
847 }
848
849 // EventCompleted is executed on the event_thread_ worker thread. We want
850 // to propagate the fact that the event is completed to any subsequent events
851 // that might depend on this event.
EventCompleted(int64_t event_id,Status status)852 void EventCompleted(int64_t event_id, Status status) TF_LOCKS_EXCLUDED(mu_) {
853 absl::MutexLock l(&mu_);
854
855 absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator
856 curr_event;
857 if (!status.ok()) abnormal_event_status_.insert({event_id, status});
858 curr_event = events_.find(event_id);
859
860 DCHECK(curr_event->second->callbacks.empty());
861 DCHECK(curr_event->second->incomplete_deps.empty());
862
863 for (auto& event : events_) {
864 event.second->incomplete_deps.erase(event_id);
865 // The if statement conditions on both
866 // - all previous events have completed (incomplete_deps.empty())
867 // - the op creating this event has not been called yet
868 // (event.second.create_fn != nullptr)
869 // We call the create_fn that creates the event and adds any relevant
870 // callbacks to the actual event, before setting create_fn to nullptr
871 // to indicate that it has already been called
872 if (event.second->incomplete_deps.empty() &&
873 event.second->create_fn != nullptr) {
874 // We were the last unfilled dependency, all other dependencies are
875 // filled. We can now fire the create function.
876 event.second->underlying_event = event.second->create_fn();
877 for (auto& fn : event.second->callbacks) {
878 event.second->underlying_event->AddCallback(std::move(fn));
879 }
880 event.second->callbacks.clear();
881 event.second->create_fn = nullptr;
882 }
883 }
884
885 // We erase the current event to signal that it has finished.
886 events_.erase(curr_event);
887 }
888
ScheduleRequest(int64_t operation_id,std::function<std::shared_ptr<Event> (void)> fn,const absl::flat_hash_set<int64_t> & deps)889 void ScheduleRequest(int64_t operation_id,
890 std::function<std::shared_ptr<Event>(void)> fn,
891 const absl::flat_hash_set<int64_t>& deps)
892 TF_LOCKS_EXCLUDED(mu_) {
893 absl::MutexLock l(&mu_);
894 absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator event;
895 absl::flat_hash_set<int64_t> incomplete_deps;
896
897 event = events_.insert({operation_id, absl::make_unique<EventInFlight>()})
898 .first;
899 for (const auto& dep : deps) {
900 if (events_.count(dep) > 0) incomplete_deps.insert(dep);
901 }
902
903 if (incomplete_deps.empty()) {
904 // All dependencies have been fulfilled, we execute the request
905 // immediately and add a callback to inform our event fulfilled thread
906 // when it is done.
907 event->second->create_fn = nullptr;
908 event->second->underlying_event = fn();
909 event->second->underlying_event->AddCallback(
910 [this, operation_id](Status status) {
911 event_thread_.Schedule([this, operation_id, status]() {
912 EventCompleted(operation_id, status);
913 });
914 });
915 } else {
916 // There are some dependencies that are not yet fulfilled. We attach
917 // the request to the event, and will execute it in the EventFulfilled
918 // worker thread when all its dependencies are fulfilled.
919 event->second->create_fn = std::move(fn);
920 event->second->incomplete_deps = std::move(incomplete_deps);
921 event->second->callbacks.push_back([this, operation_id](Status status) {
922 event_thread_.Schedule([this, operation_id, status]() {
923 EventCompleted(operation_id, status);
924 });
925 });
926 }
927 }
928
929 template <typename T>
CheckHandleExists(absl::flat_hash_map<int64_t,T> & container,int64_t target_op_id,int64_t operation_id)930 std::shared_ptr<Event> CheckHandleExists(
931 absl::flat_hash_map<int64_t, T>& container, int64_t target_op_id,
932 int64_t operation_id) {
933 if (container.count(target_op_id) == 0) {
934 return std::make_shared<ErrorEvent>(
935 this, operation_id,
936 tensorflow::errors::InvalidArgument("Handle ", target_op_id,
937 " does not exist."));
938 }
939 return nullptr;
940 }
941 };
942
Await()943 xla::Status PodEvent::Await() {
944 return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value();
945 }
946
AwaitWithTimeout(absl::Duration duration)947 absl::optional<xla::Status> PodEvent::AwaitWithTimeout(
948 absl::Duration duration) {
949 return driver_->WaitForEvent(operation_id_, duration);
950 }
951
AddCallback(std::function<void (Status)> callback)952 void PodEvent::AddCallback(std::function<void(Status)> callback) {
953 driver_->AddCallbackForEvent(operation_id_, std::move(callback));
954 }
955
CreatePodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)956 xla::StatusOr<std::unique_ptr<TpuDriver>> CreatePodTpuDriver(
957 const TpuDriverConfig& config,
958 std::shared_ptr<::grpc::ChannelCredentials> creds) {
959 return std::unique_ptr<TpuDriver>(new PodTpuDriver(config, creds));
960 }
961
program_shape(xla::ProgramShapeProto * program_shape)962 xla::Status PodCompiledProgramHandle::program_shape(
963 xla::ProgramShapeProto* program_shape) {
964 return driver_->GetCompiledProgramShape(operation_id(), program_shape);
965 }
966
967 } // namespace
968
969 REGISTER_TPU_DRIVER(kPodTpuDriverPrefix,
970 [](const TpuDriverConfig& config)
__anon6b014a981202(const TpuDriverConfig& config) 971 -> xla::StatusOr<std::unique_ptr<TpuDriver>> {
972 return CreatePodTpuDriver(
973 config,
974 ::grpc::InsecureChannelCredentials()); // NOLINT
975 });
976
977 } // namespace tpu_driver
978