• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file wraps cuda runtime calls with dso loader so that we don't need to
17 // have explicit linking to libcuda.
18 
19 #include "cuda/include/cuda_runtime_api.h"
20 #include "tensorflow/stream_executor/lib/env.h"
21 #include "tensorflow/stream_executor/platform/dso_loader.h"
22 
23 namespace {
GetDsoHandle()24 void* GetDsoHandle() {
25   static auto handle = []() -> void* {
26     auto handle_or =
27         stream_executor::internal::DsoLoader::GetCudaRuntimeDsoHandle();
28     if (!handle_or.ok()) return nullptr;
29     return handle_or.ValueOrDie();
30   }();
31   return handle;
32 }
33 
34 template <typename T>
LoadSymbol(const char * symbol_name)35 T LoadSymbol(const char* symbol_name) {
36   void* symbol = nullptr;
37   auto env = stream_executor::port::Env::Default();
38   env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError();
39   return reinterpret_cast<T>(symbol);
40 }
GetSymbolNotFoundError()41 cudaError_t GetSymbolNotFoundError() {
42   return cudaErrorSharedObjectSymbolNotFound;
43 }
GetSymbolNotFoundStrError()44 const char* GetSymbolNotFoundStrError() {
45   return "cudaErrorSharedObjectSymbolNotFound";
46 }
47 }  // namespace
48 
49 // Code below is auto-generated.
50 extern "C" {
cudaFree(void * devPtr)51 cudaError_t CUDART_CB cudaFree(void* devPtr) {
52   using FuncPtr = cudaError_t (*)(void* devPtr);
53   static auto func_ptr = LoadSymbol<FuncPtr>("cudaFree");
54   if (!func_ptr) return GetSymbolNotFoundError();
55   return func_ptr(devPtr);
56 }
57 
cudaGetDevice(int * device)58 cudaError_t CUDART_CB cudaGetDevice(int* device) {
59   using FuncPtr = cudaError_t (*)(int* device);
60   static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDevice");
61   if (!func_ptr) return GetSymbolNotFoundError();
62   return func_ptr(device);
63 }
64 
cudaGetDeviceProperties(cudaDeviceProp * prop,int device)65 cudaError_t CUDART_CB cudaGetDeviceProperties(cudaDeviceProp* prop,
66                                               int device) {
67   using FuncPtr = cudaError_t (*)(cudaDeviceProp * prop, int device);
68   static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceProperties");
69   if (!func_ptr) return GetSymbolNotFoundError();
70   return func_ptr(prop, device);
71 }
72 
cudaGetErrorString(cudaError_t error)73 const char* CUDART_CB cudaGetErrorString(cudaError_t error) {
74   using FuncPtr = const char* (*)(cudaError_t error);
75   static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetErrorString");
76   if (!func_ptr) return GetSymbolNotFoundStrError();
77   return func_ptr(error);
78 }
79 
cudaSetDevice(int device)80 cudaError_t CUDART_CB cudaSetDevice(int device) {
81   using FuncPtr = cudaError_t (*)(int device);
82   static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetDevice");
83   if (!func_ptr) return GetSymbolNotFoundError();
84   return func_ptr(device);
85 }
86 
cudaStreamAddCallback(cudaStream_t stream,cudaStreamCallback_t callback,void * userData,unsigned int flags)87 cudaError_t CUDART_CB cudaStreamAddCallback(cudaStream_t stream,
88                                             cudaStreamCallback_t callback,
89                                             void* userData,
90                                             unsigned int flags) {
91   using FuncPtr =
92       cudaError_t (*)(cudaStream_t stream, cudaStreamCallback_t callback,
93                       void* userData, unsigned int flags);
94   static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamAddCallback");
95   if (!func_ptr) return GetSymbolNotFoundError();
96   return func_ptr(stream, callback, userData, flags);
97 }
98 
cudaGetDeviceCount(int * count)99 cudaError_t CUDART_CB cudaGetDeviceCount(int* count) {
100   using FuncPtr = cudaError_t (*)(int* count);
101   static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceCount");
102   if (!func_ptr) return GetSymbolNotFoundError();
103   return func_ptr(count);
104 }
105 
cudaPointerGetAttributes(struct cudaPointerAttributes * attributes,const void * ptr)106 cudaError_t CUDART_CB cudaPointerGetAttributes(
107     struct cudaPointerAttributes* attributes, const void* ptr) {
108   using FuncPtr = cudaError_t (*)(struct cudaPointerAttributes * attributes,
109                                   const void* ptr);
110   static auto func_ptr = LoadSymbol<FuncPtr>("cudaPointerGetAttributes");
111   if (!func_ptr) return GetSymbolNotFoundError();
112   return func_ptr(attributes, ptr);
113 }
114 
cudaGetLastError()115 cudaError_t CUDART_CB cudaGetLastError() {
116   using FuncPtr = cudaError_t (*)();
117   static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetLastError");
118   if (!func_ptr) return GetSymbolNotFoundError();
119   return func_ptr();
120 }
121 }  // extern "C"
122