1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Asynchronous execution:
19 // -----------------------
20 //
21 // Computations and host-to-device transfers do not need to block the host
22 // waiting for the operation to complete but instead return control to the host
23 // immediately. This allows client logic to overlap with device-side
24 // computation.
25 //
26 // For a good user experience, we must be careful only to enqueue operations
27 // that are unlikely to fail; as a rule error checking must be done eagerly
28 // before returning control to the client.
29 //
30 // The degree to which the client can enqueue operations ahead of the client
31 // is limited by a semaphore. There are at two modes: asynchronous, where we
32 // allow the client to enqueue up to 32 executions ahead of the device, and
33 // synchronous, where we limit the client to having one enqueued operation at
34 // a time. The value of 32 is arbitrary.
35 //
36 // Even in asynchronous mode, it is important that we do not permit
37 // unbounded queue-ahead. Firstly it is problematic when the user does something
38 // like the following in Python:
39 // %timeit run_computation()
40 // To the timeit logic, op() appears to be extremely cheap since it is deferring
41 // all of its real work and not blocking, and so the %timeit will run op() many
42 // (e.g., 10000) times to get better timing resolution, even though in reality
43 // it may be expensive. Secondly, on CPU the allocator is synchronized with the
44 // head of the compute stream, and we allocate buffers for all of the enqueued
45 // programs without any reuse (unlike GPU). This means that the memory usage
46 // is proportional to the queue size.
47 //
48 // Multi-stream execution:
49 // -----------------------
50 //
51 // We use a multistream execution design, where different Streams are used for
52 // host-to-device transfers, device-to-host transfers, and compute. This allows
53 // us to overlap transfers on and off the device with computation.
54 //
55 // Synchronization between streams occurs via BufferSequencingEvents that
56 // describe when the contents of a logical buffer are known to be valid on
57 // a particular stream, and when a buffer's uses have all completed.
58 //
59 // Synchronous vs asynchronous deallocation:
60 // -----------------------------------------
61 //
62 // See the comment on LocalDeviceState::AllocationModel for a discussion of the
63 // different allocation semantics on CPU, GPU, and TPU.
64
65 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
66
67 #include <cstddef>
68 #include <cstdlib>
69 #include <memory>
70 #include <string>
71 #include <vector>
72
73 #include "absl/base/casts.h"
74 #include "absl/container/flat_hash_set.h"
75 #include "absl/container/inlined_vector.h"
76 #include "absl/memory/memory.h"
77 #include "absl/strings/str_format.h"
78 #include "absl/synchronization/mutex.h"
79 #include "absl/time/time.h"
80 #include "absl/types/optional.h"
81 #include "absl/types/span.h"
82 #include "tensorflow/compiler/xla/client/local_client.h"
83 #include "tensorflow/compiler/xla/client/xla_computation.h"
84 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
85 #include "tensorflow/compiler/xla/executable_run_options.h"
86 #include "tensorflow/compiler/xla/layout.h"
87 #include "tensorflow/compiler/xla/literal.h"
88 #include "tensorflow/compiler/xla/literal_util.h"
89 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
90 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
91 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
92 #include "tensorflow/compiler/xla/pjrt/metrics.h"
93 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
94 #include "tensorflow/compiler/xla/pjrt/utils.h"
95 #include "tensorflow/compiler/xla/service/computation_layout.h"
96 #include "tensorflow/compiler/xla/service/executable.h"
97 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
98 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
99 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
100 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
101 #include "tensorflow/compiler/xla/service/transfer_manager.h"
102 #include "tensorflow/compiler/xla/shape.h"
103 #include "tensorflow/compiler/xla/shape_util.h"
104 #include "tensorflow/compiler/xla/util.h"
105 #include "tensorflow/compiler/xla/xla_data.pb.h"
106 #include "tensorflow/core/platform/cpu_info.h"
107 #include "tensorflow/core/platform/env.h"
108 #include "tensorflow/core/platform/errors.h"
109 #include "tensorflow/core/platform/fingerprint.h"
110 #include "tensorflow/core/platform/mem.h"
111 #include "tensorflow/core/platform/status.h"
112 #include "tensorflow/core/platform/types.h"
113 #include "tensorflow/core/profiler/lib/connected_traceme.h"
114 #include "tensorflow/core/profiler/lib/traceme.h"
115 #include "tensorflow/core/profiler/lib/traceme_encode.h"
116 #include "tensorflow/stream_executor/device_memory.h"
117 #include "tensorflow/stream_executor/device_memory_allocator.h"
118 #include "tensorflow/stream_executor/event.h"
119 #include "tensorflow/stream_executor/host/host_platform_id.h"
120 #include "tensorflow/stream_executor/lib/statusor.h"
121 #include "tensorflow/stream_executor/stream.h"
122
123 namespace xla {
124
platform_id() const125 PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
126 return client_->platform_id();
127 }
platform_name() const128 absl::string_view PjRtStreamExecutorDevice::platform_name() const {
129 return client_->platform_name();
130 }
131
GetLocalDeviceState() const132 StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
133 const {
134 if (local_device_state_) {
135 return local_device_state_.get();
136 }
137 return InvalidArgument("Device %s is not a local device.", DebugString());
138 }
139
DebugString() const140 std::string PjRtStreamExecutorDevice::DebugString() const {
141 return absl::StrCat(platform_name(), ":", id());
142 }
143
DevicesToDeviceAssignment(absl::Span<const std::vector<PjRtDevice * >> devices)144 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
145 absl::Span<const std::vector<PjRtDevice*>> devices) {
146 if (devices.empty()) {
147 return InvalidArgument(
148 "Device assignment passed to Compile() must be non-empty.");
149 }
150 if (devices[0].empty()) {
151 return InvalidArgument(
152 "Device assignment passed to Compile() must have a nonzero number of "
153 "partitions per replica; replica 0 had 0 partitions.");
154 }
155 DeviceAssignment xla_assignment(devices.size(), devices[0].size());
156 for (int replica = 0; replica < devices.size(); ++replica) {
157 if (devices[replica].size() != devices[0].size()) {
158 return InvalidArgument(
159 "Device assignment passed to Compile() has different numbers of "
160 "partitions between replicas; %d partitions for replica %d versus %d "
161 "partitions for replica 0.",
162 devices[replica].size(), replica, devices[0].size());
163 }
164 for (int partition = 0; partition < devices[replica].size(); ++partition) {
165 if (devices[0][0]->client()->platform_id() !=
166 devices[replica][partition]->client()->platform_id()) {
167 return InvalidArgument(
168 "Device assignment passed to Compile() must have devices of a "
169 "single kind, got %s for replica 0 partition 0 and %s for replica "
170 "%d partition %d.",
171 devices[0][0]->client()->platform_name(),
172 devices[replica][partition]->client()->platform_name(), replica,
173 partition);
174 }
175 xla_assignment(replica, partition) = devices[replica][partition]->id();
176 }
177 }
178 return xla_assignment;
179 }
180
181 class CpuAllocator : public tensorflow::Allocator {
182 public:
183 CpuAllocator() = default;
184
Name()185 std::string Name() override { return "cpu"; }
186
AllocateRaw(size_t alignment,size_t num_bytes)187 void* AllocateRaw(size_t alignment, size_t num_bytes) override {
188 return tensorflow::port::AlignedMalloc(num_bytes, alignment);
189 }
DeallocateRaw(void * ptr)190 void DeallocateRaw(void* ptr) override {
191 return tensorflow::port::AlignedFree(ptr);
192 }
193 };
194
PjRtStreamExecutorClient(std::string platform_name,LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int process_index,std::unique_ptr<se::DeviceMemoryAllocator> allocator,std::unique_ptr<tensorflow::Allocator> host_memory_allocator,bool should_stage_host_to_device_transfers,std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)195 PjRtStreamExecutorClient::PjRtStreamExecutorClient(
196 std::string platform_name, LocalClient* client,
197 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
198 int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
199 std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
200 bool should_stage_host_to_device_transfers,
201 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
202 : platform_id_(tensorflow::Fingerprint64(platform_name)),
203 platform_name_(std::move(platform_name)),
204 client_(client),
205 host_memory_allocator_(std::move(host_memory_allocator)),
206 owned_allocator_(std::move(allocator)),
207 owned_devices_(std::move(devices)),
208 process_index_(process_index),
209 should_stage_host_to_device_transfers_(
210 should_stage_host_to_device_transfers),
211 gpu_run_options_(std::move(gpu_run_options)),
212 thread_pool_(
213 tensorflow::Env::Default(), "pjrt_thread_pool",
214 std::max<int>(DefaultThreadPoolSize(), client->device_count())),
215 transpose_cache_(1024) {
216 if (owned_allocator_ != nullptr) {
217 allocator_ = owned_allocator_.get();
218 } else {
219 allocator_ = client_->backend().memory_allocator();
220 }
221
222 if (!host_memory_allocator_) {
223 host_memory_allocator_ = std::make_unique<CpuAllocator>();
224 }
225
226 for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
227 owned_devices_) {
228 devices_.push_back(device.get());
229 CHECK(id_to_device_.insert({device->id(), device.get()}).second)
230 << "Duplicate device id: " << device->id();
231
232 if (device->IsAddressable()) {
233 int idx = device->local_hardware_id();
234 if (idx >= addressable_devices_.size()) {
235 addressable_devices_.resize(idx + 1);
236 }
237 CHECK(addressable_devices_[idx] == nullptr) << idx;
238 addressable_devices_[idx] = device.get();
239 }
240 device->SetClient(this);
241 }
242 for (int idx = 0; idx < addressable_devices_.size(); ++idx) {
243 CHECK(addressable_devices_[idx] != nullptr) << idx;
244 }
245 }
246
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const247 StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
248 int num_replicas, int num_partitions) const {
249 return client_->backend().computation_placer()->AssignDevices(num_replicas,
250 num_partitions);
251 }
252
253 StatusOr<std::unique_ptr<HloCostAnalysis>>
GetHloCostAnalysis()254 PjRtStreamExecutorClient::GetHloCostAnalysis() {
255 return absl::make_unique<HloCostAnalysis>(
256 client_->backend().compiler()->ShapeSizeBytesFunction());
257 }
258
259 namespace {
260
261 // Ensures that it is safe to deallocate any buffers that have been enqueued in
262 // an operation on stream. Called only in rare error cases that are triggered
263 // during enqueue. These cases generally correspond to resource exhaustion.
StallStreamOnError(LocalDeviceState * local_device,se::Stream * stream)264 void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
265 switch (local_device->allocation_model()) {
266 case LocalDeviceState::kAsynchronous:
267 // We can safely deallocate any dangling buffers immediately. NOTE: this
268 // assumes that any buffers enqueued on stream are local to stream's
269 // executor, and manual action may be needed if that condition is not met.
270 break;
271
272 case LocalDeviceState::kComputeSynchronized:
273 // This will stall computation but that's ok in this very rare error
274 // case.
275 if (stream != local_device->compute_stream()) {
276 local_device->compute_stream()->ThenWaitFor(stream);
277 }
278 break;
279
280 case LocalDeviceState::kSynchronous:
281 // This will stall the calling thread but that's ok in this very rare
282 // error case. If the stall fails just crash, since we have no other
283 // way to synchronize.
284 TF_CHECK_OK(stream->BlockHostUntilDone());
285 break;
286 }
287 }
288
289 // Does all necessary bookkeeping, after a buffer is successfully enqueued onto
290 // a stream, to ensure that the buffer will be kept alive until its use on that
291 // stream is complete.
292 //
293 // device_buffer: the buffer that was enqueued.
294 // buffer_local_device: the device the buffer was allocated on.
295 // stream_local_device: the device that manages usage_stream.
296 // event: an event that was recorded on usage_stream
297 // after the usage of device_buffer was enqueued.
298 // usage_stream: the stream the operation using device_buffer
299 // was enqueued on.
300 // prefer_to_retain_reference: relevant only for the compute synchronous
301 // allocation model. If true, retain a reference
302 // to device_buffer until after the operation
303 // completes. If false then the compute stream
304 // will have to be synchronized past event before
305 // device_buffer can be freed.
306 //
307 // prefer_to_retain_reference encodes a heuristic set by the caller for the
308 // compute synchronous model:
309 //
310 // Generally when a buffer is the destination of a copy to a device, it will
311 // subsequently be used on the device's compute stream before being freed. In
312 // that case, there is no need to retain a reference to the buffer. If the
313 // buffer is freed before being used on the compute stream, the free will be
314 // delayed until the host knows that event has completed, but this is expected
315 // to be uncommon.
316 //
317 // When a buffer is the source of a copy from a device, we need to either retain
318 // a reference to the buffer until the copy completes or serialize the compute
319 // stream behind the copy. It is often better to retain a reference since while
320 // that keeps memory alive longer, it avoids stalling the compute stream.
RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,LocalDeviceState * buffer_local_device,LocalDeviceState * stream_local_device,std::shared_ptr<BufferSequencingEvent> event,se::Stream * usage_stream,bool prefer_to_retain_reference,std::vector<std::shared_ptr<TrackedDeviceBuffer>> * buffers_to_release=nullptr)321 void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
322 LocalDeviceState* buffer_local_device,
323 LocalDeviceState* stream_local_device,
324 std::shared_ptr<BufferSequencingEvent> event,
325 se::Stream* usage_stream, bool prefer_to_retain_reference,
326 std::vector<std::shared_ptr<TrackedDeviceBuffer>>*
327 buffers_to_release = nullptr) {
328 tensorflow::profiler::TraceMe traceme("RecordUsage");
329 bool retain_buffer_until_completion =
330 // If the buffer wasn't allocated on the same device as the stream, always
331 // retain a reference.
332 (stream_local_device != buffer_local_device) ||
333 // In the synchronous allocation model, always retain a reference.
334 (stream_local_device->allocation_model() ==
335 LocalDeviceState::kSynchronous) ||
336 // In the compute synchronous model, use the caller's heuristic.
337 (stream_local_device->allocation_model() ==
338 LocalDeviceState::kComputeSynchronized &&
339 prefer_to_retain_reference);
340 if (retain_buffer_until_completion) {
341 if (buffers_to_release) {
342 buffers_to_release->push_back(device_buffer.buffer());
343 } else {
344 buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer());
345 }
346 }
347 device_buffer.ConvertUsageHold(usage_stream, event,
348 retain_buffer_until_completion);
349 }
350
351 // Allocates the device buffers for a buffer that will be used as the
352 // destination of a copy, either from the host or another device. copy_stream
353 // may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the
354 // buffer is a tuple then the tuple tables are allocated, and all necessary
355 // synchronization for them is dealt with, before the buffer is returned.
356 //
357 // It is safe to delete the returned PjRtBuffer without further
358 // synchronization if an error occurs before the buffer is used.
359 //
360 // The caller may optionally provide a definition event to be recorded in
361 // the buffer.
362 // TODO(phawkins): replace on_host_shape here with on_device_shape.
AllocateDestinationBuffer(const Shape & on_host_shape,PjRtDevice * device,LocalDeviceState * local_device,se::Stream * copy_stream,bool is_uninitialized_create,PjRtClient * client,std::shared_ptr<BufferSequencingEvent> definition_event=nullptr)363 StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
364 const Shape& on_host_shape, PjRtDevice* device,
365 LocalDeviceState* local_device, se::Stream* copy_stream,
366 bool is_uninitialized_create, PjRtClient* client,
367 std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
368 if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
369 return InvalidArgument("Can't make a buffer from an empty tuple");
370 }
371
372 auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
373 TransferManager* transfer_manager =
374 se_client->client()->backend().transfer_manager();
375 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
376 transfer_manager->AllocateScopedShapedBuffer(
377 on_host_shape, se_client->allocator(),
378 local_device->device_ordinal()));
379 if (local_device->allocation_model() ==
380 LocalDeviceState::kComputeSynchronized) {
381 if (copy_stream == nullptr) {
382 CHECK(is_uninitialized_create);
383 } else {
384 copy_stream->ThenWaitFor(local_device->compute_stream());
385 }
386 } else {
387 DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
388 local_device->compute_stream()->parent(), dst_buffer));
389 }
390 Shape on_device_shape = dst_buffer.on_device_shape();
391
392 absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
393 definition_events;
394 if (is_uninitialized_create) {
395 // There is not going to be any copy into the buffer so in general we don't
396 // need a definition event.
397 if (local_device->allocation_model() ==
398 LocalDeviceState::kComputeSynchronized) {
399 // The allocation is not valid until the compute stream passes this point,
400 // so add a definition event in the compute stream.
401 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
402 TF_ASSIGN_OR_RETURN(EventPool::Handle event,
403 local_device->event_pool().ThenAllocateAndRecordEvent(
404 local_device->compute_stream()));
405 definition_events.back()->SetSequencingEvent(
406 std::move(event), local_device->compute_stream());
407 }
408 // if the caller provided a definition event then we record that.
409 if (definition_event) {
410 definition_events.emplace_back(definition_event);
411 }
412 } else {
413 // We have at least one definition event, for the copy completing to
414 // the device buffers.
415 if (definition_event) {
416 definition_events.emplace_back(definition_event);
417 } else {
418 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
419 }
420 }
421 se::Stream* tuple_table_stream = local_device->host_to_device_stream();
422 if (on_device_shape.IsTuple()) {
423 // We also need to copy the tuple tables, so we'll have an additional
424 // definition event for that copy to complete.
425 if (tuple_table_stream != copy_stream) {
426 if (local_device->allocation_model() ==
427 LocalDeviceState::kComputeSynchronized) {
428 tuple_table_stream->ThenWaitFor(local_device->compute_stream());
429 } else {
430 DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
431 local_device->compute_stream()->parent(), dst_buffer));
432 }
433 }
434
435 TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
436 tuple_table_stream, dst_buffer));
437 // CAUTION: From this point onwards we need to be careful about returning
438 // from error cases because we have started a transfer and must not allow
439 // dst_buffer to be freed too soon in the non-async allocation models.
440
441 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
442 StatusOr<EventPool::Handle> event_or =
443 local_device->event_pool().ThenAllocateAndRecordEvent(
444 tuple_table_stream);
445 if (!event_or.ok()) {
446 StallStreamOnError(local_device, tuple_table_stream);
447 return event_or.status();
448 }
449 definition_events.back()->SetSequencingEvent(event_or.ConsumeValueOrDie(),
450 tuple_table_stream);
451 }
452 std::shared_ptr<TrackedDeviceBuffer> dst_device_buffer =
453 TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
454 definition_events);
455
456 auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
457 on_device_shape, std::move(dst_device_buffer), client, device);
458
459 if (on_device_shape.IsTuple()) {
460 // Add a usage hold for the tuple table write and immediately convert it to
461 // the appropriate form of synchronization. prefer_to_retain_reference=false
462 // means don't retain a memory reference until the transfer is complete when
463 // using the ComputeSynchronized allocation model. This is a heuristic
464 // because in the common case destination buffers will be used on the
465 // compute stream and therefore don't require any synchronization before
466 // being freed. If the buffer is allocated and never used, the free will
467 // take longer and this is assumed to be ok.
468 RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
469 definition_events.back(), tuple_table_stream,
470 /*prefer_to_retain_reference=*/false);
471 }
472
473 return py_buffer;
474 }
475
476 // Adds necessary synchronization after a copy has been enqueued to a buffer.
477 // definition_event was added when the buffer was allocated, but has not yet
478 // had an event recorded.
AddDestinationBufferSynchronization(LocalDeviceState * local_device,PjRtStreamExecutorBuffer::ScopedHold device_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,se::Stream * copy_stream)479 Status AddDestinationBufferSynchronization(
480 LocalDeviceState* local_device,
481 PjRtStreamExecutorBuffer::ScopedHold device_buffer,
482 std::shared_ptr<BufferSequencingEvent> definition_event,
483 se::Stream* copy_stream) {
484 StatusOr<EventPool::Handle> event_or =
485 local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream);
486 if (!event_or.ok()) {
487 StallStreamOnError(local_device, copy_stream);
488 return event_or.status();
489 }
490 definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(),
491 copy_stream);
492 // prefer_to_retain_reference=false means don't retain a memory reference
493 // until the transfer is complete when using the ComputeSynchronized
494 // allocation model. This is a heuristic because in the common case
495 // destination buffers will be used on the compute stream and therefore don't
496 // require any synchronization before being freed. If the buffer is allocated
497 // and never used, the free will take longer and this is assumed to be ok.
498 RecordUsage(std::move(device_buffer), local_device, local_device,
499 definition_event, copy_stream,
500 /*prefer_to_retain_reference=*/false);
501 return Status::OK();
502 }
503
504 } // namespace
505
~ScopedHold()506 PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
507 if (ok()) {
508 parent_->DropHold(type_, buffer().get());
509 }
510 }
511
ScopedHold(ScopedHold && other)512 PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
513 : parent_(other.parent_),
514 type_(other.type_),
515 state_(other.state_),
516 status_(std::move(other.status_)),
517 buffer_(std::move(other.buffer_)) {
518 // Preserve the invariant that status is invalid if buffer == nullptr.
519 other.SetState(kMoved);
520 }
521
Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>> && buffer_or)522 void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
523 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
524 CHECK(!ok());
525 if (buffer_or.ok()) {
526 buffer_ = buffer_or.ValueOrDie();
527 SetState(kValid);
528 } else {
529 status_ = buffer_or.status();
530 buffer_ = nullptr;
531 SetState(kError);
532 }
533 // Check the invariant holds.
534 CHECK(!ok() || buffer_ != nullptr);
535 }
536
537 PjRtStreamExecutorBuffer::ScopedHold::ForClosure
ToClosure()538 PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
539 CHECK(ok());
540 ForClosure for_closure(parent_, type_, state_, std::move(status_),
541 std::move(buffer_));
542 SetState(kReleased);
543 return for_closure;
544 }
545
ConvertUsageHold(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)546 void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
547 se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
548 bool reference_held) {
549 CHECK(ok());
550 CHECK_EQ(type_, kUsage);
551 parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
552 reference_held);
553 SetState(kConverted);
554 }
555
ConfirmDonation()556 void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
557 CHECK(ok());
558 CHECK_EQ(type_, kDonation);
559 parent_->ConfirmDonation(buffer().get());
560 SetState(kDonated);
561 }
562
AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const563 void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
564 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
565 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
566 ExecutionInput* execution_input,
567 se::DeviceMemoryAllocator* allocator) const {
568 CHECK(ok());
569 if (type_ == kDonation) {
570 buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
571 } else {
572 CHECK_EQ(type_, kUsage);
573 buffer()->AddToInputAsImmutable(iterator, end);
574 }
575 }
576
IsOnCpu() const577 bool PjRtStreamExecutorBuffer::IsOnCpu() const {
578 return client()->platform_id() == kCpuId;
579 }
580
logical_on_device_shape()581 StatusOr<Shape> PjRtStreamExecutorBuffer::logical_on_device_shape() {
582 if (on_device_shape_.is_static()) {
583 return on_device_shape_;
584 }
585 auto* local_device = device_->local_device_state();
586 auto* stream = local_device->GetDeviceToHostStream();
587 ScopedHold device_buffer(this, ScopedHold::kUsage);
588 {
589 absl::MutexLock lock(&mu_);
590 // We can't perform any other action while a donation hold is in progress.
591 WaitForOutstandingDonationHold();
592 if (device_buffer_ == nullptr) {
593 return InvalidArgument(
594 "logical_on_device_shape() called on deleted or donated buffer");
595 }
596 AcquireHoldLocked(&device_buffer);
597 }
598
599 WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
600 ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
601 StatusOr<EventPool::Handle> event_or =
602 local_device->event_pool().AllocateEvent(stream->parent());
603 if (!event_or.ok()) {
604 return event_or.status();
605 }
606 Shape ret_shape = on_device_shape_;
607 TransferManager* transfer_manager =
608 client_->client()->backend().transfer_manager();
609 TF_RETURN_IF_ERROR(
610 transfer_manager->ReadDynamicShapes(stream, &shaped_buffer, &ret_shape));
611 return ret_shape;
612 }
613
614 namespace {
615
616 // Implements PjRtBuffer::ExternalReference as a wrapped
617 // ScopedHold::kExternalReference.
618 class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
619 public:
ScopedHoldAsExternalReference(PjRtStreamExecutorBuffer::ScopedHold hold)620 explicit ScopedHoldAsExternalReference(
621 PjRtStreamExecutorBuffer::ScopedHold hold)
622 : external_reference_(std::move(hold)) {
623 CHECK(external_reference_.type() ==
624 PjRtStreamExecutorBuffer::ScopedHold::kExternalReference);
625 data_ptr_ = external_reference_->device_memory().front().opaque();
626 }
627
628 ~ScopedHoldAsExternalReference() override = default;
629
630 private:
631 PjRtStreamExecutorBuffer::ScopedHold external_reference_;
632 };
633
634 } // namespace
635
636 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()637 PjRtStreamExecutorBuffer::AcquireExternalReference() {
638 ScopedHold hold = GetBufferWithExternalReference();
639 Status hold_status = hold.status();
640 if (!hold_status.ok()) return hold_status;
641 return std::unique_ptr<ExternalReference>(
642 std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
643 }
644
645 class TrackedDeviceBufferExternalReference
646 : public PjRtBuffer::ExternalReference {
647 public:
TrackedDeviceBufferExternalReference(std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)648 explicit TrackedDeviceBufferExternalReference(
649 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)
650 : tracked_device_buffer_(std::move(tracked_device_buffer)) {
651 data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque();
652 }
653
654 ~TrackedDeviceBufferExternalReference() override = default;
655
656 private:
657 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer_;
658 };
659
660 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)661 PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
662 bool wait_for_operations_to_complete) {
663 if (on_device_shape_.IsTuple()) {
664 return InvalidArgument(
665 "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
666 }
667 TF_ASSIGN_OR_RETURN(
668 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
669 Release(wait_for_operations_to_complete));
670
671 std::unique_ptr<PjRtBuffer::ExternalReference> ref;
672 if (tracked_device_buffer) {
673 ref = std::make_unique<TrackedDeviceBufferExternalReference>(
674 std::move(tracked_device_buffer));
675 }
676 return ref;
677 }
678
679 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,absl::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)680 PjRtStreamExecutorClient::BufferFromHostBuffer(
681 const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
682 absl::optional<absl::Span<int64_t const>> byte_strides,
683 HostBufferSemantics host_buffer_semantics,
684 std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
685 tensorflow::profiler::TraceMe traceme(
686 "PjRtStreamExecutorClient::BufferFromHostBuffer");
687 Shape shape = ShapeUtil::MakeShape(type, dims);
688 VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
689 << shape.ToString() << " device: " << device->DebugString();
690 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
691 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
692 ->GetLocalDeviceState());
693
694 absl::InlinedVector<int64, 4> tmp_strides;
695 if (!byte_strides) {
696 tmp_strides.resize(dims.size());
697 TF_RETURN_IF_ERROR(
698 ShapeUtil::ByteStrides(shape, absl::MakeSpan(tmp_strides)));
699 byte_strides = tmp_strides;
700 }
701 int64_t size = ShapeUtil::ByteSizeOf(shape);
702
703 TransferManager* transfer_manager = client()->backend().transfer_manager();
704 TF_ASSIGN_OR_RETURN(Shape compact_shape,
705 transfer_manager->ChooseCompactLayoutForShape(shape));
706 absl::InlinedVector<int64_t, 4> compact_shape_strides(
707 compact_shape.dimensions_size());
708 TF_RETURN_IF_ERROR(ShapeUtil::ByteStrides(
709 compact_shape, absl::MakeSpan(compact_shape_strides)));
710 bool host_and_device_strides_equal =
711 (size == 0 || *byte_strides == compact_shape_strides);
712 // The CPU platform is special because the "host" and the "device" are in the
713 // same memory space. If the input shape is in the correct layout and we don't
714 // want to defer the copy onto a thread, we can use the following fast
715 // path.
716 bool is_cpu_platform =
717 local_device->executor()->platform()->id() == se::host::kHostPlatformId;
718 if (is_cpu_platform) {
719 // If we are on the host platform and the input buffer is sufficiently
720 // aligned, we can simply point to the input array's data without any
721 // further copies. At the time of writing we require a 16-byte alignment
722 // because XLA may generate code which requires it.
723 bool can_use_zero_copy =
724 host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
725 ((absl::bit_cast<std::uintptr_t>(data) &
726 (cpu_function_runtime::kMinAlign - 1)) == 0);
727 if (host_and_device_strides_equal &&
728 (host_buffer_semantics ==
729 HostBufferSemantics::kImmutableOnlyDuringCall ||
730 can_use_zero_copy)) {
731 std::function<void()> on_delete_callback;
732 se::DeviceMemoryBase buffer;
733 // If we are on the host platform and the input buffer is sufficiently
734 // aligned, we can simply point to the input array's data without any
735 // further copies. At the time of writing we require a 16-byte alignment
736 // because XLA may generate code which requires it.
737 if (can_use_zero_copy) {
738 on_delete_callback = std::move(on_done_with_host_buffer);
739 buffer = se::DeviceMemoryBase(
740 const_cast<void*>(static_cast<const void*>(data)), size);
741 } else {
742 void* staging_buffer = host_memory_allocator()->AllocateRaw(
743 cpu_function_runtime::kMinAlign, size);
744 buffer = se::DeviceMemoryBase(staging_buffer, size);
745 std::memcpy(staging_buffer, data, size);
746 if (on_done_with_host_buffer) {
747 on_done_with_host_buffer();
748 }
749 on_delete_callback = [staging_buffer, host_memory_allocator =
750 host_memory_allocator()]() {
751 host_memory_allocator->DeallocateRaw(staging_buffer);
752 };
753 }
754 absl::Span<const std::shared_ptr<BufferSequencingEvent>>
755 definition_events;
756 auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
757 /*allocator=*/nullptr, local_device->device_ordinal(),
758 std::initializer_list<se::DeviceMemoryBase>{buffer},
759 definition_events, std::move(on_delete_callback));
760 return std::unique_ptr<PjRtBuffer>(
761 std::make_unique<PjRtStreamExecutorBuffer>(
762 shape, std::move(device_buffer), this, device));
763 }
764 }
765
766 TF_ASSIGN_OR_RETURN(
767 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
768 AllocateDestinationBuffer(compact_shape, device, local_device,
769 local_device->host_to_device_stream(),
770 /*is_uninitialized_create=*/false, this));
771
772 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
773 py_buffer->GetBufferWithUsageHold());
774 CHECK(device_buffer.ok());
775
776 // If necessary, allocate a host-side buffer for staging host-to-device
777 // transfers. On GPU this is a buffer in pinned memory.
778 std::shared_ptr<void> staging_buffer;
779 if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
780 should_stage_host_to_device_transfers() ||
781 !host_and_device_strides_equal) {
782 void* ptr = host_memory_allocator()->AllocateRaw(
783 tensorflow::Allocator::kAllocatorAlignment, size);
784 staging_buffer = std::shared_ptr<void>(
785 ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
786 host_memory_allocator->DeallocateRaw(ptr);
787 });
788 }
789
790 std::shared_ptr<TransposePlan> transpose;
791 if (!host_and_device_strides_equal) {
792 absl::InlinedVector<int64_t, 4> permutation(dims.size());
793 absl::c_reverse_copy(compact_shape.layout().minor_to_major(),
794 permutation.begin());
795 absl::MutexLock lock(&transpose_mu_);
796 TF_ASSIGN_OR_RETURN(transpose,
797 transpose_cache_.GetOrCreate(
798 primitive_util::ByteWidth(type), dims, permutation,
799 TransposePlan::Striding{*byte_strides}));
800 }
801
802 // Copy the buffer into a staging buffer before returning control to the
803 // caller if the caller only guaranteed that the buffer is valid for the
804 // duration of the call. Otherwise, we stage (if necessary) on a separate
805 // thread.
806 if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
807 if (transpose) {
808 transpose->Execute(data, staging_buffer.get());
809 } else {
810 std::memcpy(staging_buffer.get(), data, size);
811 }
812 if (on_done_with_host_buffer) {
813 on_done_with_host_buffer();
814 on_done_with_host_buffer = nullptr;
815 }
816 }
817
818 // The host to device transfer is performed on a thread pool, mostly because
819 // it includes linearization that may be slow. It is OK to capture the
820 // py_buffer pointer because the py_buffer can't be deleted until all the
821 // usage holds have gone away.
822 // TODO(misard) assess if it would be preferable to introduce a heuristic to
823 // put the transfer into the calling thread for small literals.
824 auto transfer_h2d =
825 [local_client = client(), transfer_manager, local_device, data, size,
826 movable_device_buffer{device_buffer.ToClosure()}, shape,
827 py_buffer{py_buffer.get()},
828 on_device_shape{py_buffer->on_device_shape()},
829 staging_buffer{std::move(staging_buffer)},
830 on_done_with_host_buffer{std::move(on_done_with_host_buffer)},
831 host_buffer_semantics, transpose{std::move(transpose)}]() {
832 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
833 movable_device_buffer);
834 // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
835 // to report failures from a callback. However, the operations here are
836 // unlikely to fail and not recoverable even if we were to fail: DMAs to
837 // memory that has already been allocated, and a possible Event
838 // allocation.
839
840 ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
841 // If applicable on the backend, stage the transfer via host memory
842 // allocated via the host_memory_allocator. On GPU, this is pinned
843 // memory.
844 if (staging_buffer) {
845 // If we didn't already copy the input buffer into the staging buffer,
846 // do so now.
847 if (host_buffer_semantics !=
848 HostBufferSemantics::kImmutableOnlyDuringCall) {
849 if (transpose) {
850 transpose->Execute(data, staging_buffer.get());
851 } else {
852 std::memcpy(staging_buffer.get(), data, size);
853 }
854 }
855 // The buffer has the same dimension order as the on-device shape, but
856 // is not tiled, etc.
857 BorrowingLiteral literal(
858 static_cast<const char*>(staging_buffer.get()),
859 ShapeUtil::DeviceShapeToHostShape(on_device_shape));
860 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
861 local_device->host_to_device_stream(), literal, buffer));
862 } else {
863 BorrowingLiteral literal(
864 reinterpret_cast<const char*>(data),
865 ShapeUtil::DeviceShapeToHostShape(on_device_shape));
866 // Otherwise, just transfer the literal.
867 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
868 local_device->host_to_device_stream(), literal, buffer));
869 }
870
871 std::shared_ptr<BufferSequencingEvent> event =
872 device_buffer->definition_events()[0];
873 TF_CHECK_OK(AddDestinationBufferSynchronization(
874 local_device, std::move(device_buffer), event,
875 local_device->host_to_device_stream()));
876
877 local_device->ThenExecuteCallback(
878 local_device->host_to_device_stream(),
879 [staging_buffer{std::move(staging_buffer)},
880 on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
881 if (on_done_with_host_buffer) {
882 on_done_with_host_buffer();
883 }
884 });
885 };
886 if (is_cpu_platform) {
887 // Using the thread_pool would be a double thread hop; the code
888 // already defers its work onto a stream (= thread on CPU).
889 transfer_h2d();
890 } else {
891 thread_pool()->Schedule(transfer_h2d);
892 }
893 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
894 }
895
896 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)897 PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
898 PjRtDevice* device) {
899 return CreateUninitializedBuffer(shape, device, nullptr);
900 }
901
902 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device,std::shared_ptr<BufferSequencingEvent> definition_event)903 PjRtStreamExecutorClient::CreateUninitializedBuffer(
904 const Shape& shape, PjRtDevice* device,
905 std::shared_ptr<BufferSequencingEvent> definition_event) {
906 tensorflow::profiler::TraceMe traceme(
907 "PjRtStreamExecutorClient::CreateUninitializedBuffer");
908 VLOG(1) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
909 << shape.ToString() << " device: " << device->DebugString();
910 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
911 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
912 ->GetLocalDeviceState());
913
914 TransferManager* transfer_manager = client()->backend().transfer_manager();
915 TF_ASSIGN_OR_RETURN(Shape compact_shape,
916 transfer_manager->ChooseCompactLayoutForShape(shape));
917
918 TF_ASSIGN_OR_RETURN(
919 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
920 AllocateDestinationBuffer(compact_shape, device, local_device,
921 /*copy_stream=*/nullptr,
922 /*is_uninitialized_create=*/true, this,
923 definition_event));
924 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
925 }
926
927 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)928 PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
929 PjRtDevice* device) {
930 tensorflow::profiler::TraceMe traceme(
931 "PjRtStreamExecutorClient::BufferFromHostLiteral");
932 VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
933 << literal.shape().ToString() << " device: " << device->DebugString();
934 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
935 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
936 ->GetLocalDeviceState());
937
938 TransferManager* transfer_manager = client()->backend().transfer_manager();
939 TF_ASSIGN_OR_RETURN(
940 Shape compact_shape,
941 transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
942 TF_ASSIGN_OR_RETURN(
943 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
944 AllocateDestinationBuffer(compact_shape, device, local_device,
945 local_device->host_to_device_stream(),
946 /*is_uninitialized_create=*/false, this));
947
948 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
949 py_buffer->GetBufferWithUsageHold());
950 CHECK(device_buffer.ok());
951
952 // The host to device transfer is performed on a thread pool, mostly because
953 // it includes linearization that may be slow. It is OK to capture the
954 // py_buffer pointer because the py_buffer can't be deleted until all the
955 // usage holds have gone away.
956 // TODO(misard) assess if it would be preferable to introduce a heuristic to
957 // put the transfer into the calling thread for small literals.
958 auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
959 movable_device_buffer{device_buffer.ToClosure()},
960 literal, py_buffer{py_buffer.get()},
961 on_device_shape{py_buffer->on_device_shape()}]() {
962 PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
963 // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
964 // to report failures from a callback. However, the operations here are
965 // unlikely to fail and not recoverable even if we were to fail: DMAs to
966 // memory that has already been allocated, and a possible Event
967 // allocation.
968
969 se::Stream* h2d_stream = local_device->host_to_device_stream();
970 ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
971 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
972 h2d_stream, literal, buffer));
973
974 std::shared_ptr<BufferSequencingEvent> event =
975 device_buffer->definition_events()[0];
976 TF_CHECK_OK(AddDestinationBufferSynchronization(
977 local_device, std::move(device_buffer), event, h2d_stream));
978
979 // This can sometimes catch the case where the literal memory has been
980 // freed before the H2D transfer was issued.
981 h2d_stream->RefreshStatus()
982 .IgnoreError(); // Can return error::Unimplemented
983 QCHECK(h2d_stream->ok());
984 };
985 thread_pool()->Schedule(transfer_h2d);
986 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
987 }
988
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier && notifier)989 void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
990 absl::Span<const Shape> shapes, PjRtDevice* device,
991 PjRtCrossHostRecvNotifier&& notifier) {
992 if (shapes.empty()) {
993 notifier(InvalidArgument(
994 "shapes parameter empty in MakeCrossHostReceiveBuffers"));
995 return;
996 }
997
998 auto local_device_or =
999 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1000 ->GetLocalDeviceState();
1001 if (!local_device_or.ok()) {
1002 notifier(local_device_or.status());
1003 return;
1004 }
1005 LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
1006 std::shared_ptr<BufferSequencingEvent> definition_event =
1007 std::make_shared<BufferSequencingEvent>();
1008 std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1009 buffers.reserve(shapes.size());
1010 for (const auto& shape : shapes) {
1011 StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
1012 shape, device, local_device,
1013 /*copy_stream=*/nullptr,
1014 /*is_uninitialized_create=*/false, this, definition_event);
1015 if (!buffer_or.ok()) {
1016 notifier(buffer_or.status());
1017 return;
1018 }
1019 buffers.push_back(buffer_or.ConsumeValueOrDie());
1020 }
1021
1022 EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
1023 std::move(notifier), absl::nullopt);
1024 }
MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier && notifier)1025 void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffersForGather(
1026 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
1027 PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) {
1028 VLOG(2) << "Making " << gather_details.size()
1029 << " cross host receive buffers for gather";
1030 if (gather_details.empty()) {
1031 notifier(
1032 InvalidArgument("gather_details parameter empty in "
1033 "MakeCrossHostReceiveBuffersForGather"));
1034 return;
1035 }
1036
1037 if (shapes.size() != gather_details.size()) {
1038 notifier(
1039 InvalidArgument("gather_details parameter has length %lld but shapes "
1040 "parameter has length %lld in "
1041 "MakeCrossHostReceiveBuffersForGather",
1042 gather_details.size(), shapes.size()));
1043 return;
1044 }
1045
1046 auto local_device_or =
1047 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1048 ->GetLocalDeviceState();
1049 if (!local_device_or.ok()) {
1050 notifier(local_device_or.status());
1051 return;
1052 }
1053 LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
1054 std::shared_ptr<BufferSequencingEvent> definition_event =
1055 std::make_shared<BufferSequencingEvent>();
1056 std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1057 buffers.reserve(shapes.size());
1058 for (int i = 0; i < shapes.size(); ++i) {
1059 StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
1060 shapes[i], device, local_device,
1061 /*copy_stream=*/nullptr,
1062 /*is_uninitialized_create=*/false, this, definition_event);
1063 if (!buffer_or.ok()) {
1064 notifier(buffer_or.status());
1065 return;
1066 }
1067 buffers.push_back(buffer_or.ConsumeValueOrDie());
1068 }
1069
1070 EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
1071 std::move(notifier), gather_details);
1072 }
1073
1074 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)1075 PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
1076 void* device_ptr, const Shape& shape, PjRtDevice* device,
1077 std::function<void()> on_delete_callback) {
1078 se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape));
1079 absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
1080 auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
1081 /*allocator=*/nullptr, device->local_hardware_id(),
1082 std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
1083 std::move(on_delete_callback));
1084 return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
1085 shape, std::move(device_buffer), this, device));
1086 }
1087
1088 // Transfer the given literal to the infeed queue of the given local device.
TransferToInfeed(const LiteralSlice & literal)1089 Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
1090 // Only support infeed to local device.
1091 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1092 return local_device->client()->TransferToInfeedLocal(
1093 literal, local_device->device_ordinal());
1094 }
1095
TransferFromOutfeed(MutableBorrowingLiteral literal)1096 Status PjRtStreamExecutorDevice::TransferFromOutfeed(
1097 MutableBorrowingLiteral literal) {
1098 VLOG(1) << "PjRtStreamExecutorDevice::TransferFromOutfeed";
1099 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1100 return local_device->client()->TransferFromOutfeedLocal(
1101 local_device->device_ordinal(), literal);
1102 }
1103
LookupAddressableDevice(int local_hardware_id) const1104 StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
1105 int local_hardware_id) const {
1106 for (auto* device : addressable_devices_) {
1107 if (local_hardware_id == device->local_hardware_id()) {
1108 return device;
1109 }
1110 }
1111 return InvalidArgument("No matching device found for local_hardware_id %d",
1112 local_hardware_id);
1113 }
1114
PjRtStreamExecutorBuffer(Shape on_device_shape,std::shared_ptr<TrackedDeviceBuffer> device_buffer,PjRtClient * client,PjRtDevice * device)1115 PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
1116 Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
1117 PjRtClient* client, PjRtDevice* device)
1118 : client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
1119 on_device_shape_(std::move(on_device_shape)),
1120 device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
1121 device_buffer_(std::move(device_buffer)) {
1122 for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1123 holds_[i] = 0;
1124 }
1125 }
1126
~PjRtStreamExecutorBuffer()1127 PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
1128 Delete();
1129 for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1130 CHECK_EQ(holds_[i], 0);
1131 }
1132 }
1133
WaitForOutstandingUsageHolds()1134 void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
1135 auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1136 return holds_[ScopedHold::kUsage] == 0;
1137 };
1138 mu_.Await(absl::Condition(¬_in_usage_hold));
1139 }
1140
WaitForOutstandingDonationHold()1141 void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
1142 auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1143 return holds_[ScopedHold::kDonation] == 0;
1144 };
1145 mu_.Await(absl::Condition(¬_in_donation_hold));
1146 }
1147
1148 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
Release(bool wait_for_operations_to_complete)1149 PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
1150 tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
1151 std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1152 TrackedDeviceBuffer::StreamAndEventContainer events;
1153 {
1154 absl::MutexLock lock(&mu_);
1155 // We first wait for a donation hold to complete if there is one in
1156 // progress. If the donation succeeds via ConfirmDonation() then it will
1157 // set device_buffer_ to nullptr before returning to this thread.
1158 WaitForOutstandingDonationHold();
1159 if (device_buffer_ == nullptr) {
1160 return std::shared_ptr<TrackedDeviceBuffer>();
1161 }
1162 // Set device_buffer_ to null now so that no other
1163 // thread can add a hold while we are in WaitForOutstandingUsageHolds()
1164 // below.
1165 std::swap(device_buffer_, device_buffer);
1166 WaitForOutstandingUsageHolds();
1167 // Now that all holds have completed and no more can be added, we can get
1168 // the final set of usage events.
1169 events = device_buffer->LockUseAndTransferUsageEvents();
1170 }
1171 LocalDeviceState* local_device_state = device_->local_device_state();
1172 if (wait_for_operations_to_complete) {
1173 // Block the host until all usage events have completed. Usage events
1174 // dominate definition events, so this also waits for the buffer to be
1175 // defined.
1176 std::unique_ptr<se::Stream> stream;
1177 for (const auto& stream_and_event : events) {
1178 if (!stream_and_event.event->IsComplete()) {
1179 if (stream == nullptr) {
1180 stream = local_device_state->BorrowStreamFromPool();
1181 }
1182 stream_and_event.event->WaitForEventOnStream(stream.get());
1183 }
1184 }
1185 if (stream != nullptr) {
1186 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1187 local_device_state->ReturnStreamToPool(std::move(stream));
1188 }
1189 } else {
1190 if (local_device_state->allocation_model() ==
1191 LocalDeviceState::kComputeSynchronized) {
1192 std::unique_ptr<se::Stream> block_stream;
1193 for (const auto& stream_and_event : events) {
1194 // We only need to do something for events that didn't already acquire a
1195 // reference to the buffer, and also which the compute stream didn't
1196 // already wait for. Based on our heuristics this rare case should only
1197 // occur when a buffer was copied to a device and then never used there.
1198 // In that case we get a new stream and use it to hold onto a reference
1199 // to the buffer until the events are complete.
1200 if (!stream_and_event.reference_held &&
1201 !stream_and_event.event->DefinedOn(
1202 local_device_state->compute_stream()) &&
1203 !stream_and_event.event->IsComplete()) {
1204 if (block_stream == nullptr) {
1205 block_stream = local_device_state->BorrowStreamFromPool();
1206 }
1207 stream_and_event.event->WaitForEventOnStream(block_stream.get());
1208 }
1209 }
1210 if (block_stream != nullptr) {
1211 se::Stream* block_stream_ptr = block_stream.release();
1212 local_device_state->ThenExecuteCallback(
1213 block_stream_ptr,
1214 [device_buffer, block_stream_ptr, local_device_state]() {
1215 local_device_state->ReturnStreamToPool(
1216 std::unique_ptr<se::Stream>(block_stream_ptr));
1217 });
1218 }
1219 }
1220 }
1221 return device_buffer;
1222 }
1223
Delete()1224 void PjRtStreamExecutorBuffer::Delete() {
1225 VLOG(1) << "PjRtStreamExecutorBuffer::Delete";
1226 // When wait_for_reads_to_complete is false, Release should never fail.
1227 TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
1228 }
1229
IsDeleted()1230 bool PjRtStreamExecutorBuffer::IsDeleted() {
1231 absl::MutexLock lock(&mu_);
1232 return device_buffer_ == nullptr;
1233 }
1234
1235 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
GetBufferForHoldLocked(ScopedHold::Type type)1236 PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
1237 // All callers should have called WaitForOutstandingDonationHold().
1238 CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1239 if (type == ScopedHold::kDonation) {
1240 if (device_buffer_ == nullptr) {
1241 return InvalidArgument("Donation requested for invalid buffer");
1242 }
1243 if (holds_[ScopedHold::kExternalReference] > 0) {
1244 return InvalidArgument(
1245 "Donation requested for buffer with external reference");
1246 }
1247 // First add the donation hold.
1248 ++holds_[type];
1249 // Then wait for any usage holds to be dropped or converted. No new usage
1250 // holds can be added until we drop the donation hold so this wait will
1251 // complete eventually.
1252 WaitForOutstandingUsageHolds();
1253 // Because we added a donation hold, nobody could release the buffer while
1254 // we were waiting.
1255 CHECK(device_buffer_ != nullptr);
1256 } else {
1257 if (device_buffer_ == nullptr) {
1258 return InvalidArgument("Buffer has been deleted or donated.");
1259 } else {
1260 ++holds_[type];
1261 }
1262 }
1263 return device_buffer_;
1264 }
1265
AcquireHoldLocked(ScopedHold * hold)1266 void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
1267 hold->Acquire(GetBufferForHoldLocked(hold->type()));
1268 }
1269
ConvertUsageHold(TrackedDeviceBuffer * buffer,se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)1270 void PjRtStreamExecutorBuffer::ConvertUsageHold(
1271 TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
1272 std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
1273 absl::MutexLock lock(&mu_);
1274 CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1275 buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
1276 CHECK_GT(holds_[ScopedHold::kUsage], 0);
1277 --holds_[ScopedHold::kUsage];
1278 }
1279
ConfirmDonation(TrackedDeviceBuffer * device_buffer)1280 void PjRtStreamExecutorBuffer::ConfirmDonation(
1281 TrackedDeviceBuffer* device_buffer) {
1282 {
1283 absl::MutexLock lock(&mu_);
1284 CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1285 CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1286 CHECK_EQ(holds_[ScopedHold::kDonation], 1);
1287 holds_[ScopedHold::kDonation] = 0;
1288 CHECK(device_buffer_.get() == device_buffer);
1289 // As a sanity check ensure no more usage events can be added to the buffer.
1290 device_buffer->LockUseAndTransferUsageEvents();
1291 // Give up ownership of the device memory so we don't free it when the last
1292 // reference to device_buffer_ goes away.
1293 device_buffer->ReleaseDeviceMemory();
1294 // Make *this invalid so it can't be used again. Any threads blocking in
1295 // Release or GetBufferWithHold will see an invalid buffer and return.
1296 device_buffer_.reset();
1297 }
1298 }
1299
DropHold(ScopedHold::Type type,TrackedDeviceBuffer * buffer)1300 void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
1301 TrackedDeviceBuffer* buffer) {
1302 absl::MutexLock lock(&mu_);
1303 CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1304 CHECK_GT(holds_[type], 0);
1305 --holds_[type];
1306 if (type == ScopedHold::kDonation) {
1307 CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1308 CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1309 CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1310 }
1311 }
1312
ToLiteral(MutableLiteralBase * literal,std::function<void (Status)> on_ready)1313 void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
1314 std::function<void(Status)> on_ready) {
1315 VLOG(1) << "PjRtStreamExecutorBuffer::ToLiteral";
1316 if (IsEmptyTuple()) {
1317 on_ready(InvalidArgument("ToLiteral called on empty tuple"));
1318 return;
1319 }
1320 LocalDeviceState* local_device = device_->local_device_state();
1321 se::Stream* stream = local_device->GetDeviceToHostStream();
1322 ScopedHold device_buffer(this, ScopedHold::kUsage);
1323 {
1324 absl::MutexLock lock(&mu_);
1325 // We can't perform any other action while a donation hold is in progress.
1326 WaitForOutstandingDonationHold();
1327 if (device_buffer_ == nullptr) {
1328 on_ready(InvalidArgument(
1329 "CopyToHostAsync() called on deleted or donated buffer"));
1330 return;
1331 }
1332 AcquireHoldLocked(&device_buffer);
1333 }
1334
1335 WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
1336 ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
1337 StatusOr<EventPool::Handle> event_or =
1338 local_device->event_pool().AllocateEvent(stream->parent());
1339 if (!event_or.ok()) {
1340 on_ready(event_or.status());
1341 return;
1342 }
1343 client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
1344 stream, shaped_buffer, literal, std::move(on_ready));
1345
1346 auto usage_event = std::make_shared<BufferSequencingEvent>();
1347 local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
1348 usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1349 // When using the ComputeSynchronized allocation model, retain a reference to
1350 // the device_buffer until the copy completes, to ensure that the buffer isn't
1351 // deleted or donated while it is still in use. The choice of retaining a
1352 // reference at the host is a heuristic; the alternative is to ensure, before
1353 // freeing the buffer, that the compute stream is synchronized past the
1354 // transfer, but it seems better to hold onto the buffer too long than to
1355 // stall the compute stream, particularly since the overwhelmingly common
1356 // use case of CopyToHostAsync will hold onto the reference long enough to
1357 // read the buffer in a subsequent call to ToLiteral.
1358 RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
1359 stream,
1360 /*prefer_to_retain_reference=*/true);
1361 }
1362
GetOnDeviceSizeInBytes() const1363 StatusOr<size_t> PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() const {
1364 absl::MutexLock lock(&mu_);
1365 if (device_buffer_ == nullptr) {
1366 return InvalidArgument(
1367 "GetOnDeviceSizeInBytes called on deleted or donated buffer");
1368 }
1369 if (device_buffer_->device_memory().size() != 1) {
1370 return InvalidArgument(
1371 "GetOnDeviceSizeInBytes called on tuple-shaped buffer");
1372 }
1373 return device_buffer_->device_memory()[0].size();
1374 }
1375
CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size,std::function<void (Status)> on_ready)1376 Status PjRtStreamExecutorBuffer::CopyRawToHost(
1377 void* dst, int64_t offset, int64_t transfer_size,
1378 std::function<void(Status)> on_ready) {
1379 return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size,
1380 std::move(on_ready));
1381 }
1382
AsShapedBuffer() const1383 StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
1384 absl::MutexLock lock(&mu_);
1385 if (device_buffer_ == nullptr) {
1386 return InvalidArgument(
1387 "Attempted to fetch value of invalid/deleted buffer.");
1388 }
1389 return device_buffer_->AsShapedBuffer(on_device_shape_);
1390 }
1391
1392 PjRtStreamExecutorBuffer::ScopedHold
GetBufferWithHold(ScopedHold::Type type)1393 PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
1394 absl::MutexLock lock(&mu_);
1395 // Ensure that at most one donation hold can be in progress at a time.
1396 WaitForOutstandingDonationHold();
1397 ScopedHold hold(this, type);
1398 AcquireHoldLocked(&hold);
1399 return hold;
1400 }
1401
1402 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1403 std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice * dst_device,LocalDeviceState * dst_local_device,LocalDeviceState * transfer_local_device,se::Stream * transfer_stream,std::shared_ptr<TrackedDeviceBuffer> src_device_buffer)1404 PjRtStreamExecutorBuffer::CopyToDeviceHelper(
1405 PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
1406 LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
1407 std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
1408 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
1409 AllocateDestinationBuffer(
1410 ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
1411 dst_device, dst_local_device, transfer_stream,
1412 /*is_uninitialized_create=*/false, client_));
1413
1414 TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
1415
1416 WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream);
1417
1418 ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
1419 CHECK(dst_device_buffer.ok());
1420 ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
1421
1422 // Copy the leaf buffers.
1423 StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
1424 [&]() -> StatusOr<std::shared_ptr<BufferSequencingEvent>> {
1425 for (const auto& leaf : src_buffer.buffers().leaves()) {
1426 const ShapeIndex& index = leaf.first;
1427 const se::DeviceMemoryBase& input_buffer = leaf.second;
1428 const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
1429 TF_RET_CHECK(input_buffer.size() == output_buffer.size())
1430 << "input: " << input_buffer.size()
1431 << " output: " << output_buffer.size();
1432 if (input_buffer.size() != 0) {
1433 TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
1434 transfer_stream, dst_local_device->compute_stream(), input_buffer,
1435 output_buffer));
1436 }
1437 }
1438 std::shared_ptr<BufferSequencingEvent> event =
1439 dst_device_buffer->definition_events()[0];
1440 TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization(
1441 transfer_local_device, std::move(dst_device_buffer), event,
1442 transfer_stream));
1443 return event;
1444 }();
1445 if (!copy_event_or.ok()) {
1446 StallStreamOnError(transfer_local_device, transfer_stream);
1447 if (transfer_local_device == dst_local_device) {
1448 // Some copies may have been enqueued before the error was returned, and
1449 // StallStreamOnError only makes sure the destination device is ok, so
1450 // make sure that the src buffer remains valid until after any transfers
1451 // have completed.
1452 device_->local_device_state()->ThenRelease(transfer_stream,
1453 std::move(src_device_buffer));
1454 }
1455 return copy_event_or.status();
1456 }
1457
1458 return std::pair<std::unique_ptr<PjRtBuffer>,
1459 std::shared_ptr<BufferSequencingEvent>>(
1460 std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
1461 copy_event_or.ConsumeValueOrDie());
1462 }
1463
CopyToDevice(PjRtDevice * dst_device)1464 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
1465 PjRtDevice* dst_device) {
1466 tensorflow::profiler::TraceMe traceme(
1467 "PjRtStreamExecutorBuffer::CopyToDevice");
1468 VLOG(1) << "PjRtStreamExecutorBuffer::CopyToDevice";
1469 if (dst_device == device_) {
1470 return InvalidArgument(
1471 "CopyToDevice cannot accept the same source and destination devices");
1472 }
1473
1474 // Copying across PjRtClients involves a copy through the host.
1475 if (dst_device->client() != client_) {
1476 TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
1477 // Avoid use-after-free on `literal` due to unsequenced move and use.
1478 Literal* literal_pointer = literal.get();
1479 absl::InlinedVector<int64_t, 4> byte_strides(
1480 literal->shape().dimensions_size());
1481 TF_RETURN_IF_ERROR(
1482 ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides)));
1483 return dst_device->client()->BufferFromHostBuffer(
1484 literal_pointer->untyped_data(),
1485 literal_pointer->shape().element_type(),
1486 literal_pointer->shape().dimensions(), byte_strides,
1487 PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
1488 [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1489 }
1490
1491 TF_ASSIGN_OR_RETURN(
1492 LocalDeviceState * dst_local_device,
1493 tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
1494 ->GetLocalDeviceState());
1495 LocalDeviceState* transfer_local_device =
1496 client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
1497 : dst_local_device;
1498 CHECK_EQ(dst_local_device->allocation_model(),
1499 transfer_local_device->allocation_model());
1500
1501 se::Stream* transfer_stream =
1502 transfer_local_device->GetDeviceToDeviceStream();
1503
1504 ScopedHold src_device_buffer(this, ScopedHold::kUsage);
1505 {
1506 absl::MutexLock lock(&mu_);
1507 // We can't perform any other action while a donation hold is in progress.
1508 WaitForOutstandingDonationHold();
1509 if (device_buffer_ == nullptr) {
1510 return InvalidArgument(
1511 "CopyToDevice called on deleted or donated buffer");
1512 }
1513 AcquireHoldLocked(&src_device_buffer);
1514 }
1515
1516 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1517 std::shared_ptr<BufferSequencingEvent>>>
1518 buffer_and_event_or = CopyToDeviceHelper(
1519 dst_device, dst_local_device, transfer_local_device, transfer_stream,
1520 src_device_buffer.buffer());
1521 if (!buffer_and_event_or.ok()) {
1522 return buffer_and_event_or.status();
1523 }
1524
1525 auto& buffer_and_event = buffer_and_event_or.ValueOrDie();
1526 std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
1527 std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;
1528
1529 // prefer_to_retain_reference=*/true means that, when using the
1530 // ComputeSynchronized allocation model, retain a reference to the
1531 // src_device_buffer until the copy completes. This is a heuristic; the
1532 // alternative is to ensure, before freeing the buffer, that the compute
1533 // stream is synchronized past the transfer, but it seems better to hold onto
1534 // the buffer too long than to stall the compute stream.
1535 RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
1536 transfer_local_device, event, transfer_stream,
1537 /*prefer_to_retain_reference=*/true);
1538
1539 return std::move(buffer);
1540 }
1541
CopyToRemoteDevice(absl::string_view serialized_descriptor)1542 Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
1543 absl::string_view serialized_descriptor) {
1544 VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDevice";
1545 return client_->CopyToRemoteDevice(this, serialized_descriptor);
1546 }
1547
CopyToRemoteDeviceScattered(absl::Span<const std::string> serialized_descriptors,const ScatterDetails & scatter_details)1548 Status PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered(
1549 absl::Span<const std::string> serialized_descriptors,
1550 const ScatterDetails& scatter_details) {
1551 VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered";
1552 return client_->CopyToRemoteDeviceScattered(this, serialized_descriptors,
1553 scatter_details);
1554 }
1555
BlockHostUntilReady()1556 Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
1557 tensorflow::profiler::TraceMe traceme(
1558 "PjRtStreamExecutorBuffer::BlockHostUntilReady");
1559 VLOG(1) << "PjRtStreamExecutorBuffer::BlockHostUntilReady";
1560 std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1561 {
1562 absl::MutexLock lock(&mu_);
1563 if (device_buffer_ == nullptr) {
1564 return InvalidArgument(
1565 "BlockHostUntilReady() called on deleted or donated buffer");
1566 }
1567 device_buffer = device_buffer_;
1568 }
1569 LocalDeviceState* local_device_state = device_->local_device_state();
1570 std::unique_ptr<se::Stream> stream;
1571 for (auto& event : device_buffer->definition_events()) {
1572 if (!event->IsComplete()) {
1573 if (stream == nullptr) {
1574 stream = local_device_state->BorrowStreamFromPool();
1575 }
1576 event->WaitForEventOnStream(stream.get());
1577 }
1578 }
1579 if (stream != nullptr) {
1580 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1581 local_device_state->ReturnStreamToPool(std::move(stream));
1582 }
1583 return Status::OK();
1584 }
1585
1586 namespace {
1587
1588 // Helper struct for the tuple that is transiently constructed to hold the
1589 // arguments of an execution.
1590 struct TupleHandle {
1591 // The ExecutionInput describing the tuple.
1592 ExecutionInput execution_input;
1593 // A definition event that has been recorded on the host_to_device stream
1594 // after the tuple table transfer.
1595 std::shared_ptr<BufferSequencingEvent> event;
1596 };
1597
CheckCompatibleShapes(bool strict_shape_checking,const Shape & buffer_shape,const Shape & execution_shape,const TransferManager & transfer_manager,int parameter_index)1598 Status CheckCompatibleShapes(bool strict_shape_checking,
1599 const Shape& buffer_shape,
1600 const Shape& execution_shape,
1601 const TransferManager& transfer_manager,
1602 int parameter_index) {
1603 // TODO(misard) Support casting of tuple parameters.
1604 if (strict_shape_checking || buffer_shape.IsTuple()) {
1605 if (!ShapeUtil::Equal(buffer_shape, execution_shape)) {
1606 return InvalidArgument(
1607 "Executable expected shape %s for argument %d but got "
1608 "incompatible "
1609 "shape %s",
1610 ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1611 ShapeUtil::HumanStringWithLayout(buffer_shape));
1612 }
1613 } else {
1614 if (transfer_manager.GetByteSizeRequirement(buffer_shape) !=
1615 transfer_manager.GetByteSizeRequirement(execution_shape)) {
1616 return InvalidArgument(
1617 "Executable expected shape %s for argument %d but got "
1618 "incompatible "
1619 "shape %s",
1620 ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1621 ShapeUtil::HumanStringWithLayout(buffer_shape));
1622 }
1623 }
1624 return Status::OK();
1625 }
1626
1627 // Makes a tuple from the arguments to an execution.
MakeTupleHelper(PjRtStreamExecutorClient * client,LocalDeviceState * local_device,bool strict_shape_checking,const Shape & tupled_parameter_shape,absl::Span<PjRtBuffer * const> py_buffers,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,int device_ordinal)1628 StatusOr<TupleHandle> MakeTupleHelper(
1629 PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
1630 bool strict_shape_checking, const Shape& tupled_parameter_shape,
1631 absl::Span<PjRtBuffer* const> py_buffers,
1632 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1633 int device_ordinal) {
1634 se::DeviceMemoryAllocator* allocator = client->allocator();
1635 TransferManager* transfer_manager =
1636 client->client()->backend().transfer_manager();
1637
1638 if (tupled_parameter_shape.tuple_shapes_size() != py_buffers.size()) {
1639 return InvalidArgument("Executable expected %lld parameters but got %lld",
1640 tupled_parameter_shape.tuple_shapes_size(),
1641 py_buffers.size());
1642 }
1643 for (int i = 0; i < py_buffers.size(); ++i) {
1644 TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1645 strict_shape_checking, py_buffers[i]->on_device_shape(),
1646 tupled_parameter_shape.tuple_shapes(i), *transfer_manager, i));
1647 }
1648
1649 se::Stream* stream = local_device->host_to_device_stream();
1650 TF_ASSIGN_OR_RETURN(
1651 se::OwningDeviceMemory root_table_memory,
1652 allocator->Allocate(
1653 device_ordinal,
1654 transfer_manager->GetByteSizeRequirement(tupled_parameter_shape)));
1655
1656 if (local_device->allocation_model() ==
1657 LocalDeviceState::kComputeSynchronized) {
1658 stream->ThenWaitFor(local_device->compute_stream());
1659 } else {
1660 DCHECK(transfer_manager->CanBufferBeAccessedNow(
1661 local_device->compute_stream()->parent(), root_table_memory.cref()));
1662 }
1663
1664 ExecutionInput execution_input(tupled_parameter_shape);
1665 ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1666 execution_input.MutableBuffers()->begin();
1667 ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1668 execution_input.MutableBuffers()->end();
1669 // First set the root tuple table which is the first buffer in the ShapeTree.
1670 execution_input.SetBuffer(
1671 input_iterator->first,
1672 MaybeOwningDeviceMemory(std::move(root_table_memory)));
1673 ++input_iterator;
1674 // Then set each sub-tuple in turn from the parameters.
1675 for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
1676 device_buffers) {
1677 device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
1678 allocator);
1679 }
1680 CHECK(input_iterator == iterator_end);
1681
1682 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
1683 stream, execution_input.Buffers()));
1684 StatusOr<EventPool::Handle> event_or =
1685 local_device->event_pool().ThenAllocateAndRecordEvent(stream);
1686 if (!event_or.ok()) {
1687 StallStreamOnError(local_device, stream);
1688 return event_or.status();
1689 }
1690
1691 auto transfer_event = std::make_shared<BufferSequencingEvent>();
1692 transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1693 return TupleHandle({std::move(execution_input), std::move(transfer_event)});
1694 }
1695
1696 // Converts a ScopedShapedBuffer returned from an execution into a
1697 // PjRtBuffer.
OutputBufferHelper(ScopedShapedBuffer * result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtClient * client,PjRtDevice * device,LocalDeviceState * local_device,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release)1698 std::unique_ptr<PjRtBuffer> OutputBufferHelper(
1699 ScopedShapedBuffer* result_buffer,
1700 std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
1701 PjRtDevice* device, LocalDeviceState* local_device,
1702 std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) {
1703 std::shared_ptr<TrackedDeviceBuffer> out_buffer =
1704 TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
1705 {definition_event});
1706 auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
1707 result_buffer->on_device_shape(), std::move(out_buffer), client, device);
1708 RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
1709 definition_event, local_device->compute_stream(),
1710 /*prefer_to_retain_reference=*/false, &buffers_to_release);
1711 return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
1712 }
1713 } // namespace
1714
PjRtStreamExecutorExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,bool parameter_is_tupled_arguments,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,PjRtStreamExecutorClient * client)1715 PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
1716 std::vector<std::unique_ptr<LocalExecutable>> executables,
1717 bool parameter_is_tupled_arguments,
1718 std::shared_ptr<DeviceAssignment> device_assignment,
1719 std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1720 std::vector<PjRtDevice*> addressable_devices,
1721 PjRtStreamExecutorClient* client)
1722 : client_(client),
1723 device_assignment_(std::move(device_assignment)),
1724 parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1725 addressable_device_logical_ids_(
1726 std::move(addressable_device_logical_ids)),
1727 addressable_devices_(std::move(addressable_devices)) {
1728 TransferManager* transfer_manager =
1729 client_->client()->backend().transfer_manager();
1730 executables_.reserve(executables.size());
1731 for (auto& executable : executables) {
1732 const auto& computation_layout =
1733 executable->executable()->module().entry_computation_layout();
1734 std::vector<Shape> parameter_shapes;
1735 parameter_shapes.reserve(computation_layout.parameter_count());
1736 for (int i = 0; i < computation_layout.parameter_count(); ++i) {
1737 parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape(
1738 computation_layout.parameter_shape(i)));
1739 }
1740 executables_.emplace_back(std::move(executable));
1741 on_device_executable_parameter_shapes_.push_back(
1742 std::move(parameter_shapes));
1743 }
1744
1745 int num_partitions;
1746 if (device_assignment_ == nullptr) {
1747 // This must go after `executables_` is initialized.
1748 VLOG(3) << "PjRtStreamExecutorExecutable portable single-core";
1749 num_partitions = 1;
1750 CHECK(addressable_devices_.empty());
1751 } else {
1752 // This must go after `executables_` is initialized.
1753 VLOG(3) << "PjRtStreamExecutorExecutable device_assignment:\n"
1754 << device_assignment_->ToString();
1755 CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
1756 CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
1757 << "Inconsistent local device count.";
1758 num_partitions = device_assignment_->computation_count();
1759 }
1760
1761 // SPMD sharding produces a single executable for multiple partitions.
1762 if (executables_.size() > 1) {
1763 CHECK_EQ(num_partitions, executables_.size())
1764 << "Number of executables " << executables_.size()
1765 << " did not match number of partitions " << num_partitions;
1766 }
1767 }
1768
SetUpDonation(bool tuple_inputs)1769 Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
1770 parameters_that_must_be_donated_.reserve(executables_.size());
1771 for (auto& executable : executables_) {
1772 TF_ASSIGN_OR_RETURN(std::vector<int> parameters_to_donate,
1773 ComputeParametersThatMustBeDonated(
1774 executable->executable()->module(), tuple_inputs));
1775 parameters_that_must_be_donated_.emplace_back(
1776 std::move(parameters_to_donate));
1777 }
1778 return Status::OK();
1779 }
1780
name() const1781 absl::string_view PjRtStreamExecutorExecutable::name() const {
1782 Executable* executable = executables_[0]->executable();
1783 if (executable->has_module()) {
1784 return executable->module().name();
1785 } else {
1786 return "<unknown executable>";
1787 }
1788 }
1789
ParametersThatMustBeDonated(int executable_idx) const1790 absl::Span<int const> PjRtStreamExecutorExecutable::ParametersThatMustBeDonated(
1791 int executable_idx) const {
1792 return parameters_that_must_be_donated_[executable_idx];
1793 }
1794
1795 StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(int device_ordinal,const ExecuteOptions & options,absl::Span<const Shape> executable_parameter_shapes,absl::Span<PjRtBuffer * const> argument_handles,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,absl::flat_hash_set<BufferSequencingEvent * > & events) const1796 PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
1797 int device_ordinal, const ExecuteOptions& options,
1798 absl::Span<const Shape> executable_parameter_shapes,
1799 absl::Span<PjRtBuffer* const> argument_handles,
1800 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1801 absl::flat_hash_set<BufferSequencingEvent*>& events) const {
1802 std::vector<ExecutionInput> execution_inputs;
1803 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1804 TransferManager* transfer_manager =
1805 client_->client()->backend().transfer_manager();
1806 // Lift tuple_handle outside the conditional so that the event it returns is
1807 // not destroyed until after the loop below that waits on events.
1808 absl::optional<TupleHandle> tuple_handle;
1809 if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1810 TF_ASSIGN_OR_RETURN(
1811 tuple_handle,
1812 MakeTupleHelper(client_, device_state, options.strict_shape_checking,
1813 executable_parameter_shapes[0], argument_handles,
1814 device_buffers, device_ordinal));
1815 events.insert(tuple_handle->event.get());
1816 execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
1817 } else {
1818 if (argument_handles.size() != executable_parameter_shapes.size()) {
1819 return InvalidArgument("Executable expected %lld arguments but got %lld",
1820 executable_parameter_shapes.size(),
1821 argument_handles.size());
1822 }
1823 execution_inputs.reserve(argument_handles.size());
1824 for (int i = 0; i < argument_handles.size(); ++i) {
1825 PjRtBuffer* handle = argument_handles[i];
1826
1827 // Make an ExecutionInput from the device buffer.
1828 TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1829 options.strict_shape_checking, handle->on_device_shape(),
1830 executable_parameter_shapes[i], *transfer_manager, i));
1831 execution_inputs.emplace_back(executable_parameter_shapes[i]);
1832 ExecutionInput& execution_input = execution_inputs.back();
1833 ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1834 execution_input.MutableBuffers()->begin();
1835 ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1836 execution_input.MutableBuffers()->end();
1837 device_buffers[i].AddToInput(&input_iterator, iterator_end,
1838 &execution_input, client_->allocator());
1839 CHECK(input_iterator == iterator_end);
1840 }
1841 }
1842
1843 for (BufferSequencingEvent* event : events) {
1844 event->WaitForEventOnStream(device_state->compute_stream());
1845 }
1846
1847 return execution_inputs;
1848 }
1849
1850 // Enqueues a computation onto the compute stream. Each buffer returned in
1851 // device_buffers has a usage hold added that must be dropped on error or
1852 // converted on success.
EnqueueExecution(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,int executable_idx,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device,std::vector<PjRtStreamExecutorBuffer::ScopedHold> * device_buffers,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<std::function<void ()>> & compute_callbacks) const1853 StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
1854 absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1855 int executable_idx, const RunId& run_id, const ExecuteOptions& options,
1856 PjRtDevice* device,
1857 std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
1858 std::shared_ptr<DeviceAssignment> device_assignment,
1859 std::vector<std::function<void()>>& compute_callbacks) const {
1860 int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1861 ->local_device_state()
1862 ->device_ordinal();
1863 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1864 tensorflow::profiler::TraceMeConsumer activity(
1865 "PjRtStreamExecutorExecutable::EnqueueExecution",
1866 tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
1867 VLOG(3) << "Replica " << replica << ", partition " << partition
1868 << " mapped to device ordinal for execution: " << device_ordinal;
1869
1870 absl::flat_hash_set<BufferSequencingEvent*> events;
1871 device_buffers->reserve(argument_handles.size());
1872 absl::Span<int const> donated_params =
1873 ParametersThatMustBeDonated(executable_idx);
1874 auto donate_it = donated_params.begin();
1875 for (int i = 0; i < argument_handles.size(); ++i) {
1876 auto* handle =
1877 tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
1878 if (handle->device() != device) {
1879 return InvalidArgument(
1880 "Buffer passed to Execute() as argument %d to replica %d is on "
1881 "device %s, but replica is assigned to device %s.",
1882 i, replica, handle->device()->DebugString(), device->DebugString());
1883 }
1884 bool must_donate = donate_it != donated_params.end() && *donate_it == i;
1885 if (must_donate) {
1886 ++donate_it;
1887 }
1888 device_buffers->emplace_back(handle->GetBufferWithHold(
1889 must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
1890 : PjRtStreamExecutorBuffer::ScopedHold::kUsage));
1891 PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
1892 device_buffers->back();
1893 if (!device_buffer.ok()) {
1894 return InvalidArgument(
1895 "Invalid buffer passed to Execute() as argument %d to replica %d: "
1896 "%s",
1897 i, replica, device_buffer.status().ToString());
1898 }
1899 // If we are trying to donate the buffer wait on the usage events as well
1900 // as the definition events to ensure that all reads have been completed
1901 // before the buffer is mutated. Usage holds are excluded during a donation
1902 // hold so we know that the set of usage events won't be modified while we
1903 // are enqueueing.
1904 GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
1905 &events);
1906 }
1907
1908 if (options.arguments_are_tupled) {
1909 if (!parameter_is_tupled_arguments_) {
1910 return InvalidArgument(
1911 "Arguments may only be supplied as a tuple when the executable was "
1912 "compiled with a single tupled parameter");
1913 }
1914 if (argument_handles.size() != 1) {
1915 return InvalidArgument(
1916 "Option arguments_are_tupled was true but %d buffers were passed to "
1917 "execution",
1918 argument_handles.size());
1919 }
1920 }
1921
1922 TF_ASSIGN_OR_RETURN(
1923 std::vector<ExecutionInput> execution_inputs,
1924 MakeExecutionInputsAndWaitForEvents(
1925 device_ordinal, options,
1926 on_device_executable_parameter_shapes_[executable_idx],
1927 argument_handles, *device_buffers, events));
1928
1929 ExecutableRunOptions run_options;
1930 run_options.set_stream(device_state->compute_stream());
1931 run_options.set_host_to_device_stream(device_state->host_to_device_stream());
1932 run_options.set_allocator(client_->allocator());
1933 run_options.set_intra_op_thread_pool(
1934 client_->client()->backend().eigen_intra_op_thread_pool_device());
1935 run_options.set_device_assignment(device_assignment.get());
1936 run_options.set_run_id(run_id);
1937 run_options.set_rng_seed(device_state->GetNewPrngSeed());
1938 run_options.set_gpu_executable_run_options(client_->gpu_run_options());
1939 run_options.set_launch_id(options.launch_id);
1940 if (run_options.launch_id() != 0) {
1941 VLOG(3) << "launch id for " << name() << ": " << run_options.launch_id();
1942 }
1943
1944 // The choice of where we wait is arbitrary; the reason for the wait is
1945 // pacing to avoid problems such as memory fragmentation and running ahead
1946 // too far, not for correctness. Placing it before the executable launch
1947 // allows the inputs for the next executable to be fetched even if the
1948 // launch is delayed.
1949 std::shared_ptr<Semaphore::ScopedReservation> compute_reservation;
1950 {
1951 tensorflow::profiler::TraceMe traceme("ComputeSemaphoreAcquire");
1952 compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
1953 device_state->compute_semaphore().ScopedAcquire(1));
1954 }
1955
1956 StatusOr<ExecutionOutput> result_buffer_or_status =
1957 executables_[executable_idx]->RunAsync(std::move(execution_inputs),
1958 run_options);
1959
1960 VLOG(1) << "Replica " << replica << " partition " << partition
1961 << " completed; ok=" << result_buffer_or_status.ok();
1962
1963 if (!result_buffer_or_status.ok()) {
1964 return result_buffer_or_status.status();
1965 }
1966
1967 if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1968 ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
1969 // If we used a transient tuple for the arguments we donated its root table
1970 // buffer. In that case, and/or if we donated any input buffers that were
1971 // not aliased, the donated buffers are going to be passed back to us via
1972 // the execution output. We need to ensure they aren't freed until after
1973 // execution completes. (Currently XLA does not support aliasing tuple
1974 // tables, so if any donated parameter is a tuple there will be donated but
1975 // unaliased buffers.)
1976 std::vector<se::OwningDeviceMemory> donated_memory =
1977 execution_output.ConsumeToBeReleased();
1978 absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
1979 donated_ptrs.reserve(donated_memory.size());
1980 for (se::OwningDeviceMemory& owning : donated_memory) {
1981 // Release the owning memory so we can pass it to the closure.
1982 donated_ptrs.push_back(owning.Release());
1983 }
1984 compute_callbacks.push_back(
1985 [references{std::make_tuple(executables_[executable_idx],
1986 compute_reservation, device_assignment)},
1987 donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
1988 device_ordinal]() {
1989 for (const auto& ptr : donated_ptrs) {
1990 TF_CHECK_OK(allocator->Deallocate(device_ordinal, ptr));
1991 }
1992 });
1993 } else {
1994 // Any donated memory returned by the ExecutionOutput can be immediately
1995 // freed.
1996 compute_callbacks.push_back(
1997 [to_release{std::make_tuple(executables_[executable_idx],
1998 compute_reservation,
1999 device_assignment)}]() {});
2000 }
2001
2002 return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
2003 }
2004
2005 std::vector<std::unique_ptr<PjRtBuffer>>
MakeOutputBuffers(int device_ordinal,const ExecuteOptions & options,ScopedShapedBuffer result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtDevice * device,std::vector<std::function<void ()>> & compute_callbacks,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release) const2006 PjRtStreamExecutorExecutable::MakeOutputBuffers(
2007 int device_ordinal, const ExecuteOptions& options,
2008 ScopedShapedBuffer result_buffer,
2009 std::shared_ptr<BufferSequencingEvent> definition_event, PjRtDevice* device,
2010 std::vector<std::function<void()>>& compute_callbacks,
2011 std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release)
2012 const {
2013 tensorflow::profiler::TraceMe traceme("MakeOutputBuffers");
2014 std::vector<std::unique_ptr<PjRtBuffer>> outputs;
2015 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2016 if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
2017 int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
2018 outputs.reserve(tuple_count);
2019 // Take ownership of each of the output values, leaving only the root table
2020 // in result_buffer.
2021 for (int i = 0; i < tuple_count; ++i) {
2022 ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i});
2023 outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event,
2024 client_, device, device_state,
2025 buffers_to_release));
2026 }
2027 if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
2028 // Don't release the root buffer until after execution completes.
2029 ShapedBuffer root_buffer_holder = result_buffer.release();
2030 se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer();
2031 compute_callbacks.push_back(
2032 [root_buffer, allocator{client_->allocator()}, device_ordinal]() {
2033 TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer));
2034 });
2035 }
2036 } else {
2037 outputs.push_back(OutputBufferHelper(&result_buffer, definition_event,
2038 client_, device, device_state,
2039 buffers_to_release));
2040 }
2041 return outputs;
2042 }
2043
2044 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device) const2045 PjRtStreamExecutorExecutable::ExecuteHelper(
2046 absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
2047 const RunId& run_id, const ExecuteOptions& options,
2048 PjRtDevice* device) const {
2049 const uint64 start_time_usecs = tensorflow::Env::Default()->NowMicros();
2050 std::shared_ptr<DeviceAssignment> device_assignment;
2051 if (device == nullptr) {
2052 CHECK(device_assignment_ != nullptr);
2053 const int device_id = (*device_assignment_)(replica, partition);
2054 TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
2055 device_assignment = device_assignment_;
2056 } else {
2057 CHECK(device_assignment_ == nullptr);
2058 CHECK_EQ(replica, 0);
2059 CHECK_EQ(partition, 0);
2060 CHECK(addressable_devices_.empty());
2061 device_assignment = std::make_shared<DeviceAssignment>(1, 1);
2062 (*device_assignment)(0, 0) = device->id();
2063 }
2064
2065 CHECK_EQ(device->process_index(), client_->process_index());
2066 int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2067 ->local_device_state()
2068 ->device_ordinal();
2069 tensorflow::profiler::TraceMe traceme(
2070 "PjRtStreamExecutorExecutable::ExecuteHelper");
2071 VLOG(1) << "Replica " << replica << ", partition " << partition
2072 << " mapped to device ordinal for execution: " << device_ordinal;
2073
2074 // SPMD sharding produces a single executable for multiple partitions.
2075 int executable_idx = executables_.size() > 1 ? partition : 0;
2076
2077 std::vector<std::function<void()>> compute_callbacks;
2078 std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
2079 device_buffers.reserve(argument_handles.size());
2080 StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
2081 argument_handles, replica, partition, executable_idx, run_id, options,
2082 device, &device_buffers, std::move(device_assignment), compute_callbacks);
2083
2084 if (!result_buffer_or_status.ok()) {
2085 LOG(ERROR) << "Execution of replica " << replica
2086 << " failed: " << result_buffer_or_status.status();
2087 return result_buffer_or_status.status();
2088 }
2089 ScopedShapedBuffer result_buffer =
2090 result_buffer_or_status.ConsumeValueOrDie();
2091
2092 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2093 se::Stream* stream = device_state->compute_stream();
2094 StatusOr<EventPool::Handle> event_or =
2095 device_state->event_pool().ThenAllocateAndRecordEvent(stream);
2096 if (!event_or.ok()) {
2097 StallStreamOnError(device_state, stream);
2098 for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2099 if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
2100 // Even though there was an error we need to call ConfirmDonation, which
2101 // renders b invalid, since the computation has been enqueued and b has
2102 // been donated.
2103 b.ConfirmDonation();
2104 }
2105 }
2106 return event_or.status();
2107 }
2108 auto definition_event = std::make_shared<BufferSequencingEvent>();
2109 definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
2110 std::vector<std::shared_ptr<TrackedDeviceBuffer>> buffers_to_release;
2111 std::vector<std::unique_ptr<PjRtBuffer>> outputs = MakeOutputBuffers(
2112 device_ordinal, options, std::move(result_buffer), definition_event,
2113 device, compute_callbacks, buffers_to_release);
2114
2115 for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2116 // prefer_to_retain_reference=false because when using the
2117 // ComputeSynchronized allocation model we don't need to retain a reference
2118 // to the device_buffer during execution because by definition the compute
2119 // stream is synchronized past the execution.
2120 if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
2121 RecordUsage(std::move(b), device_state, device_state, definition_event,
2122 stream,
2123 /*prefer_to_retain_reference=*/false, &buffers_to_release);
2124 } else {
2125 CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
2126 b.ConfirmDonation();
2127 }
2128 }
2129
2130 if (!compute_callbacks.empty()) {
2131 device_state->ThenExecuteCallback(
2132 stream, [callbacks{std::move(compute_callbacks)},
2133 buffers_to_release{std::move(buffers_to_release)}]() {
2134 for (auto& fn : callbacks) {
2135 fn();
2136 }
2137 });
2138 }
2139 ReportExecutableEnqueueTime(tensorflow::Env::Default()->NowMicros() -
2140 start_time_usecs);
2141 return outputs;
2142 }
2143
2144 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options)2145 PjRtStreamExecutorExecutable::Execute(
2146 absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
2147 const ExecuteOptions& options) {
2148 if (device_assignment_ == nullptr) {
2149 return InvalidArgument("Execute expects a non-null device_assignment");
2150 }
2151
2152 RunId run_id;
2153 tensorflow::profiler::TraceMeProducer activity(
2154 "PjRtStreamExecutorExecutable::Execute",
2155 tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
2156
2157 const int num_addressable_devices = addressable_devices_.size();
2158
2159 if (argument_handles.size() != num_addressable_devices) {
2160 return InvalidArgument(
2161 "Attempted to execute with %d argument lists when local device "
2162 "count is %d (total replica count: %d, partition count: %d)",
2163 argument_handles.size(), num_addressable_devices, num_replicas(),
2164 num_partitions());
2165 }
2166
2167 VLOG(1) << "Executing computation " << name()
2168 << "; num_replicas=" << num_replicas()
2169 << " num_partitions=" << num_partitions()
2170 << " num_addressable_devices=" << num_addressable_devices;
2171 std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
2172 num_addressable_devices);
2173 if (num_addressable_devices == 1) {
2174 // Fast-path if there is only one device — run the computation on the
2175 // current thread.
2176 const int replica = addressable_device_logical_ids_[0].replica;
2177 const int partition = addressable_device_logical_ids_[0].partition;
2178 results[0] =
2179 ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
2180 } else {
2181 absl::Mutex mu;
2182 int running = num_addressable_devices;
2183 int failed = 0;
2184 Status first_failure_status;
2185
2186 for (int i = 0; i < num_addressable_devices; ++i) {
2187 const int replica = addressable_device_logical_ids_[i].replica;
2188 const int partition = addressable_device_logical_ids_[i].partition;
2189 PjRtDevice* device = addressable_devices_[i];
2190 const LocalDeviceState& device_state =
2191 *tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2192 ->local_device_state();
2193 device_state.execute_thread()->Schedule([&, replica, partition, i] {
2194 results[i] = ExecuteHelper(argument_handles[i], replica, partition,
2195 run_id, options);
2196
2197 absl::MutexLock lock(&mu);
2198 --running;
2199 if (!results[i].ok()) {
2200 if (failed == 0) {
2201 first_failure_status = results[i].status();
2202 }
2203 ++failed;
2204 }
2205 });
2206 }
2207
2208 auto done_running_or_failed = [&]() {
2209 mu.AssertHeld();
2210 return running == 0 || failed > 0;
2211 };
2212 absl::MutexLock lock(&mu);
2213 mu.Await(absl::Condition(&done_running_or_failed));
2214 if (failed > 0) {
2215 auto done_running = [&]() {
2216 mu.AssertHeld();
2217 return running == 0;
2218 };
2219 // If execution does not terminate within a reasonable amount of time,
2220 // we may be stuck at a cross-replica barrier on-device. Terminate the
2221 // process since that's the only way we can escape this situation at the
2222 // moment (b/130629719).
2223 if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
2224 absl::Seconds(10))) {
2225 LOG(FATAL)
2226 << "Replicated computation launch failed, but not all replicas "
2227 "terminated. Aborting process to work around deadlock. "
2228 "Failure message (there may have been multiple failures, see "
2229 "the error log for all failures): \n\n"
2230 << first_failure_status.error_message();
2231 }
2232 }
2233 }
2234 VLOG(1) << "Replicated execution complete.";
2235
2236 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
2237 num_addressable_devices);
2238 for (int i = 0; i < num_addressable_devices; ++i) {
2239 const int replica = addressable_device_logical_ids_[i].replica;
2240 const int partition = addressable_device_logical_ids_[i].partition;
2241 auto& statusor = results[i];
2242 if (!statusor.ok()) {
2243 if (num_addressable_devices == 1) {
2244 return statusor.status();
2245 } else {
2246 return AppendStatus(
2247 statusor.status(),
2248 absl::StrFormat("while running replica %d and partition %d of a "
2249 "replicated computation (other "
2250 "replicas may have failed as well).",
2251 replica, partition));
2252 }
2253 }
2254 wrapped_results[i] = std::move(statusor.ValueOrDie());
2255 }
2256 return wrapped_results;
2257 }
2258
2259 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2260 PjRtStreamExecutorExecutable::ExecuteSharded(
2261 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2262 const ExecuteOptions& options) {
2263 if (device_assignment_ == nullptr) {
2264 return InvalidArgument("ExecuteShard expects a non-null device_assignment");
2265 }
2266 for (int i = 0; i < addressable_devices_.size(); ++i) {
2267 if (addressable_devices_[i] == device) {
2268 VLOG(1) << "ExecuteShard executes computation " << name()
2269 << " on assigned replica/partition on device "
2270 << device->DebugString();
2271 return ExecuteHelper(
2272 argument_handles, addressable_device_logical_ids_[i].replica,
2273 addressable_device_logical_ids_[i].partition, RunId(), options);
2274 }
2275 }
2276 return InvalidArgument(
2277 "ExecuteShard attempted to execute on device id %d which is not "
2278 "addressable by this client",
2279 device->id());
2280 }
2281
2282 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2283 PjRtStreamExecutorExecutable::ExecutePortable(
2284 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2285 const ExecuteOptions& options) {
2286 if (device_assignment_ != nullptr) {
2287 return InvalidArgument("ExecutePortable gets a non-portable executable");
2288 }
2289 if (num_replicas() != 1 || num_partitions() != 1) {
2290 return InvalidArgument(
2291 "ExecutePortable expects a single-core executable but gets "
2292 "one with %d replica %d partition",
2293 num_replicas(), num_partitions());
2294 }
2295 if (device == nullptr) {
2296 return InvalidArgument("ExecutePortable expects a device to be specified");
2297 }
2298 VLOG(1) << "ExecutePortable executes single-core portable executable "
2299 << name();
2300 return ExecuteHelper(argument_handles,
2301 /*replica=*/0,
2302 /*partition=*/0, RunId(), options, device);
2303 }
2304
2305 StatusOr<std::vector<std::shared_ptr<HloModule>>>
GetHloModules() const2306 PjRtStreamExecutorExecutable::GetHloModules() const {
2307 std::vector<std::shared_ptr<HloModule>> modules;
2308 modules.reserve(executables().size());
2309 for (const auto& local_exec : executables()) {
2310 if (!local_exec->executable()->has_module()) {
2311 return InvalidArgument("Executable does not have HLO modules.");
2312 }
2313 modules.push_back(local_exec->executable()->shared_module());
2314 }
2315 return std::move(modules);
2316 }
2317
2318 StatusOr<PjRtStreamExecutorClient::ExecutableExtras>
GetExecutableExtras(CompileOptions * options)2319 PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
2320 ExecutableExtras extras;
2321 std::shared_ptr<DeviceAssignment>& device_assignment =
2322 extras.device_assignment;
2323 std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2324 addressable_device_logical_ids = extras.addressable_device_logical_ids;
2325 std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2326
2327 ExecutableBuildOptions& build_options = options->executable_build_options;
2328 if (!build_options.compile_thread_pool()) {
2329 build_options.set_compile_thread_pool(thread_pool());
2330 }
2331 if (!build_options.device_allocator()) {
2332 build_options.set_device_allocator(allocator());
2333 }
2334
2335 int num_replicas;
2336 int num_partitions;
2337 TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
2338 options->compile_portable_executable, &options->executable_build_options,
2339 [this](int num_replicas, int num_partitions) {
2340 return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
2341 },
2342 &num_replicas, &num_partitions, &device_assignment));
2343
2344 // Find devices that are addressable by this client/task.
2345 if (device_assignment != nullptr) {
2346 addressable_device_logical_ids.reserve(num_replicas * num_partitions);
2347 addressable_devices.reserve(num_replicas * num_partitions);
2348 for (int replica = 0; replica < num_replicas; ++replica) {
2349 for (int partition = 0; partition < num_partitions; ++partition) {
2350 int device_id = (*device_assignment)(replica, partition);
2351 TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
2352 if (device->process_index() != process_index()) {
2353 VLOG(3) << "Non-local device: " << device_id;
2354 continue;
2355 }
2356 PjRtExecutable::LogicalDeviceIds logica_device_ids;
2357 logica_device_ids.replica = replica;
2358 logica_device_ids.partition = partition;
2359 addressable_device_logical_ids.push_back(std::move(logica_device_ids));
2360 addressable_devices.push_back(device);
2361 }
2362 }
2363 if (addressable_devices.empty()) {
2364 return InvalidArgument(
2365 "Device assignment (%s) does not have any local devices.",
2366 device_assignment->ToString());
2367 }
2368
2369 if (build_options.device_ordinal() < 0) {
2370 build_options.set_device_ordinal(
2371 addressable_devices.front()->local_hardware_id());
2372 }
2373 }
2374 return extras;
2375 }
2376
Compile(const XlaComputation & computation,CompileOptions options)2377 StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
2378 const XlaComputation& computation, CompileOptions options) {
2379 tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
2380 VLOG(1) << "PjRtStreamExecutorClient::Compile";
2381
2382 TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options));
2383 std::shared_ptr<DeviceAssignment>& device_assignment =
2384 extras.device_assignment;
2385 std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2386 addressable_device_logical_ids = extras.addressable_device_logical_ids;
2387 std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2388
2389 std::vector<const Shape*> argument_layout_pointers;
2390 TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
2391 computation,
2392 [local_client = client()](Shape shape) {
2393 return local_client->backend()
2394 .transfer_manager()
2395 ->ChooseCompactLayoutForShape(shape);
2396 },
2397 options.argument_layouts, &options.executable_build_options,
2398 &argument_layout_pointers));
2399
2400 TF_ASSIGN_OR_RETURN(
2401 std::vector<std::unique_ptr<LocalExecutable>> local_executables,
2402 client()->Compile(computation, argument_layout_pointers,
2403 options.executable_build_options));
2404
2405 auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
2406 std::move(local_executables), options.parameter_is_tupled_arguments,
2407 std::move(device_assignment), std::move(addressable_device_logical_ids),
2408 std::move(addressable_devices), this);
2409 TF_RETURN_IF_ERROR(
2410 executable->SetUpDonation(options.parameter_is_tupled_arguments));
2411 return std::unique_ptr<PjRtExecutable>(std::move(executable));
2412 }
2413
2414 } // namespace xla
2415