• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file wraps hipsolver API calls with dso loader so that we don't need to
17 // have explicit linking to libhipsolver. All TF hipsolver API usage should
18 // route through this wrapper.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_
22 
23 #include "rocm/rocm_config.h"
24 
25 #if TF_ROCM_VERSION >= 40500
26 
27 #include "rocm/include/hipsolver.h"
28 #include "tensorflow/stream_executor/lib/env.h"
29 #include "tensorflow/stream_executor/platform/dso_loader.h"
30 #include "tensorflow/stream_executor/platform/port.h"
31 
32 namespace tensorflow {
33 namespace wrap {
34 
35 #ifdef PLATFORM_GOOGLE
36 
37 #define HIPSOLVER_API_WRAPPER(api_name)                        \
38   template <typename... Args>                                  \
39   auto api_name(Args... args)->decltype(::api_name(args...)) { \
40     return ::api_name(args...);                                \
41   }
42 
43 #else
44 
45 #define TO_STR_(x) #x
46 #define TO_STR(x) TO_STR_(x)
47 
48 #define HIPSOLVER_API_WRAPPER(api_name)                                       \
49   template <typename... Args>                                                 \
50   auto api_name(Args... args)->decltype(::api_name(args...)) {                \
51     using FuncPtrT = std::add_pointer<decltype(::api_name)>::type;            \
52     static FuncPtrT loaded = []() -> FuncPtrT {                               \
53       static const char* kName = TO_STR(api_name);                            \
54       void* f;                                                                \
55       auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary(   \
56           stream_executor::internal::CachedDsoLoader::GetHipsolverDsoHandle() \
57               .ValueOrDie(),                                                  \
58           kName, &f);                                                         \
59       CHECK(s.ok()) << "could not find " << kName                             \
60                     << " in hipsolver lib; dlerror: " << s.error_message();   \
61       return reinterpret_cast<FuncPtrT>(f);                                   \
62     }();                                                                      \
63     return loaded(args...);                                                   \
64   }
65 
66 #endif
67 
68 // clang-format off
69 #define FOREACH_HIPSOLVER_API(__macro)       \
70   __macro(hipsolverCreate)                   \
71   __macro(hipsolverDestroy)                  \
72   __macro(hipsolverSetStream)                \
73   __macro(hipsolverCgetrf)                   \
74   __macro(hipsolverCgetrf_bufferSize)        \
75   __macro(hipsolverDgetrf)                   \
76   __macro(hipsolverDgetrf_bufferSize)        \
77   __macro(hipsolverSgetrf)                   \
78   __macro(hipsolverSgetrf_bufferSize)        \
79   __macro(hipsolverZgetrf)                   \
80   __macro(hipsolverZgetrf_bufferSize)        \
81   __macro(hipsolverCgetrs)                   \
82   __macro(hipsolverCgetrs_bufferSize)        \
83   __macro(hipsolverDgetrs)                   \
84   __macro(hipsolverDgetrs_bufferSize)        \
85   __macro(hipsolverSgetrs)                   \
86   __macro(hipsolverSgetrs_bufferSize)        \
87   __macro(hipsolverZgetrs)                   \
88   __macro(hipsolverZgetrs_bufferSize)        \
89   __macro(hipsolverCpotrf)                   \
90   __macro(hipsolverCpotrf_bufferSize)        \
91   __macro(hipsolverDpotrf)                   \
92   __macro(hipsolverDpotrf_bufferSize)        \
93   __macro(hipsolverSpotrf)                   \
94   __macro(hipsolverSpotrf_bufferSize)        \
95   __macro(hipsolverZpotrf)                   \
96   __macro(hipsolverZpotrf_bufferSize)        \
97   __macro(hipsolverCpotrfBatched)            \
98   __macro(hipsolverCpotrfBatched_bufferSize) \
99   __macro(hipsolverDpotrfBatched)            \
100   __macro(hipsolverDpotrfBatched_bufferSize) \
101   __macro(hipsolverSpotrfBatched)            \
102   __macro(hipsolverSpotrfBatched_bufferSize) \
103   __macro(hipsolverZpotrfBatched)            \
104   __macro(hipsolverZpotrfBatched_bufferSize) \
105   __macro(hipsolverCgeqrf)                   \
106   __macro(hipsolverCgeqrf_bufferSize)        \
107   __macro(hipsolverDgeqrf)                   \
108   __macro(hipsolverDgeqrf_bufferSize)        \
109   __macro(hipsolverSgeqrf)                   \
110   __macro(hipsolverSgeqrf_bufferSize)        \
111   __macro(hipsolverZgeqrf)                   \
112   __macro(hipsolverZgeqrf_bufferSize)        \
113   __macro(hipsolverCunmqr)                   \
114   __macro(hipsolverCunmqr_bufferSize)        \
115   __macro(hipsolverZunmqr)                   \
116   __macro(hipsolverZunmqr_bufferSize)        \
117   __macro(hipsolverCungqr)                   \
118   __macro(hipsolverCungqr_bufferSize)        \
119   __macro(hipsolverZungqr)                   \
120   __macro(hipsolverZungqr_bufferSize)        \
121   __macro(hipsolverCheevd)                   \
122   __macro(hipsolverCheevd_bufferSize)        \
123   __macro(hipsolverZheevd)                   \
124   __macro(hipsolverZheevd_bufferSize)
125 // clang-format on
126 
127 FOREACH_HIPSOLVER_API(HIPSOLVER_API_WRAPPER)
128 
129 #undef TO_STR_
130 #undef TO_STR
131 #undef FOREACH_HIPSOLVER_API
132 #undef HIPSOLVER_API_WRAPPER
133 
134 }  // namespace wrap
135 }  // namespace tensorflow
136 
137 #endif  // TF_ROCM_VERSION
138 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_
139