• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18 
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24 
25 #include "absl/base/macros.h"
26 #include "absl/synchronization/mutex.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/stream_executor/device_memory_allocator.h"
29 #include "tensorflow/stream_executor/lib/status.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31 #include "tensorflow/stream_executor/lib/threadpool.h"
32 #include "tensorflow/stream_executor/platform.h"
33 #include "tensorflow/stream_executor/platform/logging.h"
34 #include "tensorflow/stream_executor/platform/port.h"
35 #include "tensorflow/stream_executor/platform/thread_annotations.h"
36 #include "tensorflow/stream_executor/rng.h"
37 #include "tensorflow/stream_executor/shared_memory_config.h"
38 #include "tensorflow/stream_executor/stream.h"
39 #include "tensorflow/stream_executor/stream_executor_internal.h"
40 #include "tensorflow/stream_executor/trace_listener.h"
41 
42 namespace stream_executor {
43 
44 // Structure used for device memory leak checking.
45 struct AllocRecord {
46   // The requested allocation size of the buffer.
47   uint64 bytes;
48 
49   // Holds a representation of the stack at the time the associated buffer was
50   // allocated. Produced in a form described in
51   // //util/symbolize/symbolized_stacktrace.h.
52   string stack_trace;
53 };
54 
55 // Forward declaration of private friend class.
56 template <typename BeginCallT, typename CompleteCallT,
57           typename ReturnT, typename... BeginArgsT>
58 class ScopedTracer;
59 
60 // A StreamExecutor manages a single device, in terms of executing work (kernel
61 // launches) and memory management (allocation/deallocation, memory copies to
62 // and from the device). It is conceptually the "handle" for a device -- Stream
63 // objects, which are used to enqueue work to run on the
64 // coprocessor have a StreamExecutor instance as their "parent" object.
65 //
66 // StreamExecutor objects have an underlying platform that is specified up
67 // front;
68 // e.g. either it is a CUDA or OpenCL executor.
69 //
70 // Thread-safe after initialization.
71 // StreamExecutor interface should not be invoked from a signal handler.
72 class StreamExecutor {
73  public:
74   StreamExecutor(
75       const Platform *platform,
76       std::unique_ptr<internal::StreamExecutorInterface> implementation,
77       int device_ordinal);
78 
79   ~StreamExecutor();
80 
81   port::Status Init();
82   port::Status Init(DeviceOptions device_options);
83 
84   // Returns the platform that this StreamExecutor is acting upon.
85   ABSL_DEPRECATED("Use platform() instead.")
platform_kind()86   PlatformKind platform_kind() const { return platform_kind_; }
87 
88   // Returns a reference to the platform that created this executor.
platform()89   const Platform *platform() const { return platform_; }
90 
91   // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
92   // upon, if one exists.
93   //
94   // Parameters:
95   //   spec: The MultiKernelLoaderSpec is usually generated as a compile-time
96   //    constant into an appropriate namespace. For example, see
97   //    stream_executor::executor_sample::kKernelLoaderSpecs, from which a
98   //    MultiKernelLoaderSpec is selected.
99   //   kernel: Outparam that the kernel is loaded into. A given Kernel
100   //    instantiation should not be loaded into more than once.
101   //
102   // If an error occurs, or there is no kernel available for the StreamExecutor
103   // platform, error status is returned.
104   port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
105 
106   // Releases any state associated with the previously loaded kernel.
107   void UnloadKernel(const KernelBase *kernel);
108 
109   // Loads a module for the platform this StreamExecutor is acting upon.
110   //
111   // `spec` describes the module to be loaded.  On success writes the handle for
112   // the loaded module to `module_handle` and returns Status::OK.
113   // Otherwise, returns the error which has occurred.
114   port::Status LoadModule(const MultiModuleLoaderSpec &spec,
115                           ModuleHandle *module_handle);
116 
117   // Unloads the module with handle `module_handle`.
118   bool UnloadModule(ModuleHandle module_handle);
119 
120   // Synchronously allocates an array on the device of type T with element_count
121   // elements.
122   template <typename T>
123   DeviceMemory<T> AllocateArray(uint64 element_count, int64 memory_space = 0);
124 
125   // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
126   template <typename T>
AllocateOwnedArray(uint64 element_count)127   ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
128     return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
129   }
130 
131   // Convenience wrapper that allocates space for a single element of type T in
132   // device memory.
133   template <typename T>
AllocateScalar()134   DeviceMemory<T> AllocateScalar() {
135     return AllocateArray<T>(1);
136   }
137 
138   // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
139   template <typename T>
AllocateOwnedScalar()140   ScopedDeviceMemory<T> AllocateOwnedScalar() {
141     return AllocateOwnedArray<T>(1);
142   }
143 
144   // Synchronously allocates a scalar of type T on the device that is (POD)
145   // zero-byte initialized.
146   template <typename T>
147   DeviceMemory<T> AllocateZeroed();
148 
149   // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
150   template <typename T>
AllocateOwnedZeroed()151   ScopedDeviceMemory<T> AllocateOwnedZeroed() {
152     return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
153   }
154 
155   // Allocate a memory region inside another allocated memory region.
156   // Offset and size are specified in terms of T elements.
157   // Warning: Do not free a parent buffer before its sub-buffers; this may cause
158   // use-after-free issues (the specific behavior is not consistent across
159   // platforms).
160   //  - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
161   //    sub-buffer after parent deallocation is expected to be safe. This will
162   //    render your code non-platform-portable, however.
163   template <typename T>
164   DeviceMemory<T> GetSubBuffer(DeviceMemory<T> *parent, uint64 element_offset,
165                                uint64 element_count);
166 
167   // Finds a symbol and returns device memory allocated to the symbol. The
168   // symbol is searched in any kernels that were previously loaded through
169   // GetKernel() before the GetSymbol() call. The user has to make sure that the
170   // type of symbol and T match.
171   // - Note: symbol_name should include its namespace as well. For example,
172   //         pass "nms0::symbol" if referring to nms0::symbol.
173   //
174   // If `module_handle` is set then searches only within the module
175   // corresponding to `module_handle`.
176   template <typename T>
177   port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
178                                             ModuleHandle module_handle = {});
179 
180   // An untyped version of GetSymbol.
181   port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
182       const string &symbol_name, ModuleHandle module_handle = {});
183 
184   // Deallocate the DeviceMemory previously allocated via this interface.
185   // Deallocation of a nullptr-representative value is permitted.
186   //
187   // Resets the internal contents of mem to be null-representative, but this
188   // null-out effect should not be relied upon in client code.
189   void Deallocate(DeviceMemoryBase *mem);
190 
191   // Retrieves a mapping of active opaque device memory pointer to a string
192   // representation of the [allocating thread's] stack at the time the pointer
193   // was allocated. Useful for tracking device memory leaks.
194   //
195   // Note: this will only be populated if --check_device_leaks flag is
196   // activated.
197   void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
198 
199   // Allocates unified memory space of the given size, if supported.
200   // See
201   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
202   // for more details on unified memory.
203   void *UnifiedMemoryAllocate(uint64 bytes);
204 
205   // Deallocates unified memory space previously allocated with
206   // UnifiedMemoryAllocate.
207   void UnifiedMemoryDeallocate(void *location);
208 
209   // Allocates a region of host memory and registers it with the platform API.
210   // Memory allocated in this manner (or allocated and registered with
211   // HostMemoryRegister() is required for use in asynchronous memcpy operations,
212   // such as Stream::ThenMemcpy.
213   void *HostMemoryAllocate(uint64 size);
214 
215   // Deallocates a region of host memory allocated by HostMemoryAllocate().
216   void HostMemoryDeallocate(void *location);
217 
218   // Registers a region of host memory with the platform API. Registered memory
219   // (or memory allocated with HostMemoryAllocate) is required for use with
220   // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
221   // is used to register memory allocated outside the StreamExecutor;
222   // HostMemoryAllocate implicitly registers its allocations and
223   // HostMemoryDeallocate implicitly deregisters on deallocation.
224   bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
225 
226   // Unregisters a region of host memory registered with HostMemoryRegister.
227   // This should be done before deallocating the region with delete[]/free/etc.
228   bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
229 
230   // Synchronizes all activity occurring in the StreamExecutor's context (most
231   // likely a whole device).
232   bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
233 
234   // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
235   // given location in device memory.
236   port::Status SynchronousMemZero(DeviceMemoryBase *location,
237                                   uint64 size) SE_MUST_USE_RESULT;
238 
239   // Blocks the caller while "size" bytes are initialized to "value" (in POD
240   // fashion) at the given location in device memory.
241   port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
242                                  uint64 size) SE_MUST_USE_RESULT;
243 
244   // [deprecated] Blocks the caller while a data segment of the given size is
245   // copied from the host source to the device destination.
246   ABSL_DEPRECATED(
247       "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
248   bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
249                          uint64 size) SE_MUST_USE_RESULT;
250 
251   // [deprecated] Blocks the caller while a data segment of the given size is
252   // copied from the device source to the host destination.
253   ABSL_DEPRECATED(
254       "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
255   bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
256                          uint64 size) SE_MUST_USE_RESULT;
257 
258   // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
259   port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
260                                     DeviceMemoryBase *device_dst);
261 
262   // Alternative interface for memcpying from host to device that takes an
263   // array slice. Checks that the destination size can accommodate the host
264   // slice size.
265   template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)266   port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
267                                     DeviceMemoryBase *device_dst) {
268     auto host_size = host_src.size() * sizeof(T);
269     CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
270     return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
271   }
272 
273   // Same as SynchronousMemcpy(void*, ...) above.
274   port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
275                                     int64 size, void *host_dst);
276 
277   // Alternative interface for memcpying from device to host that takes an
278   // array slice. Checks that the destination size can accommodate the host
279   // slice size.
280   template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)281   port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
282                                     port::MutableArraySlice<T> host_dst) {
283     auto host_size = host_dst.size() * sizeof(T);
284     CHECK(device_src.size() == 0 || host_size >= device_src.size());
285     return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
286   }
287 
288   // Blocks the caller while a data segment of the given size is copied from the
289   // device source to the device destination.
290   bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
291                          const DeviceMemoryBase &device_src,
292                          uint64 size) SE_MUST_USE_RESULT;
293 
294   // Enqueues an operation onto stream to zero out size bytes at the given
295   // device memory location. Neither stream nor location may be null. Returns
296   // whether the operation was successfully enqueued onto the stream.
297   port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
298                        uint64 size) SE_MUST_USE_RESULT;
299 
300   // Enqueues an operation onto stream to set 32-bit patterns starting at
301   // location, for byte count given by size. size must be 32-bit quantified
302   // (i.e. evently divisible by 4). Returns whether the operation was
303   // successfully enqueued onto the stream.
304   port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
305                         uint32 pattern, uint64 size);
306 
307   // Enables peer access from this StreamExecutor to memory
308   // allocated by other, such that launched device code, memcpies, etc may
309   // access it directly.
310   //
311   // Both this StreamExecutor and other must be backed by the same platform (as
312   // in
313   // CUDA vs OpenCL) implementation.
314   port::Status EnablePeerAccessTo(StreamExecutor *other);
315 
316   // Returns whether it's possible to enable peer access from this
317   // StreamExecutor
318   // to memory allocated by another.
319   //
320   // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
321   // this is more an up-front test as to whether it's expressly forbidden.
322   bool CanEnablePeerAccessTo(StreamExecutor *other);
323 
324   // Gets the preferred shared memory configuration for the device to which this
325   // executor is bound.
326   SharedMemoryConfig GetDeviceSharedMemoryConfig();
327 
328   // Sets the preferred shared memory configuration for the device to which this
329   // executor is bound.
330   port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
331 
332   // Obtains metadata about the underlying device.
333   // The value is cached on first use.
334   const DeviceDescription &GetDeviceDescription() const;
335 
336   // If implemented, returns device specific measurement of load
337   // (e.g. pending requests).
338   int64 GetDeviceLoad() const;
339 
340   // Returns the underlying device memory usage information, if it is available.
341   // If it is not available (false is returned), free/total may not be
342   // initialized.
343   //
344   // Note: "Free" reflects the amount of free memory on the underlying device,
345   // so allocations via other StreamExecutors that have the same underlying
346   // device
347   // will be reflected in "free".
348   bool DeviceMemoryUsage(int64 *free, int64 *total) const;
349 
350   // The device count reported by this StreamExecutor's platform.
351   // Note: on OpenCL we implicitly select platform zero at the moment.
352   int PlatformDeviceCount() const;
353 
354   // Returns whether the StreamExecutor supports BLAS routines for the platform
355   // that underlies this interface.
356   bool SupportsBlas() const;
357 
358   // Returns whether the StreamExecutor supports FFT routines for the platform
359   // that underlies this interface.
360   bool SupportsFft() const;
361 
362   // Returns whether the StreamExecutor supports RNG routines for the platform
363   // that underlies this interface.
364   bool SupportsRng() const;
365 
366   // Returns whether the StreamExecutor support neural net routines for the
367   // platform that underlies this interface.
368   bool SupportsDnn() const;
369 
370   // Returns the list of supported algorithms for the forward convolution
371   // operation.
372   bool GetConvolveAlgorithms(bool with_winograd_nonfused,
373                              std::vector<dnn::AlgorithmDesc> *out_algorithms);
374 
375   // Returns the list of supported algorithms for the forward convolution
376   // operation.
377   bool GetMIOpenConvolveAlgorithms(
378       dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
379       const dnn::BatchDescriptor &input_descriptor,
380       const dnn::FilterDescriptor &filter_descriptor,
381       const dnn::ConvolutionDescriptor &convolution_descriptor,
382       const dnn::BatchDescriptor &output_descriptor,
383       std::vector<dnn::ProfileResult> *out_algorithms);
384 
385   // Returns the list of supported algorithms for rnn operation.
386   bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
387 
388   // Get the list of supported algorithms for the backward convolution on data.
389   bool GetConvolveBackwardDataAlgorithms(
390       bool with_winograd_nonfused,
391       std::vector<dnn::AlgorithmDesc> *out_algorithms);
392 
393   // Get the list of supported algorithms for the backward convolution on the
394   // filter.
395   bool GetConvolveBackwardFilterAlgorithms(
396       bool with_winograd_nonfused,
397       std::vector<dnn::AlgorithmDesc> *out_algorithms);
398 
399   // Get the list of supported algorithms for BLAS gemm.
400   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
401 
402   // Create an RNN descriptor based on model shapes and configurations.
403   // The caller retains the ownership of the descriptor.
404   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
405       int num_layers, int hidden_size, int input_size, int cell_size,
406       int batch_size, dnn::RnnInputMode input_mode,
407       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
408       dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
409       float dropout, uint64 seed, ScratchAllocator *state_allocator,
410       bool use_padded_io);
411 
412   // Create a RNN sequence descriptor that specifies either the input or output
413   // sequence. The caller retains the ownership of the returned descriptor.
414   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
415   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
416                                     int data_size, dnn::DataType data_type);
417 
418   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
419   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
420                                     int data_size,
421                                     const absl::Span<const int> &seq_lengths,
422                                     bool time_major, dnn::DataType data_type);
423 
424   // Create an RNN state descriptor that specifies the input or hidden state.
425   // The caller retains the ownership of the returned descriptor.
426   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
427   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
428                                  dnn::DataType data_type);
429 
430   // Returns the device ordinal that this StreamExecutor was initialized with.
431   // Meaningless before initialization.
device_ordinal()432   int device_ordinal() const { return device_ordinal_; }
433 
434   // Returns a borrowed pointer to the underlying StreamExecutor implementation.
435   internal::StreamExecutorInterface *implementation();
436 
437   // Creates a kernel which can be launched with stream.ThenLaunch, such that
438   // the types of the arguments provided for launch would have to match
439   // types of the arguments provided at creation time.
440   //
441   // The kernel has a name kernel_name, and is based from provided PTX in ptx,
442   // and (optional) compiled PTX in cubin_data.
443   // The canonical storage for both ptx and cubin_data should outlive the
444   // lifetime of the kernel.
445   template <typename... Args>
446   port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
447       absl::string_view kernel_name, absl::string_view ptx,
448       absl::Span<const uint8> cubin_data);
449 
450   // Warning: use Stream::ThenLaunch instead, this method is not for general
451   // consumption. However, this is the only way to launch a kernel for which
452   // the type signature is only known at runtime; say, if an application
453   // supports loading/launching kernels with arbitrary type signatures.
454   // In this case, the application is expected to know how to do parameter
455   // packing that obeys the contract of the underlying platform implementation.
456   //
457   // Launches a data parallel kernel with the given thread/block
458   // dimensionality and already-packed args/sizes to pass to the underlying
459   // platform driver.
460   //
461   // This is called by Stream::Launch() to delegate to the platform's launch
462   // implementation in StreamExecutorInterface::Launch().
463   port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
464                       const BlockDim &block_dims, const KernelBase &kernel,
465                       const KernelArgsArrayBase &args);
466 
467   // Gets-or-creates (creates with memoization) a FftSupport datatype that can
468   // be used to execute FFT routines on the current platform.
469   //
470   // Ownership and user-facing is the same as AsBlas() below.
471   //
472   // Returns null if there was an error initializing the FFT support for the
473   // underlying platform.
474   fft::FftSupport *AsFft();
475 
476   // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
477   // be used for neural network routines on the current platform.
478   //
479   // Ownership and user-facing is the same as AsBlas() below.
480   //
481   // Returns null if there was an error initializing the DNN support for the
482   // underlying platform.
483   dnn::DnnSupport *AsDnn();
484 
485   // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
486   // be used to execute BLAS routines on the current platform. This is typically
487   // not user-facing, as users will use the Stream::ThenBlas* family of routines
488   // to entrain BLAS operations. See blas.h for additional details.
489   //
490   // Ownership is not transferred to the caller -- ownership is retained by this
491   // object for memoization. This BLAS interface is also only expected to be
492   // used by a Stream for entraining calls to BLAS functionality.
493   //
494   // Returns null if there was an error initializing the BLAS support for the
495   // underlying platform.
496   blas::BlasSupport *AsBlas();
497 
498   // Turns StreamExecutor operation tracing on or off.
499   void EnableTracing(bool enable);
500 
501   // Registers a trace listener to receive callbacks for only a single
502   // StreamExecutor instance.
503   // To register a listener for all executors for a given platform, see
504   // Platform::RegisterTraceListener().
505   // Does not take ownership of listener.
506   void RegisterTraceListener(TraceListener* listener);
507 
508   // Removes a TraceListener from this StreamExecutor instance.
509   // Returns false (and logs) in cases where the argument listener was not
510   // previously registered.
511   bool UnregisterTraceListener(TraceListener* listener);
512 
513   // Return allocator statistics.
514   absl::optional<AllocatorStats> GetAllocatorStats();
515 
516   // Return an allocator which delegates to this stream executor for memory
517   // allocation.
GetAllocator()518   StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
519 
520  private:
521   template <typename BeginCallT, typename CompleteCallT,
522             typename ReturnT, typename... BeginArgsT>
523   friend class ScopedTracer;
524   friend class Event;
525   friend class Stream;
526   friend class Timer;
527   template <typename... Params>
528   friend class TypedKernel;
529   template <typename... Args>
530   friend struct ThenBlasImpl;
531 
532   // Synchronously allocates size bytes on the underlying platform and returns
533   // a DeviceMemoryBase representing that allocation. In the case of failure,
534   // nullptr is returned.
535   DeviceMemoryBase Allocate(uint64 size, int64 memory_space);
536 
537   // Gets-or-creates (creates with memoization) an RngSupport datatype that can
538   // be used for random-number-generation routines on the current platform.
539   //
540   // Ownership and user-facing is the same as AsBlas() above.
541   //
542   // Returns null if there was an error initializing the RNG support for the
543   // underlying platform.
544   rng::RngSupport *AsRng();
545 
546   // Causes the host code to synchronously wait for operations entrained onto
547   // stream to complete. Effectively a join on the asynchronous device
548   // operations enqueued on the stream before this program point.
549   port::Status BlockHostUntilDone(Stream *stream);
550 
551   // Without blocking the device, retrieve the current stream status.
552   port::Status GetStatus(Stream *stream);
553 
554   // Finds and retrieves device memory for the symbol on the underlying
555   // platform.
556   bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
557                  void **mem, size_t *bytes);
558 
559   // Entrains a memcpy operation onto stream, with a host destination location
560   // host_dst and a device memory source, with target size size.
561   bool Memcpy(Stream *stream, void *host_dst,
562               const DeviceMemoryBase &device_src, uint64 size);
563 
564   // Entrains a memcpy operation onto stream, with a device destination location
565   // and a host memory source, with target size size.
566   bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
567               const void *host_src, uint64 size);
568 
569   // Entrains a memcpy operation onto stream, with a device destination location
570   // and a device source location, with target size size. Peer access should
571   // have been enabled between the StreamExecutors owning the device memory
572   // regions.
573   bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
574                             const DeviceMemoryBase &device_src, uint64 size);
575 
576   // Entrains on a stream a user-specified function to be run on the host.
577   // See Stream::ThenDoHostCallback for full details.
578   bool HostCallback(Stream *stream, std::function<void()> callback);
579 
580   // Entrains on a stream a user-specified function to be run on the host.
581   // See Stream::ThenDoHostCallback for full details.
582   // This is the preferred form for a callback that may return an error.
583   bool HostCallback(Stream *stream, std::function<port::Status()> callback);
584 
585   // Performs platform-specific allocation and initialization of an event.
586   port::Status AllocateEvent(Event *event);
587 
588   // Performs platform-specific deallocation and cleanup of an event.
589   port::Status DeallocateEvent(Event *event);
590 
591   // Inserts the specified event at the end of the specified stream.
592   port::Status RecordEvent(Stream *stream, Event *event);
593 
594   // Wait for the specified event at the end of the specified stream.
595   port::Status WaitForEvent(Stream *stream, Event *event);
596 
597   // Requests the current status of the event from the underlying platform.
598   Event::Status PollForEventStatus(Event *event);
599 
600   // Allocates stream resources on the underlying platform and initializes its
601   // internals.
602   bool AllocateStream(Stream *stream);
603 
604   // Deallocates stream resources on the underlying platform.
605   void DeallocateStream(Stream *stream);
606 
607   // Causes dependent to not begin execution until other has finished its
608   // last-enqueued work.
609   bool CreateStreamDependency(Stream *dependent, Stream *other);
610 
611   // Allocates timer resources on the underlying platform and initializes its
612   // internals.
613   bool AllocateTimer(Timer *timer);
614 
615   // Deallocates timer resources on the underlying platform.
616   void DeallocateTimer(Timer *timer);
617 
618   // Records a start event for an interval timer.
619   bool StartTimer(Stream *stream, Timer *timer);
620 
621   // Records a stop event for an interval timer.
622   bool StopTimer(Stream *stream, Timer *timer);
623 
624   // Allocates a new metadata object, appropriately populated, on the heap, with
625   // ownership transfer to caller.
626   std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
627 
628   // Adds a task to the port::ThreadPool work queue. These tasks must be
629   // fire-and-forget and have no external data or timing dependencies; their
630   // execution order and completion time have no guarantees.
631   // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
632   // there, temporary internal buffers are freed using this method.
633   void EnqueueOnBackgroundThread(std::function<void()> task);
634 
635   // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
636   // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
637   // tracked.
638   void CreateAllocRecord(void *opaque, uint64 bytes);
639 
640   // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
641   // pointers will not be erased (as they're not tracked, per above).
642   void EraseAllocRecord(void *opaque);
643 
644   // Calls the relevant TraceListener routine to begin tracing for the specified
645   // asynchronous method.
646   template <typename TraceCallT, typename... ArgsT>
647   void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
648 
649   // Reader/writer lock for class-static StreamExecutor members.
650   static absl::Mutex static_mu_;
651 
652   // Reader/writer lock for mutable data structures on this StreamExecutor.
653   //
654   // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
655   // can acquire the lock on their first (mutating) call as well.
656   mutable absl::Mutex mu_;
657 
658   // Reference to the platform that created this executor.
659   const Platform *platform_;
660 
661   // Pointer to the platform-specific-interface implementation. This is
662   // delegated to by the interface routines in pointer-to-implementation
663   // fashion.
664   std::unique_ptr<internal::StreamExecutorInterface> implementation_;
665 
666   // A mapping of pointer (to device memory) to string representation of the
667   // stack (of the allocating thread) at the time at which the pointer was
668   // allocated.
669   std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
670 
671   // Memoized BLAS support object -- we only want to create this once when asked
672   // for a BLAS interface.
673   std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
674 
675   // Memoized DNN support object -- we only want to create this once when asked
676   // for an DNN interface.
677   std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
678 
679   // Memoized FFT support object -- we only want to create this once when asked
680   // for a FFT interface.
681   std::unique_ptr<fft::FftSupport> fft_;
682 
683   // Memoized RNG support object -- we only want to create this once when asked
684   // for an RNG interface.
685   std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
686 
687   // Slot to cache the owned DeviceDescription for the underlying device
688   // once it has been queried from DeviceDescription().
689   mutable std::unique_ptr<DeviceDescription> device_description_
690       GUARDED_BY(mu_);
691 
692   // The kind of the underlying platform that is being targeted, as passed
693   // during construction.
694   //
695   // Immutable post-initialization.
696   PlatformKind platform_kind_;
697 
698   // The device ordinal that this object was initialized with.
699   //
700   // Immutable post-initialization.
701   int device_ordinal_;
702 
703   // Executor for handling host callback work that cannot be performed
704   // by a host callback thread - for example, cleanup after a host BLAS routine
705   // (which may make device API calls). This work cannot block the host
706   // callback thread, will be completed asynchronously, and should be treated
707   // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
708   // here.
709   //
710   // Immutable post-initialization. Object is thread-safe.
711   std::unique_ptr<port::ThreadPool> background_threads_;
712 
713   // Counter for the current number of live streams. This is used to check
714   // for accidentally-outstanding streams at StreamExecutor teardown time, as
715   // well
716   // as to indicate leaks (via a large outstanding count being logged) in the
717   // case we can't allocate more streams.
718   std::atomic_int_fast32_t live_stream_count_;
719 
720   // Only one worker thread is needed; little work will be done by the
721   // executor.
722   static const int kNumBackgroundThreads = 1;
723 
724   // Indicates if StreamExecutor operation tracing should be performed.
725   bool tracing_enabled_;
726 
727   // The set of TraceListeners registered for this StreamExecutor.
728   std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
729 
730   // Allocated memory in bytes.
731   int64 mem_alloc_bytes_;
732 
733   // Memory limit in bytes. Value less or equal to 0 indicates there is no
734   // limit.
735   int64 memory_limit_bytes_;
736 
737   StreamExecutorMemoryAllocator allocator_;
738 
739   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
740 };
741 
742 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
743 class ScopedModuleHandle {
744  public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)745   explicit ScopedModuleHandle(StreamExecutor *executor,
746                               ModuleHandle module_handle)
747       : executor_(executor), module_handle_(module_handle) {}
748 
ScopedModuleHandle(ScopedModuleHandle && other)749   ScopedModuleHandle(ScopedModuleHandle &&other) {
750     executor_ = other.executor_;
751     module_handle_ = other.module_handle_;
752     other.executor_ = nullptr;
753     other.module_handle_ = ModuleHandle();
754   }
755 
756   ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
757     executor_ = other.executor_;
758     module_handle_ = other.module_handle_;
759     other.executor_ = nullptr;
760     other.module_handle_ = ModuleHandle();
761     return *this;
762   }
763 
~ScopedModuleHandle()764   ~ScopedModuleHandle() {
765     if (static_cast<bool>(module_handle_)) {
766       CHECK(executor_->UnloadModule(module_handle_));
767     }
768   }
769 
770  private:
771   StreamExecutor *executor_;
772   ModuleHandle module_handle_;
773 
774   TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
775 };
776 
777 ////////////
778 // Inlines
779 
780 template <typename... Args>
781 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)782 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
783                                   absl::string_view ptx,
784                                   absl::Span<const uint8> cubin_data) {
785   auto kernel_base = absl::make_unique<TypedKernel<Args...>>(this);
786   MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
787   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
788 
789   if (!cubin_data.empty()) {
790     loader_spec.AddCudaCubinInMemory(
791         reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
792   }
793 
794   TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
795   return std::move(kernel_base);
796 }
797 
798 template <typename T>
AllocateArray(uint64 element_count,int64 memory_space)799 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
800                                                      int64 memory_space) {
801   uint64 bytes = sizeof(T) * element_count;
802   return DeviceMemory<T>(Allocate(bytes, memory_space));
803 }
804 
805 template <typename T>
GetSymbol(const string & symbol_name,ModuleHandle module_handle)806 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
807     const string &symbol_name, ModuleHandle module_handle) {
808   port::StatusOr<DeviceMemoryBase> untyped_symbol =
809       GetUntypedSymbol(symbol_name, module_handle);
810   if (!untyped_symbol.ok()) {
811     return untyped_symbol.status();
812   }
813   return DeviceMemory<T>(untyped_symbol.ValueOrDie());
814 }
815 
816 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)817 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
818                                               DeviceMemoryBase value)
819     : wrapped_(value),
820       device_ordinal_(parent->device_ordinal()),
821       allocator_(parent->GetAllocator()) {}
822 
823 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)824 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
825     StreamExecutor *parent, std::initializer_list<ElemT> values)
826     : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
827   if (ptr() != nullptr) {
828     std::vector<ElemT> local(values);
829     if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
830                                    ptr()->size())) {
831       TF_CHECK_OK(Free());
832     }
833   }
834 }
835 
836 template <typename T>
AllocateZeroed()837 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
838   DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
839   if (buf.is_null()) {
840     return DeviceMemory<T>{};
841   }
842 
843   DeviceMemory<T> result(buf);
844   bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
845   if (!ok) {
846     Deallocate(&result);
847     return DeviceMemory<T>{};
848   }
849 
850   return result;
851 }
852 
853 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)854 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T> *parent,
855                                              uint64 element_offset,
856                                              uint64 element_count) {
857   if (element_offset + element_count > parent->ElementCount()) {
858     LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
859                << "than parent allocation size: (" << element_offset << " + "
860                << element_count << ") vs. (" << parent->ElementCount() << ")";
861     return DeviceMemory<T>{};
862   }
863 
864   void *opaque = implementation_->GetSubBuffer(
865       parent, sizeof(T) * element_offset, sizeof(T) * element_count);
866   if (opaque == nullptr) {
867     return DeviceMemory<T>{};
868   }
869   return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
870 }
871 
872 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)873 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
874                                   const TypedKernel<Params...> &kernel,
875                                   Args... args) {
876   KernelInvocationChecker<std::tuple<Params...>,
877                           std::tuple<Args...>>::CheckAllStaticAssert();
878   if (ok()) {
879     // This is the core that allows type-safe kernel launching.
880     // Since the platforms take kernel arguments as tuples of (void *, size),
881     // we pack the variadic parameters passed as ...args into the desired
882     // tuple form and pass that packed form to the StreamExecutor::Launch()
883     // implementation.
884     KernelArgsArray<sizeof...(args)> kernel_args;
885     kernel.PackParams(&kernel_args, args...);
886     DCHECK(parent_ != nullptr);
887     bool ok =
888         parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)
889             .ok();
890     if (!ok) {
891       SetError();
892       LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
893     }
894   }
895   return *this;
896 }
897 
898 }  // namespace stream_executor
899 
900 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
901