1 #pragma once
2 
3 // This header provides C++ wrappers around commonly used CUDA API functions.
4 // The benefit of using C++ here is that we can raise an exception in the
5 // event of an error, rather than explicitly pass around error codes.  This
6 // leads to more natural APIs.
7 //
8 // The naming convention used here matches the naming convention of torch.cuda
9 
10 #include <c10/core/Device.h>
11 #include <c10/core/impl/GPUTrace.h>
12 #include <c10/cuda/CUDAException.h>
13 #include <c10/cuda/CUDAMacros.h>
14 #include <cuda_runtime_api.h>
15 namespace c10::cuda {
16 
17 // NB: In the past, we were inconsistent about whether or not this reported
18 // an error if there were driver problems are not.  Based on experience
19 // interacting with users, it seems that people basically ~never want this
20 // function to fail; it should just return zero if things are not working.
21 // Oblige them.
22 // It still might log a warning for user first time it's invoked
23 C10_CUDA_API DeviceIndex device_count() noexcept;
24 
25 // Version of device_count that throws is no devices are detected
26 C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
27 
28 C10_CUDA_API DeviceIndex current_device();
29 
30 C10_CUDA_API void set_device(DeviceIndex device);
31 
32 C10_CUDA_API void device_synchronize();
33 
34 C10_CUDA_API void warn_or_error_on_sync();
35 
36 // Raw CUDA device management functions
37 C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
38 
39 C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
40 
41 C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
42 
43 C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
44 
45 C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device);
46 
47 C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device);
48 
49 C10_CUDA_API void SetTargetDevice();
50 
51 enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
52 
53 // this is a holder for c10 global state (similar to at GlobalContext)
54 // currently it's used to store cuda synchronization warning state,
55 // but can be expanded to hold other related global state, e.g. to
56 // record stream usage
57 class WarningState {
58  public:
set_sync_debug_mode(SyncDebugMode l)59   void set_sync_debug_mode(SyncDebugMode l) {
60     sync_debug_mode = l;
61   }
62 
get_sync_debug_mode()63   SyncDebugMode get_sync_debug_mode() {
64     return sync_debug_mode;
65   }
66 
67  private:
68   SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
69 };
70 
warning_state()71 C10_CUDA_API __inline__ WarningState& warning_state() {
72   static WarningState warning_state_;
73   return warning_state_;
74 }
75 // the subsequent functions are defined in the header because for performance
76 // reasons we want them to be inline
memcpy_and_sync(void * dst,const void * src,int64_t nbytes,cudaMemcpyKind kind,cudaStream_t stream)77 C10_CUDA_API void __inline__ memcpy_and_sync(
78     void* dst,
79     const void* src,
80     int64_t nbytes,
81     cudaMemcpyKind kind,
82     cudaStream_t stream) {
83   if (C10_UNLIKELY(
84           warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
85     warn_or_error_on_sync();
86   }
87   const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
88   if (C10_UNLIKELY(interp)) {
89     (*interp)->trace_gpu_stream_synchronization(
90         c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
91   }
92 #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
93   C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
94 #else
95   C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
96   C10_CUDA_CHECK(cudaStreamSynchronize(stream));
97 #endif
98 }
99 
stream_synchronize(cudaStream_t stream)100 C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
101   if (C10_UNLIKELY(
102           warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
103     warn_or_error_on_sync();
104   }
105   const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
106   if (C10_UNLIKELY(interp)) {
107     (*interp)->trace_gpu_stream_synchronization(
108         c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
109   }
110   C10_CUDA_CHECK(cudaStreamSynchronize(stream));
111 }
112 
113 C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
114 C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
115 
116 } // namespace c10::cuda
117