• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/synchronization/notification.h"
29 #include "absl/types/optional.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/client/executable_build_options.h"
32 #include "tensorflow/compiler/xla/client/local_client.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/layout.h"
35 #include "tensorflow/compiler/xla/literal.h"
36 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
37 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
38 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
39 #include "tensorflow/compiler/xla/service/computation_placer.h"
40 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
43 #include "tensorflow/compiler/xla/shape.h"
44 #include "tensorflow/compiler/xla/status.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 #include "tensorflow/core/framework/allocator.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/platform/casts.h"
51 #include "tensorflow/core/platform/fingerprint.h"
52 #include "tensorflow/core/platform/thread_annotations.h"
53 #include "tensorflow/core/platform/types.h"
54 
55 namespace xla {
56 
57 class PjRtStreamExecutorDevice : public PjRtDevice {
58  public:
59   explicit PjRtStreamExecutorDevice(
60       int id, std::unique_ptr<LocalDeviceState> local_device_state,
61       std::string device_kind, int task_id = 0)
id_(id)62       : id_(id),
63         device_ordinal_(
64             local_device_state ? local_device_state->device_ordinal() : -1),
65         local_device_state_(std::move(local_device_state)),
66         task_id_(task_id),
67         device_kind_(std::move(device_kind)) {}
~PjRtStreamExecutorDevice()68   ~PjRtStreamExecutorDevice() override {}
69 
70   // Must set client exactly once.
SetClient(PjRtClient * client)71   void SetClient(PjRtClient* client) {
72     CHECK(client_ == nullptr);
73     client_ = client;
74   }
75 
task_id()76   int task_id() const override { return task_id_; }
77 
78   // Return `platform_id` from client.
79   PjRtPlatformId platform_id() const;
80 
81   // Return `platform_name` from client.
82   absl::string_view platform_name() const;
83 
client()84   PjRtClient* client() const override { return client_; }
85 
id()86   int id() const override { return id_; }
87 
IsAddressable()88   bool IsAddressable() const override { return device_ordinal_ != -1; }
89 
local_hardware_id()90   int local_hardware_id() const override { return device_ordinal_; }
91 
92   // If this is a device local to this host, returns a LocalDeviceState object
93   // that can be used to manipulate the device. Returns nullptr if the device is
94   // not local to this host.
local_device_state()95   LocalDeviceState* local_device_state() const {
96     return local_device_state_.get();
97   }
98 
99   // If this is a device local to this host, returns a LocalDeviceState object
100   // that can be used to manipulate the device. Returns an error if the device
101   // is not local to this host.
102   StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
103 
device_kind()104   absl::string_view device_kind() const override { return device_kind_; }
105 
106   std::string DebugString() const override;
107 
108   Status TransferToInfeed(const LiteralSlice& literal) override;
109 
110   Status TransferFromOutfeed(MutableBorrowingLiteral literal) override;
111 
112  private:
113   const int id_;
114   const int device_ordinal_;  // -1 means not local.
115   const std::unique_ptr<LocalDeviceState> local_device_state_;
116   const int task_id_;
117   const std::string device_kind_;
118   PjRtClient* client_ = nullptr;
119 };
120 
121 class PjRtStreamExecutorClient : public PjRtClient {
122  public:
123   // `allocator` may null, in which case the platform default allocator is used.
124   explicit PjRtStreamExecutorClient(
125       std::string platform_name, LocalClient* client,
126       std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
127       int task_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
128       std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
129       bool should_stage_host_to_device_transfers,
130       std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
131   ~PjRtStreamExecutorClient() override = default;
132 
task_id()133   int task_id() const override { return task_id_; }
134 
device_count()135   int device_count() const override { return devices_.size(); }
addressable_device_count()136   int addressable_device_count() const override {
137     return addressable_devices_.size();
138   }
devices()139   absl::Span<PjRtDevice* const> devices() const override { return devices_; }
addressable_devices()140   absl::Span<PjRtDevice* const> addressable_devices() const override {
141     return addressable_devices_;
142   }
143 
LookupDevice(int device_id)144   StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
145     auto it = id_to_device_.find(device_id);
146     if (it != id_to_device_.end()) {
147       return it->second;
148     }
149     return InvalidArgument("No matching device found for device_id %d",
150                            device_id);
151   }
152 
153   StatusOr<PjRtDevice*> LookupAddressableDevice(
154       int local_hardware_id) const override;
155 
platform_id()156   PjRtPlatformId platform_id() const override { return platform_id_; }
platform_name()157   absl::string_view platform_name() const override { return platform_name_; }
158 
159   // Most platforms expect device-to-device transfers to be enqueued on the
160   // source d2d stream, but some platforms use the destination d2d stream. This
161   // function specifies which one the platform expects.
EnqueueD2DTransfersOnSrcStream()162   virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
163 
164   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
165       int num_replicas, int num_partitions) const override;
166 
167   StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
168       const XlaComputation& computation, CompileOptions options) override;
169 
ExecutableFingerprint(const PjRtExecutable & executable)170   StatusOr<absl::optional<std::string>> ExecutableFingerprint(
171       const PjRtExecutable& executable) const override {
172     return absl::optional<std::string>();
173   }
174 
175   StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override;
176 
177   // Creates a buffer on the device without initializing or copying any data.
178   // An optional `definition_event` may be speficied that can be used to
179   // ensure the buffer isn't referenced until some external mechanism has
180   // initialized the data.
181   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
182       const Shape& shape, PjRtDevice* device) override;
183   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
184       const Shape& shape, PjRtDevice* device,
185       std::shared_ptr<BufferSequencingEvent> definition_event);
186 
187   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
188       const void* data, const Shape& shape,
189       HostBufferSemantics host_buffer_semantics,
190       std::function<void()> on_done_with_host_buffer,
191       PjRtDevice* device) override;
192 
193   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
194       const LiteralSlice& literal, PjRtDevice* device) override;
195 
196   void MakeCrossHostReceiveBuffers(
197       absl::Span<const Shape> shapes, PjRtDevice* device,
198       PjRtCrossHostRecvNotifier&& notifier) override;
199 
200   StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
201       void* device_ptr, const Shape& shape, PjRtDevice* device,
202       std::function<void()> on_delete_callback) override;
203 
CreateChannelHandle()204   StatusOr<ChannelHandle> CreateChannelHandle() override {
205     return client()->CreateChannelHandle();
206   }
CreateDeviceToHostChannelHandle()207   StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
208     return client()->CreateDeviceToHostChannelHandle();
209   }
CreateHostToDeviceChannelHandle()210   StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
211     return client()->CreateHostToDeviceChannelHandle();
212   }
213 
device_state(int device_ordinal)214   LocalDeviceState& device_state(int device_ordinal) const {
215     return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
216                 addressable_devices_.at(device_ordinal))
217                 ->local_device_state();
218   }
client()219   LocalClient* client() const { return client_; }
allocator()220   se::DeviceMemoryAllocator* allocator() const { return allocator_; }
host_memory_allocator()221   tensorflow::Allocator* host_memory_allocator() const {
222     return host_memory_allocator_.get();
223   }
should_stage_host_to_device_transfers()224   bool should_stage_host_to_device_transfers() const {
225     return should_stage_host_to_device_transfers_;
226   }
227 
gpu_run_options()228   gpu::GpuExecutableRunOptions* gpu_run_options() const {
229     return gpu_run_options_.get();
230   }
231 
thread_pool()232   tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; }
233 
234  protected:
235   friend class PjRtStreamExecutorBuffer;
EnqueueCrossHostReceive(std::vector<std::unique_ptr<PjRtBuffer>> && buffers,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtCrossHostRecvNotifier && notifier)236   virtual void EnqueueCrossHostReceive(
237       std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
238       std::shared_ptr<BufferSequencingEvent> definition_event,
239       PjRtCrossHostRecvNotifier&& notifier) const {
240     notifier(Unimplemented("Cross host receives not implemented."));
241   }
242 
CopyToRemoteDevice(PjRtBuffer * buffer,absl::string_view serialized_descriptor)243   virtual Status CopyToRemoteDevice(
244       PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
245     return Unimplemented("Cross host sends not implemented.");
246   }
247 
248   const PjRtPlatformId platform_id_;
249   const std::string platform_name_;
250   LocalClient* client_;
251 
252   // Allocator to be used for staging memory transfers to devices.
253   std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
254 
255   // Includes all devices, including non-local devices on multi-host platforms.
256   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
257   // Pointers to `owned_devices_`.
258   std::vector<PjRtDevice*> devices_;
259   // Maps Device::id() to the corresponding Device. Includes all devices.
260   std::map<int, PjRtDevice*> id_to_device_;
261   // Local devices indexed by local device ordinal.
262   std::vector<PjRtDevice*> addressable_devices_;
263   int task_id_;
264 
265   se::DeviceMemoryAllocator* allocator_;
266   std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
267 
268   // Should we always prefer to stage host-to-device transfers via memory
269   // allocated on host_memory_allocator_? True only on GPU, where we prefer to
270   // transfer via pinned memory.
271   bool should_stage_host_to_device_transfers_;
272 
273   std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
274 
275   tensorflow::thread::ThreadPool thread_pool_;
276 };
277 
278 // Converts a 2D set of Device objects indexed by [replica][partition] into an
279 // xla::DeviceAssignment.
280 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
281     absl::Span<const std::vector<PjRtDevice*>> devices);
282 
283 class PjRtStreamExecutorBuffer : public PjRtBuffer {
284  public:
285   // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold
286   // may not outlive its parent PjRtStreamExecutorBuffer.
287   //
288   // There are three types of hold, as follows:
289   //
290   // 1) Usage hold: a transient hold while an operation using the buffer is
291   //    being enqueued onto a stream.
292   // A client acquires a usage hold by calling
293   // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience
294   // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the
295   // hold should be released using a call to ConvertUsageHold. If the ScopedHold
296   // is deleted without ConvertUsageHold being called, e.g., on error, the hold
297   // is dropped. It is legal to drop a usage hold instead of calling
298   // ConvertUsageHold, even if the buffer was successfully enqueued, as long as
299   // the client ensures that all necessary synchronization has been done.
300   //
301   // 2) External hold: a potentially long-lived hold while the buffer is being
302   //    shared by an external framework, e.g., NumPy.
303   // A client acquires an external hold by calling
304   // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience
305   // wrapper GetBufferWithExternalReference and releases it by deleting the
306   // ScopedHold. The external framework should not modify the underlying buffer
307   // unless it is confident via its own synchronization that modifications do
308   // not race with reads from the PjRtStreamExecutorBuffer.
309   //
310   // 3) Donation hold: a transient hold while an execution that donates the
311   //    buffer is being enqueued onto the compute stream.
312   // A client acquires a donation hold by calling
313   // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue
314   // completes successfully the hold should be released using a call to
315   // ConfirmDonation after which the buffer is invalid. If the ScopedHold is
316   // deleted without ConfirmDonation being called, e.g., on error, the hold is
317   // dropped and the buffer remains valid. If the buffer is successfully
318   // enqueued the client *must* call ConfirmDonation.
319   //
320   // Donation holds behave like exclusive write locks: when a donation hold
321   // has been acquired, any attempt to acquire another hold of any type will
322   // block until the donation hold is dropped or confirmed. Acquiring a donation
323   // hold will fail with an error if there is any outstanding external hold, and
324   // will block if there are any outstanding usage holds until those holds are
325   // dropped or converted.
326   //
327   // Calls to PjRtStreamExecutorBuffer::Release (and transitively to
328   // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will
329   // block until all usage and donation holds are either deleted or
330   // converted/confirmed.
331   class ScopedHold {
332    public:
333     enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
334     // Use a State enum instead of encoding the state in an error Status to
335     // avoid creating Status values in non-error cases. Creating a Status
336     // entails several allocations and can add O(us) to every use of a hold.
337     enum State {
338       kUninitialized = 0,
339       kValid,
340       kMoved,
341       kConverted,
342       kReleased,
343       kDonated,
344       kError
345     };
346 
347     ~ScopedHold();
348     ScopedHold(ScopedHold&& other);
349     ScopedHold(const ScopedHold&) = delete;
350     ScopedHold& operator=(const ScopedHold&) = delete;
351 
type()352     Type type() const { return type_; }
353 
status()354     Status status() const {
355       // Lazily create Status values only when they are requested.
356       switch (state_) {
357         case kUninitialized:
358           return InvalidArgument("Buffer has not been initialized");
359         case kValid:
360           return Status::OK();
361         case kMoved:
362           return InvalidArgument("Buffer has been moved.");
363         case kConverted:
364           return InvalidArgument("Buffer has been converted");
365         case kReleased:
366           return InvalidArgument("Buffer has been released");
367         case kDonated:
368           return InvalidArgument("Buffer has been donated");
369         case kError:
370           return status_;
371         default:
372           CHECK(false) << "Unexpected state value " << state_;
373       }
374     }
ok()375     bool ok() const { return state_ == kValid; }
376 
377     // Access to the underlying device buffer storage. Requires this->ok().
buffer()378     const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
379       CHECK_EQ(state_, kValid);
380       CHECK_NE(buffer_, nullptr);
381       return buffer_;
382     }
383     TrackedDeviceBuffer* operator->() const { return buffer().get(); }
384     const TrackedDeviceBuffer& operator*() const { return *buffer(); }
385 
386     // Converts the hold into a usage event. Only valid for holds of type
387     // kUsage.
388     //
389     //   usage_stream:   the stream that the buffer was used on.
390     //   event:          an event that has been recorded on usage_stream after
391     //                   the buffer was used.
392     //   reference_held: true if and only if the caller has caused a
393     //                   reference to this->buffer() to stay live until after
394     //                   the host is sure that the usage (transfer or execution)
395     //                   has completed.
396     void ConvertUsageHold(se::Stream* usage_stream,
397                           std::shared_ptr<BufferSequencingEvent> event,
398                           bool reference_held);
399 
400     // Confirms that the buffer was successfully donated to an execution.
401     // Only valid for holds of type kDonation. Causes the buffer to become
402     // invalid.
403     void ConfirmDonation();
404 
405     // Adds the held device buffers in order to 'iterator'. Used to add the
406     // buffers to an ExecutionInput. We require but do not verify that
407     // 'iterator' when passed in is pointing to a sub-tuple of the
408     // ExecutionInput whose on_device_shape matches that of the
409     // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
410     // out of bounds. Donates the device buffers if the hold type is kDonation,
411     // otherwise retains ownership of the device buffers.
412     void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
413                     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
414                     ExecutionInput* execution_input,
415                     se::DeviceMemoryAllocator* allocator) const;
416 
417    private:
418     friend class PjRtStreamExecutorBuffer;
419     friend class PjRtStreamExecutorClient;
420 
421     // Helper struct that makes it possible to move a ScopedHold through a
422     // closure.
423     using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State,
424                                   Status, std::shared_ptr<TrackedDeviceBuffer>>;
425 
ScopedHold(PjRtStreamExecutorBuffer * parent,Type type)426     ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
427         : parent_(parent), type_(type), state_(kUninitialized) {}
ScopedHold(const ForClosure & closure_helper)428     explicit ScopedHold(const ForClosure& closure_helper)
429         : parent_(std::get<0>(closure_helper)),
430           type_(std::get<1>(closure_helper)),
431           state_(std::get<2>(closure_helper)),
432           status_(std::get<3>(closure_helper)),
433           buffer_(std::get<4>(closure_helper)) {
434       // Check the buffer is not in an error state.
435       CHECK(status_.ok() && buffer_ != nullptr);
436     }
437 
438     // Sets buffer state.
SetState(State state)439     void SetState(State state) { state_ = state; }
440 
441     // Sets buffer_ and status_. Called by parent_ to initialize the hold.
442     void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
443     // Releases the contents of *this, so *this can subsequently be
444     // deleted without releasing the parent's hold. Should be passed to the
445     // appropriate constructor of another ScopedHold, e.g., when a hold must be
446     // passed through a closure that is incompatible with std::move.
447     ForClosure ToClosure();
448 
449     PjRtStreamExecutorBuffer* const parent_;
450     const Type type_;
451 
452     // There is an invariant that if ok() then
453     // buffer_.ValueOrDie() != nullptr.
454     State state_;
455     Status status_;
456     std::shared_ptr<TrackedDeviceBuffer> buffer_;
457   };
458 
459   PjRtStreamExecutorBuffer(Shape on_device_shape,
460                            std::shared_ptr<TrackedDeviceBuffer> device_buffer,
461                            PjRtClient* client, PjRtDevice* device);
462   ~PjRtStreamExecutorBuffer() override;
463 
464   PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
465   PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
466   PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
467   PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
468 
on_device_shape()469   const Shape& on_device_shape() const override { return on_device_shape_; }
device()470   PjRtStreamExecutorDevice* device() const override { return device_; }
platform_id()471   PjRtPlatformId platform_id() const { return client_->platform_id(); }
platform_name()472   absl::string_view platform_name() const { return client_->platform_name(); }
client()473   PjRtStreamExecutorClient* client() const override { return client_; }
IsEmptyTuple()474   bool IsEmptyTuple() const {
475     return on_device_shape_.IsTuple() &&
476            on_device_shape_.tuple_shapes_size() == 0;
477   }
478 
479   int64 OnDeviceSizeInBytes() const override;
480 
481   StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference()
482       override;
483 
484   StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
485       bool wait_for_operations_to_complete) override;
486 
487   using PjRtBuffer::ToLiteral;
488   void ToLiteral(MutableLiteralBase* literal,
489                  std::function<void(Status)> on_ready) override;
490 
491   // Drops the buffer's reference to its associated device memory, leaving the
492   // buffer in an invalid state. The memory will be freed lazily when all async
493   // operations using the buffer have completed, according to the allocation
494   // semantics of the underlying platform. Delete may briefly block if another
495   // thread is in the process of enqueuing an operation on this buffer, but it
496   // will never block for a stream operation to complete. If an external
497   // framework holds a reference to the TrackedDeviceBuffer via
498   // GetBufferWithExternalReference, the memory will not be freed until the
499   // external framework drops the reference.
500   void Delete() override;
501 
502   bool IsDeleted() override;
503 
504   // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
505   // PjRtBuffer retains ownership of the device buffers.
506   StatusOr<ShapedBuffer> AsShapedBuffer() const;
507 
508   // Returns a hold on the TrackedDeviceBuffer holding the device
509   // buffers. See comment on ScopedHold.
510   ScopedHold GetBufferWithHold(ScopedHold::Type type);
GetBufferWithUsageHold()511   ScopedHold GetBufferWithUsageHold() {
512     return GetBufferWithHold(ScopedHold::kUsage);
513   }
GetBufferWithExternalReference()514   ScopedHold GetBufferWithExternalReference() {
515     return GetBufferWithHold(ScopedHold::kExternalReference);
516   }
517 
518   StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
519       PjRtDevice* dst_device) override;
520 
521   Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override;
522 
523   Status BlockHostUntilReady() override;
524 
525   bool IsOnCpu() const override;
526 
527   // Similar to Delete, drops the buffer's reference to its associated device
528   // memory, leaving the buffer in an invalid state, but returns the
529   // TrackedDeviceBuffer rather than freeing the device memory, so that another
530   // framework can take ownership of it. The buffer returned from Release may
531   // be safely dropped at any time even if it still has pending async
532   // operations. The client should call BlockHostUntilReady before calling
533   // Release with wait_for_operations_to_complete=false, to ensure that the host
534   // has synchronized past any outstanding write operations to the buffer. If
535   // wait_for_operations_to_complete=true the host will block until any
536   // potentially outstanding asynchronous operations have completed before
537   // returning, in which case it is safe to read or mutate the returned buffer.
538   // If the buffer was shared via an external reference it is the client's
539   // responsibility that accesses via that reference do not interfere with
540   // accesses via the buffer returned from Release.
541   StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
542       bool wait_for_operations_to_complete);
543 
544  private:
545   friend class PjRtClient;
546 
547   // Blocks in mu_.Await until there are no more usage holds.
548   void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
549 
550   // Blocks in mu_.Await until there is no donation hold.
551   void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
552 
553   // Adds a hold of 'type' and returns device_buffer_. Returns an error if
554   // device_buffer_ is null, or if a donation hold was requested when there is
555   // an outstanding external hold.
556   StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
557       ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
558 
559   // Adds a hold of hold->type() and initializes `hold` with device_buffer_.
560   // Initializes hold with an error if device_buffer_ is null, or if a donation
561   // hold was requested when there is an outstanding external hold.
562   void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
563 
564   // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
565   // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
566   // device_buffer_ was successfully enqueued on a stream.
567   void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
568                         std::shared_ptr<BufferSequencingEvent> event,
569                         bool reference_held);
570 
571   // Drops a donation hold and makes *this invalid for further use. Does a
572   // sanity check that buffer==device_buffer_. Called after device_buffer_ was
573   // successfully donated to an execution.
574   void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
575 
576   // Drops a hold without taking any other action. Does a sanity check that
577   // buffer==device_buffer_ or device_buffer_==nullptr.
578   void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
579 
580   StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
581                      std::shared_ptr<BufferSequencingEvent>>>
582   CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
583                      LocalDeviceState* transfer_local_device,
584                      se::Stream* transfer_stream,
585                      std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
586 
587   PjRtStreamExecutorClient* const client_;
588   const Shape on_device_shape_;
589   PjRtStreamExecutorDevice* const device_;
590 
591   mutable absl::Mutex mu_;
592   std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
593   // Count of holds on the buffer.
594   std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
595   // Semaphore used to ensure there is only one outstanding donation hold.
596   Semaphore donation_semaphore_;
597 };
598 
599 // Wraps one or more XLA LocalExecutables (one per partition, as specified by
600 // the build options).
601 class PjRtStreamExecutorExecutable : public PjRtExecutable {
602  public:
603   PjRtStreamExecutorExecutable(
604       std::vector<std::unique_ptr<LocalExecutable>> executables,
605       bool parameter_is_tupled_arguments,
606       std::shared_ptr<DeviceAssignment> device_assignment,
607       std::vector<LogicalDeviceIds> addressable_device_logical_ids,
608       std::vector<PjRtDevice*> addressable_devices,
609       PjRtStreamExecutorClient* client);
610 
611   ~PjRtStreamExecutorExecutable() override = default;
612 
client()613   PjRtStreamExecutorClient* client() const override { return client_; }
614 
615   absl::string_view name() const override;
616 
num_replicas()617   int num_replicas() const override {
618     return executables_[0]->build_options().num_replicas();
619   }
620 
num_partitions()621   int num_partitions() const override {
622     return executables_[0]->build_options().num_partitions();
623   }
624 
SizeOfGeneratedCodeInBytes()625   int64 SizeOfGeneratedCodeInBytes() const override {
626     int64 size = 0;
627     for (auto& executable : executables_) {
628       size += executable->executable()->SizeOfGeneratedCodeInBytes();
629     }
630     return size;
631   }
632 
device_assignment()633   const DeviceAssignment& device_assignment() const override {
634     return *device_assignment_;
635   }
636 
addressable_device_logical_ids()637   absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
638       const override {
639     return addressable_device_logical_ids_;
640   }
641 
addressable_devices()642   absl::Span<PjRtDevice* const> addressable_devices() const override {
643     return addressable_devices_;
644   }
645 
646   // Return an HloModule per partition.
647   StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
648       const override;
649 
650   StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
651       absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
652       const ExecuteOptions& options) override;
653 
654   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
655       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
656       const ExecuteOptions& options) override;
657 
658   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
659       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
660       const ExecuteOptions& options) override;
661 
Delete()662   void Delete() override { executables_.clear(); }
663 
executables()664   absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
665     return executables_;
666   }
667 
668  protected:
parameter_is_tupled_arguments()669   bool parameter_is_tupled_arguments() const {
670     return parameter_is_tupled_arguments_;
671   }
672 
673  private:
674   friend class PjRtStreamExecutorClient;
675   // Initializes information about which arguments to which executables must be
676   // donated due to aliases that were specified by the computation.
677   Status SetUpDonation(bool tuple_inputs);
678 
679   virtual bool MustDonateParameter(int executable_idx, int parameter) const;
680 
681   virtual StatusOr<std::vector<ExecutionInput>>
682   MakeExecutionInputsAndWaitForEvents(
683       int device_ordinal, const ExecuteOptions& options,
684       absl::Span<PjRtBuffer* const> argument_handles,
685       absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
686       absl::flat_hash_set<BufferSequencingEvent*>& events) const;
687 
688   StatusOr<ScopedShapedBuffer> EnqueueExecution(
689       absl::Span<PjRtBuffer* const> argument_handles, int replica,
690       int partition, int executable_idx, const RunId& run_id,
691       const ExecuteOptions& options, PjRtDevice* device,
692       std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
693       std::shared_ptr<DeviceAssignment> device_assignment) const;
694 
695   virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
696       int device_ordinal, const ExecuteOptions& options,
697       ScopedShapedBuffer result_buffer,
698       std::shared_ptr<BufferSequencingEvent> definition_event,
699       PjRtDevice* device) const;
700 
701   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
702       absl::Span<PjRtBuffer* const> argument_handles, int replica,
703       int partition, const RunId& run_id, const ExecuteOptions& options,
704       PjRtDevice* device = nullptr) const;
705 
706   // Create shared pointers so we can free them after the execution: with
707   // asynchronous execution, the process being executed can outlive the
708   // executable itself.
709   PjRtStreamExecutorClient* const client_;
710   // One executable per partition.
711   std::vector<std::shared_ptr<LocalExecutable>> executables_;
712   // Per-executable set of parameters that have any aliased buffers and thus
713   // must be donated when executing the computation.
714   std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
715   std::shared_ptr<DeviceAssignment> device_assignment_;
716 
717   // True if the executables were compiled expecting arguments in a single
718   // tuple.
719   const bool parameter_is_tupled_arguments_;
720 
721   // The replica and partition indices of device_assignment_ to be run by this
722   // client. On single-host platforms without partitioning, this is all replicas
723   // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
724   // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
725   // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
726   std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
727 
728   // addressable_devices_[i] is the Device to which
729   // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
730   // unique_ptrs to play well with the Python bindings (see xla.cc).
731   std::vector<PjRtDevice*> addressable_devices_;
732 };
733 
734 // Executables can donate buffers so that buffers can be aliased from inputs
735 // to outputs. This function returns the list of parameters that must be
736 // donated when executable is run. tuple_inputs reflects the option that
737 // executable was compiled with.
738 StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
739     const HloModule& hlo_module, bool tuple_inputs);
740 
741 }  // namespace xla
742 
743 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
744