1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/synchronization/notification.h"
26 #include "absl/types/optional.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/executable_build_options.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/status.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/platform/casts.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/fingerprint.h"
43 #include "tensorflow/core/platform/thread_annotations.h"
44 #include "tensorflow/core/platform/types.h"
45
46 // API notes:
47 // PjRt stands for "Pretty much Just another RunTime".
48
49 namespace xla {
50
51 using PjRtPlatformId = uint64;
52
53 constexpr char kCpuName[] = "cpu";
54 constexpr char kGpuName[] = "gpu";
55 constexpr char kTpuName[] = "tpu";
56 static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(kCpuName);
57 static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(kGpuName);
58 static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(kTpuName);
59
60 enum PjRtRuntimeType { kStreamExecutor, kTfrt };
PjRtRuntimeTypeString(PjRtRuntimeType type)61 static constexpr absl::string_view PjRtRuntimeTypeString(PjRtRuntimeType type) {
62 switch (type) {
63 case kStreamExecutor:
64 return "stream_executor";
65 case kTfrt:
66 return "tfrt";
67 }
68 }
69
70 class PjRtClient;
71
72 class PjRtDevice {
73 public:
~PjRtDevice()74 virtual ~PjRtDevice() {}
75
76 // Return the client that owns this device.
77 virtual PjRtClient* client() const = 0;
78
79 // Whether client can issue command to this device.
80 virtual bool IsAddressable() const = 0;
81
82 // The ID of this device. IDs are unique among devices of this type
83 // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
84 // hosts' devices. This is the ID that should be used in a DeviceAssignment.
85 virtual int id() const = 0;
86
87 // The index of the process that this device belongs to, i.e. is addressable
88 // from. This is not always identical to PjRtClient::process_index() in a
89 // multi-process setting, where each client can see devices from all
90 // processes, but only a subset of them are addressable and have the same
91 // process_index as the client.
92 virtual int process_index() const = 0;
93
94 // Deprecated; please switch to process_index().
task_id()95 int task_id() const { return process_index(); }
96
97 // Opaque hardware ID, e.g., the CUDA device number, useful for identifying
98 // which GPU when interacting with non-JAX code. In general, not guaranteed to
99 // be dense, and -1 if undefined.
100 virtual int local_hardware_id() const = 0;
101
102 // A vendor-dependent string that uniquely identifies the kind of device,
103 // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
104 // compatible compilation.
105 virtual absl::string_view device_kind() const = 0;
106
107 virtual std::string DebugString() const = 0;
108
109 // Transfer the given literal to the infeed queue.
110 virtual Status TransferToInfeed(const LiteralSlice& literal) = 0;
111
112 // Transfer and return a value of the given shape from the outfeed queue.
113 virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0;
114 };
115
116 // Forward declaration.
117 class PjRtBuffer;
118
119 // Helper struct for cross host transfers, returned by the callback from a call
120 // to PjRtBuffer::MakeCrossHostReceiveBuffers or
121 // PjRtBuffer::MakeCrossHostReceiveBuffersForGather.
122 struct PjRtCrossHostRecvBuffer {
123 // There is one serialized_descriptor per sub-buffer being gathered (i.e. a
124 // single descriptor if the buffer is returned from a call to
125 // MakeCrossHostReceiveBuffers). The descriptor should be transmitted to the
126 // sender(s) and passed to a call to src_buffer->CopyToRemoteDevice.
127 absl::InlinedVector<std::string, 1> serialized_descriptors;
128 // The buffer that will hold the result of the transfer.
129 std::unique_ptr<PjRtBuffer> buffer;
130 };
131 using PjRtCrossHostRecvNotifier =
132 std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
133
134 struct CompileOptions {
135 // The layouts of the arguments that the computation should expect.
136 absl::optional<std::vector<Shape>> argument_layouts;
137
138 // If true, the supplied computation expects its arguments to be wrapped in a
139 // tuple and passed as a single parameter.
140 bool parameter_is_tupled_arguments = false;
141
142 // XLA's compilation time options.
143 ExecutableBuildOptions executable_build_options;
144
145 // If true, the executable can be run on any device. May only be true if
146 // !executable_build_options.has_device_assignment(), so only applies to
147 // single-device executables. Beware: on GPUs, sometimes an executable
148 // compiled for one device doesn't run on another.
149 bool compile_portable_executable = false;
150 };
151
152 class PjRtExecutable;
153
154 // Encapsulates the state of Python session with XLA.
155 //
156 // It is the responsibility of the client of this API to keep the PjRtClient
157 // alive as long as any of the other runtime objects are alive.
158 class PjRtClient {
159 public:
160 virtual ~PjRtClient() = default;
161
162 // Return the process index of this client. Always 0 in single-process
163 // settings.
164 virtual int process_index() const = 0;
165
166 // Deprecated; please switch to process_index().
task_id()167 int task_id() const { return process_index(); }
168
169 // Return the number of devices in the entire computation. In multi-headed
170 // client setting, some are addressable by this client, some are not. In a
171 // single-client setting, this is equal to the number of addressable devices.
172 virtual int device_count() const = 0;
173
174 // Return number of addressable devices. Addressable devices are those that
175 // the client can issue commands to.
176 virtual int addressable_device_count() const = 0;
177
178 // Return all devices in the entire computation, including addressable and
179 // non-addressable devices.
180 virtual absl::Span<PjRtDevice* const> devices() const = 0;
181
182 // Return only addressable devices.
183 virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
184
185 // Lookup any PjRtDevice for a given PjRtDevice::id().
186 virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
187
188 // Return an addressable PjRtDevice for a given
189 // PjRtDevice::local_hardware_id().
190 virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
191 int local_hardware_id) const = 0;
192
193 // Return an ID that identifies the platform (CPU/GPU/TPU).
194 virtual PjRtPlatformId platform_id() const = 0;
195
196 // Returns a string that identifies the platform (CPU/GPU/TPU).
197 virtual absl::string_view platform_name() const = 0;
198
199 // Returns a string containing human-readable, platform-specific version info
200 // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
201 virtual absl::string_view platform_version() const = 0;
202
203 // Returns an enum that identifies the type of runtime being used under this
204 // client.
205 virtual PjRtRuntimeType runtime_type() const = 0;
206
207 // Return a device-specific default device assignment, e.g., GPU and TPU may
208 // be different.
209 virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
210 int num_replicas, int num_partitions) const = 0;
211
212 // Returns a backend-specific HLO cost analysis visitor.
213 virtual StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() = 0;
214
215 // Compile `computation` with given `options`.
216 virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
217 const XlaComputation& computation, CompileOptions options) = 0;
218
219 // Generates a unique fingerprint for `executable`, may be absl::nullopt.
220 virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
221 const PjRtExecutable& executable) const = 0;
222
223 // Returns a platform-specific serialization of `executable`. The
224 // serialization is not guaranteed to be stable over time. `executable` must
225 // have been produced by this client.
226 virtual StatusOr<std::string> SerializeExecutable(
227 const PjRtExecutable& executable) const = 0;
228
229 // Deserializes a serialized executable as produced by
230 // SerializeExecutable(). `serialized` must have been produced by a client of
231 // the same platform and version as this one.
232 virtual StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
233 absl::string_view serialized, CompileOptions options) = 0;
234
235 // Creates a buffer on the device without initializing or copying any data.
236 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
237 const Shape& shape, PjRtDevice* device) = 0;
238
239 // A client may want to create a buffer, and hand the buffer to other PjRt
240 // methods, before the data to store in the buffer is available to the client.
241 // This is supported using CreateBuffersForAsyncTransfer, which returns an
242 // AsyncBufferTransferManager helper object.
243 //
244 // The PjRtBuffers can be retrieved from the AsyncBufferTransferManager and
245 // safely passed immediately to downstream PjRt method calls. Subsequently the
246 // client can call methods on the AsyncBufferTransferManager object to copy
247 // data into the buffers, and once the data copies are complete, the buffers'
248 // definition events will automatically become ready, unblocking downstream
249 // consumers of the buffers.
250 //
251 // A single call to CreateBuffersForAsyncTransfer creates a "batch" of buffers
252 // that share a single definition event, which may amortize some performance
253 // overheads, but means that none of the buffers are available to downstream
254 // consumers until all the transfers have completed. Multiple calls to
255 // CreateBuffersForAsyncTransfer should be made if it is desirable for buffers
256 // to become available as soon as transfers into them complete.
257
258 // Helper class to all clients to asynchronously transfer data into buffers
259 // that are created uninitialized, see comments immediately above.
260 class AsyncBufferTransferManager {
261 public:
262 virtual ~AsyncBufferTransferManager() = default;
263
264 // Returns the number of buffers managed by this object.
265 virtual size_t buffer_count() const = 0;
266
267 // Returns the destination device of the transfers.
268 virtual PjRtDevice* device() const = 0;
269
270 // Returns buffer_index, which can be passed to downstream consumers
271 // immediately and will become available once transfers complete. May not
272 // be called more than once for a given buffer_index.
273 //
274 // RetrieveBuffer can be called at any convenient time; transfer methods
275 // can safely be called for a buffer index after RetrieveBuffer has been
276 // called.
277 virtual std::unique_ptr<PjRtBuffer> RetrieveBuffer(int buffer_index) = 0;
278
279 // Transfers 'literal' into buffer_index. No transfer calls into
280 // buffer_index can be made after this call. on_done is called when the
281 // transfer is complete but before the buffers are made available to
282 // their consumers. 'literal' must remain in scope until on_done is
283 // called.
284 virtual Status TransferLiteralToBuffer(int buffer_index,
285 const LiteralSlice& literal,
286 std::function<void()> on_done) = 0;
287
288 // Returns the on-device size in bytes of buffer buffer_index.
289 virtual size_t buffer_size(int buffer_index) const = 0;
290
291 // Transfers 'data' into buffer_index. 'data' must be already laid out in
292 // the correct on-device format, for example returned by a call to
293 // buffer->CopyRawToHost. No transfer calls into buffer_index can be made
294 // after this call. on_done is called when the transfer is complete but
295 // before the buffers are made available to their consumers. 'data' must
296 // remain in scope until on_done is called.
297 virtual Status TransferRawDataToBuffer(int buffer_index,
298 absl::string_view data,
299 std::function<void()> on_done) = 0;
300
301 // Transfers 'data' into a sub-buffer of buffer_index starting at offset, of
302 // length transfer_size. 'data' must be already laid out in the correct
303 // on-device format, for example returned by a call to
304 // buffer->CopyRawToHost. If is_last_transfer is false then the buffer
305 // remains unavailable to consumers after the transfer completes. If
306 // is_last_transfer is true then the buffer becomes available to consumers
307 // after the transfer completes, and no transfer calls into buffer_index can
308 // be made after this call. on_done is called when the transfer is complete
309 // but before the buffers are made available to their consumers. 'data' must
310 // remain in scope until on_done is called.
311 virtual Status TransferRawDataToSubBuffer(
312 int buffer_index, const void* data, int64_t offset,
313 int64_t transfer_size, bool is_last_transfer,
314 std::function<void()> on_done) = 0;
315
316 // Indicates that a client error occurred and the transfers will never
317 // complete. Puts all buffers in an error state. For the stream executor
318 // client, since error states are not well supported, this triggers a fatal
319 // error.
320 //
321 // SetTransferError may be called at most once, and may not be called unless
322 // at least one buffer has not yet had its final transfer initiated.
323 virtual void SetTransferError(Status error) = 0;
324 };
325
326 // Returns a manager for async transfers into a set of buffers with on-host
327 // shapes 'shapes'.
328 virtual StatusOr<std::unique_ptr<AsyncBufferTransferManager>>
329 CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
330 PjRtDevice* device) = 0;
331
332 // Describes the semantics the caller to BufferFromHostBuffer expects from the
333 // runtime, in a total order from most restrictive to least restrictive.
334 enum class HostBufferSemantics {
335 // The runtime may not hold references to `data` after the call to
336 // `BufferFromHostBuffer` completes. The caller promises that `data` is
337 // immutable and will not be freed only for the duration of the
338 // BufferFromHostBuffer call. `on_done_with_host_buffer` will be called
339 // before `BufferFromHostBuffer` returns.
340 kImmutableOnlyDuringCall,
341
342 // The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
343 // returns while the runtime completes a transfer to the device. The caller
344 // promises not to mutate or free `data` until the transfer completes, at
345 // which point the runtime will call `on_done_with_host_buffer`. It is also
346 // correct to wait on the host (directly or indirectly) for the buffer's
347 // definition event to complete.
348 kImmutableUntilTransferCompletes,
349
350 // The PjRtBuffer may alias `data` internally and the runtime may use the
351 // `data` contents as long as the buffer is alive. The caller promises to
352 // keep `data` alive and not to mutate its contents as long as the buffer is
353 // alive; to notify the caller that the buffer may be freed, the runtime
354 // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
355 // non-CPU platforms this acts identically to
356 // kImmutableUntilTransferCompletes.
357 kZeroCopy,
358 };
359
360 // on_done_with_host_buffer is optional and may be null.
361 // on_done_with_host_buffer will be called iff an OK status is returned.
362 //
363 // `data` points to the backing array of the host buffer. Caution:
364 // `byte_strides` are allowed to be negative, in which case `data` may need
365 // to point to the interior of the buffer, not necessarily its start.
366 //
367 // If byte_strides is omitted, the array is assumed to have a dense layout
368 // with dimensions in major-to-minor order.
369 virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
370 const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
371 absl::optional<absl::Span<int64_t const>> byte_strides,
372 HostBufferSemantics host_buffer_semantics,
373 std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0;
374
375 // Note that literal must remain in scope until the transfer has completed, so
376 // the caller should, for example, wait for BlockHostUntilReady() completes on
377 // the return value before letting literal go out of scope.
378 virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
379 const LiteralSlice& literal, PjRtDevice* device) = 0;
380
381 // Creates a PjRtBuffer that is a non-owned view of an on-device
382 // buffer (typically allocated by another library).
383 // on_delete_callback is called when the PjRtBuffer is done with the on-device
384 // buffer. The buffer may be mutated, for example, if the buffer is donated
385 // to an Execute operation.
386 // TODO(phawkins): Currently this API assumes the buffer is ready to use
387 // immediately on the device. Extend it to support, for example, waiting for a
388 // CUDA stream/event.
389 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
390 void* device_ptr, const Shape& shape, PjRtDevice* device,
391 std::function<void()> on_delete_callback) = 0;
392
393 // Returns platform-dependent address for the given buffer that is often but
394 // not guaranteed to be the physical/device address.
395 virtual StatusOr<std::uintptr_t> UnsafeBufferPointer(PjRtBuffer* buffer);
396
397 // Asynchronously makes a vector of PjRtBuffers that can be used to receive
398 // cross host transfers using `client` on `device'. `shapes` must be the exact
399 // shapes, with identical layouts, corresponding to the buffers that will be
400 // sent. When resources for the transfer are available, notifier will be
401 // called with a vector of PjRtCrossHostRecvBuffer structs, one for each
402 // shape in `shapes`. Each struct contains a buffer that will contain the
403 // received value, and an opaque string that should be transmitted to the
404 // sending host and used in a call to CopyToRemoteDevice. None of the recv
405 // buffers will become ready until *all* of the sends have completed.
406 virtual void MakeCrossHostReceiveBuffers(
407 absl::Span<const Shape> shapes, PjRtDevice* device,
408 PjRtCrossHostRecvNotifier&& notifier) = 0;
409
410 // Asynchronously makes a vector of PjRtBuffers that can be used to receive
411 // cross host transfers, as in MakeCrossHostReceiveBuffers above, however
412 // each buffer expects to be "gathered" using multiple sends, one for each of
413 // a set of sub-slices of the destination buffer.
414 //
415 // For each value in shapes there is a corresponding FullGatherDetails struct
416 // that describes the sub-slices.
417 struct GatherDetails {
418 // The dimensions of the corresponding buffer that the gather slices
419 // into. These dimensions must be the major dimensions in the on-device
420 // layout of the buffer, and must all be untiled. The scatter acts as if
421 // the buffer were transposed/reshaped so that all of these dimensions were
422 // combined into a single dimension whose size is the product of the
423 // dimensions, and the slice indices correspond to indices in that single
424 // combined dimension.
425 //
426 // For example, if the shape is [3, 4, 128, 128] with [3, 4] as the major
427 // dimensions in the layout, and dimensions = {0, 1}, then the buffer is
428 // treated as if it were shape [12, 128, 128] and the indices in
429 // slice_boundaries range in [0, 12].
430 absl::InlinedVector<int, 3> dimensions;
431 // The cumulative indices in dimension of the slices. For example, if
432 // shape.dimensions(dimension)==10, setting slice_boundaries to {2, 5, 10}
433 // would correspond to 3 slices of sizes {2, 3, 5} respectively. If the last
434 // entry in slice_boundaries is less than the size of the combined gather
435 // dimension, the trailing data in the buffer is undefined after the receive
436 // completes.
437 std::vector<int64> slice_boundaries;
438 };
439 virtual void MakeCrossHostReceiveBuffersForGather(
440 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
441 PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) = 0;
442
443 // Create ChannelHandles for XLA send/recv.
444 virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
445 virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
446 virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
447
448 // TODO(zhangqiaorjc): Experimental API to be removed.
449 // Defragment device memory.
450 virtual Status Defragment() = 0;
451 };
452
453 // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
454 // can be either valid or invalid. An invalid buffer is one that has never been
455 // initialized, or a buffer that has been deleted (e.g., by calling Delete, or
456 // by donating it to a computation that aliases an input parameter to an
457 // output). We allow PjRtBuffer objects to outlive the underlying device
458 // buffers so we can decouple buffer lifetimes from the corresponding Python
459 // references if needed. Thread-safe.
460 class PjRtBuffer {
461 public:
462 virtual ~PjRtBuffer() = default;
463
464 virtual const Shape& on_device_shape() const = 0;
465
466 // Same as on_device_shape when the shape is static. When the shape is
467 // dynamic, it gathers the metadata from the device and returns a static shape
468 // representing the logical shape of the data. This approach is identical to
469 // how tensorflow and xrt setup the output buffer in the graph.
470 //
471 // Since this method actually acquires locks and communicate with the device,
472 // it does not have the const qualifier, similar to what ToLiteral does.
473 virtual StatusOr<Shape> logical_on_device_shape() = 0;
474 virtual PjRtDevice* device() const = 0;
475 virtual PjRtClient* client() const = 0;
476
477 // ExternalReference is a potentially long-lived reference held while a buffer
478 // is being shared by an external framework, e.g., NumPy. A client acquires an
479 // external reference by calling PjRtBuffer::AcquireExternalReference() and
480 // releases it by deleting the ExternalReference. The external framework
481 // should not modify the underlying buffer unless it is confident via its own
482 // synchronization that modifications do not race with reads from the
483 // PjRtBuffer.
484 class ExternalReference {
485 public:
486 virtual ~ExternalReference() = 0;
487 // Return opaque device memory pointer to root buffer.
OpaqueDeviceMemoryDataPointer()488 void* OpaqueDeviceMemoryDataPointer() const { return data_ptr_; }
489
490 protected:
491 void* data_ptr_;
492 };
493 virtual StatusOr<std::unique_ptr<ExternalReference>>
494 AcquireExternalReference() = 0;
495
496 // Copies the buffer's value into `literal`. Calls `on_ready` when the value
497 // (or an error) is ready. The transfer respects the layout of `literal`; to
498 // specify a particular layout, set the layout before calling `ToLiteral`.
499 virtual void ToLiteral(MutableLiteralBase* literal,
500 std::function<void(Status)> on_ready) = 0;
501
502 // Synchronous overload of ToLiteral, as a convenience.
ToLiteral(MutableLiteralBase * literal)503 Status ToLiteral(MutableLiteralBase* literal) {
504 absl::Notification done;
505 Status status;
506 ToLiteral(literal, [&](Status s) {
507 status = std::move(s);
508 done.Notify();
509 });
510 done.WaitForNotification();
511 return status;
512 }
513
514 // Convenience synchronous overload that allocates a literal with a default
515 // layout.
ToLiteral()516 StatusOr<std::shared_ptr<Literal>> ToLiteral() {
517 auto literal = std::make_shared<Literal>(
518 ShapeUtil::DeviceShapeToHostShape(on_device_shape()));
519 TF_RETURN_IF_ERROR(ToLiteral(literal.get()));
520 return literal;
521 }
522
523 // Returns the number of bytes of the buffer storage on the device.
524 virtual StatusOr<size_t> GetOnDeviceSizeInBytes() const = 0;
525
526 // Transfers a sub-range of the on-device representation of the buffer.
527 // offset+transfer_size must be less than GetOnDeviceSizeInBytes. on_ready
528 // is called if and only if CopyRawToHost returns OK. on_ready will be called
529 // with a non-OK status if the buffer asynchronously transitions to an error
530 // state.
531 virtual Status CopyRawToHost(void* dst, int64_t offset, int64_t transfer_size,
532 std::function<void(Status)> on_ready) = 0;
533
534 // Drops the buffer's reference to its associated device memory, leaving the
535 // buffer in an invalid state. The memory will be freed lazily when all async
536 // operations using the buffer have completed, according to the allocation
537 // semantics of the underlying platform. Delete may briefly block if another
538 // thread is in the process of enqueuing an operation on this buffer, but it
539 // will never block for a stream operation to complete. If an external
540 // framework holds a reference to the TrackedDeviceBuffer via
541 // GetBufferWithExternalReference, the memory will not be freed until the
542 // external framework drops the reference.
543 virtual void Delete() = 0;
544
545 // Similar to Delete, drops the buffer's reference to its associated device
546 // memory, leaving the buffer in an invalid state, but transfers the device
547 // memory ownership out via an ExternalReference rather than
548 // freeing the device memory, so that another framework can take ownership of
549 // it. A return value of nullptr indicates that PjRtBuffer has been
550 // deleted. The buffer returned from Release may be safely dropped at any time
551 // even if it still has pending async operations. The client should call
552 // BlockHostUntilReady before calling ReleaseDeviceMemoryOwnership with
553 // wait_for_operations_to_complete=false, to ensure that the host has
554 // synchronized past any outstanding write operations to the buffer. If
555 // wait_for_operations_to_complete=true the host will block until any
556 // potentially outstanding asynchronous operations have completed before
557 // returning, in which case it is safe to read or mutate the returned buffer.
558 // If the buffer was shared via an external reference it is the client's
559 // responsibility that accesses via that reference do not interfere with
560 // accesses via the buffer returned from ReleaseDeviceMemoryOwnership.
561 virtual StatusOr<std::unique_ptr<ExternalReference>>
562 ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) = 0;
563
564 // True if and only if Delete or Release has previously been called.
565 virtual bool IsDeleted() = 0;
566
567 // Copies the buffer to device `dst_device`, performing a d2d transfer when
568 // `dst_device` is sharing the same Client, and performing a d2h and h2d copy
569 // if `dst_device` lives on a different Client.
570 // Returns an error if the buffer is already on dst_device.
571 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
572 PjRtDevice* dst_device) = 0;
573
574 // Copies the buffer to the remote device encoded in serialized_descriptor.
575 // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
576 // remote host's destination device. MakeCrossHostReceiveBuffers takes an
577 // array of shapes to construct the destination buffers, and a callback
578 // supplies an array containing both the destination buffers, and a serialized
579 // descriptor for each buffer. For each destination buffer there should be a
580 // matching call to src->CopyToRemoteDevice on a remote host for a src buffer
581 // of the corresponding shape. serialized_descriptor is the string returned by
582 // the callback along with the corresponding destination buffer.
583 virtual Status CopyToRemoteDevice(
584 absl::string_view serialized_descriptor) = 0;
585 struct ScatterDetails {
586 // The dimensions of the corresponding buffer that the scatter slices
587 // across. These dimensions must be the major dimensions in the on-device
588 // layout of the buffer, and must all be untiled. The scatter acts as if
589 // the buffer were transposed/reshaped so that all of these dimensions were
590 // combined into a single dimension whose size is the product of the
591 // dimensions, and the slice indices correspond to indices in that single
592 // combined dimension.
593 //
594 // For example, if the shape is [3, 4, 128, 128] with [3, 4] as the major
595 // dimensions in the layout, and dimensions = {0, 1}, then the buffer is
596 // treated as if it were shape [12, 128, 128] and the indices in slices
597 // range in [0, 12].
598 absl::InlinedVector<int, 3> dimensions;
599 // The start and end indices of the slices.
600 std::vector<std::pair<int64, int64>> slices;
601 };
602 virtual Status CopyToRemoteDeviceScattered(
603 absl::Span<const std::string> serialized_descriptors,
604 const ScatterDetails& scatter_details) = 0;
605
606 // Blocks the host until the buffer's value has been computed and is ready for
607 // immediate use on the device. Useful in particular for timing benchmarks.
608 virtual Status BlockHostUntilReady() = 0;
609
610 // Whether this buffer is on CPU and thus allows for certain optimizations.
611 virtual bool IsOnCpu() const = 0;
612 };
613
614 class ExecuteContext {
615 public:
616 virtual ~ExecuteContext() = default;
617 };
618
619 struct ExecuteOptions {
620 // If true, the client must pass a single PjRtBuffer which contains all of
621 // the arguments as a single XLA tuple, otherwise each argument must be
622 // passed in its own PjRtBuffer. May only be true if the executable was
623 // compiled with parameter_is_tupled_arguments==true.
624 bool arguments_are_tupled = false;
625 // If true, the computation must return a tuple, which will be destructured
626 // into its elements.
627 bool untuple_result = false;
628 // If non-zero, identifies this execution as part of a potentially
629 // multi-device launch. This can be used to detect scheduling errors, e.g. if
630 // multi-host programs are launched in different orders on different hosts,
631 // the launch IDs may be used by the runtime to detect the mismatch.
632 int32 launch_id = 0;
633 // If non-null, an opaque context passed to an execution that may be used to
634 // supply additional arguments to a derived class of PjRtExecutable.
635 const ExecuteContext* context = nullptr;
636 // If true, check that the PjRtBuffer argument shapes match the compiled
637 // shapes. Otherwise, any shape with the right size on device may be passed.
638 bool strict_shape_checking = true;
639 };
640
641 // Represents a compiled computation that can be executed given handles to
642 // device-allocated literals. If any input/output alias has been specified in
643 // the computation, the parameter containing the input buffer will be donated
644 // when passed to the execution.
645 class PjRtExecutable {
646 public:
647 virtual ~PjRtExecutable() = default;
648
649 virtual PjRtClient* client() const = 0;
650
651 // Unique name for this executable, e.g., HloModule name.
652 virtual absl::string_view name() const = 0;
653
654 virtual int num_replicas() const = 0;
655
656 virtual int num_partitions() const = 0;
657
658 virtual int64 SizeOfGeneratedCodeInBytes() const = 0;
659
660 virtual const DeviceAssignment& device_assignment() const = 0;
661
662 // The replica and partition indices of device_assignment to be run by this
663 // client. On single-host platforms without partitioning, this is all replicas
664 // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
665 // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
666 // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
667 struct LogicalDeviceIds {
668 int replica;
669 int partition;
670 };
671 virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
672 const = 0;
673
674 // An addressable_device is one which the client can issue commands to.
675 // addressable_devices()[i] is the Device to which
676 // addressable_device_logical_ids()[i] is assigned.
677 virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
678
679 // Return an HloModule (optimized) per partition.
680 virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
681 const = 0;
682
683 // Executes on devices addressable by the client. Requires executable has a
684 // device_assignment and all devices in the device_assignment are addressable
685 // by the client.
686 // `argument_handles` is `[num_devices, num_args]`.
687 virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
688 Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
689 const ExecuteOptions& options) = 0;
690
691 // Execute the assigned replica/partition on a given `device`. Requires
692 // executable has a device_assignment, `device` is present in the
693 // device_assignment and addressable by the client.
694 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
695 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
696 const ExecuteOptions& options) = 0;
697
698 // Execute on a given `device`. Requires `device` to be addressable by client.
699 // Requires executable has exactly 1 replica and 1 partition and no
700 // device_assignment (thus portable).
701 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
702 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
703 const ExecuteOptions& options) = 0;
704
705 // Asynchronously free resources after the last execution completes.
706 virtual void Delete() = 0;
707
708 // True if on-device resources associated with the executable are freed.
709 virtual bool IsDeleted() = 0;
710 };
711
712 } // namespace xla
713
714 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
715