1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/container/inlined_vector.h" 26 #include "absl/strings/string_view.h" 27 #include "absl/synchronization/mutex.h" 28 #include "absl/synchronization/notification.h" 29 #include "absl/types/optional.h" 30 #include "absl/types/span.h" 31 #include "tensorflow/compiler/xla/client/executable_build_options.h" 32 #include "tensorflow/compiler/xla/client/local_client.h" 33 #include "tensorflow/compiler/xla/client/xla_computation.h" 34 #include "tensorflow/compiler/xla/layout.h" 35 #include "tensorflow/compiler/xla/literal.h" 36 #include "tensorflow/compiler/xla/pjrt/local_device_state.h" 37 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 38 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" 39 #include "tensorflow/compiler/xla/pjrt/transpose.h" 40 #include "tensorflow/compiler/xla/service/computation_layout.h" 41 #include "tensorflow/compiler/xla/service/computation_placer.h" 42 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" 43 #include "tensorflow/compiler/xla/service/hlo_module.h" 44 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 45 #include "tensorflow/compiler/xla/shape.h" 46 #include "tensorflow/compiler/xla/status.h" 47 #include "tensorflow/compiler/xla/statusor.h" 48 #include "tensorflow/compiler/xla/util.h" 49 #include "tensorflow/compiler/xla/xla_data.pb.h" 50 #include "tensorflow/core/framework/allocator.h" 51 #include "tensorflow/core/lib/core/status.h" 52 #include "tensorflow/core/platform/casts.h" 53 #include "tensorflow/core/platform/fingerprint.h" 54 #include "tensorflow/core/platform/thread_annotations.h" 55 #include "tensorflow/core/platform/types.h" 56 #include "tensorflow/stream_executor/stream.h" 57 58 namespace xla { 59 60 class PjRtStreamExecutorDevice : public PjRtDevice { 61 public: 62 explicit PjRtStreamExecutorDevice( 63 int id, std::unique_ptr<LocalDeviceState> local_device_state, 64 std::string device_kind, int process_index = 0) id_(id)65 : id_(id), 66 device_ordinal_( 67 local_device_state ? local_device_state->device_ordinal() : -1), 68 local_device_state_(std::move(local_device_state)), 69 process_index_(process_index), 70 device_kind_(std::move(device_kind)) {} ~PjRtStreamExecutorDevice()71 ~PjRtStreamExecutorDevice() override {} 72 73 // Must set client exactly once. SetClient(PjRtClient * client)74 void SetClient(PjRtClient* client) { 75 CHECK(client_ == nullptr); 76 client_ = client; 77 } 78 process_index()79 int process_index() const override { return process_index_; } 80 81 // Return `platform_id` from client. 82 PjRtPlatformId platform_id() const; 83 84 // Return `platform_name` from client. 85 absl::string_view platform_name() const; 86 client()87 PjRtClient* client() const override { return client_; } 88 id()89 int id() const override { return id_; } 90 IsAddressable()91 bool IsAddressable() const override { return device_ordinal_ != -1; } 92 local_hardware_id()93 int local_hardware_id() const override { return device_ordinal_; } 94 95 // If this is a device local to this host, returns a LocalDeviceState object 96 // that can be used to manipulate the device. Returns nullptr if the device is 97 // not local to this host. local_device_state()98 LocalDeviceState* local_device_state() const { 99 return local_device_state_.get(); 100 } 101 102 // If this is a device local to this host, returns a LocalDeviceState object 103 // that can be used to manipulate the device. Returns an error if the device 104 // is not local to this host. 105 StatusOr<LocalDeviceState*> GetLocalDeviceState() const; 106 device_kind()107 absl::string_view device_kind() const override { return device_kind_; } 108 109 std::string DebugString() const override; 110 111 Status TransferToInfeed(const LiteralSlice& literal) override; 112 113 Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; 114 115 private: 116 const int id_; 117 const int device_ordinal_; // -1 means not local. 118 const std::unique_ptr<LocalDeviceState> local_device_state_; 119 const int process_index_; 120 const std::string device_kind_; 121 PjRtClient* client_ = nullptr; 122 }; 123 124 class PjRtStreamExecutorClient : public PjRtClient { 125 public: 126 // `allocator` may null, in which case the platform default allocator is used. 127 explicit PjRtStreamExecutorClient( 128 std::string platform_name, LocalClient* client, 129 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, 130 int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator, 131 std::unique_ptr<tensorflow::Allocator> host_memory_allocator, 132 bool should_stage_host_to_device_transfers, 133 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options); 134 ~PjRtStreamExecutorClient() override = default; 135 process_index()136 int process_index() const override { return process_index_; } 137 device_count()138 int device_count() const override { return devices_.size(); } addressable_device_count()139 int addressable_device_count() const override { 140 return addressable_devices_.size(); 141 } devices()142 absl::Span<PjRtDevice* const> devices() const override { return devices_; } addressable_devices()143 absl::Span<PjRtDevice* const> addressable_devices() const override { 144 return addressable_devices_; 145 } 146 LookupDevice(int device_id)147 StatusOr<PjRtDevice*> LookupDevice(int device_id) const override { 148 auto it = id_to_device_.find(device_id); 149 if (it != id_to_device_.end()) { 150 return it->second; 151 } 152 return InvalidArgument("No matching device found for device_id %d", 153 device_id); 154 } 155 156 StatusOr<PjRtDevice*> LookupAddressableDevice( 157 int local_hardware_id) const override; 158 platform_id()159 PjRtPlatformId platform_id() const override { return platform_id_; } platform_name()160 absl::string_view platform_name() const override { return platform_name_; } platform_version()161 absl::string_view platform_version() const override { return "<unknown>"; } runtime_type()162 PjRtRuntimeType runtime_type() const override { return kStreamExecutor; } 163 164 // Most platforms expect device-to-device transfers to be enqueued on the 165 // source d2d stream, but some platforms use the destination d2d stream. This 166 // function specifies which one the platform expects. EnqueueD2DTransfersOnSrcStream()167 virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } 168 169 StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( 170 int num_replicas, int num_partitions) const override; 171 172 StatusOr<std::unique_ptr<PjRtExecutable>> Compile( 173 const XlaComputation& computation, CompileOptions options) override; 174 ExecutableFingerprint(const PjRtExecutable & executable)175 StatusOr<absl::optional<std::string>> ExecutableFingerprint( 176 const PjRtExecutable& executable) const override { 177 return absl::optional<std::string>(); 178 } 179 SerializeExecutable(const PjRtExecutable & executable)180 StatusOr<std::string> SerializeExecutable( 181 const PjRtExecutable& executable) const override { 182 return Unimplemented("SerializeExecutable not implemented on %s", 183 platform_name()); 184 } 185 DeserializeExecutable(absl::string_view serialized,CompileOptions options)186 StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable( 187 absl::string_view serialized, CompileOptions options) override { 188 return Unimplemented("DeserializeExecutable not implemented on %s", 189 platform_name()); 190 } 191 192 StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override; 193 194 // Creates a buffer on the device without initializing or copying any data. 195 // An optional `definition_event` may be speficied that can be used to 196 // ensure the buffer isn't referenced until some external mechanism has 197 // initialized the data. 198 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 199 const Shape& shape, PjRtDevice* device) override; 200 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 201 const Shape& shape, PjRtDevice* device, 202 std::shared_ptr<BufferSequencingEvent> definition_event); 203 204 StatusOr<std::unique_ptr<PjRtClient::AsyncBufferTransferManager>> CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)205 CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes, 206 PjRtDevice* device) override { 207 return Unimplemented("Async transfer to buffers not implemented"); 208 }; 209 210 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( 211 const void* data, PrimitiveType type, absl::Span<int64_t const> dims, 212 absl::optional<absl::Span<int64_t const>> byte_strides, 213 HostBufferSemantics host_buffer_semantics, 214 std::function<void()> on_done_with_host_buffer, 215 PjRtDevice* device) override; 216 217 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( 218 const LiteralSlice& literal, PjRtDevice* device) override; 219 220 void MakeCrossHostReceiveBuffers( 221 absl::Span<const Shape> shapes, PjRtDevice* device, 222 PjRtCrossHostRecvNotifier&& notifier) override; 223 224 void MakeCrossHostReceiveBuffersForGather( 225 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details, 226 PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) override; 227 228 StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer( 229 void* device_ptr, const Shape& shape, PjRtDevice* device, 230 std::function<void()> on_delete_callback) override; 231 CreateChannelHandle()232 StatusOr<ChannelHandle> CreateChannelHandle() override { 233 return client()->CreateChannelHandle(); 234 } CreateDeviceToHostChannelHandle()235 StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override { 236 return client()->CreateDeviceToHostChannelHandle(); 237 } CreateHostToDeviceChannelHandle()238 StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override { 239 return client()->CreateHostToDeviceChannelHandle(); 240 } 241 242 // TODO(zhangqiaorjc): Experimental. Will be removed. Defragment()243 Status Defragment() override { 244 return Unimplemented("Defragment not implemented"); 245 } 246 device_state(int device_ordinal)247 LocalDeviceState& device_state(int device_ordinal) const { 248 return *tensorflow::down_cast<PjRtStreamExecutorDevice*>( 249 addressable_devices_.at(device_ordinal)) 250 ->local_device_state(); 251 } client()252 LocalClient* client() const { return client_; } allocator()253 se::DeviceMemoryAllocator* allocator() const { return allocator_; } host_memory_allocator()254 tensorflow::Allocator* host_memory_allocator() const { 255 return host_memory_allocator_.get(); 256 } should_stage_host_to_device_transfers()257 bool should_stage_host_to_device_transfers() const { 258 return should_stage_host_to_device_transfers_; 259 } 260 gpu_run_options()261 gpu::GpuExecutableRunOptions* gpu_run_options() const { 262 return gpu_run_options_.get(); 263 } 264 thread_pool()265 tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; } 266 267 protected: 268 friend class PjRtStreamExecutorBuffer; 269 EnqueueCrossHostReceive(std::vector<std::unique_ptr<PjRtBuffer>> && buffers,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtCrossHostRecvNotifier && notifier,absl::optional<std::vector<GatherDetails>> gather_details)270 virtual void EnqueueCrossHostReceive( 271 std::vector<std::unique_ptr<PjRtBuffer>>&& buffers, 272 std::shared_ptr<BufferSequencingEvent> definition_event, 273 PjRtCrossHostRecvNotifier&& notifier, 274 absl::optional<std::vector<GatherDetails>> gather_details) const { 275 notifier(Unimplemented("Cross host receives not implemented.")); 276 } 277 CopyToRemoteDevice(PjRtBuffer * buffer,absl::string_view serialized_descriptor)278 virtual Status CopyToRemoteDevice( 279 PjRtBuffer* buffer, absl::string_view serialized_descriptor) const { 280 return Unimplemented("Cross host sends not implemented."); 281 } 282 CopyToRemoteDeviceScattered(PjRtBuffer * buffer,absl::Span<const std::string> serialized_descriptors,const PjRtBuffer::ScatterDetails & scatter_details)283 virtual Status CopyToRemoteDeviceScattered( 284 PjRtBuffer* buffer, absl::Span<const std::string> serialized_descriptors, 285 const PjRtBuffer::ScatterDetails& scatter_details) const { 286 return Unimplemented("Scattered cross host sends not implemented."); 287 } 288 CopyRawSubBufferToHost(PjRtBuffer * buffer,void * dst,int64_t offset,int64_t transfer_size,std::function<void (Status)> on_ready)289 virtual Status CopyRawSubBufferToHost(PjRtBuffer* buffer, void* dst, 290 int64_t offset, int64_t transfer_size, 291 std::function<void(Status)> on_ready) { 292 return Unimplemented("Raw copies to host not implemented."); 293 } 294 295 // Helper function for creating PjRtStreamExecutorExecutables. Modifies 296 // `options` in-place. 297 struct ExecutableExtras { 298 std::shared_ptr<DeviceAssignment> device_assignment; 299 std::vector<PjRtExecutable::LogicalDeviceIds> 300 addressable_device_logical_ids; 301 std::vector<PjRtDevice*> addressable_devices; 302 }; 303 StatusOr<ExecutableExtras> GetExecutableExtras(CompileOptions* options); 304 305 const PjRtPlatformId platform_id_; 306 const std::string platform_name_; 307 LocalClient* client_; 308 309 // Allocator to be used for staging memory transfers to devices. 310 std::unique_ptr<tensorflow::Allocator> host_memory_allocator_; 311 312 // Device memory allocator. If owned, the allocator must outlive the devices, 313 // because it is the device destructor that waits for any outstanding work to 314 // complete. 315 se::DeviceMemoryAllocator* allocator_; 316 std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_; 317 318 // Includes all devices, including non-local devices on multi-host platforms. 319 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_; 320 // Pointers to `owned_devices_`. 321 std::vector<PjRtDevice*> devices_; 322 // Maps Device::id() to the corresponding Device. Includes all devices. 323 std::map<int, PjRtDevice*> id_to_device_; 324 // Local devices indexed by local device ordinal. 325 std::vector<PjRtDevice*> addressable_devices_; 326 int process_index_; 327 328 // Should we always prefer to stage host-to-device transfers via memory 329 // allocated on host_memory_allocator_? True only on GPU, where we prefer to 330 // transfer via pinned memory. 331 bool should_stage_host_to_device_transfers_; 332 333 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_; 334 335 tensorflow::thread::ThreadPool thread_pool_; 336 337 absl::Mutex transpose_mu_; 338 TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); 339 }; 340 341 // Converts a 2D set of Device objects indexed by [replica][partition] into an 342 // xla::DeviceAssignment. 343 StatusOr<DeviceAssignment> DevicesToDeviceAssignment( 344 absl::Span<const std::vector<PjRtDevice*>> devices); 345 346 class PjRtStreamExecutorBuffer : public PjRtBuffer { 347 public: 348 // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold 349 // may not outlive its parent PjRtStreamExecutorBuffer. 350 // 351 // There are three types of hold, as follows: 352 // 353 // 1) Usage hold: a transient hold while an operation using the buffer is 354 // being enqueued onto a stream. 355 // A client acquires a usage hold by calling 356 // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience 357 // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the 358 // hold should be released using a call to ConvertUsageHold. If the ScopedHold 359 // is deleted without ConvertUsageHold being called, e.g., on error, the hold 360 // is dropped. It is legal to drop a usage hold instead of calling 361 // ConvertUsageHold, even if the buffer was successfully enqueued, as long as 362 // the client ensures that all necessary synchronization has been done. 363 // 364 // 2) External hold: a potentially long-lived hold while the buffer is being 365 // shared by an external framework, e.g., NumPy. 366 // A client acquires an external hold by calling 367 // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience 368 // wrapper GetBufferWithExternalReference and releases it by deleting the 369 // ScopedHold. The external framework should not modify the underlying buffer 370 // unless it is confident via its own synchronization that modifications do 371 // not race with reads from the PjRtStreamExecutorBuffer. 372 // 373 // 3) Donation hold: a transient hold while an execution that donates the 374 // buffer is being enqueued onto the compute stream. 375 // A client acquires a donation hold by calling 376 // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue 377 // completes successfully the hold should be released using a call to 378 // ConfirmDonation after which the buffer is invalid. If the ScopedHold is 379 // deleted without ConfirmDonation being called, e.g., on error, the hold is 380 // dropped and the buffer remains valid. If the buffer is successfully 381 // enqueued the client *must* call ConfirmDonation. 382 // 383 // Donation holds behave like exclusive write locks: when a donation hold 384 // has been acquired, any attempt to acquire another hold of any type will 385 // block until the donation hold is dropped or confirmed. Acquiring a donation 386 // hold will fail with an error if there is any outstanding external hold, and 387 // will block if there are any outstanding usage holds until those holds are 388 // dropped or converted. 389 // 390 // Calls to PjRtStreamExecutorBuffer::Release (and transitively to 391 // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will 392 // block until all usage and donation holds are either deleted or 393 // converted/confirmed. 394 class ScopedHold { 395 public: 396 enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; 397 // Use a State enum instead of encoding the state in an error Status to 398 // avoid creating Status values in non-error cases. Creating a Status 399 // entails several allocations and can add O(us) to every use of a hold. 400 enum State { 401 kUninitialized = 0, 402 kValid, 403 kMoved, 404 kConverted, 405 kReleased, 406 kDonated, 407 kError 408 }; 409 410 ~ScopedHold(); 411 ScopedHold(ScopedHold&& other); 412 ScopedHold(const ScopedHold&) = delete; 413 ScopedHold& operator=(const ScopedHold&) = delete; 414 type()415 Type type() const { return type_; } 416 status()417 Status status() const { 418 // Lazily create Status values only when they are requested. 419 switch (state_) { 420 case kUninitialized: 421 return InvalidArgument("Buffer has not been initialized"); 422 case kValid: 423 return Status::OK(); 424 case kMoved: 425 return InvalidArgument("Buffer has been moved."); 426 case kConverted: 427 return InvalidArgument("Buffer has been converted"); 428 case kReleased: 429 return InvalidArgument("Buffer has been released"); 430 case kDonated: 431 return InvalidArgument("Buffer has been donated"); 432 case kError: 433 return status_; 434 default: 435 CHECK(false) << "Unexpected state value " << state_; 436 } 437 } ok()438 bool ok() const { return state_ == kValid; } 439 440 // Access to the underlying device buffer storage. Requires this->ok(). buffer()441 const std::shared_ptr<TrackedDeviceBuffer>& buffer() const { 442 CHECK_EQ(state_, kValid); 443 CHECK_NE(buffer_, nullptr); 444 return buffer_; 445 } 446 TrackedDeviceBuffer* operator->() const { return buffer().get(); } 447 const TrackedDeviceBuffer& operator*() const { return *buffer(); } 448 449 // Converts the hold into a usage event. Only valid for holds of type 450 // kUsage. 451 // 452 // usage_stream: the stream that the buffer was used on. 453 // event: an event that has been recorded on usage_stream after 454 // the buffer was used. 455 // reference_held: true if and only if the caller has caused a 456 // reference to this->buffer() to stay live until after 457 // the host is sure that the usage (transfer or execution) 458 // has completed. 459 void ConvertUsageHold(se::Stream* usage_stream, 460 std::shared_ptr<BufferSequencingEvent> event, 461 bool reference_held); 462 463 // Confirms that the buffer was successfully donated to an execution. 464 // Only valid for holds of type kDonation. Causes the buffer to become 465 // invalid. 466 void ConfirmDonation(); 467 468 // Adds the held device buffers in order to 'iterator'. Used to add the 469 // buffers to an ExecutionInput. We require but do not verify that 470 // 'iterator' when passed in is pointing to a sub-tuple of the 471 // ExecutionInput whose on_device_shape matches that of the 472 // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run 473 // out of bounds. Donates the device buffers if the hold type is kDonation, 474 // otherwise retains ownership of the device buffers. 475 void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator, 476 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end, 477 ExecutionInput* execution_input, 478 se::DeviceMemoryAllocator* allocator) const; 479 480 private: 481 friend class PjRtStreamExecutorBuffer; 482 friend class PjRtStreamExecutorClient; 483 484 // Helper struct that makes it possible to move a ScopedHold through a 485 // closure. 486 using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State, 487 Status, std::shared_ptr<TrackedDeviceBuffer>>; 488 ScopedHold(PjRtStreamExecutorBuffer * parent,Type type)489 ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) 490 : parent_(parent), type_(type), state_(kUninitialized) {} ScopedHold(const ForClosure & closure_helper)491 explicit ScopedHold(const ForClosure& closure_helper) 492 : parent_(std::get<0>(closure_helper)), 493 type_(std::get<1>(closure_helper)), 494 state_(std::get<2>(closure_helper)), 495 status_(std::get<3>(closure_helper)), 496 buffer_(std::get<4>(closure_helper)) { 497 // Check the buffer is not in an error state. 498 CHECK(status_.ok() && buffer_ != nullptr); 499 } 500 501 // Sets buffer state. SetState(State state)502 void SetState(State state) { state_ = state; } 503 504 // Sets buffer_ and status_. Called by parent_ to initialize the hold. 505 void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or); 506 // Releases the contents of *this, so *this can subsequently be 507 // deleted without releasing the parent's hold. Should be passed to the 508 // appropriate constructor of another ScopedHold, e.g., when a hold must be 509 // passed through a closure that is incompatible with std::move. 510 ForClosure ToClosure(); 511 512 PjRtStreamExecutorBuffer* const parent_; 513 const Type type_; 514 515 // There is an invariant that if ok() then 516 // buffer_.ValueOrDie() != nullptr. 517 State state_; 518 Status status_; 519 std::shared_ptr<TrackedDeviceBuffer> buffer_; 520 }; 521 522 PjRtStreamExecutorBuffer(Shape on_device_shape, 523 std::shared_ptr<TrackedDeviceBuffer> device_buffer, 524 PjRtClient* client, PjRtDevice* device); 525 ~PjRtStreamExecutorBuffer() override; 526 527 PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete; 528 PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete; 529 PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; 530 PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; 531 on_device_shape()532 const Shape& on_device_shape() const override { return on_device_shape_; } 533 StatusOr<Shape> logical_on_device_shape() override; device()534 PjRtStreamExecutorDevice* device() const override { return device_; } platform_id()535 PjRtPlatformId platform_id() const { return client_->platform_id(); } platform_name()536 absl::string_view platform_name() const { return client_->platform_name(); } client()537 PjRtStreamExecutorClient* client() const override { return client_; } IsEmptyTuple()538 bool IsEmptyTuple() const { 539 return on_device_shape_.IsTuple() && 540 on_device_shape_.tuple_shapes_size() == 0; 541 } 542 543 StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference() 544 override; 545 546 StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership( 547 bool wait_for_operations_to_complete) override; 548 549 using PjRtBuffer::ToLiteral; 550 void ToLiteral(MutableLiteralBase* literal, 551 std::function<void(Status)> on_ready) override; 552 553 StatusOr<size_t> GetOnDeviceSizeInBytes() const override; 554 555 Status CopyRawToHost(void* dst, int64_t offset, int64_t transfer_size, 556 std::function<void(Status)> on_ready) override; 557 558 // Drops the buffer's reference to its associated device memory, leaving the 559 // buffer in an invalid state. The memory will be freed lazily when all async 560 // operations using the buffer have completed, according to the allocation 561 // semantics of the underlying platform. Delete may briefly block if another 562 // thread is in the process of enqueuing an operation on this buffer, but it 563 // will never block for a stream operation to complete. If an external 564 // framework holds a reference to the TrackedDeviceBuffer via 565 // GetBufferWithExternalReference, the memory will not be freed until the 566 // external framework drops the reference. 567 void Delete() override; 568 569 bool IsDeleted() override; 570 571 // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The 572 // PjRtBuffer retains ownership of the device buffers. 573 StatusOr<ShapedBuffer> AsShapedBuffer() const; 574 575 // Returns a hold on the TrackedDeviceBuffer holding the device 576 // buffers. See comment on ScopedHold. 577 ScopedHold GetBufferWithHold(ScopedHold::Type type); GetBufferWithUsageHold()578 ScopedHold GetBufferWithUsageHold() { 579 return GetBufferWithHold(ScopedHold::kUsage); 580 } GetBufferWithExternalReference()581 ScopedHold GetBufferWithExternalReference() { 582 return GetBufferWithHold(ScopedHold::kExternalReference); 583 } 584 585 StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice( 586 PjRtDevice* dst_device) override; 587 588 Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override; 589 590 Status CopyToRemoteDeviceScattered( 591 absl::Span<const std::string> serialized_descriptors, 592 const ScatterDetails& scatter_details) override; 593 594 Status BlockHostUntilReady() override; 595 596 bool IsOnCpu() const override; 597 598 // Similar to Delete, drops the buffer's reference to its associated device 599 // memory, leaving the buffer in an invalid state, but returns the 600 // TrackedDeviceBuffer rather than freeing the device memory, so that another 601 // framework can take ownership of it. The buffer returned from Release may 602 // be safely dropped at any time even if it still has pending async 603 // operations. The client should call BlockHostUntilReady before calling 604 // Release with wait_for_operations_to_complete=false, to ensure that the host 605 // has synchronized past any outstanding write operations to the buffer. If 606 // wait_for_operations_to_complete=true the host will block until any 607 // potentially outstanding asynchronous operations have completed before 608 // returning, in which case it is safe to read or mutate the returned buffer. 609 // If the buffer was shared via an external reference it is the client's 610 // responsibility that accesses via that reference do not interfere with 611 // accesses via the buffer returned from Release. 612 StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release( 613 bool wait_for_operations_to_complete); 614 615 private: 616 friend class PjRtClient; 617 618 // Blocks in mu_.Await until there are no more usage holds. 619 void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 620 621 // Blocks in mu_.Await until there is no donation hold. 622 void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 623 624 // Adds a hold of 'type' and returns device_buffer_. Returns an error if 625 // device_buffer_ is null, or if a donation hold was requested when there is 626 // an outstanding external hold. 627 // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() 628 // must be called first.) 629 StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked( 630 ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 631 632 // Adds a hold of hold->type() and initializes `hold` with device_buffer_. 633 // Initializes hold with an error if device_buffer_ is null, or if a donation 634 // hold was requested when there is an outstanding external hold. 635 // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() 636 // must be called first.) 637 void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 638 639 // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity 640 // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after 641 // device_buffer_ was successfully enqueued on a stream. 642 void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream, 643 std::shared_ptr<BufferSequencingEvent> event, 644 bool reference_held); 645 646 // Drops a donation hold and makes *this invalid for further use. Does a 647 // sanity check that buffer==device_buffer_. Called after device_buffer_ was 648 // successfully donated to an execution. 649 void ConfirmDonation(TrackedDeviceBuffer* device_buffer); 650 651 // Drops a hold without taking any other action. Does a sanity check that 652 // buffer==device_buffer_ or device_buffer_==nullptr. 653 void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); 654 655 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>, 656 std::shared_ptr<BufferSequencingEvent>>> 657 CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, 658 LocalDeviceState* transfer_local_device, 659 se::Stream* transfer_stream, 660 std::shared_ptr<TrackedDeviceBuffer> src_device_buffer); 661 662 PjRtStreamExecutorClient* const client_; 663 const Shape on_device_shape_; 664 PjRtStreamExecutorDevice* const device_; 665 666 mutable absl::Mutex mu_; 667 std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_); 668 // Count of holds on the buffer. 669 std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_); 670 }; 671 672 // Wraps one or more XLA LocalExecutables (one per partition, as specified by 673 // the build options). 674 class PjRtStreamExecutorExecutable : public PjRtExecutable { 675 public: 676 PjRtStreamExecutorExecutable( 677 std::vector<std::unique_ptr<LocalExecutable>> executables, 678 bool parameter_is_tupled_arguments, 679 std::shared_ptr<DeviceAssignment> device_assignment, 680 std::vector<LogicalDeviceIds> addressable_device_logical_ids, 681 std::vector<PjRtDevice*> addressable_devices, 682 PjRtStreamExecutorClient* client); 683 684 ~PjRtStreamExecutorExecutable() override = default; 685 client()686 PjRtStreamExecutorClient* client() const override { return client_; } 687 688 absl::string_view name() const override; 689 num_replicas()690 int num_replicas() const override { 691 return executables_[0]->build_options().num_replicas(); 692 } 693 num_partitions()694 int num_partitions() const override { 695 return executables_[0]->build_options().num_partitions(); 696 } 697 SizeOfGeneratedCodeInBytes()698 int64 SizeOfGeneratedCodeInBytes() const override { 699 int64_t size = 0; 700 for (auto& executable : executables_) { 701 size += executable->executable()->SizeOfGeneratedCodeInBytes(); 702 } 703 return size; 704 } 705 device_assignment()706 const DeviceAssignment& device_assignment() const override { 707 return *device_assignment_; 708 } 709 addressable_device_logical_ids()710 absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() 711 const override { 712 return addressable_device_logical_ids_; 713 } 714 addressable_devices()715 absl::Span<PjRtDevice* const> addressable_devices() const override { 716 return addressable_devices_; 717 } 718 719 // Return an HloModule per partition. 720 StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules() 721 const override; 722 723 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute( 724 absl::Span<const std::vector<PjRtBuffer*>> argument_handles, 725 const ExecuteOptions& options) override; 726 727 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded( 728 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 729 const ExecuteOptions& options) override; 730 731 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable( 732 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 733 const ExecuteOptions& options) override; 734 Delete()735 void Delete() override { executables_.clear(); } 736 IsDeleted()737 bool IsDeleted() override { return executables_.empty(); } 738 executables()739 absl::Span<const std::shared_ptr<LocalExecutable>> executables() const { 740 return executables_; 741 } 742 743 protected: parameter_is_tupled_arguments()744 bool parameter_is_tupled_arguments() const { 745 return parameter_is_tupled_arguments_; 746 } 747 748 private: 749 friend class PjRtStreamExecutorClient; 750 friend class PjRtTpuClient; 751 friend class InternalPjRtTpuClient; 752 // Initializes information about which arguments to which executables must be 753 // donated due to aliases that were specified by the computation. 754 Status SetUpDonation(bool tuple_inputs); 755 756 // Returns a sorted list of the parameters that must be donated. Derived 757 // classes may use custom logic. 758 virtual absl::Span<int const> ParametersThatMustBeDonated( 759 int executable_idx) const; 760 761 virtual StatusOr<std::vector<ExecutionInput>> 762 MakeExecutionInputsAndWaitForEvents( 763 int device_ordinal, const ExecuteOptions& options, 764 absl::Span<const Shape> executable_parameter_shapes, 765 absl::Span<PjRtBuffer* const> argument_handles, 766 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers, 767 absl::flat_hash_set<BufferSequencingEvent*>& events) const; 768 769 StatusOr<ScopedShapedBuffer> EnqueueExecution( 770 absl::Span<PjRtBuffer* const> argument_handles, int replica, 771 int partition, int executable_idx, const RunId& run_id, 772 const ExecuteOptions& options, PjRtDevice* device, 773 std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers, 774 std::shared_ptr<DeviceAssignment> device_assignment, 775 std::vector<std::function<void()>>& compute_callbacks) const; 776 777 virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers( 778 int device_ordinal, const ExecuteOptions& options, 779 ScopedShapedBuffer result_buffer, 780 std::shared_ptr<BufferSequencingEvent> definition_event, 781 PjRtDevice* device, std::vector<std::function<void()>>& compute_callbacks, 782 std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) 783 const; 784 785 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper( 786 absl::Span<PjRtBuffer* const> argument_handles, int replica, 787 int partition, const RunId& run_id, const ExecuteOptions& options, 788 PjRtDevice* device = nullptr) const; 789 790 // Create shared pointers so we can free them after the execution: with 791 // asynchronous execution, the process being executed can outlive the 792 // executable itself. 793 PjRtStreamExecutorClient* const client_; 794 // One executable per partition. 795 std::vector<std::shared_ptr<LocalExecutable>> executables_; 796 // On device shapes of the executable parameters. 797 std::vector<std::vector<Shape>> on_device_executable_parameter_shapes_; 798 // Per-executable sorted vector of parameters that have any aliased buffers 799 // and thus must be donated when executing the computation. 800 std::vector<std::vector<int>> parameters_that_must_be_donated_; 801 std::shared_ptr<DeviceAssignment> device_assignment_; 802 803 // True if the executables were compiled expecting arguments in a single 804 // tuple. 805 const bool parameter_is_tupled_arguments_; 806 807 // The replica and partition indices of device_assignment_ to be run by this 808 // client. On single-host platforms without partitioning, this is all replicas 809 // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the 810 // case on multi-host platforms. If there are 4 replicas and 2 partitions on a 811 // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. 812 std::vector<LogicalDeviceIds> addressable_device_logical_ids_; 813 814 // addressable_devices_[i] is the Device to which 815 // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of 816 // unique_ptrs to play well with the Python bindings (see xla.cc). 817 std::vector<PjRtDevice*> addressable_devices_; 818 }; 819 820 } // namespace xla 821 822 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 823