• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27 #include "tensorflow/stream_executor/blas.h"
28 #include "tensorflow/stream_executor/device_memory.h"
29 
30 namespace xla {
31 namespace gpu {
32 
TriangularSolveThunk(ThunkInfo thunk_info,const TriangularSolveOptions & options,const BufferAllocation::Slice & a_buffer,const BufferAllocation::Slice & b_buffer,PrimitiveType type,int64_t batch_size,int64_t m,int64_t n,int64_t a_batch_stride,int64_t b_batch_stride)33 TriangularSolveThunk::TriangularSolveThunk(
34     ThunkInfo thunk_info, const TriangularSolveOptions& options,
35     const BufferAllocation::Slice& a_buffer,
36     const BufferAllocation::Slice& b_buffer, PrimitiveType type,
37     int64_t batch_size, int64_t m, int64_t n, int64_t a_batch_stride,
38     int64_t b_batch_stride)
39     : Thunk(Kind::kTriangularSolve, thunk_info),
40       uplo_(options.lower() ? se::blas::UpperLower::kLower
41                             : se::blas::UpperLower::kUpper),
42       side_(options.left_side() ? se::blas::Side::kLeft
43                                 : se::blas::Side::kRight),
44       unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit
45                                              : se::blas::Diagonal::kNonUnit),
46       a_buffer_(a_buffer),
47       b_buffer_(b_buffer),
48       type_(type),
49       batch_size_(batch_size),
50       m_(m),
51       n_(n),
52       a_batch_stride_(a_batch_stride),
53       b_batch_stride_(b_batch_stride) {
54   transpose_a_ = [&] {
55     switch (options.transpose_a()) {
56       case TriangularSolveOptions::NO_TRANSPOSE:
57         return se::blas::Transpose::kNoTranspose;
58       case TriangularSolveOptions::TRANSPOSE:
59         return se::blas::Transpose::kTranspose;
60       case TriangularSolveOptions::ADJOINT:
61         return se::blas::Transpose::kConjugateTranspose;
62       default:
63         LOG(ERROR) << "Invalid triangular solve transpose value "
64                    << options.transpose_a();
65         return se::blas::Transpose::kNoTranspose;
66     }
67   }();
68 }
69 
ExecuteOnStream(const ExecuteParams & params)70 Status TriangularSolveThunk::ExecuteOnStream(const ExecuteParams& params) {
71   auto& stream = *params.stream;
72   auto& buffer_allocations = *params.buffer_allocations;
73 
74   VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
75           << " side=" << se::blas::SideString(side_)
76           << " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
77           << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_
78           << " a_batch_stride=" << a_batch_stride_
79           << " b_batch_stride=" << b_batch_stride_;
80 
81   const int lda = side_ == se::blas::Side::kLeft ? m_ : n_;
82   const int ldb = m_;
83 
84   char* a_base = static_cast<char*>(
85       buffer_allocations.GetDeviceAddress(a_buffer_).opaque());
86   char* b_base = static_cast<char*>(
87       buffer_allocations.GetDeviceAddress(b_buffer_).opaque());
88   for (int64_t i = 0; i < batch_size_; ++i) {
89     bool launch_ok;
90     se::DeviceMemoryBase a_data =
91         se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
92     se::DeviceMemoryBase b_data =
93         se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_);
94     switch (type_) {
95       case F32: {
96         se::DeviceMemory<float> b_data_typed(b_data);
97         launch_ok = stream
98                         .ThenBlasTrsm(side_, uplo_, transpose_a_,
99                                       unit_diagonal_, m_, n_, /*alpha=*/1.0f,
100                                       se::DeviceMemory<float>(a_data), lda,
101                                       &b_data_typed, ldb)
102                         .ok();
103         break;
104       }
105       case F64: {
106         se::DeviceMemory<double> b_data_typed(b_data);
107         launch_ok = stream
108                         .ThenBlasTrsm(side_, uplo_, transpose_a_,
109                                       unit_diagonal_, m_, n_, /*alpha=*/1.0,
110                                       se::DeviceMemory<double>(a_data), lda,
111                                       &b_data_typed, ldb)
112                         .ok();
113         break;
114       }
115       case C64: {
116         se::DeviceMemory<std::complex<float>> b_data_typed(b_data);
117         launch_ok =
118             stream
119                 .ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
120                               n_, /*alpha=*/1.0f,
121                               se::DeviceMemory<std::complex<float>>(a_data),
122                               lda, &b_data_typed, ldb)
123                 .ok();
124         break;
125       }
126       case C128: {
127         se::DeviceMemory<std::complex<double>> b_data_typed(b_data);
128         launch_ok =
129             stream
130                 .ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
131                               n_, /*alpha=*/1.0,
132                               se::DeviceMemory<std::complex<double>>(a_data),
133                               lda, &b_data_typed, ldb)
134                 .ok();
135         break;
136       }
137       default:
138         return InvalidArgument("Invalid type for triangular solve %d", type_);
139     }
140     if (!launch_ok) {
141       return InternalError("Unable to launch triangular solve for thunk %p",
142                            this);
143     }
144   }
145   return Status::OK();
146 }
147 
148 }  // namespace gpu
149 }  // namespace xla
150