1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Interfaces for platform-dependent implementations to satisfy. This are 17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e. 18 // the StreamExecutor is just a husk that delegates calls to the 19 // platform-specific objects which implement the interfaces defined here. 20 21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 23 24 #include <functional> 25 #include <map> 26 #include <memory> 27 #include <utility> 28 #include <vector> 29 30 #include "absl/types/optional.h" 31 #include "tensorflow/stream_executor/allocator_stats.h" 32 #include "tensorflow/stream_executor/device_description.h" 33 #include "tensorflow/stream_executor/device_memory.h" 34 #include "tensorflow/stream_executor/device_options.h" 35 #include "tensorflow/stream_executor/dnn.h" 36 #include "tensorflow/stream_executor/event.h" 37 #include "tensorflow/stream_executor/kernel.h" 38 #include "tensorflow/stream_executor/kernel_cache_config.h" 39 #include "tensorflow/stream_executor/kernel_spec.h" 40 #include "tensorflow/stream_executor/launch_dim.h" 41 #include "tensorflow/stream_executor/lib/status.h" 42 #include "tensorflow/stream_executor/lib/statusor.h" 43 #include "tensorflow/stream_executor/module_spec.h" 44 #include "tensorflow/stream_executor/platform.h" 45 #include "tensorflow/stream_executor/platform/port.h" 46 #include "tensorflow/stream_executor/plugin_registry.h" 47 #include "tensorflow/stream_executor/trace_listener.h" 48 49 namespace stream_executor { 50 51 class Stream; 52 class Timer; 53 54 // An opaque handle to a loaded module. 55 // 56 // An instance of this is returned from StreamExecutor::GetModule. 57 class ModuleHandle { 58 public: id_(id)59 /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {} 60 61 // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a 62 // null pointer. id()63 void *id() const { return id_; } 64 65 explicit operator bool() const { return id() != nullptr; } 66 67 private: 68 void *id_; 69 }; 70 71 namespace internal { 72 73 // Platform-dependent interface class for the generic Events interface, in 74 // the PIMPL style. 75 class EventInterface { 76 public: EventInterface()77 EventInterface() {} ~EventInterface()78 virtual ~EventInterface() {} 79 80 private: 81 SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); 82 }; 83 84 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to 85 // this interface) with virtual destruction. This class exists for the 86 // platform-dependent code to hang any kernel data/resource info/functionality 87 // off of. 88 class KernelInterface { 89 public: 90 // Default constructor for the abstract interface. KernelInterface()91 KernelInterface() {} 92 93 // Default destructor for the abstract interface. ~KernelInterface()94 virtual ~KernelInterface() {} 95 96 // Returns the number of formal parameters that this kernel accepts. 97 virtual unsigned Arity() const = 0; 98 99 // Sets the preferred cache configuration. 100 virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; 101 102 // Gets the preferred cache configuration. 103 virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; 104 105 private: 106 SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); 107 }; 108 109 // Pointer-to-implementation object type (i.e. the Stream class delegates to 110 // this interface) with virtual destruction. This class exists for the 111 // platform-dependent code to hang any kernel data/resource info/functionality 112 // off of. 113 class StreamInterface { 114 public: 115 // Default constructor for the abstract interface. StreamInterface()116 StreamInterface() {} 117 118 // Default destructor for the abstract interface. ~StreamInterface()119 virtual ~StreamInterface() {} 120 121 // Returns the GPU stream associated with this platform's stream 122 // implementation. 123 // 124 // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, 125 // causing a fatal error if it is not. This hack is made available solely for 126 // use from distbelief code, which temporarily has strong ties to CUDA or 127 // ROCm as a platform. GpuStreamHack()128 virtual void *GpuStreamHack() { return nullptr; } 129 130 // See the above comment on GpuStreamHack -- this further breaks abstraction 131 // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a 132 // platform, and a historical attachment to a programming model which takes a 133 // stream-slot rather than a stream-value. GpuStreamMemberHack()134 virtual void **GpuStreamMemberHack() { return nullptr; } 135 136 private: 137 SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); 138 }; 139 140 // Pointer-to-implementation object type (i.e. the Timer class delegates to 141 // this interface) with virtual destruction. This class exists for the 142 // platform-dependent code to hang any timer data/resource info/functionality 143 // off of. 144 class TimerInterface { 145 public: 146 // Default constructor for the abstract interface. TimerInterface()147 TimerInterface() {} 148 149 // Default destructor for the abstract interface. ~TimerInterface()150 virtual ~TimerInterface() {} 151 152 // Returns the number of microseconds elapsed in a completed timer. 153 virtual uint64 Microseconds() const = 0; 154 155 // Returns the number of nanoseconds elapsed in a completed timer. 156 virtual uint64 Nanoseconds() const = 0; 157 158 private: 159 SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); 160 }; 161 162 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL). 163 // 164 // Various platforms will provide an implementation that satisfy this interface. 165 class StreamExecutorInterface { 166 public: 167 // Default constructor for the abstract interface. StreamExecutorInterface()168 StreamExecutorInterface() {} 169 170 // Default destructor for the abstract interface. ~StreamExecutorInterface()171 virtual ~StreamExecutorInterface() {} 172 173 // Returns the (transitively) wrapped executor if this executor is 174 // wrapping another executor; otherwise, returns this. GetUnderlyingExecutor()175 virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; } 176 177 // See the StreamExecutor interface for comments on the same-named methods. 178 virtual port::Status Init(int device_ordinal, 179 DeviceOptions device_options) = 0; 180 GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)181 virtual port::Status GetKernel(const MultiKernelLoaderSpec &spec, 182 KernelBase *kernel) { 183 return port::UnimplementedError("Not Implemented"); 184 } UnloadModule(ModuleHandle module_handle)185 virtual bool UnloadModule(ModuleHandle module_handle) { return false; } LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)186 virtual port::Status LoadModule(const MultiModuleLoaderSpec &spec, 187 ModuleHandle *module_handle) { 188 return port::UnimplementedError("Not Implemented"); 189 } Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)190 virtual port::Status Launch(Stream *stream, const ThreadDim &thread_dims, 191 const BlockDim &block_dims, const KernelBase &k, 192 const KernelArgsArrayBase &args) { 193 return port::UnimplementedError("Not Implemented"); 194 } 195 196 // Releases any state associated with the kernel. UnloadKernel(const KernelBase * kernel)197 virtual void UnloadKernel(const KernelBase *kernel) {} 198 virtual DeviceMemoryBase Allocate(uint64 size, int64 memory_space) = 0; Allocate(uint64 size)199 DeviceMemoryBase Allocate(uint64 size) { 200 return Allocate(size, /*memory_space=*/0); 201 } 202 virtual void *GetSubBuffer(DeviceMemoryBase *parent, uint64 offset, 203 uint64 size) = 0; 204 virtual void Deallocate(DeviceMemoryBase *mem) = 0; 205 // Allocates unified memory space of the given size, if supported. 206 // See 207 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd 208 // for more details on unified memory. UnifiedMemoryAllocate(uint64 size)209 virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; } 210 211 // Deallocates unified memory space previously allocated with 212 // UnifiedMemoryAllocate. UnifiedMemoryDeallocate(void * mem)213 virtual void UnifiedMemoryDeallocate(void *mem) {} 214 virtual void *HostMemoryAllocate(uint64 size) = 0; 215 virtual void HostMemoryDeallocate(void *mem) = 0; 216 virtual bool HostMemoryRegister(void *mem, uint64 size) = 0; 217 virtual bool HostMemoryUnregister(void *mem) = 0; 218 virtual bool SynchronizeAllActivity() = 0; 219 virtual port::Status SynchronousMemZero(DeviceMemoryBase *location, 220 uint64 size) = 0; 221 virtual port::Status SynchronousMemSet(DeviceMemoryBase *location, int value, 222 uint64 size) = 0; 223 virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst, 224 const void *host_src, uint64 size) = 0; 225 virtual port::Status SynchronousMemcpy(void *host_dst, 226 const DeviceMemoryBase &gpu_src, 227 uint64 size) = 0; 228 virtual port::Status SynchronousMemcpyDeviceToDevice( 229 DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, 230 uint64 size) = 0; 231 virtual port::Status MemZero(Stream *stream, DeviceMemoryBase *location, 232 uint64 size) = 0; Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)233 virtual port::Status Memset(Stream *stream, DeviceMemoryBase *location, 234 uint8 pattern, uint64 size) { 235 return port::InternalError("Not implemented"); 236 } 237 virtual port::Status Memset32(Stream *stream, DeviceMemoryBase *location, 238 uint32 pattern, uint64 size) = 0; 239 virtual bool Memcpy(Stream *stream, void *host_dst, 240 const DeviceMemoryBase &gpu_src, uint64 size) = 0; 241 virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, 242 const void *host_src, uint64 size) = 0; 243 virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst, 244 const DeviceMemoryBase &gpu_src, 245 uint64 size) = 0; 246 virtual bool HostCallback(Stream *stream, std::function<void()> callback); 247 virtual bool HostCallback(Stream *stream, 248 std::function<port::Status()> callback) = 0; 249 virtual port::Status AllocateEvent(Event *event) = 0; 250 virtual port::Status DeallocateEvent(Event *event) = 0; 251 virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; 252 virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0; 253 virtual Event::Status PollForEventStatus(Event *event) = 0; 254 virtual bool AllocateStream(Stream *stream) = 0; 255 virtual void DeallocateStream(Stream *stream) = 0; 256 virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0; 257 virtual bool AllocateTimer(Timer *timer) = 0; 258 virtual void DeallocateTimer(Timer *timer) = 0; 259 virtual bool StartTimer(Stream *stream, Timer *timer) = 0; 260 virtual bool StopTimer(Stream *stream, Timer *timer) = 0; 261 virtual port::Status BlockHostUntilDone(Stream *stream) = 0; GetStatus(Stream * stream)262 virtual port::Status GetStatus(Stream *stream) { 263 return port::Status(port::error::UNIMPLEMENTED, 264 "GetStatus is not supported on this executor."); 265 } 266 virtual int PlatformDeviceCount() = 0; 267 virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0; 268 virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0; 269 GetDeviceLoad()270 virtual int64 GetDeviceLoad() { return -1; } 271 DeviceMemoryUsage(int64 * free,int64 * total)272 virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const { 273 return false; 274 } 275 276 // Retrieves device pointer and size for a symbol. The device pointer is 277 // stored at mem, and the size is stored at size. Either mem or bytes can be 278 // null, however, both of them cannot be null at the same time. To use 279 // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol 280 // is found. 281 // 282 // If ModuleHandle is set then we search for `symbol_name` only within the 283 // module corresponding to `module_handle`. Otherwise all loaded modules are 284 // searched. GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)285 virtual bool GetSymbol(const std::string &symbol_name, 286 ModuleHandle module_handle, void **mem, 287 size_t *bytes) { 288 return false; 289 } 290 291 // Creates a new DeviceDescription object. Ownership is transferred to the 292 // caller. 293 virtual port::StatusOr<std::unique_ptr<DeviceDescription>> 294 CreateDeviceDescription() const = 0; 295 296 // Attempts to register the provided TraceListener with the device-specific 297 // Executor implementation. When this is called, the PIMPL interface has 298 // already taken ownership of the object and is managing the generic tracing 299 // events. The device-specific implementation must determine if the passed 300 // listener is of a type appropriate for it to trace during registration (and 301 // before dispatching events to it). 302 // Returns true if the listener was successfully registered, false otherwise. 303 // Does not take ownership of listener. RegisterTraceListener(TraceListener * listener)304 virtual bool RegisterTraceListener(TraceListener *listener) { return false; } 305 306 // Unregisters the specified listener from the device-specific Executor. 307 // Returns true if the listener was successfully registered, false otherwise. UnregisterTraceListener(TraceListener * listener)308 virtual bool UnregisterTraceListener(TraceListener *listener) { 309 return false; 310 } 311 312 // Returns whether this StreamExecutor has BLAS support for its underlying 313 // platform. SupportsBlas()314 virtual bool SupportsBlas() const { return false; } 315 316 // Creates a new BlasSupport object, ownership is transferred to the caller. 317 // If SupportsBlas() is false, this will always return null. 318 // 319 // If SupportsBlas() is true, this may return null, for example, if the BLAS 320 // initialization fails. CreateBlas()321 virtual blas::BlasSupport *CreateBlas() { return nullptr; } 322 323 // Returns whether this StreamExecutor has FFT support for its underlying 324 // platform. SupportsFft()325 virtual bool SupportsFft() const { return false; } 326 327 // Creates a new fft::FftSupport object, ownership is transferred to the 328 // caller. 329 // If SupportsFft() is false, this will always return null. 330 // 331 // If SupportsFft() is true, this may return null, for example, if the FFT 332 // initialization fails. CreateFft()333 virtual fft::FftSupport *CreateFft() { return nullptr; } 334 335 // Returns whether this StreamExecutor has Random Number Generation support 336 // for 337 // its underlying platform. SupportsRng()338 virtual bool SupportsRng() const { return false; } 339 340 // Returns whether this StreamExecutor has neural net support for its 341 // underlying 342 // platform. SupportsDnn()343 virtual bool SupportsDnn() const { return false; } 344 345 // Creates a new RngSupport object, ownership is transferred to the caller. 346 // If SupportsRng() is false, this will always return null. 347 // 348 // If SupportsRng() is true, this may return null, for example, if the RNG 349 // initialization fails. CreateRng()350 virtual rng::RngSupport *CreateRng() { return nullptr; } 351 352 // Creates a new DnnSupport object, ownership is transferred to the caller. 353 // If SupportsDnn() is false, this will always return null. 354 // 355 // If SupportsDnn() is true, this may return null, for example, if the DNN 356 // initialization fails. CreateDnn()357 virtual dnn::DnnSupport *CreateDnn() { return nullptr; } 358 359 // Each call creates a new instance of the platform-specific implementation of 360 // the corresponding interface type. 361 virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0; 362 virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0; 363 virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; 364 virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; 365 366 // Returns the CUDA or ROCm context associated with this StreamExecutor 367 // platform implementation. 368 // 369 // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, 370 // causing a fatal error if it is not. This hack is made available solely for 371 // use from distbelief code, which temporarily has strong ties to CUDA or ROCm 372 // as a platform. GpuContextHack()373 virtual void *GpuContextHack() { return nullptr; } 374 375 // Return allocator statistics. GetAllocatorStats()376 virtual absl::optional<AllocatorStats> GetAllocatorStats() { 377 return absl::nullopt; 378 } 379 380 // Clears the compilation cache from volatile memory. Returns OK if no 381 // compilation cache exists or if clearing the compilation cache is 382 // unsupported. Caches in non-volatile storage are unaffected. FlushCompilationCache()383 virtual port::Status FlushCompilationCache() { return port::Status::OK(); } 384 385 private: 386 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); 387 }; 388 389 } // namespace internal 390 } // namespace stream_executor 391 392 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 393