• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Interfaces for platform-dependent implementations to satisfy. This are
17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e.
18 // the StreamExecutor is just a husk that delegates calls to the
19 // platform-specific objects which implement the interfaces defined here.
20 
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
23 
24 #include <functional>
25 #include <map>
26 #include <memory>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/types/optional.h"
31 #include "tensorflow/stream_executor/allocator_stats.h"
32 #include "tensorflow/stream_executor/device_description.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34 #include "tensorflow/stream_executor/device_options.h"
35 #include "tensorflow/stream_executor/dnn.h"
36 #include "tensorflow/stream_executor/event.h"
37 #include "tensorflow/stream_executor/kernel.h"
38 #include "tensorflow/stream_executor/kernel_cache_config.h"
39 #include "tensorflow/stream_executor/kernel_spec.h"
40 #include "tensorflow/stream_executor/launch_dim.h"
41 #include "tensorflow/stream_executor/lib/status.h"
42 #include "tensorflow/stream_executor/lib/statusor.h"
43 #include "tensorflow/stream_executor/module_spec.h"
44 #include "tensorflow/stream_executor/platform.h"
45 #include "tensorflow/stream_executor/platform/port.h"
46 #include "tensorflow/stream_executor/plugin_registry.h"
47 #include "tensorflow/stream_executor/trace_listener.h"
48 
49 namespace stream_executor {
50 
51 class Stream;
52 class Timer;
53 
54 // An opaque handle to a loaded module.
55 //
56 // An instance of this is returned from StreamExecutor::GetModule.
57 class ModuleHandle {
58  public:
id_(id)59   /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {}
60 
61   // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
62   // null pointer.
id()63   void *id() const { return id_; }
64 
65   explicit operator bool() const { return id() != nullptr; }
66 
67  private:
68   void *id_;
69 };
70 
71 namespace internal {
72 
73 // Platform-dependent interface class for the generic Events interface, in
74 // the PIMPL style.
75 class EventInterface {
76  public:
EventInterface()77   EventInterface() {}
~EventInterface()78   virtual ~EventInterface() {}
79 
80  private:
81   SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
82 };
83 
84 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to
85 // this interface) with virtual destruction. This class exists for the
86 // platform-dependent code to hang any kernel data/resource info/functionality
87 // off of.
88 class KernelInterface {
89  public:
90   // Default constructor for the abstract interface.
KernelInterface()91   KernelInterface() {}
92 
93   // Default destructor for the abstract interface.
~KernelInterface()94   virtual ~KernelInterface() {}
95 
96   // Returns the number of formal parameters that this kernel accepts.
97   virtual unsigned Arity() const = 0;
98 
99   // Sets the preferred cache configuration.
100   virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
101 
102   // Gets the preferred cache configuration.
103   virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
104 
105  private:
106   SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
107 };
108 
109 // Pointer-to-implementation object type (i.e. the Stream class delegates to
110 // this interface) with virtual destruction. This class exists for the
111 // platform-dependent code to hang any kernel data/resource info/functionality
112 // off of.
113 class StreamInterface {
114  public:
115   // Default constructor for the abstract interface.
StreamInterface()116   StreamInterface() {}
117 
118   // Default destructor for the abstract interface.
~StreamInterface()119   virtual ~StreamInterface() {}
120 
121   // Returns the GPU stream associated with this platform's stream
122   // implementation.
123   //
124   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
125   // causing a fatal error if it is not. This hack is made available solely for
126   // use from distbelief code, which temporarily has strong ties to CUDA or
127   // ROCm as a platform.
GpuStreamHack()128   virtual void *GpuStreamHack() { return nullptr; }
129 
130   // See the above comment on GpuStreamHack -- this further breaks abstraction
131   // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a
132   // platform, and a historical attachment to a programming model which takes a
133   // stream-slot rather than a stream-value.
GpuStreamMemberHack()134   virtual void **GpuStreamMemberHack() { return nullptr; }
135 
136  private:
137   SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
138 };
139 
140 // Pointer-to-implementation object type (i.e. the Timer class delegates to
141 // this interface) with virtual destruction. This class exists for the
142 // platform-dependent code to hang any timer data/resource info/functionality
143 // off of.
144 class TimerInterface {
145  public:
146   // Default constructor for the abstract interface.
TimerInterface()147   TimerInterface() {}
148 
149   // Default destructor for the abstract interface.
~TimerInterface()150   virtual ~TimerInterface() {}
151 
152   // Returns the number of microseconds elapsed in a completed timer.
153   virtual uint64 Microseconds() const = 0;
154 
155   // Returns the number of nanoseconds elapsed in a completed timer.
156   virtual uint64 Nanoseconds() const = 0;
157 
158  private:
159   SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
160 };
161 
162 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
163 //
164 // Various platforms will provide an implementation that satisfy this interface.
165 class StreamExecutorInterface {
166  public:
167   // Default constructor for the abstract interface.
StreamExecutorInterface()168   StreamExecutorInterface() {}
169 
170   // Default destructor for the abstract interface.
~StreamExecutorInterface()171   virtual ~StreamExecutorInterface() {}
172 
173   // Returns the (transitively) wrapped executor if this executor is
174   // wrapping another executor; otherwise, returns this.
GetUnderlyingExecutor()175   virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; }
176 
177   // See the StreamExecutor interface for comments on the same-named methods.
178   virtual port::Status Init(int device_ordinal,
179                             DeviceOptions device_options) = 0;
180 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)181   virtual port::Status GetKernel(const MultiKernelLoaderSpec &spec,
182                                  KernelBase *kernel) {
183     return port::UnimplementedError("Not Implemented");
184   }
UnloadModule(ModuleHandle module_handle)185   virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)186   virtual port::Status LoadModule(const MultiModuleLoaderSpec &spec,
187                                   ModuleHandle *module_handle) {
188     return port::UnimplementedError("Not Implemented");
189   }
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)190   virtual port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
191                               const BlockDim &block_dims, const KernelBase &k,
192                               const KernelArgsArrayBase &args) {
193     return port::UnimplementedError("Not Implemented");
194   }
195 
196   // Releases any state associated with the kernel.
UnloadKernel(const KernelBase * kernel)197   virtual void UnloadKernel(const KernelBase *kernel) {}
198   virtual DeviceMemoryBase Allocate(uint64 size, int64 memory_space) = 0;
Allocate(uint64 size)199   DeviceMemoryBase Allocate(uint64 size) {
200     return Allocate(size, /*memory_space=*/0);
201   }
202   virtual void *GetSubBuffer(DeviceMemoryBase *parent, uint64 offset,
203                              uint64 size) = 0;
204   virtual void Deallocate(DeviceMemoryBase *mem) = 0;
205   // Allocates unified memory space of the given size, if supported.
206   // See
207   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
208   // for more details on unified memory.
UnifiedMemoryAllocate(uint64 size)209   virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; }
210 
211   // Deallocates unified memory space previously allocated with
212   // UnifiedMemoryAllocate.
UnifiedMemoryDeallocate(void * mem)213   virtual void UnifiedMemoryDeallocate(void *mem) {}
214   virtual void *HostMemoryAllocate(uint64 size) = 0;
215   virtual void HostMemoryDeallocate(void *mem) = 0;
216   virtual bool HostMemoryRegister(void *mem, uint64 size) = 0;
217   virtual bool HostMemoryUnregister(void *mem) = 0;
218   virtual bool SynchronizeAllActivity() = 0;
219   virtual port::Status SynchronousMemZero(DeviceMemoryBase *location,
220                                           uint64 size) = 0;
221   virtual port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
222                                          uint64 size) = 0;
223   virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
224                                          const void *host_src, uint64 size) = 0;
225   virtual port::Status SynchronousMemcpy(void *host_dst,
226                                          const DeviceMemoryBase &gpu_src,
227                                          uint64 size) = 0;
228   virtual port::Status SynchronousMemcpyDeviceToDevice(
229       DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
230       uint64 size) = 0;
231   virtual port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
232                                uint64 size) = 0;
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)233   virtual port::Status Memset(Stream *stream, DeviceMemoryBase *location,
234                               uint8 pattern, uint64 size) {
235     return port::InternalError("Not implemented");
236   }
237   virtual port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
238                                 uint32 pattern, uint64 size) = 0;
239   virtual bool Memcpy(Stream *stream, void *host_dst,
240                       const DeviceMemoryBase &gpu_src, uint64 size) = 0;
241   virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
242                       const void *host_src, uint64 size) = 0;
243   virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
244                                     const DeviceMemoryBase &gpu_src,
245                                     uint64 size) = 0;
246   virtual bool HostCallback(Stream *stream, std::function<void()> callback);
247   virtual bool HostCallback(Stream *stream,
248                             std::function<port::Status()> callback) = 0;
249   virtual port::Status AllocateEvent(Event *event) = 0;
250   virtual port::Status DeallocateEvent(Event *event) = 0;
251   virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
252   virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0;
253   virtual Event::Status PollForEventStatus(Event *event) = 0;
254   virtual bool AllocateStream(Stream *stream) = 0;
255   virtual void DeallocateStream(Stream *stream) = 0;
256   virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0;
257   virtual bool AllocateTimer(Timer *timer) = 0;
258   virtual void DeallocateTimer(Timer *timer) = 0;
259   virtual bool StartTimer(Stream *stream, Timer *timer) = 0;
260   virtual bool StopTimer(Stream *stream, Timer *timer) = 0;
261   virtual port::Status BlockHostUntilDone(Stream *stream) = 0;
GetStatus(Stream * stream)262   virtual port::Status GetStatus(Stream *stream) {
263     return port::Status(port::error::UNIMPLEMENTED,
264                         "GetStatus is not supported on this executor.");
265   }
266   virtual int PlatformDeviceCount() = 0;
267   virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0;
268   virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0;
269 
GetDeviceLoad()270   virtual int64 GetDeviceLoad() { return -1; }
271 
DeviceMemoryUsage(int64 * free,int64 * total)272   virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const {
273     return false;
274   }
275 
276   // Retrieves device pointer and size for a symbol. The device pointer is
277   // stored at mem, and the size is stored at size. Either mem or bytes can be
278   // null, however, both of them cannot be null at the same time. To use
279   // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
280   // is found.
281   //
282   // If ModuleHandle is set then we search for `symbol_name` only within the
283   // module corresponding to `module_handle`.  Otherwise all loaded modules are
284   // searched.
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)285   virtual bool GetSymbol(const std::string &symbol_name,
286                          ModuleHandle module_handle, void **mem,
287                          size_t *bytes) {
288     return false;
289   }
290 
291   // Creates a new DeviceDescription object. Ownership is transferred to the
292   // caller.
293   virtual port::StatusOr<std::unique_ptr<DeviceDescription>>
294   CreateDeviceDescription() const = 0;
295 
296   // Attempts to register the provided TraceListener with the device-specific
297   // Executor implementation. When this is called, the PIMPL interface has
298   // already taken ownership of the object and is managing the generic tracing
299   // events. The device-specific implementation must determine if the passed
300   // listener is of a type appropriate for it to trace during registration (and
301   // before dispatching events to it).
302   // Returns true if the listener was successfully registered, false otherwise.
303   // Does not take ownership of listener.
RegisterTraceListener(TraceListener * listener)304   virtual bool RegisterTraceListener(TraceListener *listener) { return false; }
305 
306   // Unregisters the specified listener from the device-specific Executor.
307   // Returns true if the listener was successfully registered, false otherwise.
UnregisterTraceListener(TraceListener * listener)308   virtual bool UnregisterTraceListener(TraceListener *listener) {
309     return false;
310   }
311 
312   // Returns whether this StreamExecutor has BLAS support for its underlying
313   // platform.
SupportsBlas()314   virtual bool SupportsBlas() const { return false; }
315 
316   // Creates a new BlasSupport object, ownership is transferred to the caller.
317   // If SupportsBlas() is false, this will always return null.
318   //
319   // If SupportsBlas() is true, this may return null, for example, if the BLAS
320   // initialization fails.
CreateBlas()321   virtual blas::BlasSupport *CreateBlas() { return nullptr; }
322 
323   // Returns whether this StreamExecutor has FFT support for its underlying
324   // platform.
SupportsFft()325   virtual bool SupportsFft() const { return false; }
326 
327   // Creates a new fft::FftSupport object, ownership is transferred to the
328   // caller.
329   // If SupportsFft() is false, this will always return null.
330   //
331   // If SupportsFft() is true, this may return null, for example, if the FFT
332   // initialization fails.
CreateFft()333   virtual fft::FftSupport *CreateFft() { return nullptr; }
334 
335   // Returns whether this StreamExecutor has Random Number Generation support
336   // for
337   // its underlying platform.
SupportsRng()338   virtual bool SupportsRng() const { return false; }
339 
340   // Returns whether this StreamExecutor has neural net support for its
341   // underlying
342   // platform.
SupportsDnn()343   virtual bool SupportsDnn() const { return false; }
344 
345   // Creates a new RngSupport object, ownership is transferred to the caller.
346   // If SupportsRng() is false, this will always return null.
347   //
348   // If SupportsRng() is true, this may return null, for example, if the RNG
349   // initialization fails.
CreateRng()350   virtual rng::RngSupport *CreateRng() { return nullptr; }
351 
352   // Creates a new DnnSupport object, ownership is transferred to the caller.
353   // If SupportsDnn() is false, this will always return null.
354   //
355   // If SupportsDnn() is true, this may return null, for example, if the DNN
356   // initialization fails.
CreateDnn()357   virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
358 
359   // Each call creates a new instance of the platform-specific implementation of
360   // the corresponding interface type.
361   virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0;
362   virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0;
363   virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
364   virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
365 
366   // Returns the CUDA or ROCm context associated with this StreamExecutor
367   // platform implementation.
368   //
369   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
370   // causing a fatal error if it is not. This hack is made available solely for
371   // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
372   // as a platform.
GpuContextHack()373   virtual void *GpuContextHack() { return nullptr; }
374 
375   // Return allocator statistics.
GetAllocatorStats()376   virtual absl::optional<AllocatorStats> GetAllocatorStats() {
377     return absl::nullopt;
378   }
379 
380   // Clears the compilation cache from volatile memory. Returns OK if no
381   // compilation cache exists or if clearing the compilation cache is
382   // unsupported. Caches in non-volatile storage are unaffected.
FlushCompilationCache()383   virtual port::Status FlushCompilationCache() { return port::Status::OK(); }
384 
385  private:
386   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
387 };
388 
389 }  // namespace internal
390 }  // namespace stream_executor
391 
392 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
393