1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file wraps hipsolver API calls with dso loader so that we don't need to 17 // have explicit linking to libhipsolver. All TF hipsolver API usage should 18 // route through this wrapper. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_ 22 23 #include "rocm/rocm_config.h" 24 25 #if TF_ROCM_VERSION >= 40500 26 27 #include "rocm/include/hipsolver.h" 28 #include "tensorflow/stream_executor/lib/env.h" 29 #include "tensorflow/stream_executor/platform/dso_loader.h" 30 #include "tensorflow/stream_executor/platform/port.h" 31 32 namespace tensorflow { 33 namespace wrap { 34 35 #ifdef PLATFORM_GOOGLE 36 37 #define HIPSOLVER_API_WRAPPER(api_name) \ 38 template <typename... Args> \ 39 auto api_name(Args... args)->decltype(::api_name(args...)) { \ 40 return ::api_name(args...); \ 41 } 42 43 #else 44 45 #define TO_STR_(x) #x 46 #define TO_STR(x) TO_STR_(x) 47 48 #define HIPSOLVER_API_WRAPPER(api_name) \ 49 template <typename... Args> \ 50 auto api_name(Args... args)->decltype(::api_name(args...)) { \ 51 using FuncPtrT = std::add_pointer<decltype(::api_name)>::type; \ 52 static FuncPtrT loaded = []() -> FuncPtrT { \ 53 static const char* kName = TO_STR(api_name); \ 54 void* f; \ 55 auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \ 56 stream_executor::internal::CachedDsoLoader::GetHipsolverDsoHandle() \ 57 .ValueOrDie(), \ 58 kName, &f); \ 59 CHECK(s.ok()) << "could not find " << kName \ 60 << " in hipsolver lib; dlerror: " << s.error_message(); \ 61 return reinterpret_cast<FuncPtrT>(f); \ 62 }(); \ 63 return loaded(args...); \ 64 } 65 66 #endif 67 68 // clang-format off 69 #define FOREACH_HIPSOLVER_API(__macro) \ 70 __macro(hipsolverCreate) \ 71 __macro(hipsolverDestroy) \ 72 __macro(hipsolverSetStream) \ 73 __macro(hipsolverCgetrf) \ 74 __macro(hipsolverCgetrf_bufferSize) \ 75 __macro(hipsolverDgetrf) \ 76 __macro(hipsolverDgetrf_bufferSize) \ 77 __macro(hipsolverSgetrf) \ 78 __macro(hipsolverSgetrf_bufferSize) \ 79 __macro(hipsolverZgetrf) \ 80 __macro(hipsolverZgetrf_bufferSize) \ 81 __macro(hipsolverCgetrs) \ 82 __macro(hipsolverCgetrs_bufferSize) \ 83 __macro(hipsolverDgetrs) \ 84 __macro(hipsolverDgetrs_bufferSize) \ 85 __macro(hipsolverSgetrs) \ 86 __macro(hipsolverSgetrs_bufferSize) \ 87 __macro(hipsolverZgetrs) \ 88 __macro(hipsolverZgetrs_bufferSize) \ 89 __macro(hipsolverCpotrf) \ 90 __macro(hipsolverCpotrf_bufferSize) \ 91 __macro(hipsolverDpotrf) \ 92 __macro(hipsolverDpotrf_bufferSize) \ 93 __macro(hipsolverSpotrf) \ 94 __macro(hipsolverSpotrf_bufferSize) \ 95 __macro(hipsolverZpotrf) \ 96 __macro(hipsolverZpotrf_bufferSize) \ 97 __macro(hipsolverCpotrfBatched) \ 98 __macro(hipsolverCpotrfBatched_bufferSize) \ 99 __macro(hipsolverDpotrfBatched) \ 100 __macro(hipsolverDpotrfBatched_bufferSize) \ 101 __macro(hipsolverSpotrfBatched) \ 102 __macro(hipsolverSpotrfBatched_bufferSize) \ 103 __macro(hipsolverZpotrfBatched) \ 104 __macro(hipsolverZpotrfBatched_bufferSize) \ 105 __macro(hipsolverCgeqrf) \ 106 __macro(hipsolverCgeqrf_bufferSize) \ 107 __macro(hipsolverDgeqrf) \ 108 __macro(hipsolverDgeqrf_bufferSize) \ 109 __macro(hipsolverSgeqrf) \ 110 __macro(hipsolverSgeqrf_bufferSize) \ 111 __macro(hipsolverZgeqrf) \ 112 __macro(hipsolverZgeqrf_bufferSize) \ 113 __macro(hipsolverCunmqr) \ 114 __macro(hipsolverCunmqr_bufferSize) \ 115 __macro(hipsolverZunmqr) \ 116 __macro(hipsolverZunmqr_bufferSize) \ 117 __macro(hipsolverCungqr) \ 118 __macro(hipsolverCungqr_bufferSize) \ 119 __macro(hipsolverZungqr) \ 120 __macro(hipsolverZungqr_bufferSize) \ 121 __macro(hipsolverCheevd) \ 122 __macro(hipsolverCheevd_bufferSize) \ 123 __macro(hipsolverZheevd) \ 124 __macro(hipsolverZheevd_bufferSize) 125 // clang-format on 126 127 FOREACH_HIPSOLVER_API(HIPSOLVER_API_WRAPPER) 128 129 #undef TO_STR_ 130 #undef TO_STR 131 #undef FOREACH_HIPSOLVER_API 132 #undef HIPSOLVER_API_WRAPPER 133 134 } // namespace wrap 135 } // namespace tensorflow 136 137 #endif // TF_ROCM_VERSION 138 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_ 139