• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h"
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/synchronization/mutex.h"
24 #include "absl/time/time.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
28 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
29 #include "tensorflow/compiler/xla/service/computation_placer.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/threadpool.h"
34 #include "tensorflow/core/profiler/lib/traceme.h"
35 
36 namespace xla {
37 
TpuDevice(int id,int task_id,const std::array<int,3> & coords,int core_on_chip)38 TpuDevice::TpuDevice(int id, int task_id, const std::array<int, 3>& coords,
39                      int core_on_chip)
40     : id_(id),
41       task_id_(task_id),
42       coords_(coords),
43       core_on_chip_(core_on_chip) {}
44 
DebugString() const45 std::string TpuDevice::DebugString() const {
46   return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(),
47                          coords_[0], coords_[1], coords_[2], core_on_chip_);
48 }
49 
50 xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
GetTpuDevices(const tpu_driver::SystemInfo & system_info)51 TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
52   std::vector<std::shared_ptr<PjRtDevice>> devices;
53   for (const auto& chip : system_info.tpu_chip()) {
54     auto& coord = chip.chip_coord();
55     std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
56     int task_id = chip.host_id();
57     for (const auto& core : chip.core()) {
58       auto device = std::make_shared<TpuDevice>(
59           core.id(), task_id, coords_array, core.core_on_chip_index());
60       devices.push_back(device);
61     }
62   }
63 
64   return devices;
65 }
66 
Get(const std::string & worker)67 StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
68     const std::string& worker) {
69   tpu_driver::TpuDriverConfig driver_config;
70   driver_config.set_worker(worker);
71   auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
72   if (!client_status.ok()) {
73     return client_status.status();
74   }
75 
76   auto client = client_status.ConsumeValueOrDie();
77 
78   tpu_driver::SystemInfo system_info;
79   client->QuerySystemInfo(&system_info);
80 
81   TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<PjRtDevice>> devices,
82                       TpuDevice::GetTpuDevices(system_info));
83 
84   return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
85                                        std::move(devices),
86                                        system_info.host_id());
87 }
88 
PyTpuClient(std::string platform_name,std::unique_ptr<tpu_driver::TpuDriver> driver,std::vector<std::shared_ptr<PjRtDevice>> devices,int task_id)89 PyTpuClient::PyTpuClient(std::string platform_name,
90                          std::unique_ptr<tpu_driver::TpuDriver> driver,
91                          std::vector<std::shared_ptr<PjRtDevice>> devices,
92                          int task_id)
93     : platform_name_(std::move(platform_name)),
94       driver_(std::move(driver)),
95       devices_(std::move(devices)),
96       task_id_(task_id) {
97   for (const std::shared_ptr<PjRtDevice>& device : devices_) {
98     CHECK(id_to_device_.insert({device->id(), device}).second)
99         << "Duplicate device id: " << device->id();
100 
101     if (device->task_id() == task_id_) {
102       LOG(INFO) << "Detected local device, host id: " << task_id_
103                 << ". device id: " << device->id();
104       local_devices_.push_back(device);
105     } else {
106       VLOG(2) << "Other devices, device id: " << device->id();
107     }
108   }
109   CHECK_GE(local_devices_.size(), 1);
110   LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
111 
112   for (int idx = 0; idx < local_devices_.size(); ++idx) {
113     CHECK(local_devices_[idx] != nullptr) << idx;
114   }
115 
116   // TODO(frankchn): Check if thread pool size needs to be adjusted (perhaps
117   // something like min(cores, devices_.size()) might be appropriate depending
118   // on the number of devices.
119   pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
120       tensorflow::Env::Default(), "PyTpuClient", devices_.size());
121 }
122 
TransferToInfeed(const LiteralSlice & literal,int device_id)123 Status PyTpuClient::TransferToInfeed(const LiteralSlice& literal,
124                                      int device_id) {
125   return Unimplemented("Infeed not implemented.");
126 }
127 
TransferFromOutfeed(const Shape & shape,int device_id)128 StatusOr<Literal> PyTpuClient::TransferFromOutfeed(const Shape& shape,
129                                                    int device_id) {
130   return Unimplemented("Outfeed not implemented.");
131 }
132 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const133 StatusOr<DeviceAssignment> PyTpuClient::GetDefaultDeviceAssignment(
134     int num_replicas, int num_partitions) const {
135   if (num_partitions > 1) {
136     return InvalidArgument("Num partitions greater than 1, is not supported.");
137   }
138   if (num_replicas * num_partitions <= local_device_count()) {
139     DeviceAssignment assignment(num_replicas, num_partitions);
140     for (int replica = 0; replica < num_replicas; ++replica) {
141       for (int partition = 0; partition < num_partitions; ++partition) {
142         assignment(replica, partition) = local_devices_[replica]->id();
143       }
144     }
145     return assignment;
146   }
147 
148   // Fallback to default global device assignment if we can't run locally.
149   xla::ComputationPlacer placer;
150   return placer.AssignDevices(num_replicas, num_partitions);
151 }
152 
CheckDeviceId(int device_id,absl::string_view caller_name)153 Status PyTpuClient::CheckDeviceId(int device_id,
154                                   absl::string_view caller_name) {
155   if (device_id < 0 || device_id >= device_count()) {
156     return InvalidArgument("%s got bad device_id: %d (num_devices=%d).",
157                            caller_name, device_id, device_count());
158   }
159   return Status::OK();
160 }
161 
CheckDataType(xla::PrimitiveType dtype)162 static Status CheckDataType(xla::PrimitiveType dtype) {
163   if (dtype == xla::PrimitiveType::F64 || dtype == xla::PrimitiveType::S64 ||
164       dtype == xla::PrimitiveType::U64) {
165     return InvalidArgument(
166         "64-bit data types are not yet supported on the TPU driver API. "
167         "Convert inputs to float32/int32 before using.");
168   }
169   return Status::OK();
170 }
171 
172 /* static */
FromLiterals(std::vector<BorrowingLiteral> leaves,const Shape & tuple_shape,std::shared_ptr<void> leaves_references,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)173 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
174     std::vector<BorrowingLiteral> leaves, const Shape& tuple_shape,
175     std::shared_ptr<void> leaves_references,
176     std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
177   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals");
178   VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString()
179           << " device: " << device->DebugString();
180   TF_RETURN_IF_ERROR(
181       client->CheckDeviceId(device->id(), "PyTpuBuffer::FromLiterals"));
182   tpu_driver::TpuDriver* driver = client->driver();
183 
184   if (!tuple_shape.IsTuple()) {
185     TF_RET_CHECK(leaves.size() == 1);
186     return CreateBuffer(
187         tuple_shape,
188         [driver, &leaves, &tuple_shape,
189          leaves_references](tpu_driver::BufferHandle* handle) {
190           auto event =
191               driver->TransferToDevice(leaves[0].untyped_data(), handle, {});
192           event->AddCallback([leaves_references](Status) {});
193           return event;
194         },
195         std::move(client), std::move(device));
196   }
197 
198   std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
199   child_buffers.reserve(leaves.size());
200   std::vector<PyTpuBuffer*> child_buffer_ptrs;
201   child_buffer_ptrs.reserve(leaves.size());
202 
203   auto it_leaf = leaves.begin();
204   for (const ShapeUtil::IndexedShape& indexed_shape :
205        ShapeUtil::GetLeafShapes(tuple_shape)) {
206     TF_RET_CHECK(it_leaf != leaves.end());
207     auto& leaf = *it_leaf;
208     TF_ASSIGN_OR_RETURN(
209         std::unique_ptr<PyTpuBuffer> child_buffer,
210         CreateBuffer(
211             indexed_shape.shape,
212             [driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) {
213               return driver->TransferToDevice(leaf.untyped_data(), handle, {});
214             },
215             client, device));
216     child_buffer_ptrs.push_back(child_buffer.get());
217     child_buffers.push_back(std::move(child_buffer));
218     ++it_leaf;
219   }
220   TF_RET_CHECK(it_leaf == leaves.end());
221 
222   // `MakeTuple` will extract and make the tuple buffer hold onto the
223   // `device_buffer_` contained in each `child_buffer`, so it's safe for
224   // `child_buffers` to get destroyed before this call returns.
225   return MakeTuple(std::move(child_buffer_ptrs), std::move(client),
226                    std::move(device));
227 }
228 
229 /* static */
MakeTuple(absl::Span<PyTpuBuffer * const> buffers,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)230 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
231     absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
232     std::shared_ptr<PjRtDevice> device) {
233   std::vector<Shape> child_shapes;
234   std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
235   std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
236   std::vector<std::shared_ptr<tpu_driver::Event>> child_events;
237 
238   for (const auto& child_buffer : buffers) {
239     child_shapes.push_back(child_buffer->on_host_shape());
240     std::shared_ptr<TpuSharedBuffer> child_device_buffer =
241         child_buffer->DeviceBuffer();
242     // Merge all definition events from all children, so that anyone using this
243     // tuple must wait for all its children to finish receiving transfers. This
244     // works recursively up a nested tuple tree as well.
245     for (std::shared_ptr<tpu_driver::Event> child_event :
246          child_device_buffer->wait_for_use) {
247       child_events.push_back(std::move(child_event));
248     }
249     child_handle_ptrs.push_back(child_device_buffer->handle.get());
250     child_device_buffers.push_back(std::move(child_device_buffer));
251   }
252 
253   Shape tuple_shape = ShapeUtil::MakeTupleShape(child_shapes);
254   std::unique_ptr<tpu_driver::BufferHandle> tuple_handle =
255       client->driver()->AllocateTuple(
256           device->id(), tpu_driver::MemoryRegion::HBM, child_handle_ptrs, {});
257   auto tuple_device_buffer = std::make_shared<TpuSharedBuffer>(
258       client->driver(), std::move(tuple_handle), std::move(child_events),
259       std::move(device));
260   return absl::make_unique<PyTpuBuffer>(
261       tuple_shape, std::move(tuple_device_buffer),
262       std::move(child_device_buffers), std::move(client));
263 }
264 
PyTpuBuffer(Shape on_host_shape,std::shared_ptr<TpuSharedBuffer> device_buffer,std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,std::shared_ptr<PyTpuClient> client)265 PyTpuBuffer::PyTpuBuffer(
266     Shape on_host_shape, std::shared_ptr<TpuSharedBuffer> device_buffer,
267     std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
268     std::shared_ptr<PyTpuClient> client)
269     : client_(std::move(client)),
270       on_host_shape_(std::move(on_host_shape)),
271       device_(device_buffer->device),
272       device_buffer_(std::move(device_buffer)),
273       child_buffers_(std::move(child_buffers)) {}
274 
Delete()275 void PyTpuBuffer::Delete() {
276   absl::MutexLock lock(&mu_);
277   device_buffer_ = nullptr;
278   child_buffers_.clear();
279   host_value_ = nullptr;
280 }
281 
CopyToHostAsync()282 Status PyTpuBuffer::CopyToHostAsync() {
283   std::vector<std::shared_ptr<tpu_driver::Event>> transfer_events;
284   std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
285 
286   {
287     absl::MutexLock lock(&mu_);
288     if (!device_buffer_) {
289       return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
290     }
291 
292     if (host_value_) {
293       // The host value has already been requested or is available.
294       return Status::OK();
295     }
296 
297     host_value->value = std::make_shared<Literal>(on_host_shape_);
298     host_value->pending_ops = std::max(1ul, child_buffers_.size());
299     host_value_ = host_value;
300 
301     std::vector<tpu_driver::Event*> events;
302     for (const auto& e : device_buffer_->wait_for_use) {
303       events.push_back(e.get());
304     }
305 
306     VLOG(1) << "CopyToHostAsync:: host shape: "
307             << host_value->value->shape().DebugString();
308 
309     if (!on_host_shape_.IsTuple()) {
310       CHECK(child_buffers_.empty());
311       transfer_events.push_back(client_->driver()->TransferFromDevice(
312           device_buffer_->handle.get(), host_value->value->untyped_data(),
313           events));
314     } else {
315       for (int i = 0; i < child_buffers_.size(); ++i) {
316         auto& c = child_buffers_[i];
317         transfer_events.push_back(client_->driver()->TransferFromDevice(
318             c->handle.get(),
319             host_value->value->untyped_data(xla::ShapeIndex({i})), events));
320       }
321     }
322   }
323 
324   for (auto& t : transfer_events) {
325     t->AddCallback([host_value](const xla::Status& status) {
326       VLOG(1) << "Device to host transfer finished.";
327       if (!status.ok()) {
328         host_value->status =
329             Status(static_cast<tensorflow::error::Code>(status.code()),
330                    status.error_message());
331       }
332 
333       absl::MutexLock m(&host_value->mutex);
334       --host_value->pending_ops;
335       if (host_value->pending_ops == 0) {
336         VLOG(1) << "Host value done: " << host_value->status;
337         host_value->ready.Notify();
338       }
339     });
340   }
341   return Status::OK();
342 }
343 
ToLiteral()344 StatusOr<std::shared_ptr<Literal>> PyTpuBuffer::ToLiteral() {
345   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::ToLiteral");
346   TF_RETURN_IF_ERROR(CopyToHostAsync());
347 
348   mu_.Lock();
349   std::shared_ptr<HostValue> host_value = host_value_;
350   mu_.Unlock();
351 
352   VLOG(1) << "Waiting for device to host transfer " << host_value.get();
353   host_value->ready.WaitForNotification();
354   VLOG(1) << "Host copy finished, status:" << host_value->status;
355   TF_RETURN_IF_ERROR(host_value->status);
356 
357   return host_value->value;
358 }
359 
DeviceBuffer() const360 std::shared_ptr<TpuSharedBuffer> PyTpuBuffer::DeviceBuffer() const {
361   absl::MutexLock lock(&mu_);
362   return device_buffer_;
363 }
364 
365 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
DestructureTuple()366 PyTpuBuffer::DestructureTuple() {
367   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::DestructureTuple");
368   if (!on_host_shape_.IsTuple()) {
369     return InvalidArgument(
370         "Attempted to destructure a PyTpuBuffer that did not have a tuple "
371         "shape; shape: %s.",
372         ShapeUtil::HumanString(on_host_shape_));
373   }
374   if (DeviceBuffer() == nullptr) {
375     return InvalidArgument("Attempted to destructure a deleted buffer.");
376   }
377 
378   absl::MutexLock lock(&mu_);
379   int num_children = ShapeUtil::TupleElementCount(on_host_shape_);
380   std::vector<std::unique_ptr<PyTpuBuffer>> results;
381   results.reserve(num_children);
382   for (int i = 0; i < num_children; ++i) {
383     results.push_back(absl::make_unique<PyTpuBuffer>(
384         on_host_shape_.tuple_shapes(i), child_buffers_.at(i),
385         std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_));
386   }
387   return results;
388 }
389 
CopyToDevice(std::shared_ptr<PjRtDevice> dst_device)390 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
391     std::shared_ptr<PjRtDevice> dst_device) {
392   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice");
393   if (on_host_shape_.IsTuple()) {
394     return Unimplemented("CopyToDevice for tuples is not supported.");
395   }
396 
397   std::shared_ptr<TpuSharedBuffer> src_device_buffer = DeviceBuffer();
398   if (dst_device->id() == device_->id()) {
399     return absl::make_unique<PyTpuBuffer>(
400         on_host_shape_, src_device_buffer,
401         std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_);
402   }
403 
404   tpu_driver::TpuDriver* driver = client_->driver();
405   TF_ASSIGN_OR_RETURN(
406       std::unique_ptr<PyTpuBuffer> dst_buffer,
407       CreateBuffer(
408           on_host_shape_,
409           [driver, src_device_buffer](tpu_driver::BufferHandle* dst_handle) {
410             std::vector<tpu_driver::Event*> src_wait_for_use;
411             for (auto& event : src_device_buffer->wait_for_use) {
412               src_wait_for_use.push_back(event.get());
413             }
414             return driver->TransferFromDeviceToDevice(
415                 src_device_buffer->handle.get(), dst_handle, src_wait_for_use);
416           },
417           client_, std::move(dst_device)));
418   // TODO(jiawenhao): This may be too pessimistic: it prevents future readers
419   // from reading `src_device_buffer` until the device-to-device copy is done.
420   // Should this go into a new `TpuSharedBuffer::wait_for_dealloc` field?
421   auto& wait_for_use = dst_buffer->DeviceBuffer()->wait_for_use;
422   src_device_buffer->wait_for_use.insert(src_device_buffer->wait_for_use.end(),
423                                          wait_for_use.begin(),
424                                          wait_for_use.end());
425   return dst_buffer;
426 }
427 
BlockHostUntilReady()428 Status PyTpuBuffer::BlockHostUntilReady() {
429   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::BlockHostUntilReady");
430   std::shared_ptr<TpuSharedBuffer> device_buffer = DeviceBuffer();
431   if (!device_buffer) {
432     return InvalidArgument(
433         "BlockHostUntilReady() called on deleted or donated buffer");
434   }
435   return device_buffer->handle->OnReady()->Await();
436 }
437 
438 /* static */
AllocateBuffer(const Shape & shape,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)439 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
440     const Shape& shape, std::shared_ptr<PyTpuClient> client,
441     std::shared_ptr<PjRtDevice> device) {
442   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer");
443   VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString()
444           << " device: " << device->DebugString();
445 
446   if (!shape.IsTuple()) {
447     return CreateBuffer(shape, absl::nullopt, std::move(client),
448                         std::move(device));
449   }
450 
451   std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
452   child_buffers.reserve(shape.tuple_shapes().size());
453   std::vector<PyTpuBuffer*> child_buffer_ptrs;
454   child_buffer_ptrs.reserve(shape.tuple_shapes().size());
455 
456   for (const auto& child_shape : shape.tuple_shapes()) {
457     TF_ASSIGN_OR_RETURN(std::unique_ptr<PyTpuBuffer> child_buffer,
458                         AllocateBuffer(child_shape, client, device));
459     child_buffer_ptrs.push_back(child_buffer.get());
460     child_buffers.push_back(std::move(child_buffer));
461   }
462 
463   // `MakeTuple` will extract and make the tuple buffer hold onto the
464   // `device_buffer_` contained in each `child_buffer`, so it's safe for
465   // `child_buffers` to get destroyed before this call returns.
466   return PyTpuBuffer::MakeTuple(child_buffer_ptrs, std::move(client),
467                                 std::move(device));
468 }
469 
470 /*static*/
CreateBuffer(const Shape & non_tuple_shape,absl::optional<BufferInitializer> initializer,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)471 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
472     const Shape& non_tuple_shape, absl::optional<BufferInitializer> initializer,
473     std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
474   tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer");
475   VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: "
476           << non_tuple_shape.DebugString()
477           << " device: " << device->DebugString();
478   TF_RET_CHECK(!non_tuple_shape.IsTuple());
479   TF_RETURN_IF_ERROR(CheckDataType(non_tuple_shape.element_type()));
480 
481   std::unique_ptr<tpu_driver::BufferHandle> handle =
482       client->driver()->Allocate(device->id(), tpu_driver::MemoryRegion::HBM,
483                                  non_tuple_shape.ToProto(), {});
484 
485   // If this buffer needs to be initialized, anyone using this buffer must wait
486   // for the initialization event in `wait_for_use` to finish first.
487   std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
488   if (initializer.has_value()) {
489     std::shared_ptr<tpu_driver::Event> init = initializer.value()(handle.get());
490     wait_for_use.push_back(std::move(init));
491   }
492   auto device_buffer = std::make_shared<TpuSharedBuffer>(
493       client->driver(), std::move(handle), std::move(wait_for_use),
494       std::move(device));
495 
496   return absl::make_unique<PyTpuBuffer>(
497       non_tuple_shape, std::move(device_buffer),
498       std::vector<std::shared_ptr<TpuSharedBuffer>>(), client);
499 }
500 
LookupDevice(const PyTpuClient & client,int device_id)501 static std::shared_ptr<PjRtDevice> LookupDevice(const PyTpuClient& client,
502                                                 int device_id) {
503   auto it = client.id_to_device().find(device_id);
504   CHECK(it != client.id_to_device().end())
505       << "Unknown device id: " << device_id;
506   return it->second;
507 }
508 
PyTpuExecutable(std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,DeviceAssignment device_assignment,std::shared_ptr<PyTpuClient> client,xla::Shape result_shape,bool tuple_arguments)509 PyTpuExecutable::PyTpuExecutable(
510     std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
511     DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
512     xla::Shape result_shape, bool tuple_arguments)
513     : client_(std::move(client)),
514       device_assignment_(std::move(device_assignment)),
515       tuple_arguments_(tuple_arguments),
516       result_shape_(std::move(result_shape)) {
517   VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString();
518   const int num_replicas = device_assignment_.replica_count();
519   const int num_partitions = device_assignment_.computation_count();
520   CHECK_EQ(num_partitions, 1) << "partition count > 1 is not supported.";
521   for (int replica = 0; replica < num_replicas; ++replica) {
522     for (int partition = 0; partition < num_partitions; ++partition) {
523       int device_id = device_assignment_(replica, partition);
524       std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
525       if (device->task_id() != client_->task_id()) {
526         VLOG(3) << "Non-local device: " << device_id;
527         continue;
528       }
529       // TODO(b/147895917): support replica + partition natively.
530       CHECK(executables_.find(replica) == executables_.end())
531           << "Inserting duplicate replica:" << replica;
532       executables_[replica] =
533           client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
534       local_logical_device_ids_.emplace_back(replica, partition);
535       local_devices_.push_back(device);
536     }
537   }
538   CHECK_GE(local_devices_.size(), 1);
539   CHECK_LE(executables_.size(), client_->device_count());
540   CHECK_LE(local_devices_.size(), client_->local_device_count())
541       << "Inconsistent local device count.";
542 }
543 
ExecuteHelper(absl::Span<const std::vector<PyTpuBuffer * >> all_core_arguments,absl::Span<PyTpuBuffer * const> this_core_arguments,int replica,int partition,const RunId & run_id)544 PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
545     absl::Span<const std::vector<PyTpuBuffer*>> all_core_arguments,
546     absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
547     int partition, const RunId& run_id) {
548   const int device_id = device_assignment_(replica, partition);
549   std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
550   CHECK_EQ(device->task_id(), client_->task_id());
551   tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
552   VLOG(3) << "Replica " << replica << ", partition " << partition
553           << " mapped to device id for execution: " << device_id;
554 
555   std::unique_ptr<::xla::PyTpuBuffer> output_buffer =
556       ::xla::PyTpuBuffer::AllocateBuffer(result_shape_, client_,
557                                          std::move(device))
558           .ValueOrDie();
559   VLOG(1) << "Created output buffer: " << result_shape_.DebugString();
560 
561   std::vector<tpu_driver::BufferHandle*> inputs;
562   std::vector<tpu_driver::Event*> ready_to_execute;
563 
564   std::shared_ptr<tpu_driver::Event> output_buffer_ready =
565       output_buffer->DeviceBuffer()->handle->OnReady();
566 
567   ready_to_execute.push_back(output_buffer_ready.get());
568 
569   for (auto* input_handle : this_core_arguments) {
570     inputs.push_back(input_handle->DeviceBuffer()->handle.get());
571   }
572 
573   for (const auto& core_args : all_core_arguments) {
574     for (const auto* handle : core_args) {
575       for (const auto& pending_event : handle->DeviceBuffer()->wait_for_use) {
576         ready_to_execute.push_back(pending_event.get());
577       }
578     }
579   }
580 
581   xla::DeviceAssignmentProto device_assignment;
582   CHECK(device_assignment_.Serialize(&device_assignment).ok());
583   std::shared_ptr<tpu_driver::Event> on_execute_finished =
584       client_->driver()->ExecuteProgram(
585           executables_.find(replica)->second.get(), inputs,
586           {output_buffer->DeviceBuffer()->handle.get()}, device_assignment,
587           {ready_to_execute});
588 
589   return {std::move(output_buffer), std::move(on_execute_finished)};
590 }
591 
592 // Delay before warning about a slow execute call.
593 static const absl::Duration kWarnExecutionDelay = absl::Seconds(10);
594 
595 // Delay before terminating a stalled execute call.
596 static const absl::Duration kMaxExecutionDelay = absl::Minutes(60);
597 
WaitForExecuteEvent(tpu_driver::Event * event)598 Status WaitForExecuteEvent(tpu_driver::Event* event) {
599   absl::optional<Status> opt_status;
600   auto start_time = absl::Now();
601 
602   while (!opt_status.has_value() &&
603          absl::Now() - start_time < kMaxExecutionDelay) {
604     opt_status = event->AwaitWithTimeout(kWarnExecutionDelay);
605     if (!opt_status.has_value()) {
606       LOG(WARNING)
607           << "TPU Execute is taking a long time. This might be due to a "
608              "deadlock between multiple TPU cores or a very slow program.";
609     }
610   }
611 
612   if (!opt_status.has_value()) {
613     return tensorflow::errors::DeadlineExceeded(
614         absl::StrFormat("TPU program took more than %d seconds to complete.",
615                         absl::ToInt64Seconds(kMaxExecutionDelay)));
616   }
617 
618   return opt_status.value();
619 }
620 
Execute(absl::Span<PyTpuBuffer * const> argument_handles)621 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
622     absl::Span<PyTpuBuffer* const> argument_handles) {
623   if (num_replicas() != 1) {
624     return InvalidArgument(
625         "Attempted to execute computation with %d replicas using Execute().",
626         num_replicas());
627   }
628   if (num_partitions() != 1) {
629     return InvalidArgument(
630         "Attempted to execute computation with %d partitions using Execute().",
631         num_partitions());
632   }
633 
634   std::vector<PyTpuBuffer*> all_core_arguments;
635 
636   std::unique_ptr<PyTpuBuffer> tupled_arguments;
637   if (tuple_arguments_) {
638     TF_ASSIGN_OR_RETURN(tupled_arguments,
639                         PyTpuBuffer::MakeTuple(argument_handles, client_,
640                                                local_devices_.front()));
641     all_core_arguments = {tupled_arguments.get()};
642   } else {
643     all_core_arguments = std::vector<PyTpuBuffer*>(argument_handles.begin(),
644                                                    argument_handles.end());
645   }
646   ExecuteResult result =
647       ExecuteHelper(absl::MakeSpan(&all_core_arguments, 1), argument_handles,
648                     /*replica=*/0, /*partition=*/0, RunId());
649 
650   Status status = WaitForExecuteEvent(result.on_execute_finished.get());
651 
652   if (!status.ok()) {
653     LOG(ERROR) << "Failed to execute program: " << status;
654     return status;
655   }
656 
657   if (result.buffer->on_host_shape().IsTuple()) {
658     return result.buffer->DestructureTuple();
659   } else {
660     std::vector<std::unique_ptr<PyTpuBuffer>> outputs;
661     outputs.push_back(std::move(result.buffer));
662     return outputs;
663   }
664 }
665 
666 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> argument_handles)667 PyTpuExecutable::ExecuteOnLocalDevices(
668     absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
669   tensorflow::profiler::TraceMe traceme(
670       "PyTpuExecutable::ExecuteOnLocalDevices");
671 
672   const int num_local_devices = local_devices_.size();
673 
674   if (argument_handles.size() != num_local_devices) {
675     return InvalidArgument(
676         "Attempted to execute with %d argument lists when local device "
677         "count is %d (total replica count: %d, partition count: %d).",
678         argument_handles.size(), num_local_devices, num_replicas(),
679         num_partitions());
680   }
681 
682   VLOG(1) << "Executing computation; num_replicas=" << num_replicas()
683           << " num_partitions=" << num_partitions()
684           << " num_local_devices=" << num_local_devices;
685 
686   std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
687   std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
688   if (tuple_arguments_) {
689     tupled_arguments.resize(argument_handles.size());
690     tupled_argument_pointers.resize(argument_handles.size());
691     for (int i = 0; i < num_local_devices; ++i) {
692       TF_ASSIGN_OR_RETURN(tupled_arguments[i],
693                           PyTpuBuffer::MakeTuple(argument_handles[i], client_,
694                                                  local_devices_.at(i)));
695       tupled_argument_pointers[i] = {tupled_arguments[i].get()};
696     }
697     argument_handles = tupled_argument_pointers;
698   }
699 
700   absl::Mutex results_lock;
701   std::vector<ExecuteResult> results(num_local_devices);
702 
703   auto* thread_pool = client_->GetThreadPool();
704 
705   int failed = 0;
706   Status first_failure_status;
707 
708   xla::Semaphore execute_semaphore(0);
709   for (int i = 0; i < num_local_devices; ++i) {
710     // We are scheduling Execute on a thread pool as ExecuteHelper can take a
711     // long time and we want all cores to be scheduled in parallel.
712     thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
713                            &execute_semaphore]() {
714       const int replica = local_logical_device_ids_[i].first;
715       const int partition = local_logical_device_ids_[i].second;
716       RunId run_id;
717       auto result = ExecuteHelper(argument_handles, argument_handles[i],
718                                   replica, partition, run_id);
719       results[i] = std::move(result);
720       execute_semaphore.Release(1);
721     });
722   }
723 
724   execute_semaphore.Acquire(num_local_devices);
725 
726   for (int i = 0; i < num_local_devices; ++i) {
727     auto s = WaitForExecuteEvent(results[i].on_execute_finished.get());
728     if (!s.ok()) {
729       if (failed == 0) {
730         first_failure_status = Status(
731             static_cast<tensorflow::error::Code>(s.code()), s.error_message());
732       }
733       ++failed;
734     }
735   }
736   if (failed > 0) {
737     return first_failure_status;
738   }
739   VLOG(1) << "Replicated execution complete.";
740 
741   std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> wrapped_results(
742       num_local_devices);
743   for (int i = 0; i < num_local_devices; ++i) {
744     if (results[i].buffer->on_host_shape().IsTuple()) {
745       TF_ASSIGN_OR_RETURN(wrapped_results[i],
746                           results[i].buffer->DestructureTuple());
747     } else {
748       wrapped_results[i].push_back(std::move(results[i].buffer));
749     }
750   }
751   return wrapped_results;
752 }
753 
754 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> args)755 PyTpuExecutable::ExecuteShardedOnLocalDevices(
756     absl::Span<const std::vector<PyTpuBuffer*>> args) {
757   std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> output_buffers;
758   TF_RET_CHECK(!args.empty());
759   int num_computations = args.front().size();
760   for (const auto& arg : args) {
761     if (arg.size() != num_computations) {
762       return xla::InvalidArgument(
763           "Expected args to execute_sharded_on_local_devices to have %d "
764           "shards, got: [%s]",
765           num_computations,
766           absl::StrJoin(
767               args, ", ",
768               [](std::string* out, const std::vector<PyTpuBuffer*>& arg) {
769                 out->append(std::to_string(arg.size()));
770               }));
771     }
772   }
773   std::vector<std::vector<PyTpuBuffer*>> arg_buffers(num_computations);
774   for (int computation = 0; computation < num_computations; ++computation) {
775     arg_buffers[computation].resize(args.size());
776     absl::c_transform(
777         args, arg_buffers[computation].begin(),
778         [&](const std::vector<PyTpuBuffer*>& arg) { return arg[computation]; });
779   }
780   TF_ASSIGN_OR_RETURN(output_buffers, ExecuteOnLocalDevices(arg_buffers));
781   int num_output_buffers = output_buffers[0].size();
782   std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> outputs;
783   outputs.resize(num_output_buffers);
784   for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
785     outputs[buffer_id].reserve(num_computations);
786     for (int computation = 0; computation < num_computations; ++computation) {
787       outputs[buffer_id].push_back(
788           std::move(output_buffers[computation][buffer_id]));
789     }
790   }
791   return outputs;
792 }
793 
Compile(const XlaComputation & computation,absl::optional<std::vector<Shape>> argument_layouts,const ExecutableBuildOptions * build_options,std::shared_ptr<PyTpuClient> client,bool tuple_arguments)794 /*static*/ StatusOr<std::unique_ptr<PyTpuExecutable>> PyTpuExecutable::Compile(
795     const XlaComputation& computation,
796     absl::optional<std::vector<Shape>> argument_layouts,
797     const ExecutableBuildOptions* build_options,
798     std::shared_ptr<PyTpuClient> client, bool tuple_arguments) {
799   tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
800 
801   VLOG(1) << "Compile: "
802           << computation.GetProgramShape().ValueOrDie().DebugString();
803 
804   // TODO(power) -- handle argument layouts
805   // TODO(power) -- handle build options
806   ExecutableBuildOptions options;
807   if (build_options != nullptr) {
808     options = *build_options;
809   }
810   absl::optional<xla::DeviceAssignment> device_assignment;
811 
812   // For POD use case, the device_assignment.num_replicas() may be greater than
813   // the number of available local devices, where applicable the non-local
814   // devices must be filtered out from participating local computation.
815   if (options.has_device_assignment()) {
816     if (options.device_assignment().replica_count() != options.num_replicas()) {
817       return InvalidArgument(
818           "Mismatched number of replicas for device "
819           "assignment and computation (%d vs %d).",
820           options.device_assignment().replica_count(), options.num_replicas());
821     } else if (options.device_assignment().computation_count() != 1) {
822       return Unimplemented(
823           "Only 1 computation per replica supported, %d requested.",
824           options.device_assignment().computation_count());
825     }
826     device_assignment = options.device_assignment();
827   } else {
828     TF_ASSIGN_OR_RETURN(device_assignment,
829                         client->GetDefaultDeviceAssignment(
830                             options.num_replicas(), options.num_partitions()));
831   }
832   CHECK_GE(options.num_replicas(), 1);
833   CHECK_EQ(options.num_replicas(), device_assignment->replica_count());
834   CHECK(!argument_layouts.has_value());
835 
836   // TODO(henrytan): an area for optimization with less buffer copy.
837   xla::HloProto hlo_proto;
838   *hlo_proto.mutable_hlo_module() = computation.proto();
839 
840   // TODO(henrytan): in the future, we want to consider argument Layout
841   // information e.g. for linearization.
842   std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program =
843       client->driver()->CompileProgram(hlo_proto, options.num_replicas(), {});
844 
845   ::xla::Shape result_layout;
846   if (options.result_layout()) {
847     result_layout = *options.result_layout();
848   } else {
849     xla::ProgramShapeProto program_shape_proto;
850     auto fetch_metadata_status =
851         compiled_program->program_shape(&program_shape_proto);
852 
853     if (!fetch_metadata_status.ok()) {
854       return Status(
855           static_cast<tensorflow::error::Code>(fetch_metadata_status.code()),
856           fetch_metadata_status.error_message());
857     }
858     result_layout = ::xla::Shape(program_shape_proto.result());
859   }
860   VLOG(1) << "Got result shape: " << result_layout.DebugString();
861 
862   return absl::make_unique<PyTpuExecutable>(
863       std::move(compiled_program), std::move(*device_assignment),
864       std::move(client), std::move(result_layout), tuple_arguments);
865 }
866 
867 }  // namespace xla
868