1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/strings/string_view.h" 24 #include "absl/synchronization/notification.h" 25 #include "absl/types/optional.h" 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/client/executable_build_options.h" 28 #include "tensorflow/compiler/xla/client/xla_computation.h" 29 #include "tensorflow/compiler/xla/layout.h" 30 #include "tensorflow/compiler/xla/literal.h" 31 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 32 #include "tensorflow/compiler/xla/service/hlo_module.h" 33 #include "tensorflow/compiler/xla/shape.h" 34 #include "tensorflow/compiler/xla/status.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/lib/core/status.h" 39 #include "tensorflow/core/platform/casts.h" 40 #include "tensorflow/core/platform/fingerprint.h" 41 #include "tensorflow/core/platform/thread_annotations.h" 42 #include "tensorflow/core/platform/types.h" 43 44 // API notes: 45 // PjRt stands for "Pretty much Just another RunTime". 46 47 namespace xla { 48 49 using PjRtPlatformId = uint64; 50 51 constexpr char kCpuName[] = "cpu"; 52 constexpr char kGpuName[] = "gpu"; 53 constexpr char kTpuName[] = "tpu"; 54 static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(kCpuName); 55 static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(kGpuName); 56 static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(kTpuName); 57 58 class PjRtClient; 59 60 class PjRtDevice { 61 public: ~PjRtDevice()62 virtual ~PjRtDevice() {} 63 64 // Return the client that owns this device. 65 virtual PjRtClient* client() const = 0; 66 67 // Whether client can issue command to this device. 68 virtual bool IsAddressable() const = 0; 69 70 // The ID of this device. IDs are unique among devices of this type 71 // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all 72 // hosts' devices. This is the ID that should be used in a DeviceAssignment. 73 virtual int id() const = 0; 74 75 // The task ID of this device according to TpuTopology. This is not always 76 // identical to PjRtClient::task_id() in a multi-task setting, where each 77 // client can see devices from all tasks, but only a subset of them are 78 // addressable and have the same task_id as the client. 79 virtual int task_id() const = 0; 80 81 // Opaque hardware ID, e.g., the CUDA device number, useful for identifying 82 // which GPU when interacting with non-JAX code. In general, not guaranteed to 83 // be dense, and -1 if undefined. 84 virtual int local_hardware_id() const = 0; 85 86 // A vendor-dependent string that uniquely identifies the kind of device, 87 // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are 88 // compatible compilation. 89 virtual absl::string_view device_kind() const = 0; 90 91 virtual std::string DebugString() const = 0; 92 93 // Transfer the given literal to the infeed queue. 94 virtual Status TransferToInfeed(const LiteralSlice& literal) = 0; 95 96 // Transfer and return a value of the given shape from the outfeed queue. 97 virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0; 98 }; 99 100 // Forward declaration. 101 class PjRtBuffer; 102 103 // Helper struct for cross host transfers, returned by the callback from a call 104 // to PjRtBuffer::MakeCrossHostReceiveBuffers. 105 struct PjRtCrossHostRecvBuffer { 106 // serialized_descriptor should be transmitted to the sender and passed to a 107 // call to src_buffer->CopyToRemoteDevice. 108 std::string serialized_descriptor; 109 // The buffer that will hold the result of the transfer. 110 std::unique_ptr<PjRtBuffer> buffer; 111 }; 112 using PjRtCrossHostRecvNotifier = 113 std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>; 114 115 struct CompileOptions { 116 // The layouts of the arguments that the computation should expect. 117 absl::optional<std::vector<Shape>> argument_layouts; 118 119 // If true, the supplied computation expects its arguments to be wrapped in a 120 // tuple and passed as a single parameter. 121 bool parameter_is_tupled_arguments = false; 122 123 // XLA's compilation time options. 124 ExecutableBuildOptions executable_build_options; 125 126 // If true, the executable can be run on any device. May only be true if 127 // !executable_build_options.has_device_assignment(), so only applies to 128 // single-device executables. Beware: on GPUs, sometimes an executable 129 // compiled for one device doesn't run on another. 130 bool compile_portable_executable = false; 131 }; 132 133 class PjRtExecutable; 134 135 // Encapsulates the state of Python session with XLA. 136 // 137 // It is the responsibility of the client of this API to keep the PjRtClient 138 // alive as long as any of the other runtime objects are alive. 139 class PjRtClient { 140 public: 141 virtual ~PjRtClient() = default; 142 143 // Return the task id of this client. In single-task setting, always 0. 144 virtual int task_id() const = 0; 145 146 // Return the number of devices in the entire computation. In multi-headed 147 // client setting, some are addressable by this client, some are not. In a 148 // single-client setting, this is equal to the number of addressable devices. 149 virtual int device_count() const = 0; 150 151 // Return number of addressable devices. Addressable devices are those that 152 // the client can issue commands to. 153 virtual int addressable_device_count() const = 0; 154 155 // Return all devices in the entire computation, including addressable and 156 // non-addressable devices. 157 virtual absl::Span<PjRtDevice* const> devices() const = 0; 158 159 // Return only addressable devices. 160 virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0; 161 162 // Lookup any PjRtDevice for a given PjRtDevice::id(). 163 virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0; 164 165 // Return an addressable PjRtDevice for a given 166 // PjRtDevice::local_hardware_id(). 167 virtual StatusOr<PjRtDevice*> LookupAddressableDevice( 168 int local_hardware_id) const = 0; 169 170 // Return an ID that identifies the platform (CPU/GPU/TPU). 171 virtual PjRtPlatformId platform_id() const = 0; 172 173 // Returns a string that identifies the platform (CPU/GPU/TPU). 174 virtual absl::string_view platform_name() const = 0; 175 176 // Return a device-specific default device assignment, e.g., GPU and TPU may 177 // be different. 178 virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( 179 int num_replicas, int num_partitions) const = 0; 180 181 // Returns a backend-specific HLO cost analysis visitor. 182 virtual StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() = 0; 183 184 // Compile `computation` with given `options`. 185 virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile( 186 const XlaComputation& computation, CompileOptions options) = 0; 187 188 // Generates a unique fingerprint for `executable`, may be absl::nullopt. 189 virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint( 190 const PjRtExecutable& executable) const = 0; 191 192 // Creates a buffer on the device without initializing or copying any data. 193 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 194 const Shape& shape, PjRtDevice* device) = 0; 195 196 // Describes the semantics the caller to BufferFromHostBuffer expects from the 197 // runtime, in a total order from most restrictive to least restrictive. 198 enum class HostBufferSemantics { 199 // The runtime may not hold references to `data` after the call to 200 // `BufferFromHostBuffer` completes. The caller promises that `data` is 201 // immutable and will not be freed only for the duration of the 202 // BufferFromHostBuffer call. `on_done_with_host_buffer` will be called 203 // before `BufferFromHostBuffer` returns. 204 kImmutableOnlyDuringCall, 205 206 // The runtime may hold onto `data` after the call to `BufferFromHostBuffer` 207 // returns while the runtime completes a transfer to the device. The caller 208 // promises not to mutate or free `data` until the transfer completes, at 209 // which point the runtime will call `on_done_with_host_buffer`. It is also 210 // correct to wait on the host (directly or indirectly) for the buffer's 211 // definition event to complete. 212 kImmutableUntilTransferCompletes, 213 214 // The PjRtBuffer may alias `data` internally and the runtime may use the 215 // `data` contents as long as the buffer is alive. The caller promises to 216 // keep `data` alive and not to mutate its contents as long as the buffer is 217 // alive; to notify the caller that the buffer may be freed, the runtime 218 // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On 219 // non-CPU platforms this acts identically to 220 // kImmutableUntilTransferCompletes. 221 kZeroCopy, 222 }; 223 // on_done_with_host_buffer is optional and may be null. 224 // on_done_with_host_buffer will be called iff an OK status is returned. 225 virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( 226 const void* data, const Shape& shape, 227 HostBufferSemantics host_buffer_semantics, 228 std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0; 229 230 // Note that literal must remain in scope until the transfer has completed, so 231 // the caller should, for example, wait for BlockHostUntilReady() completes on 232 // the return value before letting literal go out of scope. 233 virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( 234 const LiteralSlice& literal, PjRtDevice* device) = 0; 235 236 // Creates a PjRtBuffer that is a non-owned view of an on-device 237 // buffer (typically allocated by another library). 238 // on_delete_callback is called when the PjRtBuffer is done with the on-device 239 // buffer. The buffer may be mutated, for example, if the buffer is donated 240 // to an Execute operation. 241 // TODO(phawkins): Currently this API assumes the buffer is ready to use 242 // immediately on the device. Extend it to support, for example, waiting for a 243 // CUDA stream/event. 244 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer( 245 void* device_ptr, const Shape& shape, PjRtDevice* device, 246 std::function<void()> on_delete_callback) = 0; 247 248 // Asynchronously makes a vector of PjRtBuffers that can be used to receive 249 // cross host transfers using `client` on `device'. `shapes` must be the exact 250 // shapes, with identical layouts, corresponding to the buffers that will be 251 // sent. When resources for the transfer are available, notifier will be 252 // called with a vector of PjRtCrossHostRecvBuffer structs, one for each 253 // shape in `shapes`. Each struct contains a buffer that will contain the 254 // received value, and an opaque string that should be transmitted to the 255 // sending host and used in a call to CopyToRemoteDevice. None of the recv 256 // buffers will become ready until *all* of the sends have completed. 257 virtual void MakeCrossHostReceiveBuffers( 258 absl::Span<const Shape> shapes, PjRtDevice* device, 259 PjRtCrossHostRecvNotifier&& notifier) = 0; 260 261 // Create ChannelHandles for XLA send/recv. 262 virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0; 263 virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0; 264 virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0; 265 }; 266 267 // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer 268 // can be either valid or invalid. An invalid buffer is one that has never been 269 // initialized, or a buffer that has been deleted (e.g., by calling Delete, or 270 // by donating it to a computation that aliases an input parameter to an 271 // output). We allow PjRtBuffer objects to outlive the underlying device 272 // buffers so we can decouple buffer lifetimes from the corresponding Python 273 // references if needed. Thread-safe. 274 class PjRtBuffer { 275 public: 276 virtual ~PjRtBuffer() = default; 277 278 virtual const Shape& on_device_shape() const = 0; 279 virtual PjRtDevice* device() const = 0; 280 virtual PjRtClient* client() const = 0; 281 282 // Returns the size of the on-device representation of this buffer in bytes. 283 virtual int64 OnDeviceSizeInBytes() const = 0; 284 285 // ExternalReference is a potentially long-lived reference held while a buffer 286 // is being shared by an external framework, e.g., NumPy. A client acquires an 287 // external reference by calling PjRtBuffer::AcquireExternalReference() and 288 // releases it by deleting the ExternalReference. The external framework 289 // should not modify the underlying buffer unless it is confident via its own 290 // synchronization that modifications do not race with reads from the 291 // PjRtBuffer. 292 class ExternalReference { 293 public: 294 virtual ~ExternalReference() = 0; 295 // Return opaque device memory pointer to root buffer. OpaqueDeviceMemoryDataPointer()296 void* OpaqueDeviceMemoryDataPointer() const { return data_ptr_; } 297 298 protected: 299 void* data_ptr_; 300 }; 301 virtual StatusOr<std::unique_ptr<ExternalReference>> 302 AcquireExternalReference() = 0; 303 304 // Copies the buffer's value into `literal`. Calls `on_ready` when the value 305 // (or an error) is ready. The transfer respects the layout of `literal`; to 306 // specify a particular layout, set the layout before calling `ToLiteral`. 307 virtual void ToLiteral(MutableLiteralBase* literal, 308 std::function<void(Status)> on_ready) = 0; 309 310 // Synchronous overload of ToLiteral, as a convenience. ToLiteral(MutableLiteralBase * literal)311 Status ToLiteral(MutableLiteralBase* literal) { 312 absl::Notification done; 313 Status status; 314 ToLiteral(literal, [&](Status s) { 315 status = std::move(s); 316 done.Notify(); 317 }); 318 done.WaitForNotification(); 319 return status; 320 } 321 322 // Convenience synchronous overload that allocates a literal with a default 323 // layout. ToLiteral()324 StatusOr<std::shared_ptr<Literal>> ToLiteral() { 325 auto literal = std::make_shared<Literal>( 326 ShapeUtil::DeviceShapeToHostShape(on_device_shape())); 327 TF_RETURN_IF_ERROR(ToLiteral(literal.get())); 328 return literal; 329 } 330 331 // Drops the buffer's reference to its associated device memory, leaving the 332 // buffer in an invalid state. The memory will be freed lazily when all async 333 // operations using the buffer have completed, according to the allocation 334 // semantics of the underlying platform. Delete may briefly block if another 335 // thread is in the process of enqueuing an operation on this buffer, but it 336 // will never block for a stream operation to complete. If an external 337 // framework holds a reference to the TrackedDeviceBuffer via 338 // GetBufferWithExternalReference, the memory will not be freed until the 339 // external framework drops the reference. 340 virtual void Delete() = 0; 341 342 // Similar to Delete, drops the buffer's reference to its associated device 343 // memory, leaving the buffer in an invalid state, but transfers the device 344 // memory ownership out via an ExternalReference rather than 345 // freeing the device memory, so that another framework can take ownership of 346 // it. A return value of nullptr indicates that PjRtBuffer has been 347 // deleted. The buffer returned from Release may be safely dropped at any time 348 // even if it still has pending async operations. The client should call 349 // BlockHostUntilReady before calling ReleaseDeviceMemoryOwnership with 350 // wait_for_operations_to_complete=false, to ensure that the host has 351 // synchronized past any outstanding write operations to the buffer. If 352 // wait_for_operations_to_complete=true the host will block until any 353 // potentially outstanding asynchronous operations have completed before 354 // returning, in which case it is safe to read or mutate the returned buffer. 355 // If the buffer was shared via an external reference it is the client's 356 // responsibility that accesses via that reference do not interfere with 357 // accesses via the buffer returned from ReleaseDeviceMemoryOwnership. 358 virtual StatusOr<std::unique_ptr<ExternalReference>> 359 ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) = 0; 360 361 // True if and only if Delete or Release has previously been called. 362 virtual bool IsDeleted() = 0; 363 364 // Copies the buffer to device `dst_device`, performing a d2d transfer when 365 // `dst_device` is sharing the same Client, and performing a d2h and h2d copy 366 // if `dst_device` lives on a different Client. 367 // Returns an error if the buffer is already on dst_device. 368 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice( 369 PjRtDevice* dst_device) = 0; 370 371 // Copies the buffer to the remote device encoded in serialized_descriptor. 372 // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the 373 // remote host's destination device. MakeCrossHostReceiveBuffers takes an 374 // array of shapes to construct the destination buffers, and a callback 375 // supplies an array containing both the destination buffers, and a serialized 376 // descriptor for each buffer. For each destination buffer there should be a 377 // matching call to src->CopyToRemoteDevice on a remote host for a src buffer 378 // of the corresponding shape. serialized_descriptor is the string returned by 379 // the callback along with the corresponding destination buffer. 380 virtual Status CopyToRemoteDevice( 381 absl::string_view serialized_descriptor) = 0; 382 383 // Blocks the host until the buffer's value has been computed and is ready for 384 // immediate use on the device. Useful in particular for timing benchmarks. 385 virtual Status BlockHostUntilReady() = 0; 386 387 // Whether this buffer is on CPU and thus allows for certain optimizations. 388 virtual bool IsOnCpu() const = 0; 389 }; 390 391 class ExecuteContext { 392 public: 393 virtual ~ExecuteContext() = default; 394 }; 395 396 struct ExecuteOptions { 397 // If true, the client must pass a single PjRtBuffer which contains all of 398 // the arguments as a single XLA tuple, otherwise each argument must be 399 // passed in its own PjRtBuffer. May only be true if the executable was 400 // compiled with parameter_is_tupled_arguments==true. 401 bool arguments_are_tupled = false; 402 // If true, the computation must return a tuple, which will be destructured 403 // into its elements. 404 bool untuple_result = false; 405 // If non-zero, identifies this execution as part of a potentially 406 // multi-device launch. This can be used to detect scheduling errors, e.g. if 407 // multi-host programs are launched in different orders on different hosts, 408 // the launch IDs may be used by the runtime to detect the mismatch. 409 int32 launch_id = 0; 410 // If non-null, an opaque context passed to an execution that may be used to 411 // supply additional arguments to a derived class of PjRtExecutable. 412 const ExecuteContext* context = nullptr; 413 }; 414 415 // Represents a compiled computation that can be executed given handles to 416 // device-allocated literals. If any input/output alias has been specified in 417 // the computation, the parameter containing the input buffer will be donated 418 // when passed to the execution. 419 class PjRtExecutable { 420 public: 421 virtual ~PjRtExecutable() = default; 422 423 virtual PjRtClient* client() const = 0; 424 425 // Unique name for this executable, e.g., HloModule name. 426 virtual absl::string_view name() const = 0; 427 428 virtual int num_replicas() const = 0; 429 430 virtual int num_partitions() const = 0; 431 432 virtual int64 SizeOfGeneratedCodeInBytes() const = 0; 433 434 virtual const DeviceAssignment& device_assignment() const = 0; 435 436 // The replica and partition indices of device_assignment to be run by this 437 // client. On single-host platforms without partitioning, this is all replicas 438 // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the 439 // case on multi-host platforms. If there are 4 replicas and 2 partitions on a 440 // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. 441 struct LogicalDeviceIds { 442 int replica; 443 int partition; 444 }; 445 virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() 446 const = 0; 447 448 // An addressable_device is one which the client can issue commands to. 449 // addressable_devices()[i] is the Device to which 450 // addressable_device_logical_ids()[i] is assigned. 451 virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0; 452 453 // Return an HloModule (optimized) per partition. 454 virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules() 455 const = 0; 456 457 // Executes on devices addressable by the client. Requires executable has a 458 // device_assignment and all devices in the device_assignment are addressable 459 // by the client. 460 virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> 461 Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, 462 const ExecuteOptions& options) = 0; 463 464 // Execute the assigned replica/partition on a given `device`. Requires 465 // executable has a device_assignment, `device` is present in the 466 // device_assignment and addressable by the client. 467 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded( 468 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 469 const ExecuteOptions& options) = 0; 470 471 // Execute on a given `device`. Requires `device` to be addressable by client. 472 // Requires executable has exactly 1 replica and 1 partition and no 473 // device_assignment (thus portable). 474 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable( 475 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 476 const ExecuteOptions& options) = 0; 477 478 // Asynchronously free resources after the last execution completes. 479 virtual void Delete() = 0; 480 }; 481 482 } // namespace xla 483 484 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ 485