1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18
19 #include <atomic>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <tuple>
25 #include <vector>
26
27 #include "absl/base/macros.h"
28 #include "absl/base/thread_annotations.h"
29 #include "absl/memory/memory.h"
30 #include "absl/synchronization/mutex.h"
31 #include "absl/types/optional.h"
32 #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h"
33 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
34 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
35 #include "tensorflow/compiler/xla/stream_executor/lib/threadpool.h"
36 #include "tensorflow/compiler/xla/stream_executor/platform.h"
37 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
38 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
39 #include "tensorflow/compiler/xla/stream_executor/rng.h"
40 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
41 #include "tensorflow/compiler/xla/stream_executor/trace_listener.h"
42
43 namespace stream_executor {
44
45 class Stream;
46
47 // Structure used for device memory leak checking.
48 struct AllocRecord {
49 // The requested allocation size of the buffer.
50 uint64_t bytes;
51
52 // Holds a representation of the stack at the time the associated buffer was
53 // allocated. Produced in a form described in
54 // //util/symbolize/symbolized_stacktrace.h.
55 std::string stack_trace;
56 };
57
58 // Forward declaration of private friend class.
59 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
60 typename... BeginArgsT>
61 class ScopedTracer;
62
63 // A StreamExecutor manages a single device, in terms of executing work (kernel
64 // launches) and memory management (allocation/deallocation, memory copies to
65 // and from the device). It is conceptually the "handle" for a device -- Stream
66 // objects, which are used to enqueue work to run on the
67 // coprocessor have a StreamExecutor instance as their "parent" object.
68 //
69 // StreamExecutor objects have an underlying platform that is specified up
70 // front;
71 // e.g. either it is a CUDA or OpenCL executor.
72 //
73 // Thread-safe after initialization.
74 // StreamExecutor interface should not be invoked from a signal handler.
75 class StreamExecutor {
76 public:
77 StreamExecutor(
78 const Platform* platform,
79 std::unique_ptr<internal::StreamExecutorInterface> implementation,
80 int device_ordinal);
81
82 ~StreamExecutor();
83
84 port::Status Init();
85 port::Status Init(DeviceOptions device_options);
86
87 // Returns the platform that this StreamExecutor is acting upon.
88 ABSL_DEPRECATED("Use platform() instead.")
platform_kind()89 PlatformKind platform_kind() const { return platform_kind_; }
90
91 // Returns a reference to the platform that created this executor.
platform()92 const Platform* platform() const { return platform_; }
93
94 // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
95 // upon, if one exists.
96 //
97 // Parameters:
98 // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
99 // constant into an appropriate namespace. For example, see
100 // stream_executor::executor_sample::kKernelLoaderSpecs, from which a
101 // MultiKernelLoaderSpec is selected.
102 // kernel: Outparam that the kernel is loaded into. A given Kernel
103 // instantiation should not be loaded into more than once.
104 //
105 // If an error occurs, or there is no kernel available for the StreamExecutor
106 // platform, error status is returned.
107 port::Status GetKernel(const MultiKernelLoaderSpec& spec, KernelBase* kernel);
108
109 // Releases any state associated with the previously loaded kernel.
110 void UnloadKernel(const KernelBase* kernel);
111
112 // Loads a module for the platform this StreamExecutor is acting upon.
113 //
114 // `spec` describes the module to be loaded. On success writes the handle for
115 // the loaded module to `module_handle` and returns Status::OK.
116 // Otherwise, returns the error which has occurred.
117 port::Status LoadModule(const MultiModuleLoaderSpec& spec,
118 ModuleHandle* module_handle);
119
120 // Unloads the module with handle `module_handle`.
121 bool UnloadModule(ModuleHandle module_handle);
122
123 port::StatusOr<std::shared_ptr<DeviceMemoryBase>> CreateOrShareConstant(
124 Stream* stream, const std::vector<uint8_t>& content);
125
126 // Synchronously allocates an array on the device of type T with element_count
127 // elements.
128 template <typename T>
129 DeviceMemory<T> AllocateArray(uint64_t element_count,
130 int64_t memory_space = 0);
131
132 // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
133 template <typename T>
AllocateOwnedArray(uint64_t element_count)134 ScopedDeviceMemory<T> AllocateOwnedArray(uint64_t element_count) {
135 return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
136 }
137
138 // Convenience wrapper that allocates space for a single element of type T in
139 // device memory.
140 template <typename T>
AllocateScalar()141 DeviceMemory<T> AllocateScalar() {
142 return AllocateArray<T>(1);
143 }
144
145 // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
146 template <typename T>
AllocateOwnedScalar()147 ScopedDeviceMemory<T> AllocateOwnedScalar() {
148 return AllocateOwnedArray<T>(1);
149 }
150
151 // Synchronously allocates a scalar of type T on the device that is (POD)
152 // zero-byte initialized.
153 template <typename T>
154 DeviceMemory<T> AllocateZeroed();
155
156 // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
157 template <typename T>
AllocateOwnedZeroed()158 ScopedDeviceMemory<T> AllocateOwnedZeroed() {
159 return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
160 }
161
162 // Allocate a memory region inside another allocated memory region.
163 // Offset and size are specified in terms of T elements.
164 // Warning: Do not free a parent buffer before its sub-buffers; this may cause
165 // use-after-free issues (the specific behavior is not consistent across
166 // platforms).
167 // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
168 // sub-buffer after parent deallocation is expected to be safe. This will
169 // render your code non-platform-portable, however.
170 template <typename T>
171 DeviceMemory<T> GetSubBuffer(DeviceMemory<T>* parent, uint64_t element_offset,
172 uint64_t element_count);
173
174 // Finds a symbol within the module corresponding to `module_handle` and
175 // returns device memory allocated to the symbol. The user has to make sure
176 // that the type of symbol and T match.
177 // - Note: symbol_name should include its namespace as well. For example,
178 // pass "nms0::symbol" if referring to nms0::symbol.
179 template <typename T>
180 port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string& symbol_name,
181 ModuleHandle module_handle);
182
183 // An untyped version of GetSymbol.
184 port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
185 const std::string& symbol_name, ModuleHandle module_handle);
186
187 // Deallocate the DeviceMemory previously allocated via this interface.
188 // Deallocation of a nullptr-representative value is permitted.
189 //
190 // Resets the internal contents of mem to be null-representative, but this
191 // null-out effect should not be relied upon in client code.
192 void Deallocate(DeviceMemoryBase* mem);
193
194 // Retrieves a mapping of active opaque device memory pointer to a string
195 // representation of the [allocating thread's] stack at the time the pointer
196 // was allocated. Useful for tracking device memory leaks.
197 //
198 // Note: this will only be populated if --check_device_leaks flag is
199 // activated.
200 void GetMemAllocs(std::map<void*, AllocRecord>* records_out);
201
202 // Allocates unified memory space of the given size, if supported.
203 // See
204 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
205 // for more details on unified memory.
206 void* UnifiedMemoryAllocate(uint64_t bytes);
207
208 // Deallocates unified memory space previously allocated with
209 // UnifiedMemoryAllocate.
210 void UnifiedMemoryDeallocate(void* location);
211
212 // Allocates a region of host memory and registers it with the platform API.
213 // Memory allocated in this manner (or allocated and registered with
214 // HostMemoryRegister() is required for use in asynchronous memcpy operations,
215 // such as Stream::ThenMemcpy.
216 void* HostMemoryAllocate(uint64_t size);
217
218 // Deallocates a region of host memory allocated by HostMemoryAllocate().
219 void HostMemoryDeallocate(void* location);
220
221 // Registers a region of host memory with the platform API. Registered memory
222 // (or memory allocated with HostMemoryAllocate) is required for use with
223 // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
224 // is used to register memory allocated outside the StreamExecutor;
225 // HostMemoryAllocate implicitly registers its allocations and
226 // HostMemoryDeallocate implicitly deregisters on deallocation.
227 bool HostMemoryRegister(void* location, uint64_t size) SE_MUST_USE_RESULT;
228
229 // Unregisters a region of host memory registered with HostMemoryRegister.
230 // This should be done before deallocating the region with delete[]/free/etc.
231 bool HostMemoryUnregister(void* location) SE_MUST_USE_RESULT;
232
233 // Synchronizes all activity occurring in the StreamExecutor's context (most
234 // likely a whole device).
235 bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
236
237 // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
238 // given location in device memory.
239 port::Status SynchronousMemZero(DeviceMemoryBase* location,
240 uint64_t size) SE_MUST_USE_RESULT;
241
242 // Blocks the caller while "size" bytes are initialized to "value" (in POD
243 // fashion) at the given location in device memory.
244 port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
245 uint64_t size) SE_MUST_USE_RESULT;
246
247 // [deprecated] Blocks the caller while a data segment of the given size is
248 // copied from the host source to the device destination.
249 ABSL_DEPRECATED(
250 "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
251 bool SynchronousMemcpy(DeviceMemoryBase* device_dst, const void* host_src,
252 uint64_t size) SE_MUST_USE_RESULT;
253
254 // [deprecated] Blocks the caller while a data segment of the given size is
255 // copied from the device source to the host destination.
256 ABSL_DEPRECATED(
257 "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
258 bool SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& device_src,
259 uint64_t size) SE_MUST_USE_RESULT;
260
261 // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
262 port::Status SynchronousMemcpyH2D(const void* host_src, int64_t size,
263 DeviceMemoryBase* device_dst);
264
265 // Alternative interface for memcpying from host to device that takes an
266 // array slice. Checks that the destination size can accommodate the host
267 // slice size.
268 template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)269 port::Status SynchronousMemcpyH2D(
270 port::ArraySlice<T> host_src, // non-absl ok
271 DeviceMemoryBase* device_dst) {
272 auto host_size = host_src.size() * sizeof(T);
273 CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
274 return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
275 }
276
277 // Same as SynchronousMemcpy(void*, ...) above.
278 port::Status SynchronousMemcpyD2H(const DeviceMemoryBase& device_src,
279 int64_t size, void* host_dst);
280
281 // Alternative interface for memcpying from device to host that takes an
282 // array slice. Checks that the destination size can accommodate the host
283 // slice size.
284 template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)285 port::Status SynchronousMemcpyD2H(const DeviceMemory<T>& device_src,
286 port::MutableArraySlice<T> host_dst) {
287 auto host_size = host_dst.size() * sizeof(T);
288 CHECK(device_src.size() == 0 || host_size >= device_src.size());
289 return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
290 }
291
292 // Blocks the caller while a data segment of the given size is copied from the
293 // device source to the device destination.
294 bool SynchronousMemcpy(DeviceMemoryBase* device_dst,
295 const DeviceMemoryBase& device_src,
296 uint64_t size) SE_MUST_USE_RESULT;
297
298 // Enqueues an operation onto stream to zero out size bytes at the given
299 // device memory location. Neither stream nor location may be null. Returns
300 // whether the operation was successfully enqueued onto the stream.
301 port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
302 uint64_t size) SE_MUST_USE_RESULT;
303
304 // Enqueues an operation onto stream to set 32-bit patterns starting at
305 // location, for byte count given by size. size must be 32-bit quantified
306 // (i.e. evently divisible by 4). Returns whether the operation was
307 // successfully enqueued onto the stream.
308 port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
309 uint32 pattern, uint64_t size);
310
311 // Enables peer access from this StreamExecutor to memory
312 // allocated by other, such that launched device code, memcpies, etc may
313 // access it directly.
314 //
315 // Both this StreamExecutor and other must be backed by the same platform (as
316 // in
317 // CUDA vs OpenCL) implementation.
318 port::Status EnablePeerAccessTo(StreamExecutor* other);
319
320 // Returns whether it's possible to enable peer access from this
321 // StreamExecutor
322 // to memory allocated by another.
323 //
324 // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
325 // this is more an up-front test as to whether it's expressly forbidden.
326 bool CanEnablePeerAccessTo(StreamExecutor* other);
327
328 // Obtains metadata about the underlying device.
329 // The value is cached on first use.
330 const DeviceDescription& GetDeviceDescription() const;
331
332 // If implemented, returns device specific measurement of load
333 // (e.g. pending requests).
334 int64_t GetDeviceLoad() const;
335
336 // Returns the underlying device memory usage information, if it is available.
337 // If it is not available (false is returned), free/total may not be
338 // initialized.
339 //
340 // Note: "Free" reflects the amount of free memory on the underlying device,
341 // so allocations via other StreamExecutors that have the same underlying
342 // device
343 // will be reflected in "free".
344 bool DeviceMemoryUsage(int64_t* free, int64_t* total) const;
345
346 // The device count reported by this StreamExecutor's platform.
347 // Note: on OpenCL we implicitly select platform zero at the moment.
348 int PlatformDeviceCount() const;
349
350 // Returns whether the StreamExecutor supports BLAS routines for the platform
351 // that underlies this interface.
352 bool SupportsBlas() const;
353
354 // Returns whether the StreamExecutor supports FFT routines for the platform
355 // that underlies this interface.
356 bool SupportsFft() const;
357
358 // Returns whether the StreamExecutor supports RNG routines for the platform
359 // that underlies this interface.
360 bool SupportsRng() const;
361
362 // Returns whether the StreamExecutor support neural net routines for the
363 // platform that underlies this interface.
364 bool SupportsDnn() const;
365
366 // Returns the list of supported algorithms for the specified convolution
367 // operation.
368 bool GetConvolveAlgorithms(dnn::ConvolutionKind kind,
369 std::vector<dnn::AlgorithmDesc>* out_algorithms);
370
371 // Returns the supported algorithms / execution plans for a convolution.
372 port::Status GetConvolveRunners(
373 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
374 dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
375 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
376 const dnn::FilterDescriptor& filter_descriptor,
377 DeviceMemoryBase filter_data,
378 const dnn::BatchDescriptor& output_descriptor,
379 DeviceMemoryBase output_data,
380 const dnn::ConvolutionDescriptor& convolution_descriptor,
381 bool use_fallback, ScratchAllocator* scratch_allocator,
382 std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans);
383
384 port::Status GetFusedConvolveRunners(
385 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
386 dnn::DataType input_type, dnn::DataType bias_type,
387 dnn::DataType output_type, double conv_input_scale,
388 double side_input_scale, double leakyrelu_alpha, Stream* stream,
389 const dnn::BatchDescriptor& input_descriptor,
390 const dnn::FilterDescriptor& filter_descriptor,
391 const dnn::BatchDescriptor& bias_descriptor,
392 const dnn::BatchDescriptor& output_descriptor,
393 const dnn::ConvolutionDescriptor& convolution_descriptor,
394 bool use_fallback, dnn::ActivationMode activation_mode,
395 std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans);
396
397 // Returns the list of supported algorithms for the forward convolution
398 // operation.
399 bool GetMIOpenConvolveAlgorithms(
400 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
401 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
402 const dnn::FilterDescriptor& filter_descriptor,
403 DeviceMemoryBase filter_data,
404 const dnn::BatchDescriptor& output_descriptor,
405 DeviceMemoryBase output_data,
406 const dnn::ConvolutionDescriptor& convolution_descriptor,
407 ScratchAllocator* scratch_allocator,
408 std::vector<dnn::ProfileResult>* out_algorithms);
409
410 // Returns the list of supported algorithms for rnn operation.
411 bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc>* out_algorithms);
412
413 // Get the list of supported algorithms for BLAS gemm.
414 bool GetBlasGemmAlgorithms(Stream* stream,
415 std::vector<blas::AlgorithmType>* out_algorithms);
416
417 // Create an RNN descriptor based on model shapes and configurations.
418 // The caller retains the ownership of the descriptor.
419 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
420 int num_layers, int hidden_size, int input_size, int cell_size,
421 int batch_size, dnn::RnnInputMode input_mode,
422 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
423 dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
424 float dropout, uint64_t seed, ScratchAllocator* state_allocator,
425 bool use_padded_io);
426
427 // Create a RNN sequence descriptor that specifies either the input or output
428 // sequence. The caller retains the ownership of the returned descriptor.
429 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
430 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
431 int data_size, dnn::DataType data_type);
432
433 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
434 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
435 int data_size,
436 const absl::Span<const int>& seq_lengths,
437 bool time_major, dnn::DataType data_type);
438
439 // Create an RNN state descriptor that specifies the input or hidden state.
440 // The caller retains the ownership of the returned descriptor.
441 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
442 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
443 dnn::DataType data_type);
444
445 // Returns the device ordinal that this StreamExecutor was initialized with.
446 // Meaningless before initialization.
device_ordinal()447 int device_ordinal() const { return device_ordinal_; }
448
449 // Returns a borrowed pointer to the underlying StreamExecutor implementation.
450 internal::StreamExecutorInterface* implementation();
451
452 // Creates a kernel which can be launched with stream.ThenLaunch, such that
453 // the types of the arguments provided for launch would have to match
454 // types of the arguments provided at creation time.
455 //
456 // The kernel has a name kernel_name, and is based from provided PTX in ptx,
457 // and (optional) compiled PTX in cubin_data.
458 // The canonical storage for both ptx and cubin_data should outlive the
459 // lifetime of the kernel.
460 template <typename... Args>
461 port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
462 absl::string_view kernel_name, absl::string_view ptx,
463 absl::Span<const uint8> cubin_data);
464
465 // Warning: use Stream::ThenLaunch instead, this method is not for general
466 // consumption. However, this is the only way to launch a kernel for which
467 // the type signature is only known at runtime; say, if an application
468 // supports loading/launching kernels with arbitrary type signatures.
469 // In this case, the application is expected to know how to do parameter
470 // packing that obeys the contract of the underlying platform implementation.
471 //
472 // Launches a data parallel kernel with the given thread/block
473 // dimensionality and already-packed args/sizes to pass to the underlying
474 // platform driver.
475 //
476 // This is called by Stream::Launch() to delegate to the platform's launch
477 // implementation in StreamExecutorInterface::Launch().
478 port::Status Launch(Stream* stream, const ThreadDim& thread_dims,
479 const BlockDim& block_dims, const KernelBase& kernel,
480 const KernelArgsArrayBase& args);
481
482 // Gets-or-creates (creates with memoization) a FftSupport datatype that can
483 // be used to execute FFT routines on the current platform.
484 //
485 // Ownership and user-facing is the same as AsBlas() below.
486 //
487 // Returns null if there was an error initializing the FFT support for the
488 // underlying platform.
489 fft::FftSupport* AsFft();
490
491 // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
492 // be used for neural network routines on the current platform.
493 //
494 // Ownership and user-facing is the same as AsBlas() below.
495 //
496 // Returns null if there was an error initializing the DNN support for the
497 // underlying platform.
498 dnn::DnnSupport* AsDnn();
499
500 // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
501 // be used to execute BLAS routines on the current platform. This is typically
502 // not user-facing, as users will use the Stream::ThenBlas* family of routines
503 // to entrain BLAS operations. See blas.h for additional details.
504 //
505 // Ownership is not transferred to the caller -- ownership is retained by this
506 // object for memoization. This BLAS interface is also only expected to be
507 // used by a Stream for entraining calls to BLAS functionality.
508 //
509 // Returns null if there was an error initializing the BLAS support for the
510 // underlying platform.
511 blas::BlasSupport* AsBlas();
512
513 // Turns StreamExecutor operation tracing on or off.
514 void EnableTracing(bool enable);
515
516 // Registers a trace listener to receive callbacks for only a single
517 // StreamExecutor instance.
518 // To register a listener for all executors for a given platform, see
519 // Platform::RegisterTraceListener().
520 // Does not take ownership of listener.
521 void RegisterTraceListener(TraceListener* listener);
522
523 // Removes a TraceListener from this StreamExecutor instance.
524 // Returns false (and logs) in cases where the argument listener was not
525 // previously registered.
526 bool UnregisterTraceListener(TraceListener* listener);
527
528 // Return allocator statistics.
529 std::optional<AllocatorStats> GetAllocatorStats();
530
531 // Clears the internal stats except for the `in_use` fields
532 // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`.
533 bool ClearAllocatorStats();
534
535 // Return an allocator which delegates to this stream executor for memory
536 // allocation.
GetAllocator()537 StreamExecutorMemoryAllocator* GetAllocator() { return &allocator_; }
538
GetInternalExecutor()539 internal::StreamExecutorInterface* GetInternalExecutor() {
540 return implementation_.get();
541 }
542
543 // Returns a stream allocated by this executor, or nullptr if not found.
544 // Performs linear search over alive GPU streams.
FindAllocatedStream(void * gpu_stream)545 Stream* FindAllocatedStream(void* gpu_stream) {
546 return implementation()->FindAllocatedStream(gpu_stream);
547 }
548
549 private:
550 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
551 typename... BeginArgsT>
552 friend class ScopedTracer;
553 friend class Event;
554 friend class Stream;
555 friend class Timer;
556 template <typename... Params>
557 friend class TypedKernel;
558 template <typename... Args>
559 friend struct ThenBlasImpl;
560
561 // Synchronously allocates size bytes on the underlying platform and returns
562 // a DeviceMemoryBase representing that allocation. In the case of failure,
563 // nullptr is returned.
564 DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space);
565
566 // Gets-or-creates (creates with memoization) an RngSupport datatype that can
567 // be used for random-number-generation routines on the current platform.
568 //
569 // Ownership and user-facing is the same as AsBlas() above.
570 //
571 // Returns null if there was an error initializing the RNG support for the
572 // underlying platform.
573 rng::RngSupport* AsRng();
574
575 // Causes the host code to synchronously wait for operations entrained onto
576 // stream to complete. Effectively a join on the asynchronous device
577 // operations enqueued on the stream before this program point.
578 port::Status BlockHostUntilDone(Stream* stream);
579
580 // Without blocking the device, retrieve the current stream status.
581 port::Status GetStatus(Stream* stream);
582
583 // Finds and retrieves device memory for the symbol on the underlying
584 // platform.
585 bool GetSymbol(const std::string& symbol_name, ModuleHandle module_handle,
586 void** mem, size_t* bytes);
587
588 // Entrains a memcpy operation onto stream, with a host destination location
589 // host_dst and a device memory source, with target size size.
590 bool Memcpy(Stream* stream, void* host_dst,
591 const DeviceMemoryBase& device_src, uint64_t size);
592
593 // Entrains a memcpy operation onto stream, with a device destination location
594 // and a host memory source, with target size size.
595 bool Memcpy(Stream* stream, DeviceMemoryBase* device_dst,
596 const void* host_src, uint64_t size);
597
598 // Entrains a memcpy operation onto stream, with a device destination location
599 // and a device source location, with target size size. Peer access should
600 // have been enabled between the StreamExecutors owning the device memory
601 // regions.
602 bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* device_dst,
603 const DeviceMemoryBase& device_src, uint64_t size);
604
605 // Entrains on a stream a user-specified function to be run on the host.
606 // See Stream::ThenDoHostCallback for full details.
607 bool HostCallback(Stream* stream, std::function<void()> callback);
608
609 // Entrains on a stream a user-specified function to be run on the host.
610 // See Stream::ThenDoHostCallback for full details.
611 // This is the preferred form for a callback that may return an error.
612 bool HostCallback(Stream* stream, std::function<port::Status()> callback);
613
614 // Performs platform-specific allocation and initialization of an event.
615 port::Status AllocateEvent(Event* event);
616
617 // Performs platform-specific deallocation and cleanup of an event.
618 port::Status DeallocateEvent(Event* event);
619
620 // Inserts the specified event at the end of the specified stream.
621 port::Status RecordEvent(Stream* stream, Event* event);
622
623 // Wait for the specified event at the end of the specified stream.
624 port::Status WaitForEvent(Stream* stream, Event* event);
625
626 // Requests the current status of the event from the underlying platform.
627 Event::Status PollForEventStatus(Event* event);
628
629 // Allocates stream resources on the underlying platform and initializes its
630 // internals.
631 bool AllocateStream(Stream* stream);
632
633 // Deallocates stream resources on the underlying platform.
634 void DeallocateStream(Stream* stream);
635
636 // Causes dependent to not begin execution until other has finished its
637 // last-enqueued work.
638 bool CreateStreamDependency(Stream* dependent, Stream* other);
639
640 // Allocates timer resources on the underlying platform and initializes its
641 // internals.
642 bool AllocateTimer(Timer* timer);
643
644 // Deallocates timer resources on the underlying platform.
645 void DeallocateTimer(Timer* timer);
646
647 // Records a start event for an interval timer.
648 bool StartTimer(Stream* stream, Timer* timer);
649
650 // Records a stop event for an interval timer.
651 bool StopTimer(Stream* stream, Timer* timer);
652
653 // Allocates a new metadata object, appropriately populated, on the heap, with
654 // ownership transfer to caller.
655 std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
656
657 // Adds a task to the port::ThreadPool work queue. These tasks must be
658 // fire-and-forget and have no external data or timing dependencies; their
659 // execution order and completion time have no guarantees.
660 // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
661 // there, temporary internal buffers are freed using this method.
662 void EnqueueOnBackgroundThread(std::function<void()> task);
663
664 // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
665 // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
666 // tracked.
667 void CreateAllocRecord(void* opaque, uint64_t bytes);
668
669 // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
670 // pointers will not be erased (as they're not tracked, per above).
671 void EraseAllocRecord(void* opaque);
672
673 // Calls the relevant TraceListener routine to begin tracing for the specified
674 // asynchronous method.
675 template <typename TraceCallT, typename... ArgsT>
676 void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
677
678 // Reader/writer lock for class-static StreamExecutor members.
679 static absl::Mutex static_mu_;
680
681 // Reader/writer lock for mutable data structures on this StreamExecutor.
682 //
683 // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
684 // can acquire the lock on their first (mutating) call as well.
685 mutable absl::Mutex mu_;
686
687 // Reference to the platform that created this executor.
688 const Platform* platform_;
689
690 // Pointer to the platform-specific-interface implementation. This is
691 // delegated to by the interface routines in pointer-to-implementation
692 // fashion.
693 std::unique_ptr<internal::StreamExecutorInterface> implementation_;
694
695 // A mapping of pointer (to device memory) to string representation of the
696 // stack (of the allocating thread) at the time at which the pointer was
697 // allocated.
698 std::map<void*, AllocRecord> mem_allocs_ ABSL_GUARDED_BY(mu_);
699
700 // Memoized BLAS support object -- we only want to create this once when asked
701 // for a BLAS interface.
702 std::unique_ptr<blas::BlasSupport> blas_ ABSL_GUARDED_BY(mu_);
703
704 // Memoized DNN support object -- we only want to create this once when asked
705 // for an DNN interface.
706 std::unique_ptr<dnn::DnnSupport> dnn_ ABSL_GUARDED_BY(mu_);
707
708 // Memoized FFT support object -- we only want to create this once when asked
709 // for a FFT interface.
710 std::unique_ptr<fft::FftSupport> fft_;
711
712 // Memoized RNG support object -- we only want to create this once when asked
713 // for an RNG interface.
714 std::unique_ptr<rng::RngSupport> rng_ ABSL_GUARDED_BY(mu_);
715
716 // Slot to cache the owned DeviceDescription for the underlying device
717 // once it has been queried from DeviceDescription().
718 mutable std::unique_ptr<DeviceDescription> device_description_
719 ABSL_GUARDED_BY(mu_);
720
721 // The kind of the underlying platform that is being targeted, as passed
722 // during construction.
723 //
724 // Immutable post-initialization.
725 PlatformKind platform_kind_;
726
727 // The device ordinal that this object was initialized with.
728 //
729 // Immutable post-initialization.
730 int device_ordinal_;
731
732 // Executor for handling host callback work that cannot be performed
733 // by a host callback thread - for example, cleanup after a host BLAS routine
734 // (which may make device API calls). This work cannot block the host
735 // callback thread, will be completed asynchronously, and should be treated
736 // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
737 // here.
738 //
739 // Immutable post-initialization. Object is thread-safe.
740 std::unique_ptr<port::ThreadPool> background_threads_;
741
742 // Counter for the current number of live streams. This is used to check
743 // for accidentally-outstanding streams at StreamExecutor teardown time, as
744 // well
745 // as to indicate leaks (via a large outstanding count being logged) in the
746 // case we can't allocate more streams.
747 std::atomic_int_fast32_t live_stream_count_;
748
749 // Only one worker thread is needed; little work will be done by the
750 // executor.
751 static constexpr int kNumBackgroundThreads = 1;
752
753 // Indicates if StreamExecutor operation tracing should be performed.
754 bool tracing_enabled_;
755
756 // The set of TraceListeners registered for this StreamExecutor.
757 std::set<TraceListener*> listeners_ ABSL_GUARDED_BY(mu_);
758
759 // Allocated memory in bytes.
760 int64_t mem_alloc_bytes_;
761
762 // Memory limit in bytes. Value less or equal to 0 indicates there is no
763 // limit.
764 int64_t memory_limit_bytes_;
765
766 StreamExecutorMemoryAllocator allocator_;
767
768 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
769 };
770
771 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
772 class ScopedModuleHandle {
773 public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)774 explicit ScopedModuleHandle(StreamExecutor* executor,
775 ModuleHandle module_handle)
776 : executor_(executor), module_handle_(module_handle) {}
777
ScopedModuleHandle(ScopedModuleHandle && other)778 ScopedModuleHandle(ScopedModuleHandle&& other) {
779 executor_ = other.executor_;
780 module_handle_ = other.module_handle_;
781 other.executor_ = nullptr;
782 other.module_handle_ = ModuleHandle();
783 }
784
785 ScopedModuleHandle& operator=(ScopedModuleHandle&& other) {
786 executor_ = other.executor_;
787 module_handle_ = other.module_handle_;
788 other.executor_ = nullptr;
789 other.module_handle_ = ModuleHandle();
790 return *this;
791 }
792
~ScopedModuleHandle()793 ~ScopedModuleHandle() {
794 if (static_cast<bool>(module_handle_)) {
795 CHECK(executor_->UnloadModule(module_handle_));
796 }
797 }
798
799 private:
800 StreamExecutor* executor_;
801 ModuleHandle module_handle_;
802
803 TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
804 };
805
806 ////////////
807 // Inlines
808
809 template <typename... Args>
810 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)811 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
812 absl::string_view ptx,
813 absl::Span<const uint8> cubin_data) {
814 auto kernel_base = std::make_unique<TypedKernel<Args...>>(this);
815 MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
816 loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
817
818 if (!cubin_data.empty()) {
819 loader_spec.AddCudaCubinInMemory(
820 reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
821 }
822
823 TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
824 return std::move(kernel_base);
825 }
826
827 template <typename T>
AllocateArray(uint64_t element_count,int64_t memory_space)828 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64_t element_count,
829 int64_t memory_space) {
830 uint64_t bytes = sizeof(T) * element_count;
831 return DeviceMemory<T>(Allocate(bytes, memory_space));
832 }
833
834 template <typename T>
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle)835 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
836 const std::string& symbol_name, ModuleHandle module_handle) {
837 port::StatusOr<DeviceMemoryBase> untyped_symbol =
838 GetUntypedSymbol(symbol_name, module_handle);
839 if (!untyped_symbol.ok()) {
840 return untyped_symbol.status();
841 }
842 return DeviceMemory<T>(untyped_symbol.ValueOrDie());
843 }
844
845 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)846 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor* parent,
847 DeviceMemoryBase value)
848 : wrapped_(value),
849 device_ordinal_(parent->device_ordinal()),
850 allocator_(parent->GetAllocator()) {}
851
852 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)853 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
854 StreamExecutor* parent, std::initializer_list<ElemT> values)
855 : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
856 if (ptr() != nullptr) {
857 std::vector<ElemT> local(values);
858 if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT*>(&local[0]),
859 ptr()->size())) {
860 TF_CHECK_OK(Free());
861 }
862 }
863 }
864
865 template <typename T>
AllocateZeroed()866 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
867 DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
868 if (buf.is_null()) {
869 return DeviceMemory<T>{};
870 }
871
872 DeviceMemory<T> result(buf);
873 bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
874 if (!ok) {
875 Deallocate(&result);
876 return DeviceMemory<T>{};
877 }
878
879 return result;
880 }
881
882 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64_t element_offset,uint64_t element_count)883 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T>* parent,
884 uint64_t element_offset,
885 uint64_t element_count) {
886 if (element_offset + element_count > parent->ElementCount()) {
887 LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
888 << "than parent allocation size: (" << element_offset << " + "
889 << element_count << ") vs. (" << parent->ElementCount() << ")";
890 return DeviceMemory<T>{};
891 }
892
893 void* opaque = implementation_->GetSubBuffer(
894 parent, sizeof(T) * element_offset, sizeof(T) * element_count);
895 if (opaque == nullptr) {
896 return DeviceMemory<T>{};
897 }
898 return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
899 }
900
901 } // namespace stream_executor
902
903 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
904