1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
17
18 #include <cstdlib>
19 #include <functional>
20 #include <numeric>
21 #include <optional>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/stream_executor/blas.h"
37
38 namespace xla {
39 namespace gpu {
40
41 namespace {
42
SetFortranLayout(Shape * shape)43 void SetFortranLayout(Shape* shape) {
44 LayoutUtil::SetToDefaultLayout(shape);
45 int n = shape->mutable_layout()->minor_to_major_size();
46 CHECK_GE(n, 2);
47 std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
48 shape->mutable_layout()->mutable_minor_to_major()->at(1));
49 }
50
CreateCholesky(GpuSolverContext * context,HloInstruction * operand,const CholeskyOptions & options,const OpMetadata & metadata)51 StatusOr<HloInstruction*> CreateCholesky(GpuSolverContext* context,
52 HloInstruction* operand,
53 const CholeskyOptions& options,
54 const OpMetadata& metadata) {
55 HloComputation* computation = operand->parent();
56
57 Shape a_shape = operand->shape();
58 int ndim = a_shape.dimensions_size();
59 CHECK_GE(ndim, 2);
60 int64_t n = a_shape.dimensions(ndim - 1);
61
62 std::vector<int64_t> batch_dims(a_shape.dimensions().begin(),
63 a_shape.dimensions().end() - 2);
64 std::vector<int64_t> batch_dim_ids(batch_dims.size());
65 absl::c_iota(batch_dim_ids, 0);
66 int64_t batch_size = absl::c_accumulate(batch_dims, 1, std::multiplies<>{});
67
68 // Find the workspace size.
69 se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
70 : se::blas::UpperLower::kUpper;
71 int64_t workspace_size; // Number of elements of size a_shape.element_type()
72 TF_ASSIGN_OR_RETURN(
73 workspace_size,
74 context->PotrfBufferSize(a_shape.element_type(), uplo, n, n, batch_size));
75
76 // TODO(phawkins): Ideally we would relax this constraint. What we actually
77 // want is that:
78 // a) the batch dimensions are major, in no particular order.
79 // b) the two minor dimensions are in fortran (column-major) order,
80
81 SetFortranLayout(&a_shape);
82
83 // This call returns a tuple of (cholesky_result, workspace, info) where:
84 // * cholesky_result is the result of the Cholesky decomposition,
85 // * workspace is temporary scratch memory used by cuSolver.
86 // * info contains the Potrf success/failure status.
87 // Currently we have no meaningful way to report an error, so we simply
88 // discard the success/failure information. Obviously this is suboptimal.
89 Shape info_shape = ShapeUtil::MakeShape(S32, batch_dims);
90 Shape call_shape = ShapeUtil::MakeTupleShape(
91 {a_shape,
92 ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
93 info_shape});
94
95 HloInstruction* custom_call =
96 computation->AddInstruction(HloInstruction::CreateCustomCall(
97 call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
98 custom_call->set_metadata(metadata);
99 TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
100 HloInstruction* out = computation->AddInstruction(
101 HloInstruction::CreateGetTupleElement(a_shape, custom_call, 0));
102 HloInstruction* info = computation->AddInstruction(
103 HloInstruction::CreateGetTupleElement(info_shape, custom_call, 2));
104
105 // If info was non-zero, indicating that the Cholesky decomposition failed,
106 // returns an array full of NaNs for the corresponding batch element.
107 HloInstruction* zero = computation->AddInstruction(
108 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
109 HloInstruction* zeros =
110 computation->AddInstruction(HloInstruction::CreateBroadcast(
111 info_shape, zero, /*broadcast_dimensions=*/{}));
112 HloInstruction* ok = computation->AddInstruction(
113 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, batch_dims),
114 info, zeros, ComparisonDirection::kEq));
115 ok = computation->AddInstruction(HloInstruction::CreateBroadcast(
116 ShapeUtil::MakeShape(PRED, a_shape.dimensions()), ok,
117 /*broadcast_dimensions=*/batch_dim_ids));
118
119 TF_ASSIGN_OR_RETURN(Literal nan_literal,
120 LiteralUtil::NanValue(a_shape.element_type()));
121 HloInstruction* nan = computation->AddInstruction(
122 HloInstruction::CreateConstant(std::move(nan_literal)));
123 HloInstruction* nans =
124 computation->AddInstruction(HloInstruction::CreateBroadcast(
125 a_shape, nan, /*broadcast_dimensions=*/{}));
126
127 HloInstruction* select =
128 computation->AddInstruction(HloInstruction::CreateTernary(
129 a_shape, HloOpcode::kSelect, ok, out, nans));
130 return select;
131 }
132
133 // Tries to rewrite a single convolution into a call to cudnn.
RunOnInstruction(GpuSolverContext * context,HloInstruction * instruction)134 StatusOr<bool> RunOnInstruction(GpuSolverContext* context,
135 HloInstruction* instruction) {
136 if (instruction->opcode() != HloOpcode::kCholesky) {
137 return false;
138 }
139
140 TF_ASSIGN_OR_RETURN(
141 HloInstruction * custom_call,
142 CreateCholesky(context, instruction->mutable_operand(0),
143 instruction->cholesky_options(), instruction->metadata()));
144
145 VLOG(1) << "Replacing " << instruction->ToString() << " with "
146 << custom_call->ToString();
147
148 TF_RETURN_IF_ERROR(
149 instruction->parent()->ReplaceInstruction(instruction, custom_call));
150 return true;
151 }
152
153 } // namespace
154
155 // Rewrites the convolutions in the given computation into calls to cudnn.
156 // Returns true if it made any changes.
RunOnComputation(HloComputation * computation)157 StatusOr<bool> GpusolverRewriter::RunOnComputation(
158 HloComputation* computation) {
159 std::vector<HloInstruction*> cusolver_calls;
160 for (auto* hlo : computation->instructions()) {
161 if (hlo->opcode() == HloOpcode::kCholesky) {
162 cusolver_calls.push_back(hlo);
163 }
164 }
165
166 if (cusolver_calls.empty()) {
167 return false;
168 }
169
170 TF_ASSIGN_OR_RETURN(GpuSolverContext context,
171 GpuSolverContext::Create(/*stream=*/nullptr));
172
173 bool changed = false;
174 for (HloInstruction* instruction : cusolver_calls) {
175 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(&context, instruction));
176 changed |= result;
177 }
178 return changed;
179 }
180
181 GpusolverRewriter::GpusolverRewriter() = default;
182
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)183 StatusOr<bool> GpusolverRewriter::Run(
184 HloModule* module,
185 const absl::flat_hash_set<absl::string_view>& execution_threads) {
186 bool changed = false;
187 for (HloComputation* computation :
188 module->MakeNonfusionComputations(execution_threads)) {
189 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
190 changed |= result;
191 }
192 return changed;
193 }
194
195 } // namespace gpu
196 } // namespace xla
197