1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24
25 #include "absl/base/macros.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/stream_executor/lib/status.h"
28 #include "tensorflow/stream_executor/lib/statusor.h"
29 #include "tensorflow/stream_executor/lib/threadpool.h"
30 #include "tensorflow/stream_executor/platform.h"
31 #include "tensorflow/stream_executor/platform/logging.h"
32 #include "tensorflow/stream_executor/platform/mutex.h"
33 #include "tensorflow/stream_executor/platform/port.h"
34 #include "tensorflow/stream_executor/platform/thread_annotations.h"
35 #include "tensorflow/stream_executor/rng.h"
36 #include "tensorflow/stream_executor/shared_memory_config.h"
37 #include "tensorflow/stream_executor/stream.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/trace_listener.h"
40
41 namespace stream_executor {
42
43 // Structure used for device memory leak checking.
44 struct AllocRecord {
45 // The requested allocation size of the buffer.
46 uint64 bytes;
47
48 // Holds a representation of the stack at the time the associated buffer was
49 // allocated. Produced in a form described in
50 // //util/symbolize/symbolized_stacktrace.h.
51 string stack_trace;
52 };
53
54 // Forward declaration of private friend class.
55 template <typename BeginCallT, typename CompleteCallT,
56 typename ReturnT, typename... BeginArgsT>
57 class ScopedTracer;
58
59 // A StreamExecutor manages a single device, in terms of executing work (kernel
60 // launches) and memory management (allocation/deallocation, memory copies to
61 // and from the device). It is conceptually the "handle" for a device -- Stream
62 // objects, which are used to enqueue work to run on the
63 // coprocessor have a StreamExecutor instance as their "parent" object.
64 //
65 // StreamExecutor objects have an underlying platform that is specified up
66 // front;
67 // e.g. either it is a CUDA or OpenCL executor.
68 //
69 // Thread-safe after initialization.
70 // StreamExecutor interface should not be invoked from a signal handler.
71 class StreamExecutor {
72 public:
73 explicit StreamExecutor(PlatformKind kind,
74 const PluginConfig &plugin_config = PluginConfig());
75
76 StreamExecutor(
77 const Platform *platform,
78 std::unique_ptr<internal::StreamExecutorInterface> implementation);
79
80 ~StreamExecutor();
81
82 port::Status Init();
83 port::Status Init(int device_ordinal, DeviceOptions device_options);
84
85 // Returns the platform that this StreamExecutor is acting upon.
86 ABSL_DEPRECATED("Use platform() instead.")
platform_kind()87 PlatformKind platform_kind() const { return platform_kind_; }
88
89 // Returns a reference to the platform that created this executor.
platform()90 const Platform *platform() const { return platform_; }
91
92 // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
93 // upon, if one exists.
94 //
95 // Parameters:
96 // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
97 // constant into an appropriate namespace. For example, see
98 // stream_executor::executor_sample::kKernelLoaderSpecs, from which a
99 // MultiKernelLoaderSpec is selected.
100 // kernel: Outparam that the kernel is loaded into. A given Kernel
101 // instantiation should not be loaded into more than once.
102 //
103 // If an error occurs, or there is no kernel available for the StreamExecutor
104 // platform, false is returned.
105 bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
106
107 // Releases any state associated with the previously loaded kernel.
108 void UnloadKernel(const KernelBase *kernel);
109
110 // Loads a module for the platform this StreamExecutor is acting upon.
111 //
112 // `spec` describes the module to be loaded. On success writes the handle for
113 // the loaded module to `module_handle` and returns true. Else returns false.
114 bool LoadModule(const MultiModuleLoaderSpec &spec,
115 ModuleHandle *module_handle);
116
117 // Unloads the module with handle `module_handle`.
118 bool UnloadModule(ModuleHandle module_handle);
119
120 // Synchronously allocates an array on the device of type T with element_count
121 // elements.
122 template <typename T>
123 DeviceMemory<T> AllocateArray(uint64 element_count);
124
125 // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
126 template <typename T>
AllocateOwnedArray(uint64 element_count)127 ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
128 return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
129 }
130
131 // Convenience wrapper that allocates space for a single element of type T in
132 // device memory.
133 template <typename T>
AllocateScalar()134 DeviceMemory<T> AllocateScalar() {
135 return AllocateArray<T>(1);
136 }
137
138 // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
139 template <typename T>
AllocateOwnedScalar()140 ScopedDeviceMemory<T> AllocateOwnedScalar() {
141 return AllocateOwnedArray<T>(1);
142 }
143
144 // Synchronously allocates a scalar of type T on the device that is (POD)
145 // zero-byte initialized.
146 template <typename T>
147 DeviceMemory<T> AllocateZeroed();
148
149 // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
150 template <typename T>
AllocateOwnedZeroed()151 ScopedDeviceMemory<T> AllocateOwnedZeroed() {
152 return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
153 }
154
155 // Allocate a memory region inside another allocated memory region.
156 // Offset and size are specified in terms of T elements.
157 // Warning: Do not free a parent buffer before its sub-buffers; this may cause
158 // use-after-free issues (the specific behavior is not consistent across
159 // platforms).
160 // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
161 // sub-buffer after parent deallocation is expected to be safe. This will
162 // render your code non-platform-portable, however.
163 template <typename T>
164 DeviceMemory<T> AllocateSubBuffer(DeviceMemory<T> *parent,
165 uint64 element_offset,
166 uint64 element_count);
167
168 // As AllocateSubBuffer(), but returns a ScopedDeviceMemory<T>.
169 template <typename T>
AllocateOwnedSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)170 ScopedDeviceMemory<T> AllocateOwnedSubBuffer(DeviceMemory<T> *parent,
171 uint64 element_offset,
172 uint64 element_count) {
173 return ScopedDeviceMemory<T>(
174 this, AllocateSubBuffer<T>(parent, element_offset, element_count));
175 }
176
177 // Finds a symbol and returns device memory allocated to the symbol. The
178 // symbol is searched in any kernels that were previously loaded through
179 // GetKernel() before the GetSymbol() call. The user has to make sure that the
180 // type of symbol and T match.
181 // - Note: symbol_name should include its namespace as well. For example,
182 // pass "nms0::symbol" if referring to nms0::symbol.
183 //
184 // If `module_handle` is set then searches only within the module
185 // corresponding to `module_handle`.
186 template <typename T>
187 port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
188 ModuleHandle module_handle = {});
189
190 // An untyped version of GetSymbol.
191 port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
192 const string &symbol_name, ModuleHandle module_handle = {});
193
194 // Deallocate the DeviceMemory previously allocated via this interface.
195 // Deallocation of a nullptr-representative value is permitted.
196 //
197 // Resets the internal contents of mem to be null-representative, but this
198 // null-out effect should not be relied upon in client code.
199 //
200 // TODO(jlebar): Change this to accept a DeviceMemoryBase by value, see
201 // discussion in cl/195744342.
202 void Deallocate(DeviceMemoryBase *mem);
203
204 // Retrieves a mapping of active opaque device memory pointer to a string
205 // representation of the [allocating thread's] stack at the time the pointer
206 // was allocated. Useful for tracking device memory leaks.
207 //
208 // Note: this will only be populated if --check_device_leaks flag is
209 // activated.
210 void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
211
212 // Allocates unified memory space of the given size, if supported.
213 // See
214 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
215 // for more details on unified memory.
216 void *UnifiedMemoryAllocate(uint64 bytes);
217
218 // Deallocates unified memory space previously allocated with
219 // UnifiedMemoryAllocate.
220 void UnifiedMemoryDeallocate(void *location);
221
222 // Allocates a region of host memory and registers it with the platform API.
223 // Memory allocated in this manner (or allocated and registered with
224 // HostMemoryRegister() is required for use in asynchronous memcpy operations,
225 // such as Stream::ThenMemcpy.
226 void *HostMemoryAllocate(uint64 bytes);
227
228 // Deallocates a region of host memory allocated by HostMemoryAllocate().
229 void HostMemoryDeallocate(void *location);
230
231 // Registers a region of host memory with the platform API. Registered memory
232 // (or memory allocated with HostMemoryAllocate) is required for use with
233 // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
234 // is used to register memory allocated outside the StreamExecutor;
235 // HostMemoryAllocate implicitly registers its allocations and
236 // HostMemoryDeallocate implicitly deregisters on deallocation.
237 bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
238
239 // Unregisters a region of host memory registered with HostMemoryRegister.
240 // This should be done before deallocating the region with delete[]/free/etc.
241 bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
242
243 // Synchronizes all activity occurring in the StreamExecutor's context (most
244 // likely a whole device).
245 bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
246
247 // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
248 // given location in device memory.
249 bool SynchronousMemZero(DeviceMemoryBase *location,
250 uint64 size) SE_MUST_USE_RESULT;
251
252 // Blocks the caller while "size" bytes are initialized to "value" (in POD
253 // fashion) at the given location in device memory.
254 bool SynchronousMemSet(DeviceMemoryBase *location, int value,
255 uint64 size) SE_MUST_USE_RESULT;
256
257 // [deprecated] Blocks the caller while a data segment of the given size is
258 // copied from the host source to the device destination.
259 ABSL_DEPRECATED(
260 "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
261 bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
262 uint64 size) SE_MUST_USE_RESULT;
263
264 // [deprecated] Blocks the caller while a data segment of the given size is
265 // copied from the device source to the host destination.
266 ABSL_DEPRECATED(
267 "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
268 bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
269 uint64 size) SE_MUST_USE_RESULT;
270
271 // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
272 port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
273 DeviceMemoryBase *device_dst);
274
275 // Alternative interface for memcpying from host to device that takes an
276 // array slice. Checks that the destination size can accommodate the host
277 // slice size.
278 template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)279 port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
280 DeviceMemoryBase *device_dst) {
281 auto host_size = host_src.size() * sizeof(T);
282 CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
283 return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
284 }
285
286 // Same as SynchronousMemcpy(void*, ...) above.
287 port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
288 int64 size, void *host_dst);
289
290 // Alternative interface for memcpying from device to host that takes an
291 // array slice. Checks that the destination size can accommodate the host
292 // slice size.
293 template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)294 port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
295 port::MutableArraySlice<T> host_dst) {
296 auto host_size = host_dst.size() * sizeof(T);
297 CHECK(device_src.size() == 0 || host_size >= device_src.size());
298 return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
299 }
300
301 // Blocks the caller while a data segment of the given size is copied from the
302 // device source to the device destination.
303 bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
304 const DeviceMemoryBase &device_src,
305 uint64 size) SE_MUST_USE_RESULT;
306
307 // Enqueues an operation onto stream to zero out size bytes at the given
308 // device memory location. Neither stream nor location may be null. Returns
309 // whether the operation was successfully enqueued onto the stream.
310 bool MemZero(Stream *stream, DeviceMemoryBase *location,
311 uint64 size) SE_MUST_USE_RESULT;
312
313 // Enqueues an operation onto stream to set 32-bit patterns starting at
314 // location, for byte count given by size. size must be 32-bit quantified
315 // (i.e. evently divisible by 4). Returns whether the operation was
316 // successfully enqueued onto the stream.
317 bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
318 uint64 size) SE_MUST_USE_RESULT;
319
320 // Enables peer access from this StreamExecutor to memory
321 // allocated by other, such that launched device code, memcpies, etc may
322 // access it directly.
323 //
324 // Both this StreamExecutor and other must be backed by the same platform (as
325 // in
326 // CUDA vs OpenCL) implementation.
327 port::Status EnablePeerAccessTo(StreamExecutor *other);
328
329 // Returns whether it's possible to enable peer access from this
330 // StreamExecutor
331 // to memory allocated by another.
332 //
333 // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
334 // this is more an up-front test as to whether it's expressly forbidden.
335 bool CanEnablePeerAccessTo(StreamExecutor *other);
336
337 // Gets the preferred shared memory configuration for the device to which this
338 // executor is bound.
339 SharedMemoryConfig GetDeviceSharedMemoryConfig();
340
341 // Sets the preferred shared memory configuration for the device to which this
342 // executor is bound.
343 port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
344
345 // Obtains metadata about the underlying device.
346 // The value is cached on first use.
347 const DeviceDescription &GetDeviceDescription() const;
348
349 // If implemented, returns device specific measurement of load
350 // (e.g. pending requests).
351 int64 GetDeviceLoad() const;
352
353 // Returns the underlying device memory usage information, if it is available.
354 // If it is not available (false is returned), free/total may not be
355 // initialized.
356 //
357 // Note: "Free" reflects the amount of free memory on the underlying device,
358 // so allocations via other StreamExecutors that have the same underlying
359 // device
360 // will be reflected in "free".
361 bool DeviceMemoryUsage(int64 *free, int64 *total) const;
362
363 // The device count reported by this StreamExecutor's platform.
364 // Note: on OpenCL we implicitly select platform zero at the moment.
365 int PlatformDeviceCount() const;
366
367 // Returns whether the StreamExecutor supports BLAS routines for the platform
368 // that underlies this interface.
369 bool SupportsBlas() const;
370
371 // Returns whether the StreamExecutor supports FFT routines for the platform
372 // that underlies this interface.
373 bool SupportsFft() const;
374
375 // Returns whether the StreamExecutor supports RNG routines for the platform
376 // that underlies this interface.
377 bool SupportsRng() const;
378
379 // Returns whether the StreamExecutor support neural net routines for the
380 // platform that underlies this interface.
381 bool SupportsDnn() const;
382
383 // Returns the list of supported algorithms for the forward convolution
384 // operation.
385 bool GetConvolveAlgorithms(bool with_winograd_nonfused,
386 std::vector<dnn::AlgorithmDesc> *out_algorithms);
387
388 // Returns the list of supported algorithms for rnn operation.
389 bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
390
391 // Get the list of supported algorithms for the backward convolution on data.
392 bool GetConvolveBackwardDataAlgorithms(
393 bool with_winograd_nonfused,
394 std::vector<dnn::AlgorithmDesc> *out_algorithms);
395
396 // Get the list of supported algorithms for the backward convolution on the
397 // filter.
398 bool GetConvolveBackwardFilterAlgorithms(
399 bool with_winograd_nonfused,
400 std::vector<dnn::AlgorithmDesc> *out_algorithms);
401
402 // Get the list of supported algorithms for BLAS gemm.
403 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
404
405 // Create an RNN descriptor based on model shapes and configurations.
406 // The caller retains the ownership of the descriptor.
407 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
408 int num_layers, int hidden_size, int input_size, int batch_size,
409 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
410 dnn::RnnMode rnn_mode, dnn::DataType data_type,
411 const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
412 ScratchAllocator *state_allocator);
413
414 // Create a RNN sequence descriptor that specifies either the input or output
415 // sequence. The caller retains the ownership of the returned descriptor.
416 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
417 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
418 int data_size, dnn::DataType data_type);
419
420 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
421 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
422 int data_size,
423 const absl::Span<const int> &seq_lengths,
424 bool time_major, dnn::DataType data_type);
425
426 // Create an RNN state descriptor that specifies the input or hidden state.
427 // The caller retains the ownership of the returned descriptor.
428 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
429 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
430 dnn::DataType data_type);
431
432 // Returns the device ordinal that this StreamExecutor was initialized with.
433 // Meaningless before initialization.
device_ordinal()434 int device_ordinal() const { return device_ordinal_; }
435
436 // Returns a borrowed pointer to the underlying StreamExecutor implementation.
437 internal::StreamExecutorInterface *implementation();
438
439 // Warning: use Stream::ThenLaunch instead, this method is not for general
440 // consumption. However, this is the only way to launch a kernel for which
441 // the type signature is only known at runtime; say, if an application
442 // supports loading/launching kernels with arbitrary type signatures.
443 // In this case, the application is expected to know how to do parameter
444 // packing that obeys the contract of the underlying platform implementation.
445 //
446 // Launches a data parallel kernel with the given thread/block
447 // dimensionality and already-packed args/sizes to pass to the underlying
448 // platform driver.
449 //
450 // This is called by Stream::Launch() to delegate to the platform's launch
451 // implementation in StreamExecutorInterface::Launch().
452 bool Launch(Stream *stream, const ThreadDim &thread_dims,
453 const BlockDim &block_dims, const KernelBase &kernel,
454 const KernelArgsArrayBase &args);
455
456 // Gets-or-creates (creates with memoization) a FftSupport datatype that can
457 // be used to execute FFT routines on the current platform.
458 //
459 // Ownership and user-facing is the same as AsBlas() below.
460 //
461 // Returns null if there was an error initializing the FFT support for the
462 // underlying platform.
463 fft::FftSupport *AsFft();
464
465 // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
466 // be used for neural network routines on the current platform.
467 //
468 // Ownership and user-facing is the same as AsBlas() below.
469 //
470 // Returns null if there was an error initializing the DNN support for the
471 // underlying platform.
472 dnn::DnnSupport *AsDnn();
473
474 // Turns StreamExecutor operation tracing on or off.
475 void EnableTracing(bool enable);
476
477 // Registers a trace listener to receive callbacks for only a single
478 // StreamExecutor instance.
479 // To register a listener for all executors for a given platform, see
480 // Platform::RegisterTraceListener().
481 // Does not take ownership of listener.
482 void RegisterTraceListener(TraceListener* listener);
483
484 // Removes a TraceListener from this StreamExecutor instance.
485 // Returns false (and logs) in cases where the argument listener was not
486 // previously registered.
487 bool UnregisterTraceListener(TraceListener* listener);
488
489 // Return allocator statistics.
490 absl::optional<AllocatorStats> GetAllocatorStats();
491
492 private:
493 template <typename BeginCallT, typename CompleteCallT,
494 typename ReturnT, typename... BeginArgsT>
495 friend class ScopedTracer;
496 friend class Event;
497 friend class Stream;
498 friend class Timer;
499 template <typename... Params>
500 friend class TypedKernel;
501 template <typename... Args>
502 friend struct ThenBlasImpl;
503
504 // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
505 // be used to execute BLAS routines on the current platform. This is typically
506 // not user-facing, as users will use the Stream::ThenBlas* family of routines
507 // to entrain BLAS operations. See blas.h for additional details.
508 //
509 // Ownership is not transferred to the caller -- ownership is retained by this
510 // object for memoization. This BLAS interface is also only expected to be
511 // used by a Stream for entraining calls to BLAS functionality.
512 //
513 // Returns null if there was an error initializing the BLAS support for the
514 // underlying platform.
515 blas::BlasSupport *AsBlas();
516
517 // Gets-or-creates (creates with memoization) an RngSupport datatype that can
518 // be used for random-number-generation routines on the current platform.
519 //
520 // Ownership and user-facing is the same as AsBlas() above.
521 //
522 // Returns null if there was an error initializing the RNG support for the
523 // underlying platform.
524 rng::RngSupport *AsRng();
525
526 // Causes the host code to synchronously wait for operations entrained onto
527 // stream to complete. Effectively a join on the asynchronous device
528 // operations enqueued on the stream before this program point.
529 port::Status BlockHostUntilDone(Stream *stream);
530
531 // Without blocking the device, retrieve the current stream status.
532 port::Status GetStatus(Stream *stream);
533
534 // Synchronously allocates size bytes on the underlying platform and returns
535 // an opaque void* representing that allocation. In the case of failure,
536 // nullptr is returned.
537 void *Allocate(uint64 size);
538
539 // Finds and retrieves device memory for the symbol on the underlying
540 // platform.
541 bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
542 void **mem, size_t *bytes);
543
544 // Entrains a memcpy operation onto stream, with a host destination location
545 // host_dst and a device memory source, with target size size.
546 bool Memcpy(Stream *stream, void *host_dst,
547 const DeviceMemoryBase &device_src, uint64 size);
548
549 // Entrains a memcpy operation onto stream, with a device destination location
550 // and a host memory source, with target size size.
551 bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
552 const void *host_src, uint64 size);
553
554 // Entrains a memcpy operation onto stream, with a device destination location
555 // and a device source location, with target size size. Peer access should
556 // have been enabled between the StreamExecutors owning the device memory
557 // regions.
558 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
559 const DeviceMemoryBase &device_src, uint64 size);
560
561 // Entrains on a stream a user-specified function to be run on the host.
562 // See Stream::ThenDoHostCallback for full details.
563 bool HostCallback(Stream *stream, std::function<void()> callback);
564
565 // Entrains on a stream a user-specified function to be run on the host.
566 // See Stream::ThenDoHostCallback for full details.
567 // This is the preferred form for a callback that may return an error.
568 bool HostCallback(Stream *stream, std::function<port::Status()> callback);
569
570 // Performs platform-specific allocation and initialization of an event.
571 port::Status AllocateEvent(Event *event);
572
573 // Performs platform-specific deallocation and cleanup of an event.
574 port::Status DeallocateEvent(Event *event);
575
576 // Inserts the specified event at the end of the specified stream.
577 port::Status RecordEvent(Stream *stream, Event *event);
578
579 // Wait for the specified event at the end of the specified stream.
580 port::Status WaitForEvent(Stream *stream, Event *event);
581
582 // Requests the current status of the event from the underlying platform.
583 Event::Status PollForEventStatus(Event *event);
584
585 // Allocates stream resources on the underlying platform for subject and
586 // initializes its internals.
587 bool AllocateStream(Stream *subject);
588
589 // Deallocates stream resources on the underlying platform.
590 void DeallocateStream(Stream *subject);
591
592 // Causes dependent to not begin execution until other has finished its
593 // last-enqueued work.
594 bool CreateStreamDependency(Stream *dependent, Stream *other);
595
596 // Allocates timer resources on the underlying platform for subject and
597 // initializes its internals.
598 bool AllocateTimer(Timer *subject);
599
600 // Deallocates timer resources on the underlying platform.
601 void DeallocateTimer(Timer *subject);
602
603 // Records a start event for an interval timer.
604 bool StartTimer(Stream *stream, Timer *timer);
605
606 // Records a stop event for an interval timer.
607 bool StopTimer(Stream *stream, Timer *timer);
608
609 // Allocates a new metadata object, appropriately populated, on the heap, with
610 // ownership transfer to caller.
611 DeviceDescription *PopulateDeviceDescription() const;
612
613 // Adds a task to the port::ThreadPool work queue. These tasks must be
614 // fire-and-forget and have no external data or timing dependencies; their
615 // execution order and completion time have no guarantees.
616 // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
617 // there, temporary internal buffers are freed using this method.
618 void EnqueueOnBackgroundThread(std::function<void()> task);
619
620 // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
621 // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
622 // tracked.
623 void CreateAllocRecord(void *opaque, uint64 size);
624
625 // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
626 // pointers will not be erased (as they're not tracked, per above).
627 void EraseAllocRecord(void *opaque);
628
629 // Calls the relevant TraceListener routine to begin tracing for the specified
630 // asynchronous method.
631 template <typename TraceCallT, typename... ArgsT>
632 void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
633
634 // Reader/writer lock for class-static StreamExecutor members.
635 static mutex static_mu_;
636
637 // Reader/writer lock for mutable data structures on this StreamExecutor.
638 //
639 // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
640 // can acquire the lock on their first (mutating) call as well.
641 mutable mutex mu_;
642
643 // Reference to the platform that created this executor.
644 const Platform *platform_;
645
646 // Pointer to the platform-specific-interface implementation. This is
647 // delegated to by the interface routines in pointer-to-implementation
648 // fashion.
649 std::unique_ptr<internal::StreamExecutorInterface> implementation_;
650
651 // A mapping of pointer (to device memory) to string representation of the
652 // stack (of the allocating thread) at the time at which the pointer was
653 // allocated.
654 std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
655
656 // Memoized BLAS support object -- we only want to create this once when asked
657 // for a BLAS interface.
658 std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
659
660 // Memoized DNN support object -- we only want to create this once when asked
661 // for an DNN interface.
662 std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
663
664 // Memoized FFT support object -- we only want to create this once when asked
665 // for a FFT interface.
666 std::unique_ptr<fft::FftSupport> fft_;
667
668 // Memoized RNG support object -- we only want to create this once when asked
669 // for an RNG interface.
670 std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
671
672 // Slot to cache the owned DeviceDescription for the underlying device
673 // once it has been quieried from DeviceDescription().
674 mutable std::unique_ptr<DeviceDescription> device_description_
675 GUARDED_BY(mu_);
676
677 // The kind of the underlying platform that is being targeted, as passed
678 // during construction.
679 //
680 // Immutable post-initialization.
681 PlatformKind platform_kind_;
682
683 // The device ordinal that this object was initialized with.
684 //
685 // Immutable post-initialization.
686 int device_ordinal_;
687
688 // Executor for handling host callback work that cannot be performed
689 // by a host callback thread - for example, cleanup after a host BLAS routine
690 // (which may make device API calls). This work cannot block the host
691 // callback thread, will be completed asynchronously, and should be treated
692 // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
693 // here.
694 //
695 // Immutable post-initialization. Object is thread-safe.
696 std::unique_ptr<port::ThreadPool> background_threads_;
697
698 // Counter for the current number of live streams. This is used to check
699 // for accidentally-outstanding streams at StreamExecutor teardown time, as
700 // well
701 // as to indicate leaks (via a large outstanding count being logged) in the
702 // case we can't allocate more streams.
703 std::atomic_int_fast32_t live_stream_count_;
704
705 // Only one worker thread is needed; little work will be done by the
706 // executor.
707 static const int kNumBackgroundThreads = 1;
708
709 // Indicates if StreamExecutor operation tracing should be performed.
710 bool tracing_enabled_;
711
712 // The set of TraceListeners registered for this StreamExecutor.
713 std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
714
715 // Allocated memory in bytes.
716 int64 mem_alloc_bytes_;
717
718 // Memory limit in bytes. Value less or equal to 0 indicates there is no
719 // limit.
720 int64 memory_limit_bytes_;
721
722 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
723 };
724
725 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
726 class ScopedModuleHandle {
727 public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)728 explicit ScopedModuleHandle(StreamExecutor *executor,
729 ModuleHandle module_handle)
730 : executor_(executor), module_handle_(module_handle) {}
731
ScopedModuleHandle(ScopedModuleHandle && other)732 ScopedModuleHandle(ScopedModuleHandle &&other) {
733 executor_ = other.executor_;
734 module_handle_ = other.module_handle_;
735 other.executor_ = nullptr;
736 other.module_handle_ = ModuleHandle();
737 }
738
739 ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
740 executor_ = other.executor_;
741 module_handle_ = other.module_handle_;
742 other.executor_ = nullptr;
743 other.module_handle_ = ModuleHandle();
744 return *this;
745 }
746
~ScopedModuleHandle()747 ~ScopedModuleHandle() {
748 if (static_cast<bool>(module_handle_)) {
749 CHECK(executor_->UnloadModule(module_handle_));
750 }
751 }
752
753 private:
754 StreamExecutor *executor_;
755 ModuleHandle module_handle_;
756
757 TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
758 };
759
760 ////////////
761 // Inlines
762
763 template <typename T>
AllocateArray(uint64 element_count)764 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
765 uint64 bytes = sizeof(T) * element_count;
766 void *opaque = Allocate(bytes);
767 return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
768 }
769
770 template <typename T>
GetSymbol(const string & symbol_name,ModuleHandle module_handle)771 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
772 const string &symbol_name, ModuleHandle module_handle) {
773 port::StatusOr<DeviceMemoryBase> untyped_symbol =
774 GetUntypedSymbol(symbol_name, module_handle);
775 if (!untyped_symbol.ok()) {
776 return untyped_symbol.status();
777 }
778 return DeviceMemory<T>(untyped_symbol.ValueOrDie());
779 }
780
781 template <typename ElemT>
ScopedDeviceMemory()782 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory()
783 : wrapped_(DeviceMemoryBase()), parent_(nullptr) {}
784
785 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)786 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
787 DeviceMemoryBase value)
788 : wrapped_(value), parent_(parent) {}
789
790 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)791 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
792 StreamExecutor *parent, std::initializer_list<ElemT> values)
793 : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
794 if (ptr() != nullptr) {
795 std::vector<ElemT> local(values);
796 if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
797 ptr()->size())) {
798 Reset(nullptr);
799 }
800 }
801 }
802
803 template <typename ElemT>
~ScopedDeviceMemory()804 ScopedDeviceMemory<ElemT>::~ScopedDeviceMemory() {
805 if (wrapped_ == nullptr) return;
806 DCHECK(parent_ != nullptr);
807 parent_->Deallocate(&wrapped_);
808 }
809
810 template <typename ElemT>
Reset(DeviceMemory<ElemT> updated)811 void ScopedDeviceMemory<ElemT>::Reset(DeviceMemory<ElemT> updated) {
812 if (wrapped_ != nullptr) {
813 DCHECK(parent_ != nullptr);
814 parent_->Deallocate(&wrapped_);
815 }
816 wrapped_ = updated;
817 }
818
819 template <typename ElemT>
Reset(std::nullptr_t)820 void ScopedDeviceMemory<ElemT>::Reset(std::nullptr_t) {
821 if (wrapped_ != nullptr) {
822 DCHECK(parent_ != nullptr);
823 parent_->Deallocate(&wrapped_);
824 }
825 wrapped_ = DeviceMemory<ElemT>{};
826 }
827
828 template <typename T>
AllocateZeroed()829 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
830 void *opaque = Allocate(sizeof(T));
831 if (opaque == nullptr) {
832 return DeviceMemory<T>{};
833 }
834
835 DeviceMemory<T> result = DeviceMemory<T>::MakeFromByteSize(opaque, sizeof(T));
836 bool ok = SynchronousMemZero(&result, sizeof(T));
837 if (!ok) {
838 Deallocate(&result);
839 return DeviceMemory<T>{};
840 }
841
842 return result;
843 }
844
845 template <typename T>
AllocateSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)846 DeviceMemory<T> StreamExecutor::AllocateSubBuffer(DeviceMemory<T> *parent,
847 uint64 element_offset,
848 uint64 element_count) {
849 if (element_offset + element_count > parent->ElementCount()) {
850 LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
851 << "than parent allocation size: (" << element_offset << " + "
852 << element_count << ") vs. (" << parent->ElementCount() << ")";
853 return DeviceMemory<T>{};
854 }
855
856 void *opaque = implementation_->AllocateSubBuffer(
857 parent, sizeof(T) * element_offset, sizeof(T) * element_count);
858 if (opaque == nullptr) {
859 return DeviceMemory<T>{};
860 }
861 CreateAllocRecord(opaque, sizeof(T) * element_count);
862 return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count,
863 true /* = is_sub_buffer */));
864 }
865
866 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)867 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
868 const TypedKernel<Params...> &kernel,
869 Args... args) {
870 KernelInvocationChecker<std::tuple<Params...>,
871 std::tuple<Args...>>::CheckAllStaticAssert();
872 if (ok()) {
873 // This is the core that allows type-safe kernel launching.
874 // Since the platforms take kernel arguments as tuples of (void *, size),
875 // we pack the variadic parameters passed as ...args into the desired
876 // tuple form and pass that packed form to the StreamExecutor::Launch()
877 // implementation.
878 KernelArgsArray<sizeof...(args)> kernel_args;
879 kernel.PackParams(&kernel_args, args...);
880 DCHECK(parent_ != nullptr);
881 bool ok =
882 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
883 if (!ok) {
884 SetError();
885 LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
886 }
887 }
888 return *this;
889 }
890
891 } // namespace stream_executor
892
893 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
894