1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/container/inlined_vector.h" 26 #include "absl/strings/string_view.h" 27 #include "absl/synchronization/mutex.h" 28 #include "absl/synchronization/notification.h" 29 #include "absl/types/optional.h" 30 #include "absl/types/span.h" 31 #include "tensorflow/compiler/xla/client/executable_build_options.h" 32 #include "tensorflow/compiler/xla/client/local_client.h" 33 #include "tensorflow/compiler/xla/client/xla_computation.h" 34 #include "tensorflow/compiler/xla/layout.h" 35 #include "tensorflow/compiler/xla/literal.h" 36 #include "tensorflow/compiler/xla/pjrt/local_device_state.h" 37 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 38 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" 39 #include "tensorflow/compiler/xla/service/computation_placer.h" 40 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" 41 #include "tensorflow/compiler/xla/service/hlo_module.h" 42 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 43 #include "tensorflow/compiler/xla/shape.h" 44 #include "tensorflow/compiler/xla/status.h" 45 #include "tensorflow/compiler/xla/statusor.h" 46 #include "tensorflow/compiler/xla/util.h" 47 #include "tensorflow/compiler/xla/xla_data.pb.h" 48 #include "tensorflow/core/framework/allocator.h" 49 #include "tensorflow/core/lib/core/status.h" 50 #include "tensorflow/core/platform/casts.h" 51 #include "tensorflow/core/platform/fingerprint.h" 52 #include "tensorflow/core/platform/thread_annotations.h" 53 #include "tensorflow/core/platform/types.h" 54 55 namespace xla { 56 57 class PjRtStreamExecutorDevice : public PjRtDevice { 58 public: 59 explicit PjRtStreamExecutorDevice( 60 int id, std::unique_ptr<LocalDeviceState> local_device_state, 61 std::string device_kind, int task_id = 0) id_(id)62 : id_(id), 63 device_ordinal_( 64 local_device_state ? local_device_state->device_ordinal() : -1), 65 local_device_state_(std::move(local_device_state)), 66 task_id_(task_id), 67 device_kind_(std::move(device_kind)) {} ~PjRtStreamExecutorDevice()68 ~PjRtStreamExecutorDevice() override {} 69 70 // Must set client exactly once. SetClient(PjRtClient * client)71 void SetClient(PjRtClient* client) { 72 CHECK(client_ == nullptr); 73 client_ = client; 74 } 75 task_id()76 int task_id() const override { return task_id_; } 77 78 // Return `platform_id` from client. 79 PjRtPlatformId platform_id() const; 80 81 // Return `platform_name` from client. 82 absl::string_view platform_name() const; 83 client()84 PjRtClient* client() const override { return client_; } 85 id()86 int id() const override { return id_; } 87 IsAddressable()88 bool IsAddressable() const override { return device_ordinal_ != -1; } 89 local_hardware_id()90 int local_hardware_id() const override { return device_ordinal_; } 91 92 // If this is a device local to this host, returns a LocalDeviceState object 93 // that can be used to manipulate the device. Returns nullptr if the device is 94 // not local to this host. local_device_state()95 LocalDeviceState* local_device_state() const { 96 return local_device_state_.get(); 97 } 98 99 // If this is a device local to this host, returns a LocalDeviceState object 100 // that can be used to manipulate the device. Returns an error if the device 101 // is not local to this host. 102 StatusOr<LocalDeviceState*> GetLocalDeviceState() const; 103 device_kind()104 absl::string_view device_kind() const override { return device_kind_; } 105 106 std::string DebugString() const override; 107 108 Status TransferToInfeed(const LiteralSlice& literal) override; 109 110 Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; 111 112 private: 113 const int id_; 114 const int device_ordinal_; // -1 means not local. 115 const std::unique_ptr<LocalDeviceState> local_device_state_; 116 const int task_id_; 117 const std::string device_kind_; 118 PjRtClient* client_ = nullptr; 119 }; 120 121 class PjRtStreamExecutorClient : public PjRtClient { 122 public: 123 // `allocator` may null, in which case the platform default allocator is used. 124 explicit PjRtStreamExecutorClient( 125 std::string platform_name, LocalClient* client, 126 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, 127 int task_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator, 128 std::unique_ptr<tensorflow::Allocator> host_memory_allocator, 129 bool should_stage_host_to_device_transfers, 130 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options); 131 ~PjRtStreamExecutorClient() override = default; 132 task_id()133 int task_id() const override { return task_id_; } 134 device_count()135 int device_count() const override { return devices_.size(); } addressable_device_count()136 int addressable_device_count() const override { 137 return addressable_devices_.size(); 138 } devices()139 absl::Span<PjRtDevice* const> devices() const override { return devices_; } addressable_devices()140 absl::Span<PjRtDevice* const> addressable_devices() const override { 141 return addressable_devices_; 142 } 143 LookupDevice(int device_id)144 StatusOr<PjRtDevice*> LookupDevice(int device_id) const override { 145 auto it = id_to_device_.find(device_id); 146 if (it != id_to_device_.end()) { 147 return it->second; 148 } 149 return InvalidArgument("No matching device found for device_id %d", 150 device_id); 151 } 152 153 StatusOr<PjRtDevice*> LookupAddressableDevice( 154 int local_hardware_id) const override; 155 platform_id()156 PjRtPlatformId platform_id() const override { return platform_id_; } platform_name()157 absl::string_view platform_name() const override { return platform_name_; } 158 159 // Most platforms expect device-to-device transfers to be enqueued on the 160 // source d2d stream, but some platforms use the destination d2d stream. This 161 // function specifies which one the platform expects. EnqueueD2DTransfersOnSrcStream()162 virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } 163 164 StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( 165 int num_replicas, int num_partitions) const override; 166 167 StatusOr<std::unique_ptr<PjRtExecutable>> Compile( 168 const XlaComputation& computation, CompileOptions options) override; 169 ExecutableFingerprint(const PjRtExecutable & executable)170 StatusOr<absl::optional<std::string>> ExecutableFingerprint( 171 const PjRtExecutable& executable) const override { 172 return absl::optional<std::string>(); 173 } 174 175 StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override; 176 177 // Creates a buffer on the device without initializing or copying any data. 178 // An optional `definition_event` may be speficied that can be used to 179 // ensure the buffer isn't referenced until some external mechanism has 180 // initialized the data. 181 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 182 const Shape& shape, PjRtDevice* device) override; 183 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 184 const Shape& shape, PjRtDevice* device, 185 std::shared_ptr<BufferSequencingEvent> definition_event); 186 187 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( 188 const void* data, const Shape& shape, 189 HostBufferSemantics host_buffer_semantics, 190 std::function<void()> on_done_with_host_buffer, 191 PjRtDevice* device) override; 192 193 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( 194 const LiteralSlice& literal, PjRtDevice* device) override; 195 196 void MakeCrossHostReceiveBuffers( 197 absl::Span<const Shape> shapes, PjRtDevice* device, 198 PjRtCrossHostRecvNotifier&& notifier) override; 199 200 StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer( 201 void* device_ptr, const Shape& shape, PjRtDevice* device, 202 std::function<void()> on_delete_callback) override; 203 CreateChannelHandle()204 StatusOr<ChannelHandle> CreateChannelHandle() override { 205 return client()->CreateChannelHandle(); 206 } CreateDeviceToHostChannelHandle()207 StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override { 208 return client()->CreateDeviceToHostChannelHandle(); 209 } CreateHostToDeviceChannelHandle()210 StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override { 211 return client()->CreateHostToDeviceChannelHandle(); 212 } 213 device_state(int device_ordinal)214 LocalDeviceState& device_state(int device_ordinal) const { 215 return *tensorflow::down_cast<PjRtStreamExecutorDevice*>( 216 addressable_devices_.at(device_ordinal)) 217 ->local_device_state(); 218 } client()219 LocalClient* client() const { return client_; } allocator()220 se::DeviceMemoryAllocator* allocator() const { return allocator_; } host_memory_allocator()221 tensorflow::Allocator* host_memory_allocator() const { 222 return host_memory_allocator_.get(); 223 } should_stage_host_to_device_transfers()224 bool should_stage_host_to_device_transfers() const { 225 return should_stage_host_to_device_transfers_; 226 } 227 gpu_run_options()228 gpu::GpuExecutableRunOptions* gpu_run_options() const { 229 return gpu_run_options_.get(); 230 } 231 thread_pool()232 tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; } 233 234 protected: 235 friend class PjRtStreamExecutorBuffer; EnqueueCrossHostReceive(std::vector<std::unique_ptr<PjRtBuffer>> && buffers,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtCrossHostRecvNotifier && notifier)236 virtual void EnqueueCrossHostReceive( 237 std::vector<std::unique_ptr<PjRtBuffer>>&& buffers, 238 std::shared_ptr<BufferSequencingEvent> definition_event, 239 PjRtCrossHostRecvNotifier&& notifier) const { 240 notifier(Unimplemented("Cross host receives not implemented.")); 241 } 242 CopyToRemoteDevice(PjRtBuffer * buffer,absl::string_view serialized_descriptor)243 virtual Status CopyToRemoteDevice( 244 PjRtBuffer* buffer, absl::string_view serialized_descriptor) const { 245 return Unimplemented("Cross host sends not implemented."); 246 } 247 248 const PjRtPlatformId platform_id_; 249 const std::string platform_name_; 250 LocalClient* client_; 251 252 // Allocator to be used for staging memory transfers to devices. 253 std::unique_ptr<tensorflow::Allocator> host_memory_allocator_; 254 255 // Includes all devices, including non-local devices on multi-host platforms. 256 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_; 257 // Pointers to `owned_devices_`. 258 std::vector<PjRtDevice*> devices_; 259 // Maps Device::id() to the corresponding Device. Includes all devices. 260 std::map<int, PjRtDevice*> id_to_device_; 261 // Local devices indexed by local device ordinal. 262 std::vector<PjRtDevice*> addressable_devices_; 263 int task_id_; 264 265 se::DeviceMemoryAllocator* allocator_; 266 std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_; 267 268 // Should we always prefer to stage host-to-device transfers via memory 269 // allocated on host_memory_allocator_? True only on GPU, where we prefer to 270 // transfer via pinned memory. 271 bool should_stage_host_to_device_transfers_; 272 273 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_; 274 275 tensorflow::thread::ThreadPool thread_pool_; 276 }; 277 278 // Converts a 2D set of Device objects indexed by [replica][partition] into an 279 // xla::DeviceAssignment. 280 StatusOr<DeviceAssignment> DevicesToDeviceAssignment( 281 absl::Span<const std::vector<PjRtDevice*>> devices); 282 283 class PjRtStreamExecutorBuffer : public PjRtBuffer { 284 public: 285 // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold 286 // may not outlive its parent PjRtStreamExecutorBuffer. 287 // 288 // There are three types of hold, as follows: 289 // 290 // 1) Usage hold: a transient hold while an operation using the buffer is 291 // being enqueued onto a stream. 292 // A client acquires a usage hold by calling 293 // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience 294 // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the 295 // hold should be released using a call to ConvertUsageHold. If the ScopedHold 296 // is deleted without ConvertUsageHold being called, e.g., on error, the hold 297 // is dropped. It is legal to drop a usage hold instead of calling 298 // ConvertUsageHold, even if the buffer was successfully enqueued, as long as 299 // the client ensures that all necessary synchronization has been done. 300 // 301 // 2) External hold: a potentially long-lived hold while the buffer is being 302 // shared by an external framework, e.g., NumPy. 303 // A client acquires an external hold by calling 304 // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience 305 // wrapper GetBufferWithExternalReference and releases it by deleting the 306 // ScopedHold. The external framework should not modify the underlying buffer 307 // unless it is confident via its own synchronization that modifications do 308 // not race with reads from the PjRtStreamExecutorBuffer. 309 // 310 // 3) Donation hold: a transient hold while an execution that donates the 311 // buffer is being enqueued onto the compute stream. 312 // A client acquires a donation hold by calling 313 // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue 314 // completes successfully the hold should be released using a call to 315 // ConfirmDonation after which the buffer is invalid. If the ScopedHold is 316 // deleted without ConfirmDonation being called, e.g., on error, the hold is 317 // dropped and the buffer remains valid. If the buffer is successfully 318 // enqueued the client *must* call ConfirmDonation. 319 // 320 // Donation holds behave like exclusive write locks: when a donation hold 321 // has been acquired, any attempt to acquire another hold of any type will 322 // block until the donation hold is dropped or confirmed. Acquiring a donation 323 // hold will fail with an error if there is any outstanding external hold, and 324 // will block if there are any outstanding usage holds until those holds are 325 // dropped or converted. 326 // 327 // Calls to PjRtStreamExecutorBuffer::Release (and transitively to 328 // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will 329 // block until all usage and donation holds are either deleted or 330 // converted/confirmed. 331 class ScopedHold { 332 public: 333 enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; 334 // Use a State enum instead of encoding the state in an error Status to 335 // avoid creating Status values in non-error cases. Creating a Status 336 // entails several allocations and can add O(us) to every use of a hold. 337 enum State { 338 kUninitialized = 0, 339 kValid, 340 kMoved, 341 kConverted, 342 kReleased, 343 kDonated, 344 kError 345 }; 346 347 ~ScopedHold(); 348 ScopedHold(ScopedHold&& other); 349 ScopedHold(const ScopedHold&) = delete; 350 ScopedHold& operator=(const ScopedHold&) = delete; 351 type()352 Type type() const { return type_; } 353 status()354 Status status() const { 355 // Lazily create Status values only when they are requested. 356 switch (state_) { 357 case kUninitialized: 358 return InvalidArgument("Buffer has not been initialized"); 359 case kValid: 360 return Status::OK(); 361 case kMoved: 362 return InvalidArgument("Buffer has been moved."); 363 case kConverted: 364 return InvalidArgument("Buffer has been converted"); 365 case kReleased: 366 return InvalidArgument("Buffer has been released"); 367 case kDonated: 368 return InvalidArgument("Buffer has been donated"); 369 case kError: 370 return status_; 371 default: 372 CHECK(false) << "Unexpected state value " << state_; 373 } 374 } ok()375 bool ok() const { return state_ == kValid; } 376 377 // Access to the underlying device buffer storage. Requires this->ok(). buffer()378 const std::shared_ptr<TrackedDeviceBuffer>& buffer() const { 379 CHECK_EQ(state_, kValid); 380 CHECK_NE(buffer_, nullptr); 381 return buffer_; 382 } 383 TrackedDeviceBuffer* operator->() const { return buffer().get(); } 384 const TrackedDeviceBuffer& operator*() const { return *buffer(); } 385 386 // Converts the hold into a usage event. Only valid for holds of type 387 // kUsage. 388 // 389 // usage_stream: the stream that the buffer was used on. 390 // event: an event that has been recorded on usage_stream after 391 // the buffer was used. 392 // reference_held: true if and only if the caller has caused a 393 // reference to this->buffer() to stay live until after 394 // the host is sure that the usage (transfer or execution) 395 // has completed. 396 void ConvertUsageHold(se::Stream* usage_stream, 397 std::shared_ptr<BufferSequencingEvent> event, 398 bool reference_held); 399 400 // Confirms that the buffer was successfully donated to an execution. 401 // Only valid for holds of type kDonation. Causes the buffer to become 402 // invalid. 403 void ConfirmDonation(); 404 405 // Adds the held device buffers in order to 'iterator'. Used to add the 406 // buffers to an ExecutionInput. We require but do not verify that 407 // 'iterator' when passed in is pointing to a sub-tuple of the 408 // ExecutionInput whose on_device_shape matches that of the 409 // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run 410 // out of bounds. Donates the device buffers if the hold type is kDonation, 411 // otherwise retains ownership of the device buffers. 412 void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator, 413 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end, 414 ExecutionInput* execution_input, 415 se::DeviceMemoryAllocator* allocator) const; 416 417 private: 418 friend class PjRtStreamExecutorBuffer; 419 friend class PjRtStreamExecutorClient; 420 421 // Helper struct that makes it possible to move a ScopedHold through a 422 // closure. 423 using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State, 424 Status, std::shared_ptr<TrackedDeviceBuffer>>; 425 ScopedHold(PjRtStreamExecutorBuffer * parent,Type type)426 ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) 427 : parent_(parent), type_(type), state_(kUninitialized) {} ScopedHold(const ForClosure & closure_helper)428 explicit ScopedHold(const ForClosure& closure_helper) 429 : parent_(std::get<0>(closure_helper)), 430 type_(std::get<1>(closure_helper)), 431 state_(std::get<2>(closure_helper)), 432 status_(std::get<3>(closure_helper)), 433 buffer_(std::get<4>(closure_helper)) { 434 // Check the buffer is not in an error state. 435 CHECK(status_.ok() && buffer_ != nullptr); 436 } 437 438 // Sets buffer state. SetState(State state)439 void SetState(State state) { state_ = state; } 440 441 // Sets buffer_ and status_. Called by parent_ to initialize the hold. 442 void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or); 443 // Releases the contents of *this, so *this can subsequently be 444 // deleted without releasing the parent's hold. Should be passed to the 445 // appropriate constructor of another ScopedHold, e.g., when a hold must be 446 // passed through a closure that is incompatible with std::move. 447 ForClosure ToClosure(); 448 449 PjRtStreamExecutorBuffer* const parent_; 450 const Type type_; 451 452 // There is an invariant that if ok() then 453 // buffer_.ValueOrDie() != nullptr. 454 State state_; 455 Status status_; 456 std::shared_ptr<TrackedDeviceBuffer> buffer_; 457 }; 458 459 PjRtStreamExecutorBuffer(Shape on_device_shape, 460 std::shared_ptr<TrackedDeviceBuffer> device_buffer, 461 PjRtClient* client, PjRtDevice* device); 462 ~PjRtStreamExecutorBuffer() override; 463 464 PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete; 465 PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete; 466 PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; 467 PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; 468 on_device_shape()469 const Shape& on_device_shape() const override { return on_device_shape_; } device()470 PjRtStreamExecutorDevice* device() const override { return device_; } platform_id()471 PjRtPlatformId platform_id() const { return client_->platform_id(); } platform_name()472 absl::string_view platform_name() const { return client_->platform_name(); } client()473 PjRtStreamExecutorClient* client() const override { return client_; } IsEmptyTuple()474 bool IsEmptyTuple() const { 475 return on_device_shape_.IsTuple() && 476 on_device_shape_.tuple_shapes_size() == 0; 477 } 478 479 int64 OnDeviceSizeInBytes() const override; 480 481 StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference() 482 override; 483 484 StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership( 485 bool wait_for_operations_to_complete) override; 486 487 using PjRtBuffer::ToLiteral; 488 void ToLiteral(MutableLiteralBase* literal, 489 std::function<void(Status)> on_ready) override; 490 491 // Drops the buffer's reference to its associated device memory, leaving the 492 // buffer in an invalid state. The memory will be freed lazily when all async 493 // operations using the buffer have completed, according to the allocation 494 // semantics of the underlying platform. Delete may briefly block if another 495 // thread is in the process of enqueuing an operation on this buffer, but it 496 // will never block for a stream operation to complete. If an external 497 // framework holds a reference to the TrackedDeviceBuffer via 498 // GetBufferWithExternalReference, the memory will not be freed until the 499 // external framework drops the reference. 500 void Delete() override; 501 502 bool IsDeleted() override; 503 504 // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The 505 // PjRtBuffer retains ownership of the device buffers. 506 StatusOr<ShapedBuffer> AsShapedBuffer() const; 507 508 // Returns a hold on the TrackedDeviceBuffer holding the device 509 // buffers. See comment on ScopedHold. 510 ScopedHold GetBufferWithHold(ScopedHold::Type type); GetBufferWithUsageHold()511 ScopedHold GetBufferWithUsageHold() { 512 return GetBufferWithHold(ScopedHold::kUsage); 513 } GetBufferWithExternalReference()514 ScopedHold GetBufferWithExternalReference() { 515 return GetBufferWithHold(ScopedHold::kExternalReference); 516 } 517 518 StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice( 519 PjRtDevice* dst_device) override; 520 521 Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override; 522 523 Status BlockHostUntilReady() override; 524 525 bool IsOnCpu() const override; 526 527 // Similar to Delete, drops the buffer's reference to its associated device 528 // memory, leaving the buffer in an invalid state, but returns the 529 // TrackedDeviceBuffer rather than freeing the device memory, so that another 530 // framework can take ownership of it. The buffer returned from Release may 531 // be safely dropped at any time even if it still has pending async 532 // operations. The client should call BlockHostUntilReady before calling 533 // Release with wait_for_operations_to_complete=false, to ensure that the host 534 // has synchronized past any outstanding write operations to the buffer. If 535 // wait_for_operations_to_complete=true the host will block until any 536 // potentially outstanding asynchronous operations have completed before 537 // returning, in which case it is safe to read or mutate the returned buffer. 538 // If the buffer was shared via an external reference it is the client's 539 // responsibility that accesses via that reference do not interfere with 540 // accesses via the buffer returned from Release. 541 StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release( 542 bool wait_for_operations_to_complete); 543 544 private: 545 friend class PjRtClient; 546 547 // Blocks in mu_.Await until there are no more usage holds. 548 void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 549 550 // Blocks in mu_.Await until there is no donation hold. 551 void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 552 553 // Adds a hold of 'type' and returns device_buffer_. Returns an error if 554 // device_buffer_ is null, or if a donation hold was requested when there is 555 // an outstanding external hold. 556 StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked( 557 ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 558 559 // Adds a hold of hold->type() and initializes `hold` with device_buffer_. 560 // Initializes hold with an error if device_buffer_ is null, or if a donation 561 // hold was requested when there is an outstanding external hold. 562 void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 563 564 // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity 565 // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after 566 // device_buffer_ was successfully enqueued on a stream. 567 void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream, 568 std::shared_ptr<BufferSequencingEvent> event, 569 bool reference_held); 570 571 // Drops a donation hold and makes *this invalid for further use. Does a 572 // sanity check that buffer==device_buffer_. Called after device_buffer_ was 573 // successfully donated to an execution. 574 void ConfirmDonation(TrackedDeviceBuffer* device_buffer); 575 576 // Drops a hold without taking any other action. Does a sanity check that 577 // buffer==device_buffer_ or device_buffer_==nullptr. 578 void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); 579 580 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>, 581 std::shared_ptr<BufferSequencingEvent>>> 582 CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, 583 LocalDeviceState* transfer_local_device, 584 se::Stream* transfer_stream, 585 std::shared_ptr<TrackedDeviceBuffer> src_device_buffer); 586 587 PjRtStreamExecutorClient* const client_; 588 const Shape on_device_shape_; 589 PjRtStreamExecutorDevice* const device_; 590 591 mutable absl::Mutex mu_; 592 std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_); 593 // Count of holds on the buffer. 594 std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_); 595 // Semaphore used to ensure there is only one outstanding donation hold. 596 Semaphore donation_semaphore_; 597 }; 598 599 // Wraps one or more XLA LocalExecutables (one per partition, as specified by 600 // the build options). 601 class PjRtStreamExecutorExecutable : public PjRtExecutable { 602 public: 603 PjRtStreamExecutorExecutable( 604 std::vector<std::unique_ptr<LocalExecutable>> executables, 605 bool parameter_is_tupled_arguments, 606 std::shared_ptr<DeviceAssignment> device_assignment, 607 std::vector<LogicalDeviceIds> addressable_device_logical_ids, 608 std::vector<PjRtDevice*> addressable_devices, 609 PjRtStreamExecutorClient* client); 610 611 ~PjRtStreamExecutorExecutable() override = default; 612 client()613 PjRtStreamExecutorClient* client() const override { return client_; } 614 615 absl::string_view name() const override; 616 num_replicas()617 int num_replicas() const override { 618 return executables_[0]->build_options().num_replicas(); 619 } 620 num_partitions()621 int num_partitions() const override { 622 return executables_[0]->build_options().num_partitions(); 623 } 624 SizeOfGeneratedCodeInBytes()625 int64 SizeOfGeneratedCodeInBytes() const override { 626 int64 size = 0; 627 for (auto& executable : executables_) { 628 size += executable->executable()->SizeOfGeneratedCodeInBytes(); 629 } 630 return size; 631 } 632 device_assignment()633 const DeviceAssignment& device_assignment() const override { 634 return *device_assignment_; 635 } 636 addressable_device_logical_ids()637 absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() 638 const override { 639 return addressable_device_logical_ids_; 640 } 641 addressable_devices()642 absl::Span<PjRtDevice* const> addressable_devices() const override { 643 return addressable_devices_; 644 } 645 646 // Return an HloModule per partition. 647 StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules() 648 const override; 649 650 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute( 651 absl::Span<const std::vector<PjRtBuffer*>> argument_handles, 652 const ExecuteOptions& options) override; 653 654 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded( 655 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 656 const ExecuteOptions& options) override; 657 658 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable( 659 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 660 const ExecuteOptions& options) override; 661 Delete()662 void Delete() override { executables_.clear(); } 663 executables()664 absl::Span<const std::shared_ptr<LocalExecutable>> executables() const { 665 return executables_; 666 } 667 668 protected: parameter_is_tupled_arguments()669 bool parameter_is_tupled_arguments() const { 670 return parameter_is_tupled_arguments_; 671 } 672 673 private: 674 friend class PjRtStreamExecutorClient; 675 // Initializes information about which arguments to which executables must be 676 // donated due to aliases that were specified by the computation. 677 Status SetUpDonation(bool tuple_inputs); 678 679 virtual bool MustDonateParameter(int executable_idx, int parameter) const; 680 681 virtual StatusOr<std::vector<ExecutionInput>> 682 MakeExecutionInputsAndWaitForEvents( 683 int device_ordinal, const ExecuteOptions& options, 684 absl::Span<PjRtBuffer* const> argument_handles, 685 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers, 686 absl::flat_hash_set<BufferSequencingEvent*>& events) const; 687 688 StatusOr<ScopedShapedBuffer> EnqueueExecution( 689 absl::Span<PjRtBuffer* const> argument_handles, int replica, 690 int partition, int executable_idx, const RunId& run_id, 691 const ExecuteOptions& options, PjRtDevice* device, 692 std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers, 693 std::shared_ptr<DeviceAssignment> device_assignment) const; 694 695 virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers( 696 int device_ordinal, const ExecuteOptions& options, 697 ScopedShapedBuffer result_buffer, 698 std::shared_ptr<BufferSequencingEvent> definition_event, 699 PjRtDevice* device) const; 700 701 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper( 702 absl::Span<PjRtBuffer* const> argument_handles, int replica, 703 int partition, const RunId& run_id, const ExecuteOptions& options, 704 PjRtDevice* device = nullptr) const; 705 706 // Create shared pointers so we can free them after the execution: with 707 // asynchronous execution, the process being executed can outlive the 708 // executable itself. 709 PjRtStreamExecutorClient* const client_; 710 // One executable per partition. 711 std::vector<std::shared_ptr<LocalExecutable>> executables_; 712 // Per-executable set of parameters that have any aliased buffers and thus 713 // must be donated when executing the computation. 714 std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_; 715 std::shared_ptr<DeviceAssignment> device_assignment_; 716 717 // True if the executables were compiled expecting arguments in a single 718 // tuple. 719 const bool parameter_is_tupled_arguments_; 720 721 // The replica and partition indices of device_assignment_ to be run by this 722 // client. On single-host platforms without partitioning, this is all replicas 723 // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the 724 // case on multi-host platforms. If there are 4 replicas and 2 partitions on a 725 // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. 726 std::vector<LogicalDeviceIds> addressable_device_logical_ids_; 727 728 // addressable_devices_[i] is the Device to which 729 // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of 730 // unique_ptrs to play well with the Python bindings (see xla.cc). 731 std::vector<PjRtDevice*> addressable_devices_; 732 }; 733 734 // Executables can donate buffers so that buffers can be aliased from inputs 735 // to outputs. This function returns the list of parameters that must be 736 // donated when executable is run. tuple_inputs reflects the option that 737 // executable was compiled with. 738 StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated( 739 const HloModule& hlo_module, bool tuple_inputs); 740 741 } // namespace xla 742 743 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 744