• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18 
19 #include <atomic>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <tuple>
25 #include <vector>
26 
27 #include "absl/base/macros.h"
28 #include "absl/base/thread_annotations.h"
29 #include "absl/memory/memory.h"
30 #include "absl/synchronization/mutex.h"
31 #include "absl/types/optional.h"
32 #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h"
33 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
34 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
35 #include "tensorflow/compiler/xla/stream_executor/lib/threadpool.h"
36 #include "tensorflow/compiler/xla/stream_executor/platform.h"
37 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
38 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
39 #include "tensorflow/compiler/xla/stream_executor/rng.h"
40 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
41 #include "tensorflow/compiler/xla/stream_executor/trace_listener.h"
42 
43 namespace stream_executor {
44 
45 class Stream;
46 
47 // Structure used for device memory leak checking.
48 struct AllocRecord {
49   // The requested allocation size of the buffer.
50   uint64_t bytes;
51 
52   // Holds a representation of the stack at the time the associated buffer was
53   // allocated. Produced in a form described in
54   // //util/symbolize/symbolized_stacktrace.h.
55   std::string stack_trace;
56 };
57 
58 // Forward declaration of private friend class.
59 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
60           typename... BeginArgsT>
61 class ScopedTracer;
62 
63 // A StreamExecutor manages a single device, in terms of executing work (kernel
64 // launches) and memory management (allocation/deallocation, memory copies to
65 // and from the device). It is conceptually the "handle" for a device -- Stream
66 // objects, which are used to enqueue work to run on the
67 // coprocessor have a StreamExecutor instance as their "parent" object.
68 //
69 // StreamExecutor objects have an underlying platform that is specified up
70 // front;
71 // e.g. either it is a CUDA or OpenCL executor.
72 //
73 // Thread-safe after initialization.
74 // StreamExecutor interface should not be invoked from a signal handler.
75 class StreamExecutor {
76  public:
77   StreamExecutor(
78       const Platform* platform,
79       std::unique_ptr<internal::StreamExecutorInterface> implementation,
80       int device_ordinal);
81 
82   ~StreamExecutor();
83 
84   port::Status Init();
85   port::Status Init(DeviceOptions device_options);
86 
87   // Returns the platform that this StreamExecutor is acting upon.
88   ABSL_DEPRECATED("Use platform() instead.")
platform_kind()89   PlatformKind platform_kind() const { return platform_kind_; }
90 
91   // Returns a reference to the platform that created this executor.
platform()92   const Platform* platform() const { return platform_; }
93 
94   // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
95   // upon, if one exists.
96   //
97   // Parameters:
98   //   spec: The MultiKernelLoaderSpec is usually generated as a compile-time
99   //    constant into an appropriate namespace. For example, see
100   //    stream_executor::executor_sample::kKernelLoaderSpecs, from which a
101   //    MultiKernelLoaderSpec is selected.
102   //   kernel: Outparam that the kernel is loaded into. A given Kernel
103   //    instantiation should not be loaded into more than once.
104   //
105   // If an error occurs, or there is no kernel available for the StreamExecutor
106   // platform, error status is returned.
107   port::Status GetKernel(const MultiKernelLoaderSpec& spec, KernelBase* kernel);
108 
109   // Releases any state associated with the previously loaded kernel.
110   void UnloadKernel(const KernelBase* kernel);
111 
112   // Loads a module for the platform this StreamExecutor is acting upon.
113   //
114   // `spec` describes the module to be loaded.  On success writes the handle for
115   // the loaded module to `module_handle` and returns Status::OK.
116   // Otherwise, returns the error which has occurred.
117   port::Status LoadModule(const MultiModuleLoaderSpec& spec,
118                           ModuleHandle* module_handle);
119 
120   // Unloads the module with handle `module_handle`.
121   bool UnloadModule(ModuleHandle module_handle);
122 
123   port::StatusOr<std::shared_ptr<DeviceMemoryBase>> CreateOrShareConstant(
124       Stream* stream, const std::vector<uint8_t>& content);
125 
126   // Synchronously allocates an array on the device of type T with element_count
127   // elements.
128   template <typename T>
129   DeviceMemory<T> AllocateArray(uint64_t element_count,
130                                 int64_t memory_space = 0);
131 
132   // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
133   template <typename T>
AllocateOwnedArray(uint64_t element_count)134   ScopedDeviceMemory<T> AllocateOwnedArray(uint64_t element_count) {
135     return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
136   }
137 
138   // Convenience wrapper that allocates space for a single element of type T in
139   // device memory.
140   template <typename T>
AllocateScalar()141   DeviceMemory<T> AllocateScalar() {
142     return AllocateArray<T>(1);
143   }
144 
145   // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
146   template <typename T>
AllocateOwnedScalar()147   ScopedDeviceMemory<T> AllocateOwnedScalar() {
148     return AllocateOwnedArray<T>(1);
149   }
150 
151   // Synchronously allocates a scalar of type T on the device that is (POD)
152   // zero-byte initialized.
153   template <typename T>
154   DeviceMemory<T> AllocateZeroed();
155 
156   // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
157   template <typename T>
AllocateOwnedZeroed()158   ScopedDeviceMemory<T> AllocateOwnedZeroed() {
159     return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
160   }
161 
162   // Allocate a memory region inside another allocated memory region.
163   // Offset and size are specified in terms of T elements.
164   // Warning: Do not free a parent buffer before its sub-buffers; this may cause
165   // use-after-free issues (the specific behavior is not consistent across
166   // platforms).
167   //  - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
168   //    sub-buffer after parent deallocation is expected to be safe. This will
169   //    render your code non-platform-portable, however.
170   template <typename T>
171   DeviceMemory<T> GetSubBuffer(DeviceMemory<T>* parent, uint64_t element_offset,
172                                uint64_t element_count);
173 
174   // Finds a symbol within the module corresponding to `module_handle` and
175   // returns device memory allocated to the symbol. The user has to make sure
176   // that the type of symbol and T match.
177   // - Note: symbol_name should include its namespace as well. For example,
178   //         pass "nms0::symbol" if referring to nms0::symbol.
179   template <typename T>
180   port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string& symbol_name,
181                                             ModuleHandle module_handle);
182 
183   // An untyped version of GetSymbol.
184   port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
185       const std::string& symbol_name, ModuleHandle module_handle);
186 
187   // Deallocate the DeviceMemory previously allocated via this interface.
188   // Deallocation of a nullptr-representative value is permitted.
189   //
190   // Resets the internal contents of mem to be null-representative, but this
191   // null-out effect should not be relied upon in client code.
192   void Deallocate(DeviceMemoryBase* mem);
193 
194   // Retrieves a mapping of active opaque device memory pointer to a string
195   // representation of the [allocating thread's] stack at the time the pointer
196   // was allocated. Useful for tracking device memory leaks.
197   //
198   // Note: this will only be populated if --check_device_leaks flag is
199   // activated.
200   void GetMemAllocs(std::map<void*, AllocRecord>* records_out);
201 
202   // Allocates unified memory space of the given size, if supported.
203   // See
204   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
205   // for more details on unified memory.
206   void* UnifiedMemoryAllocate(uint64_t bytes);
207 
208   // Deallocates unified memory space previously allocated with
209   // UnifiedMemoryAllocate.
210   void UnifiedMemoryDeallocate(void* location);
211 
212   // Allocates a region of host memory and registers it with the platform API.
213   // Memory allocated in this manner (or allocated and registered with
214   // HostMemoryRegister() is required for use in asynchronous memcpy operations,
215   // such as Stream::ThenMemcpy.
216   void* HostMemoryAllocate(uint64_t size);
217 
218   // Deallocates a region of host memory allocated by HostMemoryAllocate().
219   void HostMemoryDeallocate(void* location);
220 
221   // Registers a region of host memory with the platform API. Registered memory
222   // (or memory allocated with HostMemoryAllocate) is required for use with
223   // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
224   // is used to register memory allocated outside the StreamExecutor;
225   // HostMemoryAllocate implicitly registers its allocations and
226   // HostMemoryDeallocate implicitly deregisters on deallocation.
227   bool HostMemoryRegister(void* location, uint64_t size) SE_MUST_USE_RESULT;
228 
229   // Unregisters a region of host memory registered with HostMemoryRegister.
230   // This should be done before deallocating the region with delete[]/free/etc.
231   bool HostMemoryUnregister(void* location) SE_MUST_USE_RESULT;
232 
233   // Synchronizes all activity occurring in the StreamExecutor's context (most
234   // likely a whole device).
235   bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
236 
237   // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
238   // given location in device memory.
239   port::Status SynchronousMemZero(DeviceMemoryBase* location,
240                                   uint64_t size) SE_MUST_USE_RESULT;
241 
242   // Blocks the caller while "size" bytes are initialized to "value" (in POD
243   // fashion) at the given location in device memory.
244   port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
245                                  uint64_t size) SE_MUST_USE_RESULT;
246 
247   // [deprecated] Blocks the caller while a data segment of the given size is
248   // copied from the host source to the device destination.
249   ABSL_DEPRECATED(
250       "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
251   bool SynchronousMemcpy(DeviceMemoryBase* device_dst, const void* host_src,
252                          uint64_t size) SE_MUST_USE_RESULT;
253 
254   // [deprecated] Blocks the caller while a data segment of the given size is
255   // copied from the device source to the host destination.
256   ABSL_DEPRECATED(
257       "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
258   bool SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& device_src,
259                          uint64_t size) SE_MUST_USE_RESULT;
260 
261   // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
262   port::Status SynchronousMemcpyH2D(const void* host_src, int64_t size,
263                                     DeviceMemoryBase* device_dst);
264 
265   // Alternative interface for memcpying from host to device that takes an
266   // array slice. Checks that the destination size can accommodate the host
267   // slice size.
268   template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)269   port::Status SynchronousMemcpyH2D(
270       port::ArraySlice<T> host_src,  // non-absl ok
271       DeviceMemoryBase* device_dst) {
272     auto host_size = host_src.size() * sizeof(T);
273     CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
274     return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
275   }
276 
277   // Same as SynchronousMemcpy(void*, ...) above.
278   port::Status SynchronousMemcpyD2H(const DeviceMemoryBase& device_src,
279                                     int64_t size, void* host_dst);
280 
281   // Alternative interface for memcpying from device to host that takes an
282   // array slice. Checks that the destination size can accommodate the host
283   // slice size.
284   template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)285   port::Status SynchronousMemcpyD2H(const DeviceMemory<T>& device_src,
286                                     port::MutableArraySlice<T> host_dst) {
287     auto host_size = host_dst.size() * sizeof(T);
288     CHECK(device_src.size() == 0 || host_size >= device_src.size());
289     return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
290   }
291 
292   // Blocks the caller while a data segment of the given size is copied from the
293   // device source to the device destination.
294   bool SynchronousMemcpy(DeviceMemoryBase* device_dst,
295                          const DeviceMemoryBase& device_src,
296                          uint64_t size) SE_MUST_USE_RESULT;
297 
298   // Enqueues an operation onto stream to zero out size bytes at the given
299   // device memory location. Neither stream nor location may be null. Returns
300   // whether the operation was successfully enqueued onto the stream.
301   port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
302                        uint64_t size) SE_MUST_USE_RESULT;
303 
304   // Enqueues an operation onto stream to set 32-bit patterns starting at
305   // location, for byte count given by size. size must be 32-bit quantified
306   // (i.e. evently divisible by 4). Returns whether the operation was
307   // successfully enqueued onto the stream.
308   port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
309                         uint32 pattern, uint64_t size);
310 
311   // Enables peer access from this StreamExecutor to memory
312   // allocated by other, such that launched device code, memcpies, etc may
313   // access it directly.
314   //
315   // Both this StreamExecutor and other must be backed by the same platform (as
316   // in
317   // CUDA vs OpenCL) implementation.
318   port::Status EnablePeerAccessTo(StreamExecutor* other);
319 
320   // Returns whether it's possible to enable peer access from this
321   // StreamExecutor
322   // to memory allocated by another.
323   //
324   // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
325   // this is more an up-front test as to whether it's expressly forbidden.
326   bool CanEnablePeerAccessTo(StreamExecutor* other);
327 
328   // Obtains metadata about the underlying device.
329   // The value is cached on first use.
330   const DeviceDescription& GetDeviceDescription() const;
331 
332   // If implemented, returns device specific measurement of load
333   // (e.g. pending requests).
334   int64_t GetDeviceLoad() const;
335 
336   // Returns the underlying device memory usage information, if it is available.
337   // If it is not available (false is returned), free/total may not be
338   // initialized.
339   //
340   // Note: "Free" reflects the amount of free memory on the underlying device,
341   // so allocations via other StreamExecutors that have the same underlying
342   // device
343   // will be reflected in "free".
344   bool DeviceMemoryUsage(int64_t* free, int64_t* total) const;
345 
346   // The device count reported by this StreamExecutor's platform.
347   // Note: on OpenCL we implicitly select platform zero at the moment.
348   int PlatformDeviceCount() const;
349 
350   // Returns whether the StreamExecutor supports BLAS routines for the platform
351   // that underlies this interface.
352   bool SupportsBlas() const;
353 
354   // Returns whether the StreamExecutor supports FFT routines for the platform
355   // that underlies this interface.
356   bool SupportsFft() const;
357 
358   // Returns whether the StreamExecutor supports RNG routines for the platform
359   // that underlies this interface.
360   bool SupportsRng() const;
361 
362   // Returns whether the StreamExecutor support neural net routines for the
363   // platform that underlies this interface.
364   bool SupportsDnn() const;
365 
366   // Returns the list of supported algorithms for the specified convolution
367   // operation.
368   bool GetConvolveAlgorithms(dnn::ConvolutionKind kind,
369                              std::vector<dnn::AlgorithmDesc>* out_algorithms);
370 
371   // Returns the supported algorithms / execution plans for a convolution.
372   port::Status GetConvolveRunners(
373       bool use_cudnn_frontend, dnn::ConvolutionKind kind,
374       dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
375       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
376       const dnn::FilterDescriptor& filter_descriptor,
377       DeviceMemoryBase filter_data,
378       const dnn::BatchDescriptor& output_descriptor,
379       DeviceMemoryBase output_data,
380       const dnn::ConvolutionDescriptor& convolution_descriptor,
381       bool use_fallback, ScratchAllocator* scratch_allocator,
382       std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans);
383 
384   port::Status GetFusedConvolveRunners(
385       bool use_cudnn_frontend, dnn::ConvolutionKind kind,
386       dnn::DataType input_type, dnn::DataType bias_type,
387       dnn::DataType output_type, double conv_input_scale,
388       double side_input_scale, double leakyrelu_alpha, Stream* stream,
389       const dnn::BatchDescriptor& input_descriptor,
390       const dnn::FilterDescriptor& filter_descriptor,
391       const dnn::BatchDescriptor& bias_descriptor,
392       const dnn::BatchDescriptor& output_descriptor,
393       const dnn::ConvolutionDescriptor& convolution_descriptor,
394       bool use_fallback, dnn::ActivationMode activation_mode,
395       std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans);
396 
397   // Returns the list of supported algorithms for the forward convolution
398   // operation.
399   bool GetMIOpenConvolveAlgorithms(
400       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
401       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
402       const dnn::FilterDescriptor& filter_descriptor,
403       DeviceMemoryBase filter_data,
404       const dnn::BatchDescriptor& output_descriptor,
405       DeviceMemoryBase output_data,
406       const dnn::ConvolutionDescriptor& convolution_descriptor,
407       ScratchAllocator* scratch_allocator,
408       std::vector<dnn::ProfileResult>* out_algorithms);
409 
410   // Returns the list of supported algorithms for rnn operation.
411   bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc>* out_algorithms);
412 
413   // Get the list of supported algorithms for BLAS gemm.
414   bool GetBlasGemmAlgorithms(Stream* stream,
415                              std::vector<blas::AlgorithmType>* out_algorithms);
416 
417   // Create an RNN descriptor based on model shapes and configurations.
418   // The caller retains the ownership of the descriptor.
419   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
420       int num_layers, int hidden_size, int input_size, int cell_size,
421       int batch_size, dnn::RnnInputMode input_mode,
422       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
423       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
424       float dropout, uint64_t seed, ScratchAllocator* state_allocator,
425       bool use_padded_io);
426 
427   // Create a RNN sequence descriptor that specifies either the input or output
428   // sequence. The caller retains the ownership of the returned descriptor.
429   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
430   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
431                                     int data_size, dnn::DataType data_type);
432 
433   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
434   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
435                                     int data_size,
436                                     const absl::Span<const int>& seq_lengths,
437                                     bool time_major, dnn::DataType data_type);
438 
439   // Create an RNN state descriptor that specifies the input or hidden state.
440   // The caller retains the ownership of the returned descriptor.
441   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
442   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
443                                  dnn::DataType data_type);
444 
445   // Returns the device ordinal that this StreamExecutor was initialized with.
446   // Meaningless before initialization.
device_ordinal()447   int device_ordinal() const { return device_ordinal_; }
448 
449   // Returns a borrowed pointer to the underlying StreamExecutor implementation.
450   internal::StreamExecutorInterface* implementation();
451 
452   // Creates a kernel which can be launched with stream.ThenLaunch, such that
453   // the types of the arguments provided for launch would have to match
454   // types of the arguments provided at creation time.
455   //
456   // The kernel has a name kernel_name, and is based from provided PTX in ptx,
457   // and (optional) compiled PTX in cubin_data.
458   // The canonical storage for both ptx and cubin_data should outlive the
459   // lifetime of the kernel.
460   template <typename... Args>
461   port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
462       absl::string_view kernel_name, absl::string_view ptx,
463       absl::Span<const uint8> cubin_data);
464 
465   // Warning: use Stream::ThenLaunch instead, this method is not for general
466   // consumption. However, this is the only way to launch a kernel for which
467   // the type signature is only known at runtime; say, if an application
468   // supports loading/launching kernels with arbitrary type signatures.
469   // In this case, the application is expected to know how to do parameter
470   // packing that obeys the contract of the underlying platform implementation.
471   //
472   // Launches a data parallel kernel with the given thread/block
473   // dimensionality and already-packed args/sizes to pass to the underlying
474   // platform driver.
475   //
476   // This is called by Stream::Launch() to delegate to the platform's launch
477   // implementation in StreamExecutorInterface::Launch().
478   port::Status Launch(Stream* stream, const ThreadDim& thread_dims,
479                       const BlockDim& block_dims, const KernelBase& kernel,
480                       const KernelArgsArrayBase& args);
481 
482   // Gets-or-creates (creates with memoization) a FftSupport datatype that can
483   // be used to execute FFT routines on the current platform.
484   //
485   // Ownership and user-facing is the same as AsBlas() below.
486   //
487   // Returns null if there was an error initializing the FFT support for the
488   // underlying platform.
489   fft::FftSupport* AsFft();
490 
491   // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
492   // be used for neural network routines on the current platform.
493   //
494   // Ownership and user-facing is the same as AsBlas() below.
495   //
496   // Returns null if there was an error initializing the DNN support for the
497   // underlying platform.
498   dnn::DnnSupport* AsDnn();
499 
500   // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
501   // be used to execute BLAS routines on the current platform. This is typically
502   // not user-facing, as users will use the Stream::ThenBlas* family of routines
503   // to entrain BLAS operations. See blas.h for additional details.
504   //
505   // Ownership is not transferred to the caller -- ownership is retained by this
506   // object for memoization. This BLAS interface is also only expected to be
507   // used by a Stream for entraining calls to BLAS functionality.
508   //
509   // Returns null if there was an error initializing the BLAS support for the
510   // underlying platform.
511   blas::BlasSupport* AsBlas();
512 
513   // Turns StreamExecutor operation tracing on or off.
514   void EnableTracing(bool enable);
515 
516   // Registers a trace listener to receive callbacks for only a single
517   // StreamExecutor instance.
518   // To register a listener for all executors for a given platform, see
519   // Platform::RegisterTraceListener().
520   // Does not take ownership of listener.
521   void RegisterTraceListener(TraceListener* listener);
522 
523   // Removes a TraceListener from this StreamExecutor instance.
524   // Returns false (and logs) in cases where the argument listener was not
525   // previously registered.
526   bool UnregisterTraceListener(TraceListener* listener);
527 
528   // Return allocator statistics.
529   std::optional<AllocatorStats> GetAllocatorStats();
530 
531   // Clears the internal stats except for the `in_use` fields
532   // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`.
533   bool ClearAllocatorStats();
534 
535   // Return an allocator which delegates to this stream executor for memory
536   // allocation.
GetAllocator()537   StreamExecutorMemoryAllocator* GetAllocator() { return &allocator_; }
538 
GetInternalExecutor()539   internal::StreamExecutorInterface* GetInternalExecutor() {
540     return implementation_.get();
541   }
542 
543   // Returns a stream allocated by this executor, or nullptr if not found.
544   // Performs linear search over alive GPU streams.
FindAllocatedStream(void * gpu_stream)545   Stream* FindAllocatedStream(void* gpu_stream) {
546     return implementation()->FindAllocatedStream(gpu_stream);
547   }
548 
549  private:
550   template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
551             typename... BeginArgsT>
552   friend class ScopedTracer;
553   friend class Event;
554   friend class Stream;
555   friend class Timer;
556   template <typename... Params>
557   friend class TypedKernel;
558   template <typename... Args>
559   friend struct ThenBlasImpl;
560 
561   // Synchronously allocates size bytes on the underlying platform and returns
562   // a DeviceMemoryBase representing that allocation. In the case of failure,
563   // nullptr is returned.
564   DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space);
565 
566   // Gets-or-creates (creates with memoization) an RngSupport datatype that can
567   // be used for random-number-generation routines on the current platform.
568   //
569   // Ownership and user-facing is the same as AsBlas() above.
570   //
571   // Returns null if there was an error initializing the RNG support for the
572   // underlying platform.
573   rng::RngSupport* AsRng();
574 
575   // Causes the host code to synchronously wait for operations entrained onto
576   // stream to complete. Effectively a join on the asynchronous device
577   // operations enqueued on the stream before this program point.
578   port::Status BlockHostUntilDone(Stream* stream);
579 
580   // Without blocking the device, retrieve the current stream status.
581   port::Status GetStatus(Stream* stream);
582 
583   // Finds and retrieves device memory for the symbol on the underlying
584   // platform.
585   bool GetSymbol(const std::string& symbol_name, ModuleHandle module_handle,
586                  void** mem, size_t* bytes);
587 
588   // Entrains a memcpy operation onto stream, with a host destination location
589   // host_dst and a device memory source, with target size size.
590   bool Memcpy(Stream* stream, void* host_dst,
591               const DeviceMemoryBase& device_src, uint64_t size);
592 
593   // Entrains a memcpy operation onto stream, with a device destination location
594   // and a host memory source, with target size size.
595   bool Memcpy(Stream* stream, DeviceMemoryBase* device_dst,
596               const void* host_src, uint64_t size);
597 
598   // Entrains a memcpy operation onto stream, with a device destination location
599   // and a device source location, with target size size. Peer access should
600   // have been enabled between the StreamExecutors owning the device memory
601   // regions.
602   bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* device_dst,
603                             const DeviceMemoryBase& device_src, uint64_t size);
604 
605   // Entrains on a stream a user-specified function to be run on the host.
606   // See Stream::ThenDoHostCallback for full details.
607   bool HostCallback(Stream* stream, std::function<void()> callback);
608 
609   // Entrains on a stream a user-specified function to be run on the host.
610   // See Stream::ThenDoHostCallback for full details.
611   // This is the preferred form for a callback that may return an error.
612   bool HostCallback(Stream* stream, std::function<port::Status()> callback);
613 
614   // Performs platform-specific allocation and initialization of an event.
615   port::Status AllocateEvent(Event* event);
616 
617   // Performs platform-specific deallocation and cleanup of an event.
618   port::Status DeallocateEvent(Event* event);
619 
620   // Inserts the specified event at the end of the specified stream.
621   port::Status RecordEvent(Stream* stream, Event* event);
622 
623   // Wait for the specified event at the end of the specified stream.
624   port::Status WaitForEvent(Stream* stream, Event* event);
625 
626   // Requests the current status of the event from the underlying platform.
627   Event::Status PollForEventStatus(Event* event);
628 
629   // Allocates stream resources on the underlying platform and initializes its
630   // internals.
631   bool AllocateStream(Stream* stream);
632 
633   // Deallocates stream resources on the underlying platform.
634   void DeallocateStream(Stream* stream);
635 
636   // Causes dependent to not begin execution until other has finished its
637   // last-enqueued work.
638   bool CreateStreamDependency(Stream* dependent, Stream* other);
639 
640   // Allocates timer resources on the underlying platform and initializes its
641   // internals.
642   bool AllocateTimer(Timer* timer);
643 
644   // Deallocates timer resources on the underlying platform.
645   void DeallocateTimer(Timer* timer);
646 
647   // Records a start event for an interval timer.
648   bool StartTimer(Stream* stream, Timer* timer);
649 
650   // Records a stop event for an interval timer.
651   bool StopTimer(Stream* stream, Timer* timer);
652 
653   // Allocates a new metadata object, appropriately populated, on the heap, with
654   // ownership transfer to caller.
655   std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
656 
657   // Adds a task to the port::ThreadPool work queue. These tasks must be
658   // fire-and-forget and have no external data or timing dependencies; their
659   // execution order and completion time have no guarantees.
660   // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
661   // there, temporary internal buffers are freed using this method.
662   void EnqueueOnBackgroundThread(std::function<void()> task);
663 
664   // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
665   // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
666   // tracked.
667   void CreateAllocRecord(void* opaque, uint64_t bytes);
668 
669   // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
670   // pointers will not be erased (as they're not tracked, per above).
671   void EraseAllocRecord(void* opaque);
672 
673   // Calls the relevant TraceListener routine to begin tracing for the specified
674   // asynchronous method.
675   template <typename TraceCallT, typename... ArgsT>
676   void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
677 
678   // Reader/writer lock for class-static StreamExecutor members.
679   static absl::Mutex static_mu_;
680 
681   // Reader/writer lock for mutable data structures on this StreamExecutor.
682   //
683   // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
684   // can acquire the lock on their first (mutating) call as well.
685   mutable absl::Mutex mu_;
686 
687   // Reference to the platform that created this executor.
688   const Platform* platform_;
689 
690   // Pointer to the platform-specific-interface implementation. This is
691   // delegated to by the interface routines in pointer-to-implementation
692   // fashion.
693   std::unique_ptr<internal::StreamExecutorInterface> implementation_;
694 
695   // A mapping of pointer (to device memory) to string representation of the
696   // stack (of the allocating thread) at the time at which the pointer was
697   // allocated.
698   std::map<void*, AllocRecord> mem_allocs_ ABSL_GUARDED_BY(mu_);
699 
700   // Memoized BLAS support object -- we only want to create this once when asked
701   // for a BLAS interface.
702   std::unique_ptr<blas::BlasSupport> blas_ ABSL_GUARDED_BY(mu_);
703 
704   // Memoized DNN support object -- we only want to create this once when asked
705   // for an DNN interface.
706   std::unique_ptr<dnn::DnnSupport> dnn_ ABSL_GUARDED_BY(mu_);
707 
708   // Memoized FFT support object -- we only want to create this once when asked
709   // for a FFT interface.
710   std::unique_ptr<fft::FftSupport> fft_;
711 
712   // Memoized RNG support object -- we only want to create this once when asked
713   // for an RNG interface.
714   std::unique_ptr<rng::RngSupport> rng_ ABSL_GUARDED_BY(mu_);
715 
716   // Slot to cache the owned DeviceDescription for the underlying device
717   // once it has been queried from DeviceDescription().
718   mutable std::unique_ptr<DeviceDescription> device_description_
719       ABSL_GUARDED_BY(mu_);
720 
721   // The kind of the underlying platform that is being targeted, as passed
722   // during construction.
723   //
724   // Immutable post-initialization.
725   PlatformKind platform_kind_;
726 
727   // The device ordinal that this object was initialized with.
728   //
729   // Immutable post-initialization.
730   int device_ordinal_;
731 
732   // Executor for handling host callback work that cannot be performed
733   // by a host callback thread - for example, cleanup after a host BLAS routine
734   // (which may make device API calls). This work cannot block the host
735   // callback thread, will be completed asynchronously, and should be treated
736   // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
737   // here.
738   //
739   // Immutable post-initialization. Object is thread-safe.
740   std::unique_ptr<port::ThreadPool> background_threads_;
741 
742   // Counter for the current number of live streams. This is used to check
743   // for accidentally-outstanding streams at StreamExecutor teardown time, as
744   // well
745   // as to indicate leaks (via a large outstanding count being logged) in the
746   // case we can't allocate more streams.
747   std::atomic_int_fast32_t live_stream_count_;
748 
749   // Only one worker thread is needed; little work will be done by the
750   // executor.
751   static constexpr int kNumBackgroundThreads = 1;
752 
753   // Indicates if StreamExecutor operation tracing should be performed.
754   bool tracing_enabled_;
755 
756   // The set of TraceListeners registered for this StreamExecutor.
757   std::set<TraceListener*> listeners_ ABSL_GUARDED_BY(mu_);
758 
759   // Allocated memory in bytes.
760   int64_t mem_alloc_bytes_;
761 
762   // Memory limit in bytes. Value less or equal to 0 indicates there is no
763   // limit.
764   int64_t memory_limit_bytes_;
765 
766   StreamExecutorMemoryAllocator allocator_;
767 
768   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
769 };
770 
771 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
772 class ScopedModuleHandle {
773  public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)774   explicit ScopedModuleHandle(StreamExecutor* executor,
775                               ModuleHandle module_handle)
776       : executor_(executor), module_handle_(module_handle) {}
777 
ScopedModuleHandle(ScopedModuleHandle && other)778   ScopedModuleHandle(ScopedModuleHandle&& other) {
779     executor_ = other.executor_;
780     module_handle_ = other.module_handle_;
781     other.executor_ = nullptr;
782     other.module_handle_ = ModuleHandle();
783   }
784 
785   ScopedModuleHandle& operator=(ScopedModuleHandle&& other) {
786     executor_ = other.executor_;
787     module_handle_ = other.module_handle_;
788     other.executor_ = nullptr;
789     other.module_handle_ = ModuleHandle();
790     return *this;
791   }
792 
~ScopedModuleHandle()793   ~ScopedModuleHandle() {
794     if (static_cast<bool>(module_handle_)) {
795       CHECK(executor_->UnloadModule(module_handle_));
796     }
797   }
798 
799  private:
800   StreamExecutor* executor_;
801   ModuleHandle module_handle_;
802 
803   TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
804 };
805 
806 ////////////
807 // Inlines
808 
809 template <typename... Args>
810 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)811 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
812                                   absl::string_view ptx,
813                                   absl::Span<const uint8> cubin_data) {
814   auto kernel_base = std::make_unique<TypedKernel<Args...>>(this);
815   MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
816   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
817 
818   if (!cubin_data.empty()) {
819     loader_spec.AddCudaCubinInMemory(
820         reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
821   }
822 
823   TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
824   return std::move(kernel_base);
825 }
826 
827 template <typename T>
AllocateArray(uint64_t element_count,int64_t memory_space)828 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64_t element_count,
829                                                      int64_t memory_space) {
830   uint64_t bytes = sizeof(T) * element_count;
831   return DeviceMemory<T>(Allocate(bytes, memory_space));
832 }
833 
834 template <typename T>
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle)835 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
836     const std::string& symbol_name, ModuleHandle module_handle) {
837   port::StatusOr<DeviceMemoryBase> untyped_symbol =
838       GetUntypedSymbol(symbol_name, module_handle);
839   if (!untyped_symbol.ok()) {
840     return untyped_symbol.status();
841   }
842   return DeviceMemory<T>(untyped_symbol.ValueOrDie());
843 }
844 
845 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)846 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor* parent,
847                                               DeviceMemoryBase value)
848     : wrapped_(value),
849       device_ordinal_(parent->device_ordinal()),
850       allocator_(parent->GetAllocator()) {}
851 
852 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)853 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
854     StreamExecutor* parent, std::initializer_list<ElemT> values)
855     : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
856   if (ptr() != nullptr) {
857     std::vector<ElemT> local(values);
858     if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT*>(&local[0]),
859                                    ptr()->size())) {
860       TF_CHECK_OK(Free());
861     }
862   }
863 }
864 
865 template <typename T>
AllocateZeroed()866 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
867   DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
868   if (buf.is_null()) {
869     return DeviceMemory<T>{};
870   }
871 
872   DeviceMemory<T> result(buf);
873   bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
874   if (!ok) {
875     Deallocate(&result);
876     return DeviceMemory<T>{};
877   }
878 
879   return result;
880 }
881 
882 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64_t element_offset,uint64_t element_count)883 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T>* parent,
884                                              uint64_t element_offset,
885                                              uint64_t element_count) {
886   if (element_offset + element_count > parent->ElementCount()) {
887     LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
888                << "than parent allocation size: (" << element_offset << " + "
889                << element_count << ") vs. (" << parent->ElementCount() << ")";
890     return DeviceMemory<T>{};
891   }
892 
893   void* opaque = implementation_->GetSubBuffer(
894       parent, sizeof(T) * element_offset, sizeof(T) * element_count);
895   if (opaque == nullptr) {
896     return DeviceMemory<T>{};
897   }
898   return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
899 }
900 
901 }  // namespace stream_executor
902 
903 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
904