• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18 
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24 
25 #include "absl/base/macros.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/stream_executor/lib/status.h"
28 #include "tensorflow/stream_executor/lib/statusor.h"
29 #include "tensorflow/stream_executor/lib/threadpool.h"
30 #include "tensorflow/stream_executor/platform.h"
31 #include "tensorflow/stream_executor/platform/logging.h"
32 #include "tensorflow/stream_executor/platform/mutex.h"
33 #include "tensorflow/stream_executor/platform/port.h"
34 #include "tensorflow/stream_executor/platform/thread_annotations.h"
35 #include "tensorflow/stream_executor/rng.h"
36 #include "tensorflow/stream_executor/shared_memory_config.h"
37 #include "tensorflow/stream_executor/stream.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/trace_listener.h"
40 
41 namespace stream_executor {
42 
43 // Structure used for device memory leak checking.
44 struct AllocRecord {
45   // The requested allocation size of the buffer.
46   uint64 bytes;
47 
48   // Holds a representation of the stack at the time the associated buffer was
49   // allocated. Produced in a form described in
50   // //util/symbolize/symbolized_stacktrace.h.
51   string stack_trace;
52 };
53 
54 // Forward declaration of private friend class.
55 template <typename BeginCallT, typename CompleteCallT,
56           typename ReturnT, typename... BeginArgsT>
57 class ScopedTracer;
58 
59 // A StreamExecutor manages a single device, in terms of executing work (kernel
60 // launches) and memory management (allocation/deallocation, memory copies to
61 // and from the device). It is conceptually the "handle" for a device -- Stream
62 // objects, which are used to enqueue work to run on the
63 // coprocessor have a StreamExecutor instance as their "parent" object.
64 //
65 // StreamExecutor objects have an underlying platform that is specified up
66 // front;
67 // e.g. either it is a CUDA or OpenCL executor.
68 //
69 // Thread-safe after initialization.
70 // StreamExecutor interface should not be invoked from a signal handler.
71 class StreamExecutor {
72  public:
73   explicit StreamExecutor(PlatformKind kind,
74                           const PluginConfig &plugin_config = PluginConfig());
75 
76   StreamExecutor(
77       const Platform *platform,
78       std::unique_ptr<internal::StreamExecutorInterface> implementation);
79 
80   ~StreamExecutor();
81 
82   port::Status Init();
83   port::Status Init(int device_ordinal, DeviceOptions device_options);
84 
85   // Returns the platform that this StreamExecutor is acting upon.
86   ABSL_DEPRECATED("Use platform() instead.")
platform_kind()87   PlatformKind platform_kind() const { return platform_kind_; }
88 
89   // Returns a reference to the platform that created this executor.
platform()90   const Platform *platform() const { return platform_; }
91 
92   // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
93   // upon, if one exists.
94   //
95   // Parameters:
96   //   spec: The MultiKernelLoaderSpec is usually generated as a compile-time
97   //    constant into an appropriate namespace. For example, see
98   //    stream_executor::executor_sample::kKernelLoaderSpecs, from which a
99   //    MultiKernelLoaderSpec is selected.
100   //   kernel: Outparam that the kernel is loaded into. A given Kernel
101   //    instantiation should not be loaded into more than once.
102   //
103   // If an error occurs, or there is no kernel available for the StreamExecutor
104   // platform, false is returned.
105   bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
106 
107   // Releases any state associated with the previously loaded kernel.
108   void UnloadKernel(const KernelBase *kernel);
109 
110   // Loads a module for the platform this StreamExecutor is acting upon.
111   //
112   // `spec` describes the module to be loaded.  On success writes the handle for
113   // the loaded module to `module_handle` and returns true.  Else returns false.
114   bool LoadModule(const MultiModuleLoaderSpec &spec,
115                   ModuleHandle *module_handle);
116 
117   // Unloads the module with handle `module_handle`.
118   bool UnloadModule(ModuleHandle module_handle);
119 
120   // Synchronously allocates an array on the device of type T with element_count
121   // elements.
122   template <typename T>
123   DeviceMemory<T> AllocateArray(uint64 element_count);
124 
125   // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
126   template <typename T>
AllocateOwnedArray(uint64 element_count)127   ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
128     return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
129   }
130 
131   // Convenience wrapper that allocates space for a single element of type T in
132   // device memory.
133   template <typename T>
AllocateScalar()134   DeviceMemory<T> AllocateScalar() {
135     return AllocateArray<T>(1);
136   }
137 
138   // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
139   template <typename T>
AllocateOwnedScalar()140   ScopedDeviceMemory<T> AllocateOwnedScalar() {
141     return AllocateOwnedArray<T>(1);
142   }
143 
144   // Synchronously allocates a scalar of type T on the device that is (POD)
145   // zero-byte initialized.
146   template <typename T>
147   DeviceMemory<T> AllocateZeroed();
148 
149   // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
150   template <typename T>
AllocateOwnedZeroed()151   ScopedDeviceMemory<T> AllocateOwnedZeroed() {
152     return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
153   }
154 
155   // Allocate a memory region inside another allocated memory region.
156   // Offset and size are specified in terms of T elements.
157   // Warning: Do not free a parent buffer before its sub-buffers; this may cause
158   // use-after-free issues (the specific behavior is not consistent across
159   // platforms).
160   //  - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
161   //    sub-buffer after parent deallocation is expected to be safe. This will
162   //    render your code non-platform-portable, however.
163   template <typename T>
164   DeviceMemory<T> AllocateSubBuffer(DeviceMemory<T> *parent,
165                                     uint64 element_offset,
166                                     uint64 element_count);
167 
168   // As AllocateSubBuffer(), but returns a ScopedDeviceMemory<T>.
169   template <typename T>
AllocateOwnedSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)170   ScopedDeviceMemory<T> AllocateOwnedSubBuffer(DeviceMemory<T> *parent,
171                                                uint64 element_offset,
172                                                uint64 element_count) {
173     return ScopedDeviceMemory<T>(
174         this, AllocateSubBuffer<T>(parent, element_offset, element_count));
175   }
176 
177   // Finds a symbol and returns device memory allocated to the symbol. The
178   // symbol is searched in any kernels that were previously loaded through
179   // GetKernel() before the GetSymbol() call. The user has to make sure that the
180   // type of symbol and T match.
181   // - Note: symbol_name should include its namespace as well. For example,
182   //         pass "nms0::symbol" if referring to nms0::symbol.
183   //
184   // If `module_handle` is set then searches only within the module
185   // corresponding to `module_handle`.
186   template <typename T>
187   port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
188                                             ModuleHandle module_handle = {});
189 
190   // An untyped version of GetSymbol.
191   port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
192       const string &symbol_name, ModuleHandle module_handle = {});
193 
194   // Deallocate the DeviceMemory previously allocated via this interface.
195   // Deallocation of a nullptr-representative value is permitted.
196   //
197   // Resets the internal contents of mem to be null-representative, but this
198   // null-out effect should not be relied upon in client code.
199   //
200   // TODO(jlebar): Change this to accept a DeviceMemoryBase by value, see
201   // discussion in cl/195744342.
202   void Deallocate(DeviceMemoryBase *mem);
203 
204   // Retrieves a mapping of active opaque device memory pointer to a string
205   // representation of the [allocating thread's] stack at the time the pointer
206   // was allocated. Useful for tracking device memory leaks.
207   //
208   // Note: this will only be populated if --check_device_leaks flag is
209   // activated.
210   void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
211 
212   // Allocates unified memory space of the given size, if supported.
213   // See
214   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
215   // for more details on unified memory.
216   void *UnifiedMemoryAllocate(uint64 bytes);
217 
218   // Deallocates unified memory space previously allocated with
219   // UnifiedMemoryAllocate.
220   void UnifiedMemoryDeallocate(void *location);
221 
222   // Allocates a region of host memory and registers it with the platform API.
223   // Memory allocated in this manner (or allocated and registered with
224   // HostMemoryRegister() is required for use in asynchronous memcpy operations,
225   // such as Stream::ThenMemcpy.
226   void *HostMemoryAllocate(uint64 bytes);
227 
228   // Deallocates a region of host memory allocated by HostMemoryAllocate().
229   void HostMemoryDeallocate(void *location);
230 
231   // Registers a region of host memory with the platform API. Registered memory
232   // (or memory allocated with HostMemoryAllocate) is required for use with
233   // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
234   // is used to register memory allocated outside the StreamExecutor;
235   // HostMemoryAllocate implicitly registers its allocations and
236   // HostMemoryDeallocate implicitly deregisters on deallocation.
237   bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
238 
239   // Unregisters a region of host memory registered with HostMemoryRegister.
240   // This should be done before deallocating the region with delete[]/free/etc.
241   bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
242 
243   // Synchronizes all activity occurring in the StreamExecutor's context (most
244   // likely a whole device).
245   bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
246 
247   // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
248   // given location in device memory.
249   bool SynchronousMemZero(DeviceMemoryBase *location,
250                           uint64 size) SE_MUST_USE_RESULT;
251 
252   // Blocks the caller while "size" bytes are initialized to "value" (in POD
253   // fashion) at the given location in device memory.
254   bool SynchronousMemSet(DeviceMemoryBase *location, int value,
255                          uint64 size) SE_MUST_USE_RESULT;
256 
257   // [deprecated] Blocks the caller while a data segment of the given size is
258   // copied from the host source to the device destination.
259   ABSL_DEPRECATED(
260       "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
261   bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
262                          uint64 size) SE_MUST_USE_RESULT;
263 
264   // [deprecated] Blocks the caller while a data segment of the given size is
265   // copied from the device source to the host destination.
266   ABSL_DEPRECATED(
267       "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
268   bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
269                          uint64 size) SE_MUST_USE_RESULT;
270 
271   // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
272   port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
273                                     DeviceMemoryBase *device_dst);
274 
275   // Alternative interface for memcpying from host to device that takes an
276   // array slice. Checks that the destination size can accommodate the host
277   // slice size.
278   template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)279   port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
280                                     DeviceMemoryBase *device_dst) {
281     auto host_size = host_src.size() * sizeof(T);
282     CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
283     return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
284   }
285 
286   // Same as SynchronousMemcpy(void*, ...) above.
287   port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
288                                     int64 size, void *host_dst);
289 
290   // Alternative interface for memcpying from device to host that takes an
291   // array slice. Checks that the destination size can accommodate the host
292   // slice size.
293   template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)294   port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
295                                     port::MutableArraySlice<T> host_dst) {
296     auto host_size = host_dst.size() * sizeof(T);
297     CHECK(device_src.size() == 0 || host_size >= device_src.size());
298     return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
299   }
300 
301   // Blocks the caller while a data segment of the given size is copied from the
302   // device source to the device destination.
303   bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
304                          const DeviceMemoryBase &device_src,
305                          uint64 size) SE_MUST_USE_RESULT;
306 
307   // Enqueues an operation onto stream to zero out size bytes at the given
308   // device memory location. Neither stream nor location may be null. Returns
309   // whether the operation was successfully enqueued onto the stream.
310   bool MemZero(Stream *stream, DeviceMemoryBase *location,
311                uint64 size) SE_MUST_USE_RESULT;
312 
313   // Enqueues an operation onto stream to set 32-bit patterns starting at
314   // location, for byte count given by size. size must be 32-bit quantified
315   // (i.e. evently divisible by 4). Returns whether the operation was
316   // successfully enqueued onto the stream.
317   bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
318                 uint64 size) SE_MUST_USE_RESULT;
319 
320   // Enables peer access from this StreamExecutor to memory
321   // allocated by other, such that launched device code, memcpies, etc may
322   // access it directly.
323   //
324   // Both this StreamExecutor and other must be backed by the same platform (as
325   // in
326   // CUDA vs OpenCL) implementation.
327   port::Status EnablePeerAccessTo(StreamExecutor *other);
328 
329   // Returns whether it's possible to enable peer access from this
330   // StreamExecutor
331   // to memory allocated by another.
332   //
333   // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
334   // this is more an up-front test as to whether it's expressly forbidden.
335   bool CanEnablePeerAccessTo(StreamExecutor *other);
336 
337   // Gets the preferred shared memory configuration for the device to which this
338   // executor is bound.
339   SharedMemoryConfig GetDeviceSharedMemoryConfig();
340 
341   // Sets the preferred shared memory configuration for the device to which this
342   // executor is bound.
343   port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
344 
345   // Obtains metadata about the underlying device.
346   // The value is cached on first use.
347   const DeviceDescription &GetDeviceDescription() const;
348 
349   // If implemented, returns device specific measurement of load
350   // (e.g. pending requests).
351   int64 GetDeviceLoad() const;
352 
353   // Returns the underlying device memory usage information, if it is available.
354   // If it is not available (false is returned), free/total may not be
355   // initialized.
356   //
357   // Note: "Free" reflects the amount of free memory on the underlying device,
358   // so allocations via other StreamExecutors that have the same underlying
359   // device
360   // will be reflected in "free".
361   bool DeviceMemoryUsage(int64 *free, int64 *total) const;
362 
363   // The device count reported by this StreamExecutor's platform.
364   // Note: on OpenCL we implicitly select platform zero at the moment.
365   int PlatformDeviceCount() const;
366 
367   // Returns whether the StreamExecutor supports BLAS routines for the platform
368   // that underlies this interface.
369   bool SupportsBlas() const;
370 
371   // Returns whether the StreamExecutor supports FFT routines for the platform
372   // that underlies this interface.
373   bool SupportsFft() const;
374 
375   // Returns whether the StreamExecutor supports RNG routines for the platform
376   // that underlies this interface.
377   bool SupportsRng() const;
378 
379   // Returns whether the StreamExecutor support neural net routines for the
380   // platform that underlies this interface.
381   bool SupportsDnn() const;
382 
383   // Returns the list of supported algorithms for the forward convolution
384   // operation.
385   bool GetConvolveAlgorithms(bool with_winograd_nonfused,
386                              std::vector<dnn::AlgorithmDesc> *out_algorithms);
387 
388   // Returns the list of supported algorithms for rnn operation.
389   bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
390 
391   // Get the list of supported algorithms for the backward convolution on data.
392   bool GetConvolveBackwardDataAlgorithms(
393       bool with_winograd_nonfused,
394       std::vector<dnn::AlgorithmDesc> *out_algorithms);
395 
396   // Get the list of supported algorithms for the backward convolution on the
397   // filter.
398   bool GetConvolveBackwardFilterAlgorithms(
399       bool with_winograd_nonfused,
400       std::vector<dnn::AlgorithmDesc> *out_algorithms);
401 
402   // Get the list of supported algorithms for BLAS gemm.
403   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
404 
405   // Create an RNN descriptor based on model shapes and configurations.
406   // The caller retains the ownership of the descriptor.
407   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
408       int num_layers, int hidden_size, int input_size, int batch_size,
409       dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
410       dnn::RnnMode rnn_mode, dnn::DataType data_type,
411       const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
412       ScratchAllocator *state_allocator);
413 
414   // Create a RNN sequence descriptor that specifies either the input or output
415   // sequence. The caller retains the ownership of the returned descriptor.
416   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
417   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
418                                     int data_size, dnn::DataType data_type);
419 
420   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
421   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
422                                     int data_size,
423                                     const absl::Span<const int> &seq_lengths,
424                                     bool time_major, dnn::DataType data_type);
425 
426   // Create an RNN state descriptor that specifies the input or hidden state.
427   // The caller retains the ownership of the returned descriptor.
428   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
429   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
430                                  dnn::DataType data_type);
431 
432   // Returns the device ordinal that this StreamExecutor was initialized with.
433   // Meaningless before initialization.
device_ordinal()434   int device_ordinal() const { return device_ordinal_; }
435 
436   // Returns a borrowed pointer to the underlying StreamExecutor implementation.
437   internal::StreamExecutorInterface *implementation();
438 
439   // Warning: use Stream::ThenLaunch instead, this method is not for general
440   // consumption. However, this is the only way to launch a kernel for which
441   // the type signature is only known at runtime; say, if an application
442   // supports loading/launching kernels with arbitrary type signatures.
443   // In this case, the application is expected to know how to do parameter
444   // packing that obeys the contract of the underlying platform implementation.
445   //
446   // Launches a data parallel kernel with the given thread/block
447   // dimensionality and already-packed args/sizes to pass to the underlying
448   // platform driver.
449   //
450   // This is called by Stream::Launch() to delegate to the platform's launch
451   // implementation in StreamExecutorInterface::Launch().
452   bool Launch(Stream *stream, const ThreadDim &thread_dims,
453               const BlockDim &block_dims, const KernelBase &kernel,
454               const KernelArgsArrayBase &args);
455 
456   // Gets-or-creates (creates with memoization) a FftSupport datatype that can
457   // be used to execute FFT routines on the current platform.
458   //
459   // Ownership and user-facing is the same as AsBlas() below.
460   //
461   // Returns null if there was an error initializing the FFT support for the
462   // underlying platform.
463   fft::FftSupport *AsFft();
464 
465   // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
466   // be used for neural network routines on the current platform.
467   //
468   // Ownership and user-facing is the same as AsBlas() below.
469   //
470   // Returns null if there was an error initializing the DNN support for the
471   // underlying platform.
472   dnn::DnnSupport *AsDnn();
473 
474   // Turns StreamExecutor operation tracing on or off.
475   void EnableTracing(bool enable);
476 
477   // Registers a trace listener to receive callbacks for only a single
478   // StreamExecutor instance.
479   // To register a listener for all executors for a given platform, see
480   // Platform::RegisterTraceListener().
481   // Does not take ownership of listener.
482   void RegisterTraceListener(TraceListener* listener);
483 
484   // Removes a TraceListener from this StreamExecutor instance.
485   // Returns false (and logs) in cases where the argument listener was not
486   // previously registered.
487   bool UnregisterTraceListener(TraceListener* listener);
488 
489   // Return allocator statistics.
490   absl::optional<AllocatorStats> GetAllocatorStats();
491 
492  private:
493   template <typename BeginCallT, typename CompleteCallT,
494             typename ReturnT, typename... BeginArgsT>
495   friend class ScopedTracer;
496   friend class Event;
497   friend class Stream;
498   friend class Timer;
499   template <typename... Params>
500   friend class TypedKernel;
501   template <typename... Args>
502   friend struct ThenBlasImpl;
503 
504   // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
505   // be used to execute BLAS routines on the current platform. This is typically
506   // not user-facing, as users will use the Stream::ThenBlas* family of routines
507   // to entrain BLAS operations. See blas.h for additional details.
508   //
509   // Ownership is not transferred to the caller -- ownership is retained by this
510   // object for memoization. This BLAS interface is also only expected to be
511   // used by a Stream for entraining calls to BLAS functionality.
512   //
513   // Returns null if there was an error initializing the BLAS support for the
514   // underlying platform.
515   blas::BlasSupport *AsBlas();
516 
517   // Gets-or-creates (creates with memoization) an RngSupport datatype that can
518   // be used for random-number-generation routines on the current platform.
519   //
520   // Ownership and user-facing is the same as AsBlas() above.
521   //
522   // Returns null if there was an error initializing the RNG support for the
523   // underlying platform.
524   rng::RngSupport *AsRng();
525 
526   // Causes the host code to synchronously wait for operations entrained onto
527   // stream to complete. Effectively a join on the asynchronous device
528   // operations enqueued on the stream before this program point.
529   port::Status BlockHostUntilDone(Stream *stream);
530 
531   // Without blocking the device, retrieve the current stream status.
532   port::Status GetStatus(Stream *stream);
533 
534   // Synchronously allocates size bytes on the underlying platform and returns
535   // an opaque void* representing that allocation. In the case of failure,
536   // nullptr is returned.
537   void *Allocate(uint64 size);
538 
539   // Finds and retrieves device memory for the symbol on the underlying
540   // platform.
541   bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
542                  void **mem, size_t *bytes);
543 
544   // Entrains a memcpy operation onto stream, with a host destination location
545   // host_dst and a device memory source, with target size size.
546   bool Memcpy(Stream *stream, void *host_dst,
547               const DeviceMemoryBase &device_src, uint64 size);
548 
549   // Entrains a memcpy operation onto stream, with a device destination location
550   // and a host memory source, with target size size.
551   bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
552               const void *host_src, uint64 size);
553 
554   // Entrains a memcpy operation onto stream, with a device destination location
555   // and a device source location, with target size size. Peer access should
556   // have been enabled between the StreamExecutors owning the device memory
557   // regions.
558   bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
559                             const DeviceMemoryBase &device_src, uint64 size);
560 
561   // Entrains on a stream a user-specified function to be run on the host.
562   // See Stream::ThenDoHostCallback for full details.
563   bool HostCallback(Stream *stream, std::function<void()> callback);
564 
565   // Entrains on a stream a user-specified function to be run on the host.
566   // See Stream::ThenDoHostCallback for full details.
567   // This is the preferred form for a callback that may return an error.
568   bool HostCallback(Stream *stream, std::function<port::Status()> callback);
569 
570   // Performs platform-specific allocation and initialization of an event.
571   port::Status AllocateEvent(Event *event);
572 
573   // Performs platform-specific deallocation and cleanup of an event.
574   port::Status DeallocateEvent(Event *event);
575 
576   // Inserts the specified event at the end of the specified stream.
577   port::Status RecordEvent(Stream *stream, Event *event);
578 
579   // Wait for the specified event at the end of the specified stream.
580   port::Status WaitForEvent(Stream *stream, Event *event);
581 
582   // Requests the current status of the event from the underlying platform.
583   Event::Status PollForEventStatus(Event *event);
584 
585   // Allocates stream resources on the underlying platform for subject and
586   // initializes its internals.
587   bool AllocateStream(Stream *subject);
588 
589   // Deallocates stream resources on the underlying platform.
590   void DeallocateStream(Stream *subject);
591 
592   // Causes dependent to not begin execution until other has finished its
593   // last-enqueued work.
594   bool CreateStreamDependency(Stream *dependent, Stream *other);
595 
596   // Allocates timer resources on the underlying platform for subject and
597   // initializes its internals.
598   bool AllocateTimer(Timer *subject);
599 
600   // Deallocates timer resources on the underlying platform.
601   void DeallocateTimer(Timer *subject);
602 
603   // Records a start event for an interval timer.
604   bool StartTimer(Stream *stream, Timer *timer);
605 
606   // Records a stop event for an interval timer.
607   bool StopTimer(Stream *stream, Timer *timer);
608 
609   // Allocates a new metadata object, appropriately populated, on the heap, with
610   // ownership transfer to caller.
611   DeviceDescription *PopulateDeviceDescription() const;
612 
613   // Adds a task to the port::ThreadPool work queue. These tasks must be
614   // fire-and-forget and have no external data or timing dependencies; their
615   // execution order and completion time have no guarantees.
616   // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
617   // there, temporary internal buffers are freed using this method.
618   void EnqueueOnBackgroundThread(std::function<void()> task);
619 
620   // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
621   // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
622   // tracked.
623   void CreateAllocRecord(void *opaque, uint64 size);
624 
625   // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
626   // pointers will not be erased (as they're not tracked, per above).
627   void EraseAllocRecord(void *opaque);
628 
629   // Calls the relevant TraceListener routine to begin tracing for the specified
630   // asynchronous method.
631   template <typename TraceCallT, typename... ArgsT>
632   void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
633 
634   // Reader/writer lock for class-static StreamExecutor members.
635   static mutex static_mu_;
636 
637   // Reader/writer lock for mutable data structures on this StreamExecutor.
638   //
639   // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
640   // can acquire the lock on their first (mutating) call as well.
641   mutable mutex mu_;
642 
643   // Reference to the platform that created this executor.
644   const Platform *platform_;
645 
646   // Pointer to the platform-specific-interface implementation. This is
647   // delegated to by the interface routines in pointer-to-implementation
648   // fashion.
649   std::unique_ptr<internal::StreamExecutorInterface> implementation_;
650 
651   // A mapping of pointer (to device memory) to string representation of the
652   // stack (of the allocating thread) at the time at which the pointer was
653   // allocated.
654   std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
655 
656   // Memoized BLAS support object -- we only want to create this once when asked
657   // for a BLAS interface.
658   std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
659 
660   // Memoized DNN support object -- we only want to create this once when asked
661   // for an DNN interface.
662   std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
663 
664   // Memoized FFT support object -- we only want to create this once when asked
665   // for a FFT interface.
666   std::unique_ptr<fft::FftSupport> fft_;
667 
668   // Memoized RNG support object -- we only want to create this once when asked
669   // for an RNG interface.
670   std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
671 
672   // Slot to cache the owned DeviceDescription for the underlying device
673   // once it has been quieried from DeviceDescription().
674   mutable std::unique_ptr<DeviceDescription> device_description_
675       GUARDED_BY(mu_);
676 
677   // The kind of the underlying platform that is being targeted, as passed
678   // during construction.
679   //
680   // Immutable post-initialization.
681   PlatformKind platform_kind_;
682 
683   // The device ordinal that this object was initialized with.
684   //
685   // Immutable post-initialization.
686   int device_ordinal_;
687 
688   // Executor for handling host callback work that cannot be performed
689   // by a host callback thread - for example, cleanup after a host BLAS routine
690   // (which may make device API calls). This work cannot block the host
691   // callback thread, will be completed asynchronously, and should be treated
692   // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
693   // here.
694   //
695   // Immutable post-initialization. Object is thread-safe.
696   std::unique_ptr<port::ThreadPool> background_threads_;
697 
698   // Counter for the current number of live streams. This is used to check
699   // for accidentally-outstanding streams at StreamExecutor teardown time, as
700   // well
701   // as to indicate leaks (via a large outstanding count being logged) in the
702   // case we can't allocate more streams.
703   std::atomic_int_fast32_t live_stream_count_;
704 
705   // Only one worker thread is needed; little work will be done by the
706   // executor.
707   static const int kNumBackgroundThreads = 1;
708 
709   // Indicates if StreamExecutor operation tracing should be performed.
710   bool tracing_enabled_;
711 
712   // The set of TraceListeners registered for this StreamExecutor.
713   std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
714 
715   // Allocated memory in bytes.
716   int64 mem_alloc_bytes_;
717 
718   // Memory limit in bytes. Value less or equal to 0 indicates there is no
719   // limit.
720   int64 memory_limit_bytes_;
721 
722   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
723 };
724 
725 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
726 class ScopedModuleHandle {
727  public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)728   explicit ScopedModuleHandle(StreamExecutor *executor,
729                               ModuleHandle module_handle)
730       : executor_(executor), module_handle_(module_handle) {}
731 
ScopedModuleHandle(ScopedModuleHandle && other)732   ScopedModuleHandle(ScopedModuleHandle &&other) {
733     executor_ = other.executor_;
734     module_handle_ = other.module_handle_;
735     other.executor_ = nullptr;
736     other.module_handle_ = ModuleHandle();
737   }
738 
739   ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
740     executor_ = other.executor_;
741     module_handle_ = other.module_handle_;
742     other.executor_ = nullptr;
743     other.module_handle_ = ModuleHandle();
744     return *this;
745   }
746 
~ScopedModuleHandle()747   ~ScopedModuleHandle() {
748     if (static_cast<bool>(module_handle_)) {
749       CHECK(executor_->UnloadModule(module_handle_));
750     }
751   }
752 
753  private:
754   StreamExecutor *executor_;
755   ModuleHandle module_handle_;
756 
757   TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
758 };
759 
760 ////////////
761 // Inlines
762 
763 template <typename T>
AllocateArray(uint64 element_count)764 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
765   uint64 bytes = sizeof(T) * element_count;
766   void *opaque = Allocate(bytes);
767   return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
768 }
769 
770 template <typename T>
GetSymbol(const string & symbol_name,ModuleHandle module_handle)771 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
772     const string &symbol_name, ModuleHandle module_handle) {
773   port::StatusOr<DeviceMemoryBase> untyped_symbol =
774       GetUntypedSymbol(symbol_name, module_handle);
775   if (!untyped_symbol.ok()) {
776     return untyped_symbol.status();
777   }
778   return DeviceMemory<T>(untyped_symbol.ValueOrDie());
779 }
780 
781 template <typename ElemT>
ScopedDeviceMemory()782 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory()
783     : wrapped_(DeviceMemoryBase()), parent_(nullptr) {}
784 
785 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)786 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
787                                               DeviceMemoryBase value)
788     : wrapped_(value), parent_(parent) {}
789 
790 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)791 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
792     StreamExecutor *parent, std::initializer_list<ElemT> values)
793     : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
794   if (ptr() != nullptr) {
795     std::vector<ElemT> local(values);
796     if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
797                                    ptr()->size())) {
798       Reset(nullptr);
799     }
800   }
801 }
802 
803 template <typename ElemT>
~ScopedDeviceMemory()804 ScopedDeviceMemory<ElemT>::~ScopedDeviceMemory() {
805   if (wrapped_ == nullptr) return;
806   DCHECK(parent_ != nullptr);
807   parent_->Deallocate(&wrapped_);
808 }
809 
810 template <typename ElemT>
Reset(DeviceMemory<ElemT> updated)811 void ScopedDeviceMemory<ElemT>::Reset(DeviceMemory<ElemT> updated) {
812   if (wrapped_ != nullptr) {
813     DCHECK(parent_ != nullptr);
814     parent_->Deallocate(&wrapped_);
815   }
816   wrapped_ = updated;
817 }
818 
819 template <typename ElemT>
Reset(std::nullptr_t)820 void ScopedDeviceMemory<ElemT>::Reset(std::nullptr_t) {
821   if (wrapped_ != nullptr) {
822     DCHECK(parent_ != nullptr);
823     parent_->Deallocate(&wrapped_);
824   }
825   wrapped_ = DeviceMemory<ElemT>{};
826 }
827 
828 template <typename T>
AllocateZeroed()829 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
830   void *opaque = Allocate(sizeof(T));
831   if (opaque == nullptr) {
832     return DeviceMemory<T>{};
833   }
834 
835   DeviceMemory<T> result = DeviceMemory<T>::MakeFromByteSize(opaque, sizeof(T));
836   bool ok = SynchronousMemZero(&result, sizeof(T));
837   if (!ok) {
838     Deallocate(&result);
839     return DeviceMemory<T>{};
840   }
841 
842   return result;
843 }
844 
845 template <typename T>
AllocateSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)846 DeviceMemory<T> StreamExecutor::AllocateSubBuffer(DeviceMemory<T> *parent,
847                                                   uint64 element_offset,
848                                                   uint64 element_count) {
849   if (element_offset + element_count > parent->ElementCount()) {
850     LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
851                << "than parent allocation size: (" << element_offset << " + "
852                << element_count << ") vs. (" << parent->ElementCount() << ")";
853     return DeviceMemory<T>{};
854   }
855 
856   void *opaque = implementation_->AllocateSubBuffer(
857       parent, sizeof(T) * element_offset, sizeof(T) * element_count);
858   if (opaque == nullptr) {
859     return DeviceMemory<T>{};
860   }
861   CreateAllocRecord(opaque, sizeof(T) * element_count);
862   return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count,
863                                           true /* = is_sub_buffer */));
864 }
865 
866 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)867 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
868                                   const TypedKernel<Params...> &kernel,
869                                   Args... args) {
870   KernelInvocationChecker<std::tuple<Params...>,
871                           std::tuple<Args...>>::CheckAllStaticAssert();
872   if (ok()) {
873     // This is the core that allows type-safe kernel launching.
874     // Since the platforms take kernel arguments as tuples of (void *, size),
875     // we pack the variadic parameters passed as ...args into the desired
876     // tuple form and pass that packed form to the StreamExecutor::Launch()
877     // implementation.
878     KernelArgsArray<sizeof...(args)> kernel_args;
879     kernel.PackParams(&kernel_args, args...);
880     DCHECK(parent_ != nullptr);
881     bool ok =
882         parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
883     if (!ok) {
884       SetError();
885       LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
886     }
887   }
888   return *this;
889 }
890 
891 }  // namespace stream_executor
892 
893 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
894