• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/synchronization/mutex.h"
25 #include "absl/time/time.h"
26 #include "absl/types/span.h"
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
30 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
31 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
32 #include "tensorflow/compiler/xla/service/computation_placer.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/threadpool.h"
37 #include "tensorflow/core/profiler/lib/traceme.h"
38 
39 namespace xla {
40 
TpuDevice(int id,int process_index,const std::array<int,3> & coords,int core_on_chip)41 TpuDevice::TpuDevice(int id, int process_index,
42                      const std::array<int, 3>& coords, int core_on_chip)
43     : id_(id),
44       process_index_(process_index),
45       coords_(coords),
46       core_on_chip_(core_on_chip) {
47   debug_string_ =
48       absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id_, process_index_,
49                       coords_[0], coords_[1], coords_[2], core_on_chip_);
50   to_string_ = absl::StrFormat(
51       "TpuDevice(id=%i, process_index=%i, coords=(%s), core_on_chip=%i)", id_,
52       process_index_, absl::StrJoin(coords_, ","), core_on_chip_);
53 }
54 
DebugString() const55 absl::string_view TpuDevice::DebugString() const { return debug_string_; }
56 
ToString() const57 absl::string_view TpuDevice::ToString() const { return to_string_; }
58 
59 xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
GetTpuDevices(const tpu_driver::SystemInfo & system_info)60 TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
61   std::vector<std::shared_ptr<PjRtDevice>> devices;
62   for (const auto& chip : system_info.tpu_chip()) {
63     auto& coord = chip.chip_coord();
64     std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
65     int process_index = chip.host_id();
66     for (const auto& core : chip.core()) {
67       auto device = std::make_shared<TpuDevice>(
68           core.id(), process_index, coords_array, core.core_on_chip_index());
69       devices.push_back(device);
70     }
71   }
72 
73   return devices;
74 }
75 
Get(const std::string & worker)76 StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
77     const std::string& worker) {
78   tpu_driver::TpuDriverConfig driver_config;
79   driver_config.set_worker(worker);
80   auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
81   if (!client_status.ok()) {
82     return client_status.status();
83   }
84 
85   auto client = std::move(client_status).value();
86 
87   tpu_driver::SystemInfo system_info;
88   client->QuerySystemInfo(&system_info);
89 
90   TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<PjRtDevice>> devices,
91                       TpuDevice::GetTpuDevices(system_info));
92 
93   return std::make_shared<PyTpuClient>(TpuPlatform(), std::move(client),
94                                        std::move(devices),
95                                        system_info.host_id());
96 }
97 
PyTpuClient(std::string platform_name,std::unique_ptr<tpu_driver::TpuDriver> driver,std::vector<std::shared_ptr<PjRtDevice>> devices,int process_index)98 PyTpuClient::PyTpuClient(std::string platform_name,
99                          std::unique_ptr<tpu_driver::TpuDriver> driver,
100                          std::vector<std::shared_ptr<PjRtDevice>> devices,
101                          int process_index)
102     : platform_name_(std::move(platform_name)),
103       platform_version_("tpu_driver (deprecated)"),
104       driver_(std::move(driver)),
105       devices_(std::move(devices)),
106       process_index_(process_index) {
107   for (const std::shared_ptr<PjRtDevice>& device : devices_) {
108     tensorflow::down_cast<TpuDevice*>(device.get())->set_tpu_client(this);
109     CHECK(id_to_device_.insert({device->id(), device}).second)
110         << "Duplicate device id: " << device->id();
111 
112     if (device->process_index() == process_index_) {
113       LOG(INFO) << "Detected local device, host id: " << process_index_
114                 << ". device id: " << device->id();
115       local_devices_.push_back(device);
116     } else {
117       VLOG(2) << "Other devices, device id: " << device->id();
118     }
119   }
120   CHECK_GE(local_devices_.size(), 1);
121   LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
122 
123   for (int idx = 0; idx < local_devices_.size(); ++idx) {
124     CHECK(local_devices_[idx] != nullptr) << idx;
125   }
126 
127   // TODO(frankchn): Check if thread pool size needs to be adjusted (perhaps
128   // something like min(cores, devices_.size()) might be appropriate depending
129   // on the number of devices.
130   pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
131       tensorflow::Env::Default(), "PyTpuClient", devices_.size());
132 }
133 
TransferToInfeed(const LiteralSlice & literal,int device_id)134 Status PyTpuClient::TransferToInfeed(const LiteralSlice& literal,
135                                      int device_id) {
136   return Unimplemented("Infeed not implemented.");
137 }
138 
TransferFromOutfeed(const Shape & shape,int device_id)139 StatusOr<Literal> PyTpuClient::TransferFromOutfeed(const Shape& shape,
140                                                    int device_id) {
141   return Unimplemented("Outfeed not implemented.");
142 }
143 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const144 StatusOr<DeviceAssignment> PyTpuClient::GetDefaultDeviceAssignment(
145     int num_replicas, int num_partitions) const {
146   if (num_partitions > 1) {
147     return InvalidArgument("Num partitions greater than 1, is not supported.");
148   }
149   if (num_replicas * num_partitions <= local_device_count()) {
150     DeviceAssignment assignment(num_replicas, num_partitions);
151     for (int replica = 0; replica < num_replicas; ++replica) {
152       for (int partition = 0; partition < num_partitions; ++partition) {
153         assignment(replica, partition) = local_devices_[replica]->id();
154       }
155     }
156     return assignment;
157   }
158 
159   // Fallback to default global device assignment if we can't run locally.
160   xla::ComputationPlacer placer;
161   return placer.AssignDevices(num_replicas, num_partitions);
162 }
163 
CheckDeviceId(int device_id,absl::string_view caller_name)164 Status PyTpuClient::CheckDeviceId(int device_id,
165                                   absl::string_view caller_name) {
166   if (device_id < 0 || device_id >= device_count()) {
167     return InvalidArgument("%s got bad device_id: %d (num_devices=%d).",
168                            caller_name, device_id, device_count());
169   }
170   return OkStatus();
171 }
172 
CheckDataType(xla::PrimitiveType dtype)173 static Status CheckDataType(xla::PrimitiveType dtype) {
174   if (dtype == xla::PrimitiveType::F64 || dtype == xla::PrimitiveType::S64 ||
175       dtype == xla::PrimitiveType::U64) {
176     return InvalidArgument(
177         "64-bit data types are not yet supported on the TPU driver API. "
178         "Convert inputs to float32/int32_t before using.");
179   }
180   return OkStatus();
181 }
182 
183 /* static */
FromLiterals(std::vector<BorrowingLiteral> leaves,const Shape & tuple_shape,std::shared_ptr<void> leaves_references,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)184 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
185     std::vector<BorrowingLiteral> leaves, const Shape& tuple_shape,
186     std::shared_ptr<void> leaves_references,
187     std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
188   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals");
189   VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString()
190           << " device: " << device->DebugString();
191   TF_RETURN_IF_ERROR(
192       client->CheckDeviceId(device->id(), "PyTpuBuffer::FromLiterals"));
193   tpu_driver::TpuDriver* driver = client->driver();
194 
195   if (!tuple_shape.IsTuple()) {
196     TF_RET_CHECK(leaves.size() == 1);
197     return CreateBuffer(
198         tuple_shape,
199         [driver, &leaves, &tuple_shape,
200          leaves_references](tpu_driver::BufferHandle* handle) {
201           auto event =
202               driver->TransferToDevice(leaves[0].untyped_data(), handle, {});
203           event->AddCallback([leaves_references](Status) {});
204           return event;
205         },
206         std::move(client), std::move(device));
207   }
208 
209   std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
210   child_buffers.reserve(leaves.size());
211   std::vector<PyTpuBuffer*> child_buffer_ptrs;
212   child_buffer_ptrs.reserve(leaves.size());
213 
214   auto it_leaf = leaves.begin();
215   for (const ShapeUtil::IndexedShape& indexed_shape :
216        ShapeUtil::GetLeafShapes(tuple_shape)) {
217     TF_RET_CHECK(it_leaf != leaves.end());
218     auto& leaf = *it_leaf;
219     TF_ASSIGN_OR_RETURN(
220         std::unique_ptr<PyTpuBuffer> child_buffer,
221         CreateBuffer(
222             indexed_shape.shape,
223             [driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) {
224               return driver->TransferToDevice(leaf.untyped_data(), handle, {});
225             },
226             client, device));
227     child_buffer_ptrs.push_back(child_buffer.get());
228     child_buffers.push_back(std::move(child_buffer));
229     ++it_leaf;
230   }
231   TF_RET_CHECK(it_leaf == leaves.end());
232 
233   // `MakeTuple` will extract and make the tuple buffer hold onto the
234   // `device_buffer_` contained in each `child_buffer`, so it's safe for
235   // `child_buffers` to get destroyed before this call returns.
236   return MakeTuple(std::move(child_buffer_ptrs), std::move(client),
237                    std::move(device));
238 }
239 
240 /* static */
MakeTuple(absl::Span<PyTpuBuffer * const> buffers,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)241 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
242     absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
243     std::shared_ptr<PjRtDevice> device) {
244   std::vector<Shape> child_shapes;
245   std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
246   std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
247   std::vector<std::shared_ptr<tpu_driver::Event>> child_events;
248 
249   for (const auto& child_buffer : buffers) {
250     child_shapes.push_back(child_buffer->on_host_shape());
251     std::shared_ptr<TpuSharedBuffer> child_device_buffer =
252         child_buffer->DeviceBuffer();
253     // Merge all definition events from all children, so that anyone using this
254     // tuple must wait for all its children to finish receiving transfers. This
255     // works recursively up a nested tuple tree as well.
256     for (std::shared_ptr<tpu_driver::Event> child_event :
257          child_device_buffer->wait_for_use) {
258       child_events.push_back(std::move(child_event));
259     }
260     child_handle_ptrs.push_back(child_device_buffer->handle.get());
261     child_device_buffers.push_back(std::move(child_device_buffer));
262   }
263 
264   Shape tuple_shape = ShapeUtil::MakeTupleShape(child_shapes);
265   std::unique_ptr<tpu_driver::BufferHandle> tuple_handle =
266       client->driver()->AllocateTuple(
267           device->id(), tpu_driver::MemoryRegion::HBM, child_handle_ptrs, {});
268   auto tuple_device_buffer = std::make_shared<TpuSharedBuffer>(
269       client->driver(), std::move(tuple_handle), std::move(child_events),
270       std::move(device));
271   return std::make_unique<PyTpuBuffer>(
272       tuple_shape, std::move(tuple_device_buffer),
273       std::move(child_device_buffers), std::move(client));
274 }
275 
PyTpuBuffer(Shape on_host_shape,std::shared_ptr<TpuSharedBuffer> device_buffer,std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,std::shared_ptr<PyTpuClient> client)276 PyTpuBuffer::PyTpuBuffer(
277     Shape on_host_shape, std::shared_ptr<TpuSharedBuffer> device_buffer,
278     std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
279     std::shared_ptr<PyTpuClient> client)
280     : client_(std::move(client)),
281       on_host_shape_(std::move(on_host_shape)),
282       device_(device_buffer->device),
283       device_buffer_(std::move(device_buffer)),
284       child_buffers_(std::move(child_buffers)) {}
285 
Delete()286 void PyTpuBuffer::Delete() {
287   absl::MutexLock lock(&mu_);
288   device_buffer_ = nullptr;
289   child_buffers_.clear();
290   host_value_ = nullptr;
291 }
292 
CopyToHostAsync()293 Status PyTpuBuffer::CopyToHostAsync() {
294   std::vector<std::shared_ptr<tpu_driver::Event>> transfer_events;
295   std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
296 
297   {
298     absl::MutexLock lock(&mu_);
299     if (!device_buffer_) {
300       return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
301     }
302 
303     if (host_value_) {
304       // The host value has already been requested or is available.
305       return OkStatus();
306     }
307 
308     host_value->value = std::make_shared<Literal>(on_host_shape_);
309     host_value->pending_ops = std::max(1ul, child_buffers_.size());
310     host_value_ = host_value;
311 
312     std::vector<tpu_driver::Event*> events;
313     for (const auto& e : device_buffer_->wait_for_use) {
314       events.push_back(e.get());
315     }
316 
317     VLOG(1) << "CopyToHostAsync:: host shape: "
318             << host_value->value->shape().DebugString();
319 
320     if (!on_host_shape_.IsTuple()) {
321       CHECK(child_buffers_.empty());
322       transfer_events.push_back(client_->driver()->TransferFromDevice(
323           device_buffer_->handle.get(), host_value->value->untyped_data(),
324           events));
325     } else {
326       for (int i = 0; i < child_buffers_.size(); ++i) {
327         auto& c = child_buffers_[i];
328         transfer_events.push_back(client_->driver()->TransferFromDevice(
329             c->handle.get(),
330             host_value->value->untyped_data(xla::ShapeIndex({i})), events));
331       }
332     }
333   }
334 
335   for (auto& t : transfer_events) {
336     t->AddCallback([host_value](const xla::Status& status) {
337       VLOG(1) << "Device to host transfer finished.";
338       if (!status.ok()) {
339         host_value->status =
340             Status(static_cast<tensorflow::error::Code>(status.code()),
341                    status.error_message());
342       }
343 
344       absl::MutexLock m(&host_value->mutex);
345       --host_value->pending_ops;
346       if (host_value->pending_ops == 0) {
347         VLOG(1) << "Host value done: " << host_value->status;
348         host_value->ready.Notify();
349       }
350     });
351   }
352   return OkStatus();
353 }
354 
ToLiteral()355 StatusOr<std::shared_ptr<Literal>> PyTpuBuffer::ToLiteral() {
356   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::ToLiteral");
357   TF_RETURN_IF_ERROR(CopyToHostAsync());
358 
359   mu_.Lock();
360   std::shared_ptr<HostValue> host_value = host_value_;
361   mu_.Unlock();
362 
363   VLOG(1) << "Waiting for device to host transfer " << host_value.get();
364   host_value->ready.WaitForNotification();
365   VLOG(1) << "Host copy finished, status:" << host_value->status;
366   TF_RETURN_IF_ERROR(host_value->status);
367 
368   return host_value->value;
369 }
370 
DeviceBuffer() const371 std::shared_ptr<TpuSharedBuffer> PyTpuBuffer::DeviceBuffer() const {
372   absl::MutexLock lock(&mu_);
373   return device_buffer_;
374 }
375 
376 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
DestructureTuple()377 PyTpuBuffer::DestructureTuple() {
378   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::DestructureTuple");
379   if (!on_host_shape_.IsTuple()) {
380     return InvalidArgument(
381         "Attempted to destructure a PyTpuBuffer that did not have a tuple "
382         "shape; shape: %s.",
383         ShapeUtil::HumanString(on_host_shape_));
384   }
385   if (DeviceBuffer() == nullptr) {
386     return InvalidArgument("Attempted to destructure a deleted buffer.");
387   }
388 
389   absl::MutexLock lock(&mu_);
390   int num_children = ShapeUtil::TupleElementCount(on_host_shape_);
391   std::vector<std::unique_ptr<PyTpuBuffer>> results;
392   results.reserve(num_children);
393   for (int i = 0; i < num_children; ++i) {
394     results.push_back(std::make_unique<PyTpuBuffer>(
395         on_host_shape_.tuple_shapes(i), child_buffers_.at(i),
396         std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_));
397   }
398   return results;
399 }
400 
CopyToDevice(std::shared_ptr<PjRtDevice> dst_device)401 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
402     std::shared_ptr<PjRtDevice> dst_device) {
403   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice");
404   if (on_host_shape_.IsTuple()) {
405     return Unimplemented("CopyToDevice for tuples is not supported.");
406   }
407 
408   std::shared_ptr<TpuSharedBuffer> src_device_buffer = DeviceBuffer();
409   if (dst_device->id() == device_->id()) {
410     return std::make_unique<PyTpuBuffer>(
411         on_host_shape_, src_device_buffer,
412         std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_);
413   }
414 
415   tpu_driver::TpuDriver* driver = client_->driver();
416   TF_ASSIGN_OR_RETURN(
417       std::unique_ptr<PyTpuBuffer> dst_buffer,
418       CreateBuffer(
419           on_host_shape_,
420           [driver, src_device_buffer](tpu_driver::BufferHandle* dst_handle) {
421             std::vector<tpu_driver::Event*> src_wait_for_use;
422             for (auto& event : src_device_buffer->wait_for_use) {
423               src_wait_for_use.push_back(event.get());
424             }
425             return driver->TransferFromDeviceToDevice(
426                 src_device_buffer->handle.get(), dst_handle, src_wait_for_use);
427           },
428           client_, std::move(dst_device)));
429   // TODO(jiawenhao): This may be too pessimistic: it prevents future readers
430   // from reading `src_device_buffer` until the device-to-device copy is done.
431   // Should this go into a new `TpuSharedBuffer::wait_for_dealloc` field?
432   auto& wait_for_use = dst_buffer->DeviceBuffer()->wait_for_use;
433   src_device_buffer->wait_for_use.insert(src_device_buffer->wait_for_use.end(),
434                                          wait_for_use.begin(),
435                                          wait_for_use.end());
436   return dst_buffer;
437 }
438 
BlockHostUntilReady()439 Status PyTpuBuffer::BlockHostUntilReady() {
440   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::BlockHostUntilReady");
441   std::shared_ptr<TpuSharedBuffer> device_buffer = DeviceBuffer();
442   if (!device_buffer) {
443     return InvalidArgument(
444         "BlockHostUntilReady() called on deleted or donated buffer");
445   }
446   return device_buffer->handle->OnReady()->Await();
447 }
448 
449 /* static */
AllocateBuffer(const Shape & shape,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)450 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
451     const Shape& shape, std::shared_ptr<PyTpuClient> client,
452     std::shared_ptr<PjRtDevice> device) {
453   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer");
454   VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString()
455           << " device: " << device->DebugString();
456 
457   if (!shape.IsTuple()) {
458     return CreateBuffer(shape, std::nullopt, std::move(client),
459                         std::move(device));
460   }
461 
462   std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
463   child_buffers.reserve(shape.tuple_shapes().size());
464   std::vector<PyTpuBuffer*> child_buffer_ptrs;
465   child_buffer_ptrs.reserve(shape.tuple_shapes().size());
466 
467   for (const auto& child_shape : shape.tuple_shapes()) {
468     TF_ASSIGN_OR_RETURN(std::unique_ptr<PyTpuBuffer> child_buffer,
469                         AllocateBuffer(child_shape, client, device));
470     child_buffer_ptrs.push_back(child_buffer.get());
471     child_buffers.push_back(std::move(child_buffer));
472   }
473 
474   // `MakeTuple` will extract and make the tuple buffer hold onto the
475   // `device_buffer_` contained in each `child_buffer`, so it's safe for
476   // `child_buffers` to get destroyed before this call returns.
477   return PyTpuBuffer::MakeTuple(child_buffer_ptrs, std::move(client),
478                                 std::move(device));
479 }
480 
481 /*static*/
CreateBuffer(const Shape & non_tuple_shape,std::optional<BufferInitializer> initializer,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)482 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
483     const Shape& non_tuple_shape, std::optional<BufferInitializer> initializer,
484     std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
485   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer");
486   VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: "
487           << non_tuple_shape.DebugString()
488           << " device: " << device->DebugString();
489   TF_RET_CHECK(!non_tuple_shape.IsTuple());
490   TF_RETURN_IF_ERROR(CheckDataType(non_tuple_shape.element_type()));
491 
492   std::unique_ptr<tpu_driver::BufferHandle> handle =
493       client->driver()->Allocate(device->id(), tpu_driver::MemoryRegion::HBM,
494                                  non_tuple_shape.ToProto(), {});
495 
496   // If this buffer needs to be initialized, anyone using this buffer must wait
497   // for the initialization event in `wait_for_use` to finish first.
498   std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
499   if (initializer.has_value()) {
500     std::shared_ptr<tpu_driver::Event> init = initializer.value()(handle.get());
501     wait_for_use.push_back(std::move(init));
502   }
503   auto device_buffer = std::make_shared<TpuSharedBuffer>(
504       client->driver(), std::move(handle), std::move(wait_for_use),
505       std::move(device));
506 
507   return std::make_unique<PyTpuBuffer>(
508       non_tuple_shape, std::move(device_buffer),
509       std::vector<std::shared_ptr<TpuSharedBuffer>>(), client);
510 }
511 
LookupDevice(const PyTpuClient & client,int device_id)512 static std::shared_ptr<PjRtDevice> LookupDevice(const PyTpuClient& client,
513                                                 int device_id) {
514   auto it = client.id_to_device().find(device_id);
515   CHECK(it != client.id_to_device().end())
516       << "Unknown device id: " << device_id;
517   return it->second;
518 }
519 
PyTpuExecutable(std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,DeviceAssignment device_assignment,std::shared_ptr<PyTpuClient> client,xla::Shape result_shape,bool tuple_arguments)520 PyTpuExecutable::PyTpuExecutable(
521     std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
522     DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
523     xla::Shape result_shape, bool tuple_arguments)
524     : client_(std::move(client)),
525       device_assignment_(std::move(device_assignment)),
526       tuple_arguments_(tuple_arguments),
527       result_shape_(std::move(result_shape)) {
528   VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString();
529   const int num_replicas = device_assignment_.replica_count();
530   const int num_partitions = device_assignment_.computation_count();
531   CHECK_EQ(num_partitions, 1) << "partition count > 1 is not supported.";
532   for (int replica = 0; replica < num_replicas; ++replica) {
533     for (int partition = 0; partition < num_partitions; ++partition) {
534       int device_id = device_assignment_(replica, partition);
535       std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
536       if (device->process_index() != client_->process_index()) {
537         VLOG(3) << "Non-local device: " << device_id;
538         continue;
539       }
540       // TODO(b/147895917): support replica + partition natively.
541       CHECK(executables_.find(replica) == executables_.end())
542           << "Inserting duplicate replica:" << replica;
543       executables_[replica] =
544           client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
545       local_logical_device_ids_.emplace_back(replica, partition);
546       local_devices_.push_back(device);
547     }
548   }
549   CHECK_GE(local_devices_.size(), 1);
550   CHECK_LE(executables_.size(), client_->device_count());
551   CHECK_LE(local_devices_.size(), client_->local_device_count())
552       << "Inconsistent local device count.";
553 }
554 
ExecuteHelper(absl::Span<const std::vector<PyTpuBuffer * >> maybe_tupled_args,absl::Span<PyTpuBuffer * const> this_core_arguments,int replica,int partition,const RunId & run_id)555 PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
556     absl::Span<const std::vector<PyTpuBuffer*>> maybe_tupled_args,
557     absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
558     int partition, const RunId& run_id) {
559   const int device_id = device_assignment_(replica, partition);
560   std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
561   CHECK_EQ(device->process_index(), client_->process_index());
562   tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
563   VLOG(3) << "Replica " << replica << ", partition " << partition
564           << " mapped to device id for execution: " << device_id;
565 
566   std::unique_ptr<::xla::PyTpuBuffer> output_buffer =
567       ::xla::PyTpuBuffer::AllocateBuffer(result_shape_, client_,
568                                          std::move(device))
569           .ValueOrDie();
570   VLOG(1) << "Created output buffer: " << result_shape_.DebugString();
571 
572   std::vector<tpu_driver::BufferHandle*> inputs;
573   std::vector<tpu_driver::Event*> ready_to_execute;
574 
575   std::shared_ptr<tpu_driver::Event> output_buffer_ready =
576       output_buffer->DeviceBuffer()->handle->OnReady();
577 
578   ready_to_execute.push_back(output_buffer_ready.get());
579 
580   for (auto* input_handle : this_core_arguments) {
581     inputs.push_back(input_handle->DeviceBuffer()->handle.get());
582   }
583 
584   for (const auto& core_args : maybe_tupled_args) {
585     for (const auto* handle : core_args) {
586       for (const auto& pending_event : handle->DeviceBuffer()->wait_for_use) {
587         ready_to_execute.push_back(pending_event.get());
588       }
589     }
590   }
591 
592   xla::DeviceAssignmentProto device_assignment;
593   CHECK(device_assignment_.Serialize(&device_assignment).ok());
594   std::shared_ptr<tpu_driver::Event> on_execute_finished =
595       client_->driver()->ExecuteProgram(
596           executables_.find(replica)->second.get(), inputs,
597           {output_buffer->DeviceBuffer()->handle.get()}, device_assignment,
598           {ready_to_execute});
599 
600   return {std::move(output_buffer), std::move(on_execute_finished)};
601 }
602 
603 // Delay before warning about a slow execute call.
604 static const absl::Duration kWarnExecutionDelay = absl::Seconds(10);
605 
606 // Delay before terminating a stalled execute call.
607 static const absl::Duration kMaxExecutionDelay = absl::Minutes(60);
608 
WaitForExecuteEvent(tpu_driver::Event * event)609 Status WaitForExecuteEvent(tpu_driver::Event* event) {
610   std::optional<Status> opt_status;
611   auto start_time = absl::Now();
612 
613   while (!opt_status.has_value() &&
614          absl::Now() - start_time < kMaxExecutionDelay) {
615     opt_status = event->AwaitWithTimeout(kWarnExecutionDelay);
616     if (!opt_status.has_value()) {
617       LOG(WARNING)
618           << "TPU Execute is taking a long time. This might be due to a "
619              "deadlock between multiple TPU cores or a very slow program.";
620     }
621   }
622 
623   if (!opt_status.has_value()) {
624     return tensorflow::errors::DeadlineExceeded(
625         absl::StrFormat("TPU program took more than %d seconds to complete.",
626                         absl::ToInt64Seconds(kMaxExecutionDelay)));
627   }
628 
629   return opt_status.value();
630 }
631 
Execute(absl::Span<PyTpuBuffer * const> argument_handles)632 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
633     absl::Span<PyTpuBuffer* const> argument_handles) {
634   if (num_replicas() != 1) {
635     return InvalidArgument(
636         "Attempted to execute computation with %d replicas using Execute().",
637         num_replicas());
638   }
639   if (num_partitions() != 1) {
640     return InvalidArgument(
641         "Attempted to execute computation with %d partitions using Execute().",
642         num_partitions());
643   }
644 
645   std::vector<PyTpuBuffer*> maybe_tupled_args;
646   std::unique_ptr<PyTpuBuffer> tupled_arguments;
647   if (tuple_arguments_) {
648     TF_ASSIGN_OR_RETURN(tupled_arguments,
649                         PyTpuBuffer::MakeTuple(argument_handles, client_,
650                                                local_devices_.front()));
651     maybe_tupled_args = {tupled_arguments.get()};
652   } else {
653     maybe_tupled_args = std::vector<PyTpuBuffer*>(argument_handles.begin(),
654                                                   argument_handles.end());
655   }
656   ExecuteResult result =
657       ExecuteHelper(absl::MakeSpan(&maybe_tupled_args, 1), maybe_tupled_args,
658                     /*replica=*/0, /*partition=*/0, RunId());
659 
660   Status status = WaitForExecuteEvent(result.on_execute_finished.get());
661 
662   if (!status.ok()) {
663     LOG(ERROR) << "Failed to execute program: " << status;
664     return status;
665   }
666 
667   if (result.buffer->on_host_shape().IsTuple()) {
668     return result.buffer->DestructureTuple();
669   } else {
670     std::vector<std::unique_ptr<PyTpuBuffer>> outputs;
671     outputs.push_back(std::move(result.buffer));
672     return outputs;
673   }
674 }
675 
676 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> argument_handles)677 PyTpuExecutable::ExecuteOnLocalDevices(
678     absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
679   tensorflow::profiler::TraceMe traceme(
680       "PyTpuExecutable::ExecuteOnLocalDevices");
681 
682   const int num_local_devices = local_devices_.size();
683 
684   if (argument_handles.size() != num_local_devices) {
685     return InvalidArgument(
686         "Attempted to execute with %d argument lists when local device "
687         "count is %d (total replica count: %d, partition count: %d).",
688         argument_handles.size(), num_local_devices, num_replicas(),
689         num_partitions());
690   }
691 
692   VLOG(1) << "Executing computation; num_replicas=" << num_replicas()
693           << " num_partitions=" << num_partitions()
694           << " num_local_devices=" << num_local_devices;
695 
696   std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
697   std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
698   if (tuple_arguments_) {
699     tupled_arguments.resize(argument_handles.size());
700     tupled_argument_pointers.resize(argument_handles.size());
701     for (int i = 0; i < num_local_devices; ++i) {
702       TF_ASSIGN_OR_RETURN(tupled_arguments[i],
703                           PyTpuBuffer::MakeTuple(argument_handles[i], client_,
704                                                  local_devices_.at(i)));
705       tupled_argument_pointers[i] = {tupled_arguments[i].get()};
706     }
707     argument_handles = tupled_argument_pointers;
708   }
709 
710   absl::Mutex results_lock;
711   std::vector<ExecuteResult> results(num_local_devices);
712 
713   auto* thread_pool = client_->GetThreadPool();
714 
715   int failed = 0;
716   Status first_failure_status;
717 
718   xla::Semaphore execute_semaphore(0);
719   for (int i = 0; i < num_local_devices; ++i) {
720     // We are scheduling Execute on a thread pool as ExecuteHelper can take a
721     // long time and we want all cores to be scheduled in parallel.
722     thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
723                            &execute_semaphore]() {
724       const int replica = local_logical_device_ids_[i].first;
725       const int partition = local_logical_device_ids_[i].second;
726       RunId run_id;
727       auto result = ExecuteHelper(argument_handles, argument_handles[i],
728                                   replica, partition, run_id);
729       results[i] = std::move(result);
730       execute_semaphore.Release(1);
731     });
732   }
733 
734   execute_semaphore.Acquire(num_local_devices);
735 
736   for (int i = 0; i < num_local_devices; ++i) {
737     auto s = WaitForExecuteEvent(results[i].on_execute_finished.get());
738     if (!s.ok()) {
739       if (failed == 0) {
740         first_failure_status = Status(
741             static_cast<tensorflow::error::Code>(s.code()), s.error_message());
742       }
743       ++failed;
744     }
745   }
746   if (failed > 0) {
747     return first_failure_status;
748   }
749   VLOG(1) << "Replicated execution complete.";
750 
751   std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> wrapped_results(
752       num_local_devices);
753   for (int i = 0; i < num_local_devices; ++i) {
754     if (results[i].buffer->on_host_shape().IsTuple()) {
755       TF_ASSIGN_OR_RETURN(wrapped_results[i],
756                           results[i].buffer->DestructureTuple());
757     } else {
758       wrapped_results[i].push_back(std::move(results[i].buffer));
759     }
760   }
761   return wrapped_results;
762 }
763 
764 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> args)765 PyTpuExecutable::ExecuteShardedOnLocalDevices(
766     absl::Span<const std::vector<PyTpuBuffer*>> args) {
767   std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> output_buffers;
768   TF_RET_CHECK(!args.empty());
769   int num_computations = args.front().size();
770   for (const auto& arg : args) {
771     if (arg.size() != num_computations) {
772       return xla::InvalidArgument(
773           "Expected args to execute_sharded_on_local_devices to have %d "
774           "shards, got: [%s]",
775           num_computations,
776           absl::StrJoin(
777               args, ", ",
778               [](std::string* out, const std::vector<PyTpuBuffer*>& arg) {
779                 out->append(std::to_string(arg.size()));
780               }));
781     }
782   }
783   std::vector<std::vector<PyTpuBuffer*>> arg_buffers(num_computations);
784   for (int computation = 0; computation < num_computations; ++computation) {
785     arg_buffers[computation].resize(args.size());
786     absl::c_transform(
787         args, arg_buffers[computation].begin(),
788         [&](const std::vector<PyTpuBuffer*>& arg) { return arg[computation]; });
789   }
790   TF_ASSIGN_OR_RETURN(output_buffers, ExecuteOnLocalDevices(arg_buffers));
791   int num_output_buffers = output_buffers[0].size();
792   std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> outputs;
793   outputs.resize(num_output_buffers);
794   for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
795     outputs[buffer_id].reserve(num_computations);
796     for (int computation = 0; computation < num_computations; ++computation) {
797       outputs[buffer_id].push_back(
798           std::move(output_buffers[computation][buffer_id]));
799     }
800   }
801   return outputs;
802 }
803 
Compile(const XlaComputation & computation,std::optional<std::vector<Shape>> argument_layouts,const ExecutableBuildOptions * build_options,std::shared_ptr<PyTpuClient> client,bool tuple_arguments)804 /*static*/ StatusOr<std::unique_ptr<PyTpuExecutable>> PyTpuExecutable::Compile(
805     const XlaComputation& computation,
806     std::optional<std::vector<Shape>> argument_layouts,
807     const ExecutableBuildOptions* build_options,
808     std::shared_ptr<PyTpuClient> client, bool tuple_arguments) {
809   tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
810 
811   VLOG(1) << "Compile: "
812           << computation.GetProgramShape().ValueOrDie().DebugString();
813 
814   // TODO(power) -- handle argument layouts
815   // TODO(power) -- handle build options
816   ExecutableBuildOptions options;
817   if (build_options != nullptr) {
818     options = *build_options;
819   }
820   std::optional<xla::DeviceAssignment> device_assignment;
821 
822   // For POD use case, the device_assignment.num_replicas() may be greater than
823   // the number of available local devices, where applicable the non-local
824   // devices must be filtered out from participating local computation.
825   if (options.has_device_assignment()) {
826     if (options.device_assignment().replica_count() != options.num_replicas()) {
827       return InvalidArgument(
828           "Mismatched number of replicas for device "
829           "assignment and computation (%d vs %d).",
830           options.device_assignment().replica_count(), options.num_replicas());
831     } else if (options.device_assignment().computation_count() != 1) {
832       return Unimplemented(
833           "Only 1 computation per replica supported, %d requested.",
834           options.device_assignment().computation_count());
835     }
836     device_assignment = options.device_assignment();
837   } else {
838     TF_ASSIGN_OR_RETURN(device_assignment,
839                         client->GetDefaultDeviceAssignment(
840                             options.num_replicas(), options.num_partitions()));
841   }
842   CHECK_GE(options.num_replicas(), 1);
843   CHECK_EQ(options.num_replicas(), device_assignment->replica_count());
844   CHECK(!argument_layouts.has_value());
845 
846   // TODO(henrytan): an area for optimization with less buffer copy.
847   xla::HloProto hlo_proto;
848   *hlo_proto.mutable_hlo_module() = computation.proto();
849 
850   // TODO(henrytan): in the future, we want to consider argument Layout
851   // information e.g. for linearization.
852   std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program =
853       client->driver()->CompileProgram(hlo_proto, options.num_replicas(), {});
854 
855   ::xla::Shape result_layout;
856   if (options.result_layout()) {
857     result_layout = *options.result_layout();
858   } else {
859     xla::ProgramShapeProto program_shape_proto;
860     auto fetch_metadata_status =
861         compiled_program->program_shape(&program_shape_proto);
862 
863     if (!fetch_metadata_status.ok()) {
864       return Status(
865           static_cast<tensorflow::error::Code>(fetch_metadata_status.code()),
866           fetch_metadata_status.error_message());
867     }
868     result_layout = ::xla::Shape(program_shape_proto.result());
869   }
870   VLOG(1) << "Got result shape: " << result_layout.DebugString();
871 
872   return std::make_unique<PyTpuExecutable>(
873       std::move(compiled_program), std::move(*device_assignment),
874       std::move(client), std::move(result_layout), tuple_arguments);
875 }
876 
877 /*static*/ StatusOr<std::unique_ptr<PyTpuExecutable>>
CompileMlir(mlir::ModuleOp module,std::optional<std::vector<Shape>> argument_layouts,const ExecutableBuildOptions * build_options,std::shared_ptr<PyTpuClient> client,bool tuple_arguments)878 PyTpuExecutable::CompileMlir(mlir::ModuleOp module,
879                              std::optional<std::vector<Shape>> argument_layouts,
880                              const ExecutableBuildOptions* build_options,
881                              std::shared_ptr<PyTpuClient> client,
882                              bool tuple_arguments) {
883   XlaComputation xla_computation;
884   TF_RETURN_IF_ERROR(MlirToXlaComputation(module, xla_computation,
885                                           /*use_tuple_args=*/tuple_arguments,
886                                           /*return_tuple=*/false));
887   return Compile(xla_computation, argument_layouts, build_options, client,
888                  tuple_arguments);
889 }
890 
891 }  // namespace xla
892