1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24
25 #include "absl/base/macros.h"
26 #include "absl/memory/memory.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/platform/thread_annotations.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31 #include "tensorflow/stream_executor/lib/status.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 #include "tensorflow/stream_executor/lib/threadpool.h"
34 #include "tensorflow/stream_executor/platform.h"
35 #include "tensorflow/stream_executor/platform/logging.h"
36 #include "tensorflow/stream_executor/platform/port.h"
37 #include "tensorflow/stream_executor/rng.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/trace_listener.h"
40
41 namespace stream_executor {
42
43 class Stream;
44
45 // Structure used for device memory leak checking.
46 struct AllocRecord {
47 // The requested allocation size of the buffer.
48 uint64 bytes;
49
50 // Holds a representation of the stack at the time the associated buffer was
51 // allocated. Produced in a form described in
52 // //util/symbolize/symbolized_stacktrace.h.
53 std::string stack_trace;
54 };
55
56 // Forward declaration of private friend class.
57 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
58 typename... BeginArgsT>
59 class ScopedTracer;
60
61 // A StreamExecutor manages a single device, in terms of executing work (kernel
62 // launches) and memory management (allocation/deallocation, memory copies to
63 // and from the device). It is conceptually the "handle" for a device -- Stream
64 // objects, which are used to enqueue work to run on the
65 // coprocessor have a StreamExecutor instance as their "parent" object.
66 //
67 // StreamExecutor objects have an underlying platform that is specified up
68 // front;
69 // e.g. either it is a CUDA or OpenCL executor.
70 //
71 // Thread-safe after initialization.
72 // StreamExecutor interface should not be invoked from a signal handler.
73 class StreamExecutor {
74 public:
75 StreamExecutor(
76 const Platform *platform,
77 std::unique_ptr<internal::StreamExecutorInterface> implementation,
78 int device_ordinal);
79
80 ~StreamExecutor();
81
82 port::Status Init();
83 port::Status Init(DeviceOptions device_options);
84
85 // Returns the platform that this StreamExecutor is acting upon.
86 ABSL_DEPRECATED("Use platform() instead.")
platform_kind()87 PlatformKind platform_kind() const { return platform_kind_; }
88
89 // Returns a reference to the platform that created this executor.
platform()90 const Platform *platform() const { return platform_; }
91
92 // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
93 // upon, if one exists.
94 //
95 // Parameters:
96 // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
97 // constant into an appropriate namespace. For example, see
98 // stream_executor::executor_sample::kKernelLoaderSpecs, from which a
99 // MultiKernelLoaderSpec is selected.
100 // kernel: Outparam that the kernel is loaded into. A given Kernel
101 // instantiation should not be loaded into more than once.
102 //
103 // If an error occurs, or there is no kernel available for the StreamExecutor
104 // platform, error status is returned.
105 port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
106
107 // Releases any state associated with the previously loaded kernel.
108 void UnloadKernel(const KernelBase *kernel);
109
110 // Loads a module for the platform this StreamExecutor is acting upon.
111 //
112 // `spec` describes the module to be loaded. On success writes the handle for
113 // the loaded module to `module_handle` and returns Status::OK.
114 // Otherwise, returns the error which has occurred.
115 port::Status LoadModule(const MultiModuleLoaderSpec &spec,
116 ModuleHandle *module_handle);
117
118 // Unloads the module with handle `module_handle`.
119 bool UnloadModule(ModuleHandle module_handle);
120
121 // Synchronously allocates an array on the device of type T with element_count
122 // elements.
123 template <typename T>
124 DeviceMemory<T> AllocateArray(uint64 element_count, int64 memory_space = 0);
125
126 // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
127 template <typename T>
AllocateOwnedArray(uint64 element_count)128 ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
129 return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
130 }
131
132 // Convenience wrapper that allocates space for a single element of type T in
133 // device memory.
134 template <typename T>
AllocateScalar()135 DeviceMemory<T> AllocateScalar() {
136 return AllocateArray<T>(1);
137 }
138
139 // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
140 template <typename T>
AllocateOwnedScalar()141 ScopedDeviceMemory<T> AllocateOwnedScalar() {
142 return AllocateOwnedArray<T>(1);
143 }
144
145 // Synchronously allocates a scalar of type T on the device that is (POD)
146 // zero-byte initialized.
147 template <typename T>
148 DeviceMemory<T> AllocateZeroed();
149
150 // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
151 template <typename T>
AllocateOwnedZeroed()152 ScopedDeviceMemory<T> AllocateOwnedZeroed() {
153 return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
154 }
155
156 // Allocate a memory region inside another allocated memory region.
157 // Offset and size are specified in terms of T elements.
158 // Warning: Do not free a parent buffer before its sub-buffers; this may cause
159 // use-after-free issues (the specific behavior is not consistent across
160 // platforms).
161 // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
162 // sub-buffer after parent deallocation is expected to be safe. This will
163 // render your code non-platform-portable, however.
164 template <typename T>
165 DeviceMemory<T> GetSubBuffer(DeviceMemory<T> *parent, uint64 element_offset,
166 uint64 element_count);
167
168 // Finds a symbol and returns device memory allocated to the symbol. The
169 // symbol is searched in any kernels that were previously loaded through
170 // GetKernel() before the GetSymbol() call. The user has to make sure that the
171 // type of symbol and T match.
172 // - Note: symbol_name should include its namespace as well. For example,
173 // pass "nms0::symbol" if referring to nms0::symbol.
174 //
175 // If `module_handle` is set then searches only within the module
176 // corresponding to `module_handle`.
177 template <typename T>
178 port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string &symbol_name,
179 ModuleHandle module_handle = {});
180
181 // An untyped version of GetSymbol.
182 port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
183 const std::string &symbol_name, ModuleHandle module_handle = {});
184
185 // Deallocate the DeviceMemory previously allocated via this interface.
186 // Deallocation of a nullptr-representative value is permitted.
187 //
188 // Resets the internal contents of mem to be null-representative, but this
189 // null-out effect should not be relied upon in client code.
190 void Deallocate(DeviceMemoryBase *mem);
191
192 // Retrieves a mapping of active opaque device memory pointer to a string
193 // representation of the [allocating thread's] stack at the time the pointer
194 // was allocated. Useful for tracking device memory leaks.
195 //
196 // Note: this will only be populated if --check_device_leaks flag is
197 // activated.
198 void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
199
200 // Allocates unified memory space of the given size, if supported.
201 // See
202 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
203 // for more details on unified memory.
204 void *UnifiedMemoryAllocate(uint64 bytes);
205
206 // Deallocates unified memory space previously allocated with
207 // UnifiedMemoryAllocate.
208 void UnifiedMemoryDeallocate(void *location);
209
210 // Allocates a region of host memory and registers it with the platform API.
211 // Memory allocated in this manner (or allocated and registered with
212 // HostMemoryRegister() is required for use in asynchronous memcpy operations,
213 // such as Stream::ThenMemcpy.
214 void *HostMemoryAllocate(uint64 size);
215
216 // Deallocates a region of host memory allocated by HostMemoryAllocate().
217 void HostMemoryDeallocate(void *location);
218
219 // Registers a region of host memory with the platform API. Registered memory
220 // (or memory allocated with HostMemoryAllocate) is required for use with
221 // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
222 // is used to register memory allocated outside the StreamExecutor;
223 // HostMemoryAllocate implicitly registers its allocations and
224 // HostMemoryDeallocate implicitly deregisters on deallocation.
225 bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
226
227 // Unregisters a region of host memory registered with HostMemoryRegister.
228 // This should be done before deallocating the region with delete[]/free/etc.
229 bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
230
231 // Synchronizes all activity occurring in the StreamExecutor's context (most
232 // likely a whole device).
233 bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
234
235 // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
236 // given location in device memory.
237 port::Status SynchronousMemZero(DeviceMemoryBase *location,
238 uint64 size) SE_MUST_USE_RESULT;
239
240 // Blocks the caller while "size" bytes are initialized to "value" (in POD
241 // fashion) at the given location in device memory.
242 port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
243 uint64 size) SE_MUST_USE_RESULT;
244
245 // [deprecated] Blocks the caller while a data segment of the given size is
246 // copied from the host source to the device destination.
247 ABSL_DEPRECATED(
248 "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
249 bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
250 uint64 size) SE_MUST_USE_RESULT;
251
252 // [deprecated] Blocks the caller while a data segment of the given size is
253 // copied from the device source to the host destination.
254 ABSL_DEPRECATED(
255 "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
256 bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
257 uint64 size) SE_MUST_USE_RESULT;
258
259 // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
260 port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
261 DeviceMemoryBase *device_dst);
262
263 // Alternative interface for memcpying from host to device that takes an
264 // array slice. Checks that the destination size can accommodate the host
265 // slice size.
266 template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)267 port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
268 DeviceMemoryBase *device_dst) {
269 auto host_size = host_src.size() * sizeof(T);
270 CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
271 return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
272 }
273
274 // Same as SynchronousMemcpy(void*, ...) above.
275 port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
276 int64 size, void *host_dst);
277
278 // Alternative interface for memcpying from device to host that takes an
279 // array slice. Checks that the destination size can accommodate the host
280 // slice size.
281 template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)282 port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
283 port::MutableArraySlice<T> host_dst) {
284 auto host_size = host_dst.size() * sizeof(T);
285 CHECK(device_src.size() == 0 || host_size >= device_src.size());
286 return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
287 }
288
289 // Blocks the caller while a data segment of the given size is copied from the
290 // device source to the device destination.
291 bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
292 const DeviceMemoryBase &device_src,
293 uint64 size) SE_MUST_USE_RESULT;
294
295 // Enqueues an operation onto stream to zero out size bytes at the given
296 // device memory location. Neither stream nor location may be null. Returns
297 // whether the operation was successfully enqueued onto the stream.
298 port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
299 uint64 size) SE_MUST_USE_RESULT;
300
301 // Enqueues an operation onto stream to set 32-bit patterns starting at
302 // location, for byte count given by size. size must be 32-bit quantified
303 // (i.e. evently divisible by 4). Returns whether the operation was
304 // successfully enqueued onto the stream.
305 port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
306 uint32 pattern, uint64 size);
307
308 // Enables peer access from this StreamExecutor to memory
309 // allocated by other, such that launched device code, memcpies, etc may
310 // access it directly.
311 //
312 // Both this StreamExecutor and other must be backed by the same platform (as
313 // in
314 // CUDA vs OpenCL) implementation.
315 port::Status EnablePeerAccessTo(StreamExecutor *other);
316
317 // Returns whether it's possible to enable peer access from this
318 // StreamExecutor
319 // to memory allocated by another.
320 //
321 // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
322 // this is more an up-front test as to whether it's expressly forbidden.
323 bool CanEnablePeerAccessTo(StreamExecutor *other);
324
325 // Obtains metadata about the underlying device.
326 // The value is cached on first use.
327 const DeviceDescription &GetDeviceDescription() const;
328
329 // If implemented, returns device specific measurement of load
330 // (e.g. pending requests).
331 int64 GetDeviceLoad() const;
332
333 // Returns the underlying device memory usage information, if it is available.
334 // If it is not available (false is returned), free/total may not be
335 // initialized.
336 //
337 // Note: "Free" reflects the amount of free memory on the underlying device,
338 // so allocations via other StreamExecutors that have the same underlying
339 // device
340 // will be reflected in "free".
341 bool DeviceMemoryUsage(int64 *free, int64 *total) const;
342
343 // The device count reported by this StreamExecutor's platform.
344 // Note: on OpenCL we implicitly select platform zero at the moment.
345 int PlatformDeviceCount() const;
346
347 // Returns whether the StreamExecutor supports BLAS routines for the platform
348 // that underlies this interface.
349 bool SupportsBlas() const;
350
351 // Returns whether the StreamExecutor supports FFT routines for the platform
352 // that underlies this interface.
353 bool SupportsFft() const;
354
355 // Returns whether the StreamExecutor supports RNG routines for the platform
356 // that underlies this interface.
357 bool SupportsRng() const;
358
359 // Returns whether the StreamExecutor support neural net routines for the
360 // platform that underlies this interface.
361 bool SupportsDnn() const;
362
363 // Returns the list of supported algorithms for the forward convolution
364 // operation.
365 bool GetConvolveAlgorithms(bool with_winograd_nonfused,
366 std::vector<dnn::AlgorithmDesc> *out_algorithms);
367
368 // Returns the list of supported algorithms for the forward convolution
369 // operation.
370 bool GetMIOpenConvolveAlgorithms(
371 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
372 const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
373 const dnn::FilterDescriptor &filter_descriptor,
374 DeviceMemoryBase filter_data,
375 const dnn::BatchDescriptor &output_descriptor,
376 DeviceMemoryBase output_data,
377 const dnn::ConvolutionDescriptor &convolution_descriptor,
378 ScratchAllocator *scratch_allocator,
379 std::vector<dnn::ProfileResult> *out_algorithms);
380
381 // Returns the list of supported algorithms for rnn operation.
382 bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
383
384 // Get the list of supported algorithms for the backward convolution on data.
385 bool GetConvolveBackwardDataAlgorithms(
386 bool with_winograd_nonfused,
387 std::vector<dnn::AlgorithmDesc> *out_algorithms);
388
389 // Get the list of supported algorithms for the backward convolution on the
390 // filter.
391 bool GetConvolveBackwardFilterAlgorithms(
392 bool with_winograd_nonfused,
393 std::vector<dnn::AlgorithmDesc> *out_algorithms);
394
395 // Get the list of supported algorithms for BLAS gemm.
396 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
397
398 // Creates a backend-specific plan object for a blaslt matmul operation, which
399 // can then be passed to DoBlasLtMatmul(). When possible, plans should be
400 // created once and reused for multiple calls to DoBlasLtMatmul().
401 // Returns a null pointer on failure.
402 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
403 CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms);
404
405 // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
406 // returned in the order of increasing estimated compute time according to an
407 // internal heuristic. The first returned algorithm can be used as the default
408 // algorithm if no autotuning is to be performed.
409 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
410 GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
411 size_t max_workspace_size, int max_algorithm_count);
412
413 // Create an RNN descriptor based on model shapes and configurations.
414 // The caller retains the ownership of the descriptor.
415 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
416 int num_layers, int hidden_size, int input_size, int cell_size,
417 int batch_size, dnn::RnnInputMode input_mode,
418 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
419 dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
420 float dropout, uint64 seed, ScratchAllocator *state_allocator,
421 bool use_padded_io);
422
423 // Create a RNN sequence descriptor that specifies either the input or output
424 // sequence. The caller retains the ownership of the returned descriptor.
425 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
426 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
427 int data_size, dnn::DataType data_type);
428
429 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
430 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
431 int data_size,
432 const absl::Span<const int> &seq_lengths,
433 bool time_major, dnn::DataType data_type);
434
435 // Create an RNN state descriptor that specifies the input or hidden state.
436 // The caller retains the ownership of the returned descriptor.
437 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
438 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
439 dnn::DataType data_type);
440
441 // Returns the device ordinal that this StreamExecutor was initialized with.
442 // Meaningless before initialization.
device_ordinal()443 int device_ordinal() const { return device_ordinal_; }
444
445 // Returns a borrowed pointer to the underlying StreamExecutor implementation.
446 internal::StreamExecutorInterface *implementation();
447
448 // Creates a kernel which can be launched with stream.ThenLaunch, such that
449 // the types of the arguments provided for launch would have to match
450 // types of the arguments provided at creation time.
451 //
452 // The kernel has a name kernel_name, and is based from provided PTX in ptx,
453 // and (optional) compiled PTX in cubin_data.
454 // The canonical storage for both ptx and cubin_data should outlive the
455 // lifetime of the kernel.
456 template <typename... Args>
457 port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
458 absl::string_view kernel_name, absl::string_view ptx,
459 absl::Span<const uint8> cubin_data);
460
461 // Warning: use Stream::ThenLaunch instead, this method is not for general
462 // consumption. However, this is the only way to launch a kernel for which
463 // the type signature is only known at runtime; say, if an application
464 // supports loading/launching kernels with arbitrary type signatures.
465 // In this case, the application is expected to know how to do parameter
466 // packing that obeys the contract of the underlying platform implementation.
467 //
468 // Launches a data parallel kernel with the given thread/block
469 // dimensionality and already-packed args/sizes to pass to the underlying
470 // platform driver.
471 //
472 // This is called by Stream::Launch() to delegate to the platform's launch
473 // implementation in StreamExecutorInterface::Launch().
474 port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
475 const BlockDim &block_dims, const KernelBase &kernel,
476 const KernelArgsArrayBase &args);
477
478 // Gets-or-creates (creates with memoization) a FftSupport datatype that can
479 // be used to execute FFT routines on the current platform.
480 //
481 // Ownership and user-facing is the same as AsBlas() below.
482 //
483 // Returns null if there was an error initializing the FFT support for the
484 // underlying platform.
485 fft::FftSupport *AsFft();
486
487 // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
488 // be used for neural network routines on the current platform.
489 //
490 // Ownership and user-facing is the same as AsBlas() below.
491 //
492 // Returns null if there was an error initializing the DNN support for the
493 // underlying platform.
494 dnn::DnnSupport *AsDnn();
495
496 // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
497 // be used to execute BLAS routines on the current platform. This is typically
498 // not user-facing, as users will use the Stream::ThenBlas* family of routines
499 // to entrain BLAS operations. See blas.h for additional details.
500 //
501 // Ownership is not transferred to the caller -- ownership is retained by this
502 // object for memoization. This BLAS interface is also only expected to be
503 // used by a Stream for entraining calls to BLAS functionality.
504 //
505 // Returns null if there was an error initializing the BLAS support for the
506 // underlying platform.
507 blas::BlasSupport *AsBlas();
508
509 // Turns StreamExecutor operation tracing on or off.
510 void EnableTracing(bool enable);
511
512 // Registers a trace listener to receive callbacks for only a single
513 // StreamExecutor instance.
514 // To register a listener for all executors for a given platform, see
515 // Platform::RegisterTraceListener().
516 // Does not take ownership of listener.
517 void RegisterTraceListener(TraceListener *listener);
518
519 // Removes a TraceListener from this StreamExecutor instance.
520 // Returns false (and logs) in cases where the argument listener was not
521 // previously registered.
522 bool UnregisterTraceListener(TraceListener *listener);
523
524 // Return allocator statistics.
525 absl::optional<AllocatorStats> GetAllocatorStats();
526
527 // Return an allocator which delegates to this stream executor for memory
528 // allocation.
GetAllocator()529 StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
530
531 private:
532 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
533 typename... BeginArgsT>
534 friend class ScopedTracer;
535 friend class Event;
536 friend class Stream;
537 friend class Timer;
538 template <typename... Params>
539 friend class TypedKernel;
540 template <typename... Args>
541 friend struct ThenBlasImpl;
542
543 // Synchronously allocates size bytes on the underlying platform and returns
544 // a DeviceMemoryBase representing that allocation. In the case of failure,
545 // nullptr is returned.
546 DeviceMemoryBase Allocate(uint64 size, int64 memory_space);
547
548 // Gets-or-creates (creates with memoization) an RngSupport datatype that can
549 // be used for random-number-generation routines on the current platform.
550 //
551 // Ownership and user-facing is the same as AsBlas() above.
552 //
553 // Returns null if there was an error initializing the RNG support for the
554 // underlying platform.
555 rng::RngSupport *AsRng();
556
557 // Causes the host code to synchronously wait for operations entrained onto
558 // stream to complete. Effectively a join on the asynchronous device
559 // operations enqueued on the stream before this program point.
560 port::Status BlockHostUntilDone(Stream *stream);
561
562 // Without blocking the device, retrieve the current stream status.
563 port::Status GetStatus(Stream *stream);
564
565 // Finds and retrieves device memory for the symbol on the underlying
566 // platform.
567 bool GetSymbol(const std::string &symbol_name, ModuleHandle module_handle,
568 void **mem, size_t *bytes);
569
570 // Entrains a memcpy operation onto stream, with a host destination location
571 // host_dst and a device memory source, with target size size.
572 bool Memcpy(Stream *stream, void *host_dst,
573 const DeviceMemoryBase &device_src, uint64 size);
574
575 // Entrains a memcpy operation onto stream, with a device destination location
576 // and a host memory source, with target size size.
577 bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
578 const void *host_src, uint64 size);
579
580 // Entrains a memcpy operation onto stream, with a device destination location
581 // and a device source location, with target size size. Peer access should
582 // have been enabled between the StreamExecutors owning the device memory
583 // regions.
584 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
585 const DeviceMemoryBase &device_src, uint64 size);
586
587 // Entrains on a stream a user-specified function to be run on the host.
588 // See Stream::ThenDoHostCallback for full details.
589 bool HostCallback(Stream *stream, std::function<void()> callback);
590
591 // Entrains on a stream a user-specified function to be run on the host.
592 // See Stream::ThenDoHostCallback for full details.
593 // This is the preferred form for a callback that may return an error.
594 bool HostCallback(Stream *stream, std::function<port::Status()> callback);
595
596 // Performs platform-specific allocation and initialization of an event.
597 port::Status AllocateEvent(Event *event);
598
599 // Performs platform-specific deallocation and cleanup of an event.
600 port::Status DeallocateEvent(Event *event);
601
602 // Inserts the specified event at the end of the specified stream.
603 port::Status RecordEvent(Stream *stream, Event *event);
604
605 // Wait for the specified event at the end of the specified stream.
606 port::Status WaitForEvent(Stream *stream, Event *event);
607
608 // Requests the current status of the event from the underlying platform.
609 Event::Status PollForEventStatus(Event *event);
610
611 // Allocates stream resources on the underlying platform and initializes its
612 // internals.
613 bool AllocateStream(Stream *stream);
614
615 // Deallocates stream resources on the underlying platform.
616 void DeallocateStream(Stream *stream);
617
618 // Causes dependent to not begin execution until other has finished its
619 // last-enqueued work.
620 bool CreateStreamDependency(Stream *dependent, Stream *other);
621
622 // Allocates timer resources on the underlying platform and initializes its
623 // internals.
624 bool AllocateTimer(Timer *timer);
625
626 // Deallocates timer resources on the underlying platform.
627 void DeallocateTimer(Timer *timer);
628
629 // Records a start event for an interval timer.
630 bool StartTimer(Stream *stream, Timer *timer);
631
632 // Records a stop event for an interval timer.
633 bool StopTimer(Stream *stream, Timer *timer);
634
635 // Allocates a new metadata object, appropriately populated, on the heap, with
636 // ownership transfer to caller.
637 std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
638
639 // Adds a task to the port::ThreadPool work queue. These tasks must be
640 // fire-and-forget and have no external data or timing dependencies; their
641 // execution order and completion time have no guarantees.
642 // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
643 // there, temporary internal buffers are freed using this method.
644 void EnqueueOnBackgroundThread(std::function<void()> task);
645
646 // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
647 // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
648 // tracked.
649 void CreateAllocRecord(void *opaque, uint64 bytes);
650
651 // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
652 // pointers will not be erased (as they're not tracked, per above).
653 void EraseAllocRecord(void *opaque);
654
655 // Calls the relevant TraceListener routine to begin tracing for the specified
656 // asynchronous method.
657 template <typename TraceCallT, typename... ArgsT>
658 void SubmitTrace(TraceCallT trace_call, ArgsT &&...args);
659
660 // Reader/writer lock for class-static StreamExecutor members.
661 static absl::Mutex static_mu_;
662
663 // Reader/writer lock for mutable data structures on this StreamExecutor.
664 //
665 // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
666 // can acquire the lock on their first (mutating) call as well.
667 mutable absl::Mutex mu_;
668
669 // Reference to the platform that created this executor.
670 const Platform *platform_;
671
672 // Pointer to the platform-specific-interface implementation. This is
673 // delegated to by the interface routines in pointer-to-implementation
674 // fashion.
675 std::unique_ptr<internal::StreamExecutorInterface> implementation_;
676
677 // A mapping of pointer (to device memory) to string representation of the
678 // stack (of the allocating thread) at the time at which the pointer was
679 // allocated.
680 std::map<void *, AllocRecord> mem_allocs_ TF_GUARDED_BY(mu_);
681
682 // Memoized BLAS support object -- we only want to create this once when asked
683 // for a BLAS interface.
684 std::unique_ptr<blas::BlasSupport> blas_ TF_GUARDED_BY(mu_);
685
686 // Memoized DNN support object -- we only want to create this once when asked
687 // for an DNN interface.
688 std::unique_ptr<dnn::DnnSupport> dnn_ TF_GUARDED_BY(mu_);
689
690 // Memoized FFT support object -- we only want to create this once when asked
691 // for a FFT interface.
692 std::unique_ptr<fft::FftSupport> fft_;
693
694 // Memoized RNG support object -- we only want to create this once when asked
695 // for an RNG interface.
696 std::unique_ptr<rng::RngSupport> rng_ TF_GUARDED_BY(mu_);
697
698 // Slot to cache the owned DeviceDescription for the underlying device
699 // once it has been queried from DeviceDescription().
700 mutable std::unique_ptr<DeviceDescription> device_description_
701 TF_GUARDED_BY(mu_);
702
703 // The kind of the underlying platform that is being targeted, as passed
704 // during construction.
705 //
706 // Immutable post-initialization.
707 PlatformKind platform_kind_;
708
709 // The device ordinal that this object was initialized with.
710 //
711 // Immutable post-initialization.
712 int device_ordinal_;
713
714 // Executor for handling host callback work that cannot be performed
715 // by a host callback thread - for example, cleanup after a host BLAS routine
716 // (which may make device API calls). This work cannot block the host
717 // callback thread, will be completed asynchronously, and should be treated
718 // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
719 // here.
720 //
721 // Immutable post-initialization. Object is thread-safe.
722 std::unique_ptr<port::ThreadPool> background_threads_;
723
724 // Counter for the current number of live streams. This is used to check
725 // for accidentally-outstanding streams at StreamExecutor teardown time, as
726 // well
727 // as to indicate leaks (via a large outstanding count being logged) in the
728 // case we can't allocate more streams.
729 std::atomic_int_fast32_t live_stream_count_;
730
731 // Only one worker thread is needed; little work will be done by the
732 // executor.
733 static constexpr int kNumBackgroundThreads = 1;
734
735 // Indicates if StreamExecutor operation tracing should be performed.
736 bool tracing_enabled_;
737
738 // The set of TraceListeners registered for this StreamExecutor.
739 std::set<TraceListener *> listeners_ TF_GUARDED_BY(mu_);
740
741 // Allocated memory in bytes.
742 int64 mem_alloc_bytes_;
743
744 // Memory limit in bytes. Value less or equal to 0 indicates there is no
745 // limit.
746 int64 memory_limit_bytes_;
747
748 StreamExecutorMemoryAllocator allocator_;
749
750 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
751 };
752
753 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
754 class ScopedModuleHandle {
755 public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)756 explicit ScopedModuleHandle(StreamExecutor *executor,
757 ModuleHandle module_handle)
758 : executor_(executor), module_handle_(module_handle) {}
759
ScopedModuleHandle(ScopedModuleHandle && other)760 ScopedModuleHandle(ScopedModuleHandle &&other) {
761 executor_ = other.executor_;
762 module_handle_ = other.module_handle_;
763 other.executor_ = nullptr;
764 other.module_handle_ = ModuleHandle();
765 }
766
767 ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
768 executor_ = other.executor_;
769 module_handle_ = other.module_handle_;
770 other.executor_ = nullptr;
771 other.module_handle_ = ModuleHandle();
772 return *this;
773 }
774
~ScopedModuleHandle()775 ~ScopedModuleHandle() {
776 if (static_cast<bool>(module_handle_)) {
777 CHECK(executor_->UnloadModule(module_handle_));
778 }
779 }
780
781 private:
782 StreamExecutor *executor_;
783 ModuleHandle module_handle_;
784
785 TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
786 };
787
788 ////////////
789 // Inlines
790
791 template <typename... Args>
792 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)793 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
794 absl::string_view ptx,
795 absl::Span<const uint8> cubin_data) {
796 auto kernel_base = absl::make_unique<TypedKernel<Args...>>(this);
797 MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
798 loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
799
800 if (!cubin_data.empty()) {
801 loader_spec.AddCudaCubinInMemory(
802 reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
803 }
804
805 TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
806 return std::move(kernel_base);
807 }
808
809 template <typename T>
AllocateArray(uint64 element_count,int64 memory_space)810 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
811 int64 memory_space) {
812 uint64 bytes = sizeof(T) * element_count;
813 return DeviceMemory<T>(Allocate(bytes, memory_space));
814 }
815
816 template <typename T>
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle)817 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
818 const std::string &symbol_name, ModuleHandle module_handle) {
819 port::StatusOr<DeviceMemoryBase> untyped_symbol =
820 GetUntypedSymbol(symbol_name, module_handle);
821 if (!untyped_symbol.ok()) {
822 return untyped_symbol.status();
823 }
824 return DeviceMemory<T>(untyped_symbol.ValueOrDie());
825 }
826
827 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)828 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
829 DeviceMemoryBase value)
830 : wrapped_(value),
831 device_ordinal_(parent->device_ordinal()),
832 allocator_(parent->GetAllocator()) {}
833
834 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)835 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
836 StreamExecutor *parent, std::initializer_list<ElemT> values)
837 : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
838 if (ptr() != nullptr) {
839 std::vector<ElemT> local(values);
840 if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
841 ptr()->size())) {
842 TF_CHECK_OK(Free());
843 }
844 }
845 }
846
847 template <typename T>
AllocateZeroed()848 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
849 DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
850 if (buf.is_null()) {
851 return DeviceMemory<T>{};
852 }
853
854 DeviceMemory<T> result(buf);
855 bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
856 if (!ok) {
857 Deallocate(&result);
858 return DeviceMemory<T>{};
859 }
860
861 return result;
862 }
863
864 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)865 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T> *parent,
866 uint64 element_offset,
867 uint64 element_count) {
868 if (element_offset + element_count > parent->ElementCount()) {
869 LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
870 << "than parent allocation size: (" << element_offset << " + "
871 << element_count << ") vs. (" << parent->ElementCount() << ")";
872 return DeviceMemory<T>{};
873 }
874
875 void *opaque = implementation_->GetSubBuffer(
876 parent, sizeof(T) * element_offset, sizeof(T) * element_count);
877 if (opaque == nullptr) {
878 return DeviceMemory<T>{};
879 }
880 return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
881 }
882
883 } // namespace stream_executor
884
885 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
886