• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30 
31 namespace xla {
32 namespace gpu {
33 
TriangularSolveThunk(ThunkInfo thunk_info,const TriangularSolveOptions & options,const BufferAllocation::Slice & a_buffer,const BufferAllocation::Slice & b_buffer,PrimitiveType type,int64 batch_size,int64 m,int64 n,int64 a_batch_stride,int64 b_batch_stride)34 TriangularSolveThunk::TriangularSolveThunk(
35     ThunkInfo thunk_info, const TriangularSolveOptions& options,
36     const BufferAllocation::Slice& a_buffer,
37     const BufferAllocation::Slice& b_buffer, PrimitiveType type,
38     int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
39     int64 b_batch_stride)
40     : Thunk(Kind::kTriangularSolve, thunk_info),
41       uplo_(options.lower() ? se::blas::UpperLower::kLower
42                             : se::blas::UpperLower::kUpper),
43       side_(options.left_side() ? se::blas::Side::kLeft
44                                 : se::blas::Side::kRight),
45       unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit
46                                              : se::blas::Diagonal::kNonUnit),
47       a_buffer_(a_buffer),
48       b_buffer_(b_buffer),
49       type_(type),
50       batch_size_(batch_size),
51       m_(m),
52       n_(n),
53       a_batch_stride_(a_batch_stride),
54       b_batch_stride_(b_batch_stride) {
55   transpose_a_ = [&] {
56     switch (options.transpose_a()) {
57       case TriangularSolveOptions::NO_TRANSPOSE:
58         return se::blas::Transpose::kNoTranspose;
59       case TriangularSolveOptions::TRANSPOSE:
60         return se::blas::Transpose::kTranspose;
61       case TriangularSolveOptions::ADJOINT:
62         return se::blas::Transpose::kConjugateTranspose;
63       default:
64         LOG(ERROR) << "Invalid triangular solve transpose value "
65                    << options.transpose_a();
66         return se::blas::Transpose::kNoTranspose;
67     }
68   }();
69 }
70 
ExecuteOnStream(const ExecuteParams & params)71 Status TriangularSolveThunk::ExecuteOnStream(const ExecuteParams& params) {
72   auto& stream = *params.stream;
73   auto& buffer_allocations = *params.buffer_allocations;
74 
75   VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
76           << " side=" << se::blas::SideString(side_)
77           << " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
78           << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_
79           << " a_batch_stride=" << a_batch_stride_
80           << " b_batch_stride=" << b_batch_stride_;
81 
82   const int lda = side_ == se::blas::Side::kLeft ? m_ : n_;
83   const int ldb = m_;
84 
85   char* a_base = static_cast<char*>(
86       buffer_allocations.GetDeviceAddress(a_buffer_).opaque());
87   char* b_base = static_cast<char*>(
88       buffer_allocations.GetDeviceAddress(b_buffer_).opaque());
89   for (int64 i = 0; i < batch_size_; ++i) {
90     bool launch_ok;
91     se::DeviceMemoryBase a_data =
92         se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
93     se::DeviceMemoryBase b_data =
94         se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_);
95     switch (type_) {
96       case F32: {
97         se::DeviceMemory<float> b_data_typed(b_data);
98         launch_ok = stream
99                         .ThenBlasTrsm(side_, uplo_, transpose_a_,
100                                       unit_diagonal_, m_, n_, /*alpha=*/1.0f,
101                                       se::DeviceMemory<float>(a_data), lda,
102                                       &b_data_typed, ldb)
103                         .ok();
104         break;
105       }
106       case F64: {
107         se::DeviceMemory<double> b_data_typed(b_data);
108         launch_ok = stream
109                         .ThenBlasTrsm(side_, uplo_, transpose_a_,
110                                       unit_diagonal_, m_, n_, /*alpha=*/1.0,
111                                       se::DeviceMemory<double>(a_data), lda,
112                                       &b_data_typed, ldb)
113                         .ok();
114         break;
115       }
116       case C64: {
117         se::DeviceMemory<std::complex<float>> b_data_typed(b_data);
118         launch_ok =
119             stream
120                 .ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
121                               n_, /*alpha=*/1.0f,
122                               se::DeviceMemory<std::complex<float>>(a_data),
123                               lda, &b_data_typed, ldb)
124                 .ok();
125         break;
126       }
127       case C128: {
128         se::DeviceMemory<std::complex<double>> b_data_typed(b_data);
129         launch_ok =
130             stream
131                 .ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
132                               n_, /*alpha=*/1.0,
133                               se::DeviceMemory<std::complex<double>>(a_data),
134                               lda, &b_data_typed, ldb)
135                 .ok();
136         break;
137       }
138       default:
139         return InvalidArgument("Invalid type for triangular solve %d", type_);
140     }
141     if (!launch_ok) {
142       return InternalError("Unable to launch triangular solve for thunk %p",
143                            this);
144     }
145   }
146   return Status::OK();
147 }
148 
149 }  // namespace gpu
150 }  // namespace xla
151