1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
17
18 #include "tensorflow/compiler/xla/util.h"
19
20 #if TENSORFLOW_USE_ROCM
21
22 namespace rocblas_wrap {
23
24 using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
25 using tensorflow::Env;
26
27 #ifdef PLATFORM_GOOGLE
28 #define ROCBLAS_API_WRAPPER(__name) \
29 struct WrapperShim__##__name { \
30 static const char* kName; \
31 template <typename... Args> \
32 rocblas_status operator()(Args... args) { \
33 return ::__name(args...); \
34 } \
35 } __name; \
36 const char* WrapperShim__##__name::kName = #__name;
37
38 #else
39
40 #define ROCBLAS_API_WRAPPER(__name) \
41 struct DynLoadShim__##__name { \
42 static const char* kName; \
43 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
44 static void* GetDsoHandle() { \
45 auto s = GetRocblasDsoHandle(); \
46 return s.ValueOrDie(); \
47 } \
48 static FuncPtrT LoadOrDie() { \
49 void* f; \
50 auto s = \
51 Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \
52 CHECK(s.ok()) << "could not find " << kName \
53 << " in rocblas DSO; dlerror: " << s.error_message(); \
54 return reinterpret_cast<FuncPtrT>(f); \
55 } \
56 static FuncPtrT DynLoad() { \
57 static FuncPtrT f = LoadOrDie(); \
58 return f; \
59 } \
60 template <typename... Args> \
61 rocblas_status operator()(Args... args) { \
62 return DynLoad()(args...); \
63 } \
64 } __name; \
65 const char* DynLoadShim__##__name::kName = #__name;
66
67 #endif
68
69 // clang-format off
70 #define FOREACH_ROCBLAS_API(__macro) \
71 __macro(rocblas_create_handle) \
72 __macro(rocblas_destroy_handle) \
73 __macro(rocblas_set_stream)
74 // clang-format on
75
76 FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER)
77
78 } // namespace rocblas_wrap
79
80 #endif // TENSORFLOW_USE_ROCM
81
82 namespace xla {
83 namespace gpu {
84
85 namespace {
86
87 // Type traits to get CUDA complex types from std::complex<T>.
88 template <typename T>
89 struct GpuComplexT {
90 typedef T type;
91 };
92 #if !defined(TENSORFLOW_USE_ROCM)
93
94 using gpuStream_t = cudaStream_t;
95
96 #define GpuSolverCreate cusolverDnCreate
97 #define GpuSolverSetStream cusolverDnSetStream
98 #define GpuSolverDestroy cusolverDnDestroy
99
100 template <>
101 struct GpuComplexT<std::complex<float>> {
102 typedef cuComplex type;
103 };
104 template <>
105 struct GpuComplexT<std::complex<double>> {
106 typedef cuDoubleComplex type;
107 };
108
109 #else
110
111 using gpuStream_t = hipStream_t;
112
113 #define GpuSolverCreate rocblas_wrap::rocblas_create_handle
114 #define GpuSolverSetStream rocblas_wrap::rocblas_set_stream
115 #define GpuSolverDestroy rocblas_wrap::rocblas_destroy_handle
116
117 template <>
118 struct GpuComplexT<std::complex<float>> {
119 typedef rocblas_float_complex type;
120 };
121 template <>
122 struct GpuComplexT<std::complex<double>> {
123 typedef rocblas_double_complex type;
124 };
125 #endif
126
127 template <typename T>
ToDevicePointer(se::DeviceMemory<T> p)128 inline typename GpuComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
129 return static_cast<typename GpuComplexT<T>::type*>(p.opaque());
130 }
131
132 #if !defined(TENSORFLOW_USE_ROCM)
GpuBlasUpperLower(se::blas::UpperLower uplo)133 cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
134 switch (uplo) {
135 case se::blas::UpperLower::kUpper:
136 return CUBLAS_FILL_MODE_UPPER;
137 case se::blas::UpperLower::kLower:
138 return CUBLAS_FILL_MODE_LOWER;
139 default:
140 LOG(FATAL) << "Invalid value of blas::UpperLower.";
141 }
142 }
143
144 // Converts a cuSolver status to a Status.
ConvertStatus(cusolverStatus_t status)145 Status ConvertStatus(cusolverStatus_t status) {
146 switch (status) {
147 case CUSOLVER_STATUS_SUCCESS:
148 return Status::OK();
149 case CUSOLVER_STATUS_NOT_INITIALIZED:
150 return FailedPrecondition("cuSolver has not been initialized");
151 case CUSOLVER_STATUS_ALLOC_FAILED:
152 return ResourceExhausted("cuSolver allocation failed");
153 case CUSOLVER_STATUS_INVALID_VALUE:
154 return InvalidArgument("cuSolver invalid value error");
155 case CUSOLVER_STATUS_ARCH_MISMATCH:
156 return FailedPrecondition("cuSolver architecture mismatch error");
157 case CUSOLVER_STATUS_MAPPING_ERROR:
158 return Unknown("cuSolver mapping error");
159 case CUSOLVER_STATUS_EXECUTION_FAILED:
160 return Unknown("cuSolver execution failed");
161 case CUSOLVER_STATUS_INTERNAL_ERROR:
162 return Internal("cuSolver internal error");
163 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
164 return Unimplemented("cuSolver matrix type not supported error");
165 case CUSOLVER_STATUS_NOT_SUPPORTED:
166 return Unimplemented("cuSolver not supported error");
167 case CUSOLVER_STATUS_ZERO_PIVOT:
168 return InvalidArgument("cuSolver zero pivot error");
169 case CUSOLVER_STATUS_INVALID_LICENSE:
170 return FailedPrecondition("cuSolver invalid license error");
171 default:
172 return Unknown("Unknown cuSolver error");
173 }
174 }
175 #else
GpuBlasUpperLower(se::blas::UpperLower uplo)176 rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) {
177 switch (uplo) {
178 case se::blas::UpperLower::kUpper:
179 return rocblas_fill_upper;
180 case se::blas::UpperLower::kLower:
181 return rocblas_fill_lower;
182 default:
183 LOG(FATAL) << "Invalid value of blas::UpperLower.";
184 }
185 }
186
187 // Converts a cuSolver status to a Status.
ConvertStatus(rocblas_status status)188 Status ConvertStatus(rocblas_status status) {
189 switch (status) {
190 case rocblas_status_success:
191 return Status::OK();
192 case rocblas_status_invalid_handle:
193 return FailedPrecondition("handle not initialized, invalid or null");
194 case rocblas_status_not_implemented:
195 return Internal("function is not implemented");
196 case rocblas_status_invalid_pointer:
197 return InvalidArgument("invalid pointer argument");
198 case rocblas_status_invalid_size:
199 return InvalidArgument("invalid size argument");
200 case rocblas_status_memory_error:
201 return Internal("failed internal memory allocation, copy or dealloc");
202 case rocblas_status_internal_error:
203 return Internal("other internal library failure");
204 case rocblas_status_perf_degraded:
205 return Internal("performance degraded due to low device memory");
206 case rocblas_status_size_query_mismatch:
207 return Unknown("unmatched start/stop size query");
208 case rocblas_status_size_increased:
209 return Unknown("queried device memory size increased");
210 case rocblas_status_size_unchanged:
211 return Unknown("queried device memory size unchanged");
212 case rocblas_status_invalid_value:
213 return InvalidArgument("passed argument not valid");
214 case rocblas_status_continue:
215 return Unknown("nothing preventing function to proceed");
216 default:
217 return Unknown("Unknown rocsolver error");
218 }
219 }
220 #endif
221
222 } // namespace
223
Create(se::Stream * stream)224 StatusOr<GpuSolverContext> GpuSolverContext::Create(se::Stream* stream) {
225 gpusolverHandle_t handle;
226 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverCreate(&handle)));
227 GpuSolverContext context(stream, handle);
228
229 if (stream) {
230 // StreamExecutor really should just expose the Cuda stream to clients...
231 const gpuStream_t* gpu_stream =
232 CHECK_NOTNULL(reinterpret_cast<const gpuStream_t*>(
233 stream->implementation()->GpuStreamMemberHack()));
234 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverSetStream(handle, *gpu_stream)));
235 }
236
237 return std::move(context);
238 }
239
GpuSolverContext(se::Stream * stream,gpusolverHandle_t handle)240 GpuSolverContext::GpuSolverContext(se::Stream* stream, gpusolverHandle_t handle)
241 : stream_(stream), handle_(handle) {}
242
GpuSolverContext(GpuSolverContext && other)243 GpuSolverContext::GpuSolverContext(GpuSolverContext&& other) {
244 handle_ = other.handle_;
245 stream_ = other.stream_;
246 other.handle_ = nullptr;
247 other.stream_ = nullptr;
248 }
249
operator =(GpuSolverContext && other)250 GpuSolverContext& GpuSolverContext::operator=(GpuSolverContext&& other) {
251 std::swap(handle_, other.handle_);
252 std::swap(stream_, other.stream_);
253 return *this;
254 }
255
~GpuSolverContext()256 GpuSolverContext::~GpuSolverContext() {
257 if (handle_) {
258 Status status = ConvertStatus(GpuSolverDestroy(handle_));
259 if (!status.ok()) {
260 LOG(ERROR) << "GpuSolverDestroy failed: " << status;
261 }
262 }
263 }
264
265 #if !defined(TENSORFLOW_USE_ROCM)
266 #define CALL_LAPACK_TYPES(m) \
267 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
268
269 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
270
271 #else
272
273 #define CALL_LAPACK_TYPES(m) \
274 m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
275
276 #define DN_SOLVER_FN(method, type_prefix) \
277 tensorflow::wrap::rocsolver_##type_prefix##method
278
279 #endif
280 // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
281 // buffers to cuSolver buffer size methods and this will be a documented
282 // behavior in a future cuSolver release.
PotrfBufferSize(PrimitiveType type,se::blas::UpperLower uplo,int n,int lda)283 StatusOr<int64> GpuSolverContext::PotrfBufferSize(PrimitiveType type,
284 se::blas::UpperLower uplo,
285 int n, int lda) {
286 #if !defined(TENSORFLOW_USE_ROCM)
287 int size = -1;
288 switch (type) {
289 case F32: {
290 TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnSpotrf_bufferSize(
291 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
292 break;
293 }
294 case F64: {
295 TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnDpotrf_bufferSize(
296 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
297 break;
298 }
299 case C64: {
300 TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnCpotrf_bufferSize(
301 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
302 break;
303 }
304 case C128: {
305 TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnZpotrf_bufferSize(
306 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
307 break;
308 }
309 default:
310 return InvalidArgument("Invalid type for cholesky decomposition: %s",
311 PrimitiveType_Name(type));
312 }
313 return size;
314 #else
315 return 0;
316 #endif
317 }
318
319 #if !defined(TENSORFLOW_USE_ROCM)
320 #define POTRF_INSTANCE(T, type_prefix) \
321 template <> \
322 Status GpuSolverContext::Potrf<T>( \
323 se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
324 se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
325 return ConvertStatus(DN_SOLVER_FN(potrf, type_prefix)( \
326 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
327 ToDevicePointer(workspace), workspace.ElementCount(), \
328 ToDevicePointer(lapack_info))); \
329 }
330 #else
331 #define POTRF_INSTANCE(T, type_prefix) \
332 template <> \
333 Status GpuSolverContext::Potrf<T>( \
334 se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
335 se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
336 return ConvertStatus(DN_SOLVER_FN(potrf, type_prefix)( \
337 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
338 ToDevicePointer(lapack_info))); \
339 }
340 #endif
341
342 CALL_LAPACK_TYPES(POTRF_INSTANCE);
343
344 } // namespace gpu
345 } // namespace xla
346