1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24
25 #include "tensorflow/stream_executor/lib/status.h"
26 #include "tensorflow/stream_executor/lib/statusor.h"
27 #include "tensorflow/stream_executor/lib/strcat.h"
28 #include "tensorflow/stream_executor/lib/threadpool.h"
29 #include "tensorflow/stream_executor/platform.h"
30 #include "tensorflow/stream_executor/platform/logging.h"
31 #include "tensorflow/stream_executor/platform/mutex.h"
32 #include "tensorflow/stream_executor/platform/port.h"
33 #include "tensorflow/stream_executor/platform/thread_annotations.h"
34 #include "tensorflow/stream_executor/rng.h"
35 #include "tensorflow/stream_executor/shared_memory_config.h"
36 #include "tensorflow/stream_executor/stream.h"
37 #include "tensorflow/stream_executor/stream_executor_internal.h"
38 #include "tensorflow/stream_executor/trace_listener.h"
39
40 namespace perftools {
41 namespace gputools {
42
43 // Structure used for device memory leak checking.
44 struct AllocRecord {
45 // The requested allocation size of the buffer.
46 uint64 bytes;
47
48 // Holds a representation of the stack at the time the associated buffer was
49 // allocated. Produced in a form described in
50 // //util/symbolize/symbolized_stacktrace.h.
51 string stack_trace;
52 };
53
54 // Forward declaration of private friend class.
55 template <typename BeginCallT, typename CompleteCallT,
56 typename ReturnT, typename... BeginArgsT>
57 class ScopedTracer;
58
59 // A StreamExecutor manages a single device, in terms of executing work (kernel
60 // launches) and memory management (allocation/deallocation, memory copies to
61 // and from the device). It is conceptually the "handle" for a device -- Stream
62 // objects, which are used to enqueue work to run on the
63 // coprocessor have a StreamExecutor instance as their "parent" object.
64 //
65 // StreamExecutor objects have an underlying platform that is specified up
66 // front;
67 // e.g. either it is a CUDA or OpenCL executor.
68 //
69 // Thread-safe after initialization.
70 // StreamExecutor interface should not be invoked from a signal handler.
71 class StreamExecutor {
72 public:
73 explicit StreamExecutor(PlatformKind kind,
74 const PluginConfig &plugin_config = PluginConfig());
75
76 StreamExecutor(
77 const Platform *platform,
78 std::unique_ptr<internal::StreamExecutorInterface> implementation);
79
80 ~StreamExecutor();
81
82 port::Status Init();
83 port::Status Init(int device_ordinal, DeviceOptions device_options);
84
85 // DEPRECATED: Do not use; use platform() instead.
86 // Returns the platform that this StreamExecutor is acting upon.
platform_kind()87 PlatformKind platform_kind() const { return platform_kind_; }
88
89 // Returns a reference to the platform that created this executor.
platform()90 const Platform *platform() const { return platform_; }
91
92 // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
93 // upon, if one exists.
94 //
95 // Parameters:
96 // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
97 // constant into an appropriate namespace. For example, see
98 // perftools::gputools::executor_sample::kKernelLoaderSpecs, from which a
99 // MultiKernelLoaderSpec is selected.
100 // kernel: Outparam that the kernel is loaded into. A given Kernel
101 // instantiation should not be loaded into more than once.
102 //
103 // If an error occurs, or there is no kernel available for the StreamExecutor
104 // platform, false is returned.
105 bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
106
107 // Releases any state associated with the previously loaded kernel.
108 void UnloadKernel(const KernelBase *kernel);
109
110 // Synchronously allocates an array on the device of type T with element_count
111 // elements.
112 template <typename T>
113 DeviceMemory<T> AllocateArray(uint64 element_count);
114
115 // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
116 template <typename T>
AllocateOwnedArray(uint64 element_count)117 ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
118 return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
119 }
120
121 // Convenience wrapper that allocates space for a single element of type T in
122 // device memory.
123 template <typename T>
AllocateScalar()124 DeviceMemory<T> AllocateScalar() {
125 return AllocateArray<T>(1);
126 }
127
128 // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
129 template <typename T>
AllocateOwnedScalar()130 ScopedDeviceMemory<T> AllocateOwnedScalar() {
131 return AllocateOwnedArray<T>(1);
132 }
133
134 // Synchronously allocates a scalar of type T on the device that is (POD)
135 // zero-byte initialized.
136 template <typename T>
137 DeviceMemory<T> AllocateZeroed();
138
139 // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
140 template <typename T>
AllocateOwnedZeroed()141 ScopedDeviceMemory<T> AllocateOwnedZeroed() {
142 return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
143 }
144
145 // Allocate a memory region inside another allocated memory region.
146 // Offset and size are specified in terms of T elements.
147 // Warning: Do not free a parent buffer before its sub-buffers; this may cause
148 // use-after-free issues (the specific behavior is not consistent across
149 // platforms).
150 // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
151 // sub-buffer after parent deallocation is expected to be safe. This will
152 // render your code non-platform-portable, however.
153 template <typename T>
154 DeviceMemory<T> AllocateSubBuffer(DeviceMemory<T> *parent,
155 uint64 element_offset,
156 uint64 element_count);
157
158 // As AllocateSubBuffer(), but returns a ScopedDeviceMemory<T>.
159 template <typename T>
AllocateOwnedSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)160 ScopedDeviceMemory<T> AllocateOwnedSubBuffer(DeviceMemory<T> *parent,
161 uint64 element_offset,
162 uint64 element_count) {
163 return ScopedDeviceMemory<T>(
164 this, AllocateSubBuffer<T>(parent, element_offset, element_count));
165 }
166
167 // Finds a symbol and returns device memory allocated to the symbol. The
168 // symbol is searched in any kernels that were previously loaded through
169 // GetKernel() before the GetSymbol() call. The user has to make sure that the
170 // type of symbol and T match.
171 // - Note: symbol_name should include its namespace as well. For example,
172 // pass "nms0::symbol" if referring to nms0::symbol.
173 template <typename T>
174 port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name);
175
176 // Deallocate the DeviceMemory previously allocated via this interface.
177 // Deallocation of a nullptr-representative value is permitted.
178 //
179 // Resets the internal contents of mem to be null-representative, but this
180 // null-out effect should not be relied upon in client code.
181 void Deallocate(DeviceMemoryBase *mem);
182
183 // Retrieves a mapping of active opaque device memory pointer to a string
184 // representation of the [allocating thread's] stack at the time the pointer
185 // was allocated. Useful for tracking device memory leaks.
186 //
187 // Note: this will only be populated if --check_device_leaks flag is
188 // activated.
189 void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
190
191 // Allocates a region of host memory and registers it with the platform API.
192 // Memory allocated in this manner (or allocated and registered with
193 // HostMemoryRegister() is required for use in asynchronous memcpy operations,
194 // such as Stream::ThenMemcpy.
195 void *HostMemoryAllocate(uint64 bytes);
196
197 // Deallocates a region of host memory allocated by HostMemoryAllocate().
198 void HostMemoryDeallocate(void *location);
199
200 // Registers a region of host memory with the platform API. Registered memory
201 // (or memory allocated with HostMemoryAllocate) is required for use with
202 // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
203 // is used to register memory allocated outside the StreamExecutor;
204 // HostMemoryAllocate implicitly registers its allocations and
205 // HostMemoryDeallocate implicitly deregisters on deallocation.
206 bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
207
208 // Unregisters a region of host memory registered with HostMemoryRegister.
209 // This should be done before deallocating the region with delete[]/free/etc.
210 bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
211
212 // Synchronizes all activity occurring in the StreamExecutor's context (most
213 // likely a whole device).
214 bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
215
216 // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
217 // given location in device memory.
218 bool SynchronousMemZero(DeviceMemoryBase *location,
219 uint64 size) SE_MUST_USE_RESULT;
220
221 // Blocks the caller while "size" bytes are initialized to "value" (in POD
222 // fashion) at the given location in device memory.
223 bool SynchronousMemSet(DeviceMemoryBase *location, int value,
224 uint64 size) SE_MUST_USE_RESULT;
225
226 // [deprecated] Blocks the caller while a data segment of the given size is
227 // copied from the host source to the device destination.
228 //
229 // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
230 bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
231 uint64 size) SE_MUST_USE_RESULT;
232
233 // [deprecated] Blocks the caller while a data segment of the given size is
234 // copied from the device source to the host destination.
235 //
236 // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
237 bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
238 uint64 size) SE_MUST_USE_RESULT;
239
240 // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
241 port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
242 DeviceMemoryBase *device_dst);
243
244 // Alternative interface for memcpying from host to device that takes an
245 // array slice. Checks that the destination size can accommodate the host
246 // slice size.
247 template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)248 port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
249 DeviceMemoryBase *device_dst) {
250 auto host_size = host_src.size() * sizeof(T);
251 CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
252 return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
253 }
254
255 // Same as SynchronousMemcpy(void*, ...) above.
256 port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
257 int64 size, void *host_dst);
258
259 // Alternative interface for memcpying from device to host that takes an
260 // array slice. Checks that the destination size can accommodate the host
261 // slice size.
262 template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)263 port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
264 port::MutableArraySlice<T> host_dst) {
265 auto host_size = host_dst.size() * sizeof(T);
266 CHECK(device_src.size() == 0 || host_size >= device_src.size());
267 return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
268 }
269
270 // Blocks the caller while a data segment of the given size is copied from the
271 // device source to the device destination.
272 bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
273 const DeviceMemoryBase &device_src,
274 uint64 size) SE_MUST_USE_RESULT;
275
276 // Enqueues an operation onto stream to zero out size bytes at the given
277 // device memory location. Neither stream nor location may be null. Returns
278 // whether the operation was successfully enqueued onto the stream.
279 bool MemZero(Stream *stream, DeviceMemoryBase *location,
280 uint64 size) SE_MUST_USE_RESULT;
281
282 // Enqueues an operation onto stream to set 32-bit patterns starting at
283 // location, for byte count given by size. size must be 32-bit quantified
284 // (i.e. evently divisible by 4). Returns whether the operation was
285 // successfully enqueued onto the stream.
286 bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
287 uint64 size) SE_MUST_USE_RESULT;
288
289 // Enables peer access from this StreamExecutor to memory
290 // allocated by other, such that launched device code, memcpies, etc may
291 // access it directly.
292 //
293 // Both this StreamExecutor and other must be backed by the same platform (as
294 // in
295 // CUDA vs OpenCL) implementation.
296 port::Status EnablePeerAccessTo(StreamExecutor *other);
297
298 // Returns whether it's possible to enable peer access from this
299 // StreamExecutor
300 // to memory allocated by another.
301 //
302 // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
303 // this is more an up-front test as to whether it's expressly forbidden.
304 bool CanEnablePeerAccessTo(StreamExecutor *other);
305
306 // Gets the preferred shared memory configuration for the device to which this
307 // executor is bound.
308 SharedMemoryConfig GetDeviceSharedMemoryConfig();
309
310 // Sets the preferred shared memory configuration for the device to which this
311 // executor is bound.
312 port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
313
314 // Obtains metadata about the underlying device.
315 // The value is cached on first use.
316 const DeviceDescription &GetDeviceDescription() const;
317
318 // If implemented, returns device specific measurement of load
319 // (e.g. pending requests).
320 int64 GetDeviceLoad() const;
321
322 // Returns the underlying device memory usage information, if it is available.
323 // If it is not available (false is returned), free/total may not be
324 // initialized.
325 //
326 // Note: "Free" reflects the amount of free memory on the underlying device,
327 // so allocations via other StreamExecutors that have the same underlying
328 // device
329 // will be reflected in "free".
330 bool DeviceMemoryUsage(int64 *free, int64 *total) const;
331
332 // The device count reported by this StreamExecutor's platform.
333 // Note: on OpenCL we implicitly select platform zero at the moment.
334 int PlatformDeviceCount() const;
335
336 // Returns whether the StreamExecutor supports BLAS routines for the platform
337 // that underlies this interface.
338 bool SupportsBlas() const;
339
340 // Returns whether the StreamExecutor supports FFT routines for the platform
341 // that underlies this interface.
342 bool SupportsFft() const;
343
344 // Returns whether the StreamExecutor supports RNG routines for the platform
345 // that underlies this interface.
346 bool SupportsRng() const;
347
348 // Returns whether the StreamExecutor support neural net routines for the
349 // platform that underlies this interface.
350 bool SupportsDnn() const;
351
352 // Get the list of supported algorithms for the forward convolution opeartion.
353 bool GetConvolveAlgorithms(bool with_winograd_nonfused,
354 std::vector<dnn::AlgorithmDesc> *out_algorithms);
355
356 // Get the list of supported algorithms for the backward convolution on data.
357 bool GetConvolveBackwardDataAlgorithms(
358 bool with_winograd_nonfused,
359 std::vector<dnn::AlgorithmDesc> *out_algorithms);
360
361 // Get the list of supported algorithms for the backward convolution on the
362 // filter.
363 bool GetConvolveBackwardFilterAlgorithms(
364 bool with_winograd_nonfused,
365 std::vector<dnn::AlgorithmDesc> *out_algorithms);
366
367 // Get the list of supported algorithms for BLAS gemm.
368 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
369
370 // Create an RNN descriptor based on model shapes and configurations.
371 // The caller retains the ownership of the descriptor.
372 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
373 int num_layers, int hidden_size, int input_size,
374 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
375 dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
376 uint64 seed, ScratchAllocator *state_allocator);
377
378 // Create a RNN sequence descriptor that specifies either the input or output
379 // sequence. The caller retains the ownership of the returned descriptor.
380 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
381 createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
382 int data_size, dnn::DataType data_type);
383
384 // Create an RNN state descriptor that specifies the input or hidden state.
385 // The caller retains the ownership of the returned descriptor.
386 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
387 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
388 dnn::DataType data_type);
389
390 // Returns the device ordinal that this StreamExecutor was initialized with.
391 // Meaningless before initialization.
device_ordinal()392 int device_ordinal() const { return device_ordinal_; }
393
394 // Returns a borrowed pointer to the underlying StreamExecutor implementation.
395 internal::StreamExecutorInterface *implementation();
396
397 // Warning: use Stream::ThenLaunch instead, this method is not for general
398 // consumption. However, this is the only way to launch a kernel for which
399 // the type signature is only known at runtime; say, if an application
400 // supports loading/launching kernels with arbitrary type signatures.
401 // In this case, the application is expected to know how to do parameter
402 // packing that obeys the contract of the underlying platform implementation.
403 //
404 // Launches a data parallel kernel with the given thread/block
405 // dimensionality and already-packed args/sizes to pass to the underlying
406 // platform driver.
407 //
408 // This is called by Stream::Launch() to delegate to the platform's launch
409 // implementation in StreamExecutorInterface::Launch().
410 bool Launch(Stream *stream, const ThreadDim &thread_dims,
411 const BlockDim &block_dims, const KernelBase &kernel,
412 const KernelArgsArrayBase &args);
413
414 // Gets-or-creates (creates with memoization) a FftSupport datatype that can
415 // be used to execute FFT routines on the current platform.
416 //
417 // Ownership and user-facing is the same as AsBlas() below.
418 //
419 // Returns null if there was an error initializing the FFT support for the
420 // underlying platform.
421 fft::FftSupport *AsFft();
422
423 // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
424 // be used for neural network routines on the current platform.
425 //
426 // Ownership and user-facing is the same as AsBlas() below.
427 //
428 // Returns null if there was an error initializing the DNN support for the
429 // underlying platform.
430 dnn::DnnSupport *AsDnn();
431
432 // Turns StreamExecutor operation tracing on or off.
433 void EnableTracing(bool enable);
434
435 // Registers a trace listener to receive callbacks for only a single
436 // StreamExecutor instance.
437 // To register a listener for all executors for a given platform, see
438 // Platform::RegisterTraceListener().
439 // Does not take ownership of listener.
440 void RegisterTraceListener(TraceListener* listener);
441
442 // Removes a TraceListener from this StreamExecutor instance.
443 // Returns false (and logs) in cases where the argument listener was not
444 // previously registered.
445 bool UnregisterTraceListener(TraceListener* listener);
446
447 private:
448 template <typename BeginCallT, typename CompleteCallT,
449 typename ReturnT, typename... BeginArgsT>
450 friend class ScopedTracer;
451 friend class Event;
452 friend class Stream;
453 friend class Timer;
454 template <typename... Params>
455 friend class TypedKernel;
456 template <typename... Args>
457 friend struct ThenBlasImpl;
458
459 // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
460 // be used to execute BLAS routines on the current platform. This is typically
461 // not user-facing, as users will use the Stream::ThenBlas* family of routines
462 // to entrain BLAS operations. See blas.h for additional details.
463 //
464 // Ownership is not transferred to the caller -- ownership is retained by this
465 // object for memoization. This BLAS interface is also only expected to be
466 // used by a Stream for entraining calls to BLAS functionality.
467 //
468 // Returns null if there was an error initializing the BLAS support for the
469 // underlying platform.
470 blas::BlasSupport *AsBlas();
471
472 // Gets-or-creates (creates with memoization) an RngSupport datatype that can
473 // be used for random-number-generation routines on the current platform.
474 //
475 // Ownership and user-facing is the same as AsBlas() above.
476 //
477 // Returns null if there was an error initializing the RNG support for the
478 // underlying platform.
479 rng::RngSupport *AsRng();
480
481 // Causes the host code to synchronously wait for operations entrained onto
482 // stream to complete. Effectively a join on the asynchronous device
483 // operations enqueued on the stream before this program point.
484 port::Status BlockHostUntilDone(Stream *stream);
485
486 // Synchronously allocates size bytes on the underlying platform and returns
487 // an opaque void* representing that allocation. In the case of failure,
488 // nullptr is returned.
489 void *Allocate(uint64 size);
490
491 // Finds and retrieves device memory for the symbol on the underlying
492 // platform.
493 bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes);
494
495 // Entrains a memcpy operation onto stream, with a host destination location
496 // host_dst and a device memory source, with target size size.
497 bool Memcpy(Stream *stream, void *host_dst,
498 const DeviceMemoryBase &device_src, uint64 size);
499
500 // Entrains a memcpy operation onto stream, with a device destination location
501 // and a host memory source, with target size size.
502 bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
503 const void *host_src, uint64 size);
504
505 // Entrains a memcpy operation onto stream, with a device destination location
506 // and a device source location, with target size size. Peer access should
507 // have been enabled between the StreamExecutors owning the device memory
508 // regions.
509 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
510 const DeviceMemoryBase &device_src, uint64 size);
511
512 // Entrains on a stream a user-specified function to be run on the host.
513 // See Stream::ThenDoHostCallback for full details.
514 bool HostCallback(Stream *stream, std::function<void()> callback);
515
516 // Performs platform-specific allocation and initialization of an event.
517 port::Status AllocateEvent(Event *event);
518
519 // Performs platform-specific deallocation and cleanup of an event.
520 port::Status DeallocateEvent(Event *event);
521
522 // Inserts the specified event at the end of the specified stream.
523 port::Status RecordEvent(Stream *stream, Event *event);
524
525 // Wait for the specified event at the end of the specified stream.
526 port::Status WaitForEvent(Stream *stream, Event *event);
527
528 // Requests the current status of the event from the underlying platform.
529 Event::Status PollForEventStatus(Event *event);
530
531 // Allocates stream resources on the underlying platform for subject and
532 // initializes its internals.
533 bool AllocateStream(Stream *subject);
534
535 // Deallocates stream resources on the underlying platform.
536 void DeallocateStream(Stream *subject);
537
538 // Causes dependent to not begin execution until other has finished its
539 // last-enqueued work.
540 bool CreateStreamDependency(Stream *dependent, Stream *other);
541
542 // Allocates timer resources on the underlying platform for subject and
543 // initializes its internals.
544 bool AllocateTimer(Timer *subject);
545
546 // Deallocates timer resources on the underlying platform.
547 void DeallocateTimer(Timer *subject);
548
549 // Records a start event for an interval timer.
550 bool StartTimer(Stream *stream, Timer *timer);
551
552 // Records a stop event for an interval timer.
553 bool StopTimer(Stream *stream, Timer *timer);
554
555 // Allocates a new metadata object, appropriately populated, on the heap, with
556 // ownership transfer to caller.
557 DeviceDescription *PopulateDeviceDescription() const;
558
559 // Adds a task to the port::ThreadPool work queue. These tasks must be
560 // fire-and-forget and have no external data or timing dependencies; their
561 // execution order and completion time have no guarantees.
562 // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
563 // there, temporary internal buffers are freed using this method.
564 void EnqueueOnBackgroundThread(std::function<void()> task);
565
566 // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
567 // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
568 // tracked.
569 void CreateAllocRecord(void *opaque, uint64 size);
570
571 // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
572 // pointers will not be erased (as they're not tracked, per above).
573 void EraseAllocRecord(void *opaque);
574
575 // Calls the relevant TraceListener routine to begin tracing for the specified
576 // asynchronous method.
577 template <typename TraceCallT, typename... ArgsT>
578 void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
579
580 // Reader/writer lock for class-static StreamExecutor members.
581 static mutex static_mu_;
582
583 // Reader/writer lock for mutable data structures on this StreamExecutor.
584 //
585 // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
586 // can acquire the lock on their first (mutating) call as well.
587 mutable mutex mu_;
588
589 // Reference to the platform that created this executor.
590 const Platform *platform_;
591
592 // Pointer to the platform-specific-interface implementation. This is
593 // delegated to by the interface routines in pointer-to-implementation
594 // fashion.
595 std::unique_ptr<internal::StreamExecutorInterface> implementation_;
596
597 // A mapping of pointer (to device memory) to string representation of the
598 // stack (of the allocating thread) at the time at which the pointer was
599 // allocated.
600 std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
601
602 // Memoized BLAS support object -- we only want to create this once when asked
603 // for a BLAS interface.
604 std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
605
606 // Memoized DNN support object -- we only want to create this once when asked
607 // for an DNN interface.
608 std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
609
610 // Memoized FFT support object -- we only want to create this once when asked
611 // for a FFT interface.
612 std::unique_ptr<fft::FftSupport> fft_;
613
614 // Memoized RNG support object -- we only want to create this once when asked
615 // for an RNG interface.
616 std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
617
618 // Slot to cache the owned DeviceDescription for the underlying device
619 // once it has been quieried from DeviceDescription().
620 mutable std::unique_ptr<DeviceDescription> device_description_
621 GUARDED_BY(mu_);
622
623 // The kind of the underlying platform that is being targeted, as passed
624 // during construction.
625 //
626 // Immutable post-initialization.
627 PlatformKind platform_kind_;
628
629 // The device ordinal that this object was initialized with.
630 //
631 // Immutable post-initialization.
632 int device_ordinal_;
633
634 // Executor for handling host callback work that cannot be performed
635 // by a host callback thread - for example, cleanup after a host BLAS routine
636 // (which may make device API calls). This work cannot block the host
637 // callback thread, will be completed asynchronously, and should be treated
638 // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
639 // here.
640 //
641 // Immutable post-initialization. Object is thread-safe.
642 std::unique_ptr<port::ThreadPool> background_threads_;
643
644 // Counter for the current number of live streams. This is used to check
645 // for accidentally-outstanding streams at StreamExecutor teardown time, as
646 // well
647 // as to indicate leaks (via a large outstanding count being logged) in the
648 // case we can't allocate more streams.
649 std::atomic_int_fast32_t live_stream_count_;
650
651 // Only one worker thread is needed; little work will be done by the
652 // executor.
653 static const int kNumBackgroundThreads = 1;
654
655 // Indicates if StreamExecutor operation tracing should be performed.
656 bool tracing_enabled_;
657
658 // The set of TraceListeners registered for this StreamExecutor.
659 std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
660
661 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
662 };
663
664 ////////////
665 // Inlines
666
667 template <typename T>
AllocateArray(uint64 element_count)668 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
669 uint64 bytes = sizeof(T) * element_count;
670 void *opaque = Allocate(bytes);
671 return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
672 }
673
674 template <typename T>
GetSymbol(const string & symbol_name)675 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
676 const string &symbol_name) {
677 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
678 // be nullptr/0 for consistency with DeviceMemory semantics.
679 void *opaque = nullptr;
680 size_t bytes = 0;
681 if (GetSymbol(symbol_name, &opaque, &bytes)) {
682 CHECK_EQ(bytes % sizeof(T), 0);
683 return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
684 }
685 return port::Status(
686 port::error::NOT_FOUND,
687 port::StrCat("Check if kernel using the symbol is loaded: ",
688 symbol_name));
689 }
690
691 template <typename ElemT>
ScopedDeviceMemory()692 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory()
693 : wrapped_(DeviceMemoryBase()), parent_(nullptr) {}
694
695 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)696 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
697 DeviceMemoryBase value)
698 : wrapped_(value), parent_(parent) {}
699
700 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)701 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
702 StreamExecutor *parent, std::initializer_list<ElemT> values)
703 : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
704 if (ptr() != nullptr) {
705 std::vector<ElemT> local(values);
706 if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
707 ptr()->size())) {
708 Reset(nullptr);
709 }
710 }
711 }
712
713 template <typename ElemT>
~ScopedDeviceMemory()714 ScopedDeviceMemory<ElemT>::~ScopedDeviceMemory() {
715 if (wrapped_ == nullptr) return;
716 DCHECK(parent_ != nullptr);
717 parent_->Deallocate(&wrapped_);
718 }
719
720 template <typename ElemT>
Reset(DeviceMemory<ElemT> updated)721 void ScopedDeviceMemory<ElemT>::Reset(DeviceMemory<ElemT> updated) {
722 if (wrapped_ != nullptr) {
723 DCHECK(parent_ != nullptr);
724 parent_->Deallocate(&wrapped_);
725 }
726 wrapped_ = updated;
727 }
728
729 template <typename ElemT>
Reset(std::nullptr_t)730 void ScopedDeviceMemory<ElemT>::Reset(std::nullptr_t) {
731 if (wrapped_ != nullptr) {
732 DCHECK(parent_ != nullptr);
733 parent_->Deallocate(&wrapped_);
734 }
735 wrapped_ = DeviceMemory<ElemT>{};
736 }
737
738 template <typename T>
AllocateZeroed()739 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
740 void *opaque = Allocate(sizeof(T));
741 if (opaque == nullptr) {
742 return DeviceMemory<T>{};
743 }
744
745 DeviceMemory<T> result = DeviceMemory<T>::MakeFromByteSize(opaque, sizeof(T));
746 bool ok = SynchronousMemZero(&result, sizeof(T));
747 if (!ok) {
748 Deallocate(&result);
749 return DeviceMemory<T>{};
750 }
751
752 return result;
753 }
754
755 template <typename T>
AllocateSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)756 DeviceMemory<T> StreamExecutor::AllocateSubBuffer(DeviceMemory<T> *parent,
757 uint64 element_offset,
758 uint64 element_count) {
759 if (element_offset + element_count > parent->ElementCount()) {
760 LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
761 << "than parent allocation size: (" << element_offset << " + "
762 << element_count << ") vs. (" << parent->ElementCount() << ")";
763 return DeviceMemory<T>{};
764 }
765
766 void *opaque = implementation_->AllocateSubBuffer(
767 parent, sizeof(T) * element_offset, sizeof(T) * element_count);
768 if (opaque == nullptr) {
769 return DeviceMemory<T>{};
770 }
771 CreateAllocRecord(opaque, sizeof(T) * element_count);
772 return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count,
773 true /* = is_sub_buffer */));
774 }
775
776 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)777 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
778 const TypedKernel<Params...> &kernel,
779 Args... args) {
780 KernelInvocationChecker<std::tuple<Params...>,
781 std::tuple<Args...>>::CheckAllStaticAssert();
782 if (ok()) {
783 // This is the core that allows type-safe kernel launching.
784 // Since the platforms take kernel arguments as tuples of (void *, size),
785 // we pack the variadic parameters passed as ...args into the desired
786 // tuple form and pass that packed form to the StreamExecutor::Launch()
787 // implementation.
788 KernelArgsArray<sizeof...(args)> kernel_args;
789 kernel.PackParams(&kernel_args, args...);
790 DCHECK(parent_ != nullptr);
791 bool ok =
792 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
793 if (!ok) {
794 SetError();
795 LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
796 }
797 }
798 return *this;
799 }
800
801 } // namespace gputools
802 } // namespace perftools
803
804 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
805