1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Asynchronous execution:
19 // -----------------------
20 //
21 // Computations and host-to-device transfers do not need to block the host
22 // waiting for the operation to complete but instead return control to the host
23 // immediately. This allows client logic to overlap with device-side
24 // computation.
25 //
26 // For a good user experience, we must be careful only to enqueue operations
27 // that are unlikely to fail; as a rule error checking must be done eagerly
28 // before returning control to the client.
29 //
30 // The degree to which the client can enqueue operations ahead of the client
31 // is limited by a semaphore. There are at two modes: asynchronous, where we
32 // allow the client to enqueue up to 32 executions ahead of the device, and
33 // synchronous, where we limit the client to having one enqueued operation at
34 // a time. The value of 32 is arbitrary.
35 //
36 // Even in asynchronous mode, it is important that we do not permit
37 // unbounded queue-ahead. Firstly it is problematic when the user does something
38 // like the following in Python:
39 // %timeit run_computation()
40 // To the timeit logic, op() appears to be extremely cheap since it is deferring
41 // all of its real work and not blocking, and so the %timeit will run op() many
42 // (e.g., 10000) times to get better timing resolution, even though in reality
43 // it may be expensive. Secondly, on CPU the allocator is synchronized with the
44 // head of the compute stream, and we allocate buffers for all of the enqueued
45 // programs without any reuse (unlike GPU). This means that the memory usage
46 // is proportional to the queue size.
47 //
48 // Multi-stream execution:
49 // -----------------------
50 //
51 // We use a multistream execution design, where different Streams are used for
52 // host-to-device transfers, device-to-host transfers, and compute. This allows
53 // us to overlap transfers on and off the device with computation.
54 //
55 // Synchronization between streams occurs via BufferSequencingEvents that
56 // describe when the contents of a logical buffer are known to be valid on
57 // a particular stream, and when a buffer's uses have all completed.
58 //
59 // Synchronous vs asynchronous deallocation:
60 // -----------------------------------------
61 //
62 // See the comment on LocalDeviceState::AllocationModel for a discussion of the
63 // different allocation semantics on CPU, GPU, and TPU.
64
65 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
66
67 #include <cstddef>
68 #include <cstdlib>
69 #include <memory>
70 #include <string>
71 #include <vector>
72
73 #include "absl/base/casts.h"
74 #include "absl/container/flat_hash_set.h"
75 #include "absl/container/inlined_vector.h"
76 #include "absl/memory/memory.h"
77 #include "absl/strings/str_format.h"
78 #include "absl/synchronization/mutex.h"
79 #include "absl/time/time.h"
80 #include "absl/types/optional.h"
81 #include "absl/types/span.h"
82 #include "tensorflow/compiler/xla/client/local_client.h"
83 #include "tensorflow/compiler/xla/client/xla_computation.h"
84 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
85 #include "tensorflow/compiler/xla/executable_run_options.h"
86 #include "tensorflow/compiler/xla/layout.h"
87 #include "tensorflow/compiler/xla/literal.h"
88 #include "tensorflow/compiler/xla/literal_util.h"
89 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
90 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
91 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
92 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
93 #include "tensorflow/compiler/xla/pjrt/utils.h"
94 #include "tensorflow/compiler/xla/service/executable.h"
95 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
96 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
97 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
98 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
99 #include "tensorflow/compiler/xla/shape_util.h"
100 #include "tensorflow/compiler/xla/util.h"
101 #include "tensorflow/compiler/xla/xla_data.pb.h"
102 #include "tensorflow/core/platform/cpu_info.h"
103 #include "tensorflow/core/platform/errors.h"
104 #include "tensorflow/core/platform/fingerprint.h"
105 #include "tensorflow/core/platform/mem.h"
106 #include "tensorflow/core/platform/status.h"
107 #include "tensorflow/core/platform/types.h"
108 #include "tensorflow/core/profiler/lib/connected_traceme.h"
109 #include "tensorflow/core/profiler/lib/traceme.h"
110 #include "tensorflow/core/profiler/lib/traceme_encode.h"
111 #include "tensorflow/stream_executor/device_memory.h"
112 #include "tensorflow/stream_executor/device_memory_allocator.h"
113 #include "tensorflow/stream_executor/event.h"
114 #include "tensorflow/stream_executor/host/host_platform_id.h"
115 #include "tensorflow/stream_executor/lib/statusor.h"
116 #include "tensorflow/stream_executor/stream.h"
117
118 namespace xla {
119
platform_id() const120 PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
121 return client_->platform_id();
122 }
platform_name() const123 absl::string_view PjRtStreamExecutorDevice::platform_name() const {
124 return client_->platform_name();
125 }
126
GetLocalDeviceState() const127 StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
128 const {
129 if (local_device_state_) {
130 return local_device_state_.get();
131 }
132 return InvalidArgument("Device %s is not a local device.", DebugString());
133 }
134
DebugString() const135 std::string PjRtStreamExecutorDevice::DebugString() const {
136 return absl::StrCat(platform_name(), ":", id());
137 }
138
DevicesToDeviceAssignment(absl::Span<const std::vector<PjRtDevice * >> devices)139 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
140 absl::Span<const std::vector<PjRtDevice*>> devices) {
141 if (devices.empty()) {
142 return InvalidArgument(
143 "Device assignment passed to Compile() must be non-empty.");
144 }
145 if (devices[0].empty()) {
146 return InvalidArgument(
147 "Device assignment passed to Compile() must have a nonzero number of "
148 "partitions per replica; replica 0 had 0 partitions.");
149 }
150 DeviceAssignment xla_assignment(devices.size(), devices[0].size());
151 for (int replica = 0; replica < devices.size(); ++replica) {
152 if (devices[replica].size() != devices[0].size()) {
153 return InvalidArgument(
154 "Device assignment passed to Compile() has different numbers of "
155 "partitions between replicas; %d partitions for replica %d versus %d "
156 "partitions for replica 0.",
157 devices[replica].size(), replica, devices[0].size());
158 }
159 for (int partition = 0; partition < devices[replica].size(); ++partition) {
160 if (devices[0][0]->client()->platform_id() !=
161 devices[replica][partition]->client()->platform_id()) {
162 return InvalidArgument(
163 "Device assignment passed to Compile() must have devices of a "
164 "single kind, got %s for replica 0 partition 0 and %s for replica "
165 "%d partition %d.",
166 devices[0][0]->client()->platform_name(),
167 devices[replica][partition]->client()->platform_name(), replica,
168 partition);
169 }
170 xla_assignment(replica, partition) = devices[replica][partition]->id();
171 }
172 }
173 return xla_assignment;
174 }
175
176 class CpuAllocator : public tensorflow::Allocator {
177 public:
178 CpuAllocator() = default;
179
Name()180 std::string Name() override { return "cpu"; }
181
AllocateRaw(size_t alignment,size_t num_bytes)182 void* AllocateRaw(size_t alignment, size_t num_bytes) override {
183 return tensorflow::port::AlignedMalloc(num_bytes, alignment);
184 }
DeallocateRaw(void * ptr)185 void DeallocateRaw(void* ptr) override {
186 return tensorflow::port::AlignedFree(ptr);
187 }
188 };
189
DefaultThreadPoolSize()190 static int DefaultThreadPoolSize() {
191 // Google's CI system exposes an environment variable NPROC that describes
192 // a CPU reservation for tests.
193 // TODO(phawkins): expose a better thought-out set of knobs to control
194 // parallelism.
195 const char* nproc_str = std::getenv("NPROC");
196 int nproc = 0;
197 if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
198 return std::max(0, nproc);
199 }
200 return tensorflow::port::MaxParallelism();
201 }
202
PjRtStreamExecutorClient(std::string platform_name,LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int task_id,std::unique_ptr<se::DeviceMemoryAllocator> allocator,std::unique_ptr<tensorflow::Allocator> host_memory_allocator,bool should_stage_host_to_device_transfers,std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)203 PjRtStreamExecutorClient::PjRtStreamExecutorClient(
204 std::string platform_name, LocalClient* client,
205 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
206 std::unique_ptr<se::DeviceMemoryAllocator> allocator,
207 std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
208 bool should_stage_host_to_device_transfers,
209 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
210 : platform_id_(tensorflow::Fingerprint64(platform_name)),
211 platform_name_(std::move(platform_name)),
212 client_(client),
213 host_memory_allocator_(std::move(host_memory_allocator)),
214 owned_devices_(std::move(devices)),
215 task_id_(task_id),
216 owned_allocator_(std::move(allocator)),
217 should_stage_host_to_device_transfers_(
218 should_stage_host_to_device_transfers),
219 gpu_run_options_(std::move(gpu_run_options)),
220 thread_pool_(
221 tensorflow::Env::Default(), "pjrt_thread_pool",
222 std::max<int>(DefaultThreadPoolSize(), client->device_count())) {
223 if (owned_allocator_ != nullptr) {
224 allocator_ = owned_allocator_.get();
225 } else {
226 allocator_ = client_->backend().memory_allocator();
227 }
228
229 if (!host_memory_allocator_) {
230 host_memory_allocator_ = std::make_unique<CpuAllocator>();
231 }
232
233 for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
234 owned_devices_) {
235 devices_.push_back(device.get());
236 CHECK(id_to_device_.insert({device->id(), device.get()}).second)
237 << "Duplicate device id: " << device->id();
238
239 if (device->IsAddressable()) {
240 int idx = device->local_hardware_id();
241 if (idx >= addressable_devices_.size()) {
242 addressable_devices_.resize(idx + 1);
243 }
244 CHECK(addressable_devices_[idx] == nullptr) << idx;
245 addressable_devices_[idx] = device.get();
246 }
247 device->SetClient(this);
248 }
249 for (int idx = 0; idx < addressable_devices_.size(); ++idx) {
250 CHECK(addressable_devices_[idx] != nullptr) << idx;
251 }
252 }
253
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const254 StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
255 int num_replicas, int num_partitions) const {
256 return client_->backend().computation_placer()->AssignDevices(num_replicas,
257 num_partitions);
258 }
259
260 StatusOr<std::unique_ptr<HloCostAnalysis>>
GetHloCostAnalysis()261 PjRtStreamExecutorClient::GetHloCostAnalysis() {
262 return absl::make_unique<HloCostAnalysis>(
263 client_->backend().compiler()->ShapeSizeBytesFunction());
264 }
265
266 namespace {
267
268 // Ensures that it is safe to deallocate any buffers that have been enqueued in
269 // an operation on stream. Called only in rare error cases that are triggered
270 // during enqueue. These cases generally correspond to resource exhaustion.
StallStreamOnError(LocalDeviceState * local_device,se::Stream * stream)271 void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
272 switch (local_device->allocation_model()) {
273 case LocalDeviceState::kAsynchronous:
274 // We can safely deallocate any dangling buffers immediately. NOTE: this
275 // assumes that any buffers enqueued on stream are local to stream's
276 // executor, and manual action may be needed if that condition is not met.
277 break;
278
279 case LocalDeviceState::kComputeSynchronized:
280 // This will stall computation but that's ok in this very rare error
281 // case.
282 if (stream != local_device->compute_stream()) {
283 local_device->compute_stream()->ThenWaitFor(stream);
284 }
285 break;
286
287 case LocalDeviceState::kSynchronous:
288 // This will stall the calling thread but that's ok in this very rare
289 // error case. If the stall fails just crash, since we have no other
290 // way to synchronize.
291 TF_CHECK_OK(stream->BlockHostUntilDone());
292 break;
293 }
294 }
295
296 // Does all necessary bookkeeping, after a buffer is successfully enqueued onto
297 // a stream, to ensure that the buffer will be kept alive until its use on that
298 // stream is complete.
299 //
300 // device_buffer: the buffer that was enqueued.
301 // buffer_local_device: the device the buffer was allocated on.
302 // stream_local_device: the device that manages usage_stream.
303 // event: an event that was recorded on usage_stream
304 // after the usage of device_buffer was enqueued.
305 // usage_stream: the stream the operation using device_buffer
306 // was enqueued on.
307 // prefer_to_retain_reference: relevant only for the compute synchronous
308 // allocation model. If true, retain a reference
309 // to device_buffer until after the operation
310 // completes. If false then the compute stream
311 // will have to be synchronized past event before
312 // device_buffer can be freed.
313 //
314 // prefer_to_retain_reference encodes a heuristic set by the caller for the
315 // compute synchronous model:
316 //
317 // Generally when a buffer is the destination of a copy to a device, it will
318 // subsequently be used on the device's compute stream before being freed. In
319 // that case, there is no need to retain a reference to the buffer. If the
320 // buffer is freed before being used on the compute stream, the free will be
321 // delayed until the host knows that event has completed, but this is expected
322 // to be uncommon.
323 //
324 // When a buffer is the source of a copy from a device, we need to either retain
325 // a reference to the buffer until the copy completes or serialize the compute
326 // stream behind the copy. It is often better to retain a reference since while
327 // that keeps memory alive longer, it avoids stalling the compute stream.
RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,LocalDeviceState * buffer_local_device,LocalDeviceState * stream_local_device,std::shared_ptr<BufferSequencingEvent> event,se::Stream * usage_stream,bool prefer_to_retain_reference)328 void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
329 LocalDeviceState* buffer_local_device,
330 LocalDeviceState* stream_local_device,
331 std::shared_ptr<BufferSequencingEvent> event,
332 se::Stream* usage_stream, bool prefer_to_retain_reference) {
333 bool retain_buffer_until_completion =
334 // If the buffer wasn't allocated on the same device as the stream, always
335 // retain a reference.
336 (stream_local_device != buffer_local_device) ||
337 // In the synchronous allocation model, always retain a reference.
338 (stream_local_device->allocation_model() ==
339 LocalDeviceState::kSynchronous) ||
340 // In the compute synchronous model, use the caller's heuristic.
341 (stream_local_device->allocation_model() ==
342 LocalDeviceState::kComputeSynchronized &&
343 prefer_to_retain_reference);
344 if (retain_buffer_until_completion) {
345 buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer());
346 }
347 device_buffer.ConvertUsageHold(usage_stream, event,
348 retain_buffer_until_completion);
349 }
350
351 // Allocates the device buffers for a buffer that will be used as the
352 // destination of a copy, either from the host or another device. copy_stream
353 // may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the
354 // buffer is a tuple then the tuple tables are allocated, and all necessary
355 // synchronization for them is dealt with, before the buffer is returned.
356 //
357 // It is safe to delete the returned PjRtBuffer without further
358 // synchronization if an error occurs before the buffer is used.
359 //
360 // The caller may optionally provide a definition event to be recorded in
361 // the buffer.
362 // TODO(phawkins): replace on_host_shape here with on_device_shape.
AllocateDestinationBuffer(const Shape & on_host_shape,PjRtDevice * device,LocalDeviceState * local_device,se::Stream * copy_stream,bool is_uninitialized_create,PjRtClient * client,std::shared_ptr<BufferSequencingEvent> definition_event=nullptr)363 StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
364 const Shape& on_host_shape, PjRtDevice* device,
365 LocalDeviceState* local_device, se::Stream* copy_stream,
366 bool is_uninitialized_create, PjRtClient* client,
367 std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
368 if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
369 return InvalidArgument("Can't make a buffer from an empty tuple");
370 }
371
372 auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
373 TransferManager* transfer_manager =
374 se_client->client()->backend().transfer_manager();
375 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
376 transfer_manager->AllocateScopedShapedBuffer(
377 on_host_shape, se_client->allocator(),
378 local_device->device_ordinal()));
379 if (local_device->allocation_model() ==
380 LocalDeviceState::kComputeSynchronized) {
381 if (copy_stream == nullptr) {
382 CHECK(is_uninitialized_create);
383 } else {
384 copy_stream->ThenWaitFor(local_device->compute_stream());
385 }
386 } else {
387 DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
388 local_device->compute_stream()->parent(), dst_buffer));
389 }
390 Shape on_device_shape = dst_buffer.on_device_shape();
391
392 absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
393 definition_events;
394 if (is_uninitialized_create) {
395 // There is not going to be any copy into the buffer so in general we don't
396 // need a definition event.
397 if (local_device->allocation_model() ==
398 LocalDeviceState::kComputeSynchronized) {
399 // The allocation is not valid until the compute stream passes this point,
400 // so add a definition event in the compute stream.
401 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
402 TF_ASSIGN_OR_RETURN(EventPool::Handle event,
403 local_device->event_pool().ThenAllocateAndRecordEvent(
404 local_device->compute_stream()));
405 definition_events.back()->SetSequencingEvent(
406 std::move(event), local_device->compute_stream());
407 }
408 // if the caller provided a definition event then we record that.
409 if (definition_event) {
410 definition_events.emplace_back(definition_event);
411 }
412 } else {
413 // We have at least one definition event, for the copy completing to
414 // the device buffers.
415 if (definition_event) {
416 definition_events.emplace_back(definition_event);
417 } else {
418 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
419 }
420 }
421 se::Stream* tuple_table_stream = local_device->host_to_device_stream();
422 if (on_device_shape.IsTuple()) {
423 // We also need to copy the tuple tables, so we'll have an additional
424 // definition event for that copy to complete.
425 if (tuple_table_stream != copy_stream) {
426 if (local_device->allocation_model() ==
427 LocalDeviceState::kComputeSynchronized) {
428 tuple_table_stream->ThenWaitFor(local_device->compute_stream());
429 } else {
430 DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
431 local_device->compute_stream()->parent(), dst_buffer));
432 }
433 }
434
435 TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
436 tuple_table_stream, dst_buffer));
437 // CAUTION: From this point onwards we need to be careful about returning
438 // from error cases because we have started a transfer and must not allow
439 // dst_buffer to be freed too soon in the non-async allocation models.
440
441 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
442 StatusOr<EventPool::Handle> event_or =
443 local_device->event_pool().ThenAllocateAndRecordEvent(
444 tuple_table_stream);
445 if (!event_or.ok()) {
446 StallStreamOnError(local_device, tuple_table_stream);
447 return event_or.status();
448 }
449 definition_events.back()->SetSequencingEvent(event_or.ConsumeValueOrDie(),
450 tuple_table_stream);
451 }
452 std::shared_ptr<TrackedDeviceBuffer> dst_device_buffer =
453 TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
454 definition_events);
455
456 auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
457 on_device_shape, std::move(dst_device_buffer), client, device);
458
459 if (on_device_shape.IsTuple()) {
460 // Add a usage hold for the tuple table write and immediately convert it to
461 // the appropriate form of synchronization. prefer_to_retain_reference=false
462 // means don't retain a memory reference until the transfer is complete when
463 // using the ComputeSynchronized allocation model. This is a heuristic
464 // because in the common case destination buffers will be used on the
465 // compute stream and therefore don't require any synchronization before
466 // being freed. If the buffer is allocated and never used, the free will
467 // take longer and this is assumed to be ok.
468 RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
469 definition_events.back(), tuple_table_stream,
470 /*prefer_to_retain_reference=*/false);
471 }
472
473 return py_buffer;
474 }
475
476 // Adds necessary synchronization after a copy has been enqueued to a buffer.
477 // definition_event was added when the buffer was allocated, but has not yet
478 // had an event recorded.
AddDestinationBufferSynchronization(LocalDeviceState * local_device,PjRtStreamExecutorBuffer::ScopedHold device_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,se::Stream * copy_stream)479 Status AddDestinationBufferSynchronization(
480 LocalDeviceState* local_device,
481 PjRtStreamExecutorBuffer::ScopedHold device_buffer,
482 std::shared_ptr<BufferSequencingEvent> definition_event,
483 se::Stream* copy_stream) {
484 StatusOr<EventPool::Handle> event_or =
485 local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream);
486 if (!event_or.ok()) {
487 StallStreamOnError(local_device, copy_stream);
488 return event_or.status();
489 }
490 definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(),
491 copy_stream);
492 // prefer_to_retain_reference=false means don't retain a memory reference
493 // until the transfer is complete when using the ComputeSynchronized
494 // allocation model. This is a heuristic because in the common case
495 // destination buffers will be used on the compute stream and therefore don't
496 // require any synchronization before being freed. If the buffer is allocated
497 // and never used, the free will take longer and this is assumed to be ok.
498 RecordUsage(std::move(device_buffer), local_device, local_device,
499 definition_event, copy_stream,
500 /*prefer_to_retain_reference=*/false);
501 return Status::OK();
502 }
503
504 } // namespace
505
~ScopedHold()506 PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
507 if (ok()) {
508 parent_->DropHold(type_, buffer().get());
509 }
510 }
511
ScopedHold(ScopedHold && other)512 PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
513 : parent_(other.parent_),
514 type_(other.type_),
515 state_(other.state_),
516 status_(std::move(other.status_)),
517 buffer_(std::move(other.buffer_)) {
518 // Preserve the invariant that status is invalid if buffer == nullptr.
519 other.SetState(kMoved);
520 }
521
Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>> && buffer_or)522 void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
523 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
524 CHECK(!ok());
525 if (buffer_or.ok()) {
526 buffer_ = buffer_or.ValueOrDie();
527 SetState(kValid);
528 } else {
529 status_ = buffer_or.status();
530 buffer_ = nullptr;
531 SetState(kError);
532 }
533 // Check the invariant holds.
534 CHECK(!ok() || buffer_ != nullptr);
535 }
536
537 PjRtStreamExecutorBuffer::ScopedHold::ForClosure
ToClosure()538 PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
539 CHECK(ok());
540 ForClosure for_closure(parent_, type_, state_, std::move(status_),
541 std::move(buffer_));
542 SetState(kReleased);
543 return for_closure;
544 }
545
ConvertUsageHold(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)546 void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
547 se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
548 bool reference_held) {
549 CHECK(ok());
550 CHECK_EQ(type_, kUsage);
551 parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
552 reference_held);
553 SetState(kConverted);
554 }
555
ConfirmDonation()556 void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
557 CHECK(ok());
558 CHECK_EQ(type_, kDonation);
559 parent_->ConfirmDonation(buffer().get());
560 SetState(kDonated);
561 }
562
AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const563 void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
564 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
565 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
566 ExecutionInput* execution_input,
567 se::DeviceMemoryAllocator* allocator) const {
568 CHECK(ok());
569 if (type_ == kDonation) {
570 buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
571 } else {
572 CHECK_EQ(type_, kUsage);
573 buffer()->AddToInputAsImmutable(iterator, end);
574 }
575 }
576
IsOnCpu() const577 bool PjRtStreamExecutorBuffer::IsOnCpu() const {
578 return client()->platform_id() == kCpuId;
579 }
580
581 namespace {
582
583 // Implements PjRtBuffer::ExternalReference as a wrapped
584 // ScopedHold::kExternalReference.
585 class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
586 public:
ScopedHoldAsExternalReference(PjRtStreamExecutorBuffer::ScopedHold hold)587 explicit ScopedHoldAsExternalReference(
588 PjRtStreamExecutorBuffer::ScopedHold hold)
589 : external_reference_(std::move(hold)) {
590 CHECK(external_reference_.type() ==
591 PjRtStreamExecutorBuffer::ScopedHold::kExternalReference);
592 data_ptr_ = external_reference_->device_memory().front().opaque();
593 }
594
595 ~ScopedHoldAsExternalReference() override = default;
596
597 private:
598 PjRtStreamExecutorBuffer::ScopedHold external_reference_;
599 };
600
601 } // namespace
602
603 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()604 PjRtStreamExecutorBuffer::AcquireExternalReference() {
605 ScopedHold hold = GetBufferWithExternalReference();
606 Status hold_status = hold.status();
607 if (!hold_status.ok()) return hold_status;
608 return std::unique_ptr<ExternalReference>(
609 std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
610 }
611
612 class TrackedDeviceBufferExternalReference
613 : public PjRtBuffer::ExternalReference {
614 public:
TrackedDeviceBufferExternalReference(std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)615 explicit TrackedDeviceBufferExternalReference(
616 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)
617 : tracked_device_buffer_(std::move(tracked_device_buffer)) {
618 data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque();
619 }
620
621 ~TrackedDeviceBufferExternalReference() override = default;
622
623 private:
624 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer_;
625 };
626
627 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)628 PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
629 bool wait_for_operations_to_complete) {
630 if (on_device_shape_.IsTuple()) {
631 return InvalidArgument(
632 "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
633 }
634 TF_ASSIGN_OR_RETURN(
635 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
636 Release(wait_for_operations_to_complete));
637
638 std::unique_ptr<PjRtBuffer::ExternalReference> ref;
639 if (tracked_device_buffer) {
640 ref = std::make_unique<TrackedDeviceBufferExternalReference>(
641 std::move(tracked_device_buffer));
642 }
643 return ref;
644 }
645
646 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostBuffer(const void * data,const Shape & shape,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)647 PjRtStreamExecutorClient::BufferFromHostBuffer(
648 const void* data, const Shape& shape,
649 HostBufferSemantics host_buffer_semantics,
650 std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
651 tensorflow::profiler::TraceMe traceme(
652 "PjRtStreamExecutorClient::BufferFromHostBuffer");
653 VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
654 << shape.ToString() << " device: " << device->DebugString();
655 if (shape.IsTuple()) {
656 return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
657 }
658 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
659 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
660 ->GetLocalDeviceState());
661 int64 size = ShapeUtil::ByteSizeOf(shape);
662
663 TransferManager* transfer_manager = client()->backend().transfer_manager();
664 TF_ASSIGN_OR_RETURN(Shape compact_shape,
665 transfer_manager->ChooseCompactLayoutForShape(shape));
666
667 // The CPU platform is special because the "host" and the "device" are in the
668 // same memory space. If the input shape is in the correct layout and we don't
669 // want to defer the copy onto a thread, we can use the following fast
670 // path.
671 bool is_cpu_platform =
672 local_device->executor()->platform()->id() == se::host::kHostPlatformId;
673 if (is_cpu_platform) {
674 // If we are on the host platform and the input buffer is sufficiently
675 // aligned, we can simply point to the input array's data without any
676 // further copies. At the time of writing we require a 16-byte alignment
677 // because XLA may generate code which requires it.
678 bool can_use_zero_copy =
679 host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
680 ((absl::bit_cast<std::uintptr_t>(data) &
681 (cpu_function_runtime::kMinAlign - 1)) == 0);
682 if (shape.layout() == compact_shape.layout() &&
683 (host_buffer_semantics ==
684 HostBufferSemantics::kImmutableOnlyDuringCall ||
685 can_use_zero_copy)) {
686 std::function<void()> on_delete_callback;
687 se::DeviceMemoryBase buffer;
688 // If we are on the host platform and the input buffer is sufficiently
689 // aligned, we can simply point to the input array's data without any
690 // further copies. At the time of writing we require a 16-byte alignment
691 // because XLA may generate code which requires it.
692 if (can_use_zero_copy) {
693 on_delete_callback = std::move(on_done_with_host_buffer);
694 buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
695 } else {
696 void* staging_buffer = host_memory_allocator()->AllocateRaw(
697 cpu_function_runtime::kMinAlign, size);
698 buffer = se::DeviceMemoryBase(staging_buffer, size);
699 std::memcpy(staging_buffer, data, size);
700 if (on_done_with_host_buffer) {
701 on_done_with_host_buffer();
702 }
703 on_delete_callback = [staging_buffer, host_memory_allocator =
704 host_memory_allocator()]() {
705 host_memory_allocator->DeallocateRaw(staging_buffer);
706 };
707 }
708 absl::Span<const std::shared_ptr<BufferSequencingEvent>>
709 definition_events;
710 auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
711 /*allocator=*/nullptr, local_device->device_ordinal(),
712 std::initializer_list<se::DeviceMemoryBase>{buffer},
713 definition_events, std::move(on_delete_callback));
714 return std::unique_ptr<PjRtBuffer>(
715 std::make_unique<PjRtStreamExecutorBuffer>(
716 shape, std::move(device_buffer), this, device));
717 }
718 }
719
720 TF_ASSIGN_OR_RETURN(
721 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
722 AllocateDestinationBuffer(compact_shape, device, local_device,
723 local_device->host_to_device_stream(),
724 /*is_uninitialized_create=*/false, this));
725
726 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
727 py_buffer->GetBufferWithUsageHold());
728 CHECK(device_buffer.ok());
729
730 // If necessary, allocate a host-side buffer for staging host-to-device
731 // transfers. On GPU this is a buffer in pinned memory.
732 std::shared_ptr<void> staging_buffer;
733 if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
734 should_stage_host_to_device_transfers()) {
735 void* ptr = host_memory_allocator()->AllocateRaw(
736 tensorflow::Allocator::kAllocatorAlignment, size);
737 staging_buffer = std::shared_ptr<void>(
738 ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
739 host_memory_allocator->DeallocateRaw(ptr);
740 });
741 }
742
743 // Copy the buffer into a staging buffer before returning control to the
744 // caller if the caller only guaranteed that the buffer is valid for the
745 // duration of the call. Otherwise, we stage (if necessary) on a separate
746 // thread.
747 if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
748 std::memcpy(staging_buffer.get(), data, size);
749 if (on_done_with_host_buffer) {
750 on_done_with_host_buffer();
751 on_done_with_host_buffer = nullptr;
752 }
753 data = nullptr;
754 }
755
756 // The host to device transfer is performed on a thread pool, mostly because
757 // it includes linearization that may be slow. It is OK to capture the
758 // py_buffer pointer because the py_buffer can't be deleted until all the
759 // usage holds have gone away.
760 // TODO(misard) assess if it would be preferable to introduce a heuristic to
761 // put the transfer into the calling thread for small literals.
762 auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
763 data, size,
764 movable_device_buffer{device_buffer.ToClosure()}, shape,
765 py_buffer{py_buffer.get()},
766 on_device_shape{py_buffer->on_device_shape()},
767 staging_buffer{std::move(staging_buffer)},
768 on_done_with_host_buffer{
769 std::move(on_done_with_host_buffer)},
770 host_buffer_semantics]() {
771 PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
772 // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
773 // to report failures from a callback. However, the operations here are
774 // unlikely to fail and not recoverable even if we were to fail: DMAs to
775 // memory that has already been allocated, and a possible Event
776 // allocation.
777
778 ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
779 // If applicable on the backend, stage the transfer via host memory
780 // allocated via the host_memory_allocator. On GPU, this is pinned
781 // memory.
782 if (staging_buffer) {
783 // If we didn't already copy the input buffer into the staging buffer,
784 // do so now.
785 if (host_buffer_semantics !=
786 HostBufferSemantics::kImmutableOnlyDuringCall) {
787 std::memcpy(staging_buffer.get(), data, size);
788 }
789 BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
790 shape);
791 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
792 local_device->host_to_device_stream(), literal, buffer));
793 } else {
794 BorrowingLiteral literal(static_cast<const char*>(data), shape);
795 // Otherwise, just transfer the literal.
796 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
797 local_device->host_to_device_stream(), literal, buffer));
798 }
799
800 std::shared_ptr<BufferSequencingEvent> event =
801 device_buffer->definition_events()[0];
802 TF_CHECK_OK(AddDestinationBufferSynchronization(
803 local_device, std::move(device_buffer), event,
804 local_device->host_to_device_stream()));
805
806 local_device->callback_stream()->ThenWaitFor(
807 local_device->host_to_device_stream());
808 local_device->ThenExecuteOnCallbackThread(
809 local_device->callback_stream(),
810 [staging_buffer{std::move(staging_buffer)},
811 on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
812 if (on_done_with_host_buffer) {
813 on_done_with_host_buffer();
814 }
815 });
816 };
817 if (is_cpu_platform) {
818 // Using the thread_pool would be a double thread hop; the code
819 // already defers its work onto a stream (= thread on CPU).
820 transfer_h2d();
821 } else {
822 thread_pool()->Schedule(transfer_h2d);
823 }
824 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
825 }
826
827 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)828 PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
829 PjRtDevice* device) {
830 return CreateUninitializedBuffer(shape, device, nullptr);
831 }
832
833 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device,std::shared_ptr<BufferSequencingEvent> definition_event)834 PjRtStreamExecutorClient::CreateUninitializedBuffer(
835 const Shape& shape, PjRtDevice* device,
836 std::shared_ptr<BufferSequencingEvent> definition_event) {
837 tensorflow::profiler::TraceMe traceme(
838 "PjRtStreamExecutorClient::CreateUninitializedBuffer");
839 VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
840 << shape.ToString() << " device: " << device->DebugString();
841 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
842 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
843 ->GetLocalDeviceState());
844
845 TransferManager* transfer_manager = client()->backend().transfer_manager();
846 TF_ASSIGN_OR_RETURN(Shape compact_shape,
847 transfer_manager->ChooseCompactLayoutForShape(shape));
848
849 TF_ASSIGN_OR_RETURN(
850 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
851 AllocateDestinationBuffer(compact_shape, device, local_device,
852 /*copy_stream=*/nullptr,
853 /*is_uninitialized_create=*/true, this,
854 definition_event));
855 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
856 }
857
858 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)859 PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
860 PjRtDevice* device) {
861 tensorflow::profiler::TraceMe traceme(
862 "PjRtStreamExecutorClient::BufferFromHostLiteral");
863 VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
864 << literal.shape().ToString() << " device: " << device->DebugString();
865 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
866 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
867 ->GetLocalDeviceState());
868
869 TransferManager* transfer_manager = client()->backend().transfer_manager();
870 TF_ASSIGN_OR_RETURN(
871 Shape compact_shape,
872 transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
873 TF_ASSIGN_OR_RETURN(
874 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
875 AllocateDestinationBuffer(compact_shape, device, local_device,
876 local_device->host_to_device_stream(),
877 /*is_uninitialized_create=*/false, this));
878
879 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
880 py_buffer->GetBufferWithUsageHold());
881 CHECK(device_buffer.ok());
882
883 // The host to device transfer is performed on a thread pool, mostly because
884 // it includes linearization that may be slow. It is OK to capture the
885 // py_buffer pointer because the py_buffer can't be deleted until all the
886 // usage holds have gone away.
887 // TODO(misard) assess if it would be preferable to introduce a heuristic to
888 // put the transfer into the calling thread for small literals.
889 auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
890 movable_device_buffer{device_buffer.ToClosure()},
891 literal, py_buffer{py_buffer.get()},
892 on_device_shape{py_buffer->on_device_shape()}]() {
893 PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
894 // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
895 // to report failures from a callback. However, the operations here are
896 // unlikely to fail and not recoverable even if we were to fail: DMAs to
897 // memory that has already been allocated, and a possible Event
898 // allocation.
899
900 se::Stream* h2d_stream = local_device->host_to_device_stream();
901 ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
902 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
903 h2d_stream, literal, buffer));
904
905 std::shared_ptr<BufferSequencingEvent> event =
906 device_buffer->definition_events()[0];
907 TF_CHECK_OK(AddDestinationBufferSynchronization(
908 local_device, std::move(device_buffer), event, h2d_stream));
909
910 // This can sometimes catch the case where the literal memory has been
911 // freed before the H2D transfer was issued.
912 h2d_stream->RefreshStatus()
913 .IgnoreError(); // Can return error::Unimplemented
914 QCHECK(h2d_stream->ok());
915 };
916 thread_pool()->Schedule(transfer_h2d);
917 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
918 }
919
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier && notifier)920 void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
921 absl::Span<const Shape> shapes, PjRtDevice* device,
922 PjRtCrossHostRecvNotifier&& notifier) {
923 if (shapes.empty()) {
924 notifier(InvalidArgument(
925 "shapes parameter empty in MakeCrossHostReceiveBuffers"));
926 return;
927 }
928
929 auto local_device_or =
930 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
931 ->GetLocalDeviceState();
932 if (!local_device_or.ok()) {
933 notifier(local_device_or.status());
934 return;
935 }
936 LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
937 std::shared_ptr<BufferSequencingEvent> definition_event =
938 std::make_shared<BufferSequencingEvent>();
939 std::vector<std::unique_ptr<PjRtBuffer>> buffers;
940 buffers.reserve(shapes.size());
941 for (const auto& shape : shapes) {
942 StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
943 shape, device, local_device,
944 /*copy_stream=*/nullptr,
945 /*is_uninitialized_create=*/false, this, definition_event);
946 if (!buffer_or.ok()) {
947 notifier(buffer_or.status());
948 return;
949 }
950 buffers.push_back(buffer_or.ConsumeValueOrDie());
951 }
952
953 EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
954 std::move(notifier));
955 }
956
957 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)958 PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
959 void* device_ptr, const Shape& shape, PjRtDevice* device,
960 std::function<void()> on_delete_callback) {
961 se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape));
962 absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
963 auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
964 /*allocator=*/nullptr, device->local_hardware_id(),
965 std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
966 std::move(on_delete_callback));
967 return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
968 shape, std::move(device_buffer), this, device));
969 }
970
971 // Transfer the given literal to the infeed queue of the given local device.
TransferToInfeed(const LiteralSlice & literal)972 Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
973 // Only support infeed to local device.
974 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
975 return local_device->client()->TransferToInfeedLocal(
976 literal, local_device->device_ordinal());
977 }
978
TransferFromOutfeed(MutableBorrowingLiteral literal)979 Status PjRtStreamExecutorDevice::TransferFromOutfeed(
980 MutableBorrowingLiteral literal) {
981 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
982 return local_device->client()->TransferFromOutfeedLocal(
983 local_device->device_ordinal(), literal);
984 }
985
LookupAddressableDevice(int local_hardware_id) const986 StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
987 int local_hardware_id) const {
988 for (auto* device : addressable_devices_) {
989 if (local_hardware_id == device->local_hardware_id()) {
990 return device;
991 }
992 }
993 return InvalidArgument("No matching device found for local_hardware_id %d",
994 local_hardware_id);
995 }
996
PjRtStreamExecutorBuffer(Shape on_device_shape,std::shared_ptr<TrackedDeviceBuffer> device_buffer,PjRtClient * client,PjRtDevice * device)997 PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
998 Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
999 PjRtClient* client, PjRtDevice* device)
1000 : client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
1001 on_device_shape_(std::move(on_device_shape)),
1002 device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
1003 device_buffer_(std::move(device_buffer)),
1004 donation_semaphore_(/*capacity=*/1) {
1005 for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1006 holds_[i] = 0;
1007 }
1008 }
1009
~PjRtStreamExecutorBuffer()1010 PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
1011 Delete();
1012 for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1013 CHECK_EQ(holds_[i], 0);
1014 }
1015 }
1016
OnDeviceSizeInBytes() const1017 int64 PjRtStreamExecutorBuffer::OnDeviceSizeInBytes() const {
1018 return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1019 ->client()
1020 ->backend()
1021 .transfer_manager()
1022 ->GetByteSizeRequirement(on_device_shape_);
1023 }
1024
WaitForOutstandingUsageHolds()1025 void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
1026 auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1027 return holds_[ScopedHold::kUsage] == 0;
1028 };
1029 mu_.Await(absl::Condition(¬_in_usage_hold));
1030 }
1031
WaitForOutstandingDonationHold()1032 void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
1033 auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1034 return holds_[ScopedHold::kDonation] == 0;
1035 };
1036 mu_.Await(absl::Condition(¬_in_donation_hold));
1037 }
1038
1039 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
Release(bool wait_for_operations_to_complete)1040 PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
1041 tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
1042 std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1043 TrackedDeviceBuffer::StreamAndEventContainer events;
1044 {
1045 absl::MutexLock lock(&mu_);
1046 // We first wait for a donation hold to complete if there is one in
1047 // progress. If the donation succeeds via ConfirmDonation() then it will
1048 // set device_buffer_ to nullptr before returning to this thread.
1049 WaitForOutstandingDonationHold();
1050 if (device_buffer_ == nullptr) {
1051 return std::shared_ptr<TrackedDeviceBuffer>();
1052 }
1053 // Set device_buffer_ to null now so that no other
1054 // thread can add a hold while we are in WaitForOutstandingUsageHolds()
1055 // below.
1056 std::swap(device_buffer_, device_buffer);
1057 WaitForOutstandingUsageHolds();
1058 // Now that all holds have completed and no more can be added, we can get
1059 // the final set of usage events.
1060 events = device_buffer->LockUseAndTransferUsageEvents();
1061 }
1062 LocalDeviceState* local_device_state =
1063 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1064 ->local_device_state();
1065 if (wait_for_operations_to_complete) {
1066 // Block the host until all usage events have completed. Usage events
1067 // dominate definition events, so this also waits for the buffer to be
1068 // defined.
1069 std::unique_ptr<se::Stream> stream;
1070 for (const auto& stream_and_event : events) {
1071 if (!stream_and_event.event->IsComplete()) {
1072 if (stream == nullptr) {
1073 stream = local_device_state->BorrowStreamFromPool();
1074 }
1075 stream_and_event.event->WaitForEventOnStream(stream.get());
1076 }
1077 }
1078 if (stream != nullptr) {
1079 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1080 local_device_state->ReturnStreamToPool(std::move(stream));
1081 }
1082 } else {
1083 if (local_device_state->allocation_model() ==
1084 LocalDeviceState::kComputeSynchronized) {
1085 std::unique_ptr<se::Stream> block_stream;
1086 for (const auto& stream_and_event : events) {
1087 // We only need to do something for events that didn't already acquire a
1088 // reference to the buffer, and also which the compute stream didn't
1089 // already wait for. Based on our heuristics this rare case should only
1090 // occur when a buffer was copied to a device and then never used there.
1091 // In that case we get a new stream and use it to hold onto a reference
1092 // to the buffer until the events are complete.
1093 if (!stream_and_event.reference_held &&
1094 !stream_and_event.event->DefinedOn(
1095 local_device_state->compute_stream()) &&
1096 !stream_and_event.event->IsComplete()) {
1097 if (block_stream == nullptr) {
1098 block_stream = local_device_state->BorrowStreamFromPool();
1099 }
1100 stream_and_event.event->WaitForEventOnStream(block_stream.get());
1101 }
1102 }
1103 if (block_stream != nullptr) {
1104 se::Stream* block_stream_ptr = block_stream.release();
1105 local_device_state->ThenExecuteOnCallbackThread(
1106 block_stream_ptr,
1107 [device_buffer, block_stream_ptr, local_device_state]() {
1108 local_device_state->ReturnStreamToPool(
1109 std::unique_ptr<se::Stream>(block_stream_ptr));
1110 });
1111 }
1112 }
1113 }
1114 return device_buffer;
1115 }
1116
Delete()1117 void PjRtStreamExecutorBuffer::Delete() {
1118 // When wait_for_reads_to_complete is false, Release should never fail.
1119 TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
1120 }
1121
IsDeleted()1122 bool PjRtStreamExecutorBuffer::IsDeleted() {
1123 absl::MutexLock lock(&mu_);
1124 return device_buffer_ == nullptr;
1125 }
1126
1127 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
GetBufferForHoldLocked(ScopedHold::Type type)1128 PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
1129 if (type == ScopedHold::kDonation) {
1130 if (device_buffer_ == nullptr) {
1131 return InvalidArgument("Donation requested for invalid buffer");
1132 }
1133 if (holds_[ScopedHold::kExternalReference] > 0) {
1134 return InvalidArgument(
1135 "Donation requested for buffer with external reference");
1136 }
1137 // donation_semaphore_ was acquired in GetBufferWithHold so that only one
1138 // thread at a time can attempt to get a donation hold.
1139 CHECK_EQ(holds_[type], 0);
1140 // First add the donation hold.
1141 ++holds_[type];
1142 // Then wait for any usage holds to be dropped or converted. No new usage
1143 // holds can be added until we drop the donation hold so this wait will
1144 // complete eventually.
1145 WaitForOutstandingUsageHolds();
1146 // Because we added a donation hold, nobody could release the buffer while
1147 // we were waiting.
1148 CHECK(device_buffer_ != nullptr);
1149 } else {
1150 // If there is a donation hold in progress we have to wait before
1151 // acquiring any other kind of hold.
1152 WaitForOutstandingDonationHold();
1153 if (device_buffer_ == nullptr) {
1154 return InvalidArgument("Hold requested on deleted or donated buffer");
1155 } else {
1156 ++holds_[type];
1157 }
1158 }
1159 return device_buffer_;
1160 }
1161
AcquireHoldLocked(ScopedHold * hold)1162 void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
1163 hold->Acquire(GetBufferForHoldLocked(hold->type()));
1164 }
1165
ConvertUsageHold(TrackedDeviceBuffer * buffer,se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)1166 void PjRtStreamExecutorBuffer::ConvertUsageHold(
1167 TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
1168 std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
1169 absl::MutexLock lock(&mu_);
1170 CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1171 buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
1172 CHECK_GT(holds_[ScopedHold::kUsage], 0);
1173 --holds_[ScopedHold::kUsage];
1174 }
1175
ConfirmDonation(TrackedDeviceBuffer * device_buffer)1176 void PjRtStreamExecutorBuffer::ConfirmDonation(
1177 TrackedDeviceBuffer* device_buffer) {
1178 {
1179 absl::MutexLock lock(&mu_);
1180 CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1181 CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1182 CHECK_EQ(holds_[ScopedHold::kDonation], 1);
1183 holds_[ScopedHold::kDonation] = 0;
1184 CHECK(device_buffer_.get() == device_buffer);
1185 // As a sanity check ensure no more usage events can be added to the buffer.
1186 device_buffer->LockUseAndTransferUsageEvents();
1187 // Give up ownership of the device memory so we don't free it when the last
1188 // reference to device_buffer_ goes away.
1189 device_buffer->ReleaseDeviceMemory();
1190 // Make *this invalid so it can't be used again. Any threads blocking in
1191 // Release or GetBufferWithHold will see an invalid buffer and return.
1192 device_buffer_.reset();
1193 }
1194 // Unblock another thread, if any, trying to get a donation hold.
1195 donation_semaphore_.Release(1);
1196 }
1197
DropHold(ScopedHold::Type type,TrackedDeviceBuffer * buffer)1198 void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
1199 TrackedDeviceBuffer* buffer) {
1200 absl::MutexLock lock(&mu_);
1201 CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1202 CHECK_GT(holds_[type], 0);
1203 --holds_[type];
1204 if (type == ScopedHold::kDonation) {
1205 CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1206 CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1207 CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1208 donation_semaphore_.Release(1);
1209 }
1210 }
1211
ToLiteral(MutableLiteralBase * literal,std::function<void (Status)> on_ready)1212 void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
1213 std::function<void(Status)> on_ready) {
1214 if (IsEmptyTuple()) {
1215 on_ready(InvalidArgument("ToLiteral called on empty tuple"));
1216 return;
1217 }
1218 LocalDeviceState* local_device =
1219 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1220 ->local_device_state();
1221 se::Stream* stream = local_device->GetDeviceToHostStream();
1222 ScopedHold device_buffer(this, ScopedHold::kUsage);
1223 {
1224 absl::MutexLock lock(&mu_);
1225 // We can't perform any other action while a donation hold is in progress.
1226 WaitForOutstandingDonationHold();
1227 if (device_buffer_ == nullptr) {
1228 on_ready(InvalidArgument(
1229 "CopyToHostAsync() called on deleted or donated buffer"));
1230 return;
1231 }
1232 AcquireHoldLocked(&device_buffer);
1233 }
1234
1235 WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
1236 ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
1237 StatusOr<EventPool::Handle> event_or =
1238 local_device->event_pool().AllocateEvent(stream->parent());
1239 if (!event_or.ok()) {
1240 on_ready(event_or.status());
1241 return;
1242 }
1243 tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1244 ->client()
1245 ->backend()
1246 .transfer_manager()
1247 ->TransferLiteralFromDevice(stream, shaped_buffer, literal,
1248 std::move(on_ready));
1249
1250 auto usage_event = std::make_shared<BufferSequencingEvent>();
1251 local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
1252 usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1253 // When using the ComputeSynchronized allocation model, retain a reference to
1254 // the device_buffer until the copy completes, to ensure that the buffer isn't
1255 // deleted or donated while it is still in use. The choice of retaining a
1256 // reference at the host is a heuristic; the alternative is to ensure, before
1257 // freeing the buffer, that the compute stream is synchronized past the
1258 // transfer, but it seems better to hold onto the buffer too long than to
1259 // stall the compute stream, particularly since the overwhelmingly common
1260 // use case of CopyToHostAsync will hold onto the reference long enough to
1261 // read the buffer in a subsequent call to ToLiteral.
1262 RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
1263 stream,
1264 /*prefer_to_retain_reference=*/true);
1265 }
1266
AsShapedBuffer() const1267 StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
1268 absl::MutexLock lock(&mu_);
1269 if (device_buffer_ == nullptr) {
1270 return InvalidArgument(
1271 "Attempted to fetch value of invalid/deleted buffer.");
1272 }
1273 return device_buffer_->AsShapedBuffer(on_device_shape_);
1274 }
1275
1276 PjRtStreamExecutorBuffer::ScopedHold
GetBufferWithHold(ScopedHold::Type type)1277 PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
1278 if (type == ScopedHold::kDonation) {
1279 // Ensure that at most one donation hold can be in progress at a time.
1280 donation_semaphore_.Acquire(1);
1281 }
1282 absl::MutexLock lock(&mu_);
1283 ScopedHold hold(this, type);
1284 AcquireHoldLocked(&hold);
1285 if (type == ScopedHold::kDonation && !hold.ok()) {
1286 donation_semaphore_.Release(1);
1287 }
1288 return hold;
1289 }
1290
1291 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1292 std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice * dst_device,LocalDeviceState * dst_local_device,LocalDeviceState * transfer_local_device,se::Stream * transfer_stream,std::shared_ptr<TrackedDeviceBuffer> src_device_buffer)1293 PjRtStreamExecutorBuffer::CopyToDeviceHelper(
1294 PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
1295 LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
1296 std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
1297 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
1298 AllocateDestinationBuffer(
1299 ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
1300 dst_device, dst_local_device, transfer_stream,
1301 /*is_uninitialized_create=*/false, client_));
1302
1303 TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
1304
1305 WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream);
1306
1307 ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
1308 CHECK(dst_device_buffer.ok());
1309 ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
1310
1311 // Copy the leaf buffers.
1312 StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
1313 [&]() -> StatusOr<std::shared_ptr<BufferSequencingEvent>> {
1314 for (const auto& leaf : src_buffer.buffers().leaves()) {
1315 const ShapeIndex& index = leaf.first;
1316 const se::DeviceMemoryBase& input_buffer = leaf.second;
1317 const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
1318 TF_RET_CHECK(input_buffer.size() == output_buffer.size())
1319 << "input: " << input_buffer.size()
1320 << " output: " << output_buffer.size();
1321 if (input_buffer.size() != 0) {
1322 TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
1323 transfer_stream, dst_local_device->compute_stream(), input_buffer,
1324 output_buffer));
1325 }
1326 }
1327 std::shared_ptr<BufferSequencingEvent> event =
1328 dst_device_buffer->definition_events()[0];
1329 TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization(
1330 transfer_local_device, std::move(dst_device_buffer), event,
1331 transfer_stream));
1332 return event;
1333 }();
1334 if (!copy_event_or.ok()) {
1335 StallStreamOnError(transfer_local_device, transfer_stream);
1336 if (transfer_local_device == dst_local_device) {
1337 // Some copies may have been enqueued before the error was returned, and
1338 // StallStreamOnError only makes sure the destination device is ok, so
1339 // make sure that the src buffer remains valid until after any transfers
1340 // have completed.
1341 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1342 ->local_device_state()
1343 ->ThenRelease(transfer_stream, src_device_buffer);
1344 }
1345 return copy_event_or.status();
1346 }
1347
1348 return std::pair<std::unique_ptr<PjRtBuffer>,
1349 std::shared_ptr<BufferSequencingEvent>>(
1350 std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
1351 copy_event_or.ConsumeValueOrDie());
1352 }
1353
CopyToDevice(PjRtDevice * dst_device)1354 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
1355 PjRtDevice* dst_device) {
1356 tensorflow::profiler::TraceMe traceme(
1357 "PjRtStreamExecutorBuffer::CopyToDevice");
1358 if (dst_device == device_) {
1359 return InvalidArgument(
1360 "CopyToDevice cannot accept the same source and destination devices");
1361 }
1362
1363 // Copying across PjRtClients involves a copy through the host.
1364 if (dst_device->client() != client_) {
1365 TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
1366 // Avoid use-after-free on `literal` due to unsequenced move and use.
1367 Literal* literal_pointer = literal.get();
1368 return dst_device->client()->BufferFromHostBuffer(
1369 literal_pointer->untyped_data(), literal_pointer->shape(),
1370 PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
1371 [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1372 }
1373
1374 TF_ASSIGN_OR_RETURN(
1375 LocalDeviceState * dst_local_device,
1376 tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
1377 ->GetLocalDeviceState());
1378 LocalDeviceState* transfer_local_device =
1379 tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1380 ->EnqueueD2DTransfersOnSrcStream()
1381 ? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1382 ->local_device_state()
1383 : dst_local_device;
1384 CHECK_EQ(dst_local_device->allocation_model(),
1385 transfer_local_device->allocation_model());
1386
1387 se::Stream* transfer_stream =
1388 transfer_local_device->GetDeviceToDeviceStream();
1389
1390 ScopedHold src_device_buffer(this, ScopedHold::kUsage);
1391 {
1392 absl::MutexLock lock(&mu_);
1393 // We can't perform any other action while a donation hold is in progress.
1394 WaitForOutstandingDonationHold();
1395 if (device_buffer_ == nullptr) {
1396 return InvalidArgument(
1397 "CopyToDevice called on deleted or donated buffer");
1398 }
1399 AcquireHoldLocked(&src_device_buffer);
1400 }
1401
1402 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1403 std::shared_ptr<BufferSequencingEvent>>>
1404 buffer_and_event_or = CopyToDeviceHelper(
1405 dst_device, dst_local_device, transfer_local_device, transfer_stream,
1406 src_device_buffer.buffer());
1407 if (!buffer_and_event_or.ok()) {
1408 return buffer_and_event_or.status();
1409 }
1410
1411 auto& buffer_and_event = buffer_and_event_or.ValueOrDie();
1412 std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
1413 std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;
1414
1415 // prefer_to_retain_reference=*/true means that, when using the
1416 // ComputeSynchronized allocation model, retain a reference to the
1417 // src_device_buffer until the copy completes. This is a heuristic; the
1418 // alternative is to ensure, before freeing the buffer, that the compute
1419 // stream is synchronized past the transfer, but it seems better to hold onto
1420 // the buffer too long than to stall the compute stream.
1421 RecordUsage(std::move(src_device_buffer),
1422 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1423 ->local_device_state(),
1424 transfer_local_device, event, transfer_stream,
1425 /*prefer_to_retain_reference=*/true);
1426
1427 return std::move(buffer);
1428 }
1429
CopyToRemoteDevice(absl::string_view serialized_descriptor)1430 Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
1431 absl::string_view serialized_descriptor) {
1432 return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1433 ->CopyToRemoteDevice(this, serialized_descriptor);
1434 }
1435
BlockHostUntilReady()1436 Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
1437 tensorflow::profiler::TraceMe traceme(
1438 "PjRtStreamExecutorBuffer::BlockHostUntilReady");
1439 std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1440 {
1441 absl::MutexLock lock(&mu_);
1442 if (device_buffer_ == nullptr) {
1443 return InvalidArgument(
1444 "BlockHostUntilReady() called on deleted or donated buffer");
1445 }
1446 device_buffer = device_buffer_;
1447 }
1448 LocalDeviceState* local_device_state =
1449 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1450 ->local_device_state();
1451 std::unique_ptr<se::Stream> stream;
1452 for (auto& event : device_buffer->definition_events()) {
1453 if (!event->IsComplete()) {
1454 if (stream == nullptr) {
1455 stream = local_device_state->BorrowStreamFromPool();
1456 }
1457 event->WaitForEventOnStream(stream.get());
1458 }
1459 }
1460 if (stream != nullptr) {
1461 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1462 local_device_state->ReturnStreamToPool(std::move(stream));
1463 }
1464 return Status::OK();
1465 }
1466
1467 namespace {
1468
1469 // Helper struct for the tuple that is transiently constructed to hold the
1470 // arguments of an execution.
1471 struct TupleHandle {
1472 // The ExecutionInput describing the tuple.
1473 ExecutionInput execution_input;
1474 // A definition event that has been recorded on the host_to_device stream
1475 // after the tuple table transfer.
1476 std::shared_ptr<BufferSequencingEvent> event;
1477 };
1478
1479 // Makes a tuple from the arguments to an execution.
MakeTupleHelper(PjRtClient * client,LocalDeviceState * local_device,absl::Span<PjRtBuffer * const> py_buffers,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,int device_ordinal)1480 StatusOr<TupleHandle> MakeTupleHelper(
1481 PjRtClient* client, LocalDeviceState* local_device,
1482 absl::Span<PjRtBuffer* const> py_buffers,
1483 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1484 int device_ordinal) {
1485 std::vector<Shape> host_shapes;
1486 std::vector<Shape> device_shapes;
1487 host_shapes.reserve(py_buffers.size());
1488 device_shapes.reserve(py_buffers.size());
1489 for (const PjRtBuffer* buffer : py_buffers) {
1490 device_shapes.push_back(buffer->on_device_shape());
1491 }
1492 Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
1493
1494 se::DeviceMemoryAllocator* allocator =
1495 tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
1496 TransferManager* transfer_manager =
1497 tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
1498 ->client()
1499 ->backend()
1500 .transfer_manager();
1501 se::Stream* stream = local_device->host_to_device_stream();
1502 TF_ASSIGN_OR_RETURN(
1503 se::OwningDeviceMemory root_table_memory,
1504 allocator->Allocate(
1505 device_ordinal,
1506 transfer_manager->GetByteSizeRequirement(on_device_shape)));
1507
1508 if (local_device->allocation_model() ==
1509 LocalDeviceState::kComputeSynchronized) {
1510 stream->ThenWaitFor(local_device->compute_stream());
1511 } else {
1512 DCHECK(transfer_manager->CanBufferBeAccessedNow(
1513 local_device->compute_stream()->parent(), root_table_memory.cref()));
1514 }
1515
1516 ExecutionInput execution_input(on_device_shape);
1517 ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1518 execution_input.MutableBuffers()->begin();
1519 ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1520 execution_input.MutableBuffers()->end();
1521 // First set the root tuple table which is the first buffer in the ShapeTree.
1522 execution_input.SetBuffer(
1523 input_iterator->first,
1524 MaybeOwningDeviceMemory(std::move(root_table_memory)));
1525 ++input_iterator;
1526 // Then set each sub-tuple in turn from the parameters.
1527 for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
1528 device_buffers) {
1529 device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
1530 allocator);
1531 }
1532 CHECK(input_iterator == iterator_end);
1533
1534 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
1535 stream, execution_input.Buffers()));
1536 StatusOr<EventPool::Handle> event_or =
1537 local_device->event_pool().ThenAllocateAndRecordEvent(stream);
1538 if (!event_or.ok()) {
1539 StallStreamOnError(local_device, stream);
1540 return event_or.status();
1541 }
1542
1543 auto transfer_event = std::make_shared<BufferSequencingEvent>();
1544 transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1545 return TupleHandle({std::move(execution_input), std::move(transfer_event)});
1546 }
1547
1548 // Converts a ScopedShapedBuffer returned from an execution into a
1549 // PjRtBuffer.
OutputBufferHelper(ScopedShapedBuffer * result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtClient * client,PjRtDevice * device,LocalDeviceState * local_device)1550 std::unique_ptr<PjRtBuffer> OutputBufferHelper(
1551 ScopedShapedBuffer* result_buffer,
1552 std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
1553 PjRtDevice* device, LocalDeviceState* local_device) {
1554 std::shared_ptr<TrackedDeviceBuffer> out_buffer =
1555 TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
1556 {definition_event});
1557 auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
1558 result_buffer->on_device_shape(), std::move(out_buffer), client, device);
1559 RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
1560 definition_event, local_device->compute_stream(),
1561 /*prefer_to_retain_reference=*/false);
1562 return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
1563 }
1564 } // namespace
1565
PjRtStreamExecutorExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,bool parameter_is_tupled_arguments,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,PjRtStreamExecutorClient * client)1566 PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
1567 std::vector<std::unique_ptr<LocalExecutable>> executables,
1568 bool parameter_is_tupled_arguments,
1569 std::shared_ptr<DeviceAssignment> device_assignment,
1570 std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1571 std::vector<PjRtDevice*> addressable_devices,
1572 PjRtStreamExecutorClient* client)
1573 : client_(client),
1574 device_assignment_(std::move(device_assignment)),
1575 parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1576 addressable_device_logical_ids_(
1577 std::move(addressable_device_logical_ids)),
1578 addressable_devices_(std::move(addressable_devices)) {
1579 executables_.reserve(executables.size());
1580 for (auto& executable : executables) {
1581 executables_.emplace_back(std::move(executable));
1582 }
1583
1584 int num_partitions;
1585 if (device_assignment_ == nullptr) {
1586 // This must go after `executables_` is initialized.
1587 VLOG(1) << "PjRtStreamExecutorExecutable portable single-core";
1588 num_partitions = 1;
1589 CHECK(addressable_devices_.empty());
1590 } else {
1591 // This must go after `executables_` is initialized.
1592 VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
1593 << device_assignment_->ToString();
1594 CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
1595 CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
1596 << "Inconsistent local device count.";
1597 num_partitions = device_assignment_->computation_count();
1598 }
1599
1600 // SPMD sharding produces a single executable for multiple partitions.
1601 if (executables_.size() > 1) {
1602 CHECK_EQ(num_partitions, executables_.size())
1603 << "Number of executables " << executables_.size()
1604 << " did not match number of partitions " << num_partitions;
1605 }
1606 }
1607
SetUpDonation(bool tuple_inputs)1608 Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
1609 parameters_that_must_be_donated_.reserve(executables_.size());
1610 for (auto& executable : executables_) {
1611 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> parameters_to_donate,
1612 GetParametersThatMustBeDonated(
1613 executable->executable()->module(), tuple_inputs));
1614 parameters_that_must_be_donated_.emplace_back(
1615 std::move(parameters_to_donate));
1616 }
1617 return Status::OK();
1618 }
1619
name() const1620 absl::string_view PjRtStreamExecutorExecutable::name() const {
1621 Executable* executable = executables_[0]->executable();
1622 if (executable->has_module()) {
1623 return executable->module().name();
1624 } else {
1625 return "<unknown executable>";
1626 }
1627 }
1628
MustDonateParameter(int executable_idx,int parameter) const1629 bool PjRtStreamExecutorExecutable::MustDonateParameter(int executable_idx,
1630 int parameter) const {
1631 return parameters_that_must_be_donated_[executable_idx].contains(parameter);
1632 }
1633
1634 StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(int device_ordinal,const ExecuteOptions & options,absl::Span<PjRtBuffer * const> argument_handles,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,absl::flat_hash_set<BufferSequencingEvent * > & events) const1635 PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
1636 int device_ordinal, const ExecuteOptions& options,
1637 absl::Span<PjRtBuffer* const> argument_handles,
1638 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1639 absl::flat_hash_set<BufferSequencingEvent*>& events) const {
1640 std::vector<ExecutionInput> execution_inputs;
1641 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1642 // Lift tuple_handle outside the conditional so that the event it returns is
1643 // not destroyed until after the loop below that waits on events.
1644 absl::optional<TupleHandle> tuple_handle;
1645 if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1646 TF_ASSIGN_OR_RETURN(tuple_handle,
1647 MakeTupleHelper(client_, device_state, argument_handles,
1648 device_buffers, device_ordinal));
1649 events.insert(tuple_handle->event.get());
1650 execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
1651 } else {
1652 execution_inputs.reserve(argument_handles.size());
1653 for (int i = 0; i < argument_handles.size(); ++i) {
1654 PjRtBuffer* handle = argument_handles[i];
1655
1656 // Make an ExecutionInput from the device buffer.
1657 execution_inputs.emplace_back(handle->on_device_shape());
1658 ExecutionInput& execution_input = execution_inputs.back();
1659 ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1660 execution_input.MutableBuffers()->begin();
1661 ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1662 execution_input.MutableBuffers()->end();
1663 device_buffers[i].AddToInput(
1664 &input_iterator, iterator_end, &execution_input,
1665 tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1666 ->allocator());
1667 CHECK(input_iterator == iterator_end);
1668 }
1669 }
1670
1671 for (BufferSequencingEvent* event : events) {
1672 event->WaitForEventOnStream(device_state->compute_stream());
1673 }
1674
1675 return execution_inputs;
1676 }
1677
1678 // Enqueues a computation onto the compute stream. Each buffer returned in
1679 // device_buffers has a usage hold added that must be dropped on error or
1680 // converted on success.
EnqueueExecution(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,int executable_idx,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device,std::vector<PjRtStreamExecutorBuffer::ScopedHold> * device_buffers,std::shared_ptr<DeviceAssignment> device_assignment) const1681 StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
1682 absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1683 int executable_idx, const RunId& run_id, const ExecuteOptions& options,
1684 PjRtDevice* device,
1685 std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
1686 std::shared_ptr<DeviceAssignment> device_assignment) const {
1687 int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1688 ->local_device_state()
1689 ->device_ordinal();
1690 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1691 tensorflow::profiler::TraceMeConsumer activity(
1692 "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
1693 run_id.ToInt());
1694 VLOG(3) << "Replica " << replica << ", partition " << partition
1695 << " mapped to device ordinal for execution: " << device_ordinal;
1696
1697 absl::flat_hash_set<BufferSequencingEvent*> events;
1698 device_buffers->reserve(argument_handles.size());
1699 for (int i = 0; i < argument_handles.size(); ++i) {
1700 auto* handle =
1701 tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
1702 if (handle->device() != device) {
1703 return InvalidArgument(
1704 "Buffer passed to Execute() as argument %d to replica %d is on "
1705 "device %s, but replica is assigned to device %s.",
1706 i, replica, handle->device()->DebugString(), device->DebugString());
1707 }
1708 bool must_donate = MustDonateParameter(executable_idx, i);
1709 device_buffers->emplace_back(handle->GetBufferWithHold(
1710 must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
1711 : PjRtStreamExecutorBuffer::ScopedHold::kUsage));
1712 PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
1713 device_buffers->back();
1714 if (!device_buffer.ok()) {
1715 return InvalidArgument(
1716 "Invalid buffer passed to Execute() as argument %d to replica %d: "
1717 "%s",
1718 i, replica, device_buffer.status().ToString());
1719 }
1720 // If we are trying to donate the buffer wait on the usage events as well
1721 // as the definition events to ensure that all reads have been completed
1722 // before the buffer is mutated. Usage holds are excluded during a donation
1723 // hold so we know that the set of usage events won't be modified while we
1724 // are enqueueing.
1725 GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
1726 &events);
1727 }
1728
1729 if (options.arguments_are_tupled) {
1730 if (!parameter_is_tupled_arguments_) {
1731 return InvalidArgument(
1732 "Arguments may only be supplied as a tuple when the executable was "
1733 "compiled with a single tupled parameter");
1734 }
1735 if (argument_handles.size() != 1) {
1736 return InvalidArgument(
1737 "Option arguments_are_tupled was true but %d buffers were passed to "
1738 "execution",
1739 argument_handles.size());
1740 }
1741 }
1742
1743 TF_ASSIGN_OR_RETURN(
1744 std::vector<ExecutionInput> execution_inputs,
1745 MakeExecutionInputsAndWaitForEvents(
1746 device_ordinal, options, argument_handles, *device_buffers, events));
1747
1748 ExecutableRunOptions run_options;
1749 run_options.set_stream(device_state->compute_stream());
1750 run_options.set_host_to_device_stream(device_state->host_to_device_stream());
1751 run_options.set_allocator(client_->allocator());
1752 run_options.set_intra_op_thread_pool(
1753 client_->client()->backend().eigen_intra_op_thread_pool_device());
1754 run_options.set_device_assignment(device_assignment.get());
1755 run_options.set_run_id(run_id);
1756 run_options.set_rng_seed(device_state->GetNewPrngSeed());
1757 run_options.set_gpu_executable_run_options(client_->gpu_run_options());
1758 run_options.set_launch_id(options.launch_id);
1759 if (run_options.launch_id() != 0) {
1760 VLOG(1) << "launch id for " << name() << ": " << run_options.launch_id();
1761 }
1762
1763 // The choice of where we wait is arbitrary; the reason for the wait is
1764 // pacing to avoid problems such as memory fragmentation and running ahead
1765 // too far, not for correctness. Placing it before the executable launch
1766 // allows the inputs for the next executable to be fetched even if the
1767 // launch is delayed.
1768 auto compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
1769 device_state->compute_semaphore().ScopedAcquire(1));
1770
1771 StatusOr<ExecutionOutput> result_buffer_or_status =
1772 executables_[executable_idx]->RunAsync(std::move(execution_inputs),
1773 run_options);
1774
1775 VLOG(1) << "Replica " << replica << " partition " << partition
1776 << " completed; ok=" << result_buffer_or_status.ok();
1777
1778 if (!result_buffer_or_status.ok()) {
1779 return result_buffer_or_status.status();
1780 }
1781
1782 if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1783 ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
1784 // If we used a transient tuple for the arguments we donated its root table
1785 // buffer. In that case, and/or if we donated any input buffers that were
1786 // not aliased, the donated buffers are going to be passed back to us via
1787 // the execution output. We need to ensure they aren't freed until after
1788 // execution completes. (Currently XLA does not support aliasing tuple
1789 // tables, so if any donated parameter is a tuple there will be donated but
1790 // unaliased buffers.)
1791 std::vector<se::OwningDeviceMemory> donated_memory =
1792 execution_output.ConsumeToBeReleased();
1793 absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
1794 donated_ptrs.reserve(donated_memory.size());
1795 for (se::OwningDeviceMemory& owning : donated_memory) {
1796 // Release the owning memory so we can pass it to the closure.
1797 donated_ptrs.push_back(owning.Release());
1798 }
1799 device_state->ThenExecuteOnCallbackThread(
1800 device_state->compute_stream(),
1801 [references{std::make_tuple(executables_[executable_idx],
1802 compute_reservation, device_assignment)},
1803 donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
1804 device_ordinal]() {
1805 for (const auto& ptr : donated_ptrs) {
1806 TF_CHECK_OK(allocator->Deallocate(device_ordinal, ptr));
1807 }
1808 });
1809 } else {
1810 // Any donated memory returned by the ExecutionOutput can be immediately
1811 // freed.
1812 device_state->ThenRelease(
1813 device_state->compute_stream(),
1814 std::make_tuple(executables_[executable_idx], compute_reservation,
1815 device_assignment));
1816 }
1817
1818 return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
1819 }
1820
1821 std::vector<std::unique_ptr<PjRtBuffer>>
MakeOutputBuffers(int device_ordinal,const ExecuteOptions & options,ScopedShapedBuffer result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtDevice * device) const1822 PjRtStreamExecutorExecutable::MakeOutputBuffers(
1823 int device_ordinal, const ExecuteOptions& options,
1824 ScopedShapedBuffer result_buffer,
1825 std::shared_ptr<BufferSequencingEvent> definition_event,
1826 PjRtDevice* device) const {
1827 std::vector<std::unique_ptr<PjRtBuffer>> outputs;
1828 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1829 if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
1830 int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
1831 outputs.reserve(tuple_count);
1832 // Take ownership of each of the output values, leaving only the root table
1833 // in result_buffer.
1834 for (int i = 0; i < tuple_count; ++i) {
1835 ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i});
1836 outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event,
1837 client_, device, device_state));
1838 }
1839 if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1840 // Don't release the root buffer until after execution completes.
1841 ShapedBuffer root_buffer_holder = result_buffer.release();
1842 se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer();
1843 device_state->ThenExecuteOnCallbackThread(
1844 device_state->compute_stream(),
1845 [root_buffer, allocator{client_->allocator()}, device_ordinal]() {
1846 TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer));
1847 });
1848 }
1849 } else {
1850 outputs.push_back(OutputBufferHelper(&result_buffer, definition_event,
1851 client_, device, device_state));
1852 }
1853 return outputs;
1854 }
1855
1856 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device) const1857 PjRtStreamExecutorExecutable::ExecuteHelper(
1858 absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1859 const RunId& run_id, const ExecuteOptions& options,
1860 PjRtDevice* device) const {
1861 std::shared_ptr<DeviceAssignment> device_assignment;
1862 if (device == nullptr) {
1863 CHECK(device_assignment_ != nullptr);
1864 const int device_id = (*device_assignment_)(replica, partition);
1865 TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
1866 device_assignment = device_assignment_;
1867 } else {
1868 CHECK(device_assignment_ == nullptr);
1869 CHECK_EQ(replica, 0);
1870 CHECK_EQ(partition, 0);
1871 CHECK(addressable_devices_.empty());
1872 device_assignment = std::make_shared<DeviceAssignment>(1, 1);
1873 (*device_assignment)(0, 0) = device->id();
1874 }
1875
1876 CHECK_EQ(device->task_id(), client_->task_id());
1877 int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1878 ->local_device_state()
1879 ->device_ordinal();
1880 tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
1881 VLOG(3) << "Replica " << replica << ", partition " << partition
1882 << " mapped to device ordinal for execution: " << device_ordinal;
1883
1884 // SPMD sharding produces a single executable for multiple partitions.
1885 int executable_idx = executables_.size() > 1 ? partition : 0;
1886
1887 std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
1888 device_buffers.reserve(argument_handles.size());
1889 StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
1890 argument_handles, replica, partition, executable_idx, run_id, options,
1891 device, &device_buffers, std::move(device_assignment));
1892
1893 if (!result_buffer_or_status.ok()) {
1894 LOG(ERROR) << "Execution of replica " << replica
1895 << " failed: " << result_buffer_or_status.status();
1896 return result_buffer_or_status.status();
1897 }
1898 ScopedShapedBuffer result_buffer =
1899 result_buffer_or_status.ConsumeValueOrDie();
1900
1901 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1902 se::Stream* stream = device_state->compute_stream();
1903 StatusOr<EventPool::Handle> event_or =
1904 device_state->event_pool().ThenAllocateAndRecordEvent(stream);
1905 if (!event_or.ok()) {
1906 StallStreamOnError(device_state, stream);
1907 for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
1908 if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
1909 // Even though there was an error we need to call ConfirmDonation, which
1910 // renders b invalid, since the computation has been enqueued and b has
1911 // been donated.
1912 b.ConfirmDonation();
1913 }
1914 }
1915 return event_or.status();
1916 }
1917 auto definition_event = std::make_shared<BufferSequencingEvent>();
1918 definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1919 std::vector<std::unique_ptr<PjRtBuffer>> outputs =
1920 MakeOutputBuffers(device_ordinal, options, std::move(result_buffer),
1921 definition_event, device);
1922
1923 for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
1924 // prefer_to_retain_reference=false because when using the
1925 // ComputeSynchronized allocation model we don't need to retain a reference
1926 // to the device_buffer during execution because by definition the compute
1927 // stream is synchronized past the execution.
1928 if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
1929 RecordUsage(std::move(b), device_state, device_state, definition_event,
1930 stream,
1931 /*prefer_to_retain_reference=*/false);
1932 } else {
1933 CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
1934 b.ConfirmDonation();
1935 }
1936 }
1937
1938 return outputs;
1939 }
1940
1941 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options)1942 PjRtStreamExecutorExecutable::Execute(
1943 absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1944 const ExecuteOptions& options) {
1945 if (device_assignment_ == nullptr) {
1946 return InvalidArgument("Execute expects a non-null device_assignment");
1947 }
1948
1949 RunId run_id;
1950 tensorflow::profiler::TraceMeProducer activity(
1951 "PjRtStreamExecutorExecutable::Execute",
1952 tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
1953
1954 const int num_addressable_devices = addressable_devices_.size();
1955
1956 if (argument_handles.size() != num_addressable_devices) {
1957 return InvalidArgument(
1958 "Attempted to execute with %d argument lists when local device "
1959 "count is %d (total replica count: %d, partition count: %d)",
1960 argument_handles.size(), num_addressable_devices, num_replicas(),
1961 num_partitions());
1962 }
1963
1964 VLOG(1) << "Executing computation " << name()
1965 << "; num_replicas=" << num_replicas()
1966 << " num_partitions=" << num_partitions()
1967 << " num_addressable_devices=" << num_addressable_devices;
1968 std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
1969 num_addressable_devices);
1970 if (num_addressable_devices == 1) {
1971 // Fast-path if there is only one device — run the computation on the
1972 // current thread.
1973 const int replica = addressable_device_logical_ids_[0].replica;
1974 const int partition = addressable_device_logical_ids_[0].partition;
1975 results[0] =
1976 ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
1977 } else {
1978 absl::Mutex mu;
1979 int running = num_addressable_devices;
1980 int failed = 0;
1981 Status first_failure_status;
1982
1983 for (int i = 0; i < num_addressable_devices; ++i) {
1984 const int replica = addressable_device_logical_ids_[i].replica;
1985 const int partition = addressable_device_logical_ids_[i].partition;
1986 PjRtDevice* device = addressable_devices_[i];
1987 const LocalDeviceState& device_state =
1988 *tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1989 ->local_device_state();
1990 device_state.execute_thread()->Schedule([&, replica, partition, i] {
1991 results[i] = ExecuteHelper(argument_handles[i], replica, partition,
1992 run_id, options);
1993
1994 absl::MutexLock lock(&mu);
1995 --running;
1996 if (!results[i].ok()) {
1997 if (failed == 0) {
1998 first_failure_status = results[i].status();
1999 }
2000 ++failed;
2001 }
2002 });
2003 }
2004
2005 auto done_running_or_failed = [&]() {
2006 mu.AssertHeld();
2007 return running == 0 || failed > 0;
2008 };
2009 absl::MutexLock lock(&mu);
2010 mu.Await(absl::Condition(&done_running_or_failed));
2011 if (failed > 0) {
2012 auto done_running = [&]() {
2013 mu.AssertHeld();
2014 return running == 0;
2015 };
2016 // If execution does not terminate within a reasonable amount of time,
2017 // we may be stuck at a cross-replica barrier on-device. Terminate the
2018 // process since that's the only way we can escape this situation at the
2019 // moment (b/130629719).
2020 if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
2021 absl::Seconds(10))) {
2022 LOG(FATAL)
2023 << "Replicated computation launch failed, but not all replicas "
2024 "terminated. Aborting process to work around deadlock. "
2025 "Failure message (there may have been multiple failures, see "
2026 "the error log for all failures): \n\n"
2027 << first_failure_status.error_message();
2028 }
2029 }
2030 }
2031 VLOG(1) << "Replicated execution complete.";
2032
2033 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
2034 num_addressable_devices);
2035 for (int i = 0; i < num_addressable_devices; ++i) {
2036 const int replica = addressable_device_logical_ids_[i].replica;
2037 const int partition = addressable_device_logical_ids_[i].partition;
2038 auto& statusor = results[i];
2039 if (!statusor.ok()) {
2040 if (num_addressable_devices == 1) {
2041 return statusor.status();
2042 } else {
2043 return AppendStatus(
2044 statusor.status(),
2045 absl::StrFormat("while running replica %d and partition %d of a "
2046 "replicated computation (other "
2047 "replicas may have failed as well).",
2048 replica, partition));
2049 }
2050 }
2051 wrapped_results[i] = std::move(statusor.ValueOrDie());
2052 }
2053 return wrapped_results;
2054 }
2055
2056 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2057 PjRtStreamExecutorExecutable::ExecuteSharded(
2058 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2059 const ExecuteOptions& options) {
2060 if (device_assignment_ == nullptr) {
2061 return InvalidArgument("ExecuteShard expects a non-null device_assignment");
2062 }
2063 for (int i = 0; i < addressable_devices_.size(); ++i) {
2064 if (addressable_devices_[i] == device) {
2065 VLOG(1) << "ExecuteShard executes computation " << name()
2066 << " on assigned replica/partition on device "
2067 << device->DebugString();
2068 return ExecuteHelper(
2069 argument_handles, addressable_device_logical_ids_[i].replica,
2070 addressable_device_logical_ids_[i].partition, RunId(), options);
2071 }
2072 }
2073 return InvalidArgument(
2074 "ExecuteShard attempted to execute on device id %d which is not "
2075 "addressable by this client",
2076 device->id());
2077 }
2078
2079 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2080 PjRtStreamExecutorExecutable::ExecutePortable(
2081 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2082 const ExecuteOptions& options) {
2083 if (device_assignment_ != nullptr) {
2084 return InvalidArgument("ExecutePortable gets a non-portable executable");
2085 }
2086 if (num_replicas() != 1 || num_partitions() != 1) {
2087 return InvalidArgument(
2088 "ExecutePortable expects a single-core executable but gets "
2089 "one with %d replica %d partition",
2090 num_replicas(), num_partitions());
2091 }
2092 if (device == nullptr) {
2093 return InvalidArgument("ExecutePortable expects a device to be specified");
2094 }
2095 VLOG(1) << "ExecutePortable executes single-core portable executable "
2096 << name();
2097 return ExecuteHelper(argument_handles,
2098 /*replica=*/0,
2099 /*partition=*/0, RunId(), options, device);
2100 }
2101
2102 StatusOr<std::vector<std::shared_ptr<HloModule>>>
GetHloModules() const2103 PjRtStreamExecutorExecutable::GetHloModules() const {
2104 std::vector<std::shared_ptr<HloModule>> modules;
2105 modules.reserve(executables().size());
2106 for (const auto& local_exec : executables()) {
2107 if (!local_exec->executable()->has_module()) {
2108 return InvalidArgument("Executable does not have HLO modules.");
2109 }
2110 modules.push_back(local_exec->executable()->shared_module());
2111 }
2112 return std::move(modules);
2113 }
2114
Compile(const XlaComputation & computation,CompileOptions options)2115 StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
2116 const XlaComputation& computation, CompileOptions options) {
2117 tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
2118
2119 ExecutableBuildOptions& build_options = options.executable_build_options;
2120 if (!build_options.compile_thread_pool()) {
2121 build_options.set_compile_thread_pool(thread_pool());
2122 }
2123 if (!build_options.device_allocator()) {
2124 build_options.set_device_allocator(allocator());
2125 }
2126
2127 int num_replicas;
2128 int num_partitions;
2129 std::shared_ptr<DeviceAssignment> device_assignment;
2130 TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
2131 options.compile_portable_executable, &options.executable_build_options,
2132 [this](int num_replicas, int num_partitions) {
2133 return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
2134 },
2135 &num_replicas, &num_partitions, &device_assignment));
2136
2137 std::vector<const Shape*> argument_layout_pointers;
2138 TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
2139 computation,
2140 [local_client = client()](Shape shape) {
2141 return local_client->backend()
2142 .transfer_manager()
2143 ->ChooseCompactLayoutForShape(shape);
2144 },
2145 options.argument_layouts, &options.executable_build_options,
2146 &argument_layout_pointers));
2147
2148 // Find devices that are addressable by this client/task.
2149 std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
2150 std::vector<PjRtDevice*> addressable_devices;
2151 if (device_assignment != nullptr) {
2152 addressable_device_logical_ids.reserve(num_replicas * num_partitions);
2153 addressable_devices.reserve(num_replicas * num_partitions);
2154 for (int replica = 0; replica < num_replicas; ++replica) {
2155 for (int partition = 0; partition < num_partitions; ++partition) {
2156 int device_id = (*device_assignment)(replica, partition);
2157 TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
2158 if (device->task_id() != task_id()) {
2159 VLOG(3) << "Non-local device: " << device_id;
2160 continue;
2161 }
2162 PjRtExecutable::LogicalDeviceIds logica_device_ids;
2163 logica_device_ids.replica = replica;
2164 logica_device_ids.partition = partition;
2165 addressable_device_logical_ids.push_back(std::move(logica_device_ids));
2166 addressable_devices.push_back(device);
2167 }
2168 }
2169 if (addressable_devices.empty()) {
2170 return InvalidArgument(
2171 "Device assignment (%s) does not have any local devices.",
2172 device_assignment->ToString());
2173 }
2174
2175 if (build_options.device_ordinal() < 0) {
2176 build_options.set_device_ordinal(
2177 addressable_devices.front()->local_hardware_id());
2178 }
2179 }
2180
2181 TF_ASSIGN_OR_RETURN(
2182 std::vector<std::unique_ptr<LocalExecutable>> local_executables,
2183 client()->Compile(computation, argument_layout_pointers, build_options));
2184
2185 auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
2186 std::move(local_executables), options.parameter_is_tupled_arguments,
2187 std::move(device_assignment), std::move(addressable_device_logical_ids),
2188 std::move(addressable_devices), this);
2189 TF_RETURN_IF_ERROR(
2190 executable->SetUpDonation(options.parameter_is_tupled_arguments));
2191 return std::unique_ptr<PjRtExecutable>(std::move(executable));
2192 }
2193
2194 } // namespace xla
2195