1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Interfaces for platform-dependent implementations to satisfy. This are 17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e. 18 // the StreamExecutor is just a husk that delegates calls to the 19 // platform-specific objects which implement the interfaces defined here. 20 21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 23 24 #include <functional> 25 #include <map> 26 #include <memory> 27 #include <utility> 28 #include <vector> 29 30 #include "absl/types/optional.h" 31 #include "tensorflow/stream_executor/allocator_stats.h" 32 #include "tensorflow/stream_executor/device_description.h" 33 #include "tensorflow/stream_executor/device_memory.h" 34 #include "tensorflow/stream_executor/device_options.h" 35 #include "tensorflow/stream_executor/dnn.h" 36 #include "tensorflow/stream_executor/event.h" 37 #include "tensorflow/stream_executor/kernel.h" 38 #include "tensorflow/stream_executor/kernel_cache_config.h" 39 #include "tensorflow/stream_executor/kernel_spec.h" 40 #include "tensorflow/stream_executor/launch_dim.h" 41 #include "tensorflow/stream_executor/lib/status.h" 42 #include "tensorflow/stream_executor/lib/statusor.h" 43 #include "tensorflow/stream_executor/module_spec.h" 44 #include "tensorflow/stream_executor/platform.h" 45 #include "tensorflow/stream_executor/platform/port.h" 46 #include "tensorflow/stream_executor/plugin_registry.h" 47 #include "tensorflow/stream_executor/shared_memory_config.h" 48 #include "tensorflow/stream_executor/trace_listener.h" 49 50 namespace stream_executor { 51 52 class Stream; 53 class Timer; 54 55 // An opaque handle to a loaded module. 56 // 57 // An instance of this is returned from StreamExecutor::GetModule. 58 class ModuleHandle { 59 public: id_(id)60 /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {} 61 62 // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a 63 // null pointer. id()64 void *id() const { return id_; } 65 66 explicit operator bool() const { return id() != nullptr; } 67 68 private: 69 void *id_; 70 }; 71 72 namespace internal { 73 74 // Platform-dependent interface class for the generic Events interface, in 75 // the PIMPL style. 76 class EventInterface { 77 public: EventInterface()78 EventInterface() {} ~EventInterface()79 virtual ~EventInterface() {} 80 81 private: 82 SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); 83 }; 84 85 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to 86 // this interface) with virtual destruction. This class exists for the 87 // platform-dependent code to hang any kernel data/resource info/functionality 88 // off of. 89 class KernelInterface { 90 public: 91 // Default constructor for the abstract interface. KernelInterface()92 KernelInterface() {} 93 94 // Default destructor for the abstract interface. ~KernelInterface()95 virtual ~KernelInterface() {} 96 97 // Returns the number of formal parameters that this kernel accepts. 98 virtual unsigned Arity() const = 0; 99 100 // Sets the preferred cache configuration. 101 virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; 102 103 // Gets the preferred cache configuration. 104 virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; 105 106 private: 107 SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); 108 }; 109 110 // Pointer-to-implementation object type (i.e. the Stream class delegates to 111 // this interface) with virtual destruction. This class exists for the 112 // platform-dependent code to hang any kernel data/resource info/functionality 113 // off of. 114 class StreamInterface { 115 public: 116 // Default constructor for the abstract interface. StreamInterface()117 StreamInterface() {} 118 119 // Default destructor for the abstract interface. ~StreamInterface()120 virtual ~StreamInterface() {} 121 122 // Returns the GPU stream associated with this platform's stream 123 // implementation. 124 // 125 // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, 126 // causing a fatal error if it is not. This hack is made available solely for 127 // use from distbelief code, which temporarily has strong ties to CUDA or 128 // ROCm as a platform. GpuStreamHack()129 virtual void *GpuStreamHack() { return nullptr; } 130 131 // See the above comment on GpuStreamHack -- this further breaks abstraction 132 // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a 133 // platform, and a historical attachment to a programming model which takes a 134 // stream-slot rather than a stream-value. GpuStreamMemberHack()135 virtual void **GpuStreamMemberHack() { return nullptr; } 136 137 private: 138 SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); 139 }; 140 141 // Pointer-to-implementation object type (i.e. the Timer class delegates to 142 // this interface) with virtual destruction. This class exists for the 143 // platform-dependent code to hang any timer data/resource info/functionality 144 // off of. 145 class TimerInterface { 146 public: 147 // Default constructor for the abstract interface. TimerInterface()148 TimerInterface() {} 149 150 // Default destructor for the abstract interface. ~TimerInterface()151 virtual ~TimerInterface() {} 152 153 // Returns the number of microseconds elapsed in a completed timer. 154 virtual uint64 Microseconds() const = 0; 155 156 // Returns the number of nanoseconds elapsed in a completed timer. 157 virtual uint64 Nanoseconds() const = 0; 158 159 private: 160 SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); 161 }; 162 163 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL). 164 // 165 // Various platforms will provide an implementation that satisfy this interface. 166 class StreamExecutorInterface { 167 public: 168 // Default constructor for the abstract interface. StreamExecutorInterface()169 StreamExecutorInterface() {} 170 171 // Default destructor for the abstract interface. ~StreamExecutorInterface()172 virtual ~StreamExecutorInterface() {} 173 174 // Returns the (transitively) wrapped executor if this executor is 175 // wrapping another executor; otherwise, returns this. GetUnderlyingExecutor()176 virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; } 177 178 // See the StreamExecutor interface for comments on the same-named methods. 179 virtual port::Status Init(int device_ordinal, 180 DeviceOptions device_options) = 0; 181 GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)182 virtual bool GetKernel(const MultiKernelLoaderSpec &spec, 183 KernelBase *kernel) { 184 return false; 185 } LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)186 virtual bool LoadModule(const MultiModuleLoaderSpec &spec, 187 ModuleHandle *module_handle) { 188 return false; 189 } UnloadModule(ModuleHandle module_handle)190 virtual bool UnloadModule(ModuleHandle module_handle) { return false; } Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)191 virtual bool Launch(Stream *stream, const ThreadDim &thread_dims, 192 const BlockDim &block_dims, const KernelBase &k, 193 const KernelArgsArrayBase &args) { 194 return false; 195 } 196 // Releases any state associated with the kernel. UnloadKernel(const KernelBase * kernel)197 virtual void UnloadKernel(const KernelBase *kernel) {} 198 virtual void *Allocate(uint64 size) = 0; 199 virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset, 200 uint64 size) = 0; 201 virtual void Deallocate(DeviceMemoryBase *mem) = 0; 202 // Allocates unified memory space of the given size, if supported. 203 // See 204 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd 205 // for more details on unified memory. UnifiedMemoryAllocate(uint64 size)206 virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; } 207 208 // Deallocates unified memory space previously allocated with 209 // UnifiedMemoryAllocate. UnifiedMemoryDeallocate(void * mem)210 virtual void UnifiedMemoryDeallocate(void *mem) {} 211 virtual void *HostMemoryAllocate(uint64 size) = 0; 212 virtual void HostMemoryDeallocate(void *mem) = 0; 213 virtual bool HostMemoryRegister(void *mem, uint64 size) = 0; 214 virtual bool HostMemoryUnregister(void *mem) = 0; 215 virtual bool SynchronizeAllActivity() = 0; 216 virtual bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) = 0; 217 virtual bool SynchronousMemSet(DeviceMemoryBase *location, int value, 218 uint64 size) = 0; 219 virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst, 220 const void *host_src, uint64 size) = 0; 221 virtual port::Status SynchronousMemcpy(void *host_dst, 222 const DeviceMemoryBase &gpu_src, 223 uint64 size) = 0; 224 virtual port::Status SynchronousMemcpyDeviceToDevice( 225 DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, 226 uint64 size) = 0; 227 virtual bool MemZero(Stream *stream, DeviceMemoryBase *location, 228 uint64 size) = 0; Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)229 virtual bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern, 230 uint64 size) { 231 return false; 232 } 233 virtual bool Memset32(Stream *stream, DeviceMemoryBase *location, 234 uint32 pattern, uint64 size) = 0; 235 virtual bool Memcpy(Stream *stream, void *host_dst, 236 const DeviceMemoryBase &gpu_src, uint64 size) = 0; 237 virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, 238 const void *host_src, uint64 size) = 0; 239 virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst, 240 const DeviceMemoryBase &gpu_src, 241 uint64 size) = 0; 242 virtual bool HostCallback(Stream *stream, std::function<void()> callback); 243 virtual bool HostCallback(Stream *stream, 244 std::function<port::Status()> callback) = 0; 245 virtual port::Status AllocateEvent(Event *event) = 0; 246 virtual port::Status DeallocateEvent(Event *event) = 0; 247 virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; 248 virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0; 249 virtual Event::Status PollForEventStatus(Event *event) = 0; 250 virtual bool AllocateStream(Stream *stream) = 0; 251 virtual void DeallocateStream(Stream *stream) = 0; 252 virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0; 253 virtual bool AllocateTimer(Timer *timer) = 0; 254 virtual void DeallocateTimer(Timer *timer) = 0; 255 virtual bool StartTimer(Stream *stream, Timer *timer) = 0; 256 virtual bool StopTimer(Stream *stream, Timer *timer) = 0; 257 virtual port::Status BlockHostUntilDone(Stream *stream) = 0; GetStatus(Stream * stream)258 virtual port::Status GetStatus(Stream *stream) { 259 return port::Status(port::error::UNIMPLEMENTED, 260 "GetStatus is not supported on this executor."); 261 } 262 virtual int PlatformDeviceCount() = 0; 263 virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0; 264 virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0; 265 virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0; 266 virtual port::Status SetDeviceSharedMemoryConfig( 267 SharedMemoryConfig config) = 0; 268 GetDeviceLoad()269 virtual int64 GetDeviceLoad() { return -1; } 270 DeviceMemoryUsage(int64 * free,int64 * total)271 virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const { 272 return false; 273 } 274 275 // Retrieves device pointer and size for a symbol. The device pointer is 276 // stored at mem, and the size is stored at size. Either mem or bytes can be 277 // null, however, both of them cannot be null at the same time. To use 278 // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol 279 // is found. 280 // 281 // If ModuleHandle is set then we search for `symbol_name` only within the 282 // module corresponding to `module_handle`. Otherwise all loaded modules are 283 // searched. GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)284 virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, 285 void **mem, size_t *bytes) { 286 return false; 287 } 288 289 // Creates a new DeviceDescription object. Ownership is transferred to the 290 // caller. 291 virtual DeviceDescription *PopulateDeviceDescription() const = 0; 292 293 // Attempts to register the provided TraceListener with the device-specific 294 // Executor implementation. When this is called, the PIMPL interface has 295 // already taken ownership of the object and is managing the generic tracing 296 // events. The device-specific implementation must determine if the passed 297 // listener is of a type appropriate for it to trace during registration (and 298 // before dispatching events to it). 299 // Returns true if the listener was successfully registered, false otherwise. 300 // Does not take ownership of listener. RegisterTraceListener(TraceListener * listener)301 virtual bool RegisterTraceListener(TraceListener* listener) { return false; } 302 303 // Unregisters the specified listener from the device-specific Executor. 304 // Returns true if the listener was successfully registered, false otherwise. UnregisterTraceListener(TraceListener * listener)305 virtual bool UnregisterTraceListener(TraceListener* listener) { 306 return false; 307 } 308 309 // Returns whether this StreamExecutor has BLAS support for its underlying 310 // platform. SupportsBlas()311 virtual bool SupportsBlas() const { return false; } 312 313 // Creates a new BlasSupport object, ownership is transferred to the caller. 314 // If SupportsBlas() is false, this will always return null. 315 // 316 // If SupportsBlas() is true, this may return null, for example, if the BLAS 317 // initialization fails. CreateBlas()318 virtual blas::BlasSupport *CreateBlas() { return nullptr; } 319 320 // Returns whether this StreamExecutor has FFT support for its underlying 321 // platform. SupportsFft()322 virtual bool SupportsFft() const { return false; } 323 324 // Creates a new fft::FftSupport object, ownership is transferred to the 325 // caller. 326 // If SupportsFft() is false, this will always return null. 327 // 328 // If SupportsFft() is true, this may return null, for example, if the FFT 329 // initialization fails. CreateFft()330 virtual fft::FftSupport *CreateFft() { return nullptr; } 331 332 // Returns whether this StreamExecutor has Random Number Generation support 333 // for 334 // its underlying platform. SupportsRng()335 virtual bool SupportsRng() const { return false; } 336 337 // Returns whether this StreamExecutor has neural net support for its 338 // underlying 339 // platform. SupportsDnn()340 virtual bool SupportsDnn() const { return false; } 341 342 // Creates a new RngSupport object, ownership is transferred to the caller. 343 // If SupportsRng() is false, this will always return null. 344 // 345 // If SupportsRng() is true, this may return null, for example, if the RNG 346 // initialization fails. CreateRng()347 virtual rng::RngSupport *CreateRng() { return nullptr; } 348 349 // Creates a new DnnSupport object, ownership is transferred to the caller. 350 // If SupportsDnn() is false, this will always return null. 351 // 352 // If SupportsDnn() is true, this may return null, for example, if the DNN 353 // initialization fails. CreateDnn()354 virtual dnn::DnnSupport *CreateDnn() { return nullptr; } 355 356 // Each call creates a new instance of the platform-specific implementation of 357 // the corresponding interface type. 358 virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0; 359 virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0; 360 virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; 361 virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; 362 363 // Returns the CUDA or ROCm context associated with this StreamExecutor 364 // platform implementation. 365 // 366 // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, 367 // causing a fatal error if it is not. This hack is made available solely for 368 // use from distbelief code, which temporarily has strong ties to CUDA or ROCm 369 // as a platform. GpuContextHack()370 virtual void *GpuContextHack() { return nullptr; } 371 372 // Return allocator statistics. GetAllocatorStats()373 virtual absl::optional<AllocatorStats> GetAllocatorStats() { 374 return absl::nullopt; 375 } 376 377 private: 378 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); 379 }; 380 381 using StreamExecutorFactory = 382 std::function<StreamExecutorInterface *(const PluginConfig &)>; 383 using EventFactory = std::function<EventInterface *(StreamExecutor *)>; 384 using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>; 385 using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>; 386 using KernelFactory = std::function<KernelInterface*()>; 387 388 StreamExecutorFactory *MakeCUDAExecutorImplementation(); 389 390 StreamExecutorFactory *MakeROCMExecutorImplementation(); 391 392 StreamExecutorFactory *MakeOpenCLExecutorImplementation(); 393 394 extern StreamExecutorFactory MakeHostExecutorImplementation; 395 396 397 } // namespace internal 398 } // namespace stream_executor 399 400 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 401