• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
17 
18 #include <complex>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/base/call_once.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/service/gpu/precompiled_kernels.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34 
35 namespace xla {
36 namespace gpu {
37 
38 namespace {
39 
GetContext(se::Stream * stream)40 StatusOr<GpuSolverContext*> GetContext(se::Stream* stream) {
41   // TODO(b/214454412): This global hashtable is incorrect (ABA bug if a Stream
42   // is added to the hasthable, then deleted, and then a new Stream is created
43   // at the same address).  It also leaks memory!
44   static absl::Mutex mu(absl::kConstInit);
45   static auto contexts =
46       new absl::flat_hash_map<se::Stream*, GpuSolverContext> ABSL_GUARDED_BY(
47           mu);
48 
49   absl::MutexLock lock(&mu);
50   auto result = contexts->emplace(stream, GpuSolverContext());
51   if (result.second) {
52     TF_ASSIGN_OR_RETURN(result.first->second, GpuSolverContext::Create(stream));
53   }
54   return &result.first->second;
55 }
56 
57 template <typename T>
DoPotrfBatched(const se::GpuAsmOpts & asm_opts,CholeskyParams * params,se::Stream * stream,GpuSolverContext * context)58 Status DoPotrfBatched(const se::GpuAsmOpts& asm_opts, CholeskyParams* params,
59                       se::Stream* stream, GpuSolverContext* context) {
60   T* a_base = static_cast<T*>(params->a_buffer.opaque());
61   se::DeviceMemory<int> infos(params->info_buffer);
62 #if TENSORFLOW_USE_ROCSOLVER
63   // hipsolver is not supported so allocate a GPU buffer
64   se::ScopedDeviceMemory<T*> ptrs =
65       stream->parent()->AllocateOwnedArray<T*>(batch_size_);
66   auto as = *ptrs;
67 #else
68   se::DeviceMemory<T*> as(params->workspace_buffer);
69 #endif
70 
71   CHECK_GE(as.size(), params->batch_size);
72   CHECK_GE(infos.size(), params->batch_size);
73 
74   // Run a kernel that sets as[i] = &a_base[i * stride].
75   const int64_t stride_bytes = params->n * params->n * sizeof(T);
76   TF_RETURN_IF_ERROR(MakeBatchPointers(
77       stream, asm_opts, se::DeviceMemoryBase(a_base), stride_bytes,
78       static_cast<int>(params->batch_size), se::DeviceMemoryBase(as)));
79 
80   // Now that we've set up the `as` array, we can call cusolver.
81   return context->PotrfBatched(params->uplo, params->n, as, params->n, infos,
82                                params->batch_size);
83 }
84 
85 template <typename T>
DoPotrfUnbatched(CholeskyParams * params,GpuSolverContext * context)86 Status DoPotrfUnbatched(CholeskyParams* params, GpuSolverContext* context) {
87   T* a_base = static_cast<T*>(params->a_buffer.opaque());
88   int* info_base = static_cast<int*>(params->info_buffer.opaque());
89 
90   int64_t stride = params->n * params->n;
91   for (int64_t i = 0; i < params->batch_size; ++i) {
92     se::DeviceMemory<T> a_data(
93         se::DeviceMemoryBase(&a_base[i * stride], sizeof(T) * stride));
94     se::DeviceMemory<int> info_data(
95         se::DeviceMemoryBase(&info_base[i], sizeof(int)));
96     TF_RETURN_IF_ERROR(context->Potrf(params->uplo, params->n, a_data,
97                                       params->n, info_data,
98                                       params->workspace_buffer));
99   }
100   return Status::OK();
101 }
102 
103 }  // namespace
104 
CholeskyThunk(ThunkInfo thunk_info,const CholeskyOptions & options,const se::GpuAsmOpts asm_opts,BufferAllocation::Slice a_buffer,BufferAllocation::Slice workspace_buffer,BufferAllocation::Slice info_buffer,PrimitiveType type,int64_t batch_size,int64_t n)105 CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
106                              const CholeskyOptions& options,
107                              const se::GpuAsmOpts asm_opts,
108                              BufferAllocation::Slice a_buffer,
109                              BufferAllocation::Slice workspace_buffer,
110                              BufferAllocation::Slice info_buffer,
111                              PrimitiveType type, int64_t batch_size, int64_t n)
112     : Thunk(Kind::kCholesky, thunk_info),
113       asm_opts_(asm_opts),
114       uplo_(options.lower() ? se::blas::UpperLower::kLower
115                             : se::blas::UpperLower::kUpper),
116       a_buffer_(a_buffer),
117       workspace_buffer_(workspace_buffer),
118       info_buffer_(info_buffer),
119       type_(type),
120       batch_size_(batch_size),
121       n_(n) {}
122 
ExecuteOnStream(const ExecuteParams & params)123 Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
124   VLOG(3) << "type=" << PrimitiveType_Name(type_)
125           << " uplo=" << se::blas::UpperLowerString(uplo_)
126           << " batch_size=" << batch_size_ << " n=" << n_
127           << " a=" << a_buffer_.ToString()
128           << " workspace=" << workspace_buffer_.ToString()
129           << " info=" << info_buffer_.ToString();
130 
131   se::DeviceMemoryBase a_buffer =
132       params.buffer_allocations->GetDeviceAddress(a_buffer_);
133   se::DeviceMemoryBase info_buffer =
134       params.buffer_allocations->GetDeviceAddress(info_buffer_);
135   se::DeviceMemoryBase workspace_buffer =
136       params.buffer_allocations->GetDeviceAddress(workspace_buffer_);
137   CholeskyParams cholesky_params{n_,       batch_size_,      uplo_,
138                                  a_buffer, workspace_buffer, info_buffer};
139   return RunCholesky(asm_opts_, type_, &cholesky_params, params.stream);
140 }
141 
RunCholesky(const se::GpuAsmOpts & asm_opts,PrimitiveType type,CholeskyParams * cholesky_params,se::Stream * stream)142 Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type,
143                    CholeskyParams* cholesky_params, se::Stream* stream) {
144   TF_ASSIGN_OR_RETURN(GpuSolverContext * context, GetContext(stream));
145   if (context->SupportsPotrfBatched()) {
146     switch (type) {
147       case F32:
148         return DoPotrfBatched<float>(asm_opts, cholesky_params, stream,
149                                      context);
150       case F64:
151         return DoPotrfBatched<double>(asm_opts, cholesky_params, stream,
152                                       context);
153       case C64:
154         return DoPotrfBatched<std::complex<float>>(asm_opts, cholesky_params,
155                                                    stream, context);
156       case C128:
157         return DoPotrfBatched<std::complex<double>>(asm_opts, cholesky_params,
158                                                     stream, context);
159       default:
160         return InvalidArgument("Invalid type for cholesky %s",
161                                PrimitiveType_Name(type));
162     }
163   } else {
164     switch (type) {
165       case F32:
166         return DoPotrfUnbatched<float>(cholesky_params, context);
167       case F64:
168         return DoPotrfUnbatched<double>(cholesky_params, context);
169       case C64:
170         return DoPotrfUnbatched<std::complex<float>>(cholesky_params, context);
171       case C128:
172         return DoPotrfUnbatched<std::complex<double>>(cholesky_params, context);
173       default:
174         return InvalidArgument("Invalid type for cholesky %s",
175                                PrimitiveType_Name(type));
176     }
177   }
178 }
179 
180 }  // namespace gpu
181 }  // namespace xla
182