• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file wraps cuda runtime calls with dso loader so that we don't need to
17 // have explicit linking to libcuda.
18 
19 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
20 #include "tensorflow/stream_executor/lib/env.h"
21 #include "tensorflow/stream_executor/platform/dso_loader.h"
22 
23 namespace {
GetDsoHandle()24 void* GetDsoHandle() {
25   static auto handle = []() -> void* {
26     auto handle_or =
27         stream_executor::internal::DsoLoader::GetCudaRuntimeDsoHandle();
28     if (!handle_or.ok()) {
29       LOG(INFO) << "Ignore above cudart dlerror if you do not have a GPU set "
30                    "up on your machine.";
31       return nullptr;
32     }
33     return handle_or.ValueOrDie();
34   }();
35   return handle;
36 }
37 
38 template <typename T>
LoadSymbol(const char * symbol_name)39 T LoadSymbol(const char* symbol_name) {
40   void* symbol = nullptr;
41   auto env = stream_executor::port::Env::Default();
42   env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError();
43   return reinterpret_cast<T>(symbol);
44 }
GetSymbolNotFoundError()45 cudaError_t GetSymbolNotFoundError() {
46   return cudaErrorSharedObjectSymbolNotFound;
47 }
48 }  // namespace
49 
50 #define __dv(v)
51 #define __CUDA_DEPRECATED
52 
53 // A bunch of new symbols were introduced in version 10
54 #if CUDART_VERSION < 10000
55 #include "tensorflow/stream_executor/cuda/cuda_runtime_9_0.inc"
56 #elif CUDART_VERSION < 10010
57 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc"
58 #elif CUDART_VERSION < 10020
59 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_1.inc"
60 #elif CUDART_VERSION < 11000
61 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_2.inc"
62 #elif CUDART_VERSION < 11020
63 #include "tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc"
64 #else
65 #include "tensorflow/stream_executor/cuda/cuda_runtime_11_2.inc"
66 #endif
67 #undef __dv
68 #undef __CUDA_DEPRECATED
69 
70 extern "C" {
71 
72 // Following are private symbols in libcudart that got inserted by nvcc.
__cudaRegisterFunction(void ** fatCubinHandle,const char * hostFun,char * deviceFun,const char * deviceName,int thread_limit,uint3 * tid,uint3 * bid,dim3 * bDim,dim3 * gDim,int * wSize)73 extern void CUDARTAPI __cudaRegisterFunction(
74     void **fatCubinHandle, const char *hostFun, char *deviceFun,
75     const char *deviceName, int thread_limit, uint3 *tid, uint3 *bid,
76     dim3 *bDim, dim3 *gDim, int *wSize) {
77   using FuncPtr = void(CUDARTAPI *)(void **fatCubinHandle, const char *hostFun,
78                                     char *deviceFun, const char *deviceName,
79                                     int thread_limit, uint3 *tid, uint3 *bid,
80                                     dim3 *bDim, dim3 *gDim, int *wSize);
81   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterFunction");
82   if (!func_ptr) return;
83   func_ptr(fatCubinHandle, hostFun, deviceFun, deviceName, thread_limit, tid,
84            bid, bDim, gDim, wSize);
85 }
86 
__cudaUnregisterFatBinary(void ** fatCubinHandle)87 extern void CUDARTAPI __cudaUnregisterFatBinary(void **fatCubinHandle) {
88   using FuncPtr = void(CUDARTAPI *)(void **fatCubinHandle);
89   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaUnregisterFatBinary");
90   if (!func_ptr) return;
91   func_ptr(fatCubinHandle);
92 }
93 
__cudaRegisterVar(void ** fatCubinHandle,char * hostVar,char * deviceAddress,const char * deviceName,int ext,size_t size,int constant,int global)94 extern void CUDARTAPI __cudaRegisterVar(void **fatCubinHandle, char *hostVar,
95                                         char *deviceAddress,
96                                         const char *deviceName, int ext,
97                                         size_t size, int constant, int global) {
98   using FuncPtr = void(CUDARTAPI *)(
99       void **fatCubinHandle, char *hostVar, char *deviceAddress,
100       const char *deviceName, int ext, size_t size, int constant, int global);
101   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterVar");
102   if (!func_ptr) return;
103   func_ptr(fatCubinHandle, hostVar, deviceAddress, deviceName, ext, size,
104            constant, global);
105 }
106 
__cudaRegisterFatBinary(void * fatCubin)107 extern void **CUDARTAPI __cudaRegisterFatBinary(void *fatCubin) {
108   using FuncPtr = void **(CUDARTAPI *)(void *fatCubin);
109   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterFatBinary");
110   if (!func_ptr) return nullptr;
111   return (void **)func_ptr(fatCubin);
112 }
113 
__cudaPopCallConfiguration(dim3 * gridDim,dim3 * blockDim,size_t * sharedMem,void * stream)114 extern cudaError_t CUDARTAPI __cudaPopCallConfiguration(dim3 *gridDim,
115                                                         dim3 *blockDim,
116                                                         size_t *sharedMem,
117                                                         void *stream) {
118   using FuncPtr = cudaError_t(CUDARTAPI *)(dim3 * gridDim, dim3 * blockDim,
119                                            size_t * sharedMem, void *stream);
120   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaPopCallConfiguration");
121   if (!func_ptr) return GetSymbolNotFoundError();
122   return func_ptr(gridDim, blockDim, sharedMem, stream);
123 }
124 
__cudaPushCallConfiguration(dim3 gridDim,dim3 blockDim,size_t sharedMem=0,void * stream=0)125 extern __host__ __device__ unsigned CUDARTAPI __cudaPushCallConfiguration(
126     dim3 gridDim, dim3 blockDim, size_t sharedMem = 0, void *stream = 0) {
127   using FuncPtr = unsigned(CUDARTAPI *)(dim3 gridDim, dim3 blockDim,
128                                         size_t sharedMem, void *stream);
129   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaPushCallConfiguration");
130   if (!func_ptr) return 0;
131   return func_ptr(gridDim, blockDim, sharedMem, stream);
132 }
133 
__cudaInitModule(void ** fatCubinHandle)134 extern char CUDARTAPI __cudaInitModule(void **fatCubinHandle) {
135   using FuncPtr = char(CUDARTAPI *)(void **fatCubinHandle);
136   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaInitModule");
137   if (!func_ptr) return 0;
138   return func_ptr(fatCubinHandle);
139 }
140 
141 #if CUDART_VERSION >= 10010
__cudaRegisterFatBinaryEnd(void ** fatCubinHandle)142 extern void CUDARTAPI __cudaRegisterFatBinaryEnd(void **fatCubinHandle) {
143   using FuncPtr = void(CUDARTAPI *)(void **fatCubinHandle);
144   static auto func_ptr = LoadSymbol<FuncPtr>("__cudaRegisterFatBinaryEnd");
145   if (!func_ptr) return;
146   func_ptr(fatCubinHandle);
147 }
148 #endif
149 }  // extern "C"
150