1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27 #include "tensorflow/stream_executor/blas.h"
28 #include "tensorflow/stream_executor/device_memory.h"
29
30 namespace xla {
31 namespace gpu {
32
33 static tensorflow::mutex contexts_mu(tensorflow::LINKER_INITIALIZED);
34 static auto contexts =
35 new absl::flat_hash_map<se::Stream*, GpuSolverContext> TF_GUARDED_BY(
36 contexts_mu);
37
CholeskyThunk(ThunkInfo thunk_info,const CholeskyOptions & options,BufferAllocation::Slice a_buffer,BufferAllocation::Slice workspace_buffer,BufferAllocation::Slice info_buffer,PrimitiveType type,int64_t batch_size,int64_t n)38 CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
39 const CholeskyOptions& options,
40 BufferAllocation::Slice a_buffer,
41 BufferAllocation::Slice workspace_buffer,
42 BufferAllocation::Slice info_buffer,
43 PrimitiveType type, int64_t batch_size, int64_t n)
44 : Thunk(Kind::kCholesky, thunk_info),
45 uplo_(options.lower() ? se::blas::UpperLower::kLower
46 : se::blas::UpperLower::kUpper),
47 a_buffer_(a_buffer),
48 workspace_buffer_(workspace_buffer),
49 info_buffer_(info_buffer),
50 type_(type),
51 batch_size_(batch_size),
52 a_batch_stride_(n * n * ShapeUtil::ByteSizeOfPrimitiveType(type)),
53 n_(n) {}
54
ExecuteOnStream(const ExecuteParams & params)55 Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
56 VLOG(3) << "type=" << PrimitiveType_Name(type_)
57 << " uplo=" << se::blas::UpperLowerString(uplo_)
58 << " batch_size=" << batch_size_ << " n=" << n_
59 << " a=" << a_buffer_.ToString()
60 << " workspace=" << workspace_buffer_.ToString()
61 << " info=" << info_buffer_.ToString();
62
63 GpuSolverContext* context;
64 {
65 tensorflow::mutex_lock lock(contexts_mu);
66 auto result = contexts->emplace(params.stream, GpuSolverContext());
67 if (result.second) {
68 TF_ASSIGN_OR_RETURN(result.first->second,
69 GpuSolverContext::Create(params.stream));
70 }
71 context = &result.first->second;
72 }
73
74 char* a_base = static_cast<char*>(
75 params.buffer_allocations->GetDeviceAddress(a_buffer_).opaque());
76 int* info_base = static_cast<int*>(
77 params.buffer_allocations->GetDeviceAddress(info_buffer_).opaque());
78 se::DeviceMemoryBase workspace_data =
79 params.buffer_allocations->GetDeviceAddress(workspace_buffer_);
80 for (int64_t i = 0; i < batch_size_; ++i) {
81 se::DeviceMemoryBase a_data =
82 se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
83 se::DeviceMemory<int> info_data(
84 se::DeviceMemoryBase(info_base + i, sizeof(int)));
85 switch (type_) {
86 case F32: {
87 TF_RETURN_IF_ERROR(
88 context->Potrf(uplo_, n_, se::DeviceMemory<float>(a_data), n_,
89 info_data, se::DeviceMemory<float>(workspace_data)));
90 break;
91 }
92 case F64: {
93 TF_RETURN_IF_ERROR(context->Potrf(
94 uplo_, n_, se::DeviceMemory<double>(a_data), n_, info_data,
95 se::DeviceMemory<double>(workspace_data)));
96 break;
97 }
98 case C64: {
99 TF_RETURN_IF_ERROR(context->Potrf(
100 uplo_, n_, se::DeviceMemory<std::complex<float>>(a_data), n_,
101 info_data, se::DeviceMemory<std::complex<float>>(workspace_data)));
102 break;
103 }
104 case C128: {
105 TF_RETURN_IF_ERROR(context->Potrf(
106 uplo_, n_, se::DeviceMemory<std::complex<double>>(a_data), n_,
107 info_data, se::DeviceMemory<std::complex<double>>(workspace_data)));
108 break;
109 }
110 default:
111 return InvalidArgument("Invalid type for cholesky %s",
112 PrimitiveType_Name(type_));
113 }
114 }
115 return Status::OK();
116 }
117
118 } // namespace gpu
119 } // namespace xla
120