1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ 18 19 #include <complex> 20 21 #if !TENSORFLOW_USE_ROCM 22 #include "third_party/gpus/cuda/include/cusolverDn.h" 23 #endif 24 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/util.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 30 #include "tensorflow/stream_executor/blas.h" 31 32 namespace xla { 33 namespace gpu { 34 35 #if !TENSORFLOW_USE_ROCM 36 37 class CusolverContext { 38 public: 39 // stream may be nullptr, in which case the context can only be used for 40 // buffer size queries. 41 static StatusOr<CusolverContext> Create(se::Stream* stream); 42 CusolverContext() = default; 43 ~CusolverContext(); 44 45 CusolverContext(const CusolverContext&) = delete; 46 CusolverContext(CusolverContext&&); 47 CusolverContext& operator=(const CusolverContext&) = delete; 48 CusolverContext& operator=(CusolverContext&&); 49 50 // Computes the Cholesky factorization A = L * L^T for a single matrix. 51 // Returns Status::OK() if the kernel was launched successfully. See: 52 // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf 53 template <typename T, typename = std::enable_if_t< 54 std::is_same<T, float>::value || 55 std::is_same<T, double>::value || 56 std::is_same<T, std::complex<float>>::value || 57 std::is_same<T, std::complex<double>>::value>> 58 Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<T> dev_A, 59 int lda, se::DeviceMemory<int> dev_lapack_info, 60 se::DeviceMemory<T> workspace) = delete; 61 62 // Returns the size of the `workspace` required by Potrf, in number of 63 // elements of `type`. 64 StatusOr<int64> PotrfBufferSize(PrimitiveType type, se::blas::UpperLower uplo, 65 int n, int lda); 66 67 private: 68 CusolverContext(se::Stream* stream, cusolverDnHandle_t handle); 69 handle()70 cusolverDnHandle_t handle() const { return handle_; } 71 72 se::Stream* stream_ = nullptr; 73 cusolverDnHandle_t handle_ = nullptr; 74 }; 75 76 #define CALL_LAPACK_TYPES(m) \ 77 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z) 78 #define POTRF_INSTANCE(T, type_prefix) \ 79 template <> \ 80 Status CusolverContext::Potrf<T>( \ 81 se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \ 82 se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace); 83 CALL_LAPACK_TYPES(POTRF_INSTANCE); 84 #undef POTRF_INSTANCE 85 #undef CALL_LAPACK_TYPES 86 87 #else 88 89 typedef void* cusolverDnHandle_t; 90 91 // TODO(cheshire): Remove this hack once we have ROCM implementation. 92 class CusolverContext { 93 public: 94 static StatusOr<CusolverContext> Create(se::Stream* stream) { 95 LOG(FATAL) << "Unimplemented"; 96 } 97 98 template <typename T, typename = std::enable_if_t< 99 std::is_same<T, float>::value || 100 std::is_same<T, double>::value || 101 std::is_same<T, std::complex<float>>::value || 102 std::is_same<T, std::complex<double>>::value>> 103 Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<T> dev_A, 104 int lda, se::DeviceMemory<int> dev_lapack_info, 105 se::DeviceMemory<T> workspace) { 106 LOG(FATAL) << "Unimplemented"; 107 } 108 109 StatusOr<int64> PotrfBufferSize(PrimitiveType type, se::blas::UpperLower uplo, 110 int n, int lda) { 111 LOG(FATAL) << "Unimplemented"; 112 } 113 }; 114 115 #endif 116 117 } // namespace gpu 118 } // namespace xla 119 120 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ 121