• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18 
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24 
25 #include "absl/base/macros.h"
26 #include "absl/memory/memory.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/platform/thread_annotations.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31 #include "tensorflow/stream_executor/lib/status.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 #include "tensorflow/stream_executor/lib/threadpool.h"
34 #include "tensorflow/stream_executor/platform.h"
35 #include "tensorflow/stream_executor/platform/logging.h"
36 #include "tensorflow/stream_executor/platform/port.h"
37 #include "tensorflow/stream_executor/rng.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/trace_listener.h"
40 #if GOOGLE_CUDA
41 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
42 #endif  // GOOGLE_CUDA
43 
44 namespace stream_executor {
45 
46 class Stream;
47 
48 // Structure used for device memory leak checking.
49 struct AllocRecord {
50   // The requested allocation size of the buffer.
51   uint64 bytes;
52 
53   // Holds a representation of the stack at the time the associated buffer was
54   // allocated. Produced in a form described in
55   // //util/symbolize/symbolized_stacktrace.h.
56   std::string stack_trace;
57 };
58 
59 // Forward declaration of private friend class.
60 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
61           typename... BeginArgsT>
62 class ScopedTracer;
63 
64 // A StreamExecutor manages a single device, in terms of executing work (kernel
65 // launches) and memory management (allocation/deallocation, memory copies to
66 // and from the device). It is conceptually the "handle" for a device -- Stream
67 // objects, which are used to enqueue work to run on the
68 // coprocessor have a StreamExecutor instance as their "parent" object.
69 //
70 // StreamExecutor objects have an underlying platform that is specified up
71 // front;
72 // e.g. either it is a CUDA or OpenCL executor.
73 //
74 // Thread-safe after initialization.
75 // StreamExecutor interface should not be invoked from a signal handler.
76 class StreamExecutor {
77  public:
78   StreamExecutor(
79       const Platform *platform,
80       std::unique_ptr<internal::StreamExecutorInterface> implementation,
81       int device_ordinal);
82 
83   ~StreamExecutor();
84 
85   port::Status Init();
86   port::Status Init(DeviceOptions device_options);
87 
88   // Returns the platform that this StreamExecutor is acting upon.
89   ABSL_DEPRECATED("Use platform() instead.")
platform_kind()90   PlatformKind platform_kind() const { return platform_kind_; }
91 
92   // Returns a reference to the platform that created this executor.
platform()93   const Platform *platform() const { return platform_; }
94 
95   // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
96   // upon, if one exists.
97   //
98   // Parameters:
99   //   spec: The MultiKernelLoaderSpec is usually generated as a compile-time
100   //    constant into an appropriate namespace. For example, see
101   //    stream_executor::executor_sample::kKernelLoaderSpecs, from which a
102   //    MultiKernelLoaderSpec is selected.
103   //   kernel: Outparam that the kernel is loaded into. A given Kernel
104   //    instantiation should not be loaded into more than once.
105   //
106   // If an error occurs, or there is no kernel available for the StreamExecutor
107   // platform, error status is returned.
108   port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
109 
110   // Releases any state associated with the previously loaded kernel.
111   void UnloadKernel(const KernelBase *kernel);
112 
113   // Loads a module for the platform this StreamExecutor is acting upon.
114   //
115   // `spec` describes the module to be loaded.  On success writes the handle for
116   // the loaded module to `module_handle` and returns Status::OK.
117   // Otherwise, returns the error which has occurred.
118   port::Status LoadModule(const MultiModuleLoaderSpec &spec,
119                           ModuleHandle *module_handle);
120 
121   // Unloads the module with handle `module_handle`.
122   bool UnloadModule(ModuleHandle module_handle);
123 
124   // Synchronously allocates an array on the device of type T with element_count
125   // elements.
126   template <typename T>
127   DeviceMemory<T> AllocateArray(uint64 element_count, int64_t memory_space = 0);
128 
129   // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
130   template <typename T>
AllocateOwnedArray(uint64 element_count)131   ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
132     return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
133   }
134 
135   // Convenience wrapper that allocates space for a single element of type T in
136   // device memory.
137   template <typename T>
AllocateScalar()138   DeviceMemory<T> AllocateScalar() {
139     return AllocateArray<T>(1);
140   }
141 
142   // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
143   template <typename T>
AllocateOwnedScalar()144   ScopedDeviceMemory<T> AllocateOwnedScalar() {
145     return AllocateOwnedArray<T>(1);
146   }
147 
148   // Synchronously allocates a scalar of type T on the device that is (POD)
149   // zero-byte initialized.
150   template <typename T>
151   DeviceMemory<T> AllocateZeroed();
152 
153   // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
154   template <typename T>
AllocateOwnedZeroed()155   ScopedDeviceMemory<T> AllocateOwnedZeroed() {
156     return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
157   }
158 
159   // Allocate a memory region inside another allocated memory region.
160   // Offset and size are specified in terms of T elements.
161   // Warning: Do not free a parent buffer before its sub-buffers; this may cause
162   // use-after-free issues (the specific behavior is not consistent across
163   // platforms).
164   //  - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
165   //    sub-buffer after parent deallocation is expected to be safe. This will
166   //    render your code non-platform-portable, however.
167   template <typename T>
168   DeviceMemory<T> GetSubBuffer(DeviceMemory<T> *parent, uint64 element_offset,
169                                uint64 element_count);
170 
171   // Finds a symbol and returns device memory allocated to the symbol. The
172   // symbol is searched in any kernels that were previously loaded through
173   // GetKernel() before the GetSymbol() call. The user has to make sure that the
174   // type of symbol and T match.
175   // - Note: symbol_name should include its namespace as well. For example,
176   //         pass "nms0::symbol" if referring to nms0::symbol.
177   //
178   // If `module_handle` is set then searches only within the module
179   // corresponding to `module_handle`.
180   template <typename T>
181   port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string &symbol_name,
182                                             ModuleHandle module_handle = {});
183 
184   // An untyped version of GetSymbol.
185   port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
186       const std::string &symbol_name, ModuleHandle module_handle = {});
187 
188   // Deallocate the DeviceMemory previously allocated via this interface.
189   // Deallocation of a nullptr-representative value is permitted.
190   //
191   // Resets the internal contents of mem to be null-representative, but this
192   // null-out effect should not be relied upon in client code.
193   void Deallocate(DeviceMemoryBase *mem);
194 
195   // Retrieves a mapping of active opaque device memory pointer to a string
196   // representation of the [allocating thread's] stack at the time the pointer
197   // was allocated. Useful for tracking device memory leaks.
198   //
199   // Note: this will only be populated if --check_device_leaks flag is
200   // activated.
201   void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
202 
203   // Allocates unified memory space of the given size, if supported.
204   // See
205   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
206   // for more details on unified memory.
207   void *UnifiedMemoryAllocate(uint64 bytes);
208 
209   // Deallocates unified memory space previously allocated with
210   // UnifiedMemoryAllocate.
211   void UnifiedMemoryDeallocate(void *location);
212 
213   // Allocates a region of host memory and registers it with the platform API.
214   // Memory allocated in this manner (or allocated and registered with
215   // HostMemoryRegister() is required for use in asynchronous memcpy operations,
216   // such as Stream::ThenMemcpy.
217   void *HostMemoryAllocate(uint64 size);
218 
219   // Deallocates a region of host memory allocated by HostMemoryAllocate().
220   void HostMemoryDeallocate(void *location);
221 
222   // Registers a region of host memory with the platform API. Registered memory
223   // (or memory allocated with HostMemoryAllocate) is required for use with
224   // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
225   // is used to register memory allocated outside the StreamExecutor;
226   // HostMemoryAllocate implicitly registers its allocations and
227   // HostMemoryDeallocate implicitly deregisters on deallocation.
228   bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
229 
230   // Unregisters a region of host memory registered with HostMemoryRegister.
231   // This should be done before deallocating the region with delete[]/free/etc.
232   bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
233 
234   // Synchronizes all activity occurring in the StreamExecutor's context (most
235   // likely a whole device).
236   bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
237 
238   // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
239   // given location in device memory.
240   port::Status SynchronousMemZero(DeviceMemoryBase *location,
241                                   uint64 size) SE_MUST_USE_RESULT;
242 
243   // Blocks the caller while "size" bytes are initialized to "value" (in POD
244   // fashion) at the given location in device memory.
245   port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
246                                  uint64 size) SE_MUST_USE_RESULT;
247 
248   // [deprecated] Blocks the caller while a data segment of the given size is
249   // copied from the host source to the device destination.
250   ABSL_DEPRECATED(
251       "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
252   bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
253                          uint64 size) SE_MUST_USE_RESULT;
254 
255   // [deprecated] Blocks the caller while a data segment of the given size is
256   // copied from the device source to the host destination.
257   ABSL_DEPRECATED(
258       "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
259   bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
260                          uint64 size) SE_MUST_USE_RESULT;
261 
262   // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
263   port::Status SynchronousMemcpyH2D(const void *host_src, int64_t size,
264                                     DeviceMemoryBase *device_dst);
265 
266   // Alternative interface for memcpying from host to device that takes an
267   // array slice. Checks that the destination size can accommodate the host
268   // slice size.
269   template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)270   port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
271                                     DeviceMemoryBase *device_dst) {
272     auto host_size = host_src.size() * sizeof(T);
273     CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
274     return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
275   }
276 
277   // Same as SynchronousMemcpy(void*, ...) above.
278   port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
279                                     int64_t size, void *host_dst);
280 
281   // Alternative interface for memcpying from device to host that takes an
282   // array slice. Checks that the destination size can accommodate the host
283   // slice size.
284   template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)285   port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
286                                     port::MutableArraySlice<T> host_dst) {
287     auto host_size = host_dst.size() * sizeof(T);
288     CHECK(device_src.size() == 0 || host_size >= device_src.size());
289     return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
290   }
291 
292   // Blocks the caller while a data segment of the given size is copied from the
293   // device source to the device destination.
294   bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
295                          const DeviceMemoryBase &device_src,
296                          uint64 size) SE_MUST_USE_RESULT;
297 
298   // Enqueues an operation onto stream to zero out size bytes at the given
299   // device memory location. Neither stream nor location may be null. Returns
300   // whether the operation was successfully enqueued onto the stream.
301   port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
302                        uint64 size) SE_MUST_USE_RESULT;
303 
304   // Enqueues an operation onto stream to set 32-bit patterns starting at
305   // location, for byte count given by size. size must be 32-bit quantified
306   // (i.e. evently divisible by 4). Returns whether the operation was
307   // successfully enqueued onto the stream.
308   port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
309                         uint32 pattern, uint64 size);
310 
311   // Enables peer access from this StreamExecutor to memory
312   // allocated by other, such that launched device code, memcpies, etc may
313   // access it directly.
314   //
315   // Both this StreamExecutor and other must be backed by the same platform (as
316   // in
317   // CUDA vs OpenCL) implementation.
318   port::Status EnablePeerAccessTo(StreamExecutor *other);
319 
320   // Returns whether it's possible to enable peer access from this
321   // StreamExecutor
322   // to memory allocated by another.
323   //
324   // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
325   // this is more an up-front test as to whether it's expressly forbidden.
326   bool CanEnablePeerAccessTo(StreamExecutor *other);
327 
328   // Obtains metadata about the underlying device.
329   // The value is cached on first use.
330   const DeviceDescription &GetDeviceDescription() const;
331 
332   // If implemented, returns device specific measurement of load
333   // (e.g. pending requests).
334   int64 GetDeviceLoad() const;
335 
336   // Returns the underlying device memory usage information, if it is available.
337   // If it is not available (false is returned), free/total may not be
338   // initialized.
339   //
340   // Note: "Free" reflects the amount of free memory on the underlying device,
341   // so allocations via other StreamExecutors that have the same underlying
342   // device
343   // will be reflected in "free".
344   bool DeviceMemoryUsage(int64 *free, int64 *total) const;
345 
346   // The device count reported by this StreamExecutor's platform.
347   // Note: on OpenCL we implicitly select platform zero at the moment.
348   int PlatformDeviceCount() const;
349 
350   // Returns whether the StreamExecutor supports BLAS routines for the platform
351   // that underlies this interface.
352   bool SupportsBlas() const;
353 
354   // Returns whether the StreamExecutor supports FFT routines for the platform
355   // that underlies this interface.
356   bool SupportsFft() const;
357 
358   // Returns whether the StreamExecutor supports RNG routines for the platform
359   // that underlies this interface.
360   bool SupportsRng() const;
361 
362   // Returns whether the StreamExecutor support neural net routines for the
363   // platform that underlies this interface.
364   bool SupportsDnn() const;
365 
366   // Returns the list of supported algorithms for the forward convolution
367   // operation.
368   bool GetConvolveAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
369 
370   // Returns the supported execution plans for the convolution operation.
371   bool GetConvolveExecutionPlans(
372       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
373       const dnn::BatchDescriptor &input_descriptor,
374       const dnn::FilterDescriptor &filter_descriptor,
375       const dnn::BatchDescriptor &output_descriptor,
376       const dnn::ConvolutionDescriptor &convolution_descriptor,
377       std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> *out_exec_plans);
378 
GetFusedConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,double conv_input_scale,double side_input_scale,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)379   port::Status GetFusedConvolveExecutionPlans(
380       dnn::ConvolutionKind kind, dnn::DataType element_type,
381       double conv_input_scale, double side_input_scale, Stream *stream,
382       const dnn::BatchDescriptor &input_descriptor,
383       const dnn::FilterDescriptor &filter_descriptor,
384       const dnn::BatchDescriptor &bias_descriptor,
385       const dnn::BatchDescriptor &output_descriptor,
386       const dnn::ConvolutionDescriptor &convolution_descriptor,
387       dnn::ActivationMode activation_mode,
388       std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>
389           *out_exec_plans) {
390     dnn::DnnSupport *dnn_support = AsDnn();
391     if (dnn_support) {
392 #if GOOGLE_CUDA
393       gpu::CudnnSupport *cudnn_dnn =
394           dynamic_cast<gpu::CudnnSupport *>(dnn_support);
395       return cudnn_dnn->GetFusedConvolveExecutionPlans(
396           kind, element_type, conv_input_scale, side_input_scale, stream,
397           input_descriptor, filter_descriptor, bias_descriptor,
398           output_descriptor, convolution_descriptor, activation_mode,
399           out_exec_plans);
400 #endif  // GOOGLE_CUDA
401     }
402     return port::UnimplementedError("DNN library is not found.");
403   }
404 
405   // Returns the list of supported algorithms for the forward convolution
406   // operation.
407   bool GetMIOpenConvolveAlgorithms(
408       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
409       const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
410       const dnn::FilterDescriptor &filter_descriptor,
411       DeviceMemoryBase filter_data,
412       const dnn::BatchDescriptor &output_descriptor,
413       DeviceMemoryBase output_data,
414       const dnn::ConvolutionDescriptor &convolution_descriptor,
415       ScratchAllocator *scratch_allocator,
416       std::vector<dnn::ProfileResult> *out_algorithms);
417 
418   // Returns the list of supported algorithms for rnn operation.
419   bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
420 
421   // Get the list of supported algorithms for the backward convolution on data.
422   bool GetConvolveBackwardDataAlgorithms(
423       std::vector<dnn::AlgorithmDesc> *out_algorithms);
424 
425   // Get the list of supported algorithms for the backward convolution on the
426   // filter.
427   bool GetConvolveBackwardFilterAlgorithms(
428       std::vector<dnn::AlgorithmDesc> *out_algorithms);
429 
430   // Get the list of supported algorithms for BLAS gemm.
431   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
432 
433   // Creates a backend-specific plan object for a blaslt matmul operation, which
434   // can then be passed to DoBlasLtMatmul(). When possible, plans should be
435   // created once and reused for multiple calls to DoBlasLtMatmul().
436   // Returns a null pointer on failure.
437   port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
438   CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params);
439 
440   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
441   // returned in the order of increasing estimated compute time according to an
442   // internal heuristic. The first returned algorithm can be used as the default
443   // algorithm if no autotuning is to be performed.
444   port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
445   GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
446                             size_t max_workspace_size, int max_algorithm_count);
447 
448   // Create an RNN descriptor based on model shapes and configurations.
449   // The caller retains the ownership of the descriptor.
450   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
451       int num_layers, int hidden_size, int input_size, int cell_size,
452       int batch_size, dnn::RnnInputMode input_mode,
453       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
454       dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
455       float dropout, uint64 seed, ScratchAllocator *state_allocator,
456       bool use_padded_io);
457 
458   // Create a RNN sequence descriptor that specifies either the input or output
459   // sequence. The caller retains the ownership of the returned descriptor.
460   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
461   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
462                                     int data_size, dnn::DataType data_type);
463 
464   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
465   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
466                                     int data_size,
467                                     const absl::Span<const int> &seq_lengths,
468                                     bool time_major, dnn::DataType data_type);
469 
470   // Create an RNN state descriptor that specifies the input or hidden state.
471   // The caller retains the ownership of the returned descriptor.
472   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
473   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
474                                  dnn::DataType data_type);
475 
476   // Returns the device ordinal that this StreamExecutor was initialized with.
477   // Meaningless before initialization.
device_ordinal()478   int device_ordinal() const { return device_ordinal_; }
479 
480   // Returns a borrowed pointer to the underlying StreamExecutor implementation.
481   internal::StreamExecutorInterface *implementation();
482 
483   // Creates a kernel which can be launched with stream.ThenLaunch, such that
484   // the types of the arguments provided for launch would have to match
485   // types of the arguments provided at creation time.
486   //
487   // The kernel has a name kernel_name, and is based from provided PTX in ptx,
488   // and (optional) compiled PTX in cubin_data.
489   // The canonical storage for both ptx and cubin_data should outlive the
490   // lifetime of the kernel.
491   template <typename... Args>
492   port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
493       absl::string_view kernel_name, absl::string_view ptx,
494       absl::Span<const uint8> cubin_data);
495 
496   // Warning: use Stream::ThenLaunch instead, this method is not for general
497   // consumption. However, this is the only way to launch a kernel for which
498   // the type signature is only known at runtime; say, if an application
499   // supports loading/launching kernels with arbitrary type signatures.
500   // In this case, the application is expected to know how to do parameter
501   // packing that obeys the contract of the underlying platform implementation.
502   //
503   // Launches a data parallel kernel with the given thread/block
504   // dimensionality and already-packed args/sizes to pass to the underlying
505   // platform driver.
506   //
507   // This is called by Stream::Launch() to delegate to the platform's launch
508   // implementation in StreamExecutorInterface::Launch().
509   port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
510                       const BlockDim &block_dims, const KernelBase &kernel,
511                       const KernelArgsArrayBase &args);
512 
513   // Gets-or-creates (creates with memoization) a FftSupport datatype that can
514   // be used to execute FFT routines on the current platform.
515   //
516   // Ownership and user-facing is the same as AsBlas() below.
517   //
518   // Returns null if there was an error initializing the FFT support for the
519   // underlying platform.
520   fft::FftSupport *AsFft();
521 
522   // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
523   // be used for neural network routines on the current platform.
524   //
525   // Ownership and user-facing is the same as AsBlas() below.
526   //
527   // Returns null if there was an error initializing the DNN support for the
528   // underlying platform.
529   dnn::DnnSupport *AsDnn();
530 
531   // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
532   // be used to execute BLAS routines on the current platform. This is typically
533   // not user-facing, as users will use the Stream::ThenBlas* family of routines
534   // to entrain BLAS operations. See blas.h for additional details.
535   //
536   // Ownership is not transferred to the caller -- ownership is retained by this
537   // object for memoization. This BLAS interface is also only expected to be
538   // used by a Stream for entraining calls to BLAS functionality.
539   //
540   // Returns null if there was an error initializing the BLAS support for the
541   // underlying platform.
542   blas::BlasSupport *AsBlas();
543 
544   // Turns StreamExecutor operation tracing on or off.
545   void EnableTracing(bool enable);
546 
547   // Registers a trace listener to receive callbacks for only a single
548   // StreamExecutor instance.
549   // To register a listener for all executors for a given platform, see
550   // Platform::RegisterTraceListener().
551   // Does not take ownership of listener.
552   void RegisterTraceListener(TraceListener *listener);
553 
554   // Removes a TraceListener from this StreamExecutor instance.
555   // Returns false (and logs) in cases where the argument listener was not
556   // previously registered.
557   bool UnregisterTraceListener(TraceListener *listener);
558 
559   // Return allocator statistics.
560   absl::optional<AllocatorStats> GetAllocatorStats();
561 
562   // Clears the internal stats except for the `in_use` fields
563   // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`.
564   bool ClearAllocatorStats();
565 
566   // Return an allocator which delegates to this stream executor for memory
567   // allocation.
GetAllocator()568   StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
569 
570  private:
571   template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
572             typename... BeginArgsT>
573   friend class ScopedTracer;
574   friend class Event;
575   friend class Stream;
576   friend class Timer;
577   template <typename... Params>
578   friend class TypedKernel;
579   template <typename... Args>
580   friend struct ThenBlasImpl;
581 
582   // Synchronously allocates size bytes on the underlying platform and returns
583   // a DeviceMemoryBase representing that allocation. In the case of failure,
584   // nullptr is returned.
585   DeviceMemoryBase Allocate(uint64 size, int64_t memory_space);
586 
587   // Gets-or-creates (creates with memoization) an RngSupport datatype that can
588   // be used for random-number-generation routines on the current platform.
589   //
590   // Ownership and user-facing is the same as AsBlas() above.
591   //
592   // Returns null if there was an error initializing the RNG support for the
593   // underlying platform.
594   rng::RngSupport *AsRng();
595 
596   // Causes the host code to synchronously wait for operations entrained onto
597   // stream to complete. Effectively a join on the asynchronous device
598   // operations enqueued on the stream before this program point.
599   port::Status BlockHostUntilDone(Stream *stream);
600 
601   // Without blocking the device, retrieve the current stream status.
602   port::Status GetStatus(Stream *stream);
603 
604   // Finds and retrieves device memory for the symbol on the underlying
605   // platform.
606   bool GetSymbol(const std::string &symbol_name, ModuleHandle module_handle,
607                  void **mem, size_t *bytes);
608 
609   // Entrains a memcpy operation onto stream, with a host destination location
610   // host_dst and a device memory source, with target size size.
611   bool Memcpy(Stream *stream, void *host_dst,
612               const DeviceMemoryBase &device_src, uint64 size);
613 
614   // Entrains a memcpy operation onto stream, with a device destination location
615   // and a host memory source, with target size size.
616   bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
617               const void *host_src, uint64 size);
618 
619   // Entrains a memcpy operation onto stream, with a device destination location
620   // and a device source location, with target size size. Peer access should
621   // have been enabled between the StreamExecutors owning the device memory
622   // regions.
623   bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
624                             const DeviceMemoryBase &device_src, uint64 size);
625 
626   // Entrains on a stream a user-specified function to be run on the host.
627   // See Stream::ThenDoHostCallback for full details.
628   bool HostCallback(Stream *stream, std::function<void()> callback);
629 
630   // Entrains on a stream a user-specified function to be run on the host.
631   // See Stream::ThenDoHostCallback for full details.
632   // This is the preferred form for a callback that may return an error.
633   bool HostCallback(Stream *stream, std::function<port::Status()> callback);
634 
635   // Performs platform-specific allocation and initialization of an event.
636   port::Status AllocateEvent(Event *event);
637 
638   // Performs platform-specific deallocation and cleanup of an event.
639   port::Status DeallocateEvent(Event *event);
640 
641   // Inserts the specified event at the end of the specified stream.
642   port::Status RecordEvent(Stream *stream, Event *event);
643 
644   // Wait for the specified event at the end of the specified stream.
645   port::Status WaitForEvent(Stream *stream, Event *event);
646 
647   // Requests the current status of the event from the underlying platform.
648   Event::Status PollForEventStatus(Event *event);
649 
650   // Allocates stream resources on the underlying platform and initializes its
651   // internals.
652   bool AllocateStream(Stream *stream);
653 
654   // Deallocates stream resources on the underlying platform.
655   void DeallocateStream(Stream *stream);
656 
657   // Causes dependent to not begin execution until other has finished its
658   // last-enqueued work.
659   bool CreateStreamDependency(Stream *dependent, Stream *other);
660 
661   // Allocates timer resources on the underlying platform and initializes its
662   // internals.
663   bool AllocateTimer(Timer *timer);
664 
665   // Deallocates timer resources on the underlying platform.
666   void DeallocateTimer(Timer *timer);
667 
668   // Records a start event for an interval timer.
669   bool StartTimer(Stream *stream, Timer *timer);
670 
671   // Records a stop event for an interval timer.
672   bool StopTimer(Stream *stream, Timer *timer);
673 
674   // Allocates a new metadata object, appropriately populated, on the heap, with
675   // ownership transfer to caller.
676   std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
677 
678   // Adds a task to the port::ThreadPool work queue. These tasks must be
679   // fire-and-forget and have no external data or timing dependencies; their
680   // execution order and completion time have no guarantees.
681   // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
682   // there, temporary internal buffers are freed using this method.
683   void EnqueueOnBackgroundThread(std::function<void()> task);
684 
685   // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
686   // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
687   // tracked.
688   void CreateAllocRecord(void *opaque, uint64 bytes);
689 
690   // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
691   // pointers will not be erased (as they're not tracked, per above).
692   void EraseAllocRecord(void *opaque);
693 
694   // Calls the relevant TraceListener routine to begin tracing for the specified
695   // asynchronous method.
696   template <typename TraceCallT, typename... ArgsT>
697   void SubmitTrace(TraceCallT trace_call, ArgsT &&...args);
698 
699   // Reader/writer lock for class-static StreamExecutor members.
700   static absl::Mutex static_mu_;
701 
702   // Reader/writer lock for mutable data structures on this StreamExecutor.
703   //
704   // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
705   // can acquire the lock on their first (mutating) call as well.
706   mutable absl::Mutex mu_;
707 
708   // Reference to the platform that created this executor.
709   const Platform *platform_;
710 
711   // Pointer to the platform-specific-interface implementation. This is
712   // delegated to by the interface routines in pointer-to-implementation
713   // fashion.
714   std::unique_ptr<internal::StreamExecutorInterface> implementation_;
715 
716   // A mapping of pointer (to device memory) to string representation of the
717   // stack (of the allocating thread) at the time at which the pointer was
718   // allocated.
719   std::map<void *, AllocRecord> mem_allocs_ TF_GUARDED_BY(mu_);
720 
721   // Memoized BLAS support object -- we only want to create this once when asked
722   // for a BLAS interface.
723   std::unique_ptr<blas::BlasSupport> blas_ TF_GUARDED_BY(mu_);
724 
725   // Memoized DNN support object -- we only want to create this once when asked
726   // for an DNN interface.
727   std::unique_ptr<dnn::DnnSupport> dnn_ TF_GUARDED_BY(mu_);
728 
729   // Memoized FFT support object -- we only want to create this once when asked
730   // for a FFT interface.
731   std::unique_ptr<fft::FftSupport> fft_;
732 
733   // Memoized RNG support object -- we only want to create this once when asked
734   // for an RNG interface.
735   std::unique_ptr<rng::RngSupport> rng_ TF_GUARDED_BY(mu_);
736 
737   // Slot to cache the owned DeviceDescription for the underlying device
738   // once it has been queried from DeviceDescription().
739   mutable std::unique_ptr<DeviceDescription> device_description_
740       TF_GUARDED_BY(mu_);
741 
742   // The kind of the underlying platform that is being targeted, as passed
743   // during construction.
744   //
745   // Immutable post-initialization.
746   PlatformKind platform_kind_;
747 
748   // The device ordinal that this object was initialized with.
749   //
750   // Immutable post-initialization.
751   int device_ordinal_;
752 
753   // Executor for handling host callback work that cannot be performed
754   // by a host callback thread - for example, cleanup after a host BLAS routine
755   // (which may make device API calls). This work cannot block the host
756   // callback thread, will be completed asynchronously, and should be treated
757   // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
758   // here.
759   //
760   // Immutable post-initialization. Object is thread-safe.
761   std::unique_ptr<port::ThreadPool> background_threads_;
762 
763   // Counter for the current number of live streams. This is used to check
764   // for accidentally-outstanding streams at StreamExecutor teardown time, as
765   // well
766   // as to indicate leaks (via a large outstanding count being logged) in the
767   // case we can't allocate more streams.
768   std::atomic_int_fast32_t live_stream_count_;
769 
770   // Only one worker thread is needed; little work will be done by the
771   // executor.
772   static constexpr int kNumBackgroundThreads = 1;
773 
774   // Indicates if StreamExecutor operation tracing should be performed.
775   bool tracing_enabled_;
776 
777   // The set of TraceListeners registered for this StreamExecutor.
778   std::set<TraceListener *> listeners_ TF_GUARDED_BY(mu_);
779 
780   // Allocated memory in bytes.
781   int64 mem_alloc_bytes_;
782 
783   // Memory limit in bytes. Value less or equal to 0 indicates there is no
784   // limit.
785   int64 memory_limit_bytes_;
786 
787   StreamExecutorMemoryAllocator allocator_;
788 
789   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
790 };
791 
792 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
793 class ScopedModuleHandle {
794  public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)795   explicit ScopedModuleHandle(StreamExecutor *executor,
796                               ModuleHandle module_handle)
797       : executor_(executor), module_handle_(module_handle) {}
798 
ScopedModuleHandle(ScopedModuleHandle && other)799   ScopedModuleHandle(ScopedModuleHandle &&other) {
800     executor_ = other.executor_;
801     module_handle_ = other.module_handle_;
802     other.executor_ = nullptr;
803     other.module_handle_ = ModuleHandle();
804   }
805 
806   ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
807     executor_ = other.executor_;
808     module_handle_ = other.module_handle_;
809     other.executor_ = nullptr;
810     other.module_handle_ = ModuleHandle();
811     return *this;
812   }
813 
~ScopedModuleHandle()814   ~ScopedModuleHandle() {
815     if (static_cast<bool>(module_handle_)) {
816       CHECK(executor_->UnloadModule(module_handle_));
817     }
818   }
819 
820  private:
821   StreamExecutor *executor_;
822   ModuleHandle module_handle_;
823 
824   TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
825 };
826 
827 ////////////
828 // Inlines
829 
830 template <typename... Args>
831 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)832 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
833                                   absl::string_view ptx,
834                                   absl::Span<const uint8> cubin_data) {
835   auto kernel_base = absl::make_unique<TypedKernel<Args...>>(this);
836   MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
837   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
838 
839   if (!cubin_data.empty()) {
840     loader_spec.AddCudaCubinInMemory(
841         reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
842   }
843 
844   TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
845   return std::move(kernel_base);
846 }
847 
848 template <typename T>
AllocateArray(uint64 element_count,int64_t memory_space)849 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
850                                                      int64_t memory_space) {
851   uint64 bytes = sizeof(T) * element_count;
852   return DeviceMemory<T>(Allocate(bytes, memory_space));
853 }
854 
855 template <typename T>
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle)856 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
857     const std::string &symbol_name, ModuleHandle module_handle) {
858   port::StatusOr<DeviceMemoryBase> untyped_symbol =
859       GetUntypedSymbol(symbol_name, module_handle);
860   if (!untyped_symbol.ok()) {
861     return untyped_symbol.status();
862   }
863   return DeviceMemory<T>(untyped_symbol.ValueOrDie());
864 }
865 
866 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)867 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
868                                               DeviceMemoryBase value)
869     : wrapped_(value),
870       device_ordinal_(parent->device_ordinal()),
871       allocator_(parent->GetAllocator()) {}
872 
873 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)874 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
875     StreamExecutor *parent, std::initializer_list<ElemT> values)
876     : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
877   if (ptr() != nullptr) {
878     std::vector<ElemT> local(values);
879     if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
880                                    ptr()->size())) {
881       TF_CHECK_OK(Free());
882     }
883   }
884 }
885 
886 template <typename T>
AllocateZeroed()887 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
888   DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
889   if (buf.is_null()) {
890     return DeviceMemory<T>{};
891   }
892 
893   DeviceMemory<T> result(buf);
894   bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
895   if (!ok) {
896     Deallocate(&result);
897     return DeviceMemory<T>{};
898   }
899 
900   return result;
901 }
902 
903 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)904 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T> *parent,
905                                              uint64 element_offset,
906                                              uint64 element_count) {
907   if (element_offset + element_count > parent->ElementCount()) {
908     LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
909                << "than parent allocation size: (" << element_offset << " + "
910                << element_count << ") vs. (" << parent->ElementCount() << ")";
911     return DeviceMemory<T>{};
912   }
913 
914   void *opaque = implementation_->GetSubBuffer(
915       parent, sizeof(T) * element_offset, sizeof(T) * element_count);
916   if (opaque == nullptr) {
917     return DeviceMemory<T>{};
918   }
919   return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
920 }
921 
922 }  // namespace stream_executor
923 
924 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
925