• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Asynchronous execution:
19 // -----------------------
20 //
21 // Computations and host-to-device transfers do not need to block the host
22 // waiting for the operation to complete but instead return control to the host
23 // immediately. This allows client logic to overlap with device-side
24 // computation.
25 //
26 // For a good user experience, we must be careful only to enqueue operations
27 // that are unlikely to fail; as a rule error checking must be done eagerly
28 // before returning control to the client.
29 //
30 // The degree to which the client can enqueue operations ahead of the client
31 // is limited by a semaphore. There are at two modes: asynchronous, where we
32 // allow the client to enqueue up to 32 executions ahead of the device, and
33 // synchronous, where we limit the client to having one enqueued operation at
34 // a time. The value of 32 is arbitrary.
35 //
36 // Even in asynchronous mode, it is important that we do not permit
37 // unbounded queue-ahead. Firstly it is problematic when the user does something
38 // like the following in Python:
39 // %timeit run_computation()
40 // To the timeit logic, op() appears to be extremely cheap since it is deferring
41 // all of its real work and not blocking, and so the %timeit will run op() many
42 // (e.g., 10000) times to get better timing resolution, even though in reality
43 // it may be expensive. Secondly, on CPU the allocator is synchronized with the
44 // head of the compute stream, and we allocate buffers for all of the enqueued
45 // programs without any reuse (unlike GPU). This means that the memory usage
46 // is proportional to the queue size.
47 //
48 // Multi-stream execution:
49 // -----------------------
50 //
51 // We use a multistream execution design, where different Streams are used for
52 // host-to-device transfers, device-to-host transfers, and compute. This allows
53 // us to overlap transfers on and off the device with computation.
54 //
55 // Synchronization between streams occurs via BufferSequencingEvents that
56 // describe when the contents of a logical buffer are known to be valid on
57 // a particular stream, and when a buffer's uses have all completed.
58 //
59 // Synchronous vs asynchronous deallocation:
60 // -----------------------------------------
61 //
62 // See the comment on LocalDeviceState::AllocationModel for a discussion of the
63 // different allocation semantics on CPU, GPU, and TPU.
64 
65 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
66 
67 #include <cstddef>
68 #include <cstdlib>
69 #include <memory>
70 #include <string>
71 #include <vector>
72 
73 #include "absl/base/casts.h"
74 #include "absl/container/flat_hash_set.h"
75 #include "absl/container/inlined_vector.h"
76 #include "absl/memory/memory.h"
77 #include "absl/strings/str_format.h"
78 #include "absl/synchronization/mutex.h"
79 #include "absl/time/time.h"
80 #include "absl/types/optional.h"
81 #include "absl/types/span.h"
82 #include "tensorflow/compiler/xla/client/local_client.h"
83 #include "tensorflow/compiler/xla/client/xla_computation.h"
84 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
85 #include "tensorflow/compiler/xla/executable_run_options.h"
86 #include "tensorflow/compiler/xla/layout.h"
87 #include "tensorflow/compiler/xla/literal.h"
88 #include "tensorflow/compiler/xla/literal_util.h"
89 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
90 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
91 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
92 #include "tensorflow/compiler/xla/pjrt/metrics.h"
93 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
94 #include "tensorflow/compiler/xla/pjrt/utils.h"
95 #include "tensorflow/compiler/xla/service/computation_layout.h"
96 #include "tensorflow/compiler/xla/service/executable.h"
97 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
98 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
99 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
100 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
101 #include "tensorflow/compiler/xla/service/transfer_manager.h"
102 #include "tensorflow/compiler/xla/shape.h"
103 #include "tensorflow/compiler/xla/shape_util.h"
104 #include "tensorflow/compiler/xla/util.h"
105 #include "tensorflow/compiler/xla/xla_data.pb.h"
106 #include "tensorflow/core/platform/cpu_info.h"
107 #include "tensorflow/core/platform/env.h"
108 #include "tensorflow/core/platform/errors.h"
109 #include "tensorflow/core/platform/fingerprint.h"
110 #include "tensorflow/core/platform/mem.h"
111 #include "tensorflow/core/platform/status.h"
112 #include "tensorflow/core/platform/types.h"
113 #include "tensorflow/core/profiler/lib/connected_traceme.h"
114 #include "tensorflow/core/profiler/lib/traceme.h"
115 #include "tensorflow/core/profiler/lib/traceme_encode.h"
116 #include "tensorflow/stream_executor/device_memory.h"
117 #include "tensorflow/stream_executor/device_memory_allocator.h"
118 #include "tensorflow/stream_executor/event.h"
119 #include "tensorflow/stream_executor/host/host_platform_id.h"
120 #include "tensorflow/stream_executor/lib/statusor.h"
121 #include "tensorflow/stream_executor/stream.h"
122 
123 namespace xla {
124 
platform_id() const125 PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
126   return client_->platform_id();
127 }
platform_name() const128 absl::string_view PjRtStreamExecutorDevice::platform_name() const {
129   return client_->platform_name();
130 }
131 
GetLocalDeviceState() const132 StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
133     const {
134   if (local_device_state_) {
135     return local_device_state_.get();
136   }
137   return InvalidArgument("Device %s is not a local device.", DebugString());
138 }
139 
DebugString() const140 std::string PjRtStreamExecutorDevice::DebugString() const {
141   return absl::StrCat(platform_name(), ":", id());
142 }
143 
DevicesToDeviceAssignment(absl::Span<const std::vector<PjRtDevice * >> devices)144 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
145     absl::Span<const std::vector<PjRtDevice*>> devices) {
146   if (devices.empty()) {
147     return InvalidArgument(
148         "Device assignment passed to Compile() must be non-empty.");
149   }
150   if (devices[0].empty()) {
151     return InvalidArgument(
152         "Device assignment passed to Compile() must have a nonzero number of "
153         "partitions per replica; replica 0 had 0 partitions.");
154   }
155   DeviceAssignment xla_assignment(devices.size(), devices[0].size());
156   for (int replica = 0; replica < devices.size(); ++replica) {
157     if (devices[replica].size() != devices[0].size()) {
158       return InvalidArgument(
159           "Device assignment passed to Compile() has different numbers of "
160           "partitions between replicas; %d partitions for replica %d versus %d "
161           "partitions for replica 0.",
162           devices[replica].size(), replica, devices[0].size());
163     }
164     for (int partition = 0; partition < devices[replica].size(); ++partition) {
165       if (devices[0][0]->client()->platform_id() !=
166           devices[replica][partition]->client()->platform_id()) {
167         return InvalidArgument(
168             "Device assignment passed to Compile() must have devices of a "
169             "single kind, got %s for replica 0 partition 0 and %s for replica "
170             "%d partition %d.",
171             devices[0][0]->client()->platform_name(),
172             devices[replica][partition]->client()->platform_name(), replica,
173             partition);
174       }
175       xla_assignment(replica, partition) = devices[replica][partition]->id();
176     }
177   }
178   return xla_assignment;
179 }
180 
181 class CpuAllocator : public tensorflow::Allocator {
182  public:
183   CpuAllocator() = default;
184 
Name()185   std::string Name() override { return "cpu"; }
186 
AllocateRaw(size_t alignment,size_t num_bytes)187   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
188     return tensorflow::port::AlignedMalloc(num_bytes, alignment);
189   }
DeallocateRaw(void * ptr)190   void DeallocateRaw(void* ptr) override {
191     return tensorflow::port::AlignedFree(ptr);
192   }
193 };
194 
PjRtStreamExecutorClient(std::string platform_name,LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int process_index,std::unique_ptr<se::DeviceMemoryAllocator> allocator,std::unique_ptr<tensorflow::Allocator> host_memory_allocator,bool should_stage_host_to_device_transfers,std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)195 PjRtStreamExecutorClient::PjRtStreamExecutorClient(
196     std::string platform_name, LocalClient* client,
197     std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
198     int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
199     std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
200     bool should_stage_host_to_device_transfers,
201     std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
202     : platform_id_(tensorflow::Fingerprint64(platform_name)),
203       platform_name_(std::move(platform_name)),
204       client_(client),
205       host_memory_allocator_(std::move(host_memory_allocator)),
206       owned_allocator_(std::move(allocator)),
207       owned_devices_(std::move(devices)),
208       process_index_(process_index),
209       should_stage_host_to_device_transfers_(
210           should_stage_host_to_device_transfers),
211       gpu_run_options_(std::move(gpu_run_options)),
212       thread_pool_(
213           tensorflow::Env::Default(), "pjrt_thread_pool",
214           std::max<int>(DefaultThreadPoolSize(), client->device_count())),
215       transpose_cache_(1024) {
216   if (owned_allocator_ != nullptr) {
217     allocator_ = owned_allocator_.get();
218   } else {
219     allocator_ = client_->backend().memory_allocator();
220   }
221 
222   if (!host_memory_allocator_) {
223     host_memory_allocator_ = std::make_unique<CpuAllocator>();
224   }
225 
226   for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
227        owned_devices_) {
228     devices_.push_back(device.get());
229     CHECK(id_to_device_.insert({device->id(), device.get()}).second)
230         << "Duplicate device id: " << device->id();
231 
232     if (device->IsAddressable()) {
233       int idx = device->local_hardware_id();
234       if (idx >= addressable_devices_.size()) {
235         addressable_devices_.resize(idx + 1);
236       }
237       CHECK(addressable_devices_[idx] == nullptr) << idx;
238       addressable_devices_[idx] = device.get();
239     }
240     device->SetClient(this);
241   }
242   for (int idx = 0; idx < addressable_devices_.size(); ++idx) {
243     CHECK(addressable_devices_[idx] != nullptr) << idx;
244   }
245 }
246 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const247 StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
248     int num_replicas, int num_partitions) const {
249   return client_->backend().computation_placer()->AssignDevices(num_replicas,
250                                                                 num_partitions);
251 }
252 
253 StatusOr<std::unique_ptr<HloCostAnalysis>>
GetHloCostAnalysis()254 PjRtStreamExecutorClient::GetHloCostAnalysis() {
255   return absl::make_unique<HloCostAnalysis>(
256       client_->backend().compiler()->ShapeSizeBytesFunction());
257 }
258 
259 namespace {
260 
261 // Ensures that it is safe to deallocate any buffers that have been enqueued in
262 // an operation on stream. Called only in rare error cases that are triggered
263 // during enqueue. These cases generally correspond to resource exhaustion.
StallStreamOnError(LocalDeviceState * local_device,se::Stream * stream)264 void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
265   switch (local_device->allocation_model()) {
266     case LocalDeviceState::kAsynchronous:
267       // We can safely deallocate any dangling buffers immediately. NOTE: this
268       // assumes that any buffers enqueued on stream are local to stream's
269       // executor, and manual action may be needed if that condition is not met.
270       break;
271 
272     case LocalDeviceState::kComputeSynchronized:
273       // This will stall computation but that's ok in this very rare error
274       // case.
275       if (stream != local_device->compute_stream()) {
276         local_device->compute_stream()->ThenWaitFor(stream);
277       }
278       break;
279 
280     case LocalDeviceState::kSynchronous:
281       // This will stall the calling thread but that's ok in this very rare
282       // error case. If the stall fails just crash, since we have no other
283       // way to synchronize.
284       TF_CHECK_OK(stream->BlockHostUntilDone());
285       break;
286   }
287 }
288 
289 // Does all necessary bookkeeping, after a buffer is successfully enqueued onto
290 // a stream, to ensure that the buffer will be kept alive until its use on that
291 // stream is complete.
292 //
293 //   device_buffer:              the buffer that was enqueued.
294 //   buffer_local_device:        the device the buffer was allocated on.
295 //   stream_local_device:        the device that manages usage_stream.
296 //   event:                      an event that was recorded on usage_stream
297 //                               after the usage of device_buffer was enqueued.
298 //   usage_stream:               the stream the operation using device_buffer
299 //                               was enqueued on.
300 //   prefer_to_retain_reference: relevant only for the compute synchronous
301 //                               allocation model. If true, retain a reference
302 //                               to device_buffer until after the operation
303 //                               completes. If false then the compute stream
304 //                               will have to be synchronized past event before
305 //                               device_buffer can be freed.
306 //
307 // prefer_to_retain_reference encodes a heuristic set by the caller for the
308 // compute synchronous model:
309 //
310 // Generally when a buffer is the destination of a copy to a device, it will
311 // subsequently be used on the device's compute stream before being freed. In
312 // that case, there is no need to retain a reference to the buffer. If the
313 // buffer is freed before being used on the compute stream, the free will be
314 // delayed until the host knows that event has completed, but this is expected
315 // to be uncommon.
316 //
317 // When a buffer is the source of a copy from a device, we need to either retain
318 // a reference to the buffer until the copy completes or serialize the compute
319 // stream behind the copy. It is often better to retain a reference since while
320 // that keeps memory alive longer, it avoids stalling the compute stream.
RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,LocalDeviceState * buffer_local_device,LocalDeviceState * stream_local_device,std::shared_ptr<BufferSequencingEvent> event,se::Stream * usage_stream,bool prefer_to_retain_reference,std::vector<std::shared_ptr<TrackedDeviceBuffer>> * buffers_to_release=nullptr)321 void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
322                  LocalDeviceState* buffer_local_device,
323                  LocalDeviceState* stream_local_device,
324                  std::shared_ptr<BufferSequencingEvent> event,
325                  se::Stream* usage_stream, bool prefer_to_retain_reference,
326                  std::vector<std::shared_ptr<TrackedDeviceBuffer>>*
327                      buffers_to_release = nullptr) {
328   tensorflow::profiler::TraceMe traceme("RecordUsage");
329   bool retain_buffer_until_completion =
330       // If the buffer wasn't allocated on the same device as the stream, always
331       // retain a reference.
332       (stream_local_device != buffer_local_device) ||
333       // In the synchronous allocation model, always retain a reference.
334       (stream_local_device->allocation_model() ==
335        LocalDeviceState::kSynchronous) ||
336       // In the compute synchronous model, use the caller's heuristic.
337       (stream_local_device->allocation_model() ==
338            LocalDeviceState::kComputeSynchronized &&
339        prefer_to_retain_reference);
340   if (retain_buffer_until_completion) {
341     if (buffers_to_release) {
342       buffers_to_release->push_back(device_buffer.buffer());
343     } else {
344       buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer());
345     }
346   }
347   device_buffer.ConvertUsageHold(usage_stream, event,
348                                  retain_buffer_until_completion);
349 }
350 
351 // Allocates the device buffers for a buffer that will be used as the
352 // destination of a copy, either from the host or another device. copy_stream
353 // may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the
354 // buffer is a tuple then the tuple tables are allocated, and all necessary
355 // synchronization for them is dealt with, before the buffer is returned.
356 //
357 // It is safe to delete the returned PjRtBuffer without further
358 // synchronization if an error occurs before the buffer is used.
359 //
360 // The caller may optionally provide a definition event to be recorded in
361 // the buffer.
362 // TODO(phawkins): replace on_host_shape here with on_device_shape.
AllocateDestinationBuffer(const Shape & on_host_shape,PjRtDevice * device,LocalDeviceState * local_device,se::Stream * copy_stream,bool is_uninitialized_create,PjRtClient * client,std::shared_ptr<BufferSequencingEvent> definition_event=nullptr)363 StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
364     const Shape& on_host_shape, PjRtDevice* device,
365     LocalDeviceState* local_device, se::Stream* copy_stream,
366     bool is_uninitialized_create, PjRtClient* client,
367     std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
368   if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
369     return InvalidArgument("Can't make a buffer from an empty tuple");
370   }
371 
372   auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
373   TransferManager* transfer_manager =
374       se_client->client()->backend().transfer_manager();
375   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
376                       transfer_manager->AllocateScopedShapedBuffer(
377                           on_host_shape, se_client->allocator(),
378                           local_device->device_ordinal()));
379   if (local_device->allocation_model() ==
380       LocalDeviceState::kComputeSynchronized) {
381     if (copy_stream == nullptr) {
382       CHECK(is_uninitialized_create);
383     } else {
384       copy_stream->ThenWaitFor(local_device->compute_stream());
385     }
386   } else {
387     DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
388         local_device->compute_stream()->parent(), dst_buffer));
389   }
390   Shape on_device_shape = dst_buffer.on_device_shape();
391 
392   absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
393       definition_events;
394   if (is_uninitialized_create) {
395     // There is not going to be any copy into the buffer so in general we don't
396     // need a definition event.
397     if (local_device->allocation_model() ==
398         LocalDeviceState::kComputeSynchronized) {
399       // The allocation is not valid until the compute stream passes this point,
400       // so add a definition event in the compute stream.
401       definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
402       TF_ASSIGN_OR_RETURN(EventPool::Handle event,
403                           local_device->event_pool().ThenAllocateAndRecordEvent(
404                               local_device->compute_stream()));
405       definition_events.back()->SetSequencingEvent(
406           std::move(event), local_device->compute_stream());
407     }
408     // if the caller provided a definition event then we record that.
409     if (definition_event) {
410       definition_events.emplace_back(definition_event);
411     }
412   } else {
413     // We have at least one definition event, for the copy completing to
414     // the device buffers.
415     if (definition_event) {
416       definition_events.emplace_back(definition_event);
417     } else {
418       definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
419     }
420   }
421   se::Stream* tuple_table_stream = local_device->host_to_device_stream();
422   if (on_device_shape.IsTuple()) {
423     // We also need to copy the tuple tables, so we'll have an additional
424     // definition event for that copy to complete.
425     if (tuple_table_stream != copy_stream) {
426       if (local_device->allocation_model() ==
427           LocalDeviceState::kComputeSynchronized) {
428         tuple_table_stream->ThenWaitFor(local_device->compute_stream());
429       } else {
430         DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
431             local_device->compute_stream()->parent(), dst_buffer));
432       }
433     }
434 
435     TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
436         tuple_table_stream, dst_buffer));
437     // CAUTION: From this point onwards we need to be careful about returning
438     // from error cases because we have started a transfer and must not allow
439     // dst_buffer to be freed too soon in the non-async allocation models.
440 
441     definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
442     StatusOr<EventPool::Handle> event_or =
443         local_device->event_pool().ThenAllocateAndRecordEvent(
444             tuple_table_stream);
445     if (!event_or.ok()) {
446       StallStreamOnError(local_device, tuple_table_stream);
447       return event_or.status();
448     }
449     definition_events.back()->SetSequencingEvent(event_or.ConsumeValueOrDie(),
450                                                  tuple_table_stream);
451   }
452   std::shared_ptr<TrackedDeviceBuffer> dst_device_buffer =
453       TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
454                                                   definition_events);
455 
456   auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
457       on_device_shape, std::move(dst_device_buffer), client, device);
458 
459   if (on_device_shape.IsTuple()) {
460     // Add a usage hold for the tuple table write and immediately convert it to
461     // the appropriate form of synchronization. prefer_to_retain_reference=false
462     // means don't retain a memory reference until the transfer is complete when
463     // using the ComputeSynchronized allocation model. This is a heuristic
464     // because in the common case destination buffers will be used on the
465     // compute stream and therefore don't require any synchronization before
466     // being freed. If the buffer is allocated and never used, the free will
467     // take longer and this is assumed to be ok.
468     RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
469                 definition_events.back(), tuple_table_stream,
470                 /*prefer_to_retain_reference=*/false);
471   }
472 
473   return py_buffer;
474 }
475 
476 // Adds necessary synchronization after a copy has been enqueued to a buffer.
477 // definition_event was added when the buffer was allocated, but has not yet
478 // had an event recorded.
AddDestinationBufferSynchronization(LocalDeviceState * local_device,PjRtStreamExecutorBuffer::ScopedHold device_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,se::Stream * copy_stream)479 Status AddDestinationBufferSynchronization(
480     LocalDeviceState* local_device,
481     PjRtStreamExecutorBuffer::ScopedHold device_buffer,
482     std::shared_ptr<BufferSequencingEvent> definition_event,
483     se::Stream* copy_stream) {
484   StatusOr<EventPool::Handle> event_or =
485       local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream);
486   if (!event_or.ok()) {
487     StallStreamOnError(local_device, copy_stream);
488     return event_or.status();
489   }
490   definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(),
491                                        copy_stream);
492   // prefer_to_retain_reference=false means don't retain a memory reference
493   // until the transfer is complete when using the ComputeSynchronized
494   // allocation model. This is a heuristic because in the common case
495   // destination buffers will be used on the compute stream and therefore don't
496   // require any synchronization before being freed. If the buffer is allocated
497   // and never used, the free will take longer and this is assumed to be ok.
498   RecordUsage(std::move(device_buffer), local_device, local_device,
499               definition_event, copy_stream,
500               /*prefer_to_retain_reference=*/false);
501   return Status::OK();
502 }
503 
504 }  // namespace
505 
~ScopedHold()506 PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
507   if (ok()) {
508     parent_->DropHold(type_, buffer().get());
509   }
510 }
511 
ScopedHold(ScopedHold && other)512 PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
513     : parent_(other.parent_),
514       type_(other.type_),
515       state_(other.state_),
516       status_(std::move(other.status_)),
517       buffer_(std::move(other.buffer_)) {
518   // Preserve the invariant that status is invalid if buffer == nullptr.
519   other.SetState(kMoved);
520 }
521 
Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>> && buffer_or)522 void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
523     StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
524   CHECK(!ok());
525   if (buffer_or.ok()) {
526     buffer_ = buffer_or.ValueOrDie();
527     SetState(kValid);
528   } else {
529     status_ = buffer_or.status();
530     buffer_ = nullptr;
531     SetState(kError);
532   }
533   // Check the invariant holds.
534   CHECK(!ok() || buffer_ != nullptr);
535 }
536 
537 PjRtStreamExecutorBuffer::ScopedHold::ForClosure
ToClosure()538 PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
539   CHECK(ok());
540   ForClosure for_closure(parent_, type_, state_, std::move(status_),
541                          std::move(buffer_));
542   SetState(kReleased);
543   return for_closure;
544 }
545 
ConvertUsageHold(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)546 void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
547     se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
548     bool reference_held) {
549   CHECK(ok());
550   CHECK_EQ(type_, kUsage);
551   parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
552                             reference_held);
553   SetState(kConverted);
554 }
555 
ConfirmDonation()556 void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
557   CHECK(ok());
558   CHECK_EQ(type_, kDonation);
559   parent_->ConfirmDonation(buffer().get());
560   SetState(kDonated);
561 }
562 
AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const563 void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
564     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
565     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
566     ExecutionInput* execution_input,
567     se::DeviceMemoryAllocator* allocator) const {
568   CHECK(ok());
569   if (type_ == kDonation) {
570     buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
571   } else {
572     CHECK_EQ(type_, kUsage);
573     buffer()->AddToInputAsImmutable(iterator, end);
574   }
575 }
576 
IsOnCpu() const577 bool PjRtStreamExecutorBuffer::IsOnCpu() const {
578   return client()->platform_id() == kCpuId;
579 }
580 
logical_on_device_shape()581 StatusOr<Shape> PjRtStreamExecutorBuffer::logical_on_device_shape() {
582   if (on_device_shape_.is_static()) {
583     return on_device_shape_;
584   }
585   auto* local_device = device_->local_device_state();
586   auto* stream = local_device->GetDeviceToHostStream();
587   ScopedHold device_buffer(this, ScopedHold::kUsage);
588   {
589     absl::MutexLock lock(&mu_);
590     // We can't perform any other action while a donation hold is in progress.
591     WaitForOutstandingDonationHold();
592     if (device_buffer_ == nullptr) {
593       return InvalidArgument(
594           "logical_on_device_shape() called on deleted or donated buffer");
595     }
596     AcquireHoldLocked(&device_buffer);
597   }
598 
599   WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
600   ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
601   StatusOr<EventPool::Handle> event_or =
602       local_device->event_pool().AllocateEvent(stream->parent());
603   if (!event_or.ok()) {
604     return event_or.status();
605   }
606   Shape ret_shape = on_device_shape_;
607   TransferManager* transfer_manager =
608       client_->client()->backend().transfer_manager();
609   TF_RETURN_IF_ERROR(
610       transfer_manager->ReadDynamicShapes(stream, &shaped_buffer, &ret_shape));
611   return ret_shape;
612 }
613 
614 namespace {
615 
616 // Implements PjRtBuffer::ExternalReference as a wrapped
617 // ScopedHold::kExternalReference.
618 class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
619  public:
ScopedHoldAsExternalReference(PjRtStreamExecutorBuffer::ScopedHold hold)620   explicit ScopedHoldAsExternalReference(
621       PjRtStreamExecutorBuffer::ScopedHold hold)
622       : external_reference_(std::move(hold)) {
623     CHECK(external_reference_.type() ==
624           PjRtStreamExecutorBuffer::ScopedHold::kExternalReference);
625     data_ptr_ = external_reference_->device_memory().front().opaque();
626   }
627 
628   ~ScopedHoldAsExternalReference() override = default;
629 
630  private:
631   PjRtStreamExecutorBuffer::ScopedHold external_reference_;
632 };
633 
634 }  // namespace
635 
636 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()637 PjRtStreamExecutorBuffer::AcquireExternalReference() {
638   ScopedHold hold = GetBufferWithExternalReference();
639   Status hold_status = hold.status();
640   if (!hold_status.ok()) return hold_status;
641   return std::unique_ptr<ExternalReference>(
642       std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
643 }
644 
645 class TrackedDeviceBufferExternalReference
646     : public PjRtBuffer::ExternalReference {
647  public:
TrackedDeviceBufferExternalReference(std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)648   explicit TrackedDeviceBufferExternalReference(
649       std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)
650       : tracked_device_buffer_(std::move(tracked_device_buffer)) {
651     data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque();
652   }
653 
654   ~TrackedDeviceBufferExternalReference() override = default;
655 
656  private:
657   std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer_;
658 };
659 
660 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)661 PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
662     bool wait_for_operations_to_complete) {
663   if (on_device_shape_.IsTuple()) {
664     return InvalidArgument(
665         "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
666   }
667   TF_ASSIGN_OR_RETURN(
668       std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
669       Release(wait_for_operations_to_complete));
670 
671   std::unique_ptr<PjRtBuffer::ExternalReference> ref;
672   if (tracked_device_buffer) {
673     ref = std::make_unique<TrackedDeviceBufferExternalReference>(
674         std::move(tracked_device_buffer));
675   }
676   return ref;
677 }
678 
679 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,absl::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)680 PjRtStreamExecutorClient::BufferFromHostBuffer(
681     const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
682     absl::optional<absl::Span<int64_t const>> byte_strides,
683     HostBufferSemantics host_buffer_semantics,
684     std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
685   tensorflow::profiler::TraceMe traceme(
686       "PjRtStreamExecutorClient::BufferFromHostBuffer");
687   Shape shape = ShapeUtil::MakeShape(type, dims);
688   VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
689           << shape.ToString() << " device: " << device->DebugString();
690   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
691                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
692                           ->GetLocalDeviceState());
693 
694   absl::InlinedVector<int64, 4> tmp_strides;
695   if (!byte_strides) {
696     tmp_strides.resize(dims.size());
697     TF_RETURN_IF_ERROR(
698         ShapeUtil::ByteStrides(shape, absl::MakeSpan(tmp_strides)));
699     byte_strides = tmp_strides;
700   }
701   int64_t size = ShapeUtil::ByteSizeOf(shape);
702 
703   TransferManager* transfer_manager = client()->backend().transfer_manager();
704   TF_ASSIGN_OR_RETURN(Shape compact_shape,
705                       transfer_manager->ChooseCompactLayoutForShape(shape));
706   absl::InlinedVector<int64_t, 4> compact_shape_strides(
707       compact_shape.dimensions_size());
708   TF_RETURN_IF_ERROR(ShapeUtil::ByteStrides(
709       compact_shape, absl::MakeSpan(compact_shape_strides)));
710   bool host_and_device_strides_equal =
711       (size == 0 || *byte_strides == compact_shape_strides);
712   // The CPU platform is special because the "host" and the "device" are in the
713   // same memory space. If the input shape is in the correct layout and we don't
714   // want to defer the copy onto a thread, we can use the following fast
715   // path.
716   bool is_cpu_platform =
717       local_device->executor()->platform()->id() == se::host::kHostPlatformId;
718   if (is_cpu_platform) {
719     // If we are on the host platform and the input buffer is sufficiently
720     // aligned, we can simply point to the input array's data without any
721     // further copies. At the time of writing we require a 16-byte alignment
722     // because XLA may generate code which requires it.
723     bool can_use_zero_copy =
724         host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
725         ((absl::bit_cast<std::uintptr_t>(data) &
726           (cpu_function_runtime::kMinAlign - 1)) == 0);
727     if (host_and_device_strides_equal &&
728         (host_buffer_semantics ==
729              HostBufferSemantics::kImmutableOnlyDuringCall ||
730          can_use_zero_copy)) {
731       std::function<void()> on_delete_callback;
732       se::DeviceMemoryBase buffer;
733       // If we are on the host platform and the input buffer is sufficiently
734       // aligned, we can simply point to the input array's data without any
735       // further copies. At the time of writing we require a 16-byte alignment
736       // because XLA may generate code which requires it.
737       if (can_use_zero_copy) {
738         on_delete_callback = std::move(on_done_with_host_buffer);
739         buffer = se::DeviceMemoryBase(
740             const_cast<void*>(static_cast<const void*>(data)), size);
741       } else {
742         void* staging_buffer = host_memory_allocator()->AllocateRaw(
743             cpu_function_runtime::kMinAlign, size);
744         buffer = se::DeviceMemoryBase(staging_buffer, size);
745         std::memcpy(staging_buffer, data, size);
746         if (on_done_with_host_buffer) {
747           on_done_with_host_buffer();
748         }
749         on_delete_callback = [staging_buffer, host_memory_allocator =
750                                                   host_memory_allocator()]() {
751           host_memory_allocator->DeallocateRaw(staging_buffer);
752         };
753       }
754       absl::Span<const std::shared_ptr<BufferSequencingEvent>>
755           definition_events;
756       auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
757           /*allocator=*/nullptr, local_device->device_ordinal(),
758           std::initializer_list<se::DeviceMemoryBase>{buffer},
759           definition_events, std::move(on_delete_callback));
760       return std::unique_ptr<PjRtBuffer>(
761           std::make_unique<PjRtStreamExecutorBuffer>(
762               shape, std::move(device_buffer), this, device));
763     }
764   }
765 
766   TF_ASSIGN_OR_RETURN(
767       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
768       AllocateDestinationBuffer(compact_shape, device, local_device,
769                                 local_device->host_to_device_stream(),
770                                 /*is_uninitialized_create=*/false, this));
771 
772   PjRtStreamExecutorBuffer::ScopedHold device_buffer(
773       py_buffer->GetBufferWithUsageHold());
774   CHECK(device_buffer.ok());
775 
776   // If necessary, allocate a host-side buffer for staging host-to-device
777   // transfers. On GPU this is a buffer in pinned memory.
778   std::shared_ptr<void> staging_buffer;
779   if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
780       should_stage_host_to_device_transfers() ||
781       !host_and_device_strides_equal) {
782     void* ptr = host_memory_allocator()->AllocateRaw(
783         tensorflow::Allocator::kAllocatorAlignment, size);
784     staging_buffer = std::shared_ptr<void>(
785         ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
786           host_memory_allocator->DeallocateRaw(ptr);
787         });
788   }
789 
790   std::shared_ptr<TransposePlan> transpose;
791   if (!host_and_device_strides_equal) {
792     absl::InlinedVector<int64_t, 4> permutation(dims.size());
793     absl::c_reverse_copy(compact_shape.layout().minor_to_major(),
794                          permutation.begin());
795     absl::MutexLock lock(&transpose_mu_);
796     TF_ASSIGN_OR_RETURN(transpose,
797                         transpose_cache_.GetOrCreate(
798                             primitive_util::ByteWidth(type), dims, permutation,
799                             TransposePlan::Striding{*byte_strides}));
800   }
801 
802   // Copy the buffer into a staging buffer before returning control to the
803   // caller if the caller only guaranteed that the buffer is valid for the
804   // duration of the call. Otherwise, we stage (if necessary) on a separate
805   // thread.
806   if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
807     if (transpose) {
808       transpose->Execute(data, staging_buffer.get());
809     } else {
810       std::memcpy(staging_buffer.get(), data, size);
811     }
812     if (on_done_with_host_buffer) {
813       on_done_with_host_buffer();
814       on_done_with_host_buffer = nullptr;
815     }
816   }
817 
818   // The host to device transfer is performed on a thread pool, mostly because
819   // it includes linearization that may be slow. It is OK to capture the
820   // py_buffer pointer because the py_buffer can't be deleted until all the
821   // usage holds have gone away.
822   // TODO(misard) assess if it would be preferable to introduce a heuristic to
823   // put the transfer into the calling thread for small literals.
824   auto transfer_h2d =
825       [local_client = client(), transfer_manager, local_device, data, size,
826        movable_device_buffer{device_buffer.ToClosure()}, shape,
827        py_buffer{py_buffer.get()},
828        on_device_shape{py_buffer->on_device_shape()},
829        staging_buffer{std::move(staging_buffer)},
830        on_done_with_host_buffer{std::move(on_done_with_host_buffer)},
831        host_buffer_semantics, transpose{std::move(transpose)}]() {
832         PjRtStreamExecutorBuffer::ScopedHold device_buffer(
833             movable_device_buffer);
834         // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
835         // to report failures from a callback. However, the operations here are
836         // unlikely to fail and not recoverable even if we were to fail: DMAs to
837         // memory that has already been allocated, and a possible Event
838         // allocation.
839 
840         ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
841         // If applicable on the backend, stage the transfer via host memory
842         // allocated via the host_memory_allocator. On GPU, this is pinned
843         // memory.
844         if (staging_buffer) {
845           // If we didn't already copy the input buffer into the staging buffer,
846           // do so now.
847           if (host_buffer_semantics !=
848               HostBufferSemantics::kImmutableOnlyDuringCall) {
849             if (transpose) {
850               transpose->Execute(data, staging_buffer.get());
851             } else {
852               std::memcpy(staging_buffer.get(), data, size);
853             }
854           }
855           // The buffer has the same dimension order as the on-device shape, but
856           // is not tiled, etc.
857           BorrowingLiteral literal(
858               static_cast<const char*>(staging_buffer.get()),
859               ShapeUtil::DeviceShapeToHostShape(on_device_shape));
860           TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
861               local_device->host_to_device_stream(), literal, buffer));
862         } else {
863           BorrowingLiteral literal(
864               reinterpret_cast<const char*>(data),
865               ShapeUtil::DeviceShapeToHostShape(on_device_shape));
866           // Otherwise, just transfer the literal.
867           TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
868               local_device->host_to_device_stream(), literal, buffer));
869         }
870 
871         std::shared_ptr<BufferSequencingEvent> event =
872             device_buffer->definition_events()[0];
873         TF_CHECK_OK(AddDestinationBufferSynchronization(
874             local_device, std::move(device_buffer), event,
875             local_device->host_to_device_stream()));
876 
877         local_device->ThenExecuteCallback(
878             local_device->host_to_device_stream(),
879             [staging_buffer{std::move(staging_buffer)},
880              on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
881               if (on_done_with_host_buffer) {
882                 on_done_with_host_buffer();
883               }
884             });
885       };
886   if (is_cpu_platform) {
887     // Using the thread_pool would be a double thread hop; the code
888     // already defers its work onto a stream (= thread on CPU).
889     transfer_h2d();
890   } else {
891     thread_pool()->Schedule(transfer_h2d);
892   }
893   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
894 }
895 
896 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)897 PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
898                                                     PjRtDevice* device) {
899   return CreateUninitializedBuffer(shape, device, nullptr);
900 }
901 
902 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device,std::shared_ptr<BufferSequencingEvent> definition_event)903 PjRtStreamExecutorClient::CreateUninitializedBuffer(
904     const Shape& shape, PjRtDevice* device,
905     std::shared_ptr<BufferSequencingEvent> definition_event) {
906   tensorflow::profiler::TraceMe traceme(
907       "PjRtStreamExecutorClient::CreateUninitializedBuffer");
908   VLOG(1) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
909           << shape.ToString() << " device: " << device->DebugString();
910   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
911                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
912                           ->GetLocalDeviceState());
913 
914   TransferManager* transfer_manager = client()->backend().transfer_manager();
915   TF_ASSIGN_OR_RETURN(Shape compact_shape,
916                       transfer_manager->ChooseCompactLayoutForShape(shape));
917 
918   TF_ASSIGN_OR_RETURN(
919       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
920       AllocateDestinationBuffer(compact_shape, device, local_device,
921                                 /*copy_stream=*/nullptr,
922                                 /*is_uninitialized_create=*/true, this,
923                                 definition_event));
924   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
925 }
926 
927 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)928 PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
929                                                 PjRtDevice* device) {
930   tensorflow::profiler::TraceMe traceme(
931       "PjRtStreamExecutorClient::BufferFromHostLiteral");
932   VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
933           << literal.shape().ToString() << " device: " << device->DebugString();
934   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
935                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
936                           ->GetLocalDeviceState());
937 
938   TransferManager* transfer_manager = client()->backend().transfer_manager();
939   TF_ASSIGN_OR_RETURN(
940       Shape compact_shape,
941       transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
942   TF_ASSIGN_OR_RETURN(
943       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
944       AllocateDestinationBuffer(compact_shape, device, local_device,
945                                 local_device->host_to_device_stream(),
946                                 /*is_uninitialized_create=*/false, this));
947 
948   PjRtStreamExecutorBuffer::ScopedHold device_buffer(
949       py_buffer->GetBufferWithUsageHold());
950   CHECK(device_buffer.ok());
951 
952   // The host to device transfer is performed on a thread pool, mostly because
953   // it includes linearization that may be slow. It is OK to capture the
954   // py_buffer pointer because the py_buffer can't be deleted until all the
955   // usage holds have gone away.
956   // TODO(misard) assess if it would be preferable to introduce a heuristic to
957   // put the transfer into the calling thread for small literals.
958   auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
959                        movable_device_buffer{device_buffer.ToClosure()},
960                        literal, py_buffer{py_buffer.get()},
961                        on_device_shape{py_buffer->on_device_shape()}]() {
962     PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
963     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
964     // to report failures from a callback. However, the operations here are
965     // unlikely to fail and not recoverable even if we were to fail: DMAs to
966     // memory that has already been allocated, and a possible Event
967     // allocation.
968 
969     se::Stream* h2d_stream = local_device->host_to_device_stream();
970     ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
971     TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
972         h2d_stream, literal, buffer));
973 
974     std::shared_ptr<BufferSequencingEvent> event =
975         device_buffer->definition_events()[0];
976     TF_CHECK_OK(AddDestinationBufferSynchronization(
977         local_device, std::move(device_buffer), event, h2d_stream));
978 
979     // This can sometimes catch the case where the literal memory has been
980     // freed before the H2D transfer was issued.
981     h2d_stream->RefreshStatus()
982         .IgnoreError();  // Can return error::Unimplemented
983     QCHECK(h2d_stream->ok());
984   };
985   thread_pool()->Schedule(transfer_h2d);
986   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
987 }
988 
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier && notifier)989 void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
990     absl::Span<const Shape> shapes, PjRtDevice* device,
991     PjRtCrossHostRecvNotifier&& notifier) {
992   if (shapes.empty()) {
993     notifier(InvalidArgument(
994         "shapes parameter empty in MakeCrossHostReceiveBuffers"));
995     return;
996   }
997 
998   auto local_device_or =
999       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1000           ->GetLocalDeviceState();
1001   if (!local_device_or.ok()) {
1002     notifier(local_device_or.status());
1003     return;
1004   }
1005   LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
1006   std::shared_ptr<BufferSequencingEvent> definition_event =
1007       std::make_shared<BufferSequencingEvent>();
1008   std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1009   buffers.reserve(shapes.size());
1010   for (const auto& shape : shapes) {
1011     StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
1012         shape, device, local_device,
1013         /*copy_stream=*/nullptr,
1014         /*is_uninitialized_create=*/false, this, definition_event);
1015     if (!buffer_or.ok()) {
1016       notifier(buffer_or.status());
1017       return;
1018     }
1019     buffers.push_back(buffer_or.ConsumeValueOrDie());
1020   }
1021 
1022   EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
1023                           std::move(notifier), absl::nullopt);
1024 }
MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier && notifier)1025 void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffersForGather(
1026     absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
1027     PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) {
1028   VLOG(2) << "Making " << gather_details.size()
1029           << " cross host receive buffers for gather";
1030   if (gather_details.empty()) {
1031     notifier(
1032         InvalidArgument("gather_details parameter empty in "
1033                         "MakeCrossHostReceiveBuffersForGather"));
1034     return;
1035   }
1036 
1037   if (shapes.size() != gather_details.size()) {
1038     notifier(
1039         InvalidArgument("gather_details parameter has length %lld but shapes "
1040                         "parameter has length %lld in "
1041                         "MakeCrossHostReceiveBuffersForGather",
1042                         gather_details.size(), shapes.size()));
1043     return;
1044   }
1045 
1046   auto local_device_or =
1047       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1048           ->GetLocalDeviceState();
1049   if (!local_device_or.ok()) {
1050     notifier(local_device_or.status());
1051     return;
1052   }
1053   LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
1054   std::shared_ptr<BufferSequencingEvent> definition_event =
1055       std::make_shared<BufferSequencingEvent>();
1056   std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1057   buffers.reserve(shapes.size());
1058   for (int i = 0; i < shapes.size(); ++i) {
1059     StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
1060         shapes[i], device, local_device,
1061         /*copy_stream=*/nullptr,
1062         /*is_uninitialized_create=*/false, this, definition_event);
1063     if (!buffer_or.ok()) {
1064       notifier(buffer_or.status());
1065       return;
1066     }
1067     buffers.push_back(buffer_or.ConsumeValueOrDie());
1068   }
1069 
1070   EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
1071                           std::move(notifier), gather_details);
1072 }
1073 
1074 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)1075 PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
1076     void* device_ptr, const Shape& shape, PjRtDevice* device,
1077     std::function<void()> on_delete_callback) {
1078   se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape));
1079   absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
1080   auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
1081       /*allocator=*/nullptr, device->local_hardware_id(),
1082       std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
1083       std::move(on_delete_callback));
1084   return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
1085       shape, std::move(device_buffer), this, device));
1086 }
1087 
1088 // Transfer the given literal to the infeed queue of the given local device.
TransferToInfeed(const LiteralSlice & literal)1089 Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
1090   // Only support infeed to local device.
1091   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1092   return local_device->client()->TransferToInfeedLocal(
1093       literal, local_device->device_ordinal());
1094 }
1095 
TransferFromOutfeed(MutableBorrowingLiteral literal)1096 Status PjRtStreamExecutorDevice::TransferFromOutfeed(
1097     MutableBorrowingLiteral literal) {
1098   VLOG(1) << "PjRtStreamExecutorDevice::TransferFromOutfeed";
1099   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1100   return local_device->client()->TransferFromOutfeedLocal(
1101       local_device->device_ordinal(), literal);
1102 }
1103 
LookupAddressableDevice(int local_hardware_id) const1104 StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
1105     int local_hardware_id) const {
1106   for (auto* device : addressable_devices_) {
1107     if (local_hardware_id == device->local_hardware_id()) {
1108       return device;
1109     }
1110   }
1111   return InvalidArgument("No matching device found for local_hardware_id %d",
1112                          local_hardware_id);
1113 }
1114 
PjRtStreamExecutorBuffer(Shape on_device_shape,std::shared_ptr<TrackedDeviceBuffer> device_buffer,PjRtClient * client,PjRtDevice * device)1115 PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
1116     Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
1117     PjRtClient* client, PjRtDevice* device)
1118     : client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
1119       on_device_shape_(std::move(on_device_shape)),
1120       device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
1121       device_buffer_(std::move(device_buffer)) {
1122   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1123     holds_[i] = 0;
1124   }
1125 }
1126 
~PjRtStreamExecutorBuffer()1127 PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
1128   Delete();
1129   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1130     CHECK_EQ(holds_[i], 0);
1131   }
1132 }
1133 
WaitForOutstandingUsageHolds()1134 void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
1135   auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1136     return holds_[ScopedHold::kUsage] == 0;
1137   };
1138   mu_.Await(absl::Condition(&not_in_usage_hold));
1139 }
1140 
WaitForOutstandingDonationHold()1141 void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
1142   auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1143     return holds_[ScopedHold::kDonation] == 0;
1144   };
1145   mu_.Await(absl::Condition(&not_in_donation_hold));
1146 }
1147 
1148 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
Release(bool wait_for_operations_to_complete)1149 PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
1150   tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
1151   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1152   TrackedDeviceBuffer::StreamAndEventContainer events;
1153   {
1154     absl::MutexLock lock(&mu_);
1155     // We first wait for a donation hold to complete if there is one in
1156     // progress. If the donation succeeds via ConfirmDonation() then it will
1157     // set device_buffer_ to nullptr before returning to this thread.
1158     WaitForOutstandingDonationHold();
1159     if (device_buffer_ == nullptr) {
1160       return std::shared_ptr<TrackedDeviceBuffer>();
1161     }
1162     // Set device_buffer_ to null now so that no other
1163     // thread can add a hold while we are in WaitForOutstandingUsageHolds()
1164     // below.
1165     std::swap(device_buffer_, device_buffer);
1166     WaitForOutstandingUsageHolds();
1167     // Now that all holds have completed and no more can be added, we can get
1168     // the final set of usage events.
1169     events = device_buffer->LockUseAndTransferUsageEvents();
1170   }
1171   LocalDeviceState* local_device_state = device_->local_device_state();
1172   if (wait_for_operations_to_complete) {
1173     // Block the host until all usage events have completed. Usage events
1174     // dominate definition events, so this also waits for the buffer to be
1175     // defined.
1176     std::unique_ptr<se::Stream> stream;
1177     for (const auto& stream_and_event : events) {
1178       if (!stream_and_event.event->IsComplete()) {
1179         if (stream == nullptr) {
1180           stream = local_device_state->BorrowStreamFromPool();
1181         }
1182         stream_and_event.event->WaitForEventOnStream(stream.get());
1183       }
1184     }
1185     if (stream != nullptr) {
1186       TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1187       local_device_state->ReturnStreamToPool(std::move(stream));
1188     }
1189   } else {
1190     if (local_device_state->allocation_model() ==
1191         LocalDeviceState::kComputeSynchronized) {
1192       std::unique_ptr<se::Stream> block_stream;
1193       for (const auto& stream_and_event : events) {
1194         // We only need to do something for events that didn't already acquire a
1195         // reference to the buffer, and also which the compute stream didn't
1196         // already wait for. Based on our heuristics this rare case should only
1197         // occur when a buffer was copied to a device and then never used there.
1198         // In that case we get a new stream and use it to hold onto a reference
1199         // to the buffer until the events are complete.
1200         if (!stream_and_event.reference_held &&
1201             !stream_and_event.event->DefinedOn(
1202                 local_device_state->compute_stream()) &&
1203             !stream_and_event.event->IsComplete()) {
1204           if (block_stream == nullptr) {
1205             block_stream = local_device_state->BorrowStreamFromPool();
1206           }
1207           stream_and_event.event->WaitForEventOnStream(block_stream.get());
1208         }
1209       }
1210       if (block_stream != nullptr) {
1211         se::Stream* block_stream_ptr = block_stream.release();
1212         local_device_state->ThenExecuteCallback(
1213             block_stream_ptr,
1214             [device_buffer, block_stream_ptr, local_device_state]() {
1215               local_device_state->ReturnStreamToPool(
1216                   std::unique_ptr<se::Stream>(block_stream_ptr));
1217             });
1218       }
1219     }
1220   }
1221   return device_buffer;
1222 }
1223 
Delete()1224 void PjRtStreamExecutorBuffer::Delete() {
1225   VLOG(1) << "PjRtStreamExecutorBuffer::Delete";
1226   // When wait_for_reads_to_complete is false, Release should never fail.
1227   TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
1228 }
1229 
IsDeleted()1230 bool PjRtStreamExecutorBuffer::IsDeleted() {
1231   absl::MutexLock lock(&mu_);
1232   return device_buffer_ == nullptr;
1233 }
1234 
1235 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
GetBufferForHoldLocked(ScopedHold::Type type)1236 PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
1237   // All callers should have called WaitForOutstandingDonationHold().
1238   CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1239   if (type == ScopedHold::kDonation) {
1240     if (device_buffer_ == nullptr) {
1241       return InvalidArgument("Donation requested for invalid buffer");
1242     }
1243     if (holds_[ScopedHold::kExternalReference] > 0) {
1244       return InvalidArgument(
1245           "Donation requested for buffer with external reference");
1246     }
1247     // First add the donation hold.
1248     ++holds_[type];
1249     // Then wait for any usage holds to be dropped or converted. No new usage
1250     // holds can be added until we drop the donation hold so this wait will
1251     // complete eventually.
1252     WaitForOutstandingUsageHolds();
1253     // Because we added a donation hold, nobody could release the buffer while
1254     // we were waiting.
1255     CHECK(device_buffer_ != nullptr);
1256   } else {
1257     if (device_buffer_ == nullptr) {
1258       return InvalidArgument("Buffer has been deleted or donated.");
1259     } else {
1260       ++holds_[type];
1261     }
1262   }
1263   return device_buffer_;
1264 }
1265 
AcquireHoldLocked(ScopedHold * hold)1266 void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
1267   hold->Acquire(GetBufferForHoldLocked(hold->type()));
1268 }
1269 
ConvertUsageHold(TrackedDeviceBuffer * buffer,se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)1270 void PjRtStreamExecutorBuffer::ConvertUsageHold(
1271     TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
1272     std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
1273   absl::MutexLock lock(&mu_);
1274   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1275   buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
1276   CHECK_GT(holds_[ScopedHold::kUsage], 0);
1277   --holds_[ScopedHold::kUsage];
1278 }
1279 
ConfirmDonation(TrackedDeviceBuffer * device_buffer)1280 void PjRtStreamExecutorBuffer::ConfirmDonation(
1281     TrackedDeviceBuffer* device_buffer) {
1282   {
1283     absl::MutexLock lock(&mu_);
1284     CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1285     CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1286     CHECK_EQ(holds_[ScopedHold::kDonation], 1);
1287     holds_[ScopedHold::kDonation] = 0;
1288     CHECK(device_buffer_.get() == device_buffer);
1289     // As a sanity check ensure no more usage events can be added to the buffer.
1290     device_buffer->LockUseAndTransferUsageEvents();
1291     // Give up ownership of the device memory so we don't free it when the last
1292     // reference to device_buffer_ goes away.
1293     device_buffer->ReleaseDeviceMemory();
1294     // Make *this invalid so it can't be used again. Any threads blocking in
1295     // Release or GetBufferWithHold will see an invalid buffer and return.
1296     device_buffer_.reset();
1297   }
1298 }
1299 
DropHold(ScopedHold::Type type,TrackedDeviceBuffer * buffer)1300 void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
1301                                         TrackedDeviceBuffer* buffer) {
1302   absl::MutexLock lock(&mu_);
1303   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1304   CHECK_GT(holds_[type], 0);
1305   --holds_[type];
1306   if (type == ScopedHold::kDonation) {
1307     CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1308     CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1309     CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1310   }
1311 }
1312 
ToLiteral(MutableLiteralBase * literal,std::function<void (Status)> on_ready)1313 void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
1314                                          std::function<void(Status)> on_ready) {
1315   VLOG(1) << "PjRtStreamExecutorBuffer::ToLiteral";
1316   if (IsEmptyTuple()) {
1317     on_ready(InvalidArgument("ToLiteral called on empty tuple"));
1318     return;
1319   }
1320   LocalDeviceState* local_device = device_->local_device_state();
1321   se::Stream* stream = local_device->GetDeviceToHostStream();
1322   ScopedHold device_buffer(this, ScopedHold::kUsage);
1323   {
1324     absl::MutexLock lock(&mu_);
1325     // We can't perform any other action while a donation hold is in progress.
1326     WaitForOutstandingDonationHold();
1327     if (device_buffer_ == nullptr) {
1328       on_ready(InvalidArgument(
1329           "CopyToHostAsync() called on deleted or donated buffer"));
1330       return;
1331     }
1332     AcquireHoldLocked(&device_buffer);
1333   }
1334 
1335   WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
1336   ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
1337   StatusOr<EventPool::Handle> event_or =
1338       local_device->event_pool().AllocateEvent(stream->parent());
1339   if (!event_or.ok()) {
1340     on_ready(event_or.status());
1341     return;
1342   }
1343   client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
1344       stream, shaped_buffer, literal, std::move(on_ready));
1345 
1346   auto usage_event = std::make_shared<BufferSequencingEvent>();
1347   local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
1348   usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1349   // When using the ComputeSynchronized allocation model, retain a reference to
1350   // the device_buffer until the copy completes, to ensure that the buffer isn't
1351   // deleted or donated while it is still in use. The choice of retaining a
1352   // reference at the host is a heuristic; the alternative is to ensure, before
1353   // freeing the buffer, that the compute stream is synchronized past the
1354   // transfer, but it seems better to hold onto the buffer too long than to
1355   // stall the compute stream, particularly since the overwhelmingly common
1356   // use case of CopyToHostAsync will hold onto the reference long enough to
1357   // read the buffer in a subsequent call to ToLiteral.
1358   RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
1359               stream,
1360               /*prefer_to_retain_reference=*/true);
1361 }
1362 
GetOnDeviceSizeInBytes() const1363 StatusOr<size_t> PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() const {
1364   absl::MutexLock lock(&mu_);
1365   if (device_buffer_ == nullptr) {
1366     return InvalidArgument(
1367         "GetOnDeviceSizeInBytes called on deleted or donated buffer");
1368   }
1369   if (device_buffer_->device_memory().size() != 1) {
1370     return InvalidArgument(
1371         "GetOnDeviceSizeInBytes called on tuple-shaped buffer");
1372   }
1373   return device_buffer_->device_memory()[0].size();
1374 }
1375 
CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size,std::function<void (Status)> on_ready)1376 Status PjRtStreamExecutorBuffer::CopyRawToHost(
1377     void* dst, int64_t offset, int64_t transfer_size,
1378     std::function<void(Status)> on_ready) {
1379   return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size,
1380                                          std::move(on_ready));
1381 }
1382 
AsShapedBuffer() const1383 StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
1384   absl::MutexLock lock(&mu_);
1385   if (device_buffer_ == nullptr) {
1386     return InvalidArgument(
1387         "Attempted to fetch value of invalid/deleted buffer.");
1388   }
1389   return device_buffer_->AsShapedBuffer(on_device_shape_);
1390 }
1391 
1392 PjRtStreamExecutorBuffer::ScopedHold
GetBufferWithHold(ScopedHold::Type type)1393 PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
1394   absl::MutexLock lock(&mu_);
1395   // Ensure that at most one donation hold can be in progress at a time.
1396   WaitForOutstandingDonationHold();
1397   ScopedHold hold(this, type);
1398   AcquireHoldLocked(&hold);
1399   return hold;
1400 }
1401 
1402 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1403                    std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice * dst_device,LocalDeviceState * dst_local_device,LocalDeviceState * transfer_local_device,se::Stream * transfer_stream,std::shared_ptr<TrackedDeviceBuffer> src_device_buffer)1404 PjRtStreamExecutorBuffer::CopyToDeviceHelper(
1405     PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
1406     LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
1407     std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
1408   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
1409                       AllocateDestinationBuffer(
1410                           ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
1411                           dst_device, dst_local_device, transfer_stream,
1412                           /*is_uninitialized_create=*/false, client_));
1413 
1414   TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
1415 
1416   WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream);
1417 
1418   ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
1419   CHECK(dst_device_buffer.ok());
1420   ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
1421 
1422   // Copy the leaf buffers.
1423   StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
1424       [&]() -> StatusOr<std::shared_ptr<BufferSequencingEvent>> {
1425     for (const auto& leaf : src_buffer.buffers().leaves()) {
1426       const ShapeIndex& index = leaf.first;
1427       const se::DeviceMemoryBase& input_buffer = leaf.second;
1428       const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
1429       TF_RET_CHECK(input_buffer.size() == output_buffer.size())
1430           << "input: " << input_buffer.size()
1431           << " output: " << output_buffer.size();
1432       if (input_buffer.size() != 0) {
1433         TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
1434             transfer_stream, dst_local_device->compute_stream(), input_buffer,
1435             output_buffer));
1436       }
1437     }
1438     std::shared_ptr<BufferSequencingEvent> event =
1439         dst_device_buffer->definition_events()[0];
1440     TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization(
1441         transfer_local_device, std::move(dst_device_buffer), event,
1442         transfer_stream));
1443     return event;
1444   }();
1445   if (!copy_event_or.ok()) {
1446     StallStreamOnError(transfer_local_device, transfer_stream);
1447     if (transfer_local_device == dst_local_device) {
1448       // Some copies may have been enqueued before the error was returned, and
1449       // StallStreamOnError only makes sure the destination device is ok, so
1450       // make sure that the src buffer remains valid until after any transfers
1451       // have completed.
1452       device_->local_device_state()->ThenRelease(transfer_stream,
1453                                                  std::move(src_device_buffer));
1454     }
1455     return copy_event_or.status();
1456   }
1457 
1458   return std::pair<std::unique_ptr<PjRtBuffer>,
1459                    std::shared_ptr<BufferSequencingEvent>>(
1460       std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
1461       copy_event_or.ConsumeValueOrDie());
1462 }
1463 
CopyToDevice(PjRtDevice * dst_device)1464 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
1465     PjRtDevice* dst_device) {
1466   tensorflow::profiler::TraceMe traceme(
1467       "PjRtStreamExecutorBuffer::CopyToDevice");
1468   VLOG(1) << "PjRtStreamExecutorBuffer::CopyToDevice";
1469   if (dst_device == device_) {
1470     return InvalidArgument(
1471         "CopyToDevice cannot accept the same source and destination devices");
1472   }
1473 
1474   // Copying across PjRtClients involves a copy through the host.
1475   if (dst_device->client() != client_) {
1476     TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
1477     // Avoid use-after-free on `literal` due to unsequenced move and use.
1478     Literal* literal_pointer = literal.get();
1479     absl::InlinedVector<int64_t, 4> byte_strides(
1480         literal->shape().dimensions_size());
1481     TF_RETURN_IF_ERROR(
1482         ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides)));
1483     return dst_device->client()->BufferFromHostBuffer(
1484         literal_pointer->untyped_data(),
1485         literal_pointer->shape().element_type(),
1486         literal_pointer->shape().dimensions(), byte_strides,
1487         PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
1488         [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1489   }
1490 
1491   TF_ASSIGN_OR_RETURN(
1492       LocalDeviceState * dst_local_device,
1493       tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
1494           ->GetLocalDeviceState());
1495   LocalDeviceState* transfer_local_device =
1496       client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
1497                                                 : dst_local_device;
1498   CHECK_EQ(dst_local_device->allocation_model(),
1499            transfer_local_device->allocation_model());
1500 
1501   se::Stream* transfer_stream =
1502       transfer_local_device->GetDeviceToDeviceStream();
1503 
1504   ScopedHold src_device_buffer(this, ScopedHold::kUsage);
1505   {
1506     absl::MutexLock lock(&mu_);
1507     // We can't perform any other action while a donation hold is in progress.
1508     WaitForOutstandingDonationHold();
1509     if (device_buffer_ == nullptr) {
1510       return InvalidArgument(
1511           "CopyToDevice called on deleted or donated buffer");
1512     }
1513     AcquireHoldLocked(&src_device_buffer);
1514   }
1515 
1516   StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1517                      std::shared_ptr<BufferSequencingEvent>>>
1518       buffer_and_event_or = CopyToDeviceHelper(
1519           dst_device, dst_local_device, transfer_local_device, transfer_stream,
1520           src_device_buffer.buffer());
1521   if (!buffer_and_event_or.ok()) {
1522     return buffer_and_event_or.status();
1523   }
1524 
1525   auto& buffer_and_event = buffer_and_event_or.ValueOrDie();
1526   std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
1527   std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;
1528 
1529   // prefer_to_retain_reference=*/true means that, when using the
1530   // ComputeSynchronized allocation model, retain a reference to the
1531   // src_device_buffer until the copy completes. This is a heuristic; the
1532   // alternative is to ensure, before freeing the buffer, that the compute
1533   // stream is synchronized past the transfer, but it seems better to hold onto
1534   // the buffer too long than to stall the compute stream.
1535   RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
1536               transfer_local_device, event, transfer_stream,
1537               /*prefer_to_retain_reference=*/true);
1538 
1539   return std::move(buffer);
1540 }
1541 
CopyToRemoteDevice(absl::string_view serialized_descriptor)1542 Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
1543     absl::string_view serialized_descriptor) {
1544   VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDevice";
1545   return client_->CopyToRemoteDevice(this, serialized_descriptor);
1546 }
1547 
CopyToRemoteDeviceScattered(absl::Span<const std::string> serialized_descriptors,const ScatterDetails & scatter_details)1548 Status PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered(
1549     absl::Span<const std::string> serialized_descriptors,
1550     const ScatterDetails& scatter_details) {
1551   VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered";
1552   return client_->CopyToRemoteDeviceScattered(this, serialized_descriptors,
1553                                               scatter_details);
1554 }
1555 
BlockHostUntilReady()1556 Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
1557   tensorflow::profiler::TraceMe traceme(
1558       "PjRtStreamExecutorBuffer::BlockHostUntilReady");
1559   VLOG(1) << "PjRtStreamExecutorBuffer::BlockHostUntilReady";
1560   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1561   {
1562     absl::MutexLock lock(&mu_);
1563     if (device_buffer_ == nullptr) {
1564       return InvalidArgument(
1565           "BlockHostUntilReady() called on deleted or donated buffer");
1566     }
1567     device_buffer = device_buffer_;
1568   }
1569   LocalDeviceState* local_device_state = device_->local_device_state();
1570   std::unique_ptr<se::Stream> stream;
1571   for (auto& event : device_buffer->definition_events()) {
1572     if (!event->IsComplete()) {
1573       if (stream == nullptr) {
1574         stream = local_device_state->BorrowStreamFromPool();
1575       }
1576       event->WaitForEventOnStream(stream.get());
1577     }
1578   }
1579   if (stream != nullptr) {
1580     TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1581     local_device_state->ReturnStreamToPool(std::move(stream));
1582   }
1583   return Status::OK();
1584 }
1585 
1586 namespace {
1587 
1588 // Helper struct for the tuple that is transiently constructed to hold the
1589 // arguments of an execution.
1590 struct TupleHandle {
1591   // The ExecutionInput describing the tuple.
1592   ExecutionInput execution_input;
1593   // A definition event that has been recorded on the host_to_device stream
1594   // after the tuple table transfer.
1595   std::shared_ptr<BufferSequencingEvent> event;
1596 };
1597 
CheckCompatibleShapes(bool strict_shape_checking,const Shape & buffer_shape,const Shape & execution_shape,const TransferManager & transfer_manager,int parameter_index)1598 Status CheckCompatibleShapes(bool strict_shape_checking,
1599                              const Shape& buffer_shape,
1600                              const Shape& execution_shape,
1601                              const TransferManager& transfer_manager,
1602                              int parameter_index) {
1603   // TODO(misard) Support casting of tuple parameters.
1604   if (strict_shape_checking || buffer_shape.IsTuple()) {
1605     if (!ShapeUtil::Equal(buffer_shape, execution_shape)) {
1606       return InvalidArgument(
1607           "Executable expected shape %s for argument %d but got "
1608           "incompatible "
1609           "shape %s",
1610           ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1611           ShapeUtil::HumanStringWithLayout(buffer_shape));
1612     }
1613   } else {
1614     if (transfer_manager.GetByteSizeRequirement(buffer_shape) !=
1615         transfer_manager.GetByteSizeRequirement(execution_shape)) {
1616       return InvalidArgument(
1617           "Executable expected shape %s for argument %d but got "
1618           "incompatible "
1619           "shape %s",
1620           ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1621           ShapeUtil::HumanStringWithLayout(buffer_shape));
1622     }
1623   }
1624   return Status::OK();
1625 }
1626 
1627 // Makes a tuple from the arguments to an execution.
MakeTupleHelper(PjRtStreamExecutorClient * client,LocalDeviceState * local_device,bool strict_shape_checking,const Shape & tupled_parameter_shape,absl::Span<PjRtBuffer * const> py_buffers,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,int device_ordinal)1628 StatusOr<TupleHandle> MakeTupleHelper(
1629     PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
1630     bool strict_shape_checking, const Shape& tupled_parameter_shape,
1631     absl::Span<PjRtBuffer* const> py_buffers,
1632     absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1633     int device_ordinal) {
1634   se::DeviceMemoryAllocator* allocator = client->allocator();
1635   TransferManager* transfer_manager =
1636       client->client()->backend().transfer_manager();
1637 
1638   if (tupled_parameter_shape.tuple_shapes_size() != py_buffers.size()) {
1639     return InvalidArgument("Executable expected %lld parameters but got %lld",
1640                            tupled_parameter_shape.tuple_shapes_size(),
1641                            py_buffers.size());
1642   }
1643   for (int i = 0; i < py_buffers.size(); ++i) {
1644     TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1645         strict_shape_checking, py_buffers[i]->on_device_shape(),
1646         tupled_parameter_shape.tuple_shapes(i), *transfer_manager, i));
1647   }
1648 
1649   se::Stream* stream = local_device->host_to_device_stream();
1650   TF_ASSIGN_OR_RETURN(
1651       se::OwningDeviceMemory root_table_memory,
1652       allocator->Allocate(
1653           device_ordinal,
1654           transfer_manager->GetByteSizeRequirement(tupled_parameter_shape)));
1655 
1656   if (local_device->allocation_model() ==
1657       LocalDeviceState::kComputeSynchronized) {
1658     stream->ThenWaitFor(local_device->compute_stream());
1659   } else {
1660     DCHECK(transfer_manager->CanBufferBeAccessedNow(
1661         local_device->compute_stream()->parent(), root_table_memory.cref()));
1662   }
1663 
1664   ExecutionInput execution_input(tupled_parameter_shape);
1665   ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1666       execution_input.MutableBuffers()->begin();
1667   ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1668       execution_input.MutableBuffers()->end();
1669   // First set the root tuple table which is the first buffer in the ShapeTree.
1670   execution_input.SetBuffer(
1671       input_iterator->first,
1672       MaybeOwningDeviceMemory(std::move(root_table_memory)));
1673   ++input_iterator;
1674   // Then set each sub-tuple in turn from the parameters.
1675   for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
1676        device_buffers) {
1677     device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
1678                              allocator);
1679   }
1680   CHECK(input_iterator == iterator_end);
1681 
1682   TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
1683       stream, execution_input.Buffers()));
1684   StatusOr<EventPool::Handle> event_or =
1685       local_device->event_pool().ThenAllocateAndRecordEvent(stream);
1686   if (!event_or.ok()) {
1687     StallStreamOnError(local_device, stream);
1688     return event_or.status();
1689   }
1690 
1691   auto transfer_event = std::make_shared<BufferSequencingEvent>();
1692   transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1693   return TupleHandle({std::move(execution_input), std::move(transfer_event)});
1694 }
1695 
1696 // Converts a ScopedShapedBuffer returned from an execution into a
1697 // PjRtBuffer.
OutputBufferHelper(ScopedShapedBuffer * result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtClient * client,PjRtDevice * device,LocalDeviceState * local_device,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release)1698 std::unique_ptr<PjRtBuffer> OutputBufferHelper(
1699     ScopedShapedBuffer* result_buffer,
1700     std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
1701     PjRtDevice* device, LocalDeviceState* local_device,
1702     std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) {
1703   std::shared_ptr<TrackedDeviceBuffer> out_buffer =
1704       TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
1705                                                   {definition_event});
1706   auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
1707       result_buffer->on_device_shape(), std::move(out_buffer), client, device);
1708   RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
1709               definition_event, local_device->compute_stream(),
1710               /*prefer_to_retain_reference=*/false, &buffers_to_release);
1711   return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
1712 }
1713 }  // namespace
1714 
PjRtStreamExecutorExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,bool parameter_is_tupled_arguments,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,PjRtStreamExecutorClient * client)1715 PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
1716     std::vector<std::unique_ptr<LocalExecutable>> executables,
1717     bool parameter_is_tupled_arguments,
1718     std::shared_ptr<DeviceAssignment> device_assignment,
1719     std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1720     std::vector<PjRtDevice*> addressable_devices,
1721     PjRtStreamExecutorClient* client)
1722     : client_(client),
1723       device_assignment_(std::move(device_assignment)),
1724       parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1725       addressable_device_logical_ids_(
1726           std::move(addressable_device_logical_ids)),
1727       addressable_devices_(std::move(addressable_devices)) {
1728   TransferManager* transfer_manager =
1729       client_->client()->backend().transfer_manager();
1730   executables_.reserve(executables.size());
1731   for (auto& executable : executables) {
1732     const auto& computation_layout =
1733         executable->executable()->module().entry_computation_layout();
1734     std::vector<Shape> parameter_shapes;
1735     parameter_shapes.reserve(computation_layout.parameter_count());
1736     for (int i = 0; i < computation_layout.parameter_count(); ++i) {
1737       parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape(
1738           computation_layout.parameter_shape(i)));
1739     }
1740     executables_.emplace_back(std::move(executable));
1741     on_device_executable_parameter_shapes_.push_back(
1742         std::move(parameter_shapes));
1743   }
1744 
1745   int num_partitions;
1746   if (device_assignment_ == nullptr) {
1747     // This must go after `executables_` is initialized.
1748     VLOG(3) << "PjRtStreamExecutorExecutable portable single-core";
1749     num_partitions = 1;
1750     CHECK(addressable_devices_.empty());
1751   } else {
1752     // This must go after `executables_` is initialized.
1753     VLOG(3) << "PjRtStreamExecutorExecutable device_assignment:\n"
1754             << device_assignment_->ToString();
1755     CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
1756     CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
1757         << "Inconsistent local device count.";
1758     num_partitions = device_assignment_->computation_count();
1759   }
1760 
1761   // SPMD sharding produces a single executable for multiple partitions.
1762   if (executables_.size() > 1) {
1763     CHECK_EQ(num_partitions, executables_.size())
1764         << "Number of executables " << executables_.size()
1765         << " did not match number of partitions " << num_partitions;
1766   }
1767 }
1768 
SetUpDonation(bool tuple_inputs)1769 Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
1770   parameters_that_must_be_donated_.reserve(executables_.size());
1771   for (auto& executable : executables_) {
1772     TF_ASSIGN_OR_RETURN(std::vector<int> parameters_to_donate,
1773                         ComputeParametersThatMustBeDonated(
1774                             executable->executable()->module(), tuple_inputs));
1775     parameters_that_must_be_donated_.emplace_back(
1776         std::move(parameters_to_donate));
1777   }
1778   return Status::OK();
1779 }
1780 
name() const1781 absl::string_view PjRtStreamExecutorExecutable::name() const {
1782   Executable* executable = executables_[0]->executable();
1783   if (executable->has_module()) {
1784     return executable->module().name();
1785   } else {
1786     return "<unknown executable>";
1787   }
1788 }
1789 
ParametersThatMustBeDonated(int executable_idx) const1790 absl::Span<int const> PjRtStreamExecutorExecutable::ParametersThatMustBeDonated(
1791     int executable_idx) const {
1792   return parameters_that_must_be_donated_[executable_idx];
1793 }
1794 
1795 StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(int device_ordinal,const ExecuteOptions & options,absl::Span<const Shape> executable_parameter_shapes,absl::Span<PjRtBuffer * const> argument_handles,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,absl::flat_hash_set<BufferSequencingEvent * > & events) const1796 PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
1797     int device_ordinal, const ExecuteOptions& options,
1798     absl::Span<const Shape> executable_parameter_shapes,
1799     absl::Span<PjRtBuffer* const> argument_handles,
1800     absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1801     absl::flat_hash_set<BufferSequencingEvent*>& events) const {
1802   std::vector<ExecutionInput> execution_inputs;
1803   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1804   TransferManager* transfer_manager =
1805       client_->client()->backend().transfer_manager();
1806   // Lift tuple_handle outside the conditional so that the event it returns is
1807   // not destroyed until after the loop below that waits on events.
1808   absl::optional<TupleHandle> tuple_handle;
1809   if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1810     TF_ASSIGN_OR_RETURN(
1811         tuple_handle,
1812         MakeTupleHelper(client_, device_state, options.strict_shape_checking,
1813                         executable_parameter_shapes[0], argument_handles,
1814                         device_buffers, device_ordinal));
1815     events.insert(tuple_handle->event.get());
1816     execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
1817   } else {
1818     if (argument_handles.size() != executable_parameter_shapes.size()) {
1819       return InvalidArgument("Executable expected %lld arguments but got %lld",
1820                              executable_parameter_shapes.size(),
1821                              argument_handles.size());
1822     }
1823     execution_inputs.reserve(argument_handles.size());
1824     for (int i = 0; i < argument_handles.size(); ++i) {
1825       PjRtBuffer* handle = argument_handles[i];
1826 
1827       // Make an ExecutionInput from the device buffer.
1828       TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1829           options.strict_shape_checking, handle->on_device_shape(),
1830           executable_parameter_shapes[i], *transfer_manager, i));
1831       execution_inputs.emplace_back(executable_parameter_shapes[i]);
1832       ExecutionInput& execution_input = execution_inputs.back();
1833       ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1834           execution_input.MutableBuffers()->begin();
1835       ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1836           execution_input.MutableBuffers()->end();
1837       device_buffers[i].AddToInput(&input_iterator, iterator_end,
1838                                    &execution_input, client_->allocator());
1839       CHECK(input_iterator == iterator_end);
1840     }
1841   }
1842 
1843   for (BufferSequencingEvent* event : events) {
1844     event->WaitForEventOnStream(device_state->compute_stream());
1845   }
1846 
1847   return execution_inputs;
1848 }
1849 
1850 // Enqueues a computation onto the compute stream. Each buffer returned in
1851 // device_buffers has a usage hold added that must be dropped on error or
1852 // converted on success.
EnqueueExecution(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,int executable_idx,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device,std::vector<PjRtStreamExecutorBuffer::ScopedHold> * device_buffers,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<std::function<void ()>> & compute_callbacks) const1853 StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
1854     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1855     int executable_idx, const RunId& run_id, const ExecuteOptions& options,
1856     PjRtDevice* device,
1857     std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
1858     std::shared_ptr<DeviceAssignment> device_assignment,
1859     std::vector<std::function<void()>>& compute_callbacks) const {
1860   int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1861                            ->local_device_state()
1862                            ->device_ordinal();
1863   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1864   tensorflow::profiler::TraceMeConsumer activity(
1865       "PjRtStreamExecutorExecutable::EnqueueExecution",
1866       tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
1867   VLOG(3) << "Replica " << replica << ", partition " << partition
1868           << " mapped to device ordinal for execution: " << device_ordinal;
1869 
1870   absl::flat_hash_set<BufferSequencingEvent*> events;
1871   device_buffers->reserve(argument_handles.size());
1872   absl::Span<int const> donated_params =
1873       ParametersThatMustBeDonated(executable_idx);
1874   auto donate_it = donated_params.begin();
1875   for (int i = 0; i < argument_handles.size(); ++i) {
1876     auto* handle =
1877         tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
1878     if (handle->device() != device) {
1879       return InvalidArgument(
1880           "Buffer passed to Execute() as argument %d to replica %d is on "
1881           "device %s, but replica is assigned to device %s.",
1882           i, replica, handle->device()->DebugString(), device->DebugString());
1883     }
1884     bool must_donate = donate_it != donated_params.end() && *donate_it == i;
1885     if (must_donate) {
1886       ++donate_it;
1887     }
1888     device_buffers->emplace_back(handle->GetBufferWithHold(
1889         must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
1890                     : PjRtStreamExecutorBuffer::ScopedHold::kUsage));
1891     PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
1892         device_buffers->back();
1893     if (!device_buffer.ok()) {
1894       return InvalidArgument(
1895           "Invalid buffer passed to Execute() as argument %d to replica %d: "
1896           "%s",
1897           i, replica, device_buffer.status().ToString());
1898     }
1899     // If we are trying to donate the buffer wait on the usage events as well
1900     // as the definition events to ensure that all reads have been completed
1901     // before the buffer is mutated. Usage holds are excluded during a donation
1902     // hold so we know that the set of usage events won't be modified while we
1903     // are enqueueing.
1904     GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
1905                           &events);
1906   }
1907 
1908   if (options.arguments_are_tupled) {
1909     if (!parameter_is_tupled_arguments_) {
1910       return InvalidArgument(
1911           "Arguments may only be supplied as a tuple when the executable was "
1912           "compiled with a single tupled parameter");
1913     }
1914     if (argument_handles.size() != 1) {
1915       return InvalidArgument(
1916           "Option arguments_are_tupled was true but %d buffers were passed to "
1917           "execution",
1918           argument_handles.size());
1919     }
1920   }
1921 
1922   TF_ASSIGN_OR_RETURN(
1923       std::vector<ExecutionInput> execution_inputs,
1924       MakeExecutionInputsAndWaitForEvents(
1925           device_ordinal, options,
1926           on_device_executable_parameter_shapes_[executable_idx],
1927           argument_handles, *device_buffers, events));
1928 
1929   ExecutableRunOptions run_options;
1930   run_options.set_stream(device_state->compute_stream());
1931   run_options.set_host_to_device_stream(device_state->host_to_device_stream());
1932   run_options.set_allocator(client_->allocator());
1933   run_options.set_intra_op_thread_pool(
1934       client_->client()->backend().eigen_intra_op_thread_pool_device());
1935   run_options.set_device_assignment(device_assignment.get());
1936   run_options.set_run_id(run_id);
1937   run_options.set_rng_seed(device_state->GetNewPrngSeed());
1938   run_options.set_gpu_executable_run_options(client_->gpu_run_options());
1939   run_options.set_launch_id(options.launch_id);
1940   if (run_options.launch_id() != 0) {
1941     VLOG(3) << "launch id for " << name() << ": " << run_options.launch_id();
1942   }
1943 
1944   // The choice of where we wait is arbitrary; the reason for the wait is
1945   // pacing to avoid problems such as memory fragmentation and running ahead
1946   // too far, not for correctness. Placing it before the executable launch
1947   // allows the inputs for the next executable to be fetched even if the
1948   // launch is delayed.
1949   std::shared_ptr<Semaphore::ScopedReservation> compute_reservation;
1950   {
1951     tensorflow::profiler::TraceMe traceme("ComputeSemaphoreAcquire");
1952     compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
1953         device_state->compute_semaphore().ScopedAcquire(1));
1954   }
1955 
1956   StatusOr<ExecutionOutput> result_buffer_or_status =
1957       executables_[executable_idx]->RunAsync(std::move(execution_inputs),
1958                                              run_options);
1959 
1960   VLOG(1) << "Replica " << replica << " partition " << partition
1961           << " completed; ok=" << result_buffer_or_status.ok();
1962 
1963   if (!result_buffer_or_status.ok()) {
1964     return result_buffer_or_status.status();
1965   }
1966 
1967   if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1968     ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
1969     // If we used a transient tuple for the arguments we donated its root table
1970     // buffer. In that case, and/or if we donated any input buffers that were
1971     // not aliased, the donated buffers are going to be passed back to us via
1972     // the execution output. We need to ensure they aren't freed until after
1973     // execution completes. (Currently XLA does not support aliasing tuple
1974     // tables, so if any donated parameter is a tuple there will be donated but
1975     // unaliased buffers.)
1976     std::vector<se::OwningDeviceMemory> donated_memory =
1977         execution_output.ConsumeToBeReleased();
1978     absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
1979     donated_ptrs.reserve(donated_memory.size());
1980     for (se::OwningDeviceMemory& owning : donated_memory) {
1981       // Release the owning memory so we can pass it to the closure.
1982       donated_ptrs.push_back(owning.Release());
1983     }
1984     compute_callbacks.push_back(
1985         [references{std::make_tuple(executables_[executable_idx],
1986                                     compute_reservation, device_assignment)},
1987          donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
1988          device_ordinal]() {
1989           for (const auto& ptr : donated_ptrs) {
1990             TF_CHECK_OK(allocator->Deallocate(device_ordinal, ptr));
1991           }
1992         });
1993   } else {
1994     // Any donated memory returned by the ExecutionOutput can be immediately
1995     // freed.
1996     compute_callbacks.push_back(
1997         [to_release{std::make_tuple(executables_[executable_idx],
1998                                     compute_reservation,
1999                                     device_assignment)}]() {});
2000   }
2001 
2002   return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
2003 }
2004 
2005 std::vector<std::unique_ptr<PjRtBuffer>>
MakeOutputBuffers(int device_ordinal,const ExecuteOptions & options,ScopedShapedBuffer result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtDevice * device,std::vector<std::function<void ()>> & compute_callbacks,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release) const2006 PjRtStreamExecutorExecutable::MakeOutputBuffers(
2007     int device_ordinal, const ExecuteOptions& options,
2008     ScopedShapedBuffer result_buffer,
2009     std::shared_ptr<BufferSequencingEvent> definition_event, PjRtDevice* device,
2010     std::vector<std::function<void()>>& compute_callbacks,
2011     std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release)
2012     const {
2013   tensorflow::profiler::TraceMe traceme("MakeOutputBuffers");
2014   std::vector<std::unique_ptr<PjRtBuffer>> outputs;
2015   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2016   if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
2017     int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
2018     outputs.reserve(tuple_count);
2019     // Take ownership of each of the output values, leaving only the root table
2020     // in result_buffer.
2021     for (int i = 0; i < tuple_count; ++i) {
2022       ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i});
2023       outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event,
2024                                            client_, device, device_state,
2025                                            buffers_to_release));
2026     }
2027     if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
2028       // Don't release the root buffer until after execution completes.
2029       ShapedBuffer root_buffer_holder = result_buffer.release();
2030       se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer();
2031       compute_callbacks.push_back(
2032           [root_buffer, allocator{client_->allocator()}, device_ordinal]() {
2033             TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer));
2034           });
2035     }
2036   } else {
2037     outputs.push_back(OutputBufferHelper(&result_buffer, definition_event,
2038                                          client_, device, device_state,
2039                                          buffers_to_release));
2040   }
2041   return outputs;
2042 }
2043 
2044 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device) const2045 PjRtStreamExecutorExecutable::ExecuteHelper(
2046     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
2047     const RunId& run_id, const ExecuteOptions& options,
2048     PjRtDevice* device) const {
2049   const uint64 start_time_usecs = tensorflow::Env::Default()->NowMicros();
2050   std::shared_ptr<DeviceAssignment> device_assignment;
2051   if (device == nullptr) {
2052     CHECK(device_assignment_ != nullptr);
2053     const int device_id = (*device_assignment_)(replica, partition);
2054     TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
2055     device_assignment = device_assignment_;
2056   } else {
2057     CHECK(device_assignment_ == nullptr);
2058     CHECK_EQ(replica, 0);
2059     CHECK_EQ(partition, 0);
2060     CHECK(addressable_devices_.empty());
2061     device_assignment = std::make_shared<DeviceAssignment>(1, 1);
2062     (*device_assignment)(0, 0) = device->id();
2063   }
2064 
2065   CHECK_EQ(device->process_index(), client_->process_index());
2066   int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2067                            ->local_device_state()
2068                            ->device_ordinal();
2069   tensorflow::profiler::TraceMe traceme(
2070       "PjRtStreamExecutorExecutable::ExecuteHelper");
2071   VLOG(1) << "Replica " << replica << ", partition " << partition
2072           << " mapped to device ordinal for execution: " << device_ordinal;
2073 
2074   // SPMD sharding produces a single executable for multiple partitions.
2075   int executable_idx = executables_.size() > 1 ? partition : 0;
2076 
2077   std::vector<std::function<void()>> compute_callbacks;
2078   std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
2079   device_buffers.reserve(argument_handles.size());
2080   StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
2081       argument_handles, replica, partition, executable_idx, run_id, options,
2082       device, &device_buffers, std::move(device_assignment), compute_callbacks);
2083 
2084   if (!result_buffer_or_status.ok()) {
2085     LOG(ERROR) << "Execution of replica " << replica
2086                << " failed: " << result_buffer_or_status.status();
2087     return result_buffer_or_status.status();
2088   }
2089   ScopedShapedBuffer result_buffer =
2090       result_buffer_or_status.ConsumeValueOrDie();
2091 
2092   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2093   se::Stream* stream = device_state->compute_stream();
2094   StatusOr<EventPool::Handle> event_or =
2095       device_state->event_pool().ThenAllocateAndRecordEvent(stream);
2096   if (!event_or.ok()) {
2097     StallStreamOnError(device_state, stream);
2098     for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2099       if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
2100         // Even though there was an error we need to call ConfirmDonation, which
2101         // renders b invalid, since the computation has been enqueued and b has
2102         // been donated.
2103         b.ConfirmDonation();
2104       }
2105     }
2106     return event_or.status();
2107   }
2108   auto definition_event = std::make_shared<BufferSequencingEvent>();
2109   definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
2110   std::vector<std::shared_ptr<TrackedDeviceBuffer>> buffers_to_release;
2111   std::vector<std::unique_ptr<PjRtBuffer>> outputs = MakeOutputBuffers(
2112       device_ordinal, options, std::move(result_buffer), definition_event,
2113       device, compute_callbacks, buffers_to_release);
2114 
2115   for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2116     // prefer_to_retain_reference=false because when using the
2117     // ComputeSynchronized allocation model we don't need to retain a reference
2118     // to the device_buffer during execution because by definition the compute
2119     // stream is synchronized past the execution.
2120     if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
2121       RecordUsage(std::move(b), device_state, device_state, definition_event,
2122                   stream,
2123                   /*prefer_to_retain_reference=*/false, &buffers_to_release);
2124     } else {
2125       CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
2126       b.ConfirmDonation();
2127     }
2128   }
2129 
2130   if (!compute_callbacks.empty()) {
2131     device_state->ThenExecuteCallback(
2132         stream, [callbacks{std::move(compute_callbacks)},
2133                  buffers_to_release{std::move(buffers_to_release)}]() {
2134           for (auto& fn : callbacks) {
2135             fn();
2136           }
2137         });
2138   }
2139   ReportExecutableEnqueueTime(tensorflow::Env::Default()->NowMicros() -
2140                               start_time_usecs);
2141   return outputs;
2142 }
2143 
2144 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options)2145 PjRtStreamExecutorExecutable::Execute(
2146     absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
2147     const ExecuteOptions& options) {
2148   if (device_assignment_ == nullptr) {
2149     return InvalidArgument("Execute expects a non-null device_assignment");
2150   }
2151 
2152   RunId run_id;
2153   tensorflow::profiler::TraceMeProducer activity(
2154       "PjRtStreamExecutorExecutable::Execute",
2155       tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
2156 
2157   const int num_addressable_devices = addressable_devices_.size();
2158 
2159   if (argument_handles.size() != num_addressable_devices) {
2160     return InvalidArgument(
2161         "Attempted to execute with %d argument lists when local device "
2162         "count is %d (total replica count: %d, partition count: %d)",
2163         argument_handles.size(), num_addressable_devices, num_replicas(),
2164         num_partitions());
2165   }
2166 
2167   VLOG(1) << "Executing computation " << name()
2168           << "; num_replicas=" << num_replicas()
2169           << " num_partitions=" << num_partitions()
2170           << " num_addressable_devices=" << num_addressable_devices;
2171   std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
2172       num_addressable_devices);
2173   if (num_addressable_devices == 1) {
2174     // Fast-path if there is only one device — run the computation on the
2175     // current thread.
2176     const int replica = addressable_device_logical_ids_[0].replica;
2177     const int partition = addressable_device_logical_ids_[0].partition;
2178     results[0] =
2179         ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
2180   } else {
2181     absl::Mutex mu;
2182     int running = num_addressable_devices;
2183     int failed = 0;
2184     Status first_failure_status;
2185 
2186     for (int i = 0; i < num_addressable_devices; ++i) {
2187       const int replica = addressable_device_logical_ids_[i].replica;
2188       const int partition = addressable_device_logical_ids_[i].partition;
2189       PjRtDevice* device = addressable_devices_[i];
2190       const LocalDeviceState& device_state =
2191           *tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2192                ->local_device_state();
2193       device_state.execute_thread()->Schedule([&, replica, partition, i] {
2194         results[i] = ExecuteHelper(argument_handles[i], replica, partition,
2195                                    run_id, options);
2196 
2197         absl::MutexLock lock(&mu);
2198         --running;
2199         if (!results[i].ok()) {
2200           if (failed == 0) {
2201             first_failure_status = results[i].status();
2202           }
2203           ++failed;
2204         }
2205       });
2206     }
2207 
2208     auto done_running_or_failed = [&]() {
2209       mu.AssertHeld();
2210       return running == 0 || failed > 0;
2211     };
2212     absl::MutexLock lock(&mu);
2213     mu.Await(absl::Condition(&done_running_or_failed));
2214     if (failed > 0) {
2215       auto done_running = [&]() {
2216         mu.AssertHeld();
2217         return running == 0;
2218       };
2219       // If execution does not terminate within a reasonable amount of time,
2220       // we may be stuck at a cross-replica barrier on-device. Terminate the
2221       // process since that's the only way we can escape this situation at the
2222       // moment (b/130629719).
2223       if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
2224                                absl::Seconds(10))) {
2225         LOG(FATAL)
2226             << "Replicated computation launch failed, but not all replicas "
2227                "terminated. Aborting process to work around deadlock. "
2228                "Failure message (there may have been multiple failures, see "
2229                "the error log for all failures): \n\n"
2230             << first_failure_status.error_message();
2231       }
2232     }
2233   }
2234   VLOG(1) << "Replicated execution complete.";
2235 
2236   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
2237       num_addressable_devices);
2238   for (int i = 0; i < num_addressable_devices; ++i) {
2239     const int replica = addressable_device_logical_ids_[i].replica;
2240     const int partition = addressable_device_logical_ids_[i].partition;
2241     auto& statusor = results[i];
2242     if (!statusor.ok()) {
2243       if (num_addressable_devices == 1) {
2244         return statusor.status();
2245       } else {
2246         return AppendStatus(
2247             statusor.status(),
2248             absl::StrFormat("while running replica %d and partition %d of a "
2249                             "replicated computation (other "
2250                             "replicas may have failed as well).",
2251                             replica, partition));
2252       }
2253     }
2254     wrapped_results[i] = std::move(statusor.ValueOrDie());
2255   }
2256   return wrapped_results;
2257 }
2258 
2259 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2260 PjRtStreamExecutorExecutable::ExecuteSharded(
2261     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2262     const ExecuteOptions& options) {
2263   if (device_assignment_ == nullptr) {
2264     return InvalidArgument("ExecuteShard expects a non-null device_assignment");
2265   }
2266   for (int i = 0; i < addressable_devices_.size(); ++i) {
2267     if (addressable_devices_[i] == device) {
2268       VLOG(1) << "ExecuteShard executes computation " << name()
2269               << " on assigned replica/partition on device "
2270               << device->DebugString();
2271       return ExecuteHelper(
2272           argument_handles, addressable_device_logical_ids_[i].replica,
2273           addressable_device_logical_ids_[i].partition, RunId(), options);
2274     }
2275   }
2276   return InvalidArgument(
2277       "ExecuteShard attempted to execute on device id %d which is not "
2278       "addressable by this client",
2279       device->id());
2280 }
2281 
2282 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2283 PjRtStreamExecutorExecutable::ExecutePortable(
2284     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2285     const ExecuteOptions& options) {
2286   if (device_assignment_ != nullptr) {
2287     return InvalidArgument("ExecutePortable gets a non-portable executable");
2288   }
2289   if (num_replicas() != 1 || num_partitions() != 1) {
2290     return InvalidArgument(
2291         "ExecutePortable expects a single-core executable but gets "
2292         "one with %d replica %d partition",
2293         num_replicas(), num_partitions());
2294   }
2295   if (device == nullptr) {
2296     return InvalidArgument("ExecutePortable expects a device to be specified");
2297   }
2298   VLOG(1) << "ExecutePortable executes single-core portable executable "
2299           << name();
2300   return ExecuteHelper(argument_handles,
2301                        /*replica=*/0,
2302                        /*partition=*/0, RunId(), options, device);
2303 }
2304 
2305 StatusOr<std::vector<std::shared_ptr<HloModule>>>
GetHloModules() const2306 PjRtStreamExecutorExecutable::GetHloModules() const {
2307   std::vector<std::shared_ptr<HloModule>> modules;
2308   modules.reserve(executables().size());
2309   for (const auto& local_exec : executables()) {
2310     if (!local_exec->executable()->has_module()) {
2311       return InvalidArgument("Executable does not have HLO modules.");
2312     }
2313     modules.push_back(local_exec->executable()->shared_module());
2314   }
2315   return std::move(modules);
2316 }
2317 
2318 StatusOr<PjRtStreamExecutorClient::ExecutableExtras>
GetExecutableExtras(CompileOptions * options)2319 PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
2320   ExecutableExtras extras;
2321   std::shared_ptr<DeviceAssignment>& device_assignment =
2322       extras.device_assignment;
2323   std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2324       addressable_device_logical_ids = extras.addressable_device_logical_ids;
2325   std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2326 
2327   ExecutableBuildOptions& build_options = options->executable_build_options;
2328   if (!build_options.compile_thread_pool()) {
2329     build_options.set_compile_thread_pool(thread_pool());
2330   }
2331   if (!build_options.device_allocator()) {
2332     build_options.set_device_allocator(allocator());
2333   }
2334 
2335   int num_replicas;
2336   int num_partitions;
2337   TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
2338       options->compile_portable_executable, &options->executable_build_options,
2339       [this](int num_replicas, int num_partitions) {
2340         return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
2341       },
2342       &num_replicas, &num_partitions, &device_assignment));
2343 
2344   // Find devices that are addressable by this client/task.
2345   if (device_assignment != nullptr) {
2346     addressable_device_logical_ids.reserve(num_replicas * num_partitions);
2347     addressable_devices.reserve(num_replicas * num_partitions);
2348     for (int replica = 0; replica < num_replicas; ++replica) {
2349       for (int partition = 0; partition < num_partitions; ++partition) {
2350         int device_id = (*device_assignment)(replica, partition);
2351         TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
2352         if (device->process_index() != process_index()) {
2353           VLOG(3) << "Non-local device: " << device_id;
2354           continue;
2355         }
2356         PjRtExecutable::LogicalDeviceIds logica_device_ids;
2357         logica_device_ids.replica = replica;
2358         logica_device_ids.partition = partition;
2359         addressable_device_logical_ids.push_back(std::move(logica_device_ids));
2360         addressable_devices.push_back(device);
2361       }
2362     }
2363     if (addressable_devices.empty()) {
2364       return InvalidArgument(
2365           "Device assignment (%s) does not have any local devices.",
2366           device_assignment->ToString());
2367     }
2368 
2369     if (build_options.device_ordinal() < 0) {
2370       build_options.set_device_ordinal(
2371           addressable_devices.front()->local_hardware_id());
2372     }
2373   }
2374   return extras;
2375 }
2376 
Compile(const XlaComputation & computation,CompileOptions options)2377 StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
2378     const XlaComputation& computation, CompileOptions options) {
2379   tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
2380   VLOG(1) << "PjRtStreamExecutorClient::Compile";
2381 
2382   TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options));
2383   std::shared_ptr<DeviceAssignment>& device_assignment =
2384       extras.device_assignment;
2385   std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2386       addressable_device_logical_ids = extras.addressable_device_logical_ids;
2387   std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2388 
2389   std::vector<const Shape*> argument_layout_pointers;
2390   TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
2391       computation,
2392       [local_client = client()](Shape shape) {
2393         return local_client->backend()
2394             .transfer_manager()
2395             ->ChooseCompactLayoutForShape(shape);
2396       },
2397       options.argument_layouts, &options.executable_build_options,
2398       &argument_layout_pointers));
2399 
2400   TF_ASSIGN_OR_RETURN(
2401       std::vector<std::unique_ptr<LocalExecutable>> local_executables,
2402       client()->Compile(computation, argument_layout_pointers,
2403                         options.executable_build_options));
2404 
2405   auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
2406       std::move(local_executables), options.parameter_is_tupled_arguments,
2407       std::move(device_assignment), std::move(addressable_device_logical_ids),
2408       std::move(addressable_devices), this);
2409   TF_RETURN_IF_ERROR(
2410       executable->SetUpDonation(options.parameter_is_tupled_arguments));
2411   return std::unique_ptr<PjRtExecutable>(std::move(executable));
2412 }
2413 
2414 }  // namespace xla
2415