1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24
25 #include "absl/base/macros.h"
26 #include "absl/synchronization/mutex.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/stream_executor/device_memory_allocator.h"
29 #include "tensorflow/stream_executor/lib/status.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31 #include "tensorflow/stream_executor/lib/threadpool.h"
32 #include "tensorflow/stream_executor/platform.h"
33 #include "tensorflow/stream_executor/platform/logging.h"
34 #include "tensorflow/stream_executor/platform/port.h"
35 #include "tensorflow/stream_executor/platform/thread_annotations.h"
36 #include "tensorflow/stream_executor/rng.h"
37 #include "tensorflow/stream_executor/shared_memory_config.h"
38 #include "tensorflow/stream_executor/stream.h"
39 #include "tensorflow/stream_executor/stream_executor_internal.h"
40 #include "tensorflow/stream_executor/trace_listener.h"
41
42 namespace stream_executor {
43
44 // Structure used for device memory leak checking.
45 struct AllocRecord {
46 // The requested allocation size of the buffer.
47 uint64 bytes;
48
49 // Holds a representation of the stack at the time the associated buffer was
50 // allocated. Produced in a form described in
51 // //util/symbolize/symbolized_stacktrace.h.
52 string stack_trace;
53 };
54
55 // Forward declaration of private friend class.
56 template <typename BeginCallT, typename CompleteCallT,
57 typename ReturnT, typename... BeginArgsT>
58 class ScopedTracer;
59
60 // A StreamExecutor manages a single device, in terms of executing work (kernel
61 // launches) and memory management (allocation/deallocation, memory copies to
62 // and from the device). It is conceptually the "handle" for a device -- Stream
63 // objects, which are used to enqueue work to run on the
64 // coprocessor have a StreamExecutor instance as their "parent" object.
65 //
66 // StreamExecutor objects have an underlying platform that is specified up
67 // front;
68 // e.g. either it is a CUDA or OpenCL executor.
69 //
70 // Thread-safe after initialization.
71 // StreamExecutor interface should not be invoked from a signal handler.
72 class StreamExecutor {
73 public:
74 StreamExecutor(
75 const Platform *platform,
76 std::unique_ptr<internal::StreamExecutorInterface> implementation,
77 int device_ordinal);
78
79 ~StreamExecutor();
80
81 port::Status Init();
82 port::Status Init(DeviceOptions device_options);
83
84 // Returns the platform that this StreamExecutor is acting upon.
85 ABSL_DEPRECATED("Use platform() instead.")
platform_kind()86 PlatformKind platform_kind() const { return platform_kind_; }
87
88 // Returns a reference to the platform that created this executor.
platform()89 const Platform *platform() const { return platform_; }
90
91 // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
92 // upon, if one exists.
93 //
94 // Parameters:
95 // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
96 // constant into an appropriate namespace. For example, see
97 // stream_executor::executor_sample::kKernelLoaderSpecs, from which a
98 // MultiKernelLoaderSpec is selected.
99 // kernel: Outparam that the kernel is loaded into. A given Kernel
100 // instantiation should not be loaded into more than once.
101 //
102 // If an error occurs, or there is no kernel available for the StreamExecutor
103 // platform, error status is returned.
104 port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
105
106 // Releases any state associated with the previously loaded kernel.
107 void UnloadKernel(const KernelBase *kernel);
108
109 // Loads a module for the platform this StreamExecutor is acting upon.
110 //
111 // `spec` describes the module to be loaded. On success writes the handle for
112 // the loaded module to `module_handle` and returns Status::OK.
113 // Otherwise, returns the error which has occurred.
114 port::Status LoadModule(const MultiModuleLoaderSpec &spec,
115 ModuleHandle *module_handle);
116
117 // Unloads the module with handle `module_handle`.
118 bool UnloadModule(ModuleHandle module_handle);
119
120 // Synchronously allocates an array on the device of type T with element_count
121 // elements.
122 template <typename T>
123 DeviceMemory<T> AllocateArray(uint64 element_count, int64 memory_space = 0);
124
125 // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
126 template <typename T>
AllocateOwnedArray(uint64 element_count)127 ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
128 return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
129 }
130
131 // Convenience wrapper that allocates space for a single element of type T in
132 // device memory.
133 template <typename T>
AllocateScalar()134 DeviceMemory<T> AllocateScalar() {
135 return AllocateArray<T>(1);
136 }
137
138 // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
139 template <typename T>
AllocateOwnedScalar()140 ScopedDeviceMemory<T> AllocateOwnedScalar() {
141 return AllocateOwnedArray<T>(1);
142 }
143
144 // Synchronously allocates a scalar of type T on the device that is (POD)
145 // zero-byte initialized.
146 template <typename T>
147 DeviceMemory<T> AllocateZeroed();
148
149 // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
150 template <typename T>
AllocateOwnedZeroed()151 ScopedDeviceMemory<T> AllocateOwnedZeroed() {
152 return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
153 }
154
155 // Allocate a memory region inside another allocated memory region.
156 // Offset and size are specified in terms of T elements.
157 // Warning: Do not free a parent buffer before its sub-buffers; this may cause
158 // use-after-free issues (the specific behavior is not consistent across
159 // platforms).
160 // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
161 // sub-buffer after parent deallocation is expected to be safe. This will
162 // render your code non-platform-portable, however.
163 template <typename T>
164 DeviceMemory<T> GetSubBuffer(DeviceMemory<T> *parent, uint64 element_offset,
165 uint64 element_count);
166
167 // Finds a symbol and returns device memory allocated to the symbol. The
168 // symbol is searched in any kernels that were previously loaded through
169 // GetKernel() before the GetSymbol() call. The user has to make sure that the
170 // type of symbol and T match.
171 // - Note: symbol_name should include its namespace as well. For example,
172 // pass "nms0::symbol" if referring to nms0::symbol.
173 //
174 // If `module_handle` is set then searches only within the module
175 // corresponding to `module_handle`.
176 template <typename T>
177 port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
178 ModuleHandle module_handle = {});
179
180 // An untyped version of GetSymbol.
181 port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
182 const string &symbol_name, ModuleHandle module_handle = {});
183
184 // Deallocate the DeviceMemory previously allocated via this interface.
185 // Deallocation of a nullptr-representative value is permitted.
186 //
187 // Resets the internal contents of mem to be null-representative, but this
188 // null-out effect should not be relied upon in client code.
189 void Deallocate(DeviceMemoryBase *mem);
190
191 // Retrieves a mapping of active opaque device memory pointer to a string
192 // representation of the [allocating thread's] stack at the time the pointer
193 // was allocated. Useful for tracking device memory leaks.
194 //
195 // Note: this will only be populated if --check_device_leaks flag is
196 // activated.
197 void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
198
199 // Allocates unified memory space of the given size, if supported.
200 // See
201 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
202 // for more details on unified memory.
203 void *UnifiedMemoryAllocate(uint64 bytes);
204
205 // Deallocates unified memory space previously allocated with
206 // UnifiedMemoryAllocate.
207 void UnifiedMemoryDeallocate(void *location);
208
209 // Allocates a region of host memory and registers it with the platform API.
210 // Memory allocated in this manner (or allocated and registered with
211 // HostMemoryRegister() is required for use in asynchronous memcpy operations,
212 // such as Stream::ThenMemcpy.
213 void *HostMemoryAllocate(uint64 size);
214
215 // Deallocates a region of host memory allocated by HostMemoryAllocate().
216 void HostMemoryDeallocate(void *location);
217
218 // Registers a region of host memory with the platform API. Registered memory
219 // (or memory allocated with HostMemoryAllocate) is required for use with
220 // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
221 // is used to register memory allocated outside the StreamExecutor;
222 // HostMemoryAllocate implicitly registers its allocations and
223 // HostMemoryDeallocate implicitly deregisters on deallocation.
224 bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
225
226 // Unregisters a region of host memory registered with HostMemoryRegister.
227 // This should be done before deallocating the region with delete[]/free/etc.
228 bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
229
230 // Synchronizes all activity occurring in the StreamExecutor's context (most
231 // likely a whole device).
232 bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
233
234 // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
235 // given location in device memory.
236 port::Status SynchronousMemZero(DeviceMemoryBase *location,
237 uint64 size) SE_MUST_USE_RESULT;
238
239 // Blocks the caller while "size" bytes are initialized to "value" (in POD
240 // fashion) at the given location in device memory.
241 port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
242 uint64 size) SE_MUST_USE_RESULT;
243
244 // [deprecated] Blocks the caller while a data segment of the given size is
245 // copied from the host source to the device destination.
246 ABSL_DEPRECATED(
247 "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
248 bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
249 uint64 size) SE_MUST_USE_RESULT;
250
251 // [deprecated] Blocks the caller while a data segment of the given size is
252 // copied from the device source to the host destination.
253 ABSL_DEPRECATED(
254 "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
255 bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
256 uint64 size) SE_MUST_USE_RESULT;
257
258 // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
259 port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
260 DeviceMemoryBase *device_dst);
261
262 // Alternative interface for memcpying from host to device that takes an
263 // array slice. Checks that the destination size can accommodate the host
264 // slice size.
265 template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)266 port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
267 DeviceMemoryBase *device_dst) {
268 auto host_size = host_src.size() * sizeof(T);
269 CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
270 return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
271 }
272
273 // Same as SynchronousMemcpy(void*, ...) above.
274 port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
275 int64 size, void *host_dst);
276
277 // Alternative interface for memcpying from device to host that takes an
278 // array slice. Checks that the destination size can accommodate the host
279 // slice size.
280 template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)281 port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
282 port::MutableArraySlice<T> host_dst) {
283 auto host_size = host_dst.size() * sizeof(T);
284 CHECK(device_src.size() == 0 || host_size >= device_src.size());
285 return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
286 }
287
288 // Blocks the caller while a data segment of the given size is copied from the
289 // device source to the device destination.
290 bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
291 const DeviceMemoryBase &device_src,
292 uint64 size) SE_MUST_USE_RESULT;
293
294 // Enqueues an operation onto stream to zero out size bytes at the given
295 // device memory location. Neither stream nor location may be null. Returns
296 // whether the operation was successfully enqueued onto the stream.
297 port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
298 uint64 size) SE_MUST_USE_RESULT;
299
300 // Enqueues an operation onto stream to set 32-bit patterns starting at
301 // location, for byte count given by size. size must be 32-bit quantified
302 // (i.e. evently divisible by 4). Returns whether the operation was
303 // successfully enqueued onto the stream.
304 port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
305 uint32 pattern, uint64 size);
306
307 // Enables peer access from this StreamExecutor to memory
308 // allocated by other, such that launched device code, memcpies, etc may
309 // access it directly.
310 //
311 // Both this StreamExecutor and other must be backed by the same platform (as
312 // in
313 // CUDA vs OpenCL) implementation.
314 port::Status EnablePeerAccessTo(StreamExecutor *other);
315
316 // Returns whether it's possible to enable peer access from this
317 // StreamExecutor
318 // to memory allocated by another.
319 //
320 // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
321 // this is more an up-front test as to whether it's expressly forbidden.
322 bool CanEnablePeerAccessTo(StreamExecutor *other);
323
324 // Gets the preferred shared memory configuration for the device to which this
325 // executor is bound.
326 SharedMemoryConfig GetDeviceSharedMemoryConfig();
327
328 // Sets the preferred shared memory configuration for the device to which this
329 // executor is bound.
330 port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
331
332 // Obtains metadata about the underlying device.
333 // The value is cached on first use.
334 const DeviceDescription &GetDeviceDescription() const;
335
336 // If implemented, returns device specific measurement of load
337 // (e.g. pending requests).
338 int64 GetDeviceLoad() const;
339
340 // Returns the underlying device memory usage information, if it is available.
341 // If it is not available (false is returned), free/total may not be
342 // initialized.
343 //
344 // Note: "Free" reflects the amount of free memory on the underlying device,
345 // so allocations via other StreamExecutors that have the same underlying
346 // device
347 // will be reflected in "free".
348 bool DeviceMemoryUsage(int64 *free, int64 *total) const;
349
350 // The device count reported by this StreamExecutor's platform.
351 // Note: on OpenCL we implicitly select platform zero at the moment.
352 int PlatformDeviceCount() const;
353
354 // Returns whether the StreamExecutor supports BLAS routines for the platform
355 // that underlies this interface.
356 bool SupportsBlas() const;
357
358 // Returns whether the StreamExecutor supports FFT routines for the platform
359 // that underlies this interface.
360 bool SupportsFft() const;
361
362 // Returns whether the StreamExecutor supports RNG routines for the platform
363 // that underlies this interface.
364 bool SupportsRng() const;
365
366 // Returns whether the StreamExecutor support neural net routines for the
367 // platform that underlies this interface.
368 bool SupportsDnn() const;
369
370 // Returns the list of supported algorithms for the forward convolution
371 // operation.
372 bool GetConvolveAlgorithms(bool with_winograd_nonfused,
373 std::vector<dnn::AlgorithmDesc> *out_algorithms);
374
375 // Returns the list of supported algorithms for the forward convolution
376 // operation.
377 bool GetMIOpenConvolveAlgorithms(
378 dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
379 const dnn::BatchDescriptor &input_descriptor,
380 const dnn::FilterDescriptor &filter_descriptor,
381 const dnn::ConvolutionDescriptor &convolution_descriptor,
382 const dnn::BatchDescriptor &output_descriptor,
383 std::vector<dnn::ProfileResult> *out_algorithms);
384
385 // Returns the list of supported algorithms for rnn operation.
386 bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
387
388 // Get the list of supported algorithms for the backward convolution on data.
389 bool GetConvolveBackwardDataAlgorithms(
390 bool with_winograd_nonfused,
391 std::vector<dnn::AlgorithmDesc> *out_algorithms);
392
393 // Get the list of supported algorithms for the backward convolution on the
394 // filter.
395 bool GetConvolveBackwardFilterAlgorithms(
396 bool with_winograd_nonfused,
397 std::vector<dnn::AlgorithmDesc> *out_algorithms);
398
399 // Get the list of supported algorithms for BLAS gemm.
400 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
401
402 // Create an RNN descriptor based on model shapes and configurations.
403 // The caller retains the ownership of the descriptor.
404 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
405 int num_layers, int hidden_size, int input_size, int cell_size,
406 int batch_size, dnn::RnnInputMode input_mode,
407 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
408 dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
409 float dropout, uint64 seed, ScratchAllocator *state_allocator,
410 bool use_padded_io);
411
412 // Create a RNN sequence descriptor that specifies either the input or output
413 // sequence. The caller retains the ownership of the returned descriptor.
414 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
415 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
416 int data_size, dnn::DataType data_type);
417
418 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
419 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
420 int data_size,
421 const absl::Span<const int> &seq_lengths,
422 bool time_major, dnn::DataType data_type);
423
424 // Create an RNN state descriptor that specifies the input or hidden state.
425 // The caller retains the ownership of the returned descriptor.
426 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
427 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
428 dnn::DataType data_type);
429
430 // Returns the device ordinal that this StreamExecutor was initialized with.
431 // Meaningless before initialization.
device_ordinal()432 int device_ordinal() const { return device_ordinal_; }
433
434 // Returns a borrowed pointer to the underlying StreamExecutor implementation.
435 internal::StreamExecutorInterface *implementation();
436
437 // Creates a kernel which can be launched with stream.ThenLaunch, such that
438 // the types of the arguments provided for launch would have to match
439 // types of the arguments provided at creation time.
440 //
441 // The kernel has a name kernel_name, and is based from provided PTX in ptx,
442 // and (optional) compiled PTX in cubin_data.
443 // The canonical storage for both ptx and cubin_data should outlive the
444 // lifetime of the kernel.
445 template <typename... Args>
446 port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
447 absl::string_view kernel_name, absl::string_view ptx,
448 absl::Span<const uint8> cubin_data);
449
450 // Warning: use Stream::ThenLaunch instead, this method is not for general
451 // consumption. However, this is the only way to launch a kernel for which
452 // the type signature is only known at runtime; say, if an application
453 // supports loading/launching kernels with arbitrary type signatures.
454 // In this case, the application is expected to know how to do parameter
455 // packing that obeys the contract of the underlying platform implementation.
456 //
457 // Launches a data parallel kernel with the given thread/block
458 // dimensionality and already-packed args/sizes to pass to the underlying
459 // platform driver.
460 //
461 // This is called by Stream::Launch() to delegate to the platform's launch
462 // implementation in StreamExecutorInterface::Launch().
463 port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
464 const BlockDim &block_dims, const KernelBase &kernel,
465 const KernelArgsArrayBase &args);
466
467 // Gets-or-creates (creates with memoization) a FftSupport datatype that can
468 // be used to execute FFT routines on the current platform.
469 //
470 // Ownership and user-facing is the same as AsBlas() below.
471 //
472 // Returns null if there was an error initializing the FFT support for the
473 // underlying platform.
474 fft::FftSupport *AsFft();
475
476 // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
477 // be used for neural network routines on the current platform.
478 //
479 // Ownership and user-facing is the same as AsBlas() below.
480 //
481 // Returns null if there was an error initializing the DNN support for the
482 // underlying platform.
483 dnn::DnnSupport *AsDnn();
484
485 // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
486 // be used to execute BLAS routines on the current platform. This is typically
487 // not user-facing, as users will use the Stream::ThenBlas* family of routines
488 // to entrain BLAS operations. See blas.h for additional details.
489 //
490 // Ownership is not transferred to the caller -- ownership is retained by this
491 // object for memoization. This BLAS interface is also only expected to be
492 // used by a Stream for entraining calls to BLAS functionality.
493 //
494 // Returns null if there was an error initializing the BLAS support for the
495 // underlying platform.
496 blas::BlasSupport *AsBlas();
497
498 // Turns StreamExecutor operation tracing on or off.
499 void EnableTracing(bool enable);
500
501 // Registers a trace listener to receive callbacks for only a single
502 // StreamExecutor instance.
503 // To register a listener for all executors for a given platform, see
504 // Platform::RegisterTraceListener().
505 // Does not take ownership of listener.
506 void RegisterTraceListener(TraceListener* listener);
507
508 // Removes a TraceListener from this StreamExecutor instance.
509 // Returns false (and logs) in cases where the argument listener was not
510 // previously registered.
511 bool UnregisterTraceListener(TraceListener* listener);
512
513 // Return allocator statistics.
514 absl::optional<AllocatorStats> GetAllocatorStats();
515
516 // Return an allocator which delegates to this stream executor for memory
517 // allocation.
GetAllocator()518 StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
519
520 private:
521 template <typename BeginCallT, typename CompleteCallT,
522 typename ReturnT, typename... BeginArgsT>
523 friend class ScopedTracer;
524 friend class Event;
525 friend class Stream;
526 friend class Timer;
527 template <typename... Params>
528 friend class TypedKernel;
529 template <typename... Args>
530 friend struct ThenBlasImpl;
531
532 // Synchronously allocates size bytes on the underlying platform and returns
533 // a DeviceMemoryBase representing that allocation. In the case of failure,
534 // nullptr is returned.
535 DeviceMemoryBase Allocate(uint64 size, int64 memory_space);
536
537 // Gets-or-creates (creates with memoization) an RngSupport datatype that can
538 // be used for random-number-generation routines on the current platform.
539 //
540 // Ownership and user-facing is the same as AsBlas() above.
541 //
542 // Returns null if there was an error initializing the RNG support for the
543 // underlying platform.
544 rng::RngSupport *AsRng();
545
546 // Causes the host code to synchronously wait for operations entrained onto
547 // stream to complete. Effectively a join on the asynchronous device
548 // operations enqueued on the stream before this program point.
549 port::Status BlockHostUntilDone(Stream *stream);
550
551 // Without blocking the device, retrieve the current stream status.
552 port::Status GetStatus(Stream *stream);
553
554 // Finds and retrieves device memory for the symbol on the underlying
555 // platform.
556 bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
557 void **mem, size_t *bytes);
558
559 // Entrains a memcpy operation onto stream, with a host destination location
560 // host_dst and a device memory source, with target size size.
561 bool Memcpy(Stream *stream, void *host_dst,
562 const DeviceMemoryBase &device_src, uint64 size);
563
564 // Entrains a memcpy operation onto stream, with a device destination location
565 // and a host memory source, with target size size.
566 bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
567 const void *host_src, uint64 size);
568
569 // Entrains a memcpy operation onto stream, with a device destination location
570 // and a device source location, with target size size. Peer access should
571 // have been enabled between the StreamExecutors owning the device memory
572 // regions.
573 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
574 const DeviceMemoryBase &device_src, uint64 size);
575
576 // Entrains on a stream a user-specified function to be run on the host.
577 // See Stream::ThenDoHostCallback for full details.
578 bool HostCallback(Stream *stream, std::function<void()> callback);
579
580 // Entrains on a stream a user-specified function to be run on the host.
581 // See Stream::ThenDoHostCallback for full details.
582 // This is the preferred form for a callback that may return an error.
583 bool HostCallback(Stream *stream, std::function<port::Status()> callback);
584
585 // Performs platform-specific allocation and initialization of an event.
586 port::Status AllocateEvent(Event *event);
587
588 // Performs platform-specific deallocation and cleanup of an event.
589 port::Status DeallocateEvent(Event *event);
590
591 // Inserts the specified event at the end of the specified stream.
592 port::Status RecordEvent(Stream *stream, Event *event);
593
594 // Wait for the specified event at the end of the specified stream.
595 port::Status WaitForEvent(Stream *stream, Event *event);
596
597 // Requests the current status of the event from the underlying platform.
598 Event::Status PollForEventStatus(Event *event);
599
600 // Allocates stream resources on the underlying platform and initializes its
601 // internals.
602 bool AllocateStream(Stream *stream);
603
604 // Deallocates stream resources on the underlying platform.
605 void DeallocateStream(Stream *stream);
606
607 // Causes dependent to not begin execution until other has finished its
608 // last-enqueued work.
609 bool CreateStreamDependency(Stream *dependent, Stream *other);
610
611 // Allocates timer resources on the underlying platform and initializes its
612 // internals.
613 bool AllocateTimer(Timer *timer);
614
615 // Deallocates timer resources on the underlying platform.
616 void DeallocateTimer(Timer *timer);
617
618 // Records a start event for an interval timer.
619 bool StartTimer(Stream *stream, Timer *timer);
620
621 // Records a stop event for an interval timer.
622 bool StopTimer(Stream *stream, Timer *timer);
623
624 // Allocates a new metadata object, appropriately populated, on the heap, with
625 // ownership transfer to caller.
626 std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
627
628 // Adds a task to the port::ThreadPool work queue. These tasks must be
629 // fire-and-forget and have no external data or timing dependencies; their
630 // execution order and completion time have no guarantees.
631 // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
632 // there, temporary internal buffers are freed using this method.
633 void EnqueueOnBackgroundThread(std::function<void()> task);
634
635 // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
636 // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
637 // tracked.
638 void CreateAllocRecord(void *opaque, uint64 bytes);
639
640 // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
641 // pointers will not be erased (as they're not tracked, per above).
642 void EraseAllocRecord(void *opaque);
643
644 // Calls the relevant TraceListener routine to begin tracing for the specified
645 // asynchronous method.
646 template <typename TraceCallT, typename... ArgsT>
647 void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
648
649 // Reader/writer lock for class-static StreamExecutor members.
650 static absl::Mutex static_mu_;
651
652 // Reader/writer lock for mutable data structures on this StreamExecutor.
653 //
654 // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
655 // can acquire the lock on their first (mutating) call as well.
656 mutable absl::Mutex mu_;
657
658 // Reference to the platform that created this executor.
659 const Platform *platform_;
660
661 // Pointer to the platform-specific-interface implementation. This is
662 // delegated to by the interface routines in pointer-to-implementation
663 // fashion.
664 std::unique_ptr<internal::StreamExecutorInterface> implementation_;
665
666 // A mapping of pointer (to device memory) to string representation of the
667 // stack (of the allocating thread) at the time at which the pointer was
668 // allocated.
669 std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
670
671 // Memoized BLAS support object -- we only want to create this once when asked
672 // for a BLAS interface.
673 std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
674
675 // Memoized DNN support object -- we only want to create this once when asked
676 // for an DNN interface.
677 std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
678
679 // Memoized FFT support object -- we only want to create this once when asked
680 // for a FFT interface.
681 std::unique_ptr<fft::FftSupport> fft_;
682
683 // Memoized RNG support object -- we only want to create this once when asked
684 // for an RNG interface.
685 std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
686
687 // Slot to cache the owned DeviceDescription for the underlying device
688 // once it has been queried from DeviceDescription().
689 mutable std::unique_ptr<DeviceDescription> device_description_
690 GUARDED_BY(mu_);
691
692 // The kind of the underlying platform that is being targeted, as passed
693 // during construction.
694 //
695 // Immutable post-initialization.
696 PlatformKind platform_kind_;
697
698 // The device ordinal that this object was initialized with.
699 //
700 // Immutable post-initialization.
701 int device_ordinal_;
702
703 // Executor for handling host callback work that cannot be performed
704 // by a host callback thread - for example, cleanup after a host BLAS routine
705 // (which may make device API calls). This work cannot block the host
706 // callback thread, will be completed asynchronously, and should be treated
707 // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
708 // here.
709 //
710 // Immutable post-initialization. Object is thread-safe.
711 std::unique_ptr<port::ThreadPool> background_threads_;
712
713 // Counter for the current number of live streams. This is used to check
714 // for accidentally-outstanding streams at StreamExecutor teardown time, as
715 // well
716 // as to indicate leaks (via a large outstanding count being logged) in the
717 // case we can't allocate more streams.
718 std::atomic_int_fast32_t live_stream_count_;
719
720 // Only one worker thread is needed; little work will be done by the
721 // executor.
722 static const int kNumBackgroundThreads = 1;
723
724 // Indicates if StreamExecutor operation tracing should be performed.
725 bool tracing_enabled_;
726
727 // The set of TraceListeners registered for this StreamExecutor.
728 std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
729
730 // Allocated memory in bytes.
731 int64 mem_alloc_bytes_;
732
733 // Memory limit in bytes. Value less or equal to 0 indicates there is no
734 // limit.
735 int64 memory_limit_bytes_;
736
737 StreamExecutorMemoryAllocator allocator_;
738
739 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
740 };
741
742 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
743 class ScopedModuleHandle {
744 public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)745 explicit ScopedModuleHandle(StreamExecutor *executor,
746 ModuleHandle module_handle)
747 : executor_(executor), module_handle_(module_handle) {}
748
ScopedModuleHandle(ScopedModuleHandle && other)749 ScopedModuleHandle(ScopedModuleHandle &&other) {
750 executor_ = other.executor_;
751 module_handle_ = other.module_handle_;
752 other.executor_ = nullptr;
753 other.module_handle_ = ModuleHandle();
754 }
755
756 ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
757 executor_ = other.executor_;
758 module_handle_ = other.module_handle_;
759 other.executor_ = nullptr;
760 other.module_handle_ = ModuleHandle();
761 return *this;
762 }
763
~ScopedModuleHandle()764 ~ScopedModuleHandle() {
765 if (static_cast<bool>(module_handle_)) {
766 CHECK(executor_->UnloadModule(module_handle_));
767 }
768 }
769
770 private:
771 StreamExecutor *executor_;
772 ModuleHandle module_handle_;
773
774 TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
775 };
776
777 ////////////
778 // Inlines
779
780 template <typename... Args>
781 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)782 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
783 absl::string_view ptx,
784 absl::Span<const uint8> cubin_data) {
785 auto kernel_base = absl::make_unique<TypedKernel<Args...>>(this);
786 MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
787 loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
788
789 if (!cubin_data.empty()) {
790 loader_spec.AddCudaCubinInMemory(
791 reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
792 }
793
794 TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
795 return std::move(kernel_base);
796 }
797
798 template <typename T>
AllocateArray(uint64 element_count,int64 memory_space)799 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
800 int64 memory_space) {
801 uint64 bytes = sizeof(T) * element_count;
802 return DeviceMemory<T>(Allocate(bytes, memory_space));
803 }
804
805 template <typename T>
GetSymbol(const string & symbol_name,ModuleHandle module_handle)806 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
807 const string &symbol_name, ModuleHandle module_handle) {
808 port::StatusOr<DeviceMemoryBase> untyped_symbol =
809 GetUntypedSymbol(symbol_name, module_handle);
810 if (!untyped_symbol.ok()) {
811 return untyped_symbol.status();
812 }
813 return DeviceMemory<T>(untyped_symbol.ValueOrDie());
814 }
815
816 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)817 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
818 DeviceMemoryBase value)
819 : wrapped_(value),
820 device_ordinal_(parent->device_ordinal()),
821 allocator_(parent->GetAllocator()) {}
822
823 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)824 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
825 StreamExecutor *parent, std::initializer_list<ElemT> values)
826 : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
827 if (ptr() != nullptr) {
828 std::vector<ElemT> local(values);
829 if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
830 ptr()->size())) {
831 TF_CHECK_OK(Free());
832 }
833 }
834 }
835
836 template <typename T>
AllocateZeroed()837 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
838 DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
839 if (buf.is_null()) {
840 return DeviceMemory<T>{};
841 }
842
843 DeviceMemory<T> result(buf);
844 bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
845 if (!ok) {
846 Deallocate(&result);
847 return DeviceMemory<T>{};
848 }
849
850 return result;
851 }
852
853 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)854 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T> *parent,
855 uint64 element_offset,
856 uint64 element_count) {
857 if (element_offset + element_count > parent->ElementCount()) {
858 LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
859 << "than parent allocation size: (" << element_offset << " + "
860 << element_count << ") vs. (" << parent->ElementCount() << ")";
861 return DeviceMemory<T>{};
862 }
863
864 void *opaque = implementation_->GetSubBuffer(
865 parent, sizeof(T) * element_offset, sizeof(T) * element_count);
866 if (opaque == nullptr) {
867 return DeviceMemory<T>{};
868 }
869 return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
870 }
871
872 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)873 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
874 const TypedKernel<Params...> &kernel,
875 Args... args) {
876 KernelInvocationChecker<std::tuple<Params...>,
877 std::tuple<Args...>>::CheckAllStaticAssert();
878 if (ok()) {
879 // This is the core that allows type-safe kernel launching.
880 // Since the platforms take kernel arguments as tuples of (void *, size),
881 // we pack the variadic parameters passed as ...args into the desired
882 // tuple form and pass that packed form to the StreamExecutor::Launch()
883 // implementation.
884 KernelArgsArray<sizeof...(args)> kernel_args;
885 kernel.PackParams(&kernel_args, args...);
886 DCHECK(parent_ != nullptr);
887 bool ok =
888 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)
889 .ok();
890 if (!ok) {
891 SetError();
892 LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
893 }
894 }
895 return *this;
896 }
897
898 } // namespace stream_executor
899
900 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
901