1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file wraps cuda runtime calls with dso loader so that we don't need to
17 // have explicit linking to libcuda.
18
19 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
20 #include "tensorflow/stream_executor/lib/env.h"
21 #include "tensorflow/stream_executor/platform/dso_loader.h"
22
23 namespace {
GetDsoHandle()24 void* GetDsoHandle() {
25 static auto handle = []() -> void* {
26 auto handle_or =
27 stream_executor::internal::DsoLoader::GetCudaRuntimeDsoHandle();
28 if (!handle_or.ok()) {
29 LOG(INFO) << "Ignore above cudart dlerror if you do not have a GPU set "
30 "up on your machine.";
31 return nullptr;
32 }
33 return handle_or.ValueOrDie();
34 }();
35 return handle;
36 }
37
38 template <typename T>
LoadSymbol(const char * symbol_name)39 T LoadSymbol(const char* symbol_name) {
40 void* symbol = nullptr;
41 auto env = stream_executor::port::Env::Default();
42 env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError();
43 return reinterpret_cast<T>(symbol);
44 }
GetSymbolNotFoundError()45 cudaError_t GetSymbolNotFoundError() {
46 return cudaErrorSharedObjectSymbolNotFound;
47 }
48 } // namespace
49
50 #define __dv(v)
51 #define __CUDA_DEPRECATED
52
53 // A bunch of new symbols were introduced in version 10
54 #if CUDART_VERSION < 10000
55 #include "tensorflow/stream_executor/cuda/cuda_runtime_9_0.inc"
56 #elif CUDART_VERSION < 10010
57 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc"
58 #elif CUDART_VERSION < 10020
59 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_1.inc"
60 #elif CUDART_VERSION < 11000
61 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_2.inc"
62 #elif CUDART_VERSION < 11020
63 #include "tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc"
64 #else
65 #include "tensorflow/stream_executor/cuda/cuda_runtime_11_2.inc"
66 #endif
67 #undef __dv
68 #undef __CUDA_DEPRECATED
69
70 extern "C" {
71
72 // Following are private symbols in libcudart that got inserted by nvcc.
__cudaRegisterFunction(void ** fatCubinHandle,const char * hostFun,char * deviceFun,const char * deviceName,int thread_limit,uint3 * tid,uint3 * bid,dim3 * bDim,dim3 * gDim,int * wSize)73 extern void CUDARTAPI __cudaRegisterFunction(
74 void **fatCubinHandle, const char *hostFun, char *deviceFun,
75 const char *deviceName, int thread_limit, uint3 *tid, uint3 *bid,
76 dim3 *bDim, dim3 *gDim, int *wSize) {
77 using FuncPtr = void(CUDARTAPI *)(void **fatCubinHandle, const char *hostFun,
78 char *deviceFun, const char *deviceName,
79 int thread_limit, uint3 *tid, uint3 *bid,
80 dim3 *bDim, dim3 *gDim, int *wSize);
81 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterFunction");
82 if (!func_ptr) return;
83 func_ptr(fatCubinHandle, hostFun, deviceFun, deviceName, thread_limit, tid,
84 bid, bDim, gDim, wSize);
85 }
86
__cudaUnregisterFatBinary(void ** fatCubinHandle)87 extern void CUDARTAPI __cudaUnregisterFatBinary(void **fatCubinHandle) {
88 using FuncPtr = void(CUDARTAPI *)(void **fatCubinHandle);
89 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaUnregisterFatBinary");
90 if (!func_ptr) return;
91 func_ptr(fatCubinHandle);
92 }
93
__cudaRegisterVar(void ** fatCubinHandle,char * hostVar,char * deviceAddress,const char * deviceName,int ext,size_t size,int constant,int global)94 extern void CUDARTAPI __cudaRegisterVar(void **fatCubinHandle, char *hostVar,
95 char *deviceAddress,
96 const char *deviceName, int ext,
97 size_t size, int constant, int global) {
98 using FuncPtr = void(CUDARTAPI *)(
99 void **fatCubinHandle, char *hostVar, char *deviceAddress,
100 const char *deviceName, int ext, size_t size, int constant, int global);
101 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterVar");
102 if (!func_ptr) return;
103 func_ptr(fatCubinHandle, hostVar, deviceAddress, deviceName, ext, size,
104 constant, global);
105 }
106
__cudaRegisterFatBinary(void * fatCubin)107 extern void **CUDARTAPI __cudaRegisterFatBinary(void *fatCubin) {
108 using FuncPtr = void **(CUDARTAPI *)(void *fatCubin);
109 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterFatBinary");
110 if (!func_ptr) return nullptr;
111 return (void **)func_ptr(fatCubin);
112 }
113
__cudaPopCallConfiguration(dim3 * gridDim,dim3 * blockDim,size_t * sharedMem,void * stream)114 extern cudaError_t CUDARTAPI __cudaPopCallConfiguration(dim3 *gridDim,
115 dim3 *blockDim,
116 size_t *sharedMem,
117 void *stream) {
118 using FuncPtr = cudaError_t(CUDARTAPI *)(dim3 * gridDim, dim3 * blockDim,
119 size_t * sharedMem, void *stream);
120 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaPopCallConfiguration");
121 if (!func_ptr) return GetSymbolNotFoundError();
122 return func_ptr(gridDim, blockDim, sharedMem, stream);
123 }
124
__cudaPushCallConfiguration(dim3 gridDim,dim3 blockDim,size_t sharedMem=0,void * stream=0)125 extern __host__ __device__ unsigned CUDARTAPI __cudaPushCallConfiguration(
126 dim3 gridDim, dim3 blockDim, size_t sharedMem = 0, void *stream = 0) {
127 using FuncPtr = unsigned(CUDARTAPI *)(dim3 gridDim, dim3 blockDim,
128 size_t sharedMem, void *stream);
129 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaPushCallConfiguration");
130 if (!func_ptr) return 0;
131 return func_ptr(gridDim, blockDim, sharedMem, stream);
132 }
133
__cudaInitModule(void ** fatCubinHandle)134 extern char CUDARTAPI __cudaInitModule(void **fatCubinHandle) {
135 using FuncPtr = char(CUDARTAPI *)(void **fatCubinHandle);
136 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaInitModule");
137 if (!func_ptr) return 0;
138 return func_ptr(fatCubinHandle);
139 }
140
141 #if CUDART_VERSION >= 10010
__cudaRegisterFatBinaryEnd(void ** fatCubinHandle)142 extern void CUDARTAPI __cudaRegisterFatBinaryEnd(void **fatCubinHandle) {
143 using FuncPtr = void(CUDARTAPI *)(void **fatCubinHandle);
144 static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterFatBinaryEnd");
145 if (!func_ptr) return;
146 func_ptr(fatCubinHandle);
147 }
148 #endif
149 } // extern "C"
150