1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30
31 namespace xla {
32 namespace gpu {
33
TriangularSolveThunk(ThunkInfo thunk_info,const TriangularSolveOptions & options,const BufferAllocation::Slice & a_buffer,const BufferAllocation::Slice & b_buffer,PrimitiveType type,int64 batch_size,int64 m,int64 n,int64 a_batch_stride,int64 b_batch_stride)34 TriangularSolveThunk::TriangularSolveThunk(
35 ThunkInfo thunk_info, const TriangularSolveOptions& options,
36 const BufferAllocation::Slice& a_buffer,
37 const BufferAllocation::Slice& b_buffer, PrimitiveType type,
38 int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
39 int64 b_batch_stride)
40 : Thunk(Kind::kTriangularSolve, thunk_info),
41 uplo_(options.lower() ? se::blas::UpperLower::kLower
42 : se::blas::UpperLower::kUpper),
43 side_(options.left_side() ? se::blas::Side::kLeft
44 : se::blas::Side::kRight),
45 unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit
46 : se::blas::Diagonal::kNonUnit),
47 a_buffer_(a_buffer),
48 b_buffer_(b_buffer),
49 type_(type),
50 batch_size_(batch_size),
51 m_(m),
52 n_(n),
53 a_batch_stride_(a_batch_stride),
54 b_batch_stride_(b_batch_stride) {
55 transpose_a_ = [&] {
56 switch (options.transpose_a()) {
57 case TriangularSolveOptions::NO_TRANSPOSE:
58 return se::blas::Transpose::kNoTranspose;
59 case TriangularSolveOptions::TRANSPOSE:
60 return se::blas::Transpose::kTranspose;
61 case TriangularSolveOptions::ADJOINT:
62 return se::blas::Transpose::kConjugateTranspose;
63 default:
64 LOG(ERROR) << "Invalid triangular solve transpose value "
65 << options.transpose_a();
66 return se::blas::Transpose::kNoTranspose;
67 }
68 }();
69 }
70
ExecuteOnStream(const ExecuteParams & params)71 Status TriangularSolveThunk::ExecuteOnStream(const ExecuteParams& params) {
72 auto& stream = *params.stream;
73 auto& buffer_allocations = *params.buffer_allocations;
74
75 VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
76 << " side=" << se::blas::SideString(side_)
77 << " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
78 << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_
79 << " a_batch_stride=" << a_batch_stride_
80 << " b_batch_stride=" << b_batch_stride_;
81
82 const int lda = side_ == se::blas::Side::kLeft ? m_ : n_;
83 const int ldb = m_;
84
85 char* a_base = static_cast<char*>(
86 buffer_allocations.GetDeviceAddress(a_buffer_).opaque());
87 char* b_base = static_cast<char*>(
88 buffer_allocations.GetDeviceAddress(b_buffer_).opaque());
89 for (int64 i = 0; i < batch_size_; ++i) {
90 bool launch_ok;
91 se::DeviceMemoryBase a_data =
92 se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
93 se::DeviceMemoryBase b_data =
94 se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_);
95 switch (type_) {
96 case F32: {
97 se::DeviceMemory<float> b_data_typed(b_data);
98 launch_ok = stream
99 .ThenBlasTrsm(side_, uplo_, transpose_a_,
100 unit_diagonal_, m_, n_, /*alpha=*/1.0f,
101 se::DeviceMemory<float>(a_data), lda,
102 &b_data_typed, ldb)
103 .ok();
104 break;
105 }
106 case F64: {
107 se::DeviceMemory<double> b_data_typed(b_data);
108 launch_ok = stream
109 .ThenBlasTrsm(side_, uplo_, transpose_a_,
110 unit_diagonal_, m_, n_, /*alpha=*/1.0,
111 se::DeviceMemory<double>(a_data), lda,
112 &b_data_typed, ldb)
113 .ok();
114 break;
115 }
116 case C64: {
117 se::DeviceMemory<std::complex<float>> b_data_typed(b_data);
118 launch_ok =
119 stream
120 .ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
121 n_, /*alpha=*/1.0f,
122 se::DeviceMemory<std::complex<float>>(a_data),
123 lda, &b_data_typed, ldb)
124 .ok();
125 break;
126 }
127 case C128: {
128 se::DeviceMemory<std::complex<double>> b_data_typed(b_data);
129 launch_ok =
130 stream
131 .ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
132 n_, /*alpha=*/1.0,
133 se::DeviceMemory<std::complex<double>>(a_data),
134 lda, &b_data_typed, ldb)
135 .ok();
136 break;
137 }
138 default:
139 return InvalidArgument("Invalid type for triangular solve %d", type_);
140 }
141 if (!launch_ok) {
142 return InternalError("Unable to launch triangular solve for thunk %p",
143 this);
144 }
145 }
146 return Status::OK();
147 }
148
149 } // namespace gpu
150 } // namespace xla
151