1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
17
18 #include <complex>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/base/call_once.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/service/gpu/precompiled_kernels.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34
35 namespace xla {
36 namespace gpu {
37
38 namespace {
39
GetContext(se::Stream * stream)40 StatusOr<GpuSolverContext*> GetContext(se::Stream* stream) {
41 // TODO(b/214454412): This global hashtable is incorrect (ABA bug if a Stream
42 // is added to the hasthable, then deleted, and then a new Stream is created
43 // at the same address). It also leaks memory!
44 static absl::Mutex mu(absl::kConstInit);
45 static auto contexts =
46 new absl::flat_hash_map<se::Stream*, GpuSolverContext> ABSL_GUARDED_BY(
47 mu);
48
49 absl::MutexLock lock(&mu);
50 auto result = contexts->emplace(stream, GpuSolverContext());
51 if (result.second) {
52 TF_ASSIGN_OR_RETURN(result.first->second, GpuSolverContext::Create(stream));
53 }
54 return &result.first->second;
55 }
56
57 template <typename T>
DoPotrfBatched(const se::GpuAsmOpts & asm_opts,CholeskyParams * params,se::Stream * stream,GpuSolverContext * context)58 Status DoPotrfBatched(const se::GpuAsmOpts& asm_opts, CholeskyParams* params,
59 se::Stream* stream, GpuSolverContext* context) {
60 T* a_base = static_cast<T*>(params->a_buffer.opaque());
61 se::DeviceMemory<int> infos(params->info_buffer);
62 #if TENSORFLOW_USE_ROCSOLVER
63 // hipsolver is not supported so allocate a GPU buffer
64 se::ScopedDeviceMemory<T*> ptrs =
65 stream->parent()->AllocateOwnedArray<T*>(batch_size_);
66 auto as = *ptrs;
67 #else
68 se::DeviceMemory<T*> as(params->workspace_buffer);
69 #endif
70
71 CHECK_GE(as.size(), params->batch_size);
72 CHECK_GE(infos.size(), params->batch_size);
73
74 // Run a kernel that sets as[i] = &a_base[i * stride].
75 const int64_t stride_bytes = params->n * params->n * sizeof(T);
76 TF_RETURN_IF_ERROR(MakeBatchPointers(
77 stream, asm_opts, se::DeviceMemoryBase(a_base), stride_bytes,
78 static_cast<int>(params->batch_size), se::DeviceMemoryBase(as)));
79
80 // Now that we've set up the `as` array, we can call cusolver.
81 return context->PotrfBatched(params->uplo, params->n, as, params->n, infos,
82 params->batch_size);
83 }
84
85 template <typename T>
DoPotrfUnbatched(CholeskyParams * params,GpuSolverContext * context)86 Status DoPotrfUnbatched(CholeskyParams* params, GpuSolverContext* context) {
87 T* a_base = static_cast<T*>(params->a_buffer.opaque());
88 int* info_base = static_cast<int*>(params->info_buffer.opaque());
89
90 int64_t stride = params->n * params->n;
91 for (int64_t i = 0; i < params->batch_size; ++i) {
92 se::DeviceMemory<T> a_data(
93 se::DeviceMemoryBase(&a_base[i * stride], sizeof(T) * stride));
94 se::DeviceMemory<int> info_data(
95 se::DeviceMemoryBase(&info_base[i], sizeof(int)));
96 TF_RETURN_IF_ERROR(context->Potrf(params->uplo, params->n, a_data,
97 params->n, info_data,
98 params->workspace_buffer));
99 }
100 return Status::OK();
101 }
102
103 } // namespace
104
CholeskyThunk(ThunkInfo thunk_info,const CholeskyOptions & options,const se::GpuAsmOpts asm_opts,BufferAllocation::Slice a_buffer,BufferAllocation::Slice workspace_buffer,BufferAllocation::Slice info_buffer,PrimitiveType type,int64_t batch_size,int64_t n)105 CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
106 const CholeskyOptions& options,
107 const se::GpuAsmOpts asm_opts,
108 BufferAllocation::Slice a_buffer,
109 BufferAllocation::Slice workspace_buffer,
110 BufferAllocation::Slice info_buffer,
111 PrimitiveType type, int64_t batch_size, int64_t n)
112 : Thunk(Kind::kCholesky, thunk_info),
113 asm_opts_(asm_opts),
114 uplo_(options.lower() ? se::blas::UpperLower::kLower
115 : se::blas::UpperLower::kUpper),
116 a_buffer_(a_buffer),
117 workspace_buffer_(workspace_buffer),
118 info_buffer_(info_buffer),
119 type_(type),
120 batch_size_(batch_size),
121 n_(n) {}
122
ExecuteOnStream(const ExecuteParams & params)123 Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
124 VLOG(3) << "type=" << PrimitiveType_Name(type_)
125 << " uplo=" << se::blas::UpperLowerString(uplo_)
126 << " batch_size=" << batch_size_ << " n=" << n_
127 << " a=" << a_buffer_.ToString()
128 << " workspace=" << workspace_buffer_.ToString()
129 << " info=" << info_buffer_.ToString();
130
131 se::DeviceMemoryBase a_buffer =
132 params.buffer_allocations->GetDeviceAddress(a_buffer_);
133 se::DeviceMemoryBase info_buffer =
134 params.buffer_allocations->GetDeviceAddress(info_buffer_);
135 se::DeviceMemoryBase workspace_buffer =
136 params.buffer_allocations->GetDeviceAddress(workspace_buffer_);
137 CholeskyParams cholesky_params{n_, batch_size_, uplo_,
138 a_buffer, workspace_buffer, info_buffer};
139 return RunCholesky(asm_opts_, type_, &cholesky_params, params.stream);
140 }
141
RunCholesky(const se::GpuAsmOpts & asm_opts,PrimitiveType type,CholeskyParams * cholesky_params,se::Stream * stream)142 Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type,
143 CholeskyParams* cholesky_params, se::Stream* stream) {
144 TF_ASSIGN_OR_RETURN(GpuSolverContext * context, GetContext(stream));
145 if (context->SupportsPotrfBatched()) {
146 switch (type) {
147 case F32:
148 return DoPotrfBatched<float>(asm_opts, cholesky_params, stream,
149 context);
150 case F64:
151 return DoPotrfBatched<double>(asm_opts, cholesky_params, stream,
152 context);
153 case C64:
154 return DoPotrfBatched<std::complex<float>>(asm_opts, cholesky_params,
155 stream, context);
156 case C128:
157 return DoPotrfBatched<std::complex<double>>(asm_opts, cholesky_params,
158 stream, context);
159 default:
160 return InvalidArgument("Invalid type for cholesky %s",
161 PrimitiveType_Name(type));
162 }
163 } else {
164 switch (type) {
165 case F32:
166 return DoPotrfUnbatched<float>(cholesky_params, context);
167 case F64:
168 return DoPotrfUnbatched<double>(cholesky_params, context);
169 case C64:
170 return DoPotrfUnbatched<std::complex<float>>(cholesky_params, context);
171 case C128:
172 return DoPotrfUnbatched<std::complex<double>>(cholesky_params, context);
173 default:
174 return InvalidArgument("Invalid type for cholesky %s",
175 PrimitiveType_Name(type));
176 }
177 }
178 }
179
180 } // namespace gpu
181 } // namespace xla
182