1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h"
17
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/synchronization/mutex.h"
24 #include "absl/time/time.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
28 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
29 #include "tensorflow/compiler/xla/service/computation_placer.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/threadpool.h"
34 #include "tensorflow/core/profiler/lib/traceme.h"
35
36 namespace xla {
37
TpuDevice(int id,int task_id,const std::array<int,3> & coords,int core_on_chip)38 TpuDevice::TpuDevice(int id, int task_id, const std::array<int, 3>& coords,
39 int core_on_chip)
40 : id_(id),
41 task_id_(task_id),
42 coords_(coords),
43 core_on_chip_(core_on_chip) {}
44
DebugString() const45 std::string TpuDevice::DebugString() const {
46 return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(),
47 coords_[0], coords_[1], coords_[2], core_on_chip_);
48 }
49
50 xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
GetTpuDevices(const tpu_driver::SystemInfo & system_info)51 TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
52 std::vector<std::shared_ptr<PjRtDevice>> devices;
53 for (const auto& chip : system_info.tpu_chip()) {
54 auto& coord = chip.chip_coord();
55 std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
56 int task_id = chip.host_id();
57 for (const auto& core : chip.core()) {
58 auto device = std::make_shared<TpuDevice>(
59 core.id(), task_id, coords_array, core.core_on_chip_index());
60 devices.push_back(device);
61 }
62 }
63
64 return devices;
65 }
66
Get(const std::string & worker)67 StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
68 const std::string& worker) {
69 tpu_driver::TpuDriverConfig driver_config;
70 driver_config.set_worker(worker);
71 auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
72 if (!client_status.ok()) {
73 return client_status.status();
74 }
75
76 auto client = client_status.ConsumeValueOrDie();
77
78 tpu_driver::SystemInfo system_info;
79 client->QuerySystemInfo(&system_info);
80
81 TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<PjRtDevice>> devices,
82 TpuDevice::GetTpuDevices(system_info));
83
84 return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
85 std::move(devices),
86 system_info.host_id());
87 }
88
PyTpuClient(std::string platform_name,std::unique_ptr<tpu_driver::TpuDriver> driver,std::vector<std::shared_ptr<PjRtDevice>> devices,int task_id)89 PyTpuClient::PyTpuClient(std::string platform_name,
90 std::unique_ptr<tpu_driver::TpuDriver> driver,
91 std::vector<std::shared_ptr<PjRtDevice>> devices,
92 int task_id)
93 : platform_name_(std::move(platform_name)),
94 driver_(std::move(driver)),
95 devices_(std::move(devices)),
96 task_id_(task_id) {
97 for (const std::shared_ptr<PjRtDevice>& device : devices_) {
98 CHECK(id_to_device_.insert({device->id(), device}).second)
99 << "Duplicate device id: " << device->id();
100
101 if (device->task_id() == task_id_) {
102 LOG(INFO) << "Detected local device, host id: " << task_id_
103 << ". device id: " << device->id();
104 local_devices_.push_back(device);
105 } else {
106 VLOG(2) << "Other devices, device id: " << device->id();
107 }
108 }
109 CHECK_GE(local_devices_.size(), 1);
110 LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
111
112 for (int idx = 0; idx < local_devices_.size(); ++idx) {
113 CHECK(local_devices_[idx] != nullptr) << idx;
114 }
115
116 // TODO(frankchn): Check if thread pool size needs to be adjusted (perhaps
117 // something like min(cores, devices_.size()) might be appropriate depending
118 // on the number of devices.
119 pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
120 tensorflow::Env::Default(), "PyTpuClient", devices_.size());
121 }
122
TransferToInfeed(const LiteralSlice & literal,int device_id)123 Status PyTpuClient::TransferToInfeed(const LiteralSlice& literal,
124 int device_id) {
125 return Unimplemented("Infeed not implemented.");
126 }
127
TransferFromOutfeed(const Shape & shape,int device_id)128 StatusOr<Literal> PyTpuClient::TransferFromOutfeed(const Shape& shape,
129 int device_id) {
130 return Unimplemented("Outfeed not implemented.");
131 }
132
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const133 StatusOr<DeviceAssignment> PyTpuClient::GetDefaultDeviceAssignment(
134 int num_replicas, int num_partitions) const {
135 if (num_partitions > 1) {
136 return InvalidArgument("Num partitions greater than 1, is not supported.");
137 }
138 if (num_replicas * num_partitions <= local_device_count()) {
139 DeviceAssignment assignment(num_replicas, num_partitions);
140 for (int replica = 0; replica < num_replicas; ++replica) {
141 for (int partition = 0; partition < num_partitions; ++partition) {
142 assignment(replica, partition) = local_devices_[replica]->id();
143 }
144 }
145 return assignment;
146 }
147
148 // Fallback to default global device assignment if we can't run locally.
149 xla::ComputationPlacer placer;
150 return placer.AssignDevices(num_replicas, num_partitions);
151 }
152
CheckDeviceId(int device_id,absl::string_view caller_name)153 Status PyTpuClient::CheckDeviceId(int device_id,
154 absl::string_view caller_name) {
155 if (device_id < 0 || device_id >= device_count()) {
156 return InvalidArgument("%s got bad device_id: %d (num_devices=%d).",
157 caller_name, device_id, device_count());
158 }
159 return Status::OK();
160 }
161
CheckDataType(xla::PrimitiveType dtype)162 static Status CheckDataType(xla::PrimitiveType dtype) {
163 if (dtype == xla::PrimitiveType::F64 || dtype == xla::PrimitiveType::S64 ||
164 dtype == xla::PrimitiveType::U64) {
165 return InvalidArgument(
166 "64-bit data types are not yet supported on the TPU driver API. "
167 "Convert inputs to float32/int32 before using.");
168 }
169 return Status::OK();
170 }
171
172 /* static */
FromLiterals(std::vector<BorrowingLiteral> leaves,const Shape & tuple_shape,std::shared_ptr<void> leaves_references,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)173 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
174 std::vector<BorrowingLiteral> leaves, const Shape& tuple_shape,
175 std::shared_ptr<void> leaves_references,
176 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
177 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals");
178 VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString()
179 << " device: " << device->DebugString();
180 TF_RETURN_IF_ERROR(
181 client->CheckDeviceId(device->id(), "PyTpuBuffer::FromLiterals"));
182 tpu_driver::TpuDriver* driver = client->driver();
183
184 if (!tuple_shape.IsTuple()) {
185 TF_RET_CHECK(leaves.size() == 1);
186 return CreateBuffer(
187 tuple_shape,
188 [driver, &leaves, &tuple_shape,
189 leaves_references](tpu_driver::BufferHandle* handle) {
190 auto event =
191 driver->TransferToDevice(leaves[0].untyped_data(), handle, {});
192 event->AddCallback([leaves_references](Status) {});
193 return event;
194 },
195 std::move(client), std::move(device));
196 }
197
198 std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
199 child_buffers.reserve(leaves.size());
200 std::vector<PyTpuBuffer*> child_buffer_ptrs;
201 child_buffer_ptrs.reserve(leaves.size());
202
203 auto it_leaf = leaves.begin();
204 for (const ShapeUtil::IndexedShape& indexed_shape :
205 ShapeUtil::GetLeafShapes(tuple_shape)) {
206 TF_RET_CHECK(it_leaf != leaves.end());
207 auto& leaf = *it_leaf;
208 TF_ASSIGN_OR_RETURN(
209 std::unique_ptr<PyTpuBuffer> child_buffer,
210 CreateBuffer(
211 indexed_shape.shape,
212 [driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) {
213 return driver->TransferToDevice(leaf.untyped_data(), handle, {});
214 },
215 client, device));
216 child_buffer_ptrs.push_back(child_buffer.get());
217 child_buffers.push_back(std::move(child_buffer));
218 ++it_leaf;
219 }
220 TF_RET_CHECK(it_leaf == leaves.end());
221
222 // `MakeTuple` will extract and make the tuple buffer hold onto the
223 // `device_buffer_` contained in each `child_buffer`, so it's safe for
224 // `child_buffers` to get destroyed before this call returns.
225 return MakeTuple(std::move(child_buffer_ptrs), std::move(client),
226 std::move(device));
227 }
228
229 /* static */
MakeTuple(absl::Span<PyTpuBuffer * const> buffers,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)230 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
231 absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
232 std::shared_ptr<PjRtDevice> device) {
233 std::vector<Shape> child_shapes;
234 std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
235 std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
236 std::vector<std::shared_ptr<tpu_driver::Event>> child_events;
237
238 for (const auto& child_buffer : buffers) {
239 child_shapes.push_back(child_buffer->on_host_shape());
240 std::shared_ptr<TpuSharedBuffer> child_device_buffer =
241 child_buffer->DeviceBuffer();
242 // Merge all definition events from all children, so that anyone using this
243 // tuple must wait for all its children to finish receiving transfers. This
244 // works recursively up a nested tuple tree as well.
245 for (std::shared_ptr<tpu_driver::Event> child_event :
246 child_device_buffer->wait_for_use) {
247 child_events.push_back(std::move(child_event));
248 }
249 child_handle_ptrs.push_back(child_device_buffer->handle.get());
250 child_device_buffers.push_back(std::move(child_device_buffer));
251 }
252
253 Shape tuple_shape = ShapeUtil::MakeTupleShape(child_shapes);
254 std::unique_ptr<tpu_driver::BufferHandle> tuple_handle =
255 client->driver()->AllocateTuple(
256 device->id(), tpu_driver::MemoryRegion::HBM, child_handle_ptrs, {});
257 auto tuple_device_buffer = std::make_shared<TpuSharedBuffer>(
258 client->driver(), std::move(tuple_handle), std::move(child_events),
259 std::move(device));
260 return absl::make_unique<PyTpuBuffer>(
261 tuple_shape, std::move(tuple_device_buffer),
262 std::move(child_device_buffers), std::move(client));
263 }
264
PyTpuBuffer(Shape on_host_shape,std::shared_ptr<TpuSharedBuffer> device_buffer,std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,std::shared_ptr<PyTpuClient> client)265 PyTpuBuffer::PyTpuBuffer(
266 Shape on_host_shape, std::shared_ptr<TpuSharedBuffer> device_buffer,
267 std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
268 std::shared_ptr<PyTpuClient> client)
269 : client_(std::move(client)),
270 on_host_shape_(std::move(on_host_shape)),
271 device_(device_buffer->device),
272 device_buffer_(std::move(device_buffer)),
273 child_buffers_(std::move(child_buffers)) {}
274
Delete()275 void PyTpuBuffer::Delete() {
276 absl::MutexLock lock(&mu_);
277 device_buffer_ = nullptr;
278 child_buffers_.clear();
279 host_value_ = nullptr;
280 }
281
CopyToHostAsync()282 Status PyTpuBuffer::CopyToHostAsync() {
283 std::vector<std::shared_ptr<tpu_driver::Event>> transfer_events;
284 std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
285
286 {
287 absl::MutexLock lock(&mu_);
288 if (!device_buffer_) {
289 return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
290 }
291
292 if (host_value_) {
293 // The host value has already been requested or is available.
294 return Status::OK();
295 }
296
297 host_value->value = std::make_shared<Literal>(on_host_shape_);
298 host_value->pending_ops = std::max(1ul, child_buffers_.size());
299 host_value_ = host_value;
300
301 std::vector<tpu_driver::Event*> events;
302 for (const auto& e : device_buffer_->wait_for_use) {
303 events.push_back(e.get());
304 }
305
306 VLOG(1) << "CopyToHostAsync:: host shape: "
307 << host_value->value->shape().DebugString();
308
309 if (!on_host_shape_.IsTuple()) {
310 CHECK(child_buffers_.empty());
311 transfer_events.push_back(client_->driver()->TransferFromDevice(
312 device_buffer_->handle.get(), host_value->value->untyped_data(),
313 events));
314 } else {
315 for (int i = 0; i < child_buffers_.size(); ++i) {
316 auto& c = child_buffers_[i];
317 transfer_events.push_back(client_->driver()->TransferFromDevice(
318 c->handle.get(),
319 host_value->value->untyped_data(xla::ShapeIndex({i})), events));
320 }
321 }
322 }
323
324 for (auto& t : transfer_events) {
325 t->AddCallback([host_value](const xla::Status& status) {
326 VLOG(1) << "Device to host transfer finished.";
327 if (!status.ok()) {
328 host_value->status =
329 Status(static_cast<tensorflow::error::Code>(status.code()),
330 status.error_message());
331 }
332
333 absl::MutexLock m(&host_value->mutex);
334 --host_value->pending_ops;
335 if (host_value->pending_ops == 0) {
336 VLOG(1) << "Host value done: " << host_value->status;
337 host_value->ready.Notify();
338 }
339 });
340 }
341 return Status::OK();
342 }
343
ToLiteral()344 StatusOr<std::shared_ptr<Literal>> PyTpuBuffer::ToLiteral() {
345 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::ToLiteral");
346 TF_RETURN_IF_ERROR(CopyToHostAsync());
347
348 mu_.Lock();
349 std::shared_ptr<HostValue> host_value = host_value_;
350 mu_.Unlock();
351
352 VLOG(1) << "Waiting for device to host transfer " << host_value.get();
353 host_value->ready.WaitForNotification();
354 VLOG(1) << "Host copy finished, status:" << host_value->status;
355 TF_RETURN_IF_ERROR(host_value->status);
356
357 return host_value->value;
358 }
359
DeviceBuffer() const360 std::shared_ptr<TpuSharedBuffer> PyTpuBuffer::DeviceBuffer() const {
361 absl::MutexLock lock(&mu_);
362 return device_buffer_;
363 }
364
365 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
DestructureTuple()366 PyTpuBuffer::DestructureTuple() {
367 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::DestructureTuple");
368 if (!on_host_shape_.IsTuple()) {
369 return InvalidArgument(
370 "Attempted to destructure a PyTpuBuffer that did not have a tuple "
371 "shape; shape: %s.",
372 ShapeUtil::HumanString(on_host_shape_));
373 }
374 if (DeviceBuffer() == nullptr) {
375 return InvalidArgument("Attempted to destructure a deleted buffer.");
376 }
377
378 absl::MutexLock lock(&mu_);
379 int num_children = ShapeUtil::TupleElementCount(on_host_shape_);
380 std::vector<std::unique_ptr<PyTpuBuffer>> results;
381 results.reserve(num_children);
382 for (int i = 0; i < num_children; ++i) {
383 results.push_back(absl::make_unique<PyTpuBuffer>(
384 on_host_shape_.tuple_shapes(i), child_buffers_.at(i),
385 std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_));
386 }
387 return results;
388 }
389
CopyToDevice(std::shared_ptr<PjRtDevice> dst_device)390 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
391 std::shared_ptr<PjRtDevice> dst_device) {
392 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice");
393 if (on_host_shape_.IsTuple()) {
394 return Unimplemented("CopyToDevice for tuples is not supported.");
395 }
396
397 std::shared_ptr<TpuSharedBuffer> src_device_buffer = DeviceBuffer();
398 if (dst_device->id() == device_->id()) {
399 return absl::make_unique<PyTpuBuffer>(
400 on_host_shape_, src_device_buffer,
401 std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_);
402 }
403
404 tpu_driver::TpuDriver* driver = client_->driver();
405 TF_ASSIGN_OR_RETURN(
406 std::unique_ptr<PyTpuBuffer> dst_buffer,
407 CreateBuffer(
408 on_host_shape_,
409 [driver, src_device_buffer](tpu_driver::BufferHandle* dst_handle) {
410 std::vector<tpu_driver::Event*> src_wait_for_use;
411 for (auto& event : src_device_buffer->wait_for_use) {
412 src_wait_for_use.push_back(event.get());
413 }
414 return driver->TransferFromDeviceToDevice(
415 src_device_buffer->handle.get(), dst_handle, src_wait_for_use);
416 },
417 client_, std::move(dst_device)));
418 // TODO(jiawenhao): This may be too pessimistic: it prevents future readers
419 // from reading `src_device_buffer` until the device-to-device copy is done.
420 // Should this go into a new `TpuSharedBuffer::wait_for_dealloc` field?
421 auto& wait_for_use = dst_buffer->DeviceBuffer()->wait_for_use;
422 src_device_buffer->wait_for_use.insert(src_device_buffer->wait_for_use.end(),
423 wait_for_use.begin(),
424 wait_for_use.end());
425 return dst_buffer;
426 }
427
BlockHostUntilReady()428 Status PyTpuBuffer::BlockHostUntilReady() {
429 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::BlockHostUntilReady");
430 std::shared_ptr<TpuSharedBuffer> device_buffer = DeviceBuffer();
431 if (!device_buffer) {
432 return InvalidArgument(
433 "BlockHostUntilReady() called on deleted or donated buffer");
434 }
435 return device_buffer->handle->OnReady()->Await();
436 }
437
438 /* static */
AllocateBuffer(const Shape & shape,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)439 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
440 const Shape& shape, std::shared_ptr<PyTpuClient> client,
441 std::shared_ptr<PjRtDevice> device) {
442 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer");
443 VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString()
444 << " device: " << device->DebugString();
445
446 if (!shape.IsTuple()) {
447 return CreateBuffer(shape, absl::nullopt, std::move(client),
448 std::move(device));
449 }
450
451 std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
452 child_buffers.reserve(shape.tuple_shapes().size());
453 std::vector<PyTpuBuffer*> child_buffer_ptrs;
454 child_buffer_ptrs.reserve(shape.tuple_shapes().size());
455
456 for (const auto& child_shape : shape.tuple_shapes()) {
457 TF_ASSIGN_OR_RETURN(std::unique_ptr<PyTpuBuffer> child_buffer,
458 AllocateBuffer(child_shape, client, device));
459 child_buffer_ptrs.push_back(child_buffer.get());
460 child_buffers.push_back(std::move(child_buffer));
461 }
462
463 // `MakeTuple` will extract and make the tuple buffer hold onto the
464 // `device_buffer_` contained in each `child_buffer`, so it's safe for
465 // `child_buffers` to get destroyed before this call returns.
466 return PyTpuBuffer::MakeTuple(child_buffer_ptrs, std::move(client),
467 std::move(device));
468 }
469
470 /*static*/
CreateBuffer(const Shape & non_tuple_shape,absl::optional<BufferInitializer> initializer,std::shared_ptr<PyTpuClient> client,std::shared_ptr<PjRtDevice> device)471 StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
472 const Shape& non_tuple_shape, absl::optional<BufferInitializer> initializer,
473 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
474 tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer");
475 VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: "
476 << non_tuple_shape.DebugString()
477 << " device: " << device->DebugString();
478 TF_RET_CHECK(!non_tuple_shape.IsTuple());
479 TF_RETURN_IF_ERROR(CheckDataType(non_tuple_shape.element_type()));
480
481 std::unique_ptr<tpu_driver::BufferHandle> handle =
482 client->driver()->Allocate(device->id(), tpu_driver::MemoryRegion::HBM,
483 non_tuple_shape.ToProto(), {});
484
485 // If this buffer needs to be initialized, anyone using this buffer must wait
486 // for the initialization event in `wait_for_use` to finish first.
487 std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
488 if (initializer.has_value()) {
489 std::shared_ptr<tpu_driver::Event> init = initializer.value()(handle.get());
490 wait_for_use.push_back(std::move(init));
491 }
492 auto device_buffer = std::make_shared<TpuSharedBuffer>(
493 client->driver(), std::move(handle), std::move(wait_for_use),
494 std::move(device));
495
496 return absl::make_unique<PyTpuBuffer>(
497 non_tuple_shape, std::move(device_buffer),
498 std::vector<std::shared_ptr<TpuSharedBuffer>>(), client);
499 }
500
LookupDevice(const PyTpuClient & client,int device_id)501 static std::shared_ptr<PjRtDevice> LookupDevice(const PyTpuClient& client,
502 int device_id) {
503 auto it = client.id_to_device().find(device_id);
504 CHECK(it != client.id_to_device().end())
505 << "Unknown device id: " << device_id;
506 return it->second;
507 }
508
PyTpuExecutable(std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,DeviceAssignment device_assignment,std::shared_ptr<PyTpuClient> client,xla::Shape result_shape,bool tuple_arguments)509 PyTpuExecutable::PyTpuExecutable(
510 std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
511 DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
512 xla::Shape result_shape, bool tuple_arguments)
513 : client_(std::move(client)),
514 device_assignment_(std::move(device_assignment)),
515 tuple_arguments_(tuple_arguments),
516 result_shape_(std::move(result_shape)) {
517 VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString();
518 const int num_replicas = device_assignment_.replica_count();
519 const int num_partitions = device_assignment_.computation_count();
520 CHECK_EQ(num_partitions, 1) << "partition count > 1 is not supported.";
521 for (int replica = 0; replica < num_replicas; ++replica) {
522 for (int partition = 0; partition < num_partitions; ++partition) {
523 int device_id = device_assignment_(replica, partition);
524 std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
525 if (device->task_id() != client_->task_id()) {
526 VLOG(3) << "Non-local device: " << device_id;
527 continue;
528 }
529 // TODO(b/147895917): support replica + partition natively.
530 CHECK(executables_.find(replica) == executables_.end())
531 << "Inserting duplicate replica:" << replica;
532 executables_[replica] =
533 client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
534 local_logical_device_ids_.emplace_back(replica, partition);
535 local_devices_.push_back(device);
536 }
537 }
538 CHECK_GE(local_devices_.size(), 1);
539 CHECK_LE(executables_.size(), client_->device_count());
540 CHECK_LE(local_devices_.size(), client_->local_device_count())
541 << "Inconsistent local device count.";
542 }
543
ExecuteHelper(absl::Span<const std::vector<PyTpuBuffer * >> all_core_arguments,absl::Span<PyTpuBuffer * const> this_core_arguments,int replica,int partition,const RunId & run_id)544 PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
545 absl::Span<const std::vector<PyTpuBuffer*>> all_core_arguments,
546 absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
547 int partition, const RunId& run_id) {
548 const int device_id = device_assignment_(replica, partition);
549 std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
550 CHECK_EQ(device->task_id(), client_->task_id());
551 tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
552 VLOG(3) << "Replica " << replica << ", partition " << partition
553 << " mapped to device id for execution: " << device_id;
554
555 std::unique_ptr<::xla::PyTpuBuffer> output_buffer =
556 ::xla::PyTpuBuffer::AllocateBuffer(result_shape_, client_,
557 std::move(device))
558 .ValueOrDie();
559 VLOG(1) << "Created output buffer: " << result_shape_.DebugString();
560
561 std::vector<tpu_driver::BufferHandle*> inputs;
562 std::vector<tpu_driver::Event*> ready_to_execute;
563
564 std::shared_ptr<tpu_driver::Event> output_buffer_ready =
565 output_buffer->DeviceBuffer()->handle->OnReady();
566
567 ready_to_execute.push_back(output_buffer_ready.get());
568
569 for (auto* input_handle : this_core_arguments) {
570 inputs.push_back(input_handle->DeviceBuffer()->handle.get());
571 }
572
573 for (const auto& core_args : all_core_arguments) {
574 for (const auto* handle : core_args) {
575 for (const auto& pending_event : handle->DeviceBuffer()->wait_for_use) {
576 ready_to_execute.push_back(pending_event.get());
577 }
578 }
579 }
580
581 xla::DeviceAssignmentProto device_assignment;
582 CHECK(device_assignment_.Serialize(&device_assignment).ok());
583 std::shared_ptr<tpu_driver::Event> on_execute_finished =
584 client_->driver()->ExecuteProgram(
585 executables_.find(replica)->second.get(), inputs,
586 {output_buffer->DeviceBuffer()->handle.get()}, device_assignment,
587 {ready_to_execute});
588
589 return {std::move(output_buffer), std::move(on_execute_finished)};
590 }
591
592 // Delay before warning about a slow execute call.
593 static const absl::Duration kWarnExecutionDelay = absl::Seconds(10);
594
595 // Delay before terminating a stalled execute call.
596 static const absl::Duration kMaxExecutionDelay = absl::Minutes(60);
597
WaitForExecuteEvent(tpu_driver::Event * event)598 Status WaitForExecuteEvent(tpu_driver::Event* event) {
599 absl::optional<Status> opt_status;
600 auto start_time = absl::Now();
601
602 while (!opt_status.has_value() &&
603 absl::Now() - start_time < kMaxExecutionDelay) {
604 opt_status = event->AwaitWithTimeout(kWarnExecutionDelay);
605 if (!opt_status.has_value()) {
606 LOG(WARNING)
607 << "TPU Execute is taking a long time. This might be due to a "
608 "deadlock between multiple TPU cores or a very slow program.";
609 }
610 }
611
612 if (!opt_status.has_value()) {
613 return tensorflow::errors::DeadlineExceeded(
614 absl::StrFormat("TPU program took more than %d seconds to complete.",
615 absl::ToInt64Seconds(kMaxExecutionDelay)));
616 }
617
618 return opt_status.value();
619 }
620
Execute(absl::Span<PyTpuBuffer * const> argument_handles)621 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
622 absl::Span<PyTpuBuffer* const> argument_handles) {
623 if (num_replicas() != 1) {
624 return InvalidArgument(
625 "Attempted to execute computation with %d replicas using Execute().",
626 num_replicas());
627 }
628 if (num_partitions() != 1) {
629 return InvalidArgument(
630 "Attempted to execute computation with %d partitions using Execute().",
631 num_partitions());
632 }
633
634 std::vector<PyTpuBuffer*> all_core_arguments;
635
636 std::unique_ptr<PyTpuBuffer> tupled_arguments;
637 if (tuple_arguments_) {
638 TF_ASSIGN_OR_RETURN(tupled_arguments,
639 PyTpuBuffer::MakeTuple(argument_handles, client_,
640 local_devices_.front()));
641 all_core_arguments = {tupled_arguments.get()};
642 } else {
643 all_core_arguments = std::vector<PyTpuBuffer*>(argument_handles.begin(),
644 argument_handles.end());
645 }
646 ExecuteResult result =
647 ExecuteHelper(absl::MakeSpan(&all_core_arguments, 1), argument_handles,
648 /*replica=*/0, /*partition=*/0, RunId());
649
650 Status status = WaitForExecuteEvent(result.on_execute_finished.get());
651
652 if (!status.ok()) {
653 LOG(ERROR) << "Failed to execute program: " << status;
654 return status;
655 }
656
657 if (result.buffer->on_host_shape().IsTuple()) {
658 return result.buffer->DestructureTuple();
659 } else {
660 std::vector<std::unique_ptr<PyTpuBuffer>> outputs;
661 outputs.push_back(std::move(result.buffer));
662 return outputs;
663 }
664 }
665
666 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> argument_handles)667 PyTpuExecutable::ExecuteOnLocalDevices(
668 absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
669 tensorflow::profiler::TraceMe traceme(
670 "PyTpuExecutable::ExecuteOnLocalDevices");
671
672 const int num_local_devices = local_devices_.size();
673
674 if (argument_handles.size() != num_local_devices) {
675 return InvalidArgument(
676 "Attempted to execute with %d argument lists when local device "
677 "count is %d (total replica count: %d, partition count: %d).",
678 argument_handles.size(), num_local_devices, num_replicas(),
679 num_partitions());
680 }
681
682 VLOG(1) << "Executing computation; num_replicas=" << num_replicas()
683 << " num_partitions=" << num_partitions()
684 << " num_local_devices=" << num_local_devices;
685
686 std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
687 std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
688 if (tuple_arguments_) {
689 tupled_arguments.resize(argument_handles.size());
690 tupled_argument_pointers.resize(argument_handles.size());
691 for (int i = 0; i < num_local_devices; ++i) {
692 TF_ASSIGN_OR_RETURN(tupled_arguments[i],
693 PyTpuBuffer::MakeTuple(argument_handles[i], client_,
694 local_devices_.at(i)));
695 tupled_argument_pointers[i] = {tupled_arguments[i].get()};
696 }
697 argument_handles = tupled_argument_pointers;
698 }
699
700 absl::Mutex results_lock;
701 std::vector<ExecuteResult> results(num_local_devices);
702
703 auto* thread_pool = client_->GetThreadPool();
704
705 int failed = 0;
706 Status first_failure_status;
707
708 xla::Semaphore execute_semaphore(0);
709 for (int i = 0; i < num_local_devices; ++i) {
710 // We are scheduling Execute on a thread pool as ExecuteHelper can take a
711 // long time and we want all cores to be scheduled in parallel.
712 thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
713 &execute_semaphore]() {
714 const int replica = local_logical_device_ids_[i].first;
715 const int partition = local_logical_device_ids_[i].second;
716 RunId run_id;
717 auto result = ExecuteHelper(argument_handles, argument_handles[i],
718 replica, partition, run_id);
719 results[i] = std::move(result);
720 execute_semaphore.Release(1);
721 });
722 }
723
724 execute_semaphore.Acquire(num_local_devices);
725
726 for (int i = 0; i < num_local_devices; ++i) {
727 auto s = WaitForExecuteEvent(results[i].on_execute_finished.get());
728 if (!s.ok()) {
729 if (failed == 0) {
730 first_failure_status = Status(
731 static_cast<tensorflow::error::Code>(s.code()), s.error_message());
732 }
733 ++failed;
734 }
735 }
736 if (failed > 0) {
737 return first_failure_status;
738 }
739 VLOG(1) << "Replicated execution complete.";
740
741 std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> wrapped_results(
742 num_local_devices);
743 for (int i = 0; i < num_local_devices; ++i) {
744 if (results[i].buffer->on_host_shape().IsTuple()) {
745 TF_ASSIGN_OR_RETURN(wrapped_results[i],
746 results[i].buffer->DestructureTuple());
747 } else {
748 wrapped_results[i].push_back(std::move(results[i].buffer));
749 }
750 }
751 return wrapped_results;
752 }
753
754 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyTpuBuffer * >> args)755 PyTpuExecutable::ExecuteShardedOnLocalDevices(
756 absl::Span<const std::vector<PyTpuBuffer*>> args) {
757 std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> output_buffers;
758 TF_RET_CHECK(!args.empty());
759 int num_computations = args.front().size();
760 for (const auto& arg : args) {
761 if (arg.size() != num_computations) {
762 return xla::InvalidArgument(
763 "Expected args to execute_sharded_on_local_devices to have %d "
764 "shards, got: [%s]",
765 num_computations,
766 absl::StrJoin(
767 args, ", ",
768 [](std::string* out, const std::vector<PyTpuBuffer*>& arg) {
769 out->append(std::to_string(arg.size()));
770 }));
771 }
772 }
773 std::vector<std::vector<PyTpuBuffer*>> arg_buffers(num_computations);
774 for (int computation = 0; computation < num_computations; ++computation) {
775 arg_buffers[computation].resize(args.size());
776 absl::c_transform(
777 args, arg_buffers[computation].begin(),
778 [&](const std::vector<PyTpuBuffer*>& arg) { return arg[computation]; });
779 }
780 TF_ASSIGN_OR_RETURN(output_buffers, ExecuteOnLocalDevices(arg_buffers));
781 int num_output_buffers = output_buffers[0].size();
782 std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> outputs;
783 outputs.resize(num_output_buffers);
784 for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
785 outputs[buffer_id].reserve(num_computations);
786 for (int computation = 0; computation < num_computations; ++computation) {
787 outputs[buffer_id].push_back(
788 std::move(output_buffers[computation][buffer_id]));
789 }
790 }
791 return outputs;
792 }
793
Compile(const XlaComputation & computation,absl::optional<std::vector<Shape>> argument_layouts,const ExecutableBuildOptions * build_options,std::shared_ptr<PyTpuClient> client,bool tuple_arguments)794 /*static*/ StatusOr<std::unique_ptr<PyTpuExecutable>> PyTpuExecutable::Compile(
795 const XlaComputation& computation,
796 absl::optional<std::vector<Shape>> argument_layouts,
797 const ExecutableBuildOptions* build_options,
798 std::shared_ptr<PyTpuClient> client, bool tuple_arguments) {
799 tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
800
801 VLOG(1) << "Compile: "
802 << computation.GetProgramShape().ValueOrDie().DebugString();
803
804 // TODO(power) -- handle argument layouts
805 // TODO(power) -- handle build options
806 ExecutableBuildOptions options;
807 if (build_options != nullptr) {
808 options = *build_options;
809 }
810 absl::optional<xla::DeviceAssignment> device_assignment;
811
812 // For POD use case, the device_assignment.num_replicas() may be greater than
813 // the number of available local devices, where applicable the non-local
814 // devices must be filtered out from participating local computation.
815 if (options.has_device_assignment()) {
816 if (options.device_assignment().replica_count() != options.num_replicas()) {
817 return InvalidArgument(
818 "Mismatched number of replicas for device "
819 "assignment and computation (%d vs %d).",
820 options.device_assignment().replica_count(), options.num_replicas());
821 } else if (options.device_assignment().computation_count() != 1) {
822 return Unimplemented(
823 "Only 1 computation per replica supported, %d requested.",
824 options.device_assignment().computation_count());
825 }
826 device_assignment = options.device_assignment();
827 } else {
828 TF_ASSIGN_OR_RETURN(device_assignment,
829 client->GetDefaultDeviceAssignment(
830 options.num_replicas(), options.num_partitions()));
831 }
832 CHECK_GE(options.num_replicas(), 1);
833 CHECK_EQ(options.num_replicas(), device_assignment->replica_count());
834 CHECK(!argument_layouts.has_value());
835
836 // TODO(henrytan): an area for optimization with less buffer copy.
837 xla::HloProto hlo_proto;
838 *hlo_proto.mutable_hlo_module() = computation.proto();
839
840 // TODO(henrytan): in the future, we want to consider argument Layout
841 // information e.g. for linearization.
842 std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program =
843 client->driver()->CompileProgram(hlo_proto, options.num_replicas(), {});
844
845 ::xla::Shape result_layout;
846 if (options.result_layout()) {
847 result_layout = *options.result_layout();
848 } else {
849 xla::ProgramShapeProto program_shape_proto;
850 auto fetch_metadata_status =
851 compiled_program->program_shape(&program_shape_proto);
852
853 if (!fetch_metadata_status.ok()) {
854 return Status(
855 static_cast<tensorflow::error::Code>(fetch_metadata_status.code()),
856 fetch_metadata_status.error_message());
857 }
858 result_layout = ::xla::Shape(program_shape_proto.result());
859 }
860 VLOG(1) << "Got result shape: " << result_layout.DebugString();
861
862 return absl::make_unique<PyTpuExecutable>(
863 std::move(compiled_program), std::move(*device_assignment),
864 std::move(client), std::move(result_layout), tuple_arguments);
865 }
866
867 } // namespace xla
868