• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/synchronization/notification.h"
29 #include "absl/types/optional.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/client/executable_build_options.h"
32 #include "tensorflow/compiler/xla/client/local_client.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/layout.h"
35 #include "tensorflow/compiler/xla/literal.h"
36 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
37 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
38 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
39 #include "tensorflow/compiler/xla/pjrt/transpose.h"
40 #include "tensorflow/compiler/xla/service/computation_layout.h"
41 #include "tensorflow/compiler/xla/service/computation_placer.h"
42 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
43 #include "tensorflow/compiler/xla/service/hlo_module.h"
44 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
45 #include "tensorflow/compiler/xla/shape.h"
46 #include "tensorflow/compiler/xla/status.h"
47 #include "tensorflow/compiler/xla/statusor.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 #include "tensorflow/core/framework/allocator.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/platform/casts.h"
53 #include "tensorflow/core/platform/fingerprint.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/stream_executor/stream.h"
57 
58 namespace xla {
59 
60 class PjRtStreamExecutorDevice : public PjRtDevice {
61  public:
62   explicit PjRtStreamExecutorDevice(
63       int id, std::unique_ptr<LocalDeviceState> local_device_state,
64       std::string device_kind, int process_index = 0)
id_(id)65       : id_(id),
66         device_ordinal_(
67             local_device_state ? local_device_state->device_ordinal() : -1),
68         local_device_state_(std::move(local_device_state)),
69         process_index_(process_index),
70         device_kind_(std::move(device_kind)) {}
~PjRtStreamExecutorDevice()71   ~PjRtStreamExecutorDevice() override {}
72 
73   // Must set client exactly once.
SetClient(PjRtClient * client)74   void SetClient(PjRtClient* client) {
75     CHECK(client_ == nullptr);
76     client_ = client;
77   }
78 
process_index()79   int process_index() const override { return process_index_; }
80 
81   // Return `platform_id` from client.
82   PjRtPlatformId platform_id() const;
83 
84   // Return `platform_name` from client.
85   absl::string_view platform_name() const;
86 
client()87   PjRtClient* client() const override { return client_; }
88 
id()89   int id() const override { return id_; }
90 
IsAddressable()91   bool IsAddressable() const override { return device_ordinal_ != -1; }
92 
local_hardware_id()93   int local_hardware_id() const override { return device_ordinal_; }
94 
95   // If this is a device local to this host, returns a LocalDeviceState object
96   // that can be used to manipulate the device. Returns nullptr if the device is
97   // not local to this host.
local_device_state()98   LocalDeviceState* local_device_state() const {
99     return local_device_state_.get();
100   }
101 
102   // If this is a device local to this host, returns a LocalDeviceState object
103   // that can be used to manipulate the device. Returns an error if the device
104   // is not local to this host.
105   StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
106 
device_kind()107   absl::string_view device_kind() const override { return device_kind_; }
108 
109   std::string DebugString() const override;
110 
111   Status TransferToInfeed(const LiteralSlice& literal) override;
112 
113   Status TransferFromOutfeed(MutableBorrowingLiteral literal) override;
114 
115  private:
116   const int id_;
117   const int device_ordinal_;  // -1 means not local.
118   const std::unique_ptr<LocalDeviceState> local_device_state_;
119   const int process_index_;
120   const std::string device_kind_;
121   PjRtClient* client_ = nullptr;
122 };
123 
124 class PjRtStreamExecutorClient : public PjRtClient {
125  public:
126   // `allocator` may null, in which case the platform default allocator is used.
127   explicit PjRtStreamExecutorClient(
128       std::string platform_name, LocalClient* client,
129       std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
130       int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
131       std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
132       bool should_stage_host_to_device_transfers,
133       std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
134   ~PjRtStreamExecutorClient() override = default;
135 
process_index()136   int process_index() const override { return process_index_; }
137 
device_count()138   int device_count() const override { return devices_.size(); }
addressable_device_count()139   int addressable_device_count() const override {
140     return addressable_devices_.size();
141   }
devices()142   absl::Span<PjRtDevice* const> devices() const override { return devices_; }
addressable_devices()143   absl::Span<PjRtDevice* const> addressable_devices() const override {
144     return addressable_devices_;
145   }
146 
LookupDevice(int device_id)147   StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
148     auto it = id_to_device_.find(device_id);
149     if (it != id_to_device_.end()) {
150       return it->second;
151     }
152     return InvalidArgument("No matching device found for device_id %d",
153                            device_id);
154   }
155 
156   StatusOr<PjRtDevice*> LookupAddressableDevice(
157       int local_hardware_id) const override;
158 
platform_id()159   PjRtPlatformId platform_id() const override { return platform_id_; }
platform_name()160   absl::string_view platform_name() const override { return platform_name_; }
platform_version()161   absl::string_view platform_version() const override { return "<unknown>"; }
runtime_type()162   PjRtRuntimeType runtime_type() const override { return kStreamExecutor; }
163 
164   // Most platforms expect device-to-device transfers to be enqueued on the
165   // source d2d stream, but some platforms use the destination d2d stream. This
166   // function specifies which one the platform expects.
EnqueueD2DTransfersOnSrcStream()167   virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
168 
169   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
170       int num_replicas, int num_partitions) const override;
171 
172   StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
173       const XlaComputation& computation, CompileOptions options) override;
174 
ExecutableFingerprint(const PjRtExecutable & executable)175   StatusOr<absl::optional<std::string>> ExecutableFingerprint(
176       const PjRtExecutable& executable) const override {
177     return absl::optional<std::string>();
178   }
179 
SerializeExecutable(const PjRtExecutable & executable)180   StatusOr<std::string> SerializeExecutable(
181       const PjRtExecutable& executable) const override {
182     return Unimplemented("SerializeExecutable not implemented on %s",
183                          platform_name());
184   }
185 
DeserializeExecutable(absl::string_view serialized,CompileOptions options)186   StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
187       absl::string_view serialized, CompileOptions options) override {
188     return Unimplemented("DeserializeExecutable not implemented on %s",
189                          platform_name());
190   }
191 
192   StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override;
193 
194   // Creates a buffer on the device without initializing or copying any data.
195   // An optional `definition_event` may be speficied that can be used to
196   // ensure the buffer isn't referenced until some external mechanism has
197   // initialized the data.
198   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
199       const Shape& shape, PjRtDevice* device) override;
200   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
201       const Shape& shape, PjRtDevice* device,
202       std::shared_ptr<BufferSequencingEvent> definition_event);
203 
204   StatusOr<std::unique_ptr<PjRtClient::AsyncBufferTransferManager>>
CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)205   CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
206                                 PjRtDevice* device) override {
207     return Unimplemented("Async transfer to buffers not implemented");
208   };
209 
210   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
211       const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
212       absl::optional<absl::Span<int64_t const>> byte_strides,
213       HostBufferSemantics host_buffer_semantics,
214       std::function<void()> on_done_with_host_buffer,
215       PjRtDevice* device) override;
216 
217   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
218       const LiteralSlice& literal, PjRtDevice* device) override;
219 
220   void MakeCrossHostReceiveBuffers(
221       absl::Span<const Shape> shapes, PjRtDevice* device,
222       PjRtCrossHostRecvNotifier&& notifier) override;
223 
224   void MakeCrossHostReceiveBuffersForGather(
225       absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
226       PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) override;
227 
228   StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
229       void* device_ptr, const Shape& shape, PjRtDevice* device,
230       std::function<void()> on_delete_callback) override;
231 
CreateChannelHandle()232   StatusOr<ChannelHandle> CreateChannelHandle() override {
233     return client()->CreateChannelHandle();
234   }
CreateDeviceToHostChannelHandle()235   StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
236     return client()->CreateDeviceToHostChannelHandle();
237   }
CreateHostToDeviceChannelHandle()238   StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
239     return client()->CreateHostToDeviceChannelHandle();
240   }
241 
242   // TODO(zhangqiaorjc): Experimental. Will be removed.
Defragment()243   Status Defragment() override {
244     return Unimplemented("Defragment not implemented");
245   }
246 
device_state(int device_ordinal)247   LocalDeviceState& device_state(int device_ordinal) const {
248     return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
249                 addressable_devices_.at(device_ordinal))
250                 ->local_device_state();
251   }
client()252   LocalClient* client() const { return client_; }
allocator()253   se::DeviceMemoryAllocator* allocator() const { return allocator_; }
host_memory_allocator()254   tensorflow::Allocator* host_memory_allocator() const {
255     return host_memory_allocator_.get();
256   }
should_stage_host_to_device_transfers()257   bool should_stage_host_to_device_transfers() const {
258     return should_stage_host_to_device_transfers_;
259   }
260 
gpu_run_options()261   gpu::GpuExecutableRunOptions* gpu_run_options() const {
262     return gpu_run_options_.get();
263   }
264 
thread_pool()265   tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; }
266 
267  protected:
268   friend class PjRtStreamExecutorBuffer;
269 
EnqueueCrossHostReceive(std::vector<std::unique_ptr<PjRtBuffer>> && buffers,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtCrossHostRecvNotifier && notifier,absl::optional<std::vector<GatherDetails>> gather_details)270   virtual void EnqueueCrossHostReceive(
271       std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
272       std::shared_ptr<BufferSequencingEvent> definition_event,
273       PjRtCrossHostRecvNotifier&& notifier,
274       absl::optional<std::vector<GatherDetails>> gather_details) const {
275     notifier(Unimplemented("Cross host receives not implemented."));
276   }
277 
CopyToRemoteDevice(PjRtBuffer * buffer,absl::string_view serialized_descriptor)278   virtual Status CopyToRemoteDevice(
279       PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
280     return Unimplemented("Cross host sends not implemented.");
281   }
282 
CopyToRemoteDeviceScattered(PjRtBuffer * buffer,absl::Span<const std::string> serialized_descriptors,const PjRtBuffer::ScatterDetails & scatter_details)283   virtual Status CopyToRemoteDeviceScattered(
284       PjRtBuffer* buffer, absl::Span<const std::string> serialized_descriptors,
285       const PjRtBuffer::ScatterDetails& scatter_details) const {
286     return Unimplemented("Scattered cross host sends not implemented.");
287   }
288 
CopyRawSubBufferToHost(PjRtBuffer * buffer,void * dst,int64_t offset,int64_t transfer_size,std::function<void (Status)> on_ready)289   virtual Status CopyRawSubBufferToHost(PjRtBuffer* buffer, void* dst,
290                                         int64_t offset, int64_t transfer_size,
291                                         std::function<void(Status)> on_ready) {
292     return Unimplemented("Raw copies to host not implemented.");
293   }
294 
295   // Helper function for creating PjRtStreamExecutorExecutables. Modifies
296   // `options` in-place.
297   struct ExecutableExtras {
298     std::shared_ptr<DeviceAssignment> device_assignment;
299     std::vector<PjRtExecutable::LogicalDeviceIds>
300         addressable_device_logical_ids;
301     std::vector<PjRtDevice*> addressable_devices;
302   };
303   StatusOr<ExecutableExtras> GetExecutableExtras(CompileOptions* options);
304 
305   const PjRtPlatformId platform_id_;
306   const std::string platform_name_;
307   LocalClient* client_;
308 
309   // Allocator to be used for staging memory transfers to devices.
310   std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
311 
312   // Device memory allocator. If owned, the allocator must outlive the devices,
313   // because it is the device destructor that waits for any outstanding work to
314   // complete.
315   se::DeviceMemoryAllocator* allocator_;
316   std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
317 
318   // Includes all devices, including non-local devices on multi-host platforms.
319   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
320   // Pointers to `owned_devices_`.
321   std::vector<PjRtDevice*> devices_;
322   // Maps Device::id() to the corresponding Device. Includes all devices.
323   std::map<int, PjRtDevice*> id_to_device_;
324   // Local devices indexed by local device ordinal.
325   std::vector<PjRtDevice*> addressable_devices_;
326   int process_index_;
327 
328   // Should we always prefer to stage host-to-device transfers via memory
329   // allocated on host_memory_allocator_? True only on GPU, where we prefer to
330   // transfer via pinned memory.
331   bool should_stage_host_to_device_transfers_;
332 
333   std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
334 
335   tensorflow::thread::ThreadPool thread_pool_;
336 
337   absl::Mutex transpose_mu_;
338   TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_);
339 };
340 
341 // Converts a 2D set of Device objects indexed by [replica][partition] into an
342 // xla::DeviceAssignment.
343 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
344     absl::Span<const std::vector<PjRtDevice*>> devices);
345 
346 class PjRtStreamExecutorBuffer : public PjRtBuffer {
347  public:
348   // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold
349   // may not outlive its parent PjRtStreamExecutorBuffer.
350   //
351   // There are three types of hold, as follows:
352   //
353   // 1) Usage hold: a transient hold while an operation using the buffer is
354   //    being enqueued onto a stream.
355   // A client acquires a usage hold by calling
356   // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience
357   // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the
358   // hold should be released using a call to ConvertUsageHold. If the ScopedHold
359   // is deleted without ConvertUsageHold being called, e.g., on error, the hold
360   // is dropped. It is legal to drop a usage hold instead of calling
361   // ConvertUsageHold, even if the buffer was successfully enqueued, as long as
362   // the client ensures that all necessary synchronization has been done.
363   //
364   // 2) External hold: a potentially long-lived hold while the buffer is being
365   //    shared by an external framework, e.g., NumPy.
366   // A client acquires an external hold by calling
367   // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience
368   // wrapper GetBufferWithExternalReference and releases it by deleting the
369   // ScopedHold. The external framework should not modify the underlying buffer
370   // unless it is confident via its own synchronization that modifications do
371   // not race with reads from the PjRtStreamExecutorBuffer.
372   //
373   // 3) Donation hold: a transient hold while an execution that donates the
374   //    buffer is being enqueued onto the compute stream.
375   // A client acquires a donation hold by calling
376   // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue
377   // completes successfully the hold should be released using a call to
378   // ConfirmDonation after which the buffer is invalid. If the ScopedHold is
379   // deleted without ConfirmDonation being called, e.g., on error, the hold is
380   // dropped and the buffer remains valid. If the buffer is successfully
381   // enqueued the client *must* call ConfirmDonation.
382   //
383   // Donation holds behave like exclusive write locks: when a donation hold
384   // has been acquired, any attempt to acquire another hold of any type will
385   // block until the donation hold is dropped or confirmed. Acquiring a donation
386   // hold will fail with an error if there is any outstanding external hold, and
387   // will block if there are any outstanding usage holds until those holds are
388   // dropped or converted.
389   //
390   // Calls to PjRtStreamExecutorBuffer::Release (and transitively to
391   // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will
392   // block until all usage and donation holds are either deleted or
393   // converted/confirmed.
394   class ScopedHold {
395    public:
396     enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
397     // Use a State enum instead of encoding the state in an error Status to
398     // avoid creating Status values in non-error cases. Creating a Status
399     // entails several allocations and can add O(us) to every use of a hold.
400     enum State {
401       kUninitialized = 0,
402       kValid,
403       kMoved,
404       kConverted,
405       kReleased,
406       kDonated,
407       kError
408     };
409 
410     ~ScopedHold();
411     ScopedHold(ScopedHold&& other);
412     ScopedHold(const ScopedHold&) = delete;
413     ScopedHold& operator=(const ScopedHold&) = delete;
414 
type()415     Type type() const { return type_; }
416 
status()417     Status status() const {
418       // Lazily create Status values only when they are requested.
419       switch (state_) {
420         case kUninitialized:
421           return InvalidArgument("Buffer has not been initialized");
422         case kValid:
423           return Status::OK();
424         case kMoved:
425           return InvalidArgument("Buffer has been moved.");
426         case kConverted:
427           return InvalidArgument("Buffer has been converted");
428         case kReleased:
429           return InvalidArgument("Buffer has been released");
430         case kDonated:
431           return InvalidArgument("Buffer has been donated");
432         case kError:
433           return status_;
434         default:
435           CHECK(false) << "Unexpected state value " << state_;
436       }
437     }
ok()438     bool ok() const { return state_ == kValid; }
439 
440     // Access to the underlying device buffer storage. Requires this->ok().
buffer()441     const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
442       CHECK_EQ(state_, kValid);
443       CHECK_NE(buffer_, nullptr);
444       return buffer_;
445     }
446     TrackedDeviceBuffer* operator->() const { return buffer().get(); }
447     const TrackedDeviceBuffer& operator*() const { return *buffer(); }
448 
449     // Converts the hold into a usage event. Only valid for holds of type
450     // kUsage.
451     //
452     //   usage_stream:   the stream that the buffer was used on.
453     //   event:          an event that has been recorded on usage_stream after
454     //                   the buffer was used.
455     //   reference_held: true if and only if the caller has caused a
456     //                   reference to this->buffer() to stay live until after
457     //                   the host is sure that the usage (transfer or execution)
458     //                   has completed.
459     void ConvertUsageHold(se::Stream* usage_stream,
460                           std::shared_ptr<BufferSequencingEvent> event,
461                           bool reference_held);
462 
463     // Confirms that the buffer was successfully donated to an execution.
464     // Only valid for holds of type kDonation. Causes the buffer to become
465     // invalid.
466     void ConfirmDonation();
467 
468     // Adds the held device buffers in order to 'iterator'. Used to add the
469     // buffers to an ExecutionInput. We require but do not verify that
470     // 'iterator' when passed in is pointing to a sub-tuple of the
471     // ExecutionInput whose on_device_shape matches that of the
472     // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
473     // out of bounds. Donates the device buffers if the hold type is kDonation,
474     // otherwise retains ownership of the device buffers.
475     void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
476                     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
477                     ExecutionInput* execution_input,
478                     se::DeviceMemoryAllocator* allocator) const;
479 
480    private:
481     friend class PjRtStreamExecutorBuffer;
482     friend class PjRtStreamExecutorClient;
483 
484     // Helper struct that makes it possible to move a ScopedHold through a
485     // closure.
486     using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State,
487                                   Status, std::shared_ptr<TrackedDeviceBuffer>>;
488 
ScopedHold(PjRtStreamExecutorBuffer * parent,Type type)489     ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
490         : parent_(parent), type_(type), state_(kUninitialized) {}
ScopedHold(const ForClosure & closure_helper)491     explicit ScopedHold(const ForClosure& closure_helper)
492         : parent_(std::get<0>(closure_helper)),
493           type_(std::get<1>(closure_helper)),
494           state_(std::get<2>(closure_helper)),
495           status_(std::get<3>(closure_helper)),
496           buffer_(std::get<4>(closure_helper)) {
497       // Check the buffer is not in an error state.
498       CHECK(status_.ok() && buffer_ != nullptr);
499     }
500 
501     // Sets buffer state.
SetState(State state)502     void SetState(State state) { state_ = state; }
503 
504     // Sets buffer_ and status_. Called by parent_ to initialize the hold.
505     void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
506     // Releases the contents of *this, so *this can subsequently be
507     // deleted without releasing the parent's hold. Should be passed to the
508     // appropriate constructor of another ScopedHold, e.g., when a hold must be
509     // passed through a closure that is incompatible with std::move.
510     ForClosure ToClosure();
511 
512     PjRtStreamExecutorBuffer* const parent_;
513     const Type type_;
514 
515     // There is an invariant that if ok() then
516     // buffer_.ValueOrDie() != nullptr.
517     State state_;
518     Status status_;
519     std::shared_ptr<TrackedDeviceBuffer> buffer_;
520   };
521 
522   PjRtStreamExecutorBuffer(Shape on_device_shape,
523                            std::shared_ptr<TrackedDeviceBuffer> device_buffer,
524                            PjRtClient* client, PjRtDevice* device);
525   ~PjRtStreamExecutorBuffer() override;
526 
527   PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
528   PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
529   PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
530   PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
531 
on_device_shape()532   const Shape& on_device_shape() const override { return on_device_shape_; }
533   StatusOr<Shape> logical_on_device_shape() override;
device()534   PjRtStreamExecutorDevice* device() const override { return device_; }
platform_id()535   PjRtPlatformId platform_id() const { return client_->platform_id(); }
platform_name()536   absl::string_view platform_name() const { return client_->platform_name(); }
client()537   PjRtStreamExecutorClient* client() const override { return client_; }
IsEmptyTuple()538   bool IsEmptyTuple() const {
539     return on_device_shape_.IsTuple() &&
540            on_device_shape_.tuple_shapes_size() == 0;
541   }
542 
543   StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference()
544       override;
545 
546   StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
547       bool wait_for_operations_to_complete) override;
548 
549   using PjRtBuffer::ToLiteral;
550   void ToLiteral(MutableLiteralBase* literal,
551                  std::function<void(Status)> on_ready) override;
552 
553   StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
554 
555   Status CopyRawToHost(void* dst, int64_t offset, int64_t transfer_size,
556                        std::function<void(Status)> on_ready) override;
557 
558   // Drops the buffer's reference to its associated device memory, leaving the
559   // buffer in an invalid state. The memory will be freed lazily when all async
560   // operations using the buffer have completed, according to the allocation
561   // semantics of the underlying platform. Delete may briefly block if another
562   // thread is in the process of enqueuing an operation on this buffer, but it
563   // will never block for a stream operation to complete. If an external
564   // framework holds a reference to the TrackedDeviceBuffer via
565   // GetBufferWithExternalReference, the memory will not be freed until the
566   // external framework drops the reference.
567   void Delete() override;
568 
569   bool IsDeleted() override;
570 
571   // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
572   // PjRtBuffer retains ownership of the device buffers.
573   StatusOr<ShapedBuffer> AsShapedBuffer() const;
574 
575   // Returns a hold on the TrackedDeviceBuffer holding the device
576   // buffers. See comment on ScopedHold.
577   ScopedHold GetBufferWithHold(ScopedHold::Type type);
GetBufferWithUsageHold()578   ScopedHold GetBufferWithUsageHold() {
579     return GetBufferWithHold(ScopedHold::kUsage);
580   }
GetBufferWithExternalReference()581   ScopedHold GetBufferWithExternalReference() {
582     return GetBufferWithHold(ScopedHold::kExternalReference);
583   }
584 
585   StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
586       PjRtDevice* dst_device) override;
587 
588   Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override;
589 
590   Status CopyToRemoteDeviceScattered(
591       absl::Span<const std::string> serialized_descriptors,
592       const ScatterDetails& scatter_details) override;
593 
594   Status BlockHostUntilReady() override;
595 
596   bool IsOnCpu() const override;
597 
598   // Similar to Delete, drops the buffer's reference to its associated device
599   // memory, leaving the buffer in an invalid state, but returns the
600   // TrackedDeviceBuffer rather than freeing the device memory, so that another
601   // framework can take ownership of it. The buffer returned from Release may
602   // be safely dropped at any time even if it still has pending async
603   // operations. The client should call BlockHostUntilReady before calling
604   // Release with wait_for_operations_to_complete=false, to ensure that the host
605   // has synchronized past any outstanding write operations to the buffer. If
606   // wait_for_operations_to_complete=true the host will block until any
607   // potentially outstanding asynchronous operations have completed before
608   // returning, in which case it is safe to read or mutate the returned buffer.
609   // If the buffer was shared via an external reference it is the client's
610   // responsibility that accesses via that reference do not interfere with
611   // accesses via the buffer returned from Release.
612   StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
613       bool wait_for_operations_to_complete);
614 
615  private:
616   friend class PjRtClient;
617 
618   // Blocks in mu_.Await until there are no more usage holds.
619   void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
620 
621   // Blocks in mu_.Await until there is no donation hold.
622   void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
623 
624   // Adds a hold of 'type' and returns device_buffer_. Returns an error if
625   // device_buffer_ is null, or if a donation hold was requested when there is
626   // an outstanding external hold.
627   // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds()
628   // must be called first.)
629   StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
630       ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
631 
632   // Adds a hold of hold->type() and initializes `hold` with device_buffer_.
633   // Initializes hold with an error if device_buffer_ is null, or if a donation
634   // hold was requested when there is an outstanding external hold.
635   // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds()
636   // must be called first.)
637   void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
638 
639   // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
640   // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
641   // device_buffer_ was successfully enqueued on a stream.
642   void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
643                         std::shared_ptr<BufferSequencingEvent> event,
644                         bool reference_held);
645 
646   // Drops a donation hold and makes *this invalid for further use. Does a
647   // sanity check that buffer==device_buffer_. Called after device_buffer_ was
648   // successfully donated to an execution.
649   void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
650 
651   // Drops a hold without taking any other action. Does a sanity check that
652   // buffer==device_buffer_ or device_buffer_==nullptr.
653   void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
654 
655   StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
656                      std::shared_ptr<BufferSequencingEvent>>>
657   CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
658                      LocalDeviceState* transfer_local_device,
659                      se::Stream* transfer_stream,
660                      std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
661 
662   PjRtStreamExecutorClient* const client_;
663   const Shape on_device_shape_;
664   PjRtStreamExecutorDevice* const device_;
665 
666   mutable absl::Mutex mu_;
667   std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
668   // Count of holds on the buffer.
669   std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
670 };
671 
672 // Wraps one or more XLA LocalExecutables (one per partition, as specified by
673 // the build options).
674 class PjRtStreamExecutorExecutable : public PjRtExecutable {
675  public:
676   PjRtStreamExecutorExecutable(
677       std::vector<std::unique_ptr<LocalExecutable>> executables,
678       bool parameter_is_tupled_arguments,
679       std::shared_ptr<DeviceAssignment> device_assignment,
680       std::vector<LogicalDeviceIds> addressable_device_logical_ids,
681       std::vector<PjRtDevice*> addressable_devices,
682       PjRtStreamExecutorClient* client);
683 
684   ~PjRtStreamExecutorExecutable() override = default;
685 
client()686   PjRtStreamExecutorClient* client() const override { return client_; }
687 
688   absl::string_view name() const override;
689 
num_replicas()690   int num_replicas() const override {
691     return executables_[0]->build_options().num_replicas();
692   }
693 
num_partitions()694   int num_partitions() const override {
695     return executables_[0]->build_options().num_partitions();
696   }
697 
SizeOfGeneratedCodeInBytes()698   int64 SizeOfGeneratedCodeInBytes() const override {
699     int64_t size = 0;
700     for (auto& executable : executables_) {
701       size += executable->executable()->SizeOfGeneratedCodeInBytes();
702     }
703     return size;
704   }
705 
device_assignment()706   const DeviceAssignment& device_assignment() const override {
707     return *device_assignment_;
708   }
709 
addressable_device_logical_ids()710   absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
711       const override {
712     return addressable_device_logical_ids_;
713   }
714 
addressable_devices()715   absl::Span<PjRtDevice* const> addressable_devices() const override {
716     return addressable_devices_;
717   }
718 
719   // Return an HloModule per partition.
720   StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
721       const override;
722 
723   StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
724       absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
725       const ExecuteOptions& options) override;
726 
727   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
728       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
729       const ExecuteOptions& options) override;
730 
731   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
732       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
733       const ExecuteOptions& options) override;
734 
Delete()735   void Delete() override { executables_.clear(); }
736 
IsDeleted()737   bool IsDeleted() override { return executables_.empty(); }
738 
executables()739   absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
740     return executables_;
741   }
742 
743  protected:
parameter_is_tupled_arguments()744   bool parameter_is_tupled_arguments() const {
745     return parameter_is_tupled_arguments_;
746   }
747 
748  private:
749   friend class PjRtStreamExecutorClient;
750   friend class PjRtTpuClient;
751   friend class InternalPjRtTpuClient;
752   // Initializes information about which arguments to which executables must be
753   // donated due to aliases that were specified by the computation.
754   Status SetUpDonation(bool tuple_inputs);
755 
756   // Returns a sorted list of the parameters that must be donated. Derived
757   // classes may use custom logic.
758   virtual absl::Span<int const> ParametersThatMustBeDonated(
759       int executable_idx) const;
760 
761   virtual StatusOr<std::vector<ExecutionInput>>
762   MakeExecutionInputsAndWaitForEvents(
763       int device_ordinal, const ExecuteOptions& options,
764       absl::Span<const Shape> executable_parameter_shapes,
765       absl::Span<PjRtBuffer* const> argument_handles,
766       absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
767       absl::flat_hash_set<BufferSequencingEvent*>& events) const;
768 
769   StatusOr<ScopedShapedBuffer> EnqueueExecution(
770       absl::Span<PjRtBuffer* const> argument_handles, int replica,
771       int partition, int executable_idx, const RunId& run_id,
772       const ExecuteOptions& options, PjRtDevice* device,
773       std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
774       std::shared_ptr<DeviceAssignment> device_assignment,
775       std::vector<std::function<void()>>& compute_callbacks) const;
776 
777   virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
778       int device_ordinal, const ExecuteOptions& options,
779       ScopedShapedBuffer result_buffer,
780       std::shared_ptr<BufferSequencingEvent> definition_event,
781       PjRtDevice* device, std::vector<std::function<void()>>& compute_callbacks,
782       std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release)
783       const;
784 
785   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
786       absl::Span<PjRtBuffer* const> argument_handles, int replica,
787       int partition, const RunId& run_id, const ExecuteOptions& options,
788       PjRtDevice* device = nullptr) const;
789 
790   // Create shared pointers so we can free them after the execution: with
791   // asynchronous execution, the process being executed can outlive the
792   // executable itself.
793   PjRtStreamExecutorClient* const client_;
794   // One executable per partition.
795   std::vector<std::shared_ptr<LocalExecutable>> executables_;
796   // On device shapes of the executable parameters.
797   std::vector<std::vector<Shape>> on_device_executable_parameter_shapes_;
798   // Per-executable sorted vector of parameters that have any aliased buffers
799   // and thus must be donated when executing the computation.
800   std::vector<std::vector<int>> parameters_that_must_be_donated_;
801   std::shared_ptr<DeviceAssignment> device_assignment_;
802 
803   // True if the executables were compiled expecting arguments in a single
804   // tuple.
805   const bool parameter_is_tupled_arguments_;
806 
807   // The replica and partition indices of device_assignment_ to be run by this
808   // client. On single-host platforms without partitioning, this is all replicas
809   // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
810   // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
811   // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
812   std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
813 
814   // addressable_devices_[i] is the Device to which
815   // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
816   // unique_ptrs to play well with the Python bindings (see xla.cc).
817   std::vector<PjRtDevice*> addressable_devices_;
818 };
819 
820 }  // namespace xla
821 
822 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
823