• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Interfaces for platform-dependent implementations to satisfy. This are
17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e.
18 // the StreamExecutor is just a husk that delegates calls to the
19 // platform-specific objects which implement the interfaces defined here.
20 
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
23 
24 #include <functional>
25 #include <map>
26 #include <memory>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/types/optional.h"
31 #include "tensorflow/stream_executor/allocator_stats.h"
32 #include "tensorflow/stream_executor/device_description.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34 #include "tensorflow/stream_executor/device_options.h"
35 #include "tensorflow/stream_executor/dnn.h"
36 #include "tensorflow/stream_executor/event.h"
37 #include "tensorflow/stream_executor/kernel.h"
38 #include "tensorflow/stream_executor/kernel_cache_config.h"
39 #include "tensorflow/stream_executor/kernel_spec.h"
40 #include "tensorflow/stream_executor/launch_dim.h"
41 #include "tensorflow/stream_executor/lib/status.h"
42 #include "tensorflow/stream_executor/lib/statusor.h"
43 #include "tensorflow/stream_executor/module_spec.h"
44 #include "tensorflow/stream_executor/platform.h"
45 #include "tensorflow/stream_executor/platform/port.h"
46 #include "tensorflow/stream_executor/plugin_registry.h"
47 #include "tensorflow/stream_executor/shared_memory_config.h"
48 #include "tensorflow/stream_executor/trace_listener.h"
49 
50 namespace stream_executor {
51 
52 class Stream;
53 class Timer;
54 
55 // An opaque handle to a loaded module.
56 //
57 // An instance of this is returned from StreamExecutor::GetModule.
58 class ModuleHandle {
59  public:
id_(id)60   /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {}
61 
62   // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
63   // null pointer.
id()64   void *id() const { return id_; }
65 
66   explicit operator bool() const { return id() != nullptr; }
67 
68  private:
69   void *id_;
70 };
71 
72 namespace internal {
73 
74 // Platform-dependent interface class for the generic Events interface, in
75 // the PIMPL style.
76 class EventInterface {
77  public:
EventInterface()78   EventInterface() {}
~EventInterface()79   virtual ~EventInterface() {}
80 
81  private:
82   SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
83 };
84 
85 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to
86 // this interface) with virtual destruction. This class exists for the
87 // platform-dependent code to hang any kernel data/resource info/functionality
88 // off of.
89 class KernelInterface {
90  public:
91   // Default constructor for the abstract interface.
KernelInterface()92   KernelInterface() {}
93 
94   // Default destructor for the abstract interface.
~KernelInterface()95   virtual ~KernelInterface() {}
96 
97   // Returns the number of formal parameters that this kernel accepts.
98   virtual unsigned Arity() const = 0;
99 
100   // Sets the preferred cache configuration.
101   virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
102 
103   // Gets the preferred cache configuration.
104   virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
105 
106  private:
107   SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
108 };
109 
110 // Pointer-to-implementation object type (i.e. the Stream class delegates to
111 // this interface) with virtual destruction. This class exists for the
112 // platform-dependent code to hang any kernel data/resource info/functionality
113 // off of.
114 class StreamInterface {
115  public:
116   // Default constructor for the abstract interface.
StreamInterface()117   StreamInterface() {}
118 
119   // Default destructor for the abstract interface.
~StreamInterface()120   virtual ~StreamInterface() {}
121 
122   // Returns the GPU stream associated with this platform's stream
123   // implementation.
124   //
125   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
126   // causing a fatal error if it is not. This hack is made available solely for
127   // use from distbelief code, which temporarily has strong ties to CUDA or
128   // ROCm as a platform.
GpuStreamHack()129   virtual void *GpuStreamHack() { return nullptr; }
130 
131   // See the above comment on GpuStreamHack -- this further breaks abstraction
132   // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a
133   // platform, and a historical attachment to a programming model which takes a
134   // stream-slot rather than a stream-value.
GpuStreamMemberHack()135   virtual void **GpuStreamMemberHack() { return nullptr; }
136 
137  private:
138   SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
139 };
140 
141 // Pointer-to-implementation object type (i.e. the Timer class delegates to
142 // this interface) with virtual destruction. This class exists for the
143 // platform-dependent code to hang any timer data/resource info/functionality
144 // off of.
145 class TimerInterface {
146  public:
147   // Default constructor for the abstract interface.
TimerInterface()148   TimerInterface() {}
149 
150   // Default destructor for the abstract interface.
~TimerInterface()151   virtual ~TimerInterface() {}
152 
153   // Returns the number of microseconds elapsed in a completed timer.
154   virtual uint64 Microseconds() const = 0;
155 
156   // Returns the number of nanoseconds elapsed in a completed timer.
157   virtual uint64 Nanoseconds() const = 0;
158 
159  private:
160   SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
161 };
162 
163 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
164 //
165 // Various platforms will provide an implementation that satisfy this interface.
166 class StreamExecutorInterface {
167  public:
168   // Default constructor for the abstract interface.
StreamExecutorInterface()169   StreamExecutorInterface() {}
170 
171   // Default destructor for the abstract interface.
~StreamExecutorInterface()172   virtual ~StreamExecutorInterface() {}
173 
174   // Returns the (transitively) wrapped executor if this executor is
175   // wrapping another executor; otherwise, returns this.
GetUnderlyingExecutor()176   virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; }
177 
178   // See the StreamExecutor interface for comments on the same-named methods.
179   virtual port::Status Init(int device_ordinal,
180                             DeviceOptions device_options) = 0;
181 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)182   virtual bool GetKernel(const MultiKernelLoaderSpec &spec,
183                          KernelBase *kernel) {
184     return false;
185   }
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)186   virtual bool LoadModule(const MultiModuleLoaderSpec &spec,
187                           ModuleHandle *module_handle) {
188     return false;
189   }
UnloadModule(ModuleHandle module_handle)190   virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)191   virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
192                       const BlockDim &block_dims, const KernelBase &k,
193                       const KernelArgsArrayBase &args) {
194     return false;
195   }
196   // Releases any state associated with the kernel.
UnloadKernel(const KernelBase * kernel)197   virtual void UnloadKernel(const KernelBase *kernel) {}
198   virtual void *Allocate(uint64 size) = 0;
199   virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset,
200                                   uint64 size) = 0;
201   virtual void Deallocate(DeviceMemoryBase *mem) = 0;
202   // Allocates unified memory space of the given size, if supported.
203   // See
204   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
205   // for more details on unified memory.
UnifiedMemoryAllocate(uint64 size)206   virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; }
207 
208   // Deallocates unified memory space previously allocated with
209   // UnifiedMemoryAllocate.
UnifiedMemoryDeallocate(void * mem)210   virtual void UnifiedMemoryDeallocate(void *mem) {}
211   virtual void *HostMemoryAllocate(uint64 size) = 0;
212   virtual void HostMemoryDeallocate(void *mem) = 0;
213   virtual bool HostMemoryRegister(void *mem, uint64 size) = 0;
214   virtual bool HostMemoryUnregister(void *mem) = 0;
215   virtual bool SynchronizeAllActivity() = 0;
216   virtual bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) = 0;
217   virtual bool SynchronousMemSet(DeviceMemoryBase *location, int value,
218                                  uint64 size) = 0;
219   virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
220                                          const void *host_src, uint64 size) = 0;
221   virtual port::Status SynchronousMemcpy(void *host_dst,
222                                          const DeviceMemoryBase &gpu_src,
223                                          uint64 size) = 0;
224   virtual port::Status SynchronousMemcpyDeviceToDevice(
225       DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
226       uint64 size) = 0;
227   virtual bool MemZero(Stream *stream, DeviceMemoryBase *location,
228                        uint64 size) = 0;
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)229   virtual bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
230                       uint64 size) {
231     return false;
232   }
233   virtual bool Memset32(Stream *stream, DeviceMemoryBase *location,
234                         uint32 pattern, uint64 size) = 0;
235   virtual bool Memcpy(Stream *stream, void *host_dst,
236                       const DeviceMemoryBase &gpu_src, uint64 size) = 0;
237   virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
238                       const void *host_src, uint64 size) = 0;
239   virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
240                                     const DeviceMemoryBase &gpu_src,
241                                     uint64 size) = 0;
242   virtual bool HostCallback(Stream *stream, std::function<void()> callback);
243   virtual bool HostCallback(Stream *stream,
244                             std::function<port::Status()> callback) = 0;
245   virtual port::Status AllocateEvent(Event *event) = 0;
246   virtual port::Status DeallocateEvent(Event *event) = 0;
247   virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
248   virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0;
249   virtual Event::Status PollForEventStatus(Event *event) = 0;
250   virtual bool AllocateStream(Stream *stream) = 0;
251   virtual void DeallocateStream(Stream *stream) = 0;
252   virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0;
253   virtual bool AllocateTimer(Timer *timer) = 0;
254   virtual void DeallocateTimer(Timer *timer) = 0;
255   virtual bool StartTimer(Stream *stream, Timer *timer) = 0;
256   virtual bool StopTimer(Stream *stream, Timer *timer) = 0;
257   virtual port::Status BlockHostUntilDone(Stream *stream) = 0;
GetStatus(Stream * stream)258   virtual port::Status GetStatus(Stream *stream) {
259     return port::Status(port::error::UNIMPLEMENTED,
260                         "GetStatus is not supported on this executor.");
261   }
262   virtual int PlatformDeviceCount() = 0;
263   virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0;
264   virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0;
265   virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0;
266   virtual port::Status SetDeviceSharedMemoryConfig(
267       SharedMemoryConfig config) = 0;
268 
GetDeviceLoad()269   virtual int64 GetDeviceLoad() { return -1; }
270 
DeviceMemoryUsage(int64 * free,int64 * total)271   virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const {
272     return false;
273   }
274 
275   // Retrieves device pointer and size for a symbol. The device pointer is
276   // stored at mem, and the size is stored at size. Either mem or bytes can be
277   // null, however, both of them cannot be null at the same time. To use
278   // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
279   // is found.
280   //
281   // If ModuleHandle is set then we search for `symbol_name` only within the
282   // module corresponding to `module_handle`.  Otherwise all loaded modules are
283   // searched.
GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)284   virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
285                          void **mem, size_t *bytes) {
286     return false;
287   }
288 
289   // Creates a new DeviceDescription object. Ownership is transferred to the
290   // caller.
291   virtual DeviceDescription *PopulateDeviceDescription() const = 0;
292 
293   // Attempts to register the provided TraceListener with the device-specific
294   // Executor implementation. When this is called, the PIMPL interface has
295   // already taken ownership of the object and is managing the generic tracing
296   // events. The device-specific implementation must determine if the passed
297   // listener is of a type appropriate for it to trace during registration (and
298   // before dispatching events to it).
299   // Returns true if the listener was successfully registered, false otherwise.
300   // Does not take ownership of listener.
RegisterTraceListener(TraceListener * listener)301   virtual bool RegisterTraceListener(TraceListener* listener) { return false; }
302 
303   // Unregisters the specified listener from the device-specific Executor.
304   // Returns true if the listener was successfully registered, false otherwise.
UnregisterTraceListener(TraceListener * listener)305   virtual bool UnregisterTraceListener(TraceListener* listener) {
306     return false;
307   }
308 
309   // Returns whether this StreamExecutor has BLAS support for its underlying
310   // platform.
SupportsBlas()311   virtual bool SupportsBlas() const { return false; }
312 
313   // Creates a new BlasSupport object, ownership is transferred to the caller.
314   // If SupportsBlas() is false, this will always return null.
315   //
316   // If SupportsBlas() is true, this may return null, for example, if the BLAS
317   // initialization fails.
CreateBlas()318   virtual blas::BlasSupport *CreateBlas() { return nullptr; }
319 
320   // Returns whether this StreamExecutor has FFT support for its underlying
321   // platform.
SupportsFft()322   virtual bool SupportsFft() const { return false; }
323 
324   // Creates a new fft::FftSupport object, ownership is transferred to the
325   // caller.
326   // If SupportsFft() is false, this will always return null.
327   //
328   // If SupportsFft() is true, this may return null, for example, if the FFT
329   // initialization fails.
CreateFft()330   virtual fft::FftSupport *CreateFft() { return nullptr; }
331 
332   // Returns whether this StreamExecutor has Random Number Generation support
333   // for
334   // its underlying platform.
SupportsRng()335   virtual bool SupportsRng() const { return false; }
336 
337   // Returns whether this StreamExecutor has neural net support for its
338   // underlying
339   // platform.
SupportsDnn()340   virtual bool SupportsDnn() const { return false; }
341 
342   // Creates a new RngSupport object, ownership is transferred to the caller.
343   // If SupportsRng() is false, this will always return null.
344   //
345   // If SupportsRng() is true, this may return null, for example, if the RNG
346   // initialization fails.
CreateRng()347   virtual rng::RngSupport *CreateRng() { return nullptr; }
348 
349   // Creates a new DnnSupport object, ownership is transferred to the caller.
350   // If SupportsDnn() is false, this will always return null.
351   //
352   // If SupportsDnn() is true, this may return null, for example, if the DNN
353   // initialization fails.
CreateDnn()354   virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
355 
356   // Each call creates a new instance of the platform-specific implementation of
357   // the corresponding interface type.
358   virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0;
359   virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0;
360   virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
361   virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
362 
363   // Returns the CUDA or ROCm context associated with this StreamExecutor
364   // platform implementation.
365   //
366   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
367   // causing a fatal error if it is not. This hack is made available solely for
368   // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
369   // as a platform.
GpuContextHack()370   virtual void *GpuContextHack() { return nullptr; }
371 
372   // Return allocator statistics.
GetAllocatorStats()373   virtual absl::optional<AllocatorStats> GetAllocatorStats() {
374     return absl::nullopt;
375   }
376 
377  private:
378   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
379 };
380 
381 using StreamExecutorFactory =
382     std::function<StreamExecutorInterface *(const PluginConfig &)>;
383 using EventFactory = std::function<EventInterface *(StreamExecutor *)>;
384 using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>;
385 using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>;
386 using KernelFactory = std::function<KernelInterface*()>;
387 
388 StreamExecutorFactory *MakeCUDAExecutorImplementation();
389 
390 StreamExecutorFactory *MakeROCMExecutorImplementation();
391 
392 StreamExecutorFactory *MakeOpenCLExecutorImplementation();
393 
394 extern StreamExecutorFactory MakeHostExecutorImplementation;
395 
396 
397 }  // namespace internal
398 }  // namespace stream_executor
399 
400 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
401