1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
17
18 #include "tensorflow/compiler/xla/util.h"
19
20 namespace xla {
21 namespace gpu {
22
23 namespace {
24
25 // Type traits to get CUDA complex types from std::complex<T>.
26 template <typename T>
27 struct CUDAComplexT {
28 typedef T type;
29 };
30 template <>
31 struct CUDAComplexT<std::complex<float>> {
32 typedef cuComplex type;
33 };
34 template <>
35 struct CUDAComplexT<std::complex<double>> {
36 typedef cuDoubleComplex type;
37 };
38
39 template <typename T>
ToDevicePointer(se::DeviceMemory<T> p)40 inline typename CUDAComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
41 return static_cast<typename CUDAComplexT<T>::type*>(p.opaque());
42 }
43
CUDABlasUpperLower(se::blas::UpperLower uplo)44 cublasFillMode_t CUDABlasUpperLower(se::blas::UpperLower uplo) {
45 switch (uplo) {
46 case se::blas::UpperLower::kUpper:
47 return CUBLAS_FILL_MODE_UPPER;
48 case se::blas::UpperLower::kLower:
49 return CUBLAS_FILL_MODE_LOWER;
50 default:
51 LOG(FATAL) << "Invalid value of blas::UpperLower.";
52 }
53 }
54
55 // Converts a cuSolver status to a Status.
CusolverStatusToStatus(cusolverStatus_t status)56 Status CusolverStatusToStatus(cusolverStatus_t status) {
57 switch (status) {
58 case CUSOLVER_STATUS_SUCCESS:
59 return Status::OK();
60 case CUSOLVER_STATUS_NOT_INITIALIZED:
61 return FailedPrecondition("cuSolver has not been initialized");
62 case CUSOLVER_STATUS_ALLOC_FAILED:
63 return ResourceExhausted("cuSolver allocation failed");
64 case CUSOLVER_STATUS_INVALID_VALUE:
65 return InvalidArgument("cuSolver invalid value error");
66 case CUSOLVER_STATUS_ARCH_MISMATCH:
67 return FailedPrecondition("cuSolver architecture mismatch error");
68 case CUSOLVER_STATUS_MAPPING_ERROR:
69 return Unknown("cuSolver mapping error");
70 case CUSOLVER_STATUS_EXECUTION_FAILED:
71 return Unknown("cuSolver execution failed");
72 case CUSOLVER_STATUS_INTERNAL_ERROR:
73 return Internal("cuSolver internal error");
74 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
75 return Unimplemented("cuSolver matrix type not supported error");
76 case CUSOLVER_STATUS_NOT_SUPPORTED:
77 return Unimplemented("cuSolver not supported error");
78 case CUSOLVER_STATUS_ZERO_PIVOT:
79 return InvalidArgument("cuSolver zero pivot error");
80 case CUSOLVER_STATUS_INVALID_LICENSE:
81 return FailedPrecondition("cuSolver invalid license error");
82 default:
83 return Unknown("Unknown cuSolver error");
84 }
85 }
86
87 } // namespace
88
Create(se::Stream * stream)89 StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
90 cusolverDnHandle_t handle;
91 TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle)));
92 CusolverContext context(stream, handle);
93
94 if (stream) {
95 // StreamExecutor really should just expose the Cuda stream to clients...
96 const cudaStream_t* cuda_stream =
97 CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(
98 stream->implementation()->GpuStreamMemberHack()));
99 TF_RETURN_IF_ERROR(
100 CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream)));
101 }
102
103 return std::move(context);
104 }
105
CusolverContext(se::Stream * stream,cusolverDnHandle_t handle)106 CusolverContext::CusolverContext(se::Stream* stream, cusolverDnHandle_t handle)
107 : stream_(stream), handle_(handle) {}
108
CusolverContext(CusolverContext && other)109 CusolverContext::CusolverContext(CusolverContext&& other) {
110 handle_ = other.handle_;
111 stream_ = other.stream_;
112 other.handle_ = nullptr;
113 other.stream_ = nullptr;
114 }
115
operator =(CusolverContext && other)116 CusolverContext& CusolverContext::operator=(CusolverContext&& other) {
117 std::swap(handle_, other.handle_);
118 std::swap(stream_, other.stream_);
119 return *this;
120 }
121
~CusolverContext()122 CusolverContext::~CusolverContext() {
123 if (handle_) {
124 Status status = CusolverStatusToStatus(cusolverDnDestroy(handle_));
125 if (!status.ok()) {
126 LOG(ERROR) << "cusolverDnDestroy failed: " << status;
127 }
128 }
129 }
130
131 #define CALL_LAPACK_TYPES(m) \
132 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
133
134 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
135
136 // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
137 // buffers to cuSolver buffer size methods and this will be a documented
138 // behavior in a future cuSolver release.
PotrfBufferSize(PrimitiveType type,se::blas::UpperLower uplo,int n,int lda)139 StatusOr<int64> CusolverContext::PotrfBufferSize(PrimitiveType type,
140 se::blas::UpperLower uplo,
141 int n, int lda) {
142 int size = -1;
143 switch (type) {
144 case F32: {
145 TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnSpotrf_bufferSize(
146 handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
147 break;
148 }
149 case F64: {
150 TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnDpotrf_bufferSize(
151 handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
152 break;
153 }
154 case C64: {
155 TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCpotrf_bufferSize(
156 handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
157 break;
158 }
159 case C128: {
160 TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnZpotrf_bufferSize(
161 handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
162 break;
163 }
164 default:
165 return InvalidArgument("Invalid type for cholesky decomposition: %s",
166 PrimitiveType_Name(type));
167 }
168 return size;
169 }
170
171 #define POTRF_INSTANCE(T, type_prefix) \
172 template <> \
173 Status CusolverContext::Potrf<T>( \
174 se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
175 se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
176 return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
177 handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
178 ToDevicePointer(workspace), workspace.ElementCount(), \
179 ToDevicePointer(lapack_info))); \
180 }
181
182 CALL_LAPACK_TYPES(POTRF_INSTANCE);
183
184 } // namespace gpu
185 } // namespace xla
186