• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
17 
18 #include "tensorflow/compiler/xla/util.h"
19 
20 #if TENSORFLOW_USE_ROCM
21 
22 namespace rocblas_wrap {
23 
24 using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
25 using tensorflow::Env;
26 
27 #ifdef PLATFORM_GOOGLE
28 #define ROCBLAS_API_WRAPPER(__name)           \
29   struct WrapperShim__##__name {              \
30     static const char* kName;                 \
31     template <typename... Args>               \
32     rocblas_status operator()(Args... args) { \
33       return ::__name(args...);               \
34     }                                         \
35   } __name;                                   \
36   const char* WrapperShim__##__name::kName = #__name;
37 
38 #else
39 
40 #define ROCBLAS_API_WRAPPER(__name)                                        \
41   struct DynLoadShim__##__name {                                           \
42     static const char* kName;                                              \
43     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;           \
44     static void* GetDsoHandle() {                                          \
45       auto s = GetRocblasDsoHandle();                                      \
46       return s.ValueOrDie();                                               \
47     }                                                                      \
48     static FuncPtrT LoadOrDie() {                                          \
49       void* f;                                                             \
50       auto s =                                                             \
51           Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \
52       CHECK(s.ok()) << "could not find " << kName                          \
53                     << " in rocblas DSO; dlerror: " << s.error_message();  \
54       return reinterpret_cast<FuncPtrT>(f);                                \
55     }                                                                      \
56     static FuncPtrT DynLoad() {                                            \
57       static FuncPtrT f = LoadOrDie();                                     \
58       return f;                                                            \
59     }                                                                      \
60     template <typename... Args>                                            \
61     rocblas_status operator()(Args... args) {                              \
62       return DynLoad()(args...);                                           \
63     }                                                                      \
64   } __name;                                                                \
65   const char* DynLoadShim__##__name::kName = #__name;
66 
67 #endif
68 
69 // clang-format off
70 #define FOREACH_ROCBLAS_API(__macro)            \
71   __macro(rocblas_create_handle)                \
72   __macro(rocblas_destroy_handle)               \
73   __macro(rocblas_set_stream)
74 // clang-format on
75 
76 FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER)
77 
78 }  // namespace rocblas_wrap
79 
80 #endif  // TENSORFLOW_USE_ROCM
81 
82 namespace xla {
83 namespace gpu {
84 
85 namespace {
86 
87 // Type traits to get CUDA complex types from std::complex<T>.
88 template <typename T>
89 struct GpuComplexT {
90   typedef T type;
91 };
92 #if !defined(TENSORFLOW_USE_ROCM)
93 
94 using gpuStream_t = cudaStream_t;
95 
96 #define GpuSolverCreate cusolverDnCreate
97 #define GpuSolverSetStream cusolverDnSetStream
98 #define GpuSolverDestroy cusolverDnDestroy
99 
100 template <>
101 struct GpuComplexT<std::complex<float>> {
102   typedef cuComplex type;
103 };
104 template <>
105 struct GpuComplexT<std::complex<double>> {
106   typedef cuDoubleComplex type;
107 };
108 
109 #else
110 
111 using gpuStream_t = hipStream_t;
112 
113 #define GpuSolverCreate rocblas_wrap::rocblas_create_handle
114 #define GpuSolverSetStream rocblas_wrap::rocblas_set_stream
115 #define GpuSolverDestroy rocblas_wrap::rocblas_destroy_handle
116 
117 template <>
118 struct GpuComplexT<std::complex<float>> {
119   typedef rocblas_float_complex type;
120 };
121 template <>
122 struct GpuComplexT<std::complex<double>> {
123   typedef rocblas_double_complex type;
124 };
125 #endif
126 
127 template <typename T>
ToDevicePointer(se::DeviceMemory<T> p)128 inline typename GpuComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
129   return static_cast<typename GpuComplexT<T>::type*>(p.opaque());
130 }
131 
132 #if !defined(TENSORFLOW_USE_ROCM)
GpuBlasUpperLower(se::blas::UpperLower uplo)133 cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
134   switch (uplo) {
135     case se::blas::UpperLower::kUpper:
136       return CUBLAS_FILL_MODE_UPPER;
137     case se::blas::UpperLower::kLower:
138       return CUBLAS_FILL_MODE_LOWER;
139     default:
140       LOG(FATAL) << "Invalid value of blas::UpperLower.";
141   }
142 }
143 
144 // Converts a cuSolver status to a Status.
ConvertStatus(cusolverStatus_t status)145 Status ConvertStatus(cusolverStatus_t status) {
146   switch (status) {
147     case CUSOLVER_STATUS_SUCCESS:
148       return Status::OK();
149     case CUSOLVER_STATUS_NOT_INITIALIZED:
150       return FailedPrecondition("cuSolver has not been initialized");
151     case CUSOLVER_STATUS_ALLOC_FAILED:
152       return ResourceExhausted("cuSolver allocation failed");
153     case CUSOLVER_STATUS_INVALID_VALUE:
154       return InvalidArgument("cuSolver invalid value error");
155     case CUSOLVER_STATUS_ARCH_MISMATCH:
156       return FailedPrecondition("cuSolver architecture mismatch error");
157     case CUSOLVER_STATUS_MAPPING_ERROR:
158       return Unknown("cuSolver mapping error");
159     case CUSOLVER_STATUS_EXECUTION_FAILED:
160       return Unknown("cuSolver execution failed");
161     case CUSOLVER_STATUS_INTERNAL_ERROR:
162       return Internal("cuSolver internal error");
163     case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
164       return Unimplemented("cuSolver matrix type not supported error");
165     case CUSOLVER_STATUS_NOT_SUPPORTED:
166       return Unimplemented("cuSolver not supported error");
167     case CUSOLVER_STATUS_ZERO_PIVOT:
168       return InvalidArgument("cuSolver zero pivot error");
169     case CUSOLVER_STATUS_INVALID_LICENSE:
170       return FailedPrecondition("cuSolver invalid license error");
171     default:
172       return Unknown("Unknown cuSolver error");
173   }
174 }
175 #else
GpuBlasUpperLower(se::blas::UpperLower uplo)176 rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) {
177   switch (uplo) {
178     case se::blas::UpperLower::kUpper:
179       return rocblas_fill_upper;
180     case se::blas::UpperLower::kLower:
181       return rocblas_fill_lower;
182     default:
183       LOG(FATAL) << "Invalid value of blas::UpperLower.";
184   }
185 }
186 
187 // Converts a cuSolver status to a Status.
ConvertStatus(rocblas_status status)188 Status ConvertStatus(rocblas_status status) {
189   switch (status) {
190     case rocblas_status_success:
191       return Status::OK();
192     case rocblas_status_invalid_handle:
193       return FailedPrecondition("handle not initialized, invalid or null");
194     case rocblas_status_not_implemented:
195       return Internal("function is not implemented");
196     case rocblas_status_invalid_pointer:
197       return InvalidArgument("invalid pointer argument");
198     case rocblas_status_invalid_size:
199       return InvalidArgument("invalid size argument");
200     case rocblas_status_memory_error:
201       return Internal("failed internal memory allocation, copy or dealloc");
202     case rocblas_status_internal_error:
203       return Internal("other internal library failure");
204     case rocblas_status_perf_degraded:
205       return Internal("performance degraded due to low device memory");
206     case rocblas_status_size_query_mismatch:
207       return Unknown("unmatched start/stop size query");
208     case rocblas_status_size_increased:
209       return Unknown("queried device memory size increased");
210     case rocblas_status_size_unchanged:
211       return Unknown("queried device memory size unchanged");
212     case rocblas_status_invalid_value:
213       return InvalidArgument("passed argument not valid");
214     case rocblas_status_continue:
215       return Unknown("nothing preventing function to proceed");
216     default:
217       return Unknown("Unknown rocsolver error");
218   }
219 }
220 #endif
221 
222 }  // namespace
223 
Create(se::Stream * stream)224 StatusOr<GpuSolverContext> GpuSolverContext::Create(se::Stream* stream) {
225   gpusolverHandle_t handle;
226   TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverCreate(&handle)));
227   GpuSolverContext context(stream, handle);
228 
229   if (stream) {
230     // StreamExecutor really should just expose the Cuda stream to clients...
231     const gpuStream_t* gpu_stream =
232         CHECK_NOTNULL(reinterpret_cast<const gpuStream_t*>(
233             stream->implementation()->GpuStreamMemberHack()));
234     TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverSetStream(handle, *gpu_stream)));
235   }
236 
237   return std::move(context);
238 }
239 
GpuSolverContext(se::Stream * stream,gpusolverHandle_t handle)240 GpuSolverContext::GpuSolverContext(se::Stream* stream, gpusolverHandle_t handle)
241     : stream_(stream), handle_(handle) {}
242 
GpuSolverContext(GpuSolverContext && other)243 GpuSolverContext::GpuSolverContext(GpuSolverContext&& other) {
244   handle_ = other.handle_;
245   stream_ = other.stream_;
246   other.handle_ = nullptr;
247   other.stream_ = nullptr;
248 }
249 
operator =(GpuSolverContext && other)250 GpuSolverContext& GpuSolverContext::operator=(GpuSolverContext&& other) {
251   std::swap(handle_, other.handle_);
252   std::swap(stream_, other.stream_);
253   return *this;
254 }
255 
~GpuSolverContext()256 GpuSolverContext::~GpuSolverContext() {
257   if (handle_) {
258     Status status = ConvertStatus(GpuSolverDestroy(handle_));
259     if (!status.ok()) {
260       LOG(ERROR) << "GpuSolverDestroy failed: " << status;
261     }
262   }
263 }
264 
265 #if !defined(TENSORFLOW_USE_ROCM)
266 #define CALL_LAPACK_TYPES(m) \
267   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
268 
269 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
270 
271 #else
272 
273 #define CALL_LAPACK_TYPES(m) \
274   m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
275 
276 #define DN_SOLVER_FN(method, type_prefix) \
277   tensorflow::wrap::rocsolver_##type_prefix##method
278 
279 #endif
280 // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
281 // buffers to cuSolver buffer size methods and this will be a documented
282 // behavior in a future cuSolver release.
PotrfBufferSize(PrimitiveType type,se::blas::UpperLower uplo,int n,int lda)283 StatusOr<int64> GpuSolverContext::PotrfBufferSize(PrimitiveType type,
284                                                   se::blas::UpperLower uplo,
285                                                   int n, int lda) {
286 #if !defined(TENSORFLOW_USE_ROCM)
287   int size = -1;
288   switch (type) {
289     case F32: {
290       TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnSpotrf_bufferSize(
291           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
292       break;
293     }
294     case F64: {
295       TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnDpotrf_bufferSize(
296           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
297       break;
298     }
299     case C64: {
300       TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnCpotrf_bufferSize(
301           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
302       break;
303     }
304     case C128: {
305       TF_RETURN_IF_ERROR(ConvertStatus(cusolverDnZpotrf_bufferSize(
306           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
307       break;
308     }
309     default:
310       return InvalidArgument("Invalid type for cholesky decomposition: %s",
311                              PrimitiveType_Name(type));
312   }
313   return size;
314 #else
315   return 0;
316 #endif
317 }
318 
319 #if !defined(TENSORFLOW_USE_ROCM)
320 #define POTRF_INSTANCE(T, type_prefix)                                    \
321   template <>                                                             \
322   Status GpuSolverContext::Potrf<T>(                                      \
323       se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda,   \
324       se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
325     return ConvertStatus(DN_SOLVER_FN(potrf, type_prefix)(                \
326         handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda,    \
327         ToDevicePointer(workspace), workspace.ElementCount(),             \
328         ToDevicePointer(lapack_info)));                                   \
329   }
330 #else
331 #define POTRF_INSTANCE(T, type_prefix)                                    \
332   template <>                                                             \
333   Status GpuSolverContext::Potrf<T>(                                      \
334       se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda,   \
335       se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
336     return ConvertStatus(DN_SOLVER_FN(potrf, type_prefix)(                \
337         handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda,    \
338         ToDevicePointer(lapack_info)));                                   \
339   }
340 #endif
341 
342 CALL_LAPACK_TYPES(POTRF_INSTANCE);
343 
344 }  // namespace gpu
345 }  // namespace xla
346