• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
17 
18 #include "tensorflow/compiler/xla/util.h"
19 
20 namespace xla {
21 namespace gpu {
22 
23 namespace {
24 
25 // Type traits to get CUDA complex types from std::complex<T>.
26 template <typename T>
27 struct CUDAComplexT {
28   typedef T type;
29 };
30 template <>
31 struct CUDAComplexT<std::complex<float>> {
32   typedef cuComplex type;
33 };
34 template <>
35 struct CUDAComplexT<std::complex<double>> {
36   typedef cuDoubleComplex type;
37 };
38 
39 template <typename T>
ToDevicePointer(se::DeviceMemory<T> p)40 inline typename CUDAComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
41   return static_cast<typename CUDAComplexT<T>::type*>(p.opaque());
42 }
43 
CUDABlasUpperLower(se::blas::UpperLower uplo)44 cublasFillMode_t CUDABlasUpperLower(se::blas::UpperLower uplo) {
45   switch (uplo) {
46     case se::blas::UpperLower::kUpper:
47       return CUBLAS_FILL_MODE_UPPER;
48     case se::blas::UpperLower::kLower:
49       return CUBLAS_FILL_MODE_LOWER;
50     default:
51       LOG(FATAL) << "Invalid value of blas::UpperLower.";
52   }
53 }
54 
55 // Converts a cuSolver status to a Status.
CusolverStatusToStatus(cusolverStatus_t status)56 Status CusolverStatusToStatus(cusolverStatus_t status) {
57   switch (status) {
58     case CUSOLVER_STATUS_SUCCESS:
59       return Status::OK();
60     case CUSOLVER_STATUS_NOT_INITIALIZED:
61       return FailedPrecondition("cuSolver has not been initialized");
62     case CUSOLVER_STATUS_ALLOC_FAILED:
63       return ResourceExhausted("cuSolver allocation failed");
64     case CUSOLVER_STATUS_INVALID_VALUE:
65       return InvalidArgument("cuSolver invalid value error");
66     case CUSOLVER_STATUS_ARCH_MISMATCH:
67       return FailedPrecondition("cuSolver architecture mismatch error");
68     case CUSOLVER_STATUS_MAPPING_ERROR:
69       return Unknown("cuSolver mapping error");
70     case CUSOLVER_STATUS_EXECUTION_FAILED:
71       return Unknown("cuSolver execution failed");
72     case CUSOLVER_STATUS_INTERNAL_ERROR:
73       return Internal("cuSolver internal error");
74     case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
75       return Unimplemented("cuSolver matrix type not supported error");
76     case CUSOLVER_STATUS_NOT_SUPPORTED:
77       return Unimplemented("cuSolver not supported error");
78     case CUSOLVER_STATUS_ZERO_PIVOT:
79       return InvalidArgument("cuSolver zero pivot error");
80     case CUSOLVER_STATUS_INVALID_LICENSE:
81       return FailedPrecondition("cuSolver invalid license error");
82     default:
83       return Unknown("Unknown cuSolver error");
84   }
85 }
86 
87 }  // namespace
88 
Create(se::Stream * stream)89 StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
90   cusolverDnHandle_t handle;
91   TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle)));
92   CusolverContext context(stream, handle);
93 
94   if (stream) {
95     // StreamExecutor really should just expose the Cuda stream to clients...
96     const cudaStream_t* cuda_stream =
97         CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(
98             stream->implementation()->GpuStreamMemberHack()));
99     TF_RETURN_IF_ERROR(
100         CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream)));
101   }
102 
103   return std::move(context);
104 }
105 
CusolverContext(se::Stream * stream,cusolverDnHandle_t handle)106 CusolverContext::CusolverContext(se::Stream* stream, cusolverDnHandle_t handle)
107     : stream_(stream), handle_(handle) {}
108 
CusolverContext(CusolverContext && other)109 CusolverContext::CusolverContext(CusolverContext&& other) {
110   handle_ = other.handle_;
111   stream_ = other.stream_;
112   other.handle_ = nullptr;
113   other.stream_ = nullptr;
114 }
115 
operator =(CusolverContext && other)116 CusolverContext& CusolverContext::operator=(CusolverContext&& other) {
117   std::swap(handle_, other.handle_);
118   std::swap(stream_, other.stream_);
119   return *this;
120 }
121 
~CusolverContext()122 CusolverContext::~CusolverContext() {
123   if (handle_) {
124     Status status = CusolverStatusToStatus(cusolverDnDestroy(handle_));
125     if (!status.ok()) {
126       LOG(ERROR) << "cusolverDnDestroy failed: " << status;
127     }
128   }
129 }
130 
131 #define CALL_LAPACK_TYPES(m) \
132   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
133 
134 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
135 
136 // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
137 // buffers to cuSolver buffer size methods and this will be a documented
138 // behavior in a future cuSolver release.
PotrfBufferSize(PrimitiveType type,se::blas::UpperLower uplo,int n,int lda)139 StatusOr<int64> CusolverContext::PotrfBufferSize(PrimitiveType type,
140                                                  se::blas::UpperLower uplo,
141                                                  int n, int lda) {
142   int size = -1;
143   switch (type) {
144     case F32: {
145       TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnSpotrf_bufferSize(
146           handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
147       break;
148     }
149     case F64: {
150       TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnDpotrf_bufferSize(
151           handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
152       break;
153     }
154     case C64: {
155       TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCpotrf_bufferSize(
156           handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
157       break;
158     }
159     case C128: {
160       TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnZpotrf_bufferSize(
161           handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
162       break;
163     }
164     default:
165       return InvalidArgument("Invalid type for cholesky decomposition: %s",
166                              PrimitiveType_Name(type));
167   }
168   return size;
169 }
170 
171 #define POTRF_INSTANCE(T, type_prefix)                                    \
172   template <>                                                             \
173   Status CusolverContext::Potrf<T>(                                       \
174       se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda,   \
175       se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
176     return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)(       \
177         handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda,   \
178         ToDevicePointer(workspace), workspace.ElementCount(),             \
179         ToDevicePointer(lapack_info)));                                   \
180   }
181 
182 CALL_LAPACK_TYPES(POTRF_INSTANCE);
183 
184 }  // namespace gpu
185 }  // namespace xla
186