1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/synchronization/mutex.h"
25 #include "absl/time/time.h"
26 #include "absl/types/span.h"
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
30 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
31 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
32 #include "tensorflow/compiler/xla/service/computation_placer.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/threadpool.h"
37 #include "tensorflow/core/profiler/lib/traceme.h"
38
39 namespace xla {
40
TpuDevice(int id,int process_index,const std::array<int,3> & coords,int core_on_chip)41 TpuDevice::TpuDevice(int id, int process_index,
42 const std::array<int, 3>& coords, int core_on_chip)
43 : id_(id),
44 process_index_(process_index),
45 coords_(coords),
46 core_on_chip_(core_on_chip) {
47 debug_string_ =
48 absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id_, process_index_,
49 coords_[0], coords_[1], coords_[2], core_on_chip_);
50 to_string_ = absl::StrFormat(
51 "TpuDevice(id=%i, process_index=%i, coords=(%s), core_on_chip=%i)", id_,
52 process_index_, absl::StrJoin(coords_, ","), core_on_chip_);
53 }
54
DebugString() const55 absl::string_view TpuDevice::DebugString() const { return debug_string_; }
56
ToString() const57 absl::string_view TpuDevice::ToString() const { return to_string_; }
58
59 xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
GetTpuDevices(const tpu_driver::SystemInfo & system_info)60 TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
61 std::vector<std::shared_ptr<PjRtDevice>> devices;
62 for (const auto& chip : system_info.tpu_chip()) {
63 auto& coord = chip.chip_coord();
64 std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
65 int process_index = chip.host_id();
66 for (const auto& core : chip.core()) {
67 auto device = std::make_shared<TpuDevice>(
68 core.id(), process_index, coords_array, core.core_on_chip_index());
69 devices.push_back(device);
70 }
71 }
72
73 return devices;
74 }
75
Get(const std::string & worker)76 StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
77 const std::string& worker) {
78 tpu_driver::TpuDriverConfig driver_config;
79 driver_config.set_worker(worker);
80 auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
81 if (!client_status.ok()) {
82 return client_status.status();
83 }
84
85 auto client = std::move(client_status).value();
86
87 tpu_driver::SystemInfo system_info;
88 client->QuerySystemInfo(&system_info);
89
90 TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<PjRtDevice>> devices,
91 TpuDevice::GetTpuDevices(system_info));
92
93 return std::make_shared<PyTpuClient>(TpuPlatform(), std::move(client),
94 std::move(devices),
95 system_info.host_id());
96 }
97
PyTpuClient(std::string platform_name,std::unique_ptr<tpu_driver::TpuDriver> driver,std::vector<std::shared_ptr<PjRtDevice>> devices,int process_index)98 PyTpuClient::PyTpuClient(std::string platform_name,
99 std::unique_ptr<tpu_driver::TpuDriver> driver,
100 std::vector<std::shared_ptr<PjRtDevice>> devices,
101 int process_index)
102 : platform_name_(std::move(platform_name)),
103 platform_version_("tpu_driver (deprecated)"),
104 driver_(std::move(driver)),
105 devices_(std::move(devices)),
106 process_index_(process_index) {
107 for (const std::shared_ptr<PjRtDevice>& device : devices_) {
108 tensorflow::down_cast<TpuDevice*>(device.get())->set_tpu_client(this);
109 CHECK(id_to_device_.insert({device->id(), device}).second)
110 << "Duplicate device id: " << device->id();
111
112 if (device->process_index() == process_index_) {
113 LOG(INFO) << "Detected local device, host id: " << process_index_
114 << ". device id: " << device->id();
115 local_devices_.push_back(device);
116 } else {
117 VLOG(2) << "Other devices, device id: " << device->id();
118 }
119 }
120 CHECK_GE(local_devices_.size(), 1);
121 LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
122
123 for (int idx = 0; idx < local_devices_.size(); ++idx) {
124 CHECK(local_devices_[idx] != nullptr) << idx;
125 }
126
127 // TODO(frankchn): Check if thread pool size needs to be adjusted (perhaps
128 // something like min(cores, devices_.size()) might be appropriate depending
129 // on the number of devices.
130 pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
131 tensorflow::Env::Default(), "PyTpuClient", devices_.size());
132 }
133
TransferToInfeed(const LiteralSlice & literal,int device_id)134 Status PyTpuClient::TransferToInfeed(const LiteralSlice& literal,
135 int device_id) {
136 return Unimplemented("Infeed not implemented.");
137 }
138
TransferFromOutfeed(const Shape & shape,int device_id)139 StatusOr<Literal> PyTpuClient::TransferFromOutfeed(const Shape& shape,
140 int device_id) {
141 return Unimplemented("Outfeed not implemented.");
142 }
143
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const144 StatusOr<DeviceAssignment> PyTpuClient::GetDefaultDeviceAssignment(
145 int num_replicas, int num_partitions) const {
146 if (num_partitions > 1) {
147 return InvalidArgument("Num partitions greater than 1, is not supported.");
148 }
149 if (num_replicas * num_partitions <= local_device_count()) {
150 DeviceAssignment assignment(num_replicas, num_partitions);
151 for (int replica = 0; replica < num_replicas; ++replica) {
152 for (int partition = 0; partition < num_partitions; ++partition) {
153 assignment(replica, partition) = local_devices_[replica]->id();
154 }
155 }
156 return assignment;
157 }
158
159 // Fallback to default global device assignment if we can't run locally.
160 xla::ComputationPlacer placer;
161 return placer.AssignDevices(num_replicas, num_partitions);
162 }
163
CheckDeviceId(int device_id,absl::string_view caller_name)164 Status PyTpuClient::CheckDeviceId(int device_id,
165 absl::string_view caller_name) {
166 if (device_id < 0 || device_id >= device_count()) {
167 return InvalidArgument("%s got bad device_id: %d (num_devices=%d).",
168 caller_name, device_id, device_count());
169 }
170 return OkStatus();
171 }
172
CheckDataType(xla::PrimitiveType dtype)173 static Status CheckDataType(xla::PrimitiveType dtype) {
174 if (dtype == xla::PrimitiveType::F64 || dtype == xla::PrimitiveType::S64 ||
175 dtype == xla::PrimitiveType::U64) {
176 return InvalidArgument(
177 "64-bit data types are not yet supported on the TPU driver API. "
178 "Convert inputs to float32/int32_t before using.");
179 }
180 return OkStatus();
181 }
182
183 /* static */
FromLiterals(std::vector<BorrowingLiteral> leaves,const Shape & tuple_shape,std::shared_ptr<void> leaves_references,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)184 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
185 std::vector<BorrowingLiteral> leaves, const Shape& tuple_shape,
186 std::shared_ptr<void> leaves_references,
187 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
188 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals");
189 VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString()
190 << " device: " << device->DebugString();
191 TF_RETURN_IF_ERROR(
192 client->CheckDeviceId(device->id(), "PyTpuBuffer::FromLiterals"));
193 tpu_driver::TpuDriver* driver = client->driver();
194
195 if (!tuple_shape.IsTuple()) {
196 TF_RET_CHECK(leaves.size() == 1);
197 return CreateBuffer(
198 tuple_shape,
199 [driver, &leaves, &tuple_shape,
200 leaves_references](tpu_driver::BufferHandle* handle) {
201 auto event =
202 driver->TransferToDevice(leaves[0].untyped_data(), handle, {});
203 event->AddCallback([leaves_references](Status) {});
204 return event;
205 },
206 std::move(client), std::move(device));
207 }
208
209 std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
210 child_buffers.reserve(leaves.size());
211 std::vector<PyTpuBuffer*> child_buffer_ptrs;
212 child_buffer_ptrs.reserve(leaves.size());
213
214 auto it_leaf = leaves.begin();
215 for (const ShapeUtil::IndexedShape& indexed_shape :
216 ShapeUtil::GetLeafShapes(tuple_shape)) {
217 TF_RET_CHECK(it_leaf != leaves.end());
218 auto& leaf = *it_leaf;
219 TF_ASSIGN_OR_RETURN(
220 std::unique_ptr<PyTpuBuffer> child_buffer,
221 CreateBuffer(
222 indexed_shape.shape,
223 [driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) {
224 return driver->TransferToDevice(leaf.untyped_data(), handle, {});
225 },
226 client, device));
227 child_buffer_ptrs.push_back(child_buffer.get());
228 child_buffers.push_back(std::move(child_buffer));
229 ++it_leaf;
230 }
231 TF_RET_CHECK(it_leaf == leaves.end());
232
233 // `MakeTuple` will extract and make the tuple buffer hold onto the
234 // `device_buffer_` contained in each `child_buffer`, so it's safe for
235 // `child_buffers` to get destroyed before this call returns.
236 return MakeTuple(std::move(child_buffer_ptrs), std::move(client),
237 std::move(device));
238 }
239
240 /* static */
MakeTuple(absl::Span<PyTpuBuffer * const> buffers,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)241 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
242 absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
243 std::shared_ptr<PjRtDevice> device) {
244 std::vector<Shape> child_shapes;
245 std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
246 std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
247 std::vector<std::shared_ptr<tpu_driver::Event>> child_events;
248
249 for (const auto& child_buffer : buffers) {
250 child_shapes.push_back(child_buffer->on_host_shape());
251 std::shared_ptr<TpuSharedBuffer> child_device_buffer =
252 child_buffer->DeviceBuffer();
253 // Merge all definition events from all children, so that anyone using this
254 // tuple must wait for all its children to finish receiving transfers. This
255 // works recursively up a nested tuple tree as well.
256 for (std::shared_ptr<tpu_driver::Event> child_event :
257 child_device_buffer->wait_for_use) {
258 child_events.push_back(std::move(child_event));
259 }
260 child_handle_ptrs.push_back(child_device_buffer->handle.get());
261 child_device_buffers.push_back(std::move(child_device_buffer));
262 }
263
264 Shape tuple_shape = ShapeUtil::MakeTupleShape(child_shapes);
265 std::unique_ptr<tpu_driver::BufferHandle> tuple_handle =
266 client->driver()->AllocateTuple(
267 device->id(), tpu_driver::MemoryRegion::HBM, child_handle_ptrs, {});
268 auto tuple_device_buffer = std::make_shared<TpuSharedBuffer>(
269 client->driver(), std::move(tuple_handle), std::move(child_events),
270 std::move(device));
271 return std::make_unique<PyTpuBuffer>(
272 tuple_shape, std::move(tuple_device_buffer),
273 std::move(child_device_buffers), std::move(client));
274 }
275
PyTpuBuffer(Shape on_host_shape,std::shared_ptr<TpuSharedBuffer> device_buffer,std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,std::shared_ptr<PyTpuClient> client)276 PyTpuBuffer::PyTpuBuffer(
277 Shape on_host_shape, std::shared_ptr<TpuSharedBuffer> device_buffer,
278 std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
279 std::shared_ptr<PyTpuClient> client)
280 : client_(std::move(client)),
281 on_host_shape_(std::move(on_host_shape)),
282 device_(device_buffer->device),
283 device_buffer_(std::move(device_buffer)),
284 child_buffers_(std::move(child_buffers)) {}
285
Delete()286 void PyTpuBuffer::Delete() {
287 absl::MutexLock lock(&mu_);
288 device_buffer_ = nullptr;
289 child_buffers_.clear();
290 host_value_ = nullptr;
291 }
292
CopyToHostAsync()293 Status PyTpuBuffer::CopyToHostAsync() {
294 std::vector<std::shared_ptr<tpu_driver::Event>> transfer_events;
295 std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
296
297 {
298 absl::MutexLock lock(&mu_);
299 if (!device_buffer_) {
300 return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
301 }
302
303 if (host_value_) {
304 // The host value has already been requested or is available.
305 return OkStatus();
306 }
307
308 host_value->value = std::make_shared<Literal>(on_host_shape_);
309 host_value->pending_ops = std::max(1ul, child_buffers_.size());
310 host_value_ = host_value;
311
312 std::vector<tpu_driver::Event*> events;
313 for (const auto& e : device_buffer_->wait_for_use) {
314 events.push_back(e.get());
315 }
316
317 VLOG(1) << "CopyToHostAsync:: host shape: "
318 << host_value->value->shape().DebugString();
319
320 if (!on_host_shape_.IsTuple()) {
321 CHECK(child_buffers_.empty());
322 transfer_events.push_back(client_->driver()->TransferFromDevice(
323 device_buffer_->handle.get(), host_value->value->untyped_data(),
324 events));
325 } else {
326 for (int i = 0; i < child_buffers_.size(); ++i) {
327 auto& c = child_buffers_[i];
328 transfer_events.push_back(client_->driver()->TransferFromDevice(
329 c->handle.get(),
330 host_value->value->untyped_data(xla::ShapeIndex({i})), events));
331 }
332 }
333 }
334
335 for (auto& t : transfer_events) {
336 t->AddCallback([host_value](const xla::Status& status) {
337 VLOG(1) << "Device to host transfer finished.";
338 if (!status.ok()) {
339 host_value->status =
340 Status(static_cast<tensorflow::error::Code>(status.code()),
341 status.error_message());
342 }
343
344 absl::MutexLock m(&host_value->mutex);
345 --host_value->pending_ops;
346 if (host_value->pending_ops == 0) {
347 VLOG(1) << "Host value done: " << host_value->status;
348 host_value->ready.Notify();
349 }
350 });
351 }
352 return OkStatus();
353 }
354
ToLiteral()355 StatusOr<std::shared_ptr<Literal>> PyTpuBuffer::ToLiteral() {
356 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::ToLiteral");
357 TF_RETURN_IF_ERROR(CopyToHostAsync());
358
359 mu_.Lock();
360 std::shared_ptr<HostValue> host_value = host_value_;
361 mu_.Unlock();
362
363 VLOG(1) << "Waiting for device to host transfer " << host_value.get();
364 host_value->ready.WaitForNotification();
365 VLOG(1) << "Host copy finished, status:" << host_value->status;
366 TF_RETURN_IF_ERROR(host_value->status);
367
368 return host_value->value;
369 }
370
DeviceBuffer() const371 std::shared_ptr<TpuSharedBuffer> PyTpuBuffer::DeviceBuffer() const {
372 absl::MutexLock lock(&mu_);
373 return device_buffer_;
374 }
375
376 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
DestructureTuple()377 PyTpuBuffer::DestructureTuple() {
378 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::DestructureTuple");
379 if (!on_host_shape_.IsTuple()) {
380 return InvalidArgument(
381 "Attempted to destructure a PyTpuBuffer that did not have a tuple "
382 "shape; shape: %s.",
383 ShapeUtil::HumanString(on_host_shape_));
384 }
385 if (DeviceBuffer() == nullptr) {
386 return InvalidArgument("Attempted to destructure a deleted buffer.");
387 }
388
389 absl::MutexLock lock(&mu_);
390 int num_children = ShapeUtil::TupleElementCount(on_host_shape_);
391 std::vector<std::unique_ptr<PyTpuBuffer>> results;
392 results.reserve(num_children);
393 for (int i = 0; i < num_children; ++i) {
394 results.push_back(std::make_unique<PyTpuBuffer>(
395 on_host_shape_.tuple_shapes(i), child_buffers_.at(i),
396 std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_));
397 }
398 return results;
399 }
400
CopyToDevice(std::shared_ptr<PjRtDevice> dst_device)401 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
402 std::shared_ptr<PjRtDevice> dst_device) {
403 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice");
404 if (on_host_shape_.IsTuple()) {
405 return Unimplemented("CopyToDevice for tuples is not supported.");
406 }
407
408 std::shared_ptr<TpuSharedBuffer> src_device_buffer = DeviceBuffer();
409 if (dst_device->id() == device_->id()) {
410 return std::make_unique<PyTpuBuffer>(
411 on_host_shape_, src_device_buffer,
412 std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_);
413 }
414
415 tpu_driver::TpuDriver* driver = client_->driver();
416 TF_ASSIGN_OR_RETURN(
417 std::unique_ptr<PyTpuBuffer> dst_buffer,
418 CreateBuffer(
419 on_host_shape_,
420 [driver, src_device_buffer](tpu_driver::BufferHandle* dst_handle) {
421 std::vector<tpu_driver::Event*> src_wait_for_use;
422 for (auto& event : src_device_buffer->wait_for_use) {
423 src_wait_for_use.push_back(event.get());
424 }
425 return driver->TransferFromDeviceToDevice(
426 src_device_buffer->handle.get(), dst_handle, src_wait_for_use);
427 },
428 client_, std::move(dst_device)));
429 // TODO(jiawenhao): This may be too pessimistic: it prevents future readers
430 // from reading `src_device_buffer` until the device-to-device copy is done.
431 // Should this go into a new `TpuSharedBuffer::wait_for_dealloc` field?
432 auto& wait_for_use = dst_buffer->DeviceBuffer()->wait_for_use;
433 src_device_buffer->wait_for_use.insert(src_device_buffer->wait_for_use.end(),
434 wait_for_use.begin(),
435 wait_for_use.end());
436 return dst_buffer;
437 }
438
BlockHostUntilReady()439 Status PyTpuBuffer::BlockHostUntilReady() {
440 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::BlockHostUntilReady");
441 std::shared_ptr<TpuSharedBuffer> device_buffer = DeviceBuffer();
442 if (!device_buffer) {
443 return InvalidArgument(
444 "BlockHostUntilReady() called on deleted or donated buffer");
445 }
446 return device_buffer->handle->OnReady()->Await();
447 }
448
449 /* static */
AllocateBuffer(const Shape & shape,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)450 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
451 const Shape& shape, std::shared_ptr<PyTpuClient> client,
452 std::shared_ptr<PjRtDevice> device) {
453 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer");
454 VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString()
455 << " device: " << device->DebugString();
456
457 if (!shape.IsTuple()) {
458 return CreateBuffer(shape, std::nullopt, std::move(client),
459 std::move(device));
460 }
461
462 std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
463 child_buffers.reserve(shape.tuple_shapes().size());
464 std::vector<PyTpuBuffer*> child_buffer_ptrs;
465 child_buffer_ptrs.reserve(shape.tuple_shapes().size());
466
467 for (const auto& child_shape : shape.tuple_shapes()) {
468 TF_ASSIGN_OR_RETURN(std::unique_ptr<PyTpuBuffer> child_buffer,
469 AllocateBuffer(child_shape, client, device));
470 child_buffer_ptrs.push_back(child_buffer.get());
471 child_buffers.push_back(std::move(child_buffer));
472 }
473
474 // `MakeTuple` will extract and make the tuple buffer hold onto the
475 // `device_buffer_` contained in each `child_buffer`, so it's safe for
476 // `child_buffers` to get destroyed before this call returns.
477 return PyTpuBuffer::MakeTuple(child_buffer_ptrs, std::move(client),
478 std::move(device));
479 }
480
481 /*static*/
CreateBuffer(const Shape & non_tuple_shape,std::optional<BufferInitializer> initializer,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)482 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
483 const Shape& non_tuple_shape, std::optional<BufferInitializer> initializer,
484 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
485 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer");
486 VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: "
487 << non_tuple_shape.DebugString()
488 << " device: " << device->DebugString();
489 TF_RET_CHECK(!non_tuple_shape.IsTuple());
490 TF_RETURN_IF_ERROR(CheckDataType(non_tuple_shape.element_type()));
491
492 std::unique_ptr<tpu_driver::BufferHandle> handle =
493 client->driver()->Allocate(device->id(), tpu_driver::MemoryRegion::HBM,
494 non_tuple_shape.ToProto(), {});
495
496 // If this buffer needs to be initialized, anyone using this buffer must wait
497 // for the initialization event in `wait_for_use` to finish first.
498 std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
499 if (initializer.has_value()) {
500 std::shared_ptr<tpu_driver::Event> init = initializer.value()(handle.get());
501 wait_for_use.push_back(std::move(init));
502 }
503 auto device_buffer = std::make_shared<TpuSharedBuffer>(
504 client->driver(), std::move(handle), std::move(wait_for_use),
505 std::move(device));
506
507 return std::make_unique<PyTpuBuffer>(
508 non_tuple_shape, std::move(device_buffer),
509 std::vector<std::shared_ptr<TpuSharedBuffer>>(), client);
510 }
511
LookupDevice(const PyTpuClient & client,int device_id)512 static std::shared_ptr<PjRtDevice> LookupDevice(const PyTpuClient& client,
513 int device_id) {
514 auto it = client.id_to_device().find(device_id);
515 CHECK(it != client.id_to_device().end())
516 << "Unknown device id: " << device_id;
517 return it->second;
518 }
519
PyTpuExecutable(std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,DeviceAssignment device_assignment,std::shared_ptr<PyTpuClient> client,xla::Shape result_shape,bool tuple_arguments)520 PyTpuExecutable::PyTpuExecutable(
521 std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
522 DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
523 xla::Shape result_shape, bool tuple_arguments)
524 : client_(std::move(client)),
525 device_assignment_(std::move(device_assignment)),
526 tuple_arguments_(tuple_arguments),
527 result_shape_(std::move(result_shape)) {
528 VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString();
529 const int num_replicas = device_assignment_.replica_count();
530 const int num_partitions = device_assignment_.computation_count();
531 CHECK_EQ(num_partitions, 1) << "partition count > 1 is not supported.";
532 for (int replica = 0; replica < num_replicas; ++replica) {
533 for (int partition = 0; partition < num_partitions; ++partition) {
534 int device_id = device_assignment_(replica, partition);
535 std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
536 if (device->process_index() != client_->process_index()) {
537 VLOG(3) << "Non-local device: " << device_id;
538 continue;
539 }
540 // TODO(b/147895917): support replica + partition natively.
541 CHECK(executables_.find(replica) == executables_.end())
542 << "Inserting duplicate replica:" << replica;
543 executables_[replica] =
544 client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
545 local_logical_device_ids_.emplace_back(replica, partition);
546 local_devices_.push_back(device);
547 }
548 }
549 CHECK_GE(local_devices_.size(), 1);
550 CHECK_LE(executables_.size(), client_->device_count());
551 CHECK_LE(local_devices_.size(), client_->local_device_count())
552 << "Inconsistent local device count.";
553 }
554
ExecuteHelper(absl::Span<const std::vector<PyTpuBuffer * >> maybe_tupled_args,absl::Span<PyTpuBuffer * const> this_core_arguments,int replica,int partition,const RunId & run_id)555 PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
556 absl::Span<const std::vector<PyTpuBuffer*>> maybe_tupled_args,
557 absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
558 int partition, const RunId& run_id) {
559 const int device_id = device_assignment_(replica, partition);
560 std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
561 CHECK_EQ(device->process_index(), client_->process_index());
562 tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
563 VLOG(3) << "Replica " << replica << ", partition " << partition
564 << " mapped to device id for execution: " << device_id;
565
566 std::unique_ptr<::xla::PyTpuBuffer> output_buffer =
567 ::xla::PyTpuBuffer::AllocateBuffer(result_shape_, client_,
568 std::move(device))
569 .ValueOrDie();
570 VLOG(1) << "Created output buffer: " << result_shape_.DebugString();
571
572 std::vector<tpu_driver::BufferHandle*> inputs;
573 std::vector<tpu_driver::Event*> ready_to_execute;
574
575 std::shared_ptr<tpu_driver::Event> output_buffer_ready =
576 output_buffer->DeviceBuffer()->handle->OnReady();
577
578 ready_to_execute.push_back(output_buffer_ready.get());
579
580 for (auto* input_handle : this_core_arguments) {
581 inputs.push_back(input_handle->DeviceBuffer()->handle.get());
582 }
583
584 for (const auto& core_args : maybe_tupled_args) {
585 for (const auto* handle : core_args) {
586 for (const auto& pending_event : handle->DeviceBuffer()->wait_for_use) {
587 ready_to_execute.push_back(pending_event.get());
588 }
589 }
590 }
591
592 xla::DeviceAssignmentProto device_assignment;
593 CHECK(device_assignment_.Serialize(&device_assignment).ok());
594 std::shared_ptr<tpu_driver::Event> on_execute_finished =
595 client_->driver()->ExecuteProgram(
596 executables_.find(replica)->second.get(), inputs,
597 {output_buffer->DeviceBuffer()->handle.get()}, device_assignment,
598 {ready_to_execute});
599
600 return {std::move(output_buffer), std::move(on_execute_finished)};
601 }
602
603 // Delay before warning about a slow execute call.
604 static const absl::Duration kWarnExecutionDelay = absl::Seconds(10);
605
606 // Delay before terminating a stalled execute call.
607 static const absl::Duration kMaxExecutionDelay = absl::Minutes(60);
608
WaitForExecuteEvent(tpu_driver::Event * event)609 Status WaitForExecuteEvent(tpu_driver::Event* event) {
610 std::optional<Status> opt_status;
611 auto start_time = absl::Now();
612
613 while (!opt_status.has_value() &&
614 absl::Now() - start_time < kMaxExecutionDelay) {
615 opt_status = event->AwaitWithTimeout(kWarnExecutionDelay);
616 if (!opt_status.has_value()) {
617 LOG(WARNING)
618 << "TPU Execute is taking a long time. This might be due to a "
619 "deadlock between multiple TPU cores or a very slow program.";
620 }
621 }
622
623 if (!opt_status.has_value()) {
624 return tensorflow::errors::DeadlineExceeded(
625 absl::StrFormat("TPU program took more than %d seconds to complete.",
626 absl::ToInt64Seconds(kMaxExecutionDelay)));
627 }
628
629 return opt_status.value();
630 }
631
Execute(absl::Span<PyTpuBuffer * const> argument_handles)632 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
633 absl::Span<PyTpuBuffer* const> argument_handles) {
634 if (num_replicas() != 1) {
635 return InvalidArgument(
636 "Attempted to execute computation with %d replicas using Execute().",
637 num_replicas());
638 }
639 if (num_partitions() != 1) {
640 return InvalidArgument(
641 "Attempted to execute computation with %d partitions using Execute().",
642 num_partitions());
643 }
644
645 std::vector<PyTpuBuffer*> maybe_tupled_args;
646 std::unique_ptr<PyTpuBuffer> tupled_arguments;
647 if (tuple_arguments_) {
648 TF_ASSIGN_OR_RETURN(tupled_arguments,
649 PyTpuBuffer::MakeTuple(argument_handles, client_,
650 local_devices_.front()));
651 maybe_tupled_args = {tupled_arguments.get()};
652 } else {
653 maybe_tupled_args = std::vector<PyTpuBuffer*>(argument_handles.begin(),
654 argument_handles.end());
655 }
656 ExecuteResult result =
657 ExecuteHelper(absl::MakeSpan(&maybe_tupled_args, 1), maybe_tupled_args,
658 /*replica=*/0, /*partition=*/0, RunId());
659
660 Status status = WaitForExecuteEvent(result.on_execute_finished.get());
661
662 if (!status.ok()) {
663 LOG(ERROR) << "Failed to execute program: " << status;
664 return status;
665 }
666
667 if (result.buffer->on_host_shape().IsTuple()) {
668 return result.buffer->DestructureTuple();
669 } else {
670 std::vector<std::unique_ptr<PyTpuBuffer>> outputs;
671 outputs.push_back(std::move(result.buffer));
672 return outputs;
673 }
674 }
675
676 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> argument_handles)677 PyTpuExecutable::ExecuteOnLocalDevices(
678 absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
679 tensorflow::profiler::TraceMe traceme(
680 "PyTpuExecutable::ExecuteOnLocalDevices");
681
682 const int num_local_devices = local_devices_.size();
683
684 if (argument_handles.size() != num_local_devices) {
685 return InvalidArgument(
686 "Attempted to execute with %d argument lists when local device "
687 "count is %d (total replica count: %d, partition count: %d).",
688 argument_handles.size(), num_local_devices, num_replicas(),
689 num_partitions());
690 }
691
692 VLOG(1) << "Executing computation; num_replicas=" << num_replicas()
693 << " num_partitions=" << num_partitions()
694 << " num_local_devices=" << num_local_devices;
695
696 std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
697 std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
698 if (tuple_arguments_) {
699 tupled_arguments.resize(argument_handles.size());
700 tupled_argument_pointers.resize(argument_handles.size());
701 for (int i = 0; i < num_local_devices; ++i) {
702 TF_ASSIGN_OR_RETURN(tupled_arguments[i],
703 PyTpuBuffer::MakeTuple(argument_handles[i], client_,
704 local_devices_.at(i)));
705 tupled_argument_pointers[i] = {tupled_arguments[i].get()};
706 }
707 argument_handles = tupled_argument_pointers;
708 }
709
710 absl::Mutex results_lock;
711 std::vector<ExecuteResult> results(num_local_devices);
712
713 auto* thread_pool = client_->GetThreadPool();
714
715 int failed = 0;
716 Status first_failure_status;
717
718 xla::Semaphore execute_semaphore(0);
719 for (int i = 0; i < num_local_devices; ++i) {
720 // We are scheduling Execute on a thread pool as ExecuteHelper can take a
721 // long time and we want all cores to be scheduled in parallel.
722 thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
723 &execute_semaphore]() {
724 const int replica = local_logical_device_ids_[i].first;
725 const int partition = local_logical_device_ids_[i].second;
726 RunId run_id;
727 auto result = ExecuteHelper(argument_handles, argument_handles[i],
728 replica, partition, run_id);
729 results[i] = std::move(result);
730 execute_semaphore.Release(1);
731 });
732 }
733
734 execute_semaphore.Acquire(num_local_devices);
735
736 for (int i = 0; i < num_local_devices; ++i) {
737 auto s = WaitForExecuteEvent(results[i].on_execute_finished.get());
738 if (!s.ok()) {
739 if (failed == 0) {
740 first_failure_status = Status(
741 static_cast<tensorflow::error::Code>(s.code()), s.error_message());
742 }
743 ++failed;
744 }
745 }
746 if (failed > 0) {
747 return first_failure_status;
748 }
749 VLOG(1) << "Replicated execution complete.";
750
751 std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> wrapped_results(
752 num_local_devices);
753 for (int i = 0; i < num_local_devices; ++i) {
754 if (results[i].buffer->on_host_shape().IsTuple()) {
755 TF_ASSIGN_OR_RETURN(wrapped_results[i],
756 results[i].buffer->DestructureTuple());
757 } else {
758 wrapped_results[i].push_back(std::move(results[i].buffer));
759 }
760 }
761 return wrapped_results;
762 }
763
764 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> args)765 PyTpuExecutable::ExecuteShardedOnLocalDevices(
766 absl::Span<const std::vector<PyTpuBuffer*>> args) {
767 std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> output_buffers;
768 TF_RET_CHECK(!args.empty());
769 int num_computations = args.front().size();
770 for (const auto& arg : args) {
771 if (arg.size() != num_computations) {
772 return xla::InvalidArgument(
773 "Expected args to execute_sharded_on_local_devices to have %d "
774 "shards, got: [%s]",
775 num_computations,
776 absl::StrJoin(
777 args, ", ",
778 [](std::string* out, const std::vector<PyTpuBuffer*>& arg) {
779 out->append(std::to_string(arg.size()));
780 }));
781 }
782 }
783 std::vector<std::vector<PyTpuBuffer*>> arg_buffers(num_computations);
784 for (int computation = 0; computation < num_computations; ++computation) {
785 arg_buffers[computation].resize(args.size());
786 absl::c_transform(
787 args, arg_buffers[computation].begin(),
788 [&](const std::vector<PyTpuBuffer*>& arg) { return arg[computation]; });
789 }
790 TF_ASSIGN_OR_RETURN(output_buffers, ExecuteOnLocalDevices(arg_buffers));
791 int num_output_buffers = output_buffers[0].size();
792 std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> outputs;
793 outputs.resize(num_output_buffers);
794 for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
795 outputs[buffer_id].reserve(num_computations);
796 for (int computation = 0; computation < num_computations; ++computation) {
797 outputs[buffer_id].push_back(
798 std::move(output_buffers[computation][buffer_id]));
799 }
800 }
801 return outputs;
802 }
803
Compile(const XlaComputation & computation,std::optional<std::vector<Shape>> argument_layouts,const ExecutableBuildOptions * build_options,std::shared_ptr<PyTpuClient> client,bool tuple_arguments)804 /*static*/ StatusOr<std::unique_ptr<PyTpuExecutable>> PyTpuExecutable::Compile(
805 const XlaComputation& computation,
806 std::optional<std::vector<Shape>> argument_layouts,
807 const ExecutableBuildOptions* build_options,
808 std::shared_ptr<PyTpuClient> client, bool tuple_arguments) {
809 tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
810
811 VLOG(1) << "Compile: "
812 << computation.GetProgramShape().ValueOrDie().DebugString();
813
814 // TODO(power) -- handle argument layouts
815 // TODO(power) -- handle build options
816 ExecutableBuildOptions options;
817 if (build_options != nullptr) {
818 options = *build_options;
819 }
820 std::optional<xla::DeviceAssignment> device_assignment;
821
822 // For POD use case, the device_assignment.num_replicas() may be greater than
823 // the number of available local devices, where applicable the non-local
824 // devices must be filtered out from participating local computation.
825 if (options.has_device_assignment()) {
826 if (options.device_assignment().replica_count() != options.num_replicas()) {
827 return InvalidArgument(
828 "Mismatched number of replicas for device "
829 "assignment and computation (%d vs %d).",
830 options.device_assignment().replica_count(), options.num_replicas());
831 } else if (options.device_assignment().computation_count() != 1) {
832 return Unimplemented(
833 "Only 1 computation per replica supported, %d requested.",
834 options.device_assignment().computation_count());
835 }
836 device_assignment = options.device_assignment();
837 } else {
838 TF_ASSIGN_OR_RETURN(device_assignment,
839 client->GetDefaultDeviceAssignment(
840 options.num_replicas(), options.num_partitions()));
841 }
842 CHECK_GE(options.num_replicas(), 1);
843 CHECK_EQ(options.num_replicas(), device_assignment->replica_count());
844 CHECK(!argument_layouts.has_value());
845
846 // TODO(henrytan): an area for optimization with less buffer copy.
847 xla::HloProto hlo_proto;
848 *hlo_proto.mutable_hlo_module() = computation.proto();
849
850 // TODO(henrytan): in the future, we want to consider argument Layout
851 // information e.g. for linearization.
852 std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program =
853 client->driver()->CompileProgram(hlo_proto, options.num_replicas(), {});
854
855 ::xla::Shape result_layout;
856 if (options.result_layout()) {
857 result_layout = *options.result_layout();
858 } else {
859 xla::ProgramShapeProto program_shape_proto;
860 auto fetch_metadata_status =
861 compiled_program->program_shape(&program_shape_proto);
862
863 if (!fetch_metadata_status.ok()) {
864 return Status(
865 static_cast<tensorflow::error::Code>(fetch_metadata_status.code()),
866 fetch_metadata_status.error_message());
867 }
868 result_layout = ::xla::Shape(program_shape_proto.result());
869 }
870 VLOG(1) << "Got result shape: " << result_layout.DebugString();
871
872 return std::make_unique<PyTpuExecutable>(
873 std::move(compiled_program), std::move(*device_assignment),
874 std::move(client), std::move(result_layout), tuple_arguments);
875 }
876
877 /*static*/ StatusOr<std::unique_ptr<PyTpuExecutable>>
CompileMlir(mlir::ModuleOp module,std::optional<std::vector<Shape>> argument_layouts,const ExecutableBuildOptions * build_options,std::shared_ptr<PyTpuClient> client,bool tuple_arguments)878 PyTpuExecutable::CompileMlir(mlir::ModuleOp module,
879 std::optional<std::vector<Shape>> argument_layouts,
880 const ExecutableBuildOptions* build_options,
881 std::shared_ptr<PyTpuClient> client,
882 bool tuple_arguments) {
883 XlaComputation xla_computation;
884 TF_RETURN_IF_ERROR(MlirToXlaComputation(module, xla_computation,
885 /*use_tuple_args=*/tuple_arguments,
886 /*return_tuple=*/false));
887 return Compile(xla_computation, argument_layouts, build_options, client,
888 tuple_arguments);
889 }
890
891 } // namespace xla
892