1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Interfaces for platform-dependent implementations to satisfy. This are 17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e. 18 // the StreamExecutor is just a husk that delegates calls to the 19 // platform-specific objects which implement the interfaces defined here. 20 21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 23 24 #include <functional> 25 #include <map> 26 #include <memory> 27 #include <utility> 28 #include <vector> 29 30 #include "absl/types/optional.h" 31 #include "tensorflow/stream_executor/allocator_stats.h" 32 #include "tensorflow/stream_executor/device_description.h" 33 #include "tensorflow/stream_executor/device_memory.h" 34 #include "tensorflow/stream_executor/device_options.h" 35 #include "tensorflow/stream_executor/dnn.h" 36 #include "tensorflow/stream_executor/event.h" 37 #include "tensorflow/stream_executor/kernel.h" 38 #include "tensorflow/stream_executor/kernel_cache_config.h" 39 #include "tensorflow/stream_executor/kernel_spec.h" 40 #include "tensorflow/stream_executor/launch_dim.h" 41 #include "tensorflow/stream_executor/lib/status.h" 42 #include "tensorflow/stream_executor/lib/statusor.h" 43 #include "tensorflow/stream_executor/module_spec.h" 44 #include "tensorflow/stream_executor/platform.h" 45 #include "tensorflow/stream_executor/platform/port.h" 46 #include "tensorflow/stream_executor/plugin_registry.h" 47 #include "tensorflow/stream_executor/shared_memory_config.h" 48 #include "tensorflow/stream_executor/trace_listener.h" 49 50 namespace stream_executor { 51 52 class Stream; 53 class Timer; 54 55 // An opaque handle to a loaded module. 56 // 57 // An instance of this is returned from StreamExecutor::GetModule. 58 class ModuleHandle { 59 public: id_(id)60 /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {} 61 62 // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a 63 // null pointer. id()64 void *id() const { return id_; } 65 66 explicit operator bool() const { return id() != nullptr; } 67 68 private: 69 void *id_; 70 }; 71 72 namespace internal { 73 74 // Platform-dependent interface class for the generic Events interface, in 75 // the PIMPL style. 76 class EventInterface { 77 public: EventInterface()78 EventInterface() {} ~EventInterface()79 virtual ~EventInterface() {} 80 81 private: 82 SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); 83 }; 84 85 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to 86 // this interface) with virtual destruction. This class exists for the 87 // platform-dependent code to hang any kernel data/resource info/functionality 88 // off of. 89 class KernelInterface { 90 public: 91 // Default constructor for the abstract interface. KernelInterface()92 KernelInterface() {} 93 94 // Default destructor for the abstract interface. ~KernelInterface()95 virtual ~KernelInterface() {} 96 97 // Returns the number of formal parameters that this kernel accepts. 98 virtual unsigned Arity() const = 0; 99 100 // Sets the preferred cache configuration. 101 virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; 102 103 // Gets the preferred cache configuration. 104 virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; 105 106 private: 107 SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); 108 }; 109 110 // Pointer-to-implementation object type (i.e. the Stream class delegates to 111 // this interface) with virtual destruction. This class exists for the 112 // platform-dependent code to hang any kernel data/resource info/functionality 113 // off of. 114 class StreamInterface { 115 public: 116 // Default constructor for the abstract interface. StreamInterface()117 StreamInterface() {} 118 119 // Default destructor for the abstract interface. ~StreamInterface()120 virtual ~StreamInterface() {} 121 122 // Returns the GPU stream associated with this platform's stream 123 // implementation. 124 // 125 // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, 126 // causing a fatal error if it is not. This hack is made available solely for 127 // use from distbelief code, which temporarily has strong ties to CUDA or 128 // ROCm as a platform. GpuStreamHack()129 virtual void *GpuStreamHack() { return nullptr; } 130 131 // See the above comment on GpuStreamHack -- this further breaks abstraction 132 // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a 133 // platform, and a historical attachment to a programming model which takes a 134 // stream-slot rather than a stream-value. GpuStreamMemberHack()135 virtual void **GpuStreamMemberHack() { return nullptr; } 136 137 private: 138 SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); 139 }; 140 141 // Pointer-to-implementation object type (i.e. the Timer class delegates to 142 // this interface) with virtual destruction. This class exists for the 143 // platform-dependent code to hang any timer data/resource info/functionality 144 // off of. 145 class TimerInterface { 146 public: 147 // Default constructor for the abstract interface. TimerInterface()148 TimerInterface() {} 149 150 // Default destructor for the abstract interface. ~TimerInterface()151 virtual ~TimerInterface() {} 152 153 // Returns the number of microseconds elapsed in a completed timer. 154 virtual uint64 Microseconds() const = 0; 155 156 // Returns the number of nanoseconds elapsed in a completed timer. 157 virtual uint64 Nanoseconds() const = 0; 158 159 private: 160 SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); 161 }; 162 163 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL). 164 // 165 // Various platforms will provide an implementation that satisfy this interface. 166 class StreamExecutorInterface { 167 public: 168 // Default constructor for the abstract interface. StreamExecutorInterface()169 StreamExecutorInterface() {} 170 171 // Default destructor for the abstract interface. ~StreamExecutorInterface()172 virtual ~StreamExecutorInterface() {} 173 174 // Returns the (transitively) wrapped executor if this executor is 175 // wrapping another executor; otherwise, returns this. GetUnderlyingExecutor()176 virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; } 177 178 // See the StreamExecutor interface for comments on the same-named methods. 179 virtual port::Status Init(int device_ordinal, 180 DeviceOptions device_options) = 0; 181 GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)182 virtual port::Status GetKernel(const MultiKernelLoaderSpec &spec, 183 KernelBase *kernel) { 184 return port::UnimplementedError("Not Implemented"); 185 } UnloadModule(ModuleHandle module_handle)186 virtual bool UnloadModule(ModuleHandle module_handle) { return false; } LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)187 virtual port::Status LoadModule(const MultiModuleLoaderSpec &spec, 188 ModuleHandle *module_handle) { 189 return port::UnimplementedError("Not Implemented"); 190 } Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)191 virtual port::Status Launch(Stream *stream, const ThreadDim &thread_dims, 192 const BlockDim &block_dims, const KernelBase &k, 193 const KernelArgsArrayBase &args) { 194 return port::UnimplementedError("Not Implemented"); 195 } 196 197 // Releases any state associated with the kernel. UnloadKernel(const KernelBase * kernel)198 virtual void UnloadKernel(const KernelBase *kernel) {} 199 virtual DeviceMemoryBase Allocate(uint64 size, int64 memory_space) = 0; Allocate(uint64 size)200 DeviceMemoryBase Allocate(uint64 size) { 201 return Allocate(size, /*memory_space=*/0); 202 } 203 virtual void *GetSubBuffer(DeviceMemoryBase *parent, uint64 offset, 204 uint64 size) = 0; 205 virtual void Deallocate(DeviceMemoryBase *mem) = 0; 206 // Allocates unified memory space of the given size, if supported. 207 // See 208 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd 209 // for more details on unified memory. UnifiedMemoryAllocate(uint64 size)210 virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; } 211 212 // Deallocates unified memory space previously allocated with 213 // UnifiedMemoryAllocate. UnifiedMemoryDeallocate(void * mem)214 virtual void UnifiedMemoryDeallocate(void *mem) {} 215 virtual void *HostMemoryAllocate(uint64 size) = 0; 216 virtual void HostMemoryDeallocate(void *mem) = 0; 217 virtual bool HostMemoryRegister(void *mem, uint64 size) = 0; 218 virtual bool HostMemoryUnregister(void *mem) = 0; 219 virtual bool SynchronizeAllActivity() = 0; 220 virtual port::Status SynchronousMemZero(DeviceMemoryBase *location, 221 uint64 size) = 0; 222 virtual port::Status SynchronousMemSet(DeviceMemoryBase *location, int value, 223 uint64 size) = 0; 224 virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst, 225 const void *host_src, uint64 size) = 0; 226 virtual port::Status SynchronousMemcpy(void *host_dst, 227 const DeviceMemoryBase &gpu_src, 228 uint64 size) = 0; 229 virtual port::Status SynchronousMemcpyDeviceToDevice( 230 DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, 231 uint64 size) = 0; 232 virtual port::Status MemZero(Stream *stream, DeviceMemoryBase *location, 233 uint64 size) = 0; Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)234 virtual port::Status Memset(Stream *stream, DeviceMemoryBase *location, 235 uint8 pattern, uint64 size) { 236 return port::InternalError("Not implemented"); 237 } 238 virtual port::Status Memset32(Stream *stream, DeviceMemoryBase *location, 239 uint32 pattern, uint64 size) = 0; 240 virtual bool Memcpy(Stream *stream, void *host_dst, 241 const DeviceMemoryBase &gpu_src, uint64 size) = 0; 242 virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, 243 const void *host_src, uint64 size) = 0; 244 virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst, 245 const DeviceMemoryBase &gpu_src, 246 uint64 size) = 0; 247 virtual bool HostCallback(Stream *stream, std::function<void()> callback); 248 virtual bool HostCallback(Stream *stream, 249 std::function<port::Status()> callback) = 0; 250 virtual port::Status AllocateEvent(Event *event) = 0; 251 virtual port::Status DeallocateEvent(Event *event) = 0; 252 virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; 253 virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0; 254 virtual Event::Status PollForEventStatus(Event *event) = 0; 255 virtual bool AllocateStream(Stream *stream) = 0; 256 virtual void DeallocateStream(Stream *stream) = 0; 257 virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0; 258 virtual bool AllocateTimer(Timer *timer) = 0; 259 virtual void DeallocateTimer(Timer *timer) = 0; 260 virtual bool StartTimer(Stream *stream, Timer *timer) = 0; 261 virtual bool StopTimer(Stream *stream, Timer *timer) = 0; 262 virtual port::Status BlockHostUntilDone(Stream *stream) = 0; GetStatus(Stream * stream)263 virtual port::Status GetStatus(Stream *stream) { 264 return port::Status(port::error::UNIMPLEMENTED, 265 "GetStatus is not supported on this executor."); 266 } 267 virtual int PlatformDeviceCount() = 0; 268 virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0; 269 virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0; 270 virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0; 271 virtual port::Status SetDeviceSharedMemoryConfig( 272 SharedMemoryConfig config) = 0; 273 GetDeviceLoad()274 virtual int64 GetDeviceLoad() { return -1; } 275 DeviceMemoryUsage(int64 * free,int64 * total)276 virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const { 277 return false; 278 } 279 280 // Retrieves device pointer and size for a symbol. The device pointer is 281 // stored at mem, and the size is stored at size. Either mem or bytes can be 282 // null, however, both of them cannot be null at the same time. To use 283 // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol 284 // is found. 285 // 286 // If ModuleHandle is set then we search for `symbol_name` only within the 287 // module corresponding to `module_handle`. Otherwise all loaded modules are 288 // searched. GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)289 virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, 290 void **mem, size_t *bytes) { 291 return false; 292 } 293 294 // Creates a new DeviceDescription object. Ownership is transferred to the 295 // caller. 296 virtual port::StatusOr<std::unique_ptr<DeviceDescription>> 297 CreateDeviceDescription() const = 0; 298 299 // Attempts to register the provided TraceListener with the device-specific 300 // Executor implementation. When this is called, the PIMPL interface has 301 // already taken ownership of the object and is managing the generic tracing 302 // events. The device-specific implementation must determine if the passed 303 // listener is of a type appropriate for it to trace during registration (and 304 // before dispatching events to it). 305 // Returns true if the listener was successfully registered, false otherwise. 306 // Does not take ownership of listener. RegisterTraceListener(TraceListener * listener)307 virtual bool RegisterTraceListener(TraceListener *listener) { return false; } 308 309 // Unregisters the specified listener from the device-specific Executor. 310 // Returns true if the listener was successfully registered, false otherwise. UnregisterTraceListener(TraceListener * listener)311 virtual bool UnregisterTraceListener(TraceListener *listener) { 312 return false; 313 } 314 315 // Returns whether this StreamExecutor has BLAS support for its underlying 316 // platform. SupportsBlas()317 virtual bool SupportsBlas() const { return false; } 318 319 // Creates a new BlasSupport object, ownership is transferred to the caller. 320 // If SupportsBlas() is false, this will always return null. 321 // 322 // If SupportsBlas() is true, this may return null, for example, if the BLAS 323 // initialization fails. CreateBlas()324 virtual blas::BlasSupport *CreateBlas() { return nullptr; } 325 326 // Returns whether this StreamExecutor has FFT support for its underlying 327 // platform. SupportsFft()328 virtual bool SupportsFft() const { return false; } 329 330 // Creates a new fft::FftSupport object, ownership is transferred to the 331 // caller. 332 // If SupportsFft() is false, this will always return null. 333 // 334 // If SupportsFft() is true, this may return null, for example, if the FFT 335 // initialization fails. CreateFft()336 virtual fft::FftSupport *CreateFft() { return nullptr; } 337 338 // Returns whether this StreamExecutor has Random Number Generation support 339 // for 340 // its underlying platform. SupportsRng()341 virtual bool SupportsRng() const { return false; } 342 343 // Returns whether this StreamExecutor has neural net support for its 344 // underlying 345 // platform. SupportsDnn()346 virtual bool SupportsDnn() const { return false; } 347 348 // Creates a new RngSupport object, ownership is transferred to the caller. 349 // If SupportsRng() is false, this will always return null. 350 // 351 // If SupportsRng() is true, this may return null, for example, if the RNG 352 // initialization fails. CreateRng()353 virtual rng::RngSupport *CreateRng() { return nullptr; } 354 355 // Creates a new DnnSupport object, ownership is transferred to the caller. 356 // If SupportsDnn() is false, this will always return null. 357 // 358 // If SupportsDnn() is true, this may return null, for example, if the DNN 359 // initialization fails. CreateDnn()360 virtual dnn::DnnSupport *CreateDnn() { return nullptr; } 361 362 // Each call creates a new instance of the platform-specific implementation of 363 // the corresponding interface type. 364 virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0; 365 virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0; 366 virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; 367 virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; 368 369 // Returns the CUDA or ROCm context associated with this StreamExecutor 370 // platform implementation. 371 // 372 // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, 373 // causing a fatal error if it is not. This hack is made available solely for 374 // use from distbelief code, which temporarily has strong ties to CUDA or ROCm 375 // as a platform. GpuContextHack()376 virtual void *GpuContextHack() { return nullptr; } 377 378 // Return allocator statistics. GetAllocatorStats()379 virtual absl::optional<AllocatorStats> GetAllocatorStats() { 380 return absl::nullopt; 381 } 382 383 // Clears the compilation cache from volatile memory. Returns OK if no 384 // compilation cache exists or if clearing the compilation cache is 385 // unsupported. Caches in non-volatile storage are unaffected. FlushCompilationCache()386 virtual port::Status FlushCompilationCache() { return port::Status::OK(); } 387 388 private: 389 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); 390 }; 391 392 } // namespace internal 393 } // namespace stream_executor 394 395 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 396