• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Interfaces for platform-dependent implementations to satisfy. This are
17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e.
18 // the StreamExecutor is just a husk that delegates calls to the
19 // platform-specific objects which implement the interfaces defined here.
20 
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
23 
24 #include <functional>
25 #include <map>
26 #include <memory>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/types/optional.h"
31 #include "tensorflow/stream_executor/allocator_stats.h"
32 #include "tensorflow/stream_executor/device_description.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34 #include "tensorflow/stream_executor/device_options.h"
35 #include "tensorflow/stream_executor/dnn.h"
36 #include "tensorflow/stream_executor/event.h"
37 #include "tensorflow/stream_executor/kernel.h"
38 #include "tensorflow/stream_executor/kernel_cache_config.h"
39 #include "tensorflow/stream_executor/kernel_spec.h"
40 #include "tensorflow/stream_executor/launch_dim.h"
41 #include "tensorflow/stream_executor/lib/status.h"
42 #include "tensorflow/stream_executor/lib/statusor.h"
43 #include "tensorflow/stream_executor/module_spec.h"
44 #include "tensorflow/stream_executor/platform.h"
45 #include "tensorflow/stream_executor/platform/port.h"
46 #include "tensorflow/stream_executor/plugin_registry.h"
47 #include "tensorflow/stream_executor/shared_memory_config.h"
48 #include "tensorflow/stream_executor/trace_listener.h"
49 
50 namespace stream_executor {
51 
52 class Stream;
53 class Timer;
54 
55 // An opaque handle to a loaded module.
56 //
57 // An instance of this is returned from StreamExecutor::GetModule.
58 class ModuleHandle {
59  public:
id_(id)60   /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {}
61 
62   // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
63   // null pointer.
id()64   void *id() const { return id_; }
65 
66   explicit operator bool() const { return id() != nullptr; }
67 
68  private:
69   void *id_;
70 };
71 
72 namespace internal {
73 
74 // Platform-dependent interface class for the generic Events interface, in
75 // the PIMPL style.
76 class EventInterface {
77  public:
EventInterface()78   EventInterface() {}
~EventInterface()79   virtual ~EventInterface() {}
80 
81  private:
82   SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
83 };
84 
85 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to
86 // this interface) with virtual destruction. This class exists for the
87 // platform-dependent code to hang any kernel data/resource info/functionality
88 // off of.
89 class KernelInterface {
90  public:
91   // Default constructor for the abstract interface.
KernelInterface()92   KernelInterface() {}
93 
94   // Default destructor for the abstract interface.
~KernelInterface()95   virtual ~KernelInterface() {}
96 
97   // Returns the number of formal parameters that this kernel accepts.
98   virtual unsigned Arity() const = 0;
99 
100   // Sets the preferred cache configuration.
101   virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
102 
103   // Gets the preferred cache configuration.
104   virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
105 
106  private:
107   SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
108 };
109 
110 // Pointer-to-implementation object type (i.e. the Stream class delegates to
111 // this interface) with virtual destruction. This class exists for the
112 // platform-dependent code to hang any kernel data/resource info/functionality
113 // off of.
114 class StreamInterface {
115  public:
116   // Default constructor for the abstract interface.
StreamInterface()117   StreamInterface() {}
118 
119   // Default destructor for the abstract interface.
~StreamInterface()120   virtual ~StreamInterface() {}
121 
122   // Returns the GPU stream associated with this platform's stream
123   // implementation.
124   //
125   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
126   // causing a fatal error if it is not. This hack is made available solely for
127   // use from distbelief code, which temporarily has strong ties to CUDA or
128   // ROCm as a platform.
GpuStreamHack()129   virtual void *GpuStreamHack() { return nullptr; }
130 
131   // See the above comment on GpuStreamHack -- this further breaks abstraction
132   // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a
133   // platform, and a historical attachment to a programming model which takes a
134   // stream-slot rather than a stream-value.
GpuStreamMemberHack()135   virtual void **GpuStreamMemberHack() { return nullptr; }
136 
137  private:
138   SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
139 };
140 
141 // Pointer-to-implementation object type (i.e. the Timer class delegates to
142 // this interface) with virtual destruction. This class exists for the
143 // platform-dependent code to hang any timer data/resource info/functionality
144 // off of.
145 class TimerInterface {
146  public:
147   // Default constructor for the abstract interface.
TimerInterface()148   TimerInterface() {}
149 
150   // Default destructor for the abstract interface.
~TimerInterface()151   virtual ~TimerInterface() {}
152 
153   // Returns the number of microseconds elapsed in a completed timer.
154   virtual uint64 Microseconds() const = 0;
155 
156   // Returns the number of nanoseconds elapsed in a completed timer.
157   virtual uint64 Nanoseconds() const = 0;
158 
159  private:
160   SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
161 };
162 
163 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
164 //
165 // Various platforms will provide an implementation that satisfy this interface.
166 class StreamExecutorInterface {
167  public:
168   // Default constructor for the abstract interface.
StreamExecutorInterface()169   StreamExecutorInterface() {}
170 
171   // Default destructor for the abstract interface.
~StreamExecutorInterface()172   virtual ~StreamExecutorInterface() {}
173 
174   // Returns the (transitively) wrapped executor if this executor is
175   // wrapping another executor; otherwise, returns this.
GetUnderlyingExecutor()176   virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; }
177 
178   // See the StreamExecutor interface for comments on the same-named methods.
179   virtual port::Status Init(int device_ordinal,
180                             DeviceOptions device_options) = 0;
181 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)182   virtual port::Status GetKernel(const MultiKernelLoaderSpec &spec,
183                                  KernelBase *kernel) {
184     return port::UnimplementedError("Not Implemented");
185   }
UnloadModule(ModuleHandle module_handle)186   virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)187   virtual port::Status LoadModule(const MultiModuleLoaderSpec &spec,
188                                   ModuleHandle *module_handle) {
189     return port::UnimplementedError("Not Implemented");
190   }
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)191   virtual port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
192                               const BlockDim &block_dims, const KernelBase &k,
193                               const KernelArgsArrayBase &args) {
194     return port::UnimplementedError("Not Implemented");
195   }
196 
197   // Releases any state associated with the kernel.
UnloadKernel(const KernelBase * kernel)198   virtual void UnloadKernel(const KernelBase *kernel) {}
199   virtual DeviceMemoryBase Allocate(uint64 size, int64 memory_space) = 0;
Allocate(uint64 size)200   DeviceMemoryBase Allocate(uint64 size) {
201     return Allocate(size, /*memory_space=*/0);
202   }
203   virtual void *GetSubBuffer(DeviceMemoryBase *parent, uint64 offset,
204                              uint64 size) = 0;
205   virtual void Deallocate(DeviceMemoryBase *mem) = 0;
206   // Allocates unified memory space of the given size, if supported.
207   // See
208   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
209   // for more details on unified memory.
UnifiedMemoryAllocate(uint64 size)210   virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; }
211 
212   // Deallocates unified memory space previously allocated with
213   // UnifiedMemoryAllocate.
UnifiedMemoryDeallocate(void * mem)214   virtual void UnifiedMemoryDeallocate(void *mem) {}
215   virtual void *HostMemoryAllocate(uint64 size) = 0;
216   virtual void HostMemoryDeallocate(void *mem) = 0;
217   virtual bool HostMemoryRegister(void *mem, uint64 size) = 0;
218   virtual bool HostMemoryUnregister(void *mem) = 0;
219   virtual bool SynchronizeAllActivity() = 0;
220   virtual port::Status SynchronousMemZero(DeviceMemoryBase *location,
221                                           uint64 size) = 0;
222   virtual port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
223                                          uint64 size) = 0;
224   virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
225                                          const void *host_src, uint64 size) = 0;
226   virtual port::Status SynchronousMemcpy(void *host_dst,
227                                          const DeviceMemoryBase &gpu_src,
228                                          uint64 size) = 0;
229   virtual port::Status SynchronousMemcpyDeviceToDevice(
230       DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
231       uint64 size) = 0;
232   virtual port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
233                                uint64 size) = 0;
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)234   virtual port::Status Memset(Stream *stream, DeviceMemoryBase *location,
235                               uint8 pattern, uint64 size) {
236     return port::InternalError("Not implemented");
237   }
238   virtual port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
239                                 uint32 pattern, uint64 size) = 0;
240   virtual bool Memcpy(Stream *stream, void *host_dst,
241                       const DeviceMemoryBase &gpu_src, uint64 size) = 0;
242   virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
243                       const void *host_src, uint64 size) = 0;
244   virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
245                                     const DeviceMemoryBase &gpu_src,
246                                     uint64 size) = 0;
247   virtual bool HostCallback(Stream *stream, std::function<void()> callback);
248   virtual bool HostCallback(Stream *stream,
249                             std::function<port::Status()> callback) = 0;
250   virtual port::Status AllocateEvent(Event *event) = 0;
251   virtual port::Status DeallocateEvent(Event *event) = 0;
252   virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
253   virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0;
254   virtual Event::Status PollForEventStatus(Event *event) = 0;
255   virtual bool AllocateStream(Stream *stream) = 0;
256   virtual void DeallocateStream(Stream *stream) = 0;
257   virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0;
258   virtual bool AllocateTimer(Timer *timer) = 0;
259   virtual void DeallocateTimer(Timer *timer) = 0;
260   virtual bool StartTimer(Stream *stream, Timer *timer) = 0;
261   virtual bool StopTimer(Stream *stream, Timer *timer) = 0;
262   virtual port::Status BlockHostUntilDone(Stream *stream) = 0;
GetStatus(Stream * stream)263   virtual port::Status GetStatus(Stream *stream) {
264     return port::Status(port::error::UNIMPLEMENTED,
265                         "GetStatus is not supported on this executor.");
266   }
267   virtual int PlatformDeviceCount() = 0;
268   virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0;
269   virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0;
270   virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0;
271   virtual port::Status SetDeviceSharedMemoryConfig(
272       SharedMemoryConfig config) = 0;
273 
GetDeviceLoad()274   virtual int64 GetDeviceLoad() { return -1; }
275 
DeviceMemoryUsage(int64 * free,int64 * total)276   virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const {
277     return false;
278   }
279 
280   // Retrieves device pointer and size for a symbol. The device pointer is
281   // stored at mem, and the size is stored at size. Either mem or bytes can be
282   // null, however, both of them cannot be null at the same time. To use
283   // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
284   // is found.
285   //
286   // If ModuleHandle is set then we search for `symbol_name` only within the
287   // module corresponding to `module_handle`.  Otherwise all loaded modules are
288   // searched.
GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)289   virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
290                          void **mem, size_t *bytes) {
291     return false;
292   }
293 
294   // Creates a new DeviceDescription object. Ownership is transferred to the
295   // caller.
296   virtual port::StatusOr<std::unique_ptr<DeviceDescription>>
297   CreateDeviceDescription() const = 0;
298 
299   // Attempts to register the provided TraceListener with the device-specific
300   // Executor implementation. When this is called, the PIMPL interface has
301   // already taken ownership of the object and is managing the generic tracing
302   // events. The device-specific implementation must determine if the passed
303   // listener is of a type appropriate for it to trace during registration (and
304   // before dispatching events to it).
305   // Returns true if the listener was successfully registered, false otherwise.
306   // Does not take ownership of listener.
RegisterTraceListener(TraceListener * listener)307   virtual bool RegisterTraceListener(TraceListener *listener) { return false; }
308 
309   // Unregisters the specified listener from the device-specific Executor.
310   // Returns true if the listener was successfully registered, false otherwise.
UnregisterTraceListener(TraceListener * listener)311   virtual bool UnregisterTraceListener(TraceListener *listener) {
312     return false;
313   }
314 
315   // Returns whether this StreamExecutor has BLAS support for its underlying
316   // platform.
SupportsBlas()317   virtual bool SupportsBlas() const { return false; }
318 
319   // Creates a new BlasSupport object, ownership is transferred to the caller.
320   // If SupportsBlas() is false, this will always return null.
321   //
322   // If SupportsBlas() is true, this may return null, for example, if the BLAS
323   // initialization fails.
CreateBlas()324   virtual blas::BlasSupport *CreateBlas() { return nullptr; }
325 
326   // Returns whether this StreamExecutor has FFT support for its underlying
327   // platform.
SupportsFft()328   virtual bool SupportsFft() const { return false; }
329 
330   // Creates a new fft::FftSupport object, ownership is transferred to the
331   // caller.
332   // If SupportsFft() is false, this will always return null.
333   //
334   // If SupportsFft() is true, this may return null, for example, if the FFT
335   // initialization fails.
CreateFft()336   virtual fft::FftSupport *CreateFft() { return nullptr; }
337 
338   // Returns whether this StreamExecutor has Random Number Generation support
339   // for
340   // its underlying platform.
SupportsRng()341   virtual bool SupportsRng() const { return false; }
342 
343   // Returns whether this StreamExecutor has neural net support for its
344   // underlying
345   // platform.
SupportsDnn()346   virtual bool SupportsDnn() const { return false; }
347 
348   // Creates a new RngSupport object, ownership is transferred to the caller.
349   // If SupportsRng() is false, this will always return null.
350   //
351   // If SupportsRng() is true, this may return null, for example, if the RNG
352   // initialization fails.
CreateRng()353   virtual rng::RngSupport *CreateRng() { return nullptr; }
354 
355   // Creates a new DnnSupport object, ownership is transferred to the caller.
356   // If SupportsDnn() is false, this will always return null.
357   //
358   // If SupportsDnn() is true, this may return null, for example, if the DNN
359   // initialization fails.
CreateDnn()360   virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
361 
362   // Each call creates a new instance of the platform-specific implementation of
363   // the corresponding interface type.
364   virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0;
365   virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0;
366   virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
367   virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
368 
369   // Returns the CUDA or ROCm context associated with this StreamExecutor
370   // platform implementation.
371   //
372   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
373   // causing a fatal error if it is not. This hack is made available solely for
374   // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
375   // as a platform.
GpuContextHack()376   virtual void *GpuContextHack() { return nullptr; }
377 
378   // Return allocator statistics.
GetAllocatorStats()379   virtual absl::optional<AllocatorStats> GetAllocatorStats() {
380     return absl::nullopt;
381   }
382 
383   // Clears the compilation cache from volatile memory. Returns OK if no
384   // compilation cache exists or if clearing the compilation cache is
385   // unsupported. Caches in non-volatile storage are unaffected.
FlushCompilationCache()386   virtual port::Status FlushCompilationCache() { return port::Status::OK(); }
387 
388  private:
389   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
390 };
391 
392 }  // namespace internal
393 }  // namespace stream_executor
394 
395 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
396