1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24
25 #include "absl/base/macros.h"
26 #include "absl/memory/memory.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/platform/thread_annotations.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31 #include "tensorflow/stream_executor/lib/status.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 #include "tensorflow/stream_executor/lib/threadpool.h"
34 #include "tensorflow/stream_executor/platform.h"
35 #include "tensorflow/stream_executor/platform/logging.h"
36 #include "tensorflow/stream_executor/platform/port.h"
37 #include "tensorflow/stream_executor/rng.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/trace_listener.h"
40 #if GOOGLE_CUDA
41 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
42 #endif // GOOGLE_CUDA
43
44 namespace stream_executor {
45
46 class Stream;
47
48 // Structure used for device memory leak checking.
49 struct AllocRecord {
50 // The requested allocation size of the buffer.
51 uint64 bytes;
52
53 // Holds a representation of the stack at the time the associated buffer was
54 // allocated. Produced in a form described in
55 // //util/symbolize/symbolized_stacktrace.h.
56 std::string stack_trace;
57 };
58
59 // Forward declaration of private friend class.
60 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
61 typename... BeginArgsT>
62 class ScopedTracer;
63
64 // A StreamExecutor manages a single device, in terms of executing work (kernel
65 // launches) and memory management (allocation/deallocation, memory copies to
66 // and from the device). It is conceptually the "handle" for a device -- Stream
67 // objects, which are used to enqueue work to run on the
68 // coprocessor have a StreamExecutor instance as their "parent" object.
69 //
70 // StreamExecutor objects have an underlying platform that is specified up
71 // front;
72 // e.g. either it is a CUDA or OpenCL executor.
73 //
74 // Thread-safe after initialization.
75 // StreamExecutor interface should not be invoked from a signal handler.
76 class StreamExecutor {
77 public:
78 StreamExecutor(
79 const Platform *platform,
80 std::unique_ptr<internal::StreamExecutorInterface> implementation,
81 int device_ordinal);
82
83 ~StreamExecutor();
84
85 port::Status Init();
86 port::Status Init(DeviceOptions device_options);
87
88 // Returns the platform that this StreamExecutor is acting upon.
89 ABSL_DEPRECATED("Use platform() instead.")
platform_kind()90 PlatformKind platform_kind() const { return platform_kind_; }
91
92 // Returns a reference to the platform that created this executor.
platform()93 const Platform *platform() const { return platform_; }
94
95 // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
96 // upon, if one exists.
97 //
98 // Parameters:
99 // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
100 // constant into an appropriate namespace. For example, see
101 // stream_executor::executor_sample::kKernelLoaderSpecs, from which a
102 // MultiKernelLoaderSpec is selected.
103 // kernel: Outparam that the kernel is loaded into. A given Kernel
104 // instantiation should not be loaded into more than once.
105 //
106 // If an error occurs, or there is no kernel available for the StreamExecutor
107 // platform, error status is returned.
108 port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
109
110 // Releases any state associated with the previously loaded kernel.
111 void UnloadKernel(const KernelBase *kernel);
112
113 // Loads a module for the platform this StreamExecutor is acting upon.
114 //
115 // `spec` describes the module to be loaded. On success writes the handle for
116 // the loaded module to `module_handle` and returns Status::OK.
117 // Otherwise, returns the error which has occurred.
118 port::Status LoadModule(const MultiModuleLoaderSpec &spec,
119 ModuleHandle *module_handle);
120
121 // Unloads the module with handle `module_handle`.
122 bool UnloadModule(ModuleHandle module_handle);
123
124 // Synchronously allocates an array on the device of type T with element_count
125 // elements.
126 template <typename T>
127 DeviceMemory<T> AllocateArray(uint64 element_count, int64_t memory_space = 0);
128
129 // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
130 template <typename T>
AllocateOwnedArray(uint64 element_count)131 ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
132 return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
133 }
134
135 // Convenience wrapper that allocates space for a single element of type T in
136 // device memory.
137 template <typename T>
AllocateScalar()138 DeviceMemory<T> AllocateScalar() {
139 return AllocateArray<T>(1);
140 }
141
142 // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
143 template <typename T>
AllocateOwnedScalar()144 ScopedDeviceMemory<T> AllocateOwnedScalar() {
145 return AllocateOwnedArray<T>(1);
146 }
147
148 // Synchronously allocates a scalar of type T on the device that is (POD)
149 // zero-byte initialized.
150 template <typename T>
151 DeviceMemory<T> AllocateZeroed();
152
153 // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
154 template <typename T>
AllocateOwnedZeroed()155 ScopedDeviceMemory<T> AllocateOwnedZeroed() {
156 return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
157 }
158
159 // Allocate a memory region inside another allocated memory region.
160 // Offset and size are specified in terms of T elements.
161 // Warning: Do not free a parent buffer before its sub-buffers; this may cause
162 // use-after-free issues (the specific behavior is not consistent across
163 // platforms).
164 // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
165 // sub-buffer after parent deallocation is expected to be safe. This will
166 // render your code non-platform-portable, however.
167 template <typename T>
168 DeviceMemory<T> GetSubBuffer(DeviceMemory<T> *parent, uint64 element_offset,
169 uint64 element_count);
170
171 // Finds a symbol and returns device memory allocated to the symbol. The
172 // symbol is searched in any kernels that were previously loaded through
173 // GetKernel() before the GetSymbol() call. The user has to make sure that the
174 // type of symbol and T match.
175 // - Note: symbol_name should include its namespace as well. For example,
176 // pass "nms0::symbol" if referring to nms0::symbol.
177 //
178 // If `module_handle` is set then searches only within the module
179 // corresponding to `module_handle`.
180 template <typename T>
181 port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string &symbol_name,
182 ModuleHandle module_handle = {});
183
184 // An untyped version of GetSymbol.
185 port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
186 const std::string &symbol_name, ModuleHandle module_handle = {});
187
188 // Deallocate the DeviceMemory previously allocated via this interface.
189 // Deallocation of a nullptr-representative value is permitted.
190 //
191 // Resets the internal contents of mem to be null-representative, but this
192 // null-out effect should not be relied upon in client code.
193 void Deallocate(DeviceMemoryBase *mem);
194
195 // Retrieves a mapping of active opaque device memory pointer to a string
196 // representation of the [allocating thread's] stack at the time the pointer
197 // was allocated. Useful for tracking device memory leaks.
198 //
199 // Note: this will only be populated if --check_device_leaks flag is
200 // activated.
201 void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
202
203 // Allocates unified memory space of the given size, if supported.
204 // See
205 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
206 // for more details on unified memory.
207 void *UnifiedMemoryAllocate(uint64 bytes);
208
209 // Deallocates unified memory space previously allocated with
210 // UnifiedMemoryAllocate.
211 void UnifiedMemoryDeallocate(void *location);
212
213 // Allocates a region of host memory and registers it with the platform API.
214 // Memory allocated in this manner (or allocated and registered with
215 // HostMemoryRegister() is required for use in asynchronous memcpy operations,
216 // such as Stream::ThenMemcpy.
217 void *HostMemoryAllocate(uint64 size);
218
219 // Deallocates a region of host memory allocated by HostMemoryAllocate().
220 void HostMemoryDeallocate(void *location);
221
222 // Registers a region of host memory with the platform API. Registered memory
223 // (or memory allocated with HostMemoryAllocate) is required for use with
224 // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
225 // is used to register memory allocated outside the StreamExecutor;
226 // HostMemoryAllocate implicitly registers its allocations and
227 // HostMemoryDeallocate implicitly deregisters on deallocation.
228 bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
229
230 // Unregisters a region of host memory registered with HostMemoryRegister.
231 // This should be done before deallocating the region with delete[]/free/etc.
232 bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
233
234 // Synchronizes all activity occurring in the StreamExecutor's context (most
235 // likely a whole device).
236 bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
237
238 // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
239 // given location in device memory.
240 port::Status SynchronousMemZero(DeviceMemoryBase *location,
241 uint64 size) SE_MUST_USE_RESULT;
242
243 // Blocks the caller while "size" bytes are initialized to "value" (in POD
244 // fashion) at the given location in device memory.
245 port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
246 uint64 size) SE_MUST_USE_RESULT;
247
248 // [deprecated] Blocks the caller while a data segment of the given size is
249 // copied from the host source to the device destination.
250 ABSL_DEPRECATED(
251 "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
252 bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
253 uint64 size) SE_MUST_USE_RESULT;
254
255 // [deprecated] Blocks the caller while a data segment of the given size is
256 // copied from the device source to the host destination.
257 ABSL_DEPRECATED(
258 "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
259 bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
260 uint64 size) SE_MUST_USE_RESULT;
261
262 // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
263 port::Status SynchronousMemcpyH2D(const void *host_src, int64_t size,
264 DeviceMemoryBase *device_dst);
265
266 // Alternative interface for memcpying from host to device that takes an
267 // array slice. Checks that the destination size can accommodate the host
268 // slice size.
269 template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)270 port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
271 DeviceMemoryBase *device_dst) {
272 auto host_size = host_src.size() * sizeof(T);
273 CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
274 return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
275 }
276
277 // Same as SynchronousMemcpy(void*, ...) above.
278 port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
279 int64_t size, void *host_dst);
280
281 // Alternative interface for memcpying from device to host that takes an
282 // array slice. Checks that the destination size can accommodate the host
283 // slice size.
284 template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)285 port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
286 port::MutableArraySlice<T> host_dst) {
287 auto host_size = host_dst.size() * sizeof(T);
288 CHECK(device_src.size() == 0 || host_size >= device_src.size());
289 return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
290 }
291
292 // Blocks the caller while a data segment of the given size is copied from the
293 // device source to the device destination.
294 bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
295 const DeviceMemoryBase &device_src,
296 uint64 size) SE_MUST_USE_RESULT;
297
298 // Enqueues an operation onto stream to zero out size bytes at the given
299 // device memory location. Neither stream nor location may be null. Returns
300 // whether the operation was successfully enqueued onto the stream.
301 port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
302 uint64 size) SE_MUST_USE_RESULT;
303
304 // Enqueues an operation onto stream to set 32-bit patterns starting at
305 // location, for byte count given by size. size must be 32-bit quantified
306 // (i.e. evently divisible by 4). Returns whether the operation was
307 // successfully enqueued onto the stream.
308 port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
309 uint32 pattern, uint64 size);
310
311 // Enables peer access from this StreamExecutor to memory
312 // allocated by other, such that launched device code, memcpies, etc may
313 // access it directly.
314 //
315 // Both this StreamExecutor and other must be backed by the same platform (as
316 // in
317 // CUDA vs OpenCL) implementation.
318 port::Status EnablePeerAccessTo(StreamExecutor *other);
319
320 // Returns whether it's possible to enable peer access from this
321 // StreamExecutor
322 // to memory allocated by another.
323 //
324 // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
325 // this is more an up-front test as to whether it's expressly forbidden.
326 bool CanEnablePeerAccessTo(StreamExecutor *other);
327
328 // Obtains metadata about the underlying device.
329 // The value is cached on first use.
330 const DeviceDescription &GetDeviceDescription() const;
331
332 // If implemented, returns device specific measurement of load
333 // (e.g. pending requests).
334 int64 GetDeviceLoad() const;
335
336 // Returns the underlying device memory usage information, if it is available.
337 // If it is not available (false is returned), free/total may not be
338 // initialized.
339 //
340 // Note: "Free" reflects the amount of free memory on the underlying device,
341 // so allocations via other StreamExecutors that have the same underlying
342 // device
343 // will be reflected in "free".
344 bool DeviceMemoryUsage(int64 *free, int64 *total) const;
345
346 // The device count reported by this StreamExecutor's platform.
347 // Note: on OpenCL we implicitly select platform zero at the moment.
348 int PlatformDeviceCount() const;
349
350 // Returns whether the StreamExecutor supports BLAS routines for the platform
351 // that underlies this interface.
352 bool SupportsBlas() const;
353
354 // Returns whether the StreamExecutor supports FFT routines for the platform
355 // that underlies this interface.
356 bool SupportsFft() const;
357
358 // Returns whether the StreamExecutor supports RNG routines for the platform
359 // that underlies this interface.
360 bool SupportsRng() const;
361
362 // Returns whether the StreamExecutor support neural net routines for the
363 // platform that underlies this interface.
364 bool SupportsDnn() const;
365
366 // Returns the list of supported algorithms for the forward convolution
367 // operation.
368 bool GetConvolveAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
369
370 // Returns the supported execution plans for the convolution operation.
371 bool GetConvolveExecutionPlans(
372 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
373 const dnn::BatchDescriptor &input_descriptor,
374 const dnn::FilterDescriptor &filter_descriptor,
375 const dnn::BatchDescriptor &output_descriptor,
376 const dnn::ConvolutionDescriptor &convolution_descriptor,
377 std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> *out_exec_plans);
378
GetFusedConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,double conv_input_scale,double side_input_scale,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)379 port::Status GetFusedConvolveExecutionPlans(
380 dnn::ConvolutionKind kind, dnn::DataType element_type,
381 double conv_input_scale, double side_input_scale, Stream *stream,
382 const dnn::BatchDescriptor &input_descriptor,
383 const dnn::FilterDescriptor &filter_descriptor,
384 const dnn::BatchDescriptor &bias_descriptor,
385 const dnn::BatchDescriptor &output_descriptor,
386 const dnn::ConvolutionDescriptor &convolution_descriptor,
387 dnn::ActivationMode activation_mode,
388 std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>
389 *out_exec_plans) {
390 dnn::DnnSupport *dnn_support = AsDnn();
391 if (dnn_support) {
392 #if GOOGLE_CUDA
393 gpu::CudnnSupport *cudnn_dnn =
394 dynamic_cast<gpu::CudnnSupport *>(dnn_support);
395 return cudnn_dnn->GetFusedConvolveExecutionPlans(
396 kind, element_type, conv_input_scale, side_input_scale, stream,
397 input_descriptor, filter_descriptor, bias_descriptor,
398 output_descriptor, convolution_descriptor, activation_mode,
399 out_exec_plans);
400 #endif // GOOGLE_CUDA
401 }
402 return port::UnimplementedError("DNN library is not found.");
403 }
404
405 // Returns the list of supported algorithms for the forward convolution
406 // operation.
407 bool GetMIOpenConvolveAlgorithms(
408 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
409 const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
410 const dnn::FilterDescriptor &filter_descriptor,
411 DeviceMemoryBase filter_data,
412 const dnn::BatchDescriptor &output_descriptor,
413 DeviceMemoryBase output_data,
414 const dnn::ConvolutionDescriptor &convolution_descriptor,
415 ScratchAllocator *scratch_allocator,
416 std::vector<dnn::ProfileResult> *out_algorithms);
417
418 // Returns the list of supported algorithms for rnn operation.
419 bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
420
421 // Get the list of supported algorithms for the backward convolution on data.
422 bool GetConvolveBackwardDataAlgorithms(
423 std::vector<dnn::AlgorithmDesc> *out_algorithms);
424
425 // Get the list of supported algorithms for the backward convolution on the
426 // filter.
427 bool GetConvolveBackwardFilterAlgorithms(
428 std::vector<dnn::AlgorithmDesc> *out_algorithms);
429
430 // Get the list of supported algorithms for BLAS gemm.
431 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
432
433 // Creates a backend-specific plan object for a blaslt matmul operation, which
434 // can then be passed to DoBlasLtMatmul(). When possible, plans should be
435 // created once and reused for multiple calls to DoBlasLtMatmul().
436 // Returns a null pointer on failure.
437 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
438 CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms);
439
440 // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
441 // returned in the order of increasing estimated compute time according to an
442 // internal heuristic. The first returned algorithm can be used as the default
443 // algorithm if no autotuning is to be performed.
444 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
445 GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
446 size_t max_workspace_size, int max_algorithm_count);
447
448 // Create an RNN descriptor based on model shapes and configurations.
449 // The caller retains the ownership of the descriptor.
450 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
451 int num_layers, int hidden_size, int input_size, int cell_size,
452 int batch_size, dnn::RnnInputMode input_mode,
453 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
454 dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
455 float dropout, uint64 seed, ScratchAllocator *state_allocator,
456 bool use_padded_io);
457
458 // Create a RNN sequence descriptor that specifies either the input or output
459 // sequence. The caller retains the ownership of the returned descriptor.
460 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
461 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
462 int data_size, dnn::DataType data_type);
463
464 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
465 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
466 int data_size,
467 const absl::Span<const int> &seq_lengths,
468 bool time_major, dnn::DataType data_type);
469
470 // Create an RNN state descriptor that specifies the input or hidden state.
471 // The caller retains the ownership of the returned descriptor.
472 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
473 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
474 dnn::DataType data_type);
475
476 // Returns the device ordinal that this StreamExecutor was initialized with.
477 // Meaningless before initialization.
device_ordinal()478 int device_ordinal() const { return device_ordinal_; }
479
480 // Returns a borrowed pointer to the underlying StreamExecutor implementation.
481 internal::StreamExecutorInterface *implementation();
482
483 // Creates a kernel which can be launched with stream.ThenLaunch, such that
484 // the types of the arguments provided for launch would have to match
485 // types of the arguments provided at creation time.
486 //
487 // The kernel has a name kernel_name, and is based from provided PTX in ptx,
488 // and (optional) compiled PTX in cubin_data.
489 // The canonical storage for both ptx and cubin_data should outlive the
490 // lifetime of the kernel.
491 template <typename... Args>
492 port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
493 absl::string_view kernel_name, absl::string_view ptx,
494 absl::Span<const uint8> cubin_data);
495
496 // Warning: use Stream::ThenLaunch instead, this method is not for general
497 // consumption. However, this is the only way to launch a kernel for which
498 // the type signature is only known at runtime; say, if an application
499 // supports loading/launching kernels with arbitrary type signatures.
500 // In this case, the application is expected to know how to do parameter
501 // packing that obeys the contract of the underlying platform implementation.
502 //
503 // Launches a data parallel kernel with the given thread/block
504 // dimensionality and already-packed args/sizes to pass to the underlying
505 // platform driver.
506 //
507 // This is called by Stream::Launch() to delegate to the platform's launch
508 // implementation in StreamExecutorInterface::Launch().
509 port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
510 const BlockDim &block_dims, const KernelBase &kernel,
511 const KernelArgsArrayBase &args);
512
513 // Gets-or-creates (creates with memoization) a FftSupport datatype that can
514 // be used to execute FFT routines on the current platform.
515 //
516 // Ownership and user-facing is the same as AsBlas() below.
517 //
518 // Returns null if there was an error initializing the FFT support for the
519 // underlying platform.
520 fft::FftSupport *AsFft();
521
522 // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
523 // be used for neural network routines on the current platform.
524 //
525 // Ownership and user-facing is the same as AsBlas() below.
526 //
527 // Returns null if there was an error initializing the DNN support for the
528 // underlying platform.
529 dnn::DnnSupport *AsDnn();
530
531 // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
532 // be used to execute BLAS routines on the current platform. This is typically
533 // not user-facing, as users will use the Stream::ThenBlas* family of routines
534 // to entrain BLAS operations. See blas.h for additional details.
535 //
536 // Ownership is not transferred to the caller -- ownership is retained by this
537 // object for memoization. This BLAS interface is also only expected to be
538 // used by a Stream for entraining calls to BLAS functionality.
539 //
540 // Returns null if there was an error initializing the BLAS support for the
541 // underlying platform.
542 blas::BlasSupport *AsBlas();
543
544 // Turns StreamExecutor operation tracing on or off.
545 void EnableTracing(bool enable);
546
547 // Registers a trace listener to receive callbacks for only a single
548 // StreamExecutor instance.
549 // To register a listener for all executors for a given platform, see
550 // Platform::RegisterTraceListener().
551 // Does not take ownership of listener.
552 void RegisterTraceListener(TraceListener *listener);
553
554 // Removes a TraceListener from this StreamExecutor instance.
555 // Returns false (and logs) in cases where the argument listener was not
556 // previously registered.
557 bool UnregisterTraceListener(TraceListener *listener);
558
559 // Return allocator statistics.
560 absl::optional<AllocatorStats> GetAllocatorStats();
561
562 // Clears the internal stats except for the `in_use` fields
563 // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`.
564 bool ClearAllocatorStats();
565
566 // Return an allocator which delegates to this stream executor for memory
567 // allocation.
GetAllocator()568 StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
569
570 private:
571 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
572 typename... BeginArgsT>
573 friend class ScopedTracer;
574 friend class Event;
575 friend class Stream;
576 friend class Timer;
577 template <typename... Params>
578 friend class TypedKernel;
579 template <typename... Args>
580 friend struct ThenBlasImpl;
581
582 // Synchronously allocates size bytes on the underlying platform and returns
583 // a DeviceMemoryBase representing that allocation. In the case of failure,
584 // nullptr is returned.
585 DeviceMemoryBase Allocate(uint64 size, int64_t memory_space);
586
587 // Gets-or-creates (creates with memoization) an RngSupport datatype that can
588 // be used for random-number-generation routines on the current platform.
589 //
590 // Ownership and user-facing is the same as AsBlas() above.
591 //
592 // Returns null if there was an error initializing the RNG support for the
593 // underlying platform.
594 rng::RngSupport *AsRng();
595
596 // Causes the host code to synchronously wait for operations entrained onto
597 // stream to complete. Effectively a join on the asynchronous device
598 // operations enqueued on the stream before this program point.
599 port::Status BlockHostUntilDone(Stream *stream);
600
601 // Without blocking the device, retrieve the current stream status.
602 port::Status GetStatus(Stream *stream);
603
604 // Finds and retrieves device memory for the symbol on the underlying
605 // platform.
606 bool GetSymbol(const std::string &symbol_name, ModuleHandle module_handle,
607 void **mem, size_t *bytes);
608
609 // Entrains a memcpy operation onto stream, with a host destination location
610 // host_dst and a device memory source, with target size size.
611 bool Memcpy(Stream *stream, void *host_dst,
612 const DeviceMemoryBase &device_src, uint64 size);
613
614 // Entrains a memcpy operation onto stream, with a device destination location
615 // and a host memory source, with target size size.
616 bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
617 const void *host_src, uint64 size);
618
619 // Entrains a memcpy operation onto stream, with a device destination location
620 // and a device source location, with target size size. Peer access should
621 // have been enabled between the StreamExecutors owning the device memory
622 // regions.
623 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
624 const DeviceMemoryBase &device_src, uint64 size);
625
626 // Entrains on a stream a user-specified function to be run on the host.
627 // See Stream::ThenDoHostCallback for full details.
628 bool HostCallback(Stream *stream, std::function<void()> callback);
629
630 // Entrains on a stream a user-specified function to be run on the host.
631 // See Stream::ThenDoHostCallback for full details.
632 // This is the preferred form for a callback that may return an error.
633 bool HostCallback(Stream *stream, std::function<port::Status()> callback);
634
635 // Performs platform-specific allocation and initialization of an event.
636 port::Status AllocateEvent(Event *event);
637
638 // Performs platform-specific deallocation and cleanup of an event.
639 port::Status DeallocateEvent(Event *event);
640
641 // Inserts the specified event at the end of the specified stream.
642 port::Status RecordEvent(Stream *stream, Event *event);
643
644 // Wait for the specified event at the end of the specified stream.
645 port::Status WaitForEvent(Stream *stream, Event *event);
646
647 // Requests the current status of the event from the underlying platform.
648 Event::Status PollForEventStatus(Event *event);
649
650 // Allocates stream resources on the underlying platform and initializes its
651 // internals.
652 bool AllocateStream(Stream *stream);
653
654 // Deallocates stream resources on the underlying platform.
655 void DeallocateStream(Stream *stream);
656
657 // Causes dependent to not begin execution until other has finished its
658 // last-enqueued work.
659 bool CreateStreamDependency(Stream *dependent, Stream *other);
660
661 // Allocates timer resources on the underlying platform and initializes its
662 // internals.
663 bool AllocateTimer(Timer *timer);
664
665 // Deallocates timer resources on the underlying platform.
666 void DeallocateTimer(Timer *timer);
667
668 // Records a start event for an interval timer.
669 bool StartTimer(Stream *stream, Timer *timer);
670
671 // Records a stop event for an interval timer.
672 bool StopTimer(Stream *stream, Timer *timer);
673
674 // Allocates a new metadata object, appropriately populated, on the heap, with
675 // ownership transfer to caller.
676 std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
677
678 // Adds a task to the port::ThreadPool work queue. These tasks must be
679 // fire-and-forget and have no external data or timing dependencies; their
680 // execution order and completion time have no guarantees.
681 // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
682 // there, temporary internal buffers are freed using this method.
683 void EnqueueOnBackgroundThread(std::function<void()> task);
684
685 // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
686 // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
687 // tracked.
688 void CreateAllocRecord(void *opaque, uint64 bytes);
689
690 // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
691 // pointers will not be erased (as they're not tracked, per above).
692 void EraseAllocRecord(void *opaque);
693
694 // Calls the relevant TraceListener routine to begin tracing for the specified
695 // asynchronous method.
696 template <typename TraceCallT, typename... ArgsT>
697 void SubmitTrace(TraceCallT trace_call, ArgsT &&...args);
698
699 // Reader/writer lock for class-static StreamExecutor members.
700 static absl::Mutex static_mu_;
701
702 // Reader/writer lock for mutable data structures on this StreamExecutor.
703 //
704 // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
705 // can acquire the lock on their first (mutating) call as well.
706 mutable absl::Mutex mu_;
707
708 // Reference to the platform that created this executor.
709 const Platform *platform_;
710
711 // Pointer to the platform-specific-interface implementation. This is
712 // delegated to by the interface routines in pointer-to-implementation
713 // fashion.
714 std::unique_ptr<internal::StreamExecutorInterface> implementation_;
715
716 // A mapping of pointer (to device memory) to string representation of the
717 // stack (of the allocating thread) at the time at which the pointer was
718 // allocated.
719 std::map<void *, AllocRecord> mem_allocs_ TF_GUARDED_BY(mu_);
720
721 // Memoized BLAS support object -- we only want to create this once when asked
722 // for a BLAS interface.
723 std::unique_ptr<blas::BlasSupport> blas_ TF_GUARDED_BY(mu_);
724
725 // Memoized DNN support object -- we only want to create this once when asked
726 // for an DNN interface.
727 std::unique_ptr<dnn::DnnSupport> dnn_ TF_GUARDED_BY(mu_);
728
729 // Memoized FFT support object -- we only want to create this once when asked
730 // for a FFT interface.
731 std::unique_ptr<fft::FftSupport> fft_;
732
733 // Memoized RNG support object -- we only want to create this once when asked
734 // for an RNG interface.
735 std::unique_ptr<rng::RngSupport> rng_ TF_GUARDED_BY(mu_);
736
737 // Slot to cache the owned DeviceDescription for the underlying device
738 // once it has been queried from DeviceDescription().
739 mutable std::unique_ptr<DeviceDescription> device_description_
740 TF_GUARDED_BY(mu_);
741
742 // The kind of the underlying platform that is being targeted, as passed
743 // during construction.
744 //
745 // Immutable post-initialization.
746 PlatformKind platform_kind_;
747
748 // The device ordinal that this object was initialized with.
749 //
750 // Immutable post-initialization.
751 int device_ordinal_;
752
753 // Executor for handling host callback work that cannot be performed
754 // by a host callback thread - for example, cleanup after a host BLAS routine
755 // (which may make device API calls). This work cannot block the host
756 // callback thread, will be completed asynchronously, and should be treated
757 // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
758 // here.
759 //
760 // Immutable post-initialization. Object is thread-safe.
761 std::unique_ptr<port::ThreadPool> background_threads_;
762
763 // Counter for the current number of live streams. This is used to check
764 // for accidentally-outstanding streams at StreamExecutor teardown time, as
765 // well
766 // as to indicate leaks (via a large outstanding count being logged) in the
767 // case we can't allocate more streams.
768 std::atomic_int_fast32_t live_stream_count_;
769
770 // Only one worker thread is needed; little work will be done by the
771 // executor.
772 static constexpr int kNumBackgroundThreads = 1;
773
774 // Indicates if StreamExecutor operation tracing should be performed.
775 bool tracing_enabled_;
776
777 // The set of TraceListeners registered for this StreamExecutor.
778 std::set<TraceListener *> listeners_ TF_GUARDED_BY(mu_);
779
780 // Allocated memory in bytes.
781 int64 mem_alloc_bytes_;
782
783 // Memory limit in bytes. Value less or equal to 0 indicates there is no
784 // limit.
785 int64 memory_limit_bytes_;
786
787 StreamExecutorMemoryAllocator allocator_;
788
789 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
790 };
791
792 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
793 class ScopedModuleHandle {
794 public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)795 explicit ScopedModuleHandle(StreamExecutor *executor,
796 ModuleHandle module_handle)
797 : executor_(executor), module_handle_(module_handle) {}
798
ScopedModuleHandle(ScopedModuleHandle && other)799 ScopedModuleHandle(ScopedModuleHandle &&other) {
800 executor_ = other.executor_;
801 module_handle_ = other.module_handle_;
802 other.executor_ = nullptr;
803 other.module_handle_ = ModuleHandle();
804 }
805
806 ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
807 executor_ = other.executor_;
808 module_handle_ = other.module_handle_;
809 other.executor_ = nullptr;
810 other.module_handle_ = ModuleHandle();
811 return *this;
812 }
813
~ScopedModuleHandle()814 ~ScopedModuleHandle() {
815 if (static_cast<bool>(module_handle_)) {
816 CHECK(executor_->UnloadModule(module_handle_));
817 }
818 }
819
820 private:
821 StreamExecutor *executor_;
822 ModuleHandle module_handle_;
823
824 TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
825 };
826
827 ////////////
828 // Inlines
829
830 template <typename... Args>
831 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)832 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
833 absl::string_view ptx,
834 absl::Span<const uint8> cubin_data) {
835 auto kernel_base = absl::make_unique<TypedKernel<Args...>>(this);
836 MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
837 loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
838
839 if (!cubin_data.empty()) {
840 loader_spec.AddCudaCubinInMemory(
841 reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
842 }
843
844 TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
845 return std::move(kernel_base);
846 }
847
848 template <typename T>
AllocateArray(uint64 element_count,int64_t memory_space)849 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
850 int64_t memory_space) {
851 uint64 bytes = sizeof(T) * element_count;
852 return DeviceMemory<T>(Allocate(bytes, memory_space));
853 }
854
855 template <typename T>
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle)856 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
857 const std::string &symbol_name, ModuleHandle module_handle) {
858 port::StatusOr<DeviceMemoryBase> untyped_symbol =
859 GetUntypedSymbol(symbol_name, module_handle);
860 if (!untyped_symbol.ok()) {
861 return untyped_symbol.status();
862 }
863 return DeviceMemory<T>(untyped_symbol.ValueOrDie());
864 }
865
866 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)867 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
868 DeviceMemoryBase value)
869 : wrapped_(value),
870 device_ordinal_(parent->device_ordinal()),
871 allocator_(parent->GetAllocator()) {}
872
873 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)874 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
875 StreamExecutor *parent, std::initializer_list<ElemT> values)
876 : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
877 if (ptr() != nullptr) {
878 std::vector<ElemT> local(values);
879 if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
880 ptr()->size())) {
881 TF_CHECK_OK(Free());
882 }
883 }
884 }
885
886 template <typename T>
AllocateZeroed()887 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
888 DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
889 if (buf.is_null()) {
890 return DeviceMemory<T>{};
891 }
892
893 DeviceMemory<T> result(buf);
894 bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
895 if (!ok) {
896 Deallocate(&result);
897 return DeviceMemory<T>{};
898 }
899
900 return result;
901 }
902
903 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)904 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T> *parent,
905 uint64 element_offset,
906 uint64 element_count) {
907 if (element_offset + element_count > parent->ElementCount()) {
908 LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
909 << "than parent allocation size: (" << element_offset << " + "
910 << element_count << ") vs. (" << parent->ElementCount() << ")";
911 return DeviceMemory<T>{};
912 }
913
914 void *opaque = implementation_->GetSubBuffer(
915 parent, sizeof(T) * element_offset, sizeof(T) * element_count);
916 if (opaque == nullptr) {
917 return DeviceMemory<T>{};
918 }
919 return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
920 }
921
922 } // namespace stream_executor
923
924 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
925