1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file wraps cuda runtime calls with dso loader so that we don't need to
17 // have explicit linking to libcuda.
18
19 #include "cuda/include/cuda_runtime_api.h"
20 #include "tensorflow/stream_executor/lib/env.h"
21 #include "tensorflow/stream_executor/platform/dso_loader.h"
22
23 namespace {
GetDsoHandle()24 void* GetDsoHandle() {
25 static auto handle = []() -> void* {
26 auto handle_or =
27 stream_executor::internal::DsoLoader::GetCudaRuntimeDsoHandle();
28 if (!handle_or.ok()) return nullptr;
29 return handle_or.ValueOrDie();
30 }();
31 return handle;
32 }
33
34 template <typename T>
LoadSymbol(const char * symbol_name)35 T LoadSymbol(const char* symbol_name) {
36 void* symbol = nullptr;
37 auto env = stream_executor::port::Env::Default();
38 env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError();
39 return reinterpret_cast<T>(symbol);
40 }
GetSymbolNotFoundError()41 cudaError_t GetSymbolNotFoundError() {
42 return cudaErrorSharedObjectSymbolNotFound;
43 }
GetSymbolNotFoundStrError()44 const char* GetSymbolNotFoundStrError() {
45 return "cudaErrorSharedObjectSymbolNotFound";
46 }
47 } // namespace
48
49 // Code below is auto-generated.
50 extern "C" {
cudaFree(void * devPtr)51 cudaError_t CUDART_CB cudaFree(void* devPtr) {
52 using FuncPtr = cudaError_t (*)(void* devPtr);
53 static auto func_ptr = LoadSymbol<FuncPtr>("cudaFree");
54 if (!func_ptr) return GetSymbolNotFoundError();
55 return func_ptr(devPtr);
56 }
57
cudaGetDevice(int * device)58 cudaError_t CUDART_CB cudaGetDevice(int* device) {
59 using FuncPtr = cudaError_t (*)(int* device);
60 static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDevice");
61 if (!func_ptr) return GetSymbolNotFoundError();
62 return func_ptr(device);
63 }
64
cudaGetDeviceProperties(cudaDeviceProp * prop,int device)65 cudaError_t CUDART_CB cudaGetDeviceProperties(cudaDeviceProp* prop,
66 int device) {
67 using FuncPtr = cudaError_t (*)(cudaDeviceProp * prop, int device);
68 static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceProperties");
69 if (!func_ptr) return GetSymbolNotFoundError();
70 return func_ptr(prop, device);
71 }
72
cudaGetErrorString(cudaError_t error)73 const char* CUDART_CB cudaGetErrorString(cudaError_t error) {
74 using FuncPtr = const char* (*)(cudaError_t error);
75 static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetErrorString");
76 if (!func_ptr) return GetSymbolNotFoundStrError();
77 return func_ptr(error);
78 }
79
cudaSetDevice(int device)80 cudaError_t CUDART_CB cudaSetDevice(int device) {
81 using FuncPtr = cudaError_t (*)(int device);
82 static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetDevice");
83 if (!func_ptr) return GetSymbolNotFoundError();
84 return func_ptr(device);
85 }
86
cudaStreamAddCallback(cudaStream_t stream,cudaStreamCallback_t callback,void * userData,unsigned int flags)87 cudaError_t CUDART_CB cudaStreamAddCallback(cudaStream_t stream,
88 cudaStreamCallback_t callback,
89 void* userData,
90 unsigned int flags) {
91 using FuncPtr =
92 cudaError_t (*)(cudaStream_t stream, cudaStreamCallback_t callback,
93 void* userData, unsigned int flags);
94 static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamAddCallback");
95 if (!func_ptr) return GetSymbolNotFoundError();
96 return func_ptr(stream, callback, userData, flags);
97 }
98
cudaGetDeviceCount(int * count)99 cudaError_t CUDART_CB cudaGetDeviceCount(int* count) {
100 using FuncPtr = cudaError_t (*)(int* count);
101 static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceCount");
102 if (!func_ptr) return GetSymbolNotFoundError();
103 return func_ptr(count);
104 }
105
cudaPointerGetAttributes(struct cudaPointerAttributes * attributes,const void * ptr)106 cudaError_t CUDART_CB cudaPointerGetAttributes(
107 struct cudaPointerAttributes* attributes, const void* ptr) {
108 using FuncPtr = cudaError_t (*)(struct cudaPointerAttributes * attributes,
109 const void* ptr);
110 static auto func_ptr = LoadSymbol<FuncPtr>("cudaPointerGetAttributes");
111 if (!func_ptr) return GetSymbolNotFoundError();
112 return func_ptr(attributes, ptr);
113 }
114
cudaGetLastError()115 cudaError_t CUDART_CB cudaGetLastError() {
116 using FuncPtr = cudaError_t (*)();
117 static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetLastError");
118 if (!func_ptr) return GetSymbolNotFoundError();
119 return func_ptr();
120 }
121 } // extern "C"
122